From 8399a7160a4f7fa19151cdeff1c3109253c468fd Mon Sep 17 00:00:00 2001 From: Philipp Date: Wed, 14 Feb 2024 11:06:52 +0100 Subject: [PATCH] issue #47: add option to set max_token to be generated (#48) * issue #47: add option to set max_token to be generated * issue #47: rename max_token option to -l (limit, length) --- backends/alephalpha_api.py | 3 ++- backends/anthropic_api.py | 3 ++- backends/cohere_api.py | 1 + backends/huggingface_local_api.py | 12 ++++++------ backends/mistral_api.py | 5 ++--- backends/openai_api.py | 7 +++---- backends/openai_compatible_api.py | 5 ++--- clemgame/benchmark.py | 5 +++-- clemgame/clemgame.py | 6 ++++-- scripts/cli.py | 5 +++++ 10 files changed, 30 insertions(+), 22 deletions(-) diff --git a/backends/alephalpha_api.py b/backends/alephalpha_api.py index ec769a7ddc..1532d14cc5 100644 --- a/backends/alephalpha_api.py +++ b/backends/alephalpha_api.py @@ -22,6 +22,7 @@ def __init__(self): creds = backends.load_credentials(NAME) self.client = aleph_alpha_client.Client(creds[NAME]["api_key"]) self.temperature: float = -1. + self.max_tokens: int = 100 @retry(tries=3, delay=0, logger=logger) def generate_response(self, messages: List[Dict], model: str) -> Tuple[Any, Any, str]: @@ -58,7 +59,7 @@ def generate_response(self, messages: List[Dict], model: str) -> Tuple[Any, Any, params = { "prompt": aleph_alpha_client.Prompt.from_text(prompt_text), - "maximum_tokens": 100, + "maximum_tokens": self.max_tokens, "stop_sequences": ['\n'], "temperature": self.temperature } diff --git a/backends/anthropic_api.py b/backends/anthropic_api.py index 3f38332ad2..87f32a6dd6 100644 --- a/backends/anthropic_api.py +++ b/backends/anthropic_api.py @@ -21,6 +21,7 @@ def __init__(self): creds = backends.load_credentials(NAME) self.client = anthropic.Anthropic(api_key=creds[NAME]["api_key"]) self.temperature: float = -1. + self.max_tokens: int = 100 @retry(tries=3, delay=0, logger=logger) def generate_response(self, messages: List[Dict], model: str) -> Tuple[str, Any, str]: @@ -50,7 +51,7 @@ def generate_response(self, messages: List[Dict], model: str) -> Tuple[str, Any, stop_sequences=[anthropic.HUMAN_PROMPT, '\n'], model=model, temperature=self.temperature, - max_tokens_to_sample=100 + max_tokens_to_sample=self.max_tokens ) response_text = completion.completion.strip() diff --git a/backends/cohere_api.py b/backends/cohere_api.py index e16c7826f6..3737413968 100644 --- a/backends/cohere_api.py +++ b/backends/cohere_api.py @@ -18,6 +18,7 @@ def __init__(self): creds = backends.load_credentials(NAME) self.client = cohere.Client(creds[NAME]["api_key"]) self.temperature: float = -1. + self.max_tokens: int = -1 # not applicable in this backend? @retry(tries=3, delay=0, logger=logger) def generate_response(self, messages: List[Dict], model: str) -> Tuple[str, Any, str]: diff --git a/backends/huggingface_local_api.py b/backends/huggingface_local_api.py index 8be10a0662..57c7693f1c 100644 --- a/backends/huggingface_local_api.py +++ b/backends/huggingface_local_api.py @@ -28,6 +28,7 @@ class HuggingfaceLocal(backends.Backend): def __init__(self): self.temperature: float = -1. + self.max_tokens: int = 100 # How many tokens to generate ('at most', but no stop sequence is defined). self.use_api_key: bool = False self.config_and_tokenizer_loaded: bool = False self.model_loaded: bool = False @@ -311,7 +312,7 @@ def check_context_limit(self, messages: List[Dict], model: str, return fits, tokens_used, tokens_left, self.context_size def generate_response(self, messages: List[Dict], model: str, - max_new_tokens: int = 100, return_full_text: bool = False, + return_full_text: bool = False, log_messages: bool = False) -> Tuple[Any, Any, str]: """ :param messages: for example @@ -322,7 +323,6 @@ def generate_response(self, messages: List[Dict], model: str, {"role": "user", "content": "Where was it played?"} ] :param model: model name - :param max_new_tokens: How many tokens to generate ('at most', but no stop sequence is defined). :param return_full_text: If True, whole input context is returned. :param log_messages: If True, raw and cleaned messages passed will be logged. :return: the continuation @@ -354,11 +354,11 @@ def generate_response(self, messages: List[Dict], model: str, prompt_tokens = prompt_tokens.to(self.device) prompt_text = self.tokenizer.batch_decode(prompt_tokens)[0] - prompt = {"inputs": prompt_text, "max_new_tokens": max_new_tokens, + prompt = {"inputs": prompt_text, "max_new_tokens": self.max_tokens, "temperature": self.temperature, "return_full_text": return_full_text} # check context limit: - context_check = self._check_context_limit(prompt_tokens[0], max_new_tokens=max_new_tokens) + context_check = self._check_context_limit(prompt_tokens[0], max_new_tokens=self.max_tokens) if not context_check[0]: # if context is exceeded, context_check[0] is False logger.info(f"Context token limit for {self.model_name} exceeded: {context_check[1]}/{context_check[3]}") # fail gracefully: @@ -375,13 +375,13 @@ def generate_response(self, messages: List[Dict], model: str, model_output_ids = self.model.generate( prompt_tokens, temperature=self.temperature, - max_new_tokens=max_new_tokens, + max_new_tokens=self.max_tokens, do_sample=do_sample ) else: model_output_ids = self.model.generate( prompt_tokens, - max_new_tokens=max_new_tokens, + max_new_tokens=self.max_tokens, do_sample=do_sample ) diff --git a/backends/mistral_api.py b/backends/mistral_api.py index 141aeec2c0..4fab29e24a 100644 --- a/backends/mistral_api.py +++ b/backends/mistral_api.py @@ -14,14 +14,13 @@ NAME = "mistral" -MAX_TOKENS = 100 - class Mistral(backends.Backend): def __init__(self): creds = backends.load_credentials(NAME) self.client = MistralClient(api_key=creds[NAME]["api_key"]) self.temperature: float = -1. + self.max_tokens: int = 100 def list_models(self): models = self.client.models.list() @@ -50,7 +49,7 @@ def generate_response(self, messages: List[Dict], model: str) -> Tuple[str, Any, api_response = self.client.chat(model=model, messages=prompt, temperature=self.temperature, - max_tokens=MAX_TOKENS) + max_tokens=self.max_tokens) message = api_response.choices[0].message if message.role != "assistant": # safety check raise AttributeError("Response message role is " + message.role + " but should be 'assistant'") diff --git a/backends/openai_api.py b/backends/openai_api.py index 266c882e9c..6c742140b3 100644 --- a/backends/openai_api.py +++ b/backends/openai_api.py @@ -18,8 +18,6 @@ NAME = "openai" -MAX_TOKENS = 100 # 2024-01-10, das: Should this be hardcoded??? - class OpenAI(backends.Backend): def __init__(self): @@ -35,6 +33,7 @@ def __init__(self): ) self.chat_models: List = ["gpt-3.5-turbo-0613", "gpt-3.5-turbo-1106", "gpt-4-0314", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview"] self.temperature: float = -1. + self.max_tokens: int = 100 def list_models(self): models = self.client.models.list() @@ -63,7 +62,7 @@ def generate_response(self, messages: List[Dict], model: str) -> Tuple[str, Any, api_response = self.client.chat.completions.create(model=model, messages=prompt, temperature=self.temperature, - max_tokens=MAX_TOKENS) + max_tokens=self.max_tokens) message = api_response.choices[0].message if message.role != "assistant": # safety check raise AttributeError("Response message role is " + message.role + " but should be 'assistant'") @@ -73,7 +72,7 @@ def generate_response(self, messages: List[Dict], model: str) -> Tuple[str, Any, else: # default (text completion) prompt = "\n".join([message["content"] for message in messages]) api_response = self.client.completions.create(model=model, prompt=prompt, - temperature=self.temperature, max_tokens=100) + temperature=self.temperature, max_tokens=self.max_tokens) response = json.loads(api_response.json()) response_text = api_response.choices[0].text.strip() return prompt, response, response_text diff --git a/backends/openai_compatible_api.py b/backends/openai_compatible_api.py index e51438a813..f5ee42e1f7 100644 --- a/backends/openai_compatible_api.py +++ b/backends/openai_compatible_api.py @@ -8,8 +8,6 @@ logger = backends.get_logger(__name__) -MAX_TOKENS = 100 - # For this backend, it makes less sense to talk about "supported models" than for others, # because what is supported depends very much on where this is pointed to. # E.g., if I run FastChat on my local machine, I may have very different models available @@ -36,6 +34,7 @@ def __init__(self): http_client=httpx.Client(verify=False) ) self.temperature: float = -1. + self.max_tokens: int = 100 def list_models(self): models = self.client.models.list() @@ -63,7 +62,7 @@ def generate_response(self, messages: List[Dict], model: str) -> Tuple[str, Any, prompt = messages api_response = self.client.chat.completions.create(model=model, messages=prompt, - temperature=self.temperature, max_tokens=MAX_TOKENS) + temperature=self.temperature, max_tokens=self.max_tokens) message = api_response.choices[0].message if message.role != "assistant": # safety check raise AttributeError("Response message role is " + message.role + " but should be 'assistant'") diff --git a/clemgame/benchmark.py b/clemgame/benchmark.py index 74a72908fc..711a108b6d 100644 --- a/clemgame/benchmark.py +++ b/clemgame/benchmark.py @@ -22,8 +22,9 @@ def list_games(): stdout_logger.info(" Game: %s -> %s", game.name, game.get_description()) -def run(game_name: str, temperature: float, models: List[str] = None, experiment_name: str = None): +def run(game_name: str, max_tokens: int, temperature: float, models: List[str] = None, experiment_name: str = None): assert 0.0 <= temperature <= 1.0, "Temperature must be in [0.,1.]" + assert max_tokens > 0, "Max tokens should be larger than zero" if experiment_name: logger.info("Only running experiment: %s", experiment_name) try: @@ -33,7 +34,7 @@ def run(game_name: str, temperature: float, models: List[str] = None, experiment if experiment_name: benchmark.filter_experiment.append(experiment_name) time_start = datetime.now() - benchmark.run(player_backends=models, temperature=temperature) + benchmark.run(player_backends=models, temperature=temperature, max_tokens=max_tokens) time_end = datetime.now() logger.info(f"Run {benchmark.name} took {str(time_end - time_start)}") except Exception as e: diff --git a/clemgame/clemgame.py b/clemgame/clemgame.py index 8541305a7c..ce91100e30 100644 --- a/clemgame/clemgame.py +++ b/clemgame/clemgame.py @@ -671,7 +671,7 @@ def compute_scores(self): stdout_logger.error( f"{self.name}: '{error_count}' exceptions occurred: See clembench.log for details.") - def run(self, player_backends: List[str], temperature: float): + def run(self, player_backends: List[str], temperature: float, max_tokens: int): """ Runs game-play on all game instances for a game. There must be an instances.json with the following structure: @@ -697,9 +697,11 @@ def run(self, player_backends: List[str], temperature: float): - instance.json - interaction.json """ - self.logger.warning(f"{self.name}: Detected 'temperature={temperature}'") # Setting this directly on the apis for now (not on the players) + self.logger.warning(f"{self.name}: Detected 'temperature={temperature}'") backends.configure(lambda backend: setattr(backend, "temperature", temperature)) + self.logger.warning(f"{self.name}: Detected 'max_tokens={max_tokens}'") + backends.configure(lambda backend: setattr(backend, "max_tokens", max_tokens)) experiments: List = self.instances["experiments"] if not experiments: diff --git a/scripts/cli.py b/scripts/cli.py index 406af5e165..4af3671b0a 100644 --- a/scripts/cli.py +++ b/scripts/cli.py @@ -36,6 +36,7 @@ def main(args): benchmark.list_games() if args.command_name == "run": benchmark.run(args.game, + max_tokens=args.max_tokens, temperature=args.temperature, models=args.models, experiment_name=args.experiment_name) @@ -67,6 +68,10 @@ def main(args): Default: None.""") run_parser.add_argument("-t", "--temperature", type=float, default=0.0, help="Argument to specify sampling temperature for the models. Default: 0.0.") + run_parser.add_argument("-l", "--max-tokens", type=int, default=100, + help="Specify the maximum number of tokens to be generated per turn (except for cohere). " + "Be careful with high values which might lead to exceed your API token limits." + "Default: 100.") run_parser.add_argument("-e", "--experiment_name", type=str, help="Optional argument to only run a specific experiment") run_parser.add_argument("-g", "--game", type=str,