# Adapted from MCTS-AHD: https://github.com/zz1358m/MCTS-AHD-master/blob/main/source/mcts_ahd.py
# Licensed under the MIT License (see THIRD-PARTY-LICENSES.txt)
import copy
import heapq
import logging
from dataclasses import dataclass, field
import numpy as np
from llamda.ga.base import GeneticAlgorithm
from llamda.ga.mcts.mcts_prompts import MCTSOperator
from llamda.ga.mcts.mcts import MCTS, MCTSNode
from llamda.evaluate import Evaluator
from llamda.ga.mcts.evolution_interface import MCTSIndividual, InterfaceEC
from llamda.ga.utils import population_checkpoint
from llamda.llm_client.base import BaseClient
from llamda.problem import EohProblem
logger = logging.getLogger("llamda")
[docs]
@dataclass
class AHDConfig:
# MCTS configuration
pop_size: int = 10 # Size of Elite set E, default = 10
init_size: int = 4 # Number of initial nodes N_I, default = 4
fe_max: int = 1000 # Number of evaluations, default = 1000
operators: list[str] = field(
default_factory=lambda: ["e1", "e2", "m1", "m2", "s1"]
) # evolution operators
m: int = 5 # Note: m=5 was a manual override in the original implementation
operator_weights: list[int] = field(
default_factory=lambda: [0, 1, 2, 2, 1]
) # weights for operators default
[docs]
class MCTS_AHD(GeneticAlgorithm[AHDConfig, EohProblem]):
def __init__(
self,
config: AHDConfig,
problem: EohProblem,
evaluator: Evaluator,
llm_client: BaseClient,
output_dir: str,
) -> None:
super().__init__(
config=config,
problem=problem,
evaluator=evaluator,
llm_client=llm_client,
output_dir=output_dir,
)
self.eval_times = 0 # number of populations
def _logging_context(self) -> dict:
return {
"method": "MCTS-AHD",
"problem_name": self.problem.name,
"pop_size": self.config.pop_size,
"init_size": self.config.init_size,
"fe_max": self.config.fe_max,
}
# add new individual to population
[docs]
def add2pop(
self, population: list[MCTSIndividual], offspring: MCTSIndividual
) -> None:
for ind in population:
if ind.algorithm == offspring.algorithm:
# TODO: no actual retry logic implemented
logger.warning("duplicated result, retrying ... ")
population.append(offspring)
[docs]
def expand(
self,
mcts: MCTS,
cur_node: MCTSNode,
nodes_set: list[MCTSIndividual],
option: str,
name: str,
) -> list[MCTSIndividual]:
if option == "s1":
path_set: list[MCTSIndividual] = []
now = copy.deepcopy(cur_node)
while now.code != "Root":
path_set.append(now.raw_info)
now = copy.deepcopy(now.parent)
path_set = manage_population_s1(path_set, len(path_set))
if len(path_set) == 1:
return nodes_set
self.eval_times, offsprings = self.interface_ec.evolve_algorithm(
eval_times=self.eval_times,
pop=path_set,
node=cur_node.raw_info,
operator=MCTSOperator.S1,
name=name,
)
elif option == "e1":
e1_set = [
copy.deepcopy(
children.subtree[np.random.randint(len(children.subtree))].raw_info
)
for children in mcts.root.children
]
self.eval_times, offsprings = self.interface_ec.evolve_algorithm(
eval_times=self.eval_times,
pop=e1_set,
node=cur_node.raw_info,
operator=MCTSOperator.E1,
name=name,
)
else:
self.eval_times, offsprings = self.interface_ec.evolve_algorithm(
eval_times=self.eval_times,
pop=nodes_set,
node=cur_node.raw_info,
operator=MCTSOperator(option),
name=name,
)
if offsprings is None:
logger.warning(f"Timeout emerge, no expanding with action {option}.")
return nodes_set
if option != "e1":
logger.info(
"Action",
extra={
"action": option,
"father_obj": cur_node.raw_info.obj,
"now_obj": offsprings.obj,
"depth": cur_node.depth + 1,
},
)
else:
if self.interface_ec.check_duplicate_obj(
mcts.root.children_info, offsprings.obj
):
logger.info(
"Duplicated e1, no action, Father is Root",
extra={"abandon_obj": offsprings.obj},
)
return nodes_set
else:
logger.info(
"Action, Father is Root",
extra={"now_obj": offsprings.obj},
)
if offsprings.obj != float("inf"):
self.add2pop(
nodes_set, offsprings
) # Check duplication, and add the new offspring
size_act = min(len(nodes_set), self.config.pop_size)
nodes_set = manage_population(nodes_set, size_act)
nownode = MCTSNode(
offsprings.algorithm,
offsprings.code,
offsprings.obj,
parent=cur_node,
depth=cur_node.depth + 1,
visit=1,
Q=-1 * offsprings.obj,
raw_info=offsprings,
)
if option == "e1":
nownode.subtree.append(nownode)
cur_node.add_child(nownode)
cur_node.children_info.append(offsprings)
mcts.backpropagate(nownode)
return nodes_set
# run eoh
[docs]
def run(self) -> tuple[str, str]:
logger.info("Starting MCTS-AHD evolution", extra=self._logging_context())
self.interface_ec = InterfaceEC(
m=self.config.m,
problem=self.problem,
evaluator=self.evaluator,
llm_client=self.llm_client,
output_dir=self.output_dir,
)
brothers: list[MCTSIndividual] = []
mcts = MCTS("Root")
# main loop
n_op = len(self.config.operators)
n_evals, brothers, offspring = self.interface_ec.get_algorithm(
brothers, MCTSOperator.I1, name="initialization"
)
self.eval_times += n_evals
brothers.append(offspring)
nownode = MCTSNode(
offspring.algorithm,
offspring.code,
offspring.obj,
parent=mcts.root,
depth=1,
visit=1,
Q=-1 * offspring.obj,
raw_info=offspring,
)
mcts.root.add_child(nownode)
mcts.root.children_info.append(offspring)
mcts.backpropagate(nownode)
nownode.subtree.append(nownode)
logger.info(
"Initial node created",
extra={
"objective": offspring.obj,
"eval_times": self.eval_times,
**self._logging_context(),
},
)
for i in range(1, self.config.init_size):
n_evals, brothers, offspring = self.interface_ec.get_algorithm(
brothers, MCTSOperator.E1, name=f"e1_initialization_{i}"
)
self.eval_times += n_evals
brothers.append(offspring)
nownode = MCTSNode(
offspring.algorithm,
offspring.code,
offspring.obj,
parent=mcts.root,
depth=1,
visit=1,
Q=-1 * offspring.obj,
raw_info=offspring,
)
mcts.root.add_child(nownode)
mcts.root.children_info.append(offspring)
mcts.backpropagate(nownode)
nownode.subtree.append(nownode)
nodes_set = brothers
size_act = min(len(nodes_set), self.config.pop_size)
nodes_set = manage_population(nodes_set, size_act)
logger.info(
"Initialization completed",
extra={
"population_size": len(nodes_set),
"eval_times": self.eval_times,
**self._logging_context(),
},
)
iteration = 0
while self.eval_times < self.config.fe_max:
logger.info(
"MCTS-AHD iteration",
extra={
"rank_list": mcts.rank_list,
"eval_times": self.eval_times,
**self._logging_context(),
},
)
cur_node = mcts.root
while len(cur_node.children) > 0 and cur_node.depth < mcts.max_depth:
uct_scores = [
mcts.uct(node, max(1 - self.eval_times / self.config.fe_max, 0))
for node in cur_node.children
]
selected_pair_idx = uct_scores.index(max(uct_scores))
if int((cur_node.visits) ** mcts.alpha) > len(cur_node.children):
if cur_node == mcts.root:
op = "e1"
nodes_set = self.expand(
mcts,
cur_node,
nodes_set,
op,
f"iteration{iteration}_root_e1",
)
else:
# i = random.randint(1, n_op - 1)
i = 1
op = self.config.operators[i]
nodes_set = self.expand(
mcts,
cur_node,
nodes_set,
op,
f"iteration{iteration}_max_utc_e2",
)
cur_node = cur_node.children[selected_pair_idx]
for i in range(n_op):
op = self.config.operators[i]
logger.info(
f"Applying operator [{i + 1} / {n_op}]",
extra={
"operator": op,
"eval_times": self.eval_times,
**self._logging_context(),
},
)
op_w = self.config.operator_weights[i]
for j in range(op_w):
nodes_set = self.expand(
mcts,
cur_node,
nodes_set,
op,
f"iteration{iteration}_{op}_{j}",
)
assert len(cur_node.children) == len(cur_node.children_info)
filename = population_checkpoint(
population=nodes_set,
name=f"iteration_{iteration}_evals_{self.eval_times}",
output_dir=self.output_dir,
)
iteration += 1
return nodes_set[0].code, filename
[docs]
def manage_population(
pop_input: list[MCTSIndividual], size: int
) -> list[MCTSIndividual]:
pop = [individual for individual in pop_input if individual.obj is not None]
if size > len(pop):
size = len(pop)
unique_pop: list[MCTSIndividual] = []
unique_objectives = []
for individual in pop:
if individual.obj not in unique_objectives:
unique_pop.append(individual)
unique_objectives.append(individual.obj)
# Delete the worst individual
pop_new = heapq.nsmallest(size, unique_pop, key=lambda x: x.obj)
return pop_new
[docs]
def manage_population_s1(
pop_input: list[MCTSIndividual], size: int
) -> list[MCTSIndividual]:
pop = [individual for individual in pop_input if individual.obj is not None]
if size > len(pop):
size = len(pop)
unique_pop: list[MCTSIndividual] = []
unique_algorithms = []
for individual in pop:
if individual.algorithm not in unique_algorithms:
unique_pop.append(individual)
unique_algorithms.append(individual.algorithm)
# Delete the worst individual
pop_new = heapq.nlargest(size, unique_pop, key=lambda x: x.obj)
return pop_new