Source code for llamda.ga.reevo.reevo

# Adapted from ReEvo: https://github.com/ai4co/reevo/blob/main/reevo.py
# Licensed under the MIT License (see THIRD-PARTY-LICENSES.txt)

import json
from pathlib import Path
from typing import Optional
import logging
import numpy as np
import os
from dataclasses import dataclass

from llamda.ga.base import GeneticAlgorithm
from llamda.ga.reevo.reevo_prompts import ReEvoPrompts
from llamda.evaluate import Evaluator
from llamda.individual import Individual
from llamda.llm_client.base import BaseClient
from llamda.problem import Problem
from llamda.utils import print_hyperlink
from llamda.ga.utils import extract_code_from_generator, population_checkpoint

logger = logging.getLogger("llamda")


[docs] class ReEvoLLMClients: def __init__( self, generator_llm: BaseClient, reflector_llm: BaseClient | None = None, short_reflector_llm: BaseClient | None = None, long_reflector_llm: BaseClient | None = None, crossover_llm: BaseClient | None = None, mutation_llm: BaseClient | None = None, ) -> None: self.generator_llm = generator_llm self.reflector_llm = reflector_llm or generator_llm self.short_reflector_llm = short_reflector_llm or self.reflector_llm self.long_reflector_llm = long_reflector_llm or self.reflector_llm self.crossover_llm = crossover_llm or generator_llm self.mutation_llm = mutation_llm or generator_llm
[docs] @dataclass class ReEvoConfig: max_fe: int = 100 # maximum number of function evaluations pop_size: int = 10 # population size for GA init_pop_size: int = 30 # initial population size for GA mutation_rate: float = 0.5 # mutation rate for GA diversify_init_pop: bool = True # whether to diversify the initial population
[docs] class ReEvo(GeneticAlgorithm[ReEvoConfig, Problem]): def __init__( self, config: ReEvoConfig, problem: Problem, evaluator: Evaluator, output_dir: str, llm_client: BaseClient, reflector_llm: Optional[BaseClient] = None, # Support setting different LLMs for each of the four operators: # Short-term Reflection, Long-term Reflection, Crossover, Mutation short_reflector_llm: Optional[BaseClient] = None, long_reflector_llm: Optional[BaseClient] = None, crossover_llm: Optional[BaseClient] = None, mutation_llm: Optional[BaseClient] = None, ) -> None: super().__init__( config=config, problem=problem, evaluator=evaluator, llm_client=llm_client, output_dir=output_dir, ) self.llm_clients = ReEvoLLMClients( generator_llm=llm_client, reflector_llm=reflector_llm, short_reflector_llm=short_reflector_llm, long_reflector_llm=long_reflector_llm, crossover_llm=crossover_llm, mutation_llm=mutation_llm, ) self.evol = ReEvoPrompts(self.problem) self.evaluator = evaluator self.output_dir = output_dir os.makedirs(f"{self.output_dir}/thinking", exist_ok=True) self.mutation_rate = self.config.mutation_rate self.iteration = 0 self.elitist = None self.long_term_reflection_str = "" self.best_obj_overall = None self.best_code_overall = None self.best_code_path_overall = None def _logging_context(self) -> dict: return { "method": "ReEvo", "problem_name": self.problem.name, "max_fe": self.config.max_fe, "pop_size": self.config.pop_size, }
[docs] def init_population(self) -> None: # Evaluate the seed function, and set it as Elite code = extract_code_from_generator(self.problem.seed_func).replace("v1", "v2") seed_ind = Individual(name="seed", code=code) self.population = self.evaluator.batch_evaluate( [seed_ind], Path(self.output_dir) ) self.seed_ind = self.population[0] # If seed function is invalid, stop if not self.seed_ind.exec_success: raise RuntimeError( "Seed function is invalid. " f"Please check the stdout file in {os.getcwd()}." ) self.update_iter() # Generate responses messages = self.evol.get_seed_population_messages(self.long_term_reflection_str) responses = self.llm_clients.generator_llm.multi_chat_completion( [messages], self.config.init_pop_size, temperature=self.llm_clients.generator_llm.temperature + 0.3, ) population = [ self.response_to_individual(resp, messages, f"seed_population_{i}") for i, resp in enumerate(responses) ] # Run code and evaluate population population = self.evaluator.batch_evaluate(population, Path(self.output_dir)) # Update iteration self.population = population self.update_iter()
[docs] def response_to_individual( self, response: str, messages: list[dict], name: str ) -> Individual: """ Convert response to individual """ # Write response to file individual_dir = f"{self.output_dir}/individuals/{name}" os.makedirs(individual_dir, exist_ok=True) file_name = f"{individual_dir}/response.txt" with open(file_name, "w", encoding="utf-8") as file: file.writelines(response + "\n") with open(f"{individual_dir}/message.json", "w") as file: file.write(json.dumps(messages, indent=4)) code = extract_code_from_generator(response) individual = Individual(name=name, code=code) return individual
[docs] def update_iter(self) -> None: """ Update after each iteration """ population = self.population # Save population checkpoint population_checkpoint( population=population, name=f"iteration_{self.iteration}", output_dir=self.output_dir, ) objs = [individual.obj for individual in population] best_obj, best_sample_idx = min(objs), np.argmin(np.array(objs)) # update best overall if self.best_obj_overall is None or best_obj < self.best_obj_overall: self.best_obj_overall = best_obj self.best_code_overall = population[best_sample_idx].code self.best_code_path_overall = ( f"{self.output_dir}/{population[best_sample_idx].name}/code.py" ) # update elitist if self.elitist is None or best_obj < self.elitist.obj: self.elitist = population[best_sample_idx] logger.info( f"Iteration {self.iteration}: Elitist: {self.elitist.obj}", extra={ "iteration": self.iteration, "elitist_obj": self.elitist.obj, **self._logging_context(), }, ) best_path = self.best_code_path_overall.replace(".py", ".txt").replace( "code", "response" ) logger.info( f"Best obj: {self.best_obj_overall}, Best Code Path: {print_hyperlink(best_path, self.best_code_path_overall)}", extra={ "best_objective": self.best_obj_overall, "best_code_path": self.best_code_path_overall, **self._logging_context(), }, ) logger.info( f"Iteration {self.iteration} finished...", extra={ "iteration": self.iteration, "function_evals": self.evaluator.function_evals, **self._logging_context(), }, ) self.iteration += 1
[docs] def rank_select(self, population: list[Individual]) -> list[Individual] | None: """ Rank-based selection, select individuals with probability proportional to rank. """ if self.problem.type == "black_box": population = [ individual for individual in population if individual.exec_success and individual.obj < self.seed_ind.obj ] else: population = [ individual for individual in population if individual.exec_success ] if len(population) < 2: return None # Sort population by objective value population = sorted(population, key=lambda x: x.obj) ranks = [i for i in range(len(population))] probs = [1 / (rank + 1 + len(population)) for rank in ranks] # Normalize probabilities probs = [prob / sum(probs) for prob in probs] selected_population: list[Individual] = [] trial = 0 while len(selected_population) < 2 * self.config.pop_size: trial += 1 parents: list[Individual] = np.random.choice( np.array(population), size=2, replace=False, p=probs ).tolist() if parents[0].obj != parents[1].obj: selected_population.extend(parents) if trial > 1000: return None return selected_population
[docs] def random_select(self, population: list[Individual]) -> list[Individual] | None: """ Random selection, select individuals with equal probability. """ selected_population: list[Individual] = [] # Eliminate invalid individuals if self.problem.type == "black_box": population = [ individual for individual in population if individual.exec_success and individual.obj < self.seed_ind.obj ] else: population = [ individual for individual in population if individual.exec_success ] if len(population) < 2: return None trial = 0 while len(selected_population) < 2 * self.config.pop_size: trial += 1 parents: list[Individual] = np.random.choice( np.array(population), size=2, replace=False ).tolist() # If two parents have the same objective value, consider them as identical; otherwise, add them to the selected population if parents[0].obj != parents[1].obj: selected_population.extend(parents) if trial > 1000: return None return selected_population
[docs] def long_term_reflection(self, short_term_reflections: list[str]) -> None: """ Long-term reflection before mutation. """ messages = self.evol.get_long_term_reflection_messages( short_term_reflections, self.long_term_reflection_str ) self.long_term_reflection_str = ( self.llm_clients.long_reflector_llm.multi_chat_completion([messages])[0] ) # Write reflections to file file_name = f"{self.output_dir}/thinking/iteration{self.iteration}_short_term_reflections.txt" with open(file_name, "w") as file: file.writelines("\n".join(short_term_reflections) + "\n") file_name = f"{self.output_dir}/thinking/iteration{self.iteration}_long_term_reflection.txt" with open(file_name, "w") as file: file.writelines(self.long_term_reflection_str + "\n")
[docs] def evolve(self) -> tuple[str, str]: logger.info("Starting ReEvo evolution", extra=self._logging_context()) self.init_population() logger.info( "Initial population created", extra={ "population_size": len(self.population), **self._logging_context(), }, ) while self.evaluator.function_evals < self.config.max_fe: logger.debug( "Evolution iteration", extra={ "iteration": self.iteration, "function_evals": self.evaluator.function_evals, **self._logging_context(), }, ) # If all individuals are invalid, stop if all([not individual.exec_success for individual in self.population]): logger.error( "All individuals are invalid, stopping evolution", extra=self._logging_context(), ) raise RuntimeError( f"All individuals are invalid. Please check the stdout files in {os.getcwd()}." ) # Select population_to_select = ( self.population if (self.elitist is None or self.elitist in self.population) else [self.elitist] + self.population ) # add elitist to population for selection selected_population = self.random_select(population_to_select) if selected_population is None: logger.error("Selection failed", extra=self._logging_context()) raise RuntimeError("Selection failed. Please check the population.") logger.debug( f"Selected {len(selected_population)} individuals", extra={ "n_selected": len(selected_population), **self._logging_context(), }, ) # Short-term reflection messages_lst, worse_code_lst, better_code_lst = ( self.evol.get_short_term_reflection_messages(selected_population) ) short_term_reflections = ( self.llm_clients.short_reflector_llm.multi_chat_completion(messages_lst) ) logger.debug( "Short-term reflection complete", extra=self._logging_context() ) # Crossover crossover_messages = self.evol.get_crossover_messages( (short_term_reflections, worse_code_lst, better_code_lst) ) crossed_response_lst = self.llm_clients.crossover_llm.multi_chat_completion( crossover_messages ) crossed_population = [ self.response_to_individual( response, messages, f"iteration{self.iteration}_crossover_{i}" ) for i, (response, messages) in enumerate( zip(crossed_response_lst, crossover_messages) ) ] logger.debug( f"Crossover generated {len(crossed_population)} individuals", extra={ "n_crossover": len(crossed_population), **self._logging_context(), }, ) # Evaluate self.population = self.evaluator.batch_evaluate( crossed_population, Path(self.output_dir) ) # Update self.update_iter() # Long-term reflection self.long_term_reflection(short_term_reflections) logger.debug("Long-term reflection complete", extra=self._logging_context()) # Mutate mutation_messages = self.evol.get_mutation_messages( self.long_term_reflection_str, self.elitist ) mutated_response_lst = self.llm_clients.mutation_llm.multi_chat_completion( [mutation_messages], int(self.config.pop_size * self.config.mutation_rate), ) mutated_population = [ self.response_to_individual( response, mutation_messages, f"iteration{self.iteration}_mutation_{i}", ) for i, response in enumerate(mutated_response_lst) ] logger.debug( f"Mutation generated {len(mutated_population)} individuals", extra={ "n_mutation": len(mutated_population), **self._logging_context(), }, ) # Evaluate self.population.extend( self.evaluator.batch_evaluate(mutated_population, Path(self.output_dir)) ) # Update self.update_iter() logger.info( "ReEvo evolution completed", extra={ "best_objective": self.best_obj_overall, "function_evals": self.evaluator.function_evals, **self._logging_context(), }, ) return self.best_code_overall, self.best_code_path_overall