From 6b982e28795ed6767214c5733379ba8bf0943458 Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Thu, 21 Dec 2023 16:46:29 +0800 Subject: [PATCH 1/2] feat(model): Support Mixtral-8x7B (#959) --- README.md | 3 +- README.zh.md | 3 +- dbgpt/app/chat_adapter.py | 2 +- dbgpt/configs/model_config.py | 8 +- dbgpt/model/adapter/__init__.py | 0 dbgpt/model/adapter/base.py | 437 ++++++++++++ dbgpt/model/adapter/fschat_adapter.py | 262 +++++++ dbgpt/model/adapter/hf_adapter.py | 136 ++++ dbgpt/model/adapter/model_adapter.py | 166 +++++ .../{adapter.py => adapter/old_adapter.py} | 101 ++- dbgpt/model/adapter/template.py | 130 ++++ dbgpt/model/adapter/vllm_adapter.py | 93 +++ dbgpt/model/cli.py | 2 +- dbgpt/model/cluster/worker/default_worker.py | 5 +- dbgpt/model/llm_utils.py | 2 +- dbgpt/model/loader.py | 64 +- dbgpt/model/model_adapter.py | 660 ------------------ 17 files changed, 1386 insertions(+), 688 deletions(-) create mode 100644 dbgpt/model/adapter/__init__.py create mode 100644 dbgpt/model/adapter/base.py create mode 100644 dbgpt/model/adapter/fschat_adapter.py create mode 100644 dbgpt/model/adapter/hf_adapter.py create mode 100644 dbgpt/model/adapter/model_adapter.py rename dbgpt/model/{adapter.py => adapter/old_adapter.py} (82%) create mode 100644 dbgpt/model/adapter/template.py create mode 100644 dbgpt/model/adapter/vllm_adapter.py delete mode 100644 dbgpt/model/model_adapter.py diff --git a/README.md b/README.md index 517226dc5..92778486f 100644 --- a/README.md +++ b/README.md @@ -103,7 +103,8 @@ At present, we have introduced several key features to showcase our current capa We offer extensive model support, including dozens of large language models (LLMs) from both open-source and API agents, such as LLaMA/LLaMA2, Baichuan, ChatGLM, Wenxin, Tongyi, Zhipu, and many more. - News - - 🔥🔥🔥 [qwen-72b-chat](https://huggingface.co/Qwen/Qwen-72B-Chat) + - 🔥🔥🔥 [Mixtral-8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) + - 🔥🔥🔥 [Qwen-72B-Chat](https://huggingface.co/Qwen/Qwen-72B-Chat) - 🔥🔥🔥 [Yi-34B-Chat](https://huggingface.co/01-ai/Yi-34B-Chat) - [More Supported LLMs](http://docs.dbgpt.site/docs/modules/smmf) diff --git a/README.zh.md b/README.zh.md index 0ba4c40af..61cbc7def 100644 --- a/README.zh.md +++ b/README.zh.md @@ -111,7 +111,8 @@ DB-GPT是一个开源的数据库领域大模型框架。目的是构建大模 海量模型支持,包括开源、API代理等几十种大语言模型。如LLaMA/LLaMA2、Baichuan、ChatGLM、文心、通义、智谱等。当前已支持如下模型: - 新增支持模型 - - 🔥🔥🔥 [qwen-72b-chat](https://huggingface.co/Qwen/Qwen-72B-Chat) + - 🔥🔥🔥 [Mixtral-8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) + - 🔥🔥🔥 [Qwen-72B-Chat](https://huggingface.co/Qwen/Qwen-72B-Chat) - 🔥🔥🔥 [Yi-34B-Chat](https://huggingface.co/01-ai/Yi-34B-Chat) - [更多开源模型](https://www.yuque.com/eosphoros/dbgpt-docs/iqaaqwriwhp6zslc#qQktR) diff --git a/dbgpt/app/chat_adapter.py b/dbgpt/app/chat_adapter.py index 474695beb..c1cb192b1 100644 --- a/dbgpt/app/chat_adapter.py +++ b/dbgpt/app/chat_adapter.py @@ -245,7 +245,7 @@ def get_conv_template(self, model_path: str) -> Conversation: class LlamaCppChatAdapter(BaseChatAdpter): def match(self, model_path: str): - from dbgpt.model.adapter import LlamaCppAdapater + from dbgpt.model.adapter.old_adapter import LlamaCppAdapater if "llama-cpp" == model_path: return True diff --git a/dbgpt/configs/model_config.py b/dbgpt/configs/model_config.py index ccbe4e3d4..c9f123677 100644 --- a/dbgpt/configs/model_config.py +++ b/dbgpt/configs/model_config.py @@ -113,7 +113,9 @@ def get_device() -> str: # https://huggingface.co/microsoft/Orca-2-13b "orca-2-13b": os.path.join(MODEL_PATH, "Orca-2-13b"), # https://huggingface.co/openchat/openchat_3.5 - "openchat_3.5": os.path.join(MODEL_PATH, "openchat_3.5"), + "openchat-3.5": os.path.join(MODEL_PATH, "openchat_3.5"), + # https://huggingface.co/openchat/openchat-3.5-1210 + "openchat-3.5-1210": os.path.join(MODEL_PATH, "openchat-3.5-1210"), # https://huggingface.co/hfl/chinese-alpaca-2-7b "chinese-alpaca-2-7b": os.path.join(MODEL_PATH, "chinese-alpaca-2-7b"), # https://huggingface.co/hfl/chinese-alpaca-2-13b @@ -124,6 +126,10 @@ def get_device() -> str: "zephyr-7b-alpha": os.path.join(MODEL_PATH, "zephyr-7b-alpha"), # https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1 "mistral-7b-instruct-v0.1": os.path.join(MODEL_PATH, "Mistral-7B-Instruct-v0.1"), + # https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1 + "mixtral-8x7b-instruct-v0.1": os.path.join( + MODEL_PATH, "Mixtral-8x7B-Instruct-v0.1" + ), # https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca "mistral-7b-openorca": os.path.join(MODEL_PATH, "Mistral-7B-OpenOrca"), # https://huggingface.co/Xwin-LM/Xwin-LM-7B-V0.1 diff --git a/dbgpt/model/adapter/__init__.py b/dbgpt/model/adapter/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dbgpt/model/adapter/base.py b/dbgpt/model/adapter/base.py new file mode 100644 index 000000000..df1d9441a --- /dev/null +++ b/dbgpt/model/adapter/base.py @@ -0,0 +1,437 @@ +from abc import ABC, abstractmethod +from typing import Dict, List, Optional, Any, Tuple, Type, Callable +import logging +from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType +from dbgpt.model.base import ModelType +from dbgpt.model.parameter import ( + BaseModelParameters, + ModelParameters, + LlamaCppModelParameters, + ProxyModelParameters, +) +from dbgpt.model.adapter.template import ( + get_conv_template, + ConversationAdapter, + ConversationAdapterFactory, +) + +logger = logging.getLogger(__name__) + + +class LLMModelAdapter(ABC): + """New Adapter for DB-GPT LLM models""" + + model_name: Optional[str] = None + model_path: Optional[str] = None + conv_factory: Optional[ConversationAdapterFactory] = None + # TODO: more flexible quantization config + support_4bit: bool = False + support_8bit: bool = False + support_system_message: bool = True + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} model_name={self.model_name} model_path={self.model_path}>" + + def __str__(self): + return self.__repr__() + + @abstractmethod + def new_adapter(self, **kwargs) -> "LLMModelAdapter": + """Create a new adapter instance + + Args: + **kwargs: The parameters of the new adapter instance + + Returns: + LLMModelAdapter: The new adapter instance + """ + + def use_fast_tokenizer(self) -> bool: + """Whether use a [fast Rust-based tokenizer](https://huggingface.co/docs/tokenizers/index) if it is supported + for a given model. + """ + return False + + def model_type(self) -> str: + return ModelType.HF + + def model_param_class(self, model_type: str = None) -> Type[BaseModelParameters]: + """Get the startup parameters instance of the model + + Args: + model_type (str, optional): The type of model. Defaults to None. + + Returns: + Type[BaseModelParameters]: The startup parameters instance of the model + """ + # """Get the startup parameters instance of the model""" + model_type = model_type if model_type else self.model_type() + if model_type == ModelType.LLAMA_CPP: + return LlamaCppModelParameters + elif model_type == ModelType.PROXY: + return ProxyModelParameters + return ModelParameters + + def match( + self, + model_type: str, + model_name: Optional[str] = None, + model_path: Optional[str] = None, + ) -> bool: + """Whether the model adapter can load the given model + + Args: + model_type (str): The type of model + model_name (Optional[str], optional): The name of model. Defaults to None. + model_path (Optional[str], optional): The path of model. Defaults to None. + """ + return False + + def support_quantization_4bit(self) -> bool: + """Whether the model adapter can load 4bit model + + If it is True, we will load the 4bit model with :meth:`~LLMModelAdapter.load` + + Returns: + bool: Whether the model adapter can load 4bit model, default is False + """ + return self.support_4bit + + def support_quantization_8bit(self) -> bool: + """Whether the model adapter can load 8bit model + + If it is True, we will load the 8bit model with :meth:`~LLMModelAdapter.load` + + Returns: + bool: Whether the model adapter can load 8bit model, default is False + """ + return self.support_8bit + + def load(self, model_path: str, from_pretrained_kwargs: dict): + """Load model and tokenizer""" + raise NotImplementedError + + def load_from_params(self, params): + """Load the model and tokenizer according to the given parameters""" + raise NotImplementedError + + def support_async(self) -> bool: + """Whether the loaded model supports asynchronous calls""" + return False + + def get_generate_stream_function(self, model, model_path: str): + """Get the generate stream function of the model""" + raise NotImplementedError + + def get_async_generate_stream_function(self, model, model_path: str): + """Get the asynchronous generate stream function of the model""" + raise NotImplementedError + + def get_default_conv_template( + self, model_name: str, model_path: str + ) -> Optional[ConversationAdapter]: + """Get the default conversation template + + Args: + model_name (str): The name of the model. + model_path (str): The path of the model. + + Returns: + Optional[ConversationAdapter]: The conversation template. + """ + raise NotImplementedError + + def get_default_message_separator(self) -> str: + """Get the default message separator""" + try: + conv_template = self.get_default_conv_template( + self.model_name, self.model_path + ) + return conv_template.sep + except Exception: + return "\n" + + def transform_model_messages( + self, messages: List[ModelMessage] + ) -> List[Dict[str, str]]: + """Transform the model messages + + Default is the OpenAI format, example: + .. code-block:: python + return_messages = [ + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi"} + ] + + But some model may need to transform the messages to other format(e.g. There is no system message), such as: + .. code-block:: python + return_messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi"} + ] + Args: + messages (List[ModelMessage]): The model messages + + Returns: + List[Dict[str, str]]: The transformed model messages + """ + logger.info(f"support_system_message: {self.support_system_message}") + if not self.support_system_message: + return self._transform_to_no_system_messages(messages) + else: + return ModelMessage.to_openai_messages(messages) + + def _transform_to_no_system_messages( + self, messages: List[ModelMessage] + ) -> List[Dict[str, str]]: + """Transform the model messages to no system messages + + Some opensource chat model no system messages, so wo should transform the messages to no system messages. + + Merge the system messages to the last user message, example: + .. code-block:: python + return_messages = [ + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi"} + ] + => + return_messages = [ + {"role": "user", "content": "You are a helpful assistant\nHello"}, + {"role": "assistant", "content": "Hi"} + ] + + Args: + messages (List[ModelMessage]): The model messages + + Returns: + List[Dict[str, str]]: The transformed model messages + """ + openai_messages = ModelMessage.to_openai_messages(messages) + system_messages = [] + return_messages = [] + for message in openai_messages: + if message["role"] == "system": + system_messages.append(message["content"]) + else: + return_messages.append(message) + if len(system_messages) > 1: + # Too much system messages should be a warning + logger.warning("Your system messages have more than one message") + if system_messages: + sep = self.get_default_message_separator() + str_system_messages = ",".join(system_messages) + # Update last user message + return_messages[-1]["content"] = ( + str_system_messages + sep + return_messages[-1]["content"] + ) + return return_messages + + def get_str_prompt( + self, + params: Dict, + messages: List[ModelMessage], + tokenizer: Any, + prompt_template: str = None, + ) -> Optional[str]: + """Get the string prompt from the given parameters and messages + + If the value of return is not None, we will skip :meth:`~LLMModelAdapter.get_prompt_with_template` and use the value of return. + + Args: + params (Dict): The parameters + messages (List[ModelMessage]): The model messages + tokenizer (Any): The tokenizer of model, in huggingface chat model, we can create the prompt by tokenizer + prompt_template (str, optional): The prompt template. Defaults to None. + + Returns: + Optional[str]: The string prompt + """ + return None + + def get_prompt_with_template( + self, + params: Dict, + messages: List[ModelMessage], + model_name: str, + model_path: str, + model_context: Dict, + prompt_template: str = None, + ): + conv: ConversationAdapter = self.get_default_conv_template( + model_name, model_path + ) + + if prompt_template: + logger.info(f"Use prompt template {prompt_template} from config") + conv = get_conv_template(prompt_template) + if not conv or not messages: + # Nothing to do + logger.info( + f"No conv from model_path {model_path} or no messages in params, {self}" + ) + return None, None, None + + conv = conv.copy() + system_messages = [] + user_messages = [] + ai_messages = [] + + for message in messages: + if isinstance(message, ModelMessage): + role = message.role + content = message.content + elif isinstance(message, dict): + role = message["role"] + content = message["content"] + else: + raise ValueError(f"Invalid message type: {message}") + + if role == ModelMessageRoleType.SYSTEM: + # Support for multiple system messages + system_messages.append(content) + elif role == ModelMessageRoleType.HUMAN: + # conv.append_message(conv.roles[0], content) + user_messages.append(content) + elif role == ModelMessageRoleType.AI: + # conv.append_message(conv.roles[1], content) + ai_messages.append(content) + else: + raise ValueError(f"Unknown role: {role}") + + can_use_systems: [] = [] + if system_messages: + if len(system_messages) > 1: + # Compatible with dbgpt complex scenarios, the last system will protect more complete information + # entered by the current user + user_messages[-1] = system_messages[-1] + can_use_systems = system_messages[:-1] + else: + can_use_systems = system_messages + + for i in range(len(user_messages)): + conv.append_message(conv.roles[0], user_messages[i]) + if i < len(ai_messages): + conv.append_message(conv.roles[1], ai_messages[i]) + + # TODO join all system messages may not be a good idea + conv.set_system_message("".join(can_use_systems)) + # Add a blank message for the assistant. + conv.append_message(conv.roles[1], None) + new_prompt = conv.get_prompt() + return new_prompt, conv.stop_str, conv.stop_token_ids + + def model_adaptation( + self, + params: Dict, + model_name: str, + model_path: str, + tokenizer: Any, + prompt_template: str = None, + ) -> Tuple[Dict, Dict]: + """Params adaptation""" + messages = params.get("messages") + # Some model context to dbgpt server + model_context = {"prompt_echo_len_char": -1, "has_format_prompt": False} + if messages: + # Dict message to ModelMessage + messages = [ + m if isinstance(m, ModelMessage) else ModelMessage(**m) + for m in messages + ] + params["messages"] = messages + + new_prompt = self.get_str_prompt(params, messages, tokenizer, prompt_template) + conv_stop_str, conv_stop_token_ids = None, None + if not new_prompt: + ( + new_prompt, + conv_stop_str, + conv_stop_token_ids, + ) = self.get_prompt_with_template( + params, messages, model_name, model_path, model_context, prompt_template + ) + if not new_prompt: + return params, model_context + + # Overwrite the original prompt + # TODO remote bos token and eos token from tokenizer_config.json of model + prompt_echo_len_char = len(new_prompt.replace("", "").replace("", "")) + model_context["prompt_echo_len_char"] = prompt_echo_len_char + model_context["echo"] = params.get("echo", True) + model_context["has_format_prompt"] = True + params["prompt"] = new_prompt + + custom_stop = params.get("stop") + custom_stop_token_ids = params.get("stop_token_ids") + + # Prefer the value passed in from the input parameter + params["stop"] = custom_stop or conv_stop_str + params["stop_token_ids"] = custom_stop_token_ids or conv_stop_token_ids + + return params, model_context + + +class AdapterEntry: + """The entry of model adapter""" + + def __init__( + self, + model_adapter: LLMModelAdapter, + match_funcs: List[Callable[[str, str, str], bool]] = None, + ): + self.model_adapter = model_adapter + self.match_funcs = match_funcs or [] + + +model_adapters: List[AdapterEntry] = [] + + +def register_model_adapter( + model_adapter_cls: Type[LLMModelAdapter], + match_funcs: List[Callable[[str, str, str], bool]] = None, +) -> None: + """Register a model adapter. + + Args: + model_adapter_cls (Type[LLMModelAdapter]): The model adapter class. + match_funcs (List[Callable[[str, str, str], bool]], optional): The match functions. Defaults to None. + """ + model_adapters.append(AdapterEntry(model_adapter_cls(), match_funcs)) + + +def get_model_adapter( + model_type: str, + model_name: str, + model_path: str, + conv_factory: Optional[ConversationAdapterFactory] = None, +) -> Optional[LLMModelAdapter]: + """Get a model adapter. + + Args: + model_type (str): The type of the model. + model_name (str): The name of the model. + model_path (str): The path of the model. + conv_factory (Optional[ConversationAdapterFactory], optional): The conversation factory. Defaults to None. + Returns: + Optional[LLMModelAdapter]: The model adapter. + """ + adapter = None + # First find adapter by model_name + for adapter_entry in model_adapters: + if adapter_entry.model_adapter.match(model_type, model_name, None): + adapter = adapter_entry.model_adapter + break + for adapter_entry in model_adapters: + if adapter_entry.model_adapter.match(model_type, None, model_path): + adapter = adapter_entry.model_adapter + break + if adapter: + new_adapter = adapter.new_adapter() + new_adapter.model_name = model_name + new_adapter.model_path = model_path + if conv_factory: + new_adapter.conv_factory = conv_factory + return new_adapter + return None diff --git a/dbgpt/model/adapter/fschat_adapter.py b/dbgpt/model/adapter/fschat_adapter.py new file mode 100644 index 000000000..ca5413645 --- /dev/null +++ b/dbgpt/model/adapter/fschat_adapter.py @@ -0,0 +1,262 @@ +"""Adapter for fastchat + +You can import fastchat only in this file, so that the user does not need to install fastchat if he does not use it. +""" +import os +import threading +import logging +from functools import cache +from typing import TYPE_CHECKING, Callable, Tuple, List, Optional + +try: + from fastchat.conversation import ( + Conversation, + register_conv_template, + SeparatorStyle, + ) +except ImportError as exc: + raise ValueError( + "Could not import python package: fschat " + "Please install fastchat by command `pip install fschat` " + ) from exc + +from dbgpt.model.adapter.template import ConversationAdapter, PromptType +from dbgpt.model.adapter.base import LLMModelAdapter + +if TYPE_CHECKING: + from fastchat.model.model_adapter import BaseModelAdapter + from torch.nn import Module as TorchNNModule + +logger = logging.getLogger(__name__) + +thread_local = threading.local() +_IS_BENCHMARK = os.getenv("DB_GPT_MODEL_BENCHMARK", "False").lower() == "true" + +# If some model is not in the blacklist, but it still affects the loading of DB-GPT, you can add it to the blacklist. +__BLACK_LIST_MODEL_PROMPT = [] + + +class FschatConversationAdapter(ConversationAdapter): + """The conversation adapter for fschat.""" + + def __init__(self, conv: Conversation): + self._conv = conv + + @property + def prompt_type(self) -> PromptType: + return PromptType.FSCHAT + + @property + def roles(self) -> Tuple[str]: + return self._conv.roles + + @property + def sep(self) -> Optional[str]: + return self._conv.sep + + @property + def stop_str(self) -> str: + return self._conv.stop_str + + @property + def stop_token_ids(self) -> Optional[List[int]]: + return self._conv.stop_token_ids + + def get_prompt(self) -> str: + """Get the prompt string.""" + return self._conv.get_prompt() + + def set_system_message(self, system_message: str) -> None: + """Set the system message.""" + self._conv.set_system_message(system_message) + + def append_message(self, role: str, message: str) -> None: + """Append a new message. + + Args: + role (str): The role of the message. + message (str): The message content. + """ + self._conv.append_message(role, message) + + def update_last_message(self, message: str) -> None: + """Update the last output. + + The last message is typically set to be None when constructing the prompt, + so we need to update it in-place after getting the response from a model. + + Args: + message (str): The message content. + """ + self._conv.update_last_message(message) + + def copy(self) -> "ConversationAdapter": + """Copy the conversation.""" + return FschatConversationAdapter(self._conv.copy()) + + +class FastChatLLMModelAdapterWrapper(LLMModelAdapter): + """Wrapping fastchat adapter""" + + def __init__(self, adapter: "BaseModelAdapter") -> None: + self._adapter = adapter + + def new_adapter(self, **kwargs) -> "LLMModelAdapter": + return FastChatLLMModelAdapterWrapper(self._adapter) + + def use_fast_tokenizer(self) -> bool: + return self._adapter.use_fast_tokenizer + + def load(self, model_path: str, from_pretrained_kwargs: dict): + return self._adapter.load_model(model_path, from_pretrained_kwargs) + + def get_generate_stream_function(self, model: "TorchNNModule", model_path: str): + if _IS_BENCHMARK: + from dbgpt.util.benchmarks.llm.fastchat_benchmarks_inference import ( + generate_stream, + ) + + return generate_stream + else: + from fastchat.model.model_adapter import get_generate_stream_function + + return get_generate_stream_function(model, model_path) + + def get_default_conv_template( + self, model_name: str, model_path: str + ) -> Optional[ConversationAdapter]: + conv_template = self._adapter.get_default_conv_template(model_path) + return FschatConversationAdapter(conv_template) if conv_template else None + + def __str__(self) -> str: + return "{}({}.{})".format( + self.__class__.__name__, + self._adapter.__class__.__module__, + self._adapter.__class__.__name__, + ) + + +def _get_fastchat_model_adapter( + model_name: str, + model_path: str, + caller: Callable[[str], None] = None, + use_fastchat_monkey_patch: bool = False, +): + from fastchat.model import model_adapter + + _bak_get_model_adapter = model_adapter.get_model_adapter + try: + if use_fastchat_monkey_patch: + model_adapter.get_model_adapter = _fastchat_get_adapter_monkey_patch + thread_local.model_name = model_name + _remove_black_list_model_of_fastchat() + if caller: + return caller(model_path) + finally: + del thread_local.model_name + model_adapter.get_model_adapter = _bak_get_model_adapter + + +def _fastchat_get_adapter_monkey_patch(model_path: str, model_name: str = None): + if not model_name: + if not hasattr(thread_local, "model_name"): + raise RuntimeError("fastchat get adapter monkey path need model_name") + model_name = thread_local.model_name + from fastchat.model.model_adapter import model_adapters + + for adapter in model_adapters: + if adapter.match(model_name): + logger.info( + f"Found llm model adapter with model name: {model_name}, {adapter}" + ) + return adapter + + model_path_basename = ( + None if not model_path else os.path.basename(os.path.normpath(model_path)) + ) + for adapter in model_adapters: + if model_path_basename and adapter.match(model_path_basename): + logger.info( + f"Found llm model adapter with model path: {model_path} and base name: {model_path_basename}, {adapter}" + ) + return adapter + + for adapter in model_adapters: + if model_path and adapter.match(model_path): + logger.info( + f"Found llm model adapter with model path: {model_path}, {adapter}" + ) + return adapter + + raise ValueError( + f"Invalid model adapter for model name {model_name} and model path {model_path}" + ) + + +@cache +def _remove_black_list_model_of_fastchat(): + from fastchat.model.model_adapter import model_adapters + + black_list_models = [] + for adapter in model_adapters: + try: + if ( + adapter.get_default_conv_template("/data/not_exist_model_path").name + in __BLACK_LIST_MODEL_PROMPT + ): + black_list_models.append(adapter) + except Exception: + pass + for adapter in black_list_models: + model_adapters.remove(adapter) + + +# Covering the configuration of fastcaht, we will regularly feedback the code here to fastchat. +# We also recommend that you modify it directly in the fastchat repository. + +# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L212 +register_conv_template( + Conversation( + name="aquila-legacy", + system_message="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + roles=("### Human: ", "### Assistant: ", "System"), + messages=(), + offset=0, + sep_style=SeparatorStyle.NO_COLON_TWO, + sep="\n", + sep2="", + stop_str=["", "[UNK]"], + ), + override=True, +) +# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L227 +register_conv_template( + Conversation( + name="aquila", + system_message="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("Human", "Assistant", "System"), + messages=(), + offset=0, + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep="###", + sep2="", + stop_str=["", "[UNK]"], + ), + override=True, +) +# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L242 +register_conv_template( + Conversation( + name="aquila-v1", + roles=("<|startofpiece|>", "<|endofpiece|>", ""), + messages=(), + offset=0, + sep_style=SeparatorStyle.NO_COLON_TWO, + sep="", + sep2="", + stop_str=["", "<|endoftext|>"], + ), + override=True, +) diff --git a/dbgpt/model/adapter/hf_adapter.py b/dbgpt/model/adapter/hf_adapter.py new file mode 100644 index 000000000..673223b49 --- /dev/null +++ b/dbgpt/model/adapter/hf_adapter.py @@ -0,0 +1,136 @@ +from abc import ABC, abstractmethod +from typing import Dict, Optional, List, Any +import logging + +from dbgpt.core import ModelMessage +from dbgpt.model.base import ModelType +from dbgpt.model.adapter.base import LLMModelAdapter, register_model_adapter + +logger = logging.getLogger(__name__) + + +class NewHFChatModelAdapter(LLMModelAdapter, ABC): + """Model adapter for new huggingface chat models + + See https://huggingface.co/docs/transformers/main/en/chat_templating + + We can transform the inference chat messages to chat model instead of create a + prompt template for this model + """ + + def new_adapter(self, **kwargs) -> "NewHFChatModelAdapter": + return self.__class__() + + def match( + self, + model_type: str, + model_name: Optional[str] = None, + model_path: Optional[str] = None, + ) -> bool: + if model_type != ModelType.HF: + return False + if model_name is None and model_path is None: + return False + model_name = model_name.lower() if model_name else None + model_path = model_path.lower() if model_path else None + return self.do_match(model_name) or self.do_match(model_path) + + @abstractmethod + def do_match(self, lower_model_name_or_path: Optional[str] = None): + raise NotImplementedError() + + def load(self, model_path: str, from_pretrained_kwargs: dict): + try: + import transformers + from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel + except ImportError as exc: + raise ValueError( + "Could not import depend python package " + "Please install it with `pip install transformers`." + ) from exc + if not transformers.__version__ >= "4.34.0": + raise ValueError( + "Current model (Load by NewHFChatModelAdapter) require transformers.__version__>=4.34.0" + ) + revision = from_pretrained_kwargs.get("revision", "main") + try: + tokenizer = AutoTokenizer.from_pretrained( + model_path, + use_fast=self.use_fast_tokenizer(), + revision=revision, + trust_remote_code=True, + ) + except TypeError: + tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=False, revision=revision, trust_remote_code=True + ) + try: + model = AutoModelForCausalLM.from_pretrained( + model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs + ) + except NameError: + model = AutoModel.from_pretrained( + model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs + ) + # tokenizer.use_default_system_prompt = False + return model, tokenizer + + def get_generate_stream_function(self, model, model_path: str): + """Get the generate stream function of the model""" + from dbgpt.model.llm_out.hf_chat_llm import huggingface_chat_generate_stream + + return huggingface_chat_generate_stream + + def get_str_prompt( + self, + params: Dict, + messages: List[ModelMessage], + tokenizer: Any, + prompt_template: str = None, + ) -> Optional[str]: + from transformers import AutoTokenizer + + if not tokenizer: + raise ValueError("tokenizer is is None") + tokenizer: AutoTokenizer = tokenizer + + messages = self.transform_model_messages(messages) + logger.debug(f"The messages after transform: \n{messages}") + str_prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + return str_prompt + + +class YiAdapter(NewHFChatModelAdapter): + support_4bit: bool = True + support_8bit: bool = True + support_system_message: bool = True + + def do_match(self, lower_model_name_or_path: Optional[str] = None): + return ( + lower_model_name_or_path + and "yi-" in lower_model_name_or_path + and "chat" in lower_model_name_or_path + ) + + +class Mixtral8x7BAdapter(NewHFChatModelAdapter): + """ + https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1 + """ + + support_4bit: bool = True + support_8bit: bool = True + support_system_message: bool = False + + def do_match(self, lower_model_name_or_path: Optional[str] = None): + return ( + lower_model_name_or_path + and "mixtral" in lower_model_name_or_path + and "8x7b" in lower_model_name_or_path + ) + + +register_model_adapter(YiAdapter) +register_model_adapter(Mixtral8x7BAdapter) diff --git a/dbgpt/model/adapter/model_adapter.py b/dbgpt/model/adapter/model_adapter.py new file mode 100644 index 000000000..efaf9f1b9 --- /dev/null +++ b/dbgpt/model/adapter/model_adapter.py @@ -0,0 +1,166 @@ +from __future__ import annotations + +from typing import ( + List, + Type, + Optional, +) +import logging +import threading +import os +from functools import cache +from dbgpt.model.base import ModelType +from dbgpt.model.parameter import BaseModelParameters +from dbgpt.model.adapter.base import LLMModelAdapter, get_model_adapter +from dbgpt.model.adapter.template import ( + ConversationAdapter, + ConversationAdapterFactory, +) + +logger = logging.getLogger(__name__) + +thread_local = threading.local() +_IS_BENCHMARK = os.getenv("DB_GPT_MODEL_BENCHMARK", "False").lower() == "true" + + +_OLD_MODELS = [ + "llama-cpp", + "proxyllm", + "gptj-6b", + "codellama-13b-sql-sft", + "codellama-7b", + "codellama-7b-sql-sft", + "codellama-13b", +] + + +@cache +def get_llm_model_adapter( + model_name: str, + model_path: str, + use_fastchat: bool = True, + use_fastchat_monkey_patch: bool = False, + model_type: str = None, +) -> LLMModelAdapter: + conv_factory = DefaultConversationAdapterFactory() + if model_type == ModelType.VLLM: + logger.info("Current model type is vllm, return VLLMModelAdapterWrapper") + from dbgpt.model.adapter.vllm_adapter import VLLMModelAdapterWrapper + + return VLLMModelAdapterWrapper(conv_factory) + + # Import NewHFChatModelAdapter for it can be registered + from dbgpt.model.adapter.hf_adapter import NewHFChatModelAdapter + + new_model_adapter = get_model_adapter( + model_type, model_name, model_path, conv_factory + ) + if new_model_adapter: + logger.info(f"Current model {model_name} use new adapter {new_model_adapter}") + return new_model_adapter + + must_use_old = any(m in model_name for m in _OLD_MODELS) + result_adapter: Optional[LLMModelAdapter] = None + if use_fastchat and not must_use_old: + logger.info("Use fastcat adapter") + from dbgpt.model.adapter.fschat_adapter import ( + _get_fastchat_model_adapter, + _fastchat_get_adapter_monkey_patch, + FastChatLLMModelAdapterWrapper, + ) + + adapter = _get_fastchat_model_adapter( + model_name, + model_path, + _fastchat_get_adapter_monkey_patch, + use_fastchat_monkey_patch=use_fastchat_monkey_patch, + ) + if adapter: + result_adapter = FastChatLLMModelAdapterWrapper(adapter) + + else: + from dbgpt.model.adapter.old_adapter import ( + get_llm_model_adapter as _old_get_llm_model_adapter, + OldLLMModelAdapterWrapper, + ) + from dbgpt.app.chat_adapter import get_llm_chat_adapter + + logger.info("Use DB-GPT old adapter") + result_adapter = OldLLMModelAdapterWrapper( + _old_get_llm_model_adapter(model_name, model_path), + get_llm_chat_adapter(model_name, model_path), + ) + if result_adapter: + result_adapter.model_name = model_name + result_adapter.model_path = model_path + result_adapter.conv_factory = conv_factory + return result_adapter + else: + raise ValueError(f"Can not find adapter for model {model_name}") + + +@cache +def _auto_get_conv_template( + model_name: str, model_path: str +) -> Optional[ConversationAdapter]: + """Auto get the conversation template. + + Args: + model_name (str): The name of the model. + model_path (str): The path of the model. + + Returns: + Optional[ConversationAdapter]: The conversation template. + """ + try: + adapter = get_llm_model_adapter(model_name, model_path, use_fastchat=True) + return adapter.get_default_conv_template(model_name, model_path) + except Exception as e: + logger.debug(f"Failed to get conv template for {model_name} {model_path}: {e}") + return None + + +class DefaultConversationAdapterFactory(ConversationAdapterFactory): + def get_by_model(self, model_name: str, model_path: str) -> ConversationAdapter: + """Get a conversation adapter by model. + + Args: + model_name (str): The name of the model. + model_path (str): The path of the model. + Returns: + ConversationAdapter: The conversation adapter. + """ + return _auto_get_conv_template(model_name, model_path) + + +def _dynamic_model_parser() -> Optional[List[Type[BaseModelParameters]]]: + """Dynamic model parser, parse the model parameters from the command line arguments. + + Returns: + Optional[List[Type[BaseModelParameters]]]: The model parameters class list. + """ + from dbgpt.util.parameter_utils import _SimpleArgParser + from dbgpt.model.parameter import ( + EmbeddingModelParameters, + WorkerType, + EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG, + ) + + pre_args = _SimpleArgParser("model_name", "model_path", "worker_type", "model_type") + pre_args.parse() + model_name = pre_args.get("model_name") + model_path = pre_args.get("model_path") + worker_type = pre_args.get("worker_type") + model_type = pre_args.get("model_type") + if model_name is None and model_type != ModelType.VLLM: + return None + if worker_type == WorkerType.TEXT2VEC: + return [ + EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get( + model_name, EmbeddingModelParameters + ) + ] + + llm_adapter = get_llm_model_adapter(model_name, model_path, model_type=model_type) + param_class = llm_adapter.model_param_class() + return [param_class] diff --git a/dbgpt/model/adapter.py b/dbgpt/model/adapter/old_adapter.py similarity index 82% rename from dbgpt/model/adapter.py rename to dbgpt/model/adapter/old_adapter.py index 05be0e464..a63054695 100644 --- a/dbgpt/model/adapter.py +++ b/dbgpt/model/adapter/old_adapter.py @@ -9,7 +9,7 @@ import re import logging from pathlib import Path -from typing import List, Tuple +from typing import List, Tuple, TYPE_CHECKING, Optional from functools import cache from transformers import ( AutoModel, @@ -17,6 +17,9 @@ AutoTokenizer, LlamaTokenizer, ) + +from dbgpt.model.adapter.base import LLMModelAdapter +from dbgpt.model.adapter.template import ConversationAdapter, PromptType from dbgpt.model.base import ModelType from dbgpt.model.parameter import ( @@ -24,9 +27,13 @@ LlamaCppModelParameters, ProxyModelParameters, ) +from dbgpt.model.conversation import Conversation from dbgpt.configs.model_config import get_device from dbgpt._private.config import Config +if TYPE_CHECKING: + from dbgpt.app.chat_adapter import BaseChatAdpter + logger = logging.getLogger(__name__) CFG = Config() @@ -92,17 +99,6 @@ def get_llm_model_adapter(model_name: str, model_path: str) -> BaseLLMAdaper: ) -def _parse_model_param_class(model_name: str, model_path: str) -> ModelParameters: - try: - llm_adapter = get_llm_model_adapter(model_name, model_path) - return llm_adapter.model_param_class() - except Exception as e: - logger.warn( - f"Parse model parameters with model name {model_name} and model {model_path} failed {str(e)}, return `ModelParameters`" - ) - return ModelParameters - - # TODO support cpu? for practise we support gpt4all or chatglm-6b-int4? @@ -426,6 +422,87 @@ def loader(self, model_path: str, from_pretrained_kwargs: dict): return model, tokenizer +class OldLLMModelAdapterWrapper(LLMModelAdapter): + """Wrapping old adapter, which may be removed later""" + + def __init__(self, adapter: BaseLLMAdaper, chat_adapter: "BaseChatAdpter") -> None: + self._adapter = adapter + self._chat_adapter = chat_adapter + + def new_adapter(self, **kwargs) -> "LLMModelAdapter": + return OldLLMModelAdapterWrapper(self._adapter, self._chat_adapter) + + def use_fast_tokenizer(self) -> bool: + return self._adapter.use_fast_tokenizer() + + def model_type(self) -> str: + return self._adapter.model_type() + + def model_param_class(self, model_type: str = None) -> ModelParameters: + return self._adapter.model_param_class(model_type) + + def get_default_conv_template( + self, model_name: str, model_path: str + ) -> Optional[ConversationAdapter]: + conv_template = self._chat_adapter.get_conv_template(model_path) + return OldConversationAdapter(conv_template) if conv_template else None + + def load(self, model_path: str, from_pretrained_kwargs: dict): + return self._adapter.loader(model_path, from_pretrained_kwargs) + + def get_generate_stream_function(self, model, model_path: str): + return self._chat_adapter.get_generate_stream_func(model_path) + + def __str__(self) -> str: + return "{}({}.{})".format( + self.__class__.__name__, + self._adapter.__class__.__module__, + self._adapter.__class__.__name__, + ) + + +class OldConversationAdapter(ConversationAdapter): + """Wrapping old Conversation, which may be removed later""" + + def __init__(self, conv: Conversation) -> None: + self._conv = conv + + @property + def prompt_type(self) -> PromptType: + return PromptType.DBGPT + + @property + def roles(self) -> Tuple[str]: + return self._conv.roles + + @property + def sep(self) -> Optional[str]: + return self._conv.sep + + @property + def stop_str(self) -> str: + return self._conv.stop_str + + @property + def stop_token_ids(self) -> Optional[List[int]]: + return self._conv.stop_token_ids + + def get_prompt(self) -> str: + return self._conv.get_prompt() + + def set_system_message(self, system_message: str) -> None: + self._conv.update_system_message(system_message) + + def append_message(self, role: str, message: str) -> None: + self._conv.append_message(role, message) + + def update_last_message(self, message: str) -> None: + self._conv.update_last_message(message) + + def copy(self) -> "ConversationAdapter": + return OldConversationAdapter(self._conv.copy()) + + register_llm_model_adapters(VicunaLLMAdapater) register_llm_model_adapters(ChatGLMAdapater) register_llm_model_adapters(GuanacoAdapter) diff --git a/dbgpt/model/adapter/template.py b/dbgpt/model/adapter/template.py new file mode 100644 index 000000000..3fb9a6ec1 --- /dev/null +++ b/dbgpt/model/adapter/template.py @@ -0,0 +1,130 @@ +from abc import ABC, abstractmethod +from enum import Enum +from typing import TYPE_CHECKING, Optional, Tuple, Union, List + +if TYPE_CHECKING: + from fastchat.conversation import Conversation + + +class PromptType(str, Enum): + """Prompt type.""" + + FSCHAT: str = "fschat" + DBGPT: str = "dbgpt" + + +class ConversationAdapter(ABC): + """The conversation adapter.""" + + @property + def prompt_type(self) -> PromptType: + return PromptType.FSCHAT + + @property + @abstractmethod + def roles(self) -> Tuple[str]: + """Get the roles of the conversation. + + Returns: + Tuple[str]: The roles of the conversation. + """ + + @property + def sep(self) -> Optional[str]: + """Get the separator between messages.""" + return "\n" + + @property + def stop_str(self) -> Optional[Union[str, List[str]]]: + """Get the stop criteria.""" + return None + + @property + def stop_token_ids(self) -> Optional[List[int]]: + """Stops generation if meeting any token in this list""" + return None + + @abstractmethod + def get_prompt(self) -> str: + """Get the prompt string. + + Returns: + str: The prompt string. + """ + + @abstractmethod + def set_system_message(self, system_message: str) -> None: + """Set the system message.""" + + @abstractmethod + def append_message(self, role: str, message: str) -> None: + """Append a new message. + Args: + role (str): The role of the message. + message (str): The message content. + """ + + @abstractmethod + def update_last_message(self, message: str) -> None: + """Update the last output. + + The last message is typically set to be None when constructing the prompt, + so we need to update it in-place after getting the response from a model. + + Args: + message (str): The message content. + """ + + @abstractmethod + def copy(self) -> "ConversationAdapter": + """Copy the conversation.""" + + +class ConversationAdapterFactory(ABC): + """The conversation adapter factory.""" + + def get_by_name( + self, + template_name: str, + prompt_template_type: Optional[PromptType] = PromptType.FSCHAT, + ) -> ConversationAdapter: + """Get a conversation adapter by name. + + Args: + template_name (str): The name of the template. + prompt_template_type (Optional[PromptType]): The type of the prompt template, default to be FSCHAT. + + Returns: + ConversationAdapter: The conversation adapter. + """ + raise NotImplementedError() + + def get_by_model(self, model_name: str, model_path: str) -> ConversationAdapter: + """Get a conversation adapter by model. + + Args: + model_name (str): The name of the model. + model_path (str): The path of the model. + + Returns: + ConversationAdapter: The conversation adapter. + """ + raise NotImplementedError() + + +def get_conv_template(name: str) -> ConversationAdapter: + """Get a conversation template. + + Args: + name (str): The name of the template. + + Just return the fastchat conversation template for now. + # TODO: More templates should be supported. + Returns: + Conversation: The conversation template. + """ + from fastchat.conversation import get_conv_template + from dbgpt.model.adapter.fschat_adapter import FschatConversationAdapter + + conv_template = get_conv_template(name) + return FschatConversationAdapter(conv_template) diff --git a/dbgpt/model/adapter/vllm_adapter.py b/dbgpt/model/adapter/vllm_adapter.py new file mode 100644 index 000000000..2ffe0c764 --- /dev/null +++ b/dbgpt/model/adapter/vllm_adapter.py @@ -0,0 +1,93 @@ +import dataclasses +import logging +from dbgpt.model.base import ModelType +from dbgpt.model.adapter.base import LLMModelAdapter +from dbgpt.model.adapter.template import ConversationAdapter, ConversationAdapterFactory +from dbgpt.model.parameter import BaseModelParameters +from dbgpt.util.parameter_utils import ( + _extract_parameter_details, + _build_parameter_class, + _get_dataclass_print_str, +) + +logger = logging.getLogger(__name__) + + +class VLLMModelAdapterWrapper(LLMModelAdapter): + """Wrapping vllm engine""" + + def __init__(self, conv_factory: ConversationAdapterFactory): + self.conv_factory = conv_factory + + def new_adapter(self, **kwargs) -> "VLLMModelAdapterWrapper": + return VLLMModelAdapterWrapper(self.conv_factory) + + def model_type(self) -> str: + return ModelType.VLLM + + def model_param_class(self, model_type: str = None) -> BaseModelParameters: + import argparse + from vllm.engine.arg_utils import AsyncEngineArgs + + parser = argparse.ArgumentParser() + parser = AsyncEngineArgs.add_cli_args(parser) + parser.add_argument("--model_name", type=str, help="model name") + parser.add_argument( + "--model_path", + type=str, + help="local model path of the huggingface model to use", + ) + parser.add_argument("--model_type", type=str, help="model type") + parser.add_argument("--device", type=str, default=None, help="device") + # TODO parse prompt templete from `model_name` and `model_path` + parser.add_argument( + "--prompt_template", + type=str, + default=None, + help="Prompt template. If None, the prompt template is automatically determined from model path", + ) + + descs = _extract_parameter_details( + parser, + "dbgpt.model.parameter.VLLMModelParameters", + skip_names=["model"], + overwrite_default_values={"trust_remote_code": True}, + ) + return _build_parameter_class(descs) + + def load_from_params(self, params): + from vllm import AsyncLLMEngine + from vllm.engine.arg_utils import AsyncEngineArgs + import torch + + num_gpus = torch.cuda.device_count() + if num_gpus > 1 and hasattr(params, "tensor_parallel_size"): + setattr(params, "tensor_parallel_size", num_gpus) + logger.info( + f"Start vllm AsyncLLMEngine with args: {_get_dataclass_print_str(params)}" + ) + + params = dataclasses.asdict(params) + params["model"] = params["model_path"] + attrs = [attr.name for attr in dataclasses.fields(AsyncEngineArgs)] + vllm_engine_args_dict = {attr: params.get(attr) for attr in attrs} + # Set the attributes from the parsed arguments. + engine_args = AsyncEngineArgs(**vllm_engine_args_dict) + engine = AsyncLLMEngine.from_engine_args(engine_args) + return engine, engine.engine.tokenizer + + def support_async(self) -> bool: + return True + + def get_async_generate_stream_function(self, model, model_path: str): + from dbgpt.model.llm_out.vllm_llm import generate_stream + + return generate_stream + + def get_default_conv_template( + self, model_name: str, model_path: str + ) -> ConversationAdapter: + return self.conv_factory.get_by_model(model_name, model_path) + + def __str__(self) -> str: + return "{}.{}".format(self.__class__.__module__, self.__class__.__name__) diff --git a/dbgpt/model/cli.py b/dbgpt/model/cli.py index 67d5eec3c..67e12d7eb 100644 --- a/dbgpt/model/cli.py +++ b/dbgpt/model/cli.py @@ -405,7 +405,7 @@ def stop_model_controller(port: int): def _model_dynamic_factory() -> Callable[[None], List[Type]]: - from dbgpt.model.model_adapter import _dynamic_model_parser + from dbgpt.model.adapter.model_adapter import _dynamic_model_parser param_class = _dynamic_model_parser() fix_class = [ModelWorkerParameters] diff --git a/dbgpt/model/cluster/worker/default_worker.py b/dbgpt/model/cluster/worker/default_worker.py index 7345bb5d4..c4967a076 100644 --- a/dbgpt/model/cluster/worker/default_worker.py +++ b/dbgpt/model/cluster/worker/default_worker.py @@ -6,7 +6,8 @@ import traceback from dbgpt.configs.model_config import get_device -from dbgpt.model.model_adapter import get_llm_model_adapter, LLMModelAdaper +from dbgpt.model.adapter.base import LLMModelAdapter +from dbgpt.model.adapter.model_adapter import get_llm_model_adapter from dbgpt.core import ModelOutput, ModelInferenceMetrics from dbgpt.model.loader import ModelLoader, _get_model_real_path from dbgpt.model.parameter import ModelParameters @@ -27,7 +28,7 @@ def __init__(self) -> None: self.model = None self.tokenizer = None self._model_params = None - self.llm_adapter: LLMModelAdaper = None + self.llm_adapter: LLMModelAdapter = None self._support_async = False def load_worker(self, model_name: str, model_path: str, **kwargs) -> None: diff --git a/dbgpt/model/llm_utils.py b/dbgpt/model/llm_utils.py index e877778ed..031896e86 100644 --- a/dbgpt/model/llm_utils.py +++ b/dbgpt/model/llm_utils.py @@ -37,7 +37,7 @@ def list_supported_models(): def _list_supported_models( worker_type: str, model_config: Dict[str, str] ) -> List[SupportedModel]: - from dbgpt.model.model_adapter import get_llm_model_adapter + from dbgpt.model.adapter.model_adapter import get_llm_model_adapter from dbgpt.model.loader import _get_model_real_path ret = [] diff --git a/dbgpt/model/loader.py b/dbgpt/model/loader.py index 2030eb402..4ed42d630 100644 --- a/dbgpt/model/loader.py +++ b/dbgpt/model/loader.py @@ -1,13 +1,13 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -from typing import Optional, Dict +from typing import Optional, Dict, Any -from dataclasses import asdict import logging from dbgpt.configs.model_config import get_device from dbgpt.model.base import ModelType -from dbgpt.model.model_adapter import get_llm_model_adapter, LLMModelAdaper +from dbgpt.model.adapter.base import LLMModelAdapter +from dbgpt.model.adapter.model_adapter import get_llm_model_adapter from dbgpt.model.parameter import ( ModelParameters, LlamaCppModelParameters, @@ -117,7 +117,7 @@ def loader( raise Exception(f"Unkown model type {model_type}") def loader_with_params( - self, model_params: ModelParameters, llm_adapter: LLMModelAdaper + self, model_params: ModelParameters, llm_adapter: LLMModelAdapter ): model_type = llm_adapter.model_type() self.prompt_template = model_params.prompt_template @@ -133,7 +133,7 @@ def loader_with_params( raise Exception(f"Unkown model type {model_type}") -def huggingface_loader(llm_adapter: LLMModelAdaper, model_params: ModelParameters): +def huggingface_loader(llm_adapter: LLMModelAdapter, model_params: ModelParameters): import torch from dbgpt.model.compression import compress_module @@ -174,6 +174,12 @@ def huggingface_loader(llm_adapter: LLMModelAdaper, model_params: ModelParameter else: raise ValueError(f"Invalid device: {device}") + model, tokenizer = _try_load_default_quantization_model( + llm_adapter, device, num_gpus, model_params, kwargs + ) + if model: + return model, tokenizer + can_quantization = _check_quantization(model_params) if can_quantization and (num_gpus > 1 or model_params.load_4bit): @@ -192,6 +198,46 @@ def huggingface_loader(llm_adapter: LLMModelAdaper, model_params: ModelParameter # TODO merge current code into `load_huggingface_quantization_model` compress_module(model, model_params.device) + return _handle_model_and_tokenizer(model, tokenizer, device, num_gpus, model_params) + + +def _try_load_default_quantization_model( + llm_adapter: LLMModelAdapter, + device: str, + num_gpus: int, + model_params: ModelParameters, + kwargs: Dict[str, Any], +): + """Try load default quantization model(Support by huggingface default)""" + cloned_kwargs = {k: v for k, v in kwargs.items()} + try: + model, tokenizer = None, None + if device != "cuda": + return None, None + elif model_params.load_8bit and llm_adapter.support_8bit: + cloned_kwargs["load_in_8bit"] = True + model, tokenizer = llm_adapter.load(model_params.model_path, cloned_kwargs) + elif model_params.load_4bit and llm_adapter.support_4bit: + cloned_kwargs["load_in_4bit"] = True + model, tokenizer = llm_adapter.load(model_params.model_path, cloned_kwargs) + if model: + logger.info( + f"Load default quantization model {model_params.model_name} success" + ) + return _handle_model_and_tokenizer( + model, tokenizer, device, num_gpus, model_params + ) + return None, None + except Exception as e: + logger.warning( + f"Load default quantization model {model_params.model_name} failed, error: {str(e)}" + ) + return None, None + + +def _handle_model_and_tokenizer( + model, tokenizer, device: str, num_gpus: int, model_params: ModelParameters +): if ( (device == "cuda" and num_gpus == 1 and not model_params.cpu_offloading) or device == "mps" @@ -209,7 +255,7 @@ def huggingface_loader(llm_adapter: LLMModelAdaper, model_params: ModelParameter def load_huggingface_quantization_model( - llm_adapter: LLMModelAdaper, + llm_adapter: LLMModelAdapter, model_params: ModelParameters, kwargs: Dict, max_memory: Dict[int, str], @@ -344,7 +390,9 @@ def load_huggingface_quantization_model( return model, tokenizer -def llamacpp_loader(llm_adapter: LLMModelAdaper, model_params: LlamaCppModelParameters): +def llamacpp_loader( + llm_adapter: LLMModelAdapter, model_params: LlamaCppModelParameters +): try: from dbgpt.model.llm.llama_cpp.llama_cpp import LlamaCppModel except ImportError as exc: @@ -358,7 +406,7 @@ def llamacpp_loader(llm_adapter: LLMModelAdaper, model_params: LlamaCppModelPara return model, tokenizer -def proxyllm_loader(llm_adapter: LLMModelAdaper, model_params: ProxyModelParameters): +def proxyllm_loader(llm_adapter: LLMModelAdapter, model_params: ProxyModelParameters): from dbgpt.model.proxy.llms.proxy_model import ProxyModel logger.info("Load proxyllm") diff --git a/dbgpt/model/model_adapter.py b/dbgpt/model/model_adapter.py deleted file mode 100644 index 85243a12b..000000000 --- a/dbgpt/model/model_adapter.py +++ /dev/null @@ -1,660 +0,0 @@ -from __future__ import annotations - -from typing import Callable, List, Dict, Type, Tuple, TYPE_CHECKING, Any, Optional -import dataclasses -import logging -import threading -import os -from functools import cache -from dbgpt.model.base import ModelType -from dbgpt.model.parameter import ( - ModelParameters, - LlamaCppModelParameters, - ProxyModelParameters, -) -from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType -from dbgpt.util.parameter_utils import ( - _extract_parameter_details, - _build_parameter_class, - _get_dataclass_print_str, -) - -try: - from fastchat.conversation import ( - Conversation, - register_conv_template, - SeparatorStyle, - ) -except ImportError as exc: - raise ValueError( - "Could not import python package: fschat " - "Please install fastchat by command `pip install fschat` " - ) from exc - -if TYPE_CHECKING: - from fastchat.model.model_adapter import BaseModelAdapter - from dbgpt.model.adapter import BaseLLMAdaper as OldBaseLLMAdaper - from torch.nn import Module as TorchNNModule - -logger = logging.getLogger(__name__) - -thread_local = threading.local() -_IS_BENCHMARK = os.getenv("DB_GPT_MODEL_BENCHMARK", "False").lower() == "true" - - -_OLD_MODELS = [ - "llama-cpp", - "proxyllm", - "gptj-6b", - "codellama-13b-sql-sft", - "codellama-7b", - "codellama-7b-sql-sft", - "codellama-13b", -] - -_NEW_HF_CHAT_MODELS = [ - "yi-34b", - "yi-6b", -] - -# The implementation of some models in fastchat will affect the DB-GPT loading model and will be temporarily added to the blacklist. -_BLACK_LIST_MODLE_PROMPT = ["OpenHermes-2.5-Mistral-7B"] - - -class LLMModelAdaper: - """New Adapter for DB-GPT LLM models""" - - def use_fast_tokenizer(self) -> bool: - """Whether use a [fast Rust-based tokenizer](https://huggingface.co/docs/tokenizers/index) if it is supported - for a given model. - """ - return False - - def model_type(self) -> str: - return ModelType.HF - - def model_param_class(self, model_type: str = None) -> ModelParameters: - """Get the startup parameters instance of the model""" - model_type = model_type if model_type else self.model_type() - if model_type == ModelType.LLAMA_CPP: - return LlamaCppModelParameters - elif model_type == ModelType.PROXY: - return ProxyModelParameters - return ModelParameters - - def load(self, model_path: str, from_pretrained_kwargs: dict): - """Load model and tokenizer""" - raise NotImplementedError - - def load_from_params(self, params): - """Load the model and tokenizer according to the given parameters""" - raise NotImplementedError - - def support_async(self) -> bool: - """Whether the loaded model supports asynchronous calls""" - return False - - def get_generate_stream_function(self, model, model_path: str): - """Get the generate stream function of the model""" - raise NotImplementedError - - def get_async_generate_stream_function(self, model, model_path: str): - """Get the asynchronous generate stream function of the model""" - raise NotImplementedError - - def get_default_conv_template( - self, model_name: str, model_path: str - ) -> "Conversation": - """Get the default conv template""" - raise NotImplementedError - - def get_str_prompt( - self, - params: Dict, - messages: List[ModelMessage], - tokenizer: Any, - prompt_template: str = None, - ) -> Optional[str]: - return None - - def get_prompt_with_template( - self, - params: Dict, - messages: List[ModelMessage], - model_name: str, - model_path: str, - model_context: Dict, - prompt_template: str = None, - ): - conv = self.get_default_conv_template(model_name, model_path) - - if prompt_template: - logger.info(f"Use prompt template {prompt_template} from config") - conv = get_conv_template(prompt_template) - if not conv or not messages: - # Nothing to do - logger.info( - f"No conv from model_path {model_path} or no messages in params, {self}" - ) - return None, None, None - - conv = conv.copy() - system_messages = [] - user_messages = [] - ai_messages = [] - - for message in messages: - role, content = None, None - if isinstance(message, ModelMessage): - role = message.role - content = message.content - elif isinstance(message, dict): - role = message["role"] - content = message["content"] - else: - raise ValueError(f"Invalid message type: {message}") - - if role == ModelMessageRoleType.SYSTEM: - # Support for multiple system messages - system_messages.append(content) - elif role == ModelMessageRoleType.HUMAN: - # conv.append_message(conv.roles[0], content) - user_messages.append(content) - elif role == ModelMessageRoleType.AI: - # conv.append_message(conv.roles[1], content) - ai_messages.append(content) - else: - raise ValueError(f"Unknown role: {role}") - - can_use_systems: [] = [] - if system_messages: - if len(system_messages) > 1: - ## Compatible with dbgpt complex scenarios, the last system will protect more complete information entered by the current user - user_messages[-1] = system_messages[-1] - can_use_systems = system_messages[:-1] - else: - can_use_systems = system_messages - - for i in range(len(user_messages)): - conv.append_message(conv.roles[0], user_messages[i]) - if i < len(ai_messages): - conv.append_message(conv.roles[1], ai_messages[i]) - - if isinstance(conv, Conversation): - conv.set_system_message("".join(can_use_systems)) - else: - conv.update_system_message("".join(can_use_systems)) - - # Add a blank message for the assistant. - conv.append_message(conv.roles[1], None) - new_prompt = conv.get_prompt() - return new_prompt, conv.stop_str, conv.stop_token_ids - - def model_adaptation( - self, - params: Dict, - model_name: str, - model_path: str, - tokenizer: Any, - prompt_template: str = None, - ) -> Tuple[Dict, Dict]: - """Params adaptation""" - messages = params.get("messages") - # Some model scontext to dbgpt server - model_context = {"prompt_echo_len_char": -1, "has_format_prompt": False} - if messages: - # Dict message to ModelMessage - messages = [ - m if isinstance(m, ModelMessage) else ModelMessage(**m) - for m in messages - ] - params["messages"] = messages - - new_prompt = self.get_str_prompt(params, messages, tokenizer, prompt_template) - conv_stop_str, conv_stop_token_ids = None, None - if not new_prompt: - ( - new_prompt, - conv_stop_str, - conv_stop_token_ids, - ) = self.get_prompt_with_template( - params, messages, model_name, model_path, model_context, prompt_template - ) - if not new_prompt: - return params, model_context - - # Overwrite the original prompt - # TODO remote bos token and eos token from tokenizer_config.json of model - prompt_echo_len_char = len(new_prompt.replace("", "").replace("", "")) - model_context["prompt_echo_len_char"] = prompt_echo_len_char - model_context["echo"] = params.get("echo", True) - model_context["has_format_prompt"] = True - params["prompt"] = new_prompt - - custom_stop = params.get("stop") - custom_stop_token_ids = params.get("stop_token_ids") - - # Prefer the value passed in from the input parameter - params["stop"] = custom_stop or conv_stop_str - params["stop_token_ids"] = custom_stop_token_ids or conv_stop_token_ids - - return params, model_context - - -class OldLLMModelAdaperWrapper(LLMModelAdaper): - """Wrapping old adapter, which may be removed later""" - - def __init__(self, adapter: "OldBaseLLMAdaper", chat_adapter) -> None: - self._adapter = adapter - self._chat_adapter = chat_adapter - - def use_fast_tokenizer(self) -> bool: - return self._adapter.use_fast_tokenizer() - - def model_type(self) -> str: - return self._adapter.model_type() - - def model_param_class(self, model_type: str = None) -> ModelParameters: - return self._adapter.model_param_class(model_type) - - def get_default_conv_template( - self, model_name: str, model_path: str - ) -> "Conversation": - return self._chat_adapter.get_conv_template(model_path) - - def load(self, model_path: str, from_pretrained_kwargs: dict): - return self._adapter.loader(model_path, from_pretrained_kwargs) - - def get_generate_stream_function(self, model, model_path: str): - return self._chat_adapter.get_generate_stream_func(model_path) - - def __str__(self) -> str: - return "{}({}.{})".format( - self.__class__.__name__, - self._adapter.__class__.__module__, - self._adapter.__class__.__name__, - ) - - -class FastChatLLMModelAdaperWrapper(LLMModelAdaper): - """Wrapping fastchat adapter""" - - def __init__(self, adapter: "BaseModelAdapter") -> None: - self._adapter = adapter - - def use_fast_tokenizer(self) -> bool: - return self._adapter.use_fast_tokenizer - - def load(self, model_path: str, from_pretrained_kwargs: dict): - return self._adapter.load_model(model_path, from_pretrained_kwargs) - - def get_generate_stream_function(self, model: "TorchNNModule", model_path: str): - if _IS_BENCHMARK: - from dbgpt.util.benchmarks.llm.fastchat_benchmarks_inference import ( - generate_stream, - ) - - return generate_stream - else: - from fastchat.model.model_adapter import get_generate_stream_function - - return get_generate_stream_function(model, model_path) - - def get_default_conv_template( - self, model_name: str, model_path: str - ) -> "Conversation": - return self._adapter.get_default_conv_template(model_path) - - def __str__(self) -> str: - return "{}({}.{})".format( - self.__class__.__name__, - self._adapter.__class__.__module__, - self._adapter.__class__.__name__, - ) - - -class NewHFChatModelAdapter(LLMModelAdaper): - def load(self, model_path: str, from_pretrained_kwargs: dict): - try: - import transformers - from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel - except ImportError as exc: - raise ValueError( - "Could not import depend python package " - "Please install it with `pip install transformers`." - ) from exc - if not transformers.__version__ >= "4.34.0": - raise ValueError( - "Current model (Load by HFNewChatAdapter) require transformers.__version__>=4.34.0" - ) - revision = from_pretrained_kwargs.get("revision", "main") - try: - tokenizer = AutoTokenizer.from_pretrained( - model_path, - use_fast=self.use_fast_tokenizer, - revision=revision, - trust_remote_code=True, - ) - except TypeError: - tokenizer = AutoTokenizer.from_pretrained( - model_path, use_fast=False, revision=revision, trust_remote_code=True - ) - try: - model = AutoModelForCausalLM.from_pretrained( - model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs - ) - except NameError: - model = AutoModel.from_pretrained( - model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs - ) - # tokenizer.use_default_system_prompt = False - return model, tokenizer - - def get_generate_stream_function(self, model, model_path: str): - """Get the generate stream function of the model""" - from dbgpt.model.llm_out.hf_chat_llm import huggingface_chat_generate_stream - - return huggingface_chat_generate_stream - - def get_str_prompt( - self, - params: Dict, - messages: List[ModelMessage], - tokenizer: Any, - prompt_template: str = None, - ) -> Optional[str]: - from transformers import AutoTokenizer - - if not tokenizer: - raise ValueError("tokenizer is is None") - tokenizer: AutoTokenizer = tokenizer - - messages = ModelMessage.to_openai_messages(messages) - str_prompt = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) - return str_prompt - - -def get_conv_template(name: str) -> "Conversation": - """Get a conversation template.""" - from fastchat.conversation import get_conv_template - - return get_conv_template(name) - - -@cache -def _auto_get_conv_template(model_name: str, model_path: str) -> "Conversation": - try: - adapter = get_llm_model_adapter(model_name, model_path, use_fastchat=True) - return adapter.get_default_conv_template(model_name, model_path) - except Exception: - return None - - -@cache -def get_llm_model_adapter( - model_name: str, - model_path: str, - use_fastchat: bool = True, - use_fastchat_monkey_patch: bool = False, - model_type: str = None, -) -> LLMModelAdaper: - if model_type == ModelType.VLLM: - logger.info("Current model type is vllm, return VLLMModelAdaperWrapper") - return VLLMModelAdaperWrapper() - - use_new_hf_chat_models = any(m in model_name.lower() for m in _NEW_HF_CHAT_MODELS) - if use_new_hf_chat_models: - logger.info(f"Current model {model_name} use NewHFChatModelAdapter") - return NewHFChatModelAdapter() - - must_use_old = any(m in model_name for m in _OLD_MODELS) - if use_fastchat and not must_use_old: - logger.info("Use fastcat adapter") - adapter = _get_fastchat_model_adapter( - model_name, - model_path, - _fastchat_get_adapter_monkey_patch, - use_fastchat_monkey_patch=use_fastchat_monkey_patch, - ) - return FastChatLLMModelAdaperWrapper(adapter) - else: - from dbgpt.model.adapter import ( - get_llm_model_adapter as _old_get_llm_model_adapter, - ) - from dbgpt.app.chat_adapter import get_llm_chat_adapter - - logger.info("Use DB-GPT old adapter") - return OldLLMModelAdaperWrapper( - _old_get_llm_model_adapter(model_name, model_path), - get_llm_chat_adapter(model_name, model_path), - ) - - -def _get_fastchat_model_adapter( - model_name: str, - model_path: str, - caller: Callable[[str], None] = None, - use_fastchat_monkey_patch: bool = False, -): - from fastchat.model import model_adapter - - _bak_get_model_adapter = model_adapter.get_model_adapter - try: - if use_fastchat_monkey_patch: - model_adapter.get_model_adapter = _fastchat_get_adapter_monkey_patch - thread_local.model_name = model_name - _remove_black_list_model_of_fastchat() - if caller: - return caller(model_path) - finally: - del thread_local.model_name - model_adapter.get_model_adapter = _bak_get_model_adapter - - -def _fastchat_get_adapter_monkey_patch(model_path: str, model_name: str = None): - if not model_name: - if not hasattr(thread_local, "model_name"): - raise RuntimeError("fastchat get adapter monkey path need model_name") - model_name = thread_local.model_name - from fastchat.model.model_adapter import model_adapters - - for adapter in model_adapters: - if adapter.match(model_name): - logger.info( - f"Found llm model adapter with model name: {model_name}, {adapter}" - ) - return adapter - - model_path_basename = ( - None if not model_path else os.path.basename(os.path.normpath(model_path)) - ) - for adapter in model_adapters: - if model_path_basename and adapter.match(model_path_basename): - logger.info( - f"Found llm model adapter with model path: {model_path} and base name: {model_path_basename}, {adapter}" - ) - return adapter - - for adapter in model_adapters: - if model_path and adapter.match(model_path): - logger.info( - f"Found llm model adapter with model path: {model_path}, {adapter}" - ) - return adapter - - raise ValueError( - f"Invalid model adapter for model name {model_name} and model path {model_path}" - ) - - -@cache -def _remove_black_list_model_of_fastchat(): - from fastchat.model.model_adapter import model_adapters - - black_list_models = [] - for adapter in model_adapters: - try: - if ( - adapter.get_default_conv_template("/data/not_exist_model_path").name - in _BLACK_LIST_MODLE_PROMPT - ): - black_list_models.append(adapter) - except Exception: - pass - for adapter in black_list_models: - model_adapters.remove(adapter) - - -def _dynamic_model_parser() -> Callable[[None], List[Type]]: - from dbgpt.util.parameter_utils import _SimpleArgParser - from dbgpt.model.parameter import ( - EmbeddingModelParameters, - WorkerType, - EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG, - ) - - pre_args = _SimpleArgParser("model_name", "model_path", "worker_type", "model_type") - pre_args.parse() - model_name = pre_args.get("model_name") - model_path = pre_args.get("model_path") - worker_type = pre_args.get("worker_type") - model_type = pre_args.get("model_type") - if model_name is None and model_type != ModelType.VLLM: - return None - if worker_type == WorkerType.TEXT2VEC: - return [ - EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get( - model_name, EmbeddingModelParameters - ) - ] - - llm_adapter = get_llm_model_adapter(model_name, model_path, model_type=model_type) - param_class = llm_adapter.model_param_class() - return [param_class] - - -class VLLMModelAdaperWrapper(LLMModelAdaper): - """Wrapping vllm engine""" - - def model_type(self) -> str: - return ModelType.VLLM - - def model_param_class(self, model_type: str = None) -> ModelParameters: - import argparse - from vllm.engine.arg_utils import AsyncEngineArgs - - parser = argparse.ArgumentParser() - parser = AsyncEngineArgs.add_cli_args(parser) - parser.add_argument("--model_name", type=str, help="model name") - parser.add_argument( - "--model_path", - type=str, - help="local model path of the huggingface model to use", - ) - parser.add_argument("--model_type", type=str, help="model type") - parser.add_argument("--device", type=str, default=None, help="device") - # TODO parse prompt templete from `model_name` and `model_path` - parser.add_argument( - "--prompt_template", - type=str, - default=None, - help="Prompt template. If None, the prompt template is automatically determined from model path", - ) - - descs = _extract_parameter_details( - parser, - "dbgpt.model.parameter.VLLMModelParameters", - skip_names=["model"], - overwrite_default_values={"trust_remote_code": True}, - ) - return _build_parameter_class(descs) - - def load_from_params(self, params): - from vllm import AsyncLLMEngine - from vllm.engine.arg_utils import AsyncEngineArgs - import torch - - num_gpus = torch.cuda.device_count() - if num_gpus > 1 and hasattr(params, "tensor_parallel_size"): - setattr(params, "tensor_parallel_size", num_gpus) - logger.info( - f"Start vllm AsyncLLMEngine with args: {_get_dataclass_print_str(params)}" - ) - - params = dataclasses.asdict(params) - params["model"] = params["model_path"] - attrs = [attr.name for attr in dataclasses.fields(AsyncEngineArgs)] - vllm_engine_args_dict = {attr: params.get(attr) for attr in attrs} - # Set the attributes from the parsed arguments. - engine_args = AsyncEngineArgs(**vllm_engine_args_dict) - engine = AsyncLLMEngine.from_engine_args(engine_args) - return engine, engine.engine.tokenizer - - def support_async(self) -> bool: - return True - - def get_async_generate_stream_function(self, model, model_path: str): - from dbgpt.model.llm_out.vllm_llm import generate_stream - - return generate_stream - - def get_default_conv_template( - self, model_name: str, model_path: str - ) -> "Conversation": - return _auto_get_conv_template(model_name, model_path) - - def __str__(self) -> str: - return "{}.{}".format(self.__class__.__module__, self.__class__.__name__) - - -# Covering the configuration of fastcaht, we will regularly feedback the code here to fastchat. -# We also recommend that you modify it directly in the fastchat repository. - -# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L212 -register_conv_template( - Conversation( - name="aquila-legacy", - system_message="A chat between a curious human and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", - roles=("### Human: ", "### Assistant: ", "System"), - messages=(), - offset=0, - sep_style=SeparatorStyle.NO_COLON_TWO, - sep="\n", - sep2="", - stop_str=["", "[UNK]"], - ), - override=True, -) -# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L227 -register_conv_template( - Conversation( - name="aquila", - system_message="A chat between a curious human and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the human's questions.", - roles=("Human", "Assistant", "System"), - messages=(), - offset=0, - sep_style=SeparatorStyle.ADD_COLON_TWO, - sep="###", - sep2="", - stop_str=["", "[UNK]"], - ), - override=True, -) -# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L242 -register_conv_template( - Conversation( - name="aquila-v1", - roles=("<|startofpiece|>", "<|endofpiece|>", ""), - messages=(), - offset=0, - sep_style=SeparatorStyle.NO_COLON_TWO, - sep="", - sep2="", - stop_str=["", "<|endoftext|>"], - ), - override=True, -) From d9065227bd61359e4ab71a499c33ace7f61520f8 Mon Sep 17 00:00:00 2001 From: vvycaaa <147325516+vvycaaa@users.noreply.github.com> Date: Fri, 22 Dec 2023 09:42:58 +0800 Subject: [PATCH 2/2] fix(core): Move the last user's information to the end (#960) --- dbgpt/core/interface/message.py | 13 +++-- dbgpt/core/interface/tests/test_message.py | 57 ++++++++++++++++++++++ dbgpt/model/proxy/llms/bard.py | 13 +++-- dbgpt/model/proxy/llms/chatgpt.py | 13 +++-- 4 files changed, 75 insertions(+), 21 deletions(-) mode change 100644 => 100755 dbgpt/core/interface/message.py mode change 100644 => 100755 dbgpt/core/interface/tests/test_message.py mode change 100644 => 100755 dbgpt/model/proxy/llms/bard.py mode change 100644 => 100755 dbgpt/model/proxy/llms/chatgpt.py diff --git a/dbgpt/core/interface/message.py b/dbgpt/core/interface/message.py old mode 100644 new mode 100755 index bd06f0dc7..2b1439c6d --- a/dbgpt/core/interface/message.py +++ b/dbgpt/core/interface/message.py @@ -157,14 +157,13 @@ def to_openai_messages(messages: List["ModelMessage"]) -> List[Dict[str, str]]: else: pass # Move the last user's information to the end - temp_his = history[::-1] - last_user_input = None - for m in temp_his: - if m["role"] == "user": - last_user_input = m + last_user_input_index = None + for i in range(len(history) - 1, -1, -1): + if history[i]["role"] == "user": + last_user_input_index = i break - if last_user_input: - history.remove(last_user_input) + if last_user_input_index: + last_user_input = history.pop(last_user_input_index) history.append(last_user_input) return history diff --git a/dbgpt/core/interface/tests/test_message.py b/dbgpt/core/interface/tests/test_message.py old mode 100644 new mode 100755 index 425f268af..41f5f36c5 --- a/dbgpt/core/interface/tests/test_message.py +++ b/dbgpt/core/interface/tests/test_message.py @@ -67,6 +67,23 @@ def conversation_with_messages(): return conv +@pytest.fixture +def human_model_message(): + return ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hello") + + +@pytest.fixture +def ai_model_message(): + return ModelMessage(role=ModelMessageRoleType.AI, content="Hi there") + + +@pytest.fixture +def system_model_message(): + return ModelMessage( + role=ModelMessageRoleType.SYSTEM, content="You are a helpful chatbot!" + ) + + def test_init(basic_conversation): assert basic_conversation.chat_mode == "chat_normal" assert basic_conversation.user_name == "user1" @@ -305,3 +322,43 @@ def test_load_from_storage(storage_conversation, in_memory_storage): assert new_conversation.messages[1].content == "AI response" assert isinstance(new_conversation.messages[0], HumanMessage) assert isinstance(new_conversation.messages[1], AIMessage) + + +def test_to_openai_messages( + human_model_message, ai_model_message, system_model_message +): + none_messages = ModelMessage.to_openai_messages([]) + assert none_messages == [] + + single_messages = ModelMessage.to_openai_messages([human_model_message]) + assert single_messages == [{"role": "user", "content": human_model_message.content}] + + normal_messages = ModelMessage.to_openai_messages( + [ + system_model_message, + human_model_message, + ai_model_message, + human_model_message, + ] + ) + assert normal_messages == [ + {"role": "system", "content": system_model_message.content}, + {"role": "user", "content": human_model_message.content}, + {"role": "assistant", "content": ai_model_message.content}, + {"role": "user", "content": human_model_message.content}, + ] + + shuffle_messages = ModelMessage.to_openai_messages( + [ + system_model_message, + human_model_message, + human_model_message, + ai_model_message, + ] + ) + assert shuffle_messages == [ + {"role": "system", "content": system_model_message.content}, + {"role": "user", "content": human_model_message.content}, + {"role": "assistant", "content": ai_model_message.content}, + {"role": "user", "content": human_model_message.content}, + ] diff --git a/dbgpt/model/proxy/llms/bard.py b/dbgpt/model/proxy/llms/bard.py old mode 100644 new mode 100755 index 5b6ed26a0..fc398fe8b --- a/dbgpt/model/proxy/llms/bard.py +++ b/dbgpt/model/proxy/llms/bard.py @@ -25,14 +25,13 @@ def bard_generate_stream( else: pass - temp_his = history[::-1] - last_user_input = None - for m in temp_his: - if m["role"] == "user": - last_user_input = m + last_user_input_index = None + for i in range(len(history) - 1, -1, -1): + if history[i]["role"] == "user": + last_user_input_index = i break - if last_user_input: - history.remove(last_user_input) + if last_user_input_index: + last_user_input = history.pop(last_user_input_index) history.append(last_user_input) msgs = [] diff --git a/dbgpt/model/proxy/llms/chatgpt.py b/dbgpt/model/proxy/llms/chatgpt.py old mode 100644 new mode 100755 index e9229da44..d81626e7a --- a/dbgpt/model/proxy/llms/chatgpt.py +++ b/dbgpt/model/proxy/llms/chatgpt.py @@ -110,14 +110,13 @@ def _build_request(model: ProxyModel, params): pass # Move the last user's information to the end - temp_his = history[::-1] - last_user_input = None - for m in temp_his: - if m["role"] == "user": - last_user_input = m + last_user_input_index = None + for i in range(len(history) - 1, -1, -1): + if history[i]["role"] == "user": + last_user_input_index = i break - if last_user_input: - history.remove(last_user_input) + if last_user_input_index: + last_user_input = history.pop(last_user_input_index) history.append(last_user_input) payloads = {