Skip to content

Commit

Permalink
(v1.0+ breaking change) get_max_tokens -> return int
Browse files Browse the repository at this point in the history
  • Loading branch information
ishaan-jaff committed Nov 17, 2023
1 parent c162f8b commit bd82559
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 9 deletions.
1 change: 1 addition & 0 deletions litellm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ def identify(event_details):
acreate,
get_model_list,
get_max_tokens,
get_model_info,
register_prompt_template,
validate_environment,
check_valid_key,
Expand Down
17 changes: 9 additions & 8 deletions litellm/tests/test_get_model_cost_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,21 @@

def test_get_gpt3_tokens():
max_tokens = get_max_tokens("gpt-3.5-turbo")
results = max_tokens['max_tokens']
print(results)
# test_get_gpt3_tokens()
print(max_tokens)
assert max_tokens==4097
# print(results)
test_get_gpt3_tokens()

def test_get_palm_tokens():
# # 🦄🦄🦄🦄🦄🦄🦄🦄
max_tokens = get_max_tokens("palm/chat-bison")
results = max_tokens['max_tokens']
print(results)
# test_get_palm_tokens()
assert max_tokens == 4096
print(max_tokens)
test_get_palm_tokens()

def test_zephyr_hf_tokens():
max_tokens = get_max_tokens("huggingface/HuggingFaceH4/zephyr-7b-beta")
results = max_tokens["max_tokens"]
print(results)
print(max_tokens)
assert max_tokens == 32768

test_zephyr_hf_tokens()
54 changes: 53 additions & 1 deletion litellm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2358,6 +2358,58 @@ def get_api_key(llm_provider: str, dynamic_api_key: Optional[str]):
return api_key

def get_max_tokens(model: str):
"""
Get the maximum number of tokens allowed for a given model.
Parameters:
model (str): The name of the model.
Returns:
int: The maximum number of tokens allowed for the given model.
Raises:
Exception: If the model is not mapped yet.
Example:
>>> get_max_tokens("gpt-4")
8192
"""
def _get_max_position_embeddings(model_name):
# Construct the URL for the config.json file
config_url = f"https://huggingface.co/{model_name}/raw/main/config.json"

try:
# Make the HTTP request to get the raw JSON file
response = requests.get(config_url)
response.raise_for_status() # Raise an exception for bad responses (4xx or 5xx)

# Parse the JSON response
config_json = response.json()

# Extract and return the max_position_embeddings
max_position_embeddings = config_json.get("max_position_embeddings")

if max_position_embeddings is not None:
return max_position_embeddings
else:
return None
except requests.exceptions.RequestException as e:
return None

try:
if model in litellm.model_cost:
return litellm.model_cost[model]["max_tokens"]
model, custom_llm_provider, _, _ = get_llm_provider(model=model)
if custom_llm_provider == "huggingface":
max_tokens = _get_max_position_embeddings(model_name=model)
return max_tokens
else:
raise Exception()
except:
raise Exception("This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json")


def get_model_info(model: str):
"""
Get a dict for the maximum tokens (context window),
input_cost_per_token, output_cost_per_token for a given model.
Expand All @@ -2377,7 +2429,7 @@ def get_max_tokens(model: str):
Exception: If the model is not mapped yet.
Example:
>>> get_max_tokens("gpt-4")
>>> get_model_info("gpt-4")
{
"max_tokens": 8192,
"input_cost_per_token": 0.00003,
Expand Down

0 comments on commit bd82559

Please sign in to comment.