# Adapted from MCTS-AHD: https://github.com/zz1358m/MCTS-AHD-master/blob/main/source/mcts.py
# Licensed under the MIT License (see THIRD-PARTY-LICENSES.txt)
from __future__ import annotations
import math
from typing import Any
[docs]
class MCTSNode:
def __init__(
self,
algorithm: str,
code: str,
obj: float,
depth: int = 0,
parent: MCTSNode | None = None,
visit: int = 0,
raw_info: Any | None = None,
Q: float = 0,
is_root: bool = False,
) -> None:
self.algorithm = algorithm
self.code = code
self.parent = parent
self.depth = depth
self.children: list[MCTSNode] = []
self.children_info: list[Any] = []
self.visits = visit
self.subtree: list[MCTSNode] = []
self.raw_info = raw_info
self.Q = Q
self.reward = -1 * obj
self.is_root = is_root
[docs]
def add_child(self, child_node: MCTSNode) -> None:
self.children.append(child_node)
def __repr__(self) -> str:
return f"MCTSNode(algorithm={self.algorithm}, Q={self.Q:.2f}, visits={self.visits})"
[docs]
class MCTS:
def __init__(self, root_answer: str) -> None:
self.exploration_constant_0 = 0.1
self.alpha = 0.5
self.max_depth = 10
self.epsilon = 1e-10
self.discount_factor = 1
self.q_min: float = 0
self.q_max: float = -10000
self.rank_list: list[float] = []
self.root = MCTSNode(
algorithm=root_answer, code=root_answer, depth=0, obj=0, is_root=True
)
# Logs
self.critiques: list[str] = []
self.refinements: list[str] = []
self.rewards: list[float] = []
self.selected_nodes: list[MCTSNode] = []
[docs]
def backpropagate(self, node: MCTSNode) -> None:
if node.Q not in self.rank_list:
self.rank_list.append(node.Q)
self.rank_list.sort()
self.q_min = min(self.q_min, node.Q)
self.q_max = max(self.q_max, node.Q)
parent = node.parent
while parent:
best_child_Q = max(child.Q for child in parent.children)
parent.Q = (
parent.Q * (1 - self.discount_factor)
+ best_child_Q * self.discount_factor
)
parent.visits += 1
if parent.code != "Root" and parent.parent.code == "Root":
parent.subtree.append(node)
parent = parent.parent
[docs]
def uct(self, node: MCTSNode, eval_remain: float) -> float:
self.exploration_constant = (self.exploration_constant_0) * eval_remain
return (node.Q - self.q_min) / (
self.q_max - self.q_min
) + self.exploration_constant * math.sqrt(
math.log(node.parent.visits + 1) / node.visits
)
[docs]
def is_fully_expanded(self, node: MCTSNode) -> bool:
return (
len(node.children) >= self.max_children
or any(child.Q > node.Q for child in node.children)
or node.code == "Root"
)