Skip to content

Commit

Permalink
issue #47: add option to set max_token to be generated (#48)
Browse files Browse the repository at this point in the history
* issue #47: add option to set max_token to be generated

* issue #47: rename max_token option to -l (limit, length)
  • Loading branch information
phisad authored Feb 14, 2024
1 parent e1a9e8e commit 8399a71
Show file tree
Hide file tree
Showing 10 changed files with 30 additions and 22 deletions.
3 changes: 2 additions & 1 deletion backends/alephalpha_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(self):
creds = backends.load_credentials(NAME)
self.client = aleph_alpha_client.Client(creds[NAME]["api_key"])
self.temperature: float = -1.
self.max_tokens: int = 100

@retry(tries=3, delay=0, logger=logger)
def generate_response(self, messages: List[Dict], model: str) -> Tuple[Any, Any, str]:
Expand Down Expand Up @@ -58,7 +59,7 @@ def generate_response(self, messages: List[Dict], model: str) -> Tuple[Any, Any,

params = {
"prompt": aleph_alpha_client.Prompt.from_text(prompt_text),
"maximum_tokens": 100,
"maximum_tokens": self.max_tokens,
"stop_sequences": ['\n'],
"temperature": self.temperature
}
Expand Down
3 changes: 2 additions & 1 deletion backends/anthropic_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(self):
creds = backends.load_credentials(NAME)
self.client = anthropic.Anthropic(api_key=creds[NAME]["api_key"])
self.temperature: float = -1.
self.max_tokens: int = 100

@retry(tries=3, delay=0, logger=logger)
def generate_response(self, messages: List[Dict], model: str) -> Tuple[str, Any, str]:
Expand Down Expand Up @@ -50,7 +51,7 @@ def generate_response(self, messages: List[Dict], model: str) -> Tuple[str, Any,
stop_sequences=[anthropic.HUMAN_PROMPT, '\n'],
model=model,
temperature=self.temperature,
max_tokens_to_sample=100
max_tokens_to_sample=self.max_tokens
)

response_text = completion.completion.strip()
Expand Down
1 change: 1 addition & 0 deletions backends/cohere_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(self):
creds = backends.load_credentials(NAME)
self.client = cohere.Client(creds[NAME]["api_key"])
self.temperature: float = -1.
self.max_tokens: int = -1 # not applicable in this backend?

@retry(tries=3, delay=0, logger=logger)
def generate_response(self, messages: List[Dict], model: str) -> Tuple[str, Any, str]:
Expand Down
12 changes: 6 additions & 6 deletions backends/huggingface_local_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
class HuggingfaceLocal(backends.Backend):
def __init__(self):
self.temperature: float = -1.
self.max_tokens: int = 100 # How many tokens to generate ('at most', but no stop sequence is defined).
self.use_api_key: bool = False
self.config_and_tokenizer_loaded: bool = False
self.model_loaded: bool = False
Expand Down Expand Up @@ -311,7 +312,7 @@ def check_context_limit(self, messages: List[Dict], model: str,
return fits, tokens_used, tokens_left, self.context_size

def generate_response(self, messages: List[Dict], model: str,
max_new_tokens: int = 100, return_full_text: bool = False,
return_full_text: bool = False,
log_messages: bool = False) -> Tuple[Any, Any, str]:
"""
:param messages: for example
Expand All @@ -322,7 +323,6 @@ def generate_response(self, messages: List[Dict], model: str,
{"role": "user", "content": "Where was it played?"}
]
:param model: model name
:param max_new_tokens: How many tokens to generate ('at most', but no stop sequence is defined).
:param return_full_text: If True, whole input context is returned.
:param log_messages: If True, raw and cleaned messages passed will be logged.
:return: the continuation
Expand Down Expand Up @@ -354,11 +354,11 @@ def generate_response(self, messages: List[Dict], model: str,
prompt_tokens = prompt_tokens.to(self.device)

prompt_text = self.tokenizer.batch_decode(prompt_tokens)[0]
prompt = {"inputs": prompt_text, "max_new_tokens": max_new_tokens,
prompt = {"inputs": prompt_text, "max_new_tokens": self.max_tokens,
"temperature": self.temperature, "return_full_text": return_full_text}

# check context limit:
context_check = self._check_context_limit(prompt_tokens[0], max_new_tokens=max_new_tokens)
context_check = self._check_context_limit(prompt_tokens[0], max_new_tokens=self.max_tokens)
if not context_check[0]: # if context is exceeded, context_check[0] is False
logger.info(f"Context token limit for {self.model_name} exceeded: {context_check[1]}/{context_check[3]}")
# fail gracefully:
Expand All @@ -375,13 +375,13 @@ def generate_response(self, messages: List[Dict], model: str,
model_output_ids = self.model.generate(
prompt_tokens,
temperature=self.temperature,
max_new_tokens=max_new_tokens,
max_new_tokens=self.max_tokens,
do_sample=do_sample
)
else:
model_output_ids = self.model.generate(
prompt_tokens,
max_new_tokens=max_new_tokens,
max_new_tokens=self.max_tokens,
do_sample=do_sample
)

Expand Down
5 changes: 2 additions & 3 deletions backends/mistral_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,13 @@

NAME = "mistral"

MAX_TOKENS = 100

class Mistral(backends.Backend):

def __init__(self):
creds = backends.load_credentials(NAME)
self.client = MistralClient(api_key=creds[NAME]["api_key"])
self.temperature: float = -1.
self.max_tokens: int = 100

def list_models(self):
models = self.client.models.list()
Expand Down Expand Up @@ -50,7 +49,7 @@ def generate_response(self, messages: List[Dict], model: str) -> Tuple[str, Any,
api_response = self.client.chat(model=model,
messages=prompt,
temperature=self.temperature,
max_tokens=MAX_TOKENS)
max_tokens=self.max_tokens)
message = api_response.choices[0].message
if message.role != "assistant": # safety check
raise AttributeError("Response message role is " + message.role + " but should be 'assistant'")
Expand Down
7 changes: 3 additions & 4 deletions backends/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@

NAME = "openai"

MAX_TOKENS = 100 # 2024-01-10, das: Should this be hardcoded???

class OpenAI(backends.Backend):

def __init__(self):
Expand All @@ -35,6 +33,7 @@ def __init__(self):
)
self.chat_models: List = ["gpt-3.5-turbo-0613", "gpt-3.5-turbo-1106", "gpt-4-0314", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview"]
self.temperature: float = -1.
self.max_tokens: int = 100

def list_models(self):
models = self.client.models.list()
Expand Down Expand Up @@ -63,7 +62,7 @@ def generate_response(self, messages: List[Dict], model: str) -> Tuple[str, Any,
api_response = self.client.chat.completions.create(model=model,
messages=prompt,
temperature=self.temperature,
max_tokens=MAX_TOKENS)
max_tokens=self.max_tokens)
message = api_response.choices[0].message
if message.role != "assistant": # safety check
raise AttributeError("Response message role is " + message.role + " but should be 'assistant'")
Expand All @@ -73,7 +72,7 @@ def generate_response(self, messages: List[Dict], model: str) -> Tuple[str, Any,
else: # default (text completion)
prompt = "\n".join([message["content"] for message in messages])
api_response = self.client.completions.create(model=model, prompt=prompt,
temperature=self.temperature, max_tokens=100)
temperature=self.temperature, max_tokens=self.max_tokens)
response = json.loads(api_response.json())
response_text = api_response.choices[0].text.strip()
return prompt, response, response_text
Expand Down
5 changes: 2 additions & 3 deletions backends/openai_compatible_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@

logger = backends.get_logger(__name__)

MAX_TOKENS = 100

# For this backend, it makes less sense to talk about "supported models" than for others,
# because what is supported depends very much on where this is pointed to.
# E.g., if I run FastChat on my local machine, I may have very different models available
Expand All @@ -36,6 +34,7 @@ def __init__(self):
http_client=httpx.Client(verify=False)
)
self.temperature: float = -1.
self.max_tokens: int = 100

def list_models(self):
models = self.client.models.list()
Expand Down Expand Up @@ -63,7 +62,7 @@ def generate_response(self, messages: List[Dict], model: str) -> Tuple[str, Any,

prompt = messages
api_response = self.client.chat.completions.create(model=model, messages=prompt,
temperature=self.temperature, max_tokens=MAX_TOKENS)
temperature=self.temperature, max_tokens=self.max_tokens)
message = api_response.choices[0].message
if message.role != "assistant": # safety check
raise AttributeError("Response message role is " + message.role + " but should be 'assistant'")
Expand Down
5 changes: 3 additions & 2 deletions clemgame/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ def list_games():
stdout_logger.info(" Game: %s -> %s", game.name, game.get_description())


def run(game_name: str, temperature: float, models: List[str] = None, experiment_name: str = None):
def run(game_name: str, max_tokens: int, temperature: float, models: List[str] = None, experiment_name: str = None):
assert 0.0 <= temperature <= 1.0, "Temperature must be in [0.,1.]"
assert max_tokens > 0, "Max tokens should be larger than zero"
if experiment_name:
logger.info("Only running experiment: %s", experiment_name)
try:
Expand All @@ -33,7 +34,7 @@ def run(game_name: str, temperature: float, models: List[str] = None, experiment
if experiment_name:
benchmark.filter_experiment.append(experiment_name)
time_start = datetime.now()
benchmark.run(player_backends=models, temperature=temperature)
benchmark.run(player_backends=models, temperature=temperature, max_tokens=max_tokens)
time_end = datetime.now()
logger.info(f"Run {benchmark.name} took {str(time_end - time_start)}")
except Exception as e:
Expand Down
6 changes: 4 additions & 2 deletions clemgame/clemgame.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ def compute_scores(self):
stdout_logger.error(
f"{self.name}: '{error_count}' exceptions occurred: See clembench.log for details.")

def run(self, player_backends: List[str], temperature: float):
def run(self, player_backends: List[str], temperature: float, max_tokens: int):
"""
Runs game-play on all game instances for a game.
There must be an instances.json with the following structure:
Expand All @@ -697,9 +697,11 @@ def run(self, player_backends: List[str], temperature: float):
- instance.json
- interaction.json
"""
self.logger.warning(f"{self.name}: Detected 'temperature={temperature}'")
# Setting this directly on the apis for now (not on the players)
self.logger.warning(f"{self.name}: Detected 'temperature={temperature}'")
backends.configure(lambda backend: setattr(backend, "temperature", temperature))
self.logger.warning(f"{self.name}: Detected 'max_tokens={max_tokens}'")
backends.configure(lambda backend: setattr(backend, "max_tokens", max_tokens))

experiments: List = self.instances["experiments"]
if not experiments:
Expand Down
5 changes: 5 additions & 0 deletions scripts/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def main(args):
benchmark.list_games()
if args.command_name == "run":
benchmark.run(args.game,
max_tokens=args.max_tokens,
temperature=args.temperature,
models=args.models,
experiment_name=args.experiment_name)
Expand Down Expand Up @@ -67,6 +68,10 @@ def main(args):
Default: None.""")
run_parser.add_argument("-t", "--temperature", type=float, default=0.0,
help="Argument to specify sampling temperature for the models. Default: 0.0.")
run_parser.add_argument("-l", "--max-tokens", type=int, default=100,
help="Specify the maximum number of tokens to be generated per turn (except for cohere). "
"Be careful with high values which might lead to exceed your API token limits."
"Default: 100.")
run_parser.add_argument("-e", "--experiment_name", type=str,
help="Optional argument to only run a specific experiment")
run_parser.add_argument("-g", "--game", type=str,
Expand Down

0 comments on commit 8399a71

Please sign in to comment.