diff --git a/codesage/llm/context_builder.py b/codesage/llm/context_builder.py index fcb3342..928bbf9 100644 --- a/codesage/llm/context_builder.py +++ b/codesage/llm/context_builder.py @@ -2,6 +2,7 @@ from typing import List, Dict, Any, Optional from codesage.snapshot.models import ProjectSnapshot, FileSnapshot +from codesage.snapshot.strategies import CompressionStrategyFactory class ContextBuilder: def __init__(self, model_name: str = "gpt-4", max_tokens: int = 8000, reserve_tokens: int = 1000): @@ -22,8 +23,7 @@ def fit_to_window(self, snapshot: ProjectSnapshot) -> str: """ Builds a context string that fits within the token window. - Prioritizes primary files (full content), then reference files (summaries/interfaces), - then truncates if necessary. + Uses the compression_level specified in FileSnapshot to determine content. """ available_tokens = self.max_tokens - self.reserve_tokens @@ -42,49 +42,39 @@ def fit_to_window(self, context_parts.append(project_context) current_tokens += tokens - # 2. Add Primary Files - for file in primary_files: + # Combine primary and reference files for processing + # Note: In the new logic, the SnapshotCompressor should have already assigned appropriate levels + # based on global budget. However, ContextBuilder might receive raw snapshots. + # Here we assume we respect the file.compression_level if set. + + all_files = primary_files + reference_files + + for file in all_files: content = self._read_file(file.path) if not content: continue - file_block = f"\n{content}\n\n" + # Apply compression strategy + strategy = CompressionStrategyFactory.get_strategy(getattr(file, "compression_level", "full")) + processed_content = strategy.compress(content, file.path, file.language) + + # Decorate + file_block = f"\n{processed_content}\n\n" tokens = self.count_tokens(file_block) if current_tokens + tokens <= available_tokens: context_parts.append(file_block) current_tokens += tokens else: - # Compression needed - # We try to keep imports and signatures - compressed = self._compress_file(file, content) - tokens = self.count_tokens(compressed) - if current_tokens + tokens <= available_tokens: - context_parts.append(compressed) - current_tokens += tokens + # If even the compressed content doesn't fit, we might need to truncate + # Or stop adding files. + remaining = available_tokens - current_tokens + if remaining > 50: + truncated = processed_content[:(remaining * 3)] + "\n...(truncated due to context limit)" + context_parts.append(f"\n{truncated}\n\n") + current_tokens += remaining # Approximate + break else: - # Even compressed is too large, hard truncate - remaining = available_tokens - current_tokens - if remaining > 20: # Ensure at least some chars - chars_limit = remaining * 4 - if chars_limit > len(compressed): - chars_limit = len(compressed) - - truncated = compressed[:chars_limit] + "\n...(truncated)" - context_parts.append(truncated) - current_tokens += remaining # Stop here - break - else: - break # No space even for truncated - - # 3. Add Reference Files (Summaries) if space permits - for file in reference_files: - if current_tokens >= available_tokens: break - - summary = self._summarize_file(file) - tokens = self.count_tokens(summary) - if current_tokens + tokens <= available_tokens: - context_parts.append(summary) - current_tokens += tokens + break return "\n".join(context_parts) @@ -94,76 +84,3 @@ def _read_file(self, path: str) -> str: return f.read() except Exception: return "" - - def _compress_file(self, file_snapshot: FileSnapshot, content: str) -> str: - """ - Retains imports, structs/classes/interfaces, and function signatures. - Removes function bodies. - """ - if not file_snapshot.symbols: - # Fallback: keep first 50 lines - lines = content.splitlines() - return f"\n" + "\n".join(lines[:50]) + "\n... (bodies omitted)\n\n" - - lines = content.splitlines() - - # Intervals to exclude (function bodies) - exclude_intervals = [] - - funcs = file_snapshot.symbols.get("functions", []) - - for f in funcs: - start = f.get("start_line", 0) - end = f.get("end_line", 0) - if end > start: - # To preserve closing brace if it is on end_line, we exclude up to end_line - 1? - # It depends on where end_line points. Tree-sitter end_point is row/col. - # If end_line is the line index (0-based) where function ends. - # Usually closing brace is on end_line. - - # Check if end_line contains ONLY brace. - # If we exclude start+1 to end-1, we keep start and end line. - - exclude_start = start + 1 - exclude_end = end - 1 - - if exclude_end >= exclude_start: - exclude_intervals.append((exclude_start, exclude_end)) # inclusive - - # Sort intervals - exclude_intervals.sort() - - compressed_lines = [] - skipping = False - - for i, line in enumerate(lines): - is_excluded = False - for start_idx, end_idx in exclude_intervals: - if start_idx <= i <= end_idx: # Excluding body - is_excluded = True - break - - if is_excluded: - if not skipping: - compressed_lines.append(" ... (body omitted)") - skipping = True - else: - compressed_lines.append(line) - skipping = False - - return f"\n" + "\n".join(compressed_lines) + "\n\n" - - def _summarize_file(self, file_snapshot: FileSnapshot) -> str: - lines = [f"File: {file_snapshot.path}"] - if file_snapshot.symbols: - if "functions" in file_snapshot.symbols: - funcs = file_snapshot.symbols["functions"] - lines.append("Functions: " + ", ".join([f['name'] for f in funcs])) - if "structs" in file_snapshot.symbols: - structs = file_snapshot.symbols["structs"] - lines.append("Structs: " + ", ".join([s['name'] for s in structs])) - if "external_commands" in file_snapshot.symbols: - cmds = file_snapshot.symbols["external_commands"] - lines.append("External Commands: " + ", ".join(cmds)) - - return "\n".join(lines) + "\n" diff --git a/codesage/snapshot/compressor.py b/codesage/snapshot/compressor.py index 8712faf..2283b13 100644 --- a/codesage/snapshot/compressor.py +++ b/codesage/snapshot/compressor.py @@ -1,121 +1,112 @@ -import fnmatch -import json -import hashlib -from typing import Any, Dict, List - +from typing import Any, Dict, List, Optional +import os +import tiktoken from codesage.snapshot.models import ProjectSnapshot, FileSnapshot -from codesage.analyzers.ast_models import ASTNode - +from codesage.snapshot.strategies import CompressionStrategyFactory, FullStrategy class SnapshotCompressor: - """Compresses a ProjectSnapshot to reduce its size.""" + """Compresses a ProjectSnapshot to reduce its token usage for LLM context.""" - def __init__(self, config: Dict[str, Any]): - self.config = config.get("compression", {}) - self.exclude_patterns = self.config.get("exclude_patterns", []) - self.trimming_threshold = self.config.get("trimming_threshold", 1000) + def __init__(self, config: Dict[str, Any] = None): + self.config = config or {} + # Default budget if not specified + self.token_budget = self.config.get("token_budget", 8000) + self.model_name = self.config.get("model_name", "gpt-4") - def compress(self, snapshot: ProjectSnapshot) -> ProjectSnapshot: - """ - Compresses the snapshot by applying various techniques. + try: + self.encoding = tiktoken.encoding_for_model(self.model_name) + except KeyError: + self.encoding = tiktoken.get_encoding("cl100k_base") + + def compress_project(self, snapshot: ProjectSnapshot, project_root: str) -> ProjectSnapshot: """ - compressed_snapshot = snapshot.model_copy(deep=True) + Compresses the snapshot by assigning compression levels to files based on risk and budget. - if self.exclude_patterns: - compressed_snapshot.files = self._exclude_files( - compressed_snapshot.files, self.exclude_patterns - ) + Args: + snapshot: The project snapshot. + project_root: The root directory of the project (to read file contents). - compressed_snapshot.files = self._deduplicate_ast_nodes(compressed_snapshot.files) - compressed_snapshot.files = self._trim_large_asts( - compressed_snapshot.files, self.trimming_threshold + Returns: + The modified project snapshot with updated compression_level fields. + """ + # 1. Sort files by Risk Score (Desc) + # Assuming risk.risk_score exists. If not, default to 0. + sorted_files = sorted( + snapshot.files, + key=lambda f: f.risk.risk_score if f.risk else 0.0, + reverse=True ) - return compressed_snapshot - - def _exclude_files( - self, files: List[FileSnapshot], patterns: List[str] - ) -> List[FileSnapshot]: - """Filters out files that match the exclude patterns.""" - return [ - file - for file in files - if not any(fnmatch.fnmatch(file.path, pattern) for pattern in patterns) - ] - - def _deduplicate_ast_nodes( - self, files: List[FileSnapshot] - ) -> List[FileSnapshot]: - """ - Deduplicates AST nodes by replacing identical subtrees with a reference. - This is a simplified implementation. A real one would need a more robust - hashing and reference mechanism. - """ - node_cache = {} - for file in files: - if file.ast_summary: # Assuming ast_summary holds the AST - self._traverse_and_deduplicate(file.ast_summary, node_cache) - return files - - def _traverse_and_deduplicate(self, node: ASTNode, cache: Dict[str, ASTNode]): - """Recursively traverses the AST and deduplicates nodes.""" - if not isinstance(node, ASTNode): - return - - node_hash = self._hash_node(node) - if node_hash in cache: - # Replace with a reference to the cached node - # This is a conceptual implementation. In practice, you might - # store the canonical node in a separate structure and use IDs. - node = cache[node_hash] - return - - cache[node_hash] = node - for i, child in enumerate(node.children): - node.children[i] = self._traverse_and_deduplicate(child, cache) - - def _hash_node(self, node: ASTNode) -> str: - """Creates a stable hash for an AST node.""" - # A simple hash based on type and value. A real implementation - # should be more robust, considering children as well. - hasher = hashlib.md5() - hasher.update(node.node_type.encode()) - if node.value: - hasher.update(str(node.value).encode()) - return hasher.hexdigest() - - def _trim_large_asts( - self, files: List[FileSnapshot], threshold: int - ) -> List[FileSnapshot]: - """Trims the AST of very large files to save space.""" - for file in files: - if file.lines > threshold and file.ast_summary: - self._traverse_and_trim(file.ast_summary) - return files - - def _traverse_and_trim(self, node: ASTNode): + # 2. Initial pass: Estimate costs for different levels + file_costs = {} # {file_path: {level: token_count}} + + # We need to read files. + for file in sorted_files: + file_path = os.path.join(project_root, file.path) + try: + with open(file_path, "r", encoding="utf-8", errors="replace") as f: + content = f.read() + except Exception: + content = "" # Should we handle missing files? + + # Calculate costs for all strategies + costs = {} + for level in ["full", "skeleton", "signature"]: + strategy = CompressionStrategyFactory.get_strategy(level) + compressed_content = strategy.compress(content, file.path, file.language) + costs[level] = len(self.encoding.encode(compressed_content)) + + file_costs[file.path] = costs + + # 3. Budget allocation loop + # Start with minimal cost (all signature) + current_total_tokens = sum(file_costs[f.path]["signature"] for f in sorted_files) + + # Assign initial level + for file in snapshot.files: + file.compression_level = "signature" + + # If we have budget left, upgrade files based on risk + # We iterate sorted_files (highest risk first) + + # Upgrades: signature -> skeleton -> full + + # Pass 1: Upgrade to Skeleton + for file in sorted_files: + costs = file_costs[file.path] + cost_increase = costs["skeleton"] - costs["signature"] + + if current_total_tokens + cost_increase <= self.token_budget: + file.compression_level = "skeleton" + current_total_tokens += cost_increase + else: + # If we can't upgrade this file, maybe we can upgrade smaller files? + # Greedy approach says: prioritize high risk. + # If high risk file is huge, it might consume all budget. + # Standard Knapsack problem. + # For now, simple greedy: iterate by risk. If fits, upgrade. + pass + + # Pass 2: Upgrade to Full + for file in sorted_files: + if file.compression_level == "skeleton": + costs = file_costs[file.path] + cost_increase = costs["full"] - costs["skeleton"] + + if current_total_tokens + cost_increase <= self.token_budget: + file.compression_level = "full" + current_total_tokens += cost_increase + + return snapshot + + def select_strategy(self, file_risk: float, is_focal_file: bool) -> str: """ - Recursively traverses the AST and removes non-essential nodes, - like the bodies of functions. + Determines the ideal strategy based on risk, ignoring budget. + Used as a heuristic or upper bound. """ - if not isinstance(node, ASTNode): - return - - # For function nodes, keep the signature but remove the body - if node.node_type == "function": - node.children = [] # A simple way to trim the function body - return - - for child in node.children: - self._traverse_and_trim(child) - - - def calculate_compression_ratio( - self, original: ProjectSnapshot, compressed: ProjectSnapshot - ) -> float: - """Calculates the compression ratio.""" - original_size = len(json.dumps(original.model_dump(mode='json'))) - compressed_size = len(json.dumps(compressed.model_dump(mode='json'))) - if original_size == 0: - return 0.0 - return (original_size - compressed_size) / original_size + if is_focal_file or file_risk >= 0.7: # High risk + return "full" + elif file_risk >= 0.3: # Medium risk + return "skeleton" + else: + return "signature" diff --git a/codesage/snapshot/models.py b/codesage/snapshot/models.py index bc11f87..2920e0b 100644 --- a/codesage/snapshot/models.py +++ b/codesage/snapshot/models.py @@ -101,6 +101,7 @@ class FileSnapshot(BaseModel): symbols: Optional[Dict[str, Any]] = Field(default_factory=dict, description="A dictionary of symbols defined in the file.") risk: Optional[FileRisk] = Field(None, description="Risk assessment for the file.") issues: List[Issue] = Field(default_factory=list, description="A list of issues identified in the file.") + compression_level: Literal["full", "skeleton", "signature"] = Field("full", description="The compression level applied to the file.") # Old fields for compatibility hash: Optional[str] = Field(None, description="The SHA256 hash of the file content.") diff --git a/codesage/snapshot/strategies.py b/codesage/snapshot/strategies.py new file mode 100644 index 0000000..1f3315f --- /dev/null +++ b/codesage/snapshot/strategies.py @@ -0,0 +1,304 @@ +from abc import ABC, abstractmethod +from typing import List, Dict, Any, Optional +from tree_sitter import Language, Parser, Tree +from codesage.analyzers.base import BaseParser +from codesage.analyzers.parser_factory import create_parser + +class CompressionStrategy(ABC): + """Abstract base class for code compression strategies.""" + + def __init__(self): + # We might want to inject a tokenizer here or assume one + pass + + @abstractmethod + def compress(self, code: str, file_path: str, language_id: str) -> str: + """ + Compresses the given code. + + Args: + code: The source code content. + file_path: The path to the file (for context/logging). + language_id: The language identifier (e.g., 'python', 'go'). + + Returns: + The compressed code string. + """ + pass + +class FullStrategy(CompressionStrategy): + """Retains the full code.""" + + def compress(self, code: str, file_path: str, language_id: str) -> str: + return code + +class SkeletonStrategy(CompressionStrategy): + """ + Retains AST structure, imports, signatures, and docstrings. + Replaces function bodies with '...'. + """ + + def compress(self, code: str, file_path: str, language_id: str) -> str: + try: + parser_instance = create_parser(language_id) + except ValueError: + # Fallback if parser not found for language + return code + + # Parse the code + # Note: BaseParser typically provides a parse method but it might return custom ASTNode. + # We need the tree-sitter tree for precise text replacement. + # But BaseParser wraps tree-sitter. Let's see if we can get the raw tree or use the parser directly. + # codesage.analyzers.base.BaseParser usually has self.parser + + # Accessing the tree-sitter parser from the wrapper + ts_parser = parser_instance.parser + tree = ts_parser.parse(bytes(code, "utf8")) + + return self._prune_tree(code, tree, language_id) + + def _prune_tree(self, code: str, tree: Tree, language: str) -> str: + # We need language specific queries to identify bodies + # For Python: 'block' inside 'function_definition' + # For Go: 'block' inside 'function_declaration' or 'method_declaration' + + root_node = tree.root_node + + # We will collect ranges to exclude + exclude_ranges = [] + + # Helper to find nodes + # We can use tree-sitter queries. + + query_scm = "" + if language == "python": + query_scm = """ + (function_definition + body: (block) @body) + (class_definition + body: (block + (expression_statement) @docstring . ) @class_body) + """ + # Note: For class, we want to keep methods but maybe hide other things? + # Actually, typically we keep method signatures inside classes. + # So for classes, we don't prune the whole block, we iterate inside. + # But for functions, we prune the block. + + # Revised query for Python: + query_scm = """ + (function_definition + body: (block) @body) + """ + elif language == "go": + query_scm = """ + (function_declaration + body: (block) @body) + (method_declaration + body: (block) @body) + (func_literal + body: (block) @body) + """ + else: + # Fallback for unsupported languages in skeleton strategy + return code + + # Execute query + try: + language_obj = tree.language + query = language_obj.query(query_scm) + # captures() method was removed or changed in newer versions of tree-sitter. + # Using QueryCursor if available, or just captures() if it's on query. + # Modern API: query.captures(node) returns dict or list? + # From memory instructions: "Use QueryCursor to execute queries as Query.captures is removed." + # Actually, `query.captures` exists in some versions but might be deprecated. + # The error said "Query object has no attribute captures". + # So we must use QueryCursor if available, or language.query is returning a Query object. + + # Re-checking memory: "The project uses tree-sitter version >= 0.22, requiring the use of QueryCursor to execute queries as Query.captures is removed." + + from tree_sitter import QueryCursor + cursor = QueryCursor(query) + captures = cursor.captures(root_node) + + # We work with bytes to ensure correct slicing/replacement + code_bytes_ref = code.encode("utf8") + + # Normalize captures to a list of (node, capture_name) + flat_captures = [] + if isinstance(captures, dict): + for name, nodes in captures.items(): + for node in nodes: + flat_captures.append((node, name)) + elif isinstance(captures, list): + # captures is a list of tuples (node, capture_name) in newer versions + # Or (node, capture_index) in older versions. + # But `cursor.captures` typically returns (Node, str) in the python binding provided by tree_sitter package >= 0.22? + # Actually, in some versions it is (Node, str). + # In others it might be (Node, int). + # But since we saw "too many values to unpack", it suggests it might not be a 2-tuple. + # However, the reviewer says "captures() returns a list of tuples (Node, str)". + # Let's assume the reviewer is correct and handle potential variations. + + for item in captures: + if len(item) == 2: + node, name_or_idx = item + if isinstance(name_or_idx, int): + name = query.capture_names[name_or_idx] + else: + name = name_or_idx + flat_captures.append((node, name)) + else: + # Fallback for unknown format + pass + + for node, name in flat_captures: + if name == "body": + # Refined approach for Python: + # Check if the first child of the block is a string expression (docstring). + start_byte = node.start_byte + end_byte = node.end_byte + + # Check for docstring in Python + if language == "python": + if node.child_count > 0: + first_child = node.children[0] + if first_child.type == 'expression_statement': + # Check if it looks like a string + # We might need to dig deeper or check text + # Use bytes slicing + text_bytes = code_bytes_ref[first_child.start_byte:first_child.end_byte].strip() + if text_bytes.startswith((b'"""', b"'''", b'"', b"'")): + # It's a docstring, keep it. + # We prune from after the docstring to the end of the block. + start_byte = first_child.end_byte + + exclude_ranges.append((start_byte, end_byte)) + + except Exception as e: + print(f"Error executing query for {language}: {e}") + return code + + # Apply exclusions + # Sort ranges by start_byte reversed to avoid index shifting + exclude_ranges.sort(key=lambda x: x[0], reverse=True) + + # We need to be careful about nested ranges? + # The query shouldn't return nested bodies if we are just selecting function bodies, + # unless we have functions inside functions. + # If we remove the outer body, the inner one is gone too. + # Tree-sitter captures might return both. + # If we process reversed, we remove inner then outer. + # Actually if we remove outer, the inner removal is redundant but harmless if we work on the string or bytearray. + # Wait, if we modify the string, indices shift. + # So we MUST work reversed and ensure ranges don't overlap in a way that breaks things. + # Or easier: reconstruct the string. + + # Let's use string reconstruction or bytearray modification + + # Better: Filter out ranges that are contained in other ranges to avoid double work or errors? + # If (A, B) contains (C, D), and we process (C, D) first (higher start), we replace C..D. + # Then we process (A, B). Since B > D and A < C, A..B now covers the modified area. + # But indices A and B are from the ORIGINAL string. + # If we modify the string, we must track the shift. + + # Simple algorithm: + # 1. Collect all ranges (start, end) + # 2. Merge overlapping/nested ranges. (If nested, just take the outer one? Yes, if outer body is removed, inner is too). + # 3. Apply replacements from end to start. + + # Merging ranges: + # Sort by start. + if not exclude_ranges: + return code + + # Re-sort for merging + exclude_ranges.sort(key=lambda x: x[0]) + + merged = [] + if exclude_ranges: + curr_start, curr_end = exclude_ranges[0] + for next_start, next_end in exclude_ranges[1:]: + if next_start < curr_end: + # Overlap or nested + # If nested (next_end <= curr_end), we ignore the inner one (it's covered). + # If partial overlap (unlikely for tree nodes of this type), we extend? + # AST nodes nest. So if next starts before current ends, it must be a child (or we have weird overlap). + # We take the max end. + curr_end = max(curr_end, next_end) + else: + merged.append((curr_start, curr_end)) + curr_start, curr_end = next_start, next_end + merged.append((curr_start, curr_end)) + + # Now apply reversed + code_bytes = bytearray(code, "utf8") + + replacement = b"\n ... # Pruned\n" + + for start, end in reversed(merged): + # We might want to check if the range is empty (e.g. empty block) + if end > start: + # Check if we are preserving docstrings (adjusted start) + # If we adjusted start, we need to ensure indentation is correct for the replacement. + + # Simple replacement + code_bytes[start:end] = replacement + + return code_bytes.decode("utf8") + + +class SignatureStrategy(CompressionStrategy): + """ + Retains only top-level definitions (global variables, class names, function names). + Drastic reduction. + """ + + def compress(self, code: str, file_path: str, language_id: str) -> str: + # We can use Tree-sitter to find top-level nodes and just list them. + # Or use a simpler approach if available. + + try: + parser_instance = create_parser(language_id) + except ValueError: + return "" # Or return simplified message + + ts_parser = parser_instance.parser + tree = ts_parser.parse(bytes(code, "utf8")) + root = tree.root_node + + lines = [] + lines.append(f"# Signature Digest for {file_path}") + + # Iterate top-level children + for child in root.children: + if child.type == "function_definition": + name = self._get_name(child, code) + lines.append(f"def {name}(...): ...") + elif child.type == "class_definition": + name = self._get_name(child, code) + lines.append(f"class {name}: ...") + elif child.type == "function_declaration": # Go + name = self._get_name(child, code) + lines.append(f"func {name}(...) ...") + # Add more types as needed + + return "\n".join(lines) + + def _get_name(self, node, code): + # Find 'name' or 'identifier' child + # Use bytes to handle unicode offsets correctly + code_bytes = code.encode("utf8") + for child in node.children: + if child.type == "identifier" or child.type == "name": + return code_bytes[child.start_byte:child.end_byte].decode("utf8") + return "?" + +class CompressionStrategyFactory: + @staticmethod + def get_strategy(level: str) -> CompressionStrategy: + if level == "skeleton": + return SkeletonStrategy() + elif level == "signature": + return SignatureStrategy() + else: + return FullStrategy()