Skip to content

Commit

Permalink
Merge pull request #732 from RockChinQ/feat/claude-3
Browse files Browse the repository at this point in the history
Feat: 接入 claude 3 系列模型
  • Loading branch information
RockChinQ authored Mar 18, 2024
2 parents 0cfb8bb + a723c8c commit cca48a3
Show file tree
Hide file tree
Showing 8 changed files with 112 additions and 5 deletions.
1 change: 1 addition & 0 deletions pkg/core/bootutils/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
required_deps = {
"requests": "requests",
"openai": "openai",
"anthropic": "anthropic",
"colorlog": "colorlog",
"mirai": "yiri-mirai-rc",
"aiocqhttp": "aiocqhttp",
Expand Down
82 changes: 82 additions & 0 deletions pkg/provider/modelmgr/apis/anthropicmsgs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from __future__ import annotations

import typing
import traceback

import anthropic

from .. import api, entities, errors

from .. import api, entities, errors
from ....core import entities as core_entities
from ... import entities as llm_entities
from ...tools import entities as tools_entities


@api.requester_class("anthropic-messages")
class AnthropicMessages(api.LLMAPIRequester):
"""Anthropic Messages API 请求器"""

client: anthropic.AsyncAnthropic

async def initialize(self):
self.client = anthropic.AsyncAnthropic(
api_key="",
base_url=self.ap.provider_cfg.data['requester']['anthropic-messages']['base-url'],
timeout=self.ap.provider_cfg.data['requester']['anthropic-messages']['timeout'],
proxies=self.ap.proxy_mgr.get_forward_proxies()
)

async def request(
self,
query: core_entities.Query,
) -> typing.AsyncGenerator[llm_entities.Message, None]:
self.client.api_key = query.use_model.token_mgr.get_token()

args = self.ap.provider_cfg.data['requester']['anthropic-messages']['args'].copy()
args["model"] = query.use_model.name if query.use_model.model_name is None else query.use_model.model_name

req_messages = [ # req_messages 仅用于类内,外部同步由 query.messages 进行
m.dict(exclude_none=True) for m in query.prompt.messages
] + [m.dict(exclude_none=True) for m in query.messages]

# 删除所有 role=system & content='' 的消息
req_messages = [
m for m in req_messages if not (m["role"] == "system" and m["content"].strip() == "")
]

# 检查是否有 role=system 的消息,若有,改为 role=user,并在后面加一个 role=assistant 的消息
system_role_index = []
for i, m in enumerate(req_messages):
if m["role"] == "system":
system_role_index.append(i)
m["role"] = "user"

if system_role_index:
for i in system_role_index[::-1]:
req_messages.insert(i + 1, {"role": "assistant", "content": "Okay, I'll follow."})

# 忽略掉空消息,用户可能发送空消息,而上层未过滤
req_messages = [
m for m in req_messages if m["content"].strip() != ""
]

args["messages"] = req_messages

try:

resp = await self.client.messages.create(**args)

yield llm_entities.Message(
content=resp.content[0].text,
role=resp.role
)
except anthropic.AuthenticationError as e:
raise errors.RequesterError(f'api-key 无效: {e.message}')
except anthropic.BadRequestError as e:
raise errors.RequesterError(str(e.message))
except anthropic.NotFoundError as e:
if 'model: ' in str(e):
raise errors.RequesterError(f'模型无效: {e.message}')
else:
raise errors.RequesterError(f'请求地址无效: {e.message}')
4 changes: 1 addition & 3 deletions pkg/provider/modelmgr/apis/chatcmpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
import openai.types.chat.chat_completion as chat_completion
import httpx

from pkg.provider.entities import Message

from .. import api, entities, errors
from ....core import entities as core_entities
from ... import entities as llm_entities
Expand Down Expand Up @@ -127,7 +125,7 @@ async def _request(

req_messages.append(msg.dict(exclude_none=True))

async def request(self, query: core_entities.Query) -> AsyncGenerator[Message, None]:
async def request(self, query: core_entities.Query) -> AsyncGenerator[llm_entities.Message, None]:
try:
async for msg in self._request(query):
yield msg
Expand Down
2 changes: 1 addition & 1 deletion pkg/provider/modelmgr/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ class RequesterError(Exception):
"""Base class for all Requester errors."""

def __init__(self, message: str):
super().__init__("模型请求失败: "+message)
super().__init__("模型请求失败: "+message)
2 changes: 1 addition & 1 deletion pkg/provider/modelmgr/modelmgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ...core import app

from . import token, api
from .apis import chatcmpl
from .apis import chatcmpl, anthropicmsgs

FETCH_MODEL_LIST_URL = "https://api.qchatgpt.rockchin.top/api/v2/fetch/model_list"

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
requests
openai>1.0.0
anthropic
colorlog~=6.6.0
yiri-mirai-rc
aiocqhttp
Expand Down
15 changes: 15 additions & 0 deletions templates/metadata/llm-models.json
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,21 @@
{
"model_name": "gemini-pro",
"name": "OneAPI/gemini-pro"
},
{
"name": "claude-3-opus-20240229",
"requester": "anthropic-messages",
"token_mgr": "anthropic"
},
{
"name": "claude-3-sonnet-20240229",
"requester": "anthropic-messages",
"token_mgr": "anthropic"
},
{
"name": "claude-3-haiku-20240307",
"requester": "anthropic-messages",
"token_mgr": "anthropic"
}
]
}
10 changes: 10 additions & 0 deletions templates/provider.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,23 @@
"keys": {
"openai": [
"sk-1234567890"
],
"anthropic": [
"sk-1234567890"
]
},
"requester": {
"openai-chat-completions": {
"base-url": "https://api.openai.com/v1",
"args": {},
"timeout": 120
},
"anthropic-messages": {
"base-url": "https://api.anthropic.com/v1",
"args": {
"max_tokens": 1024
},
"timeout": 120
}
},
"model": "gpt-3.5-turbo",
Expand Down

0 comments on commit cca48a3

Please sign in to comment.