Source code for llamda.llm_client.base

# Adapted from ReEvo: https://github.com/ai4co/reevo/blob/main/utils/llm_client/base.py
# Licensed under the MIT License (see THIRD-PARTY-LICENSES.txt)

import os
import time
import logging
import concurrent.futures
from dataclasses import dataclass


logger = logging.getLogger(__name__)


[docs] @dataclass class BaseLLMClientConfig: model: str temperature: float = 1.0
[docs] class BaseClient(object): def __init__( self, config: BaseLLMClientConfig, ) -> None: self.model = config.model self.temperature = config.temperature def _chat_completion_api( self, messages: list[dict], temperature: float, n: int = 1 ) -> list: raise NotImplementedError
[docs] def chat_completion( self, n: int, messages: list[dict], temperature: float | None = None ) -> list: """ Generate n responses using OpenAI Chat Completions API """ temperature = temperature or self.temperature for attempt in range(3): try: return self._chat_completion_api(messages, temperature, n) except Exception as e: if attempt == 2: # Last attempt logger.error("All 3 attempts failed") raise logger.warning(f"Attempt {attempt + 1} failed with error: {e}") time.sleep(1)
[docs] def multi_chat_completion( self, messages_list: list[list[dict]], n: int = 1, temperature: float | None = None, ) -> list[str]: """ Generate multiple chat completions in parallel. param: n: number of responses to generate for each message in messages_list """ # If messages_list is not a list of list (i.e., only one conversation), # convert it to a list of list assert isinstance(messages_list, list), "messages_list should be a list." try: if not isinstance(messages_list[0], list): messages_list = [messages_list] except Exception as e: logger.error(f"Error processing messages_list: {e}") raise IndexError("Something is wrong.") if len(messages_list) > 1: assert n == 1, "Currently, only n=1 is supported for multi-chat completion." num_workers = os.cpu_count() if "gpt" not in self.model: # Transform messages if n > 1 messages_list *= n n = 1 num_workers = 2 with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: args = [ dict(n=n, messages=messages, temperature=temperature) for messages in messages_list ] choices = executor.map(lambda p: self.chat_completion(**p), args) contents: list[str] = [] for choice in choices: for c in choice: contents.append(c.message.content) return contents