Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add llm module #242

Merged
merged 4 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
412 changes: 0 additions & 412 deletions src/pai_rag/integrations/llms/dashscope/base.py

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Any, Dict, Optional, Tuple
from llama_index.core.base.llms.generic_utils import get_from_param_or_env
import dashscope
from llama_index.core.base.llms.types import LLMMetadata
from llama_index.llms.openai_like import OpenAILike
from typing import Literal
from pydantic import BaseModel, field_validator
from enum import Enum
from llama_index.core.constants import DEFAULT_TEMPERATURE

DEFAULT_MAX_TOKENS = 2000


class DashScopeGenerationModels:
Expand All @@ -29,31 +30,31 @@ class DashScopeGenerationModels:
DASHSCOPE_MODEL_META = {
DashScopeGenerationModels.QWEN_TURBO: {
"context_window": 1024 * 8,
"num_output": 1024 * 8,
"num_output": 1024 * 2,
"is_chat_model": True,
"is_function_calling_model": True,
},
DashScopeGenerationModels.QWEN_PLUS: {
"context_window": 1024 * 32,
"num_output": 1024 * 32,
"num_output": 1024 * 2,
"is_chat_model": True,
"is_function_calling_model": True,
},
DashScopeGenerationModels.QWEN_MAX: {
"context_window": 1024 * 8,
"num_output": 1024 * 8,
"num_output": 1024 * 2,
"is_chat_model": True,
"is_function_calling_model": True,
},
DashScopeGenerationModels.QWEN_MAX_1201: {
"context_window": 1024 * 8,
"num_output": 1024 * 8,
"num_output": 1024 * 2,
"is_chat_model": True,
"is_function_calling_model": True,
},
DashScopeGenerationModels.QWEN_MAX_LONGCONTEXT: {
"context_window": 1024 * 30,
"num_output": 1024 * 30,
"num_output": 1024 * 2,
"is_chat_model": True,
"is_function_calling_model": True,
},
Expand All @@ -65,118 +66,125 @@ class DashScopeGenerationModels:
},
DashScopeGenerationModels.QWEM1P5_7B_CHAT: {
"context_window": 1024 * 8,
"num_output": 1024 * 8,
"num_output": 1024 * 2,
"is_chat_model": True,
"is_function_calling_model": True,
},
DashScopeGenerationModels.QWEM1P5_14B_CHAT: {
"context_window": 1024 * 16,
"num_output": 1024 * 16,
"num_output": 1024 * 2,
"is_chat_model": True,
"is_function_calling_model": True,
},
DashScopeGenerationModels.QWEM1P5_32B_CHAT: {
"context_window": 1024 * 16,
"num_output": 1024 * 16,
"num_output": 1024 * 2,
"is_chat_model": True,
"is_function_calling_model": True,
},
DashScopeGenerationModels.QWEM1P5_72B_CHAT: {
"context_window": 1024 * 16,
"num_output": 1024 * 16,
"num_output": 1024 * 2,
"is_chat_model": True,
"is_function_calling_model": True,
},
DashScopeGenerationModels.QWEM1P5_110B_CHAT: {
"context_window": 1024 * 32,
"num_output": 1024 * 32,
"num_output": 1024 * 2,
"is_chat_model": True,
"is_function_calling_model": True,
},
DashScopeGenerationModels.QWEM2_1P5B_INSTRUCT: {
"context_window": 1024 * 30,
"num_output": 1024 * 32,
"num_output": 1024 * 2,
"is_chat_model": True,
"is_function_calling_model": True,
},
DashScopeGenerationModels.QWEM2_7B_INSTRUCT: {
"context_window": 1024 * 32,
"num_output": 1024 * 32,
"num_output": 1024 * 2,
"is_chat_model": True,
"is_function_calling_model": True,
},
DashScopeGenerationModels.QWEM2_72B_INSTRUCT: {
"context_window": 1024 * 32,
"num_output": 1024 * 32,
"num_output": 1024 * 2,
"is_chat_model": True,
"is_function_calling_model": True,
},
}

DEFAULT_DASHSCOPE_API_BASE = "https://dashscope.aliyuncs.com/compatible-mode/v1"


def resolve_dashscope_credentials(
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Tuple[Optional[str], str, str]:
""" "Resolve OpenAI credentials.

The order of precedence is:
1. param
2. env
3. openai module
4. default
"""
# resolve from param or env
api_key = get_from_param_or_env("api_key", api_key, "DASHSCOPE_API_KEY", "")
api_base = get_from_param_or_env("api_base", api_base, "DASHSCOPE_API_BASE", "")

# resolve from openai module or default
final_api_key = api_key or dashscope.api_key or ""
final_api_base = api_base or DEFAULT_DASHSCOPE_API_BASE

return final_api_key, str(final_api_base)


class MyFCDashScope(OpenAILike):
"""
MyFCDashScope LLM is a thin wrapper around the OpenAILike model that makes it compatible
with Function Calling DashScope.
"""

def __init__(
self,
model: Optional[str] = DashScopeGenerationModels.QWEN_MAX,
additional_kwargs: Optional[Dict[str, Any]] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
**kwargs: Any,
) -> None:
additional_kwargs = additional_kwargs or {}

api_key, api_base = resolve_dashscope_credentials(
api_key=api_key,
api_base=api_base,
)

super().__init__(
model=model,
api_key=api_key,
api_base=api_base,
**kwargs,
)

class SupportedLlmType(str, Enum):
dashscope = "dashscope"
openai = "openai"
paieas = "paieas"


class PaiBaseLlmConfig(BaseModel):
source: SupportedLlmType
temperature: float = DEFAULT_TEMPERATURE
system_prompt: str = None
max_tokens: int = DEFAULT_MAX_TOKENS
model_name: str = None

@classmethod
def class_name(cls) -> str:
return "fc_dashscope_llm"

@property
def metadata(self) -> LLMMetadata:
DASHSCOPE_MODEL_META[self.model]["num_output"] = (
self.max_tokens or DASHSCOPE_MODEL_META[self.model]["num_output"]
)
return LLMMetadata(
model_name=self.model,
**DASHSCOPE_MODEL_META[self.model],
)
def get_subclasses(cls):
return tuple(cls.__subclasses__())

class Config:
frozen = True

@classmethod
def get_type(cls):
return cls.model_fields["source"].default

@field_validator("source", mode="before")
def validate_case_insensitive(cls, value):
if isinstance(value, str):
return value.lower()
return value


class DashScopeLlmConfig(PaiBaseLlmConfig):
source: Literal[SupportedLlmType.dashscope] = SupportedLlmType.dashscope
api_key: str | None = None
base_url: str = "https://dashscope.aliyuncs.com/compatible-mode/v1"
model_name: str = "qwen-turbo"


class OpenAILlmConfig(PaiBaseLlmConfig):
source: Literal[SupportedLlmType.openai] = SupportedLlmType.openai
api_key: str | None = None
model_name: str = "gpt-3.5-turbo"


class PaiEasLlmConfig(PaiBaseLlmConfig):
source: Literal[SupportedLlmType.paieas] = SupportedLlmType.paieas
endpoint: str
token: str
model_name: str = "default"


SupporttedLlmClsMap = {cls.get_type(): cls for cls in PaiBaseLlmConfig.get_subclasses()}


def parse_llm_config(config_data):
if "source" not in config_data:
raise ValueError("Llm config must contain 'source' field")

embedding_cls = SupporttedLlmClsMap.get(config_data["source"].lower())
if embedding_cls is None:
moria97 marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(f"Unsupported llm source: {config_data['source']}")

return embedding_cls(**config_data)


if __name__ == "__main__":
llm_config_data = {
"source": "dashscope",
"model_name": "qwen-turbo",
"api_key": None,
"max_tokens": 1024,
}
print(parse_llm_config(llm_config_data))
132 changes: 132 additions & 0 deletions src/pai_rag/integrations/llms/pai/llm_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import logging
import os
from llama_index.llms.openai import OpenAI
from llama_index.llms.openai_like import OpenAILike
from llama_index.multi_modal_llms.openai import OpenAIMultiModal
from pai_rag.integrations.llms.pai.llm_config import (
PaiBaseLlmConfig,
OpenAILlmConfig,
DashScopeLlmConfig,
PaiEasLlmConfig,
)
from pai_rag.integrations.llms.pai.open_ai_alike_multi_modal import (
OpenAIAlikeMultiModal,
)

logger = logging.getLogger(__name__)


def create_llm(llm_config: PaiBaseLlmConfig):
if isinstance(llm_config, OpenAILlmConfig):
logger.info(
f"""
[Parameters][LLM:OpenAI]
model = {llm_config.model_name},
temperature = {llm_config.temperature},
system_prompt = {llm_config.system_prompt}
"""
)
llm = OpenAI(
model=llm_config.model_name,
temperature=llm_config.temperature,
system_prompt=llm_config.system_prompt,
api_key=llm_config.api_key,
max_tokens=llm_config.max_tokens,
)
elif isinstance(llm_config, DashScopeLlmConfig):
logger.info(
f"""
[Parameters][LLM:DashScope]
model = {llm_config.model_name},
temperature = {llm_config.temperature},
system_prompt = {llm_config.system_prompt}
"""
)
llm = OpenAILike(
model=llm_config.model_name,
api_base=llm_config.base_url,
temperature=llm_config.temperature,
system_prompt=llm_config.system_prompt,
is_chat_model=True,
api_key=llm_config.api_key or os.environ.get("DASHSCOPE_API_KEY"),
max_tokens=llm_config.max_tokens,
)
elif isinstance(llm_config, PaiEasLlmConfig):
logger.info(
f"""
[Parameters][LLM:PAI-EAS]
model = {llm_config.model_name},
endpoint = {llm_config.endpoint},
token = {llm_config.token}
"""
)
llm = OpenAILike(
model=llm_config.model_name,
api_base=llm_config.endpoint,
temperature=llm_config.temperature,
system_prompt=llm_config.system_prompt,
api_key=llm_config.token,
max_tokens=llm_config.max_tokens,
)
else:
raise ValueError(f"Unknown LLM source: '{llm_config}'")

return llm


def create_multi_modal_llm(llm_config: PaiBaseLlmConfig):
if isinstance(llm_config, OpenAILlmConfig):
logger.info(
f"""
[Parameters][LLM:OpenAI]
model = {llm_config.model_name},
temperature = {llm_config.temperature},
system_prompt = {llm_config.system_prompt}
"""
)
llm = OpenAIMultiModal(
model=llm_config.model_name,
temperature=llm_config.temperature,
system_prompt=llm_config.system_prompt,
api_key=llm_config.api_key,
max_new_tokens=llm_config.max_tokens,
)
elif isinstance(llm_config, DashScopeLlmConfig):
logger.info(
f"""
[Parameters][LLM:DashScope]
model = {llm_config.model_name},
temperature = {llm_config.temperature},
system_prompt = {llm_config.system_prompt}
"""
)
llm = OpenAIAlikeMultiModal(
model=llm_config.model_name,
api_base=llm_config.base_url,
temperature=llm_config.temperature,
system_prompt=llm_config.system_prompt,
is_chat_model=True,
api_key=llm_config.api_key or os.environ.get("DASHSCOPE_API_KEY"),
max_new_tokens=llm_config.max_tokens,
)
elif isinstance(llm_config, PaiEasLlmConfig):
logger.info(
f"""
[Parameters][LLM:PAI-EAS]
model = {llm_config.model_name},
endpoint = {llm_config.endpoint},
token = {llm_config.token}
"""
)
llm = OpenAIAlikeMultiModal(
model=llm_config.model_name,
api_base=llm_config.endpoint,
temperature=llm_config.temperature,
system_prompt=llm_config.system_prompt,
api_key=llm_config.token,
max_new_tokens=llm_config.max_tokens,
)
else:
raise ValueError(f"Unknown Multi-modal LLM source: '{llm_config}'")

return llm
Loading
Loading