Skip to content

Commit

Permalink
Merge pull request #74 from klb3713/main
Browse files Browse the repository at this point in the history
支持json返回模式
  • Loading branch information
liuooo committed May 30, 2024
2 parents b4370ad + 52d193c commit b5e878c
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 7 deletions.
3 changes: 3 additions & 0 deletions app/core/runner/llm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def run(
extra_body=None,
temperature=None,
top_p=None,
response_format=None,
) -> ChatCompletion | Stream[ChatCompletionChunk]:
chat_params = {
"messages": messages,
Expand All @@ -44,6 +45,8 @@ def run(
if tools:
chat_params["tools"] = tools
chat_params["tool_choice"] = tool_choice if tool_choice else "auto"
if isinstance(response_format, dict) and response_format.get("type") == "json_object":
chat_params["response_format"] = {"type": "json_object"}
logging.info("chat_params: %s", chat_params)
response = self.client.chat.completions.create(**chat_params)
logging.info("chat_response: %s", response)
Expand Down
1 change: 1 addition & 0 deletions app/core/runner/thread_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def __run_step(
extra_body=run.extra_body,
temperature=run.temperature,
top_p=run.top_p,
response_format=run.response_format,
)

# create message callback
Expand Down
6 changes: 3 additions & 3 deletions app/models/assistant.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Union

from sqlalchemy import Column
from sqlmodel import Field, JSON, TEXT
Expand All @@ -15,7 +15,7 @@ class AssistantBase(BaseModel):
name: Optional[str] = Field(default=None)
tools: Optional[list] = Field(default=None, sa_column=Column(JSON))
extra_body: Optional[dict] = Field(default={}, sa_column=Column(JSON))
response_format: Optional[str] = Field(default=None) # 响应格式
response_format: Union[str, dict] = Field(default="auto", sa_column=Column(JSON)) # 响应格式
tool_resources: Optional[dict] = Field(default=None, sa_column=Column(JSON)) # 工具资源
temperature: Optional[float] = Field(default=None) # 温度
top_p: Optional[float] = Field(default=None) # top_p
Expand All @@ -38,7 +38,7 @@ class AssistantUpdate(BaseModel):
name: Optional[str] = Field(default=None)
tools: Optional[list] = Field(default=None, sa_column=Column(JSON))
extra_body: Optional[dict] = Field(default={}, sa_column=Column(JSON))
response_format: Optional[str] = Field(default=None) # 响应格式
response_format: Union[str, dict] = Field(default="auto", sa_column=Column(JSON)) # 响应格式
tool_resources: Optional[dict] = Field(default=None, sa_column=Column(JSON)) # 工具资源
temperature: Optional[float] = Field(default=None) # 温度
top_p: Optional[float] = Field(default=None) # top_p
4 changes: 2 additions & 2 deletions app/models/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class Run(BaseModel, PrimaryKeyMixin, TimeStampMixin, table=True):
incomplete_details: Optional[str] = Field(default=None) # 未完成详情
max_completion_tokens: Optional[int] = Field(default=None) # 最大完成长度
max_prompt_tokens: Optional[int] = Field(default=None) # 最大提示长度
response_format: Optional[str] = Field(default=None) # 返回格式
response_format: Union[str, dict] = Field(default="auto", sa_column=Column(JSON)) # 返回格式
tool_choice: Optional[str] = Field(default=None) # 工具选择
truncation_strategy: Optional[dict] = Field(default=None, sa_column=Column(JSON)) # 截断策略
usage: Optional[dict] = Field(default=None, sa_column=Column(JSON)) # 调用使用情况
Expand All @@ -77,7 +77,7 @@ class RunCreate(BaseModel):
max_completion_tokens: Optional[int] = None # 最大完成长度
max_prompt_tokens: Optional[int] = Field(default=None) # 最大提示长度
truncation_strategy: Optional[dict] = Field(default=None, sa_column=Column(JSON)) # 截断策略
response_format: Optional[str] = Field(default=None) # 返回格式
response_format: Union[str, dict] = Field(default="auto", sa_column=Column(JSON)) # 返回格式
tool_choice: Optional[str] = Field(default=None) # 工具选择
temperature: Optional[float] = Field(default=None) # 温度
top_p: Optional[float] = Field(default=None) # top_p
Expand Down
45 changes: 45 additions & 0 deletions migrations/versions/2024-05-28-11-35_1c667e62f698.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""empty message
Revision ID: 1c667e62f698
Revises: aa4bda3363e3
Create Date: 2024-05-28 11:35:33.961196
"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa
import sqlmodel
from sqlalchemy.dialects import mysql

# revision identifiers, used by Alembic.
revision: str = '1c667e62f698'
down_revision: Union[str, None] = 'aa4bda3363e3'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column('assistant', 'response_format',
existing_type=mysql.VARCHAR(collation='utf8mb4_unicode_ci', length=255),
type_=sa.JSON(),
existing_nullable=True)
op.alter_column('run', 'response_format',
existing_type=mysql.VARCHAR(collation='utf8mb4_unicode_ci', length=255),
type_=sa.JSON(),
existing_nullable=True)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column('run', 'response_format',
existing_type=sa.JSON(),
type_=mysql.VARCHAR(collation='utf8mb4_unicode_ci', length=255),
existing_nullable=True)
op.alter_column('assistant', 'response_format',
existing_type=sa.JSON(),
type_=mysql.VARCHAR(collation='utf8mb4_unicode_ci', length=255),
existing_nullable=True)
# ### end Alembic commands ###
5 changes: 3 additions & 2 deletions tests/e2e/run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def test_create_run_with_additional_messages_and_other_parmas():
name="Assistant Demo",
instructions="你是一个有用的助手",
model="gpt-4o",
response_format={"type": "json_object"},
)
thread = client.beta.threads.create(
messages=[
Expand Down Expand Up @@ -42,7 +43,7 @@ def test_create_run_with_additional_messages_and_other_parmas():
stream = client.beta.threads.runs.create(
thread_id=thread.id,
assistant_id=assistant.id,
instructions="",
instructions="请用 json 格式回答",
additional_messages=[
{
"role": "user",
Expand Down Expand Up @@ -75,7 +76,7 @@ def test_create_run_with_additional_messages_and_other_parmas():

query = session.query(Run).filter(Run.thread_id == thread.id)
run = query.one()
assert run.instructions == "你是一个有用的助手"
assert run.instructions == "请用 json 格式回答"
assert run.model == "gpt-4o"
query = session.query(Message).filter(Message.thread_id == thread.id).order_by(Message.created_at)
messages = query.all()
Expand Down

0 comments on commit b5e878c

Please sign in to comment.