Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 77 additions & 62 deletions agentstack/cli/cli.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Optional
import os, sys
import os
import sys
import time
from datetime import datetime
from pathlib import Path

import json
import shutil
Expand All @@ -26,8 +26,9 @@
from agentstack import inputs
from agentstack.agents import get_all_agents
from agentstack.tasks import get_all_tasks
from agentstack.utils import open_json_file, term_color, is_snake_case, get_framework
from agentstack.utils import open_json_file, term_color, is_snake_case, get_framework, validator_not_empty
from agentstack.proj_templates import TemplateConfig
from agentstack.exceptions import ValidationError


PREFERRED_MODELS = [
Expand Down Expand Up @@ -184,6 +185,75 @@ def ask_framework() -> str:
return framework


def get_validated_input(
message: str,
validate_func=None,
min_length: int = 0,
snake_case: bool = False,
) -> str:
"""Helper function to get validated input from user.

Args:
message: The prompt message to display
validate_func: Optional custom validation function
min_length: Minimum length requirement (0 for no requirement)
snake_case: Whether to enforce snake_case naming
"""
while True:
try:
value = inquirer.text(
message=message,
validate=validate_func or validator_not_empty(min_length) if min_length else None,
)
if snake_case and not is_snake_case(value):
raise ValidationError("Input must be in snake_case")
return value
except ValidationError as e:
print(term_color(f"Error: {str(e)}", 'red'))


def ask_agent_details():
agent = {}

agent['name'] = get_validated_input(
"What's the name of this agent? (snake_case)", min_length=3, snake_case=True
)

agent['role'] = get_validated_input("What role does this agent have?", min_length=3)

agent['goal'] = get_validated_input("What is the goal of the agent?", min_length=10)

agent['backstory'] = get_validated_input("Give your agent a backstory", min_length=10)

agent['model'] = inquirer.list_input(
message="What LLM should this agent use?", choices=PREFERRED_MODELS, default=PREFERRED_MODELS[0]
)

return agent


def ask_task_details(agents: list[dict]) -> dict:
task = {}

task['name'] = get_validated_input(
"What's the name of this task? (snake_case)", min_length=3, snake_case=True
)

task['description'] = get_validated_input("Describe the task in more detail", min_length=10)

task['expected_output'] = get_validated_input(
"What do you expect the result to look like? (ex: A 5 bullet point summary of the email)",
min_length=10,
)

task['agent'] = inquirer.list_input(
message="Which agent should be assigned this task?",
choices=[a['name'] for a in agents],
)

return task


def ask_design() -> dict:
use_wizard = inquirer.confirm(
message="Would you like to use the CLI wizard to set up agents and tasks?",
Expand All @@ -208,39 +278,10 @@ def ask_design() -> dict:
while make_agent:
print('---')
print(f"Agent #{len(agents)+1}")

agent_incomplete = True
agent = None
while agent_incomplete:
agent = inquirer.prompt(
[
inquirer.Text("name", message="What's the name of this agent? (snake_case)"),
inquirer.Text("role", message="What role does this agent have?"),
inquirer.Text("goal", message="What is the goal of the agent?"),
inquirer.Text("backstory", message="Give your agent a backstory"),
# TODO: make a list - #2
inquirer.Text(
'model',
message="What LLM should this agent use? (any LiteLLM provider)",
default="openai/gpt-4",
),
# inquirer.List("model", message="What LLM should this agent use? (any LiteLLM provider)", choices=[
# 'mixtral_llm',
# 'mixtral_llm',
# ]),
]
)

if not agent['name'] or agent['name'] == '':
print(term_color("Error: Agent name is required - Try again", 'red'))
agent_incomplete = True
elif not is_snake_case(agent['name']):
print(term_color("Error: Agent name must be snake case - Try again", 'red'))
else:
agent_incomplete = False

make_agent = inquirer.confirm(message="Create another agent?")
agent = ask_agent_details()
agents.append(agent)
make_agent = inquirer.confirm(message="Create another agent?")

print('')
for x in range(3):
Expand All @@ -257,35 +298,9 @@ def ask_design() -> dict:
while make_task:
print('---')
print(f"Task #{len(tasks) + 1}")

task_incomplete = True
task = None
while task_incomplete:
task = inquirer.prompt(
[
inquirer.Text("name", message="What's the name of this task? (snake_case)"),
inquirer.Text("description", message="Describe the task in more detail"),
inquirer.Text(
"expected_output",
message="What do you expect the result to look like? (ex: A 5 bullet point summary of the email)",
),
inquirer.List(
"agent",
message="Which agent should be assigned this task?",
choices=[a['name'] for a in agents], # type: ignore
),
]
)

if not task['name'] or task['name'] == '':
print(term_color("Error: Task name is required - Try again", 'red'))
elif not is_snake_case(task['name']):
print(term_color("Error: Task name must be snake case - Try again", 'red'))
else:
task_incomplete = False

make_task = inquirer.confirm(message="Create another task?")
task = ask_task_details(agents)
tasks.append(task)
make_task = inquirer.confirm(message="Create another task?")

print('')
for x in range(3):
Expand Down
16 changes: 14 additions & 2 deletions agentstack/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import Optional, Union
import os, sys
import os
import sys
import json
from ruamel.yaml import YAML
import re
from importlib.metadata import version
from pathlib import Path
import importlib.resources
from agentstack import conf
from inquirer import errors as inquirer_errors


def get_version(package: str = 'agentstack'):
Expand Down Expand Up @@ -108,3 +109,14 @@ def term_color(text: str, color: str) -> str:

def is_snake_case(string: str):
return bool(re.match('^[a-z0-9_]+$', string))


def validator_not_empty(min_length=1):
def validator(_, answer):
if len(answer) < min_length:
raise inquirer_errors.ValidationError(
'', reason=f"This field must be at least {min_length} characters long."
)
return True

return validator
42 changes: 42 additions & 0 deletions tests/cli_test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import os, sys
import subprocess

CLI_ENTRY = [
sys.executable,
"-m",
"agentstack.main",
]

def run_cli(*args):
"""Helper method to run the CLI with arguments. Cross-platform."""
try:
# Use shell=True on Windows to handle path issues
if sys.platform == 'win32':
# Add PYTHONIOENCODING to the environment
env = os.environ.copy()
env['PYTHONIOENCODING'] = 'utf-8'
result = subprocess.run(
" ".join(str(arg) for arg in CLI_ENTRY + list(args)),
capture_output=True,
text=True,
shell=True,
env=env,
encoding='utf-8'
)
else:
result = subprocess.run(
[*CLI_ENTRY, *args],
capture_output=True,
text=True,
encoding='utf-8'
)

if result.returncode != 0:
print(f"Command failed with code {result.returncode}")
print(f"STDOUT: {result.stdout}")
print(f"STDERR: {result.stderr}")

return result
except Exception as e:
print(f"Exception running command: {e}")
raise
21 changes: 7 additions & 14 deletions tests/test_cli_init.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,26 @@
import subprocess
import os, sys
import unittest
from parameterized import parameterized
from pathlib import Path
import shutil
from cli_test_utils import run_cli

BASE_PATH = Path(__file__).parent
CLI_ENTRY = [
sys.executable,
"-m",
"agentstack.main",
]


class CLIInitTest(unittest.TestCase):
def setUp(self):
self.project_dir = Path(BASE_PATH / 'tmp/cli_init')
os.makedirs(self.project_dir)
os.chdir(BASE_PATH) # Change to parent directory first
os.makedirs(self.project_dir, exist_ok=True)
os.chdir(self.project_dir)
# Force UTF-8 encoding for the test environment
os.environ['PYTHONIOENCODING'] = 'utf-8'

def tearDown(self):
shutil.rmtree(self.project_dir)

def _run_cli(self, *args):
"""Helper method to run the CLI with arguments."""
return subprocess.run([*CLI_ENTRY, *args], capture_output=True, text=True)
shutil.rmtree(self.project_dir, ignore_errors=True)

def test_init_command(self):
"""Test the 'init' command to create a project directory."""
result = self._run_cli('init', 'test_project')
result = run_cli('init', 'test_project')
self.assertEqual(result.returncode, 0)
self.assertTrue((self.project_dir / 'test_project').exists())
21 changes: 6 additions & 15 deletions tests/test_cli_loads.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,16 @@
import unittest
from pathlib import Path
import shutil
from cli_test_utils import run_cli

BASE_PATH = Path(__file__).parent


class TestAgentStackCLI(unittest.TestCase):
CLI_ENTRY = [
sys.executable,
"-m",
"agentstack.main",
]

def run_cli(self, *args):
"""Helper method to run the CLI with arguments."""
result = subprocess.run([*self.CLI_ENTRY, *args], capture_output=True, text=True)
return result

def test_version(self):
"""Test the --version command."""
result = self.run_cli("--version")
result = run_cli("--version")
print(result.stdout)
print(result.stderr)
print(result.returncode)
Expand All @@ -30,27 +21,27 @@ def test_version(self):

def test_invalid_command(self):
"""Test an invalid command gracefully exits."""
result = self.run_cli("invalid_command")
result = run_cli("invalid_command")
self.assertNotEqual(result.returncode, 0)
self.assertIn("usage:", result.stderr)

def test_run_command_invalid_project(self):
"""Test the 'run' command on an invalid project."""
test_dir = Path(BASE_PATH / 'tmp/test_project')
if test_dir.exists():
shutil.rmtree(test_dir)
shutil.rmtree(test_dir, ignore_errors=True)
os.makedirs(test_dir)

# Write a basic agentstack.json file
with (test_dir / 'agentstack.json').open('w') as f:
f.write(open(BASE_PATH / 'fixtures/agentstack.json', 'r').read())

os.chdir(test_dir)
result = self.run_cli('run')
result = run_cli('run')
self.assertNotEqual(result.returncode, 0)
self.assertIn("Project validation failed", result.stdout)

shutil.rmtree(test_dir)
shutil.rmtree(test_dir, ignore_errors=True)


if __name__ == "__main__":
Expand Down
Loading
Loading