From 09f1787f0e7040d8dd32e11430e30a7e2005341d Mon Sep 17 00:00:00 2001 From: Travis Dent Date: Mon, 2 Dec 2024 14:28:26 -0800 Subject: [PATCH 01/14] Validate framework entrypoint file before execution. (#60) Start to build frameworks abstraction module. --- agentstack/cli/__init__.py | 2 +- agentstack/cli/cli.py | 21 +++++++++ agentstack/frameworks/__init__.py | 41 ++++++++++++++++ agentstack/frameworks/crewai.py | 59 ++++++++++++++++++++++++ agentstack/generation/gen_utils.py | 13 ++---- agentstack/generation/tool_generation.py | 22 ++------- agentstack/main.py | 6 +-- tests/fixtures/agentstack.json | 2 +- tests/test_cli_loads.py | 26 +++++++++-- tests/test_generation_files.py | 2 +- 10 files changed, 158 insertions(+), 36 deletions(-) create mode 100644 agentstack/frameworks/__init__.py create mode 100644 agentstack/frameworks/crewai.py diff --git a/agentstack/cli/__init__.py b/agentstack/cli/__init__.py index 3c35ec37..e283f5d1 100644 --- a/agentstack/cli/__init__.py +++ b/agentstack/cli/__init__.py @@ -1 +1 @@ -from .cli import init_project_builder, list_tools +from .cli import init_project_builder, list_tools, run_project diff --git a/agentstack/cli/cli.py b/agentstack/cli/cli.py index 9b560d16..5bd0370d 100644 --- a/agentstack/cli/cli.py +++ b/agentstack/cli/cli.py @@ -4,6 +4,7 @@ import time from datetime import datetime from typing import Optional +from pathlib import Path import requests import itertools @@ -17,6 +18,8 @@ from agentstack.logger import log from agentstack.utils import get_package_path from agentstack.generation.tool_generation import get_all_tools +from agentstack import frameworks +import agentstack.frameworks.crewai from .. import generation from ..utils import open_json_file, term_color, is_snake_case @@ -114,6 +117,24 @@ def welcome_message(): print(border) +def run_project(framework: str, path: str = ''): + """Validate that the project is ready to run and then run it.""" + if not framework in frameworks.SUPPORTED_FRAMEWORKS: + print(term_color(f"Framework {framework} is not supported by agentstack.", 'red')) + sys.exit(1) + + try: + frameworks.validate_project(framework, path) + except frameworks.ValidationError as e: + print(term_color("Project validation failed:", 'red')) + print(e) + sys.exit(1) + + path = Path(path) + entrypoint = path/frameworks.get_entrypoint_path(framework) + os.system(f'python {entrypoint}') + + def ask_framework() -> str: framework = "CrewAI" # framework = inquirer.list_input( diff --git a/agentstack/frameworks/__init__.py b/agentstack/frameworks/__init__.py new file mode 100644 index 00000000..8c4779b4 --- /dev/null +++ b/agentstack/frameworks/__init__.py @@ -0,0 +1,41 @@ +""" +Methods for interacting with framework-specific features. + +Each framework should have a module in the `frameworks` package which defines the following methods: + +- `ENTRYPOINT`: Path: Relative path to the entrypoint file for the framework +- `validate_project(path: Optional[Path] = None) -> None`: Validate that a project is ready to run. + Raises a `ValidationError` if the project is not valid. +""" +from typing import Optional +from importlib import import_module +from pathlib import Path + + +CREWAI = 'crewai' +SUPPORTED_FRAMEWORKS = [CREWAI, ] + +def get_framework_module(framework: str) -> import_module: + """ + Get the module for a framework. + """ + if framework == CREWAI: + from . import crewai + return crewai + else: + raise ValueError(f"Framework {framework} not supported") + +def get_entrypoint_path(framework: str) -> Path: + """ + Get the path to the entrypoint file for a framework. + """ + return get_framework_module(framework).ENTRYPOINT + +class ValidationError(Exception): pass + +def validate_project(framework: str, path: Optional[Path] = None) -> None: + """ + Run the framework specific project validation. + """ + return get_framework_module(framework).validate_project(path) + diff --git a/agentstack/frameworks/crewai.py b/agentstack/frameworks/crewai.py new file mode 100644 index 00000000..9a8515af --- /dev/null +++ b/agentstack/frameworks/crewai.py @@ -0,0 +1,59 @@ +from typing import Optional +from pathlib import Path +import ast +from . import SUPPORTED_FRAMEWORKS, ValidationError + + +ENTRYPOINT: Path = Path('src/crew.py') + +def validate_project(path: Optional[Path] = None) -> None: + """ + Validate that a CrewAI project is ready to run. + Raises a frameworks.VaidationError if the project is not valid. + """ + try: + if path is None: path = Path() + with open(path/ENTRYPOINT, 'r') as f: + tree = ast.parse(f.read()) + except (FileNotFoundError, SyntaxError) as e: + raise ValidationError(f"Failed to parse {ENTRYPOINT}\n {e}") + + # A valid project must have a class in the crew.py file decorated with `@CrewBase` + try: + class_node = _find_class_with_decorator(tree, 'CrewBase')[0] + except IndexError: + raise ValidationError(f"`@CrewBase` decorated class not found in {ENTRYPOINT}") + + # The Crew class must have one or more methods decorated with `@agent` + if len(_find_decorated_method_in_class(class_node, 'task')) < 1: + raise ValidationError(f"`@task` decorated method not found in `{class_node.name}` class in {ENTRYPOINT}") + + # The Crew class must have one or more methods decorated with `@agent` + if len(_find_decorated_method_in_class(class_node, 'agent')) < 1: + raise ValidationError(f"`@agent` decorated method not found in `{class_node.name}` class in {ENTRYPOINT}") + + # The Crew class must have one method decorated with `@crew` + if len(_find_decorated_method_in_class(class_node, 'crew')) < 1: + raise ValidationError(f"`@crew` decorated method not found in `{class_node.name}` class in {ENTRYPOINT}") + +# TODO move these to a shared AST utility module +def _find_class_with_decorator(tree: ast.AST, decorator_name: str) -> list[ast.ClassDef]: + """Find a class definition that is marked by a decorator in an AST.""" + nodes = [] + for node in ast.iter_child_nodes(tree): + if isinstance(node, ast.ClassDef): + for decorator in node.decorator_list: + if isinstance(decorator, ast.Name) and decorator.id == decorator_name: + nodes.append(node) + return nodes + +def _find_decorated_method_in_class(classdef: ast.ClassDef, decorator_name: str) -> list[ast.FunctionDef]: + """Find all method definitions in a class definition which are decorated with a specific decorator.""" + nodes = [] + for node in ast.iter_child_nodes(classdef): + if isinstance(node, ast.FunctionDef): + for decorator in node.decorator_list: + if isinstance(decorator, ast.Name) and decorator.id == decorator_name: + nodes.append(node) + return nodes + diff --git a/agentstack/generation/gen_utils.py b/agentstack/generation/gen_utils.py index f9ac0e5f..dc9ac38c 100644 --- a/agentstack/generation/gen_utils.py +++ b/agentstack/generation/gen_utils.py @@ -2,8 +2,10 @@ import sys from enum import Enum from typing import Optional, Union, List +from pathlib import Path from agentstack.utils import term_color +from agentstack import frameworks def insert_code_after_tag(file_path, tag, code_to_insert, next_line=False): @@ -72,14 +74,6 @@ def string_in_file(file_path: str, str_to_match: str) -> bool: return str_to_match in file_content -def _framework_filename(framework: str, path: str = ''): - if framework == 'crewai': - return f'{path}src/crew.py' - - print(term_color(f'Unknown framework: {framework}', 'red')) - sys.exit(1) - - class CrewComponent(str, Enum): AGENT = "agent" TASK = "task" @@ -103,7 +97,8 @@ def get_crew_components( Returns: Dictionary with 'agents' and 'tasks' keys containing lists of names """ - filename = _framework_filename(framework, path) + path = Path(path) + filename = path/frameworks.get_entrypoint_path(framework) # Convert single component type to list for consistent handling if isinstance(component_type, CrewComponent): diff --git a/agentstack/generation/tool_generation.py b/agentstack/generation/tool_generation.py index c4e19fc0..68a2472f 100644 --- a/agentstack/generation/tool_generation.py +++ b/agentstack/generation/tool_generation.py @@ -7,7 +7,7 @@ from typing import Optional, List, Dict, Union from . import get_agent_names -from .gen_utils import insert_code_after_tag, string_in_file, _framework_filename +from .gen_utils import insert_code_after_tag, string_in_file from ..utils import open_json_file, get_framework, term_color import os import shutil @@ -18,27 +18,13 @@ from agentstack.utils import get_package_path from agentstack.generation.files import ConfigFile, EnvFile +from agentstack import frameworks from .gen_utils import insert_code_after_tag, string_in_file from ..utils import open_json_file, get_framework, term_color TOOL_INIT_FILENAME = "src/tools/__init__.py" -FRAMEWORK_FILENAMES: dict[str, str] = { - 'crewai': 'src/crew.py', -} - -def get_framework_filename(framework: str, path: str = ''): - if path: - path = path.endswith('/') and path or path + '/' - else: - path = './' - try: - return f"{path}{FRAMEWORK_FILENAMES[framework]}" - except KeyError: - print(term_color(f'Unknown framework: {framework}', 'red')) - sys.exit(1) - class ToolConfig(BaseModel): name: str category: str @@ -106,7 +92,6 @@ def add_tool(tool_name: str, path: Optional[str] = None, agents: Optional[List[s tool_data = ToolConfig.from_tool_name(tool_name) tool_file_path = tool_data.get_impl_file_path(framework) - if tool_data.packages: os.system(f"poetry add {' '.join(tool_data.packages)}") # Install packages shutil.copy(tool_file_path, f'{path}src/tools/{tool_name}_tool.py') # Move tool from package to project @@ -376,7 +361,8 @@ def modify_agent_tools( print(term_color(f"Agent '{agent}' not found in the project.", 'red')) sys.exit(1) - filename = _framework_filename(framework, path) + path = Path(path) + filename = path/frameworks.get_entrypoint_path(framework) with open(filename, 'r') as f: source = f.read() diff --git a/agentstack/main.py b/agentstack/main.py index 14a448cf..7a855a6a 100644 --- a/agentstack/main.py +++ b/agentstack/main.py @@ -2,10 +2,11 @@ import os import sys -from agentstack.cli import init_project_builder, list_tools +from agentstack.cli import init_project_builder, list_tools, run_project from agentstack.telemetry import track_cli_command from agentstack.utils import get_version, get_framework import agentstack.generation as generation +from agentstack import frameworks import webbrowser @@ -98,8 +99,7 @@ def main(): init_project_builder(args.slug_name, args.template, args.wizard) elif args.command in ['run', 'r']: framework = get_framework() - if framework == "crewai": - os.system('python src/main.py') + run_project(framework) elif args.command in ['generate', 'g']: if args.generate_command in ['agent', 'a']: generation.generate_agent(args.name, args.role, args.goal, args.backstory, args.llm) diff --git a/tests/fixtures/agentstack.json b/tests/fixtures/agentstack.json index 4ca18a10..f39237b1 100644 --- a/tests/fixtures/agentstack.json +++ b/tests/fixtures/agentstack.json @@ -1,4 +1,4 @@ { "framework": "crewai", - "tools": ["tool1", "tool2"] + "tools": [] } \ No newline at end of file diff --git a/tests/test_cli_loads.py b/tests/test_cli_loads.py index 49bb15cd..768ee0b9 100644 --- a/tests/test_cli_loads.py +++ b/tests/test_cli_loads.py @@ -1,9 +1,10 @@ import subprocess -import sys +import os, sys import unittest from pathlib import Path import shutil +BASE_PATH = Path(__file__).parent class TestAgentStackCLI(unittest.TestCase): CLI_ENTRY = [sys.executable, "-m", "agentstack.main"] # Replace with your actual CLI entry point if different @@ -31,12 +32,14 @@ def test_invalid_command(self): def test_init_command(self): """Test the 'init' command to create a project directory.""" - test_dir = Path("test_project") + test_dir = Path(BASE_PATH/'tmp/test_project') # Ensure the directory doesn't exist from previous runs if test_dir.exists(): shutil.rmtree(test_dir) - + os.mkdir(test_dir) + + os.chdir(test_dir) result = self.run_cli("init", str(test_dir)) self.assertEqual(result.returncode, 0) self.assertTrue(test_dir.exists()) @@ -44,6 +47,23 @@ def test_init_command(self): # Clean up shutil.rmtree(test_dir) + def test_run_command_invalid_project(self): + """Test the 'run' command on an invalid project.""" + test_dir = Path(BASE_PATH/'tmp/test_project') + if test_dir.exists(): + shutil.rmtree(test_dir) + os.mkdir(test_dir) + + # Write a basic agentstack.json file + with (test_dir/'agentstack.json').open('w') as f: + f.write(open(BASE_PATH/'fixtures/agentstack.json', 'r').read()) + + os.chdir(test_dir) + result = self.run_cli('run') + self.assertNotEqual(result.returncode, 0) + self.assertIn("Project validation failed", result.stdout) + + shutil.rmtree(test_dir) if __name__ == "__main__": unittest.main() diff --git a/tests/test_generation_files.py b/tests/test_generation_files.py index 8f8549e3..6e58113b 100644 --- a/tests/test_generation_files.py +++ b/tests/test_generation_files.py @@ -12,7 +12,7 @@ class GenerationFilesTest(unittest.TestCase): def test_read_config(self): config = ConfigFile(BASE_PATH / "fixtures") # + agentstack.json assert config.framework == "crewai" - assert config.tools == ["tool1", "tool2"] + assert config.tools == [] assert config.telemetry_opt_out is None def test_write_config(self): From 8175251d85d06c567bed9e12282ae545c580216b Mon Sep 17 00:00:00 2001 From: Travis Dent Date: Mon, 2 Dec 2024 15:05:02 -0800 Subject: [PATCH 02/14] Recursive tmp directory creation in CLI tests. --- tests/test_cli_loads.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_cli_loads.py b/tests/test_cli_loads.py index 768ee0b9..13597c62 100644 --- a/tests/test_cli_loads.py +++ b/tests/test_cli_loads.py @@ -37,7 +37,7 @@ def test_init_command(self): # Ensure the directory doesn't exist from previous runs if test_dir.exists(): shutil.rmtree(test_dir) - os.mkdir(test_dir) + os.makedirs(test_dir) os.chdir(test_dir) result = self.run_cli("init", str(test_dir)) @@ -52,7 +52,7 @@ def test_run_command_invalid_project(self): test_dir = Path(BASE_PATH/'tmp/test_project') if test_dir.exists(): shutil.rmtree(test_dir) - os.mkdir(test_dir) + os.makedirs(test_dir) # Write a basic agentstack.json file with (test_dir/'agentstack.json').open('w') as f: From 3da6973845e79e5fdf9657d45a07c10b3fceb536 Mon Sep 17 00:00:00 2001 From: Travis Dent Date: Tue, 3 Dec 2024 10:40:12 -0800 Subject: [PATCH 03/14] Add better error messaging to crew validation. --- agentstack/frameworks/crewai.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/agentstack/frameworks/crewai.py b/agentstack/frameworks/crewai.py index 9a8515af..249caca2 100644 --- a/agentstack/frameworks/crewai.py +++ b/agentstack/frameworks/crewai.py @@ -26,11 +26,15 @@ def validate_project(path: Optional[Path] = None) -> None: # The Crew class must have one or more methods decorated with `@agent` if len(_find_decorated_method_in_class(class_node, 'task')) < 1: - raise ValidationError(f"`@task` decorated method not found in `{class_node.name}` class in {ENTRYPOINT}") + raise ValidationError( + f"`@task` decorated method not found in `{class_node.name}` class in {ENTRYPOINT}.\n" + "Create a new task using `agentstack generate task `.") # The Crew class must have one or more methods decorated with `@agent` if len(_find_decorated_method_in_class(class_node, 'agent')) < 1: - raise ValidationError(f"`@agent` decorated method not found in `{class_node.name}` class in {ENTRYPOINT}") + raise ValidationError( + f"`@agent` decorated method not found in `{class_node.name}` class in {ENTRYPOINT}.\n" + "Create a new agent using `agentstack generate agent `.") # The Crew class must have one method decorated with `@crew` if len(_find_decorated_method_in_class(class_node, 'crew')) < 1: From 96f0adac0ab712bac2458604cccc89fc268ed685 Mon Sep 17 00:00:00 2001 From: Travis Dent Date: Tue, 3 Dec 2024 14:45:41 -0800 Subject: [PATCH 04/14] Frameworks AST progress (amend this) --- agentstack/frameworks/__init__.py | 55 +++++++++++- agentstack/frameworks/crewai.py | 107 +++++++++++++++-------- agentstack/generation/astools.py | 34 +++++++ agentstack/generation/tool_generation.py | 52 ++--------- 4 files changed, 161 insertions(+), 87 deletions(-) create mode 100644 agentstack/generation/astools.py diff --git a/agentstack/frameworks/__init__.py b/agentstack/frameworks/__init__.py index 8c4779b4..211669eb 100644 --- a/agentstack/frameworks/__init__.py +++ b/agentstack/frameworks/__init__.py @@ -4,17 +4,27 @@ Each framework should have a module in the `frameworks` package which defines the following methods: - `ENTRYPOINT`: Path: Relative path to the entrypoint file for the framework -- `validate_project(path: Optional[Path] = None) -> None`: Validate that a project is ready to run. +- `validate_project(framework: str, path: Optional[Path] = None) -> None`: Validate that a project is ready to run. Raises a `ValidationError` if the project is not valid. +- `add_tool(framework: str, path: Optional[Path] = None) -> None`: Add a tool to the framework. +- `remove_tool(framework: str, path: Optional[Path] = None) -> None`: Remove a tool from the framework. +- `add_agent(framework: str, path: Optional[Path] = None) -> None`: Add an agent to the framework. +- `remove_agent(framework: str, path: Optional[Path] = None) -> None`: Remove an agent from the framework. +- `add_input(framework: str, path: Optional[Path] = None) -> None`: Add an input to the framework. +- `remove_input(framework: str, path: Optional[Path] = None) -> None`: Remove an input from the framework. """ -from typing import Optional +from typing import TYPE_CHECKING, Optional from importlib import import_module from pathlib import Path +if TYPE_CHECKING: + from agentstack.generation.tool_generation import ToolData CREWAI = 'crewai' SUPPORTED_FRAMEWORKS = [CREWAI, ] +class ValidationError(Exception): pass + def get_framework_module(framework: str) -> import_module: """ Get the module for a framework. @@ -31,11 +41,48 @@ def get_entrypoint_path(framework: str) -> Path: """ return get_framework_module(framework).ENTRYPOINT -class ValidationError(Exception): pass - def validate_project(framework: str, path: Optional[Path] = None) -> None: """ Run the framework specific project validation. """ return get_framework_module(framework).validate_project(path) +def add_tool(framework: str, tool: 'ToolData', path: Optional[Path] = None) -> None: + """ + Add a tool to the framework. + + The tool will have aready been installed in the user's application and have + all dependencies installed. We're just handling code generation here. + """ + return get_framework_module(framework).add_tool(tool, path) + +def remove_tool(framework: str, tool: 'ToolData', path: Optional[Path] = None) -> None: + """ + Remove a tool from the framework. + """ + return get_framework_module(framework).remove_tool(tool, path) + +def add_agent(framework: str, path: Optional[Path] = None) -> None: + """ + Add an agent to the framework. + """ + return get_framework_module(framework).add_agent(path) + +def remove_agent(framework: str, path: Optional[Path] = None) -> None: + """ + Remove an agent from the framework. + """ + return get_framework_module(framework).remove_agent(path) + +def add_input(framework: str, path: Optional[Path] = None) -> None: + """ + Add an input to the framework. + """ + return get_framework_module(framework).add_input(path) + +def remove_input(framework: str, path: Optional[Path] = None) -> None: + """ + Remove an input from the framework. + """ + return get_framework_module(framework).remove_input(path) + diff --git a/agentstack/frameworks/crewai.py b/agentstack/frameworks/crewai.py index 249caca2..f7d4aa3a 100644 --- a/agentstack/frameworks/crewai.py +++ b/agentstack/frameworks/crewai.py @@ -1,63 +1,96 @@ from typing import Optional from pathlib import Path import ast +from agentstack.generation import astools from . import SUPPORTED_FRAMEWORKS, ValidationError ENTRYPOINT: Path = Path('src/crew.py') +class CrewFile: + """ + Parses and manipulates the CrewAI entrypoint file. + """ + tree: ast.AST + _base_class: ast.ClassDef + + def __init__(self, path: Optional[Path] = None): + if path is None: path = Path() + try: + with open(path/ENTRYPOINT, 'r') as f: + self.tree = ast.parse(f.read()) + except (FileNotFoundError, SyntaxError) as e: + raise ValidationError(f"Failed to parse {ENTRYPOINT}\n{e}") + + def get_base_class(self) -> ast.ClassDef: + """A base class is a class decorated with `@CrewBase`.""" + if self._base_class is None: # Gets cached to save repeat iteration + try: + self._base_class = astools.find_class_with_decorator(self.tree, 'CrewBase')[0] + except IndexError: + raise ValidationError(f"`@CrewBase` decorated class not found in {self.ENTRYPOINT}") + return self._base_class + + def get_crew_method(self) -> ast.FunctionDef: + """A `crew` method is a method decorated with `@crew`.""" + try: + base_class = self.get_base_class() + return astools.find_decorated_method_in_class(base_class, 'crew')[0] + except IndexError: + raise ValidationError(f"`@crew` decorated method not found in `{base_class.name}` class in {self.ENTRYPOINT}") + + def get_task_methods(self) -> list[ast.FunctionDef]: + """A `task` method is a method decorated with `@task`.""" + return astools.find_decorated_method_in_class(self.get_base_class(), 'task') + + def get_agent_methods(self) -> list[ast.FunctionDef]: + """An `agent` method is a method decorated with `@agent`.""" + return astools.find_decorated_method_in_class(self.get_base_class(), 'agent') + def validate_project(path: Optional[Path] = None) -> None: """ Validate that a CrewAI project is ready to run. Raises a frameworks.VaidationError if the project is not valid. """ - try: - if path is None: path = Path() - with open(path/ENTRYPOINT, 'r') as f: - tree = ast.parse(f.read()) - except (FileNotFoundError, SyntaxError) as e: - raise ValidationError(f"Failed to parse {ENTRYPOINT}\n {e}") - + crew_file = CrewFile(path) # raises ValidationError # A valid project must have a class in the crew.py file decorated with `@CrewBase` - try: - class_node = _find_class_with_decorator(tree, 'CrewBase')[0] - except IndexError: - raise ValidationError(f"`@CrewBase` decorated class not found in {ENTRYPOINT}") + class_node = crew_file.get_base_class() # raises ValidationError + # The Crew class must have one method decorated with `@crew` + crew_file.get_crew_method() # raises ValidationError # The Crew class must have one or more methods decorated with `@agent` - if len(_find_decorated_method_in_class(class_node, 'task')) < 1: + if len(crew_file.get_task_methods()) < 1: raise ValidationError( f"`@task` decorated method not found in `{class_node.name}` class in {ENTRYPOINT}.\n" "Create a new task using `agentstack generate task `.") # The Crew class must have one or more methods decorated with `@agent` - if len(_find_decorated_method_in_class(class_node, 'agent')) < 1: + if len(crew_file.get_agent_methods()) < 1: raise ValidationError( f"`@agent` decorated method not found in `{class_node.name}` class in {ENTRYPOINT}.\n" "Create a new agent using `agentstack generate agent `.") - # The Crew class must have one method decorated with `@crew` - if len(_find_decorated_method_in_class(class_node, 'crew')) < 1: - raise ValidationError(f"`@crew` decorated method not found in `{class_node.name}` class in {ENTRYPOINT}") - -# TODO move these to a shared AST utility module -def _find_class_with_decorator(tree: ast.AST, decorator_name: str) -> list[ast.ClassDef]: - """Find a class definition that is marked by a decorator in an AST.""" - nodes = [] - for node in ast.iter_child_nodes(tree): - if isinstance(node, ast.ClassDef): - for decorator in node.decorator_list: - if isinstance(decorator, ast.Name) and decorator.id == decorator_name: - nodes.append(node) - return nodes - -def _find_decorated_method_in_class(classdef: ast.ClassDef, decorator_name: str) -> list[ast.FunctionDef]: - """Find all method definitions in a class definition which are decorated with a specific decorator.""" - nodes = [] - for node in ast.iter_child_nodes(classdef): - if isinstance(node, ast.FunctionDef): - for decorator in node.decorator_list: - if isinstance(decorator, ast.Name) and decorator.id == decorator_name: - nodes.append(node) - return nodes +def add_tool(path: Optional[Path] = None) -> None: + """ + Add a tool to the CrewAI framework. + + Creates the tool's method in the Crew class in the entrypoint file and + imports the tool's methods from the tool module. + """ + pass + +def remove_tool(path: Optional[Path] = None) -> None: + pass + +def add_agent(path: Optional[Path] = None) -> None: + pass + +def remove_agent(path: Optional[Path] = None) -> None: + pass + +def add_input(path: Optional[Path] = None) -> None: + pass + +def remove_input(path: Optional[Path] = None) -> None: + pass diff --git a/agentstack/generation/astools.py b/agentstack/generation/astools.py new file mode 100644 index 00000000..2183c3ba --- /dev/null +++ b/agentstack/generation/astools.py @@ -0,0 +1,34 @@ +""" +Shotcuts for working with ASTs. +""" +import ast + + +def find_class_with_decorator(tree: ast.AST, decorator_name: str) -> list[ast.ClassDef]: + """Find a class definition that is marked by a decorator in an AST.""" + nodes = [] + for node in ast.iter_child_nodes(tree): + if isinstance(node, ast.ClassDef): + for decorator in node.decorator_list: + if isinstance(decorator, ast.Name) and decorator.id == decorator_name: + nodes.append(node) + return nodes + +def find_decorated_method_in_class(classdef: ast.ClassDef, decorator_name: str) -> list[ast.FunctionDef]: + """Find all method definitions in a class definition which are decorated with a specific decorator.""" + nodes = [] + for node in ast.iter_child_nodes(classdef): + if isinstance(node, ast.FunctionDef): + for decorator in node.decorator_list: + if isinstance(decorator, ast.Name) and decorator.id == decorator_name: + nodes.append(node) + return nodes + +def create_attribute(attr_name: str, base_name: str) -> ast.Attribute: + """Create an AST node for an attribute""" + return ast.Attribute( + value=ast.Name(id=base_name, ctx=ast.Load()), + attr=attr_name, + ctx=ast.Load() + ) + diff --git a/agentstack/generation/tool_generation.py b/agentstack/generation/tool_generation.py index 68a2472f..8a5b48df 100644 --- a/agentstack/generation/tool_generation.py +++ b/agentstack/generation/tool_generation.py @@ -17,6 +17,7 @@ from pydantic import BaseModel, ValidationError from agentstack.utils import get_package_path +from agentstack.generation import astools from agentstack.generation.files import ConfigFile, EnvFile from agentstack import frameworks from .gen_utils import insert_code_after_tag, string_in_file @@ -183,47 +184,6 @@ def remove_tool_from_agent_definition(framework: str, tool_data: ToolConfig, pat modify_agent_tools(framework=framework, tool_data=tool_data, operation='remove', agents=None, path=path, base_name='tools') -def _create_tool_attribute(tool_name: str, base_name: str = 'tools') -> ast.Attribute: - """Create an AST node for a tool attribute""" - return ast.Attribute( - value=ast.Name(id=base_name, ctx=ast.Load()), - attr=tool_name, - ctx=ast.Load() - ) - -def _create_starred_tool(tool_name: str, base_name: str = 'tools') -> ast.Starred: - """Create an AST node for a starred tool expression""" - return ast.Starred( - value=ast.Attribute( - value=ast.Name(id=base_name, ctx=ast.Load()), - attr=tool_name, - ctx=ast.Load() - ), - ctx=ast.Load() - ) - - -def _create_tool_attributes( - tool_names: List[str], - base_name: str = 'tools' -) -> List[ast.Attribute]: - """Create AST nodes for multiple tool attributes""" - return [_create_tool_attribute(name, base_name) for name in tool_names] - - -def _create_tool_nodes( - tool_names: List[str], - is_bundled: bool = False, - base_name: str = 'tools' -) -> List[Union[ast.Attribute, ast.Starred]]: - """Create AST nodes for multiple tool attributes""" - return [ - _create_starred_tool(name, base_name) if is_bundled - else _create_tool_attribute(name, base_name) - for name in tool_names - ] - - def _is_tool_node_match(node: ast.AST, tool_name: str, base_name: str = 'tools') -> bool: """ Check if an AST node matches a tool reference, regardless of whether it's starred @@ -268,11 +228,11 @@ def _process_tools_list( if operation == 'add': new_tools = current_tools.copy() # Add new tools with bundling if specified - new_tools.extend(_create_tool_nodes( - tool_data.tools, - tool_data.tools_bundled, - base_name - )) + new_tools.extend([ + ast.Starred(astools.create_attribute(name, base_name)) if tool_data.tools_bundled, + else astools.create_attribute(name, base_name) + for name in tool_data.tools + ]) return new_tools elif operation == 'remove': From 14a0cf0b5f1b2d60b8cbc40aa95b9120b63783f4 Mon Sep 17 00:00:00 2001 From: Travis Dent Date: Wed, 4 Dec 2024 15:20:18 -0800 Subject: [PATCH 05/14] Multi-framework support via `agentstack.frameworks` module. Move AST methods to generic tools file, implement CrewAI-specific file modifications in crew frameworks. Make `tools` a first-class module in the package. Add tests for framework generation and tool generation. Loosen tox version requirement to avoid conflict with AgentOps. --- agentstack/__init__.py | 8 + agentstack/cli/cli.py | 2 +- agentstack/frameworks/__init__.py | 35 +- agentstack/frameworks/crewai.py | 145 +++++-- agentstack/generation/astools.py | 118 +++++- agentstack/generation/tool_generation.py | 354 +++++------------- agentstack/tools.py | 63 ++++ pyproject.toml | 3 +- .../frameworks/crewai/entrypoint_max.py | 29 ++ .../frameworks/crewai/entrypoint_min.py | 16 + tests/test_frameworks.py | 119 ++++++ tests/test_tool_config.py | 2 +- tests/test_tool_generation_init.py | 78 ++++ tox.ini | 1 + 14 files changed, 661 insertions(+), 312 deletions(-) create mode 100644 agentstack/tools.py create mode 100644 tests/fixtures/frameworks/crewai/entrypoint_max.py create mode 100644 tests/fixtures/frameworks/crewai/entrypoint_min.py create mode 100644 tests/test_frameworks.py create mode 100644 tests/test_tool_generation_init.py diff --git a/agentstack/__init__.py b/agentstack/__init__.py index e69de29b..e645be5c 100644 --- a/agentstack/__init__.py +++ b/agentstack/__init__.py @@ -0,0 +1,8 @@ + + +class ValidationError(Exception): + """ + Raised when a validation error occurs ie. a file does not meet the required + format or a syntax error is found. + """ + pass \ No newline at end of file diff --git a/agentstack/cli/cli.py b/agentstack/cli/cli.py index 05b69c4f..a62df66e 100644 --- a/agentstack/cli/cli.py +++ b/agentstack/cli/cli.py @@ -17,8 +17,8 @@ from .agentstack_data import FrameworkData, ProjectMetadata, ProjectStructure, CookiecutterData from agentstack.logger import log from agentstack.utils import get_package_path +from agentstack.tools import get_all_tools from agentstack.generation.files import ConfigFile -from agentstack.generation.tool_generation import get_all_tools from agentstack import frameworks from agentstack import packaging from agentstack import generation diff --git a/agentstack/frameworks/__init__.py b/agentstack/frameworks/__init__.py index 211669eb..e9359f51 100644 --- a/agentstack/frameworks/__init__.py +++ b/agentstack/frameworks/__init__.py @@ -16,15 +16,13 @@ from typing import TYPE_CHECKING, Optional from importlib import import_module from pathlib import Path -if TYPE_CHECKING: - from agentstack.generation.tool_generation import ToolData +from agentstack import ValidationError +from agentstack.tools import ToolConfig CREWAI = 'crewai' SUPPORTED_FRAMEWORKS = [CREWAI, ] -class ValidationError(Exception): pass - def get_framework_module(framework: str) -> import_module: """ Get the module for a framework. @@ -35,52 +33,59 @@ def get_framework_module(framework: str) -> import_module: else: raise ValueError(f"Framework {framework} not supported") -def get_entrypoint_path(framework: str) -> Path: +def get_entrypoint_path(framework: str, path: Optional[Path] = None) -> Path: """ Get the path to the entrypoint file for a framework. """ - return get_framework_module(framework).ENTRYPOINT + if not path: path = Path() + return path/get_framework_module(framework).ENTRYPOINT -def validate_project(framework: str, path: Optional[Path] = None) -> None: +def validate_project(framework: str, path: Optional[Path] = None): """ Run the framework specific project validation. """ return get_framework_module(framework).validate_project(path) -def add_tool(framework: str, tool: 'ToolData', path: Optional[Path] = None) -> None: +def add_tool(framework: str, tool: ToolConfig, agent: str, path: Optional[Path] = None): """ Add a tool to the framework. The tool will have aready been installed in the user's application and have all dependencies installed. We're just handling code generation here. """ - return get_framework_module(framework).add_tool(tool, path) + return get_framework_module(framework).add_tool(tool, agent, path) -def remove_tool(framework: str, tool: 'ToolData', path: Optional[Path] = None) -> None: +def remove_tool(framework: str, tool: ToolConfig, agent: str, path: Optional[Path] = None): """ Remove a tool from the framework. """ - return get_framework_module(framework).remove_tool(tool, path) + return get_framework_module(framework).remove_tool(tool, agent, path) + +def get_agent_names(framework: str, path: Optional[Path] = None) -> list[str]: + """ + Get a list of agent names from the framework. + """ + return get_framework_module(framework).get_agent_names(path) -def add_agent(framework: str, path: Optional[Path] = None) -> None: +def add_agent(framework: str, path: Optional[Path] = None): """ Add an agent to the framework. """ return get_framework_module(framework).add_agent(path) -def remove_agent(framework: str, path: Optional[Path] = None) -> None: +def remove_agent(framework: str, path: Optional[Path] = None): """ Remove an agent from the framework. """ return get_framework_module(framework).remove_agent(path) -def add_input(framework: str, path: Optional[Path] = None) -> None: +def add_input(framework: str, path: Optional[Path] = None): """ Add an input to the framework. """ return get_framework_module(framework).add_input(path) -def remove_input(framework: str, path: Optional[Path] = None) -> None: +def remove_input(framework: str, path: Optional[Path] = None): """ Remove an input from the framework. """ diff --git a/agentstack/frameworks/crewai.py b/agentstack/frameworks/crewai.py index f7d4aa3a..02df063d 100644 --- a/agentstack/frameworks/crewai.py +++ b/agentstack/frameworks/crewai.py @@ -1,26 +1,20 @@ from typing import Optional from pathlib import Path import ast +from agentstack import ValidationError +from agentstack.tools import ToolConfig from agentstack.generation import astools -from . import SUPPORTED_FRAMEWORKS, ValidationError +from . import SUPPORTED_FRAMEWORKS ENTRYPOINT: Path = Path('src/crew.py') -class CrewFile: +class CrewFile(astools.File): """ Parses and manipulates the CrewAI entrypoint file. + All AST interactions should happen within the methods of this class. """ - tree: ast.AST - _base_class: ast.ClassDef - - def __init__(self, path: Optional[Path] = None): - if path is None: path = Path() - try: - with open(path/ENTRYPOINT, 'r') as f: - self.tree = ast.parse(f.read()) - except (FileNotFoundError, SyntaxError) as e: - raise ValidationError(f"Failed to parse {ENTRYPOINT}\n{e}") + _base_class: ast.ClassDef = None def get_base_class(self) -> ast.ClassDef: """A base class is a class decorated with `@CrewBase`.""" @@ -28,7 +22,7 @@ def get_base_class(self) -> ast.ClassDef: try: self._base_class = astools.find_class_with_decorator(self.tree, 'CrewBase')[0] except IndexError: - raise ValidationError(f"`@CrewBase` decorated class not found in {self.ENTRYPOINT}") + raise ValidationError(f"`@CrewBase` decorated class not found in {ENTRYPOINT}") return self._base_class def get_crew_method(self) -> ast.FunctionDef: @@ -37,7 +31,7 @@ def get_crew_method(self) -> ast.FunctionDef: base_class = self.get_base_class() return astools.find_decorated_method_in_class(base_class, 'crew')[0] except IndexError: - raise ValidationError(f"`@crew` decorated method not found in `{base_class.name}` class in {self.ENTRYPOINT}") + raise ValidationError(f"`@crew` decorated method not found in `{base_class.name}` class in {ENTRYPOINT}") def get_task_methods(self) -> list[ast.FunctionDef]: """A `task` method is a method decorated with `@task`.""" @@ -47,16 +41,98 @@ def get_agent_methods(self) -> list[ast.FunctionDef]: """An `agent` method is a method decorated with `@agent`.""" return astools.find_decorated_method_in_class(self.get_base_class(), 'agent') + def get_agent_tools(self, agent_name: str) -> ast.List: + """ + Get the tools used by an agent as AST nodes. + + Tool definitons are inside of the methods marked with an `@agent` decorator. + The method returns a new class instance with the tools as a list of callables + under the kwarg `tools`. + """ + method = astools.find_method(self.get_agent_methods(), agent_name) + if method is None: + raise ValidationError(f"`@agent` method `{agent_name}` does not exist in {ENTRYPOINT}") + + agent_class = astools.find_class_instantiation(method, 'Agent') + if agent_class is None: + raise ValidationError(f"`@agent` method `{agent_name}` does not have an `Agent` class instantiation in {ENTRYPOINT}") + + tools_kwarg = astools.find_kwarg_in_method_call(agent_class, 'tools') + if not tools_kwarg: + raise ValidationError(f"`@agent` method `{agent_name}` does not have a keyword argument `tools` in {ENTRYPOINT}") + + return tools_kwarg.value + + def add_agent_tools(self, agent_name: str, tool: ToolConfig): + """ + Add new tools to be used by an agent. + + Tool definitons are inside of the methods marked with an `@agent` decorator. + The method returns a new class instance with the tools as a list of callables + under the kwarg `tools`. + """ + method = astools.find_method(self.get_agent_methods(), agent_name) + if method is None: + raise ValidationError(f"`@agent` method `{agent_name}` does not exist in {ENTRYPOINT}") + + new_tool_nodes = [] + for tool_name in tool.tools: + # This prefixes the tool name with the 'tools' module + node = astools.create_attribute('tools', tool_name) + if tool.tools_bundled: # Splat the variable if it's bundled + node = ast.Starred(value=node, ctx=ast.Load()) + new_tool_nodes.append(node) + + existing_node: ast.List = self.get_agent_tools(agent_name) + new_node = ast.List( + elts=set(existing_node.elts + new_tool_nodes), + ctx=ast.Load() + ) + start, end = self.get_node_range(existing_node) + self.edit_node_range(start, end, new_node) + + def remove_agent_tools(self, agent_name: str, tool: ToolConfig): + """ + Remove tools from an agent belonging to `tool`. + """ + existing_node: ast.List = self.get_agent_tools(agent_name) + start, end = self.get_node_range(existing_node) + + # modify the existing node to remove any matching tools + for tool_name in tool.tools: + for node in existing_node.elts: + if isinstance(node, ast.Starred): + attr_name = node.value.attr + else: + attr_name = node.attr + if attr_name == tool_name: + existing_node.elts.remove(node) + + self.edit_node_range(start, end, existing_node) + + def validate_project(path: Optional[Path] = None) -> None: """ Validate that a CrewAI project is ready to run. Raises a frameworks.VaidationError if the project is not valid. """ - crew_file = CrewFile(path) # raises ValidationError + if path is None: path = Path() + try: + crew_file = CrewFile(path/ENTRYPOINT) + except ValidationError as e: + raise e + # A valid project must have a class in the crew.py file decorated with `@CrewBase` - class_node = crew_file.get_base_class() # raises ValidationError + try: + class_node = crew_file.get_base_class() + except ValidationError as e: + raise e + # The Crew class must have one method decorated with `@crew` - crew_file.get_crew_method() # raises ValidationError + try: + crew_file.get_crew_method() + except ValidationError as e: + raise e # The Crew class must have one or more methods decorated with `@agent` if len(crew_file.get_task_methods()) < 1: @@ -70,27 +146,40 @@ def validate_project(path: Optional[Path] = None) -> None: f"`@agent` decorated method not found in `{class_node.name}` class in {ENTRYPOINT}.\n" "Create a new agent using `agentstack generate agent `.") -def add_tool(path: Optional[Path] = None) -> None: +def add_tool(tool: ToolConfig, agent_name: str, path: Optional[Path] = None): + """ + Add a tool to the CrewAI framework for the specified agent. + + The agent should already exist in the crew class and have a keyword argument `tools`. """ - Add a tool to the CrewAI framework. + if path is None: path = Path() + with CrewFile(path/ENTRYPOINT) as crew_file: + crew_file.add_agent_tools(agent_name, tool) - Creates the tool's method in the Crew class in the entrypoint file and - imports the tool's methods from the tool module. +def remove_tool(tool: ToolConfig, agent_name: str, path: Optional[Path] = None): + """ + Remove a tool from the CrewAI framework for the specified agent. """ - pass + if path is None: path = Path() + with CrewFile(path/ENTRYPOINT) as crew_file: + crew_file.remove_agent_tools(agent_name, tool) -def remove_tool(path: Optional[Path] = None) -> None: - pass +def get_agent_names(path: Optional[Path] = None) -> list[str]: + """ + Get a list of agent names (methods with an @agent decorator). + """ + crew_file = CrewFile(path/ENTRYPOINT) + return [method.name for method in crew_file.get_agent_methods()] def add_agent(path: Optional[Path] = None) -> None: - pass + raise NotImplementedError def remove_agent(path: Optional[Path] = None) -> None: - pass + raise NotImplementedError def add_input(path: Optional[Path] = None) -> None: - pass + raise NotImplementedError def remove_input(path: Optional[Path] = None) -> None: - pass + raise NotImplementedError diff --git a/agentstack/generation/astools.py b/agentstack/generation/astools.py index 2183c3ba..7eb84f53 100644 --- a/agentstack/generation/astools.py +++ b/agentstack/generation/astools.py @@ -1,8 +1,122 @@ """ -Shotcuts for working with ASTs. +Tools for working with ASTs. + +We include convenience functions here based on real needs inside the codebase, +such as finding a method definition in a class, or finding a method by its decorator. + +It's not optimal to have a fully-featured set of functions as this would be +unwieldy, but since our use-cases are well-defined, we can provide a set of +functions that are useful for the specific tasks we need to accomplish. """ +from typing import Optional, Union +from pathlib import Path import ast +import astor +import asttokens +from agentstack import ValidationError + + +class File: + """ + Parses and manipulates a Python source file with an AST. + + Use it as a context manager to make and save edits: + ```python + with File(filename) as f: + f.edit_node_range(start, end, new_node) + ``` + + Lookups are done using the built-in `ast` module, which we only use to find + and read nodes in the tree. + + Edits are done using string indexing on the source code, which preserves a + majority of the original formatting and prevents comments from being lost. + + In cases where we are constructing new AST nodes, we use `ast.unparse`. + """ + filename: Path = None + source: str = None + atok: asttokens.ASTTokens = None + tree: ast.AST = None + + def __init__(self, filename: Path): + self.filename = filename + self.read() + + def read(self): + try: + with open(self.filename, 'r') as f: + self.source = f.read() + self.atok = asttokens.ASTTokens(self.source, parse=True) + self.tree = self.atok.tree + except (FileNotFoundError, SyntaxError) as e: + raise ValidationError(f"Failed to parse {self.filename}\n{e}") + + def write(self): + with open(self.filename, 'w', encoding='utf-8') as f: + f.write(self.source) + + def get_node_range(self, node: ast.AST) -> tuple[int, int]: + """Get the string start and end indexes for a node in the source code.""" + return self.atok.get_text_range(node) + + def edit_node_range(self, start: int, end: int, node: Union[str, ast.AST]): + """Splice a new node or string into the source code at the given range.""" + if isinstance(node, ast.AST): + module = ast.Module( + body=[ast.Expr(value=node)], + type_ignores=[] + ) + node = astor.to_source(module).strip() + self.source = self.source[:start] + node + self.source[end:] + # In order to continue accurately modifying the AST, we need to re-parse the source. + self.atok = asttokens.ASTTokens(self.source, parse=True) + self.tree = self.atok.tree + + def __enter__(self) -> 'File': return self + def __exit__(self, *args): self.write() + + +def get_all_imports(tree: ast.AST) -> list[Union[ast.Import, ast.ImportFrom]]: + """Find all import statements in an AST.""" + imports = [] + for node in ast.iter_child_nodes(tree): + if isinstance(node, ast.Import) or isinstance(node, ast.ImportFrom): + imports.append(node) + return imports + +def find_method(tree: Union[list[ast.AST], ast.AST], method_name: str) -> Optional[ast.FunctionDef]: + """Find a method definition in an AST.""" + if not isinstance(tree, list): + tree: generator = ast.iter_child_nodes(tree) + for node in tree: + if isinstance(node, ast.FunctionDef) and node.name == method_name: + return node + return None + +def find_kwarg_in_method_call(node: ast.Call, kwarg_name: str) -> Optional[ast.keyword]: + """Find a keyword argument in a method call or class instantiation.""" + for arg in node.keywords: + if isinstance(arg, ast.keyword) and arg.arg == kwarg_name: + return arg + return None +def find_class_instantiation(tree: Union[list[ast.AST], ast.AST], class_name: str) -> Optional[ast.Call]: + """ + Find a class instantiation statement in an AST by the class name. + This can either be an assignment to a variable or a return statement. + """ + if not isinstance(tree, list): + tree: generator = ast.iter_child_nodes(tree) + for node in tree: + if isinstance(node, ast.Assign): + for target in node.targets: + if isinstance(target, ast.Name) and target.id == class_name: + return node.value + elif isinstance(node, ast.Return): + if isinstance(node.value, ast.Call) and node.value.func.id == class_name: + return node.value + return None def find_class_with_decorator(tree: ast.AST, decorator_name: str) -> list[ast.ClassDef]: """Find a class definition that is marked by a decorator in an AST.""" @@ -24,7 +138,7 @@ def find_decorated_method_in_class(classdef: ast.ClassDef, decorator_name: str) nodes.append(node) return nodes -def create_attribute(attr_name: str, base_name: str) -> ast.Attribute: +def create_attribute(base_name: str, attr_name: str) -> ast.Attribute: """Create an AST node for an attribute""" return ast.Attribute( value=ast.Name(id=base_name, ctx=ast.Load()), diff --git a/agentstack/generation/tool_generation.py b/agentstack/generation/tool_generation.py index 6c5ff878..26b4fde0 100644 --- a/agentstack/generation/tool_generation.py +++ b/agentstack/generation/tool_generation.py @@ -12,12 +12,12 @@ import os import shutil import fileinput -import astor import ast -from pydantic import BaseModel, ValidationError from agentstack import packaging +from agentstack import ValidationError from agentstack.utils import get_package_path +from agentstack.tools import ToolConfig from agentstack.generation import astools from agentstack.generation.files import ConfigFile, EnvFile from agentstack import frameworks @@ -25,60 +25,67 @@ from ..utils import open_json_file, get_framework, term_color -TOOL_INIT_FILENAME = "src/tools/__init__.py" - -class ToolConfig(BaseModel): - name: str - category: str - tools: list[str] - url: Optional[str] = None - tools_bundled: bool = False - cta: Optional[str] = None - env: Optional[dict] = None - packages: Optional[List[str]] = None - post_install: Optional[str] = None - post_remove: Optional[str] = None - - @classmethod - def from_tool_name(cls, name: str) -> 'ToolConfig': - path = get_package_path() / f'tools/{name}.json' - if not os.path.exists(path): - print(term_color(f'No known agentstack tool: {name}', 'red')) - sys.exit(1) - return cls.from_json(path) - - @classmethod - def from_json(cls, path: Path) -> 'ToolConfig': - data = open_json_file(path) +# This is the filename of the location of tool imports in the user's project. +TOOLS_INIT_FILENAME: Path = Path("src/tools/__init__.py") + +class ToolsInitFile(astools.File): + """ + Modifiable AST representation of the tools init file. + + Use it as a context manager to make and save edits: + ```python + with ToolsInitFile(filename) as tools_init: + tools_init.add_import_for_tool(...) + ``` + """ + def get_import_for_tool(self, tool: ToolConfig) -> ast.Import: + """ + Get the import statement for a tool. + raises a ValidationError if the tool is imported multiple times. + """ + all_imports = astools.get_all_imports(self.tree) + tool_imports = [i for i in all_imports if tool.name in i.names[0].name] + + if len(tool_imports) > 1: + raise ValidationError(f"Multiple imports for tool {tool.name} found in {self.filename}") + + try: + return tool_imports[0] + except IndexError: + return None + + def add_import_for_tool(self, framework: str, tool: ToolConfig): + """ + Add an import for a tool. + raises a ValidationError if the tool is already imported. + """ + tool_import = self.get_import_for_tool(tool) + if tool_import: + raise ValidationError(f"Tool {tool.name} already imported in {self.filename}") + try: - return cls(**data) - except ValidationError as e: - print(term_color(f"Error validating tool config JSON: \n{path}", 'red')) - for error in e.errors(): - print(f"{' '.join(error['loc'])}: {error['msg']}") - sys.exit(1) - - def get_import_statement(self) -> str: - return f"from .{self.name}_tool import {', '.join(self.tools)}" - - def get_impl_file_path(self, framework: str) -> Path: - return get_package_path() / f'templates/{framework}/tools/{self.name}_tool.py' - -def get_all_tool_paths() -> list[Path]: - paths = [] - tools_dir = get_package_path() / 'tools' - for file in tools_dir.iterdir(): - if file.is_file() and file.suffix == '.json': - paths.append(file) - return paths - -def get_all_tool_names() -> list[str]: - return [path.stem for path in get_all_tool_paths()] - -def get_all_tools() -> list[ToolConfig]: - return [ToolConfig.from_json(path) for path in get_all_tool_paths()] - -def add_tool(tool_name: str, path: Optional[str] = None, agents: Optional[List[str]] = []): + last_import = astools.get_all_imports(self.tree)[-1] + start, end = self.get_node_range(last_import) + except IndexError: + start, end = 0, 0 # No imports in the file + + import_statement = tool.get_import_statement(framework) + self.edit_node_range(end, end, f"\n{import_statement}") + + def remove_import_for_tool(self, framework: str, tool: ToolConfig): + """ + Remove an import for a tool. + raises a ValidationError if the tool is not imported. + """ + tool_import = self.get_imports_for_tool(tool) + if not tool_import: + raise ValidationError(f"Tool {tool.name} not imported in {self.filename}") + + start, end = self.get_node_range(tool_import) + self.edit_node_range(start, end, "") + + +def add_tool(tool_name: str, agents: Optional[List[str]] = [], path: Optional[str] = None): if path: path = path.endswith('/') and path or path + '/' else: @@ -97,8 +104,19 @@ def add_tool(tool_name: str, path: Optional[str] = None, agents: Optional[List[s if tool_data.packages: packaging.install(' '.join(tool_data.packages)) shutil.copy(tool_file_path, f'{path}src/tools/{tool_name}_tool.py') # Move tool from package to project - add_tool_to_tools_init(tool_data, path) # Export tool from tools dir - add_tool_to_agent_definition(framework=framework, tool_data=tool_data, path=path, agents=agents) # Add tool to agent definition + + try: # Edit the user's project tool init file to include the tool + with ToolsInitFile(path/TOOLS_INIT_FILENAME) as tools_init: + tools_init.add_import_for_tool(tool_data) + except ValidationError as e: + print(term_color(f"Error adding tool:\n{e}", 'red')) + sys.exit(1) + + # Edit the framework entrypoint file to include the tool in the agent definition + if not len(agents): # If no agents are specified, add the tool to all agents + agents = frameworks.get_agent_names(framework, path) + for agent_name in agents: + frameworks.add_tool(framework, tool_data, agent_name, path) if tool_data.env: # add environment variables which don't exist with EnvFile(path) as env: @@ -119,7 +137,7 @@ def add_tool(tool_name: str, path: Optional[str] = None, agents: Optional[List[s print(term_color(f'đŸĒŠ {tool_data.cta}', 'blue')) -def remove_tool(tool_name: str, path: Optional[str] = None): +def remove_tool(tool_name: str, agents: Optional[List[str]] = [], path: Optional[str] = None): if path: path = path.endswith('/') and path or path + '/' else: @@ -139,8 +157,20 @@ def remove_tool(tool_name: str, path: Optional[str] = None): os.remove(f'{path}src/tools/{tool_name}_tool.py') except FileNotFoundError: print(f'"src/tools/{tool_name}_tool.py" not found') - remove_tool_from_tools_init(tool_data, path) - remove_tool_from_agent_definition(framework, tool_data, path) + + try: # Edit the user's project tool init file to exclude the tool + with ToolsInitFile(path/TOOLS_INIT_FILENAME) as tools_init: + tools_init.remove_import_for_tool(tool_data) + except ValidationError as e: + print(term_color(f"Error removing tool:\n{e}", 'red')) + sys.exit(1) + + # Edit the framework entrypoint file to exclude the tool in the agent definition + if not len(agents): # If no agents are specified, remove the tool from all agents + agents = frameworks.get_agent_names(framework, path) + for agent_name in agents: + frameworks.remove_tool(framework, tool_data, agent_name, path) + if tool_data.post_remove: os.system(tool_data.post_remove) # We don't remove the .env variables to preserve user data. @@ -150,207 +180,3 @@ def remove_tool(tool_name: str, path: Optional[str] = None): print(term_color(f'🔨 Tool {tool_name}', 'green'), term_color('removed', 'red'), term_color('from agentstack project successfully', 'green')) - -def add_tool_to_tools_init(tool_data: ToolConfig, path: str = ''): - file_path = f'{path}{TOOL_INIT_FILENAME}' - tag = '# tool import' - code_to_insert = [tool_data.get_import_statement(), ] - insert_code_after_tag(file_path, tag, code_to_insert, next_line=True) - - -def remove_tool_from_tools_init(tool_data: ToolConfig, path: str = ''): - """Search for the import statement in the init and remove it.""" - file_path = f'{path}{TOOL_INIT_FILENAME}' - import_statement = tool_data.get_import_statement() - with fileinput.input(files=file_path, inplace=True) as f: - for line in f: - if line.strip() != import_statement: - print(line, end='') - - -def add_tool_to_agent_definition(framework: str, tool_data: ToolConfig, path: str = '', agents: list[str] = []): - """ - Add tools to specific agent definitions using AST transformation. - - Args: - framework: Name of the framework - tool_data: ToolConfig - agents: Optional list of agent names to modify. If None, modifies all agents. - path: Optional path to the framework file - """ - modify_agent_tools(framework=framework, tool_data=tool_data, operation='add', agents=agents, path=path, base_name='tools') - - -def remove_tool_from_agent_definition(framework: str, tool_data: ToolConfig, path: str = ''): - modify_agent_tools(framework=framework, tool_data=tool_data, operation='remove', agents=None, path=path, base_name='tools') - - -def _is_tool_node_match(node: ast.AST, tool_name: str, base_name: str = 'tools') -> bool: - """ - Check if an AST node matches a tool reference, regardless of whether it's starred - - Args: - node: AST node to check (can be Attribute or Starred) - tool_name: Name of the tool to match - base_name: Base module name (default: 'tools') - - Returns: - bool: True if the node matches the tool reference - """ - # If it's a Starred node, check its value - if isinstance(node, ast.Starred): - node = node.value - - # Extract the attribute name and base regardless of node type - if isinstance(node, ast.Attribute): - is_base_match = (isinstance(node.value, ast.Name) and - node.value.id == base_name) - is_name_match = node.attr == tool_name - return is_base_match and is_name_match - - return False - - -def _process_tools_list( - current_tools: List[ast.AST], - tool_data: ToolConfig, - operation: str, - base_name: str = 'tools' -) -> List[ast.AST]: - """ - Process a tools list according to the specified operation. - - Args: - current_tools: Current list of tool nodes - tool_data: Tool configuration - operation: Operation to perform ('add' or 'remove') - base_name: Base module name for tools - """ - if operation == 'add': - new_tools = current_tools.copy() - # Add new tools with bundling if specified - new_tools.extend([ - ast.Starred(astools.create_attribute(name, base_name)) if tool_data.tools_bundled, - else astools.create_attribute(name, base_name) - for name in tool_data.tools - ]) - return new_tools - - elif operation == 'remove': - # Filter out tools that match any in the removal list - return [ - tool for tool in current_tools - if not any(_is_tool_node_match(tool, name, base_name) - for name in tool_data.tools) - ] - - raise ValueError(f"Unsupported operation: {operation}") - - -def _modify_agent_tools( - node: ast.FunctionDef, - tool_data: ToolConfig, - operation: str, - agents: Optional[List[str]] = None, - base_name: str = 'tools' -) -> ast.FunctionDef: - """ - Modify the tools list in an agent definition. - - Args: - node: AST node of the function to modify - tool_data: Tool configuration - operation: Operation to perform ('add' or 'remove') - agents: Optional list of agent names to modify - base_name: Base module name for tools - """ - # Skip if not in specified agents list - if agents is not None and agents != []: - if node.name not in agents: - return node - - # Check if this is an agent-decorated function - if not any(isinstance(d, ast.Name) and d.id == 'agent' - for d in node.decorator_list): - return node - - # Find the Return statement and modify tools - for item in node.body: - if isinstance(item, ast.Return): - agent_call = item.value - if isinstance(agent_call, ast.Call): - for kw in agent_call.keywords: - if kw.arg == 'tools': - if isinstance(kw.value, ast.List): - # Process the tools list - new_tools = _process_tools_list( - kw.value.elts, - tool_data, - operation, - base_name - ) - - # Replace with new list - kw.value = ast.List(elts=new_tools, ctx=ast.Load()) - - return node - - -def modify_agent_tools( - framework: str, - tool_data: ToolConfig, - operation: str, - agents: Optional[List[str]] = None, - path: str = '', - base_name: str = 'tools' -) -> None: - """ - Modify tools in agent definitions using AST transformation. - - Args: - framework: Name of the framework - tool_data: ToolConfig - operation: Operation to perform ('add' or 'remove') - agents: Optional list of agent names to modify - path: Optional path to the framework file - base_name: Base module name for tools (default: 'tools') - """ - if agents is not None: - valid_agents = get_agent_names(path=path) - for agent in agents: - if agent not in valid_agents: - print(term_color(f"Agent '{agent}' not found in the project.", 'red')) - sys.exit(1) - - path = Path(path) - filename = path/frameworks.get_entrypoint_path(framework) - - with open(filename, 'r', encoding='utf-8') as f: - source_lines = f.readlines() - - # Create a map of line numbers to comments - comments = {} - for i, line in enumerate(source_lines): - stripped = line.strip() - if stripped.startswith('#'): - comments[i + 1] = line - - tree = ast.parse(''.join(source_lines)) - - class ModifierTransformer(ast.NodeTransformer): - def visit_FunctionDef(self, node): - return _modify_agent_tools(node, tool_data, operation, agents, base_name) - - modified_tree = ModifierTransformer().visit(tree) - modified_source = astor.to_source(modified_tree) - modified_lines = modified_source.splitlines() - - # Reinsert comments - final_lines = [] - for i, line in enumerate(modified_lines, 1): - if i in comments: - final_lines.append(comments[i]) - final_lines.append(line + '\n') - - with open(filename, 'w', encoding='utf-8') as f: - f.write(''.join(final_lines)) \ No newline at end of file diff --git a/agentstack/tools.py b/agentstack/tools.py new file mode 100644 index 00000000..6e24b663 --- /dev/null +++ b/agentstack/tools.py @@ -0,0 +1,63 @@ +from typing import Optional +import os, sys +from pathlib import Path +import pydantic +from agentstack.utils import get_package_path, open_json_file, term_color + + +class ToolConfig(pydantic.BaseModel): + """ + This represents the configuration data for a tool. + It parses and validates the `config.json` file for a tool. + """ + name: str + category: str + tools: list[str] + url: Optional[str] = None + tools_bundled: bool = False + cta: Optional[str] = None + env: Optional[dict] = None + packages: Optional[list[str]] = None + post_install: Optional[str] = None + post_remove: Optional[str] = None + + @classmethod + def from_tool_name(cls, name: str) -> 'ToolConfig': + path = get_package_path() / f'tools/{name}.json' + if not os.path.exists(path): # TODO raise exceptions and handle message/exit in cli + print(term_color(f'No known agentstack tool: {name}', 'red')) + sys.exit(1) + return cls.from_json(path) + + @classmethod + def from_json(cls, path: Path) -> 'ToolConfig': + data = open_json_file(path) + try: + return cls(**data) + except pydantic.ValidationError as e: + # TODO raise exceptions and handle message/exit in cli + print(term_color(f"Error validating tool config JSON: \n{path}", 'red')) + for error in e.errors(): + print(f"{' '.join(error['loc'])}: {error['msg']}") + sys.exit(1) + + def get_import_statement(self, framework: str) -> str: + return f"from .{self.name}_tool import {', '.join(self.tools)}" + + def get_impl_file_path(self, framework: str) -> Path: + return get_package_path()/f'templates/{framework}/tools/{self.name}_tool.py' + +def get_all_tool_paths() -> list[Path]: + paths = [] + tools_dir = get_package_path()/'tools' + for file in tools_dir.iterdir(): + if file.is_file() and file.suffix == '.json': + paths.append(file) + return paths + +def get_all_tool_names() -> list[str]: + return [path.stem for path in get_all_tool_paths()] + +def get_all_tools() -> list[ToolConfig]: + return [ToolConfig.from_json(path) for path in get_all_tool_paths()] + diff --git a/pyproject.toml b/pyproject.toml index 4017a8f4..59190512 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ dependencies = [ "cookiecutter==2.6.0", "psutil==5.9.8", "astor==0.8.1", + "asttokens", "pydantic>=2.10", "packaging>=23.2", "requests>=2.32", @@ -31,7 +32,7 @@ dependencies = [ [project.optional-dependencies] test = [ - "tox>=4.23.2", + "tox>=4", ] crewai = [ "crewai>=0.83.0", diff --git a/tests/fixtures/frameworks/crewai/entrypoint_max.py b/tests/fixtures/frameworks/crewai/entrypoint_max.py new file mode 100644 index 00000000..b873d377 --- /dev/null +++ b/tests/fixtures/frameworks/crewai/entrypoint_max.py @@ -0,0 +1,29 @@ +from crewai import Agent, Crew, Process, Task +from crewai.project import CrewBase, agent, crew, task +import tools + +@CrewBase +class TestCrew(): + @agent + def test_agent(self) -> Agent: + return Agent( + config=self.agents_config['test_agent'], + tools=[], + verbose=True + ) + + @task + def test_task(self) -> Task: + return Task( + config=self.tasks_config['test_task'], + ) + + @crew + def crew(self) -> Crew: + return Crew( + agents=self.agents, + tasks=self.tasks, + process=Process.sequential, + verbose=True, + ) + diff --git a/tests/fixtures/frameworks/crewai/entrypoint_min.py b/tests/fixtures/frameworks/crewai/entrypoint_min.py new file mode 100644 index 00000000..7fabf7d3 --- /dev/null +++ b/tests/fixtures/frameworks/crewai/entrypoint_min.py @@ -0,0 +1,16 @@ +from crewai import Agent, Crew, Process, Task +from crewai.project import CrewBase, agent, crew, task +import tools + +@CrewBase +class TestCrew(): + + @crew + def crew(self) -> Crew: + return Crew( + agents=self.agents, + tasks=self.tasks, + process=Process.sequential, + verbose=True, + ) + diff --git a/tests/test_frameworks.py b/tests/test_frameworks.py new file mode 100644 index 00000000..756c14e5 --- /dev/null +++ b/tests/test_frameworks.py @@ -0,0 +1,119 @@ +import os, sys +from pathlib import Path +import shutil +import unittest +from parameterized import parameterized_class +from agentstack import ValidationError +from agentstack import frameworks +from agentstack.tools import ToolConfig + +BASE_PATH = Path(__file__).parent + + +@parameterized_class([ + {"framework": framework} for framework in frameworks.SUPPORTED_FRAMEWORKS +]) +class TestFrameworks(unittest.TestCase): + def setUp(self): + self.project_dir = BASE_PATH/'tmp'/self.framework + + os.makedirs(self.project_dir) + os.makedirs(self.project_dir/'src') + os.makedirs(self.project_dir/'src'/'tools') + + (self.project_dir/'src'/'__init__.py').touch() + (self.project_dir/'src'/'tools'/'__init__.py').touch() + + def tearDown(self): + shutil.rmtree(self.project_dir) + + def _populate_min_entrypoint(self): + """This entrypoint does not have any tools or agents.""" + entrypoint_path = frameworks.get_entrypoint_path(self.framework, self.project_dir) + shutil.copy(BASE_PATH/f"fixtures/frameworks/{self.framework}/entrypoint_min.py", entrypoint_path) + + def _populate_max_entrypoint(self): + """This entrypoint has tools and agents.""" + entrypoint_path = frameworks.get_entrypoint_path(self.framework, self.project_dir) + shutil.copy(BASE_PATH/f"fixtures/frameworks/{self.framework}/entrypoint_max.py", entrypoint_path) + + def _get_test_tool(self) -> ToolConfig: + return ToolConfig(name='test_tool', category='test', tools=['test_tool']) + + def _get_test_tool_starred(self) -> ToolConfig: + return ToolConfig(name='test_tool_star', category='test', tools=['test_tool_star'], tools_bundled=True) + + def test_get_framework_module(self): + module = frameworks.get_framework_module(self.framework) + assert module.__name__ == f"agentstack.frameworks.{self.framework}" + + def test_get_framework_module_invalid(self): + with self.assertRaises(ValueError) as context: + frameworks.get_framework_module('invalid') + + def test_validate_project(self): + self._populate_max_entrypoint() + frameworks.validate_project(self.framework, self.project_dir) + + def test_validate_project_invalid(self): + self._populate_min_entrypoint() + with self.assertRaises(ValidationError) as context: + frameworks.validate_project(self.framework, self.project_dir) + + def test_add_tool(self): + self._populate_max_entrypoint() + frameworks.add_tool(self.framework, self._get_test_tool(), 'test_agent', self.project_dir) + + entrypoint_src = open(frameworks.get_entrypoint_path(self.framework, self.project_dir)).read() + assert 'tools=[tools.test_tool' in entrypoint_src + + def test_add_tool_starred(self): + self._populate_max_entrypoint() + frameworks.add_tool(self.framework, self._get_test_tool_starred(), 'test_agent', self.project_dir) + + entrypoint_src = open(frameworks.get_entrypoint_path(self.framework, self.project_dir)).read() + assert 'tools=[*tools.test_tool_star' in entrypoint_src + + def test_add_tool_invalid(self): + self._populate_min_entrypoint() + with self.assertRaises(ValidationError) as context: + frameworks.add_tool(self.framework, self._get_test_tool(), 'test_agent', self.project_dir) + + def test_remove_tool(self): + self._populate_max_entrypoint() + frameworks.add_tool(self.framework, self._get_test_tool(), 'test_agent', self.project_dir) + frameworks.remove_tool(self.framework, self._get_test_tool(), 'test_agent', self.project_dir) + + entrypoint_src = open(frameworks.get_entrypoint_path(self.framework, self.project_dir)).read() + assert 'tools=[tools.test_tool' not in entrypoint_src + + def test_remove_tool_starred(self): + self._populate_max_entrypoint() + frameworks.add_tool(self.framework, self._get_test_tool_starred(), 'test_agent', self.project_dir) + frameworks.remove_tool(self.framework, self._get_test_tool_starred(), 'test_agent', self.project_dir) + + entrypoint_src = open(frameworks.get_entrypoint_path(self.framework, self.project_dir)).read() + assert 'tools=[*tools.test_tool_star' not in entrypoint_src + + def test_add_multiple_tools(self): + self._populate_max_entrypoint() + frameworks.add_tool(self.framework, self._get_test_tool(), 'test_agent', self.project_dir) + frameworks.add_tool(self.framework, self._get_test_tool_starred(), 'test_agent', self.project_dir) + + entrypoint_src = open(frameworks.get_entrypoint_path(self.framework, self.project_dir)).read() + assert ( # ordering is not guaranteed + 'tools=[tools.test_tool, *tools.test_tool_star' in entrypoint_src + or + 'tools=[*tools.test_tool_star, tools.test_tool' in entrypoint_src + ) + + def test_remove_one_tool_of_multiple(self): + self._populate_max_entrypoint() + frameworks.add_tool(self.framework, self._get_test_tool(), 'test_agent', self.project_dir) + frameworks.add_tool(self.framework, self._get_test_tool_starred(), 'test_agent', self.project_dir) + frameworks.remove_tool(self.framework, self._get_test_tool(), 'test_agent', self.project_dir) + + entrypoint_src = open(frameworks.get_entrypoint_path(self.framework, self.project_dir)).read() + assert 'tools=[tools.test_tool' not in entrypoint_src + assert 'tools=[*tools.test_tool_star' in entrypoint_src + diff --git a/tests/test_tool_config.py b/tests/test_tool_config.py index 90931c7b..894406af 100644 --- a/tests/test_tool_config.py +++ b/tests/test_tool_config.py @@ -3,7 +3,7 @@ import unittest import importlib.resources from pathlib import Path -from agentstack.generation.tool_generation import get_all_tool_paths, get_all_tool_names, ToolConfig +from agentstack.tools import ToolConfig, get_all_tool_paths, get_all_tool_names BASE_PATH = Path(__file__).parent diff --git a/tests/test_tool_generation_init.py b/tests/test_tool_generation_init.py new file mode 100644 index 00000000..88c0a9e3 --- /dev/null +++ b/tests/test_tool_generation_init.py @@ -0,0 +1,78 @@ +import os, sys +from pathlib import Path +import shutil +import unittest +from parameterized import parameterized_class +from agentstack import ValidationError +from agentstack import frameworks +from agentstack.tools import ToolConfig +from agentstack.generation.files import ConfigFile +from agentstack.generation.tool_generation import ToolsInitFile, TOOLS_INIT_FILENAME + + +BASE_PATH = Path(__file__).parent + +@parameterized_class([ + {"framework": framework} for framework in frameworks.SUPPORTED_FRAMEWORKS +]) +class TestToolGenerationInit(unittest.TestCase): + def setUp(self): + self.project_dir = BASE_PATH/'tmp'/'tool_generation' + os.makedirs(self.project_dir) + os.makedirs(self.project_dir/'src') + os.makedirs(self.project_dir/'src'/'tools') + (self.project_dir/'src'/'__init__.py').touch() + (self.project_dir/'src'/'tools'/'__init__.py').touch() + shutil.copy(BASE_PATH/'fixtures'/'agentstack.json', self.project_dir/'agentstack.json') + # set the framework in agentstack.json + with ConfigFile(self.project_dir) as config: + config.framework = self.framework + + def tearDown(self): + shutil.rmtree(self.project_dir) + + def _get_test_tool(self) -> ToolConfig: + return ToolConfig(name='test_tool', category='test', tools=['test_tool']) + + def _get_test_tool_alt(self) -> ToolConfig: + return ToolConfig(name='test_tool_alt', category='test', tools=['test_tool_alt']) + + def test_tools_init_file(self): + tools_init = ToolsInitFile(self.project_dir/TOOLS_INIT_FILENAME) + # file is empty + assert tools_init.get_import_for_tool(self._get_test_tool()) == None + + def test_tools_init_file_missing(self): + with self.assertRaises(ValidationError) as context: + tools_init = ToolsInitFile(self.project_dir/'missing') + + def test_tools_init_file_add_import(self): + tool = self._get_test_tool() + with ToolsInitFile(self.project_dir/TOOLS_INIT_FILENAME) as tools_init: + tools_init.add_import_for_tool(self.framework, tool) + + tool_init_src = open(self.project_dir/TOOLS_INIT_FILENAME).read() + assert tool.get_import_statement(self.framework) in tool_init_src + + def test_tools_init_file_add_import_multiple(self): + tool = self._get_test_tool() + tool_alt = self._get_test_tool_alt() + with ToolsInitFile(self.project_dir/TOOLS_INIT_FILENAME) as tools_init: + tools_init.add_import_for_tool(self.framework, tool) + + with ToolsInitFile(self.project_dir/TOOLS_INIT_FILENAME) as tools_init: + tools_init.add_import_for_tool(self.framework, tool_alt) + + # Should not be able to re-add a tool import + with self.assertRaises(ValidationError) as context: + with ToolsInitFile(self.project_dir/TOOLS_INIT_FILENAME) as tools_init: + tools_init.add_import_for_tool(self.framework, tool) + + tool_init_src = open(self.project_dir/TOOLS_INIT_FILENAME).read() + assert tool.get_import_statement(self.framework) in tool_init_src + assert tool_alt.get_import_statement(self.framework) in tool_init_src + # TODO this might be a little too strict + assert tool_init_src == """ +from .test_tool_tool import test_tool +from .test_tool_alt_tool import test_tool_alt""" + diff --git a/tox.ini b/tox.ini index 6ab5f4a5..6733c3ab 100644 --- a/tox.ini +++ b/tox.ini @@ -9,6 +9,7 @@ envlist = py310,py311,py312 [testenv] deps = pytest + parameterized mypy: mypy commands = pytest -v From 5294e3ef80cfa32f10f0401a8e4cdc6c2d086c6d Mon Sep 17 00:00:00 2001 From: Travis Dent Date: Wed, 4 Dec 2024 18:20:40 -0800 Subject: [PATCH 06/14] Clenaup `tool_generation`, add ToolConfig.module_name, add tests for tool_generation frontend --- agentstack/generation/tool_generation.py | 105 ++++++++++------------- agentstack/tools.py | 8 +- tests/test_frameworks.py | 1 + tests/test_tool_generation.py | 69 +++++++++++++++ tests/test_tool_generation_init.py | 3 +- 5 files changed, 122 insertions(+), 64 deletions(-) create mode 100644 tests/test_tool_generation.py diff --git a/agentstack/generation/tool_generation.py b/agentstack/generation/tool_generation.py index 26b4fde0..4bde2c24 100644 --- a/agentstack/generation/tool_generation.py +++ b/agentstack/generation/tool_generation.py @@ -1,28 +1,17 @@ import os, sys -from typing import Optional, Any, List -import importlib.resources +from typing import Optional, Union, Any from pathlib import Path -import json -import sys -from typing import Optional, List, Dict, Union - -from . import get_agent_names -from .gen_utils import insert_code_after_tag, string_in_file -from ..utils import open_json_file, get_framework, term_color -import os import shutil import fileinput import ast +from agentstack import frameworks from agentstack import packaging from agentstack import ValidationError -from agentstack.utils import get_package_path +from agentstack.utils import term_color from agentstack.tools import ToolConfig from agentstack.generation import astools from agentstack.generation.files import ConfigFile, EnvFile -from agentstack import frameworks -from .gen_utils import insert_code_after_tag, string_in_file -from ..utils import open_json_file, get_framework, term_color # This is the filename of the location of tool imports in the user's project. @@ -44,7 +33,7 @@ def get_import_for_tool(self, tool: ToolConfig) -> ast.Import: raises a ValidationError if the tool is imported multiple times. """ all_imports = astools.get_all_imports(self.tree) - tool_imports = [i for i in all_imports if tool.name in i.names[0].name] + tool_imports = [i for i in all_imports if tool.module_name == i.module] if len(tool_imports) > 1: raise ValidationError(f"Multiple imports for tool {tool.name} found in {self.filename}") @@ -77,7 +66,7 @@ def remove_import_for_tool(self, framework: str, tool: ToolConfig): Remove an import for a tool. raises a ValidationError if the tool is not imported. """ - tool_import = self.get_imports_for_tool(tool) + tool_import = self.get_import_for_tool(tool) if not tool_import: raise ValidationError(f"Tool {tool.name} not imported in {self.filename}") @@ -85,98 +74,92 @@ def remove_import_for_tool(self, framework: str, tool: ToolConfig): self.edit_node_range(start, end, "") -def add_tool(tool_name: str, agents: Optional[List[str]] = [], path: Optional[str] = None): - if path: - path = path.endswith('/') and path or path + '/' - else: - path = './' - - framework = get_framework(path) +def add_tool(tool_name: str, agents: Optional[list[str]] = [], path: Optional[Path] = None): + if path is None: path = Path() agentstack_config = ConfigFile(path) + framework = agentstack_config.framework if tool_name in agentstack_config.tools: print(term_color(f'Tool {tool_name} is already installed', 'red')) sys.exit(1) - tool_data = ToolConfig.from_tool_name(tool_name) - tool_file_path = tool_data.get_impl_file_path(framework) + tool = ToolConfig.from_tool_name(tool_name) + tool_file_path = tool.get_impl_file_path(framework) - if tool_data.packages: - packaging.install(' '.join(tool_data.packages)) - shutil.copy(tool_file_path, f'{path}src/tools/{tool_name}_tool.py') # Move tool from package to project + if tool.packages: + packaging.install(' '.join(tool.packages)) + + # Move tool from package to project + shutil.copy(tool_file_path, path/f'src/tools/{tool.module_name}.py') try: # Edit the user's project tool init file to include the tool with ToolsInitFile(path/TOOLS_INIT_FILENAME) as tools_init: - tools_init.add_import_for_tool(tool_data) + tools_init.add_import_for_tool(framework, tool) except ValidationError as e: print(term_color(f"Error adding tool:\n{e}", 'red')) - sys.exit(1) # Edit the framework entrypoint file to include the tool in the agent definition - if not len(agents): # If no agents are specified, add the tool to all agents + if not agents: # If no agents are specified, add the tool to all agents agents = frameworks.get_agent_names(framework, path) for agent_name in agents: - frameworks.add_tool(framework, tool_data, agent_name, path) + frameworks.add_tool(framework, tool, agent_name, path) - if tool_data.env: # add environment variables which don't exist + if tool.env: # add environment variables which don't exist with EnvFile(path) as env: - for var, value in tool_data.env.items(): + for var, value in tool.env.items(): env.append_if_new(var, value) with EnvFile(path, filename=".env.example") as env: - for var, value in tool_data.env.items(): + for var, value in tool.env.items(): env.append_if_new(var, value) - if tool_data.post_install: - os.system(tool_data.post_install) + if tool.post_install: + os.system(tool.post_install) with agentstack_config as config: - config.tools.append(tool_name) - - print(term_color(f'🔨 Tool {tool_name} added to agentstack project successfully', 'green')) - if tool_data.cta: - print(term_color(f'đŸĒŠ {tool_data.cta}', 'blue')) + config.tools.append(tool.name) + print(term_color(f'🔨 Tool {tool.name} added to agentstack project successfully', 'green')) + if tool.cta: + print(term_color(f'đŸĒŠ {tool.cta}', 'blue')) -def remove_tool(tool_name: str, agents: Optional[List[str]] = [], path: Optional[str] = None): - if path: - path = path.endswith('/') and path or path + '/' - else: - path = './' - framework = get_framework() +def remove_tool(tool_name: str, agents: Optional[list[str]] = [], path: Optional[Path] = None): + if path is None: path = Path() agentstack_config = ConfigFile(path) + framework = agentstack_config.framework if not tool_name in agentstack_config.tools: print(term_color(f'Tool {tool_name} is not installed', 'red')) sys.exit(1) - tool_data = ToolConfig.from_tool_name(tool_name) - if tool_data.packages: - packaging.remove(' '.join(tool_data.packages)) + tool = ToolConfig.from_tool_name(tool_name) + if tool.packages: + packaging.remove(' '.join(tool.packages)) + try: - os.remove(f'{path}src/tools/{tool_name}_tool.py') + os.remove(path/f'src/tools/{tool.module_name}.py') except FileNotFoundError: - print(f'"src/tools/{tool_name}_tool.py" not found') + print(f'"src/tools/{tool.module_name}.py" not found') try: # Edit the user's project tool init file to exclude the tool with ToolsInitFile(path/TOOLS_INIT_FILENAME) as tools_init: - tools_init.remove_import_for_tool(tool_data) + tools_init.remove_import_for_tool(framework, tool) except ValidationError as e: print(term_color(f"Error removing tool:\n{e}", 'red')) - sys.exit(1) # Edit the framework entrypoint file to exclude the tool in the agent definition - if not len(agents): # If no agents are specified, remove the tool from all agents + if not agents: # If no agents are specified, remove the tool from all agents agents = frameworks.get_agent_names(framework, path) for agent_name in agents: - frameworks.remove_tool(framework, tool_data, agent_name, path) + frameworks.remove_tool(framework, tool, agent_name, path) - if tool_data.post_remove: - os.system(tool_data.post_remove) + if tool.post_remove: + os.system(tool.post_remove) # We don't remove the .env variables to preserve user data. with agentstack_config as config: - config.tools.remove(tool_name) + config.tools.remove(tool.name) - print(term_color(f'🔨 Tool {tool_name}', 'green'), term_color('removed', 'red'), term_color('from agentstack project successfully', 'green')) + print(term_color(f'🔨 Tool {tool_name}', 'green'), term_color('removed', 'red'), + term_color('from agentstack project successfully', 'green')) diff --git a/agentstack/tools.py b/agentstack/tools.py index 6e24b663..209264f7 100644 --- a/agentstack/tools.py +++ b/agentstack/tools.py @@ -41,11 +41,15 @@ def from_json(cls, path: Path) -> 'ToolConfig': print(f"{' '.join(error['loc'])}: {error['msg']}") sys.exit(1) + @property + def module_name(self) -> str: + return f"{self.name}_tool" + def get_import_statement(self, framework: str) -> str: - return f"from .{self.name}_tool import {', '.join(self.tools)}" + return f"from .{self.module_name} import {', '.join(self.tools)}" def get_impl_file_path(self, framework: str) -> Path: - return get_package_path()/f'templates/{framework}/tools/{self.name}_tool.py' + return get_package_path()/f'templates/{framework}/tools/{self.module_name}.py' def get_all_tool_paths() -> list[Path]: paths = [] diff --git a/tests/test_frameworks.py b/tests/test_frameworks.py index 756c14e5..94a90c91 100644 --- a/tests/test_frameworks.py +++ b/tests/test_frameworks.py @@ -3,6 +3,7 @@ import shutil import unittest from parameterized import parameterized_class + from agentstack import ValidationError from agentstack import frameworks from agentstack.tools import ToolConfig diff --git a/tests/test_tool_generation.py b/tests/test_tool_generation.py new file mode 100644 index 00000000..a19b92cf --- /dev/null +++ b/tests/test_tool_generation.py @@ -0,0 +1,69 @@ +import os, sys +from pathlib import Path +import shutil +import unittest +from parameterized import parameterized_class + +from agentstack import frameworks +from agentstack.tools import get_all_tools, ToolConfig +from agentstack.generation.files import ConfigFile +from agentstack.generation.tool_generation import add_tool, remove_tool, TOOLS_INIT_FILENAME + + +BASE_PATH = Path(__file__).parent + +# TODO parameterize all tools +@parameterized_class([ + {"framework": framework} for framework in frameworks.SUPPORTED_FRAMEWORKS +]) +class TestToolGeneration(unittest.TestCase): + def setUp(self): + self.project_dir = BASE_PATH/'tmp'/'tool_generation' + + os.makedirs(self.project_dir) + os.makedirs(self.project_dir/'src') + os.makedirs(self.project_dir/'src'/'tools') + (self.project_dir/'src'/'__init__.py').touch() + (self.project_dir/TOOLS_INIT_FILENAME).touch() + + # populate the entrypoint + entrypoint_path = frameworks.get_entrypoint_path(self.framework, self.project_dir) + shutil.copy(BASE_PATH/f"fixtures/frameworks/{self.framework}/entrypoint_max.py", entrypoint_path) + + # set the framework in agentstack.json + shutil.copy(BASE_PATH/'fixtures'/'agentstack.json', self.project_dir/'agentstack.json') + with ConfigFile(self.project_dir) as config: + config.framework = self.framework + + def tearDown(self): + shutil.rmtree(self.project_dir) + + def test_add_tool(self): + tool_conf = ToolConfig.from_tool_name('agent_connect') + add_tool('agent_connect', path=self.project_dir) + + entrypoint_path = frameworks.get_entrypoint_path(self.framework, self.project_dir) + entrypoint_src = open(entrypoint_path).read() + tools_init_src = open(self.project_dir/TOOLS_INIT_FILENAME).read() + + # TODO verify tool is added to all agents (this is covered in test_frameworks.py) + #assert 'agent_connect' in entrypoint_src + assert f'from .{tool_conf.module_name} import' in tools_init_src + assert (self.project_dir/'src'/'tools'/f'{tool_conf.module_name}.py').exists() + assert 'agent_connect' in open(self.project_dir/'agentstack.json').read() + + def test_remove_tool(self): + tool_conf = ToolConfig.from_tool_name('agent_connect') + add_tool('agent_connect', path=self.project_dir) + remove_tool('agent_connect', path=self.project_dir) + + entrypoint_path = frameworks.get_entrypoint_path(self.framework, self.project_dir) + entrypoint_src = open(entrypoint_path).read() + tools_init_src = open(self.project_dir/TOOLS_INIT_FILENAME).read() + + # TODO verify tool is removed from all agents (this is covered in test_frameworks.py) + #assert 'agent_connect' not in entrypoint_src + assert f'from .{tool_conf.module_name} import' not in tools_init_src + assert not (self.project_dir/'src'/'tools'/f'{tool_conf.module_name}.py').exists() + assert 'agent_connect' not in open(self.project_dir/'agentstack.json').read() + diff --git a/tests/test_tool_generation_init.py b/tests/test_tool_generation_init.py index 88c0a9e3..6df0ccae 100644 --- a/tests/test_tool_generation_init.py +++ b/tests/test_tool_generation_init.py @@ -3,6 +3,7 @@ import shutil import unittest from parameterized import parameterized_class + from agentstack import ValidationError from agentstack import frameworks from agentstack.tools import ToolConfig @@ -17,7 +18,7 @@ ]) class TestToolGenerationInit(unittest.TestCase): def setUp(self): - self.project_dir = BASE_PATH/'tmp'/'tool_generation' + self.project_dir = BASE_PATH/'tmp'/'tool_generation_init' os.makedirs(self.project_dir) os.makedirs(self.project_dir/'src') os.makedirs(self.project_dir/'src'/'tools') From 98a47dda97162faf819b7a48a0a3a5b36c32462e Mon Sep 17 00:00:00 2001 From: Travis Dent Date: Thu, 5 Dec 2024 08:41:40 -0800 Subject: [PATCH 07/14] Comments, imports, naming, typing cleanup. --- agentstack/frameworks/__init__.py | 80 ++++++++++++------- agentstack/frameworks/crewai.py | 29 ++++--- .../generation/{astools.py => asttools.py} | 3 +- agentstack/generation/tool_generation.py | 11 +-- tests/test_frameworks.py | 1 + 5 files changed, 74 insertions(+), 50 deletions(-) rename agentstack/generation/{astools.py => asttools.py} (99%) diff --git a/agentstack/frameworks/__init__.py b/agentstack/frameworks/__init__.py index e9359f51..8ce04c42 100644 --- a/agentstack/frameworks/__init__.py +++ b/agentstack/frameworks/__init__.py @@ -1,19 +1,39 @@ """ Methods for interacting with framework-specific features. -Each framework should have a module in the `frameworks` package which defines the following methods: - -- `ENTRYPOINT`: Path: Relative path to the entrypoint file for the framework -- `validate_project(framework: str, path: Optional[Path] = None) -> None`: Validate that a project is ready to run. - Raises a `ValidationError` if the project is not valid. -- `add_tool(framework: str, path: Optional[Path] = None) -> None`: Add a tool to the framework. -- `remove_tool(framework: str, path: Optional[Path] = None) -> None`: Remove a tool from the framework. -- `add_agent(framework: str, path: Optional[Path] = None) -> None`: Add an agent to the framework. -- `remove_agent(framework: str, path: Optional[Path] = None) -> None`: Remove an agent from the framework. -- `add_input(framework: str, path: Optional[Path] = None) -> None`: Add an input to the framework. -- `remove_input(framework: str, path: Optional[Path] = None) -> None`: Remove an input from the framework. +Each framework should have a module in the `frameworks` package which defines the +following methods: + +`ENTRYPOINT`: Path: + Relative path to the entrypoint file for the framework in the user's project. + ie. `src/crewai.py` + +`validate_project(path: Optional[Path] = None) -> None`: + Validate that a user's project is ready to run. + Raises a `ValidationError` if the project is not valid. + +`add_tool(tool: ToolConfig, agent_name: str, path: Optional[Path] = None) -> None`: + Add a tool to an agent in the user's project. + +`remove_tool(tool: ToolConfig, agent_name: str, path: Optional[Path] = None) -> None`: + Remove a tool from an agent in user's project. + +`get_agent_names(path: Optional[Path] = None) -> list[str]`: + Get a list of agent names in the user's project. + +`add_agent(path: Optional[Path] = None) -> None`: + Add an agent to the user's project. + +`remove_agent(path: Optional[Path] = None) -> None`: + Remove an agent from the user's project. + +`add_input(path: Optional[Path] = None) -> None`: + Add an input to the user's project. + +`remove_input(path: Optional[Path] = None) -> None`: + Remove an input from the user's project. """ -from typing import TYPE_CHECKING, Optional +from typing import Optional from importlib import import_module from pathlib import Path from agentstack import ValidationError @@ -27,12 +47,15 @@ def get_framework_module(framework: str) -> import_module: """ Get the module for a framework. """ - if framework == CREWAI: - from . import crewai - return crewai - else: + if not framework in SUPPORTED_FRAMEWORKS: raise ValueError(f"Framework {framework} not supported") + try: + # repeated calls hit the `sys.modules` cache + return import_module(f".{framework}", __package__) + except ImportError: + raise ValueError(f"Framework {framework} could not be imported.") + def get_entrypoint_path(framework: str, path: Optional[Path] = None) -> Path: """ Get the path to the entrypoint file for a framework. @@ -42,52 +65,51 @@ def get_entrypoint_path(framework: str, path: Optional[Path] = None) -> Path: def validate_project(framework: str, path: Optional[Path] = None): """ - Run the framework specific project validation. + Validate that the user's project is ready to run. """ return get_framework_module(framework).validate_project(path) -def add_tool(framework: str, tool: ToolConfig, agent: str, path: Optional[Path] = None): +def add_tool(framework: str, tool: ToolConfig, agent_name: str, path: Optional[Path] = None): """ - Add a tool to the framework. - + Add a tool to the user's project. The tool will have aready been installed in the user's application and have all dependencies installed. We're just handling code generation here. """ - return get_framework_module(framework).add_tool(tool, agent, path) + return get_framework_module(framework).add_tool(tool, agent_name, path) -def remove_tool(framework: str, tool: ToolConfig, agent: str, path: Optional[Path] = None): +def remove_tool(framework: str, tool: ToolConfig, agent_name: str, path: Optional[Path] = None): """ - Remove a tool from the framework. + Remove a tool from the user's project. """ - return get_framework_module(framework).remove_tool(tool, agent, path) + return get_framework_module(framework).remove_tool(tool, agent_name, path) def get_agent_names(framework: str, path: Optional[Path] = None) -> list[str]: """ - Get a list of agent names from the framework. + Get a list of agent names in the user's project. """ return get_framework_module(framework).get_agent_names(path) def add_agent(framework: str, path: Optional[Path] = None): """ - Add an agent to the framework. + Add an agent to the user's project. """ return get_framework_module(framework).add_agent(path) def remove_agent(framework: str, path: Optional[Path] = None): """ - Remove an agent from the framework. + Remove an agent from the user's project. """ return get_framework_module(framework).remove_agent(path) def add_input(framework: str, path: Optional[Path] = None): """ - Add an input to the framework. + Add an input to the user's project. """ return get_framework_module(framework).add_input(path) def remove_input(framework: str, path: Optional[Path] = None): """ - Remove an input from the framework. + Remove an input from the user's project. """ return get_framework_module(framework).remove_input(path) diff --git a/agentstack/frameworks/crewai.py b/agentstack/frameworks/crewai.py index 02df063d..e0db44e7 100644 --- a/agentstack/frameworks/crewai.py +++ b/agentstack/frameworks/crewai.py @@ -3,13 +3,13 @@ import ast from agentstack import ValidationError from agentstack.tools import ToolConfig -from agentstack.generation import astools +from agentstack.generation import asttools from . import SUPPORTED_FRAMEWORKS ENTRYPOINT: Path = Path('src/crew.py') -class CrewFile(astools.File): +class CrewFile(asttools.File): """ Parses and manipulates the CrewAI entrypoint file. All AST interactions should happen within the methods of this class. @@ -20,7 +20,7 @@ def get_base_class(self) -> ast.ClassDef: """A base class is a class decorated with `@CrewBase`.""" if self._base_class is None: # Gets cached to save repeat iteration try: - self._base_class = astools.find_class_with_decorator(self.tree, 'CrewBase')[0] + self._base_class = asttools.find_class_with_decorator(self.tree, 'CrewBase')[0] except IndexError: raise ValidationError(f"`@CrewBase` decorated class not found in {ENTRYPOINT}") return self._base_class @@ -29,17 +29,17 @@ def get_crew_method(self) -> ast.FunctionDef: """A `crew` method is a method decorated with `@crew`.""" try: base_class = self.get_base_class() - return astools.find_decorated_method_in_class(base_class, 'crew')[0] + return asttools.find_decorated_method_in_class(base_class, 'crew')[0] except IndexError: raise ValidationError(f"`@crew` decorated method not found in `{base_class.name}` class in {ENTRYPOINT}") def get_task_methods(self) -> list[ast.FunctionDef]: """A `task` method is a method decorated with `@task`.""" - return astools.find_decorated_method_in_class(self.get_base_class(), 'task') + return asttools.find_decorated_method_in_class(self.get_base_class(), 'task') def get_agent_methods(self) -> list[ast.FunctionDef]: """An `agent` method is a method decorated with `@agent`.""" - return astools.find_decorated_method_in_class(self.get_base_class(), 'agent') + return asttools.find_decorated_method_in_class(self.get_base_class(), 'agent') def get_agent_tools(self, agent_name: str) -> ast.List: """ @@ -49,15 +49,15 @@ def get_agent_tools(self, agent_name: str) -> ast.List: The method returns a new class instance with the tools as a list of callables under the kwarg `tools`. """ - method = astools.find_method(self.get_agent_methods(), agent_name) + method = asttools.find_method(self.get_agent_methods(), agent_name) if method is None: raise ValidationError(f"`@agent` method `{agent_name}` does not exist in {ENTRYPOINT}") - agent_class = astools.find_class_instantiation(method, 'Agent') + agent_class = asttools.find_class_instantiation(method, 'Agent') if agent_class is None: raise ValidationError(f"`@agent` method `{agent_name}` does not have an `Agent` class instantiation in {ENTRYPOINT}") - tools_kwarg = astools.find_kwarg_in_method_call(agent_class, 'tools') + tools_kwarg = asttools.find_kwarg_in_method_call(agent_class, 'tools') if not tools_kwarg: raise ValidationError(f"`@agent` method `{agent_name}` does not have a keyword argument `tools` in {ENTRYPOINT}") @@ -71,14 +71,14 @@ def add_agent_tools(self, agent_name: str, tool: ToolConfig): The method returns a new class instance with the tools as a list of callables under the kwarg `tools`. """ - method = astools.find_method(self.get_agent_methods(), agent_name) + method = asttools.find_method(self.get_agent_methods(), agent_name) if method is None: raise ValidationError(f"`@agent` method `{agent_name}` does not exist in {ENTRYPOINT}") new_tool_nodes = [] for tool_name in tool.tools: # This prefixes the tool name with the 'tools' module - node = astools.create_attribute('tools', tool_name) + node = asttools.create_attribute('tools', tool_name) if tool.tools_bundled: # Splat the variable if it's bundled node = ast.Starred(value=node, ctx=ast.Load()) new_tool_nodes.append(node) @@ -114,7 +114,7 @@ def remove_agent_tools(self, agent_name: str, tool: ToolConfig): def validate_project(path: Optional[Path] = None) -> None: """ Validate that a CrewAI project is ready to run. - Raises a frameworks.VaidationError if the project is not valid. + Raises an `agentstack.VaidationError` if the project is not valid. """ if path is None: path = Path() try: @@ -134,7 +134,7 @@ def validate_project(path: Optional[Path] = None) -> None: except ValidationError as e: raise e - # The Crew class must have one or more methods decorated with `@agent` + # The Crew class must have one or more methods decorated with `@task` if len(crew_file.get_task_methods()) < 1: raise ValidationError( f"`@task` decorated method not found in `{class_node.name}` class in {ENTRYPOINT}.\n" @@ -148,8 +148,7 @@ def validate_project(path: Optional[Path] = None) -> None: def add_tool(tool: ToolConfig, agent_name: str, path: Optional[Path] = None): """ - Add a tool to the CrewAI framework for the specified agent. - + Add a tool to the CrewAI entrypoint for the specified agent. The agent should already exist in the crew class and have a keyword argument `tools`. """ if path is None: path = Path() diff --git a/agentstack/generation/astools.py b/agentstack/generation/asttools.py similarity index 99% rename from agentstack/generation/astools.py rename to agentstack/generation/asttools.py index 7eb84f53..a89366b5 100644 --- a/agentstack/generation/astools.py +++ b/agentstack/generation/asttools.py @@ -32,7 +32,8 @@ class File: Edits are done using string indexing on the source code, which preserves a majority of the original formatting and prevents comments from being lost. - In cases where we are constructing new AST nodes, we use `ast.unparse`. + In cases where we are constructing new AST nodes, we use `astor` to render + the node as source code. """ filename: Path = None source: str = None diff --git a/agentstack/generation/tool_generation.py b/agentstack/generation/tool_generation.py index 4bde2c24..3cdf23a5 100644 --- a/agentstack/generation/tool_generation.py +++ b/agentstack/generation/tool_generation.py @@ -10,14 +10,14 @@ from agentstack import ValidationError from agentstack.utils import term_color from agentstack.tools import ToolConfig -from agentstack.generation import astools +from agentstack.generation import asttools from agentstack.generation.files import ConfigFile, EnvFile # This is the filename of the location of tool imports in the user's project. TOOLS_INIT_FILENAME: Path = Path("src/tools/__init__.py") -class ToolsInitFile(astools.File): +class ToolsInitFile(asttools.File): """ Modifiable AST representation of the tools init file. @@ -27,12 +27,12 @@ class ToolsInitFile(astools.File): tools_init.add_import_for_tool(...) ``` """ - def get_import_for_tool(self, tool: ToolConfig) -> ast.Import: + def get_import_for_tool(self, tool: ToolConfig) -> Union[ast.Import, ast.ImportFrom]: """ Get the import statement for a tool. raises a ValidationError if the tool is imported multiple times. """ - all_imports = astools.get_all_imports(self.tree) + all_imports = asttools.get_all_imports(self.tree) tool_imports = [i for i in all_imports if tool.module_name == i.module] if len(tool_imports) > 1: @@ -53,7 +53,7 @@ def add_import_for_tool(self, framework: str, tool: ToolConfig): raise ValidationError(f"Tool {tool.name} already imported in {self.filename}") try: - last_import = astools.get_all_imports(self.tree)[-1] + last_import = asttools.get_all_imports(self.tree)[-1] start, end = self.get_node_range(last_import) except IndexError: start, end = 0, 0 # No imports in the file @@ -136,6 +136,7 @@ def remove_tool(tool_name: str, agents: Optional[list[str]] = [], path: Optional if tool.packages: packaging.remove(' '.join(tool.packages)) + # TODO ensure that other agents in the project are not using the tool. try: os.remove(path/f'src/tools/{tool.module_name}.py') except FileNotFoundError: diff --git a/tests/test_frameworks.py b/tests/test_frameworks.py index 94a90c91..325be57d 100644 --- a/tests/test_frameworks.py +++ b/tests/test_frameworks.py @@ -66,6 +66,7 @@ def test_add_tool(self): frameworks.add_tool(self.framework, self._get_test_tool(), 'test_agent', self.project_dir) entrypoint_src = open(frameworks.get_entrypoint_path(self.framework, self.project_dir)).read() + # TODO these asserts are not framework agnostic assert 'tools=[tools.test_tool' in entrypoint_src def test_add_tool_starred(self): From 5d11d32a8cba8a651817bb1e559e12617138685b Mon Sep 17 00:00:00 2001 From: Travis Dent Date: Thu, 5 Dec 2024 15:59:49 -0800 Subject: [PATCH 08/14] Migrate agent generation to frameworks. Add agent generation tests. --- agentstack/agents.py | 89 ++++++++++++ agentstack/frameworks/__init__.py | 18 +-- agentstack/frameworks/crewai.py | 37 ++++- agentstack/generation/__init__.py | 2 +- agentstack/generation/agent_generation.py | 137 +++++------------- agentstack/main.py | 5 +- agentstack/utils.py | 10 ++ tests/fixtures/agents_max.yaml | 16 ++ tests/fixtures/agents_min.yaml | 2 + .../frameworks/crewai/entrypoint_min.py | 2 +- tests/test_agents_config.py | 77 ++++++++++ tests/test_frameworks.py | 2 +- tests/test_generation_agent.py | 60 ++++++++ ..._generation.py => test_generation_tool.py} | 2 +- 14 files changed, 338 insertions(+), 121 deletions(-) create mode 100644 agentstack/agents.py create mode 100644 tests/fixtures/agents_max.yaml create mode 100644 tests/fixtures/agents_min.yaml create mode 100644 tests/test_agents_config.py create mode 100644 tests/test_generation_agent.py rename tests/{test_tool_generation.py => test_generation_tool.py} (98%) diff --git a/agentstack/agents.py b/agentstack/agents.py new file mode 100644 index 00000000..4ebb27d0 --- /dev/null +++ b/agentstack/agents.py @@ -0,0 +1,89 @@ +from typing import Optional +import os, sys +from pathlib import Path +import pydantic +from ruamel.yaml import YAML, YAMLError +from ruamel.yaml.scalarstring import FoldedScalarString +from agentstack import ValidationError + + +AGENTS_FILENAME: Path = Path("src/config/agents.yaml") + +yaml = YAML() +yaml.preserve_quotes = True # Preserve quotes in existing data + +class AgentConfig(pydantic.BaseModel): + """ + Interface for interacting with an agent configuration. + + Multiple agents are stored in a single YAML file, so we always look up the + requested agent by `name`. + + Use it as a context manager to make and save edits: + ```python + with AgentConfig('agent_name') as config: + config.llm = "openai/gpt-4o" + + Config Schema + ------------- + name: str + The name of the agent; used for lookup. + role: Optional[str] + The role of the agent. + goal: Optional[str] + The goal of the agent. + backstory: Optional[str] + The backstory of the agent. + llm: Optional[str] + The model this agent should use. + Adheres to the format set by the framework. + """ + name: str + role: Optional[str] = "" + goal: Optional[str] = "" + backstory: Optional[str] = "" + llm: Optional[str] = "" + + def __init__(self, name: str, path: Optional[Path] = None): + if not path: path = Path() + + if not os.path.exists(path/AGENTS_FILENAME): + os.makedirs((path/AGENTS_FILENAME).parent, exist_ok=True) + (path/AGENTS_FILENAME).touch() + + try: + with open(path/AGENTS_FILENAME, 'r') as f: + data = yaml.load(f) or {} + data = data.get(name, {}) or {} + super().__init__(**{**{'name': name}, **data}) + except YAMLError as e: + # TODO format MarkedYAMLError lines/messages + raise ValidationError(f"Error parsing agents file: {filename}\n{e}") + except pydantic.ValidationError as e: + error_str = "Error validating agent config:\n" + for error in e.errors(): + error_str += f"{' '.join(error['loc'])}: {error['msg']}\n" + raise ValidationError(f"Error loading agent {name} from {filename}.\n{error_str}") + + # store the path *after* loading data + self._path = path + + def model_dump(self, *args, **kwargs) -> dict: + dump = super().model_dump(*args, **kwargs) + dump.pop('name') # name is the key, so keep it out of the data + # format these as FoldedScalarStrings + for key in ('role', 'goal', 'backstory'): + dump[key] = FoldedScalarString(dump.get(key) or "") + return {self.name: dump} + + def write(self): + with open(self._path/AGENTS_FILENAME, 'r') as f: + data = yaml.load(f) or {} + + data.update(self.model_dump()) + + with open(self._path/AGENTS_FILENAME, 'w') as f: + yaml.dump(data, f) + + def __enter__(self) -> 'AgentConfig': return self + def __exit__(self, *args): self.write() diff --git a/agentstack/frameworks/__init__.py b/agentstack/frameworks/__init__.py index 8ce04c42..f8d9b635 100644 --- a/agentstack/frameworks/__init__.py +++ b/agentstack/frameworks/__init__.py @@ -38,6 +38,7 @@ from pathlib import Path from agentstack import ValidationError from agentstack.tools import ToolConfig +from agentstack.agents import AgentConfig CREWAI = 'crewai' @@ -47,20 +48,15 @@ def get_framework_module(framework: str) -> import_module: """ Get the module for a framework. """ - if not framework in SUPPORTED_FRAMEWORKS: - raise ValueError(f"Framework {framework} not supported") - try: - # repeated calls hit the `sys.modules` cache - return import_module(f".{framework}", __package__) + return import_module(f".{framework}", package=__package__) except ImportError: - raise ValueError(f"Framework {framework} could not be imported.") + raise Exception(f"Framework {framework} could not be imported.") def get_entrypoint_path(framework: str, path: Optional[Path] = None) -> Path: """ Get the path to the entrypoint file for a framework. """ - if not path: path = Path() return path/get_framework_module(framework).ENTRYPOINT def validate_project(framework: str, path: Optional[Path] = None): @@ -89,17 +85,17 @@ def get_agent_names(framework: str, path: Optional[Path] = None) -> list[str]: """ return get_framework_module(framework).get_agent_names(path) -def add_agent(framework: str, path: Optional[Path] = None): +def add_agent(framework: str, agent: AgentConfig, path: Optional[Path] = None): """ Add an agent to the user's project. """ - return get_framework_module(framework).add_agent(path) + return get_framework_module(framework).add_agent(agent, path) -def remove_agent(framework: str, path: Optional[Path] = None): +def remove_agent(framework: str, agent: AgentConfig, path: Optional[Path] = None): """ Remove an agent from the user's project. """ - return get_framework_module(framework).remove_agent(path) + return get_framework_module(framework).remove_agent(agent, path) def add_input(framework: str, path: Optional[Path] = None): """ diff --git a/agentstack/frameworks/crewai.py b/agentstack/frameworks/crewai.py index e0db44e7..9af30764 100644 --- a/agentstack/frameworks/crewai.py +++ b/agentstack/frameworks/crewai.py @@ -3,6 +3,7 @@ import ast from agentstack import ValidationError from agentstack.tools import ToolConfig +from agentstack.agents import AgentConfig from agentstack.generation import asttools from . import SUPPORTED_FRAMEWORKS @@ -41,6 +42,31 @@ def get_agent_methods(self) -> list[ast.FunctionDef]: """An `agent` method is a method decorated with `@agent`.""" return asttools.find_decorated_method_in_class(self.get_base_class(), 'agent') + def add_agent_method(self, agent: AgentConfig): + """Add a new agent method to the CrewAI entrypoint.""" + # TODO do we want to pre-populate any tools? + agent_methods = self.get_agent_methods() + if agent.name in [method.name for method in agent_methods]: + raise ValidationError(f"Agent `{agent.name}` already exists in {ENTRYPOINT}") + if agent_methods: + # Add after the existing agent methods + _, pos = self.get_node_range(agent_methods[-1]) + else: + # Add before the `crew` method + crew_method = self.get_crew_method() + pos, _ = self.get_node_range(crew_method) + + code = f""" + + @agent + def {agent.name}(self) -> Agent: + return Agent( + config=self.agents_config['{agent.name}'], + tools=[], # add tools here or use `agentstack tools add + verbose=True, + )""" + self.edit_node_range(pos, pos, code) + def get_agent_tools(self, agent_name: str) -> ast.List: """ Get the tools used by an agent as AST nodes. @@ -170,10 +196,15 @@ def get_agent_names(path: Optional[Path] = None) -> list[str]: crew_file = CrewFile(path/ENTRYPOINT) return [method.name for method in crew_file.get_agent_methods()] -def add_agent(path: Optional[Path] = None) -> None: - raise NotImplementedError +def add_agent(agent: AgentConfig, path: Optional[Path] = None) -> None: + """ + Add an agent method to the CrewAI entrypoint. + """ + if path is None: path = Path() + with CrewFile(path/ENTRYPOINT) as crew_file: + crew_file.add_agent_method(agent) -def remove_agent(path: Optional[Path] = None) -> None: +def remove_agent(agent: AgentConfig, path: Optional[Path] = None) -> None: raise NotImplementedError def add_input(path: Optional[Path] = None) -> None: diff --git a/agentstack/generation/__init__.py b/agentstack/generation/__init__.py index 49d62c82..0685338c 100644 --- a/agentstack/generation/__init__.py +++ b/agentstack/generation/__init__.py @@ -1,4 +1,4 @@ -from .agent_generation import generate_agent, get_agent_names +from .agent_generation import add_agent from .task_generation import generate_task, get_task_names from .tool_generation import add_tool, remove_tool from .files import ConfigFile, EnvFile, CONFIG_FILENAME \ No newline at end of file diff --git a/agentstack/generation/agent_generation.py b/agentstack/generation/agent_generation.py index bf64dd2e..a47af60a 100644 --- a/agentstack/generation/agent_generation.py +++ b/agentstack/generation/agent_generation.py @@ -1,105 +1,42 @@ +import os, sys from typing import Optional, List - -from .gen_utils import insert_code_after_tag, get_crew_components, CrewComponent -from agentstack.utils import verify_agentstack_project, get_framework +from pathlib import Path +from agentstack import ValidationError +from agentstack import frameworks +from agentstack.utils import verify_agentstack_project +from agentstack.agents import AgentConfig, AGENTS_FILENAME from agentstack.generation.files import ConfigFile -import os -from ruamel.yaml import YAML -from ruamel.yaml.scalarstring import FoldedScalarString - - -def generate_agent( - name, - role: Optional[str], - goal: Optional[str], - backstory: Optional[str], - llm: Optional[str] -): - agentstack_config = ConfigFile() # TODO path - if not role: - role = 'Add your role here' - if not goal: - goal = 'Add your goal here' - if not backstory: - backstory = 'Add your backstory here' - if not llm: - llm = agentstack_config.default_model - - verify_agentstack_project() - - framework = get_framework() - - if framework == 'crewai': - generate_crew_agent(name, role, goal, backstory, llm) - print(" > Added to src/config/agents.yaml") - else: - print(f"This function is not yet implemented for {framework}") - return - - print(f"Added agent \"{name}\" to your AgentStack project successfully!") - - -def generate_crew_agent( - name, - role: Optional[str] = 'Add your role here', - goal: Optional[str] = 'Add your goal here', - backstory: Optional[str] = 'Add your backstory here', - llm: Optional[str] = 'openai/gpt-4o' -): - config_path = os.path.join('src', 'config', 'agents.yaml') - - # Ensure the directory exists - os.makedirs(os.path.dirname(config_path), exist_ok=True) - - yaml = YAML() - yaml.preserve_quotes = True # Preserve quotes in existing data - - # Read existing data - if os.path.exists(config_path): - with open(config_path, 'r') as file: - try: - data = yaml.load(file) or {} - except Exception as exc: - print(f"Error parsing YAML file: {exc}") - data = {} - else: - data = {} - - # Handle None values - role_str = FoldedScalarString(role) if role else FoldedScalarString('') - goals_str = FoldedScalarString(goal) if goal else FoldedScalarString('') - backstory_str = FoldedScalarString(backstory) if backstory else FoldedScalarString('') - model_str = llm if llm else '' - - # Add new agent details - data[name] = { - 'role': role_str, - 'goal': goals_str, - 'backstory': backstory_str, - 'llm': model_str - } - - # Write back to the file without altering existing content - with open(config_path, 'w') as file: - yaml.dump(data, file) - - # Now lets add the agent to crew.py - file_path = 'src/crew.py' - tag = '# Agent definitions' - code_to_insert = [ - "@agent", - f"def {name}(self) -> Agent:", - " return Agent(", - f" config=self.agents_config['{name}'],", - " tools=[], # add tools here or use `agentstack tools add ", # TODO: Add any tools in agentstack.json - " verbose=True", - " )", - "" - ] - insert_code_after_tag(file_path, tag, code_to_insert) +def add_agent( + agent_name: str, + role: str = 'Add your role here', + goal: str = 'Add your goal here', + backstory: str = 'Add your backstory here', + llm: Optional[str] = None, + path: Optional[Path] = None): + + if path is None: path = Path() + verify_agentstack_project(path) + agentstack_config = ConfigFile(path) + framework = agentstack_config.framework + + agent = AgentConfig(agent_name, path) + with agent as config: + config.role = role + config.goal = goal + config.backstory = backstory + if llm: + config.llm = llm + else: + config.llm = agentstack_config.default_model + + try: + frameworks.add_agent(framework, agent, path) + print(f" > Added to {AGENTS_FILENAME}") + except ValidationError as e: + print(f"Error adding agent to project:\n{e}") + sys.exit(1) + + print(f"Added agent \"{agent_name}\" to your AgentStack project successfully!") -def get_agent_names(framework: str = 'crewai', path: str = '') -> List[str]: - """Get only agent names from the crew file""" - return get_crew_components(framework, CrewComponent.AGENT, path)['agents'] \ No newline at end of file diff --git a/agentstack/main.py b/agentstack/main.py index ab1066d0..768e3072 100644 --- a/agentstack/main.py +++ b/agentstack/main.py @@ -5,8 +5,7 @@ from agentstack.cli import init_project_builder, list_tools, configure_default_model, run_project from agentstack.telemetry import track_cli_command from agentstack.utils import get_version, get_framework -import agentstack.generation as generation -from agentstack import frameworks +from agentstack import generation from agentstack.update import check_for_updates import webbrowser @@ -108,7 +107,7 @@ def main(): if args.generate_command in ['agent', 'a']: if not args.llm: configure_default_model() - generation.generate_agent(args.name, args.role, args.goal, args.backstory, args.llm) + generation.add_agent(args.name, args.role, args.goal, args.backstory, args.llm) elif args.generate_command in ['task', 't']: generation.generate_task(args.name, args.description, args.expected_output, args.agent) else: diff --git a/agentstack/utils.py b/agentstack/utils.py index 0d1bcea9..75c347a1 100644 --- a/agentstack/utils.py +++ b/agentstack/utils.py @@ -3,6 +3,7 @@ import os import sys import json +from ruamel.yaml import YAML import re from importlib.metadata import version from pathlib import Path @@ -73,6 +74,15 @@ def open_json_file(path) -> dict: return data +def open_yaml_file(path) -> dict: + yaml = YAML() + yaml.preserve_quotes = True # Preserve quotes in existing data + + with open(path, 'r') as f: + data = yaml.load(f) + return data + + def clean_input(input_string): special_char_pattern = re.compile(r'[^a-zA-Z0-9\s_]') return re.sub(special_char_pattern, '', input_string).lower().replace(' ', '_').replace('-', '_') diff --git a/tests/fixtures/agents_max.yaml b/tests/fixtures/agents_max.yaml new file mode 100644 index 00000000..d532bbf4 --- /dev/null +++ b/tests/fixtures/agents_max.yaml @@ -0,0 +1,16 @@ +agent_name: + role: >- + role + goal: >- + this is a goal + backstory: >- + backstory + llm: provider/model +second_agent_name: + role: >- + role + goal: >- + this is a goal + backstory: >- + this is a backstory + llm: provider/model \ No newline at end of file diff --git a/tests/fixtures/agents_min.yaml b/tests/fixtures/agents_min.yaml new file mode 100644 index 00000000..1cbea78a --- /dev/null +++ b/tests/fixtures/agents_min.yaml @@ -0,0 +1,2 @@ +agent_name: + \ No newline at end of file diff --git a/tests/fixtures/frameworks/crewai/entrypoint_min.py b/tests/fixtures/frameworks/crewai/entrypoint_min.py index 7fabf7d3..20693b65 100644 --- a/tests/fixtures/frameworks/crewai/entrypoint_min.py +++ b/tests/fixtures/frameworks/crewai/entrypoint_min.py @@ -3,7 +3,7 @@ import tools @CrewBase -class TestCrew(): +class TestCrew: @crew def crew(self) -> Crew: diff --git a/tests/test_agents_config.py b/tests/test_agents_config.py new file mode 100644 index 00000000..0a46bdb9 --- /dev/null +++ b/tests/test_agents_config.py @@ -0,0 +1,77 @@ +import json +import os, sys +import shutil +import unittest +import importlib.resources +from pathlib import Path +from agentstack.agents import AgentConfig, AGENTS_FILENAME + +BASE_PATH = Path(__file__).parent + +class AgentConfigTest(unittest.TestCase): + def setUp(self): + self.project_dir = BASE_PATH/'tmp/agent_config' + os.makedirs(self.project_dir/'src/config') + + def tearDown(self): + shutil.rmtree(self.project_dir) + + def test_empty_file(self): + config = AgentConfig("agent_name", self.project_dir) + assert config.name == "agent_name" + assert config.role is "" + assert config.goal is "" + assert config.backstory is "" + assert config.llm is "" + + def test_read_minimal_yaml(self): + shutil.copy(BASE_PATH/"fixtures/agents_min.yaml", self.project_dir/AGENTS_FILENAME) + config = AgentConfig("agent_name", self.project_dir) + assert config.name == "agent_name" + assert config.role == "" + assert config.goal == "" + assert config.backstory == "" + assert config.llm == "" + + def test_read_maximal_yaml(self): + shutil.copy(BASE_PATH/"fixtures/agents_max.yaml", self.project_dir/AGENTS_FILENAME) + config = AgentConfig("agent_name", self.project_dir) + print(config.model_dump()) + assert config.name == "agent_name" + assert config.role == "role" + assert config.goal == "this is a goal" + assert config.backstory == "backstory" + assert config.llm == "provider/model" + + def test_write_yaml(self): + with AgentConfig("agent_name", self.project_dir) as config: + config.role = "role" + config.goal = "this is a goal" + config.backstory = "backstory" + config.llm = "provider/model" + + yaml_src = open(self.project_dir/AGENTS_FILENAME).read() + assert yaml_src == """agent_name: + role: >- + role + goal: >- + this is a goal + backstory: >- + backstory + llm: provider/model +""" + + def test_write_none_values(self): + with AgentConfig("agent_name", self.project_dir) as config: + config.role = None + config.goal = None + config.backstory = None + config.llm = None + + yaml_src = open(self.project_dir/AGENTS_FILENAME).read() + assert yaml_src == """agent_name: + role: > + goal: > + backstory: > + llm: +""" \ No newline at end of file diff --git a/tests/test_frameworks.py b/tests/test_frameworks.py index 325be57d..1a891ff2 100644 --- a/tests/test_frameworks.py +++ b/tests/test_frameworks.py @@ -49,7 +49,7 @@ def test_get_framework_module(self): assert module.__name__ == f"agentstack.frameworks.{self.framework}" def test_get_framework_module_invalid(self): - with self.assertRaises(ValueError) as context: + with self.assertRaises(Exception) as context: frameworks.get_framework_module('invalid') def test_validate_project(self): diff --git a/tests/test_generation_agent.py b/tests/test_generation_agent.py new file mode 100644 index 00000000..d74db555 --- /dev/null +++ b/tests/test_generation_agent.py @@ -0,0 +1,60 @@ +import os, sys +from pathlib import Path +import shutil +import unittest +from parameterized import parameterized_class + +from agentstack import frameworks, ValidationError +from agentstack.tools import get_all_tools, ToolConfig +from agentstack.generation.files import ConfigFile +from agentstack.generation.agent_generation import add_agent + +BASE_PATH = Path(__file__).parent + +# TODO parameterize all tools +@parameterized_class([ + {"framework": framework} for framework in frameworks.SUPPORTED_FRAMEWORKS +]) +class TestGenerationAgent(unittest.TestCase): + def setUp(self): + self.project_dir = BASE_PATH/'tmp'/'agent_generation' + + os.makedirs(self.project_dir) + os.makedirs(self.project_dir/'src') + os.makedirs(self.project_dir/'src'/'config') + (self.project_dir/'src'/'__init__.py').touch() + + # populate the entrypoint + entrypoint_path = frameworks.get_entrypoint_path(self.framework, self.project_dir) + shutil.copy(BASE_PATH/f"fixtures/frameworks/{self.framework}/entrypoint_max.py", entrypoint_path) + + # set the framework in agentstack.json + shutil.copy(BASE_PATH/'fixtures'/'agentstack.json', self.project_dir/'agentstack.json') + with ConfigFile(self.project_dir) as config: + config.framework = self.framework + + def tearDown(self): + shutil.rmtree(self.project_dir) + + def test_add_agent(self): + add_agent('test_agent_two', + role='role', + goal='goal', + backstory='backstory', + llm='llm', + path=self.project_dir) + + entrypoint_path = frameworks.get_entrypoint_path(self.framework, self.project_dir) + entrypoint_src = open(entrypoint_path).read() + # agents.yaml is covered in test_agents_config.py + # TODO framework-specific validation for code structure + assert 'def test_agent_two' in entrypoint_src + + def test_add_agent_exists(self): + with self.assertRaises(SystemExit) as context: + add_agent('test_agent', + role='role', + goal='goal', + backstory='backstory', + llm='llm', + path=self.project_dir) \ No newline at end of file diff --git a/tests/test_tool_generation.py b/tests/test_generation_tool.py similarity index 98% rename from tests/test_tool_generation.py rename to tests/test_generation_tool.py index a19b92cf..4a0c1031 100644 --- a/tests/test_tool_generation.py +++ b/tests/test_generation_tool.py @@ -16,7 +16,7 @@ @parameterized_class([ {"framework": framework} for framework in frameworks.SUPPORTED_FRAMEWORKS ]) -class TestToolGeneration(unittest.TestCase): +class TestGenerationTool(unittest.TestCase): def setUp(self): self.project_dir = BASE_PATH/'tmp'/'tool_generation' From 8abde6b28c5bbd63626f027acf15f3435f6bf2cb Mon Sep 17 00:00:00 2001 From: Travis Dent Date: Thu, 5 Dec 2024 17:41:12 -0800 Subject: [PATCH 09/14] Migrate task generation to frameworks. Add task generation tests. --- agentstack/frameworks/__init__.py | 31 ++---- agentstack/frameworks/crewai.py | 74 +++++++++---- agentstack/generation/__init__.py | 2 +- agentstack/generation/agent_generation.py | 21 ++-- agentstack/generation/asttools.py | 1 + agentstack/generation/task_generation.py | 128 ++++++---------------- agentstack/main.py | 2 +- agentstack/tasks.py | 85 ++++++++++++++ foo.yaml | 0 tests/fixtures/tasks_max.yaml | 14 +++ tests/fixtures/tasks_min.yaml | 1 + tests/test_agents_config.py | 1 - tests/test_generation_agent.py | 6 +- tests/test_generation_tasks.py | 60 ++++++++++ tests/test_generation_tool.py | 3 + tests/test_tasks_config.py | 69 ++++++++++++ 16 files changed, 345 insertions(+), 153 deletions(-) create mode 100644 agentstack/tasks.py create mode 100644 foo.yaml create mode 100644 tests/fixtures/tasks_max.yaml create mode 100644 tests/fixtures/tasks_min.yaml create mode 100644 tests/test_generation_tasks.py create mode 100644 tests/test_tasks_config.py diff --git a/agentstack/frameworks/__init__.py b/agentstack/frameworks/__init__.py index f8d9b635..dab450f9 100644 --- a/agentstack/frameworks/__init__.py +++ b/agentstack/frameworks/__init__.py @@ -21,17 +21,11 @@ `get_agent_names(path: Optional[Path] = None) -> list[str]`: Get a list of agent names in the user's project. -`add_agent(path: Optional[Path] = None) -> None`: +`add_agent(agent: AgentConfig, path: Optional[Path] = None) -> None`: Add an agent to the user's project. -`remove_agent(path: Optional[Path] = None) -> None`: - Remove an agent from the user's project. - -`add_input(path: Optional[Path] = None) -> None`: - Add an input to the user's project. - -`remove_input(path: Optional[Path] = None) -> None`: - Remove an input from the user's project. +`add_task(task: TaskConfig, path: Optional[Path] = None) -> None`: + Add a task to the user's project. """ from typing import Optional from importlib import import_module @@ -39,6 +33,7 @@ from agentstack import ValidationError from agentstack.tools import ToolConfig from agentstack.agents import AgentConfig +from agentstack.tasks import TaskConfig CREWAI = 'crewai' @@ -91,21 +86,9 @@ def add_agent(framework: str, agent: AgentConfig, path: Optional[Path] = None): """ return get_framework_module(framework).add_agent(agent, path) -def remove_agent(framework: str, agent: AgentConfig, path: Optional[Path] = None): - """ - Remove an agent from the user's project. - """ - return get_framework_module(framework).remove_agent(agent, path) - -def add_input(framework: str, path: Optional[Path] = None): - """ - Add an input to the user's project. - """ - return get_framework_module(framework).add_input(path) - -def remove_input(framework: str, path: Optional[Path] = None): +def add_task(framework: str, task: TaskConfig, path: Optional[Path] = None): """ - Remove an input from the user's project. + Add a task to the user's project. """ - return get_framework_module(framework).remove_input(path) + return get_framework_module(framework).add_task(task, path) diff --git a/agentstack/frameworks/crewai.py b/agentstack/frameworks/crewai.py index 9af30764..c189312f 100644 --- a/agentstack/frameworks/crewai.py +++ b/agentstack/frameworks/crewai.py @@ -3,9 +3,9 @@ import ast from agentstack import ValidationError from agentstack.tools import ToolConfig +from agentstack.tasks import TaskConfig from agentstack.agents import AgentConfig from agentstack.generation import asttools -from . import SUPPORTED_FRAMEWORKS ENTRYPOINT: Path = Path('src/crew.py') @@ -38,6 +38,31 @@ def get_task_methods(self) -> list[ast.FunctionDef]: """A `task` method is a method decorated with `@task`.""" return asttools.find_decorated_method_in_class(self.get_base_class(), 'task') + def add_task_method(self, task: TaskConfig): + """Add a new task method to the CrewAI entrypoint.""" + task_methods = self.get_task_methods() + if task.name in [method.name for method in task_methods]: + # TODO this should check all methods in the class for duplicates + raise ValidationError(f"Task `{task.name}` already exists in {ENTRYPOINT}") + if task_methods: + # Add after the existing task methods + _, pos = self.get_node_range(task_methods[-1]) + else: + # Add before the `crew` method + crew_method = self.get_crew_method() + pos, _ = self.get_node_range(crew_method) + + code = f""" @task + def {task.name}(self) -> Task: + return Task( + config=self.tasks_config['{task.name}'], + )""" + if not self.source[:pos].endswith('\n'): + code = '\n\n' + code + if not self.source[pos:].startswith('\n'): + code += '\n\n' + self.edit_node_range(pos, pos, code) + def get_agent_methods(self) -> list[ast.FunctionDef]: """An `agent` method is a method decorated with `@agent`.""" return asttools.find_decorated_method_in_class(self.get_base_class(), 'agent') @@ -47,6 +72,7 @@ def add_agent_method(self, agent: AgentConfig): # TODO do we want to pre-populate any tools? agent_methods = self.get_agent_methods() if agent.name in [method.name for method in agent_methods]: + # TODO this should check all methods in the class for duplicates raise ValidationError(f"Agent `{agent.name}` already exists in {ENTRYPOINT}") if agent_methods: # Add after the existing agent methods @@ -56,15 +82,17 @@ def add_agent_method(self, agent: AgentConfig): crew_method = self.get_crew_method() pos, _ = self.get_node_range(crew_method) - code = f""" - - @agent + code = f""" @agent def {agent.name}(self) -> Agent: return Agent( config=self.agents_config['{agent.name}'], tools=[], # add tools here or use `agentstack tools add verbose=True, )""" + if not self.source[:pos].endswith('\n'): + code = '\n\n' + code + if not self.source[pos:].startswith('\n'): + code += '\n\n' self.edit_node_range(pos, pos, code) def get_agent_tools(self, agent_name: str) -> ast.List: @@ -172,27 +200,27 @@ def validate_project(path: Optional[Path] = None) -> None: f"`@agent` decorated method not found in `{class_node.name}` class in {ENTRYPOINT}.\n" "Create a new agent using `agentstack generate agent `.") -def add_tool(tool: ToolConfig, agent_name: str, path: Optional[Path] = None): +def get_task_names(path: Optional[Path] = None) -> list[str]: """ - Add a tool to the CrewAI entrypoint for the specified agent. - The agent should already exist in the crew class and have a keyword argument `tools`. + Get a list of task names (methods with an @task decorator). """ if path is None: path = Path() - with CrewFile(path/ENTRYPOINT) as crew_file: - crew_file.add_agent_tools(agent_name, tool) + crew_file = CrewFile(path/ENTRYPOINT) + return [method.name for method in crew_file.get_task_methods()] -def remove_tool(tool: ToolConfig, agent_name: str, path: Optional[Path] = None): +def add_task(task: TaskConfig, path: Optional[Path] = None) -> None: """ - Remove a tool from the CrewAI framework for the specified agent. + Add a task method to the CrewAI entrypoint. """ if path is None: path = Path() with CrewFile(path/ENTRYPOINT) as crew_file: - crew_file.remove_agent_tools(agent_name, tool) + crew_file.add_task_method(task) def get_agent_names(path: Optional[Path] = None) -> list[str]: """ Get a list of agent names (methods with an @agent decorator). """ + if path is None: path = Path() crew_file = CrewFile(path/ENTRYPOINT) return [method.name for method in crew_file.get_agent_methods()] @@ -204,12 +232,20 @@ def add_agent(agent: AgentConfig, path: Optional[Path] = None) -> None: with CrewFile(path/ENTRYPOINT) as crew_file: crew_file.add_agent_method(agent) -def remove_agent(agent: AgentConfig, path: Optional[Path] = None) -> None: - raise NotImplementedError - -def add_input(path: Optional[Path] = None) -> None: - raise NotImplementedError +def add_tool(tool: ToolConfig, agent_name: str, path: Optional[Path] = None): + """ + Add a tool to the CrewAI entrypoint for the specified agent. + The agent should already exist in the crew class and have a keyword argument `tools`. + """ + if path is None: path = Path() + with CrewFile(path/ENTRYPOINT) as crew_file: + crew_file.add_agent_tools(agent_name, tool) -def remove_input(path: Optional[Path] = None) -> None: - raise NotImplementedError +def remove_tool(tool: ToolConfig, agent_name: str, path: Optional[Path] = None): + """ + Remove a tool from the CrewAI framework for the specified agent. + """ + if path is None: path = Path() + with CrewFile(path/ENTRYPOINT) as crew_file: + crew_file.remove_agent_tools(agent_name, tool) diff --git a/agentstack/generation/__init__.py b/agentstack/generation/__init__.py index 0685338c..82e2eb56 100644 --- a/agentstack/generation/__init__.py +++ b/agentstack/generation/__init__.py @@ -1,4 +1,4 @@ from .agent_generation import add_agent -from .task_generation import generate_task, get_task_names +from .task_generation import add_task from .tool_generation import add_tool, remove_tool from .files import ConfigFile, EnvFile, CONFIG_FILENAME \ No newline at end of file diff --git a/agentstack/generation/agent_generation.py b/agentstack/generation/agent_generation.py index a47af60a..e330e2d9 100644 --- a/agentstack/generation/agent_generation.py +++ b/agentstack/generation/agent_generation.py @@ -1,5 +1,5 @@ -import os, sys -from typing import Optional, List +import sys +from typing import Optional from pathlib import Path from agentstack import ValidationError from agentstack import frameworks @@ -10,9 +10,9 @@ def add_agent( agent_name: str, - role: str = 'Add your role here', - goal: str = 'Add your goal here', - backstory: str = 'Add your backstory here', + role: Optional[str] = None, + goal: Optional[str] = None, + backstory: Optional[str] = None, llm: Optional[str] = None, path: Optional[Path] = None): @@ -23,13 +23,10 @@ def add_agent( agent = AgentConfig(agent_name, path) with agent as config: - config.role = role - config.goal = goal - config.backstory = backstory - if llm: - config.llm = llm - else: - config.llm = agentstack_config.default_model + config.role = role or "Add your role here" + config.goal = goal or "Add your goal here" + config.backstory = backstory or "Add your backstory here" + config.llm = llm or agentstack_config.default_model try: frameworks.add_agent(framework, agent, path) diff --git a/agentstack/generation/asttools.py b/agentstack/generation/asttools.py index a89366b5..9599d132 100644 --- a/agentstack/generation/asttools.py +++ b/agentstack/generation/asttools.py @@ -69,6 +69,7 @@ def edit_node_range(self, start: int, end: int, node: Union[str, ast.AST]): type_ignores=[] ) node = astor.to_source(module).strip() + self.source = self.source[:start] + node + self.source[end:] # In order to continue accurately modifying the AST, we need to re-parse the source. self.atok = asttokens.ASTTokens(self.source, parse=True) diff --git a/agentstack/generation/task_generation.py b/agentstack/generation/task_generation.py index d2fc6ebc..0b5113dd 100644 --- a/agentstack/generation/task_generation.py +++ b/agentstack/generation/task_generation.py @@ -1,94 +1,36 @@ -from typing import Optional, List +import sys +from typing import Optional +from pathlib import Path +from agentstack import ValidationError +from agentstack import frameworks +from agentstack.utils import verify_agentstack_project +from agentstack.tasks import TaskConfig, TASKS_FILENAME +from agentstack.generation.files import ConfigFile + + +def add_task( + task_name: str, + description: Optional[str] = None, + expected_output: Optional[str] = None, + agent: Optional[str] = None, + path: Optional[Path] = None): + + if path is None: path = Path() + verify_agentstack_project(path) + agentstack_config = ConfigFile(path) + framework = agentstack_config.framework + + task = TaskConfig(task_name, path) + with task as config: + config.description = description or "Add your description here" + config.expected_output = expected_output or "Add your expected_output here" + config.agent = agent or "agent_name" + + try: + frameworks.add_task(framework, task, path) + print(f" > Added to {TASKS_FILENAME}") + except ValidationError as e: + print(f"Error adding task to project:\n{e}") + sys.exit(1) + print(f"Added task \"{task_name}\" to your AgentStack project successfully!") -from .gen_utils import insert_after_tasks, get_crew_components, CrewComponent -from ..utils import verify_agentstack_project, get_framework -import os -from ruamel.yaml import YAML -from ruamel.yaml.scalarstring import FoldedScalarString - - -def generate_task( - name, - description: Optional[str], - expected_output: Optional[str], - agent: Optional[str] -): - if not description: - description = 'Add your description here' - if not expected_output: - expected_output = 'Add your expected_output here' - if not agent: - agent = 'default_agent' - - verify_agentstack_project() - - framework = get_framework() - - if framework == 'crewai': - generate_crew_task(name, description, expected_output, agent) - print(" > Added to src/config/tasks.yaml") - else: - print(f"This function is not yet implemented for {framework}") - return - - print(f"Added task \"{name}\" to your AgentStack project successfully!") - - -def generate_crew_task( - name, - description: Optional[str], - expected_output: Optional[str], - agent: Optional[str] -): - config_path = os.path.join('src', 'config', 'tasks.yaml') - - # Ensure the directory exists - os.makedirs(os.path.dirname(config_path), exist_ok=True) - - yaml = YAML() - yaml.preserve_quotes = True # Preserve quotes in existing data - - # Read existing data - if os.path.exists(config_path): - with open(config_path, 'r') as file: - try: - data = yaml.load(file) or {} - except Exception as exc: - print(f"Error parsing YAML file: {exc}") - data = {} - else: - data = {} - - # Handle None values - description_str = FoldedScalarString(description) if description else FoldedScalarString('') - expected_output_str = FoldedScalarString(expected_output) if expected_output else FoldedScalarString('') - agent_str = FoldedScalarString(agent) if agent else FoldedScalarString('') - - # Add new agent details - data[name] = { - 'description': description_str, - 'expected_output': expected_output_str, - 'agent': agent_str, - } - - # Write back to the file without altering existing content - with open(config_path, 'w') as file: - yaml.dump(data, file) - - # Add task to crew.py - file_path = 'src/crew.py' - code_to_insert = [ - "@task", - f"def {name}(self) -> Task:", - " return Task(", - f" config=self.tasks_config['{name}'],", - " )", - "" - ] - - insert_after_tasks(file_path, code_to_insert) - - -def get_task_names(framework: str, path: str = '') -> List[str]: - """Get only task names from the crew file""" - return get_crew_components(framework, CrewComponent.TASK, path)['tasks'] \ No newline at end of file diff --git a/agentstack/main.py b/agentstack/main.py index 768e3072..15344c66 100644 --- a/agentstack/main.py +++ b/agentstack/main.py @@ -109,7 +109,7 @@ def main(): configure_default_model() generation.add_agent(args.name, args.role, args.goal, args.backstory, args.llm) elif args.generate_command in ['task', 't']: - generation.generate_task(args.name, args.description, args.expected_output, args.agent) + generation.add_task(args.name, args.description, args.expected_output, args.agent) else: generate_parser.print_help() elif args.command in ['tools', 't']: diff --git a/agentstack/tasks.py b/agentstack/tasks.py new file mode 100644 index 00000000..fbc75cdb --- /dev/null +++ b/agentstack/tasks.py @@ -0,0 +1,85 @@ +from typing import Optional +import os, sys +from pathlib import Path +import pydantic +from ruamel.yaml import YAML, YAMLError +from ruamel.yaml.scalarstring import FoldedScalarString +from agentstack import ValidationError + + +TASKS_FILENAME: Path = Path("src/config/tasks.yaml") + +yaml = YAML() +yaml.preserve_quotes = True # Preserve quotes in existing data + +class TaskConfig(pydantic.BaseModel): + """ + Interface for interacting with a task configuration. + + Multiple tasks are stored in a single YAML file, so we always look up the + requested task by `name`. + + Use it as a context manager to make and save edits: + ```python + with TaskConfig('task_name') as config: + config.description = "foo" + + Config Schema + ------------- + name: str + The name of the agent; used for lookup. + description: Optional[str] + The description of the task. + expected_output: Optional[str] + The expected output of the task. + agent: Optional[str] + The agent to use for the task. + """ + name: str + description: Optional[str] = "" + expected_output: Optional[str] = "" + agent: Optional[str] = "" + + def __init__(self, name: str, path: Optional[Path] = None): + if not path: path = Path() + + if not os.path.exists(path/TASKS_FILENAME): + os.makedirs((path/TASKS_FILENAME).parent, exist_ok=True) + (path/TASKS_FILENAME).touch() + + try: + with open(path/TASKS_FILENAME, 'r') as f: + data = yaml.load(f) or {} + data = data.get(name, {}) or {} + super().__init__(**{**{'name': name}, **data}) + except YAMLError as e: + # TODO format MarkedYAMLError lines/messages + raise ValidationError(f"Error parsing tasks file: {filename}\n{e}") + except pydantic.ValidationError as e: + error_str = "Error validating tasks config:\n" + for error in e.errors(): + error_str += f"{' '.join(error['loc'])}: {error['msg']}\n" + raise ValidationError(f"Error loading task {name} from {filename}.\n{error_str}") + + # store the path *after* loading data + self._path = path + + def model_dump(self, *args, **kwargs) -> dict: + dump = super().model_dump(*args, **kwargs) + dump.pop('name') # name is the key, so keep it out of the data + # format these as FoldedScalarStrings + for key in ('description', 'expected_output', 'agent'): + dump[key] = FoldedScalarString(dump.get(key) or "") + return {self.name: dump} + + def write(self): + with open(self._path/TASKS_FILENAME, 'r') as f: + data = yaml.load(f) or {} + + data.update(self.model_dump()) + + with open(self._path/TASKS_FILENAME, 'w') as f: + yaml.dump(data, f) + + def __enter__(self) -> 'AgentConfig': return self + def __exit__(self, *args): self.write() diff --git a/foo.yaml b/foo.yaml new file mode 100644 index 00000000..e69de29b diff --git a/tests/fixtures/tasks_max.yaml b/tests/fixtures/tasks_max.yaml new file mode 100644 index 00000000..355ebd98 --- /dev/null +++ b/tests/fixtures/tasks_max.yaml @@ -0,0 +1,14 @@ +task_name: + description: >- + Add your description here + expected_output: >- + Add your expected output here + agent: >- + default_agent +task_name_two: + description: >- + Add your description here + expected_output: >- + Add your expected output here + agent: >- + default_agent diff --git a/tests/fixtures/tasks_min.yaml b/tests/fixtures/tasks_min.yaml new file mode 100644 index 00000000..1fd435c0 --- /dev/null +++ b/tests/fixtures/tasks_min.yaml @@ -0,0 +1 @@ +task_name: diff --git a/tests/test_agents_config.py b/tests/test_agents_config.py index 0a46bdb9..43dfea0c 100644 --- a/tests/test_agents_config.py +++ b/tests/test_agents_config.py @@ -36,7 +36,6 @@ def test_read_minimal_yaml(self): def test_read_maximal_yaml(self): shutil.copy(BASE_PATH/"fixtures/agents_max.yaml", self.project_dir/AGENTS_FILENAME) config = AgentConfig("agent_name", self.project_dir) - print(config.model_dump()) assert config.name == "agent_name" assert config.role == "role" assert config.goal == "this is a goal" diff --git a/tests/test_generation_agent.py b/tests/test_generation_agent.py index d74db555..c88b0a16 100644 --- a/tests/test_generation_agent.py +++ b/tests/test_generation_agent.py @@ -3,15 +3,15 @@ import shutil import unittest from parameterized import parameterized_class +import ast from agentstack import frameworks, ValidationError -from agentstack.tools import get_all_tools, ToolConfig from agentstack.generation.files import ConfigFile from agentstack.generation.agent_generation import add_agent BASE_PATH = Path(__file__).parent -# TODO parameterize all tools + @parameterized_class([ {"framework": framework} for framework in frameworks.SUPPORTED_FRAMEWORKS ]) @@ -49,6 +49,8 @@ def test_add_agent(self): # agents.yaml is covered in test_agents_config.py # TODO framework-specific validation for code structure assert 'def test_agent_two' in entrypoint_src + # verify that the file's syntax is valid with ast + ast.parse(entrypoint_src) def test_add_agent_exists(self): with self.assertRaises(SystemExit) as context: diff --git a/tests/test_generation_tasks.py b/tests/test_generation_tasks.py new file mode 100644 index 00000000..036af6a7 --- /dev/null +++ b/tests/test_generation_tasks.py @@ -0,0 +1,60 @@ +import os, sys +from pathlib import Path +import shutil +import unittest +from parameterized import parameterized_class +import ast + +from agentstack import frameworks, ValidationError +from agentstack.generation.files import ConfigFile +from agentstack.generation.task_generation import add_task + +BASE_PATH = Path(__file__).parent + + +@parameterized_class([ + {"framework": framework} for framework in frameworks.SUPPORTED_FRAMEWORKS +]) +class TestGenerationAgent(unittest.TestCase): + def setUp(self): + self.project_dir = BASE_PATH/'tmp'/'agent_generation' + + os.makedirs(self.project_dir) + os.makedirs(self.project_dir/'src') + os.makedirs(self.project_dir/'src'/'config') + (self.project_dir/'src'/'__init__.py').touch() + + # populate the entrypoint + entrypoint_path = frameworks.get_entrypoint_path(self.framework, self.project_dir) + shutil.copy(BASE_PATH/f"fixtures/frameworks/{self.framework}/entrypoint_max.py", entrypoint_path) + + # set the framework in agentstack.json + shutil.copy(BASE_PATH/'fixtures'/'agentstack.json', self.project_dir/'agentstack.json') + with ConfigFile(self.project_dir) as config: + config.framework = self.framework + + def tearDown(self): + shutil.rmtree(self.project_dir) + + def test_add_task(self): + add_task('task_test_two', + description='description', + expected_output='expected_output', + agent='agent', + path=self.project_dir) + + entrypoint_path = frameworks.get_entrypoint_path(self.framework, self.project_dir) + entrypoint_src = open(entrypoint_path).read() + # agents.yaml is covered in test_agents_config.py + # TODO framework-specific validation for code structure + assert 'def task_test_two' in entrypoint_src + # verify that the file's syntax is valid with ast + ast.parse(entrypoint_src) + + def test_add_agent_exists(self): + with self.assertRaises(SystemExit) as context: + add_task('test_task', + description='description', + expected_output='expected_output', + agent='agent', + path=self.project_dir) \ No newline at end of file diff --git a/tests/test_generation_tool.py b/tests/test_generation_tool.py index 4a0c1031..a44d68a2 100644 --- a/tests/test_generation_tool.py +++ b/tests/test_generation_tool.py @@ -3,6 +3,7 @@ import shutil import unittest from parameterized import parameterized_class +import ast from agentstack import frameworks from agentstack.tools import get_all_tools, ToolConfig @@ -44,6 +45,7 @@ def test_add_tool(self): entrypoint_path = frameworks.get_entrypoint_path(self.framework, self.project_dir) entrypoint_src = open(entrypoint_path).read() + ast.parse(entrypoint_src) tools_init_src = open(self.project_dir/TOOLS_INIT_FILENAME).read() # TODO verify tool is added to all agents (this is covered in test_frameworks.py) @@ -59,6 +61,7 @@ def test_remove_tool(self): entrypoint_path = frameworks.get_entrypoint_path(self.framework, self.project_dir) entrypoint_src = open(entrypoint_path).read() + ast.parse(entrypoint_src) tools_init_src = open(self.project_dir/TOOLS_INIT_FILENAME).read() # TODO verify tool is removed from all agents (this is covered in test_frameworks.py) diff --git a/tests/test_tasks_config.py b/tests/test_tasks_config.py new file mode 100644 index 00000000..26cafa20 --- /dev/null +++ b/tests/test_tasks_config.py @@ -0,0 +1,69 @@ +import json +import os, sys +import shutil +import unittest +import importlib.resources +from pathlib import Path +from agentstack.tasks import TaskConfig, TASKS_FILENAME + +BASE_PATH = Path(__file__).parent + +class AgentConfigTest(unittest.TestCase): + def setUp(self): + self.project_dir = BASE_PATH/'tmp/task_config' + os.makedirs(self.project_dir/'src/config') + + def tearDown(self): + shutil.rmtree(self.project_dir) + + def test_empty_file(self): + config = TaskConfig("task_name", self.project_dir) + assert config.name == "task_name" + assert config.description is "" + assert config.expected_output is "" + assert config.agent is "" + + def test_read_minimal_yaml(self): + shutil.copy(BASE_PATH/"fixtures/tasks_min.yaml", self.project_dir/TASKS_FILENAME) + config = TaskConfig("task_name", self.project_dir) + assert config.name == "task_name" + assert config.description is "" + assert config.expected_output is "" + assert config.agent is "" + + def test_read_maximal_yaml(self): + shutil.copy(BASE_PATH/"fixtures/tasks_max.yaml", self.project_dir/TASKS_FILENAME) + config = TaskConfig("task_name", self.project_dir) + assert config.name == "task_name" + assert config.description == "Add your description here" + assert config.expected_output == "Add your expected output here" + assert config.agent == "default_agent" + + def test_write_yaml(self): + with TaskConfig("task_name", self.project_dir) as config: + config.description = "Add your description here" + config.expected_output = "Add your expected output here" + config.agent = "default_agent" + + yaml_src = open(self.project_dir/TASKS_FILENAME).read() + assert yaml_src == """task_name: + description: >- + Add your description here + expected_output: >- + Add your expected output here + agent: >- + default_agent +""" + + def test_write_none_values(self): + with TaskConfig("task_name", self.project_dir) as config: + config.description = None + config.expected_output = None + config.agent = None + + yaml_src = open(self.project_dir/TASKS_FILENAME).read() + assert yaml_src == """task_name: + description: > + expected_output: > + agent: > +""" \ No newline at end of file From 3b1c76477bffa1318ef9a94dea3e5bcf8c14a25a Mon Sep 17 00:00:00 2001 From: Travis Dent Date: Fri, 6 Dec 2024 09:52:26 -0800 Subject: [PATCH 10/14] ruff format tests --- .../frameworks/crewai/entrypoint_max.py | 10 +-- .../frameworks/crewai/entrypoint_min.py | 3 +- tests/test_agents_config.py | 37 ++++++---- tests/test_cli_loads.py | 31 ++++---- tests/test_frameworks.py | 70 +++++++++---------- tests/test_generation_agent.py | 54 +++++++------- tests/test_generation_files.py | 53 +++++++------- tests/test_generation_tasks.py | 50 ++++++------- tests/test_generation_tool.py | 56 +++++++-------- tests/test_tasks_config.py | 37 ++++++---- tests/test_tool_config.py | 9 ++- tests/test_tool_generation_init.py | 59 ++++++++-------- 12 files changed, 243 insertions(+), 226 deletions(-) diff --git a/tests/fixtures/frameworks/crewai/entrypoint_max.py b/tests/fixtures/frameworks/crewai/entrypoint_max.py index b873d377..6ccb9af2 100644 --- a/tests/fixtures/frameworks/crewai/entrypoint_max.py +++ b/tests/fixtures/frameworks/crewai/entrypoint_max.py @@ -2,15 +2,12 @@ from crewai.project import CrewBase, agent, crew, task import tools + @CrewBase -class TestCrew(): +class TestCrew: @agent def test_agent(self) -> Agent: - return Agent( - config=self.agents_config['test_agent'], - tools=[], - verbose=True - ) + return Agent(config=self.agents_config['test_agent'], tools=[], verbose=True) @task def test_task(self) -> Task: @@ -26,4 +23,3 @@ def crew(self) -> Crew: process=Process.sequential, verbose=True, ) - diff --git a/tests/fixtures/frameworks/crewai/entrypoint_min.py b/tests/fixtures/frameworks/crewai/entrypoint_min.py index 20693b65..f423807f 100644 --- a/tests/fixtures/frameworks/crewai/entrypoint_min.py +++ b/tests/fixtures/frameworks/crewai/entrypoint_min.py @@ -2,9 +2,9 @@ from crewai.project import CrewBase, agent, crew, task import tools + @CrewBase class TestCrew: - @crew def crew(self) -> Crew: return Crew( @@ -13,4 +13,3 @@ def crew(self) -> Crew: process=Process.sequential, verbose=True, ) - diff --git a/tests/test_agents_config.py b/tests/test_agents_config.py index 43dfea0c..657d9318 100644 --- a/tests/test_agents_config.py +++ b/tests/test_agents_config.py @@ -8,14 +8,15 @@ BASE_PATH = Path(__file__).parent + class AgentConfigTest(unittest.TestCase): def setUp(self): - self.project_dir = BASE_PATH/'tmp/agent_config' - os.makedirs(self.project_dir/'src/config') - + self.project_dir = BASE_PATH / 'tmp/agent_config' + os.makedirs(self.project_dir / 'src/config') + def tearDown(self): shutil.rmtree(self.project_dir) - + def test_empty_file(self): config = AgentConfig("agent_name", self.project_dir) assert config.name == "agent_name" @@ -23,18 +24,18 @@ def test_empty_file(self): assert config.goal is "" assert config.backstory is "" assert config.llm is "" - + def test_read_minimal_yaml(self): - shutil.copy(BASE_PATH/"fixtures/agents_min.yaml", self.project_dir/AGENTS_FILENAME) + shutil.copy(BASE_PATH / "fixtures/agents_min.yaml", self.project_dir / AGENTS_FILENAME) config = AgentConfig("agent_name", self.project_dir) assert config.name == "agent_name" assert config.role == "" assert config.goal == "" assert config.backstory == "" assert config.llm == "" - + def test_read_maximal_yaml(self): - shutil.copy(BASE_PATH/"fixtures/agents_max.yaml", self.project_dir/AGENTS_FILENAME) + shutil.copy(BASE_PATH / "fixtures/agents_max.yaml", self.project_dir / AGENTS_FILENAME) config = AgentConfig("agent_name", self.project_dir) assert config.name == "agent_name" assert config.role == "role" @@ -48,9 +49,11 @@ def test_write_yaml(self): config.goal = "this is a goal" config.backstory = "backstory" config.llm = "provider/model" - - yaml_src = open(self.project_dir/AGENTS_FILENAME).read() - assert yaml_src == """agent_name: + + yaml_src = open(self.project_dir / AGENTS_FILENAME).read() + assert ( + yaml_src + == """agent_name: role: >- role goal: >- @@ -59,6 +62,7 @@ def test_write_yaml(self): backstory llm: provider/model """ + ) def test_write_none_values(self): with AgentConfig("agent_name", self.project_dir) as config: @@ -66,11 +70,14 @@ def test_write_none_values(self): config.goal = None config.backstory = None config.llm = None - - yaml_src = open(self.project_dir/AGENTS_FILENAME).read() - assert yaml_src == """agent_name: + + yaml_src = open(self.project_dir / AGENTS_FILENAME).read() + assert ( + yaml_src + == """agent_name: role: > goal: > backstory: > llm: -""" \ No newline at end of file +""" + ) diff --git a/tests/test_cli_loads.py b/tests/test_cli_loads.py index 13597c62..7dfaf0bf 100644 --- a/tests/test_cli_loads.py +++ b/tests/test_cli_loads.py @@ -6,16 +6,18 @@ BASE_PATH = Path(__file__).parent + class TestAgentStackCLI(unittest.TestCase): - CLI_ENTRY = [sys.executable, "-m", "agentstack.main"] # Replace with your actual CLI entry point if different + # Replace with your actual CLI entry point if different + CLI_ENTRY = [ + sys.executable, + "-m", + "agentstack.main", + ] def run_cli(self, *args): """Helper method to run the CLI with arguments.""" - result = subprocess.run( - [*self.CLI_ENTRY, *args], - capture_output=True, - text=True - ) + result = subprocess.run([*self.CLI_ENTRY, *args], capture_output=True, text=True) return result def test_version(self): @@ -32,13 +34,13 @@ def test_invalid_command(self): def test_init_command(self): """Test the 'init' command to create a project directory.""" - test_dir = Path(BASE_PATH/'tmp/test_project') + test_dir = Path(BASE_PATH / 'tmp/test_project') # Ensure the directory doesn't exist from previous runs if test_dir.exists(): shutil.rmtree(test_dir) os.makedirs(test_dir) - + os.chdir(test_dir) result = self.run_cli("init", str(test_dir)) self.assertEqual(result.returncode, 0) @@ -49,21 +51,22 @@ def test_init_command(self): def test_run_command_invalid_project(self): """Test the 'run' command on an invalid project.""" - test_dir = Path(BASE_PATH/'tmp/test_project') + test_dir = Path(BASE_PATH / 'tmp/test_project') if test_dir.exists(): shutil.rmtree(test_dir) os.makedirs(test_dir) - + # Write a basic agentstack.json file - with (test_dir/'agentstack.json').open('w') as f: - f.write(open(BASE_PATH/'fixtures/agentstack.json', 'r').read()) - + with (test_dir / 'agentstack.json').open('w') as f: + f.write(open(BASE_PATH / 'fixtures/agentstack.json', 'r').read()) + os.chdir(test_dir) result = self.run_cli('run') self.assertNotEqual(result.returncode, 0) self.assertIn("Project validation failed", result.stdout) - + shutil.rmtree(test_dir) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_frameworks.py b/tests/test_frameworks.py index 1a891ff2..8b27b84a 100644 --- a/tests/test_frameworks.py +++ b/tests/test_frameworks.py @@ -11,71 +11,71 @@ BASE_PATH = Path(__file__).parent -@parameterized_class([ - {"framework": framework} for framework in frameworks.SUPPORTED_FRAMEWORKS -]) +@parameterized_class([{"framework": framework} for framework in frameworks.SUPPORTED_FRAMEWORKS]) class TestFrameworks(unittest.TestCase): def setUp(self): - self.project_dir = BASE_PATH/'tmp'/self.framework - + self.project_dir = BASE_PATH / 'tmp' / self.framework + os.makedirs(self.project_dir) - os.makedirs(self.project_dir/'src') - os.makedirs(self.project_dir/'src'/'tools') - - (self.project_dir/'src'/'__init__.py').touch() - (self.project_dir/'src'/'tools'/'__init__.py').touch() - + os.makedirs(self.project_dir / 'src') + os.makedirs(self.project_dir / 'src' / 'tools') + + (self.project_dir / 'src' / '__init__.py').touch() + (self.project_dir / 'src' / 'tools' / '__init__.py').touch() + def tearDown(self): shutil.rmtree(self.project_dir) - + def _populate_min_entrypoint(self): """This entrypoint does not have any tools or agents.""" entrypoint_path = frameworks.get_entrypoint_path(self.framework, self.project_dir) - shutil.copy(BASE_PATH/f"fixtures/frameworks/{self.framework}/entrypoint_min.py", entrypoint_path) - + shutil.copy(BASE_PATH / f"fixtures/frameworks/{self.framework}/entrypoint_min.py", entrypoint_path) + def _populate_max_entrypoint(self): """This entrypoint has tools and agents.""" entrypoint_path = frameworks.get_entrypoint_path(self.framework, self.project_dir) - shutil.copy(BASE_PATH/f"fixtures/frameworks/{self.framework}/entrypoint_max.py", entrypoint_path) - + shutil.copy(BASE_PATH / f"fixtures/frameworks/{self.framework}/entrypoint_max.py", entrypoint_path) + def _get_test_tool(self) -> ToolConfig: return ToolConfig(name='test_tool', category='test', tools=['test_tool']) - + def _get_test_tool_starred(self) -> ToolConfig: - return ToolConfig(name='test_tool_star', category='test', tools=['test_tool_star'], tools_bundled=True) - + return ToolConfig( + name='test_tool_star', category='test', tools=['test_tool_star'], tools_bundled=True + ) + def test_get_framework_module(self): module = frameworks.get_framework_module(self.framework) assert module.__name__ == f"agentstack.frameworks.{self.framework}" - + def test_get_framework_module_invalid(self): with self.assertRaises(Exception) as context: frameworks.get_framework_module('invalid') - + def test_validate_project(self): self._populate_max_entrypoint() frameworks.validate_project(self.framework, self.project_dir) - + def test_validate_project_invalid(self): self._populate_min_entrypoint() with self.assertRaises(ValidationError) as context: frameworks.validate_project(self.framework, self.project_dir) - + def test_add_tool(self): self._populate_max_entrypoint() frameworks.add_tool(self.framework, self._get_test_tool(), 'test_agent', self.project_dir) - + entrypoint_src = open(frameworks.get_entrypoint_path(self.framework, self.project_dir)).read() # TODO these asserts are not framework agnostic assert 'tools=[tools.test_tool' in entrypoint_src - + def test_add_tool_starred(self): self._populate_max_entrypoint() frameworks.add_tool(self.framework, self._get_test_tool_starred(), 'test_agent', self.project_dir) - + entrypoint_src = open(frameworks.get_entrypoint_path(self.framework, self.project_dir)).read() assert 'tools=[*tools.test_tool_star' in entrypoint_src - + def test_add_tool_invalid(self): self._populate_min_entrypoint() with self.assertRaises(ValidationError) as context: @@ -85,7 +85,7 @@ def test_remove_tool(self): self._populate_max_entrypoint() frameworks.add_tool(self.framework, self._get_test_tool(), 'test_agent', self.project_dir) frameworks.remove_tool(self.framework, self._get_test_tool(), 'test_agent', self.project_dir) - + entrypoint_src = open(frameworks.get_entrypoint_path(self.framework, self.project_dir)).read() assert 'tools=[tools.test_tool' not in entrypoint_src @@ -93,7 +93,7 @@ def test_remove_tool_starred(self): self._populate_max_entrypoint() frameworks.add_tool(self.framework, self._get_test_tool_starred(), 'test_agent', self.project_dir) frameworks.remove_tool(self.framework, self._get_test_tool_starred(), 'test_agent', self.project_dir) - + entrypoint_src = open(frameworks.get_entrypoint_path(self.framework, self.project_dir)).read() assert 'tools=[*tools.test_tool_star' not in entrypoint_src @@ -101,21 +101,19 @@ def test_add_multiple_tools(self): self._populate_max_entrypoint() frameworks.add_tool(self.framework, self._get_test_tool(), 'test_agent', self.project_dir) frameworks.add_tool(self.framework, self._get_test_tool_starred(), 'test_agent', self.project_dir) - + entrypoint_src = open(frameworks.get_entrypoint_path(self.framework, self.project_dir)).read() - assert ( # ordering is not guaranteed + assert ( # ordering is not guaranteed 'tools=[tools.test_tool, *tools.test_tool_star' in entrypoint_src - or - 'tools=[*tools.test_tool_star, tools.test_tool' in entrypoint_src - ) + or 'tools=[*tools.test_tool_star, tools.test_tool' in entrypoint_src + ) def test_remove_one_tool_of_multiple(self): self._populate_max_entrypoint() frameworks.add_tool(self.framework, self._get_test_tool(), 'test_agent', self.project_dir) frameworks.add_tool(self.framework, self._get_test_tool_starred(), 'test_agent', self.project_dir) frameworks.remove_tool(self.framework, self._get_test_tool(), 'test_agent', self.project_dir) - + entrypoint_src = open(frameworks.get_entrypoint_path(self.framework, self.project_dir)).read() assert 'tools=[tools.test_tool' not in entrypoint_src assert 'tools=[*tools.test_tool_star' in entrypoint_src - diff --git a/tests/test_generation_agent.py b/tests/test_generation_agent.py index c88b0a16..2f836e5e 100644 --- a/tests/test_generation_agent.py +++ b/tests/test_generation_agent.py @@ -12,37 +12,37 @@ BASE_PATH = Path(__file__).parent -@parameterized_class([ - {"framework": framework} for framework in frameworks.SUPPORTED_FRAMEWORKS -]) +@parameterized_class([{"framework": framework} for framework in frameworks.SUPPORTED_FRAMEWORKS]) class TestGenerationAgent(unittest.TestCase): def setUp(self): - self.project_dir = BASE_PATH/'tmp'/'agent_generation' - + self.project_dir = BASE_PATH / 'tmp' / 'agent_generation' + os.makedirs(self.project_dir) - os.makedirs(self.project_dir/'src') - os.makedirs(self.project_dir/'src'/'config') - (self.project_dir/'src'/'__init__.py').touch() - + os.makedirs(self.project_dir / 'src') + os.makedirs(self.project_dir / 'src' / 'config') + (self.project_dir / 'src' / '__init__.py').touch() + # populate the entrypoint entrypoint_path = frameworks.get_entrypoint_path(self.framework, self.project_dir) - shutil.copy(BASE_PATH/f"fixtures/frameworks/{self.framework}/entrypoint_max.py", entrypoint_path) - + shutil.copy(BASE_PATH / f"fixtures/frameworks/{self.framework}/entrypoint_max.py", entrypoint_path) + # set the framework in agentstack.json - shutil.copy(BASE_PATH/'fixtures'/'agentstack.json', self.project_dir/'agentstack.json') + shutil.copy(BASE_PATH / 'fixtures' / 'agentstack.json', self.project_dir / 'agentstack.json') with ConfigFile(self.project_dir) as config: config.framework = self.framework - + def tearDown(self): shutil.rmtree(self.project_dir) - + def test_add_agent(self): - add_agent('test_agent_two', - role='role', - goal='goal', - backstory='backstory', - llm='llm', - path=self.project_dir) + add_agent( + 'test_agent_two', + role='role', + goal='goal', + backstory='backstory', + llm='llm', + path=self.project_dir, + ) entrypoint_path = frameworks.get_entrypoint_path(self.framework, self.project_dir) entrypoint_src = open(entrypoint_path).read() @@ -54,9 +54,11 @@ def test_add_agent(self): def test_add_agent_exists(self): with self.assertRaises(SystemExit) as context: - add_agent('test_agent', - role='role', - goal='goal', - backstory='backstory', - llm='llm', - path=self.project_dir) \ No newline at end of file + add_agent( + 'test_agent', + role='role', + goal='goal', + backstory='backstory', + llm='llm', + path=self.project_dir, + ) diff --git a/tests/test_generation_files.py b/tests/test_generation_files.py index 1b0ec56d..31d08831 100644 --- a/tests/test_generation_files.py +++ b/tests/test_generation_files.py @@ -8,28 +8,30 @@ BASE_PATH = Path(__file__).parent + class GenerationFilesTest(unittest.TestCase): def test_read_config(self): - config = ConfigFile(BASE_PATH / "fixtures") # + agentstack.json + config = ConfigFile(BASE_PATH / "fixtures") # + agentstack.json assert config.framework == "crewai" assert config.tools == [] assert config.telemetry_opt_out is None assert config.default_model is None - + def test_write_config(self): try: - os.makedirs(BASE_PATH/"tmp", exist_ok=True) - shutil.copy(BASE_PATH/"fixtures/agentstack.json", - BASE_PATH/"tmp/agentstack.json") - - with ConfigFile(BASE_PATH/"tmp") as config: + os.makedirs(BASE_PATH / "tmp", exist_ok=True) + shutil.copy(BASE_PATH / "fixtures/agentstack.json", BASE_PATH / "tmp/agentstack.json") + + with ConfigFile(BASE_PATH / "tmp") as config: config.framework = "crewai" config.tools = ["tool1", "tool2"] config.telemetry_opt_out = True config.default_model = "openai/gpt-4o" - - tmp_data = open(BASE_PATH/"tmp/agentstack.json").read() - assert tmp_data == """{ + + tmp_data = open(BASE_PATH / "tmp/agentstack.json").read() + assert ( + tmp_data + == """{ "framework": "crewai", "tools": [ "tool1", @@ -38,11 +40,12 @@ def test_write_config(self): "telemetry_opt_out": true, "default_model": "openai/gpt-4o" }""" + ) except Exception as e: raise e finally: os.remove(BASE_PATH / "tmp/agentstack.json") - #os.rmdir(BASE_PATH / "tmp") + # os.rmdir(BASE_PATH / "tmp") def test_read_missing_config(self): with self.assertRaises(FileNotFoundError) as context: @@ -54,17 +57,17 @@ def test_verify_agentstack_project_valid(self): def test_verify_agentstack_project_invalid(self): with self.assertRaises(SystemExit) as context: verify_agentstack_project(BASE_PATH / "missing") - + def test_get_framework(self): assert get_framework(BASE_PATH / "fixtures") == "crewai" with self.assertRaises(SystemExit) as context: get_framework(BASE_PATH / "missing") - + def test_get_telemetry_opt_out(self): assert get_telemetry_opt_out(BASE_PATH / "fixtures") is False with self.assertRaises(SystemExit) as context: get_telemetry_opt_out(BASE_PATH / "missing") - + def test_read_env(self): env = EnvFile(BASE_PATH / "fixtures") assert env.variables == {"ENV_VAR1": "value1", "ENV_VAR2": "value2"} @@ -72,22 +75,20 @@ def test_read_env(self): assert env["ENV_VAR2"] == "value2" with self.assertRaises(KeyError) as context: env["ENV_VAR3"] - + def test_write_env(self): try: - os.makedirs(BASE_PATH/"tmp", exist_ok=True) - shutil.copy(BASE_PATH/"fixtures/.env", - BASE_PATH/"tmp/.env") - - with EnvFile(BASE_PATH/"tmp") as env: - env.append_if_new("ENV_VAR1", "value100") # Should not be updated - env.append_if_new("ENV_VAR100", "value2") # Should be added - - tmp_data = open(BASE_PATH/"tmp/.env").read() + os.makedirs(BASE_PATH / "tmp", exist_ok=True) + shutil.copy(BASE_PATH / "fixtures/.env", BASE_PATH / "tmp/.env") + + with EnvFile(BASE_PATH / "tmp") as env: + env.append_if_new("ENV_VAR1", "value100") # Should not be updated + env.append_if_new("ENV_VAR100", "value2") # Should be added + + tmp_data = open(BASE_PATH / "tmp/.env").read() assert tmp_data == """\nENV_VAR1=value1\nENV_VAR2=value2\nENV_VAR100=value2""" except Exception as e: raise e finally: os.remove(BASE_PATH / "tmp/.env") - #os.rmdir(BASE_PATH / "tmp") - + # os.rmdir(BASE_PATH / "tmp") diff --git a/tests/test_generation_tasks.py b/tests/test_generation_tasks.py index 036af6a7..430a3695 100644 --- a/tests/test_generation_tasks.py +++ b/tests/test_generation_tasks.py @@ -12,36 +12,36 @@ BASE_PATH = Path(__file__).parent -@parameterized_class([ - {"framework": framework} for framework in frameworks.SUPPORTED_FRAMEWORKS -]) +@parameterized_class([{"framework": framework} for framework in frameworks.SUPPORTED_FRAMEWORKS]) class TestGenerationAgent(unittest.TestCase): def setUp(self): - self.project_dir = BASE_PATH/'tmp'/'agent_generation' - + self.project_dir = BASE_PATH / 'tmp' / 'agent_generation' + os.makedirs(self.project_dir) - os.makedirs(self.project_dir/'src') - os.makedirs(self.project_dir/'src'/'config') - (self.project_dir/'src'/'__init__.py').touch() - + os.makedirs(self.project_dir / 'src') + os.makedirs(self.project_dir / 'src' / 'config') + (self.project_dir / 'src' / '__init__.py').touch() + # populate the entrypoint entrypoint_path = frameworks.get_entrypoint_path(self.framework, self.project_dir) - shutil.copy(BASE_PATH/f"fixtures/frameworks/{self.framework}/entrypoint_max.py", entrypoint_path) - + shutil.copy(BASE_PATH / f"fixtures/frameworks/{self.framework}/entrypoint_max.py", entrypoint_path) + # set the framework in agentstack.json - shutil.copy(BASE_PATH/'fixtures'/'agentstack.json', self.project_dir/'agentstack.json') + shutil.copy(BASE_PATH / 'fixtures' / 'agentstack.json', self.project_dir / 'agentstack.json') with ConfigFile(self.project_dir) as config: config.framework = self.framework - + def tearDown(self): shutil.rmtree(self.project_dir) - + def test_add_task(self): - add_task('task_test_two', - description='description', - expected_output='expected_output', - agent='agent', - path=self.project_dir) + add_task( + 'task_test_two', + description='description', + expected_output='expected_output', + agent='agent', + path=self.project_dir, + ) entrypoint_path = frameworks.get_entrypoint_path(self.framework, self.project_dir) entrypoint_src = open(entrypoint_path).read() @@ -53,8 +53,10 @@ def test_add_task(self): def test_add_agent_exists(self): with self.assertRaises(SystemExit) as context: - add_task('test_task', - description='description', - expected_output='expected_output', - agent='agent', - path=self.project_dir) \ No newline at end of file + add_task( + 'test_task', + description='description', + expected_output='expected_output', + agent='agent', + path=self.project_dir, + ) diff --git a/tests/test_generation_tool.py b/tests/test_generation_tool.py index a44d68a2..d2122689 100644 --- a/tests/test_generation_tool.py +++ b/tests/test_generation_tool.py @@ -13,60 +13,58 @@ BASE_PATH = Path(__file__).parent + # TODO parameterize all tools -@parameterized_class([ - {"framework": framework} for framework in frameworks.SUPPORTED_FRAMEWORKS -]) +@parameterized_class([{"framework": framework} for framework in frameworks.SUPPORTED_FRAMEWORKS]) class TestGenerationTool(unittest.TestCase): def setUp(self): - self.project_dir = BASE_PATH/'tmp'/'tool_generation' - + self.project_dir = BASE_PATH / 'tmp' / 'tool_generation' + os.makedirs(self.project_dir) - os.makedirs(self.project_dir/'src') - os.makedirs(self.project_dir/'src'/'tools') - (self.project_dir/'src'/'__init__.py').touch() - (self.project_dir/TOOLS_INIT_FILENAME).touch() - + os.makedirs(self.project_dir / 'src') + os.makedirs(self.project_dir / 'src' / 'tools') + (self.project_dir / 'src' / '__init__.py').touch() + (self.project_dir / TOOLS_INIT_FILENAME).touch() + # populate the entrypoint entrypoint_path = frameworks.get_entrypoint_path(self.framework, self.project_dir) - shutil.copy(BASE_PATH/f"fixtures/frameworks/{self.framework}/entrypoint_max.py", entrypoint_path) - + shutil.copy(BASE_PATH / f"fixtures/frameworks/{self.framework}/entrypoint_max.py", entrypoint_path) + # set the framework in agentstack.json - shutil.copy(BASE_PATH/'fixtures'/'agentstack.json', self.project_dir/'agentstack.json') + shutil.copy(BASE_PATH / 'fixtures' / 'agentstack.json', self.project_dir / 'agentstack.json') with ConfigFile(self.project_dir) as config: config.framework = self.framework - + def tearDown(self): shutil.rmtree(self.project_dir) - + def test_add_tool(self): tool_conf = ToolConfig.from_tool_name('agent_connect') add_tool('agent_connect', path=self.project_dir) - + entrypoint_path = frameworks.get_entrypoint_path(self.framework, self.project_dir) entrypoint_src = open(entrypoint_path).read() ast.parse(entrypoint_src) - tools_init_src = open(self.project_dir/TOOLS_INIT_FILENAME).read() - + tools_init_src = open(self.project_dir / TOOLS_INIT_FILENAME).read() + # TODO verify tool is added to all agents (this is covered in test_frameworks.py) - #assert 'agent_connect' in entrypoint_src + # assert 'agent_connect' in entrypoint_src assert f'from .{tool_conf.module_name} import' in tools_init_src - assert (self.project_dir/'src'/'tools'/f'{tool_conf.module_name}.py').exists() - assert 'agent_connect' in open(self.project_dir/'agentstack.json').read() - + assert (self.project_dir / 'src' / 'tools' / f'{tool_conf.module_name}.py').exists() + assert 'agent_connect' in open(self.project_dir / 'agentstack.json').read() + def test_remove_tool(self): tool_conf = ToolConfig.from_tool_name('agent_connect') add_tool('agent_connect', path=self.project_dir) remove_tool('agent_connect', path=self.project_dir) - + entrypoint_path = frameworks.get_entrypoint_path(self.framework, self.project_dir) entrypoint_src = open(entrypoint_path).read() ast.parse(entrypoint_src) - tools_init_src = open(self.project_dir/TOOLS_INIT_FILENAME).read() - + tools_init_src = open(self.project_dir / TOOLS_INIT_FILENAME).read() + # TODO verify tool is removed from all agents (this is covered in test_frameworks.py) - #assert 'agent_connect' not in entrypoint_src + # assert 'agent_connect' not in entrypoint_src assert f'from .{tool_conf.module_name} import' not in tools_init_src - assert not (self.project_dir/'src'/'tools'/f'{tool_conf.module_name}.py').exists() - assert 'agent_connect' not in open(self.project_dir/'agentstack.json').read() - + assert not (self.project_dir / 'src' / 'tools' / f'{tool_conf.module_name}.py').exists() + assert 'agent_connect' not in open(self.project_dir / 'agentstack.json').read() diff --git a/tests/test_tasks_config.py b/tests/test_tasks_config.py index 26cafa20..c95665bc 100644 --- a/tests/test_tasks_config.py +++ b/tests/test_tasks_config.py @@ -8,31 +8,32 @@ BASE_PATH = Path(__file__).parent + class AgentConfigTest(unittest.TestCase): def setUp(self): - self.project_dir = BASE_PATH/'tmp/task_config' - os.makedirs(self.project_dir/'src/config') - + self.project_dir = BASE_PATH / 'tmp/task_config' + os.makedirs(self.project_dir / 'src/config') + def tearDown(self): shutil.rmtree(self.project_dir) - + def test_empty_file(self): config = TaskConfig("task_name", self.project_dir) assert config.name == "task_name" assert config.description is "" assert config.expected_output is "" assert config.agent is "" - + def test_read_minimal_yaml(self): - shutil.copy(BASE_PATH/"fixtures/tasks_min.yaml", self.project_dir/TASKS_FILENAME) + shutil.copy(BASE_PATH / "fixtures/tasks_min.yaml", self.project_dir / TASKS_FILENAME) config = TaskConfig("task_name", self.project_dir) assert config.name == "task_name" assert config.description is "" assert config.expected_output is "" assert config.agent is "" - + def test_read_maximal_yaml(self): - shutil.copy(BASE_PATH/"fixtures/tasks_max.yaml", self.project_dir/TASKS_FILENAME) + shutil.copy(BASE_PATH / "fixtures/tasks_max.yaml", self.project_dir / TASKS_FILENAME) config = TaskConfig("task_name", self.project_dir) assert config.name == "task_name" assert config.description == "Add your description here" @@ -44,9 +45,11 @@ def test_write_yaml(self): config.description = "Add your description here" config.expected_output = "Add your expected output here" config.agent = "default_agent" - - yaml_src = open(self.project_dir/TASKS_FILENAME).read() - assert yaml_src == """task_name: + + yaml_src = open(self.project_dir / TASKS_FILENAME).read() + assert ( + yaml_src + == """task_name: description: >- Add your description here expected_output: >- @@ -54,16 +57,20 @@ def test_write_yaml(self): agent: >- default_agent """ + ) def test_write_none_values(self): with TaskConfig("task_name", self.project_dir) as config: config.description = None config.expected_output = None config.agent = None - - yaml_src = open(self.project_dir/TASKS_FILENAME).read() - assert yaml_src == """task_name: + + yaml_src = open(self.project_dir / TASKS_FILENAME).read() + assert ( + yaml_src + == """task_name: description: > expected_output: > agent: > -""" \ No newline at end of file +""" + ) diff --git a/tests/test_tool_config.py b/tests/test_tool_config.py index 894406af..c56d3b30 100644 --- a/tests/test_tool_config.py +++ b/tests/test_tool_config.py @@ -7,6 +7,7 @@ BASE_PATH = Path(__file__).parent + class ToolConfigTest(unittest.TestCase): def test_minimal_json(self): config = ToolConfig.from_json(BASE_PATH / "fixtures/tool_config_min.json") @@ -20,7 +21,7 @@ def test_minimal_json(self): assert config.packages is None assert config.post_install is None assert config.post_remove is None - + def test_maximal_json(self): config = ToolConfig.from_json(BASE_PATH / "fixtures/tool_config_max.json") assert config.name == "tool_name" @@ -33,7 +34,7 @@ def test_maximal_json(self): assert config.packages == ["package1", "package2"] assert config.post_install == "install.sh" assert config.post_remove == "remove.sh" - + def test_all_json_configs_from_tool_name(self): for tool_name in get_all_tool_names(): config = ToolConfig.from_tool_name(tool_name) @@ -45,7 +46,9 @@ def test_all_json_configs_from_tool_path(self): try: config = ToolConfig.from_json(path) except json.decoder.JSONDecodeError as e: - raise Exception(f"Failed to decode tool json at {path}. Does your tool config fit the required formatting? https://github.com/AgentOps-AI/AgentStack/blob/main/agentstack/tools/~README.md") + raise Exception( + f"Failed to decode tool json at {path}. Does your tool config fit the required formatting? https://github.com/AgentOps-AI/AgentStack/blob/main/agentstack/tools/~README.md" + ) assert config.name == path.stem # We can assume that pydantic validation caught any other issues diff --git a/tests/test_tool_generation_init.py b/tests/test_tool_generation_init.py index 6df0ccae..7bb79580 100644 --- a/tests/test_tool_generation_init.py +++ b/tests/test_tool_generation_init.py @@ -13,67 +13,68 @@ BASE_PATH = Path(__file__).parent -@parameterized_class([ - {"framework": framework} for framework in frameworks.SUPPORTED_FRAMEWORKS -]) + +@parameterized_class([{"framework": framework} for framework in frameworks.SUPPORTED_FRAMEWORKS]) class TestToolGenerationInit(unittest.TestCase): def setUp(self): - self.project_dir = BASE_PATH/'tmp'/'tool_generation_init' + self.project_dir = BASE_PATH / 'tmp' / 'tool_generation_init' os.makedirs(self.project_dir) - os.makedirs(self.project_dir/'src') - os.makedirs(self.project_dir/'src'/'tools') - (self.project_dir/'src'/'__init__.py').touch() - (self.project_dir/'src'/'tools'/'__init__.py').touch() - shutil.copy(BASE_PATH/'fixtures'/'agentstack.json', self.project_dir/'agentstack.json') + os.makedirs(self.project_dir / 'src') + os.makedirs(self.project_dir / 'src' / 'tools') + (self.project_dir / 'src' / '__init__.py').touch() + (self.project_dir / 'src' / 'tools' / '__init__.py').touch() + shutil.copy(BASE_PATH / 'fixtures' / 'agentstack.json', self.project_dir / 'agentstack.json') # set the framework in agentstack.json with ConfigFile(self.project_dir) as config: config.framework = self.framework - + def tearDown(self): shutil.rmtree(self.project_dir) - + def _get_test_tool(self) -> ToolConfig: return ToolConfig(name='test_tool', category='test', tools=['test_tool']) - + def _get_test_tool_alt(self) -> ToolConfig: return ToolConfig(name='test_tool_alt', category='test', tools=['test_tool_alt']) - + def test_tools_init_file(self): - tools_init = ToolsInitFile(self.project_dir/TOOLS_INIT_FILENAME) + tools_init = ToolsInitFile(self.project_dir / TOOLS_INIT_FILENAME) # file is empty assert tools_init.get_import_for_tool(self._get_test_tool()) == None - + def test_tools_init_file_missing(self): with self.assertRaises(ValidationError) as context: - tools_init = ToolsInitFile(self.project_dir/'missing') + tools_init = ToolsInitFile(self.project_dir / 'missing') def test_tools_init_file_add_import(self): tool = self._get_test_tool() - with ToolsInitFile(self.project_dir/TOOLS_INIT_FILENAME) as tools_init: + with ToolsInitFile(self.project_dir / TOOLS_INIT_FILENAME) as tools_init: tools_init.add_import_for_tool(self.framework, tool) - - tool_init_src = open(self.project_dir/TOOLS_INIT_FILENAME).read() + + tool_init_src = open(self.project_dir / TOOLS_INIT_FILENAME).read() assert tool.get_import_statement(self.framework) in tool_init_src - + def test_tools_init_file_add_import_multiple(self): tool = self._get_test_tool() tool_alt = self._get_test_tool_alt() - with ToolsInitFile(self.project_dir/TOOLS_INIT_FILENAME) as tools_init: + with ToolsInitFile(self.project_dir / TOOLS_INIT_FILENAME) as tools_init: tools_init.add_import_for_tool(self.framework, tool) - - with ToolsInitFile(self.project_dir/TOOLS_INIT_FILENAME) as tools_init: + + with ToolsInitFile(self.project_dir / TOOLS_INIT_FILENAME) as tools_init: tools_init.add_import_for_tool(self.framework, tool_alt) - + # Should not be able to re-add a tool import with self.assertRaises(ValidationError) as context: - with ToolsInitFile(self.project_dir/TOOLS_INIT_FILENAME) as tools_init: + with ToolsInitFile(self.project_dir / TOOLS_INIT_FILENAME) as tools_init: tools_init.add_import_for_tool(self.framework, tool) - - tool_init_src = open(self.project_dir/TOOLS_INIT_FILENAME).read() + + tool_init_src = open(self.project_dir / TOOLS_INIT_FILENAME).read() assert tool.get_import_statement(self.framework) in tool_init_src assert tool_alt.get_import_statement(self.framework) in tool_init_src # TODO this might be a little too strict - assert tool_init_src == """ + assert ( + tool_init_src + == """ from .test_tool_tool import test_tool from .test_tool_alt_tool import test_tool_alt""" - + ) From 406d90ffd9d1e2b134a095c26034df97a77ea525 Mon Sep 17 00:00:00 2001 From: Travis Dent Date: Fri, 6 Dec 2024 09:55:49 -0800 Subject: [PATCH 11/14] ruff format frameworks branch --- agentstack/agents.py | 46 +++--- agentstack/cli/agentstack_data.py | 45 +++--- agentstack/cli/cli.py | 164 ++++++++++++---------- agentstack/frameworks/crewai.py | 93 +++++++----- agentstack/generation/agent_generation.py | 22 +-- agentstack/generation/asttools.py | 47 ++++--- agentstack/generation/files.py | 52 ++++--- agentstack/generation/gen_utils.py | 37 ++--- agentstack/generation/task_generation.py | 20 +-- agentstack/generation/tool_generation.py | 46 +++--- agentstack/logger.py | 4 +- agentstack/main.py | 11 +- agentstack/packaging.py | 3 + agentstack/tasks.py | 46 +++--- agentstack/telemetry.py | 17 ++- agentstack/tools.py | 11 +- agentstack/update.py | 22 ++- agentstack/utils.py | 19 ++- 18 files changed, 404 insertions(+), 301 deletions(-) diff --git a/agentstack/agents.py b/agentstack/agents.py index 4ebb27d0..5d9a8bc3 100644 --- a/agentstack/agents.py +++ b/agentstack/agents.py @@ -12,18 +12,19 @@ yaml = YAML() yaml.preserve_quotes = True # Preserve quotes in existing data + class AgentConfig(pydantic.BaseModel): """ Interface for interacting with an agent configuration. - - Multiple agents are stored in a single YAML file, so we always look up the + + Multiple agents are stored in a single YAML file, so we always look up the requested agent by `name`. - + Use it as a context manager to make and save edits: ```python with AgentConfig('agent_name') as config: config.llm = "openai/gpt-4o" - + Config Schema ------------- name: str @@ -38,6 +39,7 @@ class AgentConfig(pydantic.BaseModel): The model this agent should use. Adheres to the format set by the framework. """ + name: str role: Optional[str] = "" goal: Optional[str] = "" @@ -45,14 +47,15 @@ class AgentConfig(pydantic.BaseModel): llm: Optional[str] = "" def __init__(self, name: str, path: Optional[Path] = None): - if not path: path = Path() - - if not os.path.exists(path/AGENTS_FILENAME): - os.makedirs((path/AGENTS_FILENAME).parent, exist_ok=True) - (path/AGENTS_FILENAME).touch() - + if not path: + path = Path() + + if not os.path.exists(path / AGENTS_FILENAME): + os.makedirs((path / AGENTS_FILENAME).parent, exist_ok=True) + (path / AGENTS_FILENAME).touch() + try: - with open(path/AGENTS_FILENAME, 'r') as f: + with open(path / AGENTS_FILENAME, 'r') as f: data = yaml.load(f) or {} data = data.get(name, {}) or {} super().__init__(**{**{'name': name}, **data}) @@ -64,26 +67,29 @@ def __init__(self, name: str, path: Optional[Path] = None): for error in e.errors(): error_str += f"{' '.join(error['loc'])}: {error['msg']}\n" raise ValidationError(f"Error loading agent {name} from {filename}.\n{error_str}") - + # store the path *after* loading data self._path = path def model_dump(self, *args, **kwargs) -> dict: dump = super().model_dump(*args, **kwargs) - dump.pop('name') # name is the key, so keep it out of the data + dump.pop('name') # name is the key, so keep it out of the data # format these as FoldedScalarStrings for key in ('role', 'goal', 'backstory'): dump[key] = FoldedScalarString(dump.get(key) or "") return {self.name: dump} def write(self): - with open(self._path/AGENTS_FILENAME, 'r') as f: + with open(self._path / AGENTS_FILENAME, 'r') as f: data = yaml.load(f) or {} - + data.update(self.model_dump()) - - with open(self._path/AGENTS_FILENAME, 'w') as f: + + with open(self._path / AGENTS_FILENAME, 'w') as f: yaml.dump(data, f) - - def __enter__(self) -> 'AgentConfig': return self - def __exit__(self, *args): self.write() + + def __enter__(self) -> 'AgentConfig': + return self + + def __exit__(self, *args): + self.write() diff --git a/agentstack/cli/agentstack_data.py b/agentstack/cli/agentstack_data.py index 4acd39c3..2eda10aa 100644 --- a/agentstack/cli/agentstack_data.py +++ b/agentstack/cli/agentstack_data.py @@ -7,17 +7,18 @@ class ProjectMetadata: - def __init__(self, - project_name: str = None, - project_slug: str = None, - description: str = "", - author_name: str = "", - version: str = "", - license: str = "", - year: int = datetime.now().year, - template: str = "none", - template_version: str = "0", - ): + def __init__( + self, + project_name: str = None, + project_slug: str = None, + description: str = "", + author_name: str = "", + version: str = "", + license: str = "", + year: int = datetime.now().year, + template: str = "none", + template_version: str = "0", + ): self.project_name = clean_input(project_name) if project_name else "myagent" self.project_slug = clean_input(project_slug) if project_slug else self.project_name self.description = description @@ -76,10 +77,11 @@ def to_json(self): class FrameworkData: - def __init__(self, - # name: Optional[Literal["crewai"]] = None - name: str = None # TODO: better framework handling, Literal or Enum - ): + def __init__( + self, + # name: Optional[Literal["crewai"]] = None + name: str = None, # TODO: better framework handling, Literal or Enum + ): self.name = name def to_dict(self): @@ -92,12 +94,13 @@ def to_json(self): class CookiecutterData: - def __init__(self, - project_metadata: ProjectMetadata, - structure: ProjectStructure, - # framework: Literal["crewai"], - framework: str, - ): + def __init__( + self, + project_metadata: ProjectMetadata, + structure: ProjectStructure, + # framework: Literal["crewai"], + framework: str, + ): self.project_metadata = project_metadata self.framework = framework self.structure = structure diff --git a/agentstack/cli/cli.py b/agentstack/cli/cli.py index a62df66e..c05d0239 100644 --- a/agentstack/cli/cli.py +++ b/agentstack/cli/cli.py @@ -34,7 +34,10 @@ 'anthropic/claude-3-opus', ] -def init_project_builder(slug_name: Optional[str] = None, template: Optional[str] = None, use_wizard: bool = False): + +def init_project_builder( + slug_name: Optional[str] = None, template: Optional[str] = None, use_wizard: bool = False +): if slug_name and not is_snake_case(slug_name): print(term_color("Project name must be snake case", 'red')) return @@ -46,16 +49,23 @@ def init_project_builder(slug_name: Optional[str] = None, template: Optional[str template_data = None if template is not None: url_start = "https://" - if template[:len(url_start)] == url_start: + if template[: len(url_start)] == url_start: # template is a url response = requests.get(template) if response.status_code == 200: template_data = response.json() else: - print(term_color(f"Failed to fetch template data from {template}. Status code: {response.status_code}", 'red')) + print( + term_color( + f"Failed to fetch template data from {template}. Status code: {response.status_code}", + 'red', + ) + ) sys.exit(1) else: - with importlib.resources.path('agentstack.templates.proj_templates', f'{template}.json') as template_path: + with importlib.resources.path( + 'agentstack.templates.proj_templates', f'{template}.json' + ) as template_path: if template_path is None: print(term_color(f"No such template {template} found", 'red')) sys.exit(1) @@ -67,7 +77,7 @@ def init_project_builder(slug_name: Optional[str] = None, template: Optional[str "version": "0.0.1", "description": template_data['description'], "author": "Name ", - "license": "MIT" + "license": "MIT", } framework = template_data['framework'] design = { @@ -93,24 +103,16 @@ def init_project_builder(slug_name: Optional[str] = None, template: Optional[str "version": "0.0.1", "description": "New agentstack project", "author": "Name ", - "license": "MIT" + "license": "MIT", } framework = "crewai" # TODO: if --no-wizard, require a framework flag - design = { - 'agents': [], - 'tasks': [], - 'inputs': [] - } + design = {'agents': [], 'tasks': [], 'inputs': []} tools = [] - log.debug( - f"project_details: {project_details}" - f"framework: {framework}" - f"design: {design}" - ) + log.debug(f"project_details: {project_details}" f"framework: {framework}" f"design: {design}") insert_template(project_details, framework, design, template_data) for tool_data in tools: generation.add_tool(tool_data['name'], agents=tool_data['agents'], path=project_details['name']) @@ -118,7 +120,12 @@ def init_project_builder(slug_name: Optional[str] = None, template: Optional[str try: packaging.install(f'{AGENTSTACK_PACKAGE}[{framework}]', path=slug_name) except Exception as e: - print(term_color(f"Failed to install dependencies for {slug_name}. Please try again by running `agentstack update`", 'red')) + print( + term_color( + f"Failed to install dependencies for {slug_name}. Please try again by running `agentstack update`", + 'red', + ) + ) def welcome_message(): @@ -138,8 +145,8 @@ def configure_default_model(path: Optional[str] = None): """Set the default model""" agentstack_config = ConfigFile(path) if agentstack_config.default_model: - return # Default model already set - + return # Default model already set + print("Project does not have a default model configured.") other_msg = f"Other (enter a model name)" model = inquirer.list_input( @@ -147,10 +154,10 @@ def configure_default_model(path: Optional[str] = None): choices=PREFERRED_MODELS + [other_msg], ) - if model == other_msg: # If the user selects "Other", prompt for a model name + if model == other_msg: # If the user selects "Other", prompt for a model name print(f'A list of available models is available at: "https://docs.litellm.ai/docs/providers"') model = inquirer.text(message="Enter the model name") - + with ConfigFile(path) as agentstack_config: agentstack_config.default_model = model @@ -160,7 +167,7 @@ def run_project(framework: str, path: str = ''): if not framework in frameworks.SUPPORTED_FRAMEWORKS: print(term_color(f"Framework {framework} is not supported by agentstack.", 'red')) sys.exit(1) - + try: frameworks.validate_project(framework, path) except frameworks.ValidationError as e: @@ -169,7 +176,7 @@ def run_project(framework: str, path: str = ''): sys.exit(1) path = Path(path) - entrypoint = path/frameworks.get_entrypoint_path(framework) + entrypoint = path / frameworks.get_entrypoint_path(framework) os.system(f'python {entrypoint}') @@ -205,10 +212,7 @@ def ask_design() -> dict: ) if not use_wizard: - return { - 'agents': [], - 'tasks': [] - } + return {'agents': [], 'tasks': []} os.system("cls" if os.name == "nt" else "clear") @@ -230,18 +234,24 @@ def ask_design() -> dict: agent_incomplete = True agent = None while agent_incomplete: - agent = inquirer.prompt([ - inquirer.Text("name", message="What's the name of this agent? (snake_case)"), - inquirer.Text("role", message="What role does this agent have?"), - inquirer.Text("goal", message="What is the goal of the agent?"), - inquirer.Text("backstory", message="Give your agent a backstory"), - # TODO: make a list - #2 - inquirer.Text('model', message="What LLM should this agent use? (any LiteLLM provider)", default="openai/gpt-4"), - # inquirer.List("model", message="What LLM should this agent use? (any LiteLLM provider)", choices=[ - # 'mixtral_llm', - # 'mixtral_llm', - # ]), - ]) + agent = inquirer.prompt( + [ + inquirer.Text("name", message="What's the name of this agent? (snake_case)"), + inquirer.Text("role", message="What role does this agent have?"), + inquirer.Text("goal", message="What is the goal of the agent?"), + inquirer.Text("backstory", message="Give your agent a backstory"), + # TODO: make a list - #2 + inquirer.Text( + 'model', + message="What LLM should this agent use? (any LiteLLM provider)", + default="openai/gpt-4", + ), + # inquirer.List("model", message="What LLM should this agent use? (any LiteLLM provider)", choices=[ + # 'mixtral_llm', + # 'mixtral_llm', + # ]), + ] + ) if not agent['name'] or agent['name'] == '': print(term_color("Error: Agent name is required - Try again", 'red')) @@ -273,14 +283,21 @@ def ask_design() -> dict: task_incomplete = True task = None while task_incomplete: - task = inquirer.prompt([ - inquirer.Text("name", message="What's the name of this task? (snake_case)"), - inquirer.Text("description", message="Describe the task in more detail"), - inquirer.Text("expected_output", - message="What do you expect the result to look like? (ex: A 5 bullet point summary of the email)"), - inquirer.List("agent", message="Which agent should be assigned this task?", - choices=[a['name'] for a in agents], ), - ]) + task = inquirer.prompt( + [ + inquirer.Text("name", message="What's the name of this task? (snake_case)"), + inquirer.Text("description", message="Describe the task in more detail"), + inquirer.Text( + "expected_output", + message="What do you expect the result to look like? (ex: A 5 bullet point summary of the email)", + ), + inquirer.List( + "agent", + message="Which agent should be assigned this task?", + choices=[a['name'] for a in agents], + ), + ] + ) if not task['name'] or task['name'] == '': print(term_color("Error: Task name is required - Try again", 'red')) @@ -319,17 +336,13 @@ def ask_tools() -> list: tools_data = open_json_file(tools_json_path) while adding_tools: - tool_type = inquirer.list_input( message="What category tool do you want to add?", - choices=list(tools_data.keys()) + ["~~ Stop adding tools ~~"] + choices=list(tools_data.keys()) + ["~~ Stop adding tools ~~"], ) tools_in_cat = [f"{t['name']} - {t['url']}" for t in tools_data[tool_type] if t not in tools_to_add] - tool_selection = inquirer.list_input( - message="Select your tool", - choices=tools_in_cat - ) + tool_selection = inquirer.list_input(message="Select your tool", choices=tools_in_cat) tools_to_add.append(tool_selection.split(' - ')[0]) @@ -349,36 +362,42 @@ def ask_project_details(slug_name: Optional[str] = None) -> dict: print(term_color("Project name must be snake case", 'red')) return ask_project_details(slug_name) - questions = inquirer.prompt([ - inquirer.Text("version", message="What's the initial version", default="0.1.0"), - inquirer.Text("description", message="Enter a description for your project"), - inquirer.Text("author", message="Who's the author (your name)?"), - ]) + questions = inquirer.prompt( + [ + inquirer.Text("version", message="What's the initial version", default="0.1.0"), + inquirer.Text("description", message="Enter a description for your project"), + inquirer.Text("author", message="Who's the author (your name)?"), + ] + ) questions['name'] = name return questions -def insert_template(project_details: dict, framework_name: str, design: dict, template_data: Optional[dict] = None): +def insert_template( + project_details: dict, framework_name: str, design: dict, template_data: Optional[dict] = None +): framework = FrameworkData(framework_name.lower()) - project_metadata = ProjectMetadata(project_name=project_details["name"], - description=project_details["description"], - author_name=project_details["author"], - version="0.0.1", - license="MIT", - year=datetime.now().year, - template=template_data['name'] if template_data else None, - template_version=template_data['template_version'] if template_data else None) + project_metadata = ProjectMetadata( + project_name=project_details["name"], + description=project_details["description"], + author_name=project_details["author"], + version="0.0.1", + license="MIT", + year=datetime.now().year, + template=template_data['name'] if template_data else None, + template_version=template_data['template_version'] if template_data else None, + ) project_structure = ProjectStructure() project_structure.agents = design["agents"] project_structure.tasks = design["tasks"] project_structure.set_inputs(design["inputs"]) - cookiecutter_data = CookiecutterData(project_metadata=project_metadata, - structure=project_structure, - framework=framework_name.lower()) + cookiecutter_data = CookiecutterData( + project_metadata=project_metadata, structure=project_structure, framework=framework_name.lower() + ) template_path = get_package_path() / f'templates/{framework.name}' with open(f"{template_path}/cookiecutter.json", "w") as json_file: @@ -387,7 +406,8 @@ def insert_template(project_details: dict, framework_name: str, design: dict, te # copy .env.example to .env shutil.copy( f'{template_path}/{"{{cookiecutter.project_metadata.project_slug}}"}/.env.example', - f'{template_path}/{"{{cookiecutter.project_metadata.project_slug}}"}/.env') + f'{template_path}/{"{{cookiecutter.project_metadata.project_slug}}"}/.env', + ) if os.path.isdir(project_details['name']): print(term_color(f"Directory {template_path} already exists. Please check this and try again", "red")) @@ -438,4 +458,4 @@ def list_tools(): print(f": {tool.url if tool.url else 'AgentStack default tool'}") print("\n\n✨ Add a tool with: agentstack tools add ") - print(" https://docs.agentstack.sh/tools/core") \ No newline at end of file + print(" https://docs.agentstack.sh/tools/core") diff --git a/agentstack/frameworks/crewai.py b/agentstack/frameworks/crewai.py index c189312f..2c0c3f1b 100644 --- a/agentstack/frameworks/crewai.py +++ b/agentstack/frameworks/crewai.py @@ -10,16 +10,18 @@ ENTRYPOINT: Path = Path('src/crew.py') + class CrewFile(asttools.File): """ Parses and manipulates the CrewAI entrypoint file. All AST interactions should happen within the methods of this class. """ + _base_class: ast.ClassDef = None def get_base_class(self) -> ast.ClassDef: """A base class is a class decorated with `@CrewBase`.""" - if self._base_class is None: # Gets cached to save repeat iteration + if self._base_class is None: # Gets cached to save repeat iteration try: self._base_class = asttools.find_class_with_decorator(self.tree, 'CrewBase')[0] except IndexError: @@ -32,7 +34,9 @@ def get_crew_method(self) -> ast.FunctionDef: base_class = self.get_base_class() return asttools.find_decorated_method_in_class(base_class, 'crew')[0] except IndexError: - raise ValidationError(f"`@crew` decorated method not found in `{base_class.name}` class in {ENTRYPOINT}") + raise ValidationError( + f"`@crew` decorated method not found in `{base_class.name}` class in {ENTRYPOINT}" + ) def get_task_methods(self) -> list[ast.FunctionDef]: """A `task` method is a method decorated with `@task`.""" @@ -51,7 +55,7 @@ def add_task_method(self, task: TaskConfig): # Add before the `crew` method crew_method = self.get_crew_method() pos, _ = self.get_node_range(crew_method) - + code = f""" @task def {task.name}(self) -> Task: return Task( @@ -81,7 +85,7 @@ def add_agent_method(self, agent: AgentConfig): # Add before the `crew` method crew_method = self.get_crew_method() pos, _ = self.get_node_range(crew_method) - + code = f""" @agent def {agent.name}(self) -> Agent: return Agent( @@ -98,7 +102,7 @@ def {agent.name}(self) -> Agent: def get_agent_tools(self, agent_name: str) -> ast.List: """ Get the tools used by an agent as AST nodes. - + Tool definitons are inside of the methods marked with an `@agent` decorator. The method returns a new class instance with the tools as a list of callables under the kwarg `tools`. @@ -106,21 +110,25 @@ def get_agent_tools(self, agent_name: str) -> ast.List: method = asttools.find_method(self.get_agent_methods(), agent_name) if method is None: raise ValidationError(f"`@agent` method `{agent_name}` does not exist in {ENTRYPOINT}") - + agent_class = asttools.find_class_instantiation(method, 'Agent') if agent_class is None: - raise ValidationError(f"`@agent` method `{agent_name}` does not have an `Agent` class instantiation in {ENTRYPOINT}") - + raise ValidationError( + f"`@agent` method `{agent_name}` does not have an `Agent` class instantiation in {ENTRYPOINT}" + ) + tools_kwarg = asttools.find_kwarg_in_method_call(agent_class, 'tools') if not tools_kwarg: - raise ValidationError(f"`@agent` method `{agent_name}` does not have a keyword argument `tools` in {ENTRYPOINT}") + raise ValidationError( + f"`@agent` method `{agent_name}` does not have a keyword argument `tools` in {ENTRYPOINT}" + ) return tools_kwarg.value def add_agent_tools(self, agent_name: str, tool: ToolConfig): """ Add new tools to be used by an agent. - + Tool definitons are inside of the methods marked with an `@agent` decorator. The method returns a new class instance with the tools as a list of callables under the kwarg `tools`. @@ -128,30 +136,27 @@ def add_agent_tools(self, agent_name: str, tool: ToolConfig): method = asttools.find_method(self.get_agent_methods(), agent_name) if method is None: raise ValidationError(f"`@agent` method `{agent_name}` does not exist in {ENTRYPOINT}") - + new_tool_nodes = [] for tool_name in tool.tools: # This prefixes the tool name with the 'tools' module node = asttools.create_attribute('tools', tool_name) - if tool.tools_bundled: # Splat the variable if it's bundled + if tool.tools_bundled: # Splat the variable if it's bundled node = ast.Starred(value=node, ctx=ast.Load()) new_tool_nodes.append(node) - + existing_node: ast.List = self.get_agent_tools(agent_name) - new_node = ast.List( - elts=set(existing_node.elts + new_tool_nodes), - ctx=ast.Load() - ) + new_node = ast.List(elts=set(existing_node.elts + new_tool_nodes), ctx=ast.Load()) start, end = self.get_node_range(existing_node) self.edit_node_range(start, end, new_node) - + def remove_agent_tools(self, agent_name: str, tool: ToolConfig): """ Remove tools from an agent belonging to `tool`. """ existing_node: ast.List = self.get_agent_tools(agent_name) start, end = self.get_node_range(existing_node) - + # modify the existing node to remove any matching tools for tool_name in tool.tools: for node in existing_node.elts: @@ -161,7 +166,7 @@ def remove_agent_tools(self, agent_name: str, tool: ToolConfig): attr_name = node.attr if attr_name == tool_name: existing_node.elts.remove(node) - + self.edit_node_range(start, end, existing_node) @@ -170,9 +175,10 @@ def validate_project(path: Optional[Path] = None) -> None: Validate that a CrewAI project is ready to run. Raises an `agentstack.VaidationError` if the project is not valid. """ - if path is None: path = Path() + if path is None: + path = Path() try: - crew_file = CrewFile(path/ENTRYPOINT) + crew_file = CrewFile(path / ENTRYPOINT) except ValidationError as e: raise e @@ -192,60 +198,73 @@ def validate_project(path: Optional[Path] = None) -> None: if len(crew_file.get_task_methods()) < 1: raise ValidationError( f"`@task` decorated method not found in `{class_node.name}` class in {ENTRYPOINT}.\n" - "Create a new task using `agentstack generate task `.") + "Create a new task using `agentstack generate task `." + ) # The Crew class must have one or more methods decorated with `@agent` if len(crew_file.get_agent_methods()) < 1: raise ValidationError( f"`@agent` decorated method not found in `{class_node.name}` class in {ENTRYPOINT}.\n" - "Create a new agent using `agentstack generate agent `.") + "Create a new agent using `agentstack generate agent `." + ) + def get_task_names(path: Optional[Path] = None) -> list[str]: """ Get a list of task names (methods with an @task decorator). """ - if path is None: path = Path() - crew_file = CrewFile(path/ENTRYPOINT) + if path is None: + path = Path() + crew_file = CrewFile(path / ENTRYPOINT) return [method.name for method in crew_file.get_task_methods()] + def add_task(task: TaskConfig, path: Optional[Path] = None) -> None: """ Add a task method to the CrewAI entrypoint. """ - if path is None: path = Path() - with CrewFile(path/ENTRYPOINT) as crew_file: + if path is None: + path = Path() + with CrewFile(path / ENTRYPOINT) as crew_file: crew_file.add_task_method(task) + def get_agent_names(path: Optional[Path] = None) -> list[str]: """ Get a list of agent names (methods with an @agent decorator). """ - if path is None: path = Path() - crew_file = CrewFile(path/ENTRYPOINT) + if path is None: + path = Path() + crew_file = CrewFile(path / ENTRYPOINT) return [method.name for method in crew_file.get_agent_methods()] + def add_agent(agent: AgentConfig, path: Optional[Path] = None) -> None: """ Add an agent method to the CrewAI entrypoint. """ - if path is None: path = Path() - with CrewFile(path/ENTRYPOINT) as crew_file: + if path is None: + path = Path() + with CrewFile(path / ENTRYPOINT) as crew_file: crew_file.add_agent_method(agent) + def add_tool(tool: ToolConfig, agent_name: str, path: Optional[Path] = None): """ Add a tool to the CrewAI entrypoint for the specified agent. The agent should already exist in the crew class and have a keyword argument `tools`. """ - if path is None: path = Path() - with CrewFile(path/ENTRYPOINT) as crew_file: + if path is None: + path = Path() + with CrewFile(path / ENTRYPOINT) as crew_file: crew_file.add_agent_tools(agent_name, tool) + def remove_tool(tool: ToolConfig, agent_name: str, path: Optional[Path] = None): """ Remove a tool from the CrewAI framework for the specified agent. """ - if path is None: path = Path() - with CrewFile(path/ENTRYPOINT) as crew_file: + if path is None: + path = Path() + with CrewFile(path / ENTRYPOINT) as crew_file: crew_file.remove_agent_tools(agent_name, tool) - diff --git a/agentstack/generation/agent_generation.py b/agentstack/generation/agent_generation.py index e330e2d9..2bc03157 100644 --- a/agentstack/generation/agent_generation.py +++ b/agentstack/generation/agent_generation.py @@ -9,14 +9,15 @@ def add_agent( - agent_name: str, - role: Optional[str] = None, - goal: Optional[str] = None, - backstory: Optional[str] = None, - llm: Optional[str] = None, - path: Optional[Path] = None): - - if path is None: path = Path() + agent_name: str, + role: Optional[str] = None, + goal: Optional[str] = None, + backstory: Optional[str] = None, + llm: Optional[str] = None, + path: Optional[Path] = None, +): + if path is None: + path = Path() verify_agentstack_project(path) agentstack_config = ConfigFile(path) framework = agentstack_config.framework @@ -27,13 +28,12 @@ def add_agent( config.goal = goal or "Add your goal here" config.backstory = backstory or "Add your backstory here" config.llm = llm or agentstack_config.default_model - + try: frameworks.add_agent(framework, agent, path) print(f" > Added to {AGENTS_FILENAME}") except ValidationError as e: print(f"Error adding agent to project:\n{e}") sys.exit(1) - - print(f"Added agent \"{agent_name}\" to your AgentStack project successfully!") + print(f"Added agent \"{agent_name}\" to your AgentStack project successfully!") diff --git a/agentstack/generation/asttools.py b/agentstack/generation/asttools.py index 9599d132..68ac5220 100644 --- a/agentstack/generation/asttools.py +++ b/agentstack/generation/asttools.py @@ -1,13 +1,14 @@ """ -Tools for working with ASTs. +Tools for working with ASTs. -We include convenience functions here based on real needs inside the codebase, +We include convenience functions here based on real needs inside the codebase, such as finding a method definition in a class, or finding a method by its decorator. It's not optimal to have a fully-featured set of functions as this would be unwieldy, but since our use-cases are well-defined, we can provide a set of functions that are useful for the specific tasks we need to accomplish. """ + from typing import Optional, Union from pathlib import Path import ast @@ -19,22 +20,23 @@ class File: """ Parses and manipulates a Python source file with an AST. - + Use it as a context manager to make and save edits: ```python with File(filename) as f: f.edit_node_range(start, end, new_node) ``` - + Lookups are done using the built-in `ast` module, which we only use to find - and read nodes in the tree. - + and read nodes in the tree. + Edits are done using string indexing on the source code, which preserves a majority of the original formatting and prevents comments from being lost. - + In cases where we are constructing new AST nodes, we use `astor` to render - the node as source code. + the node as source code. """ + filename: Path = None source: str = None atok: asttokens.ASTTokens = None @@ -44,7 +46,7 @@ def __init__(self, filename: Path): self.filename = filename self.read() - def read(self): + def read(self): try: with open(self.filename, 'r') as f: self.source = f.read() @@ -64,19 +66,19 @@ def get_node_range(self, node: ast.AST) -> tuple[int, int]: def edit_node_range(self, start: int, end: int, node: Union[str, ast.AST]): """Splice a new node or string into the source code at the given range.""" if isinstance(node, ast.AST): - module = ast.Module( - body=[ast.Expr(value=node)], - type_ignores=[] - ) + module = ast.Module(body=[ast.Expr(value=node)], type_ignores=[]) node = astor.to_source(module).strip() - + self.source = self.source[:start] + node + self.source[end:] # In order to continue accurately modifying the AST, we need to re-parse the source. self.atok = asttokens.ASTTokens(self.source, parse=True) self.tree = self.atok.tree - def __enter__(self) -> 'File': return self - def __exit__(self, *args): self.write() + def __enter__(self) -> 'File': + return self + + def __exit__(self, *args): + self.write() def get_all_imports(tree: ast.AST) -> list[Union[ast.Import, ast.ImportFrom]]: @@ -87,6 +89,7 @@ def get_all_imports(tree: ast.AST) -> list[Union[ast.Import, ast.ImportFrom]]: imports.append(node) return imports + def find_method(tree: Union[list[ast.AST], ast.AST], method_name: str) -> Optional[ast.FunctionDef]: """Find a method definition in an AST.""" if not isinstance(tree, list): @@ -96,6 +99,7 @@ def find_method(tree: Union[list[ast.AST], ast.AST], method_name: str) -> Option return node return None + def find_kwarg_in_method_call(node: ast.Call, kwarg_name: str) -> Optional[ast.keyword]: """Find a keyword argument in a method call or class instantiation.""" for arg in node.keywords: @@ -103,6 +107,7 @@ def find_kwarg_in_method_call(node: ast.Call, kwarg_name: str) -> Optional[ast.k return arg return None + def find_class_instantiation(tree: Union[list[ast.AST], ast.AST], class_name: str) -> Optional[ast.Call]: """ Find a class instantiation statement in an AST by the class name. @@ -120,6 +125,7 @@ def find_class_instantiation(tree: Union[list[ast.AST], ast.AST], class_name: st return node.value return None + def find_class_with_decorator(tree: ast.AST, decorator_name: str) -> list[ast.ClassDef]: """Find a class definition that is marked by a decorator in an AST.""" nodes = [] @@ -130,6 +136,7 @@ def find_class_with_decorator(tree: ast.AST, decorator_name: str) -> list[ast.Cl nodes.append(node) return nodes + def find_decorated_method_in_class(classdef: ast.ClassDef, decorator_name: str) -> list[ast.FunctionDef]: """Find all method definitions in a class definition which are decorated with a specific decorator.""" nodes = [] @@ -140,11 +147,7 @@ def find_decorated_method_in_class(classdef: ast.ClassDef, decorator_name: str) nodes.append(node) return nodes + def create_attribute(base_name: str, attr_name: str) -> ast.Attribute: """Create an AST node for an attribute""" - return ast.Attribute( - value=ast.Name(id=base_name, ctx=ast.Load()), - attr=attr_name, - ctx=ast.Load() - ) - + return ast.Attribute(value=ast.Name(id=base_name, ctx=ast.Load()), attr=attr_name, ctx=ast.Load()) diff --git a/agentstack/generation/files.py b/agentstack/generation/files.py index b1c226c3..513dba17 100644 --- a/agentstack/generation/files.py +++ b/agentstack/generation/files.py @@ -9,20 +9,21 @@ CONFIG_FILENAME = "agentstack.json" ENV_FILEMANE = ".env" + class ConfigFile(BaseModel): """ Interface for interacting with the agentstack.json file inside a project directory. Handles both data validation and file I/O. - + `path` is the directory where the agentstack.json file is located. Defaults to the current working directory. - + Use it as a context manager to make and save edits: ```python with ConfigFile() as config: config.tools.append('tool_name') ``` - + Config Schema ------------- framework: str @@ -30,15 +31,16 @@ class ConfigFile(BaseModel): tools: list[str] A list of tools that are currently installed in the project. telemetry_opt_out: Optional[bool] - Whether the user has opted out of telemetry. + Whether the user has opted out of telemetry. default_model: Optional[str] The default model to use when generating agent configurations. """ + framework: Optional[str] = DEFAULT_FRAMEWORK tools: list[str] = [] telemetry_opt_out: Optional[bool] = None default_model: Optional[str] = None - + def __init__(self, path: Union[str, Path, None] = None): path = Path(path) if path else Path.cwd() if os.path.exists(path / CONFIG_FILENAME): @@ -46,46 +48,50 @@ def __init__(self, path: Union[str, Path, None] = None): super().__init__(**json.loads(f.read())) else: raise FileNotFoundError(f"File {path / CONFIG_FILENAME} does not exist.") - self._path = path # attribute needs to be set after init + self._path = path # attribute needs to be set after init def model_dump(self, *args, **kwargs) -> dict: # Ignore None values dump = super().model_dump(*args, **kwargs) return {key: value for key, value in dump.items() if value is not None} - + def write(self): with open(self._path / CONFIG_FILENAME, 'w') as f: f.write(json.dumps(self.model_dump(), indent=4)) - - def __enter__(self) -> 'ConfigFile': return self - def __exit__(self, *args): self.write() + + def __enter__(self) -> 'ConfigFile': + return self + + def __exit__(self, *args): + self.write() class EnvFile: """ Interface for interacting with the .env file inside a project directory. - Unlike the ConfigFile, we do not re-write the entire file on every change, + Unlike the ConfigFile, we do not re-write the entire file on every change, and instead just append new lines to the end of the file. This preseres comments and other formatting that the user may have added and prevents opportunities for data loss. - + `path` is the directory where the .env file is located. Defaults to the current working directory. `filename` is the name of the .env file, defaults to '.env'. - + Use it as a context manager to make and save edits: ```python with EnvFile() as env: env.append_if_new('ENV_VAR', 'value') ``` """ + variables: dict[str, str] - + def __init__(self, path: Union[str, Path, None] = None, filename: str = ENV_FILEMANE): self._path = Path(path) if path else Path.cwd() self._filename = filename self.read() - + def __getitem__(self, key): return self.variables[key] @@ -93,15 +99,15 @@ def __setitem__(self, key, value): if key in self.variables: raise ValueError("EnvFile does not allow overwriting values.") self.append_if_new(key, value) - + def __contains__(self, key) -> bool: return key in self.variables - + def append_if_new(self, key, value): if not key in self.variables: self.variables[key] = value self._new_variables[key] = value - + def read(self): def parse_line(line): key, value = line.split('=') @@ -113,12 +119,14 @@ def parse_line(line): else: self.variables = {} self._new_variables = {} - + def write(self): with open(self._path / self._filename, 'a') as f: for key, value in self._new_variables.items(): f.write(f"\n{key}={value}") - - def __enter__(self) -> 'EnvFile': return self - def __exit__(self, *args): self.write() + def __enter__(self) -> 'EnvFile': + return self + + def __exit__(self, *args): + self.write() diff --git a/agentstack/generation/gen_utils.py b/agentstack/generation/gen_utils.py index dc9ac38c..72bf92a7 100644 --- a/agentstack/generation/gen_utils.py +++ b/agentstack/generation/gen_utils.py @@ -18,8 +18,10 @@ def insert_code_after_tag(file_path, tag, code_to_insert, next_line=False): for index, line in enumerate(lines): if tag in line: # Insert the code block after the tag - indented_code = [(line[:len(line)-len(line.lstrip())] + code_line + '\n') for code_line in code_to_insert] - lines[index+1:index+1] = indented_code + indented_code = [ + (line[: len(line) - len(line.lstrip())] + code_line + '\n') for code_line in code_to_insert + ] + lines[index + 1 : index + 1] = indented_code break else: raise ValueError(f"Tag '{tag}' not found in the file.") @@ -38,8 +40,9 @@ def insert_after_tasks(file_path, code_to_insert): last_task_end = None last_task_start = None for node in ast.walk(module): - if isinstance(node, ast.FunctionDef) and \ - any(isinstance(deco, ast.Name) and deco.id == 'task' for deco in node.decorator_list): + if isinstance(node, ast.FunctionDef) and any( + isinstance(deco, ast.Name) and deco.id == 'task' for deco in node.decorator_list + ): last_task_end = node.end_lineno last_task_start = node.lineno @@ -80,9 +83,9 @@ class CrewComponent(str, Enum): def get_crew_components( - framework: str = 'crewai', - component_type: Optional[Union[CrewComponent, List[CrewComponent]]] = None, - path: str = '' + framework: str = 'crewai', + component_type: Optional[Union[CrewComponent, List[CrewComponent]]] = None, + path: str = '', ) -> dict[str, List[str]]: """ Get names of components (agents and/or tasks) defined in a crew file. @@ -98,7 +101,7 @@ def get_crew_components( Dictionary with 'agents' and 'tasks' keys containing lists of names """ path = Path(path) - filename = path/frameworks.get_entrypoint_path(framework) + filename = path / frameworks.get_entrypoint_path(framework) # Convert single component type to list for consistent handling if isinstance(component_type, CrewComponent): @@ -111,10 +114,7 @@ def get_crew_components( # Parse the source into an AST tree = ast.parse(source) - components = { - 'agents': [], - 'tasks': [] - } + components = {'agents': [], 'tasks': []} # Find all function definitions with relevant decorators for node in ast.walk(tree): @@ -122,16 +122,17 @@ def get_crew_components( # Check decorators for decorator in node.decorator_list: if isinstance(decorator, ast.Name): - if (component_type is None or CrewComponent.AGENT in component_type) \ - and decorator.id == 'agent': + if ( + component_type is None or CrewComponent.AGENT in component_type + ) and decorator.id == 'agent': components['agents'].append(node.name) - elif (component_type is None or CrewComponent.TASK in component_type) \ - and decorator.id == 'task': + elif ( + component_type is None or CrewComponent.TASK in component_type + ) and decorator.id == 'task': components['tasks'].append(node.name) # If specific types were requested, only return those if component_type: - return {k: v for k, v in components.items() - if CrewComponent(k[:-1]) in component_type} + return {k: v for k, v in components.items() if CrewComponent(k[:-1]) in component_type} return components diff --git a/agentstack/generation/task_generation.py b/agentstack/generation/task_generation.py index 0b5113dd..a6e1d662 100644 --- a/agentstack/generation/task_generation.py +++ b/agentstack/generation/task_generation.py @@ -9,23 +9,24 @@ def add_task( - task_name: str, - description: Optional[str] = None, - expected_output: Optional[str] = None, - agent: Optional[str] = None, - path: Optional[Path] = None): - - if path is None: path = Path() + task_name: str, + description: Optional[str] = None, + expected_output: Optional[str] = None, + agent: Optional[str] = None, + path: Optional[Path] = None, +): + if path is None: + path = Path() verify_agentstack_project(path) agentstack_config = ConfigFile(path) framework = agentstack_config.framework - + task = TaskConfig(task_name, path) with task as config: config.description = description or "Add your description here" config.expected_output = expected_output or "Add your expected_output here" config.agent = agent or "agent_name" - + try: frameworks.add_task(framework, task, path) print(f" > Added to {TASKS_FILENAME}") @@ -33,4 +34,3 @@ def add_task( print(f"Error adding task to project:\n{e}") sys.exit(1) print(f"Added task \"{task_name}\" to your AgentStack project successfully!") - diff --git a/agentstack/generation/tool_generation.py b/agentstack/generation/tool_generation.py index 3cdf23a5..81237dcc 100644 --- a/agentstack/generation/tool_generation.py +++ b/agentstack/generation/tool_generation.py @@ -17,16 +17,18 @@ # This is the filename of the location of tool imports in the user's project. TOOLS_INIT_FILENAME: Path = Path("src/tools/__init__.py") + class ToolsInitFile(asttools.File): """ Modifiable AST representation of the tools init file. - + Use it as a context manager to make and save edits: ```python with ToolsInitFile(filename) as tools_init: tools_init.add_import_for_tool(...) ``` """ + def get_import_for_tool(self, tool: ToolConfig) -> Union[ast.Import, ast.ImportFrom]: """ Get the import statement for a tool. @@ -34,10 +36,10 @@ def get_import_for_tool(self, tool: ToolConfig) -> Union[ast.Import, ast.ImportF """ all_imports = asttools.get_all_imports(self.tree) tool_imports = [i for i in all_imports if tool.module_name == i.module] - + if len(tool_imports) > 1: raise ValidationError(f"Multiple imports for tool {tool.name} found in {self.filename}") - + try: return tool_imports[0] except IndexError: @@ -56,7 +58,7 @@ def add_import_for_tool(self, framework: str, tool: ToolConfig): last_import = asttools.get_all_imports(self.tree)[-1] start, end = self.get_node_range(last_import) except IndexError: - start, end = 0, 0 # No imports in the file + start, end = 0, 0 # No imports in the file import_statement = tool.get_import_statement(framework) self.edit_node_range(end, end, f"\n{import_statement}") @@ -75,7 +77,8 @@ def remove_import_for_tool(self, framework: str, tool: ToolConfig): def add_tool(tool_name: str, agents: Optional[list[str]] = [], path: Optional[Path] = None): - if path is None: path = Path() + if path is None: + path = Path() agentstack_config = ConfigFile(path) framework = agentstack_config.framework @@ -88,23 +91,23 @@ def add_tool(tool_name: str, agents: Optional[list[str]] = [], path: Optional[Pa if tool.packages: packaging.install(' '.join(tool.packages)) - + # Move tool from package to project - shutil.copy(tool_file_path, path/f'src/tools/{tool.module_name}.py') + shutil.copy(tool_file_path, path / f'src/tools/{tool.module_name}.py') - try: # Edit the user's project tool init file to include the tool - with ToolsInitFile(path/TOOLS_INIT_FILENAME) as tools_init: + try: # Edit the user's project tool init file to include the tool + with ToolsInitFile(path / TOOLS_INIT_FILENAME) as tools_init: tools_init.add_import_for_tool(framework, tool) except ValidationError as e: print(term_color(f"Error adding tool:\n{e}", 'red')) # Edit the framework entrypoint file to include the tool in the agent definition - if not agents: # If no agents are specified, add the tool to all agents + if not agents: # If no agents are specified, add the tool to all agents agents = frameworks.get_agent_names(framework, path) for agent_name in agents: frameworks.add_tool(framework, tool, agent_name, path) - if tool.env: # add environment variables which don't exist + if tool.env: # add environment variables which don't exist with EnvFile(path) as env: for var, value in tool.env.items(): env.append_if_new(var, value) @@ -124,7 +127,8 @@ def add_tool(tool_name: str, agents: Optional[list[str]] = [], path: Optional[Pa def remove_tool(tool_name: str, agents: Optional[list[str]] = [], path: Optional[Path] = None): - if path is None: path = Path() + if path is None: + path = Path() agentstack_config = ConfigFile(path) framework = agentstack_config.framework @@ -136,20 +140,20 @@ def remove_tool(tool_name: str, agents: Optional[list[str]] = [], path: Optional if tool.packages: packaging.remove(' '.join(tool.packages)) - # TODO ensure that other agents in the project are not using the tool. + # TODO ensure that other agents in the project are not using the tool. try: - os.remove(path/f'src/tools/{tool.module_name}.py') + os.remove(path / f'src/tools/{tool.module_name}.py') except FileNotFoundError: print(f'"src/tools/{tool.module_name}.py" not found') - try: # Edit the user's project tool init file to exclude the tool - with ToolsInitFile(path/TOOLS_INIT_FILENAME) as tools_init: + try: # Edit the user's project tool init file to exclude the tool + with ToolsInitFile(path / TOOLS_INIT_FILENAME) as tools_init: tools_init.remove_import_for_tool(framework, tool) except ValidationError as e: print(term_color(f"Error removing tool:\n{e}", 'red')) # Edit the framework entrypoint file to exclude the tool in the agent definition - if not agents: # If no agents are specified, remove the tool from all agents + if not agents: # If no agents are specified, remove the tool from all agents agents = frameworks.get_agent_names(framework, path) for agent_name in agents: frameworks.remove_tool(framework, tool, agent_name, path) @@ -161,6 +165,8 @@ def remove_tool(tool_name: str, agents: Optional[list[str]] = [], path: Optional with agentstack_config as config: config.tools.remove(tool.name) - print(term_color(f'🔨 Tool {tool_name}', 'green'), term_color('removed', 'red'), - term_color('from agentstack project successfully', 'green')) - + print( + term_color(f'🔨 Tool {tool_name}', 'green'), + term_color('removed', 'red'), + term_color('from agentstack project successfully', 'green'), + ) diff --git a/agentstack/logger.py b/agentstack/logger.py index 680d3067..e41c4887 100644 --- a/agentstack/logger.py +++ b/agentstack/logger.py @@ -16,7 +16,9 @@ def get_logger(name, debug=False): handler = logging.StreamHandler(sys.stdout) handler.setLevel(log_level) - formatter = logging.Formatter("%(asctime)s - %(process)d - %(threadName)s - %(filename)s:%(lineno)d - %(name)s - %(levelname)s - %(message)s") + formatter = logging.Formatter( + "%(asctime)s - %(process)d - %(threadName)s - %(filename)s:%(lineno)d - %(name)s - %(levelname)s - %(message)s" + ) handler.setFormatter(formatter) if not logger.handlers: diff --git a/agentstack/main.py b/agentstack/main.py index 15344c66..f3de95ce 100644 --- a/agentstack/main.py +++ b/agentstack/main.py @@ -10,6 +10,7 @@ import webbrowser + def main(): parser = argparse.ArgumentParser( description="AgentStack CLI - The easiest way to build an agent application" @@ -42,7 +43,9 @@ def main(): generate_parser = subparsers.add_parser('generate', aliases=['g'], help='Generate agents or tasks') # Subparsers under 'generate' - generate_subparsers = generate_parser.add_subparsers(dest='generate_command', help='Generate agents or tasks') + generate_subparsers = generate_parser.add_subparsers( + dest='generate_command', help='Generate agents or tasks' + ) # 'agent' command under 'generate' agent_parser = generate_subparsers.add_parser('agent', aliases=['a'], help='Generate an agent') @@ -71,7 +74,9 @@ def main(): # 'add' command under 'tools' tools_add_parser = tools_subparsers.add_parser('add', aliases=['a'], help='Add a new tool') tools_add_parser.add_argument('name', help='Name of the tool to add') - tools_add_parser.add_argument('--agents', '-a', help='Name of agents to add this tool to, comma separated') + tools_add_parser.add_argument( + '--agents', '-a', help='Name of agents to add this tool to, comma separated' + ) tools_add_parser.add_argument('--agent', help='Name of agent to add this tool to') # 'remove' command under 'tools' @@ -124,7 +129,7 @@ def main(): else: tools_parser.print_help() elif args.command in ['update', 'u']: - pass # Update check already done + pass # Update check already done else: parser.print_help() diff --git a/agentstack/packaging.py b/agentstack/packaging.py index 3e8a6adb..fb0e3cb5 100644 --- a/agentstack/packaging.py +++ b/agentstack/packaging.py @@ -3,13 +3,16 @@ PACKAGING_CMD = "poetry" + def install(package: str, path: Optional[str] = None): if path: os.chdir(path) os.system(f"{PACKAGING_CMD} add {package}") + def remove(package: str): os.system(f"{PACKAGING_CMD} remove {package}") + def upgrade(package: str): os.system(f"{PACKAGING_CMD} add {package}") diff --git a/agentstack/tasks.py b/agentstack/tasks.py index fbc75cdb..c50f9d42 100644 --- a/agentstack/tasks.py +++ b/agentstack/tasks.py @@ -12,18 +12,19 @@ yaml = YAML() yaml.preserve_quotes = True # Preserve quotes in existing data + class TaskConfig(pydantic.BaseModel): """ Interface for interacting with a task configuration. - - Multiple tasks are stored in a single YAML file, so we always look up the + + Multiple tasks are stored in a single YAML file, so we always look up the requested task by `name`. - + Use it as a context manager to make and save edits: ```python with TaskConfig('task_name') as config: config.description = "foo" - + Config Schema ------------- name: str @@ -35,20 +36,22 @@ class TaskConfig(pydantic.BaseModel): agent: Optional[str] The agent to use for the task. """ + name: str description: Optional[str] = "" expected_output: Optional[str] = "" agent: Optional[str] = "" def __init__(self, name: str, path: Optional[Path] = None): - if not path: path = Path() - - if not os.path.exists(path/TASKS_FILENAME): - os.makedirs((path/TASKS_FILENAME).parent, exist_ok=True) - (path/TASKS_FILENAME).touch() - + if not path: + path = Path() + + if not os.path.exists(path / TASKS_FILENAME): + os.makedirs((path / TASKS_FILENAME).parent, exist_ok=True) + (path / TASKS_FILENAME).touch() + try: - with open(path/TASKS_FILENAME, 'r') as f: + with open(path / TASKS_FILENAME, 'r') as f: data = yaml.load(f) or {} data = data.get(name, {}) or {} super().__init__(**{**{'name': name}, **data}) @@ -60,26 +63,29 @@ def __init__(self, name: str, path: Optional[Path] = None): for error in e.errors(): error_str += f"{' '.join(error['loc'])}: {error['msg']}\n" raise ValidationError(f"Error loading task {name} from {filename}.\n{error_str}") - + # store the path *after* loading data self._path = path def model_dump(self, *args, **kwargs) -> dict: dump = super().model_dump(*args, **kwargs) - dump.pop('name') # name is the key, so keep it out of the data + dump.pop('name') # name is the key, so keep it out of the data # format these as FoldedScalarStrings for key in ('description', 'expected_output', 'agent'): dump[key] = FoldedScalarString(dump.get(key) or "") return {self.name: dump} def write(self): - with open(self._path/TASKS_FILENAME, 'r') as f: + with open(self._path / TASKS_FILENAME, 'r') as f: data = yaml.load(f) or {} - + data.update(self.model_dump()) - - with open(self._path/TASKS_FILENAME, 'w') as f: + + with open(self._path / TASKS_FILENAME, 'w') as f: yaml.dump(data, f) - - def __enter__(self) -> 'AgentConfig': return self - def __exit__(self, *args): self.write() + + def __enter__(self) -> 'AgentConfig': + return self + + def __exit__(self, *args): + self.write() diff --git a/agentstack/telemetry.py b/agentstack/telemetry.py index db613c8c..2cb38016 100644 --- a/agentstack/telemetry.py +++ b/agentstack/telemetry.py @@ -32,6 +32,7 @@ TELEMETRY_URL = 'https://api.agentstack.sh/telemetry' + def collect_machine_telemetry(command: str): if command != "init" and get_telemetry_opt_out(): return @@ -43,7 +44,7 @@ def collect_machine_telemetry(command: str): 'os_version': platform.version(), 'cpu_count': psutil.cpu_count(logical=True), 'memory': psutil.virtual_memory().total, - 'agentstack_version': get_version() + 'agentstack_version': get_version(), } if command != "init": @@ -56,12 +57,14 @@ def collect_machine_telemetry(command: str): response = requests.get('https://ipinfo.io/json') if response.status_code == 200: location_data = response.json() - telemetry_data.update({ - 'ip': location_data.get('ip'), - 'city': location_data.get('city'), - 'region': location_data.get('region'), - 'country': location_data.get('country') - }) + telemetry_data.update( + { + 'ip': location_data.get('ip'), + 'city': location_data.get('city'), + 'region': location_data.get('region'), + 'country': location_data.get('country'), + } + ) except requests.RequestException as e: telemetry_data['location_error'] = str(e) diff --git a/agentstack/tools.py b/agentstack/tools.py index 209264f7..efda4d58 100644 --- a/agentstack/tools.py +++ b/agentstack/tools.py @@ -10,6 +10,7 @@ class ToolConfig(pydantic.BaseModel): This represents the configuration data for a tool. It parses and validates the `config.json` file for a tool. """ + name: str category: str tools: list[str] @@ -24,7 +25,7 @@ class ToolConfig(pydantic.BaseModel): @classmethod def from_tool_name(cls, name: str) -> 'ToolConfig': path = get_package_path() / f'tools/{name}.json' - if not os.path.exists(path): # TODO raise exceptions and handle message/exit in cli + if not os.path.exists(path): # TODO raise exceptions and handle message/exit in cli print(term_color(f'No known agentstack tool: {name}', 'red')) sys.exit(1) return cls.from_json(path) @@ -49,19 +50,21 @@ def get_import_statement(self, framework: str) -> str: return f"from .{self.module_name} import {', '.join(self.tools)}" def get_impl_file_path(self, framework: str) -> Path: - return get_package_path()/f'templates/{framework}/tools/{self.module_name}.py' + return get_package_path() / f'templates/{framework}/tools/{self.module_name}.py' + def get_all_tool_paths() -> list[Path]: paths = [] - tools_dir = get_package_path()/'tools' + tools_dir = get_package_path() / 'tools' for file in tools_dir.iterdir(): if file.is_file() and file.suffix == '.json': paths.append(file) return paths + def get_all_tool_names() -> list[str]: return [path.stem for path in get_all_tool_paths()] + def get_all_tools() -> list[ToolConfig]: return [ToolConfig.from_json(path) for path in get_all_tool_paths()] - diff --git a/agentstack/update.py b/agentstack/update.py index 1b1bd55d..145623fe 100644 --- a/agentstack/update.py +++ b/agentstack/update.py @@ -20,7 +20,7 @@ def _is_ci_environment(): 'TRAVIS', 'CIRCLECI', 'JENKINS_URL', - 'TEAMCITY_VERSION' + 'TEAMCITY_VERSION', ] return any(os.getenv(var) for var in ci_env_vars) @@ -45,7 +45,10 @@ def _is_ci_environment(): def get_latest_version(package: str) -> Version: """Get version information from PyPi to save a full package manager invocation""" import requests # defer import until we know we need it - response = requests.get(f"{ENDPOINT_URL}/{package}/", headers={"Accept": "application/vnd.pypi.simple.v1+json"}) + + response = requests.get( + f"{ENDPOINT_URL}/{package}/", headers={"Accept": "application/vnd.pypi.simple.v1+json"} + ) if response.status_code != 200: raise Exception(f"Failed to fetch package data from pypi.") data = response.json() @@ -116,14 +119,21 @@ def check_for_updates(update_requested: bool = False): installed_version: Version = parse_version(get_version(AGENTSTACK_PACKAGE)) if latest_version > installed_version: print('') # newline - if inquirer.confirm(f"New version of {AGENTSTACK_PACKAGE} available: {latest_version}! Do you want to install?"): + if inquirer.confirm( + f"New version of {AGENTSTACK_PACKAGE} available: {latest_version}! Do you want to install?" + ): packaging.upgrade(f'{AGENTSTACK_PACKAGE}[{get_framework()}]') - print(term_color(f"{AGENTSTACK_PACKAGE} updated. Re-run your command to use the latest version.", 'green')) + print( + term_color( + f"{AGENTSTACK_PACKAGE} updated. Re-run your command to use the latest version.", 'green' + ) + ) sys.exit(0) else: - print(term_color("Skipping update. Run `agentstack update` to install the latest version.", 'blue')) + print( + term_color("Skipping update. Run `agentstack update` to install the latest version.", 'blue') + ) else: print(f"{AGENTSTACK_PACKAGE} is up to date ({installed_version})") record_update_check() - diff --git a/agentstack/utils.py b/agentstack/utils.py index 75c347a1..e7a6a77e 100644 --- a/agentstack/utils.py +++ b/agentstack/utils.py @@ -9,6 +9,7 @@ from pathlib import Path import importlib.resources + def get_version(package: str = 'agentstack'): try: return version(package) @@ -19,12 +20,15 @@ def get_version(package: str = 'agentstack'): def verify_agentstack_project(path: Optional[str] = None): from agentstack.generation import ConfigFile + try: agentstack_config = ConfigFile(path) except FileNotFoundError: - print("\033[31mAgentStack Error: This does not appear to be an AgentStack project." - "\nPlease ensure you're at the root directory of your project and a file named agentstack.json exists. " - "If you're starting a new project, run `agentstack init`\033[0m") + print( + "\033[31mAgentStack Error: This does not appear to be an AgentStack project." + "\nPlease ensure you're at the root directory of your project and a file named agentstack.json exists. " + "If you're starting a new project, run `agentstack init`\033[0m" + ) sys.exit(1) @@ -37,6 +41,7 @@ def get_package_path() -> Path: def get_framework(path: Optional[str] = None) -> str: from agentstack.generation import ConfigFile + try: agentstack_config = ConfigFile(path) framework = agentstack_config.framework @@ -52,6 +57,7 @@ def get_framework(path: Optional[str] = None) -> str: def get_telemetry_opt_out(path: Optional[str] = None) -> str: from agentstack.generation import ConfigFile + try: agentstack_config = ConfigFile(path) return bool(agentstack_config.telemetry_opt_out) @@ -59,6 +65,7 @@ def get_telemetry_opt_out(path: Optional[str] = None) -> str: print("\033[31mFile agentstack.json does not exist. Are you in the right directory?\033[0m") sys.exit(1) + def camel_to_snake(name): s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() @@ -77,7 +84,7 @@ def open_json_file(path) -> dict: def open_yaml_file(path) -> dict: yaml = YAML() yaml.preserve_quotes = True # Preserve quotes in existing data - + with open(path, 'r') as f: data = yaml.load(f) return data @@ -96,7 +103,7 @@ def term_color(text: str, color: str) -> str: 'blue': '94', 'purple': '95', 'cyan': '96', - 'white': '97' + 'white': '97', } color_code = colors.get(color) if color_code: @@ -105,7 +112,5 @@ def term_color(text: str, color: str) -> str: return text - def is_snake_case(string: str): return bool(re.match('^[a-z0-9_]+$', string)) - From b90095fdced8495b99147269bcec175508839072 Mon Sep 17 00:00:00 2001 From: Travis Dent Date: Fri, 6 Dec 2024 13:49:45 -0800 Subject: [PATCH 12/14] Type checking --- agentstack/agents.py | 19 +++--- agentstack/cli/cli.py | 11 ++-- agentstack/frameworks/__init__.py | 31 +++++++++- agentstack/frameworks/crewai.py | 25 +++++--- agentstack/generation/asttools.py | 74 ++++++++++++++++-------- agentstack/generation/gen_utils.py | 68 ---------------------- agentstack/generation/tool_generation.py | 9 ++- agentstack/tasks.py | 21 ++++--- agentstack/tools.py | 5 +- agentstack/utils.py | 2 +- 10 files changed, 132 insertions(+), 133 deletions(-) diff --git a/agentstack/agents.py b/agentstack/agents.py index 5d9a8bc3..8040f26f 100644 --- a/agentstack/agents.py +++ b/agentstack/agents.py @@ -1,5 +1,5 @@ from typing import Optional -import os, sys +import os from pathlib import Path import pydantic from ruamel.yaml import YAML, YAMLError @@ -50,12 +50,13 @@ def __init__(self, name: str, path: Optional[Path] = None): if not path: path = Path() - if not os.path.exists(path / AGENTS_FILENAME): - os.makedirs((path / AGENTS_FILENAME).parent, exist_ok=True) - (path / AGENTS_FILENAME).touch() + filename = path / AGENTS_FILENAME + if not os.path.exists(filename): + os.makedirs(filename.parent, exist_ok=True) + filename.touch() try: - with open(path / AGENTS_FILENAME, 'r') as f: + with open(filename, 'r') as f: data = yaml.load(f) or {} data = data.get(name, {}) or {} super().__init__(**{**{'name': name}, **data}) @@ -65,7 +66,7 @@ def __init__(self, name: str, path: Optional[Path] = None): except pydantic.ValidationError as e: error_str = "Error validating agent config:\n" for error in e.errors(): - error_str += f"{' '.join(error['loc'])}: {error['msg']}\n" + error_str += f"{' '.join([str(loc) for loc in error['loc']])}: {error['msg']}\n" raise ValidationError(f"Error loading agent {name} from {filename}.\n{error_str}") # store the path *after* loading data @@ -80,12 +81,14 @@ def model_dump(self, *args, **kwargs) -> dict: return {self.name: dump} def write(self): - with open(self._path / AGENTS_FILENAME, 'r') as f: + filename = self._path / AGENTS_FILENAME + + with open(filename, 'r') as f: data = yaml.load(f) or {} data.update(self.model_dump()) - with open(self._path / AGENTS_FILENAME, 'w') as f: + with open(filename, 'w') as f: yaml.dump(data, f) def __enter__(self) -> 'AgentConfig': diff --git a/agentstack/cli/cli.py b/agentstack/cli/cli.py index 3b9d196e..15b61676 100644 --- a/agentstack/cli/cli.py +++ b/agentstack/cli/cli.py @@ -162,7 +162,7 @@ def configure_default_model(path: Optional[str] = None): ) if model == other_msg: # If the user selects "Other", prompt for a model name - print(f'A list of available models is available at: "https://docs.litellm.ai/docs/providers"') + print('A list of available models is available at: "https://docs.litellm.ai/docs/providers"') model = inquirer.text(message="Enter the model name") with ConfigFile(path) as agentstack_config: @@ -171,19 +171,20 @@ def configure_default_model(path: Optional[str] = None): def run_project(framework: str, path: str = ''): """Validate that the project is ready to run and then run it.""" - if not framework in frameworks.SUPPORTED_FRAMEWORKS: + if framework not in frameworks.SUPPORTED_FRAMEWORKS: print(term_color(f"Framework {framework} is not supported by agentstack.", 'red')) sys.exit(1) + _path = Path(path) + try: - frameworks.validate_project(framework, path) + frameworks.validate_project(framework, _path) except frameworks.ValidationError as e: print(term_color("Project validation failed:", 'red')) print(e) sys.exit(1) - path = Path(path) - entrypoint = path / frameworks.get_entrypoint_path(framework) + entrypoint = _path / frameworks.get_entrypoint_path(framework) os.system(f'python {entrypoint}') diff --git a/agentstack/frameworks/__init__.py b/agentstack/frameworks/__init__.py index dab450f9..eb921735 100644 --- a/agentstack/frameworks/__init__.py +++ b/agentstack/frameworks/__init__.py @@ -27,7 +27,8 @@ `add_task(task: TaskConfig, path: Optional[Path] = None) -> None`: Add a task to the user's project. """ -from typing import Optional +from typing import Optional, Protocol +from types import ModuleType from importlib import import_module from pathlib import Path from agentstack import ValidationError @@ -39,7 +40,29 @@ CREWAI = 'crewai' SUPPORTED_FRAMEWORKS = [CREWAI, ] -def get_framework_module(framework: str) -> import_module: +class FrameworkModule(Protocol): + ENTRYPOINT: Path + + def validate_project(self, path: Optional[Path] = None) -> None: + ... + + def add_tool(self, tool: ToolConfig, agent_name: str, path: Optional[Path] = None) -> None: + ... + + def remove_tool(self, tool: ToolConfig, agent_name: str, path: Optional[Path] = None) -> None: + ... + + def get_agent_names(self, path: Optional[Path] = None) -> list[str]: + ... + + def add_agent(self, agent: AgentConfig, path: Optional[Path] = None) -> None: + ... + + def add_task(self, task: TaskConfig, path: Optional[Path] = None) -> None: + ... + + +def get_framework_module(framework: str) -> FrameworkModule: """ Get the module for a framework. """ @@ -52,7 +75,9 @@ def get_entrypoint_path(framework: str, path: Optional[Path] = None) -> Path: """ Get the path to the entrypoint file for a framework. """ - return path/get_framework_module(framework).ENTRYPOINT + if path is None: + path = Path() + return path / get_framework_module(framework).ENTRYPOINT def validate_project(framework: str, path: Optional[Path] = None): """ diff --git a/agentstack/frameworks/crewai.py b/agentstack/frameworks/crewai.py index 2c0c3f1b..7c7877fc 100644 --- a/agentstack/frameworks/crewai.py +++ b/agentstack/frameworks/crewai.py @@ -17,7 +17,7 @@ class CrewFile(asttools.File): All AST interactions should happen within the methods of this class. """ - _base_class: ast.ClassDef = None + _base_class: Optional[ast.ClassDef] = None def get_base_class(self) -> ast.ClassDef: """A base class is a class decorated with `@CrewBase`.""" @@ -123,6 +123,11 @@ def get_agent_tools(self, agent_name: str) -> ast.List: f"`@agent` method `{agent_name}` does not have a keyword argument `tools` in {ENTRYPOINT}" ) + if not isinstance(tools_kwarg.value, ast.List): + raise ValidationError( + f"`@agent` method `{agent_name}` has a non-list value for the `tools` kwarg in {ENTRYPOINT}" + ) + return tools_kwarg.value def add_agent_tools(self, agent_name: str, tool: ToolConfig): @@ -137,16 +142,17 @@ def add_agent_tools(self, agent_name: str, tool: ToolConfig): if method is None: raise ValidationError(f"`@agent` method `{agent_name}` does not exist in {ENTRYPOINT}") - new_tool_nodes = [] + new_tool_nodes: set[ast.expr] = set() for tool_name in tool.tools: # This prefixes the tool name with the 'tools' module - node = asttools.create_attribute('tools', tool_name) + node: ast.expr = asttools.create_attribute('tools', tool_name) if tool.tools_bundled: # Splat the variable if it's bundled node = ast.Starred(value=node, ctx=ast.Load()) - new_tool_nodes.append(node) + new_tool_nodes.add(node) existing_node: ast.List = self.get_agent_tools(agent_name) - new_node = ast.List(elts=set(existing_node.elts + new_tool_nodes), ctx=ast.Load()) + elts: set[ast.expr] = set(existing_node.elts) | new_tool_nodes + new_node = ast.List(elts=list(elts), ctx=ast.Load()) start, end = self.get_node_range(existing_node) self.edit_node_range(start, end, new_node) @@ -161,9 +167,14 @@ def remove_agent_tools(self, agent_name: str, tool: ToolConfig): for tool_name in tool.tools: for node in existing_node.elts: if isinstance(node, ast.Starred): - attr_name = node.value.attr - else: + if isinstance(node.value, ast.Attribute): + attr_name = node.value.attr + else: + continue # not an attribute node + elif isinstance(node, ast.Attribute): attr_name = node.attr + else: + continue # not an attribute node if attr_name == tool_name: existing_node.elts.remove(node) diff --git a/agentstack/generation/asttools.py b/agentstack/generation/asttools.py index 68ac5220..575e0403 100644 --- a/agentstack/generation/asttools.py +++ b/agentstack/generation/asttools.py @@ -9,7 +9,7 @@ functions that are useful for the specific tasks we need to accomplish. """ -from typing import Optional, Union +from typing import TypeVar, Optional, Union, Iterable from pathlib import Path import ast import astor @@ -17,6 +17,10 @@ from agentstack import ValidationError +FileT = TypeVar('FileT', bound='File') +ASTT = TypeVar('ASTT', bound=ast.AST) + + class File: """ Parses and manipulates a Python source file with an AST. @@ -37,10 +41,10 @@ class File: the node as source code. """ - filename: Path = None - source: str = None - atok: asttokens.ASTTokens = None - tree: ast.AST = None + filename: Path + source: str + atok: asttokens.ASTTokens + tree: ast.Module def __init__(self, filename: Path): self.filename = filename @@ -65,36 +69,45 @@ def get_node_range(self, node: ast.AST) -> tuple[int, int]: def edit_node_range(self, start: int, end: int, node: Union[str, ast.AST]): """Splice a new node or string into the source code at the given range.""" - if isinstance(node, ast.AST): + if isinstance(node, ast.expr): module = ast.Module(body=[ast.Expr(value=node)], type_ignores=[]) - node = astor.to_source(module).strip() + _node = astor.to_source(module).strip() + else: + _node = node - self.source = self.source[:start] + node + self.source[end:] + self.source = self.source[:start] + _node + self.source[end:] # In order to continue accurately modifying the AST, we need to re-parse the source. self.atok = asttokens.ASTTokens(self.source, parse=True) - self.tree = self.atok.tree - def __enter__(self) -> 'File': + if self.atok.tree: + self.tree = self.atok.tree + else: + raise ValidationError(f"Failed to parse {self.filename} after edit") + + def __enter__(self: FileT) -> FileT: return self def __exit__(self, *args): self.write() -def get_all_imports(tree: ast.AST) -> list[Union[ast.Import, ast.ImportFrom]]: +def get_all_imports(tree: ast.Module) -> list[ast.ImportFrom]: """Find all import statements in an AST.""" imports = [] for node in ast.iter_child_nodes(tree): - if isinstance(node, ast.Import) or isinstance(node, ast.ImportFrom): + if isinstance(node, ast.ImportFrom): # NOTE must be in format `from x import y` imports.append(node) return imports -def find_method(tree: Union[list[ast.AST], ast.AST], method_name: str) -> Optional[ast.FunctionDef]: +def find_method(tree: Union[Iterable[ASTT], ASTT], method_name: str) -> Optional[ast.FunctionDef]: """Find a method definition in an AST.""" - if not isinstance(tree, list): - tree: generator = ast.iter_child_nodes(tree) - for node in tree: + if isinstance(tree, ast.AST): + _tree = list(ast.iter_child_nodes(tree)) + else: + _tree = list(tree) + + for node in _tree: if isinstance(node, ast.FunctionDef) and node.name == method_name: return node return None @@ -108,25 +121,36 @@ def find_kwarg_in_method_call(node: ast.Call, kwarg_name: str) -> Optional[ast.k return None -def find_class_instantiation(tree: Union[list[ast.AST], ast.AST], class_name: str) -> Optional[ast.Call]: +def find_class_instantiation(tree: Union[Iterable[ast.AST], ast.AST], class_name: str) -> Optional[ast.Call]: """ Find a class instantiation statement in an AST by the class name. This can either be an assignment to a variable or a return statement. """ - if not isinstance(tree, list): - tree: generator = ast.iter_child_nodes(tree) - for node in tree: + if isinstance(tree, ast.AST): + _tree = list(ast.iter_child_nodes(tree)) + else: + _tree = list(tree) + + for node in _tree: if isinstance(node, ast.Assign): for target in node.targets: - if isinstance(target, ast.Name) and target.id == class_name: + if ( + isinstance(target, ast.Name) + and isinstance(node.value, ast.Call) + and target.id == class_name + ): return node.value - elif isinstance(node, ast.Return): - if isinstance(node.value, ast.Call) and node.value.func.id == class_name: - return node.value + elif ( + isinstance(node, ast.Return) + and isinstance(node.value, ast.Call) + and isinstance(node.value.func, ast.Name) + and node.value.func.id == class_name + ): + return node.value return None -def find_class_with_decorator(tree: ast.AST, decorator_name: str) -> list[ast.ClassDef]: +def find_class_with_decorator(tree: ast.Module, decorator_name: str) -> list[ast.ClassDef]: """Find a class definition that is marked by a decorator in an AST.""" nodes = [] for node in ast.iter_child_nodes(tree): diff --git a/agentstack/generation/gen_utils.py b/agentstack/generation/gen_utils.py index 72bf92a7..551e12cc 100644 --- a/agentstack/generation/gen_utils.py +++ b/agentstack/generation/gen_utils.py @@ -1,11 +1,4 @@ import ast -import sys -from enum import Enum -from typing import Optional, Union, List -from pathlib import Path - -from agentstack.utils import term_color -from agentstack import frameworks def insert_code_after_tag(file_path, tag, code_to_insert, next_line=False): @@ -75,64 +68,3 @@ def string_in_file(file_path: str, str_to_match: str) -> bool: with open(file_path, 'r') as file: file_content = file.read() return str_to_match in file_content - - -class CrewComponent(str, Enum): - AGENT = "agent" - TASK = "task" - - -def get_crew_components( - framework: str = 'crewai', - component_type: Optional[Union[CrewComponent, List[CrewComponent]]] = None, - path: str = '', -) -> dict[str, List[str]]: - """ - Get names of components (agents and/or tasks) defined in a crew file. - - Args: - framework: Name of the framework - component_type: Optional filter for specific component types. - Can be CrewComponentType.AGENT, CrewComponentType.TASK, - or a list of types. If None, returns all components. - path: Optional path to the framework file - - Returns: - Dictionary with 'agents' and 'tasks' keys containing lists of names - """ - path = Path(path) - filename = path / frameworks.get_entrypoint_path(framework) - - # Convert single component type to list for consistent handling - if isinstance(component_type, CrewComponent): - component_type = [component_type] - - # Read the source file - with open(filename, 'r') as f: - source = f.read() - - # Parse the source into an AST - tree = ast.parse(source) - - components = {'agents': [], 'tasks': []} - - # Find all function definitions with relevant decorators - for node in ast.walk(tree): - if isinstance(node, ast.FunctionDef): - # Check decorators - for decorator in node.decorator_list: - if isinstance(decorator, ast.Name): - if ( - component_type is None or CrewComponent.AGENT in component_type - ) and decorator.id == 'agent': - components['agents'].append(node.name) - elif ( - component_type is None or CrewComponent.TASK in component_type - ) and decorator.id == 'task': - components['tasks'].append(node.name) - - # If specific types were requested, only return those - if component_type: - return {k: v for k, v in components.items() if CrewComponent(k[:-1]) in component_type} - - return components diff --git a/agentstack/generation/tool_generation.py b/agentstack/generation/tool_generation.py index 14a73528..4ecb2b21 100644 --- a/agentstack/generation/tool_generation.py +++ b/agentstack/generation/tool_generation.py @@ -1,8 +1,8 @@ -import os, sys -from typing import Optional, Union, Any +import os +import sys +from typing import Optional from pathlib import Path import shutil -import fileinput import ast from agentstack import frameworks @@ -29,7 +29,7 @@ class ToolsInitFile(asttools.File): ``` """ - def get_import_for_tool(self, tool: ToolConfig) -> Union[ast.Import, ast.ImportFrom]: + def get_import_for_tool(self, tool: ToolConfig) -> Optional[ast.ImportFrom]: """ Get the import statement for a tool. raises a ValidationError if the tool is imported multiple times. @@ -170,4 +170,3 @@ def remove_tool(tool_name: str, agents: Optional[list[str]] = [], path: Optional term_color('removed', 'red'), term_color('from agentstack project successfully', 'green'), ) - diff --git a/agentstack/tasks.py b/agentstack/tasks.py index c50f9d42..bad1d52e 100644 --- a/agentstack/tasks.py +++ b/agentstack/tasks.py @@ -1,5 +1,5 @@ from typing import Optional -import os, sys +import os from pathlib import Path import pydantic from ruamel.yaml import YAML, YAMLError @@ -46,12 +46,13 @@ def __init__(self, name: str, path: Optional[Path] = None): if not path: path = Path() - if not os.path.exists(path / TASKS_FILENAME): - os.makedirs((path / TASKS_FILENAME).parent, exist_ok=True) - (path / TASKS_FILENAME).touch() + filename = path / TASKS_FILENAME + if not os.path.exists(filename): + os.makedirs(filename.parent, exist_ok=True) + filename.touch() try: - with open(path / TASKS_FILENAME, 'r') as f: + with open(filename, 'r') as f: data = yaml.load(f) or {} data = data.get(name, {}) or {} super().__init__(**{**{'name': name}, **data}) @@ -61,7 +62,7 @@ def __init__(self, name: str, path: Optional[Path] = None): except pydantic.ValidationError as e: error_str = "Error validating tasks config:\n" for error in e.errors(): - error_str += f"{' '.join(error['loc'])}: {error['msg']}\n" + error_str += f"{' '.join([str(loc) for loc in error['loc']])}: {error['msg']}\n" raise ValidationError(f"Error loading task {name} from {filename}.\n{error_str}") # store the path *after* loading data @@ -76,15 +77,17 @@ def model_dump(self, *args, **kwargs) -> dict: return {self.name: dump} def write(self): - with open(self._path / TASKS_FILENAME, 'r') as f: + filename = self._path / TASKS_FILENAME + + with open(filename, 'r') as f: data = yaml.load(f) or {} data.update(self.model_dump()) - with open(self._path / TASKS_FILENAME, 'w') as f: + with open(filename, 'w') as f: yaml.dump(data, f) - def __enter__(self) -> 'AgentConfig': + def __enter__(self) -> 'TaskConfig': return self def __exit__(self, *args): diff --git a/agentstack/tools.py b/agentstack/tools.py index efda4d58..1acb8d97 100644 --- a/agentstack/tools.py +++ b/agentstack/tools.py @@ -1,5 +1,6 @@ from typing import Optional -import os, sys +import os +import sys from pathlib import Path import pydantic from agentstack.utils import get_package_path, open_json_file, term_color @@ -39,7 +40,7 @@ def from_json(cls, path: Path) -> 'ToolConfig': # TODO raise exceptions and handle message/exit in cli print(term_color(f"Error validating tool config JSON: \n{path}", 'red')) for error in e.errors(): - print(f"{' '.join(error['loc'])}: {error['msg']}") + print(f"{' '.join([str(loc) for loc in error['loc']])}: {error['msg']}") sys.exit(1) @property diff --git a/agentstack/utils.py b/agentstack/utils.py index b7af638c..de008489 100644 --- a/agentstack/utils.py +++ b/agentstack/utils.py @@ -16,7 +16,7 @@ def get_version(package: str = 'agentstack'): return "Unknown version" -def verify_agentstack_project(path: Optional[str] = None): +def verify_agentstack_project(path: Optional[Path] = None): from agentstack.generation import ConfigFile try: From de72af074d6ed15a802a87c6d5d592b714b042a5 Mon Sep 17 00:00:00 2001 From: Travis Dent Date: Fri, 6 Dec 2024 13:49:57 -0800 Subject: [PATCH 13/14] Rollback agentops version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 356192a0..1767a397 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ readme = "README.md" requires-python = ">=3.10" dependencies = [ - "agentops>=0.3.19", + "agentops>=0.3.18", "typer>=0.12.5", "inquirer>=3.4.0", "art>=6.3", From 6d3d63a5bfee6bc0bac761bfa79a99c9f529fa75 Mon Sep 17 00:00:00 2001 From: Travis Dent Date: Mon, 9 Dec 2024 08:47:34 -0800 Subject: [PATCH 14/14] Move frameworks module docs into protocol spec --- agentstack/frameworks/__init__.py | 55 +++++++++++++++---------------- 1 file changed, 26 insertions(+), 29 deletions(-) diff --git a/agentstack/frameworks/__init__.py b/agentstack/frameworks/__init__.py index eb921735..fef72bc6 100644 --- a/agentstack/frameworks/__init__.py +++ b/agentstack/frameworks/__init__.py @@ -1,32 +1,3 @@ -""" -Methods for interacting with framework-specific features. - -Each framework should have a module in the `frameworks` package which defines the -following methods: - -`ENTRYPOINT`: Path: - Relative path to the entrypoint file for the framework in the user's project. - ie. `src/crewai.py` - -`validate_project(path: Optional[Path] = None) -> None`: - Validate that a user's project is ready to run. - Raises a `ValidationError` if the project is not valid. - -`add_tool(tool: ToolConfig, agent_name: str, path: Optional[Path] = None) -> None`: - Add a tool to an agent in the user's project. - -`remove_tool(tool: ToolConfig, agent_name: str, path: Optional[Path] = None) -> None`: - Remove a tool from an agent in user's project. - -`get_agent_names(path: Optional[Path] = None) -> list[str]`: - Get a list of agent names in the user's project. - -`add_agent(agent: AgentConfig, path: Optional[Path] = None) -> None`: - Add an agent to the user's project. - -`add_task(task: TaskConfig, path: Optional[Path] = None) -> None`: - Add a task to the user's project. -""" from typing import Optional, Protocol from types import ModuleType from importlib import import_module @@ -41,24 +12,50 @@ SUPPORTED_FRAMEWORKS = [CREWAI, ] class FrameworkModule(Protocol): + """ + Protocol spec for a framework implementation module. + """ ENTRYPOINT: Path + """ + Relative path to the entrypoint file for the framework in the user's project. + ie. `src/crewai.py` + """ def validate_project(self, path: Optional[Path] = None) -> None: + """ + Validate that a user's project is ready to run. + Raises a `ValidationError` if the project is not valid. + """ ... def add_tool(self, tool: ToolConfig, agent_name: str, path: Optional[Path] = None) -> None: + """ + Add a tool to an agent in the user's project. + """ ... def remove_tool(self, tool: ToolConfig, agent_name: str, path: Optional[Path] = None) -> None: + """ + Remove a tool from an agent in user's project. + """ ... def get_agent_names(self, path: Optional[Path] = None) -> list[str]: + """ + Get a list of agent names in the user's project. + """ ... def add_agent(self, agent: AgentConfig, path: Optional[Path] = None) -> None: + """ + Add an agent to the user's project. + """ ... def add_task(self, task: TaskConfig, path: Optional[Path] = None) -> None: + """ + Add a task to the user's project. + """ ...