Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove model_name, use model. #248

Merged
merged 1 commit into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/config_guide_cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,16 @@ source = [HuggingFace, OpenAI, DashScope]

目前, pai_rag 支持以上三种 embedding 源.

如果 source = "HuggingFace", 您需要进一步指定 model_name 和 embed_batch_size。默认的模型名称和批处理大小分别为 bge-large-zh-v1.5 和 10。
如果 source = "HuggingFace", 您需要进一步指定 model 和 embed_batch_size。默认的模型名称和批处理大小分别为 bge-large-zh-v1.5 和 10。

source = "HuggingFace"
model_name = "bge-large-zh-v1.5"
model = "bge-large-zh-v1.5"
embed_batch_size = 10

或者, 如果你想使用其它 huggingface 模型, 请指定如下参数:

source = "HuggingFace"
model_name = "xxx"
model = "xxx"
model_dir = "xxx/xxx"
embed_batch_size = 20 (for example)

Expand Down
6 changes: 3 additions & 3 deletions docs/config_guide_en.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,16 @@ source = [HuggingFace, OpenAI, DashScope]

Currently, pai_rag supports three embedding sources.

If source = "HuggingFace", you need to further specify model_name and embed_batch_size. The default model name and batch size are bge-large-zh-v1.5 and 10, respectively.
If source = "HuggingFace", you need to further specify model_modelname and embed_batch_size. The default model name and batch size are bge-large-zh-v1.5 and 10, respectively.

source = "HuggingFace"
model_name = "bge-large-zh-v1.5"
model = "bge-large-zh-v1.5"
embed_batch_size = 10

Alternatively, if you want to use other huggingface models, please specify parameters as below:

source = "HuggingFace"
model_name = "xxx"
model = "xxx"
model_dir = "xxx/xxx"
embed_batch_size = 20 (for example)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ persist_path = "localdata/storage"
type = "SimpleDirectoryReader"

# embedding configurations, source support API: OpenAI,DashScope; and local model:HuggingFace
# if use API, need set OPENAI_API_KEY or DASHSCOPE_API_KEY in ENV, If HuggingFace, need set model_name
# if use API, need set OPENAI_API_KEY or DASHSCOPE_API_KEY in ENV, If HuggingFace, need set model
# eg.
# source = "HuggingFace"
# model_name = "bge-large-zh-v1.5"
# model = "bge-large-zh-v1.5"
# embed_batch_size = 10
[rag.embedding]
source = "DashScope"
Expand All @@ -49,11 +49,11 @@ vector_store.type = "FAISS"
# token = ""
[rag.llm]
source = "DashScope"
name = "qwen-turbo"
model = "qwen-turbo"

[rag.llm.function_calling_llm]
source = "DashScope"
name = "qwen2-7b-instruct"
model = "qwen2-7b-instruct"

[rag.llm.multi_modal]
source = ""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ persist_path = "localdata/storage"
type = "SimpleDirectoryReader"

# embedding configurations, source support API: OpenAI,DashScope; and local model:HuggingFace
# if use API, need set OPENAI_API_KEY or DASHSCOPE_API_KEY in ENV, If HuggingFace, need set model_name
# if use API, need set OPENAI_API_KEY or DASHSCOPE_API_KEY in ENV, If HuggingFace, need set model
# eg.
# source = "HuggingFace"
# model_name = "bge-large-zh-v1.5"
# model = "bge-large-zh-v1.5"
# embed_batch_size = 10
[rag.embedding]
source = "DashScope"
Expand All @@ -49,11 +49,11 @@ vector_store.type = "FAISS"
# token = ""
[rag.llm]
source = "DashScope"
name = "qwen-turbo"
model = "qwen-turbo"

[rag.llm.function_calling_llm]
source = "DashScope"
name = "qwen2-7b-instruct"
model = "qwen2-7b-instruct"

[rag.llm.multi_modal]
source = ""
Expand Down
1 change: 0 additions & 1 deletion src/pai_rag/app/web/rag_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,6 @@ def patch_config(self, update_dict: Any):
view_model: ViewModel = ViewModel.from_app_config(config)
view_model.update(update_dict)
new_config = view_model.to_app_config()

r = requests.patch(
self.config_url, json=new_config, timeout=DEFAULT_CLIENT_TIME_OUT
)
Expand Down
28 changes: 14 additions & 14 deletions src/pai_rag/app/web/view_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class ViewModel(BaseModel):
llm: str = "PaiEas"
llm_eas_url: str = None
llm_eas_token: str = None
llm_eas_model_name: str = "model_name"
llm_eas_model_name: str = "model"
llm_api_key: str = None
llm_api_model_name: str = None
llm_temperature: float = 0.1
Expand All @@ -57,7 +57,7 @@ class ViewModel(BaseModel):
mllm: str = None
mllm_eas_url: str = None
mllm_eas_token: str = None
mllm_eas_model_name: str = "model_name"
mllm_eas_model_name: str = "model"
mllm_api_key: str = None
mllm_api_model_name: str = None

Expand Down Expand Up @@ -189,7 +189,7 @@ def from_app_config(config):
"source", view_model.embed_source
)
view_model.embed_model = config["embedding"].get(
"model_name", view_model.embed_model
"model", view_model.embed_model
)
view_model.embed_api_key = config["embedding"].get(
"api_key", view_model.embed_api_key
Expand All @@ -207,11 +207,11 @@ def from_app_config(config):
)
if view_model.llm.lower() == "paieas":
view_model.llm_eas_model_name = config["llm"].get(
"model_name", view_model.llm_eas_model_name
"model", view_model.llm_eas_model_name
)
else:
view_model.llm_api_model_name = config["llm"].get(
"model_name", view_model.llm_api_model_name
"model", view_model.llm_api_model_name
)

view_model.use_mllm = config["llm"]["multi_modal"].get(
Expand All @@ -233,14 +233,14 @@ def from_app_config(config):
"view_model.mllm_eas_model_name",
view_model.mllm_eas_model_name,
"2",
config["llm"]["multi_modal"]["model_name"],
config["llm"]["multi_modal"]["model"],
)
view_model.mllm_eas_model_name = config["llm"]["multi_modal"].get(
"model_name", view_model.mllm_eas_model_name
"model", view_model.mllm_eas_model_name
)
else:
view_model.mllm_api_model_name = config["llm"]["multi_modal"].get(
"model_name", view_model.mllm_api_model_name
"model", view_model.mllm_api_model_name
)

view_model.use_oss = config["oss_store"].get("enable", view_model.use_oss)
Expand Down Expand Up @@ -436,7 +436,7 @@ def to_app_config(self):
config = recursive_dict()

config["embedding"]["source"] = self.embed_source
config["embedding"]["model_name"] = self.embed_model
config["embedding"]["model"] = self.embed_model
config["embedding"]["api_key"] = self.embed_api_key
config["embedding"]["embed_batch_size"] = int(self.embed_batch_size)

Expand All @@ -446,19 +446,19 @@ def to_app_config(self):
config["llm"]["api_key"] = self.llm_api_key
config["llm"]["temperature"] = self.llm_temperature
if self.llm.lower() == "paieas":
config["llm"]["model_name"] = self.llm_eas_model_name
config["llm"]["model"] = self.llm_eas_model_name
else:
config["llm"]["model_name"] = self.llm_api_model_name
config["llm"]["model"] = self.llm_api_model_name

config["llm"]["multi_modal"]["enable"] = self.use_mllm
config["llm"]["multi_modal"]["source"] = self.mllm
config["llm"]["multi_modal"]["endpoint"] = self.mllm_eas_url
config["llm"]["multi_modal"]["token"] = self.mllm_eas_token
config["llm"]["multi_modal"]["api_key"] = self.mllm_api_key
if self.mllm.lower() == "paieas":
config["llm"]["multi_modal"]["model_name"] = self.mllm_eas_model_name
config["llm"]["multi_modal"]["model"] = self.mllm_eas_model_name
else:
config["llm"]["multi_modal"]["model_name"] = self.mllm_api_model_name
config["llm"]["multi_modal"]["model"] = self.mllm_api_model_name

config["oss_store"]["enable"] = self.use_oss
if os.getenv("OSS_ACCESS_KEY_ID") is None and self.oss_ak:
Expand Down Expand Up @@ -728,7 +728,7 @@ def to_component_settings(self) -> Dict[str, Dict[str, Any]]:
"visible": self.llm.lower() == "paieas",
}
if self.llm.lower() == "paieas" and not self.llm_eas_model_name:
self.llm_eas_model_name = "model_name"
self.llm_eas_model_name = "model"

settings["llm_eas_model_name"] = {
"value": self.llm_eas_model_name,
Expand Down
8 changes: 4 additions & 4 deletions src/pai_rag/config/settings.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ type = "local"
type = "SimpleDirectoryReader"

# embedding configurations, source support API: OpenAI,DashScope; and local model:HuggingFace
# if use API, need set OPENAI_API_KEY or DASHSCOPE_API_KEY in ENV, If HuggingFace, need set model_name
# if use API, need set OPENAI_API_KEY or DASHSCOPE_API_KEY in ENV, If HuggingFace, need set model
# eg.
# source = "HuggingFace"
# model_name = "bge-large-zh-v1.5"
# model = "bge-large-zh-v1.5"
# embed_batch_size = 10
[rag.embedding]
source = "DashScope"
Expand All @@ -53,12 +53,12 @@ vector_store.type = "FAISS"
# llm configurations, source support API: OpenAI,DashScope or PAI-EAS's deployment
# eg.
# source = "PaiEas"
# name = ""
# model = ""
# endpoint = ""
# token = ""
[rag.llm]
source = "DashScope"
name = "qwen-turbo"
model = "qwen-turbo"

[rag.llm.function_calling_llm]
source = ""
Expand Down
10 changes: 5 additions & 5 deletions src/pai_rag/evaluation/settings_eval.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ type = "local"
type = "SimpleDirectoryReader"

# embedding configurations, source support API: OpenAI,DashScope; and local model:HuggingFace
# if use API, need set OPENAI_API_KEY or DASHSCOPE_API_KEY in ENV, If HuggingFace, need set model_name
# if use API, need set OPENAI_API_KEY or DASHSCOPE_API_KEY in ENV, If HuggingFace, need set model
# eg.
# source = "HuggingFace"
# model_name = "bge-large-zh-v1.5"
# model = "bge-large-zh-v1.5"
# embed_batch_size = 10
[rag.embedding]
source = "DashScope"
Expand All @@ -53,20 +53,20 @@ vector_store.type = "FAISS"
# llm configurations, source support API: OpenAI,DashScope or PAI-EAS's deployment
# eg.
# source = "PaiEas"
# name = ""
# model = ""
# endpoint = ""
# token = ""
[rag.llm]
source = "DashScope"
name = "qwen-turbo"
model = "qwen-turbo"

[rag.llm.function_calling_llm]
source = ""

[rag.llm.multi_modal]
enable = true
source = "DashScope"
name = "qwen-vl-plus"
model = "qwen-vl-plus"

[rag.node_enhancement]
tree_depth = 3
Expand Down
8 changes: 4 additions & 4 deletions src/pai_rag/integrations/embeddings/pai/embedding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,22 +36,22 @@ def create_embedding(embed_config: PaiBaseEmbeddingConfig):
elif isinstance(embed_config, HuggingFaceEmbeddingConfig):
pai_model_dir = os.getenv("PAI_RAG_MODEL_DIR", "./model_repository")
embed_model = HuggingFaceEmbedding(
model_name=os.path.join(pai_model_dir, embed_config.model_name),
model_name=os.path.join(pai_model_dir, embed_config.model),
embed_batch_size=embed_config.embed_batch_size,
trust_remote_code=True,
)

logger.info(
f"Initialized HuggingFace embedding model {embed_config.model_name} from model_dir_path {pai_model_dir} with {embed_config.embed_batch_size} batch size."
f"Initialized HuggingFace embedding model {embed_config.model} from model_dir_path {pai_model_dir} with {embed_config.embed_batch_size} batch size."
)

elif isinstance(embed_config, CnClipEmbeddingConfig):
embed_model = CnClipEmbedding(
model_name=embed_config.model_name,
model_name=embed_config.model,
embed_batch_size=embed_config.embed_batch_size,
)
logger.info(
f"Initialized CnClip embedding model {embed_config.model_name} with {embed_config.embed_batch_size} batch size."
f"Initialized CnClip embedding model {embed_config.model} with {embed_config.embed_batch_size} batch size."
)

else:
Expand Down
12 changes: 6 additions & 6 deletions src/pai_rag/integrations/embeddings/pai/pai_embedding_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class SupportedEmbedType(str, Enum):

class PaiBaseEmbeddingConfig(BaseModel):
source: SupportedEmbedType
model_name: str
model: str
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE

class Config:
Expand All @@ -36,24 +36,24 @@ def validate_case_insensitive(cls, value):

class DashScopeEmbeddingConfig(PaiBaseEmbeddingConfig):
source: Literal[SupportedEmbedType.dashscope] = SupportedEmbedType.dashscope
model_name: str | None = None # use default
model: str | None = None # use default
api_key: str | None = None # use default


class OpenAIEmbeddingConfig(PaiBaseEmbeddingConfig):
source: Literal[SupportedEmbedType.openai] = SupportedEmbedType.openai
model_name: str | None = None # use default
model: str | None = None # use default
api_key: str | None = None # use default


class HuggingFaceEmbeddingConfig(PaiBaseEmbeddingConfig):
source: Literal[SupportedEmbedType.huggingface] = SupportedEmbedType.huggingface
model_name: str | None = "bge-large-zh-v1.5"
model: str | None = "bge-large-zh-v1.5"


class CnClipEmbeddingConfig(PaiBaseEmbeddingConfig):
source: Literal[SupportedEmbedType.cnclip] = SupportedEmbedType.cnclip
model_name: str | None = "ViT-L-14"
model: str | None = "ViT-L-14"


SupporttedEmbeddingClsMap = {
Expand All @@ -73,6 +73,6 @@ def parse_embed_config(config_data):


if __name__ == "__main__":
embedding_config_data = {"source": "Openai", "model_name": "gpt-1", "api_key": None}
embedding_config_data = {"source": "Openai", "model": "gpt-1", "api_key": None}

print(parse_embed_config(embedding_config_data))
10 changes: 5 additions & 5 deletions src/pai_rag/integrations/llms/pai/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class PaiBaseLlmConfig(BaseModel):
temperature: float = DEFAULT_TEMPERATURE
system_prompt: str = None
max_tokens: int = DEFAULT_MAX_TOKENS
model_name: str = None
model: str = None

@classmethod
def get_subclasses(cls):
Expand All @@ -150,20 +150,20 @@ class DashScopeLlmConfig(PaiBaseLlmConfig):
source: Literal[SupportedLlmType.dashscope] = SupportedLlmType.dashscope
api_key: str | None = None
base_url: str = "https://dashscope.aliyuncs.com/compatible-mode/v1"
model_name: str = "qwen-turbo"
model: str = "qwen-turbo"


class OpenAILlmConfig(PaiBaseLlmConfig):
source: Literal[SupportedLlmType.openai] = SupportedLlmType.openai
api_key: str | None = None
model_name: str = "gpt-3.5-turbo"
model: str = "gpt-3.5-turbo"


class PaiEasLlmConfig(PaiBaseLlmConfig):
source: Literal[SupportedLlmType.paieas] = SupportedLlmType.paieas
endpoint: str
token: str
model_name: str = "default"
model: str = "default"


SupporttedLlmClsMap = {cls.get_type(): cls for cls in PaiBaseLlmConfig.get_subclasses()}
Expand All @@ -183,7 +183,7 @@ def parse_llm_config(config_data):
if __name__ == "__main__":
llm_config_data = {
"source": "dashscope",
"model_name": "qwen-turbo",
"model": "qwen-turbo",
"api_key": None,
"max_tokens": 1024,
}
Expand Down
Loading
Loading