From a775779d323f8903acf64c7cec5f0189cfdc8ee9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=AD=B1=E6=96=87?= Date: Fri, 18 Oct 2024 19:08:24 +0800 Subject: [PATCH] Remove model_name, use model. --- docs/config_guide_cn.md | 6 +-- docs/config_guide_en.md | 6 +-- .../settings.toml | 8 ++-- .../settings.toml | 8 ++-- src/pai_rag/app/web/rag_client.py | 1 - src/pai_rag/app/web/view_model.py | 28 ++++++------- src/pai_rag/config/settings.toml | 8 ++-- src/pai_rag/evaluation/settings_eval.toml | 10 ++--- .../embeddings/pai/embedding_utils.py | 8 ++-- .../embeddings/pai/pai_embedding_config.py | 12 +++--- .../integrations/llms/pai/llm_config.py | 10 ++--- .../integrations/llms/pai/llm_utils.py | 24 +++++------ src/pai_rag/integrations/llms/pai/pai_llm.py | 2 +- src/pai_rag/utils/download_models.py | 42 +++++++++---------- tests/modules/agent/test_fc_agent.py | 2 +- .../intentdetection/test_intent_detection.py | 2 +- .../modules/llm/test_function_calling_llm.py | 2 +- tests/modules/llm/test_llm.py | 2 +- 18 files changed, 90 insertions(+), 91 deletions(-) diff --git a/docs/config_guide_cn.md b/docs/config_guide_cn.md index c2fe2356..fe88d599 100644 --- a/docs/config_guide_cn.md +++ b/docs/config_guide_cn.md @@ -37,16 +37,16 @@ source = [HuggingFace, OpenAI, DashScope] 目前, pai_rag 支持以上三种 embedding 源. -如果 source = "HuggingFace", 您需要进一步指定 model_name 和 embed_batch_size。默认的模型名称和批处理大小分别为 bge-large-zh-v1.5 和 10。 +如果 source = "HuggingFace", 您需要进一步指定 model 和 embed_batch_size。默认的模型名称和批处理大小分别为 bge-large-zh-v1.5 和 10。 source = "HuggingFace" - model_name = "bge-large-zh-v1.5" + model = "bge-large-zh-v1.5" embed_batch_size = 10 或者, 如果你想使用其它 huggingface 模型, 请指定如下参数: source = "HuggingFace" - model_name = "xxx" + model = "xxx" model_dir = "xxx/xxx" embed_batch_size = 20 (for example) diff --git a/docs/config_guide_en.md b/docs/config_guide_en.md index e528903d..e3adabb2 100644 --- a/docs/config_guide_en.md +++ b/docs/config_guide_en.md @@ -37,16 +37,16 @@ source = [HuggingFace, OpenAI, DashScope] Currently, pai_rag supports three embedding sources. -If source = "HuggingFace", you need to further specify model_name and embed_batch_size. The default model name and batch size are bge-large-zh-v1.5 and 10, respectively. +If source = "HuggingFace", you need to further specify model_modelname and embed_batch_size. The default model name and batch size are bge-large-zh-v1.5 and 10, respectively. source = "HuggingFace" - model_name = "bge-large-zh-v1.5" + model = "bge-large-zh-v1.5" embed_batch_size = 10 Alternatively, if you want to use other huggingface models, please specify parameters as below: source = "HuggingFace" - model_name = "xxx" + model = "xxx" model_dir = "xxx/xxx" embed_batch_size = 20 (for example) diff --git a/example_data/function_tools/api-tool-with-intent-detection-for-travel-assistant/settings.toml b/example_data/function_tools/api-tool-with-intent-detection-for-travel-assistant/settings.toml index aafa4b8b..4095d476 100644 --- a/example_data/function_tools/api-tool-with-intent-detection-for-travel-assistant/settings.toml +++ b/example_data/function_tools/api-tool-with-intent-detection-for-travel-assistant/settings.toml @@ -26,10 +26,10 @@ persist_path = "localdata/storage" type = "SimpleDirectoryReader" # embedding configurations, source support API: OpenAI,DashScope; and local model:HuggingFace -# if use API, need set OPENAI_API_KEY or DASHSCOPE_API_KEY in ENV, If HuggingFace, need set model_name +# if use API, need set OPENAI_API_KEY or DASHSCOPE_API_KEY in ENV, If HuggingFace, need set model # eg. # source = "HuggingFace" -# model_name = "bge-large-zh-v1.5" +# model = "bge-large-zh-v1.5" # embed_batch_size = 10 [rag.embedding] source = "DashScope" @@ -49,11 +49,11 @@ vector_store.type = "FAISS" # token = "" [rag.llm] source = "DashScope" -name = "qwen-turbo" +model = "qwen-turbo" [rag.llm.function_calling_llm] source = "DashScope" -name = "qwen2-7b-instruct" +model = "qwen2-7b-instruct" [rag.llm.multi_modal] source = "" diff --git a/example_data/function_tools/python-tool-for-booking-demo/settings.toml b/example_data/function_tools/python-tool-for-booking-demo/settings.toml index 1da5d480..2cb65635 100644 --- a/example_data/function_tools/python-tool-for-booking-demo/settings.toml +++ b/example_data/function_tools/python-tool-for-booking-demo/settings.toml @@ -26,10 +26,10 @@ persist_path = "localdata/storage" type = "SimpleDirectoryReader" # embedding configurations, source support API: OpenAI,DashScope; and local model:HuggingFace -# if use API, need set OPENAI_API_KEY or DASHSCOPE_API_KEY in ENV, If HuggingFace, need set model_name +# if use API, need set OPENAI_API_KEY or DASHSCOPE_API_KEY in ENV, If HuggingFace, need set model # eg. # source = "HuggingFace" -# model_name = "bge-large-zh-v1.5" +# model = "bge-large-zh-v1.5" # embed_batch_size = 10 [rag.embedding] source = "DashScope" @@ -49,11 +49,11 @@ vector_store.type = "FAISS" # token = "" [rag.llm] source = "DashScope" -name = "qwen-turbo" +model = "qwen-turbo" [rag.llm.function_calling_llm] source = "DashScope" -name = "qwen2-7b-instruct" +model = "qwen2-7b-instruct" [rag.llm.multi_modal] source = "" diff --git a/src/pai_rag/app/web/rag_client.py b/src/pai_rag/app/web/rag_client.py index 47802aec..5976e2f4 100644 --- a/src/pai_rag/app/web/rag_client.py +++ b/src/pai_rag/app/web/rag_client.py @@ -415,7 +415,6 @@ def patch_config(self, update_dict: Any): view_model: ViewModel = ViewModel.from_app_config(config) view_model.update(update_dict) new_config = view_model.to_app_config() - r = requests.patch( self.config_url, json=new_config, timeout=DEFAULT_CLIENT_TIME_OUT ) diff --git a/src/pai_rag/app/web/view_model.py b/src/pai_rag/app/web/view_model.py index 312caeb2..26e274f0 100644 --- a/src/pai_rag/app/web/view_model.py +++ b/src/pai_rag/app/web/view_model.py @@ -47,7 +47,7 @@ class ViewModel(BaseModel): llm: str = "PaiEas" llm_eas_url: str = None llm_eas_token: str = None - llm_eas_model_name: str = "model_name" + llm_eas_model_name: str = "model" llm_api_key: str = None llm_api_model_name: str = None llm_temperature: float = 0.1 @@ -57,7 +57,7 @@ class ViewModel(BaseModel): mllm: str = None mllm_eas_url: str = None mllm_eas_token: str = None - mllm_eas_model_name: str = "model_name" + mllm_eas_model_name: str = "model" mllm_api_key: str = None mllm_api_model_name: str = None @@ -189,7 +189,7 @@ def from_app_config(config): "source", view_model.embed_source ) view_model.embed_model = config["embedding"].get( - "model_name", view_model.embed_model + "model", view_model.embed_model ) view_model.embed_api_key = config["embedding"].get( "api_key", view_model.embed_api_key @@ -207,11 +207,11 @@ def from_app_config(config): ) if view_model.llm.lower() == "paieas": view_model.llm_eas_model_name = config["llm"].get( - "model_name", view_model.llm_eas_model_name + "model", view_model.llm_eas_model_name ) else: view_model.llm_api_model_name = config["llm"].get( - "model_name", view_model.llm_api_model_name + "model", view_model.llm_api_model_name ) view_model.use_mllm = config["llm"]["multi_modal"].get( @@ -233,14 +233,14 @@ def from_app_config(config): "view_model.mllm_eas_model_name", view_model.mllm_eas_model_name, "2", - config["llm"]["multi_modal"]["model_name"], + config["llm"]["multi_modal"]["model"], ) view_model.mllm_eas_model_name = config["llm"]["multi_modal"].get( - "model_name", view_model.mllm_eas_model_name + "model", view_model.mllm_eas_model_name ) else: view_model.mllm_api_model_name = config["llm"]["multi_modal"].get( - "model_name", view_model.mllm_api_model_name + "model", view_model.mllm_api_model_name ) view_model.use_oss = config["oss_store"].get("enable", view_model.use_oss) @@ -436,7 +436,7 @@ def to_app_config(self): config = recursive_dict() config["embedding"]["source"] = self.embed_source - config["embedding"]["model_name"] = self.embed_model + config["embedding"]["model"] = self.embed_model config["embedding"]["api_key"] = self.embed_api_key config["embedding"]["embed_batch_size"] = int(self.embed_batch_size) @@ -446,9 +446,9 @@ def to_app_config(self): config["llm"]["api_key"] = self.llm_api_key config["llm"]["temperature"] = self.llm_temperature if self.llm.lower() == "paieas": - config["llm"]["model_name"] = self.llm_eas_model_name + config["llm"]["model"] = self.llm_eas_model_name else: - config["llm"]["model_name"] = self.llm_api_model_name + config["llm"]["model"] = self.llm_api_model_name config["llm"]["multi_modal"]["enable"] = self.use_mllm config["llm"]["multi_modal"]["source"] = self.mllm @@ -456,9 +456,9 @@ def to_app_config(self): config["llm"]["multi_modal"]["token"] = self.mllm_eas_token config["llm"]["multi_modal"]["api_key"] = self.mllm_api_key if self.mllm.lower() == "paieas": - config["llm"]["multi_modal"]["model_name"] = self.mllm_eas_model_name + config["llm"]["multi_modal"]["model"] = self.mllm_eas_model_name else: - config["llm"]["multi_modal"]["model_name"] = self.mllm_api_model_name + config["llm"]["multi_modal"]["model"] = self.mllm_api_model_name config["oss_store"]["enable"] = self.use_oss if os.getenv("OSS_ACCESS_KEY_ID") is None and self.oss_ak: @@ -728,7 +728,7 @@ def to_component_settings(self) -> Dict[str, Dict[str, Any]]: "visible": self.llm.lower() == "paieas", } if self.llm.lower() == "paieas" and not self.llm_eas_model_name: - self.llm_eas_model_name = "model_name" + self.llm_eas_model_name = "model" settings["llm_eas_model_name"] = { "value": self.llm_eas_model_name, diff --git a/src/pai_rag/config/settings.toml b/src/pai_rag/config/settings.toml index 75e87d15..07a7054f 100644 --- a/src/pai_rag/config/settings.toml +++ b/src/pai_rag/config/settings.toml @@ -33,10 +33,10 @@ type = "local" type = "SimpleDirectoryReader" # embedding configurations, source support API: OpenAI,DashScope; and local model:HuggingFace -# if use API, need set OPENAI_API_KEY or DASHSCOPE_API_KEY in ENV, If HuggingFace, need set model_name +# if use API, need set OPENAI_API_KEY or DASHSCOPE_API_KEY in ENV, If HuggingFace, need set model # eg. # source = "HuggingFace" -# model_name = "bge-large-zh-v1.5" +# model = "bge-large-zh-v1.5" # embed_batch_size = 10 [rag.embedding] source = "DashScope" @@ -53,12 +53,12 @@ vector_store.type = "FAISS" # llm configurations, source support API: OpenAI,DashScope or PAI-EAS's deployment # eg. # source = "PaiEas" -# name = "" +# model = "" # endpoint = "" # token = "" [rag.llm] source = "DashScope" -name = "qwen-turbo" +model = "qwen-turbo" [rag.llm.function_calling_llm] source = "" diff --git a/src/pai_rag/evaluation/settings_eval.toml b/src/pai_rag/evaluation/settings_eval.toml index 763c6d09..e5b2114f 100644 --- a/src/pai_rag/evaluation/settings_eval.toml +++ b/src/pai_rag/evaluation/settings_eval.toml @@ -33,10 +33,10 @@ type = "local" type = "SimpleDirectoryReader" # embedding configurations, source support API: OpenAI,DashScope; and local model:HuggingFace -# if use API, need set OPENAI_API_KEY or DASHSCOPE_API_KEY in ENV, If HuggingFace, need set model_name +# if use API, need set OPENAI_API_KEY or DASHSCOPE_API_KEY in ENV, If HuggingFace, need set model # eg. # source = "HuggingFace" -# model_name = "bge-large-zh-v1.5" +# model = "bge-large-zh-v1.5" # embed_batch_size = 10 [rag.embedding] source = "DashScope" @@ -53,12 +53,12 @@ vector_store.type = "FAISS" # llm configurations, source support API: OpenAI,DashScope or PAI-EAS's deployment # eg. # source = "PaiEas" -# name = "" +# model = "" # endpoint = "" # token = "" [rag.llm] source = "DashScope" -name = "qwen-turbo" +model = "qwen-turbo" [rag.llm.function_calling_llm] source = "" @@ -66,7 +66,7 @@ source = "" [rag.llm.multi_modal] enable = true source = "DashScope" -name = "qwen-vl-plus" +model = "qwen-vl-plus" [rag.node_enhancement] tree_depth = 3 diff --git a/src/pai_rag/integrations/embeddings/pai/embedding_utils.py b/src/pai_rag/integrations/embeddings/pai/embedding_utils.py index 0162382e..4e5b3300 100644 --- a/src/pai_rag/integrations/embeddings/pai/embedding_utils.py +++ b/src/pai_rag/integrations/embeddings/pai/embedding_utils.py @@ -36,22 +36,22 @@ def create_embedding(embed_config: PaiBaseEmbeddingConfig): elif isinstance(embed_config, HuggingFaceEmbeddingConfig): pai_model_dir = os.getenv("PAI_RAG_MODEL_DIR", "./model_repository") embed_model = HuggingFaceEmbedding( - model_name=os.path.join(pai_model_dir, embed_config.model_name), + model_name=os.path.join(pai_model_dir, embed_config.model), embed_batch_size=embed_config.embed_batch_size, trust_remote_code=True, ) logger.info( - f"Initialized HuggingFace embedding model {embed_config.model_name} from model_dir_path {pai_model_dir} with {embed_config.embed_batch_size} batch size." + f"Initialized HuggingFace embedding model {embed_config.model} from model_dir_path {pai_model_dir} with {embed_config.embed_batch_size} batch size." ) elif isinstance(embed_config, CnClipEmbeddingConfig): embed_model = CnClipEmbedding( - model_name=embed_config.model_name, + model_name=embed_config.model, embed_batch_size=embed_config.embed_batch_size, ) logger.info( - f"Initialized CnClip embedding model {embed_config.model_name} with {embed_config.embed_batch_size} batch size." + f"Initialized CnClip embedding model {embed_config.model} with {embed_config.embed_batch_size} batch size." ) else: diff --git a/src/pai_rag/integrations/embeddings/pai/pai_embedding_config.py b/src/pai_rag/integrations/embeddings/pai/pai_embedding_config.py index 077b8a67..ea7a03c7 100644 --- a/src/pai_rag/integrations/embeddings/pai/pai_embedding_config.py +++ b/src/pai_rag/integrations/embeddings/pai/pai_embedding_config.py @@ -13,7 +13,7 @@ class SupportedEmbedType(str, Enum): class PaiBaseEmbeddingConfig(BaseModel): source: SupportedEmbedType - model_name: str + model: str embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE class Config: @@ -36,24 +36,24 @@ def validate_case_insensitive(cls, value): class DashScopeEmbeddingConfig(PaiBaseEmbeddingConfig): source: Literal[SupportedEmbedType.dashscope] = SupportedEmbedType.dashscope - model_name: str | None = None # use default + model: str | None = None # use default api_key: str | None = None # use default class OpenAIEmbeddingConfig(PaiBaseEmbeddingConfig): source: Literal[SupportedEmbedType.openai] = SupportedEmbedType.openai - model_name: str | None = None # use default + model: str | None = None # use default api_key: str | None = None # use default class HuggingFaceEmbeddingConfig(PaiBaseEmbeddingConfig): source: Literal[SupportedEmbedType.huggingface] = SupportedEmbedType.huggingface - model_name: str | None = "bge-large-zh-v1.5" + model: str | None = "bge-large-zh-v1.5" class CnClipEmbeddingConfig(PaiBaseEmbeddingConfig): source: Literal[SupportedEmbedType.cnclip] = SupportedEmbedType.cnclip - model_name: str | None = "ViT-L-14" + model: str | None = "ViT-L-14" SupporttedEmbeddingClsMap = { @@ -73,6 +73,6 @@ def parse_embed_config(config_data): if __name__ == "__main__": - embedding_config_data = {"source": "Openai", "model_name": "gpt-1", "api_key": None} + embedding_config_data = {"source": "Openai", "model": "gpt-1", "api_key": None} print(parse_embed_config(embedding_config_data)) diff --git a/src/pai_rag/integrations/llms/pai/llm_config.py b/src/pai_rag/integrations/llms/pai/llm_config.py index 4ca108e8..127a1252 100644 --- a/src/pai_rag/integrations/llms/pai/llm_config.py +++ b/src/pai_rag/integrations/llms/pai/llm_config.py @@ -126,7 +126,7 @@ class PaiBaseLlmConfig(BaseModel): temperature: float = DEFAULT_TEMPERATURE system_prompt: str = None max_tokens: int = DEFAULT_MAX_TOKENS - model_name: str = None + model: str = None @classmethod def get_subclasses(cls): @@ -150,20 +150,20 @@ class DashScopeLlmConfig(PaiBaseLlmConfig): source: Literal[SupportedLlmType.dashscope] = SupportedLlmType.dashscope api_key: str | None = None base_url: str = "https://dashscope.aliyuncs.com/compatible-mode/v1" - model_name: str = "qwen-turbo" + model: str = "qwen-turbo" class OpenAILlmConfig(PaiBaseLlmConfig): source: Literal[SupportedLlmType.openai] = SupportedLlmType.openai api_key: str | None = None - model_name: str = "gpt-3.5-turbo" + model: str = "gpt-3.5-turbo" class PaiEasLlmConfig(PaiBaseLlmConfig): source: Literal[SupportedLlmType.paieas] = SupportedLlmType.paieas endpoint: str token: str - model_name: str = "default" + model: str = "default" SupporttedLlmClsMap = {cls.get_type(): cls for cls in PaiBaseLlmConfig.get_subclasses()} @@ -183,7 +183,7 @@ def parse_llm_config(config_data): if __name__ == "__main__": llm_config_data = { "source": "dashscope", - "model_name": "qwen-turbo", + "model": "qwen-turbo", "api_key": None, "max_tokens": 1024, } diff --git a/src/pai_rag/integrations/llms/pai/llm_utils.py b/src/pai_rag/integrations/llms/pai/llm_utils.py index a82629a9..11a390a5 100644 --- a/src/pai_rag/integrations/llms/pai/llm_utils.py +++ b/src/pai_rag/integrations/llms/pai/llm_utils.py @@ -21,13 +21,13 @@ def create_llm(llm_config: PaiBaseLlmConfig): logger.info( f""" [Parameters][LLM:OpenAI] - model = {llm_config.model_name}, + model = {llm_config.model}, temperature = {llm_config.temperature}, system_prompt = {llm_config.system_prompt} """ ) llm = OpenAI( - model=llm_config.model_name, + model=llm_config.model, temperature=llm_config.temperature, system_prompt=llm_config.system_prompt, api_key=llm_config.api_key, @@ -38,13 +38,13 @@ def create_llm(llm_config: PaiBaseLlmConfig): logger.info( f""" [Parameters][LLM:DashScope] - model = {llm_config.model_name}, + model = {llm_config.model}, temperature = {llm_config.temperature}, system_prompt = {llm_config.system_prompt} """ ) llm = OpenAILike( - model=llm_config.model_name, + model=llm_config.model, api_base=llm_config.base_url, temperature=llm_config.temperature, system_prompt=llm_config.system_prompt, @@ -57,13 +57,13 @@ def create_llm(llm_config: PaiBaseLlmConfig): logger.info( f""" [Parameters][LLM:PAI-EAS] - model = {llm_config.model_name}, + model = {llm_config.model}, endpoint = {llm_config.endpoint}, token = {llm_config.token} """ ) llm = OpenAILike( - model=llm_config.model_name, + model=llm_config.model, api_base=llm_config.endpoint, temperature=llm_config.temperature, system_prompt=llm_config.system_prompt, @@ -82,13 +82,13 @@ def create_multi_modal_llm(llm_config: PaiBaseLlmConfig): logger.info( f""" [Parameters][LLM:OpenAI] - model = {llm_config.model_name}, + model = {llm_config.model}, temperature = {llm_config.temperature}, system_prompt = {llm_config.system_prompt} """ ) llm = OpenAIMultiModal( - model=llm_config.model_name, + model=llm_config.model, temperature=llm_config.temperature, system_prompt=llm_config.system_prompt, api_key=llm_config.api_key, @@ -98,13 +98,13 @@ def create_multi_modal_llm(llm_config: PaiBaseLlmConfig): logger.info( f""" [Parameters][LLM:DashScope] - model = {llm_config.model_name}, + model = {llm_config.model}, temperature = {llm_config.temperature}, system_prompt = {llm_config.system_prompt} """ ) llm = OpenAIAlikeMultiModal( - model=llm_config.model_name, + model=llm_config.model, api_base=llm_config.base_url, temperature=llm_config.temperature, system_prompt=llm_config.system_prompt, @@ -116,13 +116,13 @@ def create_multi_modal_llm(llm_config: PaiBaseLlmConfig): logger.info( f""" [Parameters][LLM:PAI-EAS] - model = {llm_config.model_name}, + model = {llm_config.model}, endpoint = {llm_config.endpoint}, token = {llm_config.token} """ ) llm = OpenAIAlikeMultiModal( - model=llm_config.model_name, + model=llm_config.model, api_base=llm_config.endpoint, temperature=llm_config.temperature, system_prompt=llm_config.system_prompt, diff --git a/src/pai_rag/integrations/llms/pai/pai_llm.py b/src/pai_rag/integrations/llms/pai/pai_llm.py index 73a0c211..1d5cc4a9 100644 --- a/src/pai_rag/integrations/llms/pai/pai_llm.py +++ b/src/pai_rag/integrations/llms/pai/pai_llm.py @@ -34,7 +34,7 @@ class PaiLlm(OpenAILike): def __init__(self, llm_config: PaiBaseLlmConfig): super().__init__() self._llm = create_llm(llm_config) - self.model = llm_config.model_name + self.model = llm_config.model self._llm.callback_manager = Settings.callback_manager self.callback_manager = Settings.callback_manager diff --git a/src/pai_rag/utils/download_models.py b/src/pai_rag/utils/download_models.py index 066b89cd..56bb7ec4 100644 --- a/src/pai_rag/utils/download_models.py +++ b/src/pai_rag/utils/download_models.py @@ -26,25 +26,25 @@ def __init__(self): self.model_info = response.json() logger.info(f"Model info loaded {self.model_info}.") - def load_model(self, model_name): - model_path = os.path.join(self.download_directory_path, model_name) + def load_model(self, model): + model_path = os.path.join(self.download_directory_path, model) with TemporaryDirectory() as temp_dir: if not os.path.exists(model_path): - logger.info(f"start downloading model {model_name}.") + logger.info(f"start downloading model {model}.") start_time = time.time() - if model_name in self.model_info["basic_models"]: - model_id = self.model_info["basic_models"][model_name] - elif model_name in self.model_info["extra_models"]: - model_id = self.model_info["extra_models"][model_name] + if model in self.model_info["basic_models"]: + model_id = self.model_info["basic_models"][model] + elif model in self.model_info["extra_models"]: + model_id = self.model_info["extra_models"][model] else: - raise ValueError(f"{model_name} is not a valid model name.") + raise ValueError(f"{model} is not a valid model name.") temp_model_dir = snapshot_download(model_id, cache_dir=temp_dir) shutil.move(temp_model_dir, model_path) end_time = time.time() duration = end_time - start_time logger.info( - f"Finished downloading model {model_name} to {model_path}, took {duration:.2f} seconds." + f"Finished downloading model {model} to {model_path}, took {duration:.2f} seconds." ) def load_basic_models(self): @@ -52,8 +52,8 @@ def load_basic_models(self): response = requests.get(OSS_URL) response.raise_for_status() self.model_info = response.json() - for model_name in self.model_info["basic_models"].keys(): - self.load_model(model_name) + for model in self.model_info["basic_models"].keys(): + self.load_model(model) def load_mineru_config(self): source_path = "magic-pdf.template.json" @@ -80,15 +80,15 @@ def load_mineru_config(self): "Copy magic-pdf.template.json to ~/magic-pdf.json and modify models-dir to model path." ) - def load_models(self, model_name): - if model_name is None: - model_names = [ - model_name for model_name in self.model_info["basic_models"].keys() - ] + [model_name for model_name in self.model_info["extra_models"].keys()] - for model_name in model_names: - self.load_model(model_name) + def load_models(self, model): + if model is None: + models = [model for model in self.model_info["basic_models"].keys()] + [ + model for model in self.model_info["extra_models"].keys() + ] + for model in models: + self.load_model(model) else: - self.load_model(model_name) + self.load_model(model) @click.command() @@ -100,7 +100,7 @@ def load_models(self, model_name): help="model name. Default: download all models provided", default=None, ) -def load_models(model_name): +def load_models(model): download_models = ModelScopeDownloader() - download_models.load_models(model_name) + download_models.load_models(model) download_models.load_mineru_config() diff --git a/tests/modules/agent/test_fc_agent.py b/tests/modules/agent/test_fc_agent.py index 0022e8c0..410b1e07 100644 --- a/tests/modules/agent/test_fc_agent.py +++ b/tests/modules/agent/test_fc_agent.py @@ -5,7 +5,7 @@ from pai_rag.integrations.llms.pai.pai_llm import PaiLlm from pai_rag.integrations.llms.pai.llm_config import DashScopeLlmConfig -fc_llm_config = DashScopeLlmConfig(model_name="qwen-max") +fc_llm_config = DashScopeLlmConfig(model="qwen-max") fc_llm = PaiLlm(fc_llm_config) diff --git a/tests/modules/intentdetection/test_intent_detection.py b/tests/modules/intentdetection/test_intent_detection.py index 26e09e85..25a0d86a 100644 --- a/tests/modules/intentdetection/test_intent_detection.py +++ b/tests/modules/intentdetection/test_intent_detection.py @@ -3,7 +3,7 @@ from pai_rag.integrations.llms.pai.pai_llm import PaiLlm from pai_rag.integrations.llms.pai.llm_config import DashScopeLlmConfig -fc_llm_config = DashScopeLlmConfig(model_name="qwen-max") +fc_llm_config = DashScopeLlmConfig(model="qwen-max") fc_llm = PaiLlm(fc_llm_config) intents = { diff --git a/tests/modules/llm/test_function_calling_llm.py b/tests/modules/llm/test_function_calling_llm.py index faffb3d3..75c0724b 100644 --- a/tests/modules/llm/test_function_calling_llm.py +++ b/tests/modules/llm/test_function_calling_llm.py @@ -2,7 +2,7 @@ from pai_rag.integrations.llms.pai.pai_llm import PaiLlm from pai_rag.integrations.llms.pai.llm_config import DashScopeLlmConfig -fc_llm_config = DashScopeLlmConfig(model_name="qwen-max") +fc_llm_config = DashScopeLlmConfig(model="qwen-max") fc_llm = PaiLlm(fc_llm_config) diff --git a/tests/modules/llm/test_llm.py b/tests/modules/llm/test_llm.py index fd22dcb3..0d50128c 100644 --- a/tests/modules/llm/test_llm.py +++ b/tests/modules/llm/test_llm.py @@ -6,7 +6,7 @@ @pytest.fixture(scope="module", autouse=True) def llm(): - llm_config = DashScopeLlmConfig(model_name="qwen-turbo") + llm_config = DashScopeLlmConfig(model="qwen-turbo") return PaiLlm(llm_config)