Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

All model_group_alias should show up in /models, /model/info , /model_group/info #5539

Merged
27 changes: 17 additions & 10 deletions litellm/proxy/proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3005,13 +3005,13 @@ def model_list(

This is just for compatibility with openai projects like aider.
"""
global llm_model_list, general_settings
global llm_model_list, general_settings, llm_router
all_models = []
## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ##
if llm_model_list is None:
if llm_router is None:
proxy_model_list = []
else:
proxy_model_list = [m["model_name"] for m in llm_model_list]
proxy_model_list = llm_router.get_model_names()
key_models = get_key_models(
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
)
Expand Down Expand Up @@ -7503,10 +7503,11 @@ async def model_info_v1(

all_models: List[dict] = []
## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ##
if llm_model_list is None:
if llm_router is None:
proxy_model_list = []
else:
proxy_model_list = [m["model_name"] for m in llm_model_list]
proxy_model_list = llm_router.get_model_names()

key_models = get_key_models(
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
)
Expand All @@ -7523,8 +7524,14 @@ async def model_info_v1(

if len(all_models_str) > 0:
model_names = all_models_str
_relevant_models = [m for m in llm_model_list if m["model_name"] in model_names]
all_models = copy.deepcopy(_relevant_models)
llm_model_list = llm_router.get_model_list()
if llm_model_list is not None:
_relevant_models = [
m for m in llm_model_list if m["model_name"] in model_names
]
all_models = copy.deepcopy(_relevant_models) # type: ignore
else:
all_models = []

for model in all_models:
# provided model_info in config.yaml
Expand Down Expand Up @@ -7590,12 +7597,12 @@ async def model_group_info(
raise HTTPException(
status_code=500, detail={"error": "LLM Router is not loaded in"}
)
all_models: List[dict] = []
## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ##
if llm_model_list is None:
if llm_router is None:
proxy_model_list = []
else:
proxy_model_list = [m["model_name"] for m in llm_model_list]
proxy_model_list = llm_router.get_model_names()

key_models = get_key_models(
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
)
Expand Down
49 changes: 44 additions & 5 deletions litellm/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
Deployment,
DeploymentTypedDict,
LiteLLM_Params,
LiteLLMParamsTypedDict,
ModelGroupInfo,
ModelInfo,
RetryPolicy,
Expand Down Expand Up @@ -4297,7 +4298,9 @@ def get_model_info(self, id: str) -> Optional[dict]:
return model
return None

def get_model_group_info(self, model_group: str) -> Optional[ModelGroupInfo]:
def _set_model_group_info(
self, model_group: str, user_facing_model_group_name: str
) -> Optional[ModelGroupInfo]:
"""
For a given model group name, return the combined model info

Expand Down Expand Up @@ -4379,7 +4382,7 @@ def get_model_group_info(self, model_group: str) -> Optional[ModelGroupInfo]:

if model_group_info is None:
model_group_info = ModelGroupInfo(
model_group=model_group, providers=[llm_provider], **model_info # type: ignore
model_group=user_facing_model_group_name, providers=[llm_provider], **model_info # type: ignore
)
else:
# if max_input_tokens > curr
Expand Down Expand Up @@ -4464,6 +4467,26 @@ def get_model_group_info(self, model_group: str) -> Optional[ModelGroupInfo]:

return model_group_info

def get_model_group_info(self, model_group: str) -> Optional[ModelGroupInfo]:
"""
For a given model group name, return the combined model info

Returns:
- ModelGroupInfo if able to construct a model group
- None if error constructing model group info
"""
## Check if model group alias
if model_group in self.model_group_alias:
return self._set_model_group_info(
model_group=self.model_group_alias[model_group],
user_facing_model_group_name=model_group,
)

## Check if actual model
return self._set_model_group_info(
model_group=model_group, user_facing_model_group_name=model_group
)

async def get_model_group_usage(
self, model_group: str
) -> Tuple[Optional[int], Optional[int]]:
Expand Down Expand Up @@ -4534,19 +4557,35 @@ def get_model_ids(self, model_name: Optional[str] = None) -> List[str]:
return ids

def get_model_names(self) -> List[str]:
return self.model_names
"""
Returns all possible model names for router.

Includes model_group_alias models too.
"""
return self.model_names + list(self.model_group_alias.keys())

def get_model_list(
self, model_name: Optional[str] = None
) -> Optional[List[DeploymentTypedDict]]:
if hasattr(self, "model_list"):
returned_models: List[DeploymentTypedDict] = []

for model_alias, model_value in self.model_group_alias.items():
model_alias_item = DeploymentTypedDict(
model_name=model_alias,
litellm_params=LiteLLMParamsTypedDict(model=model_value),
)
returned_models.append(model_alias_item)

if model_name is None:
return self.model_list
returned_models += self.model_list

return returned_models

returned_models: List[DeploymentTypedDict] = []
for model in self.model_list:
if model["model_name"] == model_name:
returned_models.append(model)

return returned_models
return None

Expand Down