diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 8d7c524a411d..f4f3a1e586db 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -3005,13 +3005,13 @@ def model_list( This is just for compatibility with openai projects like aider. """ - global llm_model_list, general_settings + global llm_model_list, general_settings, llm_router all_models = [] ## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ## - if llm_model_list is None: + if llm_router is None: proxy_model_list = [] else: - proxy_model_list = [m["model_name"] for m in llm_model_list] + proxy_model_list = llm_router.get_model_names() key_models = get_key_models( user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list ) @@ -7503,10 +7503,11 @@ async def model_info_v1( all_models: List[dict] = [] ## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ## - if llm_model_list is None: + if llm_router is None: proxy_model_list = [] else: - proxy_model_list = [m["model_name"] for m in llm_model_list] + proxy_model_list = llm_router.get_model_names() + key_models = get_key_models( user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list ) @@ -7523,8 +7524,14 @@ async def model_info_v1( if len(all_models_str) > 0: model_names = all_models_str - _relevant_models = [m for m in llm_model_list if m["model_name"] in model_names] - all_models = copy.deepcopy(_relevant_models) + llm_model_list = llm_router.get_model_list() + if llm_model_list is not None: + _relevant_models = [ + m for m in llm_model_list if m["model_name"] in model_names + ] + all_models = copy.deepcopy(_relevant_models) # type: ignore + else: + all_models = [] for model in all_models: # provided model_info in config.yaml @@ -7590,12 +7597,12 @@ async def model_group_info( raise HTTPException( status_code=500, detail={"error": "LLM Router is not loaded in"} ) - all_models: List[dict] = [] ## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ## - if llm_model_list is None: + if llm_router is None: proxy_model_list = [] else: - proxy_model_list = [m["model_name"] for m in llm_model_list] + proxy_model_list = llm_router.get_model_names() + key_models = get_key_models( user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list ) diff --git a/litellm/router.py b/litellm/router.py index 233331e8004a..bcd0b6221de0 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -86,6 +86,7 @@ Deployment, DeploymentTypedDict, LiteLLM_Params, + LiteLLMParamsTypedDict, ModelGroupInfo, ModelInfo, RetryPolicy, @@ -4297,7 +4298,9 @@ def get_model_info(self, id: str) -> Optional[dict]: return model return None - def get_model_group_info(self, model_group: str) -> Optional[ModelGroupInfo]: + def _set_model_group_info( + self, model_group: str, user_facing_model_group_name: str + ) -> Optional[ModelGroupInfo]: """ For a given model group name, return the combined model info @@ -4379,7 +4382,7 @@ def get_model_group_info(self, model_group: str) -> Optional[ModelGroupInfo]: if model_group_info is None: model_group_info = ModelGroupInfo( - model_group=model_group, providers=[llm_provider], **model_info # type: ignore + model_group=user_facing_model_group_name, providers=[llm_provider], **model_info # type: ignore ) else: # if max_input_tokens > curr @@ -4464,6 +4467,26 @@ def get_model_group_info(self, model_group: str) -> Optional[ModelGroupInfo]: return model_group_info + def get_model_group_info(self, model_group: str) -> Optional[ModelGroupInfo]: + """ + For a given model group name, return the combined model info + + Returns: + - ModelGroupInfo if able to construct a model group + - None if error constructing model group info + """ + ## Check if model group alias + if model_group in self.model_group_alias: + return self._set_model_group_info( + model_group=self.model_group_alias[model_group], + user_facing_model_group_name=model_group, + ) + + ## Check if actual model + return self._set_model_group_info( + model_group=model_group, user_facing_model_group_name=model_group + ) + async def get_model_group_usage( self, model_group: str ) -> Tuple[Optional[int], Optional[int]]: @@ -4534,19 +4557,35 @@ def get_model_ids(self, model_name: Optional[str] = None) -> List[str]: return ids def get_model_names(self) -> List[str]: - return self.model_names + """ + Returns all possible model names for router. + + Includes model_group_alias models too. + """ + return self.model_names + list(self.model_group_alias.keys()) def get_model_list( self, model_name: Optional[str] = None ) -> Optional[List[DeploymentTypedDict]]: if hasattr(self, "model_list"): + returned_models: List[DeploymentTypedDict] = [] + + for model_alias, model_value in self.model_group_alias.items(): + model_alias_item = DeploymentTypedDict( + model_name=model_alias, + litellm_params=LiteLLMParamsTypedDict(model=model_value), + ) + returned_models.append(model_alias_item) + if model_name is None: - return self.model_list + returned_models += self.model_list + + return returned_models - returned_models: List[DeploymentTypedDict] = [] for model in self.model_list: if model["model_name"] == model_name: returned_models.append(model) + return returned_models return None