From 325498085ece4914c161021cb86b810648a7a377 Mon Sep 17 00:00:00 2001 From: wangyuxin Date: Sun, 28 Apr 2024 00:14:10 +0800 Subject: [PATCH 1/4] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=9E=B6=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- generate/__init__.py | 14 - generate/chat_completion/__init__.py | 15 - generate/chat_completion/base.py | 58 +- generate/chat_completion/cost_caculator.py | 40 +- generate/chat_completion/message/converter.py | 84 +++ generate/chat_completion/message/core.py | 16 +- generate/chat_completion/message/utils.py | 2 + generate/chat_completion/model_output.py | 19 +- generate/chat_completion/models/__init__.py | 18 +- generate/chat_completion/models/anthropic.py | 289 +++++++--- generate/chat_completion/models/azure.py | 34 +- generate/chat_completion/models/baichuan.py | 161 +++--- generate/chat_completion/models/bailian.py | 226 -------- generate/chat_completion/models/dashscope.py | 201 +++++-- .../models/dashscope_multimodal.py | 2 +- generate/chat_completion/models/deepseek.py | 29 +- generate/chat_completion/models/hunyuan.py | 2 +- .../chat_completion/models/minimax_legacy.py | 208 ------- .../chat_completion/models/minimax_pro.py | 2 +- generate/chat_completion/models/openai.py | 6 +- .../chat_completion/models/openai_like.py | 429 +++++++------- generate/chat_completion/models/openrouter.py | 142 +++++ generate/chat_completion/models/wenxin.py | 2 +- generate/chat_completion/models/yi.py | 31 +- generate/chat_completion/models/zhipu.py | 526 ++++++------------ generate/chat_completion/stream_manager.py | 11 +- generate/chat_completion/tool.py | 25 + generate/http.py | 19 +- generate/model.py | 1 - generate/platforms/__init__.py | 4 +- generate/platforms/baichuan.py | 1 - generate/platforms/bailian.py | 37 -- generate/platforms/dashscope.py | 3 + generate/platforms/openrouter.py | 12 + generate/platforms/yi.py | 2 +- generate/platforms/zhipu.py | 3 +- generate/types.py | 3 +- generate/version.py | 2 +- pyproject.toml | 15 +- 39 files changed, 1280 insertions(+), 1414 deletions(-) create mode 100644 generate/chat_completion/message/converter.py delete mode 100644 generate/chat_completion/models/bailian.py delete mode 100644 generate/chat_completion/models/minimax_legacy.py create mode 100644 generate/chat_completion/models/openrouter.py delete mode 100644 generate/platforms/bailian.py create mode 100644 generate/platforms/openrouter.py diff --git a/generate/__init__.py b/generate/__init__.py index adbebbc..4ba466d 100644 --- a/generate/__init__.py +++ b/generate/__init__.py @@ -4,8 +4,6 @@ AzureChat, BaichuanChat, BaichuanChatParameters, - BailianChat, - BailianChatParameters, ChatCompletionModel, ChatCompletionOutput, ChatModelRegistry, @@ -19,8 +17,6 @@ HunyuanChatParameters, MinimaxChat, MinimaxChatParameters, - MinimaxLegacyChat, - MinimaxLegacyChatParameters, MinimaxProChat, MinimaxProChatParameters, MoonshotChat, @@ -35,8 +31,6 @@ WenxinChatParameters, YiChat, YiChatParameters, - ZhipuCharacterChat, - ZhipuCharacterChatParameters, ZhipuChat, ZhipuChatParameters, tool, @@ -67,7 +61,6 @@ AzureSettings, BaichuanSettings, BaiduCreationSettings, - BailianSettings, DashScopeSettings, DeepSeekSettings, HunyuanSettings, @@ -102,12 +95,8 @@ 'MinimaxChatParameters', 'MinimaxProChat', 'MinimaxProChatParameters', - 'MinimaxLegacyChat', - 'MinimaxLegacyChatParameters', 'ZhipuChat', 'ZhipuChatParameters', - 'ZhipuCharacterChat', - 'ZhipuCharacterChatParameters', 'StepFunChat', 'StepFunChatParameters', 'StepFunSettings', @@ -117,8 +106,6 @@ 'HunyuanChatParameters', 'BaichuanChat', 'BaichuanChatParameters', - 'BailianChat', - 'BailianChatParameters', 'DashScopeChat', 'DashScopeChatParameters', 'DashScopeMultiModalChat', @@ -161,7 +148,6 @@ 'ZhipuSettings', 'OpenAISettings', 'QianfanSettings', - 'BailianSettings', 'HunyuanSettings', 'DashScopeSettings', 'MoonshotSettings', diff --git a/generate/chat_completion/__init__.py b/generate/chat_completion/__init__.py index 3484b73..50b49e1 100644 --- a/generate/chat_completion/__init__.py +++ b/generate/chat_completion/__init__.py @@ -20,8 +20,6 @@ AzureChat, BaichuanChat, BaichuanChatParameters, - BailianChat, - BailianChatParameters, DashScopeChat, DashScopeChatParameters, DashScopeMultiModalChat, @@ -32,8 +30,6 @@ HunyuanChatParameters, MinimaxChat, MinimaxChatParameters, - MinimaxLegacyChat, - MinimaxLegacyChatParameters, MinimaxProChat, MinimaxProChatParameters, MoonshotChat, @@ -46,8 +42,6 @@ WenxinChatParameters, YiChat, YiChatParameters, - ZhipuCharacterChat, - ZhipuCharacterChatParameters, ZhipuChat, ZhipuChatParameters, ) @@ -61,13 +55,10 @@ (OpenAIChat, OpenAIChatParameters), (MinimaxChat, MinimaxChatParameters), (MinimaxProChat, MinimaxProChatParameters), - (MinimaxLegacyChat, MinimaxProChatParameters), (ZhipuChat, ZhipuChatParameters), - (ZhipuCharacterChat, ZhipuCharacterChatParameters), (WenxinChat, WenxinChatParameters), (HunyuanChat, HunyuanChatParameters), (BaichuanChat, BaichuanChatParameters), - (BailianChat, BailianChatParameters), (DashScopeChat, DashScopeChatParameters), (DashScopeMultiModalChat, DashScopeMultiModalChatParameters), (MoonshotChat, MoonshotChatParameters), @@ -92,22 +83,16 @@ 'MinimaxChatParameters', 'MinimaxProChat', 'MinimaxProChatParameters', - 'MinimaxLegacyChat', - 'MinimaxLegacyChatParameters', 'OpenAIChat', 'OpenAIChatParameters', 'ZhipuChat', 'ZhipuChatParameters', - 'ZhipuCharacterChat', - 'ZhipuCharacterChatParameters', 'WenxinChat', 'WenxinChatParameters', 'HunyuanChat', 'HunyuanChatParameters', 'BaichuanChat', 'BaichuanChatParameters', - 'BailianChat', - 'BailianChatParameters', 'YiChat', 'YiChatParameters', 'StepFunChat', diff --git a/generate/chat_completion/base.py b/generate/chat_completion/base.py index 6aa7f56..2407123 100644 --- a/generate/chat_completion/base.py +++ b/generate/chat_completion/base.py @@ -1,5 +1,7 @@ from __future__ import annotations +import contextlib +import json import logging from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, AsyncIterator, ClassVar, Iterator, List, Type, TypeVar, get_type_hints @@ -7,9 +9,14 @@ from pydantic import BaseModel from typing_extensions import Self, Unpack, override +from generate.chat_completion.cost_caculator import CostCalculator from generate.chat_completion.message import Prompt +from generate.chat_completion.message.converter import MessageConverter +from generate.chat_completion.message.core import Messages +from generate.chat_completion.message.utils import ensure_messages from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput from generate.chat_completion.stream_manager import StreamManager +from generate.chat_completion.tool import ToolCallMixin from generate.http import HttpClient, HttpxPostKwargs from generate.model import GenerateModel, ModelParameters from generate.platforms import PlatformSettings @@ -67,6 +74,7 @@ def hook(self, **kwargs: Unpack['HookModelKwargs']) -> 'HookChatCompletionModel' class RemoteChatCompletionModel(ChatCompletionModel, ABC): settings: PlatformSettings http_client: HttpClient + message_converter: MessageConverter available_models: ClassVar[List[str]] = [] def __init__( @@ -75,28 +83,53 @@ def __init__( parameters: ModelParameters, settings: PlatformSettings, http_client: HttpClient, + message_converter: MessageConverter, + cost_calculator: CostCalculator | None = None, ) -> None: self.model = model self.parameters = parameters self.settings = settings self.http_client = http_client + self.message_converter = message_converter + self.cost_calculator = cost_calculator @abstractmethod def _process_reponse(self, response: dict[str, Any]) -> ChatCompletionOutput: ... @abstractmethod - def _process_stream_line(self, line: str, stream_manager: StreamManager) -> ChatCompletionStreamOutput | None: + def _process_stream_response( + self, response: dict[str, Any], stream_manager: StreamManager + ) -> ChatCompletionStreamOutput | None: ... @abstractmethod - def _get_request_parameters(self, prompt: Prompt, stream: bool = False, **kwargs: Any) -> HttpxPostKwargs: + def _get_request_parameters(self, messages: Messages, stream: bool = False, **kwargs: Any) -> HttpxPostKwargs: ... + def cost(self, input_tokens: int, output_tokens: int) -> float | None: + if self.cost_calculator is None: + return None + return self.cost_calculator.calculate( + model_name=self.model, + input_tokens=input_tokens, + output_tokens=output_tokens, + ) + + def list_models(self) -> List[str]: + return self.available_models + + def process_prompt(self, prompt: Prompt) -> Messages: + messages = ensure_messages(prompt) + if isinstance(self, ToolCallMixin): + self.adapt_tool_calls(messages) + return messages + @override def generate(self, prompt: Prompt, **kwargs: Any) -> ChatCompletionOutput: timeout = kwargs.pop('timeout') if 'timeout' in kwargs else None - request_parameters = self._get_request_parameters(prompt, **kwargs) + messages = self.process_prompt(prompt) + request_parameters = self._get_request_parameters(messages, **kwargs) request_parameters['timeout'] = timeout response = self.http_client.post(request_parameters=request_parameters) return self._process_reponse(response.json()) @@ -104,7 +137,8 @@ def generate(self, prompt: Prompt, **kwargs: Any) -> ChatCompletionOutput: @override async def async_generate(self, prompt: Prompt, **kwargs: Any) -> ChatCompletionOutput: timeout = kwargs.pop('timeout') if 'timeout' in kwargs else None - request_parameters = self._get_request_parameters(prompt, **kwargs) + messages = self.process_prompt(prompt) + request_parameters = self._get_request_parameters(messages, **kwargs) request_parameters['timeout'] = timeout response = await self.http_client.async_post(request_parameters=request_parameters) return self._process_reponse(response.json()) @@ -112,21 +146,27 @@ async def async_generate(self, prompt: Prompt, **kwargs: Any) -> ChatCompletionO @override def stream_generate(self, prompt: Prompt, **kwargs: Any) -> Iterator[ChatCompletionStreamOutput]: timeout = kwargs.pop('timeout') if 'timeout' in kwargs else None - request_parameters = self._get_request_parameters(prompt, stream=True, **kwargs) + messages = self.process_prompt(prompt) + request_parameters = self._get_request_parameters(messages, stream=True, **kwargs) request_parameters['timeout'] = timeout stream_manager = StreamManager(info=self.model_info) for line in self.http_client.stream_post(request_parameters=request_parameters): - if output := self._process_stream_line(line, stream_manager): - yield output + with contextlib.suppress(json.JSONDecodeError): + response = json.loads(line) + if (output := self._process_stream_response(response, stream_manager)) and output: + yield output @override async def async_stream_generate(self, prompt: Prompt, **kwargs: Any) -> AsyncIterator[ChatCompletionStreamOutput]: timeout = kwargs.pop('timeout') if 'timeout' in kwargs else None - request_parameters = self._get_request_parameters(prompt, stream=True, **kwargs) + messages = self.process_prompt(prompt) + request_parameters = self._get_request_parameters(messages, stream=True, **kwargs) request_parameters['timeout'] = timeout stream_manager = StreamManager(info=self.model_info) async for line in self.http_client.async_stream_post(request_parameters=request_parameters): - if output := self._process_stream_line(line, stream_manager): + with contextlib.suppress(json.JSONDecodeError): + response = json.loads(line) + if (output := self._process_stream_response(response, stream_manager)) and output: yield output @classmethod diff --git a/generate/chat_completion/cost_caculator.py b/generate/chat_completion/cost_caculator.py index d7e0c58..7be4636 100644 --- a/generate/chat_completion/cost_caculator.py +++ b/generate/chat_completion/cost_caculator.py @@ -1,26 +1,26 @@ from __future__ import annotations -# yuan per thousand tokens -DefaultPriceMap = { - 'moonshot': { - 'moonshot-v1-8k': (0.012, 0.012), - 'moonshot-v1-32k': (0.024, 0.024), - 'moonshot-v1-128k': (0.06, 0.06), - }, - 'minimax': { - 'abab5.5-chat': (0.015, 0.015), - 'abab5.5s-chat': (0.005, 0.005), - 'abab6-chat': (0.1, 0.1), - }, -} +from typing import Protocol +from generate.types import ModelPrice -class GeneralCostCalculator: - def __init__(self, price_map: dict[str, dict[str, tuple[float, float]]] | None = None) -> None: - self.price_map = price_map or DefaultPriceMap - def calculate(self, model_type: str, model_name: str, input_tokens: int, output_tokens: int) -> float | None: - if model_type in self.price_map and model_name in self.price_map[model_type]: - price = self.price_map[model_type][model_name] - return (input_tokens * price[0] + output_tokens * price[1]) / 1000 +class CostCalculator(Protocol): + def calculate(self, model_name: str, input_tokens: int, output_tokens: int) -> float | None: + ... + + +class GeneralCostCalculator(CostCalculator): + def __init__(self, model_price: ModelPrice, exchange_rate: float = 1) -> None: + # per million tokens + self.model_price = model_price + self.exchange_rate = exchange_rate + + def calculate(self, model_name: str, input_tokens: int, output_tokens: int) -> float | None: + if self.model_price is None: + return None + for model, (input_token_price, output_token_price) in self.model_price.items(): + if model in model_name: + cost = input_token_price * (input_tokens / 1_000_000) + output_token_price * (output_tokens / 1_000_000) + return cost * self.exchange_rate return None diff --git a/generate/chat_completion/message/converter.py b/generate/chat_completion/message/converter.py new file mode 100644 index 0000000..2dc84da --- /dev/null +++ b/generate/chat_completion/message/converter.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from typing import Any, Protocol + +from generate.chat_completion.message.core import ( + AssistantMessage, + FunctionMessage, + Messages, + SystemMessage, + ToolMessage, + UserMessage, + UserMultiPartMessage, +) + + +class MessageConverter(Protocol): + def convert_user_message(self, message: UserMessage) -> dict[str, Any]: + ... + + def convert_assistant_message(self, message: AssistantMessage) -> dict[str, Any]: + ... + + def convert_function_message(self, message: FunctionMessage) -> dict[str, Any]: + ... + + def convert_tool_message(self, message: ToolMessage) -> dict[str, Any]: + ... + + def convert_system_message(self, message: SystemMessage) -> dict[str, Any]: + ... + + def convert_user_multi_part_message(self, message: UserMultiPartMessage) -> dict[str, Any]: + ... + + def convert_message(self, message: Any) -> dict[str, Any]: + convert_methods = { + UserMessage: self.convert_user_message, + AssistantMessage: self.convert_assistant_message, + FunctionMessage: self.convert_function_message, + ToolMessage: self.convert_tool_message, + SystemMessage: self.convert_system_message, + UserMultiPartMessage: self.convert_user_multi_part_message, + } + return convert_methods[type(message)](message) + + def convert_messages(self, messages: Messages) -> list[dict[str, Any]]: + convert_methods = { + UserMessage: self.convert_user_message, + AssistantMessage: self.convert_assistant_message, + FunctionMessage: self.convert_function_message, + ToolMessage: self.convert_tool_message, + SystemMessage: self.convert_system_message, + UserMultiPartMessage: self.convert_user_multi_part_message, + } + return [convert_methods[type(message)](message) for message in messages] + + +class SimpleMessageConverter(MessageConverter): + def convert_system_message(self, message: SystemMessage) -> dict[str, Any]: + return { + 'role': 'system', + 'content': message.content, + } + + def convert_user_message(self, message: UserMessage) -> dict[str, Any]: + return { + 'role': 'user', + 'content': message.content, + } + + def convert_assistant_message(self, message: AssistantMessage) -> dict[str, Any]: + return { + 'role': 'assistant', + 'content': message.content, + } + + def convert_function_message(self, message: FunctionMessage) -> dict[str, Any]: + raise NotImplementedError('FunctionMessage is not supported by this converter') + + def convert_tool_message(self, message: ToolMessage) -> dict[str, Any]: + raise NotImplementedError('ToolMessage is not supported by this converter') + + def convert_user_multi_part_message(self, message: UserMultiPartMessage) -> dict[str, Any]: + raise NotImplementedError('UserMultiPartMessage is not supported by this converter') diff --git a/generate/chat_completion/message/core.py b/generate/chat_completion/message/core.py index 9c757a2..7adad0a 100644 --- a/generate/chat_completion/message/core.py +++ b/generate/chat_completion/message/core.py @@ -41,13 +41,16 @@ class ImageUrlPart(BaseModel): class ImagePart(BaseModel): image: bytes - image_format: Optional[str] = None + image_format: str @classmethod - def from_url_or_path(cls, url_or_path: str | Path) -> Self: + def from_url_or_path(cls, url_or_path: str | Path, image_format: str | None = None) -> Self: image_data = fetch_data(str(url_or_path)) - mimetype = mimetypes.guess_type(url=str(url_or_path))[0] - image_format = mimetype.split('/')[1] if mimetype is not None else None + if image_format is None: + mimetype = mimetypes.guess_type(url=str(url_or_path))[0] + image_format = mimetype.split('/')[1] if mimetype is not None else None + if image_format is None: + raise ValueError(f'Cannot determine image format for {url_or_path}') return cls(image=image_data, image_format=image_format) @@ -66,6 +69,7 @@ class ToolMessage(Message): role: Literal['tool'] = 'tool' tool_call_id: str content: Optional[str] = None + is_error: bool = False class FunctionCall(BaseModel): @@ -90,6 +94,10 @@ class AssistantMessage(Message): def is_over(self) -> bool: return self.function_call is None and self.tool_calls is None + def model_post_init(self, __context: Any) -> None: + if not self.content and self.function_call is None and self.tool_calls is None: + raise ValueError('AssistantMessage must have content, function_call, or tool_calls') + UnionUserMessage = Union[UserMessage, UserMultiPartMessage] UnionUserPart = Union[TextPart, ImageUrlPart] diff --git a/generate/chat_completion/message/utils.py b/generate/chat_completion/message/utils.py index e21e27d..208a0d7 100644 --- a/generate/chat_completion/message/utils.py +++ b/generate/chat_completion/message/utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Any, Literal from generate.chat_completion.message.core import ( diff --git a/generate/chat_completion/model_output.py b/generate/chat_completion/model_output.py index ec1ac6b..67c7a7c 100644 --- a/generate/chat_completion/model_output.py +++ b/generate/chat_completion/model_output.py @@ -1,5 +1,6 @@ from __future__ import annotations +from enum import Enum from typing import Literal, Optional from pydantic import BaseModel @@ -8,9 +9,25 @@ from generate.model import ModelOutput +class FinishReason(str, Enum): + end_turn = 'end_turn' + stop = 'stop' + length = 'length' + content_filter = 'content_filter' + tool_calls = 'tool_calls' + funtion_call = 'function_call' + + +class Usage(BaseModel): + input_tokens: Optional[int] = None + output_tokens: Optional[int] = None + cost: Optional[float] = None + + class ChatCompletionOutput(ModelOutput): message: AssistantMessage - finish_reason: Optional[str] = None + usage: Optional[Usage] = None + finish_reason: Optional[FinishReason] = None @property def reply(self) -> str: diff --git a/generate/chat_completion/models/__init__.py b/generate/chat_completion/models/__init__.py index 44f5a2c..be04f36 100644 --- a/generate/chat_completion/models/__init__.py +++ b/generate/chat_completion/models/__init__.py @@ -1,10 +1,6 @@ from generate.chat_completion.models.anthropic import AnthropicChat, AnthropicChatParameters from generate.chat_completion.models.azure import AzureChat from generate.chat_completion.models.baichuan import BaichuanChat, BaichuanChatParameters -from generate.chat_completion.models.bailian import ( - BailianChat, - BailianChatParameters, -) from generate.chat_completion.models.dashscope import ( DashScopeChat, DashScopeChatParameters, @@ -16,19 +12,13 @@ from generate.chat_completion.models.deepseek import DeepSeekChat, DeepSeekChatParameters from generate.chat_completion.models.hunyuan import HunyuanChat, HunyuanChatParameters from generate.chat_completion.models.minimax import MinimaxChat, MinimaxChatParameters -from generate.chat_completion.models.minimax_legacy import MinimaxLegacyChat, MinimaxLegacyChatParameters from generate.chat_completion.models.minimax_pro import MinimaxProChat, MinimaxProChatParameters from generate.chat_completion.models.moonshot import MoonshotChat, MoonshotChatParameters from generate.chat_completion.models.openai import OpenAIChat, OpenAIChatParameters from generate.chat_completion.models.stepfun import StepFunChat, StepFunChatParameters from generate.chat_completion.models.wenxin import WenxinChat, WenxinChatParameters from generate.chat_completion.models.yi import YiChat, YiChatParameters -from generate.chat_completion.models.zhipu import ( - ZhipuCharacterChat, - ZhipuCharacterChatParameters, - ZhipuChat, - ZhipuChatParameters, -) +from generate.chat_completion.models.zhipu import ZhipuChat, ZhipuChatParameters __all__ = [ 'AzureChat', @@ -36,12 +26,8 @@ 'AnthropicChatParameters', 'BaichuanChat', 'BaichuanChatParameters', - 'BailianChat', - 'BailianChatParameters', 'HunyuanChat', 'HunyuanChatParameters', - 'MinimaxLegacyChat', - 'MinimaxLegacyChatParameters', 'MinimaxProChat', 'MinimaxProChatParameters', 'MinimaxChat', @@ -56,8 +42,6 @@ 'YiChatParameters', 'ZhipuChat', 'ZhipuChatParameters', - 'ZhipuCharacterChat', - 'ZhipuCharacterChatParameters', 'DashScopeChat', 'DashScopeChatParameters', 'DashScopeMultiModalChat', diff --git a/generate/chat_completion/models/anthropic.py b/generate/chat_completion/models/anthropic.py index 0f99a54..4570a24 100644 --- a/generate/chat_completion/models/anthropic.py +++ b/generate/chat_completion/models/anthropic.py @@ -2,36 +2,53 @@ import base64 import json -from typing import Any, AsyncIterator, ClassVar, Dict, Iterator, List, Literal, Optional +import uuid +from typing import Any, AsyncIterator, ClassVar, Dict, Iterator, List, Optional from pydantic import Field, PositiveInt from typing_extensions import Annotated, TypedDict, Unpack, override from generate.chat_completion.base import RemoteChatCompletionModel +from generate.chat_completion.cost_caculator import CostCalculator, GeneralCostCalculator from generate.chat_completion.message import Prompt +from generate.chat_completion.message.converter import MessageConverter from generate.chat_completion.message.core import ( AssistantMessage, + FunctionCall, + FunctionMessage, ImagePart, ImageUrlPart, - Message, + Messages, SystemMessage, TextPart, + ToolCall, + ToolMessage, UserMessage, UserMultiPartMessage, ) from generate.chat_completion.message.exception import MessageTypeError -from generate.chat_completion.message.utils import ensure_messages -from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput +from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput, FinishReason, Usage from generate.chat_completion.stream_manager import StreamManager +from generate.chat_completion.tool import Tool, ToolCallMixin from generate.http import HttpClient, HttpxPostKwargs from generate.model import ModelParameters, RemoteModelParametersDict from generate.platforms import AnthropicSettings -from generate.types import Probability, Temperature +from generate.types import ModelPrice, OrIterable, Probability, Temperature +from generate.utils import ensure_iterable +AnthropicModelPrice: ModelPrice = { + 'claude-instant': (0.8, 2.4), + 'claude-2': (8, 24), + 'claude-3-haiku': (0.25, 1.25), + 'claude-3-sonnet': (3, 15), + 'claude-3-opus': (15, 75), +} -class AnthropicMessage(TypedDict): - role: Literal['user', 'assistant'] - content: str + +class AnthropicTool(TypedDict): + name: str + description: Optional[str] + input_schema: Dict[str, Any] class AnthropicChatParameters(ModelParameters): @@ -42,6 +59,8 @@ class AnthropicChatParameters(ModelParameters): temperature: Optional[Temperature] = None top_p: Optional[Probability] = None top_k: Optional[PositiveInt] = None + tools: Optional[List[AnthropicTool]] = None + tool_choice: Optional[str] = None class AnthropicParametersDict(RemoteModelParametersDict, total=False): @@ -52,9 +71,88 @@ class AnthropicParametersDict(RemoteModelParametersDict, total=False): temperature: Optional[Temperature] top_p: Optional[Probability] top_k: Optional[PositiveInt] + tools: Optional[List[AnthropicTool]] + tool_choice: Optional[str] + + +class AnthropicMessageConverter(MessageConverter): + def __init__(self, http_client: HttpClient) -> None: + super().__init__() + self.http_client = http_client + self.handle_tool_choice = True + + def convert_user_message(self, message: UserMessage) -> Dict[str, Any]: + return {'role': 'user', 'content': message.content} + + def convert_assistant_message(self, message: AssistantMessage) -> Dict[str, Any]: + content = [] + if message.content: + content.append({'type': 'text', 'text': message.content}) + if message.tool_calls: + for tool_call in message.tool_calls: + content.append( + { + 'type': 'tool_use', + 'id': tool_call.id, + 'name': tool_call.function.name, + 'input': json.loads(tool_call.function.arguments), + } + ) + return {'role': 'assistant', 'content': content} + + def convert_user_multi_part_message(self, message: UserMultiPartMessage) -> Dict[str, Any]: + message_dict = {'role': 'user', 'content': []} + for part in message.content: + if isinstance(part, TextPart): + message_dict['content'].append({'type': 'text', 'text': part.text}) + + if isinstance(part, ImagePart): + data = base64.b64encode(part.image).decode() + media_type = 'image/jpeg' if part.image_format is None else f'image/{part.image_format}' + message_dict['content'].append( + {'type': 'image', 'source': {'type': 'base64', 'media_type': media_type, 'data': data}} + ) + + if isinstance(part, ImageUrlPart): + response = self.http_client.get({'url': part.image_url.url}) + data = base64.b64encode(response.content).decode() + media_type = response.headers.get('Content-Type') or 'image/jpeg' + message_dict['content'].append( + {'type': 'image', 'source': {'type': 'base64', 'media_type': media_type, 'data': data}} + ) + return message_dict + def convert_system_message(self, message: SystemMessage) -> Dict[str, Any]: + raise MessageTypeError(message, (UserMessage, AssistantMessage, UserMultiPartMessage)) + + def convert_function_message(self, message: FunctionMessage) -> Dict[str, Any]: + raise MessageTypeError(message, (UserMessage, AssistantMessage, UserMultiPartMessage)) + + def convert_tool_message(self, message: ToolMessage) -> Dict[str, Any]: + tool_result: dict = { + 'type': 'tool_result', + 'tool_use_id': message.tool_call_id, + } + if message.content: + tool_result['content'] = message.content + if message.is_error: + tool_result['is_error'] = True + return { + 'role': 'user', + 'content': [tool_result], + } -class AnthropicChat(RemoteChatCompletionModel): + def convert_messages(self, messages: Messages, tool_choice: str | None = None) -> List[Dict[str, Any]]: + messages_dict = super().convert_messages(messages) + if tool_choice and self.handle_tool_choice: + for message_dict in messages_dict[::-1]: + if message_dict['role'] == 'user': + message_dict['content'] += f'\nUse the {tool_choice} tool in your response.' + break + return messages_dict + + +class AnthropicChat(RemoteChatCompletionModel, ToolCallMixin): model_type: ClassVar[str] = 'anthropic' available_models: ClassVar[List[str]] = [ 'claude-2.1', @@ -67,6 +165,7 @@ class AnthropicChat(RemoteChatCompletionModel): parameters: AnthropicChatParameters settings: AnthropicSettings + message_converter: AnthropicMessageConverter def __init__( self, @@ -74,11 +173,22 @@ def __init__( parameters: AnthropicChatParameters | None = None, settings: AnthropicSettings | None = None, http_client: HttpClient | None = None, + message_converter: AnthropicMessageConverter | None = None, + cost_calculator: CostCalculator | None = None, ) -> None: parameters = parameters or AnthropicChatParameters() settings = settings or AnthropicSettings() # type: ignore http_client = http_client or HttpClient() - super().__init__(model=model, parameters=parameters, settings=settings, http_client=http_client) + message_converter = message_converter or AnthropicMessageConverter(http_client) + cost_calculator = cost_calculator or GeneralCostCalculator(AnthropicModelPrice, exchange_rate=7.3) + super().__init__( + model=model, + parameters=parameters, + settings=settings, + http_client=http_client, + message_converter=message_converter, + cost_calculator=cost_calculator, + ) @override def generate(self, prompt: Prompt, **kwargs: Unpack[AnthropicParametersDict]) -> ChatCompletionOutput: @@ -101,57 +211,32 @@ async def async_stream_generate( async for output in super().async_stream_generate(prompt, **kwargs): yield output - def _convert_message(self, message: Message) -> dict[str, str]: - if isinstance(message, UserMessage): - return {'role': 'user', 'content': message.content} - if isinstance(message, AssistantMessage): - return {'role': 'assistant', 'content': message.content} - if isinstance(message, UserMultiPartMessage): - message_dict = {'role': 'user', 'content': []} - for part in message.content: - if isinstance(part, TextPart): - message_dict['content'].append({'type': 'text', 'text': part.text}) - - if isinstance(part, ImagePart): - data = base64.b64encode(part.image).decode() - media_type = 'image/jpeg' if part.image_format is None else f'image/{part.image_format}' - message_dict['content'].append( - {'type': 'image', 'source': {'type': 'base64', 'media_type': media_type, 'data': data}} - ) - - if isinstance(part, ImageUrlPart): - response = self.http_client.get({'url': part.image_url.url}) - data = base64.b64encode(response.content).decode() - media_type = response.headers.get('Content-Type') or 'image/jpeg' - message_dict['content'].append( - {'type': 'image', 'source': {'type': 'base64', 'media_type': media_type, 'data': data}} - ) - return message_dict - raise MessageTypeError(message, (UserMessage, AssistantMessage, UserMultiPartMessage)) - @override def _get_request_parameters( - self, prompt: Prompt, stream: bool = False, **kwargs: Unpack[AnthropicParametersDict] + self, messages: Messages, stream: bool = False, **kwargs: Unpack[AnthropicParametersDict] ) -> HttpxPostKwargs: - messages = ensure_messages(prompt) parameters = self.parameters.clone_with_changes(**kwargs) - if isinstance(messages[0], SystemMessage): parameters.system = messages[0].content messages = messages[1:] + anthropic_messages = self.message_converter.convert_messages(messages, parameters.tool_choice) + headers = { + 'Content-Type': 'application/json', + 'anthropic-version': self.settings.api_version, + 'x-api-key': self.settings.api_key.get_secret_value(), + } + if tool_use := bool(parameters.tools): + headers['anthropic-beta'] = 'tools-2024-04-04' - anthropic_messages = [self._convert_message(message) for message in messages] json_dict = parameters.custom_model_dump() json_dict['model'] = self.model json_dict['messages'] = anthropic_messages + if stream: + if tool_use: + raise ValueError('Tool calls are not supported in stream mode') json_dict['stream'] = True - headers = { - 'Content-Type': 'application/json', - 'anthropic-version': self.settings.api_version, - 'x-api-key': self.settings.api_key.get_secret_value(), - } return { 'url': self.settings.api_base + '/messages', 'headers': headers, @@ -160,54 +245,86 @@ def _get_request_parameters( @override def _process_reponse(self, response: dict[str, Any]) -> ChatCompletionOutput: - content = '' - for i in response['content']: - if i['type'] == 'text': - content += i['text'] return ChatCompletionOutput( model_info=self.model_info, - message=AssistantMessage(content=content), - finish_reason=response['stop_reason'], - cost=self._calculate_cost(**response['usage']), - extra={'usage': response['usage'], 'message_id': response['id']}, + message=self._parse_assistant_message(response), + finish_reason=self._parse_finish_reason(response), + usage=self._parse_usage(response), + extra=self._parse_extra(response), ) - def _calculate_cost(self, input_tokens: int, output_tokens: int) -> float | None: - model_price_mapping = { - 'claude-instant': (0.80, 2.40), - 'claude-2': (8, 24), - 'claude-3-haiku': (0.25, 1.25), - 'claude-3-sonnet': (3, 15), - 'claude-3-opus': (15, 75), - } - dollar_to_yuan = 7 - for model_name, (prompt_price, completion_price) in model_price_mapping.items(): - if model_name in self.model: - cost = (input_tokens * prompt_price / 1_000_000) + (output_tokens * completion_price / 1_000_000) - return cost * dollar_to_yuan - return None - @override - def _process_stream_line(self, line: str, stream_manager: StreamManager) -> ChatCompletionStreamOutput | None: - try: - data = json.loads(line) - except json.JSONDecodeError: + def _process_stream_response( + self, response: dict[str, Any], stream_manager: StreamManager + ) -> ChatCompletionStreamOutput | None: + if 'message' in response: + input_tokens = response['message']['usage']['input_tokens'] + stream_manager.usage.input_tokens = input_tokens return None - if 'message' in data: - input_tokens = data['message']['usage']['input_tokens'] - stream_manager.extra.setdefault('usage', {}).update({'input_tokens': input_tokens}) - return None - - if 'delta' in data: - if 'stop_reason' in data['delta']: - delta_dict = data['delta'] + if 'delta' in response: + if 'stop_reason' in response['delta']: + delta_dict = response['delta'] stream_manager.delta = '' - stream_manager.finish_reason = delta_dict['stop_reason'] - stream_manager.extra['usage']['output_tokens'] = data['usage']['output_tokens'] - stream_manager.cost = self._calculate_cost(**stream_manager.extra['usage']) + stream_manager.finish_reason = self._parse_finish_reason(delta_dict) + stream_manager.usage.output_tokens = response['usage']['output_tokens'] + if stream_manager.usage.input_tokens is not None: + stream_manager.usage.cost = self.cost(stream_manager.usage.input_tokens, stream_manager.usage.output_tokens) return stream_manager.build_stream_output() - stream_manager.delta = data['delta']['text'] + stream_manager.delta = response['delta']['text'] return stream_manager.build_stream_output() return None + + @override + def add_tools(self, tools: OrIterable[Tool]) -> None: + new_tools = [ + AnthropicTool(name=tool.name, description=tool.description, input_schema=tool.parameters) + for tool in ensure_iterable(tools) + ] + if self.parameters.tools is None: + self.parameters.tools = new_tools + else: + self.parameters.tools.extend(new_tools) + + @override + def generate_tool_call_id(self, function_call: FunctionCall) -> str: + return f'toolu_{uuid.uuid4().hex}' + + def _parse_assistant_message(self, response: dict[str, Any]) -> AssistantMessage: + content = '' + tool_calls = [] + for i in response['content']: + if i['type'] == 'text': + content += i['text'] + if i['type'] == 'tool_use': + tool_call = ToolCall( + id=i['id'], function=FunctionCall(name=i['name'], arguments=json.dumps(i['input'], ensure_ascii=False)) + ) + tool_calls.append(tool_call) + tool_calls = tool_calls or None + return AssistantMessage(content=content, tool_calls=tool_calls) + + def _parse_usage(self, response: dict[str, Any]) -> Usage: + if 'usage' not in response: + return Usage() + + input_tokens = response['usage']['input_tokens'] + output_tokens = response['usage']['output_tokens'] + cost = self.cost(input_tokens, output_tokens) + return Usage(input_tokens=input_tokens, output_tokens=output_tokens, cost=cost) + + def _parse_finish_reason(self, response: dict[str, Any]) -> FinishReason | None: + finish_reason_mapping = { + 'end_turn': 'end_turn', + 'max_tokens': 'length', + 'stop_sequence': 'stop', + 'tool_use': 'tool_calls', + } + finish_reason = finish_reason_mapping.get(response['stop_reason']) + return FinishReason(finish_reason) if finish_reason else None + + def _parse_extra(self, response: dict[str, Any]) -> Dict[str, Any]: + return { + 'response': response, + } diff --git a/generate/chat_completion/models/azure.py b/generate/chat_completion/models/azure.py index 4062866..abd2bad 100644 --- a/generate/chat_completion/models/azure.py +++ b/generate/chat_completion/models/azure.py @@ -1,24 +1,25 @@ from __future__ import annotations -from typing import Any, AsyncIterator, ClassVar, Iterator +from typing import Any, AsyncIterator, ClassVar, Dict, Iterator from typing_extensions import Unpack, override -from generate.chat_completion.base import RemoteChatCompletionModel -from generate.chat_completion.message import Prompt, ensure_messages +from generate.chat_completion.message import Prompt +from generate.chat_completion.message.core import Messages from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput from generate.chat_completion.models.openai import OpenAIChatParameters, OpenAIChatParametersDict -from generate.chat_completion.models.openai_like import convert_to_openai_message, process_openai_like_model_reponse +from generate.chat_completion.models.openai_like import OpenAILikeChat, OpenAIMessageConverter from generate.chat_completion.stream_manager import StreamManager from generate.http import HttpClient, HttpxPostKwargs from generate.platforms.azure import AzureSettings -class AzureChat(RemoteChatCompletionModel): +class AzureChat(OpenAILikeChat): model_type: ClassVar[str] = 'azure' parameters: OpenAIChatParameters settings: AzureSettings + message_converter: OpenAIMessageConverter def __init__( self, @@ -26,6 +27,7 @@ def __init__( parameters: OpenAIChatParameters | None = None, settings: AzureSettings | None = None, http_client: HttpClient | None = None, + message_converter: OpenAIMessageConverter | None = None, ) -> None: parameters = parameters or OpenAIChatParameters() settings = settings or AzureSettings() # type: ignore @@ -33,7 +35,10 @@ def __init__( model = model or settings.chat_api_engine if model is None: raise ValueError('model must be provided or set in settings.chat_api_engine') - super().__init__(model, parameters=parameters, settings=settings, http_client=http_client) + message_converter = message_converter or OpenAIMessageConverter() + super().__init__( + model=model, parameters=parameters, settings=settings, http_client=http_client, message_converter=message_converter + ) @override def generate(self, prompt: Prompt, **kwargs: Unpack[OpenAIChatParametersDict]) -> ChatCompletionOutput: @@ -56,14 +61,11 @@ def async_stream_generate( raise NotImplementedError('Azure does not support streaming') @override - def _get_request_parameters(self, prompt: Prompt, **kwargs: Unpack[OpenAIChatParametersDict]) -> HttpxPostKwargs: - messages = ensure_messages(prompt) + def _get_request_parameters(self, messages: Messages, **kwargs: Unpack[OpenAIChatParametersDict]) -> HttpxPostKwargs: parameters = self.parameters.clone_with_changes(**kwargs) - - openai_messages = [convert_to_openai_message(message) for message in messages] json_data = { 'model': self.model, - 'messages': openai_messages, + 'messages': self.message_converter.convert_messages(messages), **parameters.custom_model_dump(), } headers = { @@ -77,9 +79,7 @@ def _get_request_parameters(self, prompt: Prompt, **kwargs: Unpack[OpenAIChatPar } @override - def _process_reponse(self, response: dict[str, Any]) -> ChatCompletionOutput: - return process_openai_like_model_reponse(response, model_type=self.model_type) - - @override - def _process_stream_line(self, line: str, stream_manager: StreamManager) -> ChatCompletionStreamOutput | None: - raise NotImplementedError + def _process_stream_response( + self, response: Dict[str, Any], stream_manager: StreamManager + ) -> ChatCompletionStreamOutput | None: + raise NotImplementedError('Azure does not support streaming') diff --git a/generate/chat_completion/models/baichuan.py b/generate/chat_completion/models/baichuan.py index d2d15ee..fab53df 100644 --- a/generate/chat_completion/models/baichuan.py +++ b/generate/chat_completion/models/baichuan.py @@ -1,45 +1,41 @@ from __future__ import annotations -import json -from datetime import datetime -from typing import AsyncIterator, ClassVar, Iterator, List, Literal, Optional +from typing import Any, AsyncIterator, ClassVar, Iterator, List, Optional from pydantic import Field -from typing_extensions import Annotated, TypedDict, Unpack, override +from typing_extensions import Annotated, Unpack, override from generate.chat_completion.base import RemoteChatCompletionModel +from generate.chat_completion.cost_caculator import CostCalculator, GeneralCostCalculator from generate.chat_completion.message import ( AssistantMessage, - Message, Messages, - MessageTypeError, Prompt, SystemMessage, UserMessage, - ensure_messages, ) -from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput +from generate.chat_completion.message.converter import SimpleMessageConverter +from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput, FinishReason, Usage from generate.chat_completion.stream_manager import StreamManager from generate.http import ( HttpClient, HttpxPostKwargs, - ResponseValue, - UnexpectedResponseError, ) from generate.model import ModelParameters, RemoteModelParametersDict from generate.platforms.baichuan import BaichuanSettings -from generate.types import Probability, Temperature +from generate.types import ModelPrice, Probability, Temperature - -class BaichuanMessage(TypedDict): - role: Literal['user', 'assistant'] - content: str +BaichuanModelPrice: ModelPrice = { + 'Baichuan2-Turbo-192k': (16, 16), + 'Baichuan2-Turbo': (8, 8), +} class BaichuanChatParameters(ModelParameters): temperature: Optional[Temperature] = None top_k: Optional[Annotated[int, Field(ge=0)]] = None top_p: Optional[Probability] = None + max_tokens: Optional[Annotated[int, Field(ge=0)]] = None search: Optional[bool] = Field(default=None, alias='with_search_enhance') @@ -47,15 +43,17 @@ class BaichuanChatParametersDict(RemoteModelParametersDict, total=False): temperature: Optional[Temperature] top_k: Optional[int] top_p: Optional[Probability] + max_tokens: Optional[int] search: Optional[bool] class BaichuanChat(RemoteChatCompletionModel): model_type: ClassVar[str] = 'baichuan' - available_models: ClassVar[List[str]] = ['Baichuan2-Turbo', 'Baichuan2-53B', 'Baichuan2-Turbo-192k'] + available_models: ClassVar[List[str]] = ['Baichuan2-Turbo', 'Baichuan2-Turbo-192k'] parameters: BaichuanChatParameters settings: BaichuanSettings + message_converter: SimpleMessageConverter def __init__( self, @@ -63,11 +61,22 @@ def __init__( parameters: BaichuanChatParameters | None = None, settings: BaichuanSettings | None = None, http_client: HttpClient | None = None, + message_converter: SimpleMessageConverter | None = None, + cost_calculator: CostCalculator | None = None, ) -> None: parameters = parameters or BaichuanChatParameters() settings = settings or BaichuanSettings() # type: ignore http_client = http_client or HttpClient() - super().__init__(model, parameters=parameters, settings=settings, http_client=http_client) + message_converter = message_converter or SimpleMessageConverter() + cost_calculator = cost_calculator or GeneralCostCalculator(BaichuanModelPrice) + super().__init__( + model=model, + parameters=parameters, + settings=settings, + http_client=http_client, + message_converter=message_converter, + cost_calculator=cost_calculator, + ) @override def generate(self, prompt: Prompt, **kwargs: Unpack[BaichuanChatParametersDict]) -> ChatCompletionOutput: @@ -92,19 +101,20 @@ async def async_stream_generate( @override def _get_request_parameters( - self, prompt: Prompt, stream: bool = False, **kwargs: Unpack[BaichuanChatParametersDict] + self, messages: Messages, stream: bool = False, **kwargs: Unpack[BaichuanChatParametersDict] ) -> HttpxPostKwargs: - messages = ensure_messages(prompt) + if isinstance(system_message := messages[0], SystemMessage): + prepend_messages = [UserMessage(content=system_message.content)] + messages = prepend_messages + messages[1:] parameters = self.parameters.clone_with_changes(**kwargs) - baichuan_messages: list[BaichuanMessage] = self._convert_messages(messages) - data = { + json_data = { 'model': self.model, - 'messages': baichuan_messages, + 'messages': self.message_converter.convert_messages(messages), } parameters_dict = parameters.custom_model_dump() - data.update(parameters_dict) + json_data.update(parameters_dict) if stream: - data['stream'] = True + json_data['stream'] = True headers = { 'Content-Type': 'application/json', 'Authorization': 'Bearer ' + self.settings.api_key.get_secret_value(), @@ -112,77 +122,48 @@ def _get_request_parameters( return { 'url': self.settings.api_base + '/chat/completions', 'headers': headers, - 'json': data, + 'json': json_data, } - @staticmethod - def _convert_message(message: Message) -> BaichuanMessage: - if isinstance(message, UserMessage): - return { - 'role': 'user', - 'content': message.content, - } - - if isinstance(message, AssistantMessage): - return { - 'role': 'assistant', - 'content': message.content, - } - raise MessageTypeError(message, (UserMessage, AssistantMessage)) - - def _convert_messages(self, messages: Messages) -> list[BaichuanMessage]: - if isinstance(system_message := messages[0], SystemMessage): - prepend_messages = [UserMessage(content=system_message.content), AssistantMessage(content='好的')] - messages = prepend_messages + messages[1:] - return [self._convert_message(message) for message in messages] - @override - def _process_reponse(self, response: ResponseValue) -> ChatCompletionOutput: - try: - text = response['choices'][0]['message']['content'] - finish_reason = response['choices'][0]['finish_reason'] or None - usage = response.get('usage') - extra = {'id': response['id']} - if usage is not None: - cost = self._calculate_cost(usage['total_tokens']) - extra['usage'] = usage - else: - cost = None - return ChatCompletionOutput( - model_info=self.model_info, - message=AssistantMessage(content=text), - finish_reason=finish_reason, - cost=cost, - extra=extra, - ) - except (KeyError, IndexError) as e: - raise UnexpectedResponseError(response) from e + def _process_reponse(self, response: dict[str, Any]) -> ChatCompletionOutput: + return ChatCompletionOutput( + model_info=self.model_info, + message=self._parse_assistant_message(response), + finish_reason=self._parse_finish_reason(response), + usage=self._parse_usage(response), + extra=self._parse_extra(response), + ) @override - def _process_stream_line(self, line: str, stream_manager: StreamManager) -> Optional[ChatCompletionStreamOutput]: + def _process_stream_response( + self, response: dict[str, Any], stream_manager: StreamManager + ) -> ChatCompletionStreamOutput | None: + stream_manager.delta = response['choices'][0]['delta']['content'] + stream_manager.finish_reason = self._parse_finish_reason(response) + stream_manager.extra = self._parse_extra(response) + stream_manager.usage = self._parse_usage(response) + return stream_manager.build_stream_output() + + def _parse_assistant_message(self, response: dict[str, Any]) -> AssistantMessage: + return AssistantMessage(content=response['choices'][0]['message']['content']) + + def _parse_usage(self, response: dict[str, Any]) -> Usage: + usage = response.get('usage') + if usage is not None: + input_tokens = usage['prompt_tokens'] + output_tokens = usage['completion_tokens'] + cost = self.cost(input_tokens, output_tokens) + return Usage(input_tokens=input_tokens, output_tokens=output_tokens, cost=cost) + return Usage() + + def _parse_finish_reason(self, response: dict[str, Any]) -> FinishReason | None: try: - data = json.loads(line) - except json.JSONDecodeError: + choice = response['choices'][0] + if finish_reason := choice.get('finish_reason'): + return FinishReason(finish_reason) + except (KeyError, IndexError, ValueError): return None - stream_manager.delta = data['choices'][0]['delta']['content'] - stream_manager.finish_reason = data['choices'][0].get('finish_reason') or None - stream_manager.extra['id'] = data['id'] - usage = data.get('usage') - if usage: - cost = self._calculate_cost(usage['total_tokens']) - stream_manager.extra['usage'] = usage - stream_manager.cost = cost - return stream_manager.build_stream_output() - - def _calculate_cost(self, total_tokens: int) -> float | None: - if self.model == 'Baichuan2-53B': - eight_am = 8 - if 0 <= datetime.now().hour < eight_am: - return (total_tokens * 0.01) / 1000 - return (total_tokens * 0.02) / 1000 - if self.model == 'Baichuan2-Turbo': - return (total_tokens * 0.008) / 1000 - if self.model == 'Baichuan2-Turbo-192k': - return (total_tokens * 0.016) / 1000 - return None + def _parse_extra(self, response: dict[str, Any]) -> dict[str, Any]: + return {'response': response} diff --git a/generate/chat_completion/models/bailian.py b/generate/chat_completion/models/bailian.py deleted file mode 100644 index 1f4698b..0000000 --- a/generate/chat_completion/models/bailian.py +++ /dev/null @@ -1,226 +0,0 @@ -from __future__ import annotations - -import json -import uuid -from typing import Any, AsyncIterator, ClassVar, Iterator, List, Literal, Optional - -from pydantic import Field, PositiveInt -from typing_extensions import Annotated, Self, TypedDict, Unpack, override - -from generate.chat_completion.base import RemoteChatCompletionModel -from generate.chat_completion.message import ( - AssistantMessage, - Message, - MessageTypeError, - Prompt, - SystemMessage, - UserMessage, - ensure_messages, -) -from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput -from generate.chat_completion.stream_manager import StreamManager -from generate.http import ( - HttpClient, - HttpxPostKwargs, - ResponseValue, - UnexpectedResponseError, -) -from generate.model import ModelParameters, RemoteModelParametersDict -from generate.platforms.bailian import BailianSettings, BailianTokenManager -from generate.types import Probability - - -def generate_default_request_id() -> str: - uuid_obj = uuid.uuid4() - return str(uuid_obj).replace('-', '') - - -class BailianMessage(TypedDict): - Role: Literal['user', 'assistant', 'system'] - Content: str - - -class BailianChatParameters(ModelParameters): - request_id: str = Field(default_factory=generate_default_request_id, alias='RequestId') - top_p: Optional[Probability] = Field(default=None, alias='TopP') - top_k: Optional[Annotated[int, Field(ge=0)]] = None - seed: Optional[int] = None - temperature: Optional[Annotated[float, Field(ge=0, le=2)]] = None - max_tokens: Optional[PositiveInt] = None - stop: Optional[List[str]] = None - - def custom_model_dump(self) -> dict[str, Any]: - output = super().custom_model_dump() - parameters = {} - if 'top_k' in output: - parameters['TopK'] = output.pop('top_k') - if 'seed' in output: - parameters['Seed'] = output.pop('seed') - if 'temperature' in output: - parameters['Temperature'] = output.pop('temperature') - if 'max_tokens' in output: - parameters['MaxTokens'] = output.pop('max_tokens') - if 'stop' in output: - parameters['Stop'] = output.pop('stop') - if parameters: - output['Parameters'] = parameters - return output - - -class BailianChatParametersDict(RemoteModelParametersDict, total=False): - request_id: str - top_p: Optional[Probability] - top_k: Optional[int] - seed: Optional[int] - temperature: Optional[float] - max_tokens: Optional[PositiveInt] - stop: Optional[List[str]] - - -class BailianChat(RemoteChatCompletionModel): - model_type: ClassVar[str] = 'bailian' - - parameters: BailianChatParameters - settings: BailianSettings - - def __init__( - self, - app_id: str | None = None, - parameters: BailianChatParameters | None = None, - settings: BailianSettings | None = None, - http_client: HttpClient | None = None, - ) -> None: - parameters = parameters or BailianChatParameters() - settings = settings or BailianSettings() # type: ignore - http_client = http_client or HttpClient() - self.app_id = app_id or settings.default_app_id - super().__init__(model=self.app_id, parameters=parameters, settings=settings, http_client=http_client) - - self.token_manager = BailianTokenManager(self.settings, self.http_client) - - @override - def generate(self, prompt: Prompt, **kwargs: Unpack[BailianChatParametersDict]) -> ChatCompletionOutput: - return super().generate(prompt, **kwargs) - - @override - async def async_generate(self, prompt: Prompt, **kwargs: Unpack[BailianChatParametersDict]) -> ChatCompletionOutput: - return await super().async_generate(prompt, **kwargs) - - @override - def stream_generate( - self, prompt: Prompt, **kwargs: Unpack[BailianChatParametersDict] - ) -> Iterator[ChatCompletionStreamOutput]: - yield from super().stream_generate(prompt, **kwargs) - - @override - async def async_stream_generate( - self, prompt: Prompt, **kwargs: Unpack[BailianChatParametersDict] - ) -> AsyncIterator[ChatCompletionStreamOutput]: - async for output in super().async_stream_generate(prompt, **kwargs): - yield output - - @classmethod - @override - def from_name(cls, name: str) -> Self: - return cls(app_id=name) - - @staticmethod - def _convert_message(message: Message) -> BailianMessage: - if isinstance(message, UserMessage): - return {'Role': 'user', 'Content': message.content} - if isinstance(message, AssistantMessage): - return {'Role': 'assistant', 'Content': message.content} - if isinstance(message, SystemMessage): - return {'Role': 'system', 'Content': message.content} - raise MessageTypeError(message, (UserMessage, AssistantMessage, SystemMessage)) - - @override - def _get_request_parameters( - self, prompt: Prompt, stream: bool = False, **kwargs: Unpack[BailianChatParametersDict] - ) -> HttpxPostKwargs: - messages = ensure_messages(prompt) - parameters = self.parameters.clone_with_changes(**kwargs) - - if not isinstance(messages[-1], UserMessage): - raise MessageTypeError(messages[-1], allowed_message_type=(UserMessage,)) - - json_dict = parameters.custom_model_dump() - headers = { - 'Content-Type': 'application/json;charset=UTF-8', - 'Authorization': f'Bearer {self.token_manager.token}', - } - json_dict['AppId'] = self.app_id - json_dict['Messages'] = [self._convert_message(i) for i in messages] - - if stream: - headers['Accept'] = 'text/event-stream' - json_dict['Stream'] = True - - return { - 'url': self.settings.completion_api, - 'headers': headers, - 'json': json_dict, - } - - @override - def _process_reponse(self, response: ResponseValue) -> ChatCompletionOutput: - if not response['Success']: - raise UnexpectedResponseError(response) - response_data = response['Data'] - total_tokens = response_data['Usage'][0]['InputTokens'] + response_data['Usage'][0]['OutputTokens'] - model_id = response_data['Usage'][0]['ModelId'] - return ChatCompletionOutput( - model_info=self.model_info, - cost=self._calculate_cost(model_id=model_id, total_tokens=total_tokens), - finish_reason=response_data.get('FinishReason'), - message=AssistantMessage(content=response_data['Text']), - extra={ - 'thoughts': response_data.get('Thoughts'), - 'doc_references': response_data.get('DocReferences'), - 'request_id': response['RequestId'], - 'response_id': response_data.get('ResponseId'), - 'usage': response_data['Usage'][0], - }, - ) - - @override - def _process_stream_line(self, line: str, stream_manager: StreamManager) -> ChatCompletionStreamOutput | None: - try: - data = json.loads(line) - except json.JSONDecodeError: - return None - - response_data = data['Data'] - reply: str = response_data['Text'] - stream_manager.extra.update( - { - 'thoughts': response_data.get('Thoughts'), - 'doc_references': response_data.get('DocReferences'), - 'response_id': response_data.get('ResponseId'), - } - ) - - is_finish = len(reply) == len(stream_manager.content) - if is_finish: - total_tokens = response_data['Usage'][0]['InputTokens'] + response_data['Usage'][0]['OutputTokens'] - model_id = response_data['Usage'][0]['ModelId'] - stream_manager.delta = '' - stream_manager.extra['usage'] = response_data['Usage'][0] - stream_manager.extra['response_id'] = response_data.get('ResponseId') - stream_manager.cost = self._calculate_cost(model_id=model_id, total_tokens=total_tokens) - stream_manager.finish_reason = 'stop' - return stream_manager.build_stream_output() - - delta = reply[len(stream_manager.content) :] - stream_manager.delta = delta - return stream_manager.build_stream_output() - - @staticmethod - def _calculate_cost(model_id: str, total_tokens: int) -> float | None: - if model_id == 'qwen-turbo': - return 0.008 * total_tokens / 1000 - if model_id == 'qwen-plus': - return 0.012 * total_tokens / 1000 - if model_id == 'qwen-max': - return 0.12 * total_tokens / 1000 - return None diff --git a/generate/chat_completion/models/dashscope.py b/generate/chat_completion/models/dashscope.py index 3ad02c7..e12d60b 100644 --- a/generate/chat_completion/models/dashscope.py +++ b/generate/chat_completion/models/dashscope.py @@ -1,23 +1,21 @@ from __future__ import annotations -import json -from typing import AsyncIterator, ClassVar, Iterator, List, Literal, Optional, Dict +from typing import Any, AsyncIterator, ClassVar, Dict, Iterator, List, Optional from pydantic import Field, PositiveInt -from typing_extensions import Annotated, TypedDict, Unpack, override +from typing_extensions import Annotated, Unpack, override from generate.chat_completion.base import RemoteChatCompletionModel from generate.chat_completion.message import ( AssistantMessage, - Message, - MessageTypeError, Prompt, - SystemMessage, - UserMessage, - ensure_messages, ) -from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput +from generate.chat_completion.message.converter import SimpleMessageConverter +from generate.chat_completion.message.core import FunctionCall, FunctionMessage, Messages, ToolCall, ToolMessage +from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput, FinishReason, Usage +from generate.chat_completion.models.openai_like import OpenAITool, convert_to_openai_tool from generate.chat_completion.stream_manager import StreamManager +from generate.chat_completion.tool import Tool, ToolCallMixin from generate.http import ( HttpClient, HttpxPostKwargs, @@ -25,12 +23,8 @@ ) from generate.model import ModelParameters, RemoteModelParametersDict from generate.platforms.dashscope import DashScopeSettings -from generate.types import Probability - - -class DashscopeMessage(TypedDict): - role: Literal['user', 'assistant', 'system'] - content: str +from generate.types import OrIterable, Probability +from generate.utils import ensure_iterable class DashScopeChatParameters(ModelParameters): @@ -42,6 +36,7 @@ class DashScopeChatParameters(ModelParameters): temperature: Optional[Annotated[float, Field(gt=0, le=2)]] = None stop: Optional[List[str]] = None search: Annotated[Optional[bool], Field(alias='enable_search')] = None + tools: Optional[List[OpenAITool]] = None class DashScopeChatParametersDict(RemoteModelParametersDict, total=False): @@ -53,26 +48,67 @@ class DashScopeChatParametersDict(RemoteModelParametersDict, total=False): temperature: Optional[Annotated[float, Field(gt=0, le=2)]] stop: Optional[List[str]] search: Optional[bool] + tools: Optional[List[OpenAITool]] + + +class DashScopeMessageConverter(SimpleMessageConverter): + def convert_function_message(self, message: FunctionMessage) -> Dict[str, Any]: + return { + 'role': 'tool', + 'name': message.name, + 'content': message.content, + } + + def convert_assistant_message(self, message: AssistantMessage) -> Dict[str, Any]: + base_dict = { + 'role': 'assistant', + 'content': message.content or None, + } + if message.tool_calls: + tool_calls = [ + { + 'type': 'function', + 'function': { + 'name': tool_call.function.name, + 'arguments': tool_call.function.arguments, + }, + } + for tool_call in message.tool_calls + ] + base_dict['tool_calls'] = tool_calls + if message.function_call: + base_dict['tool_calls'] = [ + { + 'name': message.function_call.name, + 'arguments': message.function_call.arguments, + } + ] + return base_dict -class DashScopeChat(RemoteChatCompletionModel): +class DashScopeChat(RemoteChatCompletionModel, ToolCallMixin): model_type: ClassVar[str] = 'dashscope' - available_models: ClassVar[List[str]] = ['qwen-turbo', 'qwen-plus', 'qwen-max', 'qwen-max-1201', 'qwen-max-longcontext'] + available_models: ClassVar[List[str]] = ['qwen-turbo', 'qwen-plus', 'qwen-max', 'qwen-max-longcontext'] parameters: DashScopeChatParameters settings: DashScopeSettings + message_converter: DashScopeMessageConverter def __init__( self, - model: str = 'qwen-max', + model: str = 'qwen-plus', parameters: DashScopeChatParameters | None = None, settings: DashScopeSettings | None = None, http_client: HttpClient | None = None, + message_converter: DashScopeMessageConverter | None = None, ) -> None: parameters = parameters or DashScopeChatParameters() settings = settings or DashScopeSettings() # type: ignore http_client = http_client or HttpClient() - super().__init__(model=model, parameters=parameters, settings=settings, http_client=http_client) + message_converter = message_converter or DashScopeMessageConverter() + super().__init__( + model=model, parameters=parameters, settings=settings, http_client=http_client, message_converter=message_converter + ) @override def generate(self, prompt: Prompt, **kwargs: Unpack[DashScopeChatParametersDict]) -> ChatCompletionOutput: @@ -97,21 +133,21 @@ async def async_stream_generate( @override def _get_request_parameters( - self, prompt: Prompt, stream: bool = False, **kwargs: Unpack[DashScopeChatParametersDict] + self, messages: Messages, stream: bool = False, **kwargs: Unpack[DashScopeChatParametersDict] ) -> HttpxPostKwargs: - messages = ensure_messages(prompt) parameters = self.parameters.clone_with_changes(**kwargs) - zhipu_messages = [self._convert_message(message) for message in messages] headers = { 'Authorization': self.settings.api_key.get_secret_value(), 'Content-Type': 'application/json', } + if self.settings.workspace is not None: + headers['X-DashScope-WorkSpace'] = self.settings.workspace if stream: headers['Accept'] = 'text/event-stream' params = { 'input': { - 'messages': zhipu_messages, + 'messages': self.message_converter.convert_messages(messages), }, 'model': self.model, 'parameters': parameters.custom_model_dump(), @@ -122,55 +158,100 @@ def _get_request_parameters( 'json': params, } - def _process_usage(self, usage: Dict[str, int]) -> Dict[str, int]: - return { - 'total_tokens': usage['total_tokens'], - 'prompt_tokens': usage['input_tokens'], - 'completion_tokens': usage['output_tokens'], - } - @override def _process_reponse(self, response: ResponseValue) -> ChatCompletionOutput: return ChatCompletionOutput( model_info=self.model_info, - message=AssistantMessage(content=response['output']['text']), - cost=self._calculate_cost(response['usage']['total_tokens']), - extra={'usage': self._process_usage(response['usage']), 'request_id': response['request_id']}, - finish_reason=response['output']['finish_reason'], + message=self._parse_assistant_message(response), + usage=self._parse_usage(response), + extra=self._parse_extra(response), + finish_reason=self._parse_finish_reason(response['choices'][0]), ) @override - def _process_stream_line(self, line: str, stream_manager: StreamManager) -> ChatCompletionStreamOutput | None: - try: - data = json.loads(line) - except json.JSONDecodeError: - return None - - finish_reason = data['output']['finish_reason'] or None - reply = data['output']['text'] - stream_manager.extra['usage'] = self._process_usage(data['usage']) - stream_manager.extra['request_id'] = data['request_id'] + def _process_stream_response( + self, response: dict[str, Any], stream_manager: StreamManager + ) -> ChatCompletionStreamOutput | None: + finish_reason = self._parse_finish_reason(response['output']) + reply = response['output']['text'] + stream_manager.usage = self._parse_usage(response) + stream_manager.extra = self._parse_extra(response) if finish_reason != 'null': stream_manager.finish_reason = finish_reason stream_manager.delta = '' - stream_manager.cost = self._calculate_cost(total_tokens=stream_manager.extra['usage']['total_tokens']) return stream_manager.build_stream_output() stream_manager.delta = reply[len(stream_manager.content) :] return stream_manager.build_stream_output() - @staticmethod - def _convert_message(message: Message) -> DashscopeMessage: - if isinstance(message, UserMessage): - return {'role': 'user', 'content': message.content} - if isinstance(message, AssistantMessage): - return {'role': 'assistant', 'content': message.content} - if isinstance(message, SystemMessage): - return {'role': 'system', 'content': message.content} - raise MessageTypeError(message, (UserMessage, AssistantMessage, SystemMessage)) - - def _calculate_cost(self, total_tokens: int) -> Optional[float]: - if self.model == 'qwen-turbo': - return total_tokens * 0.008 / 1000 - if self.model == 'qwen-plus': - return total_tokens * 0.04 / 1000 + @override + def cost(self, input_tokens: int, output_tokens: int) -> Optional[float]: + total_tokens = input_tokens + output_tokens + model_price = { + 'qwen-turbo': 8, + 'qwen-plus': 20, + 'qwen-max': 120, + } + for model_name, price in model_price.items(): + if model_name in self.model: + return total_tokens * price / 1_000_000 return None + + @override + def adapt_tool_calls(self, messages: Messages) -> None: + tool_call_id_to_function_name = {} + new_messages = [] + for message in messages: + if isinstance(message, AssistantMessage) and message.tool_calls: + for tool_call in message.tool_calls: + tool_call_id_to_function_name[tool_call.id] = tool_call.function.name + if isinstance(message, ToolMessage): + message = FunctionMessage( + name=tool_call_id_to_function_name[message.tool_call_id], content=message.content or '' + ) + new_messages.append(message) + messages[:] = new_messages + + @override + def add_tools(self, tools: OrIterable[Tool]) -> None: + new_tools = [convert_to_openai_tool(tool) for tool in ensure_iterable(tools)] + if self.parameters.tools is None: + self.parameters.tools = new_tools + else: + self.parameters.tools.extend(new_tools) + + def _parse_usage(self, response: dict[str, Any]) -> Usage: + if usage := response.get('usage'): + input_tokens = usage.get('input_tokens') + output_tokens = usage.get('output_tokens') + return Usage(input_tokens=input_tokens, output_tokens=output_tokens, cost=self.cost(input_tokens, output_tokens)) + return Usage() + + def _parse_assistant_message(self, response: dict[str, Any]) -> AssistantMessage: + message = response['choices'][0]['message'] + if tool_calls_dict := message.get('tool_calls'): + tool_calls = [ + ToolCall( + id=tool_call['function']['name'], + function=FunctionCall( + name=tool_call['function']['name'], + arguments=tool_call['function']['arguments'], + ), + ) + for tool_call in tool_calls_dict + ] + else: + tool_calls = None + return AssistantMessage(content=message.get('content') or '', tool_calls=tool_calls) + + def _parse_extra(self, response: dict[str, Any]) -> Dict[str, Any]: + return { + 'request_id': response['request_id'], + 'response': response, + } + + def _parse_finish_reason(self, choice: dict[str, Any]) -> FinishReason | None: + try: + if finish_reason := choice.get('finish_reason'): + return FinishReason(finish_reason) + except (KeyError, IndexError, ValueError): + return None diff --git a/generate/chat_completion/models/dashscope_multimodal.py b/generate/chat_completion/models/dashscope_multimodal.py index 327c6b1..b15c636 100644 --- a/generate/chat_completion/models/dashscope_multimodal.py +++ b/generate/chat_completion/models/dashscope_multimodal.py @@ -208,7 +208,7 @@ def _process_reponse(self, response: ResponseValue) -> ChatCompletionOutput: ) @override - def _process_stream_line(self, line: str, stream_manager: StreamManager) -> Optional[ChatCompletionStreamOutput]: + def _process_stream_response(self, line: str, stream_manager: StreamManager) -> Optional[ChatCompletionStreamOutput]: try: data = json.loads(line) except json.JSONDecodeError: diff --git a/generate/chat_completion/models/deepseek.py b/generate/chat_completion/models/deepseek.py index 250b325..99f277f 100644 --- a/generate/chat_completion/models/deepseek.py +++ b/generate/chat_completion/models/deepseek.py @@ -5,13 +5,19 @@ from pydantic import Field, PositiveInt from typing_extensions import Annotated, Unpack, override +from generate.chat_completion.cost_caculator import CostCalculator, GeneralCostCalculator from generate.chat_completion.message import Prompt from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput -from generate.chat_completion.models.openai_like import OpenAILikeChat +from generate.chat_completion.models.openai_like import OpenAILikeChat, OpenAIMessageConverter from generate.http import HttpClient from generate.model import ModelParameters, RemoteModelParametersDict from generate.platforms import DeepSeekSettings -from generate.types import Probability +from generate.types import ModelPrice, Probability + +DeepSeekModelPrice: ModelPrice = { + 'deepseek-chat': (1, 2), + 'deepseek-coder': (1, 2), +} class DeepSeekChatParameters(ModelParameters): @@ -21,6 +27,8 @@ class DeepSeekChatParameters(ModelParameters): frequency_penalty: Optional[Annotated[float, Field(ge=-2, le=2)]] = None presence_penalty: Optional[Annotated[float, Field(ge=-2, le=2)]] = None stop: Optional[Union[str, List[str]]] = None + logprobs: Optional[bool] = None + top_logprobs: Optional[Annotated[int, Field(ge=0, le=20)]] = None class DeepSeekParametersDict(RemoteModelParametersDict, total=False): @@ -30,6 +38,8 @@ class DeepSeekParametersDict(RemoteModelParametersDict, total=False): frequency_penalty: Optional[float] presence_penalty: Optional[float] stop: Optional[Union[str, List[str]]] + logprobs: Optional[bool] + top_logprobs: Optional[int] class DeepSeekChat(OpenAILikeChat): @@ -38,6 +48,7 @@ class DeepSeekChat(OpenAILikeChat): parameters: DeepSeekChatParameters settings: DeepSeekSettings + message_converter: OpenAIMessageConverter def __init__( self, @@ -45,12 +56,20 @@ def __init__( parameters: DeepSeekChatParameters | None = None, settings: DeepSeekSettings | None = None, http_client: HttpClient | None = None, + message_converter: OpenAIMessageConverter | None = None, + cost_calculator: CostCalculator | None = None, ) -> None: parameters = parameters or DeepSeekChatParameters() settings = settings or DeepSeekSettings() # type: ignore - http_client = http_client or HttpClient() - model = model - super().__init__(model=model, parameters=parameters, settings=settings, http_client=http_client) + cost_calculator = cost_calculator or GeneralCostCalculator(DeepSeekModelPrice) + super().__init__( + model=model, + parameters=parameters, + settings=settings, + http_client=http_client, + message_converter=message_converter, + cost_calculator=cost_calculator, + ) @override def generate(self, prompt: Prompt, **kwargs: Unpack[DeepSeekParametersDict]) -> ChatCompletionOutput: diff --git a/generate/chat_completion/models/hunyuan.py b/generate/chat_completion/models/hunyuan.py index b7b9420..43da0de 100644 --- a/generate/chat_completion/models/hunyuan.py +++ b/generate/chat_completion/models/hunyuan.py @@ -137,7 +137,7 @@ def _get_request_parameters( } @override - def _process_stream_line(self, line: str, stream_manager: StreamManager) -> ChatCompletionStreamOutput | None: + def _process_stream_response(self, line: str, stream_manager: StreamManager) -> ChatCompletionStreamOutput | None: try: data = json.loads(line) except json.JSONDecodeError: diff --git a/generate/chat_completion/models/minimax_legacy.py b/generate/chat_completion/models/minimax_legacy.py deleted file mode 100644 index 47cc88f..0000000 --- a/generate/chat_completion/models/minimax_legacy.py +++ /dev/null @@ -1,208 +0,0 @@ -from __future__ import annotations - -import json -from typing import Any, AsyncIterator, ClassVar, Iterator, List, Literal, Optional - -from pydantic import Field, PositiveInt -from typing_extensions import Annotated, TypedDict, Unpack, override - -from generate.chat_completion.base import RemoteChatCompletionModel -from generate.chat_completion.message import ( - AssistantMessage, - Message, - Messages, - MessageTypeError, - Prompt, - SystemMessage, - UserMessage, - ensure_messages, -) -from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput -from generate.chat_completion.stream_manager import StreamManager -from generate.http import ( - HttpClient, - HttpxPostKwargs, - ResponseValue, - UnexpectedResponseError, -) -from generate.model import ModelParameters, RemoteModelParametersDict -from generate.platforms.minimax import MinimaxSettings -from generate.types import Probability, Temperature - - -class MinimaxMessage(TypedDict): - sender_type: Literal['USER', 'BOT'] - text: str - - -class RoleMeta(TypedDict): - user_name: str - bot_name: str - - -DEFAULT_MINIMAX_SYSTEM_PROMPT = 'MM智能助理是一款由MiniMax自研的,没有调用其他产品的接口的大型语言模型。MiniMax是一家中国科技公司,一直致力于进行大模型相关的研究。' - - -class MinimaxLegacyChatParameters(ModelParameters): - system_prompt: str = Field(default=DEFAULT_MINIMAX_SYSTEM_PROMPT, serialization_alias='prompt') - role_meta: RoleMeta = {'user_name': '用户', 'bot_name': 'MM智能助理'} - beam_width: Optional[Annotated[int, Field(ge=1, le=4)]] = None - temperature: Optional[Temperature] = None - top_p: Optional[Probability] = None - max_tokens: Optional[Annotated[PositiveInt, Field(serialization_alias='tokens_to_generate')]] = None - skip_info_mask: Optional[bool] = None - continue_last_message: Optional[bool] = None - - def custom_model_dump(self) -> dict[str, Any]: - output = super().custom_model_dump() - if 'temperature' in output: - output['temperature'] = max(0.01, output['temperature']) - if 'top_p' in output: - output['top_p'] = max(0.01, output['top_p']) - return output - - -class MinimaxLegacyChatParametersDict(RemoteModelParametersDict, total=False): - system_prompt: str - role_meta: RoleMeta - beam_width: Optional[int] - temperature: Optional[Temperature] - top_p: Optional[Probability] - max_tokens: Optional[int] - skip_info_mask: Optional[bool] - continue_last_message: Optional[bool] - - -def _convert_message_to_minimax_message(message: Message) -> MinimaxMessage: - if isinstance(message, UserMessage): - return { - 'sender_type': 'USER', - 'text': message.content, - } - if isinstance(message, AssistantMessage): - return { - 'sender_type': 'BOT', - 'text': message.content, - } - raise MessageTypeError(message, (UserMessage, AssistantMessage)) - - -def _convert_messages(messages: Messages) -> list[MinimaxMessage]: - if isinstance(system_message := messages[0], SystemMessage): - prepend_messages = [UserMessage(content=system_message.content), AssistantMessage(content='好的')] - messages = prepend_messages + messages[1:] - return [_convert_message_to_minimax_message(message) for message in messages] - - -class MinimaxLegacyChat(RemoteChatCompletionModel): - model_type: ClassVar[str] = 'minimax_legacy' - available_models: ClassVar[List[str]] = ['abab5.5-chat', 'abab5.5s-chat'] - - parameters: MinimaxLegacyChatParameters - settings: MinimaxSettings - - def __init__( - self, - model: str = 'abab5.5-chat', - settings: MinimaxSettings | None = None, - parameters: MinimaxLegacyChatParameters | None = None, - http_client: HttpClient | None = None, - ) -> None: - parameters = parameters or MinimaxLegacyChatParameters() - settings = settings or MinimaxSettings() # type: ignore - http_client = http_client or HttpClient() - if not settings.group_id: - raise ValueError( - 'group_id is required for MinimaxLegacyChat, you can set it in settings or environment variable MINIMAX_GROUP_ID' - ) - super().__init__(model=model, parameters=parameters, settings=settings, http_client=http_client) - - @override - def generate(self, prompt: Prompt, **kwargs: Unpack[MinimaxLegacyChatParametersDict]) -> ChatCompletionOutput: - return super().generate(prompt, **kwargs) - - @override - async def async_generate(self, prompt: Prompt, **kwargs: Unpack[MinimaxLegacyChatParametersDict]) -> ChatCompletionOutput: - return await super().async_generate(prompt, **kwargs) - - @override - def stream_generate( - self, prompt: Prompt, **kwargs: Unpack[MinimaxLegacyChatParametersDict] - ) -> Iterator[ChatCompletionStreamOutput]: - yield from super().stream_generate(prompt, **kwargs) - - @override - async def async_stream_generate( - self, prompt: Prompt, **kwargs: Unpack[MinimaxLegacyChatParametersDict] - ) -> AsyncIterator[ChatCompletionStreamOutput]: - async for output in super().async_stream_generate(prompt, **kwargs): - yield output - - @override - def _get_request_parameters( - self, prompt: Prompt, stream: bool = False, **kwargs: Unpack[MinimaxLegacyChatParametersDict] - ) -> HttpxPostKwargs: - messages = ensure_messages(prompt) - parameters = self.parameters.clone_with_changes(**kwargs) - minimax_messages = _convert_messages(messages) - parameters_dict = parameters.custom_model_dump() - json_data = { - 'model': self.model, - 'messages': minimax_messages, - **parameters_dict, - } - if stream: - json_data['stream'] = True - json_data['use_standard_sse'] = True - - headers = { - 'Authorization': f'Bearer {self.settings.api_key.get_secret_value()}', - 'Content-Type': 'application/json', - } - return { - 'url': self.settings.api_base + '/text/chatcompletion', - 'json': json_data, - 'headers': headers, - 'params': {'GroupId': self.settings.group_id}, - } - - @override - def _process_reponse(self, response: ResponseValue) -> ChatCompletionOutput: - try: - return ChatCompletionOutput( - model_info=self.model_info, - message=AssistantMessage(content=response['choices'][0]['text']), - finish_reason=response['choices'][0]['finish_reason'], - cost=self._calculate_cost(response['usage']), - extra={ - 'logprobes': response['choices'][0]['logprobes'], - 'input_sensitive': False, - 'output_sensitive': False, - 'usage': response['usage'], - }, - ) - except (KeyError, IndexError, TypeError) as e: - raise UnexpectedResponseError(response) from e - - @override - def _process_stream_line(self, line: str, stream_manager: StreamManager) -> ChatCompletionStreamOutput | None: - try: - data = json.loads(line) - except json.JSONDecodeError: - return None - stream_manager.delta = data['choices'][0]['delta'] - - if data['reply']: - stream_manager.finish_reason = data['choices'][0]['finish_reason'] - extra = { - 'logprobes': data['choices'][0]['logprobes'], - 'input_sensitive': False, - 'output_sensitive': False, - 'usage': data['usage'], - } - stream_manager.extra.update(extra) - stream_manager.cost = self._calculate_cost(data['usage']) - return stream_manager.build_stream_output() - - def _calculate_cost(self, usage: dict[str, int]) -> float: - return 0.015 * (usage['total_tokens'] / 1000) diff --git a/generate/chat_completion/models/minimax_pro.py b/generate/chat_completion/models/minimax_pro.py index c9145b4..cb2273f 100644 --- a/generate/chat_completion/models/minimax_pro.py +++ b/generate/chat_completion/models/minimax_pro.py @@ -371,7 +371,7 @@ def _get_request_parameters( } @override - def _process_stream_line(self, line: str, stream_manager: StreamManager) -> ChatCompletionStreamOutput | None: + def _process_stream_response(self, line: str, stream_manager: StreamManager) -> ChatCompletionStreamOutput | None: # TODO: implement this raise NotImplementedError diff --git a/generate/chat_completion/models/openai.py b/generate/chat_completion/models/openai.py index a6088f9..0ffd766 100644 --- a/generate/chat_completion/models/openai.py +++ b/generate/chat_completion/models/openai.py @@ -35,11 +35,13 @@ class OpenAIChatParameters(ModelParameters): presence_penalty: Optional[Annotated[float, Field(ge=-2, le=2)]] = None frequency_penalty: Optional[Annotated[float, Field(ge=-2, le=2)]] = None logit_bias: Optional[Dict[int, Annotated[int, Field(ge=-100, le=100)]]] = None + logprobs: Optional[bool] = None + top_logprobs: Optional[Annotated[int, Field(ge=0, le=20)]] = None user: Optional[str] = None response_format: Optional[OpenAIResponseFormat] = None seed: Optional[int] = None tools: Optional[List[OpenAITool]] = None - tool_choice: Union[Literal['auto'], OpenAIToolChoice, None] = None + tool_choice: Union[Literal['auto', 'none'], OpenAIToolChoice, None] = None class OpenAIChatParametersDict(RemoteModelParametersDict, total=False): @@ -52,6 +54,8 @@ class OpenAIChatParametersDict(RemoteModelParametersDict, total=False): presence_penalty: Optional[float] frequency_penalty: Optional[float] logit_bias: Optional[Dict[int, int]] + logprobs: Optional[bool] + top_logprobs: Optional[int] user: Optional[str] response_format: Optional[OpenAIResponseFormat] seed: Optional[int] diff --git a/generate/chat_completion/models/openai_like.py b/generate/chat_completion/models/openai_like.py index bc69c4d..a380fda 100644 --- a/generate/chat_completion/models/openai_like.py +++ b/generate/chat_completion/models/openai_like.py @@ -1,41 +1,38 @@ from __future__ import annotations import base64 -import json -import uuid from abc import ABC -from functools import partial -from typing import Any, Callable, Dict, List, Literal, Type, Union, cast +from typing import Any, Dict, List, Literal, Union from typing_extensions import NotRequired, TypedDict, override from generate.chat_completion.base import RemoteChatCompletionModel -from generate.chat_completion.cost_caculator import GeneralCostCalculator +from generate.chat_completion.cost_caculator import CostCalculator from generate.chat_completion.message import ( AssistantMessage, FunctionCall, FunctionMessage, ImagePart, - Message, - MessageTypeError, - Prompt, SystemMessage, TextPart, ToolCall, ToolMessage, UserMessage, UserMultiPartMessage, - ensure_messages, ) +from generate.chat_completion.message.converter import MessageConverter from generate.chat_completion.message.core import Messages -from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput +from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput, FinishReason, Usage from generate.chat_completion.stream_manager import StreamManager from generate.chat_completion.tool import FunctionJsonSchema, Tool from generate.http import ( + HttpClient, + HttpGetKwargs, HttpxPostKwargs, ResponseValue, ) -from generate.model import ModelInfo +from generate.model import ModelInfo, ModelParameters +from generate.platforms.base import PlatformSettings from generate.platforms.openai_like import OpenAILikeSettings @@ -77,121 +74,107 @@ class OpenAIResponseFormat(TypedDict): type: Literal['json_object', 'text'] -def _to_text_message_dict(role: str, message: Message) -> OpenAIMessage: - if not isinstance(message.content, str): - raise TypeError(f'Unexpected message content: {type(message.content)}') - return { - 'role': role, - 'content': message.content, - } +class OpenAIMessageConverter(MessageConverter): + def convert_system_message(self, message: SystemMessage) -> Dict[str, Any]: + return { + 'role': 'system', + 'content': message.content, + } + def convert_user_message(self, message: UserMessage) -> Dict[str, Any]: + return { + 'role': 'user', + 'content': message.content, + } -def _to_user_multipart_message_dict(message: UserMultiPartMessage) -> OpenAIMessage: - content = [] - for part in message.content: - if isinstance(part, TextPart): - content.append({'type': 'text', 'text': part.text}) - else: - if isinstance(part, ImagePart): - image_format = part.image_format or 'png' - url: str = f'data:image/{image_format};base64,{base64.b64encode(part.image).decode()}' - image_url_dict = {'url': url} + def convert_user_multi_part_message(self, message: UserMultiPartMessage) -> Dict[str, Any]: + content = [] + for part in message.content: + if isinstance(part, TextPart): + content.append({'type': 'text', 'text': part.text}) else: - image_url_dict = {} - image_url_dict['url'] = part.image_url.url - if part.image_url.detail: - image_url_dict['detail'] = part.image_url.detail - image_url_part_dict: dict[str, Any] = { - 'type': 'image_url', - 'image_url': image_url_dict, - } - content.append(image_url_part_dict) - return { - 'role': 'user', - 'content': content, - } - - -def _to_tool_message_dict(message: ToolMessage) -> OpenAIMessage: - return { - 'role': 'tool', - 'tool_call_id': message.tool_call_id, - 'content': message.content, - } - - -def _to_asssistant_message_dict(message: AssistantMessage) -> OpenAIMessage: - base_dict = { - 'role': 'assistant', - 'content': message.content or None, - } - if message.tool_calls: - tool_calls = [ - { - 'id': tool_call.id, - 'type': 'function', - 'function': { - 'name': tool_call.function.name, - 'arguments': tool_call.function.arguments, - }, + if isinstance(part, ImagePart): + url: str = f'data:image/{part.image_format};base64,{base64.b64encode(part.image).decode()}' + image_url_dict = {'url': url} + else: + image_url_dict = {} + image_url_dict['url'] = part.image_url.url + if part.image_url.detail: + image_url_dict['detail'] = part.image_url.detail + image_url_part_dict: dict[str, Any] = { + 'type': 'image_url', + 'image_url': image_url_dict, + } + content.append(image_url_part_dict) + return { + 'role': 'user', + 'content': content, + } + + def convert_tool_message(self, message: ToolMessage) -> Dict[str, Any]: + return { + 'role': 'tool', + 'tool_call_id': message.tool_call_id, + 'content': message.content, + } + + def convert_assistant_message(self, message: AssistantMessage) -> Dict[str, Any]: + base_dict = { + 'role': 'assistant', + 'content': message.content or None, + } + if message.tool_calls: + tool_calls = [ + { + 'id': tool_call.id, + 'type': 'function', + 'function': { + 'name': tool_call.function.name, + 'arguments': tool_call.function.arguments, + }, + } + for tool_call in message.tool_calls + ] + base_dict['tool_calls'] = tool_calls + if message.function_call: + base_dict['function_call'] = { + 'name': message.function_call.name, + 'arguments': message.function_call.arguments, } - for tool_call in message.tool_calls - ] - base_dict['tool_calls'] = tool_calls - if message.function_call: - base_dict['function_call'] = { - 'name': message.function_call.name, - 'arguments': message.function_call.arguments, + return base_dict + + def convert_function_message(self, message: FunctionMessage) -> Dict[str, Any]: + return { + 'role': 'function', + 'name': message.name, + 'content': message.content, } - return cast(OpenAIMessage, base_dict) - - -def _to_function_message_dict(message: FunctionMessage) -> OpenAIMessage: - return { - 'role': 'function', - 'name': message.name, - 'content': message.content, - } - - -def convert_to_openai_message(message: Message) -> OpenAIMessage: - to_function_map: dict[Type[Message], Callable[[Any], OpenAIMessage]] = { - SystemMessage: partial(_to_text_message_dict, 'system'), - UserMessage: partial(_to_text_message_dict, 'user'), - AssistantMessage: partial(_to_asssistant_message_dict), - UserMultiPartMessage: _to_user_multipart_message_dict, - ToolMessage: _to_tool_message_dict, - FunctionMessage: _to_function_message_dict, - } - if to_function := to_function_map.get(type(message)): - return to_function(message) - - raise MessageTypeError(message, allowed_message_type=tuple(to_function_map.keys())) - - -def openai_calculate_cost(model_name: str, input_tokens: int, output_tokens: int) -> float | None: - dollar_to_yuan = 7 - if model_name in ('gpt-4-1106-preview', 'gpt-4-1106-vision-preview'): - return (0.01 * dollar_to_yuan) * (input_tokens / 1000) + (0.03 * dollar_to_yuan) * (output_tokens / 1000) - if 'gpt-4-turbo' in model_name: - return (0.01 * dollar_to_yuan) * (input_tokens / 1000) + (0.03 * dollar_to_yuan) * (output_tokens / 1000) - if 'gpt-4-32k' in model_name: - return (0.06 * dollar_to_yuan) * (input_tokens / 1000) + (0.12 * dollar_to_yuan) * (output_tokens / 1000) - if 'gpt-4' in model_name: - return (0.03 * dollar_to_yuan) * (input_tokens / 1000) + (0.06 * dollar_to_yuan) * (output_tokens / 1000) - if 'gpt-3.5-turbo' in model_name: - return (0.001 * dollar_to_yuan) * (input_tokens / 1000) + (0.002 * dollar_to_yuan) * (output_tokens / 1000) - if 'moonshot' in model_name: - if '8k' in model_name: - return 0.012 * (input_tokens / 1000) + 0.012 * (output_tokens / 1000) - if '32k' in model_name: - return 0.024 * (input_tokens / 1000) + 0.024 * (output_tokens / 1000) - if '128k' in model_name: - return 0.06 * (input_tokens / 1000) + 0.06 * (output_tokens / 1000) - return None - - -def _convert_to_assistant_message(message: dict[str, Any]) -> AssistantMessage: + + +# def openai_calculate_cost(model_name: str, input_tokens: int, output_tokens: int) -> float | None: + +# dollar_to_yuan = 7 +# if model_name in ('gpt-4-1106-preview', 'gpt-4-1106-vision-preview'): +# return (0.01 * dollar_to_yuan) * (input_tokens / 1000) + (0.03 * dollar_to_yuan) * (output_tokens / 1000) +# if 'gpt-4-turbo' in model_name: +# return (0.01 * dollar_to_yuan) * (input_tokens / 1000) + (0.03 * dollar_to_yuan) * (output_tokens / 1000) +# if 'gpt-4-32k' in model_name: +# return (0.06 * dollar_to_yuan) * (input_tokens / 1000) + (0.12 * dollar_to_yuan) * (output_tokens / 1000) +# if 'gpt-4' in model_name: +# return (0.03 * dollar_to_yuan) * (input_tokens / 1000) + (0.06 * dollar_to_yuan) * (output_tokens / 1000) +# if 'gpt-3.5-turbo' in model_name: +# return (0.001 * dollar_to_yuan) * (input_tokens / 1000) + (0.002 * dollar_to_yuan) * (output_tokens / 1000) +# if 'moonshot' in model_name: +# if '8k' in model_name: +# return 0.012 * (input_tokens / 1000) + 0.012 * (output_tokens / 1000) +# if '32k' in model_name: +# return 0.024 * (input_tokens / 1000) + 0.024 * (output_tokens / 1000) +# if '128k' in model_name: +# return 0.06 * (input_tokens / 1000) + 0.06 * (output_tokens / 1000) +# return None + + +def parse_message_dict(message: dict[str, Any]) -> AssistantMessage: if function_call_dict := message.get('function_call'): function_call = FunctionCall( name=function_call_dict.get('name') or '', @@ -221,38 +204,30 @@ def convert_to_openai_tool(tool: Tool) -> OpenAITool: def process_openai_like_model_reponse(response: ResponseValue, model_type: str) -> ChatCompletionOutput: - message = _convert_to_assistant_message(response['choices'][0]['message']) - extra = {'usage': response['usage']} + choice = response['choices'][0] + message = parse_message_dict(choice['message']) + extra = {'response': response} if system_fingerprint := response.get('system_fingerprint'): extra['system_fingerprint'] = system_fingerprint - choice = response['choices'][0] if (finish_reason := choice.get('finish_reason')) is None: finish_reason = finish_details['type'] if (finish_details := choice.get('finish_details')) else None - try: - if model_type == 'openai': - cost = openai_calculate_cost( - model_name=response['model'], - input_tokens=response['usage']['prompt_tokens'], - output_tokens=response['usage']['completion_tokens'], - ) - else: - cost_calculator = GeneralCostCalculator() - cost = cost_calculator.calculate( - model_type=model_type, - model_name=response['model'], - input_tokens=response['usage']['prompt_tokens'], - output_tokens=response['usage']['completion_tokens'], - ) - except Exception: - cost = None + if finish_reason: + finish_reason = FinishReason(finish_reason) + input_tokens = response['usage']['prompt_tokens'] + output_tokens = response['usage']['completion_tokens'] + cost = None + for k, v in response['usage'].items(): + if k in ('cost', 'total_cost'): + cost = v + break return ChatCompletionOutput( model_info=ModelInfo(task='chat_completion', type=model_type, name=response['model']), message=message, finish_reason=finish_reason, - cost=cost, + usage=Usage(input_tokens=input_tokens, output_tokens=output_tokens, cost=cost), extra=extra, ) @@ -260,17 +235,35 @@ def process_openai_like_model_reponse(response: ResponseValue, model_type: str) class OpenAILikeChat(RemoteChatCompletionModel, ABC): settings: OpenAILikeSettings + def __init__( + self, + model: str, + parameters: ModelParameters, + settings: PlatformSettings, + http_client: HttpClient | None = None, + message_converter: MessageConverter | None = None, + cost_calculator: CostCalculator | None = None, + ) -> None: + http_client = http_client or HttpClient() + message_converter = message_converter or OpenAIMessageConverter() + super().__init__( + model=model, + parameters=parameters, + settings=settings, + http_client=http_client, + message_converter=message_converter, + cost_calculator=cost_calculator, + ) + @override - def _get_request_parameters(self, prompt: Prompt, stream: bool = False, **kwargs: Any) -> HttpxPostKwargs: - messages = ensure_messages(prompt) + def _get_request_parameters(self, messages: Messages, stream: bool = False, **kwargs: Any) -> HttpxPostKwargs: parameters = self.parameters.clone_with_changes(**kwargs) - openai_messages = self._convert_to_openai_messages(messages) headers = { 'Authorization': f'Bearer {self.settings.api_key.get_secret_value()}', } params = { 'model': self.model, - 'messages': openai_messages, + 'messages': self.message_converter.convert_messages(messages), **parameters.custom_model_dump(), } if stream: @@ -282,31 +275,91 @@ def _get_request_parameters(self, prompt: Prompt, stream: bool = False, **kwargs 'json': params, } - def _convert_to_openai_messages(self, messages: Messages) -> List[OpenAIMessage]: - return [convert_to_openai_message(message) for message in messages] - - @staticmethod - def generate_tool_call_id() -> str: - return f'call_{uuid.uuid4()}' - @override - def _process_reponse(self, response: Dict[str, Any]) -> ChatCompletionOutput: - return process_openai_like_model_reponse(response, model_type=self.model_type) + def _process_reponse(self, response: dict[str, Any]) -> ChatCompletionOutput: + return ChatCompletionOutput( + model_info=ModelInfo(task='chat_completion', type=self.model_type, name=response['model']), + message=self._parse_assistant_message(response), + finish_reason=self._parse_finish_reason(response), + usage=self._parse_usage(response), + extra=self._parse_extra_info(response), + ) @override - def _process_stream_line(self, line: str, stream_manager: StreamManager) -> ChatCompletionStreamOutput | None: - try: - data = json.loads(line) - except json.JSONDecodeError: - return None - - delta_dict = data['choices'][0].get('delta', {}) + def _process_stream_response( + self, response: Dict[str, Any], stream_manager: StreamManager + ) -> ChatCompletionStreamOutput | None: + delta_dict = response['choices'][0].get('delta', {}) self._update_delta(delta_dict, stream_manager=stream_manager) - stream_manager.extra = self._extract_extra_info(data) - stream_manager.cost = self._calculate_cost(data) - stream_manager.finish_reason = self._determine_finish_reason(data) + stream_manager.extra = self._parse_extra_info(response) + stream_manager.usage = self._parse_usage(response) + stream_manager.finish_reason = self._parse_finish_reason(response) return stream_manager.build_stream_output() + def _parse_assistant_message(self, response: dict[str, Any]) -> AssistantMessage: + message = response['choices'][0]['message'] + if function_call_dict := message.get('function_call'): + function_call = FunctionCall( + name=function_call_dict.get('name') or '', + arguments=function_call_dict['arguments'], + ) + else: + function_call = None + + if tool_calls_dict := message.get('tool_calls'): + tool_calls = [ + ToolCall( + id=tool_call['id'], + function=FunctionCall( + name=tool_call['function'].get('name') or '', + arguments=tool_call['function']['arguments'], + ), + ) + for tool_call in tool_calls_dict + ] + else: + tool_calls = None + return AssistantMessage(content=message.get('content') or '', function_call=function_call, tool_calls=tool_calls) + + @override + def list_models(self) -> List[str]: + headers = { + 'Accept': 'application/json', + 'Authorization': f'Bearer {self.settings.api_key.get_secret_value()}', + } + parameters: HttpGetKwargs = { + 'url': f'{self.settings.api_base}/models', + 'headers': headers, + } + response = self.http_client.get(parameters) + self.http_client.raise_for_status(response) + return [i['id'] for i in response.json()['data'] if i['object'] == 'model'] + + def _parse_finish_reason(self, response: dict[str, Any]) -> FinishReason | None: + choice = response['choices'][0] + finish_reason = choice.get('finish_reason') or None + if finish_reason is None: + finish_reason: str | None = finish_details['type'] if (finish_details := choice.get('finish_details')) else None + if finish_reason is not None: + finish_reason = FinishReason(finish_reason) + return finish_reason + + def _parse_usage(self, response: dict[str, Any]) -> Usage: + if usage := response.get('usage'): + input_tokens = usage['prompt_tokens'] + output_tokens = usage['completion_tokens'] + cost = self.cost(input_tokens, output_tokens) + if cost is None: + for k, v in usage.items(): + if k in ('cost', 'total_cost'): + cost = v + break + return Usage(input_tokens=input_tokens, output_tokens=output_tokens, cost=cost) + return Usage() + + def _parse_extra_info(self, response: dict[str, Any]) -> dict[str, Any]: + return {'response': response} + def _update_delta(self, delta_dict: dict[str, Any], stream_manager: StreamManager) -> None: delta_content: str = delta_dict.get('content') or '' stream_manager.delta = delta_content @@ -314,7 +367,7 @@ def _update_delta(self, delta_dict: dict[str, Any], stream_manager: StreamManage if delta_dict.get('tool_calls'): index = delta_dict['tool_calls'][0]['index'] if index >= len(stream_manager.tool_calls or []): - new_tool_calls_message = _convert_to_assistant_message(delta_dict).tool_calls + new_tool_calls_message = parse_message_dict(delta_dict).tool_calls assert new_tool_calls_message is not None if stream_manager.tool_calls is None: stream_manager.tool_calls = [] @@ -330,47 +383,3 @@ def _update_delta(self, delta_dict: dict[str, Any], stream_manager: StreamManage stream_manager.function_call.name += function_name arguments = delta_dict['function_call'].get('arguments', '') stream_manager.function_call.arguments += arguments - - def _extract_extra_info(self, response: ResponseValue) -> dict[str, Any]: - extra = { - 'id': response['id'], - } - choice = response['choices'][0] - if usage := response.get('usage'): - extra['usage'] = usage - if usage := choice.get('usage'): - extra['usage'] = usage - if system_fingerprint := response.get('system_fingerprint'): - extra['system_fingerprint'] = system_fingerprint - return extra - - def _calculate_cost(self, response: ResponseValue) -> float | None: - if response.get('usage') is None: - return None - - if self.model_type == 'openai': - return openai_calculate_cost( - model_name=response['model'], - input_tokens=response['usage']['prompt_tokens'], - output_tokens=response['usage']['completion_tokens'], - ) - - cost_calculator = GeneralCostCalculator() - input_tokens = response['usage'].get('prompt_tokens', 0) - output_tokens = response['usage'].get('completion_tokens', 0) - if 'total_tokens' in response['usage']: - input_tokens = 0 - output_tokens = response['usage']['total_tokens'] - return cost_calculator.calculate( - model_type=self.model_type, - model_name=response['model'], - input_tokens=input_tokens, - output_tokens=output_tokens, - ) - - def _determine_finish_reason(self, response: ResponseValue) -> str | None: - choice = response['choices'][0] - finish_reason = choice.get('finish_reason') or None - if finish_reason is None: - finish_reason: str | None = finish_details['type'] if (finish_details := choice.get('finish_details')) else None - return finish_reason diff --git a/generate/chat_completion/models/openrouter.py b/generate/chat_completion/models/openrouter.py new file mode 100644 index 0000000..4c96e41 --- /dev/null +++ b/generate/chat_completion/models/openrouter.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +from typing import Any, AsyncIterator, ClassVar, Dict, Iterator, List, Literal, Optional, Union + +from pydantic import BaseModel, Field, PositiveInt +from typing_extensions import Annotated, Unpack, override + +from generate.chat_completion.message import Prompt +from generate.chat_completion.message.core import Messages +from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput +from generate.chat_completion.models.openai_like import ( + OpenAILikeChat, + OpenAIResponseFormat, + OpenAITool, + OpenAIToolChoice, + convert_to_openai_tool, +) +from generate.chat_completion.tool import Tool, ToolCallMixin +from generate.http import HttpClient, HttpxPostKwargs +from generate.model import ModelParameters, RemoteModelParametersDict +from generate.platforms import OpenRouterSettings +from generate.types import OrIterable, Probability, Temperature +from generate.utils import ensure_iterable + + +class ProviderParameters(BaseModel): + allow_fallbacks: bool = True + require_parameters: bool = False + data_collection: str = 'allow' + order: Optional[List[str]] = None + + +class OpenRouterChatParameters(ModelParameters): + temperature: Optional[Temperature] = None + top_p: Optional[Probability] = None + top_k: Optional[PositiveInt] = None + max_tokens: Optional[PositiveInt] = None + stop: Union[str, List[str], None] = None + presence_penalty: Optional[Annotated[float, Field(ge=-2, le=2)]] = None + frequency_penalty: Optional[Annotated[float, Field(ge=-2, le=2)]] = None + repetition_penalty: Optional[Annotated[float, Field(ge=0.0, le=2.0)]] = None + logit_bias: Optional[Dict[int, Annotated[int, Field(ge=-100, le=100)]]] = None + user: Optional[str] = None + response_format: Optional[OpenAIResponseFormat] = None + seed: Optional[int] = None + tools: Optional[List[OpenAITool]] = None + tool_choice: Union[Literal['auto', 'none'], OpenAIToolChoice, None] = None + route: Optional[str] = None + transforms: Optional[List[str]] = None + provider: Optional[ProviderParameters] = None + + +class OpenRouterParametersDict(RemoteModelParametersDict, total=False): + temperature: Optional[Temperature] + top_p: Optional[Probability] + top_k: Optional[PositiveInt] + max_tokens: Optional[PositiveInt] + stop: Union[str, List[str], None] + presence_penalty: Optional[float] + frequency_penalty: Optional[float] + repetition_penalty: Optional[float] + logit_bias: Optional[Dict[int, int]] + user: Optional[str] + response_format: Optional[OpenAIResponseFormat] + seed: Optional[int] + tools: Optional[List[OpenAITool]] + tool_choice: Union[Literal['auto', 'none'], OpenAIToolChoice, None] + route: Optional[str] + transforms: Optional[List[str]] + provider: ProviderParameters + + +class OpenRouterChat(OpenAILikeChat, ToolCallMixin): + model_type: ClassVar[str] = 'openrouter' + available_models: ClassVar[List[str]] = ['auto'] + + parameters: OpenRouterChatParameters + settings: OpenRouterSettings + + def __init__( + self, + model: str | list[str] = 'auto', + parameters: OpenRouterChatParameters | None = None, + settings: OpenRouterSettings | None = None, + http_client: HttpClient | None = None, + app_name: str | None = None, + site_url: str | None = None, + ) -> None: + parameters = parameters or OpenRouterChatParameters() + settings = settings or OpenRouterSettings() # type: ignore + http_client = http_client or HttpClient() + if isinstance(model, list): + self.models = model + model = '-'.join(model[:3]) + if len(model) > 3: + model += '-etc' + else: + self.models = None + super().__init__(model=model, parameters=parameters, settings=settings, http_client=http_client) + self.app_name = app_name + self.site_url = site_url + + @override + def generate(self, prompt: Prompt, **kwargs: Unpack[OpenRouterParametersDict]) -> ChatCompletionOutput: + return super().generate(prompt, **kwargs) + + @override + async def async_generate(self, prompt: Prompt, **kwargs: Unpack[OpenRouterParametersDict]) -> ChatCompletionOutput: + return await super().async_generate(prompt, **kwargs) + + @override + def stream_generate( + self, prompt: Prompt, **kwargs: Unpack[OpenRouterParametersDict] + ) -> Iterator[ChatCompletionStreamOutput]: + yield from super().stream_generate(prompt, **kwargs) + + @override + async def async_stream_generate( + self, prompt: Prompt, **kwargs: Unpack[OpenRouterParametersDict] + ) -> AsyncIterator[ChatCompletionStreamOutput]: + async for stream_output in super().async_stream_generate(prompt, **kwargs): + yield stream_output + + @override + def _get_request_parameters(self, messages: Messages, stream: bool = False, **kwargs: Any) -> HttpxPostKwargs: + request_parameters: HttpxPostKwargs = super()._get_request_parameters(messages, stream=stream, **kwargs) + if self.app_name: + request_parameters['headers']['X-Title'] = self.app_name + if self.site_url: + request_parameters['headers']['HTTP-Referer'] = self.site_url + if self.models: + request_parameters['json']['models'] = self.models + request_parameters['json'].pop('model') + return request_parameters + + @override + def add_tools(self, tools: OrIterable[Tool]) -> None: + new_tools = [convert_to_openai_tool(tool) for tool in ensure_iterable(tools)] + if self.parameters.tools is None: + self.parameters.tools = new_tools + else: + self.parameters.tools.extend(new_tools) diff --git a/generate/chat_completion/models/wenxin.py b/generate/chat_completion/models/wenxin.py index e93e700..524d31e 100644 --- a/generate/chat_completion/models/wenxin.py +++ b/generate/chat_completion/models/wenxin.py @@ -225,7 +225,7 @@ def _process_reponse(self, response: ResponseValue) -> ChatCompletionOutput: ) @override - def _process_stream_line(self, line: str, stream_manager: StreamManager) -> ChatCompletionStreamOutput | None: + def _process_stream_response(self, line: str, stream_manager: StreamManager) -> ChatCompletionStreamOutput | None: try: data = json.loads(line) except json.JSONDecodeError: diff --git a/generate/chat_completion/models/yi.py b/generate/chat_completion/models/yi.py index 2ff6504..5776955 100644 --- a/generate/chat_completion/models/yi.py +++ b/generate/chat_completion/models/yi.py @@ -5,22 +5,32 @@ from pydantic import Field, PositiveInt from typing_extensions import Annotated, Unpack, override +from generate.chat_completion.cost_caculator import CostCalculator, GeneralCostCalculator from generate.chat_completion.message import Prompt from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput -from generate.chat_completion.models.openai_like import OpenAILikeChat +from generate.chat_completion.models.openai_like import OpenAILikeChat, OpenAIMessageConverter from generate.http import HttpClient from generate.model import ModelParameters, RemoteModelParametersDict from generate.platforms import YiSettings +from generate.types import ModelPrice + +YiModelPrice: ModelPrice = { + 'yi-34b-chat-200k': (12.0, 12.0), + 'yi-34b-chat': (2.5, 2.5), + 'yi-vl-plus': (6, 6), +} class YiChatParameters(ModelParameters): temperature: Optional[Annotated[float, Field(ge=0, lt=2)]] = None max_tokens: Optional[PositiveInt] = None + top_p: Optional[Annotated[float, Field(ge=0, lt=1)]] = None class YiParametersDict(RemoteModelParametersDict, total=False): - temperature: Optional[Annotated[float, Field(ge=0, lt=2)]] - max_tokens: Optional[PositiveInt] + temperature: Optional[float] + max_tokens: Optional[int] + top_p: Optional[float] class YiChat(OpenAILikeChat): @@ -29,6 +39,7 @@ class YiChat(OpenAILikeChat): parameters: YiChatParameters settings: YiSettings + message_converter: OpenAIMessageConverter def __init__( self, @@ -36,12 +47,20 @@ def __init__( parameters: YiChatParameters | None = None, settings: YiSettings | None = None, http_client: HttpClient | None = None, + message_converter: OpenAIMessageConverter | None = None, + cost_calculator: CostCalculator | None = None, ) -> None: parameters = parameters or YiChatParameters() settings = settings or YiSettings() # type: ignore - http_client = http_client or HttpClient() - model = model - super().__init__(model=model, parameters=parameters, settings=settings, http_client=http_client) + cost_calculator = cost_calculator or GeneralCostCalculator(YiModelPrice) + super().__init__( + model=model, + parameters=parameters, + settings=settings, + http_client=http_client, + message_converter=message_converter, + cost_calculator=cost_calculator, + ) @override def generate(self, prompt: Prompt, **kwargs: Unpack[YiParametersDict]) -> ChatCompletionOutput: diff --git a/generate/chat_completion/models/zhipu.py b/generate/chat_completion/models/zhipu.py index 0b07b6c..3b211fa 100644 --- a/generate/chat_completion/models/zhipu.py +++ b/generate/chat_completion/models/zhipu.py @@ -1,21 +1,19 @@ from __future__ import annotations import base64 -import json from typing import Any, AsyncIterator, ClassVar, Dict, Iterator, List, Literal, Optional, Union from pydantic import field_validator from typing_extensions import NotRequired, TypedDict, Unpack, override from generate.chat_completion.base import RemoteChatCompletionModel +from generate.chat_completion.cost_caculator import CostCalculator, GeneralCostCalculator from generate.chat_completion.message import ( AssistantMessage, FunctionCall, ImagePart, ImageUrlPart, - Message, Messages, - MessageTypeError, Prompt, SystemMessage, TextPart, @@ -23,19 +21,27 @@ ToolMessage, UserMessage, UserMultiPartMessage, - ensure_messages, ) -from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput, Stream +from generate.chat_completion.message.converter import MessageConverter +from generate.chat_completion.message.core import FunctionMessage +from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput, FinishReason, Usage from generate.chat_completion.stream_manager import StreamManager +from generate.chat_completion.tool import Tool, ToolCallMixin from generate.http import ( HttpClient, HttpxPostKwargs, ResponseValue, - UnexpectedResponseError, ) -from generate.model import ModelInfo, ModelParameters, RemoteModelParametersDict +from generate.model import ModelParameters, RemoteModelParametersDict from generate.platforms.zhipu import ZhipuSettings, generate_zhipu_token -from generate.types import JsonSchema, Probability, Temperature +from generate.types import JsonSchema, OrIterable, Probability, Temperature +from generate.utils import ensure_iterable + +ZhipuModelPrice = { + 'glm-4v': (100, 100), + 'glm-4': (100, 100), + 'glm-3-turbo': (5, 5), +} class Function(TypedDict): @@ -121,44 +127,20 @@ class ZhipuMessage(TypedDict): tool_call_id: NotRequired[str] -def convert_to_zhipu_message(message: Message) -> ZhipuMessage: - if isinstance(message, UserMessage): +class ZhipuMessageConverter(MessageConverter): + def convert_system_message(self, message: SystemMessage) -> Dict[str, Any]: return { - 'role': 'user', + 'role': 'system', 'content': message.content, } - if isinstance(message, UserMultiPartMessage): - content = [] - for part in message.content: - if isinstance(part, TextPart): - content.append( - { - 'type': 'text', - 'text': part.text, - } - ) - elif isinstance(part, ImageUrlPart): - content.append( - { - 'type': 'image_url', - 'image_url': { - 'url': part.image_url.url, - }, - } - ) - elif isinstance(part, ImagePart): - content.append( - { - 'type': 'image_url', - 'image_url': { - 'url': base64.b64encode(part.image).decode(), - }, - } - ) - return {'role': 'user', 'content': content} + def convert_user_message(self, message: UserMessage) -> Dict[str, Any]: + return { + 'role': 'user', + 'content': message.content, + } - if isinstance(message, AssistantMessage): + def convert_assistant_message(self, message: AssistantMessage) -> Dict[str, Any]: if message.tool_calls is not None: dict_format_toll_calls: list[ZhipuToolCall] = [] for index, tool_call in enumerate(message.tool_calls): @@ -186,156 +168,54 @@ def convert_to_zhipu_message(message: Message) -> ZhipuMessage: 'content': message.content, } - if isinstance(message, SystemMessage): - return { - 'role': 'system', - 'content': message.content, - } - - if isinstance(message, ToolMessage): + def convert_tool_message(self, message: ToolMessage) -> Dict[str, Any]: return { 'role': 'tool', 'content': message.content or '', 'tool_call_id': message.tool_call_id, } - raise MessageTypeError(message, (UserMessage, AssistantMessage)) - - -def _convert_to_assistant_message(zhiput_message_dict: dict[str, Any]) -> AssistantMessage: - if 'tool_calls' in zhiput_message_dict: - dict_format_tool_calls = zhiput_message_dict['tool_calls'] - dict_format_tool_calls.sort(key=lambda x: x['index']) - tool_calls = [] - for tool_call_dict in zhiput_message_dict['tool_calls']: - if tool_call_dict['type'] != 'function': - raise ValueError(f'invalid tool type: {tool_call_dict["type"]}, should be function') - tool_calls.append( - ToolCall( - id=tool_call_dict['id'], - type='function', - function=FunctionCall( - name=tool_call_dict['function']['name'], - arguments=tool_call_dict['function']['arguments'], - ), - ) - ) - return AssistantMessage( - role='assistant', - content='', - tool_calls=tool_calls, - ) - return AssistantMessage( - role='assistant', - content=zhiput_message_dict['content'], - ) - - -def _calculate_cost(model_name: str, usage: dict[str, Any]) -> float | None: - if model_name == 'glm-4': - return 0.1 * (usage['total_tokens'] / 1000) - if model_name == 'glm-3-turbo': - return 0.005 * (usage['total_tokens'] / 1000) - if model_name == 'characterglm': - return 0.015 * (usage['total_tokens'] / 1000) - return None - - -class _StreamResponseProcessor: - def __init__(self) -> None: - self.message: AssistantMessage | None = None - self.is_start = True - - def process(self, stream_line: str) -> ChatCompletionStreamOutput | None: - if not stream_line.strip(): - return None - - line = self._preprocess_stream_line(stream_line) - if not line: - return None - response = json.loads(line) - delta_dict = response['choices'][0]['delta'] - - if self.message is None: - if self._is_contains_content(delta_dict): - self.message = self.process_initial_message(delta_dict) - else: - return None - else: - self.update_existing_message(delta_dict) - extra = self.extract_extra_info(response) - cost = cost = self.calculate_response_cost(response) - finish_reason = self.determine_finish_reason(response) - stream_control = 'finish' if finish_reason else 'start' if self.is_start else 'continue' - self.is_start = False - return ChatCompletionStreamOutput( - model_info=ModelInfo(task='chat_completion', type='zhipu', name=response['model']), - message=self.message, - finish_reason=finish_reason, - cost=cost, - extra=extra, - stream=Stream(delta=delta_dict.get('content') or '', control=stream_control), - ) - - @staticmethod - def _preprocess_stream_line(line: str) -> str: - line = line.replace('data:', '') - return line.strip() - - def _is_contains_content(self, delta_dict: dict[str, Any]) -> bool: - return not ( - delta_dict.get('content') is None - and delta_dict.get('tool_calls') is None - and delta_dict.get('function_call') is None - ) - - def process_initial_message(self, delta_dict: dict[str, Any]) -> AssistantMessage: - return _convert_to_assistant_message(delta_dict) - - def update_existing_message(self, delta_dict: dict[str, Any]) -> None: - if not delta_dict: - return - assert self.message is not None - - delta_content = delta_dict.get('content', '') - self.message.content += delta_content - - if delta_dict.get('tool_calls'): - index = delta_dict['tool_calls'][0]['index'] - if index >= len(self.message.tool_calls or []): - new_tool_calls_message = _convert_to_assistant_message(delta_dict).tool_calls - assert new_tool_calls_message is not None - if self.message.tool_calls is None: - self.message.tool_calls = [] - self.message.tool_calls.append(new_tool_calls_message[0]) - else: - assert self.message.tool_calls is not None - self.message.tool_calls[index].function.arguments += delta_dict['tool_calls'][0]['function']['arguments'] - - def extract_extra_info(self, response: ResponseValue) -> dict[str, Any]: - extra = {} - if usage := response.get('usage'): - extra['usage'] = usage - if system_fingerprint := response.get('system_fingerprint'): - extra['system_fingerprint'] = system_fingerprint - return extra - - @staticmethod - def calculate_response_cost(response: ResponseValue) -> float | None: - if usage := response.get('usage'): - return _calculate_cost(response['model'], usage) - return None + def convert_function_message(self, message: FunctionMessage) -> Dict[str, Any]: + raise NotImplementedError('Zhipu does not support function messages') - def determine_finish_reason(self, response: ResponseValue) -> str | None: - return response['choices'][0].get('finish_reason') + def convert_user_multi_part_message(self, message: UserMultiPartMessage) -> Dict[str, Any]: + content = [] + for part in message.content: + if isinstance(part, TextPart): + content.append( + { + 'type': 'text', + 'text': part.text, + } + ) + elif isinstance(part, ImageUrlPart): + content.append( + { + 'type': 'image_url', + 'image_url': { + 'url': part.image_url.url, + }, + } + ) + elif isinstance(part, ImagePart): + content.append( + { + 'type': 'image_url', + 'image_url': { + 'url': base64.b64encode(part.image).decode(), + }, + } + ) + return {'role': 'user', 'content': content} -class ZhipuChat(RemoteChatCompletionModel): +class ZhipuChat(RemoteChatCompletionModel, ToolCallMixin): model_type: ClassVar[str] = 'zhipu' available_models: ClassVar[List[str]] = ['glm-4', 'glm-3-turbo', 'glm-4v'] parameters: ZhipuChatParameters settings: ZhipuSettings + message_converter: ZhipuMessageConverter def __init__( self, @@ -343,44 +223,23 @@ def __init__( parameters: ZhipuChatParameters | None = None, settings: ZhipuSettings | None = None, http_client: HttpClient | None = None, + message_converter: ZhipuMessageConverter | None = None, + cost_calculator: CostCalculator | None = None, ) -> None: parameters = parameters or ZhipuChatParameters() settings = settings or ZhipuSettings() # type: ignore - http_client = http_client or HttpClient(stream_strategy='basic') - super().__init__(model=model, parameters=parameters, settings=settings, http_client=http_client) - - @override - def _get_request_parameters( - self, prompt: Prompt, stream: bool = False, **kwargs: Unpack[ZhipuChatParametersDict] - ) -> HttpxPostKwargs: - messages = ensure_messages(prompt) - parameters = self.parameters.clone_with_changes(**kwargs) - zhipu_messages = self._convert_messages(messages) - headers = { - 'Authorization': generate_zhipu_token(self.settings.api_key.get_secret_value()), - } - params = {'messages': zhipu_messages, 'model': self.model, 'stream': stream, **parameters.custom_model_dump()} - return { - 'url': f'{self.settings.v4_api_base}/chat/completions', - 'headers': headers, - 'json': params, - } - - @override - def _process_reponse(self, response: ResponseValue) -> ChatCompletionOutput: - message_dict = response['choices'][0]['message'] - finish_reason = response['choices'][0]['finish_reason'] - return ChatCompletionOutput( - model_info=self.model_info, - message=_convert_to_assistant_message(message_dict), - cost=_calculate_cost(self.model, response['usage']), - extra={'usage': response['usage']}, - finish_reason=finish_reason, + http_client = http_client or HttpClient() + message_converter = message_converter or ZhipuMessageConverter() + cost_calculator = cost_calculator or GeneralCostCalculator(ZhipuModelPrice) + super().__init__( + model=model, + parameters=parameters, + settings=settings, + http_client=http_client, + message_converter=message_converter, + cost_calculator=cost_calculator, ) - def _convert_messages(self, messages: Messages) -> list[ZhipuMessage]: - return [convert_to_zhipu_message(message) for message in messages] - @override def generate(self, prompt: Prompt, **kwargs: Unpack[ZhipuChatParametersDict]) -> ChatCompletionOutput: return super().generate(prompt, **kwargs) @@ -393,174 +252,133 @@ async def async_generate(self, prompt: Prompt, **kwargs: Unpack[ZhipuChatParamet def stream_generate( self, prompt: Prompt, **kwargs: Unpack[ZhipuChatParametersDict] ) -> Iterator[ChatCompletionStreamOutput]: - request_parameters = self._get_request_parameters(prompt, stream=True, **kwargs) - stream_processor = _StreamResponseProcessor() - is_finish = False - for line in self.http_client.stream_post(request_parameters=request_parameters): - if is_finish: - continue - output = stream_processor.process(line) - if output is None: - continue - is_finish = output.is_finish - yield output + yield from super().stream_generate(prompt, **kwargs) @override async def async_stream_generate( self, prompt: Prompt, **kwargs: Unpack[ZhipuChatParametersDict] ) -> AsyncIterator[ChatCompletionStreamOutput]: - request_parameters = self._get_request_parameters(prompt, stream=True, **kwargs) - stream_processor = _StreamResponseProcessor() - is_finish = False - async for line in self.http_client.async_stream_post(request_parameters=request_parameters): - if is_finish: - continue - output = stream_processor.process(line) - if output is None: - continue - is_finish = output.is_finish - yield output - - @override - def _process_stream_line(self, line: str, stream_manager: StreamManager) -> ChatCompletionStreamOutput | None: - raise NotImplementedError - - -class ZhipuMeta(TypedDict): - user_info: str - bot_info: str - bot_name: str - user_name: str - - -class ZhipuCharacterChatParameters(ModelParameters): - meta: ZhipuMeta = { - 'user_info': '我是陆星辰,是一个男性,是一位知名导演,也是苏梦远的合作导演。', - 'bot_info': '苏梦远,本名苏远心,是一位当红的国内女歌手及演员。', - 'bot_name': '苏梦远', - 'user_name': '陆星辰', - } - request_id: Optional[str] = None - - def custom_model_dump(self) -> dict[str, Any]: - output = super().custom_model_dump() - output['return_type'] = 'text' - return output - - -class ZhipuCharacterChatParametersDict(RemoteModelParametersDict, total=False): - meta: ZhipuMeta - request_id: Optional[str] - - -class ZhipuCharacterChat(RemoteChatCompletionModel): - model_type: ClassVar[str] = 'zhipu_character' - available_models: ClassVar[List[str]] = ['charglm-3'] - - parameters: ZhipuCharacterChatParameters - settings: ZhipuSettings - - def __init__( - self, - model: str = 'charglm-3', - parameters: ZhipuCharacterChatParameters | None = None, - settings: ZhipuSettings | None = None, - http_client: HttpClient | None = None, - ) -> None: - parameters = parameters or ZhipuCharacterChatParameters() - settings = settings or ZhipuSettings() # type: ignore - http_client = http_client or HttpClient() - super().__init__(model=model, parameters=parameters, settings=settings, http_client=http_client) + async for stream_output in super().async_stream_generate(prompt, **kwargs): + yield stream_output @override def _get_request_parameters( - self, prompt: Prompt, stream: bool = False, **kwargs: Unpack[ZhipuCharacterChatParametersDict] + self, messages: Messages, stream: bool = False, **kwargs: Unpack[ZhipuChatParametersDict] ) -> HttpxPostKwargs: - messages = ensure_messages(prompt) parameters = self.parameters.clone_with_changes(**kwargs) - zhipu_messages = self._convert_messages(messages) headers = { 'Authorization': generate_zhipu_token(self.settings.api_key.get_secret_value()), } - params = {'prompt': zhipu_messages, **parameters.custom_model_dump()} - if stream: - url = f'{self.settings.v3_api_base}/{self.model}/sse-invoke' - else: - url = f'{self.settings.v3_api_base}/{self.model}/invoke' + params = { + 'messages': self.message_converter.convert_messages(messages), + 'model': self.model, + 'stream': stream, + **parameters.custom_model_dump(), + } return { - 'url': url, + 'url': self.settings.api_base + '/chat/completions', 'headers': headers, 'json': params, } - def _convert_messages(self, messages: Messages) -> list[ZhipuMessage]: - return [convert_to_zhipu_message(message) for message in messages] - @override def _process_reponse(self, response: ResponseValue) -> ChatCompletionOutput: - if response['success']: - text = response['data']['choices'][0]['content'] - return ChatCompletionOutput( - model_info=self.model_info, - message=AssistantMessage(content=text), - cost=_calculate_cost(self.name, response['data']['usage']), - extra={'usage': response['data']['usage']}, - ) - - raise UnexpectedResponseError(response) - - @override - def generate(self, prompt: Prompt, **kwargs: Unpack[ZhipuCharacterChatParametersDict]) -> ChatCompletionOutput: - return super().generate(prompt, **kwargs) - - @override - async def async_generate(self, prompt: Prompt, **kwargs: Unpack[ZhipuCharacterChatParametersDict]) -> ChatCompletionOutput: - return await super().async_generate(prompt, **kwargs) - - @override - def stream_generate( - self, prompt: Prompt, **kwargs: Unpack[ZhipuCharacterChatParametersDict] - ) -> Iterator[ChatCompletionStreamOutput]: - request_parameters = self._get_request_parameters(prompt, stream=True, **kwargs) - message = AssistantMessage(content='') - is_start = True - for line in self.http_client.stream_post(request_parameters=request_parameters): - message.content += line - yield ChatCompletionStreamOutput( - model_info=self.model_info, - message=message, - stream=Stream(delta=line, control='start' if is_start else 'continue'), - ) - is_start = False - yield ChatCompletionStreamOutput( + return ChatCompletionOutput( model_info=self.model_info, - message=message, - finish_reason='stop', - stream=Stream(delta='', control='finish'), + message=self._parse_assistant_message(response), + usage=self._parse_usage(response), + extra=self._parse_extra(response), + finish_reason=self._parse_finish_reason(response), ) @override - async def async_stream_generate( - self, prompt: Prompt, **kwargs: Unpack[ZhipuCharacterChatParametersDict] - ) -> AsyncIterator[ChatCompletionStreamOutput]: - request_parameters = self._get_request_parameters(prompt, stream=True, **kwargs) - message = AssistantMessage(content='') - is_start = True - async for line in self.http_client.async_stream_post(request_parameters=request_parameters): - message.content += line - yield ChatCompletionStreamOutput( - model_info=self.model_info, - message=message, - stream=Stream(delta=line, control='start' if is_start else 'continue'), + def _process_stream_response( + self, response: Dict[str, Any], stream_manager: StreamManager + ) -> ChatCompletionStreamOutput | None: + delta_dict = response['choices'][0]['delta'] + self._update_delta(delta_dict, stream_manager) + stream_manager.finish_reason = self._parse_finish_reason(response) + stream_manager.extra = self._parse_extra(response) + stream_manager.usage = self._parse_usage(response) + return stream_manager.build_stream_output() + + def add_tools(self, tools: OrIterable[Tool]) -> None: + new_tools: list[ZhipuTool] = [ + { + 'type': 'function', + 'function': { + 'name': tool.name, + 'description': tool.description, + 'parameters': tool.parameters, + }, + } + for tool in ensure_iterable(tools) + ] + if self.parameters.tools is None: + self.parameters.tools = new_tools + else: + self.parameters.tools.extend(new_tools) + + def _parse_assistant_message(self, response: dict[str, Any]) -> AssistantMessage: + if 'tool_calls' in response: + dict_format_tool_calls = response['tool_calls'] + dict_format_tool_calls.sort(key=lambda x: x['index']) + tool_calls = [] + for tool_call_dict in response['tool_calls']: + if tool_call_dict['type'] != 'function': + raise ValueError(f'invalid tool type: {tool_call_dict["type"]}, should be function') + tool_calls.append( + ToolCall( + id=tool_call_dict['id'], + type='function', + function=FunctionCall( + name=tool_call_dict['function']['name'], + arguments=tool_call_dict['function']['arguments'], + ), + ) + ) + return AssistantMessage( + role='assistant', + content=response.get('content') or '', + tool_calls=tool_calls, ) - is_start = False - yield ChatCompletionStreamOutput( - model_info=self.model_info, - message=message, - finish_reason='stop', - stream=Stream(delta='', control='finish'), + return AssistantMessage( + role='assistant', + content=response['content'], ) - @override - def _process_stream_line(self, line: str, stream_manager: StreamManager) -> ChatCompletionStreamOutput | None: - raise NotImplementedError + def _parse_usage(self, response: dict[str, Any]) -> Usage: + usage = response.get('usage') + if usage is not None: + input_tokens = usage['prompt_tokens'] + output_tokens = usage['completion_tokens'] + cost = self.cost(input_tokens, output_tokens) + return Usage(input_tokens=input_tokens, output_tokens=output_tokens, cost=cost) + return Usage() + + def _parse_finish_reason(self, response: dict[str, Any]) -> FinishReason | None: + try: + choice = response['choices'][0] + if finish_reason := choice.get('finish_reason'): + return FinishReason(finish_reason) + except (KeyError, IndexError, ValueError): + return None + + def _parse_extra(self, response: dict[str, Any]) -> dict[str, Any]: + return {'response': response} + + def _update_delta(self, delta_dict: dict[str, Any], stream_manager: StreamManager) -> None: + delta_content: str = delta_dict.get('content') or '' + stream_manager.delta = delta_content + + if delta_dict.get('tool_calls'): + index = delta_dict['tool_calls'][0]['index'] + if index >= len(stream_manager.tool_calls or []): + new_tool_calls_message = self._parse_assistant_message(delta_dict).tool_calls + assert new_tool_calls_message is not None + if stream_manager.tool_calls is None: + stream_manager.tool_calls = [] + stream_manager.tool_calls.append(new_tool_calls_message[0]) + else: + assert stream_manager.tool_calls is not None + stream_manager.tool_calls[index].function.arguments += delta_dict['tool_calls'][0]['function']['arguments'] diff --git a/generate/chat_completion/stream_manager.py b/generate/chat_completion/stream_manager.py index 5e5d054..db2c5ff 100644 --- a/generate/chat_completion/stream_manager.py +++ b/generate/chat_completion/stream_manager.py @@ -3,16 +3,16 @@ from pydantic import BaseModel from generate.chat_completion.message.core import AssistantMessage, FunctionCall, ToolCall -from generate.chat_completion.model_output import ChatCompletionStreamOutput, Stream +from generate.chat_completion.model_output import ChatCompletionStreamOutput, FinishReason, Stream, Usage from generate.model import ModelInfo class StreamManager(BaseModel): info: ModelInfo delta: Optional[str] = None - cost: Optional[float] = None + usage: Usage = Usage() history_streams: List[Stream] = [] - finish_reason: Optional[str] = None + finish_reason: Optional[FinishReason] = None function_call: Optional[FunctionCall] = None tool_calls: Optional[List[ToolCall]] = None close: bool = False @@ -50,14 +50,11 @@ def build_stream_output(self) -> Optional[ChatCompletionStreamOutput]: stream = self.current_stream if stream: - if not self.history_streams: - assert self.control == 'start' - self.history_streams.append(stream) self.delta = None output = ChatCompletionStreamOutput( model_info=self.info, - cost=self.cost, + usage=self.usage, extra=self.extra, finish_reason=self.finish_reason, message=AssistantMessage( diff --git a/generate/chat_completion/tool.py b/generate/chat_completion/tool.py index 8d14873..00247e2 100644 --- a/generate/chat_completion/tool.py +++ b/generate/chat_completion/tool.py @@ -1,5 +1,6 @@ from __future__ import annotations +import uuid from collections import UserDict from typing import Any, Callable, Generic, MutableMapping, TypeVar @@ -7,6 +8,14 @@ from pydantic import TypeAdapter, validate_call from typing_extensions import NotRequired, ParamSpec, Self, TypedDict +from generate.chat_completion.message.core import ( + AssistantMessage, + FunctionCall, + FunctionMessage, + Messages, + ToolCall, + ToolMessage, +) from generate.types import JsonSchema, OrIterable from generate.utils import ensure_iterable @@ -100,3 +109,19 @@ def from_iterable(cls, tools: OrIterable[Tool]) -> Self: class ToolCallMixin: def add_tools(self, tools: OrIterable[Tool]) -> None: raise NotImplementedError + + def generate_tool_call_id(self, function_call: FunctionCall) -> str: + return f'tool_{uuid.uuid4().hex}' + + def adapt_tool_calls(self, messages: Messages) -> None: + for index in range(len(messages)): + current_message = messages[index] + if isinstance(current_message, AssistantMessage) and current_message.function_call is not None: + tool_call_id = self.generate_tool_call_id(current_message.function_call) + messages[index].tool_calls = [ToolCall(id=tool_call_id, function=current_message.function_call)] + messages[index].function_call = None + next_message = messages[index + 1] if index + 1 < len(messages) else None + if next_message is not None and isinstance(next_message, FunctionMessage): + messages[index + 1] = ToolMessage( + tool_call_id=tool_call_id, name=next_message.name, content=next_message.content + ) diff --git a/generate/http.py b/generate/http.py index 351ea7a..af13385 100644 --- a/generate/http.py +++ b/generate/http.py @@ -27,8 +27,8 @@ class RetryStrategy(BaseModel): class HttpGetKwargs(TypedDict, total=False): url: Required[str] - params: QueryParams - headers: Headers + params: Optional[QueryParams] + headers: Optional[Headers] timeout: Optional[int] @@ -190,29 +190,36 @@ def async_stream_post(self, request_parameters: HttpxPostKwargs) -> AsyncGenerat def _post(self, request_parameters: HttpxPostKwargs) -> Response: logger.debug(f'POST {request_parameters}') http_response = self.client.post(**request_parameters) # type: ignore - http_response.raise_for_status() + self.raise_for_status(http_response) logger.debug(f'Response {http_response}') return http_response async def _async_post(self, request_parameters: HttpxPostKwargs) -> Response: logger.debug(f'POST {request_parameters}') http_response = await self.async_client.post(**request_parameters) # type: ignore - http_response.raise_for_status() + self.raise_for_status(http_response) logger.debug(f'Response {http_response}') return http_response def _get(self, request_parameters: HttpGetKwargs) -> Response: logger.debug(f'GET {request_parameters}') http_response = self.client.get(**request_parameters) - http_response.raise_for_status() + self.raise_for_status(http_response) return http_response async def _async_get(self, request_parameters: HttpGetKwargs) -> Response: logger.debug(f'GET {request_parameters}') http_response = await self.async_client.get(**request_parameters) - http_response.raise_for_status() + self.raise_for_status(http_response) return http_response + @staticmethod + def raise_for_status(response: Response) -> None: + try: + response.raise_for_status() + except httpx.HTTPStatusError as e: + raise UnexpectedResponseError(response.json()) from e + def _stream_post(self, request_parameters: HttpxPostKwargs) -> Generator[str, None, None]: logger.debug(f'POST {request_parameters}') if self.stream_strategy == 'sse': diff --git a/generate/model.py b/generate/model.py index caa6881..fea05dd 100644 --- a/generate/model.py +++ b/generate/model.py @@ -46,7 +46,6 @@ class ModelOutput(BaseModel): model_config = ConfigDict(protected_namespaces=()) model_info: ModelInfo - cost: Optional[float] = None extra: Dict[str, Any] = {} diff --git a/generate/platforms/__init__.py b/generate/platforms/__init__.py index e548507..189d4a5 100644 --- a/generate/platforms/__init__.py +++ b/generate/platforms/__init__.py @@ -2,7 +2,6 @@ from generate.platforms.azure import AzureSettings from generate.platforms.baichuan import BaichuanSettings from generate.platforms.baidu import BaiduCreationSettings, QianfanSettings -from generate.platforms.bailian import BailianSettings from generate.platforms.base import PlatformSettings from generate.platforms.dashscope import DashScopeSettings from generate.platforms.deepseek import DeepSeekSettings @@ -10,6 +9,7 @@ from generate.platforms.minimax import MinimaxSettings from generate.platforms.moonshot import MoonshotSettings from generate.platforms.openai import OpenAISettings +from generate.platforms.openrouter import OpenRouterSettings from generate.platforms.stepfun import StepFunSettings from generate.platforms.yi import YiSettings from generate.platforms.zhipu import ZhipuSettings @@ -23,7 +23,6 @@ 'ZhipuSettings', 'OpenAISettings', 'QianfanSettings', - 'BailianSettings', 'HunyuanSettings', 'DashScopeSettings', 'MoonshotSettings', @@ -31,4 +30,5 @@ 'YiSettings', 'PlatformSettings', 'StepFunSettings', + 'OpenRouterSettings', ] diff --git a/generate/platforms/baichuan.py b/generate/platforms/baichuan.py index b5c5ee4..fe816c7 100644 --- a/generate/platforms/baichuan.py +++ b/generate/platforms/baichuan.py @@ -8,6 +8,5 @@ class BaichuanSettings(PlatformSettings): model_config = SettingsConfigDict(extra='ignore', env_prefix='baichuan_', env_file='.env') api_key: SecretStr - secret_key: SecretStr api_base: str = 'https://api.baichuan-ai.com/v1' platform_url: str = 'https://platform.baichuan-ai.com/docs/api' diff --git a/generate/platforms/bailian.py b/generate/platforms/bailian.py deleted file mode 100644 index 40aeaa0..0000000 --- a/generate/platforms/bailian.py +++ /dev/null @@ -1,37 +0,0 @@ -from pydantic import SecretStr -from pydantic_settings import SettingsConfigDict - -from generate.access_token_manager import AccessTokenManager -from generate.http import HttpClient -from generate.platforms.base import PlatformSettings - - -class BailianSettings(PlatformSettings): - model_config = SettingsConfigDict(extra='ignore', env_prefix='bailian_', env_file='.env') - - default_app_id: str - access_key_id: SecretStr - access_key_secret: SecretStr - agent_key: str - completion_api: str = 'https://bailian.aliyuncs.com/v2/app/completions' - platform_url: str = 'https://help.aliyun.com/product/2400256.html' - - -class BailianTokenManager(AccessTokenManager): - def __init__(self, settings: BailianSettings, http_client: HttpClient, token_refresh_days: int = 1) -> None: - super().__init__(token_refresh_days) - self.settings = settings - self.http_client = http_client - - def _get_token(self) -> str: - try: - import broadscope_bailian - except ImportError as e: - raise ImportError('Please install broadscope_bailian first: pip install broadscope_bailian') from e - - client = broadscope_bailian.AccessTokenClient( - access_key_id=self.settings.access_key_id.get_secret_value(), - access_key_secret=self.settings.access_key_secret.get_secret_value(), - agent_key=self.settings.agent_key, - ) - return client.get_token() diff --git a/generate/platforms/dashscope.py b/generate/platforms/dashscope.py index b00e6ef..0a62089 100644 --- a/generate/platforms/dashscope.py +++ b/generate/platforms/dashscope.py @@ -1,3 +1,5 @@ +from typing import Optional + from pydantic import SecretStr from pydantic_settings import SettingsConfigDict @@ -8,5 +10,6 @@ class DashScopeSettings(PlatformSettings): model_config = SettingsConfigDict(extra='ignore', env_prefix='dashscope_', env_file='.env') api_key: SecretStr + workspace: Optional[str] = None api_base: str = 'https://dashscope.aliyuncs.com/api/v1' platform_url: str = 'https://help.aliyun.com/zh/dashscope/' diff --git a/generate/platforms/openrouter.py b/generate/platforms/openrouter.py new file mode 100644 index 0000000..4d58b87 --- /dev/null +++ b/generate/platforms/openrouter.py @@ -0,0 +1,12 @@ +from pydantic import SecretStr +from pydantic_settings import SettingsConfigDict + +from generate.platforms.openai_like import OpenAILikeSettings + + +class OpenRouterSettings(OpenAILikeSettings): + model_config = SettingsConfigDict(extra='ignore', env_prefix='openrouter_', env_file='.env') + + api_key: SecretStr + api_base: str = 'https://openrouter.ai/api/v1' + platform_url: str = 'https://openrouter.ai/' diff --git a/generate/platforms/yi.py b/generate/platforms/yi.py index 61ddd94..55b9a5b 100644 --- a/generate/platforms/yi.py +++ b/generate/platforms/yi.py @@ -9,4 +9,4 @@ class YiSettings(OpenAILikeSettings): api_key: SecretStr api_base: str = 'https://api.lingyiwanwu.com/v1' - platform_url: str = 'https://01ai.feishu.cn/docx/Q8Pcdn76uoHBc8xAvKCcPSd0nkc' + platform_url: str = 'https://platform.lingyiwanwu.com/docs' diff --git a/generate/platforms/zhipu.py b/generate/platforms/zhipu.py index 14fcae2..c950803 100644 --- a/generate/platforms/zhipu.py +++ b/generate/platforms/zhipu.py @@ -15,8 +15,7 @@ class ZhipuSettings(PlatformSettings): model_config = SettingsConfigDict(extra='ignore', env_prefix='zhipu_', env_file='.env') api_key: SecretStr - v3_api_base: str = 'https://open.bigmodel.cn/api/paas/v3/model-api' - v4_api_base: str = 'https://open.bigmodel.cn/api/paas/v4' + api_base: str = 'https://open.bigmodel.cn/api/paas/v4' platform_url: str = 'https://open.bigmodel.cn/dev/howuse/introduction' diff --git a/generate/types.py b/generate/types.py index 2ab7247..ffb9334 100644 --- a/generate/types.py +++ b/generate/types.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Dict, Iterable, Optional, Sequence, TypeVar, Union +from typing import Any, Dict, Iterable, Optional, Sequence, Tuple, TypeVar, Union from pydantic import Field from typing_extensions import Annotated @@ -12,3 +12,4 @@ PrimitiveData = Optional[Union[str, int, float, bool]] OrSequence = Union[T, Sequence[T]] OrIterable = Union[T, Iterable[T]] +ModelPrice = Dict[str, Tuple[float, float]] diff --git a/generate/version.py b/generate/version.py index 908c0bb..2b8877c 100644 --- a/generate/version.py +++ b/generate/version.py @@ -1 +1 @@ -__version__ = '0.4.3' +__version__ = '0.5.0' diff --git a/pyproject.toml b/pyproject.toml index 7438944..6150bde 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "generate-core" -version = "0.4.3" +version = "0.5.0" description = "文本生成,图像生成,语音生成" authors = ["wangyuxin "] license = "MIT" @@ -26,12 +26,12 @@ pydantic-settings = "^2.1.0" [tool.ruff] line-length = 128 lint.select = [ - "E", # pycodestyle errors - "W", # pycodestyle warnings - "F", # pyflakes - "I", # isort - "C", # flake8-comprehensions - "B", # flake8-bugbear + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # pyflakes + "I", # isort + "C", # flake8-comprehensions + "B", # flake8-bugbear "N", "SIM", "ANN", @@ -40,7 +40,6 @@ lint.select = [ "PT", "RET", "TRY", - "PERF", ] lint.ignore = [ "E501", # line too long, handled by black From 81bb0f777b1f8d27607fa984ac1e6a8a4aa74472 Mon Sep 17 00:00:00 2001 From: wangyuxin Date: Thu, 23 May 2024 17:47:53 +0800 Subject: [PATCH 2/4] =?UTF-8?q?=E6=B8=85=E7=90=86=E8=BF=87=E6=9C=9F?= =?UTF-8?q?=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- generate/__init__.py | 79 +--- generate/chat_completion/__init__.py | 25 +- generate/chat_completion/base.py | 38 +- generate/chat_completion/cost_caculator.py | 26 -- generate/chat_completion/message/converter.py | 14 +- generate/chat_completion/message/core.py | 6 +- generate/chat_completion/model_output.py | 1 - generate/chat_completion/models/__init__.py | 23 +- generate/chat_completion/models/anthropic.py | 56 +-- generate/chat_completion/models/azure.py | 18 +- generate/chat_completion/models/baichuan.py | 159 ++++++-- generate/chat_completion/models/dashscope.py | 125 +++--- .../models/dashscope_multimodal.py | 232 ----------- generate/chat_completion/models/deepseek.py | 19 +- generate/chat_completion/models/hunyuan.py | 216 ---------- generate/chat_completion/models/minimax.py | 62 ++- .../chat_completion/models/minimax_pro.py | 381 ------------------ generate/chat_completion/models/openai.py | 99 +---- .../chat_completion/models/openai_like.py | 210 +++++----- generate/chat_completion/models/openrouter.py | 24 +- generate/chat_completion/models/stepfun.py | 21 +- generate/chat_completion/models/wenxin.py | 260 ------------ generate/chat_completion/models/yi.py | 19 +- generate/chat_completion/models/zhipu.py | 31 +- generate/chat_completion/tool.py | 9 +- generate/constant.py | 1 + generate/highlevel.py | 38 -- generate/image_generation/__init__.py | 43 -- generate/image_generation/base.py | 46 --- generate/image_generation/models/__init__.py | 15 - generate/image_generation/models/baidu.py | 190 --------- generate/image_generation/models/openai.py | 178 -------- generate/image_generation/models/qianfan.py | 120 ------ generate/image_generation/models/zhipu.py | 90 ----- generate/modifiers/agent.py | 4 +- generate/modifiers/cache.py | 65 +++ generate/modifiers/structure.py | 92 ++--- generate/platforms/__init__.py | 2 - generate/platforms/azure.py | 7 +- generate/platforms/hunyuan.py | 15 - generate/text_to_speech/__init__.py | 36 -- generate/text_to_speech/base.py | 39 -- generate/text_to_speech/models/__init__.py | 16 - generate/text_to_speech/models/minimax.py | 212 ---------- generate/text_to_speech/models/openai.py | 104 ----- generate/types.py | 3 +- generate/ui.py | 14 +- poetry.lock | 13 +- pyproject.toml | 1 + tests/test_highlevel.py | 31 +- tests/test_text_to_speech_model.py | 26 -- 51 files changed, 599 insertions(+), 2955 deletions(-) delete mode 100644 generate/chat_completion/cost_caculator.py delete mode 100644 generate/chat_completion/models/dashscope_multimodal.py delete mode 100644 generate/chat_completion/models/hunyuan.py delete mode 100644 generate/chat_completion/models/minimax_pro.py delete mode 100644 generate/chat_completion/models/wenxin.py create mode 100644 generate/constant.py delete mode 100644 generate/image_generation/__init__.py delete mode 100644 generate/image_generation/base.py delete mode 100644 generate/image_generation/models/__init__.py delete mode 100644 generate/image_generation/models/baidu.py delete mode 100644 generate/image_generation/models/openai.py delete mode 100644 generate/image_generation/models/qianfan.py delete mode 100644 generate/image_generation/models/zhipu.py create mode 100644 generate/modifiers/cache.py delete mode 100644 generate/platforms/hunyuan.py delete mode 100644 generate/text_to_speech/__init__.py delete mode 100644 generate/text_to_speech/base.py delete mode 100644 generate/text_to_speech/models/__init__.py delete mode 100644 generate/text_to_speech/models/minimax.py delete mode 100644 generate/text_to_speech/models/openai.py delete mode 100644 tests/test_text_to_speech_model.py diff --git a/generate/__init__.py b/generate/__init__.py index 4ba466d..f1baf1c 100644 --- a/generate/__init__.py +++ b/generate/__init__.py @@ -9,52 +9,27 @@ ChatModelRegistry, DashScopeChat, DashScopeChatParameters, - DashScopeMultiModalChat, - DashScopeMultiModalChatParameters, DeepSeekChat, DeepSeekChatParameters, - HunyuanChat, - HunyuanChatParameters, MinimaxChat, MinimaxChatParameters, - MinimaxProChat, - MinimaxProChatParameters, MoonshotChat, MoonshotChatParameters, OpenAIChat, OpenAIChatParameters, + OpenRouterChat, + OpenRouterChatParameters, Prompt, RemoteChatCompletionModel, StepFunChat, StepFunChatParameters, - WenxinChat, - WenxinChatParameters, YiChat, YiChatParameters, ZhipuChat, ZhipuChatParameters, tool, ) -from generate.highlevel import ( - generate_image, - generate_speech, - generate_text, - load_chat_model, - load_image_generation_model, - load_speech_model, -) -from generate.image_generation import ( - BaiduImageGeneration, - BaiduImageGenerationParameters, - ImageGenerationModel, - ImageGenerationModelRegistry, - ImageGenerationOutput, - OpenAIImageGeneration, - OpenAIImageGenerationParameters, - QianfanImageGeneration, - QianfanImageGenerationParameters, - ZhipuImageGeneration, -) +from generate.highlevel import generate_text, load_chat_model from generate.modifiers.hook import AfterGenerateContext, BeforeGenerateContext from generate.platforms import ( AnthropicSettings, @@ -63,26 +38,15 @@ BaiduCreationSettings, DashScopeSettings, DeepSeekSettings, - HunyuanSettings, MinimaxSettings, MoonshotSettings, OpenAISettings, + OpenRouterSettings, QianfanSettings, StepFunSettings, YiSettings, ZhipuSettings, ) -from generate.text_to_speech import ( - MinimaxProSpeech, - MinimaxProSpeechParameters, - MinimaxSpeech, - MinimaxSpeechParameters, - OpenAISpeech, - OpenAISpeechParameters, - SpeechModelRegistry, - TextToSpeechModel, - TextToSpeechOutput, -) from generate.version import __version__ __all__ = [ @@ -93,53 +57,28 @@ 'OpenAIChatParameters', 'MinimaxChat', 'MinimaxChatParameters', - 'MinimaxProChat', - 'MinimaxProChatParameters', 'ZhipuChat', 'ZhipuChatParameters', 'StepFunChat', 'StepFunChatParameters', 'StepFunSettings', - 'WenxinChat', - 'WenxinChatParameters', - 'HunyuanChat', - 'HunyuanChatParameters', 'BaichuanChat', 'BaichuanChatParameters', 'DashScopeChat', 'DashScopeChatParameters', - 'DashScopeMultiModalChat', - 'DashScopeMultiModalChatParameters', 'MoonshotChat', 'MoonshotChatParameters', 'DeepSeekChat', 'DeepSeekChatParameters', - 'OpenAISpeech', - 'OpenAISpeechParameters', + 'OpenRouterChat', + 'OpenRouterChatParameters', 'YiChat', 'YiChatParameters', 'AnthropicChat', 'AnthropicChatParameters', - 'MinimaxSpeech', - 'MinimaxSpeechParameters', - 'MinimaxProSpeech', - 'MinimaxProSpeechParameters', - 'OpenAIImageGeneration', - 'OpenAIImageGenerationParameters', - 'BaiduImageGeneration', - 'BaiduImageGenerationParameters', - 'QianfanImageGeneration', - 'QianfanImageGenerationParameters', - 'ZhipuImageGeneration', - 'ImageGenerationModel', - 'TextToSpeechModel', - 'ImageGenerationModelRegistry', 'Prompt', 'ChatCompletionOutput', - 'ImageGenerationOutput', - 'TextToSpeechOutput', 'ChatModelRegistry', - 'SpeechModelRegistry', 'AzureSettings', 'AnthropicSettings', 'BaichuanSettings', @@ -148,19 +87,15 @@ 'ZhipuSettings', 'OpenAISettings', 'QianfanSettings', - 'HunyuanSettings', 'DashScopeSettings', 'MoonshotSettings', 'DeepSeekSettings', + 'OpenRouterSettings', 'YiSettings', 'AfterGenerateContext', 'BeforeGenerateContext', 'tool', 'load_chat_model', - 'load_speech_model', - 'load_image_generation_model', 'generate_text', - 'generate_speech', - 'generate_image', '__version__', ] diff --git a/generate/chat_completion/__init__.py b/generate/chat_completion/__init__.py index 50b49e1..1d410d3 100644 --- a/generate/chat_completion/__init__.py +++ b/generate/chat_completion/__init__.py @@ -22,24 +22,18 @@ BaichuanChatParameters, DashScopeChat, DashScopeChatParameters, - DashScopeMultiModalChat, - DashScopeMultiModalChatParameters, DeepSeekChat, DeepSeekChatParameters, - HunyuanChat, - HunyuanChatParameters, MinimaxChat, MinimaxChatParameters, - MinimaxProChat, - MinimaxProChatParameters, MoonshotChat, MoonshotChatParameters, OpenAIChat, OpenAIChatParameters, + OpenRouterChat, + OpenRouterChatParameters, StepFunChat, StepFunChatParameters, - WenxinChat, - WenxinChatParameters, YiChat, YiChatParameters, ZhipuChat, @@ -54,17 +48,14 @@ (AnthropicChat, AnthropicChatParameters), (OpenAIChat, OpenAIChatParameters), (MinimaxChat, MinimaxChatParameters), - (MinimaxProChat, MinimaxProChatParameters), (ZhipuChat, ZhipuChatParameters), - (WenxinChat, WenxinChatParameters), - (HunyuanChat, HunyuanChatParameters), (BaichuanChat, BaichuanChatParameters), (DashScopeChat, DashScopeChatParameters), - (DashScopeMultiModalChat, DashScopeMultiModalChatParameters), (MoonshotChat, MoonshotChatParameters), (DeepSeekChat, DashScopeChatParameters), (StepFunChat, StepFunChatParameters), (YiChat, YiChatParameters), + (OpenRouterChat, OpenRouterChatParameters), ] ChatModelRegistry: dict[str, tuple[Type[ChatCompletionModel], Type[ModelParameters]]] = { @@ -81,16 +72,10 @@ 'AzureChat', 'MinimaxChat', 'MinimaxChatParameters', - 'MinimaxProChat', - 'MinimaxProChatParameters', 'OpenAIChat', 'OpenAIChatParameters', 'ZhipuChat', 'ZhipuChatParameters', - 'WenxinChat', - 'WenxinChatParameters', - 'HunyuanChat', - 'HunyuanChatParameters', 'BaichuanChat', 'BaichuanChatParameters', 'YiChat', @@ -101,12 +86,12 @@ 'AnthropicChatParameters', 'DashScopeChat', 'DashScopeChatParameters', - 'DashScopeMultiModalChat', - 'DashScopeMultiModalChatParameters', 'MoonshotChat', 'MoonshotChatParameters', 'DeepSeekChat', 'DeepSeekChatParameters', + 'OpenRouterChat', + 'OpenRouterChatParameters', 'MessagePrinter', 'SimpleMessagePrinter', 'get_json_schema', diff --git a/generate/chat_completion/base.py b/generate/chat_completion/base.py index 2407123..fbeac02 100644 --- a/generate/chat_completion/base.py +++ b/generate/chat_completion/base.py @@ -1,22 +1,21 @@ from __future__ import annotations -import contextlib import json import logging from abc import ABC, abstractmethod +from pathlib import Path from typing import TYPE_CHECKING, Any, AsyncIterator, ClassVar, Iterator, List, Type, TypeVar, get_type_hints from pydantic import BaseModel from typing_extensions import Self, Unpack, override -from generate.chat_completion.cost_caculator import CostCalculator from generate.chat_completion.message import Prompt from generate.chat_completion.message.converter import MessageConverter from generate.chat_completion.message.core import Messages from generate.chat_completion.message.utils import ensure_messages from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput from generate.chat_completion.stream_manager import StreamManager -from generate.chat_completion.tool import ToolCallMixin +from generate.chat_completion.tool import SupportToolCall from generate.http import HttpClient, HttpxPostKwargs from generate.model import GenerateModel, ModelParameters from generate.platforms import PlatformSettings @@ -26,6 +25,7 @@ if TYPE_CHECKING: from generate.modifiers.agent import Agent, AgentKwargs + from generate.modifiers.cache import CacheChatCompletionModel from generate.modifiers.hook import HookChatCompletionModel, HookModelKwargs from generate.modifiers.session import SessionChatCompletionModel from generate.modifiers.structure import StructureGenerateModel, StructureModelKwargs @@ -70,6 +70,11 @@ def hook(self, **kwargs: Unpack['HookModelKwargs']) -> 'HookChatCompletionModel' return HookChatCompletionModel(model=self, **kwargs) + def cache(self, cache_dir: Path | str | None = None) -> 'CacheChatCompletionModel': + from generate.modifiers.cache import CacheChatCompletionModel + + return CacheChatCompletionModel(model=self, cache_dir=cache_dir) + class RemoteChatCompletionModel(ChatCompletionModel, ABC): settings: PlatformSettings @@ -84,14 +89,12 @@ def __init__( settings: PlatformSettings, http_client: HttpClient, message_converter: MessageConverter, - cost_calculator: CostCalculator | None = None, ) -> None: self.model = model self.parameters = parameters self.settings = settings self.http_client = http_client self.message_converter = message_converter - self.cost_calculator = cost_calculator @abstractmethod def _process_reponse(self, response: dict[str, Any]) -> ChatCompletionOutput: @@ -107,22 +110,13 @@ def _process_stream_response( def _get_request_parameters(self, messages: Messages, stream: bool = False, **kwargs: Any) -> HttpxPostKwargs: ... - def cost(self, input_tokens: int, output_tokens: int) -> float | None: - if self.cost_calculator is None: - return None - return self.cost_calculator.calculate( - model_name=self.model, - input_tokens=input_tokens, - output_tokens=output_tokens, - ) - def list_models(self) -> List[str]: return self.available_models def process_prompt(self, prompt: Prompt) -> Messages: messages = ensure_messages(prompt) - if isinstance(self, ToolCallMixin): - self.adapt_tool_calls(messages) + if isinstance(self, SupportToolCall): + self.process_messages_for_tool_call(messages) return messages @override @@ -151,10 +145,12 @@ def stream_generate(self, prompt: Prompt, **kwargs: Any) -> Iterator[ChatComplet request_parameters['timeout'] = timeout stream_manager = StreamManager(info=self.model_info) for line in self.http_client.stream_post(request_parameters=request_parameters): - with contextlib.suppress(json.JSONDecodeError): + try: response = json.loads(line) if (output := self._process_stream_response(response, stream_manager)) and output: yield output + except json.JSONDecodeError: + continue @override async def async_stream_generate(self, prompt: Prompt, **kwargs: Any) -> AsyncIterator[ChatCompletionStreamOutput]: @@ -164,10 +160,12 @@ async def async_stream_generate(self, prompt: Prompt, **kwargs: Any) -> AsyncIte request_parameters['timeout'] = timeout stream_manager = StreamManager(info=self.model_info) async for line in self.http_client.async_stream_post(request_parameters=request_parameters): - with contextlib.suppress(json.JSONDecodeError): + try: response = json.loads(line) - if (output := self._process_stream_response(response, stream_manager)) and output: - yield output + if (output := self._process_stream_response(response, stream_manager)) and output: + yield output + except json.JSONDecodeError: + continue @classmethod def how_to_settings(cls) -> str: diff --git a/generate/chat_completion/cost_caculator.py b/generate/chat_completion/cost_caculator.py deleted file mode 100644 index 7be4636..0000000 --- a/generate/chat_completion/cost_caculator.py +++ /dev/null @@ -1,26 +0,0 @@ -from __future__ import annotations - -from typing import Protocol - -from generate.types import ModelPrice - - -class CostCalculator(Protocol): - def calculate(self, model_name: str, input_tokens: int, output_tokens: int) -> float | None: - ... - - -class GeneralCostCalculator(CostCalculator): - def __init__(self, model_price: ModelPrice, exchange_rate: float = 1) -> None: - # per million tokens - self.model_price = model_price - self.exchange_rate = exchange_rate - - def calculate(self, model_name: str, input_tokens: int, output_tokens: int) -> float | None: - if self.model_price is None: - return None - for model, (input_token_price, output_token_price) in self.model_price.items(): - if model in model_name: - cost = input_token_price * (input_tokens / 1_000_000) + output_token_price * (output_tokens / 1_000_000) - return cost * self.exchange_rate - return None diff --git a/generate/chat_completion/message/converter.py b/generate/chat_completion/message/converter.py index 2dc84da..e004531 100644 --- a/generate/chat_completion/message/converter.py +++ b/generate/chat_completion/message/converter.py @@ -1,19 +1,23 @@ from __future__ import annotations -from typing import Any, Protocol +from typing import Any, Protocol, Type from generate.chat_completion.message.core import ( AssistantMessage, FunctionMessage, + Message, Messages, SystemMessage, ToolMessage, UserMessage, UserMultiPartMessage, ) +from generate.chat_completion.message.exception import MessageTypeError class MessageConverter(Protocol): + allowed_message_types: list[Type[Message]] + def convert_user_message(self, message: UserMessage) -> dict[str, Any]: ... @@ -56,6 +60,8 @@ def convert_messages(self, messages: Messages) -> list[dict[str, Any]]: class SimpleMessageConverter(MessageConverter): + allowed_message_types = [UserMessage, AssistantMessage, SystemMessage] + def convert_system_message(self, message: SystemMessage) -> dict[str, Any]: return { 'role': 'system', @@ -75,10 +81,10 @@ def convert_assistant_message(self, message: AssistantMessage) -> dict[str, Any] } def convert_function_message(self, message: FunctionMessage) -> dict[str, Any]: - raise NotImplementedError('FunctionMessage is not supported by this converter') + raise MessageTypeError(message, allowed_message_type=list(self.allowed_message_types)) def convert_tool_message(self, message: ToolMessage) -> dict[str, Any]: - raise NotImplementedError('ToolMessage is not supported by this converter') + raise MessageTypeError(message, allowed_message_type=list(self.allowed_message_types)) def convert_user_multi_part_message(self, message: UserMultiPartMessage) -> dict[str, Any]: - raise NotImplementedError('UserMultiPartMessage is not supported by this converter') + raise MessageTypeError(message, allowed_message_type=list(self.allowed_message_types)) diff --git a/generate/chat_completion/message/core.py b/generate/chat_completion/message/core.py index 7adad0a..c79d32b 100644 --- a/generate/chat_completion/message/core.py +++ b/generate/chat_completion/message/core.py @@ -61,7 +61,7 @@ class UserMultiPartMessage(Message): class FunctionMessage(Message): role: Literal['function'] = 'function' - name: str + name: str # type: ignore content: str @@ -94,10 +94,6 @@ class AssistantMessage(Message): def is_over(self) -> bool: return self.function_call is None and self.tool_calls is None - def model_post_init(self, __context: Any) -> None: - if not self.content and self.function_call is None and self.tool_calls is None: - raise ValueError('AssistantMessage must have content, function_call, or tool_calls') - UnionUserMessage = Union[UserMessage, UserMultiPartMessage] UnionUserPart = Union[TextPart, ImageUrlPart] diff --git a/generate/chat_completion/model_output.py b/generate/chat_completion/model_output.py index 67c7a7c..f964616 100644 --- a/generate/chat_completion/model_output.py +++ b/generate/chat_completion/model_output.py @@ -21,7 +21,6 @@ class FinishReason(str, Enum): class Usage(BaseModel): input_tokens: Optional[int] = None output_tokens: Optional[int] = None - cost: Optional[float] = None class ChatCompletionOutput(ModelOutput): diff --git a/generate/chat_completion/models/__init__.py b/generate/chat_completion/models/__init__.py index be04f36..dbca0b9 100644 --- a/generate/chat_completion/models/__init__.py +++ b/generate/chat_completion/models/__init__.py @@ -1,22 +1,13 @@ from generate.chat_completion.models.anthropic import AnthropicChat, AnthropicChatParameters from generate.chat_completion.models.azure import AzureChat from generate.chat_completion.models.baichuan import BaichuanChat, BaichuanChatParameters -from generate.chat_completion.models.dashscope import ( - DashScopeChat, - DashScopeChatParameters, -) -from generate.chat_completion.models.dashscope_multimodal import ( - DashScopeMultiModalChat, - DashScopeMultiModalChatParameters, -) +from generate.chat_completion.models.dashscope import DashScopeChat, DashScopeChatParameters from generate.chat_completion.models.deepseek import DeepSeekChat, DeepSeekChatParameters -from generate.chat_completion.models.hunyuan import HunyuanChat, HunyuanChatParameters from generate.chat_completion.models.minimax import MinimaxChat, MinimaxChatParameters -from generate.chat_completion.models.minimax_pro import MinimaxProChat, MinimaxProChatParameters from generate.chat_completion.models.moonshot import MoonshotChat, MoonshotChatParameters from generate.chat_completion.models.openai import OpenAIChat, OpenAIChatParameters +from generate.chat_completion.models.openrouter import OpenRouterChat, OpenRouterChatParameters from generate.chat_completion.models.stepfun import StepFunChat, StepFunChatParameters -from generate.chat_completion.models.wenxin import WenxinChat, WenxinChatParameters from generate.chat_completion.models.yi import YiChat, YiChatParameters from generate.chat_completion.models.zhipu import ZhipuChat, ZhipuChatParameters @@ -26,16 +17,10 @@ 'AnthropicChatParameters', 'BaichuanChat', 'BaichuanChatParameters', - 'HunyuanChat', - 'HunyuanChatParameters', - 'MinimaxProChat', - 'MinimaxProChatParameters', 'MinimaxChat', 'MinimaxChatParameters', 'OpenAIChat', 'OpenAIChatParameters', - 'WenxinChat', - 'WenxinChatParameters', 'StepFunChat', 'StepFunChatParameters', 'YiChat', @@ -44,10 +29,10 @@ 'ZhipuChatParameters', 'DashScopeChat', 'DashScopeChatParameters', - 'DashScopeMultiModalChat', - 'DashScopeMultiModalChatParameters', 'MoonshotChat', 'MoonshotChatParameters', 'DeepSeekChat', 'DeepSeekChatParameters', + 'OpenRouterChat', + 'OpenRouterChatParameters', ] diff --git a/generate/chat_completion/models/anthropic.py b/generate/chat_completion/models/anthropic.py index 4570a24..f6dc4b4 100644 --- a/generate/chat_completion/models/anthropic.py +++ b/generate/chat_completion/models/anthropic.py @@ -3,13 +3,12 @@ import base64 import json import uuid -from typing import Any, AsyncIterator, ClassVar, Dict, Iterator, List, Optional +from typing import Any, AsyncIterator, ClassVar, Dict, Iterator, List, Literal, Optional from pydantic import Field, PositiveInt -from typing_extensions import Annotated, TypedDict, Unpack, override +from typing_extensions import Annotated, NotRequired, TypedDict, Unpack, override from generate.chat_completion.base import RemoteChatCompletionModel -from generate.chat_completion.cost_caculator import CostCalculator, GeneralCostCalculator from generate.chat_completion.message import Prompt from generate.chat_completion.message.converter import MessageConverter from generate.chat_completion.message.core import ( @@ -29,21 +28,13 @@ from generate.chat_completion.message.exception import MessageTypeError from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput, FinishReason, Usage from generate.chat_completion.stream_manager import StreamManager -from generate.chat_completion.tool import Tool, ToolCallMixin +from generate.chat_completion.tool import SupportToolCall, Tool from generate.http import HttpClient, HttpxPostKwargs from generate.model import ModelParameters, RemoteModelParametersDict from generate.platforms import AnthropicSettings -from generate.types import ModelPrice, OrIterable, Probability, Temperature +from generate.types import OrIterable, Probability, Temperature from generate.utils import ensure_iterable -AnthropicModelPrice: ModelPrice = { - 'claude-instant': (0.8, 2.4), - 'claude-2': (8, 24), - 'claude-3-haiku': (0.25, 1.25), - 'claude-3-sonnet': (3, 15), - 'claude-3-opus': (15, 75), -} - class AnthropicTool(TypedDict): name: str @@ -51,6 +42,11 @@ class AnthropicTool(TypedDict): input_schema: Dict[str, Any] +class AnthropicToolChoice(TypedDict): + type: Literal['auto', 'any', 'tool'] + name: NotRequired[str] + + class AnthropicChatParameters(ModelParameters): system: Optional[str] = None max_tokens: PositiveInt = 1024 @@ -60,7 +56,7 @@ class AnthropicChatParameters(ModelParameters): top_p: Optional[Probability] = None top_k: Optional[PositiveInt] = None tools: Optional[List[AnthropicTool]] = None - tool_choice: Optional[str] = None + tool_choice: Optional[AnthropicToolChoice] = None class AnthropicParametersDict(RemoteModelParametersDict, total=False): @@ -72,10 +68,12 @@ class AnthropicParametersDict(RemoteModelParametersDict, total=False): top_p: Optional[Probability] top_k: Optional[PositiveInt] tools: Optional[List[AnthropicTool]] - tool_choice: Optional[str] + tool_choice: Optional[AnthropicToolChoice] class AnthropicMessageConverter(MessageConverter): + allowed_message_types = [UserMessage, AssistantMessage, UserMultiPartMessage, ToolMessage] + def __init__(self, http_client: HttpClient) -> None: super().__init__() self.http_client = http_client @@ -123,10 +121,10 @@ def convert_user_multi_part_message(self, message: UserMultiPartMessage) -> Dict return message_dict def convert_system_message(self, message: SystemMessage) -> Dict[str, Any]: - raise MessageTypeError(message, (UserMessage, AssistantMessage, UserMultiPartMessage)) + raise MessageTypeError(message, self.allowed_message_types) def convert_function_message(self, message: FunctionMessage) -> Dict[str, Any]: - raise MessageTypeError(message, (UserMessage, AssistantMessage, UserMultiPartMessage)) + raise MessageTypeError(message, self.allowed_message_types) def convert_tool_message(self, message: ToolMessage) -> Dict[str, Any]: tool_result: dict = { @@ -142,18 +140,10 @@ def convert_tool_message(self, message: ToolMessage) -> Dict[str, Any]: 'content': [tool_result], } - def convert_messages(self, messages: Messages, tool_choice: str | None = None) -> List[Dict[str, Any]]: - messages_dict = super().convert_messages(messages) - if tool_choice and self.handle_tool_choice: - for message_dict in messages_dict[::-1]: - if message_dict['role'] == 'user': - message_dict['content'] += f'\nUse the {tool_choice} tool in your response.' - break - return messages_dict - -class AnthropicChat(RemoteChatCompletionModel, ToolCallMixin): +class AnthropicChat(RemoteChatCompletionModel, SupportToolCall): model_type: ClassVar[str] = 'anthropic' + tools_beta_version: ClassVar[str] = 'tools-2024-05-16' available_models: ClassVar[List[str]] = [ 'claude-2.1', 'claude-2.0', @@ -174,20 +164,17 @@ def __init__( settings: AnthropicSettings | None = None, http_client: HttpClient | None = None, message_converter: AnthropicMessageConverter | None = None, - cost_calculator: CostCalculator | None = None, ) -> None: parameters = parameters or AnthropicChatParameters() settings = settings or AnthropicSettings() # type: ignore http_client = http_client or HttpClient() message_converter = message_converter or AnthropicMessageConverter(http_client) - cost_calculator = cost_calculator or GeneralCostCalculator(AnthropicModelPrice, exchange_rate=7.3) super().__init__( model=model, parameters=parameters, settings=settings, http_client=http_client, message_converter=message_converter, - cost_calculator=cost_calculator, ) @override @@ -219,14 +206,14 @@ def _get_request_parameters( if isinstance(messages[0], SystemMessage): parameters.system = messages[0].content messages = messages[1:] - anthropic_messages = self.message_converter.convert_messages(messages, parameters.tool_choice) + anthropic_messages = self.message_converter.convert_messages(messages) headers = { 'Content-Type': 'application/json', 'anthropic-version': self.settings.api_version, 'x-api-key': self.settings.api_key.get_secret_value(), } if tool_use := bool(parameters.tools): - headers['anthropic-beta'] = 'tools-2024-04-04' + headers['anthropic-beta'] = self.tools_beta_version json_dict = parameters.custom_model_dump() json_dict['model'] = self.model @@ -268,8 +255,6 @@ def _process_stream_response( stream_manager.delta = '' stream_manager.finish_reason = self._parse_finish_reason(delta_dict) stream_manager.usage.output_tokens = response['usage']['output_tokens'] - if stream_manager.usage.input_tokens is not None: - stream_manager.usage.cost = self.cost(stream_manager.usage.input_tokens, stream_manager.usage.output_tokens) return stream_manager.build_stream_output() stream_manager.delta = response['delta']['text'] @@ -311,8 +296,7 @@ def _parse_usage(self, response: dict[str, Any]) -> Usage: input_tokens = response['usage']['input_tokens'] output_tokens = response['usage']['output_tokens'] - cost = self.cost(input_tokens, output_tokens) - return Usage(input_tokens=input_tokens, output_tokens=output_tokens, cost=cost) + return Usage(input_tokens=input_tokens, output_tokens=output_tokens) def _parse_finish_reason(self, response: dict[str, Any]) -> FinishReason | None: finish_reason_mapping = { diff --git a/generate/chat_completion/models/azure.py b/generate/chat_completion/models/azure.py index abd2bad..e406d4e 100644 --- a/generate/chat_completion/models/azure.py +++ b/generate/chat_completion/models/azure.py @@ -6,9 +6,13 @@ from generate.chat_completion.message import Prompt from generate.chat_completion.message.core import Messages -from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput -from generate.chat_completion.models.openai import OpenAIChatParameters, OpenAIChatParametersDict -from generate.chat_completion.models.openai_like import OpenAILikeChat, OpenAIMessageConverter +from generate.chat_completion.model_output import ChatCompletionStreamOutput +from generate.chat_completion.models.openai_like import ( + OpenAIChatParameters, + OpenAIChatParametersDict, + OpenAILikeChat, + OpenAIMessageConverter, +) from generate.chat_completion.stream_manager import StreamManager from generate.http import HttpClient, HttpxPostKwargs from generate.platforms.azure import AzureSettings @@ -40,14 +44,6 @@ def __init__( model=model, parameters=parameters, settings=settings, http_client=http_client, message_converter=message_converter ) - @override - def generate(self, prompt: Prompt, **kwargs: Unpack[OpenAIChatParametersDict]) -> ChatCompletionOutput: - return super().generate(prompt, **kwargs) - - @override - async def async_generate(self, prompt: Prompt, **kwargs: Unpack[OpenAIChatParametersDict]) -> ChatCompletionOutput: - return await super().async_generate(prompt, **kwargs) - @override def stream_generate( self, prompt: Prompt, **kwargs: Unpack[OpenAIChatParametersDict] diff --git a/generate/chat_completion/models/baichuan.py b/generate/chat_completion/models/baichuan.py index fab53df..45f72c6 100644 --- a/generate/chat_completion/models/baichuan.py +++ b/generate/chat_completion/models/baichuan.py @@ -1,34 +1,54 @@ from __future__ import annotations -from typing import Any, AsyncIterator, ClassVar, Iterator, List, Optional +from typing import Any, AsyncIterator, ClassVar, Iterator, List, Literal, Optional from pydantic import Field -from typing_extensions import Annotated, Unpack, override +from typing_extensions import Annotated, NotRequired, TypedDict, Unpack, override from generate.chat_completion.base import RemoteChatCompletionModel -from generate.chat_completion.cost_caculator import CostCalculator, GeneralCostCalculator from generate.chat_completion.message import ( AssistantMessage, Messages, Prompt, SystemMessage, + ToolMessage, UserMessage, ) -from generate.chat_completion.message.converter import SimpleMessageConverter +from generate.chat_completion.message.converter import MessageConverter, SimpleMessageConverter +from generate.chat_completion.message.core import FunctionCall, FunctionMessage, ToolCall, UserMultiPartMessage +from generate.chat_completion.message.exception import MessageTypeError from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput, FinishReason, Usage +from generate.chat_completion.models.openai_like import SupportOpenAIToolCall from generate.chat_completion.stream_manager import StreamManager +from generate.chat_completion.tool import FunctionJsonSchema from generate.http import ( HttpClient, HttpxPostKwargs, ) from generate.model import ModelParameters, RemoteModelParametersDict from generate.platforms.baichuan import BaichuanSettings -from generate.types import ModelPrice, Probability, Temperature +from generate.types import Probability, Temperature -BaichuanModelPrice: ModelPrice = { - 'Baichuan2-Turbo-192k': (16, 16), - 'Baichuan2-Turbo': (8, 8), -} + +class BaichuanResponseFormat(TypedDict): + type: Literal['json_object'] + + +class BaichuanRetrieval(TypedDict): + kb_ids: List[str] + answer_model: NotRequired[str] + + +class BaichuanWebSearch(TypedDict): + enable: NotRequired[bool] + search_mode: NotRequired[str] + + +class BaichuanTool(TypedDict): + type: Literal['retrieval', 'web_search', 'function'] + retrieval: NotRequired[BaichuanRetrieval] + web_search: NotRequired[BaichuanWebSearch] + function: NotRequired[FunctionJsonSchema] class BaichuanChatParameters(ModelParameters): @@ -36,7 +56,9 @@ class BaichuanChatParameters(ModelParameters): top_k: Optional[Annotated[int, Field(ge=0)]] = None top_p: Optional[Probability] = None max_tokens: Optional[Annotated[int, Field(ge=0)]] = None - search: Optional[bool] = Field(default=None, alias='with_search_enhance') + response_format: Optional[BaichuanResponseFormat] = None + tools: Optional[List[BaichuanTool]] = None + tool_choice: Optional[str] = None class BaichuanChatParametersDict(RemoteModelParametersDict, total=False): @@ -44,12 +66,71 @@ class BaichuanChatParametersDict(RemoteModelParametersDict, total=False): top_k: Optional[int] top_p: Optional[Probability] max_tokens: Optional[int] - search: Optional[bool] + response_format: Optional[BaichuanResponseFormat] + tools: Optional[List[BaichuanTool]] + tool_choice: Optional[str] + + +class BaichuanMessageConverter(MessageConverter): + allowed_message_types = [SystemMessage, UserMessage, AssistantMessage, ToolMessage] + + def convert_system_message(self, message: SystemMessage) -> dict[str, Any]: + return { + 'role': 'system', + 'content': message.content, + } + + def convert_user_message(self, message: UserMessage) -> dict[str, Any]: + return { + 'role': 'user', + 'content': message.content, + } + + def convert_tool_message(self, message: ToolMessage) -> dict[str, Any]: + return { + 'role': 'tool', + 'tool_call_id': message.tool_call_id, + 'content': message.content, + } + + def convert_assistant_message(self, message: AssistantMessage) -> dict[str, Any]: + base_dict = { + 'role': 'assistant', + 'content': message.content or None, + } + if message.tool_calls: + tool_calls = [ + { + 'id': tool_call.id, + 'type': 'function', + 'function': { + 'name': tool_call.function.name, + 'arguments': tool_call.function.arguments, + }, + } + for tool_call in message.tool_calls + ] + base_dict['tool_calls'] = tool_calls + if message.function_call: + raise ValueError('Function calls are not supported in Baichuan') + return base_dict + def convert_user_multi_part_message(self, message: UserMultiPartMessage) -> dict[str, Any]: + raise MessageTypeError(message, allowed_message_type=self.allowed_message_types) -class BaichuanChat(RemoteChatCompletionModel): + def convert_function_message(self, message: FunctionMessage) -> dict[str, Any]: + raise MessageTypeError(message, allowed_message_type=self.allowed_message_types) + + +class BaichuanChat(RemoteChatCompletionModel, SupportOpenAIToolCall): model_type: ClassVar[str] = 'baichuan' - available_models: ClassVar[List[str]] = ['Baichuan2-Turbo', 'Baichuan2-Turbo-192k'] + available_models: ClassVar[List[str]] = [ + 'Baichuan2-Turbo', + 'Baichuan2-Turbo-192k', + 'Baichuan3-Turbo', + 'Baichuan3-Turbo-128k', + 'Baichuan4', + ] parameters: BaichuanChatParameters settings: BaichuanSettings @@ -57,25 +138,22 @@ class BaichuanChat(RemoteChatCompletionModel): def __init__( self, - model: str = 'Baichuan2-Turbo', + model: str = 'Baichuan3-Turbo', parameters: BaichuanChatParameters | None = None, settings: BaichuanSettings | None = None, http_client: HttpClient | None = None, - message_converter: SimpleMessageConverter | None = None, - cost_calculator: CostCalculator | None = None, + message_converter: MessageConverter | None = None, ) -> None: parameters = parameters or BaichuanChatParameters() settings = settings or BaichuanSettings() # type: ignore http_client = http_client or HttpClient() - message_converter = message_converter or SimpleMessageConverter() - cost_calculator = cost_calculator or GeneralCostCalculator(BaichuanModelPrice) + message_converter = message_converter or BaichuanMessageConverter() super().__init__( model=model, parameters=parameters, settings=settings, http_client=http_client, message_converter=message_converter, - cost_calculator=cost_calculator, ) @override @@ -129,7 +207,7 @@ def _get_request_parameters( def _process_reponse(self, response: dict[str, Any]) -> ChatCompletionOutput: return ChatCompletionOutput( model_info=self.model_info, - message=self._parse_assistant_message(response), + message=self._parse_assistant_message(response['choices'][0]['message']), finish_reason=self._parse_finish_reason(response), usage=self._parse_usage(response), extra=self._parse_extra(response), @@ -139,22 +217,35 @@ def _process_reponse(self, response: dict[str, Any]) -> ChatCompletionOutput: def _process_stream_response( self, response: dict[str, Any], stream_manager: StreamManager ) -> ChatCompletionStreamOutput | None: - stream_manager.delta = response['choices'][0]['delta']['content'] - stream_manager.finish_reason = self._parse_finish_reason(response) + delta_dict = response['choices'][0].get('delta', {}) + self._update_delta(delta_dict, stream_manager=stream_manager) stream_manager.extra = self._parse_extra(response) stream_manager.usage = self._parse_usage(response) + stream_manager.finish_reason = self._parse_finish_reason(response) return stream_manager.build_stream_output() - def _parse_assistant_message(self, response: dict[str, Any]) -> AssistantMessage: - return AssistantMessage(content=response['choices'][0]['message']['content']) + def _parse_assistant_message(self, message: dict[str, Any]) -> AssistantMessage: + if tool_calls_dict := message.get('tool_calls'): + tool_calls = [ + ToolCall( + id=tool_call['id'], + function=FunctionCall( + name=tool_call['function'].get('name') or '', + arguments=tool_call['function']['arguments'], + ), + ) + for tool_call in tool_calls_dict + ] + else: + tool_calls = None + return AssistantMessage(content=message.get('content') or '', tool_calls=tool_calls) def _parse_usage(self, response: dict[str, Any]) -> Usage: usage = response.get('usage') if usage is not None: input_tokens = usage['prompt_tokens'] output_tokens = usage['completion_tokens'] - cost = self.cost(input_tokens, output_tokens) - return Usage(input_tokens=input_tokens, output_tokens=output_tokens, cost=cost) + return Usage(input_tokens=input_tokens, output_tokens=output_tokens) return Usage() def _parse_finish_reason(self, response: dict[str, Any]) -> FinishReason | None: @@ -167,3 +258,19 @@ def _parse_finish_reason(self, response: dict[str, Any]) -> FinishReason | None: def _parse_extra(self, response: dict[str, Any]) -> dict[str, Any]: return {'response': response} + + def _update_delta(self, delta_dict: dict[str, Any], stream_manager: StreamManager) -> None: + delta_content: str = delta_dict.get('content') or '' + stream_manager.delta = delta_content + + if delta_dict.get('tool_calls'): + index = delta_dict['tool_calls'][0]['index'] + if index >= len(stream_manager.tool_calls or []): + new_tool_calls_message = self._parse_assistant_message(delta_dict).tool_calls + assert new_tool_calls_message is not None + if stream_manager.tool_calls is None: + stream_manager.tool_calls = [] + stream_manager.tool_calls.append(new_tool_calls_message[0]) + else: + assert stream_manager.tool_calls is not None + stream_manager.tool_calls[index].function.arguments += delta_dict['tool_calls'][0]['function']['arguments'] diff --git a/generate/chat_completion/models/dashscope.py b/generate/chat_completion/models/dashscope.py index e12d60b..50b684e 100644 --- a/generate/chat_completion/models/dashscope.py +++ b/generate/chat_completion/models/dashscope.py @@ -13,9 +13,9 @@ from generate.chat_completion.message.converter import SimpleMessageConverter from generate.chat_completion.message.core import FunctionCall, FunctionMessage, Messages, ToolCall, ToolMessage from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput, FinishReason, Usage -from generate.chat_completion.models.openai_like import OpenAITool, convert_to_openai_tool +from generate.chat_completion.models.openai_like import OpenAITool from generate.chat_completion.stream_manager import StreamManager -from generate.chat_completion.tool import Tool, ToolCallMixin +from generate.chat_completion.tool import SupportToolCall, Tool from generate.http import ( HttpClient, HttpxPostKwargs, @@ -86,7 +86,34 @@ def convert_assistant_message(self, message: AssistantMessage) -> Dict[str, Any] return base_dict -class DashScopeChat(RemoteChatCompletionModel, ToolCallMixin): +class DashScopeToolCallMixin(SupportToolCall): + parameters: DashScopeChatParameters + + @override + def process_messages_for_tool_call(self, messages: Messages) -> None: + tool_call_id_to_function_name = {} + new_messages = [] + for message in messages: + if isinstance(message, AssistantMessage) and message.tool_calls: + for tool_call in message.tool_calls: + tool_call_id_to_function_name[tool_call.id] = tool_call.function.name + if isinstance(message, ToolMessage): + message = FunctionMessage( + name=tool_call_id_to_function_name[message.tool_call_id], content=message.content or '' + ) + new_messages.append(message) + messages[:] = new_messages + + @override + def add_tools(self, tools: OrIterable[Tool]) -> None: + new_tools = [OpenAITool(type='function', function=tool.json_schema) for tool in ensure_iterable(tools)] + if self.parameters.tools is None: + self.parameters.tools = new_tools + else: + self.parameters.tools.extend(new_tools) + + +class DashScopeChat(RemoteChatCompletionModel, DashScopeToolCallMixin): model_type: ClassVar[str] = 'dashscope' available_models: ClassVar[List[str]] = ['qwen-turbo', 'qwen-plus', 'qwen-max', 'qwen-max-longcontext'] @@ -107,7 +134,11 @@ def __init__( http_client = http_client or HttpClient() message_converter = message_converter or DashScopeMessageConverter() super().__init__( - model=model, parameters=parameters, settings=settings, http_client=http_client, message_converter=message_converter + model=model, + parameters=parameters, + settings=settings, + http_client=http_client, + message_converter=message_converter, ) @override @@ -152,6 +183,7 @@ def _get_request_parameters( 'model': self.model, 'parameters': parameters.custom_model_dump(), } + params['parameters']['result_format'] = 'message' return { 'url': self.settings.api_base + '/services/aigc/text-generation/generation', 'headers': headers, @@ -160,74 +192,47 @@ def _get_request_parameters( @override def _process_reponse(self, response: ResponseValue) -> ChatCompletionOutput: + choice = response['output']['choices'][0] + message = choice['message'] return ChatCompletionOutput( model_info=self.model_info, - message=self._parse_assistant_message(response), + message=self._parse_assistant_message(message), usage=self._parse_usage(response), extra=self._parse_extra(response), - finish_reason=self._parse_finish_reason(response['choices'][0]), + finish_reason=self._parse_finish_reason(choice), ) @override def _process_stream_response( self, response: dict[str, Any], stream_manager: StreamManager ) -> ChatCompletionStreamOutput | None: - finish_reason = self._parse_finish_reason(response['output']) - reply = response['output']['text'] - stream_manager.usage = self._parse_usage(response) + choice = response['output']['choices'][0] + delta_dict = choice['message'] + self._update_delta(delta_dict, stream_manager=stream_manager) stream_manager.extra = self._parse_extra(response) - if finish_reason != 'null': - stream_manager.finish_reason = finish_reason - stream_manager.delta = '' - return stream_manager.build_stream_output() - stream_manager.delta = reply[len(stream_manager.content) :] + stream_manager.usage = self._parse_usage(response) + if choice['finish_reason'] != 'null': + stream_manager.finish_reason = self._parse_finish_reason(choice) return stream_manager.build_stream_output() - @override - def cost(self, input_tokens: int, output_tokens: int) -> Optional[float]: - total_tokens = input_tokens + output_tokens - model_price = { - 'qwen-turbo': 8, - 'qwen-plus': 20, - 'qwen-max': 120, - } - for model_name, price in model_price.items(): - if model_name in self.model: - return total_tokens * price / 1_000_000 - return None - - @override - def adapt_tool_calls(self, messages: Messages) -> None: - tool_call_id_to_function_name = {} - new_messages = [] - for message in messages: - if isinstance(message, AssistantMessage) and message.tool_calls: - for tool_call in message.tool_calls: - tool_call_id_to_function_name[tool_call.id] = tool_call.function.name - if isinstance(message, ToolMessage): - message = FunctionMessage( - name=tool_call_id_to_function_name[message.tool_call_id], content=message.content or '' - ) - new_messages.append(message) - messages[:] = new_messages - - @override - def add_tools(self, tools: OrIterable[Tool]) -> None: - new_tools = [convert_to_openai_tool(tool) for tool in ensure_iterable(tools)] - if self.parameters.tools is None: - self.parameters.tools = new_tools - else: - self.parameters.tools.extend(new_tools) + # reply = response['output']['text'] + # stream_manager.usage = self._parse_usage(response) + # stream_manager.extra = self._parse_extra(response) + # if choice['finish_reason'] != 'null': + # stream_manager.finish_reason = self._parse_finish_reason(choice) + # stream_manager.delta = '' + # return stream_manager.build_stream_output() + # stream_manager.delta = reply[len(stream_manager.content) :] + # return stream_manager.build_stream_output() def _parse_usage(self, response: dict[str, Any]) -> Usage: if usage := response.get('usage'): input_tokens = usage.get('input_tokens') output_tokens = usage.get('output_tokens') - return Usage(input_tokens=input_tokens, output_tokens=output_tokens, cost=self.cost(input_tokens, output_tokens)) + return Usage(input_tokens=input_tokens, output_tokens=output_tokens) return Usage() - def _parse_assistant_message(self, response: dict[str, Any]) -> AssistantMessage: - message = response['choices'][0]['message'] + def _parse_assistant_message(self, message: dict[str, Any]) -> AssistantMessage: if tool_calls_dict := message.get('tool_calls'): tool_calls = [ ToolCall( @@ -255,3 +260,19 @@ def _parse_finish_reason(self, choice: dict[str, Any]) -> FinishReason | None: return FinishReason(finish_reason) except (KeyError, IndexError, ValueError): return None + + def _update_delta(self, delta_dict: dict[str, Any], stream_manager: StreamManager) -> None: + delta_content: str = delta_dict.get('content') or '' + stream_manager.delta = delta_content[len(stream_manager.content) :] + + if delta_dict.get('tool_calls'): + index = delta_dict['tool_calls'][0]['index'] + if index >= len(stream_manager.tool_calls or []): + new_tool_calls_message = self._parse_assistant_message(delta_dict).tool_calls + assert new_tool_calls_message is not None + if stream_manager.tool_calls is None: + stream_manager.tool_calls = [] + stream_manager.tool_calls.append(new_tool_calls_message[0]) + else: + assert stream_manager.tool_calls is not None + stream_manager.tool_calls[index].function.arguments += delta_dict['tool_calls'][0]['function']['arguments'] diff --git a/generate/chat_completion/models/dashscope_multimodal.py b/generate/chat_completion/models/dashscope_multimodal.py deleted file mode 100644 index b15c636..0000000 --- a/generate/chat_completion/models/dashscope_multimodal.py +++ /dev/null @@ -1,232 +0,0 @@ -from __future__ import annotations - -import hashlib -import json -from io import BytesIO -from typing import AsyncIterator, ClassVar, Iterator, List, Literal, Optional - -from pydantic import Field, PositiveInt -from typing_extensions import Annotated, TypedDict, Unpack, override - -from generate.chat_completion.base import RemoteChatCompletionModel -from generate.chat_completion.message import ( - AssistantMessage, - ImagePart, - ImageUrlPart, - Message, - MessageTypeError, - Prompt, - SystemMessage, - TextPart, - UserMessage, - UserMultiPartMessage, - ensure_messages, -) -from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput -from generate.chat_completion.stream_manager import StreamManager -from generate.http import ( - HttpClient, - HttpGetKwargs, - HttpxPostKwargs, - ResponseValue, -) -from generate.model import ModelParameters, RemoteModelParametersDict -from generate.platforms.dashscope import DashScopeSettings -from generate.types import Probability - - -class DashScopeMultiModalChatParameters(ModelParameters): - seed: Optional[PositiveInt] = None - top_p: Optional[Probability] = Field(default=None, alias='TopP') - top_k: Optional[Annotated[int, Field(ge=0, le=100)]] = None - - -class DashScopeMultiModalChatParametersDict(RemoteModelParametersDict, total=False): - seed: Optional[PositiveInt] - top_p: Optional[Probability] - top_k: Optional[int] - - -class DashScopeMultiModalMessage(TypedDict): - role: str - content: list[dict[Literal['image', 'text'], str]] - - -class DashScopeMultiModalChat(RemoteChatCompletionModel): - model_type: ClassVar[str] = 'dashscope_multimodal' - available_models: ClassVar[List[str]] = ['qwen-vl-max', 'qwen-vl-plus'] - - parameters: DashScopeMultiModalChatParameters - settings: DashScopeSettings - - def __init__( - self, - model: str = 'qwen-vl-max', - parameters: DashScopeMultiModalChatParameters | None = None, - settings: DashScopeSettings | None = None, - http_client: HttpClient | None = None, - ) -> None: - parameters = parameters or DashScopeMultiModalChatParameters() - settings = settings or DashScopeSettings() # type: ignore - http_client = http_client or HttpClient() - super().__init__(model=model, parameters=parameters, settings=settings, http_client=http_client) - - def upload_image(self, image: bytes, image_format: str) -> str: - get_kwargs = HttpGetKwargs( - url=f'{self.settings.api_base}/uploads', - params={'action': 'getPolicy', 'model': self.model}, - headers={'Authorization': f'Bearer {self.settings.api_key.get_secret_value()}'}, - ) - response_data = self.http_client.get(get_kwargs) - upload_info = response_data.json()['data'] - - form_data = {} - form_data['OSSAccessKeyId'] = upload_info['oss_access_key_id'] - form_data['Signature'] = upload_info['signature'] - form_data['policy'] = upload_info['policy'] - hash_code = hashlib.md5(image).hexdigest() - form_data['key'] = upload_info['upload_dir'] + '/' + f'{hash_code}.{image_format}' - form_data['x-oss-object-acl'] = upload_info['x_oss_object_acl'] - form_data['x-oss-forbid-overwrite'] = upload_info['x_oss_forbid_overwrite'] - form_data['success_action_status'] = '200' - form_data['x-oss-content-type'] = f'image/{image_format}' - url = upload_info['upload_host'] - files = {'file': BytesIO(image)} - response = self.http_client.client.post( - url, - data=form_data, - files=files, - ) - response.raise_for_status() - return 'oss://' + form_data['key'] - - @override - def generate(self, prompt: Prompt, **kwargs: Unpack[DashScopeMultiModalChatParametersDict]) -> ChatCompletionOutput: - return super().generate(prompt, **kwargs) - - @override - async def async_generate( - self, prompt: Prompt, **kwargs: Unpack[DashScopeMultiModalChatParametersDict] - ) -> ChatCompletionOutput: - return await super().async_generate(prompt, **kwargs) - - @override - def stream_generate( - self, prompt: Prompt, **kwargs: Unpack[DashScopeMultiModalChatParametersDict] - ) -> Iterator[ChatCompletionStreamOutput]: - yield from super().stream_generate(prompt, **kwargs) - - @override - async def async_stream_generate( - self, prompt: Prompt, **kwargs: Unpack[DashScopeMultiModalChatParametersDict] - ) -> AsyncIterator[ChatCompletionStreamOutput]: - async for output in super().async_stream_generate(prompt, **kwargs): - yield output - - def _has_oss_file(self, message: DashScopeMultiModalMessage) -> bool: - for part in message['content']: - for k, v in part.items(): - if k == 'image' and v.startswith('oss://'): - return True - return False - - def _convert_message(self, message: Message) -> DashScopeMultiModalMessage: - if isinstance(message, UserMessage): - return {'role': 'user', 'content': [{'text': message.content}]} - if isinstance(message, AssistantMessage): - return {'role': 'assistant', 'content': [{'text': message.content}]} - if isinstance(message, SystemMessage): - return {'role': 'system', 'content': [{'text': message.content}]} - if isinstance(message, UserMultiPartMessage): - content = [] - for part in message.content: - if isinstance(part, TextPart): - content.append({'text': part.text}) - elif isinstance(part, ImageUrlPart): - content.append({'image': part.image_url.url}) - elif isinstance(part, ImagePart): - image_url = self.upload_image(part.image, part.image_format or 'png') - content.append({'image': image_url}) - else: - raise TypeError(f'Unsupported part type: {part}') - return {'role': 'user', 'content': content} - allowed_message_type = (UserMessage, AssistantMessage, SystemMessage, UserMultiPartMessage) - raise MessageTypeError(message, allowed_message_type=allowed_message_type) - - @override - def _get_request_parameters( - self, prompt: Prompt, stream: bool = False, **kwargs: Unpack[DashScopeMultiModalChatParametersDict] - ) -> HttpxPostKwargs: - messages = ensure_messages(prompt) - parameters = self.parameters.clone_with_changes(**kwargs) - dashscope_messages = [self._convert_message(message) for message in messages] - headers = { - 'Authorization': self.settings.api_key.get_secret_value(), - 'Content-Type': 'application/json', - } - if any(self._has_oss_file(message) for message in dashscope_messages): - headers['X-DashScope-OssResourceResolve'] = 'enable' - params = { - 'input': { - 'messages': dashscope_messages, - }, - 'model': self.model, - 'parameters': parameters.custom_model_dump(), - } - if stream: - headers['Accept'] = 'text/event-stream' - - return { - 'url': f'{self.settings.api_base}/services/aigc/multimodal-generation/generation', - 'headers': headers, - 'json': params, - } - - @override - def _process_reponse(self, response: ResponseValue) -> ChatCompletionOutput: - choice = response['output']['choices'][0] - content_list = choice['message']['content'] - text = '' - result_images = [] - for content in content_list: - for k, v in content.items(): - if k != 'result_image': - text += v - else: - result_images.append(v) - return ChatCompletionOutput( - model_info=self.model_info, - finish_reason=choice.get('finish_reason'), - message=AssistantMessage(content=text), - cost=None, - extra={ - 'usage': response['usage'], - 'request_id': response['request_id'], - 'content': content_list, - 'result_images': result_images, - }, - ) - - @override - def _process_stream_response(self, line: str, stream_manager: StreamManager) -> Optional[ChatCompletionStreamOutput]: - try: - data = json.loads(line) - except json.JSONDecodeError: - return None - - choice = data['output']['choices'][0] - finish_reason = choice['finish_reason'] - reply = choice['message']['content'][0]['text'] - usage = data['usage'] - request_id = data['request_id'] - extra = { - 'usage': usage, - 'response_id': request_id, - } - if finish_reason == 'stop': - stream_manager.finish_reason = 'stop' - stream_manager.delta = '' - stream_manager.extra.update(extra) - return stream_manager.build_stream_output() - stream_manager.delta = reply[len(stream_manager.content) :] - return stream_manager.build_stream_output() diff --git a/generate/chat_completion/models/deepseek.py b/generate/chat_completion/models/deepseek.py index 99f277f..05b1458 100644 --- a/generate/chat_completion/models/deepseek.py +++ b/generate/chat_completion/models/deepseek.py @@ -5,19 +5,14 @@ from pydantic import Field, PositiveInt from typing_extensions import Annotated, Unpack, override -from generate.chat_completion.cost_caculator import CostCalculator, GeneralCostCalculator from generate.chat_completion.message import Prompt +from generate.chat_completion.message.converter import MessageConverter from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput -from generate.chat_completion.models.openai_like import OpenAILikeChat, OpenAIMessageConverter +from generate.chat_completion.models.openai_like import OpenAILikeChat from generate.http import HttpClient from generate.model import ModelParameters, RemoteModelParametersDict from generate.platforms import DeepSeekSettings -from generate.types import ModelPrice, Probability - -DeepSeekModelPrice: ModelPrice = { - 'deepseek-chat': (1, 2), - 'deepseek-coder': (1, 2), -} +from generate.types import Probability class DeepSeekChatParameters(ModelParameters): @@ -48,7 +43,6 @@ class DeepSeekChat(OpenAILikeChat): parameters: DeepSeekChatParameters settings: DeepSeekSettings - message_converter: OpenAIMessageConverter def __init__( self, @@ -56,19 +50,16 @@ def __init__( parameters: DeepSeekChatParameters | None = None, settings: DeepSeekSettings | None = None, http_client: HttpClient | None = None, - message_converter: OpenAIMessageConverter | None = None, - cost_calculator: CostCalculator | None = None, + message_converter: MessageConverter | None = None, ) -> None: parameters = parameters or DeepSeekChatParameters() settings = settings or DeepSeekSettings() # type: ignore - cost_calculator = cost_calculator or GeneralCostCalculator(DeepSeekModelPrice) super().__init__( model=model, parameters=parameters, settings=settings, - http_client=http_client, message_converter=message_converter, - cost_calculator=cost_calculator, + http_client=http_client, ) @override diff --git a/generate/chat_completion/models/hunyuan.py b/generate/chat_completion/models/hunyuan.py deleted file mode 100644 index 43da0de..0000000 --- a/generate/chat_completion/models/hunyuan.py +++ /dev/null @@ -1,216 +0,0 @@ -from __future__ import annotations - -import base64 -import hashlib -import hmac -import json -import time -import uuid -from typing import Any, AsyncIterator, ClassVar, Iterator, Literal, Optional - -from typing_extensions import Self, TypedDict, Unpack, override - -from generate.chat_completion.base import RemoteChatCompletionModel -from generate.chat_completion.message import ( - AssistantMessage, - Message, - Messages, - MessageTypeError, - Prompt, - SystemMessage, - UserMessage, - ensure_messages, -) -from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput -from generate.chat_completion.stream_manager import StreamManager -from generate.http import ( - HttpClient, - HttpxPostKwargs, - ResponseValue, - UnexpectedResponseError, -) -from generate.model import ModelParameters, RemoteModelParametersDict -from generate.platforms.hunyuan import HunyuanSettings -from generate.types import Probability, Temperature - - -class HunyuanMessage(TypedDict): - role: Literal['user', 'assistant'] - content: str - - -class HunyuanChatParameters(ModelParameters): - temperature: Optional[Temperature] = None - top_p: Optional[Probability] = None - - -class HunyuanChatParametersDict(RemoteModelParametersDict, total=False): - temperature: Optional[Temperature] - top_p: Optional[Probability] - - -def _convert_message_to_hunyuan_message(message: Message) -> HunyuanMessage: - if isinstance(message, UserMessage): - return {'role': 'user', 'content': message.content} - if isinstance(message, AssistantMessage): - return {'role': 'assistant', 'content': message.content} - raise MessageTypeError(message, (UserMessage, AssistantMessage)) - - -def _convert_messages(messages: Messages) -> list[HunyuanMessage]: - if isinstance(system_message := messages[0], SystemMessage): - prepend_messages = [UserMessage(content=system_message.content), AssistantMessage(content='好的')] - messages = prepend_messages + messages[1:] - return [_convert_message_to_hunyuan_message(message) for message in messages] - - -class HunyuanChat(RemoteChatCompletionModel): - model_type: ClassVar[str] = 'hunyuan' - - parameters: HunyuanChatParameters - settings: HunyuanSettings - - def __init__( - self, - parameters: HunyuanChatParameters | None = None, - settings: HunyuanSettings | None = None, - http_client: HttpClient | None = None, - ) -> None: - parameters = parameters or HunyuanChatParameters() - settings = settings or HunyuanSettings() # type: ignore - http_client = http_client or HttpClient() - super().__init__(model='', parameters=parameters, settings=settings, http_client=http_client) - - @override - def generate(self, prompt: Prompt, **kwargs: Unpack[HunyuanChatParametersDict]) -> ChatCompletionOutput: - return super().generate(prompt, **kwargs) - - @override - async def async_generate(self, prompt: Prompt, **kwargs: Unpack[HunyuanChatParametersDict]) -> ChatCompletionOutput: - return await super().async_generate(prompt, **kwargs) - - @override - def stream_generate( - self, prompt: Prompt, **kwargs: Unpack[HunyuanChatParametersDict] - ) -> Iterator[ChatCompletionStreamOutput]: - yield from super().stream_generate(prompt, **kwargs) - - @override - async def async_stream_generate( - self, prompt: Prompt, **kwargs: Unpack[HunyuanChatParametersDict] - ) -> AsyncIterator[ChatCompletionStreamOutput]: - async for output in super().async_stream_generate(prompt, **kwargs): - yield output - - @override - def _process_reponse(self, response: ResponseValue) -> ChatCompletionOutput: - if response.get('error'): - raise UnexpectedResponseError(response) - return ChatCompletionOutput( - model_info=self.model_info, - message=AssistantMessage(content=response['choices'][0]['messages']['content']), - finish_reason=response['choices'][0]['finish_reason'], - cost=self.calculate_cost(response['usage']), - extra={'usage': response['usage']}, - ) - - @override - def _get_request_parameters( - self, prompt: Prompt, stream: bool = False, **kwargs: Unpack[HunyuanChatParametersDict] - ) -> HttpxPostKwargs: - messages = ensure_messages(prompt) - parameters = self.parameters.clone_with_changes(**kwargs) - hunyuan_messages = _convert_messages(messages) - if stream: - json_dict = self.generate_json_dict(hunyuan_messages, parameters, stream=True) - else: - json_dict = self.generate_json_dict(hunyuan_messages, parameters) - signature = self.generate_signature(self.generate_sign_parameters(json_dict)) - headers = { - 'Content-Type': 'application/json', - 'Authorization': signature, - } - return { - 'url': self.settings.completion_api, - 'headers': headers, - 'json': json_dict, - } - - @override - def _process_stream_response(self, line: str, stream_manager: StreamManager) -> ChatCompletionStreamOutput | None: - try: - data = json.loads(line) - except json.JSONDecodeError: - return None - - message_dict = data['choices'][0] - stream_manager.delta = message_dict['delta']['content'] - if message_dict['finish_reason']: - stream_manager.finish_reason = message_dict['finish_reason'] - stream_manager.cost = self.calculate_cost(data['usage']) - stream_manager.extra.update(usage=data['usage']) - return stream_manager.build_stream_output() - - def generate_json_dict( - self, messages: list[HunyuanMessage], parameters: HunyuanChatParameters, stream: bool = False - ) -> dict[str, Any]: - timestamp = int(time.time()) + 10000 - json_dict = { - 'app_id': self.settings.app_id, - 'secret_id': self.settings.secret_id.get_secret_value(), - 'query_id': 'query_id_' + str(uuid.uuid4()), - 'messages': messages, - 'timestamp': timestamp, - 'expired': timestamp + 24 * 60 * 60, - 'stream': int(stream), - } - json_dict.update(parameters.custom_model_dump()) - return json_dict - - @staticmethod - def generate_sign_parameters(json_dict: dict[str, Any]) -> dict[str, Any]: - params = { - 'app_id': json_dict['app_id'], - 'secret_id': json_dict['secret_id'], - 'query_id': json_dict['query_id'], - 'stream': json_dict['stream'], - } - if 'temperature' in json_dict: - params['temperature'] = f'{json_dict["temperature"]:g}' - if 'top_p' in json_dict: - params['top_p'] = f'{json_dict["top_p"]:g}' - message_str = ','.join( - ['{{"role":"{}","content":"{}"}}'.format(message['role'], message['content']) for message in json_dict['messages']] - ) - message_str = '[{}]'.format(message_str) - params['messages'] = message_str - params['timestamp'] = str(json_dict['timestamp']) - params['expired'] = str(json_dict['expired']) - return params - - def generate_signature(self, sign_parameters: dict[str, Any]) -> str: - sort_dict = sorted(sign_parameters.keys()) - sign_str = self.settings.sign_api + '?' - for key in sort_dict: - sign_str = sign_str + key + '=' + str(sign_parameters[key]) + '&' - sign_str = sign_str[:-1] - hmacstr = hmac.new( - self.settings.secret_key.get_secret_value().encode('utf-8'), sign_str.encode('utf-8'), hashlib.sha1 - ).digest() - signature = base64.b64encode(hmacstr) - return signature.decode('utf-8') - - def calculate_cost(self, usage: dict[str, Any]) -> float: - return (usage['total_tokens'] / 1000) * 0.1 - - @property - @override - def name(self) -> str: - return 'v1' - - @classmethod - @override - def from_name(cls, name: str) -> Self: - if name != 'v1': - raise ValueError('Unknown name: {}, only support v1'.format(name)) - return cls() diff --git a/generate/chat_completion/models/minimax.py b/generate/chat_completion/models/minimax.py index 32c466a..2b27b85 100644 --- a/generate/chat_completion/models/minimax.py +++ b/generate/chat_completion/models/minimax.py @@ -9,9 +9,10 @@ from generate.chat_completion.message import ( Prompt, ) -from generate.chat_completion.message.core import AssistantMessage, FunctionMessage, Messages, ToolCall, ToolMessage -from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput -from generate.chat_completion.models.openai_like import OpenAILikeChat, OpenAIMessage, OpenAITool +from generate.chat_completion.message.converter import MessageConverter +from generate.chat_completion.message.core import Messages +from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput, FinishReason, Usage +from generate.chat_completion.models.openai_like import OpenAILikeChat, OpenAITool from generate.http import ( HttpClient, HttpxPostKwargs, @@ -46,7 +47,7 @@ class MinimaxChatParametersDict(RemoteModelParametersDict, total=False): class MinimaxChat(OpenAILikeChat): model_type: ClassVar[str] = 'minimax' - available_models: ClassVar[List[str]] = ['abab5.5-chat', 'abab5.5s-chat', 'abab6-chat'] + available_models: ClassVar[List[str]] = ['abab5.5-chat', 'abab5.5s-chat', 'abab6-chat', 'abab6.5-chat'] CHAT_COMPLETION_ENDPOINT: ClassVar[str] = '/text/chatcompletion_v2' parameters: MinimaxChatParameters @@ -58,15 +59,22 @@ def __init__( parameters: MinimaxChatParameters | None = None, settings: MinimaxSettings | None = None, http_client: HttpClient | None = None, + message_converter: MessageConverter | None = None, ) -> None: parameters = parameters or MinimaxChatParameters() settings = settings or MinimaxSettings() # type: ignore http_client = http_client or HttpClient() - super().__init__(model=model, parameters=parameters, settings=settings, http_client=http_client) + super().__init__( + model=model, + parameters=parameters, + settings=settings, + http_client=http_client, + message_converter=message_converter, + ) @override - def _get_request_parameters(self, prompt: Prompt, stream: bool = False, **kwargs: Any) -> HttpxPostKwargs: - http_kwargs = super()._get_request_parameters(prompt, stream, **kwargs) + def _get_request_parameters(self, messages: Messages, stream: bool = False, **kwargs: Any) -> HttpxPostKwargs: + http_kwargs = super()._get_request_parameters(messages, stream, **kwargs) http_kwargs['url'] = self.settings.api_base + self.CHAT_COMPLETION_ENDPOINT if 'tools' in http_kwargs['json']: # Serialize jsonschema dict to JSON string for Minimax compatibility @@ -78,55 +86,35 @@ def _get_request_parameters(self, prompt: Prompt, stream: bool = False, **kwargs return http_kwargs @override - def _determine_finish_reason(self, response: Dict[str, Any]) -> str | None: + def _parse_finish_reason(self, response: Dict[str, Any]) -> FinishReason | None: choice = response['choices'][0] if 'finish_reason' in choice and 'delta' not in choice: - return choice['finish_reason'] + return FinishReason(choice['finish_reason']) return None @override - def _convert_to_openai_messages(self, messages: Messages) -> List[OpenAIMessage]: - converted_messages = [] - temp_tool_call_id = self.generate_tool_call_id() - for message in messages: - # Convert FunctionMessage to ToolMessage with self-generated tool_call_id - if isinstance(message, AssistantMessage): - if message.function_call is not None: - tool_call = ToolCall( - id=temp_tool_call_id, - function=message.function_call, - ) - message.tool_calls = [tool_call] - message.function_call = None - elif isinstance(message, FunctionMessage): - tool_message = ToolMessage( - name=message.name, - content=message.content, - tool_call_id=temp_tool_call_id, - ) - temp_tool_call_id = self.generate_tool_call_id() - converted_messages.append(tool_message) - continue - converted_messages.append(message.model_copy(deep=True)) - return super()._convert_to_openai_messages(converted_messages) + def _parse_usage(self, response: dict[str, Any]) -> Usage: + if usage := response.get('usage'): + return Usage(input_tokens=0, output_tokens=usage['total_tokens']) + return Usage() @override def generate(self, prompt: Prompt, **kwargs: Unpack[MinimaxChatParametersDict]) -> ChatCompletionOutput: - return super().generate(prompt, **kwargs) + return super().generate(prompt, **kwargs) # type: ignore @override async def async_generate(self, prompt: Prompt, **kwargs: Unpack[MinimaxChatParametersDict]) -> ChatCompletionOutput: - return await super().async_generate(prompt, **kwargs) + return await super().async_generate(prompt, **kwargs) # type: ignore @override def stream_generate( self, prompt: Prompt, **kwargs: Unpack[MinimaxChatParametersDict] ) -> Iterator[ChatCompletionStreamOutput]: - yield from super().stream_generate(prompt, **kwargs) + yield from super().stream_generate(prompt, **kwargs) # type: ignore @override async def async_stream_generate( self, prompt: Prompt, **kwargs: Unpack[MinimaxChatParametersDict] ) -> AsyncIterator[ChatCompletionStreamOutput]: - async for stream_output in super().async_stream_generate(prompt, **kwargs): + async for stream_output in super().async_stream_generate(prompt, **kwargs): # type: ignore yield stream_output diff --git a/generate/chat_completion/models/minimax_pro.py b/generate/chat_completion/models/minimax_pro.py deleted file mode 100644 index cb2273f..0000000 --- a/generate/chat_completion/models/minimax_pro.py +++ /dev/null @@ -1,381 +0,0 @@ -from __future__ import annotations - -import json -from typing import Any, AsyncIterator, ClassVar, Dict, Iterator, List, Literal, Optional, cast - -from pydantic import Field, PositiveInt, model_validator -from typing_extensions import Annotated, NotRequired, Self, TypedDict, Unpack, override - -from generate.chat_completion.base import RemoteChatCompletionModel -from generate.chat_completion.message import ( - AssistantMessage, - FunctionCall, - FunctionMessage, - Message, - MessageTypeError, - MessageValueError, - Prompt, - SystemMessage, - UserMessage, - ensure_messages, -) -from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput, Stream -from generate.chat_completion.stream_manager import StreamManager -from generate.chat_completion.tool import FunctionJsonSchema, Tool, ToolCallMixin -from generate.http import ( - HttpClient, - HttpxPostKwargs, - ResponseValue, - UnexpectedResponseError, -) -from generate.model import ModelInfo, ModelParameters, RemoteModelParametersDict -from generate.platforms.minimax import MinimaxSettings -from generate.types import OrIterable, Probability, Temperature -from generate.utils import ensure_iterable - - -class BotSettingDict(TypedDict): - bot_name: str - content: str - - -class GlyphDict(TypedDict): - type: str - raw_glpyh: str - json_properties: Dict[str, Any] - - -class ReplyConstrainsDict(TypedDict): - sender_type: str - sender_name: str - glyph: NotRequired[GlyphDict] - - -class MinimaxFunctionCall(TypedDict): - name: str - arguments: str - - -class MinimaxProMessage(TypedDict): - sender_type: Literal['USER', 'BOT', 'FUNCTION'] - sender_name: str - text: str - function_call: NotRequired[MinimaxFunctionCall] - - -class MinimaxProChatParameters(ModelParameters): - reply_constraints: ReplyConstrainsDict = {'sender_type': 'BOT', 'sender_name': 'MM智能助理'} - bot_setting: List[BotSettingDict] = [ - { - 'bot_name': 'MM智能助理', - 'content': 'MM智能助理是一款由MiniMax自研的,没有调用其他产品的接口的大型语言模型。MiniMax是一家中国科技公司,一直致力于进行大模型相关的研究。', - } - ] - temperature: Optional[Temperature] = None - top_p: Optional[Probability] = None - max_tokens: Annotated[Optional[PositiveInt], Field(serialization_alias='tokens_to_generate')] = None - mask_sensitive_info: Optional[bool] = None - sample_messages: Optional[List[MinimaxProMessage]] = None - functions: Optional[List[FunctionJsonSchema]] = None - search: Optional[bool] = None - plugins: Optional[List[str]] = None - - @model_validator(mode='after') - def check_bot_name(self) -> Self: - names: set[str] = {bot_setting['bot_name'] for bot_setting in self.bot_setting} - if (sender_name := self.reply_constraints['sender_name']) not in names: - raise ValueError(f'reply_constraints sender_name {sender_name} must be in bot_setting names: {names}') - return self - - @property - def bot_name(self) -> str | None: - if len(self.bot_setting) == 1: - return self.bot_setting[0]['bot_name'] - return None - - def set_system_prompt(self, system_prompt: str) -> None: - if len(self.bot_setting) == 1: - self.bot_setting[0]['content'] = system_prompt - else: - raise ValueError('set system_prompt is not supported when bot_setting has more than one bot') - - def set_bot_name(self, bot_name: str) -> None: - if len(self.bot_setting) == 1: - self.bot_setting[0]['bot_name'] = bot_name - self.reply_constraints['sender_name'] = bot_name - else: - raise ValueError('set bot_name is not supported when bot_setting has more than one bot') - - def custom_model_dump(self) -> dict[str, Any]: - output = super().custom_model_dump() - if 'temperature' in output: - output['temperature'] = max(0.01, output['temperature']) - if 'top_p' in output: - output['top_p'] = max(0.01, output['top_p']) - if 'search' in output: - original_plugins = output.get('plugins', []) - output['plugins'] = list(set(original_plugins + ['plugin_web_search'])) - return output - - -class MinimaxProChatParametersDict(RemoteModelParametersDict, total=False): - reply_constraints: ReplyConstrainsDict - bot_setting: List[BotSettingDict] - temperature: Optional[Temperature] - top_p: Optional[Probability] - max_tokens: Optional[PositiveInt] - mask_sensitive_info: Optional[bool] - sample_messages: Optional[List[MinimaxProMessage]] - functions: Optional[List[FunctionJsonSchema]] - search: Optional[bool] - plugins: Optional[List[str]] - - -def _convert_to_minimax_pro_message( - message: Message, default_bot_name: str | None = None, default_user_name: str = '用户' -) -> MinimaxProMessage: - if isinstance(message, UserMessage): - sender_name = message.name or default_user_name - return {'sender_type': 'USER', 'sender_name': sender_name, 'text': message.content} - - if isinstance(message, AssistantMessage): - sender_name = message.name or default_bot_name - if sender_name is None: - raise MessageValueError(message, 'bot name is required') - if message.function_call is None: - return { - 'sender_type': 'BOT', - 'sender_name': sender_name, - 'text': message.content, - } - return { - 'sender_type': 'BOT', - 'sender_name': sender_name, - 'text': message.content, - 'function_call': { - 'name': message.function_call.name, - 'arguments': message.function_call.arguments, - }, - } - - if isinstance(message, FunctionMessage): - return { - 'sender_type': 'FUNCTION', - 'sender_name': message.name, - 'text': message.content, - } - - raise MessageTypeError(message, allowed_message_type=(UserMessage, AssistantMessage, FunctionMessage)) - - -def _convert_to_message(message: MinimaxProMessage) -> AssistantMessage | FunctionMessage: - if 'function_call' in message: - return AssistantMessage( - name=message['sender_name'], - content=message['text'], - function_call=FunctionCall(name=message['function_call']['name'], arguments=message['function_call']['arguments']), - ) - if message['sender_type'] == 'BOT': - return AssistantMessage( - name=message['sender_name'], - content=message['text'], - ) - if message['sender_type'] == 'FUNCTION': - return FunctionMessage( - name=message['sender_name'], - content=message['text'], - ) - raise ValueError(f'unknown sender_type: {message["sender_type"]}') - - -class _StreamResponseProcessor: - def __init__(self, model_info: ModelInfo) -> None: - self.message: AssistantMessage | None = None - self.model_info = model_info - - def process(self, response: ResponseValue) -> ChatCompletionStreamOutput: - if response.get('usage'): - assert self.message is not None - return ChatCompletionStreamOutput( - model_info=self.model_info, - message=self.message, - finish_reason=response['choices'][0]['finish_reason'], - cost=minimax_calculate_cost(model_name=self.model_info.name, usage=response['usage']), - extra={ - 'input_sensitive': response['input_sensitive'], - 'output_sensitive': response['output_sensitive'], - 'usage': response['usage'], - }, - stream=Stream(delta='', control='finish'), - ) - - if self.message is None: - self.message = self.initial_message(response) - delta = self.message.content if isinstance(self.message, AssistantMessage) else '' - control = 'start' - else: - delta = self.update_existing_message(response) - control = 'continue' - - return ChatCompletionStreamOutput( - model_info=self.model_info, - message=self.message, - finish_reason=None, - stream=Stream(delta=delta, control=control), - ) - - def initial_message(self, response: ResponseValue) -> AssistantMessage: - output_messages = [_convert_to_message(i) for i in response['choices'][0]['messages']] - message = output_messages[-1] - return cast(AssistantMessage, message) - - def update_existing_message(self, response: ResponseValue) -> str: - output_messages = [_convert_to_message(i) for i in response['choices'][0]['messages']] - message = output_messages[-1] - if not isinstance(message, AssistantMessage): - return '' - - if message.function_call is not None: - delta = '' - self.message = message - return delta - - delta = message.content - self.message.content += message.content # type: ignore - return delta - - -def minimax_calculate_cost(model_name: str, usage: dict[str, int], num_web_search: int = 0) -> float | None: - if model_name == 'abab6-chat': - model_cost = 0.1 * (usage['total_tokens'] / 1000) - elif model_name == 'abab5.5-chat': - model_cost = 0.015 * (usage['total_tokens'] / 1000) - elif model_name == 'abab5.5s-chat': - model_cost = 0.005 * (usage['total_tokens'] / 1000) - else: - return None - return model_cost + (0.03 * num_web_search) - - -class MinimaxProChat(RemoteChatCompletionModel, ToolCallMixin): - model_type: ClassVar[str] = 'minimax_pro' - available_models: ClassVar[List[str]] = ['abab5.5-chat', 'abab5.5s-chat', 'abab6-chat'] - - parameters: MinimaxProChatParameters - settings: MinimaxSettings - - def __init__( - self, - model: str = 'abab5.5-chat', - parameters: MinimaxProChatParameters | None = None, - settings: MinimaxSettings | None = None, - http_client: HttpClient | None = None, - ) -> None: - parameters = parameters or MinimaxProChatParameters() - settings = settings or MinimaxSettings() # type: ignore - http_client = http_client or HttpClient() - if not settings.group_id: - raise ValueError( - 'group_id is required for MinimaxProChat, you can set it in settings or environment variable MINIMAX_GROUP_ID' - ) - super().__init__(model=model, parameters=parameters, settings=settings, http_client=http_client) - - self.default_user_name = '用户' - - @override - def generate(self, prompt: Prompt, **kwargs: Unpack[MinimaxProChatParametersDict]) -> ChatCompletionOutput: - return super().generate(prompt, **kwargs) - - @override - async def async_generate(self, prompt: Prompt, **kwargs: Unpack[MinimaxProChatParametersDict]) -> ChatCompletionOutput: - return await super().async_generate(prompt, **kwargs) - - @override - def stream_generate( - self, prompt: Prompt, **kwargs: Unpack[MinimaxProChatParametersDict] - ) -> Iterator[ChatCompletionStreamOutput]: - request_parameters = self._get_request_parameters(prompt, stream=True, **kwargs) - stream_processor = _StreamResponseProcessor(model_info=self.model_info) - for line in self.http_client.stream_post(request_parameters=request_parameters): - yield stream_processor.process(json.loads(line)) - - @override - async def async_stream_generate( - self, prompt: Prompt, **kwargs: Unpack[MinimaxProChatParametersDict] - ) -> AsyncIterator[ChatCompletionStreamOutput]: - request_parameters = self._get_request_parameters(prompt, stream=True, **kwargs) - stream_processor = _StreamResponseProcessor(model_info=self.model_info) - async for line in self.http_client.async_stream_post(request_parameters=request_parameters): - yield stream_processor.process(json.loads(line)) - - def add_tools(self, tools: OrIterable[Tool]) -> None: - new_functions = [tool.json_schema for tool in ensure_iterable(tools)] - if self.parameters.functions is None: - self.parameters.functions = new_functions - else: - self.parameters.functions.extend(new_functions) - - @override - def _process_reponse(self, response: ResponseValue) -> ChatCompletionOutput: - try: - messages: list[AssistantMessage | FunctionMessage] = [ - _convert_to_message(i) for i in response['choices'][0]['messages'] - ] - message = cast(AssistantMessage, messages[-1]) - finish_reason = response['choices'][0]['finish_reason'] - num_web_search = sum([1 for i in response['choices'][0]['messages'] if i['sender_name'] == 'plugin_web_search']) - return ChatCompletionOutput( - model_info=self.model_info, - message=message, - finish_reason=finish_reason, - cost=minimax_calculate_cost(model_name=self.name, usage=response['usage'], num_web_search=num_web_search), - extra={ - 'input_sensitive': response['input_sensitive'], - 'output_sensitive': response['output_sensitive'], - 'usage': response['usage'], - }, - ) - except (KeyError, IndexError, TypeError) as e: - raise UnexpectedResponseError(response) from e - - @override - def _get_request_parameters( - self, prompt: Prompt, stream: bool = False, **kwargs: Unpack[MinimaxProChatParametersDict] - ) -> HttpxPostKwargs: - messages = ensure_messages(prompt) - parameters = self.parameters.clone_with_changes(**kwargs) - if isinstance(messages[0], SystemMessage): - system_message = messages[0] - parameters.set_system_prompt(system_message.content) - messages = messages[1:] - - minimax_pro_messages = [ - _convert_to_minimax_pro_message( - message, default_bot_name=parameters.bot_name, default_user_name=self.default_user_name - ) - for message in messages - ] - json_data = {'model': self.model, 'messages': minimax_pro_messages, **parameters.custom_model_dump()} - if stream: - json_data['stream'] = True - - headers = { - 'Authorization': f'Bearer {self.settings.api_key.get_secret_value()}', - 'Content-Type': 'application/json', - } - return { - 'url': self.settings.api_base + '/text/chatcompletion_pro', - 'json': json_data, - 'headers': headers, - 'params': {'GroupId': self.settings.group_id}, - } - - @override - def _process_stream_response(self, line: str, stream_manager: StreamManager) -> ChatCompletionStreamOutput | None: - # TODO: implement this - raise NotImplementedError - - @classmethod - @override - def from_name(cls, name: str) -> Self: - return cls(model=name) diff --git a/generate/chat_completion/models/openai.py b/generate/chat_completion/models/openai.py index 0ffd766..3ae4ee4 100644 --- a/generate/chat_completion/models/openai.py +++ b/generate/chat_completion/models/openai.py @@ -1,69 +1,20 @@ from __future__ import annotations -from typing import AsyncIterator, ClassVar, Dict, Iterator, List, Literal, Optional, Union +from typing import ClassVar, List -from pydantic import Field, PositiveInt -from typing_extensions import Annotated, Unpack, override - -from generate.chat_completion.message import Prompt -from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput +from generate.chat_completion.message.converter import MessageConverter from generate.chat_completion.models.openai_like import ( - FunctionCallName, + OpenAIChatParameters, OpenAILikeChat, - OpenAIResponseFormat, - OpenAITool, - OpenAIToolChoice, - convert_to_openai_tool, + SupportOpenAIToolCall, ) -from generate.chat_completion.tool import FunctionJsonSchema, Tool, ToolCallMixin from generate.http import ( HttpClient, ) -from generate.model import ModelParameters, RemoteModelParametersDict from generate.platforms.openai import OpenAISettings -from generate.types import OrIterable, Probability, Temperature -from generate.utils import ensure_iterable - - -class OpenAIChatParameters(ModelParameters): - temperature: Optional[Temperature] = None - top_p: Optional[Probability] = None - max_tokens: Optional[PositiveInt] = None - functions: Optional[List[FunctionJsonSchema]] = None - function_call: Union[Literal['auto'], FunctionCallName, None] = None - stop: Union[str, List[str], None] = None - presence_penalty: Optional[Annotated[float, Field(ge=-2, le=2)]] = None - frequency_penalty: Optional[Annotated[float, Field(ge=-2, le=2)]] = None - logit_bias: Optional[Dict[int, Annotated[int, Field(ge=-100, le=100)]]] = None - logprobs: Optional[bool] = None - top_logprobs: Optional[Annotated[int, Field(ge=0, le=20)]] = None - user: Optional[str] = None - response_format: Optional[OpenAIResponseFormat] = None - seed: Optional[int] = None - tools: Optional[List[OpenAITool]] = None - tool_choice: Union[Literal['auto', 'none'], OpenAIToolChoice, None] = None -class OpenAIChatParametersDict(RemoteModelParametersDict, total=False): - temperature: Optional[Temperature] - top_p: Optional[Probability] - max_tokens: Optional[PositiveInt] - functions: Optional[List[FunctionJsonSchema]] - function_call: Union[Literal['auto'], FunctionCallName, None] - stop: Union[str, List[str], None] - presence_penalty: Optional[float] - frequency_penalty: Optional[float] - logit_bias: Optional[Dict[int, int]] - logprobs: Optional[bool] - top_logprobs: Optional[int] - user: Optional[str] - response_format: Optional[OpenAIResponseFormat] - seed: Optional[int] - tools: Optional[List[OpenAITool]] - tool_choice: Union[Literal['auto'], OpenAIToolChoice, None] - - -class OpenAIChat(OpenAILikeChat, ToolCallMixin): +class OpenAIChat(OpenAILikeChat, SupportOpenAIToolCall): model_type: ClassVar[str] = 'openai' available_models: ClassVar[List[str]] = [ 'gpt-4-turbo-preview', @@ -71,8 +22,8 @@ class OpenAIChat(OpenAILikeChat, ToolCallMixin): 'gpt-4-vision-preview', ] - parameters: OpenAIChatParameters settings: OpenAISettings + parameters: OpenAIChatParameters def __init__( self, @@ -80,37 +31,15 @@ def __init__( parameters: OpenAIChatParameters | None = None, settings: OpenAISettings | None = None, http_client: HttpClient | None = None, + message_converter: MessageConverter | None = None, ) -> None: parameters = parameters or OpenAIChatParameters() settings = settings or OpenAISettings() # type: ignore http_client = http_client or HttpClient() - super().__init__(model=model, parameters=parameters, settings=settings, http_client=http_client) - - @override - def generate(self, prompt: Prompt, **kwargs: Unpack[OpenAIChatParametersDict]) -> ChatCompletionOutput: - return super().generate(prompt, **kwargs) - - @override - async def async_generate(self, prompt: Prompt, **kwargs: Unpack[OpenAIChatParametersDict]) -> ChatCompletionOutput: - return await super().async_generate(prompt, **kwargs) - - @override - def stream_generate( - self, prompt: Prompt, **kwargs: Unpack[OpenAIChatParametersDict] - ) -> Iterator[ChatCompletionStreamOutput]: - yield from super().stream_generate(prompt, **kwargs) - - @override - async def async_stream_generate( - self, prompt: Prompt, **kwargs: Unpack[OpenAIChatParametersDict] - ) -> AsyncIterator[ChatCompletionStreamOutput]: - async for stream_output in super().async_stream_generate(prompt, **kwargs): - yield stream_output - - @override - def add_tools(self, tools: OrIterable[Tool]) -> None: - new_tools = [convert_to_openai_tool(tool) for tool in ensure_iterable(tools)] - if self.parameters.tools is None: - self.parameters.tools = new_tools - else: - self.parameters.tools.extend(new_tools) + super().__init__( + model=model, + parameters=parameters, + settings=settings, + message_converter=message_converter, + http_client=http_client, + ) diff --git a/generate/chat_completion/models/openai_like.py b/generate/chat_completion/models/openai_like.py index a380fda..e1e6941 100644 --- a/generate/chat_completion/models/openai_like.py +++ b/generate/chat_completion/models/openai_like.py @@ -1,13 +1,12 @@ from __future__ import annotations import base64 -from abc import ABC -from typing import Any, Dict, List, Literal, Union +from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Union, cast -from typing_extensions import NotRequired, TypedDict, override +from pydantic import Field, PositiveInt +from typing_extensions import Annotated, NotRequired, TypedDict, Unpack, override from generate.chat_completion.base import RemoteChatCompletionModel -from generate.chat_completion.cost_caculator import CostCalculator from generate.chat_completion.message import ( AssistantMessage, FunctionCall, @@ -21,19 +20,19 @@ UserMultiPartMessage, ) from generate.chat_completion.message.converter import MessageConverter -from generate.chat_completion.message.core import Messages +from generate.chat_completion.message.core import Messages, Prompt from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput, FinishReason, Usage from generate.chat_completion.stream_manager import StreamManager -from generate.chat_completion.tool import FunctionJsonSchema, Tool +from generate.chat_completion.tool import FunctionJsonSchema, SupportToolCall, Tool from generate.http import ( HttpClient, HttpGetKwargs, HttpxPostKwargs, - ResponseValue, ) -from generate.model import ModelInfo, ModelParameters -from generate.platforms.base import PlatformSettings +from generate.model import ModelInfo, ModelParameters, RemoteModelParametersDict from generate.platforms.openai_like import OpenAILikeSettings +from generate.types import OrIterable, Probability, Temperature +from generate.utils import ensure_iterable class FunctionCallName(TypedDict): @@ -74,7 +73,62 @@ class OpenAIResponseFormat(TypedDict): type: Literal['json_object', 'text'] +class OpenAIChatParameters(ModelParameters): + temperature: Optional[Temperature] = None + top_p: Optional[Probability] = None + max_tokens: Optional[PositiveInt] = None + functions: Optional[List[FunctionJsonSchema]] = None + function_call: Union[Literal['auto'], FunctionCallName, None] = None + stop: Union[str, List[str], None] = None + presence_penalty: Optional[Annotated[float, Field(ge=-2, le=2)]] = None + frequency_penalty: Optional[Annotated[float, Field(ge=-2, le=2)]] = None + logit_bias: Optional[Dict[int, Annotated[int, Field(ge=-100, le=100)]]] = None + logprobs: Optional[bool] = None + top_logprobs: Optional[Annotated[int, Field(ge=0, le=20)]] = None + user: Optional[str] = None + response_format: Optional[OpenAIResponseFormat] = None + seed: Optional[int] = None + tools: Optional[List[OpenAITool]] = None + tool_choice: Union[Literal['auto', 'none'], OpenAIToolChoice, None] = None + + +class OpenAIChatParametersDict(RemoteModelParametersDict, total=False): + temperature: Optional[Temperature] + top_p: Optional[Probability] + max_tokens: Optional[PositiveInt] + functions: Optional[List[FunctionJsonSchema]] + function_call: Union[Literal['auto'], FunctionCallName, None] + stop: Union[str, List[str], None] + presence_penalty: Optional[float] + frequency_penalty: Optional[float] + logit_bias: Optional[Dict[int, int]] + logprobs: Optional[bool] + top_logprobs: Optional[int] + user: Optional[str] + response_format: Optional[OpenAIResponseFormat] + seed: Optional[int] + tools: Optional[List[OpenAITool]] + tool_choice: Union[Literal['auto'], OpenAIToolChoice, None] + + +class SupportOpenAIToolCall(SupportToolCall): + parameters: ModelParameters + + @override + def add_tools(self, tools: OrIterable[Tool]) -> None: + new_tools = [OpenAITool(type='function', function=tool.json_schema) for tool in ensure_iterable(tools)] + if not hasattr(self.parameters, 'tools'): + raise ValueError('The parameters must have a tools attribute') + self.parameters = cast(OpenAIChatParameters, self.parameters) + if self.parameters.tools is None: + self.parameters.tools = new_tools + else: + self.parameters.tools.extend(new_tools) + + class OpenAIMessageConverter(MessageConverter): + allowed_message_types = [SystemMessage, UserMessage, UserMultiPartMessage, ToolMessage, AssistantMessage, FunctionMessage] + def convert_system_message(self, message: SystemMessage) -> Dict[str, Any]: return { 'role': 'system', @@ -151,110 +205,55 @@ def convert_function_message(self, message: FunctionMessage) -> Dict[str, Any]: } -# def openai_calculate_cost(model_name: str, input_tokens: int, output_tokens: int) -> float | None: - -# dollar_to_yuan = 7 -# if model_name in ('gpt-4-1106-preview', 'gpt-4-1106-vision-preview'): -# return (0.01 * dollar_to_yuan) * (input_tokens / 1000) + (0.03 * dollar_to_yuan) * (output_tokens / 1000) -# if 'gpt-4-turbo' in model_name: -# return (0.01 * dollar_to_yuan) * (input_tokens / 1000) + (0.03 * dollar_to_yuan) * (output_tokens / 1000) -# if 'gpt-4-32k' in model_name: -# return (0.06 * dollar_to_yuan) * (input_tokens / 1000) + (0.12 * dollar_to_yuan) * (output_tokens / 1000) -# if 'gpt-4' in model_name: -# return (0.03 * dollar_to_yuan) * (input_tokens / 1000) + (0.06 * dollar_to_yuan) * (output_tokens / 1000) -# if 'gpt-3.5-turbo' in model_name: -# return (0.001 * dollar_to_yuan) * (input_tokens / 1000) + (0.002 * dollar_to_yuan) * (output_tokens / 1000) -# if 'moonshot' in model_name: -# if '8k' in model_name: -# return 0.012 * (input_tokens / 1000) + 0.012 * (output_tokens / 1000) -# if '32k' in model_name: -# return 0.024 * (input_tokens / 1000) + 0.024 * (output_tokens / 1000) -# if '128k' in model_name: -# return 0.06 * (input_tokens / 1000) + 0.06 * (output_tokens / 1000) -# return None - - -def parse_message_dict(message: dict[str, Any]) -> AssistantMessage: - if function_call_dict := message.get('function_call'): - function_call = FunctionCall( - name=function_call_dict.get('name') or '', - arguments=function_call_dict['arguments'], - ) - else: - function_call = None - - if tool_calls_dict := message.get('tool_calls'): - tool_calls = [ - ToolCall( - id=tool_call['id'], - function=FunctionCall( - name=tool_call['function'].get('name') or '', - arguments=tool_call['function']['arguments'], - ), - ) - for tool_call in tool_calls_dict - ] - else: - tool_calls = None - return AssistantMessage(content=message.get('content') or '', function_call=function_call, tool_calls=tool_calls) - - -def convert_to_openai_tool(tool: Tool) -> OpenAITool: - return OpenAITool(type='function', function=tool.json_schema) - - -def process_openai_like_model_reponse(response: ResponseValue, model_type: str) -> ChatCompletionOutput: - choice = response['choices'][0] - message = parse_message_dict(choice['message']) - extra = {'response': response} - if system_fingerprint := response.get('system_fingerprint'): - extra['system_fingerprint'] = system_fingerprint - - if (finish_reason := choice.get('finish_reason')) is None: - finish_reason = finish_details['type'] if (finish_details := choice.get('finish_details')) else None - - if finish_reason: - finish_reason = FinishReason(finish_reason) - input_tokens = response['usage']['prompt_tokens'] - output_tokens = response['usage']['completion_tokens'] - cost = None - for k, v in response['usage'].items(): - if k in ('cost', 'total_cost'): - cost = v - break - - return ChatCompletionOutput( - model_info=ModelInfo(task='chat_completion', type=model_type, name=response['model']), - message=message, - finish_reason=finish_reason, - usage=Usage(input_tokens=input_tokens, output_tokens=output_tokens, cost=cost), - extra=extra, - ) - - -class OpenAILikeChat(RemoteChatCompletionModel, ABC): +class OpenAILikeChat(RemoteChatCompletionModel): + settings: OpenAILikeSettings + + message_converter: OpenAIMessageConverter + parameters: ModelParameters settings: OpenAILikeSettings def __init__( self, model: str, - parameters: ModelParameters, - settings: PlatformSettings, + parameters: ModelParameters | None = None, + settings: OpenAILikeSettings | None = None, http_client: HttpClient | None = None, message_converter: MessageConverter | None = None, - cost_calculator: CostCalculator | None = None, ) -> None: http_client = http_client or HttpClient() message_converter = message_converter or OpenAIMessageConverter() + parameters = parameters or OpenAIChatParameters() + if settings is None: + raise ValueError('settings is required') super().__init__( model=model, parameters=parameters, settings=settings, http_client=http_client, message_converter=message_converter, - cost_calculator=cost_calculator, ) + @override + def generate(self, prompt: Prompt, **kwargs: Unpack[OpenAIChatParametersDict]) -> ChatCompletionOutput: + return super().generate(prompt, **kwargs) + + @override + async def async_generate(self, prompt: Prompt, **kwargs: Unpack[OpenAIChatParametersDict]) -> ChatCompletionOutput: + return await super().async_generate(prompt, **kwargs) + + @override + def stream_generate( + self, prompt: Prompt, **kwargs: Unpack[OpenAIChatParametersDict] + ) -> Iterator[ChatCompletionStreamOutput]: + yield from super().stream_generate(prompt, **kwargs) + + @override + async def async_stream_generate( + self, prompt: Prompt, **kwargs: Unpack[OpenAIChatParametersDict] + ) -> AsyncIterator[ChatCompletionStreamOutput]: + async for stream_output in super().async_stream_generate(prompt, **kwargs): + yield stream_output + @override def _get_request_parameters(self, messages: Messages, stream: bool = False, **kwargs: Any) -> HttpxPostKwargs: parameters = self.parameters.clone_with_changes(**kwargs) @@ -279,10 +278,10 @@ def _get_request_parameters(self, messages: Messages, stream: bool = False, **kw def _process_reponse(self, response: dict[str, Any]) -> ChatCompletionOutput: return ChatCompletionOutput( model_info=ModelInfo(task='chat_completion', type=self.model_type, name=response['model']), - message=self._parse_assistant_message(response), + message=self._parse_assistant_message(response['choices'][0]['message']), finish_reason=self._parse_finish_reason(response), usage=self._parse_usage(response), - extra=self._parse_extra_info(response), + extra=self._parse_extra(response), ) @override @@ -291,13 +290,12 @@ def _process_stream_response( ) -> ChatCompletionStreamOutput | None: delta_dict = response['choices'][0].get('delta', {}) self._update_delta(delta_dict, stream_manager=stream_manager) - stream_manager.extra = self._parse_extra_info(response) + stream_manager.extra = self._parse_extra(response) stream_manager.usage = self._parse_usage(response) stream_manager.finish_reason = self._parse_finish_reason(response) return stream_manager.build_stream_output() - def _parse_assistant_message(self, response: dict[str, Any]) -> AssistantMessage: - message = response['choices'][0]['message'] + def _parse_assistant_message(self, message: dict[str, Any]) -> AssistantMessage: if function_call_dict := message.get('function_call'): function_call = FunctionCall( name=function_call_dict.get('name') or '', @@ -348,16 +346,10 @@ def _parse_usage(self, response: dict[str, Any]) -> Usage: if usage := response.get('usage'): input_tokens = usage['prompt_tokens'] output_tokens = usage['completion_tokens'] - cost = self.cost(input_tokens, output_tokens) - if cost is None: - for k, v in usage.items(): - if k in ('cost', 'total_cost'): - cost = v - break - return Usage(input_tokens=input_tokens, output_tokens=output_tokens, cost=cost) + return Usage(input_tokens=input_tokens, output_tokens=output_tokens) return Usage() - def _parse_extra_info(self, response: dict[str, Any]) -> dict[str, Any]: + def _parse_extra(self, response: dict[str, Any]) -> dict[str, Any]: return {'response': response} def _update_delta(self, delta_dict: dict[str, Any], stream_manager: StreamManager) -> None: @@ -367,7 +359,7 @@ def _update_delta(self, delta_dict: dict[str, Any], stream_manager: StreamManage if delta_dict.get('tool_calls'): index = delta_dict['tool_calls'][0]['index'] if index >= len(stream_manager.tool_calls or []): - new_tool_calls_message = parse_message_dict(delta_dict).tool_calls + new_tool_calls_message = self._parse_assistant_message(delta_dict).tool_calls assert new_tool_calls_message is not None if stream_manager.tool_calls is None: stream_manager.tool_calls = [] diff --git a/generate/chat_completion/models/openrouter.py b/generate/chat_completion/models/openrouter.py index 4c96e41..0023111 100644 --- a/generate/chat_completion/models/openrouter.py +++ b/generate/chat_completion/models/openrouter.py @@ -13,14 +13,12 @@ OpenAIResponseFormat, OpenAITool, OpenAIToolChoice, - convert_to_openai_tool, + SupportOpenAIToolCall, ) -from generate.chat_completion.tool import Tool, ToolCallMixin from generate.http import HttpClient, HttpxPostKwargs from generate.model import ModelParameters, RemoteModelParametersDict from generate.platforms import OpenRouterSettings -from generate.types import OrIterable, Probability, Temperature -from generate.utils import ensure_iterable +from generate.types import Probability, Temperature class ProviderParameters(BaseModel): @@ -70,7 +68,7 @@ class OpenRouterParametersDict(RemoteModelParametersDict, total=False): provider: ProviderParameters -class OpenRouterChat(OpenAILikeChat, ToolCallMixin): +class OpenRouterChat(OpenAILikeChat, SupportOpenAIToolCall): model_type: ClassVar[str] = 'openrouter' available_models: ClassVar[List[str]] = ['auto'] @@ -102,23 +100,23 @@ def __init__( @override def generate(self, prompt: Prompt, **kwargs: Unpack[OpenRouterParametersDict]) -> ChatCompletionOutput: - return super().generate(prompt, **kwargs) + return super().generate(prompt, **kwargs) # type: ignore @override async def async_generate(self, prompt: Prompt, **kwargs: Unpack[OpenRouterParametersDict]) -> ChatCompletionOutput: - return await super().async_generate(prompt, **kwargs) + return await super().async_generate(prompt, **kwargs) # type: ignore @override def stream_generate( self, prompt: Prompt, **kwargs: Unpack[OpenRouterParametersDict] ) -> Iterator[ChatCompletionStreamOutput]: - yield from super().stream_generate(prompt, **kwargs) + yield from super().stream_generate(prompt, **kwargs) # type: ignore @override async def async_stream_generate( self, prompt: Prompt, **kwargs: Unpack[OpenRouterParametersDict] ) -> AsyncIterator[ChatCompletionStreamOutput]: - async for stream_output in super().async_stream_generate(prompt, **kwargs): + async for stream_output in super().async_stream_generate(prompt, **kwargs): # type: ignore yield stream_output @override @@ -132,11 +130,3 @@ def _get_request_parameters(self, messages: Messages, stream: bool = False, **kw request_parameters['json']['models'] = self.models request_parameters['json'].pop('model') return request_parameters - - @override - def add_tools(self, tools: OrIterable[Tool]) -> None: - new_tools = [convert_to_openai_tool(tool) for tool in ensure_iterable(tools)] - if self.parameters.tools is None: - self.parameters.tools = new_tools - else: - self.parameters.tools.extend(new_tools) diff --git a/generate/chat_completion/models/stepfun.py b/generate/chat_completion/models/stepfun.py index 4f1d8ee..c7968e4 100644 --- a/generate/chat_completion/models/stepfun.py +++ b/generate/chat_completion/models/stepfun.py @@ -6,6 +6,7 @@ from typing_extensions import Annotated, Unpack, override from generate.chat_completion.message import Prompt +from generate.chat_completion.message.converter import MessageConverter from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput from generate.chat_completion.models.openai_like import OpenAILikeChat from generate.http import HttpClient @@ -13,11 +14,9 @@ from generate.platforms import StepFunSettings from generate.types import Probability -Temperature = Annotated[float, Field(ge=0, le=2)] - class StepFunChatParameters(ModelParameters): - temperature: Optional[Temperature] = None + temperature: Optional[Annotated[float, Field(ge=0, le=2)]] = None top_p: Optional[Probability] = None max_tokens: Optional[PositiveInt] = None presence_penalty: Optional[Annotated[float, Field(ge=-2, le=2)]] = None @@ -25,9 +24,9 @@ class StepFunChatParameters(ModelParameters): class StepFunParametersDict(RemoteModelParametersDict, total=False): - temperature: Optional[Temperature] - top_p: Optional[Probability] - max_tokens: Optional[PositiveInt] + temperature: Optional[float] + top_p: Optional[float] + max_tokens: Optional[int] presence_penalty: Optional[float] frequency_penalty: Optional[float] @@ -45,12 +44,18 @@ def __init__( parameters: StepFunChatParameters | None = None, settings: StepFunSettings | None = None, http_client: HttpClient | None = None, + message_converter: MessageConverter | None = None, ) -> None: parameters = parameters or StepFunChatParameters() settings = settings or StepFunSettings() # type: ignore http_client = http_client or HttpClient() - model = model - super().__init__(model=model, parameters=parameters, settings=settings, http_client=http_client) + super().__init__( + model=model, + parameters=parameters, + settings=settings, + message_converter=message_converter, + http_client=http_client, + ) @override def generate(self, prompt: Prompt, **kwargs: Unpack[StepFunParametersDict]) -> ChatCompletionOutput: diff --git a/generate/chat_completion/models/wenxin.py b/generate/chat_completion/models/wenxin.py deleted file mode 100644 index 524d31e..0000000 --- a/generate/chat_completion/models/wenxin.py +++ /dev/null @@ -1,260 +0,0 @@ -from __future__ import annotations - -import json -from typing import Any, AsyncIterator, ClassVar, Dict, Iterator, List, Literal, Optional - -from pydantic import Field, model_validator -from typing_extensions import Annotated, NotRequired, Self, TypedDict, Unpack, override - -from generate.chat_completion.base import RemoteChatCompletionModel -from generate.chat_completion.message import ( - AssistantMessage, - FunctionCall, - FunctionMessage, - Message, - Messages, - MessageTypeError, - Prompt, - SystemMessage, - UserMessage, - ensure_messages, -) -from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput -from generate.chat_completion.stream_manager import StreamManager -from generate.chat_completion.tool import Tool, ToolCallMixin -from generate.http import ( - HttpClient, - HttpxPostKwargs, - ResponseValue, - UnexpectedResponseError, -) -from generate.model import ModelParameters, RemoteModelParametersDict -from generate.platforms.baidu import QianfanSettings, QianfanTokenManager -from generate.types import JsonSchema, OrIterable, Probability, Temperature -from generate.utils import ensure_iterable - - -class WenxinMessage(TypedDict): - role: Literal['user', 'assistant', 'function'] - content: str - name: NotRequired[str] - function_call: NotRequired[WenxinFunctionCall] - - -class WenxinFunctionCall(TypedDict): - name: str - arguments: str - thoughts: NotRequired[str] - - -class WenxinFunction(TypedDict): - name: str - description: str - parameters: JsonSchema - responses: NotRequired[JsonSchema] - examples: NotRequired[List[WenxinMessage]] - - -def convert_to_wenxin_function(tool: Tool) -> WenxinFunction: - return { - 'name': tool.name, - 'description': tool.description, - 'parameters': tool.parameters, - } - - -def convert_to_wenxin_message(message: Message) -> WenxinMessage: - if isinstance(message, UserMessage): - return { - 'role': 'user', - 'content': message.content, - } - - if isinstance(message, AssistantMessage): - if message.function_call: - return { - 'role': 'assistant', - 'function_call': { - 'name': message.function_call.name, - 'arguments': message.function_call.arguments, - 'thoughts': message.function_call.thoughts or '', - }, - 'content': message.content, - } - return { - 'role': 'assistant', - 'content': message.content, - } - - if isinstance(message, FunctionMessage): - return { - 'role': 'function', - 'name': message.name, - 'content': message.content, - } - - raise MessageTypeError(message, allowed_message_type=(UserMessage, AssistantMessage, FunctionMessage)) - - -def _convert_messages(messages: Messages) -> list[WenxinMessage]: - if isinstance(system_message := messages[0], SystemMessage): - prepend_messages = [UserMessage(content=system_message.content), AssistantMessage(content='好的')] - messages = prepend_messages + messages[1:] - return [convert_to_wenxin_message(message) for message in messages] - - -class WenxinChatParameters(ModelParameters): - temperature: Optional[Temperature] = None - top_p: Optional[Probability] = None - functions: Optional[List[WenxinFunction]] = None - penalty_score: Optional[Annotated[float, Field(ge=1, le=2)]] = None - system: Optional[str] = None - user: Optional[str] = Field(default=None, serialization_alias='user_id') - - @model_validator(mode='after') - def system_function_conflict(self) -> Self: - if self.system is not None and self.functions is not None: - raise ValueError('system and functions cannot be used together') - return self - - def custom_model_dump(self) -> dict[str, Any]: - output = super().custom_model_dump() - if 'temperature' in output: - output['temperature'] = max(0.01, output['temperature']) - return output - - -class WenxinChatParametersDict(RemoteModelParametersDict, total=False): - temperature: Optional[Temperature] - top_p: Optional[Probability] - functions: Optional[List[WenxinFunction]] - penalty_score: Optional[float] - system: Optional[str] - user: Optional[str] - - -class WenxinChat(RemoteChatCompletionModel, ToolCallMixin): - model_type: ClassVar[str] = 'wenxin' - model_name_entrypoint_map: ClassVar[Dict[str, str]] = { - 'ERNIE-Bot': 'completions', - 'ERNIE-Bot-turbo': 'eb-instant', - 'ERNIE-Bot-4': 'completions_pro', - } - available_models: ClassVar[List[str]] = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'ERNIE-Bot-4'] - - parameters: WenxinChatParameters - settings: QianfanSettings - - def __init__( - self, - model: str = 'ERNIE-Bot', - parameters: WenxinChatParameters | None = None, - settings: QianfanSettings | None = None, - http_client: HttpClient | None = None, - ) -> None: - parameters = parameters or WenxinChatParameters() - settings = settings or QianfanSettings() # type: ignore - http_client = http_client or HttpClient() - super().__init__(model=model, parameters=parameters, settings=settings, http_client=http_client) - self.token_manager = QianfanTokenManager(self.settings, self.http_client) - - @override - def generate(self, prompt: Prompt, **kwargs: Unpack[WenxinChatParametersDict]) -> ChatCompletionOutput: - return super().generate(prompt, **kwargs) - - @override - async def async_generate(self, prompt: Prompt, **kwargs: Unpack[WenxinChatParametersDict]) -> ChatCompletionOutput: - return await super().async_generate(prompt, **kwargs) - - @override - def stream_generate( - self, prompt: Prompt, **kwargs: Unpack[WenxinChatParametersDict] - ) -> Iterator[ChatCompletionStreamOutput]: - yield from super().stream_generate(prompt, **kwargs) - - @override - async def async_stream_generate( - self, prompt: Prompt, **kwargs: Unpack[WenxinChatParametersDict] - ) -> AsyncIterator[ChatCompletionStreamOutput]: - async for output in super().async_stream_generate(prompt, **kwargs): - yield output - - def _get_request_parameters( - self, prompt: Prompt, stream: bool = False, **kwargs: Unpack[WenxinChatParametersDict] - ) -> HttpxPostKwargs: - messages = ensure_messages(prompt) - parameters = self.parameters.clone_with_changes(**kwargs) - wenxin_messages: list[WenxinMessage] = _convert_messages(messages) - parameters_dict = parameters.custom_model_dump() - if 'temperature' in parameters_dict: - parameters_dict['temperature'] = max(0.01, parameters_dict['temperature']) - json_data = {'messages': wenxin_messages, **parameters_dict} - if stream: - json_data['stream'] = True - - return { - 'url': self.settings.comlpetion_api_base + self.model_name_entrypoint_map[self.model], - 'json': json_data, - 'params': {'access_token': self.token_manager.token}, - 'headers': {'Content-Type': 'application/json'}, - } - - @override - def _process_reponse(self, response: ResponseValue) -> ChatCompletionOutput: - if response.get('error_msg'): - raise UnexpectedResponseError(response) - if response.get('function_call'): - function_call = FunctionCall( - name=response['function_call']['name'], - arguments=response['function_call']['arguments'], - thoughts=response['function_call']['thoughts'], - ) - else: - function_call = None - message = AssistantMessage(content=response['result'], function_call=function_call) - return ChatCompletionOutput( - model_info=self.model_info, - message=message, - cost=self._calculate_cost(response['usage']), - extra={ - 'is_truncated': response['is_truncated'], - 'need_clear_history': response['need_clear_history'], - 'usage': response['usage'], - }, - finish_reason=response['finish_reason'], - ) - - @override - def _process_stream_response(self, line: str, stream_manager: StreamManager) -> ChatCompletionStreamOutput | None: - try: - data = json.loads(line) - except json.JSONDecodeError: - return None - - stream_manager.delta = data['result'] - if data['is_end']: - stream_manager.cost = self._calculate_cost(data['usage']) - stream_manager.finish_reason = 'stop' - extra = { - 'is_truncated': data['is_truncated'], - 'need_clear_history': data['need_clear_history'], - 'usage': data['usage'], - } - stream_manager.extra.update(extra) - return stream_manager.build_stream_output() - - def _calculate_cost(self, usage: dict[str, Any]) -> float | None: - if self.name == 'ERNIE-Bot': - return 0.012 * (usage['total_tokens'] / 1000) - if self.name == 'ERNIE-Bot-turbo': - return 0.008 * (usage['total_tokens'] / 1000) - if self.name == 'ERNIE-Bot-4': - return 0.12 * (usage['total_tokens'] / 1000) - return None - - def add_tools(self, tools: OrIterable[Tool]) -> None: - new_functions = [convert_to_wenxin_function(tool) for tool in ensure_iterable(tools)] - if self.parameters.functions is None: - self.parameters.functions = new_functions - else: - self.parameters.functions.extend(new_functions) diff --git a/generate/chat_completion/models/yi.py b/generate/chat_completion/models/yi.py index 5776955..08ab8fc 100644 --- a/generate/chat_completion/models/yi.py +++ b/generate/chat_completion/models/yi.py @@ -5,20 +5,13 @@ from pydantic import Field, PositiveInt from typing_extensions import Annotated, Unpack, override -from generate.chat_completion.cost_caculator import CostCalculator, GeneralCostCalculator from generate.chat_completion.message import Prompt +from generate.chat_completion.message.converter import MessageConverter from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput -from generate.chat_completion.models.openai_like import OpenAILikeChat, OpenAIMessageConverter +from generate.chat_completion.models.openai_like import OpenAILikeChat from generate.http import HttpClient from generate.model import ModelParameters, RemoteModelParametersDict from generate.platforms import YiSettings -from generate.types import ModelPrice - -YiModelPrice: ModelPrice = { - 'yi-34b-chat-200k': (12.0, 12.0), - 'yi-34b-chat': (2.5, 2.5), - 'yi-vl-plus': (6, 6), -} class YiChatParameters(ModelParameters): @@ -39,7 +32,6 @@ class YiChat(OpenAILikeChat): parameters: YiChatParameters settings: YiSettings - message_converter: OpenAIMessageConverter def __init__( self, @@ -47,19 +39,16 @@ def __init__( parameters: YiChatParameters | None = None, settings: YiSettings | None = None, http_client: HttpClient | None = None, - message_converter: OpenAIMessageConverter | None = None, - cost_calculator: CostCalculator | None = None, + message_converter: MessageConverter | None = None, ) -> None: parameters = parameters or YiChatParameters() settings = settings or YiSettings() # type: ignore - cost_calculator = cost_calculator or GeneralCostCalculator(YiModelPrice) super().__init__( model=model, parameters=parameters, settings=settings, - http_client=http_client, message_converter=message_converter, - cost_calculator=cost_calculator, + http_client=http_client, ) @override diff --git a/generate/chat_completion/models/zhipu.py b/generate/chat_completion/models/zhipu.py index 3b211fa..e19fa43 100644 --- a/generate/chat_completion/models/zhipu.py +++ b/generate/chat_completion/models/zhipu.py @@ -7,7 +7,6 @@ from typing_extensions import NotRequired, TypedDict, Unpack, override from generate.chat_completion.base import RemoteChatCompletionModel -from generate.chat_completion.cost_caculator import CostCalculator, GeneralCostCalculator from generate.chat_completion.message import ( AssistantMessage, FunctionCall, @@ -24,9 +23,10 @@ ) from generate.chat_completion.message.converter import MessageConverter from generate.chat_completion.message.core import FunctionMessage +from generate.chat_completion.message.exception import MessageTypeError from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput, FinishReason, Usage from generate.chat_completion.stream_manager import StreamManager -from generate.chat_completion.tool import Tool, ToolCallMixin +from generate.chat_completion.tool import SupportToolCall, Tool from generate.http import ( HttpClient, HttpxPostKwargs, @@ -128,6 +128,8 @@ class ZhipuMessage(TypedDict): class ZhipuMessageConverter(MessageConverter): + allowed_message_types = [SystemMessage, UserMessage, AssistantMessage, ToolMessage, UserMultiPartMessage] + def convert_system_message(self, message: SystemMessage) -> Dict[str, Any]: return { 'role': 'system', @@ -176,7 +178,7 @@ def convert_tool_message(self, message: ToolMessage) -> Dict[str, Any]: } def convert_function_message(self, message: FunctionMessage) -> Dict[str, Any]: - raise NotImplementedError('Zhipu does not support function messages') + raise MessageTypeError(message, allowed_message_type=self.allowed_message_types) def convert_user_multi_part_message(self, message: UserMultiPartMessage) -> Dict[str, Any]: content = [] @@ -209,13 +211,12 @@ def convert_user_multi_part_message(self, message: UserMultiPartMessage) -> Dict return {'role': 'user', 'content': content} -class ZhipuChat(RemoteChatCompletionModel, ToolCallMixin): +class ZhipuChat(RemoteChatCompletionModel, SupportToolCall): model_type: ClassVar[str] = 'zhipu' available_models: ClassVar[List[str]] = ['glm-4', 'glm-3-turbo', 'glm-4v'] parameters: ZhipuChatParameters settings: ZhipuSettings - message_converter: ZhipuMessageConverter def __init__( self, @@ -224,20 +225,17 @@ def __init__( settings: ZhipuSettings | None = None, http_client: HttpClient | None = None, message_converter: ZhipuMessageConverter | None = None, - cost_calculator: CostCalculator | None = None, ) -> None: parameters = parameters or ZhipuChatParameters() settings = settings or ZhipuSettings() # type: ignore http_client = http_client or HttpClient() message_converter = message_converter or ZhipuMessageConverter() - cost_calculator = cost_calculator or GeneralCostCalculator(ZhipuModelPrice) super().__init__( model=model, parameters=parameters, settings=settings, http_client=http_client, message_converter=message_converter, - cost_calculator=cost_calculator, ) @override @@ -285,7 +283,7 @@ def _get_request_parameters( def _process_reponse(self, response: ResponseValue) -> ChatCompletionOutput: return ChatCompletionOutput( model_info=self.model_info, - message=self._parse_assistant_message(response), + message=self._parse_assistant_message(response['choices'][0]['message']), usage=self._parse_usage(response), extra=self._parse_extra(response), finish_reason=self._parse_finish_reason(response), @@ -319,12 +317,12 @@ def add_tools(self, tools: OrIterable[Tool]) -> None: else: self.parameters.tools.extend(new_tools) - def _parse_assistant_message(self, response: dict[str, Any]) -> AssistantMessage: - if 'tool_calls' in response: - dict_format_tool_calls = response['tool_calls'] + def _parse_assistant_message(self, message: dict[str, Any]) -> AssistantMessage: + if 'tool_calls' in message: + dict_format_tool_calls = message['tool_calls'] dict_format_tool_calls.sort(key=lambda x: x['index']) tool_calls = [] - for tool_call_dict in response['tool_calls']: + for tool_call_dict in message['tool_calls']: if tool_call_dict['type'] != 'function': raise ValueError(f'invalid tool type: {tool_call_dict["type"]}, should be function') tool_calls.append( @@ -339,12 +337,12 @@ def _parse_assistant_message(self, response: dict[str, Any]) -> AssistantMessage ) return AssistantMessage( role='assistant', - content=response.get('content') or '', + content=message.get('content') or '', tool_calls=tool_calls, ) return AssistantMessage( role='assistant', - content=response['content'], + content=message['content'], ) def _parse_usage(self, response: dict[str, Any]) -> Usage: @@ -352,8 +350,7 @@ def _parse_usage(self, response: dict[str, Any]) -> Usage: if usage is not None: input_tokens = usage['prompt_tokens'] output_tokens = usage['completion_tokens'] - cost = self.cost(input_tokens, output_tokens) - return Usage(input_tokens=input_tokens, output_tokens=output_tokens, cost=cost) + return Usage(input_tokens=input_tokens, output_tokens=output_tokens) return Usage() def _parse_finish_reason(self, response: dict[str, Any]) -> FinishReason | None: diff --git a/generate/chat_completion/tool.py b/generate/chat_completion/tool.py index 00247e2..4f23f9f 100644 --- a/generate/chat_completion/tool.py +++ b/generate/chat_completion/tool.py @@ -2,7 +2,7 @@ import uuid from collections import UserDict -from typing import Any, Callable, Generic, MutableMapping, TypeVar +from typing import Any, Callable, Generic, MutableMapping, Protocol, TypeVar, runtime_checkable from docstring_parser import parse from pydantic import TypeAdapter, validate_call @@ -106,14 +106,15 @@ def from_iterable(cls, tools: OrIterable[Tool]) -> Self: return cls({tool.name: tool for tool in ensure_iterable(tools)}) -class ToolCallMixin: +@runtime_checkable +class SupportToolCall(Protocol): def add_tools(self, tools: OrIterable[Tool]) -> None: - raise NotImplementedError + ... def generate_tool_call_id(self, function_call: FunctionCall) -> str: return f'tool_{uuid.uuid4().hex}' - def adapt_tool_calls(self, messages: Messages) -> None: + def process_messages_for_tool_call(self, messages: Messages) -> None: for index in range(len(messages)): current_message = messages[index] if isinstance(current_message, AssistantMessage) and current_message.function_call is not None: diff --git a/generate/constant.py b/generate/constant.py new file mode 100644 index 0000000..1c475a9 --- /dev/null +++ b/generate/constant.py @@ -0,0 +1 @@ +SINO_US_EXCHANGE_RATE = 7.3 diff --git a/generate/highlevel.py b/generate/highlevel.py index 830c2bf..c03faf6 100644 --- a/generate/highlevel.py +++ b/generate/highlevel.py @@ -11,16 +11,6 @@ ChatModelRegistry, Prompt, ) -from generate.image_generation import ( - ImageGenerationModel, - ImageGenerationModelRegistry, - ImageGenerationOutput, -) -from generate.text_to_speech import ( - SpeechModelRegistry, - TextToSpeechModel, - TextToSpeechOutput, -) def load_chat_model(model_id: str) -> ChatCompletionModel: @@ -32,39 +22,11 @@ def load_chat_model(model_id: str) -> ChatCompletionModel: return model_cls.from_name(name) -def load_speech_model(model_id: str) -> TextToSpeechModel: - if '/' not in model_id: - model_type = model_id - return SpeechModelRegistry[model_type][0]() - model_type, name = model_id.split('/', maxsplit=1) - model_cls = SpeechModelRegistry[model_type][0] - return model_cls.from_name(name) - - -def load_image_generation_model(model_id: str) -> ImageGenerationModel: - if '/' not in model_id: - model_type = model_id - return ImageGenerationModelRegistry[model_type][0]() - model_type, name = model_id.split('/', maxsplit=1) - model_cls = ImageGenerationModelRegistry[model_type][0] - return model_cls.from_name(name) - - def generate_text(prompt: Prompt, model_id: str = 'openai', **kwargs: Any) -> ChatCompletionOutput: model = load_chat_model(model_id) return model.generate(prompt, **kwargs) -def generate_speech(text: str, model_id: str = 'openai', **kwargs: Any) -> TextToSpeechOutput: - model = load_speech_model(model_id) - return model.generate(text, **kwargs) - - -def generate_image(prompt: str, model_id: str = 'openai', **kwargs: Any) -> ImageGenerationOutput: - model = load_image_generation_model(model_id) - return model.generate(prompt, **kwargs) - - @overload async def multimodel_generate_text( prompt: Prompt, model_ids: Sequence[str], ignore_error: Literal[False] = False, **kwargs: Any diff --git a/generate/image_generation/__init__.py b/generate/image_generation/__init__.py deleted file mode 100644 index 064386a..0000000 --- a/generate/image_generation/__init__.py +++ /dev/null @@ -1,43 +0,0 @@ -from __future__ import annotations - -from typing import Type, TypeVar - -from generate.image_generation.base import GeneratedImage, ImageGenerationModel, ImageGenerationOutput -from generate.image_generation.models import ( - BaiduImageGeneration, - BaiduImageGenerationParameters, - OpenAIImageGeneration, - OpenAIImageGenerationParameters, - QianfanImageGeneration, - QianfanImageGenerationParameters, - ZhipuImageGeneration, - ZhipuImageGenerationParameters, -) -from generate.model import ModelParameters - -P = TypeVar('P', bound=ModelParameters) - -ImageGenerationModels: list[tuple[Type[ImageGenerationModel], Type[ModelParameters]]] = [ - (OpenAIImageGeneration, OpenAIImageGenerationParameters), - (BaiduImageGeneration, BaiduImageGenerationParameters), - (QianfanImageGeneration, QianfanImageGenerationParameters), - (ZhipuImageGeneration, ZhipuImageGenerationParameters), -] - -ImageGenerationModelRegistry: dict[str, tuple[Type[ImageGenerationModel], Type[ModelParameters]]] = { - model_cls.model_type: (model_cls, parameter_cls) for model_cls, parameter_cls in ImageGenerationModels -} - -__all__ = [ - 'ImageGenerationModel', - 'ImageGenerationOutput', - 'GeneratedImage', - 'OpenAIImageGeneration', - 'OpenAIImageGenerationParameters', - 'BaiduImageGeneration', - 'BaiduImageGenerationParameters', - 'QianfanImageGeneration', - 'QianfanImageGenerationParameters', - 'ZhipuImageGeneration', - 'ZhipuImageGenerationParameters', -] diff --git a/generate/image_generation/base.py b/generate/image_generation/base.py deleted file mode 100644 index 9d00b2a..0000000 --- a/generate/image_generation/base.py +++ /dev/null @@ -1,46 +0,0 @@ -import logging -from abc import ABC -from typing import ClassVar, List, Optional, get_type_hints - -from pydantic import BaseModel - -from generate.http import HttpClient -from generate.model import GenerateModel, ModelOutput, ModelParameters -from generate.platforms.base import PlatformSettings - -logger = logging.getLogger(__name__) - - -class GeneratedImage(BaseModel): - url: Optional[str] = None - prompt: str - image_format: str - content: bytes - - -class ImageGenerationOutput(ModelOutput): - images: List[GeneratedImage] = [] - - -class ImageGenerationModel(GenerateModel[str, ImageGenerationOutput], ABC): - model_task: ClassVar[str] = 'image_generation' - model_type: ClassVar[str] - - -class RemoteImageGenerationModel(ImageGenerationModel): - settings: PlatformSettings - http_client: HttpClient - - def __init__( - self, - parameters: ModelParameters, - settings: PlatformSettings, - http_client: HttpClient, - ) -> None: - self.parameters = parameters - self.settings = settings - self.http_client = http_client - - @classmethod - def how_to_settings(cls) -> str: - return f'{cls.__name__} Settings\n\n' + get_type_hints(cls)['settings'].how_to_settings() diff --git a/generate/image_generation/models/__init__.py b/generate/image_generation/models/__init__.py deleted file mode 100644 index d9f405e..0000000 --- a/generate/image_generation/models/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -from generate.image_generation.models.baidu import BaiduImageGeneration, BaiduImageGenerationParameters -from generate.image_generation.models.openai import OpenAIImageGeneration, OpenAIImageGenerationParameters -from generate.image_generation.models.qianfan import QianfanImageGeneration, QianfanImageGenerationParameters -from generate.image_generation.models.zhipu import ZhipuImageGeneration, ZhipuImageGenerationParameters - -__all__ = [ - 'OpenAIImageGeneration', - 'OpenAIImageGenerationParameters', - 'BaiduImageGeneration', - 'BaiduImageGenerationParameters', - 'QianfanImageGeneration', - 'QianfanImageGenerationParameters', - 'ZhipuImageGeneration', - 'ZhipuImageGenerationParameters', -] diff --git a/generate/image_generation/models/baidu.py b/generate/image_generation/models/baidu.py deleted file mode 100644 index 7bda46c..0000000 --- a/generate/image_generation/models/baidu.py +++ /dev/null @@ -1,190 +0,0 @@ -from __future__ import annotations - -import asyncio -import time -from typing import Any, Literal, Optional, Union - -from pydantic import Base64Str, Field, HttpUrl -from typing_extensions import Annotated, Self, Unpack, override - -from generate.http import HttpClient, HttpxPostKwargs, UnexpectedResponseError -from generate.image_generation.base import ( - GeneratedImage, - ImageGenerationOutput, - RemoteImageGenerationModel, -) -from generate.model import ModelParameters, RemoteModelParametersDict -from generate.platforms.baidu import BaiduCreationSettings, BaiduCreationTokenManager - -ValidSize = Literal[ - '512x512', - '640x360', - '360x640', - '1024x1024', - '1280x720', - '720x1280', - '2048x2048', - '2560x1440', - '1440x2560', -] - - -class BaiduImageGenerationParameters(ModelParameters): - size: ValidSize = '1024x1024' - n: Optional[Annotated[int, Field(ge=1, le=8)]] = None - reference_image: Union[HttpUrl, Base64Str, None] = None - change_degree: Optional[Annotated[int, Field(ge=1, le=10)]] = None - - def custom_model_dump(self) -> dict[str, Any]: - output_data = {} - width, height = self.size.split('x') - output_data['width'] = int(width) - output_data['height'] = int(height) - n = self.n or 1 - output_data['image_num'] = n - if self.reference_image: - if isinstance(self.reference_image, str): - output_data['url'] = self.reference_image - else: - output_data['image'] = self.reference_image - if self.change_degree: - output_data['change_degree'] = self.change_degree - return output_data - - -class BaiduImageGenerationParametersDict(RemoteModelParametersDict, total=False): - size: ValidSize - n: Optional[int] - reference_image: Union[HttpUrl, Base64Str, None] - change_degree: Optional[int] - - -class BaiduImageGeneration(RemoteImageGenerationModel): - model_type = 'baidu' - - parameters: BaiduImageGenerationParameters - settings: BaiduCreationSettings - - def __init__( - self, - parameters: BaiduImageGenerationParameters | None = None, - settings: BaiduCreationSettings | None = None, - http_client: HttpClient | None = None, - task_timeout: int = 60, - ) -> None: - parameters = parameters or BaiduImageGenerationParameters() - settings = settings or BaiduCreationSettings() # type: ignore - http_client = http_client or HttpClient() - super().__init__(parameters=parameters, settings=settings, http_client=http_client) - - self.token_manager = BaiduCreationTokenManager(self.settings, self.http_client) - self.task_timeout = task_timeout - - def _get_request_parameters(self, prompt: str, parameters: BaiduImageGenerationParameters) -> HttpxPostKwargs: - headers = {'Content-Type': 'application/json', 'Accept': 'application/json'} - json_data = { - 'prompt': prompt, - **parameters.custom_model_dump(), - } - return { - 'url': self.settings.image_generation_api, - 'json': json_data, - 'headers': headers, - 'params': { - 'access_token': self.token_manager.token, - }, - } - - @override - def generate(self, prompt: str, **kwargs: Unpack[BaiduImageGenerationParametersDict]) -> ImageGenerationOutput: - parameters = self.parameters.clone_with_changes(**kwargs) - request_parameters = self._get_request_parameters(prompt, parameters) - response = self.http_client.post(request_parameters=request_parameters) - task_id = response.json()['data']['task_id'] - image_urls = self._get_image_urls(task_id) - generated_images: list[GeneratedImage] = [] - for image_url in image_urls: - image = GeneratedImage( - url=image_url, - prompt=prompt, - image_format='png', - content=self.http_client.get({'url': image_url}).content, - ) - generated_images.append(image) - return ImageGenerationOutput(model_info=self.model_info, cost=0.3 * (parameters.n or 1), images=generated_images) - - @override - async def async_generate(self, prompt: str, **kwargs: Unpack[BaiduImageGenerationParametersDict]) -> ImageGenerationOutput: - parameters = self.parameters.clone_with_changes(**kwargs) - request_parameters = self._get_request_parameters(prompt, parameters) - response = await self.http_client.async_post(request_parameters=request_parameters) - image_urls = await self._async_get_image_urls(response.json()['data']['task_id']) - images: list[GeneratedImage] = [] - images_response = await asyncio.gather(*(self.http_client.async_get({'url': image_url}) for image_url in image_urls)) - for image_url, image_response in zip(image_urls, images_response): - image = GeneratedImage( - url=image_url, - prompt=prompt, - image_format='png', - content=image_response.content, - ) - images.append(image) - return ImageGenerationOutput(model_info=self.model_info, cost=0.3 * (parameters.n or 1), images=images) - - def _get_image_request_parameters(self, task_id: str) -> HttpxPostKwargs: - return { - 'url': 'https://aip.baidubce.com/rpc/2.0/ernievilg/v1/getImgv2', - 'params': { - 'access_token': self.token_manager.token, - }, - 'headers': {'Content-Type': 'application/json'}, - 'json': {'task_id': str(task_id)}, - } - - def _parse_task_info(self, task_info: dict[str, Any]) -> list[str] | None: - task_status = task_info['data']['task_status'] - if task_status == 'FAILED': - raise UnexpectedResponseError(task_info, 'Task failed') - if task_status == 'SUCCESS': - image_urls: list[str] = [] - for sub_result in task_info['data']['sub_task_result_list']: - image_url = sub_result['final_image_list'][0]['img_url'] - image_urls.append(image_url) - return image_urls - return None - - def _get_image_urls(self, task_id: str) -> list[str]: - start_time = time.time() - task_info = None - while (time.time() - start_time) < self.task_timeout: - response = self.http_client.post(self._get_image_request_parameters(task_id)) - task_info = response.json() - image_urls = self._parse_task_info(task_info) - if image_urls: - return image_urls - time.sleep(1) - raise UnexpectedResponseError(task_info or {}, 'Timeout') - - async def _async_get_image_urls(self, task_id: str) -> list[str]: - start_time = time.time() - task_info = None - while (time.time() - start_time) < self.task_timeout: - response = await self.http_client.async_post(self._get_image_request_parameters(task_id)) - task_info = response.json() - image_urls = self._parse_task_info(task_info) - if image_urls: - return image_urls - await asyncio.sleep(1) - raise UnexpectedResponseError(task_info or {}, 'Timeout') - - @property - @override - def name(self) -> str: - return 'getImgv2' - - @classmethod - @override - def from_name(cls, name: str) -> Self: - if name and name != 'getImgv2': - raise ValueError(f'Invalid model name: {name}, expected: getImgv2') - return cls() diff --git a/generate/image_generation/models/openai.py b/generate/image_generation/models/openai.py deleted file mode 100644 index 20fcaa3..0000000 --- a/generate/image_generation/models/openai.py +++ /dev/null @@ -1,178 +0,0 @@ -from __future__ import annotations - -import base64 -from typing import Literal, Optional - -from httpx import Response -from pydantic import Field -from typing_extensions import Annotated, Self, Unpack, override - -from generate.http import HttpClient, HttpxPostKwargs -from generate.image_generation.base import GeneratedImage, ImageGenerationOutput, RemoteImageGenerationModel -from generate.model import ModelParameters, RemoteModelParametersDict -from generate.platforms.openai import OpenAISettings - -MAX_PROMPT_LENGTH_DALLE_3 = 4000 -MAX_PROMPT_LENGTH_DALLE_2 = 1000 -OPENAI_IMAGE_GENERATION_PRICE_MAP = { - 'dall-e-3': { - 'hd': { - '1024x1024': 0.04, - '1792x1024': 0.08, - '1024x1792': 0.08, - }, - 'standard': { - '1024x1024': 0.08, - '1792x1024': 0.12, - '1024x1792': 0.12, - }, - }, - 'dall-e-2': { - 'standard': { - '256x256': 0.016, - '512x512': 0.018, - '1024x1024': 0.02, - } - }, -} - - -class OpenAIImageGenerationParameters(ModelParameters): - quality: Optional[Literal['hd', 'standard']] = None - response_format: Optional[Literal['url', 'b64_json']] = None - size: Optional[Literal['256x256', '512x512', '1024x1024', '1792x1024', '1024x1792']] = None - style: Optional[Literal['vivid', 'natural']] = None - n: Optional[Annotated[int, Field(ge=1, le=10)]] = None - user: Optional[str] = None - - -class OpenAIImageGenerationParametersDict(RemoteModelParametersDict, total=False): - quality: Optional[Literal['hd', 'standard']] - response_format: Optional[Literal['url', 'b64_json']] - size: Optional[Literal['256x256', '512x512', '1024x1024', '1792x1024', '1024x1792']] - style: Optional[Literal['vivid', 'natural']] - n: Optional[int] - user: Optional[str] - - -class OpenAIImageGeneration(RemoteImageGenerationModel): - model_type = 'openai' - - parameters: OpenAIImageGenerationParameters - settings: OpenAISettings - - def __init__( - self, - model: str = 'dall-e-3', - parameters: OpenAIImageGenerationParameters | None = None, - settings: OpenAISettings | None = None, - http_client: HttpClient | None = None, - ) -> None: - parameters = parameters or OpenAIImageGenerationParameters() - settings = settings or OpenAISettings() # type: ignore - http_client = http_client or HttpClient() - super().__init__(parameters=parameters, settings=settings, http_client=http_client) - - self.model = model - self._check_parameters() - - def _check_parameters(self) -> None: - if self.model == 'dall-e-3': - if self.parameters.n is not None and self.parameters.n != 1: - raise ValueError('dall-e-3 only supports n=1') - size = self.parameters.size - if size is not None and size not in ('1024x1024', '1792x1024', '1024x1792'): - raise ValueError('dall-e-3 only supports size=1024x1024, 1792x1024, 1024x1792') - if self.model == 'dall-e-2': - if self.parameters.quality is not None: - raise ValueError('dall-e-2 does not support quality') - size = self.parameters.size - if size is not None and size not in ('256x256', '512x512', '1024x1024'): - raise ValueError('dall-e-2 only supports size=256x256, 512x512, 1024x1024') - if self.parameters.style is not None: - raise ValueError('dall-e-2 does not support style') - - def _check_prompt(self, prompt: str) -> None: - if self.model == 'dall-e-3' and len(prompt) >= MAX_PROMPT_LENGTH_DALLE_3: - raise ValueError('dall-e-3 does not support prompt length >= 4000') - - if self.model == 'dall-e-2' and len(prompt) >= MAX_PROMPT_LENGTH_DALLE_2: - raise ValueError('dall-e-2 does not support prompt length >= 100') - - def _get_request_parameters(self, prompt: str, parameters: OpenAIImageGenerationParameters) -> HttpxPostKwargs: - headers = { - 'Content-Type': 'application/json', - 'Authorization': f'Bearer {self.settings.api_key.get_secret_value()}', - } - json_data = { - 'model': self.model, - 'prompt': prompt, - **parameters.custom_model_dump(), - } - return { - 'url': self.settings.api_base + '/images/generations', - 'json': json_data, - 'headers': headers, - } - - @override - def generate(self, prompt: str, **kwargs: Unpack[OpenAIImageGenerationParametersDict]) -> ImageGenerationOutput: - self._check_prompt(prompt) - parameters = self.parameters.clone_with_changes(**kwargs) - request_parameters = self._get_request_parameters(prompt, parameters) - response = self.http_client.post(request_parameters=request_parameters) - return self._construct_model_output(prompt, parameters, response) - - @override - async def async_generate(self, prompt: str, **kwargs: Unpack[OpenAIImageGenerationParametersDict]) -> ImageGenerationOutput: - self._check_prompt(prompt) - parameters = self.parameters.clone_with_changes(**kwargs) - request_parameters = self._get_request_parameters(prompt, parameters) - response = await self.http_client.async_post(request_parameters=request_parameters) - return self._construct_model_output(prompt, parameters, response) - - def _construct_model_output( - self, prompt: str, parameters: OpenAIImageGenerationParameters, response: Response - ) -> ImageGenerationOutput: - response_data = response.json() - generated_images: list[GeneratedImage] = [] - for image_data in response_data['data']: - image_prompt = image_data.get('revised_prompt') or prompt - url = image_data.get('url') - if url: - content = self.http_client.get({'url': url}).content - else: - b64 = image_data.get('b64_json') - if b64 is None: - raise ValueError('No URL or b64_json found in response') - content = base64.b64decode(b64) - generated_images.append( - GeneratedImage( - url=url, - prompt=image_prompt, - image_format='png', - content=content, - ) - ) - return ImageGenerationOutput( - model_info=self.model_info, - images=generated_images, - cost=self.calculate_cost(parameters), - ) - - def calculate_cost(self, parameters: OpenAIImageGenerationParameters) -> float: - dollar_to_yuan = 7 - quality = parameters.quality or 'standard' - size = parameters.size or '1024x1024' - model_price = OPENAI_IMAGE_GENERATION_PRICE_MAP[self.model][quality][size] - return model_price * dollar_to_yuan - - @property - @override - def name(self) -> str: - return self.model - - @classmethod - @override - def from_name(cls, name: str) -> Self: - return cls(model=name) diff --git a/generate/image_generation/models/qianfan.py b/generate/image_generation/models/qianfan.py deleted file mode 100644 index 29ef6ac..0000000 --- a/generate/image_generation/models/qianfan.py +++ /dev/null @@ -1,120 +0,0 @@ -from __future__ import annotations - -import base64 -from typing import Literal, Optional - -from pydantic import Field -from typing_extensions import Annotated, Self, Unpack, override - -from generate.http import HttpClient, HttpxPostKwargs, ResponseValue -from generate.image_generation.base import GeneratedImage, ImageGenerationOutput, RemoteImageGenerationModel -from generate.model import ModelParameters, RemoteModelParametersDict -from generate.platforms.baidu import QianfanSettings, QianfanTokenManager - -ValidSize = Literal[ - '768x768', - '768x1024', - '1024x768', - '576x1024', - '1024x576', - '1024x1024', -] - - -class QianfanImageGenerationParameters(ModelParameters): - size: Optional[ValidSize] = None - n: Optional[Annotated[int, Field(ge=1, le=4)]] = None - negative_prompt: Optional[str] = None - steps: Optional[Annotated[int, Field(ge=1, le=50)]] = None - sampler: Optional[str] = Field(default=None, serialization_alias='sampler_index') - user: Optional[str] = Field(default=None, serialization_alias='user_id') - - -class QianfanImageGenerationParametersDict(RemoteModelParametersDict, total=False): - size: Optional[ValidSize] - n: Optional[int] - negative_prompt: Optional[str] - steps: Optional[int] - sampler: Optional[str] - user: Optional[str] - - -class QianfanImageGeneration(RemoteImageGenerationModel): - model_type = 'qianfan' - - parameters: QianfanImageGenerationParameters - settings: QianfanSettings - - def __init__( - self, - model: str = 'sd_xl', - parameters: QianfanImageGenerationParameters | None = None, - settings: QianfanSettings | None = None, - http_client: HttpClient | None = None, - ) -> None: - parameters = parameters or QianfanImageGenerationParameters() - settings = settings or QianfanSettings() # type: ignore - http_client = http_client or HttpClient() - super().__init__(parameters=parameters, settings=settings, http_client=http_client) - - self.model = model - self.token_manager = QianfanTokenManager(self.settings, self.http_client) - - @override - def generate(self, prompt: str, **kwargs: Unpack[QianfanImageGenerationParametersDict]) -> ImageGenerationOutput: - parameters = self.parameters.clone_with_changes(**kwargs) - request_parameters = self._get_request_parameters(prompt, parameters) - response = self.http_client.post(request_parameters=request_parameters) - return self._construct_model_output(prompt, response.json()) - - @override - async def async_generate( - self, prompt: str, **kwargs: Unpack[QianfanImageGenerationParametersDict] - ) -> ImageGenerationOutput: - parameters = self.parameters.clone_with_changes(**kwargs) - request_parameters = self._get_request_parameters(prompt, parameters) - response = await self.http_client.async_post(request_parameters=request_parameters) - return self._construct_model_output(prompt, response.json()) - - def _get_request_parameters(self, prompt: str, parameters: QianfanImageGenerationParameters) -> HttpxPostKwargs: - headers = {'Content-Type': 'application/json', 'Accept': 'application/json'} - json_data = { - 'prompt': prompt, - **parameters.custom_model_dump(), - } - return { - 'url': self.settings.image_generation_api_base + self.model, - 'json': json_data, - 'headers': headers, - 'params': { - 'access_token': self.token_manager.token, - }, - } - - def _construct_model_output(self, prompt: str, response_value: ResponseValue) -> ImageGenerationOutput: - images: list[GeneratedImage] = [] - for image_data in response_value['data']: - image = GeneratedImage( - prompt=prompt, - image_format='png', - content=base64.b64decode(image_data['b64_image']), - ) - images.append(image) - return ImageGenerationOutput( - model_info=self.model_info, - images=images, - extra={ - 'usage': response_value['usage'], - 'task_id': response_value['id'], - }, - ) - - @property - @override - def name(self) -> str: - return self.model - - @classmethod - @override - def from_name(cls, name: str) -> Self: - return cls(model=name) diff --git a/generate/image_generation/models/zhipu.py b/generate/image_generation/models/zhipu.py deleted file mode 100644 index 40df00f..0000000 --- a/generate/image_generation/models/zhipu.py +++ /dev/null @@ -1,90 +0,0 @@ -from __future__ import annotations - -from httpx import Response -from typing_extensions import Self, override - -from generate.http import HttpClient, HttpxPostKwargs -from generate.image_generation.base import GeneratedImage, ImageGenerationOutput, RemoteImageGenerationModel -from generate.model import ModelParameters -from generate.platforms.zhipu import ZhipuSettings, generate_zhipu_token - - -class ZhipuImageGenerationParameters(ModelParameters): - pass - - -class ZhipuImageGeneration(RemoteImageGenerationModel): - model_type = 'zhipu' - - parameters: ZhipuImageGenerationParameters - settings: ZhipuSettings - - def __init__( - self, - model: str = 'cogview-3', - settings: ZhipuSettings | None = None, - http_client: HttpClient | None = None, - ) -> None: - parameters = ZhipuImageGenerationParameters() - settings = settings or ZhipuSettings() # type: ignore - http_client = http_client or HttpClient() - super().__init__(parameters=parameters, settings=settings, http_client=http_client) - - self.model = model - - def _get_request_parameters(self, prompt: str) -> HttpxPostKwargs: - headers = { - 'Content-Type': 'application/json', - 'Authorization': generate_zhipu_token(api_key=self.settings.api_key.get_secret_value()), - } - json_data = { - 'model': self.model, - 'prompt': prompt, - } - return { - 'url': self.settings.v4_api_base + 'images/generations', - 'json': json_data, - 'headers': headers, - } - - @override - def generate(self, prompt: str) -> ImageGenerationOutput: - request_parameters = self._get_request_parameters(prompt) - response = self.http_client.post(request_parameters=request_parameters) - return self._construct_model_output(prompt, response) - - @override - async def async_generate(self, prompt: str) -> ImageGenerationOutput: - request_parameters = self._get_request_parameters(prompt) - response = await self.http_client.async_post(request_parameters=request_parameters) - return self._construct_model_output(prompt, response) - - def _construct_model_output(self, prompt: str, response: Response) -> ImageGenerationOutput: - response_data = response.json() - generated_images: list[GeneratedImage] = [] - for image_data in response_data['data']: - url = image_data['url'] - content = self.http_client.get({'url': url}).content - generated_images.append( - GeneratedImage( - url=url, - prompt=prompt, - image_format='png', - content=content, - ) - ) - return ImageGenerationOutput( - model_info=self.model_info, - images=generated_images, - cost=0.25, - ) - - @property - @override - def name(self) -> str: - return self.model - - @classmethod - @override - def from_name(cls, name: str) -> Self: - return cls(model=name) diff --git a/generate/modifiers/agent.py b/generate/modifiers/agent.py index 4a313e8..7ec0d31 100644 --- a/generate/modifiers/agent.py +++ b/generate/modifiers/agent.py @@ -17,7 +17,7 @@ ) from generate.chat_completion.message.utils import ensure_messages from generate.chat_completion.model_output import ChatCompletionOutput -from generate.chat_completion.tool import Tool, ToolCallMixin, ToolDict +from generate.chat_completion.tool import SupportToolCall, Tool, ToolDict from generate.types import OrIterable AgentMessage = Union[AssistantMessage, FunctionMessage, ToolMessage] @@ -47,7 +47,7 @@ def __init__( self.tools = ToolDict.from_iterable(tools) if self.tools: - if isinstance(model, ToolCallMixin): + if isinstance(model, SupportToolCall): model.add_tools(self.tools.values()) else: raise ValueError('Model does not support tools') diff --git a/generate/modifiers/cache.py b/generate/modifiers/cache.py new file mode 100644 index 0000000..2f02067 --- /dev/null +++ b/generate/modifiers/cache.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +import hashlib +from pathlib import Path +from typing import Any, AsyncIterator, TypeVar, cast + +from diskcache import Cache +from typing_extensions import Self, override + +from generate.chat_completion import ChatCompletionModel, ChatCompletionOutput, RemoteChatCompletionModel +from generate.chat_completion.message.core import Messages, Prompt +from generate.chat_completion.message.utils import ensure_messages +from generate.chat_completion.model_output import ChatCompletionStreamOutput + +T = TypeVar('T') + + +def messages_to_text(messages: Messages) -> str: + return '\n'.join([str(i) for i in messages]) + + +def hash_text(text: str) -> str: + return hashlib.md5(text.encode()).hexdigest() + + +class CacheChatCompletionModel(ChatCompletionModel): + defalut_cache_dir = Path.home() / '.cache' / 'generate-chat-completion' + + def __init__(self, model: ChatCompletionModel, cache_dir: Path | str | None = None) -> None: + self.model = cast(RemoteChatCompletionModel, model) + cache_dir = cache_dir or self.defalut_cache_dir + self.disk_cache = Cache(directory=cache_dir) + self.model_type = self.model.model_type # type: ignore + + @property + def name(self) -> str: + return self.model.name + + @classmethod + def from_name(cls, name: str) -> Self: + raise ValueError('Cache model cannot be created from name') + + @override + def generate(self, prompt: Prompt, **kwargs: Any) -> ChatCompletionOutput: + messages = ensure_messages(prompt) + hash_key = hash_text(f'{self.model.model_id} {messages_to_text(messages)} {self.model.parameters.model_dump_json()}') + if hash_key in self.disk_cache: + return self.disk_cache[hash_key] # type: ignore + model_output = self.model.generate(messages, **kwargs) + self.disk_cache[hash_key] = model_output + return model_output + + @override + async def async_generate(self, prompt: Prompt, **kwargs: Any) -> ChatCompletionOutput: + messages = ensure_messages(prompt) + hash_key = hash_text(f'{self.model.model_id} {messages_to_text(messages)} {self.model.parameters.model_dump_json()}') + if hash_key in self.disk_cache: + return self.disk_cache[hash_key] # type: ignore + model_output = await self.model.async_generate(messages, **kwargs) + self.disk_cache[hash_key] = model_output + return model_output + + @override + async def async_stream_generate(self, prompt: Prompt, **kwargs: Any) -> AsyncIterator[ChatCompletionStreamOutput]: + raise NotImplementedError('Stream generation is not supported for cache models') diff --git a/generate/modifiers/structure.py b/generate/modifiers/structure.py index 1d1e56d..b70f276 100644 --- a/generate/modifiers/structure.py +++ b/generate/modifiers/structure.py @@ -5,7 +5,7 @@ from copy import deepcopy from typing import Any, ClassVar, Dict, Generic, Iterable, List, Optional, Type, TypeVar -from pydantic import BaseModel +from pydantic import BaseModel, TypeAdapter from typing_extensions import Self, TypedDict, Unpack from generate.chat_completion import ChatCompletionModel @@ -19,18 +19,16 @@ ) from generate.model import GenerateModel, ModelOutput, RemoteModelParametersDict -field_info_title = 'Output JSON strictly based the format and pydantic field information below:\n' -json_schema_title = 'Output JSON strictly based the OpenAI JSON Schema:\n' system_template = """\ # Instruction {instruction} -# Output Format -{output_format_description} +# JSON Schema +{json_schema} """ I = TypeVar('I', BaseModel, str, Dict[str, Any]) # noqa: E741 -O = TypeVar('O', bound=BaseModel) # noqa: E741 +O = TypeVar('O') # noqa: E741 M = TypeVar('M', bound=ChatCompletionModel) @@ -125,15 +123,22 @@ class StructureGenerateModel(GenerateModel[str, StructureModelOutput[O]], Generi def __init__( self, model: M, - output_structure_type: Type[O], + output_structure_type: Type[O] | TypeAdapter[O], instruction: str | None = None, examples: Optional[Iterable[Example[O]]] = None, system_template: str = system_template, - output_exclude_none: bool = True, + output_exclude_none: bool = False, max_num_reask: int = 2, ) -> None: self.model = model - self.instruction = instruction or f'Extract {output_structure_type.__name__}' + + if isinstance(output_structure_type, TypeAdapter): + default_instruction = 'Extract Information' + else: + default_instruction = f'Extract {output_structure_type.__name__}' + default_instruction = 'According to the JSON Schema below, parse the input text.' + self.instruction = instruction or default_instruction + self.output_structure_type = output_structure_type self.examples = examples or [] self.system_template = system_template @@ -142,6 +147,10 @@ def __init__( self.model_type = self.model.model_type # type: ignore + @property + def is_typeadapter(self) -> bool: + return isinstance(self.output_structure_type, TypeAdapter) + @property def name(self) -> str: return self.model.name @@ -156,51 +165,47 @@ def messages(self) -> List[UnionMessage]: messages.append(self.system_message) for example in self.examples: messages.extend(ensure_messages(example.prompt)) - messages.append(AssistantMessage(content=example.output.model_dump_json(exclude_none=self.output_exclude_none))) + messages.append(AssistantMessage(content=self.model_dump_json(example.output))) return messages - @property - def _output_format_description(self) -> str: - json_schema = self.output_structure_type.model_json_schema() - have_ref = '$defs' in json_schema - if have_ref: - text = json_schema_title - json_schema = json.dumps(json_schema, indent=2) - text += json_schema - return text - text = field_info_title - fields_info = str(self.output_structure_type.model_fields) - fields_info = fields_info.replace("'", '"') - text += fields_info - return text - @property def system_message(self) -> SystemMessage: system_content = self.system_template.format( - instruction=self.instruction, output_format_description=self._output_format_description + instruction=self.instruction, + json_schema=json.dumps(self.model_json_schema(), indent=2, ensure_ascii=False), ) return SystemMessage(content=system_content) + def model_dump_json(self, item: O) -> str: + if self.is_typeadapter: + assert isinstance(self.output_structure_type, TypeAdapter) + return self.output_structure_type.dump_json(item, exclude_none=self.output_exclude_none).decode('utf-8') + assert isinstance(item, BaseModel) + return item.model_dump_json(exclude_none=self.output_exclude_none) + + def model_json_schema(self) -> dict[str, Any]: + if self.is_typeadapter: + assert isinstance(self.output_structure_type, TypeAdapter) + return self.output_structure_type.json_schema() + return self.output_structure_type.model_json_schema() # type: ignore + + def model_validate_json(self, json_string: str) -> O: + if self.is_typeadapter: + assert isinstance(self.output_structure_type, TypeAdapter) + return self.output_structure_type.validate_json(json_string) + return self.output_structure_type.model_validate_json(json_string) # type: ignore + def generate(self, prompt: Prompt, **kwargs: Unpack[RemoteModelParametersDict]) -> StructureModelOutput[O]: messages = deepcopy(self.messages) messages.extend(ensure_messages(prompt)) num_reask = 0 - cost = None while num_reask <= self.max_num_reask: model_output = self.model.generate(messages, **kwargs) messages.append(model_output.message) - if model_output.cost is not None: - if cost is None: - cost = model_output.cost - else: - cost += model_output.cost - try: json_string = ensure_valid_json(model_output.reply) - structure = self.output_structure_type.model_validate_json(json_string) - return StructureModelOutput( - model_info=model_output.model_info, structure=structure, cost=cost, extra=model_output.extra - ) + structure = self.model_validate_json(json_string) + return StructureModelOutput(model_info=model_output.model_info, structure=structure, extra=model_output.extra) except Exception as e: num_reask += 1 messages.append( @@ -213,22 +218,13 @@ async def async_generate(self, prompt: Prompt, **kwargs: Unpack[RemoteModelParam messages = deepcopy(self.messages) messages.extend(ensure_messages(prompt)) num_reask = 0 - cost = None while num_reask <= self.max_num_reask: model_output = await self.model.async_generate(messages, **kwargs) messages.append(model_output.message) - if model_output.cost is not None: - if cost is None: - cost = model_output.cost - else: - cost += model_output.cost - try: json_string = ensure_valid_json(model_output.reply) - structure = self.output_structure_type.model_validate_json(json_string) - return StructureModelOutput( - model_info=model_output.model_info, structure=structure, cost=cost, extra=model_output.extra - ) + structure = self.model_validate_json(json_string) + return StructureModelOutput(model_info=model_output.model_info, structure=structure, extra=model_output.extra) except Exception as e: num_reask += 1 messages.append( diff --git a/generate/platforms/__init__.py b/generate/platforms/__init__.py index 189d4a5..662ab1b 100644 --- a/generate/platforms/__init__.py +++ b/generate/platforms/__init__.py @@ -5,7 +5,6 @@ from generate.platforms.base import PlatformSettings from generate.platforms.dashscope import DashScopeSettings from generate.platforms.deepseek import DeepSeekSettings -from generate.platforms.hunyuan import HunyuanSettings from generate.platforms.minimax import MinimaxSettings from generate.platforms.moonshot import MoonshotSettings from generate.platforms.openai import OpenAISettings @@ -23,7 +22,6 @@ 'ZhipuSettings', 'OpenAISettings', 'QianfanSettings', - 'HunyuanSettings', 'DashScopeSettings', 'MoonshotSettings', 'DeepSeekSettings', diff --git a/generate/platforms/azure.py b/generate/platforms/azure.py index 5362c31..d9015fc 100644 --- a/generate/platforms/azure.py +++ b/generate/platforms/azure.py @@ -1,16 +1,13 @@ from typing import Optional -from pydantic import SecretStr from pydantic_settings import SettingsConfigDict -from generate.platforms.base import PlatformSettings +from generate.platforms.openai_like import OpenAILikeSettings -class AzureSettings(PlatformSettings): +class AzureSettings(OpenAILikeSettings): model_config = SettingsConfigDict(extra='ignore', env_prefix='azure_', env_file='.env') - api_key: SecretStr - api_base: str api_version: str chat_api_engine: Optional[str] = None platform_url: str = 'https://learn.microsoft.com/en-us/azure/ai-services/openai/' diff --git a/generate/platforms/hunyuan.py b/generate/platforms/hunyuan.py deleted file mode 100644 index 6a0de3b..0000000 --- a/generate/platforms/hunyuan.py +++ /dev/null @@ -1,15 +0,0 @@ -from pydantic import SecretStr -from pydantic_settings import SettingsConfigDict - -from generate.platforms.base import PlatformSettings - - -class HunyuanSettings(PlatformSettings): - model_config = SettingsConfigDict(extra='ignore', env_prefix='hunyuan_', env_file='.env') - - app_id: int - secret_id: SecretStr - secret_key: SecretStr - completion_api: str = 'https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions' - sign_api: str = 'hunyuan.cloud.tencent.com/hyllm/v1/chat/completions' - platform_url: str = 'https://cloud.tencent.com/document/product/1729' diff --git a/generate/text_to_speech/__init__.py b/generate/text_to_speech/__init__.py deleted file mode 100644 index 6344c36..0000000 --- a/generate/text_to_speech/__init__.py +++ /dev/null @@ -1,36 +0,0 @@ -from __future__ import annotations - -from typing import Type - -from generate.model import ModelParameters -from generate.text_to_speech.base import TextToSpeechModel, TextToSpeechOutput -from generate.text_to_speech.models import ( - MinimaxProSpeech, - MinimaxProSpeechParameters, - MinimaxSpeech, - MinimaxSpeechParameters, - OpenAISpeech, - OpenAISpeechParameters, -) - -SpeechModels: list[tuple[Type[TextToSpeechModel], Type[ModelParameters]]] = [ - (MinimaxSpeech, MinimaxSpeechParameters), - (MinimaxProSpeech, MinimaxProSpeechParameters), - (OpenAISpeech, OpenAISpeechParameters), -] - -SpeechModelRegistry: dict[str, tuple[Type[TextToSpeechModel], Type[ModelParameters]]] = { - model_cls.model_type: (model_cls, parameter_cls) for model_cls, parameter_cls in SpeechModels -} - - -__all__ = [ - 'TextToSpeechModel', - 'TextToSpeechOutput', - 'MinimaxSpeech', - 'MinimaxProSpeech', - 'MinimaxSpeechParameters', - 'MinimaxProSpeechParameters', - 'OpenAISpeech', - 'OpenAISpeechParameters', -] diff --git a/generate/text_to_speech/base.py b/generate/text_to_speech/base.py deleted file mode 100644 index da23b5a..0000000 --- a/generate/text_to_speech/base.py +++ /dev/null @@ -1,39 +0,0 @@ -import logging -from abc import ABC -from typing import ClassVar, Optional, TypeVar, get_type_hints - -from generate.http import HttpClient -from generate.model import GenerateModel, ModelOutput, ModelParameters -from generate.platforms.base import PlatformSettings - -P = TypeVar('P', bound=ModelParameters) -logger = logging.getLogger(__name__) - - -class TextToSpeechOutput(ModelOutput): - audio: bytes - audio_format: str - cost: Optional[float] = None - - -class TextToSpeechModel(GenerateModel[str, TextToSpeechOutput], ABC): - model_task: ClassVar[str] = 'text_to_speech' - - -class RemoteTextToSpeechModel(TextToSpeechModel): - settings: PlatformSettings - http_client: HttpClient - - def __init__( - self, - parameters: ModelParameters, - settings: PlatformSettings, - http_client: HttpClient, - ) -> None: - self.parameters = parameters - self.settings = settings - self.http_client = http_client - - @classmethod - def how_to_settings(cls) -> str: - return f'{cls.__name__} Settings\n\n' + get_type_hints(cls)['settings'].how_to_settings() diff --git a/generate/text_to_speech/models/__init__.py b/generate/text_to_speech/models/__init__.py deleted file mode 100644 index e05be46..0000000 --- a/generate/text_to_speech/models/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -from generate.text_to_speech.models.minimax import ( - MinimaxProSpeech, - MinimaxProSpeechParameters, - MinimaxSpeech, - MinimaxSpeechParameters, -) -from generate.text_to_speech.models.openai import OpenAISpeech, OpenAISpeechParameters - -__all__ = [ - 'MinimaxSpeech', - 'MinimaxProSpeech', - 'MinimaxSpeechParameters', - 'MinimaxProSpeechParameters', - 'OpenAISpeech', - 'OpenAISpeechParameters', -] diff --git a/generate/text_to_speech/models/minimax.py b/generate/text_to_speech/models/minimax.py deleted file mode 100644 index f30d2a9..0000000 --- a/generate/text_to_speech/models/minimax.py +++ /dev/null @@ -1,212 +0,0 @@ -from __future__ import annotations - -from typing import List, Literal, Optional - -from pydantic import Field, model_validator -from typing_extensions import Annotated, Self, TypedDict, Unpack, override - -from generate.http import HttpClient, HttpxPostKwargs, UnexpectedResponseError -from generate.model import ModelParameters -from generate.platforms.minimax import MinimaxSettings -from generate.text_to_speech.base import RemoteTextToSpeechModel, TextToSpeechOutput - - -class TimeberWeight(TypedDict): - voice_id: str - weight: int - - -class MinimaxSpeechParameters(ModelParameters): - voice: Optional[str] = Field(default=None, alias='voice_id') - speed: Annotated[Optional[float], Field(ge=0.5, le=2.0)] = None - vol: Annotated[Optional[float], Field(gt=0, le=10)] = None - pitch: Annotated[Optional[float], Field(ge=-12, le=12)] = None - timber_weights: Annotated[Optional[List[TimeberWeight]], Field(min_length=1, max_length=4)] = None - - @model_validator(mode='after') - def voice_exists(self) -> Self: - if self.voice is None and self.timber_weights is None: - self.voice = 'male-qn-qingse' - return self - - -class MinimaxSpeechParametersDict(TypedDict, total=False): - voice: Optional[str] - speed: Optional[float] - vol: Optional[float] - pitch: Optional[float] - timber_weights: Optional[List[TimeberWeight]] - - -class MinimaxProSpeechParameters(MinimaxSpeechParameters): - audio_sample_rate: Annotated[Optional[int], Field(ge=16000, le=24000)] = 24000 - bitrate: Literal[32000, 64000, 128000] = 128000 - - -class MinimaxProSpeechParametersDict(MinimaxSpeechParametersDict, total=False): - audio_sample_rate: Optional[int] - bitrate: Optional[Literal[32000, 64000, 128000]] - - -class MinimaxSpeech(RemoteTextToSpeechModel): - model_type = 'minimax' - - parameters: MinimaxSpeechParameters - settings: MinimaxSettings - - def __init__( - self, - model: str = 'speech-01', - settings: MinimaxSettings | None = None, - parameters: MinimaxSpeechParameters | None = None, - http_client: HttpClient | None = None, - ) -> None: - parameters = parameters or MinimaxSpeechParameters() - settings = settings or MinimaxSettings() # type: ignore - http_client = http_client or HttpClient() - super().__init__(parameters=parameters, settings=settings, http_client=http_client) - - self.model = model - - def _get_request_parameters(self, text: str, parameters: MinimaxSpeechParameters) -> HttpxPostKwargs: - json_data = { - 'model': self.model, - 'text': text, - **parameters.custom_model_dump(), - } - headers = { - 'Authorization': f'Bearer {self.settings.api_key.get_secret_value()}', - 'Content-Type': 'application/json', - } - return { - 'url': self.settings.api_base + '/text_to_speech', - 'json': json_data, - 'headers': headers, - 'params': {'GroupId': self.settings.group_id}, - } - - def generate(self, prompt: str, **kwargs: Unpack[MinimaxSpeechParametersDict]) -> TextToSpeechOutput: - parameters = self.parameters.clone_with_changes(**kwargs) - request_parameters = self._get_request_parameters(prompt, parameters) - response = self.http_client.post(request_parameters=request_parameters) - return TextToSpeechOutput( - model_info=self.model_info, - audio=response.content, - audio_format='mp3', - cost=self.calculate_cost(prompt), - ) - - async def async_generate(self, prompt: str, **kwargs: Unpack[MinimaxSpeechParametersDict]) -> TextToSpeechOutput: - parameters = self.parameters.clone_with_changes(**kwargs) - request_parameters = self._get_request_parameters(prompt, parameters) - response = await self.http_client.async_post(request_parameters=request_parameters) - return TextToSpeechOutput( - model_info=self.model_info, - audio=response.content, - audio_format='mp3', - cost=self.calculate_cost(prompt), - ) - - @property - @override - def name(self) -> str: - return self.model - - @classmethod - @override - def from_name(cls, name: str) -> Self: - return cls(model=name) - - @staticmethod - def calculate_cost(text: str) -> float: - character_count = sum(2 if '\u4e00' <= char <= '\u9fff' else 1 for char in text) - return character_count / 1000 - - -class MinimaxProSpeech(RemoteTextToSpeechModel): - model_type = 'minimax_pro' - - parameters: MinimaxProSpeechParameters - settings: MinimaxSettings - - def __init__( - self, - model: str = 'speech-01', - parameters: MinimaxProSpeechParameters | None = None, - settings: MinimaxSettings | None = None, - http_client: HttpClient | None = None, - ) -> None: - parameters = parameters or MinimaxProSpeechParameters() - settings = settings or MinimaxSettings() # type: ignore - http_client = http_client or HttpClient() - super().__init__(parameters=parameters, settings=settings, http_client=http_client) - - self.model = model - - def _get_request_parameters(self, text: str, parameters: MinimaxProSpeechParameters) -> HttpxPostKwargs: - json_data = { - 'model': self.model, - 'text': text, - **parameters.custom_model_dump(), - } - headers = { - 'Authorization': f'Bearer {self.settings.api_key.get_secret_value()}', - 'Content-Type': 'application/json', - } - return { - 'url': self.settings.api_base + '/t2a_pro', - 'json': json_data, - 'headers': headers, - 'params': {'GroupId': self.settings.group_id}, - } - - @override - def generate(self, prompt: str, **kwargs: Unpack[MinimaxProSpeechParametersDict]) -> TextToSpeechOutput: - parameters = self.parameters.clone_with_changes(**kwargs) - request_parameters = self._get_request_parameters(prompt, parameters) - response = self.http_client.post(request_parameters=request_parameters) - response_data = response.json() - if response_data['base_resp']['status_code'] != 0: - raise UnexpectedResponseError(response_data) - - model_output = TextToSpeechOutput( - model_info=self.model_info, - audio=self.http_client.get({'url': response_data['audio_file']}).content, - audio_format='mp3', - cost=response_data['extra_info']['word_count'] / 1000, - ) - model_output.extra['subtitle'] = self.http_client.get({'url': response_data['subtitle_file']}).json() - model_output.extra.update(response_data['extra_info']) - return model_output - - @override - async def async_generate(self, prompt: str, **kwargs: Unpack[MinimaxProSpeechParametersDict]) -> TextToSpeechOutput: - parameters = self.parameters.clone_with_changes(**kwargs) - request_parameters = self._get_request_parameters(prompt, parameters) - response = await self.http_client.async_post(request_parameters=request_parameters) - response_data = response.json() - if response_data['base_resp']['status_code'] != 0: - raise UnexpectedResponseError(response_data) - - audio = (await self.http_client.async_get({'url': response_data['audio_file']})).content - subtitle = (await self.http_client.async_get({'url': response_data['subtitle_file']})).json() - - model_output = TextToSpeechOutput( - model_info=self.model_info, - audio=audio, - audio_format='mp3', - cost=response_data['extra_info']['word_count'] / 1000, - extra={'subtitle': subtitle}, - ) - model_output.extra.update(response_data['extra_info']) - return model_output - - @property - @override - def name(self) -> str: - return self.model - - @classmethod - @override - def from_name(cls, name: str) -> Self: - return cls(model=name) diff --git a/generate/text_to_speech/models/openai.py b/generate/text_to_speech/models/openai.py deleted file mode 100644 index 6771f24..0000000 --- a/generate/text_to_speech/models/openai.py +++ /dev/null @@ -1,104 +0,0 @@ -from __future__ import annotations - -from typing import Literal, Optional - -from pydantic import Field -from typing_extensions import Annotated, Self, TypedDict, Unpack, override - -from generate.http import HttpClient, HttpxPostKwargs -from generate.model import ModelParameters -from generate.platforms.openai import OpenAISettings -from generate.text_to_speech.base import RemoteTextToSpeechModel, TextToSpeechOutput - - -class OpenAISpeechParameters(ModelParameters): - voice: str = 'alloy' - response_format: Optional[Literal['mp3', 'aac', 'opus', 'flac']] = None - speed: Annotated[Optional[float], Field(ge=0.25, le=4.0)] = None - - -class OpenAISpeechParametersDict(TypedDict, total=False): - voice: str - response_format: Optional[Literal['mp3', 'aac', 'opus', 'flac']] - speed: Optional[float] - - -class OpenAISpeech(RemoteTextToSpeechModel): - model_type = 'openai' - - parameters: OpenAISpeechParameters - settings: OpenAISettings - - def __init__( - self, - model: str = 'tts-1', - settings: OpenAISettings | None = None, - parameters: OpenAISpeechParameters | None = None, - http_client: HttpClient | None = None, - ) -> None: - parameters = parameters or OpenAISpeechParameters() - settings = settings or OpenAISettings() # type: ignore - http_client = http_client or HttpClient() - super().__init__(parameters=parameters, settings=settings, http_client=http_client) - - self.model = model - - def _get_request_parameters(self, text: str, parameters: OpenAISpeechParameters) -> HttpxPostKwargs: - json_data = { - 'model': self.model, - 'input': text, - **parameters.custom_model_dump(), - } - headers = { - 'Authorization': f'Bearer {self.settings.api_key.get_secret_value()}', - 'Content-Type': 'application/json', - } - return { - 'url': self.settings.api_base + '/audio/speech', - 'json': json_data, - 'headers': headers, - } - - @override - def generate(self, prompt: str, **kwargs: Unpack[OpenAISpeechParametersDict]) -> TextToSpeechOutput: - parameters = self.parameters.clone_with_changes(**kwargs) - request_parameters = self._get_request_parameters(prompt, parameters) - response = self.http_client.post(request_parameters=request_parameters) - return TextToSpeechOutput( - model_info=self.model_info, - audio=response.content, - audio_format=parameters.response_format or 'mp3', - cost=self.calculate_cost(prompt), - ) - - @override - async def async_generate(self, prompt: str, **kwargs: Unpack[OpenAISpeechParametersDict]) -> TextToSpeechOutput: - parameters = self.parameters.clone_with_changes(**kwargs) - request_parameters = self._get_request_parameters(prompt, parameters) - response = await self.http_client.async_post(request_parameters=request_parameters) - return TextToSpeechOutput( - model_info=self.model_info, - audio=response.content, - audio_format=parameters.response_format or 'mp3', - cost=self.calculate_cost(prompt), - ) - - @property - @override - def name(self) -> str: - return self.model - - @classmethod - @override - def from_name(cls, name: str) -> Self: - return cls(model=name) - - def calculate_cost(self, text: str) -> float | None: - dollar_to_yuan = 7 - if self.model == 'tts-1': - return (len(text) / 1000) * (0.015 * dollar_to_yuan) - - if self.model == 'tts-1-hd': - return (len(text) / 1000) * (0.03 * dollar_to_yuan) - - return None diff --git a/generate/types.py b/generate/types.py index ffb9334..2ab7247 100644 --- a/generate/types.py +++ b/generate/types.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Dict, Iterable, Optional, Sequence, Tuple, TypeVar, Union +from typing import Any, Dict, Iterable, Optional, Sequence, TypeVar, Union from pydantic import Field from typing_extensions import Annotated @@ -12,4 +12,3 @@ PrimitiveData = Optional[Union[str, int, float, bool]] OrSequence = Union[T, Sequence[T]] OrIterable = Union[T, Iterable[T]] -ModelPrice = Dict[str, Tuple[float, float]] diff --git a/generate/ui.py b/generate/ui.py index 426e7de..5b8d94a 100644 --- a/generate/ui.py +++ b/generate/ui.py @@ -5,7 +5,6 @@ from pydantic import BaseModel from generate.chat_completion.base import RemoteChatCompletionModel -from generate.chat_completion.models.dashscope_multimodal import DashScopeMultiModalChat from generate.model import ModelInfo try: @@ -133,15 +132,10 @@ async def main(message: cl.Message) -> None: if element.path is not None: with open(element.path, 'rb') as image_file: image_content = image_file.read() - if isinstance(state.chat_model, DashScopeMultiModalChat): - url = state.chat_model.upload_image(image_content, image_format) - image_url = ImageUrl(url=url) - image_part = ImageUrlPart(image_url=image_url) - else: - image_part = ImagePart( - image=image_content, - image_format=image_format, - ) + image_part = ImagePart( + image=image_content, + image_format=image_format, + ) image_parts.append(image_part) elif element.url is not None: image_url = ImageUrl(url=element.url) diff --git a/poetry.lock b/poetry.lock index 8e0a02c..74b9d15 100644 --- a/poetry.lock +++ b/poetry.lock @@ -794,6 +794,17 @@ wrapt = ">=1.10,<2" [package.extras] dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"] +[[package]] +name = "diskcache" +version = "5.6.3" +description = "Disk Cache -- Disk and file backed persistent cache." +optional = false +python-versions = ">=3" +files = [ + {file = "diskcache-5.6.3-py3-none-any.whl", hash = "sha256:5e31b2d5fbad117cc363ebaf6b689474db18a1f6438bc82358b024abd4c2ca19"}, + {file = "diskcache-5.6.3.tar.gz", hash = "sha256:2c3a3fa2743d8535d832ec61c2054a1641f41775aa7c556758a109941e33e4fc"}, +] + [[package]] name = "docstring-parser" version = "0.15" @@ -2937,4 +2948,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.8.1" -content-hash = "ce9fe87276bfc2874cef908975e3a4fc5ea8c84470d68445ceaa4018b804bc20" +content-hash = "cf56ac7d423478db8e01c2a6e092d88ac8edaa114c2c5ed0749f7304a609d676" diff --git a/pyproject.toml b/pyproject.toml index 6150bde..acfbcfe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ pydantic = "^2.0" docstring-parser = "^0.15" httpx-sse = "^0.4.0" pydantic-settings = "^2.1.0" +diskcache = "^5.6.3" [tool.ruff] line-length = 128 diff --git a/tests/test_highlevel.py b/tests/test_highlevel.py index 834dfa9..b170146 100644 --- a/tests/test_highlevel.py +++ b/tests/test_highlevel.py @@ -1,11 +1,4 @@ -from generate.highlevel import ( - generate_image, - generate_speech, - generate_text, - load_chat_model, - load_image_generation_model, - load_speech_model, -) +from generate.highlevel import generate_text, load_chat_model def test_load_chat_model() -> None: @@ -14,28 +7,6 @@ def test_load_chat_model() -> None: assert model.name == 'gpt-3.5-turbo' -def test_load_speech_model() -> None: - model = load_speech_model('openai/tts-1-hd') - assert model.model_type == 'openai' - assert model.name == 'tts-1-hd' - - -def test_load_image_generation_model() -> None: - model = load_image_generation_model('openai/dall-e-3') - assert model.model_type == 'openai' - assert model.name == 'dall-e-3' - - def test_generate_text() -> None: output = generate_text('你好') assert output.reply != '' - - -def test_generate_speech() -> None: - output = generate_speech('这是一个测试用例') - assert len(output.audio) != 0 - - -def test_generate_image() -> None: - output = generate_image('可爱的猫', model_id='openai/dall-e-2') - assert len(output.images[0].content) != 0 diff --git a/tests/test_text_to_speech_model.py b/tests/test_text_to_speech_model.py deleted file mode 100644 index 5426bf8..0000000 --- a/tests/test_text_to_speech_model.py +++ /dev/null @@ -1,26 +0,0 @@ -from __future__ import annotations - -import asyncio - -import pytest - -from generate.test import get_pytest_params -from generate.text_to_speech import ( - SpeechModelRegistry, - SpeechModels, - TextToSpeechModel, -) - - -def test_model_type_is_unique() -> None: - assert len(SpeechModels) == len(SpeechModelRegistry) - - -@pytest.mark.parametrize('speech_model', get_pytest_params('test_text_to_speech', SpeechModelRegistry, types='model')) -def test_speech_model(speech_model: TextToSpeechModel) -> None: - prompt = '你好,这是一个测试用例' - sync_output = speech_model.generate(prompt) - async_output = asyncio.run(speech_model.async_generate(prompt)) - - assert len(sync_output.audio) != 0 - assert len(async_output.audio) != 0 From afb54392a2d786762befd32a047f675d2652f6e4 Mon Sep 17 00:00:00 2001 From: wangyuxin Date: Thu, 23 May 2024 18:30:46 +0800 Subject: [PATCH 3/4] =?UTF-8?q?=E6=B8=85=E7=90=86=E5=86=97=E4=BD=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- generate/chat_completion/models/baichuan.py | 184 +----------------- generate/chat_completion/models/dashscope.py | 10 - generate/chat_completion/models/minimax.py | 3 +- generate/chat_completion/models/moonshot.py | 12 +- .../chat_completion/models/openai_like.py | 32 ++- generate/chat_completion/stream_manager.py | 4 +- generate/platforms/baichuan.py | 6 +- 7 files changed, 34 insertions(+), 217 deletions(-) diff --git a/generate/chat_completion/models/baichuan.py b/generate/chat_completion/models/baichuan.py index 45f72c6..a5ae75c 100644 --- a/generate/chat_completion/models/baichuan.py +++ b/generate/chat_completion/models/baichuan.py @@ -1,29 +1,18 @@ from __future__ import annotations -from typing import Any, AsyncIterator, ClassVar, Iterator, List, Literal, Optional +from typing import AsyncIterator, ClassVar, Iterator, List, Literal, Optional from pydantic import Field from typing_extensions import Annotated, NotRequired, TypedDict, Unpack, override -from generate.chat_completion.base import RemoteChatCompletionModel from generate.chat_completion.message import ( - AssistantMessage, - Messages, Prompt, - SystemMessage, - ToolMessage, - UserMessage, ) -from generate.chat_completion.message.converter import MessageConverter, SimpleMessageConverter -from generate.chat_completion.message.core import FunctionCall, FunctionMessage, ToolCall, UserMultiPartMessage -from generate.chat_completion.message.exception import MessageTypeError -from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput, FinishReason, Usage -from generate.chat_completion.models.openai_like import SupportOpenAIToolCall -from generate.chat_completion.stream_manager import StreamManager +from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput +from generate.chat_completion.models.openai_like import OpenAILikeChat, SupportOpenAIToolCall from generate.chat_completion.tool import FunctionJsonSchema from generate.http import ( HttpClient, - HttpxPostKwargs, ) from generate.model import ModelParameters, RemoteModelParametersDict from generate.platforms.baichuan import BaichuanSettings @@ -71,58 +60,7 @@ class BaichuanChatParametersDict(RemoteModelParametersDict, total=False): tool_choice: Optional[str] -class BaichuanMessageConverter(MessageConverter): - allowed_message_types = [SystemMessage, UserMessage, AssistantMessage, ToolMessage] - - def convert_system_message(self, message: SystemMessage) -> dict[str, Any]: - return { - 'role': 'system', - 'content': message.content, - } - - def convert_user_message(self, message: UserMessage) -> dict[str, Any]: - return { - 'role': 'user', - 'content': message.content, - } - - def convert_tool_message(self, message: ToolMessage) -> dict[str, Any]: - return { - 'role': 'tool', - 'tool_call_id': message.tool_call_id, - 'content': message.content, - } - - def convert_assistant_message(self, message: AssistantMessage) -> dict[str, Any]: - base_dict = { - 'role': 'assistant', - 'content': message.content or None, - } - if message.tool_calls: - tool_calls = [ - { - 'id': tool_call.id, - 'type': 'function', - 'function': { - 'name': tool_call.function.name, - 'arguments': tool_call.function.arguments, - }, - } - for tool_call in message.tool_calls - ] - base_dict['tool_calls'] = tool_calls - if message.function_call: - raise ValueError('Function calls are not supported in Baichuan') - return base_dict - - def convert_user_multi_part_message(self, message: UserMultiPartMessage) -> dict[str, Any]: - raise MessageTypeError(message, allowed_message_type=self.allowed_message_types) - - def convert_function_message(self, message: FunctionMessage) -> dict[str, Any]: - raise MessageTypeError(message, allowed_message_type=self.allowed_message_types) - - -class BaichuanChat(RemoteChatCompletionModel, SupportOpenAIToolCall): +class BaichuanChat(OpenAILikeChat, SupportOpenAIToolCall): model_type: ClassVar[str] = 'baichuan' available_models: ClassVar[List[str]] = [ 'Baichuan2-Turbo', @@ -134,7 +72,6 @@ class BaichuanChat(RemoteChatCompletionModel, SupportOpenAIToolCall): parameters: BaichuanChatParameters settings: BaichuanSettings - message_converter: SimpleMessageConverter def __init__( self, @@ -142,19 +79,12 @@ def __init__( parameters: BaichuanChatParameters | None = None, settings: BaichuanSettings | None = None, http_client: HttpClient | None = None, - message_converter: MessageConverter | None = None, ) -> None: parameters = parameters or BaichuanChatParameters() settings = settings or BaichuanSettings() # type: ignore http_client = http_client or HttpClient() - message_converter = message_converter or BaichuanMessageConverter() - super().__init__( - model=model, - parameters=parameters, - settings=settings, - http_client=http_client, - message_converter=message_converter, - ) + model = model + super().__init__(model=model, parameters=parameters, settings=settings, http_client=http_client) @override def generate(self, prompt: Prompt, **kwargs: Unpack[BaichuanChatParametersDict]) -> ChatCompletionOutput: @@ -174,103 +104,5 @@ def stream_generate( async def async_stream_generate( self, prompt: Prompt, **kwargs: Unpack[BaichuanChatParametersDict] ) -> AsyncIterator[ChatCompletionStreamOutput]: - async for output in super().async_stream_generate(prompt, **kwargs): - yield output - - @override - def _get_request_parameters( - self, messages: Messages, stream: bool = False, **kwargs: Unpack[BaichuanChatParametersDict] - ) -> HttpxPostKwargs: - if isinstance(system_message := messages[0], SystemMessage): - prepend_messages = [UserMessage(content=system_message.content)] - messages = prepend_messages + messages[1:] - parameters = self.parameters.clone_with_changes(**kwargs) - json_data = { - 'model': self.model, - 'messages': self.message_converter.convert_messages(messages), - } - parameters_dict = parameters.custom_model_dump() - json_data.update(parameters_dict) - if stream: - json_data['stream'] = True - headers = { - 'Content-Type': 'application/json', - 'Authorization': 'Bearer ' + self.settings.api_key.get_secret_value(), - } - return { - 'url': self.settings.api_base + '/chat/completions', - 'headers': headers, - 'json': json_data, - } - - @override - def _process_reponse(self, response: dict[str, Any]) -> ChatCompletionOutput: - return ChatCompletionOutput( - model_info=self.model_info, - message=self._parse_assistant_message(response['choices'][0]['message']), - finish_reason=self._parse_finish_reason(response), - usage=self._parse_usage(response), - extra=self._parse_extra(response), - ) - - @override - def _process_stream_response( - self, response: dict[str, Any], stream_manager: StreamManager - ) -> ChatCompletionStreamOutput | None: - delta_dict = response['choices'][0].get('delta', {}) - self._update_delta(delta_dict, stream_manager=stream_manager) - stream_manager.extra = self._parse_extra(response) - stream_manager.usage = self._parse_usage(response) - stream_manager.finish_reason = self._parse_finish_reason(response) - return stream_manager.build_stream_output() - - def _parse_assistant_message(self, message: dict[str, Any]) -> AssistantMessage: - if tool_calls_dict := message.get('tool_calls'): - tool_calls = [ - ToolCall( - id=tool_call['id'], - function=FunctionCall( - name=tool_call['function'].get('name') or '', - arguments=tool_call['function']['arguments'], - ), - ) - for tool_call in tool_calls_dict - ] - else: - tool_calls = None - return AssistantMessage(content=message.get('content') or '', tool_calls=tool_calls) - - def _parse_usage(self, response: dict[str, Any]) -> Usage: - usage = response.get('usage') - if usage is not None: - input_tokens = usage['prompt_tokens'] - output_tokens = usage['completion_tokens'] - return Usage(input_tokens=input_tokens, output_tokens=output_tokens) - return Usage() - - def _parse_finish_reason(self, response: dict[str, Any]) -> FinishReason | None: - try: - choice = response['choices'][0] - if finish_reason := choice.get('finish_reason'): - return FinishReason(finish_reason) - except (KeyError, IndexError, ValueError): - return None - - def _parse_extra(self, response: dict[str, Any]) -> dict[str, Any]: - return {'response': response} - - def _update_delta(self, delta_dict: dict[str, Any], stream_manager: StreamManager) -> None: - delta_content: str = delta_dict.get('content') or '' - stream_manager.delta = delta_content - - if delta_dict.get('tool_calls'): - index = delta_dict['tool_calls'][0]['index'] - if index >= len(stream_manager.tool_calls or []): - new_tool_calls_message = self._parse_assistant_message(delta_dict).tool_calls - assert new_tool_calls_message is not None - if stream_manager.tool_calls is None: - stream_manager.tool_calls = [] - stream_manager.tool_calls.append(new_tool_calls_message[0]) - else: - assert stream_manager.tool_calls is not None - stream_manager.tool_calls[index].function.arguments += delta_dict['tool_calls'][0]['function']['arguments'] + async for stream_output in super().async_stream_generate(prompt, **kwargs): + yield stream_output diff --git a/generate/chat_completion/models/dashscope.py b/generate/chat_completion/models/dashscope.py index 50b684e..9d4685b 100644 --- a/generate/chat_completion/models/dashscope.py +++ b/generate/chat_completion/models/dashscope.py @@ -215,16 +215,6 @@ def _process_stream_response( stream_manager.finish_reason = self._parse_finish_reason(choice) return stream_manager.build_stream_output() - # reply = response['output']['text'] - # stream_manager.usage = self._parse_usage(response) - # stream_manager.extra = self._parse_extra(response) - # if choice['finish_reason'] != 'null': - # stream_manager.finish_reason = self._parse_finish_reason(choice) - # stream_manager.delta = '' - # return stream_manager.build_stream_output() - # stream_manager.delta = reply[len(stream_manager.content) :] - # return stream_manager.build_stream_output() - def _parse_usage(self, response: dict[str, Any]) -> Usage: if usage := response.get('usage'): input_tokens = usage.get('input_tokens') diff --git a/generate/chat_completion/models/minimax.py b/generate/chat_completion/models/minimax.py index 2b27b85..237f063 100644 --- a/generate/chat_completion/models/minimax.py +++ b/generate/chat_completion/models/minimax.py @@ -13,6 +13,7 @@ from generate.chat_completion.message.core import Messages from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput, FinishReason, Usage from generate.chat_completion.models.openai_like import OpenAILikeChat, OpenAITool +from generate.chat_completion.tool import SupportToolCall from generate.http import ( HttpClient, HttpxPostKwargs, @@ -45,7 +46,7 @@ class MinimaxChatParametersDict(RemoteModelParametersDict, total=False): tools: Optional[List[OpenAITool]] -class MinimaxChat(OpenAILikeChat): +class MinimaxChat(OpenAILikeChat, SupportToolCall): model_type: ClassVar[str] = 'minimax' available_models: ClassVar[List[str]] = ['abab5.5-chat', 'abab5.5s-chat', 'abab6-chat', 'abab6.5-chat'] CHAT_COMPLETION_ENDPOINT: ClassVar[str] = '/text/chatcompletion_v2' diff --git a/generate/chat_completion/models/moonshot.py b/generate/chat_completion/models/moonshot.py index 16cfe2c..856e13d 100644 --- a/generate/chat_completion/models/moonshot.py +++ b/generate/chat_completion/models/moonshot.py @@ -20,7 +20,7 @@ class MoonshotChatParameters(ModelParameters): max_tokens: Optional[PositiveInt] = None -class MoonshotParametersDict(RemoteModelParametersDict, total=False): +class MoonshotChatParametersDict(RemoteModelParametersDict, total=False): temperature: Temperature top_p: Probability max_tokens: PositiveInt @@ -47,20 +47,22 @@ def __init__( super().__init__(model=model, parameters=parameters, settings=settings, http_client=http_client) @override - def generate(self, prompt: Prompt, **kwargs: Unpack[MoonshotParametersDict]) -> ChatCompletionOutput: + def generate(self, prompt: Prompt, **kwargs: Unpack[MoonshotChatParametersDict]) -> ChatCompletionOutput: return super().generate(prompt, **kwargs) @override - async def async_generate(self, prompt: Prompt, **kwargs: Unpack[MoonshotParametersDict]) -> ChatCompletionOutput: + async def async_generate(self, prompt: Prompt, **kwargs: Unpack[MoonshotChatParametersDict]) -> ChatCompletionOutput: return await super().async_generate(prompt, **kwargs) @override - def stream_generate(self, prompt: Prompt, **kwargs: Unpack[MoonshotParametersDict]) -> Iterator[ChatCompletionStreamOutput]: + def stream_generate( + self, prompt: Prompt, **kwargs: Unpack[MoonshotChatParametersDict] + ) -> Iterator[ChatCompletionStreamOutput]: yield from super().stream_generate(prompt, **kwargs) @override async def async_stream_generate( - self, prompt: Prompt, **kwargs: Unpack[MoonshotParametersDict] + self, prompt: Prompt, **kwargs: Unpack[MoonshotChatParametersDict] ) -> AsyncIterator[ChatCompletionStreamOutput]: async for stream_output in super().async_stream_generate(prompt, **kwargs): yield stream_output diff --git a/generate/chat_completion/models/openai_like.py b/generate/chat_completion/models/openai_like.py index e1e6941..3c8c2ab 100644 --- a/generate/chat_completion/models/openai_like.py +++ b/generate/chat_completion/models/openai_like.py @@ -4,7 +4,7 @@ from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Union, cast from pydantic import Field, PositiveInt -from typing_extensions import Annotated, NotRequired, TypedDict, Unpack, override +from typing_extensions import Annotated, NotRequired, TypedDict, override from generate.chat_completion.base import RemoteChatCompletionModel from generate.chat_completion.message import ( @@ -234,23 +234,19 @@ def __init__( ) @override - def generate(self, prompt: Prompt, **kwargs: Unpack[OpenAIChatParametersDict]) -> ChatCompletionOutput: + def generate(self, prompt: Prompt, **kwargs: Any) -> ChatCompletionOutput: return super().generate(prompt, **kwargs) @override - async def async_generate(self, prompt: Prompt, **kwargs: Unpack[OpenAIChatParametersDict]) -> ChatCompletionOutput: + async def async_generate(self, prompt: Prompt, **kwargs: Any) -> ChatCompletionOutput: return await super().async_generate(prompt, **kwargs) @override - def stream_generate( - self, prompt: Prompt, **kwargs: Unpack[OpenAIChatParametersDict] - ) -> Iterator[ChatCompletionStreamOutput]: + def stream_generate(self, prompt: Prompt, **kwargs: Any) -> Iterator[ChatCompletionStreamOutput]: yield from super().stream_generate(prompt, **kwargs) @override - async def async_stream_generate( - self, prompt: Prompt, **kwargs: Unpack[OpenAIChatParametersDict] - ) -> AsyncIterator[ChatCompletionStreamOutput]: + async def async_stream_generate(self, prompt: Prompt, **kwargs: Any) -> AsyncIterator[ChatCompletionStreamOutput]: async for stream_output in super().async_stream_generate(prompt, **kwargs): yield stream_output @@ -260,18 +256,18 @@ def _get_request_parameters(self, messages: Messages, stream: bool = False, **kw headers = { 'Authorization': f'Bearer {self.settings.api_key.get_secret_value()}', } - params = { + json_data = { 'model': self.model, 'messages': self.message_converter.convert_messages(messages), **parameters.custom_model_dump(), } if stream: - params['stream'] = True + json_data['stream'] = True return { 'url': f'{self.settings.api_base}/chat/completions', 'headers': headers, - 'json': params, + 'json': json_data, } @override @@ -357,16 +353,14 @@ def _update_delta(self, delta_dict: dict[str, Any], stream_manager: StreamManage stream_manager.delta = delta_content if delta_dict.get('tool_calls'): - index = delta_dict['tool_calls'][0]['index'] + tool_calls = delta_dict['tool_calls'][0] + index = tool_calls['index'] if index >= len(stream_manager.tool_calls or []): new_tool_calls_message = self._parse_assistant_message(delta_dict).tool_calls - assert new_tool_calls_message is not None - if stream_manager.tool_calls is None: - stream_manager.tool_calls = [] - stream_manager.tool_calls.append(new_tool_calls_message[0]) + if new_tool_calls_message: + stream_manager.tool_calls.append(new_tool_calls_message[0]) else: - assert stream_manager.tool_calls is not None - stream_manager.tool_calls[index].function.arguments += delta_dict['tool_calls'][0]['function']['arguments'] + stream_manager.tool_calls[index].function.arguments += tool_calls['function']['arguments'] if delta_dict.get('function_call'): if stream_manager.function_call is None: diff --git a/generate/chat_completion/stream_manager.py b/generate/chat_completion/stream_manager.py index db2c5ff..f07690b 100644 --- a/generate/chat_completion/stream_manager.py +++ b/generate/chat_completion/stream_manager.py @@ -14,7 +14,7 @@ class StreamManager(BaseModel): history_streams: List[Stream] = [] finish_reason: Optional[FinishReason] = None function_call: Optional[FunctionCall] = None - tool_calls: Optional[List[ToolCall]] = None + tool_calls: List[ToolCall] = [] close: bool = False extra: Dict[str, Any] = {} @@ -60,7 +60,7 @@ def build_stream_output(self) -> Optional[ChatCompletionStreamOutput]: message=AssistantMessage( content=self.content, function_call=self.function_call, - tool_calls=self.tool_calls, + tool_calls=self.tool_calls or None, ), stream=stream, ) diff --git a/generate/platforms/baichuan.py b/generate/platforms/baichuan.py index fe816c7..b95e338 100644 --- a/generate/platforms/baichuan.py +++ b/generate/platforms/baichuan.py @@ -1,12 +1,10 @@ -from pydantic import SecretStr from pydantic_settings import SettingsConfigDict -from generate.platforms.base import PlatformSettings +from generate.platforms.openai_like import OpenAILikeSettings -class BaichuanSettings(PlatformSettings): +class BaichuanSettings(OpenAILikeSettings): model_config = SettingsConfigDict(extra='ignore', env_prefix='baichuan_', env_file='.env') - api_key: SecretStr api_base: str = 'https://api.baichuan-ai.com/v1' platform_url: str = 'https://platform.baichuan-ai.com/docs/api' From 70566def4dc67ef1372e8aaea1a7fc6048a809d6 Mon Sep 17 00:00:00 2001 From: wangyuxin Date: Thu, 23 May 2024 19:36:15 +0800 Subject: [PATCH 4/4] =?UTF-8?q?=E4=BF=AE=E6=94=B9=20readme?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 40 ++++++++++---------- generate/chat_completion/models/anthropic.py | 2 +- generate/version.py | 2 +- pyproject.toml | 2 +- 4 files changed, 24 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index 2f55d65..3ddd34a 100644 --- a/README.md +++ b/README.md @@ -24,28 +24,30 @@

-# 简介 - +> [!IMPORTANT] +> generate v0.5.0 版本的设计思路发生较大变化。由于个人精力有限,generate 不再追求支持更多的平台,而是更多的围绕模型的辅助功能进行开发。另外,国内大部分平台已经提供了适配 OpenAI SDK 的 API,如果你只需要基础的文本生成功能,建议直接使用 OpenAI SDK。 +# 简介 Generate 允许用户通过统一的 api 访问多平台的生成式模型,当前支持: -| 平台 🤖 | 同步 🔄 | 异步 ⏳ | 流式 🌊 | Vision 👀 | Tools 🛠️ | -| ----------------- | ------- | ------- | ------- | --------- | -------- | -| OpenAI | ✅ | ✅ | ✅ | ✅ | ✅ | -| Azure | ✅ | ✅ | ❌ | ✅ | ✅ | -| Anthropic | ✅ | ✅ | ✅ | ✅ | ❌ | -| 文心 Wenxin | ✅ | ✅ | ✅ | ❌ | ✅ | -| 百炼 Bailian | ✅ | ✅ | ✅ | ❌ | ❌ | -| 灵积 DashScope | ✅ | ✅ | ✅ | ✅ | ❌ | -| 百川智能 Baichuan | ✅ | ✅ | ✅ | ❌ | ❌ | -| Minimax | ✅ | ✅ | ✅ | ❌ | ✅ | -| 混元 Hunyuan | ✅ | ✅ | ✅ | ❌ | ❌ | -| 智谱 Zhipu | ✅ | ✅ | ✅ | ✅ | ✅ | -| 月之暗面 Moonshot | ✅ | ✅ | ✅ | ❌ | ❌ | -| DeepSeek | ✅ | ✅ | ✅ | ❌ | ❌ | -| 零一万物 Yi | ✅ | ✅ | ✅ | ✅ | ❌ | -| 阶跃星辰 StepFun | ✅ | ✅ | ✅ | ✅ | ❌ | +| 平台 🤖 | 同步 🔄 | 异步 ⏳ | 流式 🌊 | Vision 👀 | Tools 🛠️ | +| ------------------- | ------- | ------- | ------- | --------- | -------- | +| OpenAI | ✅ | ✅ | ✅ | ✅ | ✅ | +| Azure | ✅ | ✅ | ❌ | ✅ | ✅ | +| Anthropic | ✅ | ✅ | ✅ | ✅ | ❌ | +| 文心 Wenxin | ✅ | ✅ | ✅ | ❌ | ✅ | +| 灵积/百炼 DashScope | ✅ | ✅ | ✅ | ✅ | ✅ | +| 百川智能 Baichuan | ✅ | ✅ | ✅ | ❌ | ✅ | +| Minimax | ✅ | ✅ | ✅ | ❌ | ✅ | +| 混元 Hunyuan | ✅ | ✅ | ✅ | ❌ | ❌ | +| 智谱 Zhipu | ✅ | ✅ | ✅ | ✅ | ✅ | +| 月之暗面 Moonshot | ✅ | ✅ | ✅ | ❌ | ✅ | +| DeepSeek | ✅ | ✅ | ✅ | ❌ | ❌ | +| 零一万物 Yi | ✅ | ✅ | ✅ | ✅ | ❌ | +| 阶跃星辰 StepFun | ✅ | ✅ | ✅ | ✅ | ❌ | + +> v0.5.0-beta 版本中,混元,文心尚未适配 ## Features @@ -53,7 +55,7 @@ Generate 允许用户通过统一的 api 访问多平台的生成式模型,当 - **跨平台**,支持 OpenAI,Azure,Minimax,智谱,月之暗面,文心一言 在内的国内外 10+ 平台 - **One API**,统一了不同平台的消息格式,推理参数,接口封装,返回解析,让用户无需关心不同平台的差异 - **异步,流式和并发**,提供流式调用,非流式调用,同步调用,异步调用,异步批量并发调用,适配不同的应用场景 -- **自带电池**,提供 chainlit UI,输入检查,参数检查,计费,速率控制,_Agent_, _Tool call_ 等 +- **自带电池**,提供 chainlit UI,输入检查,参数检查,计费,速率控制,Disk Cache,_Agent_, _Tool call_ 等 - **轻量**,最小化依赖,不同平台的请求和鉴权逻辑均为原生内置功能 - **高质量代码**,100% typehints,pylance strict, ruff lint & format, test coverage > 85% ... diff --git a/generate/chat_completion/models/anthropic.py b/generate/chat_completion/models/anthropic.py index f6dc4b4..76d4d02 100644 --- a/generate/chat_completion/models/anthropic.py +++ b/generate/chat_completion/models/anthropic.py @@ -159,7 +159,7 @@ class AnthropicChat(RemoteChatCompletionModel, SupportToolCall): def __init__( self, - model: str = 'claude-2.1', + model: str = 'claude-3-haiku-20240307', parameters: AnthropicChatParameters | None = None, settings: AnthropicSettings | None = None, http_client: HttpClient | None = None, diff --git a/generate/version.py b/generate/version.py index 2b8877c..f455207 100644 --- a/generate/version.py +++ b/generate/version.py @@ -1 +1 @@ -__version__ = '0.5.0' +__version__ = '0.5.00.5.0-beta' diff --git a/pyproject.toml b/pyproject.toml index acfbcfe..a57b8c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "generate-core" -version = "0.5.0" +version = "0.5.0-beta" description = "文本生成,图像生成,语音生成" authors = ["wangyuxin "] license = "MIT"