diff --git a/examples/function_minimization/config_meta_evolution.yaml b/examples/function_minimization/config_meta_evolution.yaml new file mode 100644 index 000000000..c0856ed4a --- /dev/null +++ b/examples/function_minimization/config_meta_evolution.yaml @@ -0,0 +1,45 @@ +# Configuration for testing prompt meta-evolution feature +max_iterations: 25 +checkpoint_interval: 5 +log_level: INFO + +# LLM configuration +llm: + primary_model: "gpt-4o-mini" + primary_model_weight: 1.0 + api_base: "https://api.openai.com/v1" + temperature: 0.7 + max_tokens: 16000 + timeout: 120 + +# Prompt configuration +prompt: + system_message: "You are an expert programmer specializing in optimization algorithms. Your task is to improve a function minimization algorithm to find the global minimum of a complex function with many local minima. The function is f(x, y) = sin(x) * cos(y) + sin(x*y) + (x^2 + y^2)/20. Focus on improving the search_algorithm function to reliably find the global minimum, escaping local minima that might trap simple algorithms." + +# Prompt meta-evolution - ENABLED for testing +prompt_meta_evolution: + enabled: true + archive_size: 20 + min_uses_for_evolution: 5 # Lower for testing + evolution_interval: 20 # Trigger at iteration 20 + exploration_rate: 0.2 + elite_fraction: 0.3 + +# Database configuration +database: + population_size: 50 + archive_size: 20 + num_islands: 3 + elite_selection_ratio: 0.2 + exploitation_ratio: 0.7 + similarity_threshold: 0.99 + +# Evaluator configuration +evaluator: + timeout: 60 + cascade_thresholds: [1.3] + parallel_evaluations: 3 + +# Evolution settings +diff_based_evolution: true +max_code_length: 20000 diff --git a/openevolve/config.py b/openevolve/config.py index bef193da2..b71f943eb 100644 --- a/openevolve/config.py +++ b/openevolve/config.py @@ -397,6 +397,56 @@ class EvolutionTraceConfig: compress: bool = False +@dataclass +class PromptMetaEvolutionConfig: + """Configuration for meta-evolution of prompt templates. + + When enabled, OpenEvolve maintains an archive of prompt templates, + tracks their success rates, and evolves them over time to improve + mutation quality. + """ + + # Master switch + enabled: bool = False + + # Archive settings + archive_size: int = 20 # Max templates to keep in archive + + # Evolution triggers + min_uses_for_evolution: int = 10 # Min uses before template can be evolved + evolution_interval: int = 20 # Trigger evolution every N iterations + + # Sampling behavior + exploration_rate: float = 0.2 # Probability of sampling random template + elite_fraction: float = 0.3 # Fraction of top templates protected from pruning + + # Scoring weights (must sum to 1.0) + # score = w_success * success_rate + w_improvement * improvement_rate + w_fitness * normalized_fitness_delta + score_weight_success: float = 0.3 # Weight for success rate (mutations accepted) + score_weight_improvement: float = 0.4 # Weight for improvement rate (fitness increased) + score_weight_fitness_delta: float = 0.3 # Weight for avg fitness delta magnitude + + # Scoring parameters + score_min_uses: int = 5 # Min uses before score is calculated (else neutral prior) + score_neutral_prior: float = 0.5 # Score returned when uses < min_uses + + def __post_init__(self): + """Validate configuration after initialization.""" + weight_sum = ( + self.score_weight_success + + self.score_weight_improvement + + self.score_weight_fitness_delta + ) + tolerance = 1e-6 + if abs(weight_sum - 1.0) > tolerance: + raise ValueError( + f"Scoring weights must sum to 1.0, got {weight_sum:.6f} " + f"(success={self.score_weight_success}, " + f"improvement={self.score_weight_improvement}, " + f"fitness_delta={self.score_weight_fitness_delta})" + ) + + @dataclass class Config: """Master configuration for OpenEvolve""" @@ -416,6 +466,9 @@ class Config: database: DatabaseConfig = field(default_factory=DatabaseConfig) evaluator: EvaluatorConfig = field(default_factory=EvaluatorConfig) evolution_trace: EvolutionTraceConfig = field(default_factory=EvolutionTraceConfig) + prompt_meta_evolution: PromptMetaEvolutionConfig = field( + default_factory=PromptMetaEvolutionConfig + ) # Evolution settings diff_based_evolution: bool = True diff --git a/openevolve/controller.py b/openevolve/controller.py index 01ffec73c..161556b4d 100644 --- a/openevolve/controller.py +++ b/openevolve/controller.py @@ -10,6 +10,7 @@ import time import uuid from pathlib import Path +from concurrent.futures import ThreadPoolExecutor from typing import Any, Dict, List, Optional, Union from openevolve.config import Config, load_config @@ -18,6 +19,7 @@ from openevolve.evolution_trace import EvolutionTracer from openevolve.llm.ensemble import LLMEnsemble from openevolve.process_parallel import ProcessParallelController +from openevolve.prompt.meta_evolution import PromptArchive, evolve_prompt from openevolve.prompt.sampler import PromptSampler from openevolve.utils.code_utils import extract_code_language from openevolve.utils.format_utils import format_improvement_safe, format_metrics_safe @@ -188,6 +190,25 @@ def __init__( # Initialize improved parallel processing components self.parallel_controller = None + # Initialize prompt meta-evolution if enabled + self.prompt_archive = None + if self.config.prompt_meta_evolution.enabled: + self.prompt_archive = PromptArchive( + max_size=self.config.prompt_meta_evolution.archive_size, + min_uses_for_evolution=self.config.prompt_meta_evolution.min_uses_for_evolution, + elite_fraction=self.config.prompt_meta_evolution.elite_fraction, + exploration_rate=self.config.prompt_meta_evolution.exploration_rate, + # Scoring configuration + score_weight_success=self.config.prompt_meta_evolution.score_weight_success, + score_weight_improvement=self.config.prompt_meta_evolution.score_weight_improvement, + score_weight_fitness_delta=self.config.prompt_meta_evolution.score_weight_fitness_delta, + score_min_uses=self.config.prompt_meta_evolution.score_min_uses, + score_neutral_prior=self.config.prompt_meta_evolution.score_neutral_prior, + ) + self._initialize_default_prompt_templates() + self.prompt_sampler.set_prompt_archive(self.prompt_archive) + logger.info("Prompt meta-evolution enabled") + def _setup_logging(self) -> None: """Set up logging""" log_dir = self.config.log_dir or os.path.join(self.output_dir, "logs") @@ -225,7 +246,7 @@ def _setup_manual_mode_queue(self) -> None: if not bool(getattr(self.config.llm, "manual_mode", False)): return - qdir = (Path(self.output_dir).expanduser().resolve() / "manual_tasks_queue") + qdir = Path(self.output_dir).expanduser().resolve() / "manual_tasks_queue" # Clear stale tasks from previous runs if qdir.exists(): @@ -246,6 +267,34 @@ def _load_initial_program(self) -> str: with open(self.initial_program_path, "r") as f: return f.read() + def _initialize_default_prompt_templates(self) -> None: + """Initialize the prompt archive with default templates from TemplateManager.""" + if self.prompt_archive is None: + return + + # Get default templates from the sampler's template manager + tm = self.prompt_sampler.template_manager + + # Get system template + system_template = self.config.prompt.system_message + if system_template in tm.templates: + system_template = tm.get_template(system_template) + + # Get user template (diff-based or full rewrite) + if self.config.diff_based_evolution: + user_template = tm.get_template("diff_user") + else: + user_template = tm.get_template("full_rewrite_user") + + # Add as the default template + self.prompt_archive.add_template( + system_template=system_template, + user_template=user_template, + is_default=True, + metadata={"source": "default"}, + ) + logger.info("Added default prompt template to archive") + async def run( self, iterations: Optional[int] = None, @@ -333,6 +382,7 @@ async def run( self.database, self.evolution_tracer, file_suffix=self.config.file_suffix, + prompt_archive=self.prompt_archive, ) # Set up signal handlers for graceful shutdown @@ -493,6 +543,20 @@ def _save_checkpoint(self, iteration: int) -> None: f"{format_metrics_safe(best_program.metrics)}" ) + # Save prompt archive if meta-evolution is enabled + if self.prompt_archive is not None: + import json + + prompt_archive_path = os.path.join(checkpoint_path, "prompt_archive.json") + with open(prompt_archive_path, "w") as f: + json.dump(self.prompt_archive.to_dict(), f, indent=2) + stats = self.prompt_archive.get_statistics() + logger.info( + f"Saved prompt archive (size={stats['size']}, " + f"total_uses={stats['total_uses']}, " + f"success_rate={stats['overall_success_rate']:.1%})" + ) + logger.info(f"Saved checkpoint at iteration {iteration} to {checkpoint_path}") def _load_checkpoint(self, checkpoint_path: str) -> None: @@ -504,6 +568,95 @@ def _load_checkpoint(self, checkpoint_path: str) -> None: self.database.load(checkpoint_path) logger.info(f"Checkpoint loaded successfully (iteration {self.database.last_iteration})") + # Load prompt archive if meta-evolution is enabled + if self.prompt_archive is not None: + import json + + prompt_archive_path = os.path.join(checkpoint_path, "prompt_archive.json") + if os.path.exists(prompt_archive_path): + with open(prompt_archive_path, "r") as f: + self.prompt_archive = PromptArchive.from_dict(json.load(f)) + # Re-inject into sampler and parallel controller + self.prompt_sampler.set_prompt_archive(self.prompt_archive) + stats = self.prompt_archive.get_statistics() + logger.info( + f"Loaded prompt archive (size={stats['size']}, " + f"total_uses={stats['total_uses']})" + ) + + def _maybe_evolve_prompts(self, iteration: int) -> None: + """ + Periodically evolve prompt templates if meta-evolution is enabled. + + Args: + iteration: Current iteration number + """ + if self.prompt_archive is None: + return + + # Only evolve at configured intervals + interval = self.config.prompt_meta_evolution.evolution_interval + if iteration == 0 or iteration % interval != 0: + return + + # Get templates ready for evolution + templates_to_evolve = self.prompt_archive.get_templates_for_evolution() + if not templates_to_evolve: + logger.debug("No templates ready for evolution yet") + return + + top_templates = self.prompt_archive.get_top_templates(5) + + # Evolve the top template that's ready for evolution + # Sort by score descending + templates_to_evolve.sort(key=lambda t: t.score, reverse=True) + template = templates_to_evolve[0] + + logger.info( + f"Evolving prompt template {template.id} " + f"(score={template.score:.3f}, uses={template.uses})" + ) + + # Create a sync wrapper for LLM generation that works within an async context. + # We run the async LLM call in a separate thread with its own event loop + # to avoid conflicts with the main event loop. + def llm_generate_sync(system: str, user: str) -> str: + def _run_async_in_thread(): + # asyncio.run() creates a new event loop, runs the coroutine, + # and cleans up the loop automatically + return asyncio.run( + self.llm_ensemble.generate_with_context( + system_message=system, + messages=[{"role": "user", "content": user}], + ) + ) + + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(_run_async_in_thread) + return future.result() + + # Evolve the template + result = evolve_prompt( + template, + top_templates, + llm_generate_sync, + score_fn=self.prompt_archive.get_template_score, + ) + if result: + new_system, new_user = result + new_template = self.prompt_archive.add_template( + system_template=new_system, + user_template=new_user, + parent_id=template.id, + metadata={"evolved_at_iteration": iteration}, + ) + logger.info( + f"Created evolved template {new_template.id} " + f"(generation {new_template.generation})" + ) + else: + logger.warning(f"Failed to evolve template {template.id}") + async def _run_evolution_with_checkpoints( self, start_iteration: int, max_iterations: int, target_score: Optional[float] ) -> None: @@ -511,9 +664,28 @@ async def _run_evolution_with_checkpoints( logger.info(f"Using island-based evolution with {self.config.database.num_islands} islands") self.database.log_island_status() - # Run the evolution process with checkpoint callback + # Track last prompt evolution for catching up between checkpoints + last_prompt_evolution = [start_iteration] # Use list for closure mutability + + # Create a combined callback that handles checkpoints and prompt evolution + def combined_callback(iteration: int) -> None: + self._save_checkpoint(iteration) + + # Trigger prompt evolution - catch up on any missed intervals + if self.prompt_archive is not None: + evolution_interval = self.config.prompt_meta_evolution.evolution_interval + # Find all evolution points between last_prompt_evolution and current iteration + next_evolution = ( + last_prompt_evolution[0] // evolution_interval + 1 + ) * evolution_interval + while next_evolution <= iteration: + self._maybe_evolve_prompts(next_evolution) + next_evolution += evolution_interval + last_prompt_evolution[0] = iteration + + # Run the evolution process with combined callback await self.parallel_controller.run_evolution( - start_iteration, max_iterations, target_score, checkpoint_callback=self._save_checkpoint + start_iteration, max_iterations, target_score, checkpoint_callback=combined_callback ) # Check if shutdown or early stopping was triggered diff --git a/openevolve/process_parallel.py b/openevolve/process_parallel.py index a2fd6592a..41ef135f4 100644 --- a/openevolve/process_parallel.py +++ b/openevolve/process_parallel.py @@ -14,9 +14,14 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING + from openevolve.config import Config from openevolve.database import Program, ProgramDatabase -from openevolve.utils.metrics_utils import safe_numeric_average +from openevolve.utils.metrics_utils import get_fitness_score, safe_numeric_average + +if TYPE_CHECKING: + from openevolve.prompt.meta_evolution import PromptArchive logger = logging.getLogger(__name__) @@ -33,6 +38,7 @@ class SerializableResult: artifacts: Optional[Dict[str, Any]] = None iteration: int = 0 error: Optional[str] = None + template_id: Optional[str] = None # For prompt meta-evolution tracking target_island: Optional[int] = None # Island where child should be placed @@ -132,9 +138,23 @@ def _lazy_init_worker_components(): def _run_iteration_worker( - iteration: int, db_snapshot: Dict[str, Any], parent_id: str, inspiration_ids: List[str] + iteration: int, + db_snapshot: Dict[str, Any], + parent_id: str, + inspiration_ids: List[str], + template_info: Optional[Dict[str, str]] = None, ) -> SerializableResult: - """Run a single iteration in a worker process""" + """Run a single iteration in a worker process + + Args: + iteration: The iteration number + db_snapshot: Snapshot of the database state + parent_id: ID of the parent program to evolve + inspiration_ids: IDs of programs to use as inspiration + template_info: Optional dict with 'template_id', 'system_template', 'user_template' + for prompt meta-evolution. If provided, uses these instead of + sampling from the worker's prompt sampler. + """ try: # Lazy initialization _lazy_init_worker_components() @@ -191,6 +211,7 @@ def _run_iteration_worker( program_artifacts=parent_artifacts, feature_dimensions=db_snapshot.get("feature_dimensions", []), current_changes_description=parent_changes_desc, + meta_template_info=template_info, # Pass pre-sampled template for meta-evolution ) iteration_start = time.time() @@ -313,6 +334,8 @@ def _run_iteration_worker( iteration_time = time.time() - iteration_start + # Extract template_id for meta-evolution tracking (if present) + template_id = prompt.get("template_id") if prompt else None # Get target island from snapshot (where child should be placed) target_island = db_snapshot.get("sampling_island") @@ -324,6 +347,7 @@ def _run_iteration_worker( llm_response=llm_response, artifacts=artifacts, iteration=iteration, + template_id=template_id, target_island=target_island, ) @@ -342,12 +366,14 @@ def __init__( database: ProgramDatabase, evolution_tracer=None, file_suffix: str = ".py", + prompt_archive: Optional["PromptArchive"] = None, ): self.config = config self.evaluation_file = evaluation_file self.database = database self.evolution_tracer = evolution_tracer self.file_suffix = file_suffix + self.prompt_archive = prompt_archive self.executor: Optional[ProcessPoolExecutor] = None self.shutdown_event = mp.Event() @@ -555,6 +581,13 @@ async def run_evolution( if result.error: logger.warning(f"Iteration {completed_iteration} error: {result.error}") + # Record failed outcome for prompt meta-evolution + if self.prompt_archive is not None and result.template_id: + self.prompt_archive.record_outcome( + result.template_id, + accepted=False, + fitness_delta=0.0, + ) elif result.child_program_dict: # Reconstruct program from dict child_program = Program(**result.child_program_dict) @@ -572,6 +605,22 @@ async def run_evolution( if result.artifacts: self.database.store_artifacts(child_program.id, result.artifacts) + # Record outcome for prompt meta-evolution + if self.prompt_archive is not None and result.template_id: + parent_program = ( + self.database.get(result.parent_id) if result.parent_id else None + ) + if parent_program: + feature_dims = self.config.database.feature_dimensions + child_fitness = get_fitness_score(child_program.metrics, feature_dims) + parent_fitness = get_fitness_score(parent_program.metrics, feature_dims) + fitness_delta = child_fitness - parent_fitness + self.prompt_archive.record_outcome( + result.template_id, + accepted=True, + fitness_delta=fitness_delta, + ) + # Log evolution trace if self.evolution_tracer: # Retrieve parent program for trace logging @@ -811,6 +860,21 @@ def _submit_iteration( db_snapshot = self._create_database_snapshot() db_snapshot["sampling_island"] = target_island # Mark which island this is for + # Sample template from archive if meta-evolution is enabled + # This must happen in the main process since workers don't have the archive + template_info = None + if self.prompt_archive is not None: + sampled_template = self.prompt_archive.sample_template() + template_info = { + "template_id": sampled_template.id, + "system_template": sampled_template.system_template, + "user_template": sampled_template.user_template, + } + logger.debug( + f"Iteration {iteration}: sampled template {sampled_template.id} " + f"(score={sampled_template.score:.3f})" + ) + # Submit to process pool future = self.executor.submit( _run_iteration_worker, @@ -818,6 +882,7 @@ def _submit_iteration( db_snapshot, parent.id, [insp.id for insp in inspirations], + template_info, ) return future diff --git a/openevolve/prompt/meta_evolution.py b/openevolve/prompt/meta_evolution.py new file mode 100644 index 000000000..611dd184e --- /dev/null +++ b/openevolve/prompt/meta_evolution.py @@ -0,0 +1,522 @@ +""" +Meta-evolution of prompt templates for OpenEvolve. + +Inspired by the Darwin Gödel Machine paper, this module enables OpenEvolve +to evolve its own prompts based on empirical success rates. +""" + +import logging +import random +import re +import uuid +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional, Tuple + +logger = logging.getLogger(__name__) + + +@dataclass +class PromptTemplate: + """An evolvable prompt template with success tracking.""" + + id: str + system_template: str + user_template: str + # Success tracking + uses: int = 0 + successes: int = 0 # Number of times mutation was accepted + improvements: int = 0 # Number of times mutation improved fitness + total_fitness_delta: float = 0.0 # Sum of fitness changes + # Lineage + parent_id: Optional[str] = None + generation: int = 0 + metadata: Dict[str, Any] = field(default_factory=dict) + + @property + def success_rate(self) -> float: + """Fraction of uses that resulted in accepted mutations.""" + return self.successes / self.uses if self.uses > 0 else 0.0 + + @property + def improvement_rate(self) -> float: + """Fraction of uses that resulted in fitness improvement.""" + return self.improvements / self.uses if self.uses > 0 else 0.0 + + @property + def avg_fitness_delta(self) -> float: + """Average fitness change per use.""" + return self.total_fitness_delta / self.uses if self.uses > 0 else 0.0 + + def compute_score( + self, + weight_success: float = 0.3, + weight_improvement: float = 0.4, + weight_fitness_delta: float = 0.3, + min_uses: int = 5, + neutral_prior: float = 0.5, + ) -> float: + """ + Compute score for template quality with configurable weights. + + Args: + weight_success: Weight for success rate (mutations accepted) + weight_improvement: Weight for improvement rate (fitness increased) + weight_fitness_delta: Weight for avg fitness delta magnitude + min_uses: Minimum uses before score is calculated + neutral_prior: Score returned when uses < min_uses + + Returns: + Combined score between 0 and 1 + """ + if self.uses < min_uses: + return neutral_prior + # Weighted combination + return ( + weight_success * self.success_rate + + weight_improvement * self.improvement_rate + + weight_fitness_delta * min(1.0, self.avg_fitness_delta + 0.5) + ) + + @property + def score(self) -> float: + """ + Combined score for template quality using default weights. + For configurable weights, use compute_score() method. + """ + return self.compute_score() + + def record_use( + self, + accepted: bool, + fitness_delta: float = 0.0, + ) -> None: + """Record the outcome of using this template.""" + self.uses += 1 + if accepted: + self.successes += 1 + if fitness_delta > 0: + self.improvements += 1 + self.total_fitness_delta += fitness_delta + + def to_dict(self) -> Dict[str, Any]: + """Serialize to dictionary.""" + return { + "id": self.id, + "system_template": self.system_template, + "user_template": self.user_template, + "uses": self.uses, + "successes": self.successes, + "improvements": self.improvements, + "total_fitness_delta": self.total_fitness_delta, + "parent_id": self.parent_id, + "generation": self.generation, + "metadata": self.metadata, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "PromptTemplate": + """Deserialize from dictionary.""" + return cls( + id=data["id"], + system_template=data["system_template"], + user_template=data["user_template"], + uses=data.get("uses", 0), + successes=data.get("successes", 0), + improvements=data.get("improvements", 0), + total_fitness_delta=data.get("total_fitness_delta", 0.0), + parent_id=data.get("parent_id"), + generation=data.get("generation", 0), + metadata=data.get("metadata", {}), + ) + + +class PromptArchive: + """ + Archive of evolvable prompt templates. + + Maintains a population of templates, tracks their success rates, + and supports sampling and evolution. + """ + + def __init__( + self, + max_size: int = 20, + min_uses_for_evolution: int = 10, + elite_fraction: float = 0.3, + exploration_rate: float = 0.2, + # Scoring weights + score_weight_success: float = 0.3, + score_weight_improvement: float = 0.4, + score_weight_fitness_delta: float = 0.3, + score_min_uses: int = 5, + score_neutral_prior: float = 0.5, + ): + """ + Initialize the prompt archive. + + Args: + max_size: Maximum number of templates to keep + min_uses_for_evolution: Minimum uses before a template can be evolved + elite_fraction: Fraction of top templates to preserve + exploration_rate: Probability of sampling a random/new template + score_weight_success: Weight for success rate in scoring + score_weight_improvement: Weight for improvement rate in scoring + score_weight_fitness_delta: Weight for fitness delta in scoring + score_min_uses: Minimum uses before calculating score + score_neutral_prior: Score for templates with insufficient uses + """ + self.max_size = max_size + self.min_uses_for_evolution = min_uses_for_evolution + self.elite_fraction = elite_fraction + self.exploration_rate = exploration_rate + + # Scoring configuration + self.score_weight_success = score_weight_success + self.score_weight_improvement = score_weight_improvement + self.score_weight_fitness_delta = score_weight_fitness_delta + self.score_min_uses = score_min_uses + self.score_neutral_prior = score_neutral_prior + + self.templates: Dict[str, PromptTemplate] = {} + self.default_template_id: Optional[str] = None + + def get_template_score(self, template: PromptTemplate) -> float: + """Get the score for a template using configured weights.""" + return template.compute_score( + weight_success=self.score_weight_success, + weight_improvement=self.score_weight_improvement, + weight_fitness_delta=self.score_weight_fitness_delta, + min_uses=self.score_min_uses, + neutral_prior=self.score_neutral_prior, + ) + + def add_template( + self, + system_template: str, + user_template: str, + parent_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + is_default: bool = False, + ) -> PromptTemplate: + """Add a new template to the archive.""" + template_id = str(uuid.uuid4())[:8] + + # Determine generation + generation = 0 + if parent_id and parent_id in self.templates: + generation = self.templates[parent_id].generation + 1 + + template = PromptTemplate( + id=template_id, + system_template=system_template, + user_template=user_template, + parent_id=parent_id, + generation=generation, + metadata=metadata or {}, + ) + + self.templates[template_id] = template + + # Set as default if first template or explicitly requested + if self.default_template_id is None or is_default: + self.default_template_id = template_id + + # Prune if over capacity + self._prune_if_needed() + + logger.info( + f"Added prompt template {template_id} (generation {generation}, " + f"archive size: {len(self.templates)})" + ) + + return template + + def get_template(self, template_id: str) -> Optional[PromptTemplate]: + """Get a template by ID.""" + return self.templates.get(template_id) + + def sample_template(self) -> PromptTemplate: + """ + Sample a template for use. + + Uses a mix of exploitation (high-scoring templates) and + exploration (less-used or random templates). + """ + if not self.templates: + raise ValueError("No templates in archive") + + # Exploration: occasionally pick a random template + if random.random() < self.exploration_rate: + template = random.choice(list(self.templates.values())) + logger.debug(f"Sampled template {template.id} (exploration)") + return template + + # Exploitation: prefer high-scoring templates + # Weight by score, with bonus for less-used templates + templates = list(self.templates.values()) + weights = [] + for t in templates: + # Exploration bonus for under-used templates: linearly decreases from 0.3 to 0 + # as uses increase from 0 to 20. This ensures new templates get enough trials + # before being judged solely on their score. + exploration_bonus = max(0, 1.0 - t.uses / 20) * 0.3 + weights.append(self.get_template_score(t) + exploration_bonus) + + # Normalize weights + total = sum(weights) + if total == 0: + template = random.choice(templates) + else: + weights = [w / total for w in weights] + template = random.choices(templates, weights=weights, k=1)[0] + + logger.debug( + f"Sampled template {template.id} (score={self.get_template_score(template):.3f}, " + f"uses={template.uses})" + ) + return template + + def record_outcome( + self, + template_id: str, + accepted: bool, + fitness_delta: float = 0.0, + ) -> None: + """Record the outcome of using a template.""" + if template_id not in self.templates: + logger.warning(f"Template {template_id} not found in archive") + return + + self.templates[template_id].record_use(accepted, fitness_delta) + logger.debug( + f"Template {template_id}: accepted={accepted}, " + f"fitness_delta={fitness_delta:.4f}, " + f"new_score={self.get_template_score(self.templates[template_id]):.3f}" + ) + + def get_templates_for_evolution(self) -> List[PromptTemplate]: + """Get templates that are ready for evolution (enough uses).""" + return [t for t in self.templates.values() if t.uses >= self.min_uses_for_evolution] + + def get_top_templates(self, n: int = 5) -> List[PromptTemplate]: + """Get the top N templates by score.""" + sorted_templates = sorted( + self.templates.values(), + key=lambda t: self.get_template_score(t), + reverse=True, + ) + return sorted_templates[:n] + + def get_statistics(self) -> Dict[str, Any]: + """Get archive statistics.""" + if not self.templates: + return {"size": 0} + + templates = list(self.templates.values()) + total_uses = sum(t.uses for t in templates) + total_successes = sum(t.successes for t in templates) + + return { + "size": len(templates), + "total_uses": total_uses, + "total_successes": total_successes, + "overall_success_rate": (total_successes / total_uses if total_uses > 0 else 0), + "max_generation": max(t.generation for t in templates), + "avg_score": sum(self.get_template_score(t) for t in templates) / len(templates), + "top_template_id": self.get_top_templates(1)[0].id if templates else None, + } + + def _prune_if_needed(self) -> None: + """Remove lowest-scoring templates if over capacity.""" + if len(self.templates) <= self.max_size: + return + + # Keep elite templates + num_elite = max(1, int(self.max_size * self.elite_fraction)) + sorted_templates = sorted( + self.templates.values(), + key=lambda t: self.get_template_score(t), + reverse=True, + ) + + # Templates to keep: elite + default + elite_ids = {t.id for t in sorted_templates[:num_elite]} + + # Also keep default template + if self.default_template_id: + elite_ids.add(self.default_template_id) + + # Remove lowest scoring non-elite templates + to_remove = [] + for t in reversed(sorted_templates): + if t.id not in elite_ids and len(self.templates) - len(to_remove) > self.max_size: + to_remove.append(t.id) + + for tid in to_remove: + del self.templates[tid] + logger.debug(f"Pruned template {tid} from archive") + + def to_dict(self) -> Dict[str, Any]: + """Serialize archive to dictionary.""" + return { + "max_size": self.max_size, + "min_uses_for_evolution": self.min_uses_for_evolution, + "elite_fraction": self.elite_fraction, + "exploration_rate": self.exploration_rate, + # Scoring configuration + "score_weight_success": self.score_weight_success, + "score_weight_improvement": self.score_weight_improvement, + "score_weight_fitness_delta": self.score_weight_fitness_delta, + "score_min_uses": self.score_min_uses, + "score_neutral_prior": self.score_neutral_prior, + "default_template_id": self.default_template_id, + "templates": {tid: t.to_dict() for tid, t in self.templates.items()}, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "PromptArchive": + """Deserialize archive from dictionary.""" + archive = cls( + max_size=data.get("max_size", 20), + min_uses_for_evolution=data.get("min_uses_for_evolution", 10), + elite_fraction=data.get("elite_fraction", 0.3), + exploration_rate=data.get("exploration_rate", 0.2), + # Scoring configuration + score_weight_success=data.get("score_weight_success", 0.3), + score_weight_improvement=data.get("score_weight_improvement", 0.4), + score_weight_fitness_delta=data.get("score_weight_fitness_delta", 0.3), + score_min_uses=data.get("score_min_uses", 5), + score_neutral_prior=data.get("score_neutral_prior", 0.5), + ) + archive.default_template_id = data.get("default_template_id") + + for tid, tdata in data.get("templates", {}).items(): + archive.templates[tid] = PromptTemplate.from_dict(tdata) + + return archive + + +# Prompt for evolving prompts (meta!) +PROMPT_EVOLUTION_SYSTEM = """You are an expert at crafting prompts for code evolution systems. +Your task is to improve prompts that guide an LLM to generate better code mutations. + +A good evolution prompt should: +1. Clearly explain the task and expected output format +2. Provide useful context without overwhelming detail +3. Encourage creative yet targeted improvements +4. Guide the LLM to explain its reasoning +""" + +PROMPT_EVOLUTION_USER = """# Current Prompt Performance + +The following prompt template has been used {uses} times: +- Success rate (mutations accepted): {success_rate:.1%} +- Improvement rate (fitness increased): {improvement_rate:.1%} +- Average fitness change: {avg_fitness_delta:+.4f} + +## Current System Template +``` +{system_template} +``` + +## Current User Template +``` +{user_template} +``` + +## Top Performing Templates for Reference + +{top_templates_section} + +# Task + +Create an improved version of this prompt that will lead to better mutation success rates. + +Focus on: +1. Clearer instructions for the type of changes to make +2. Better guidance on analyzing the current program +3. More effective use of the evolution history +4. Encouraging both exploitation (improving what works) and exploration (trying new approaches) + +Provide your improved templates in the following format: + + +Your improved system template here + + + +Your improved user template here + + +Explain your changes briefly after the templates. +""" + + +def evolve_prompt( + template: PromptTemplate, + top_templates: List[PromptTemplate], + llm_generate_fn: Callable[[str, str], str], + score_fn: Optional[Callable[[PromptTemplate], float]] = None, +) -> Optional[Tuple[str, str]]: + """ + Evolve a prompt template using an LLM. + + Args: + template: The template to evolve + top_templates: Top performing templates for reference + llm_generate_fn: Function to call LLM (takes system, user, returns str) + score_fn: Optional function to compute template scores (defaults to template.score) + + Returns: + Tuple of (new_system_template, new_user_template) or None if evolution failed + """ + # Use provided score function or fall back to default + get_score = score_fn if score_fn is not None else (lambda t: t.score) + + # Format top templates section + top_section = "" + for i, t in enumerate(top_templates[:3]): + if t.id == template.id: + continue + top_section += f"""### Template {i + 1} (score: {get_score(t):.3f}, success: {t.success_rate:.1%}) +System (truncated): {t.system_template[:200]}... +User (truncated): {t.user_template[:300]}... + +""" + + user_prompt = PROMPT_EVOLUTION_USER.format( + uses=template.uses, + success_rate=template.success_rate, + improvement_rate=template.improvement_rate, + avg_fitness_delta=template.avg_fitness_delta, + system_template=template.system_template, + user_template=template.user_template, + top_templates_section=top_section or "No other templates available yet.", + ) + + try: + response = llm_generate_fn(PROMPT_EVOLUTION_SYSTEM, user_prompt) + + # Parse response + new_system = _extract_between_tags(response, "system_template") + new_user = _extract_between_tags(response, "user_template") + + if new_system and new_user: + logger.info(f"Successfully evolved template {template.id}") + return new_system, new_user + else: + logger.warning("Failed to parse evolved template from response") + return None + + except Exception as e: + logger.error(f"Error evolving template: {e}") + return None + + +def _extract_between_tags(text: str, tag: str) -> Optional[str]: + """Extract content between XML-style tags.""" + pattern = rf"<{tag}>\s*(.*?)\s*" + match = re.search(pattern, text, re.DOTALL) + if match: + return match.group(1).strip() + return None diff --git a/openevolve/prompt/sampler.py b/openevolve/prompt/sampler.py index 61a5b98ba..6a75513d7 100644 --- a/openevolve/prompt/sampler.py +++ b/openevolve/prompt/sampler.py @@ -4,7 +4,7 @@ import logging import random -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from openevolve.config import PromptConfig from openevolve.prompt.templates import TemplateManager @@ -15,13 +15,20 @@ format_feature_coordinates, ) +if TYPE_CHECKING: + from openevolve.prompt.meta_evolution import PromptArchive + logger = logging.getLogger(__name__) class PromptSampler: """Generates prompts for code evolution""" - def __init__(self, config: PromptConfig): + def __init__( + self, + config: PromptConfig, + prompt_archive: Optional["PromptArchive"] = None, + ): self.config = config self.template_manager = TemplateManager(custom_template_dir=config.template_dir) @@ -29,6 +36,10 @@ def __init__(self, config: PromptConfig): self.system_template_override = None self.user_template_override = None + # Meta-evolution: optional prompt archive for template sampling + self.prompt_archive = prompt_archive + self._last_sampled_template_id: Optional[str] = None + # Only log once to reduce duplication if not hasattr(logger, "_prompt_sampler_logged"): logger.info("Initialized prompt sampler") @@ -48,6 +59,21 @@ def set_templates( self.user_template_override = user_template logger.info(f"Set custom templates: system={system_template}, user={user_template}") + def set_prompt_archive(self, archive: Optional["PromptArchive"]) -> None: + """ + Set the prompt archive for meta-evolution. + + Args: + archive: PromptArchive instance or None to disable + """ + self.prompt_archive = archive + if archive is not None: + logger.info(f"Enabled prompt meta-evolution (archive size: {len(archive.templates)})") + + def get_last_template_id(self) -> Optional[str]: + """Get the ID of the last sampled template, or None if not using meta-evolution.""" + return self._last_sampled_template_id + def build_prompt( self, current_program: str = "", @@ -63,6 +89,7 @@ def build_prompt( program_artifacts: Optional[Dict[str, Union[str, bytes]]] = None, feature_dimensions: Optional[List[str]] = None, current_changes_description: Optional[str] = None, + meta_template_info: Optional[Dict[str, str]] = None, **kwargs: Any, ) -> Dict[str, str]: """ @@ -80,41 +107,73 @@ def build_prompt( diff_based_evolution: Whether to use diff-based evolution (True) or full rewrites (False) template_key: Optional override for template key program_artifacts: Optional artifacts from program evaluation + meta_template_info: Optional dict with 'template_id', 'system_template', 'user_template' + for prompt meta-evolution. If provided, uses these templates + instead of sampling from the archive. **kwargs: Additional keys to replace in the user prompt Returns: - Dictionary with 'system' and 'user' keys + Dictionary with 'system', 'user', and optionally 'template_id' keys """ - # Select template based on evolution mode (with overrides) - if template_key: - # Use explicitly provided template key - user_template_key = template_key - elif self.user_template_override: - # Use the override set with set_templates - user_template_key = self.user_template_override + # Reset template tracking + self._last_sampled_template_id = None + + # Priority 1: Use pre-provided meta-evolution template (from worker processes) + if meta_template_info is not None: + self._last_sampled_template_id = meta_template_info.get("template_id") + system_message = meta_template_info.get("system_template") + user_template = meta_template_info.get("user_template") + logger.debug( + f"Using pre-sampled meta-evolution template {self._last_sampled_template_id}" + ) + # Priority 2: Sample from prompt archive (main process with archive) + elif self.prompt_archive is not None: + sampled_template = self.prompt_archive.sample_template() + self._last_sampled_template_id = sampled_template.id + system_message = sampled_template.system_template + user_template = sampled_template.user_template + logger.debug( + f"Using meta-evolved template {sampled_template.id} " + f"(score={sampled_template.score:.3f})" + ) else: - # Default behavior: diff-based vs full rewrite - user_template_key = "diff_user" if diff_based_evolution else "full_rewrite_user" + # Standard template selection (no meta-evolution) + # Select template based on evolution mode (with overrides) + if template_key: + # Use explicitly provided template key + user_template_key = template_key + elif self.user_template_override: + # Use the override set with set_templates + user_template_key = self.user_template_override + else: + # Default behavior: diff-based vs full rewrite + user_template_key = "diff_user" if diff_based_evolution else "full_rewrite_user" - # Get the template - user_template = self.template_manager.get_template(user_template_key) + # Get the template + user_template = self.template_manager.get_template(user_template_key) - # Use system template override if set - if self.system_template_override: - system_message = self.template_manager.get_template(self.system_template_override) - else: - system_message = self.config.system_message - # If system_message is a template name rather than content, get the template - if system_message in self.template_manager.templates: - system_message = self.template_manager.get_template(system_message) + # Use system template override if set + if self.system_template_override: + system_message = self.template_manager.get_template(self.system_template_override) + else: + system_message = self.config.system_message + # If system_message is a template name rather than content, get the template + if system_message in self.template_manager.templates: + system_message = self.template_manager.get_template(system_message) if self.config.programs_as_changes_description: if self.config.system_message_changes_description: - system_message_changes_description = self.config.system_message_changes_description.strip() + system_message_changes_description = ( + self.config.system_message_changes_description.strip() + ) else: - system_message_changes_description = self.template_manager.get_template("system_message_changes_description") + system_message_changes_description = self.template_manager.get_template( + "system_message_changes_description" + ) - system_message = self.template_manager.get_template("system_message_with_changes_description").format( + system_message = self.template_manager.get_template( + "system_message_with_changes_description" + ).format( system_message=system_message, system_message_changes_description=system_message_changes_description, ) @@ -161,16 +220,24 @@ def build_prompt( ) if self.config.programs_as_changes_description: - user_message = self.template_manager.get_template("user_message_with_changes_description").format( + user_message = self.template_manager.get_template( + "user_message_with_changes_description" + ).format( user_message=user_message, changes_description=current_changes_description.rstrip(), ) - return { + result = { "system": system_message, "user": user_message, } + # Include template_id if meta-evolution is active + if self._last_sampled_template_id is not None: + result["template_id"] = self._last_sampled_template_id + + return result + def _format_metrics(self, metrics: Dict[str, float]) -> str: """Format metrics for the prompt using safe formatting""" # Use safe formatting to handle mixed numeric and string values @@ -265,11 +332,8 @@ def _format_evolution_history( for i, program in enumerate(reversed(selected_previous)): attempt_number = len(previous_programs) - i - changes = ( - program.get("changes_description") - or program.get("metadata", {}).get( - "changes", self.template_manager.get_fragment("attempt_unknown_changes") - ) + changes = program.get("changes_description") or program.get("metadata", {}).get( + "changes", self.template_manager.get_fragment("attempt_unknown_changes") ) # Format performance metrics using safe formatting @@ -334,9 +398,7 @@ def _format_evolution_history( for i, program in enumerate(selected_top): use_changes = self.config.programs_as_changes_description program_code = ( - program.get("changes_description", "") - if use_changes - else program.get("code", "") + program.get("changes_description", "") if use_changes else program.get("code", "") ) if not program_code: program_code = "" if use_changes else "" @@ -351,11 +413,20 @@ def _format_evolution_history( for name, value in program.get("metrics", {}).items(): if isinstance(value, (int, float)): try: - key_features.append(self.template_manager.get_fragment("top_program_metrics_prefix") + f" {name} ({value:.4f})") + key_features.append( + self.template_manager.get_fragment("top_program_metrics_prefix") + + f" {name} ({value:.4f})" + ) except (ValueError, TypeError): - key_features.append(self.template_manager.get_fragment("top_program_metrics_prefix") + f" {name} ({value})") + key_features.append( + self.template_manager.get_fragment("top_program_metrics_prefix") + + f" {name} ({value})" + ) else: - key_features.append(self.template_manager.get_fragment("top_program_metrics_prefix") + f" {name} ({value})") + key_features.append( + self.template_manager.get_fragment("top_program_metrics_prefix") + + f" {name} ({value})" + ) key_features_str = ", ".join(key_features) @@ -385,7 +456,11 @@ def _format_evolution_history( # Use random sampling to get diverse programs diverse_programs = random.sample(remaining_programs, num_diverse) - diverse_programs_str += "\n\n## " + self.template_manager.get_fragment("diverse_programs_title") + "\n\n" + diverse_programs_str += ( + "\n\n## " + + self.template_manager.get_fragment("diverse_programs_title") + + "\n\n" + ) for i, program in enumerate(diverse_programs): use_changes = self.config.programs_as_changes_description @@ -404,7 +479,8 @@ def _format_evolution_history( key_features = program.get("key_features", []) if not key_features: key_features = [ - self.template_manager.get_fragment("diverse_program_metrics_prefix") + f" {name}" + self.template_manager.get_fragment("diverse_program_metrics_prefix") + + f" {name}" for name in list(program.get("metrics", {}).keys())[ :2 ] # Just first 2 metrics @@ -416,7 +492,9 @@ def _format_evolution_history( top_program_template.format( program_number=f"D{i + 1}", score=f"{score:.4f}", - language=("text" if self.config.programs_as_changes_description else language), + language=( + "text" if self.config.programs_as_changes_description else language + ), program_snippet=program_code, key_features=key_features_str, ) @@ -466,9 +544,7 @@ def _format_inspirations_section( for i, program in enumerate(inspirations): use_changes = self.config.programs_as_changes_description program_code = ( - program.get("changes_description", "") - if use_changes - else program.get("code", "") + program.get("changes_description", "") if use_changes else program.get("code", "") ) if not program_code: program_code = "" if use_changes else "" @@ -551,16 +627,24 @@ def _extract_unique_features(self, program: Dict[str, Any]) -> str: and self.config.include_changes_under_chars and len(changes) < self.config.include_changes_under_chars ): - features.append(self.template_manager.get_fragment("inspiration_changes_prefix").format(changes=changes)) + features.append( + self.template_manager.get_fragment("inspiration_changes_prefix").format( + changes=changes + ) + ) # Analyze metrics for standout characteristics metrics = program.get("metrics", {}) for metric_name, value in metrics.items(): if isinstance(value, (int, float)): if value >= 0.9: - features.append(f"{self.template_manager.get_fragment('inspiration_metrics_excellent').format(metric_name=metric_name, value=value)}") + features.append( + f"{self.template_manager.get_fragment('inspiration_metrics_excellent').format(metric_name=metric_name, value=value)}" + ) elif value <= 0.3: - features.append(f"{self.template_manager.get_fragment('inspiration_metrics_alternative').format(metric_name=metric_name)}") + features.append( + f"{self.template_manager.get_fragment('inspiration_metrics_alternative').format(metric_name=metric_name)}" + ) # Code-based features (simple heuristics) code = program.get("code", "") @@ -571,22 +655,32 @@ def _extract_unique_features(self, program: Dict[str, Any]) -> str: if "numpy" in code_lower or "np." in code_lower: features.append(self.template_manager.get_fragment("inspiration_code_with_numpy")) if "for" in code_lower and "while" in code_lower: - features.append(self.template_manager.get_fragment("inspiration_code_with_mixed_iteration")) + features.append( + self.template_manager.get_fragment("inspiration_code_with_mixed_iteration") + ) if ( self.config.concise_implementation_max_lines and len(code.split("\n")) <= self.config.concise_implementation_max_lines ): - features.append(self.template_manager.get_fragment("inspiration_code_with_concise_line")) + features.append( + self.template_manager.get_fragment("inspiration_code_with_concise_line") + ) elif ( self.config.comprehensive_implementation_min_lines and len(code.split("\n")) >= self.config.comprehensive_implementation_min_lines ): - features.append(self.template_manager.get_fragment("inspiration_code_with_comprehensive_line")) + features.append( + self.template_manager.get_fragment("inspiration_code_with_comprehensive_line") + ) # Default if no specific features found if not features: program_type = self._determine_program_type(program) - features.append(self.template_manager.get_fragment("inspiration_no_features_postfix").format(program_type=program_type)) + features.append( + self.template_manager.get_fragment("inspiration_no_features_postfix").format( + program_type=program_type + ) + ) # Use num_top_programs as limit for features (similar to how we limit programs) feature_limit = self.config.num_top_programs @@ -629,7 +723,12 @@ def _render_artifacts(self, artifacts: Dict[str, Union[str, bytes]]) -> str: sections.append(f"### {key}\n```\n{content}\n```") if sections: - return "## " + self.template_manager.get_fragment("artifact_title") + "\n\n" + "\n\n".join(sections) + return ( + "## " + + self.template_manager.get_fragment("artifact_title") + + "\n\n" + + "\n\n".join(sections) + ) else: return "" diff --git a/tests/test_prompt_meta_evolution.py b/tests/test_prompt_meta_evolution.py new file mode 100644 index 000000000..1444fb1d7 --- /dev/null +++ b/tests/test_prompt_meta_evolution.py @@ -0,0 +1,406 @@ +""" +Tests for prompt meta-evolution in openevolve.prompt.meta_evolution +""" + +import unittest + +from openevolve.prompt.meta_evolution import ( + PromptTemplate, + PromptArchive, + evolve_prompt, + _extract_between_tags, +) + + +class TestPromptTemplate(unittest.TestCase): + """Tests for PromptTemplate dataclass""" + + def test_initial_score(self): + """Test that new templates have neutral score""" + template = PromptTemplate( + id="test1", + system_template="You are a helpful assistant.", + user_template="Improve this code: {code}", + ) + # With 0 uses, should return 0.5 (neutral prior) + self.assertEqual(template.score, 0.5) + + def test_score_calculation(self): + """Test score calculation with usage data""" + template = PromptTemplate( + id="test1", + system_template="System", + user_template="User", + uses=10, + successes=8, # 80% success rate + improvements=6, # 60% improvement rate + total_fitness_delta=0.5, # avg delta = 0.05 + ) + + # success_rate = 0.8 + # improvement_rate = 0.6 + # avg_fitness_delta = 0.05, normalized = min(1.0, 0.05 + 0.5) = 0.55 + # score = 0.3 * 0.8 + 0.4 * 0.6 + 0.3 * 0.55 = 0.24 + 0.24 + 0.165 = 0.645 + expected_score = 0.3 * 0.8 + 0.4 * 0.6 + 0.3 * 0.55 + self.assertAlmostEqual(template.score, expected_score, places=3) + + def test_record_use(self): + """Test recording usage outcomes""" + template = PromptTemplate( + id="test1", + system_template="System", + user_template="User", + ) + + # Record successful improvement + template.record_use(accepted=True, fitness_delta=0.1) + self.assertEqual(template.uses, 1) + self.assertEqual(template.successes, 1) + self.assertEqual(template.improvements, 1) + self.assertAlmostEqual(template.total_fitness_delta, 0.1) + + # Record accepted but no improvement + template.record_use(accepted=True, fitness_delta=-0.05) + self.assertEqual(template.uses, 2) + self.assertEqual(template.successes, 2) + self.assertEqual(template.improvements, 1) # No improvement + self.assertAlmostEqual(template.total_fitness_delta, 0.05) + + # Record rejection + template.record_use(accepted=False, fitness_delta=0.0) + self.assertEqual(template.uses, 3) + self.assertEqual(template.successes, 2) + self.assertEqual(template.improvements, 1) + + def test_serialization(self): + """Test to_dict and from_dict""" + template = PromptTemplate( + id="test1", + system_template="System message", + user_template="User message", + uses=5, + successes=3, + improvements=2, + total_fitness_delta=0.25, + parent_id="parent1", + generation=1, + metadata={"source": "test"}, + ) + + data = template.to_dict() + restored = PromptTemplate.from_dict(data) + + self.assertEqual(restored.id, template.id) + self.assertEqual(restored.system_template, template.system_template) + self.assertEqual(restored.user_template, template.user_template) + self.assertEqual(restored.uses, template.uses) + self.assertEqual(restored.successes, template.successes) + self.assertEqual(restored.improvements, template.improvements) + self.assertAlmostEqual(restored.total_fitness_delta, template.total_fitness_delta) + self.assertEqual(restored.parent_id, template.parent_id) + self.assertEqual(restored.generation, template.generation) + self.assertEqual(restored.metadata, template.metadata) + + +class TestPromptArchive(unittest.TestCase): + """Tests for PromptArchive""" + + def setUp(self): + """Set up test archive""" + self.archive = PromptArchive( + max_size=5, + min_uses_for_evolution=3, + elite_fraction=0.4, + exploration_rate=0.0, # Disable exploration for deterministic tests + ) + + def test_add_template(self): + """Test adding templates""" + template = self.archive.add_template( + system_template="System", + user_template="User", + ) + + self.assertIn(template.id, self.archive.templates) + self.assertEqual(self.archive.default_template_id, template.id) + self.assertEqual(len(self.archive.templates), 1) + + def test_add_child_template(self): + """Test adding child template with parent""" + parent = self.archive.add_template( + system_template="Parent system", + user_template="Parent user", + ) + child = self.archive.add_template( + system_template="Child system", + user_template="Child user", + parent_id=parent.id, + ) + + self.assertEqual(child.parent_id, parent.id) + self.assertEqual(child.generation, 1) + + def test_sample_template(self): + """Test template sampling""" + template = self.archive.add_template( + system_template="System", + user_template="User", + ) + + sampled = self.archive.sample_template() + self.assertEqual(sampled.id, template.id) + + def test_sample_prefers_higher_score(self): + """Test that sampling prefers higher-scoring templates""" + # Add low-scoring template + low = self.archive.add_template( + system_template="Low", + user_template="Low", + ) + low.uses = 10 + low.successes = 1 + low.improvements = 0 + + # Add high-scoring template + high = self.archive.add_template( + system_template="High", + user_template="High", + ) + high.uses = 10 + high.successes = 9 + high.improvements = 8 + high.total_fitness_delta = 1.0 + + # Sample multiple times and check distribution + high_count = 0 + for _ in range(100): + sampled = self.archive.sample_template() + if sampled.id == high.id: + high_count += 1 + + # High-scoring template should be sampled more often + self.assertGreater(high_count, 50) + + def test_record_outcome(self): + """Test recording outcomes""" + template = self.archive.add_template( + system_template="System", + user_template="User", + ) + + self.archive.record_outcome(template.id, accepted=True, fitness_delta=0.1) + + self.assertEqual(template.uses, 1) + self.assertEqual(template.successes, 1) + + def test_get_templates_for_evolution(self): + """Test getting templates ready for evolution""" + template1 = self.archive.add_template( + system_template="System1", + user_template="User1", + ) + template1.uses = 5 # Above min_uses_for_evolution (3) + + template2 = self.archive.add_template( + system_template="System2", + user_template="User2", + ) + template2.uses = 2 # Below threshold + + ready = self.archive.get_templates_for_evolution() + self.assertEqual(len(ready), 1) + self.assertEqual(ready[0].id, template1.id) + + def test_pruning(self): + """Test that archive prunes when over capacity""" + # Add 6 templates (max_size is 5) + for i in range(6): + t = self.archive.add_template( + system_template=f"System{i}", + user_template=f"User{i}", + ) + t.uses = 10 + t.successes = i # Different scores + + # Should have pruned to max_size + self.assertEqual(len(self.archive.templates), 5) + + def test_serialization(self): + """Test archive serialization""" + t1 = self.archive.add_template( + system_template="System1", + user_template="User1", + ) + t1.uses = 5 + t1.successes = 3 + + t2 = self.archive.add_template( + system_template="System2", + user_template="User2", + parent_id=t1.id, + ) + + data = self.archive.to_dict() + restored = PromptArchive.from_dict(data) + + self.assertEqual(len(restored.templates), 2) + self.assertEqual(restored.default_template_id, self.archive.default_template_id) + self.assertEqual(restored.templates[t1.id].uses, 5) + self.assertEqual(restored.templates[t2.id].parent_id, t1.id) + + def test_serialization_with_scoring_config(self): + """Test that scoring config is preserved during serialization""" + # Create archive with custom scoring config + archive = PromptArchive( + max_size=10, + score_weight_success=0.2, + score_weight_improvement=0.5, + score_weight_fitness_delta=0.3, + score_min_uses=10, + score_neutral_prior=0.6, + ) + archive.add_template(system_template="Test", user_template="Test") + + # Serialize and restore + data = archive.to_dict() + restored = PromptArchive.from_dict(data) + + # Verify scoring config is preserved + self.assertEqual(restored.score_weight_success, 0.2) + self.assertEqual(restored.score_weight_improvement, 0.5) + self.assertEqual(restored.score_weight_fitness_delta, 0.3) + self.assertEqual(restored.score_min_uses, 10) + self.assertEqual(restored.score_neutral_prior, 0.6) + + def test_get_statistics(self): + """Test archive statistics""" + t1 = self.archive.add_template( + system_template="System1", + user_template="User1", + ) + t1.uses = 10 + t1.successes = 8 + + t2 = self.archive.add_template( + system_template="System2", + user_template="User2", + parent_id=t1.id, + ) + t2.uses = 5 + t2.successes = 2 + + stats = self.archive.get_statistics() + + self.assertEqual(stats["size"], 2) + self.assertEqual(stats["total_uses"], 15) + self.assertEqual(stats["total_successes"], 10) + self.assertAlmostEqual(stats["overall_success_rate"], 10 / 15) + self.assertEqual(stats["max_generation"], 1) + + +class TestExtractBetweenTags(unittest.TestCase): + """Tests for tag extraction helper""" + + def test_extract_simple(self): + """Test simple tag extraction""" + text = "content" + result = _extract_between_tags(text, "tag") + self.assertEqual(result, "content") + + def test_extract_with_whitespace(self): + """Test extraction with whitespace""" + text = " content with spaces " + result = _extract_between_tags(text, "tag") + self.assertEqual(result, "content with spaces") + + def test_extract_multiline(self): + """Test multiline extraction""" + text = """""" + result = _extract_between_tags(text, "template") + self.assertEqual(result, "line 1\nline 2") + + def test_extract_not_found(self): + """Test extraction when tag not found""" + text = "no tags here" + result = _extract_between_tags(text, "tag") + self.assertIsNone(result) + + +class TestEvolvePrompt(unittest.TestCase): + """Tests for evolve_prompt function""" + + def test_evolve_prompt_success(self): + """Test successful prompt evolution""" + template = PromptTemplate( + id="test1", + system_template="Old system", + user_template="Old user", + uses=10, + successes=5, + improvements=3, + total_fitness_delta=0.2, + ) + + # Mock LLM that returns valid evolved templates + def mock_llm(system: str, user: str) -> str: + return """ +Here's an improved version: + + +New improved system template + + + +New improved user template + + +I made these changes because... +""" + + result = evolve_prompt(template, [], mock_llm) + + self.assertIsNotNone(result) + new_system, new_user = result + self.assertEqual(new_system, "New improved system template") + self.assertEqual(new_user, "New improved user template") + + def test_evolve_prompt_failure(self): + """Test prompt evolution when LLM returns invalid format""" + template = PromptTemplate( + id="test1", + system_template="Old system", + user_template="Old user", + uses=10, + ) + + # Mock LLM that returns invalid format + def mock_llm(system: str, user: str) -> str: + return "This response doesn't have the expected tags" + + result = evolve_prompt(template, [], mock_llm) + + self.assertIsNone(result) + + def test_evolve_prompt_exception(self): + """Test prompt evolution when LLM raises exception""" + template = PromptTemplate( + id="test1", + system_template="Old system", + user_template="Old user", + uses=10, + ) + + # Mock LLM that raises exception + def mock_llm(system: str, user: str) -> str: + raise RuntimeError("LLM error") + + result = evolve_prompt(template, [], mock_llm) + + self.assertIsNone(result) + + +if __name__ == "__main__": + unittest.main()