diff --git a/codesage/config/governance.py b/codesage/config/governance.py index b3f708c..3f5cc43 100644 --- a/codesage/config/governance.py +++ b/codesage/config/governance.py @@ -1,5 +1,26 @@ from pydantic import BaseModel, Field -from typing import Literal +from typing import Literal, Dict + + +class ValidationConfig(BaseModel): + # Commands for syntax checking (linting) + # Use {file} as placeholder + syntax_commands: Dict[str, str] = Field( + default_factory=lambda: { + "python": "python -m py_compile {file}", + "go": "go vet {file}", + }, + description="Commands to check syntax for different languages." + ) + # Commands for running tests + # Use {scope} as placeholder, which might be a file or a package + test_commands: Dict[str, str] = Field( + default_factory=lambda: { + "python": "pytest {scope}", + "go": "go test {scope}", + }, + description="Commands to run tests for different languages." + ) class GovernanceConfig(BaseModel): @@ -8,6 +29,8 @@ class GovernanceConfig(BaseModel): group_by: Literal["rule", "file", "risk_level"] = Field("rule", description="How to group governance tasks.") prioritization_strategy: Literal["risk_first", "issue_count_first"] = Field("risk_first", description="Strategy to prioritize governance tasks.") + validation: ValidationConfig = Field(default_factory=ValidationConfig, description="Validation settings.") + @classmethod def default(cls) -> "GovernanceConfig": return cls() diff --git a/codesage/governance/patch_manager.py b/codesage/governance/patch_manager.py index 34e603b..49beb68 100644 --- a/codesage/governance/patch_manager.py +++ b/codesage/governance/patch_manager.py @@ -103,3 +103,26 @@ def restore_backup(self, file_path: str | Path) -> bool: except Exception as e: logger.error("Failed to restore backup", file_path=str(path), error=str(e)) return False + + def revert(self, file_path: str | Path) -> bool: + """ + Alias for restore_backup, used for semantic clarity during rollback. + """ + return self.restore_backup(file_path) + + def cleanup_backup(self, file_path: str | Path) -> bool: + """ + Removes the backup file if it exists. + """ + path = Path(file_path) + backup_path = path.with_suffix(path.suffix + ".bak") + + if backup_path.exists(): + try: + backup_path.unlink() + logger.info("Backup cleaned up", backup_path=str(backup_path)) + return True + except Exception as e: + logger.error("Failed to cleanup backup", backup_path=str(backup_path), error=str(e)) + return False + return True diff --git a/codesage/governance/sandbox.py b/codesage/governance/sandbox.py new file mode 100644 index 0000000..0b5eed5 --- /dev/null +++ b/codesage/governance/sandbox.py @@ -0,0 +1,55 @@ +import subprocess +import os +import structlog +from typing import Dict, Optional, Tuple + +logger = structlog.get_logger() + +class Sandbox: + def __init__(self, timeout: int = 30): + self.timeout = timeout + + def run(self, command: str | list[str], env: Optional[Dict[str, str]] = None, cwd: Optional[str] = None) -> Tuple[bool, str]: + """ + Runs a command in a subprocess. + Returns (success, output). + """ + try: + # Simple environment isolation: inherit mainly PATH, but could restrict others. + run_env = os.environ.copy() + if env: + run_env.update(env) + + # If command is a string, we split it for safety if not using shell=True + # But the user config provides a string template. + # Ideally, we should parse it into arguments. + # For this phase, we will switch to shell=False if list is provided, + # but if string is provided, we might still need shell=True or shlex.split. + # To address security, we use shlex.split if it's a string. + import shlex + if isinstance(command, str): + args = shlex.split(command) + else: + args = command + + result = subprocess.run( + args, + shell=False, # Changed to False for security + capture_output=True, + text=True, + timeout=self.timeout, + env=run_env, + cwd=cwd + ) + + output = result.stdout + result.stderr + if result.returncode != 0: + return False, output + return True, output + + except subprocess.TimeoutExpired: + logger.error("Sandbox execution timed out", command=command) + return False, "Execution timed out" + except Exception as e: + logger.error("Sandbox execution failed", command=command, error=str(e)) + return False, str(e) diff --git a/codesage/governance/task_orchestrator.py b/codesage/governance/task_orchestrator.py index f0f7309..f623cb0 100644 --- a/codesage/governance/task_orchestrator.py +++ b/codesage/governance/task_orchestrator.py @@ -6,17 +6,26 @@ from codesage.governance.task_models import GovernancePlan, GovernanceTask from codesage.llm.client import BaseLLMClient, LLMRequest from codesage.governance.patch_manager import PatchManager +from codesage.governance.validator import CodeValidator +from codesage.config.governance import GovernanceConfig logger = structlog.get_logger() RISK_LEVEL_MAP = {"low": 1, "medium": 2, "high": 3, "unknown": 0} class TaskOrchestrator: - def __init__(self, plan: GovernancePlan, llm_client: Optional[BaseLLMClient] = None) -> None: + def __init__( + self, + plan: GovernancePlan, + llm_client: Optional[BaseLLMClient] = None, + config: Optional[GovernanceConfig] = None + ) -> None: self._plan = plan self._all_tasks: List[GovernanceTask] = self._flatten_tasks() self.llm_client = llm_client self.patch_manager = PatchManager() + self.config = config or GovernanceConfig.default() + self.validator = CodeValidator(self.config) def _flatten_tasks(self) -> List[GovernanceTask]: """Extracts and flattens all tasks from the plan's groups.""" @@ -63,9 +72,10 @@ def select_tasks( return filtered_tasks - def execute_task(self, task: GovernanceTask, apply_fix: bool = False) -> bool: + def execute_task(self, task: GovernanceTask, apply_fix: bool = False, max_retries: int = 3) -> bool: """ Executes a governance task using the LLM client and optionally applies the fix. + Includes a validation loop with rollback and retry. """ if not self.llm_client: logger.warning("LLM client not configured, skipping execution", task_id=task.id) @@ -73,56 +83,83 @@ def execute_task(self, task: GovernanceTask, apply_fix: bool = False) -> bool: logger.info("Executing task", task_id=task.id, file=task.file_path) - # 1. Prepare context and prompt - # Assuming task.context contains necessary info or we read file file_path = Path(task.file_path) if not file_path.exists(): logger.error("File not found", file_path=str(file_path)) return False - file_content = file_path.read_text(encoding="utf-8") + original_content = file_path.read_text(encoding="utf-8") - # Construct a prompt (This logic might be moved to a PromptBuilder later) - prompt = ( + # Initial Prompt + base_prompt = ( f"Fix the following issue in {task.file_path}:\n" - f"Issue: {task.issue_type} - {task.message}\n" - f"Severity: {task.severity}\n\n" + f"Issue: {task.rule_id} - {task.description}\n" + f"Severity: {task.risk_level}\n\n" f"Here is the file content:\n" - f"```\n{file_content}\n```\n\n" + f"```\n{original_content}\n```\n\n" f"Please provide the FULL corrected file content in a markdown code block." ) - # 2. Call LLM - request = LLMRequest( - prompt=prompt, - metadata={"task_id": task.id, "file_path": task.file_path} - ) + current_prompt = base_prompt + attempts = 0 - try: - response = self.llm_client.generate(request) - except Exception as e: - logger.error("LLM generation failed", error=str(e)) - return False + while attempts <= max_retries: + # 1. Call LLM + request = LLMRequest( + prompt=current_prompt, + metadata={"task_id": task.id, "file_path": task.file_path, "attempt": attempts} + ) - # 3. Extract Code - new_content = self.patch_manager.extract_code_block(response.content) - if not new_content: - logger.error("Failed to extract code from LLM response") - return False + try: + response = self.llm_client.generate(request) + except Exception as e: + logger.error("LLM generation failed", error=str(e)) + return False - # 4. Apply Fix if requested - if apply_fix: - success = self.patch_manager.apply_patch(file_path, new_content) - if success: - task.status = "done" - logger.info("Task completed and patch applied", task_id=task.id) + # 2. Extract Code + new_content = self.patch_manager.extract_code_block(response.content, language=task.language) + if not new_content: + logger.error("Failed to extract code from LLM response", attempt=attempts) + attempts += 1 + continue + + # 3. Apply Fix (or Dry Run) + if not apply_fix: + diff = self.patch_manager.create_diff(original_content, new_content, filename=task.file_path) + print(f"--- Patch for {task.file_path} (Dry Run) ---\n{diff}\n-----------------------------") + logger.info("Dry run completed", task_id=task.id) return True + + # Apply with backup + if self.patch_manager.apply_patch(file_path, new_content, create_backup=True): + # 4. Validate + # We use file_path as scope for now. Ideally, we should detect the test scope. + validation_result = self.validator.validate( + file_path, + language=task.language, + related_test_scope=str(file_path) + ) + + if validation_result.success: + logger.info("Validation passed", task_id=task.id) + self.patch_manager.cleanup_backup(file_path) + task.status = "done" + return True + else: + logger.warning("Validation failed, rolling back", task_id=task.id, error=validation_result.error) + self.patch_manager.revert(file_path) + + # Prepare retry prompt + current_prompt = ( + f"{base_prompt}\n\n" + f"Previous attempt failed validation ({validation_result.stage}):\n" + f"Error:\n{validation_result.error}\n\n" + f"Please try again and fix the error." + ) else: - logger.error("Failed to apply patch", task_id=task.id) - return False - else: - # Just generate diff for dry-run - diff = self.patch_manager.create_diff(file_content, new_content, filename=task.file_path) - print(f"--- Patch for {task.file_path} ---\n{diff}\n-----------------------------") - logger.info("Dry run completed", task_id=task.id) - return True + logger.error("Failed to apply patch", task_id=task.id) + + attempts += 1 + + logger.error("Task failed after retries", task_id=task.id) + return False diff --git a/codesage/governance/validator.py b/codesage/governance/validator.py new file mode 100644 index 0000000..ece6b87 --- /dev/null +++ b/codesage/governance/validator.py @@ -0,0 +1,44 @@ +from pathlib import Path +from codesage.config.governance import GovernanceConfig +from codesage.governance.sandbox import Sandbox +import structlog +from dataclasses import dataclass +from typing import Optional + +logger = structlog.get_logger() + +@dataclass +class ValidationResult: + success: bool + error: str = "" + stage: str = "" + +class CodeValidator: + def __init__(self, config: GovernanceConfig, sandbox: Optional[Sandbox] = None): + self.config = config + self.sandbox = sandbox or Sandbox() + + def validate(self, file_path: Path, language: str, related_test_scope: Optional[str] = None) -> ValidationResult: + # 1. Syntax Check + syntax_cmd_template = self.config.validation.syntax_commands.get(language) + if syntax_cmd_template: + cmd = syntax_cmd_template.format(file=str(file_path)) + logger.info("Running syntax check", command=cmd) + success, output = self.sandbox.run(cmd) + if not success: + logger.warning("Syntax validation failed", file=str(file_path), error=output) + return ValidationResult(success=False, error=output, stage="syntax") + + # 2. Test Execution (Optional) + # Only run if a scope is provided. In real world, we might infer it. + if related_test_scope: + test_cmd_template = self.config.validation.test_commands.get(language) + if test_cmd_template: + cmd = test_cmd_template.format(scope=related_test_scope) + logger.info("Running test check", command=cmd) + success, output = self.sandbox.run(cmd) + if not success: + logger.warning("Test validation failed", file=str(file_path), scope=related_test_scope, error=output) + return ValidationResult(success=False, error=output, stage="test") + + return ValidationResult(success=True) diff --git a/tests/test_governance_loop.py b/tests/test_governance_loop.py new file mode 100644 index 0000000..e054fe3 --- /dev/null +++ b/tests/test_governance_loop.py @@ -0,0 +1,97 @@ +import pytest +from pathlib import Path +from unittest.mock import MagicMock +from datetime import datetime +from codesage.governance.task_orchestrator import TaskOrchestrator +from codesage.governance.task_models import GovernancePlan, GovernanceTask, GovernanceTaskGroup +from codesage.llm.client import BaseLLMClient, LLMResponse +from codesage.config.governance import GovernanceConfig +from codesage.governance.sandbox import Sandbox + +class MockSandbox(Sandbox): + def __init__(self): + super().__init__() + self.calls = [] + + def run(self, command: str, env=None, cwd=None): + self.calls.append(command) + # We don't implement logic here because we will patch 'run' in the test + return True, "" + +@pytest.fixture +def mock_llm_client(): + client = MagicMock(spec=BaseLLMClient) + # First response: bad code (trigger syntax error) + # Second response: good code + client.generate.side_effect = [ + LLMResponse(content="```python\ndef foo()\n pass\n```"), # Missing colon + LLMResponse(content="```python\ndef foo():\n pass\n```"), # Correct + ] + return client + +def test_governance_loop_rollback_and_retry(tmp_path, mock_llm_client): + # Setup file + target_file = tmp_path / "test_file.py" + target_file.write_text("def foo():\n pass\n") + + # Setup Plan and Task + task = GovernanceTask( + id="task-1", + project_name="test_project", + file_path=str(target_file), + language="python", + rule_id="R1", + description="Fix style", + risk_level="low", + priority=1 + ) + group = GovernanceTaskGroup( + id="g1", + name="g1", + group_by="rule", + tasks=[task] + ) + plan = GovernancePlan( + project_name="test_project", + created_at=datetime.utcnow(), + summary={}, + groups=[group] + ) + + # Setup Config + config = GovernanceConfig() + + # Initialize Orchestrator + orchestrator = TaskOrchestrator(plan, llm_client=mock_llm_client, config=config) + + # Inject Mock Sandbox + mock_sandbox = MockSandbox() + orchestrator.validator.sandbox = mock_sandbox + + # Define side effect for sandbox.run + # 1st run: Syntax check on file with "def foo()" -> Fail + # 2nd run: Syntax check on file with "def foo():" -> Pass + def side_effect_run(command, env=None, cwd=None): + mock_sandbox.calls.append(command) + content = target_file.read_text() + if "def foo()\n" in content: + return False, "SyntaxError: invalid syntax" + return True, "" + + mock_sandbox.run = side_effect_run + + result = orchestrator.execute_task(task, apply_fix=True) + + assert result is True + assert task.status == "done" + assert "def foo():" in target_file.read_text() + assert not (tmp_path / "test_file.py.bak").exists() # Backup cleaned up + + # Verify LLM was called twice + assert mock_llm_client.generate.call_count == 2 + + # Verify the second prompt contained the error + call_args_list = mock_llm_client.generate.call_args_list + second_call_prompt = call_args_list[1][0][0].prompt + assert "Previous attempt failed validation" in second_call_prompt + assert "SyntaxError" in second_call_prompt diff --git a/tests/test_patch_rollback.py b/tests/test_patch_rollback.py new file mode 100644 index 0000000..946dbb8 --- /dev/null +++ b/tests/test_patch_rollback.py @@ -0,0 +1,35 @@ +import shutil +from pathlib import Path +from codesage.governance.patch_manager import PatchManager + +def test_patch_manager_backup_restore(tmp_path): + pm = PatchManager() + file_path = tmp_path / "target.txt" + file_path.write_text("Original Content") + + # Apply Patch with backup + pm.apply_patch(file_path, "New Content", create_backup=True) + + assert file_path.read_text() == "New Content" + assert (tmp_path / "target.txt.bak").exists() + assert (tmp_path / "target.txt.bak").read_text() == "Original Content" + + # Revert + pm.revert(file_path) + + assert file_path.read_text() == "Original Content" + # Revert moves the backup back, so backup file should be gone (shutil.move) + assert not (tmp_path / "target.txt.bak").exists() + +def test_patch_manager_cleanup(tmp_path): + pm = PatchManager() + file_path = tmp_path / "target.txt" + file_path.write_text("Original Content") + + # Create manual backup + backup_path = tmp_path / "target.txt.bak" + shutil.copy2(file_path, backup_path) + + pm.cleanup_backup(file_path) + + assert not backup_path.exists()