diff --git a/README.md b/README.md
index 517226dc5..92778486f 100644
--- a/README.md
+++ b/README.md
@@ -103,7 +103,8 @@ At present, we have introduced several key features to showcase our current capa
We offer extensive model support, including dozens of large language models (LLMs) from both open-source and API agents, such as LLaMA/LLaMA2, Baichuan, ChatGLM, Wenxin, Tongyi, Zhipu, and many more.
- News
- - 🔥🔥🔥 [qwen-72b-chat](https://huggingface.co/Qwen/Qwen-72B-Chat)
+ - 🔥🔥🔥 [Mixtral-8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)
+ - 🔥🔥🔥 [Qwen-72B-Chat](https://huggingface.co/Qwen/Qwen-72B-Chat)
- 🔥🔥🔥 [Yi-34B-Chat](https://huggingface.co/01-ai/Yi-34B-Chat)
- [More Supported LLMs](http://docs.dbgpt.site/docs/modules/smmf)
diff --git a/README.zh.md b/README.zh.md
index afe2189e0..857e9a241 100644
--- a/README.zh.md
+++ b/README.zh.md
@@ -111,7 +111,8 @@ DB-GPT是一个开源的数据库领域大模型框架。目的是构建大模
海量模型支持,包括开源、API代理等几十种大语言模型。如LLaMA/LLaMA2、Baichuan、ChatGLM、文心、通义、智谱等。当前已支持如下模型:
- 新增支持模型
- - 🔥🔥🔥 [qwen-72b-chat](https://huggingface.co/Qwen/Qwen-72B-Chat)
+ - 🔥🔥🔥 [Mixtral-8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)
+ - 🔥🔥🔥 [Qwen-72B-Chat](https://huggingface.co/Qwen/Qwen-72B-Chat)
- 🔥🔥🔥 [Yi-34B-Chat](https://huggingface.co/01-ai/Yi-34B-Chat)
- [更多开源模型](https://www.yuque.com/eosphoros/dbgpt-docs/iqaaqwriwhp6zslc#qQktR)
diff --git a/dbgpt/app/chat_adapter.py b/dbgpt/app/chat_adapter.py
index 474695beb..c1cb192b1 100644
--- a/dbgpt/app/chat_adapter.py
+++ b/dbgpt/app/chat_adapter.py
@@ -245,7 +245,7 @@ def get_conv_template(self, model_path: str) -> Conversation:
class LlamaCppChatAdapter(BaseChatAdpter):
def match(self, model_path: str):
- from dbgpt.model.adapter import LlamaCppAdapater
+ from dbgpt.model.adapter.old_adapter import LlamaCppAdapater
if "llama-cpp" == model_path:
return True
diff --git a/dbgpt/configs/model_config.py b/dbgpt/configs/model_config.py
index 05a0636b8..136c57e43 100644
--- a/dbgpt/configs/model_config.py
+++ b/dbgpt/configs/model_config.py
@@ -114,7 +114,9 @@ def get_device() -> str:
# https://huggingface.co/microsoft/Orca-2-13b
"orca-2-13b": os.path.join(MODEL_PATH, "Orca-2-13b"),
# https://huggingface.co/openchat/openchat_3.5
- "openchat_3.5": os.path.join(MODEL_PATH, "openchat_3.5"),
+ "openchat-3.5": os.path.join(MODEL_PATH, "openchat_3.5"),
+ # https://huggingface.co/openchat/openchat-3.5-1210
+ "openchat-3.5-1210": os.path.join(MODEL_PATH, "openchat-3.5-1210"),
# https://huggingface.co/hfl/chinese-alpaca-2-7b
"chinese-alpaca-2-7b": os.path.join(MODEL_PATH, "chinese-alpaca-2-7b"),
# https://huggingface.co/hfl/chinese-alpaca-2-13b
@@ -125,6 +127,10 @@ def get_device() -> str:
"zephyr-7b-alpha": os.path.join(MODEL_PATH, "zephyr-7b-alpha"),
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1
"mistral-7b-instruct-v0.1": os.path.join(MODEL_PATH, "Mistral-7B-Instruct-v0.1"),
+ # https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1
+ "mixtral-8x7b-instruct-v0.1": os.path.join(
+ MODEL_PATH, "Mixtral-8x7B-Instruct-v0.1"
+ ),
# https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca
"mistral-7b-openorca": os.path.join(MODEL_PATH, "Mistral-7B-OpenOrca"),
# https://huggingface.co/Xwin-LM/Xwin-LM-7B-V0.1
diff --git a/dbgpt/core/interface/message.py b/dbgpt/core/interface/message.py
old mode 100644
new mode 100755
index 8009053a5..2493ebb53
--- a/dbgpt/core/interface/message.py
+++ b/dbgpt/core/interface/message.py
@@ -157,14 +157,13 @@ def to_openai_messages(messages: List["ModelMessage"]) -> List[Dict[str, str]]:
else:
pass
# Move the last user's information to the end
- temp_his = history[::-1]
- last_user_input = None
- for m in temp_his:
- if m["role"] == "user":
- last_user_input = m
+ last_user_input_index = None
+ for i in range(len(history) - 1, -1, -1):
+ if history[i]["role"] == "user":
+ last_user_input_index = i
break
- if last_user_input:
- history.remove(last_user_input)
+ if last_user_input_index:
+ last_user_input = history.pop(last_user_input_index)
history.append(last_user_input)
return history
diff --git a/dbgpt/core/interface/tests/test_message.py b/dbgpt/core/interface/tests/test_message.py
old mode 100644
new mode 100755
index ccaa5237a..0650b1f67
--- a/dbgpt/core/interface/tests/test_message.py
+++ b/dbgpt/core/interface/tests/test_message.py
@@ -67,6 +67,23 @@ def conversation_with_messages():
return conv
+@pytest.fixture
+def human_model_message():
+ return ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hello")
+
+
+@pytest.fixture
+def ai_model_message():
+ return ModelMessage(role=ModelMessageRoleType.AI, content="Hi there")
+
+
+@pytest.fixture
+def system_model_message():
+ return ModelMessage(
+ role=ModelMessageRoleType.SYSTEM, content="You are a helpful chatbot!"
+ )
+
+
def test_init(basic_conversation):
assert basic_conversation.chat_mode == "chat_normal"
assert basic_conversation.user_name == "user1"
@@ -370,3 +387,43 @@ def test_parse_model_messages_multiple_system_messages():
assert user_prompt == "How are you?"
assert system_messages == ["System start", "System check"]
assert history_messages == [["Hey", "Hello!"]]
+
+
+def test_to_openai_messages(
+ human_model_message, ai_model_message, system_model_message
+):
+ none_messages = ModelMessage.to_openai_messages([])
+ assert none_messages == []
+
+ single_messages = ModelMessage.to_openai_messages([human_model_message])
+ assert single_messages == [{"role": "user", "content": human_model_message.content}]
+
+ normal_messages = ModelMessage.to_openai_messages(
+ [
+ system_model_message,
+ human_model_message,
+ ai_model_message,
+ human_model_message,
+ ]
+ )
+ assert normal_messages == [
+ {"role": "system", "content": system_model_message.content},
+ {"role": "user", "content": human_model_message.content},
+ {"role": "assistant", "content": ai_model_message.content},
+ {"role": "user", "content": human_model_message.content},
+ ]
+
+ shuffle_messages = ModelMessage.to_openai_messages(
+ [
+ system_model_message,
+ human_model_message,
+ human_model_message,
+ ai_model_message,
+ ]
+ )
+ assert shuffle_messages == [
+ {"role": "system", "content": system_model_message.content},
+ {"role": "user", "content": human_model_message.content},
+ {"role": "assistant", "content": ai_model_message.content},
+ {"role": "user", "content": human_model_message.content},
+ ]
diff --git a/dbgpt/model/adapter/__init__.py b/dbgpt/model/adapter/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/dbgpt/model/adapter/base.py b/dbgpt/model/adapter/base.py
new file mode 100644
index 000000000..df1d9441a
--- /dev/null
+++ b/dbgpt/model/adapter/base.py
@@ -0,0 +1,437 @@
+from abc import ABC, abstractmethod
+from typing import Dict, List, Optional, Any, Tuple, Type, Callable
+import logging
+from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
+from dbgpt.model.base import ModelType
+from dbgpt.model.parameter import (
+ BaseModelParameters,
+ ModelParameters,
+ LlamaCppModelParameters,
+ ProxyModelParameters,
+)
+from dbgpt.model.adapter.template import (
+ get_conv_template,
+ ConversationAdapter,
+ ConversationAdapterFactory,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class LLMModelAdapter(ABC):
+ """New Adapter for DB-GPT LLM models"""
+
+ model_name: Optional[str] = None
+ model_path: Optional[str] = None
+ conv_factory: Optional[ConversationAdapterFactory] = None
+ # TODO: more flexible quantization config
+ support_4bit: bool = False
+ support_8bit: bool = False
+ support_system_message: bool = True
+
+ def __repr__(self) -> str:
+ return f"<{self.__class__.__name__} model_name={self.model_name} model_path={self.model_path}>"
+
+ def __str__(self):
+ return self.__repr__()
+
+ @abstractmethod
+ def new_adapter(self, **kwargs) -> "LLMModelAdapter":
+ """Create a new adapter instance
+
+ Args:
+ **kwargs: The parameters of the new adapter instance
+
+ Returns:
+ LLMModelAdapter: The new adapter instance
+ """
+
+ def use_fast_tokenizer(self) -> bool:
+ """Whether use a [fast Rust-based tokenizer](https://huggingface.co/docs/tokenizers/index) if it is supported
+ for a given model.
+ """
+ return False
+
+ def model_type(self) -> str:
+ return ModelType.HF
+
+ def model_param_class(self, model_type: str = None) -> Type[BaseModelParameters]:
+ """Get the startup parameters instance of the model
+
+ Args:
+ model_type (str, optional): The type of model. Defaults to None.
+
+ Returns:
+ Type[BaseModelParameters]: The startup parameters instance of the model
+ """
+ # """Get the startup parameters instance of the model"""
+ model_type = model_type if model_type else self.model_type()
+ if model_type == ModelType.LLAMA_CPP:
+ return LlamaCppModelParameters
+ elif model_type == ModelType.PROXY:
+ return ProxyModelParameters
+ return ModelParameters
+
+ def match(
+ self,
+ model_type: str,
+ model_name: Optional[str] = None,
+ model_path: Optional[str] = None,
+ ) -> bool:
+ """Whether the model adapter can load the given model
+
+ Args:
+ model_type (str): The type of model
+ model_name (Optional[str], optional): The name of model. Defaults to None.
+ model_path (Optional[str], optional): The path of model. Defaults to None.
+ """
+ return False
+
+ def support_quantization_4bit(self) -> bool:
+ """Whether the model adapter can load 4bit model
+
+ If it is True, we will load the 4bit model with :meth:`~LLMModelAdapter.load`
+
+ Returns:
+ bool: Whether the model adapter can load 4bit model, default is False
+ """
+ return self.support_4bit
+
+ def support_quantization_8bit(self) -> bool:
+ """Whether the model adapter can load 8bit model
+
+ If it is True, we will load the 8bit model with :meth:`~LLMModelAdapter.load`
+
+ Returns:
+ bool: Whether the model adapter can load 8bit model, default is False
+ """
+ return self.support_8bit
+
+ def load(self, model_path: str, from_pretrained_kwargs: dict):
+ """Load model and tokenizer"""
+ raise NotImplementedError
+
+ def load_from_params(self, params):
+ """Load the model and tokenizer according to the given parameters"""
+ raise NotImplementedError
+
+ def support_async(self) -> bool:
+ """Whether the loaded model supports asynchronous calls"""
+ return False
+
+ def get_generate_stream_function(self, model, model_path: str):
+ """Get the generate stream function of the model"""
+ raise NotImplementedError
+
+ def get_async_generate_stream_function(self, model, model_path: str):
+ """Get the asynchronous generate stream function of the model"""
+ raise NotImplementedError
+
+ def get_default_conv_template(
+ self, model_name: str, model_path: str
+ ) -> Optional[ConversationAdapter]:
+ """Get the default conversation template
+
+ Args:
+ model_name (str): The name of the model.
+ model_path (str): The path of the model.
+
+ Returns:
+ Optional[ConversationAdapter]: The conversation template.
+ """
+ raise NotImplementedError
+
+ def get_default_message_separator(self) -> str:
+ """Get the default message separator"""
+ try:
+ conv_template = self.get_default_conv_template(
+ self.model_name, self.model_path
+ )
+ return conv_template.sep
+ except Exception:
+ return "\n"
+
+ def transform_model_messages(
+ self, messages: List[ModelMessage]
+ ) -> List[Dict[str, str]]:
+ """Transform the model messages
+
+ Default is the OpenAI format, example:
+ .. code-block:: python
+ return_messages = [
+ {"role": "system", "content": "You are a helpful assistant"},
+ {"role": "user", "content": "Hello"},
+ {"role": "assistant", "content": "Hi"}
+ ]
+
+ But some model may need to transform the messages to other format(e.g. There is no system message), such as:
+ .. code-block:: python
+ return_messages = [
+ {"role": "user", "content": "Hello"},
+ {"role": "assistant", "content": "Hi"}
+ ]
+ Args:
+ messages (List[ModelMessage]): The model messages
+
+ Returns:
+ List[Dict[str, str]]: The transformed model messages
+ """
+ logger.info(f"support_system_message: {self.support_system_message}")
+ if not self.support_system_message:
+ return self._transform_to_no_system_messages(messages)
+ else:
+ return ModelMessage.to_openai_messages(messages)
+
+ def _transform_to_no_system_messages(
+ self, messages: List[ModelMessage]
+ ) -> List[Dict[str, str]]:
+ """Transform the model messages to no system messages
+
+ Some opensource chat model no system messages, so wo should transform the messages to no system messages.
+
+ Merge the system messages to the last user message, example:
+ .. code-block:: python
+ return_messages = [
+ {"role": "system", "content": "You are a helpful assistant"},
+ {"role": "user", "content": "Hello"},
+ {"role": "assistant", "content": "Hi"}
+ ]
+ =>
+ return_messages = [
+ {"role": "user", "content": "You are a helpful assistant\nHello"},
+ {"role": "assistant", "content": "Hi"}
+ ]
+
+ Args:
+ messages (List[ModelMessage]): The model messages
+
+ Returns:
+ List[Dict[str, str]]: The transformed model messages
+ """
+ openai_messages = ModelMessage.to_openai_messages(messages)
+ system_messages = []
+ return_messages = []
+ for message in openai_messages:
+ if message["role"] == "system":
+ system_messages.append(message["content"])
+ else:
+ return_messages.append(message)
+ if len(system_messages) > 1:
+ # Too much system messages should be a warning
+ logger.warning("Your system messages have more than one message")
+ if system_messages:
+ sep = self.get_default_message_separator()
+ str_system_messages = ",".join(system_messages)
+ # Update last user message
+ return_messages[-1]["content"] = (
+ str_system_messages + sep + return_messages[-1]["content"]
+ )
+ return return_messages
+
+ def get_str_prompt(
+ self,
+ params: Dict,
+ messages: List[ModelMessage],
+ tokenizer: Any,
+ prompt_template: str = None,
+ ) -> Optional[str]:
+ """Get the string prompt from the given parameters and messages
+
+ If the value of return is not None, we will skip :meth:`~LLMModelAdapter.get_prompt_with_template` and use the value of return.
+
+ Args:
+ params (Dict): The parameters
+ messages (List[ModelMessage]): The model messages
+ tokenizer (Any): The tokenizer of model, in huggingface chat model, we can create the prompt by tokenizer
+ prompt_template (str, optional): The prompt template. Defaults to None.
+
+ Returns:
+ Optional[str]: The string prompt
+ """
+ return None
+
+ def get_prompt_with_template(
+ self,
+ params: Dict,
+ messages: List[ModelMessage],
+ model_name: str,
+ model_path: str,
+ model_context: Dict,
+ prompt_template: str = None,
+ ):
+ conv: ConversationAdapter = self.get_default_conv_template(
+ model_name, model_path
+ )
+
+ if prompt_template:
+ logger.info(f"Use prompt template {prompt_template} from config")
+ conv = get_conv_template(prompt_template)
+ if not conv or not messages:
+ # Nothing to do
+ logger.info(
+ f"No conv from model_path {model_path} or no messages in params, {self}"
+ )
+ return None, None, None
+
+ conv = conv.copy()
+ system_messages = []
+ user_messages = []
+ ai_messages = []
+
+ for message in messages:
+ if isinstance(message, ModelMessage):
+ role = message.role
+ content = message.content
+ elif isinstance(message, dict):
+ role = message["role"]
+ content = message["content"]
+ else:
+ raise ValueError(f"Invalid message type: {message}")
+
+ if role == ModelMessageRoleType.SYSTEM:
+ # Support for multiple system messages
+ system_messages.append(content)
+ elif role == ModelMessageRoleType.HUMAN:
+ # conv.append_message(conv.roles[0], content)
+ user_messages.append(content)
+ elif role == ModelMessageRoleType.AI:
+ # conv.append_message(conv.roles[1], content)
+ ai_messages.append(content)
+ else:
+ raise ValueError(f"Unknown role: {role}")
+
+ can_use_systems: [] = []
+ if system_messages:
+ if len(system_messages) > 1:
+ # Compatible with dbgpt complex scenarios, the last system will protect more complete information
+ # entered by the current user
+ user_messages[-1] = system_messages[-1]
+ can_use_systems = system_messages[:-1]
+ else:
+ can_use_systems = system_messages
+
+ for i in range(len(user_messages)):
+ conv.append_message(conv.roles[0], user_messages[i])
+ if i < len(ai_messages):
+ conv.append_message(conv.roles[1], ai_messages[i])
+
+ # TODO join all system messages may not be a good idea
+ conv.set_system_message("".join(can_use_systems))
+ # Add a blank message for the assistant.
+ conv.append_message(conv.roles[1], None)
+ new_prompt = conv.get_prompt()
+ return new_prompt, conv.stop_str, conv.stop_token_ids
+
+ def model_adaptation(
+ self,
+ params: Dict,
+ model_name: str,
+ model_path: str,
+ tokenizer: Any,
+ prompt_template: str = None,
+ ) -> Tuple[Dict, Dict]:
+ """Params adaptation"""
+ messages = params.get("messages")
+ # Some model context to dbgpt server
+ model_context = {"prompt_echo_len_char": -1, "has_format_prompt": False}
+ if messages:
+ # Dict message to ModelMessage
+ messages = [
+ m if isinstance(m, ModelMessage) else ModelMessage(**m)
+ for m in messages
+ ]
+ params["messages"] = messages
+
+ new_prompt = self.get_str_prompt(params, messages, tokenizer, prompt_template)
+ conv_stop_str, conv_stop_token_ids = None, None
+ if not new_prompt:
+ (
+ new_prompt,
+ conv_stop_str,
+ conv_stop_token_ids,
+ ) = self.get_prompt_with_template(
+ params, messages, model_name, model_path, model_context, prompt_template
+ )
+ if not new_prompt:
+ return params, model_context
+
+ # Overwrite the original prompt
+ # TODO remote bos token and eos token from tokenizer_config.json of model
+ prompt_echo_len_char = len(new_prompt.replace("", "").replace("", ""))
+ model_context["prompt_echo_len_char"] = prompt_echo_len_char
+ model_context["echo"] = params.get("echo", True)
+ model_context["has_format_prompt"] = True
+ params["prompt"] = new_prompt
+
+ custom_stop = params.get("stop")
+ custom_stop_token_ids = params.get("stop_token_ids")
+
+ # Prefer the value passed in from the input parameter
+ params["stop"] = custom_stop or conv_stop_str
+ params["stop_token_ids"] = custom_stop_token_ids or conv_stop_token_ids
+
+ return params, model_context
+
+
+class AdapterEntry:
+ """The entry of model adapter"""
+
+ def __init__(
+ self,
+ model_adapter: LLMModelAdapter,
+ match_funcs: List[Callable[[str, str, str], bool]] = None,
+ ):
+ self.model_adapter = model_adapter
+ self.match_funcs = match_funcs or []
+
+
+model_adapters: List[AdapterEntry] = []
+
+
+def register_model_adapter(
+ model_adapter_cls: Type[LLMModelAdapter],
+ match_funcs: List[Callable[[str, str, str], bool]] = None,
+) -> None:
+ """Register a model adapter.
+
+ Args:
+ model_adapter_cls (Type[LLMModelAdapter]): The model adapter class.
+ match_funcs (List[Callable[[str, str, str], bool]], optional): The match functions. Defaults to None.
+ """
+ model_adapters.append(AdapterEntry(model_adapter_cls(), match_funcs))
+
+
+def get_model_adapter(
+ model_type: str,
+ model_name: str,
+ model_path: str,
+ conv_factory: Optional[ConversationAdapterFactory] = None,
+) -> Optional[LLMModelAdapter]:
+ """Get a model adapter.
+
+ Args:
+ model_type (str): The type of the model.
+ model_name (str): The name of the model.
+ model_path (str): The path of the model.
+ conv_factory (Optional[ConversationAdapterFactory], optional): The conversation factory. Defaults to None.
+ Returns:
+ Optional[LLMModelAdapter]: The model adapter.
+ """
+ adapter = None
+ # First find adapter by model_name
+ for adapter_entry in model_adapters:
+ if adapter_entry.model_adapter.match(model_type, model_name, None):
+ adapter = adapter_entry.model_adapter
+ break
+ for adapter_entry in model_adapters:
+ if adapter_entry.model_adapter.match(model_type, None, model_path):
+ adapter = adapter_entry.model_adapter
+ break
+ if adapter:
+ new_adapter = adapter.new_adapter()
+ new_adapter.model_name = model_name
+ new_adapter.model_path = model_path
+ if conv_factory:
+ new_adapter.conv_factory = conv_factory
+ return new_adapter
+ return None
diff --git a/dbgpt/model/adapter/fschat_adapter.py b/dbgpt/model/adapter/fschat_adapter.py
new file mode 100644
index 000000000..ca5413645
--- /dev/null
+++ b/dbgpt/model/adapter/fschat_adapter.py
@@ -0,0 +1,262 @@
+"""Adapter for fastchat
+
+You can import fastchat only in this file, so that the user does not need to install fastchat if he does not use it.
+"""
+import os
+import threading
+import logging
+from functools import cache
+from typing import TYPE_CHECKING, Callable, Tuple, List, Optional
+
+try:
+ from fastchat.conversation import (
+ Conversation,
+ register_conv_template,
+ SeparatorStyle,
+ )
+except ImportError as exc:
+ raise ValueError(
+ "Could not import python package: fschat "
+ "Please install fastchat by command `pip install fschat` "
+ ) from exc
+
+from dbgpt.model.adapter.template import ConversationAdapter, PromptType
+from dbgpt.model.adapter.base import LLMModelAdapter
+
+if TYPE_CHECKING:
+ from fastchat.model.model_adapter import BaseModelAdapter
+ from torch.nn import Module as TorchNNModule
+
+logger = logging.getLogger(__name__)
+
+thread_local = threading.local()
+_IS_BENCHMARK = os.getenv("DB_GPT_MODEL_BENCHMARK", "False").lower() == "true"
+
+# If some model is not in the blacklist, but it still affects the loading of DB-GPT, you can add it to the blacklist.
+__BLACK_LIST_MODEL_PROMPT = []
+
+
+class FschatConversationAdapter(ConversationAdapter):
+ """The conversation adapter for fschat."""
+
+ def __init__(self, conv: Conversation):
+ self._conv = conv
+
+ @property
+ def prompt_type(self) -> PromptType:
+ return PromptType.FSCHAT
+
+ @property
+ def roles(self) -> Tuple[str]:
+ return self._conv.roles
+
+ @property
+ def sep(self) -> Optional[str]:
+ return self._conv.sep
+
+ @property
+ def stop_str(self) -> str:
+ return self._conv.stop_str
+
+ @property
+ def stop_token_ids(self) -> Optional[List[int]]:
+ return self._conv.stop_token_ids
+
+ def get_prompt(self) -> str:
+ """Get the prompt string."""
+ return self._conv.get_prompt()
+
+ def set_system_message(self, system_message: str) -> None:
+ """Set the system message."""
+ self._conv.set_system_message(system_message)
+
+ def append_message(self, role: str, message: str) -> None:
+ """Append a new message.
+
+ Args:
+ role (str): The role of the message.
+ message (str): The message content.
+ """
+ self._conv.append_message(role, message)
+
+ def update_last_message(self, message: str) -> None:
+ """Update the last output.
+
+ The last message is typically set to be None when constructing the prompt,
+ so we need to update it in-place after getting the response from a model.
+
+ Args:
+ message (str): The message content.
+ """
+ self._conv.update_last_message(message)
+
+ def copy(self) -> "ConversationAdapter":
+ """Copy the conversation."""
+ return FschatConversationAdapter(self._conv.copy())
+
+
+class FastChatLLMModelAdapterWrapper(LLMModelAdapter):
+ """Wrapping fastchat adapter"""
+
+ def __init__(self, adapter: "BaseModelAdapter") -> None:
+ self._adapter = adapter
+
+ def new_adapter(self, **kwargs) -> "LLMModelAdapter":
+ return FastChatLLMModelAdapterWrapper(self._adapter)
+
+ def use_fast_tokenizer(self) -> bool:
+ return self._adapter.use_fast_tokenizer
+
+ def load(self, model_path: str, from_pretrained_kwargs: dict):
+ return self._adapter.load_model(model_path, from_pretrained_kwargs)
+
+ def get_generate_stream_function(self, model: "TorchNNModule", model_path: str):
+ if _IS_BENCHMARK:
+ from dbgpt.util.benchmarks.llm.fastchat_benchmarks_inference import (
+ generate_stream,
+ )
+
+ return generate_stream
+ else:
+ from fastchat.model.model_adapter import get_generate_stream_function
+
+ return get_generate_stream_function(model, model_path)
+
+ def get_default_conv_template(
+ self, model_name: str, model_path: str
+ ) -> Optional[ConversationAdapter]:
+ conv_template = self._adapter.get_default_conv_template(model_path)
+ return FschatConversationAdapter(conv_template) if conv_template else None
+
+ def __str__(self) -> str:
+ return "{}({}.{})".format(
+ self.__class__.__name__,
+ self._adapter.__class__.__module__,
+ self._adapter.__class__.__name__,
+ )
+
+
+def _get_fastchat_model_adapter(
+ model_name: str,
+ model_path: str,
+ caller: Callable[[str], None] = None,
+ use_fastchat_monkey_patch: bool = False,
+):
+ from fastchat.model import model_adapter
+
+ _bak_get_model_adapter = model_adapter.get_model_adapter
+ try:
+ if use_fastchat_monkey_patch:
+ model_adapter.get_model_adapter = _fastchat_get_adapter_monkey_patch
+ thread_local.model_name = model_name
+ _remove_black_list_model_of_fastchat()
+ if caller:
+ return caller(model_path)
+ finally:
+ del thread_local.model_name
+ model_adapter.get_model_adapter = _bak_get_model_adapter
+
+
+def _fastchat_get_adapter_monkey_patch(model_path: str, model_name: str = None):
+ if not model_name:
+ if not hasattr(thread_local, "model_name"):
+ raise RuntimeError("fastchat get adapter monkey path need model_name")
+ model_name = thread_local.model_name
+ from fastchat.model.model_adapter import model_adapters
+
+ for adapter in model_adapters:
+ if adapter.match(model_name):
+ logger.info(
+ f"Found llm model adapter with model name: {model_name}, {adapter}"
+ )
+ return adapter
+
+ model_path_basename = (
+ None if not model_path else os.path.basename(os.path.normpath(model_path))
+ )
+ for adapter in model_adapters:
+ if model_path_basename and adapter.match(model_path_basename):
+ logger.info(
+ f"Found llm model adapter with model path: {model_path} and base name: {model_path_basename}, {adapter}"
+ )
+ return adapter
+
+ for adapter in model_adapters:
+ if model_path and adapter.match(model_path):
+ logger.info(
+ f"Found llm model adapter with model path: {model_path}, {adapter}"
+ )
+ return adapter
+
+ raise ValueError(
+ f"Invalid model adapter for model name {model_name} and model path {model_path}"
+ )
+
+
+@cache
+def _remove_black_list_model_of_fastchat():
+ from fastchat.model.model_adapter import model_adapters
+
+ black_list_models = []
+ for adapter in model_adapters:
+ try:
+ if (
+ adapter.get_default_conv_template("/data/not_exist_model_path").name
+ in __BLACK_LIST_MODEL_PROMPT
+ ):
+ black_list_models.append(adapter)
+ except Exception:
+ pass
+ for adapter in black_list_models:
+ model_adapters.remove(adapter)
+
+
+# Covering the configuration of fastcaht, we will regularly feedback the code here to fastchat.
+# We also recommend that you modify it directly in the fastchat repository.
+
+# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L212
+register_conv_template(
+ Conversation(
+ name="aquila-legacy",
+ system_message="A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
+ roles=("### Human: ", "### Assistant: ", "System"),
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.NO_COLON_TWO,
+ sep="\n",
+ sep2="",
+ stop_str=["", "[UNK]"],
+ ),
+ override=True,
+)
+# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L227
+register_conv_template(
+ Conversation(
+ name="aquila",
+ system_message="A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
+ roles=("Human", "Assistant", "System"),
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.ADD_COLON_TWO,
+ sep="###",
+ sep2="",
+ stop_str=["", "[UNK]"],
+ ),
+ override=True,
+)
+# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L242
+register_conv_template(
+ Conversation(
+ name="aquila-v1",
+ roles=("<|startofpiece|>", "<|endofpiece|>", ""),
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.NO_COLON_TWO,
+ sep="",
+ sep2="",
+ stop_str=["", "<|endoftext|>"],
+ ),
+ override=True,
+)
diff --git a/dbgpt/model/adapter/hf_adapter.py b/dbgpt/model/adapter/hf_adapter.py
new file mode 100644
index 000000000..673223b49
--- /dev/null
+++ b/dbgpt/model/adapter/hf_adapter.py
@@ -0,0 +1,136 @@
+from abc import ABC, abstractmethod
+from typing import Dict, Optional, List, Any
+import logging
+
+from dbgpt.core import ModelMessage
+from dbgpt.model.base import ModelType
+from dbgpt.model.adapter.base import LLMModelAdapter, register_model_adapter
+
+logger = logging.getLogger(__name__)
+
+
+class NewHFChatModelAdapter(LLMModelAdapter, ABC):
+ """Model adapter for new huggingface chat models
+
+ See https://huggingface.co/docs/transformers/main/en/chat_templating
+
+ We can transform the inference chat messages to chat model instead of create a
+ prompt template for this model
+ """
+
+ def new_adapter(self, **kwargs) -> "NewHFChatModelAdapter":
+ return self.__class__()
+
+ def match(
+ self,
+ model_type: str,
+ model_name: Optional[str] = None,
+ model_path: Optional[str] = None,
+ ) -> bool:
+ if model_type != ModelType.HF:
+ return False
+ if model_name is None and model_path is None:
+ return False
+ model_name = model_name.lower() if model_name else None
+ model_path = model_path.lower() if model_path else None
+ return self.do_match(model_name) or self.do_match(model_path)
+
+ @abstractmethod
+ def do_match(self, lower_model_name_or_path: Optional[str] = None):
+ raise NotImplementedError()
+
+ def load(self, model_path: str, from_pretrained_kwargs: dict):
+ try:
+ import transformers
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
+ except ImportError as exc:
+ raise ValueError(
+ "Could not import depend python package "
+ "Please install it with `pip install transformers`."
+ ) from exc
+ if not transformers.__version__ >= "4.34.0":
+ raise ValueError(
+ "Current model (Load by NewHFChatModelAdapter) require transformers.__version__>=4.34.0"
+ )
+ revision = from_pretrained_kwargs.get("revision", "main")
+ try:
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_path,
+ use_fast=self.use_fast_tokenizer(),
+ revision=revision,
+ trust_remote_code=True,
+ )
+ except TypeError:
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_path, use_fast=False, revision=revision, trust_remote_code=True
+ )
+ try:
+ model = AutoModelForCausalLM.from_pretrained(
+ model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
+ )
+ except NameError:
+ model = AutoModel.from_pretrained(
+ model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
+ )
+ # tokenizer.use_default_system_prompt = False
+ return model, tokenizer
+
+ def get_generate_stream_function(self, model, model_path: str):
+ """Get the generate stream function of the model"""
+ from dbgpt.model.llm_out.hf_chat_llm import huggingface_chat_generate_stream
+
+ return huggingface_chat_generate_stream
+
+ def get_str_prompt(
+ self,
+ params: Dict,
+ messages: List[ModelMessage],
+ tokenizer: Any,
+ prompt_template: str = None,
+ ) -> Optional[str]:
+ from transformers import AutoTokenizer
+
+ if not tokenizer:
+ raise ValueError("tokenizer is is None")
+ tokenizer: AutoTokenizer = tokenizer
+
+ messages = self.transform_model_messages(messages)
+ logger.debug(f"The messages after transform: \n{messages}")
+ str_prompt = tokenizer.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+ )
+ return str_prompt
+
+
+class YiAdapter(NewHFChatModelAdapter):
+ support_4bit: bool = True
+ support_8bit: bool = True
+ support_system_message: bool = True
+
+ def do_match(self, lower_model_name_or_path: Optional[str] = None):
+ return (
+ lower_model_name_or_path
+ and "yi-" in lower_model_name_or_path
+ and "chat" in lower_model_name_or_path
+ )
+
+
+class Mixtral8x7BAdapter(NewHFChatModelAdapter):
+ """
+ https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1
+ """
+
+ support_4bit: bool = True
+ support_8bit: bool = True
+ support_system_message: bool = False
+
+ def do_match(self, lower_model_name_or_path: Optional[str] = None):
+ return (
+ lower_model_name_or_path
+ and "mixtral" in lower_model_name_or_path
+ and "8x7b" in lower_model_name_or_path
+ )
+
+
+register_model_adapter(YiAdapter)
+register_model_adapter(Mixtral8x7BAdapter)
diff --git a/dbgpt/model/adapter/model_adapter.py b/dbgpt/model/adapter/model_adapter.py
new file mode 100644
index 000000000..efaf9f1b9
--- /dev/null
+++ b/dbgpt/model/adapter/model_adapter.py
@@ -0,0 +1,166 @@
+from __future__ import annotations
+
+from typing import (
+ List,
+ Type,
+ Optional,
+)
+import logging
+import threading
+import os
+from functools import cache
+from dbgpt.model.base import ModelType
+from dbgpt.model.parameter import BaseModelParameters
+from dbgpt.model.adapter.base import LLMModelAdapter, get_model_adapter
+from dbgpt.model.adapter.template import (
+ ConversationAdapter,
+ ConversationAdapterFactory,
+)
+
+logger = logging.getLogger(__name__)
+
+thread_local = threading.local()
+_IS_BENCHMARK = os.getenv("DB_GPT_MODEL_BENCHMARK", "False").lower() == "true"
+
+
+_OLD_MODELS = [
+ "llama-cpp",
+ "proxyllm",
+ "gptj-6b",
+ "codellama-13b-sql-sft",
+ "codellama-7b",
+ "codellama-7b-sql-sft",
+ "codellama-13b",
+]
+
+
+@cache
+def get_llm_model_adapter(
+ model_name: str,
+ model_path: str,
+ use_fastchat: bool = True,
+ use_fastchat_monkey_patch: bool = False,
+ model_type: str = None,
+) -> LLMModelAdapter:
+ conv_factory = DefaultConversationAdapterFactory()
+ if model_type == ModelType.VLLM:
+ logger.info("Current model type is vllm, return VLLMModelAdapterWrapper")
+ from dbgpt.model.adapter.vllm_adapter import VLLMModelAdapterWrapper
+
+ return VLLMModelAdapterWrapper(conv_factory)
+
+ # Import NewHFChatModelAdapter for it can be registered
+ from dbgpt.model.adapter.hf_adapter import NewHFChatModelAdapter
+
+ new_model_adapter = get_model_adapter(
+ model_type, model_name, model_path, conv_factory
+ )
+ if new_model_adapter:
+ logger.info(f"Current model {model_name} use new adapter {new_model_adapter}")
+ return new_model_adapter
+
+ must_use_old = any(m in model_name for m in _OLD_MODELS)
+ result_adapter: Optional[LLMModelAdapter] = None
+ if use_fastchat and not must_use_old:
+ logger.info("Use fastcat adapter")
+ from dbgpt.model.adapter.fschat_adapter import (
+ _get_fastchat_model_adapter,
+ _fastchat_get_adapter_monkey_patch,
+ FastChatLLMModelAdapterWrapper,
+ )
+
+ adapter = _get_fastchat_model_adapter(
+ model_name,
+ model_path,
+ _fastchat_get_adapter_monkey_patch,
+ use_fastchat_monkey_patch=use_fastchat_monkey_patch,
+ )
+ if adapter:
+ result_adapter = FastChatLLMModelAdapterWrapper(adapter)
+
+ else:
+ from dbgpt.model.adapter.old_adapter import (
+ get_llm_model_adapter as _old_get_llm_model_adapter,
+ OldLLMModelAdapterWrapper,
+ )
+ from dbgpt.app.chat_adapter import get_llm_chat_adapter
+
+ logger.info("Use DB-GPT old adapter")
+ result_adapter = OldLLMModelAdapterWrapper(
+ _old_get_llm_model_adapter(model_name, model_path),
+ get_llm_chat_adapter(model_name, model_path),
+ )
+ if result_adapter:
+ result_adapter.model_name = model_name
+ result_adapter.model_path = model_path
+ result_adapter.conv_factory = conv_factory
+ return result_adapter
+ else:
+ raise ValueError(f"Can not find adapter for model {model_name}")
+
+
+@cache
+def _auto_get_conv_template(
+ model_name: str, model_path: str
+) -> Optional[ConversationAdapter]:
+ """Auto get the conversation template.
+
+ Args:
+ model_name (str): The name of the model.
+ model_path (str): The path of the model.
+
+ Returns:
+ Optional[ConversationAdapter]: The conversation template.
+ """
+ try:
+ adapter = get_llm_model_adapter(model_name, model_path, use_fastchat=True)
+ return adapter.get_default_conv_template(model_name, model_path)
+ except Exception as e:
+ logger.debug(f"Failed to get conv template for {model_name} {model_path}: {e}")
+ return None
+
+
+class DefaultConversationAdapterFactory(ConversationAdapterFactory):
+ def get_by_model(self, model_name: str, model_path: str) -> ConversationAdapter:
+ """Get a conversation adapter by model.
+
+ Args:
+ model_name (str): The name of the model.
+ model_path (str): The path of the model.
+ Returns:
+ ConversationAdapter: The conversation adapter.
+ """
+ return _auto_get_conv_template(model_name, model_path)
+
+
+def _dynamic_model_parser() -> Optional[List[Type[BaseModelParameters]]]:
+ """Dynamic model parser, parse the model parameters from the command line arguments.
+
+ Returns:
+ Optional[List[Type[BaseModelParameters]]]: The model parameters class list.
+ """
+ from dbgpt.util.parameter_utils import _SimpleArgParser
+ from dbgpt.model.parameter import (
+ EmbeddingModelParameters,
+ WorkerType,
+ EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
+ )
+
+ pre_args = _SimpleArgParser("model_name", "model_path", "worker_type", "model_type")
+ pre_args.parse()
+ model_name = pre_args.get("model_name")
+ model_path = pre_args.get("model_path")
+ worker_type = pre_args.get("worker_type")
+ model_type = pre_args.get("model_type")
+ if model_name is None and model_type != ModelType.VLLM:
+ return None
+ if worker_type == WorkerType.TEXT2VEC:
+ return [
+ EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
+ model_name, EmbeddingModelParameters
+ )
+ ]
+
+ llm_adapter = get_llm_model_adapter(model_name, model_path, model_type=model_type)
+ param_class = llm_adapter.model_param_class()
+ return [param_class]
diff --git a/dbgpt/model/adapter.py b/dbgpt/model/adapter/old_adapter.py
similarity index 82%
rename from dbgpt/model/adapter.py
rename to dbgpt/model/adapter/old_adapter.py
index 05be0e464..a63054695 100644
--- a/dbgpt/model/adapter.py
+++ b/dbgpt/model/adapter/old_adapter.py
@@ -9,7 +9,7 @@
import re
import logging
from pathlib import Path
-from typing import List, Tuple
+from typing import List, Tuple, TYPE_CHECKING, Optional
from functools import cache
from transformers import (
AutoModel,
@@ -17,6 +17,9 @@
AutoTokenizer,
LlamaTokenizer,
)
+
+from dbgpt.model.adapter.base import LLMModelAdapter
+from dbgpt.model.adapter.template import ConversationAdapter, PromptType
from dbgpt.model.base import ModelType
from dbgpt.model.parameter import (
@@ -24,9 +27,13 @@
LlamaCppModelParameters,
ProxyModelParameters,
)
+from dbgpt.model.conversation import Conversation
from dbgpt.configs.model_config import get_device
from dbgpt._private.config import Config
+if TYPE_CHECKING:
+ from dbgpt.app.chat_adapter import BaseChatAdpter
+
logger = logging.getLogger(__name__)
CFG = Config()
@@ -92,17 +99,6 @@ def get_llm_model_adapter(model_name: str, model_path: str) -> BaseLLMAdaper:
)
-def _parse_model_param_class(model_name: str, model_path: str) -> ModelParameters:
- try:
- llm_adapter = get_llm_model_adapter(model_name, model_path)
- return llm_adapter.model_param_class()
- except Exception as e:
- logger.warn(
- f"Parse model parameters with model name {model_name} and model {model_path} failed {str(e)}, return `ModelParameters`"
- )
- return ModelParameters
-
-
# TODO support cpu? for practise we support gpt4all or chatglm-6b-int4?
@@ -426,6 +422,87 @@ def loader(self, model_path: str, from_pretrained_kwargs: dict):
return model, tokenizer
+class OldLLMModelAdapterWrapper(LLMModelAdapter):
+ """Wrapping old adapter, which may be removed later"""
+
+ def __init__(self, adapter: BaseLLMAdaper, chat_adapter: "BaseChatAdpter") -> None:
+ self._adapter = adapter
+ self._chat_adapter = chat_adapter
+
+ def new_adapter(self, **kwargs) -> "LLMModelAdapter":
+ return OldLLMModelAdapterWrapper(self._adapter, self._chat_adapter)
+
+ def use_fast_tokenizer(self) -> bool:
+ return self._adapter.use_fast_tokenizer()
+
+ def model_type(self) -> str:
+ return self._adapter.model_type()
+
+ def model_param_class(self, model_type: str = None) -> ModelParameters:
+ return self._adapter.model_param_class(model_type)
+
+ def get_default_conv_template(
+ self, model_name: str, model_path: str
+ ) -> Optional[ConversationAdapter]:
+ conv_template = self._chat_adapter.get_conv_template(model_path)
+ return OldConversationAdapter(conv_template) if conv_template else None
+
+ def load(self, model_path: str, from_pretrained_kwargs: dict):
+ return self._adapter.loader(model_path, from_pretrained_kwargs)
+
+ def get_generate_stream_function(self, model, model_path: str):
+ return self._chat_adapter.get_generate_stream_func(model_path)
+
+ def __str__(self) -> str:
+ return "{}({}.{})".format(
+ self.__class__.__name__,
+ self._adapter.__class__.__module__,
+ self._adapter.__class__.__name__,
+ )
+
+
+class OldConversationAdapter(ConversationAdapter):
+ """Wrapping old Conversation, which may be removed later"""
+
+ def __init__(self, conv: Conversation) -> None:
+ self._conv = conv
+
+ @property
+ def prompt_type(self) -> PromptType:
+ return PromptType.DBGPT
+
+ @property
+ def roles(self) -> Tuple[str]:
+ return self._conv.roles
+
+ @property
+ def sep(self) -> Optional[str]:
+ return self._conv.sep
+
+ @property
+ def stop_str(self) -> str:
+ return self._conv.stop_str
+
+ @property
+ def stop_token_ids(self) -> Optional[List[int]]:
+ return self._conv.stop_token_ids
+
+ def get_prompt(self) -> str:
+ return self._conv.get_prompt()
+
+ def set_system_message(self, system_message: str) -> None:
+ self._conv.update_system_message(system_message)
+
+ def append_message(self, role: str, message: str) -> None:
+ self._conv.append_message(role, message)
+
+ def update_last_message(self, message: str) -> None:
+ self._conv.update_last_message(message)
+
+ def copy(self) -> "ConversationAdapter":
+ return OldConversationAdapter(self._conv.copy())
+
+
register_llm_model_adapters(VicunaLLMAdapater)
register_llm_model_adapters(ChatGLMAdapater)
register_llm_model_adapters(GuanacoAdapter)
diff --git a/dbgpt/model/adapter/template.py b/dbgpt/model/adapter/template.py
new file mode 100644
index 000000000..3fb9a6ec1
--- /dev/null
+++ b/dbgpt/model/adapter/template.py
@@ -0,0 +1,130 @@
+from abc import ABC, abstractmethod
+from enum import Enum
+from typing import TYPE_CHECKING, Optional, Tuple, Union, List
+
+if TYPE_CHECKING:
+ from fastchat.conversation import Conversation
+
+
+class PromptType(str, Enum):
+ """Prompt type."""
+
+ FSCHAT: str = "fschat"
+ DBGPT: str = "dbgpt"
+
+
+class ConversationAdapter(ABC):
+ """The conversation adapter."""
+
+ @property
+ def prompt_type(self) -> PromptType:
+ return PromptType.FSCHAT
+
+ @property
+ @abstractmethod
+ def roles(self) -> Tuple[str]:
+ """Get the roles of the conversation.
+
+ Returns:
+ Tuple[str]: The roles of the conversation.
+ """
+
+ @property
+ def sep(self) -> Optional[str]:
+ """Get the separator between messages."""
+ return "\n"
+
+ @property
+ def stop_str(self) -> Optional[Union[str, List[str]]]:
+ """Get the stop criteria."""
+ return None
+
+ @property
+ def stop_token_ids(self) -> Optional[List[int]]:
+ """Stops generation if meeting any token in this list"""
+ return None
+
+ @abstractmethod
+ def get_prompt(self) -> str:
+ """Get the prompt string.
+
+ Returns:
+ str: The prompt string.
+ """
+
+ @abstractmethod
+ def set_system_message(self, system_message: str) -> None:
+ """Set the system message."""
+
+ @abstractmethod
+ def append_message(self, role: str, message: str) -> None:
+ """Append a new message.
+ Args:
+ role (str): The role of the message.
+ message (str): The message content.
+ """
+
+ @abstractmethod
+ def update_last_message(self, message: str) -> None:
+ """Update the last output.
+
+ The last message is typically set to be None when constructing the prompt,
+ so we need to update it in-place after getting the response from a model.
+
+ Args:
+ message (str): The message content.
+ """
+
+ @abstractmethod
+ def copy(self) -> "ConversationAdapter":
+ """Copy the conversation."""
+
+
+class ConversationAdapterFactory(ABC):
+ """The conversation adapter factory."""
+
+ def get_by_name(
+ self,
+ template_name: str,
+ prompt_template_type: Optional[PromptType] = PromptType.FSCHAT,
+ ) -> ConversationAdapter:
+ """Get a conversation adapter by name.
+
+ Args:
+ template_name (str): The name of the template.
+ prompt_template_type (Optional[PromptType]): The type of the prompt template, default to be FSCHAT.
+
+ Returns:
+ ConversationAdapter: The conversation adapter.
+ """
+ raise NotImplementedError()
+
+ def get_by_model(self, model_name: str, model_path: str) -> ConversationAdapter:
+ """Get a conversation adapter by model.
+
+ Args:
+ model_name (str): The name of the model.
+ model_path (str): The path of the model.
+
+ Returns:
+ ConversationAdapter: The conversation adapter.
+ """
+ raise NotImplementedError()
+
+
+def get_conv_template(name: str) -> ConversationAdapter:
+ """Get a conversation template.
+
+ Args:
+ name (str): The name of the template.
+
+ Just return the fastchat conversation template for now.
+ # TODO: More templates should be supported.
+ Returns:
+ Conversation: The conversation template.
+ """
+ from fastchat.conversation import get_conv_template
+ from dbgpt.model.adapter.fschat_adapter import FschatConversationAdapter
+
+ conv_template = get_conv_template(name)
+ return FschatConversationAdapter(conv_template)
diff --git a/dbgpt/model/adapter/vllm_adapter.py b/dbgpt/model/adapter/vllm_adapter.py
new file mode 100644
index 000000000..2ffe0c764
--- /dev/null
+++ b/dbgpt/model/adapter/vllm_adapter.py
@@ -0,0 +1,93 @@
+import dataclasses
+import logging
+from dbgpt.model.base import ModelType
+from dbgpt.model.adapter.base import LLMModelAdapter
+from dbgpt.model.adapter.template import ConversationAdapter, ConversationAdapterFactory
+from dbgpt.model.parameter import BaseModelParameters
+from dbgpt.util.parameter_utils import (
+ _extract_parameter_details,
+ _build_parameter_class,
+ _get_dataclass_print_str,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class VLLMModelAdapterWrapper(LLMModelAdapter):
+ """Wrapping vllm engine"""
+
+ def __init__(self, conv_factory: ConversationAdapterFactory):
+ self.conv_factory = conv_factory
+
+ def new_adapter(self, **kwargs) -> "VLLMModelAdapterWrapper":
+ return VLLMModelAdapterWrapper(self.conv_factory)
+
+ def model_type(self) -> str:
+ return ModelType.VLLM
+
+ def model_param_class(self, model_type: str = None) -> BaseModelParameters:
+ import argparse
+ from vllm.engine.arg_utils import AsyncEngineArgs
+
+ parser = argparse.ArgumentParser()
+ parser = AsyncEngineArgs.add_cli_args(parser)
+ parser.add_argument("--model_name", type=str, help="model name")
+ parser.add_argument(
+ "--model_path",
+ type=str,
+ help="local model path of the huggingface model to use",
+ )
+ parser.add_argument("--model_type", type=str, help="model type")
+ parser.add_argument("--device", type=str, default=None, help="device")
+ # TODO parse prompt templete from `model_name` and `model_path`
+ parser.add_argument(
+ "--prompt_template",
+ type=str,
+ default=None,
+ help="Prompt template. If None, the prompt template is automatically determined from model path",
+ )
+
+ descs = _extract_parameter_details(
+ parser,
+ "dbgpt.model.parameter.VLLMModelParameters",
+ skip_names=["model"],
+ overwrite_default_values={"trust_remote_code": True},
+ )
+ return _build_parameter_class(descs)
+
+ def load_from_params(self, params):
+ from vllm import AsyncLLMEngine
+ from vllm.engine.arg_utils import AsyncEngineArgs
+ import torch
+
+ num_gpus = torch.cuda.device_count()
+ if num_gpus > 1 and hasattr(params, "tensor_parallel_size"):
+ setattr(params, "tensor_parallel_size", num_gpus)
+ logger.info(
+ f"Start vllm AsyncLLMEngine with args: {_get_dataclass_print_str(params)}"
+ )
+
+ params = dataclasses.asdict(params)
+ params["model"] = params["model_path"]
+ attrs = [attr.name for attr in dataclasses.fields(AsyncEngineArgs)]
+ vllm_engine_args_dict = {attr: params.get(attr) for attr in attrs}
+ # Set the attributes from the parsed arguments.
+ engine_args = AsyncEngineArgs(**vllm_engine_args_dict)
+ engine = AsyncLLMEngine.from_engine_args(engine_args)
+ return engine, engine.engine.tokenizer
+
+ def support_async(self) -> bool:
+ return True
+
+ def get_async_generate_stream_function(self, model, model_path: str):
+ from dbgpt.model.llm_out.vllm_llm import generate_stream
+
+ return generate_stream
+
+ def get_default_conv_template(
+ self, model_name: str, model_path: str
+ ) -> ConversationAdapter:
+ return self.conv_factory.get_by_model(model_name, model_path)
+
+ def __str__(self) -> str:
+ return "{}.{}".format(self.__class__.__module__, self.__class__.__name__)
diff --git a/dbgpt/model/cli.py b/dbgpt/model/cli.py
index 67d5eec3c..67e12d7eb 100644
--- a/dbgpt/model/cli.py
+++ b/dbgpt/model/cli.py
@@ -405,7 +405,7 @@ def stop_model_controller(port: int):
def _model_dynamic_factory() -> Callable[[None], List[Type]]:
- from dbgpt.model.model_adapter import _dynamic_model_parser
+ from dbgpt.model.adapter.model_adapter import _dynamic_model_parser
param_class = _dynamic_model_parser()
fix_class = [ModelWorkerParameters]
diff --git a/dbgpt/model/cluster/worker/default_worker.py b/dbgpt/model/cluster/worker/default_worker.py
index 7345bb5d4..c4967a076 100644
--- a/dbgpt/model/cluster/worker/default_worker.py
+++ b/dbgpt/model/cluster/worker/default_worker.py
@@ -6,7 +6,8 @@
import traceback
from dbgpt.configs.model_config import get_device
-from dbgpt.model.model_adapter import get_llm_model_adapter, LLMModelAdaper
+from dbgpt.model.adapter.base import LLMModelAdapter
+from dbgpt.model.adapter.model_adapter import get_llm_model_adapter
from dbgpt.core import ModelOutput, ModelInferenceMetrics
from dbgpt.model.loader import ModelLoader, _get_model_real_path
from dbgpt.model.parameter import ModelParameters
@@ -27,7 +28,7 @@ def __init__(self) -> None:
self.model = None
self.tokenizer = None
self._model_params = None
- self.llm_adapter: LLMModelAdaper = None
+ self.llm_adapter: LLMModelAdapter = None
self._support_async = False
def load_worker(self, model_name: str, model_path: str, **kwargs) -> None:
diff --git a/dbgpt/model/llm_utils.py b/dbgpt/model/llm_utils.py
index e877778ed..031896e86 100644
--- a/dbgpt/model/llm_utils.py
+++ b/dbgpt/model/llm_utils.py
@@ -37,7 +37,7 @@ def list_supported_models():
def _list_supported_models(
worker_type: str, model_config: Dict[str, str]
) -> List[SupportedModel]:
- from dbgpt.model.model_adapter import get_llm_model_adapter
+ from dbgpt.model.adapter.model_adapter import get_llm_model_adapter
from dbgpt.model.loader import _get_model_real_path
ret = []
diff --git a/dbgpt/model/loader.py b/dbgpt/model/loader.py
index 2030eb402..4ed42d630 100644
--- a/dbgpt/model/loader.py
+++ b/dbgpt/model/loader.py
@@ -1,13 +1,13 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
-from typing import Optional, Dict
+from typing import Optional, Dict, Any
-from dataclasses import asdict
import logging
from dbgpt.configs.model_config import get_device
from dbgpt.model.base import ModelType
-from dbgpt.model.model_adapter import get_llm_model_adapter, LLMModelAdaper
+from dbgpt.model.adapter.base import LLMModelAdapter
+from dbgpt.model.adapter.model_adapter import get_llm_model_adapter
from dbgpt.model.parameter import (
ModelParameters,
LlamaCppModelParameters,
@@ -117,7 +117,7 @@ def loader(
raise Exception(f"Unkown model type {model_type}")
def loader_with_params(
- self, model_params: ModelParameters, llm_adapter: LLMModelAdaper
+ self, model_params: ModelParameters, llm_adapter: LLMModelAdapter
):
model_type = llm_adapter.model_type()
self.prompt_template = model_params.prompt_template
@@ -133,7 +133,7 @@ def loader_with_params(
raise Exception(f"Unkown model type {model_type}")
-def huggingface_loader(llm_adapter: LLMModelAdaper, model_params: ModelParameters):
+def huggingface_loader(llm_adapter: LLMModelAdapter, model_params: ModelParameters):
import torch
from dbgpt.model.compression import compress_module
@@ -174,6 +174,12 @@ def huggingface_loader(llm_adapter: LLMModelAdaper, model_params: ModelParameter
else:
raise ValueError(f"Invalid device: {device}")
+ model, tokenizer = _try_load_default_quantization_model(
+ llm_adapter, device, num_gpus, model_params, kwargs
+ )
+ if model:
+ return model, tokenizer
+
can_quantization = _check_quantization(model_params)
if can_quantization and (num_gpus > 1 or model_params.load_4bit):
@@ -192,6 +198,46 @@ def huggingface_loader(llm_adapter: LLMModelAdaper, model_params: ModelParameter
# TODO merge current code into `load_huggingface_quantization_model`
compress_module(model, model_params.device)
+ return _handle_model_and_tokenizer(model, tokenizer, device, num_gpus, model_params)
+
+
+def _try_load_default_quantization_model(
+ llm_adapter: LLMModelAdapter,
+ device: str,
+ num_gpus: int,
+ model_params: ModelParameters,
+ kwargs: Dict[str, Any],
+):
+ """Try load default quantization model(Support by huggingface default)"""
+ cloned_kwargs = {k: v for k, v in kwargs.items()}
+ try:
+ model, tokenizer = None, None
+ if device != "cuda":
+ return None, None
+ elif model_params.load_8bit and llm_adapter.support_8bit:
+ cloned_kwargs["load_in_8bit"] = True
+ model, tokenizer = llm_adapter.load(model_params.model_path, cloned_kwargs)
+ elif model_params.load_4bit and llm_adapter.support_4bit:
+ cloned_kwargs["load_in_4bit"] = True
+ model, tokenizer = llm_adapter.load(model_params.model_path, cloned_kwargs)
+ if model:
+ logger.info(
+ f"Load default quantization model {model_params.model_name} success"
+ )
+ return _handle_model_and_tokenizer(
+ model, tokenizer, device, num_gpus, model_params
+ )
+ return None, None
+ except Exception as e:
+ logger.warning(
+ f"Load default quantization model {model_params.model_name} failed, error: {str(e)}"
+ )
+ return None, None
+
+
+def _handle_model_and_tokenizer(
+ model, tokenizer, device: str, num_gpus: int, model_params: ModelParameters
+):
if (
(device == "cuda" and num_gpus == 1 and not model_params.cpu_offloading)
or device == "mps"
@@ -209,7 +255,7 @@ def huggingface_loader(llm_adapter: LLMModelAdaper, model_params: ModelParameter
def load_huggingface_quantization_model(
- llm_adapter: LLMModelAdaper,
+ llm_adapter: LLMModelAdapter,
model_params: ModelParameters,
kwargs: Dict,
max_memory: Dict[int, str],
@@ -344,7 +390,9 @@ def load_huggingface_quantization_model(
return model, tokenizer
-def llamacpp_loader(llm_adapter: LLMModelAdaper, model_params: LlamaCppModelParameters):
+def llamacpp_loader(
+ llm_adapter: LLMModelAdapter, model_params: LlamaCppModelParameters
+):
try:
from dbgpt.model.llm.llama_cpp.llama_cpp import LlamaCppModel
except ImportError as exc:
@@ -358,7 +406,7 @@ def llamacpp_loader(llm_adapter: LLMModelAdaper, model_params: LlamaCppModelPara
return model, tokenizer
-def proxyllm_loader(llm_adapter: LLMModelAdaper, model_params: ProxyModelParameters):
+def proxyllm_loader(llm_adapter: LLMModelAdapter, model_params: ProxyModelParameters):
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
logger.info("Load proxyllm")
diff --git a/dbgpt/model/model_adapter.py b/dbgpt/model/model_adapter.py
deleted file mode 100644
index 85243a12b..000000000
--- a/dbgpt/model/model_adapter.py
+++ /dev/null
@@ -1,660 +0,0 @@
-from __future__ import annotations
-
-from typing import Callable, List, Dict, Type, Tuple, TYPE_CHECKING, Any, Optional
-import dataclasses
-import logging
-import threading
-import os
-from functools import cache
-from dbgpt.model.base import ModelType
-from dbgpt.model.parameter import (
- ModelParameters,
- LlamaCppModelParameters,
- ProxyModelParameters,
-)
-from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
-from dbgpt.util.parameter_utils import (
- _extract_parameter_details,
- _build_parameter_class,
- _get_dataclass_print_str,
-)
-
-try:
- from fastchat.conversation import (
- Conversation,
- register_conv_template,
- SeparatorStyle,
- )
-except ImportError as exc:
- raise ValueError(
- "Could not import python package: fschat "
- "Please install fastchat by command `pip install fschat` "
- ) from exc
-
-if TYPE_CHECKING:
- from fastchat.model.model_adapter import BaseModelAdapter
- from dbgpt.model.adapter import BaseLLMAdaper as OldBaseLLMAdaper
- from torch.nn import Module as TorchNNModule
-
-logger = logging.getLogger(__name__)
-
-thread_local = threading.local()
-_IS_BENCHMARK = os.getenv("DB_GPT_MODEL_BENCHMARK", "False").lower() == "true"
-
-
-_OLD_MODELS = [
- "llama-cpp",
- "proxyllm",
- "gptj-6b",
- "codellama-13b-sql-sft",
- "codellama-7b",
- "codellama-7b-sql-sft",
- "codellama-13b",
-]
-
-_NEW_HF_CHAT_MODELS = [
- "yi-34b",
- "yi-6b",
-]
-
-# The implementation of some models in fastchat will affect the DB-GPT loading model and will be temporarily added to the blacklist.
-_BLACK_LIST_MODLE_PROMPT = ["OpenHermes-2.5-Mistral-7B"]
-
-
-class LLMModelAdaper:
- """New Adapter for DB-GPT LLM models"""
-
- def use_fast_tokenizer(self) -> bool:
- """Whether use a [fast Rust-based tokenizer](https://huggingface.co/docs/tokenizers/index) if it is supported
- for a given model.
- """
- return False
-
- def model_type(self) -> str:
- return ModelType.HF
-
- def model_param_class(self, model_type: str = None) -> ModelParameters:
- """Get the startup parameters instance of the model"""
- model_type = model_type if model_type else self.model_type()
- if model_type == ModelType.LLAMA_CPP:
- return LlamaCppModelParameters
- elif model_type == ModelType.PROXY:
- return ProxyModelParameters
- return ModelParameters
-
- def load(self, model_path: str, from_pretrained_kwargs: dict):
- """Load model and tokenizer"""
- raise NotImplementedError
-
- def load_from_params(self, params):
- """Load the model and tokenizer according to the given parameters"""
- raise NotImplementedError
-
- def support_async(self) -> bool:
- """Whether the loaded model supports asynchronous calls"""
- return False
-
- def get_generate_stream_function(self, model, model_path: str):
- """Get the generate stream function of the model"""
- raise NotImplementedError
-
- def get_async_generate_stream_function(self, model, model_path: str):
- """Get the asynchronous generate stream function of the model"""
- raise NotImplementedError
-
- def get_default_conv_template(
- self, model_name: str, model_path: str
- ) -> "Conversation":
- """Get the default conv template"""
- raise NotImplementedError
-
- def get_str_prompt(
- self,
- params: Dict,
- messages: List[ModelMessage],
- tokenizer: Any,
- prompt_template: str = None,
- ) -> Optional[str]:
- return None
-
- def get_prompt_with_template(
- self,
- params: Dict,
- messages: List[ModelMessage],
- model_name: str,
- model_path: str,
- model_context: Dict,
- prompt_template: str = None,
- ):
- conv = self.get_default_conv_template(model_name, model_path)
-
- if prompt_template:
- logger.info(f"Use prompt template {prompt_template} from config")
- conv = get_conv_template(prompt_template)
- if not conv or not messages:
- # Nothing to do
- logger.info(
- f"No conv from model_path {model_path} or no messages in params, {self}"
- )
- return None, None, None
-
- conv = conv.copy()
- system_messages = []
- user_messages = []
- ai_messages = []
-
- for message in messages:
- role, content = None, None
- if isinstance(message, ModelMessage):
- role = message.role
- content = message.content
- elif isinstance(message, dict):
- role = message["role"]
- content = message["content"]
- else:
- raise ValueError(f"Invalid message type: {message}")
-
- if role == ModelMessageRoleType.SYSTEM:
- # Support for multiple system messages
- system_messages.append(content)
- elif role == ModelMessageRoleType.HUMAN:
- # conv.append_message(conv.roles[0], content)
- user_messages.append(content)
- elif role == ModelMessageRoleType.AI:
- # conv.append_message(conv.roles[1], content)
- ai_messages.append(content)
- else:
- raise ValueError(f"Unknown role: {role}")
-
- can_use_systems: [] = []
- if system_messages:
- if len(system_messages) > 1:
- ## Compatible with dbgpt complex scenarios, the last system will protect more complete information entered by the current user
- user_messages[-1] = system_messages[-1]
- can_use_systems = system_messages[:-1]
- else:
- can_use_systems = system_messages
-
- for i in range(len(user_messages)):
- conv.append_message(conv.roles[0], user_messages[i])
- if i < len(ai_messages):
- conv.append_message(conv.roles[1], ai_messages[i])
-
- if isinstance(conv, Conversation):
- conv.set_system_message("".join(can_use_systems))
- else:
- conv.update_system_message("".join(can_use_systems))
-
- # Add a blank message for the assistant.
- conv.append_message(conv.roles[1], None)
- new_prompt = conv.get_prompt()
- return new_prompt, conv.stop_str, conv.stop_token_ids
-
- def model_adaptation(
- self,
- params: Dict,
- model_name: str,
- model_path: str,
- tokenizer: Any,
- prompt_template: str = None,
- ) -> Tuple[Dict, Dict]:
- """Params adaptation"""
- messages = params.get("messages")
- # Some model scontext to dbgpt server
- model_context = {"prompt_echo_len_char": -1, "has_format_prompt": False}
- if messages:
- # Dict message to ModelMessage
- messages = [
- m if isinstance(m, ModelMessage) else ModelMessage(**m)
- for m in messages
- ]
- params["messages"] = messages
-
- new_prompt = self.get_str_prompt(params, messages, tokenizer, prompt_template)
- conv_stop_str, conv_stop_token_ids = None, None
- if not new_prompt:
- (
- new_prompt,
- conv_stop_str,
- conv_stop_token_ids,
- ) = self.get_prompt_with_template(
- params, messages, model_name, model_path, model_context, prompt_template
- )
- if not new_prompt:
- return params, model_context
-
- # Overwrite the original prompt
- # TODO remote bos token and eos token from tokenizer_config.json of model
- prompt_echo_len_char = len(new_prompt.replace("", "").replace("", ""))
- model_context["prompt_echo_len_char"] = prompt_echo_len_char
- model_context["echo"] = params.get("echo", True)
- model_context["has_format_prompt"] = True
- params["prompt"] = new_prompt
-
- custom_stop = params.get("stop")
- custom_stop_token_ids = params.get("stop_token_ids")
-
- # Prefer the value passed in from the input parameter
- params["stop"] = custom_stop or conv_stop_str
- params["stop_token_ids"] = custom_stop_token_ids or conv_stop_token_ids
-
- return params, model_context
-
-
-class OldLLMModelAdaperWrapper(LLMModelAdaper):
- """Wrapping old adapter, which may be removed later"""
-
- def __init__(self, adapter: "OldBaseLLMAdaper", chat_adapter) -> None:
- self._adapter = adapter
- self._chat_adapter = chat_adapter
-
- def use_fast_tokenizer(self) -> bool:
- return self._adapter.use_fast_tokenizer()
-
- def model_type(self) -> str:
- return self._adapter.model_type()
-
- def model_param_class(self, model_type: str = None) -> ModelParameters:
- return self._adapter.model_param_class(model_type)
-
- def get_default_conv_template(
- self, model_name: str, model_path: str
- ) -> "Conversation":
- return self._chat_adapter.get_conv_template(model_path)
-
- def load(self, model_path: str, from_pretrained_kwargs: dict):
- return self._adapter.loader(model_path, from_pretrained_kwargs)
-
- def get_generate_stream_function(self, model, model_path: str):
- return self._chat_adapter.get_generate_stream_func(model_path)
-
- def __str__(self) -> str:
- return "{}({}.{})".format(
- self.__class__.__name__,
- self._adapter.__class__.__module__,
- self._adapter.__class__.__name__,
- )
-
-
-class FastChatLLMModelAdaperWrapper(LLMModelAdaper):
- """Wrapping fastchat adapter"""
-
- def __init__(self, adapter: "BaseModelAdapter") -> None:
- self._adapter = adapter
-
- def use_fast_tokenizer(self) -> bool:
- return self._adapter.use_fast_tokenizer
-
- def load(self, model_path: str, from_pretrained_kwargs: dict):
- return self._adapter.load_model(model_path, from_pretrained_kwargs)
-
- def get_generate_stream_function(self, model: "TorchNNModule", model_path: str):
- if _IS_BENCHMARK:
- from dbgpt.util.benchmarks.llm.fastchat_benchmarks_inference import (
- generate_stream,
- )
-
- return generate_stream
- else:
- from fastchat.model.model_adapter import get_generate_stream_function
-
- return get_generate_stream_function(model, model_path)
-
- def get_default_conv_template(
- self, model_name: str, model_path: str
- ) -> "Conversation":
- return self._adapter.get_default_conv_template(model_path)
-
- def __str__(self) -> str:
- return "{}({}.{})".format(
- self.__class__.__name__,
- self._adapter.__class__.__module__,
- self._adapter.__class__.__name__,
- )
-
-
-class NewHFChatModelAdapter(LLMModelAdaper):
- def load(self, model_path: str, from_pretrained_kwargs: dict):
- try:
- import transformers
- from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
- except ImportError as exc:
- raise ValueError(
- "Could not import depend python package "
- "Please install it with `pip install transformers`."
- ) from exc
- if not transformers.__version__ >= "4.34.0":
- raise ValueError(
- "Current model (Load by HFNewChatAdapter) require transformers.__version__>=4.34.0"
- )
- revision = from_pretrained_kwargs.get("revision", "main")
- try:
- tokenizer = AutoTokenizer.from_pretrained(
- model_path,
- use_fast=self.use_fast_tokenizer,
- revision=revision,
- trust_remote_code=True,
- )
- except TypeError:
- tokenizer = AutoTokenizer.from_pretrained(
- model_path, use_fast=False, revision=revision, trust_remote_code=True
- )
- try:
- model = AutoModelForCausalLM.from_pretrained(
- model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
- )
- except NameError:
- model = AutoModel.from_pretrained(
- model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
- )
- # tokenizer.use_default_system_prompt = False
- return model, tokenizer
-
- def get_generate_stream_function(self, model, model_path: str):
- """Get the generate stream function of the model"""
- from dbgpt.model.llm_out.hf_chat_llm import huggingface_chat_generate_stream
-
- return huggingface_chat_generate_stream
-
- def get_str_prompt(
- self,
- params: Dict,
- messages: List[ModelMessage],
- tokenizer: Any,
- prompt_template: str = None,
- ) -> Optional[str]:
- from transformers import AutoTokenizer
-
- if not tokenizer:
- raise ValueError("tokenizer is is None")
- tokenizer: AutoTokenizer = tokenizer
-
- messages = ModelMessage.to_openai_messages(messages)
- str_prompt = tokenizer.apply_chat_template(
- messages, tokenize=False, add_generation_prompt=True
- )
- return str_prompt
-
-
-def get_conv_template(name: str) -> "Conversation":
- """Get a conversation template."""
- from fastchat.conversation import get_conv_template
-
- return get_conv_template(name)
-
-
-@cache
-def _auto_get_conv_template(model_name: str, model_path: str) -> "Conversation":
- try:
- adapter = get_llm_model_adapter(model_name, model_path, use_fastchat=True)
- return adapter.get_default_conv_template(model_name, model_path)
- except Exception:
- return None
-
-
-@cache
-def get_llm_model_adapter(
- model_name: str,
- model_path: str,
- use_fastchat: bool = True,
- use_fastchat_monkey_patch: bool = False,
- model_type: str = None,
-) -> LLMModelAdaper:
- if model_type == ModelType.VLLM:
- logger.info("Current model type is vllm, return VLLMModelAdaperWrapper")
- return VLLMModelAdaperWrapper()
-
- use_new_hf_chat_models = any(m in model_name.lower() for m in _NEW_HF_CHAT_MODELS)
- if use_new_hf_chat_models:
- logger.info(f"Current model {model_name} use NewHFChatModelAdapter")
- return NewHFChatModelAdapter()
-
- must_use_old = any(m in model_name for m in _OLD_MODELS)
- if use_fastchat and not must_use_old:
- logger.info("Use fastcat adapter")
- adapter = _get_fastchat_model_adapter(
- model_name,
- model_path,
- _fastchat_get_adapter_monkey_patch,
- use_fastchat_monkey_patch=use_fastchat_monkey_patch,
- )
- return FastChatLLMModelAdaperWrapper(adapter)
- else:
- from dbgpt.model.adapter import (
- get_llm_model_adapter as _old_get_llm_model_adapter,
- )
- from dbgpt.app.chat_adapter import get_llm_chat_adapter
-
- logger.info("Use DB-GPT old adapter")
- return OldLLMModelAdaperWrapper(
- _old_get_llm_model_adapter(model_name, model_path),
- get_llm_chat_adapter(model_name, model_path),
- )
-
-
-def _get_fastchat_model_adapter(
- model_name: str,
- model_path: str,
- caller: Callable[[str], None] = None,
- use_fastchat_monkey_patch: bool = False,
-):
- from fastchat.model import model_adapter
-
- _bak_get_model_adapter = model_adapter.get_model_adapter
- try:
- if use_fastchat_monkey_patch:
- model_adapter.get_model_adapter = _fastchat_get_adapter_monkey_patch
- thread_local.model_name = model_name
- _remove_black_list_model_of_fastchat()
- if caller:
- return caller(model_path)
- finally:
- del thread_local.model_name
- model_adapter.get_model_adapter = _bak_get_model_adapter
-
-
-def _fastchat_get_adapter_monkey_patch(model_path: str, model_name: str = None):
- if not model_name:
- if not hasattr(thread_local, "model_name"):
- raise RuntimeError("fastchat get adapter monkey path need model_name")
- model_name = thread_local.model_name
- from fastchat.model.model_adapter import model_adapters
-
- for adapter in model_adapters:
- if adapter.match(model_name):
- logger.info(
- f"Found llm model adapter with model name: {model_name}, {adapter}"
- )
- return adapter
-
- model_path_basename = (
- None if not model_path else os.path.basename(os.path.normpath(model_path))
- )
- for adapter in model_adapters:
- if model_path_basename and adapter.match(model_path_basename):
- logger.info(
- f"Found llm model adapter with model path: {model_path} and base name: {model_path_basename}, {adapter}"
- )
- return adapter
-
- for adapter in model_adapters:
- if model_path and adapter.match(model_path):
- logger.info(
- f"Found llm model adapter with model path: {model_path}, {adapter}"
- )
- return adapter
-
- raise ValueError(
- f"Invalid model adapter for model name {model_name} and model path {model_path}"
- )
-
-
-@cache
-def _remove_black_list_model_of_fastchat():
- from fastchat.model.model_adapter import model_adapters
-
- black_list_models = []
- for adapter in model_adapters:
- try:
- if (
- adapter.get_default_conv_template("/data/not_exist_model_path").name
- in _BLACK_LIST_MODLE_PROMPT
- ):
- black_list_models.append(adapter)
- except Exception:
- pass
- for adapter in black_list_models:
- model_adapters.remove(adapter)
-
-
-def _dynamic_model_parser() -> Callable[[None], List[Type]]:
- from dbgpt.util.parameter_utils import _SimpleArgParser
- from dbgpt.model.parameter import (
- EmbeddingModelParameters,
- WorkerType,
- EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
- )
-
- pre_args = _SimpleArgParser("model_name", "model_path", "worker_type", "model_type")
- pre_args.parse()
- model_name = pre_args.get("model_name")
- model_path = pre_args.get("model_path")
- worker_type = pre_args.get("worker_type")
- model_type = pre_args.get("model_type")
- if model_name is None and model_type != ModelType.VLLM:
- return None
- if worker_type == WorkerType.TEXT2VEC:
- return [
- EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
- model_name, EmbeddingModelParameters
- )
- ]
-
- llm_adapter = get_llm_model_adapter(model_name, model_path, model_type=model_type)
- param_class = llm_adapter.model_param_class()
- return [param_class]
-
-
-class VLLMModelAdaperWrapper(LLMModelAdaper):
- """Wrapping vllm engine"""
-
- def model_type(self) -> str:
- return ModelType.VLLM
-
- def model_param_class(self, model_type: str = None) -> ModelParameters:
- import argparse
- from vllm.engine.arg_utils import AsyncEngineArgs
-
- parser = argparse.ArgumentParser()
- parser = AsyncEngineArgs.add_cli_args(parser)
- parser.add_argument("--model_name", type=str, help="model name")
- parser.add_argument(
- "--model_path",
- type=str,
- help="local model path of the huggingface model to use",
- )
- parser.add_argument("--model_type", type=str, help="model type")
- parser.add_argument("--device", type=str, default=None, help="device")
- # TODO parse prompt templete from `model_name` and `model_path`
- parser.add_argument(
- "--prompt_template",
- type=str,
- default=None,
- help="Prompt template. If None, the prompt template is automatically determined from model path",
- )
-
- descs = _extract_parameter_details(
- parser,
- "dbgpt.model.parameter.VLLMModelParameters",
- skip_names=["model"],
- overwrite_default_values={"trust_remote_code": True},
- )
- return _build_parameter_class(descs)
-
- def load_from_params(self, params):
- from vllm import AsyncLLMEngine
- from vllm.engine.arg_utils import AsyncEngineArgs
- import torch
-
- num_gpus = torch.cuda.device_count()
- if num_gpus > 1 and hasattr(params, "tensor_parallel_size"):
- setattr(params, "tensor_parallel_size", num_gpus)
- logger.info(
- f"Start vllm AsyncLLMEngine with args: {_get_dataclass_print_str(params)}"
- )
-
- params = dataclasses.asdict(params)
- params["model"] = params["model_path"]
- attrs = [attr.name for attr in dataclasses.fields(AsyncEngineArgs)]
- vllm_engine_args_dict = {attr: params.get(attr) for attr in attrs}
- # Set the attributes from the parsed arguments.
- engine_args = AsyncEngineArgs(**vllm_engine_args_dict)
- engine = AsyncLLMEngine.from_engine_args(engine_args)
- return engine, engine.engine.tokenizer
-
- def support_async(self) -> bool:
- return True
-
- def get_async_generate_stream_function(self, model, model_path: str):
- from dbgpt.model.llm_out.vllm_llm import generate_stream
-
- return generate_stream
-
- def get_default_conv_template(
- self, model_name: str, model_path: str
- ) -> "Conversation":
- return _auto_get_conv_template(model_name, model_path)
-
- def __str__(self) -> str:
- return "{}.{}".format(self.__class__.__module__, self.__class__.__name__)
-
-
-# Covering the configuration of fastcaht, we will regularly feedback the code here to fastchat.
-# We also recommend that you modify it directly in the fastchat repository.
-
-# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L212
-register_conv_template(
- Conversation(
- name="aquila-legacy",
- system_message="A chat between a curious human and an artificial intelligence assistant. "
- "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
- roles=("### Human: ", "### Assistant: ", "System"),
- messages=(),
- offset=0,
- sep_style=SeparatorStyle.NO_COLON_TWO,
- sep="\n",
- sep2="",
- stop_str=["", "[UNK]"],
- ),
- override=True,
-)
-# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L227
-register_conv_template(
- Conversation(
- name="aquila",
- system_message="A chat between a curious human and an artificial intelligence assistant. "
- "The assistant gives helpful, detailed, and polite answers to the human's questions.",
- roles=("Human", "Assistant", "System"),
- messages=(),
- offset=0,
- sep_style=SeparatorStyle.ADD_COLON_TWO,
- sep="###",
- sep2="",
- stop_str=["", "[UNK]"],
- ),
- override=True,
-)
-# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L242
-register_conv_template(
- Conversation(
- name="aquila-v1",
- roles=("<|startofpiece|>", "<|endofpiece|>", ""),
- messages=(),
- offset=0,
- sep_style=SeparatorStyle.NO_COLON_TWO,
- sep="",
- sep2="",
- stop_str=["", "<|endoftext|>"],
- ),
- override=True,
-)
diff --git a/dbgpt/model/proxy/llms/bard.py b/dbgpt/model/proxy/llms/bard.py
old mode 100644
new mode 100755
index 5b6ed26a0..fc398fe8b
--- a/dbgpt/model/proxy/llms/bard.py
+++ b/dbgpt/model/proxy/llms/bard.py
@@ -25,14 +25,13 @@ def bard_generate_stream(
else:
pass
- temp_his = history[::-1]
- last_user_input = None
- for m in temp_his:
- if m["role"] == "user":
- last_user_input = m
+ last_user_input_index = None
+ for i in range(len(history) - 1, -1, -1):
+ if history[i]["role"] == "user":
+ last_user_input_index = i
break
- if last_user_input:
- history.remove(last_user_input)
+ if last_user_input_index:
+ last_user_input = history.pop(last_user_input_index)
history.append(last_user_input)
msgs = []
diff --git a/dbgpt/model/proxy/llms/chatgpt.py b/dbgpt/model/proxy/llms/chatgpt.py
old mode 100644
new mode 100755
index e9229da44..d81626e7a
--- a/dbgpt/model/proxy/llms/chatgpt.py
+++ b/dbgpt/model/proxy/llms/chatgpt.py
@@ -110,14 +110,13 @@ def _build_request(model: ProxyModel, params):
pass
# Move the last user's information to the end
- temp_his = history[::-1]
- last_user_input = None
- for m in temp_his:
- if m["role"] == "user":
- last_user_input = m
+ last_user_input_index = None
+ for i in range(len(history) - 1, -1, -1):
+ if history[i]["role"] == "user":
+ last_user_input_index = i
break
- if last_user_input:
- history.remove(last_user_input)
+ if last_user_input_index:
+ last_user_input = history.pop(last_user_input_index)
history.append(last_user_input)
payloads = {