Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 25 additions & 108 deletions codesage/llm/context_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List, Dict, Any, Optional

from codesage.snapshot.models import ProjectSnapshot, FileSnapshot
from codesage.snapshot.strategies import CompressionStrategyFactory

class ContextBuilder:
def __init__(self, model_name: str = "gpt-4", max_tokens: int = 8000, reserve_tokens: int = 1000):
Expand All @@ -22,8 +23,7 @@ def fit_to_window(self,
snapshot: ProjectSnapshot) -> str:
"""
Builds a context string that fits within the token window.
Prioritizes primary files (full content), then reference files (summaries/interfaces),
then truncates if necessary.
Uses the compression_level specified in FileSnapshot to determine content.
"""
available_tokens = self.max_tokens - self.reserve_tokens

Expand All @@ -42,49 +42,39 @@ def fit_to_window(self,
context_parts.append(project_context)
current_tokens += tokens

# 2. Add Primary Files
for file in primary_files:
# Combine primary and reference files for processing
# Note: In the new logic, the SnapshotCompressor should have already assigned appropriate levels
# based on global budget. However, ContextBuilder might receive raw snapshots.
# Here we assume we respect the file.compression_level if set.

all_files = primary_files + reference_files

for file in all_files:
content = self._read_file(file.path)
if not content: continue

file_block = f"<file path=\"{file.path}\">\n{content}\n</file>\n"
# Apply compression strategy
strategy = CompressionStrategyFactory.get_strategy(getattr(file, "compression_level", "full"))
processed_content = strategy.compress(content, file.path, file.language)

# Decorate
file_block = f"<file path=\"{file.path}\">\n{processed_content}\n</file>\n"
tokens = self.count_tokens(file_block)

if current_tokens + tokens <= available_tokens:
context_parts.append(file_block)
current_tokens += tokens
else:
# Compression needed
# We try to keep imports and signatures
compressed = self._compress_file(file, content)
tokens = self.count_tokens(compressed)
if current_tokens + tokens <= available_tokens:
context_parts.append(compressed)
current_tokens += tokens
# If even the compressed content doesn't fit, we might need to truncate
# Or stop adding files.
remaining = available_tokens - current_tokens
if remaining > 50:
truncated = processed_content[:(remaining * 3)] + "\n...(truncated due to context limit)"
context_parts.append(f"<file path=\"{file.path}\">\n{truncated}\n</file>\n")
current_tokens += remaining # Approximate
break
else:
# Even compressed is too large, hard truncate
remaining = available_tokens - current_tokens
if remaining > 20: # Ensure at least some chars
chars_limit = remaining * 4
if chars_limit > len(compressed):
chars_limit = len(compressed)

truncated = compressed[:chars_limit] + "\n...(truncated)"
context_parts.append(truncated)
current_tokens += remaining # Stop here
break
else:
break # No space even for truncated

# 3. Add Reference Files (Summaries) if space permits
for file in reference_files:
if current_tokens >= available_tokens: break

summary = self._summarize_file(file)
tokens = self.count_tokens(summary)
if current_tokens + tokens <= available_tokens:
context_parts.append(summary)
current_tokens += tokens
break

return "\n".join(context_parts)

Expand All @@ -94,76 +84,3 @@ def _read_file(self, path: str) -> str:
return f.read()
except Exception:
return ""

def _compress_file(self, file_snapshot: FileSnapshot, content: str) -> str:
"""
Retains imports, structs/classes/interfaces, and function signatures.
Removes function bodies.
"""
if not file_snapshot.symbols:
# Fallback: keep first 50 lines
lines = content.splitlines()
return f"<file path=\"{file_snapshot.path}\" compressed=\"true\">\n" + "\n".join(lines[:50]) + "\n... (bodies omitted)\n</file>\n"

lines = content.splitlines()

# Intervals to exclude (function bodies)
exclude_intervals = []

funcs = file_snapshot.symbols.get("functions", [])

for f in funcs:
start = f.get("start_line", 0)
end = f.get("end_line", 0)
if end > start:
# To preserve closing brace if it is on end_line, we exclude up to end_line - 1?
# It depends on where end_line points. Tree-sitter end_point is row/col.
# If end_line is the line index (0-based) where function ends.
# Usually closing brace is on end_line.

# Check if end_line contains ONLY brace.
# If we exclude start+1 to end-1, we keep start and end line.

exclude_start = start + 1
exclude_end = end - 1

if exclude_end >= exclude_start:
exclude_intervals.append((exclude_start, exclude_end)) # inclusive

# Sort intervals
exclude_intervals.sort()

compressed_lines = []
skipping = False

for i, line in enumerate(lines):
is_excluded = False
for start_idx, end_idx in exclude_intervals:
if start_idx <= i <= end_idx: # Excluding body
is_excluded = True
break

if is_excluded:
if not skipping:
compressed_lines.append(" ... (body omitted)")
skipping = True
else:
compressed_lines.append(line)
skipping = False

return f"<file path=\"{file_snapshot.path}\" compressed=\"true\">\n" + "\n".join(compressed_lines) + "\n</file>\n"

def _summarize_file(self, file_snapshot: FileSnapshot) -> str:
lines = [f"File: {file_snapshot.path}"]
if file_snapshot.symbols:
if "functions" in file_snapshot.symbols:
funcs = file_snapshot.symbols["functions"]
lines.append("Functions: " + ", ".join([f['name'] for f in funcs]))
if "structs" in file_snapshot.symbols:
structs = file_snapshot.symbols["structs"]
lines.append("Structs: " + ", ".join([s['name'] for s in structs]))
if "external_commands" in file_snapshot.symbols:
cmds = file_snapshot.symbols["external_commands"]
lines.append("External Commands: " + ", ".join(cmds))

return "\n".join(lines) + "\n"
209 changes: 100 additions & 109 deletions codesage/snapshot/compressor.py
Original file line number Diff line number Diff line change
@@ -1,121 +1,112 @@
import fnmatch
import json
import hashlib
from typing import Any, Dict, List

from typing import Any, Dict, List, Optional
import os
import tiktoken
from codesage.snapshot.models import ProjectSnapshot, FileSnapshot
from codesage.analyzers.ast_models import ASTNode

from codesage.snapshot.strategies import CompressionStrategyFactory, FullStrategy

class SnapshotCompressor:
"""Compresses a ProjectSnapshot to reduce its size."""
"""Compresses a ProjectSnapshot to reduce its token usage for LLM context."""

def __init__(self, config: Dict[str, Any]):
self.config = config.get("compression", {})
self.exclude_patterns = self.config.get("exclude_patterns", [])
self.trimming_threshold = self.config.get("trimming_threshold", 1000)
def __init__(self, config: Dict[str, Any] = None):
self.config = config or {}
# Default budget if not specified
self.token_budget = self.config.get("token_budget", 8000)
self.model_name = self.config.get("model_name", "gpt-4")

def compress(self, snapshot: ProjectSnapshot) -> ProjectSnapshot:
"""
Compresses the snapshot by applying various techniques.
try:
self.encoding = tiktoken.encoding_for_model(self.model_name)
except KeyError:
self.encoding = tiktoken.get_encoding("cl100k_base")

def compress_project(self, snapshot: ProjectSnapshot, project_root: str) -> ProjectSnapshot:
"""
compressed_snapshot = snapshot.model_copy(deep=True)
Compresses the snapshot by assigning compression levels to files based on risk and budget.

if self.exclude_patterns:
compressed_snapshot.files = self._exclude_files(
compressed_snapshot.files, self.exclude_patterns
)
Args:
snapshot: The project snapshot.
project_root: The root directory of the project (to read file contents).

compressed_snapshot.files = self._deduplicate_ast_nodes(compressed_snapshot.files)
compressed_snapshot.files = self._trim_large_asts(
compressed_snapshot.files, self.trimming_threshold
Returns:
The modified project snapshot with updated compression_level fields.
"""
# 1. Sort files by Risk Score (Desc)
# Assuming risk.risk_score exists. If not, default to 0.
sorted_files = sorted(
snapshot.files,
key=lambda f: f.risk.risk_score if f.risk else 0.0,
reverse=True
)

return compressed_snapshot

def _exclude_files(
self, files: List[FileSnapshot], patterns: List[str]
) -> List[FileSnapshot]:
"""Filters out files that match the exclude patterns."""
return [
file
for file in files
if not any(fnmatch.fnmatch(file.path, pattern) for pattern in patterns)
]

def _deduplicate_ast_nodes(
self, files: List[FileSnapshot]
) -> List[FileSnapshot]:
"""
Deduplicates AST nodes by replacing identical subtrees with a reference.
This is a simplified implementation. A real one would need a more robust
hashing and reference mechanism.
"""
node_cache = {}
for file in files:
if file.ast_summary: # Assuming ast_summary holds the AST
self._traverse_and_deduplicate(file.ast_summary, node_cache)
return files

def _traverse_and_deduplicate(self, node: ASTNode, cache: Dict[str, ASTNode]):
"""Recursively traverses the AST and deduplicates nodes."""
if not isinstance(node, ASTNode):
return

node_hash = self._hash_node(node)
if node_hash in cache:
# Replace with a reference to the cached node
# This is a conceptual implementation. In practice, you might
# store the canonical node in a separate structure and use IDs.
node = cache[node_hash]
return

cache[node_hash] = node
for i, child in enumerate(node.children):
node.children[i] = self._traverse_and_deduplicate(child, cache)

def _hash_node(self, node: ASTNode) -> str:
"""Creates a stable hash for an AST node."""
# A simple hash based on type and value. A real implementation
# should be more robust, considering children as well.
hasher = hashlib.md5()
hasher.update(node.node_type.encode())
if node.value:
hasher.update(str(node.value).encode())
return hasher.hexdigest()

def _trim_large_asts(
self, files: List[FileSnapshot], threshold: int
) -> List[FileSnapshot]:
"""Trims the AST of very large files to save space."""
for file in files:
if file.lines > threshold and file.ast_summary:
self._traverse_and_trim(file.ast_summary)
return files

def _traverse_and_trim(self, node: ASTNode):
# 2. Initial pass: Estimate costs for different levels
file_costs = {} # {file_path: {level: token_count}}

# We need to read files.
for file in sorted_files:
file_path = os.path.join(project_root, file.path)
try:
with open(file_path, "r", encoding="utf-8", errors="replace") as f:
content = f.read()
except Exception:
content = "" # Should we handle missing files?

# Calculate costs for all strategies
costs = {}
for level in ["full", "skeleton", "signature"]:
strategy = CompressionStrategyFactory.get_strategy(level)
compressed_content = strategy.compress(content, file.path, file.language)
costs[level] = len(self.encoding.encode(compressed_content))

file_costs[file.path] = costs

# 3. Budget allocation loop
# Start with minimal cost (all signature)
current_total_tokens = sum(file_costs[f.path]["signature"] for f in sorted_files)

# Assign initial level
for file in snapshot.files:
file.compression_level = "signature"

# If we have budget left, upgrade files based on risk
# We iterate sorted_files (highest risk first)

# Upgrades: signature -> skeleton -> full

# Pass 1: Upgrade to Skeleton
for file in sorted_files:
costs = file_costs[file.path]
cost_increase = costs["skeleton"] - costs["signature"]

if current_total_tokens + cost_increase <= self.token_budget:
file.compression_level = "skeleton"
current_total_tokens += cost_increase
else:
# If we can't upgrade this file, maybe we can upgrade smaller files?
# Greedy approach says: prioritize high risk.
# If high risk file is huge, it might consume all budget.
# Standard Knapsack problem.
# For now, simple greedy: iterate by risk. If fits, upgrade.
pass

# Pass 2: Upgrade to Full
for file in sorted_files:
if file.compression_level == "skeleton":
costs = file_costs[file.path]
cost_increase = costs["full"] - costs["skeleton"]

if current_total_tokens + cost_increase <= self.token_budget:
file.compression_level = "full"
current_total_tokens += cost_increase

return snapshot

def select_strategy(self, file_risk: float, is_focal_file: bool) -> str:
"""
Recursively traverses the AST and removes non-essential nodes,
like the bodies of functions.
Determines the ideal strategy based on risk, ignoring budget.
Used as a heuristic or upper bound.
"""
if not isinstance(node, ASTNode):
return

# For function nodes, keep the signature but remove the body
if node.node_type == "function":
node.children = [] # A simple way to trim the function body
return

for child in node.children:
self._traverse_and_trim(child)


def calculate_compression_ratio(
self, original: ProjectSnapshot, compressed: ProjectSnapshot
) -> float:
"""Calculates the compression ratio."""
original_size = len(json.dumps(original.model_dump(mode='json')))
compressed_size = len(json.dumps(compressed.model_dump(mode='json')))
if original_size == 0:
return 0.0
return (original_size - compressed_size) / original_size
if is_focal_file or file_risk >= 0.7: # High risk
return "full"
elif file_risk >= 0.3: # Medium risk
return "skeleton"
else:
return "signature"
1 change: 1 addition & 0 deletions codesage/snapshot/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ class FileSnapshot(BaseModel):
symbols: Optional[Dict[str, Any]] = Field(default_factory=dict, description="A dictionary of symbols defined in the file.")
risk: Optional[FileRisk] = Field(None, description="Risk assessment for the file.")
issues: List[Issue] = Field(default_factory=list, description="A list of issues identified in the file.")
compression_level: Literal["full", "skeleton", "signature"] = Field("full", description="The compression level applied to the file.")

# Old fields for compatibility
hash: Optional[str] = Field(None, description="The SHA256 hash of the file content.")
Expand Down
Loading
Loading