From 8ee195826837e49a6f91dcecd0b68f94ea002dd7 Mon Sep 17 00:00:00 2001 From: yangzhibo <837805784@qq.com> Date: Fri, 6 Sep 2024 14:47:55 +0800 Subject: [PATCH] feat: add test with action --- app/schemas/tool/action.py | 3 +- app/services/run/run.py | 1 - .../run/run_with_assistant_extra_body_test.py | 58 ++++ tests/run/run_with_auth_action_test.py | 139 ++++++++++ tests/tools/assistant_action_test.py | 251 ++++++++++++++++++ 5 files changed, 450 insertions(+), 2 deletions(-) create mode 100644 tests/run/run_with_assistant_extra_body_test.py create mode 100644 tests/run/run_with_auth_action_test.py create mode 100644 tests/tools/assistant_action_test.py diff --git a/app/schemas/tool/action.py b/app/schemas/tool/action.py index 7aff122..ee7f46b 100644 --- a/app/schemas/tool/action.py +++ b/app/schemas/tool/action.py @@ -88,7 +88,8 @@ def model_validator(cls, data: Any): openapi_schema = data.get("openapi_schema") validate_openapi_schema(openapi_schema) authentication = data.get("authentication") - Authentication.model_validate(authentication).encrypt() + if authentication: + Authentication.model_validate(authentication).encrypt() return data diff --git a/app/services/run/run.py b/app/services/run/run.py index ba96283..f0bcd11 100644 --- a/app/services/run/run.py +++ b/app/services/run/run.py @@ -61,7 +61,6 @@ async def create_run( # create run db_run = Run.model_validate(body.model_dump(by_alias=True), update={"thread_id": thread_id, "file_ids": file_ids}) session.add(db_run) - session.refresh(db_run) run_id = db_run.id if body.additional_messages: # create messages diff --git a/tests/run/run_with_assistant_extra_body_test.py b/tests/run/run_with_assistant_extra_body_test.py new file mode 100644 index 0000000..371a2b2 --- /dev/null +++ b/tests/run/run_with_assistant_extra_body_test.py @@ -0,0 +1,58 @@ +import time + +import openai + + +def test_run_with_assistant_extra_body(): + client = openai.OpenAI(base_url="http://localhost:8086/api/v1", api_key="xxx") + # 创建带有 action 的 assistant + assistant = client.beta.assistants.create( + name="Assistant Demo", + instructions="你是一个有用的助手", + model="gpt-3.5-turbo-1106", + extra_body={ + "extra_body": { + "model_params": { + "frequency_penalty": 0, + "logit_bias": None, + "max_tokens": 1024, + "presence_penalty": 0.6, + "temperature": 1, + "presence_penalty": 0, + "top_p": 1, + } + } + }, + ) + print(assistant, end="\n\n") + + thread = client.beta.threads.create() + print(thread, end="\n\n") + + message = client.beta.threads.messages.create( + thread_id=thread.id, + role="user", + content="你好,介绍一下你自己", + ) + print(message, end="\n\n") + + run = client.beta.threads.runs.create(thread_id=thread.id, assistant_id=assistant.id, instructions="") + print(run, end="\n\n") + + while True: + # run = client.beta.threads.runs.retrieve(thread_id=thread.id, run_id=run.id) + run = client.beta.threads.runs.retrieve(thread_id=thread.id, run_id=run.id) + if run.status == "completed": + print("done!", end="\n\n") + messages = client.beta.threads.messages.list(thread_id=thread.id) + + print("messages: ") + for message in messages: + assert message.content[0].type == "text" + print(messages) + print({"role": message.role, "message": message.content[0].text.value}) + + break + else: + print("\nin progress...") + time.sleep(1) diff --git a/tests/run/run_with_auth_action_test.py b/tests/run/run_with_auth_action_test.py new file mode 100644 index 0000000..3d59454 --- /dev/null +++ b/tests/run/run_with_auth_action_test.py @@ -0,0 +1,139 @@ +import time + +import openai +import pytest + +from app.providers.database import session +from app.schemas.tool.action import ActionBulkCreateRequest +from app.schemas.tool.authentication import Authentication, AuthenticationType +from app.services.tool.action import ActionService + + +@pytest.fixture +def api_url(): + return "http://127.0.0.1:8086/api/v1/actions" + + +@pytest.fixture +def create_workspace_with_authentication(): + return { + "openapi_schema": { + "openapi": "3.0.0", + "info": {"title": "Create New Workspace", "version": "1.0.0"}, + "servers": [{"url": "https://tx.c.csvfx.com/api"}], + "paths": { + "/tx/v1/workspaces": { + "post": { + "summary": "Create a new workspace", + "description": "This endpoint creates a new workspace with the provided data.", + "operationId": "createWorkspace", + "requestBody": { + "required": True, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "name": {"type": "string", "description": "The name of the workspace"}, + "description": { + "type": "string", + "description": "The description of the workspace", + }, + "ui_settings": { + "type": "object", + "properties": { + "color": { + "type": "string", + "description": "The color of the workspace UI", + }, + "icon": { + "type": "string", + "description": "The icon of the workspace UI", + }, + }, + }, + "tenant_id": {"type": "string", "description": "The tenant ID"}, + }, + } + } + }, + }, + "responses": { + "200": { + "description": "Workspace created successfully", + "content": {"application/json": {"schema": {"type": "object", "properties": {}}}}, + }, + "401": {"description": "Unauthorized - Authentication credentials are missing or invalid"}, + "403": {"description": "Forbidden - The authenticated user does not have permission to perform this action"}, + "500": {"description": "Internal Server Error - Something went wrong on the server side"}, + }, + } + } + }, + } + } + + +# 测试带有action的助手,run 的时候传递自己的auth信息 +def test_run_with_action_auth(create_workspace_with_authentication): + body = ActionBulkCreateRequest(**create_workspace_with_authentication) + body.authentication = Authentication(type=AuthenticationType.none) + actions = ActionService.create_actions_sync(session=session, body=body) + [create_workspace_with_authentication] = actions + + client = openai.OpenAI(base_url="http://localhost:8086/api/v1", api_key="xxx") + + # 创建带有 action 的 assistant + assistant = client.beta.assistants.create( + name="Assistant Demo", + instructions="你是一个有用的助手", + tools=[{"type": "action", "id": create_workspace_with_authentication.id}], + model="gpt-3.5-turbo-1106", + ) + print(assistant, end="\n\n") + + thread = client.beta.threads.create() + print(thread, end="\n\n") + + message = client.beta.threads.messages.create( + thread_id=thread.id, + role="user", + content="在组织63db49f7dcc8bf7b0990903c下,创建一个随机名字的工作空间", + ) + print(message, end="\n\n") + + run = client.beta.threads.runs.create( + # model="gpt-3.5-turbo-1106", + thread_id=thread.id, + assistant_id=assistant.id, + instructions="", + extra_body={ + "extra_body": { + "action_authentications": { + create_workspace_with_authentication.id: { + "type": "bearer", + "secret": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJqdGkiOiI2M2RiNDlhY2RjYzhiZjdiMDk5MDhmZDYiLCJhdWQiOiI2M2RiNDlmN2RjYzhiZjdiMDk5MDkwM2MiLCJ1aWQiOiI2M2RiNDlhY2RjYzhiZjdiMDk5MDhmZDYiLCJpYXQiOjE3MTAxNDkxODcsImV4cCI6MTcxMDIzNTU4N30.h96cKhB8rPGKM2PEq6bg4k2j09gR82HCJHUws232Oe4", + } + } + } + }, + ) + print(run, end="\n\n") + + while True: + # run = client.beta.threads.runs.retrieve(thread_id=thread.id, run_id=run.id) + run = client.beta.threads.runs.retrieve(thread_id=thread.id, run_id=run.id) + if run.status == "completed": + print("done!", end="\n\n") + messages = client.beta.threads.messages.list(thread_id=thread.id) + + print("messages: ") + for message in messages: + assert message.content[0].type == "text" + print(messages) + print({"role": message.role, "message": message.content[0].text.value}) + + break + else: + print("\nin progress...") + time.sleep(1) diff --git a/tests/tools/assistant_action_test.py b/tests/tools/assistant_action_test.py new file mode 100644 index 0000000..9ae8e24 --- /dev/null +++ b/tests/tools/assistant_action_test.py @@ -0,0 +1,251 @@ +import time + +import openai +import pytest + +from app.providers.database import session +from app.schemas.tool.action import ActionBulkCreateRequest +from app.schemas.tool.authentication import Authentication, AuthenticationType +from app.services.tool.action import ActionService + + +@pytest.fixture +def api_url(): + return "http://127.0.0.1:8086/api/v1/actions" + + +@pytest.fixture +def get_weather_data_valid_payload(): + return { + "openapi_schema": { + "openapi": "3.0.0", + "info": { + "title": "OpenWeatherMap One Call API", + "description": "API for accessing comprehensive weather data from OpenWeatherMap.", + "version": "1.0.0", + }, + "servers": [ + { + "url": "https://api.openweathermap.org/data/3.0", + "description": "OpenWeatherMap One Call API server", + } + ], + "paths": { + "/onecall": { + "get": { + "summary": "Get Comprehensive Weather Data", + "description": "Retrieves weather data for a specific latitude and longitude.", + "operationId": "get_weather_data", + "parameters": [ + { + "in": "query", + "name": "lat", + "schema": {"type": "number", "format": "float", "minimum": -90, "maximum": 90}, + "required": True, + "description": "Latitude, decimal (-90 to 90).", + }, + { + "in": "query", + "name": "lon", + "schema": {"type": "number", "format": "float", "minimum": -180, "maximum": 180}, + "required": True, + "description": "Longitude, decimal (-180 to 180).", + }, + { + "in": "query", + "name": "exclude", + "schema": {"type": "string"}, + "required": False, + "description": "Exclude some parts of the weather data(current, minutely, hourly, daily, alerts).", + }, + { + "in": "query", + "name": "appid", + "schema": {"type": "string", "enum": ["101f41d3ff4095824722d57a513cb80a"]}, + "required": True, + "description": "Your unique API key.", + }, + ], + "responses": { + "200": { + "description": "Successful response with comprehensive weather data.", + "content": {"application/json": {"schema": {"type": "object", "properties": {}}}}, + } + }, + } + } + }, + } + } + + +@pytest.fixture +def get_number_fact_valid_payload(): + return { + "openapi_schema": { + "openapi": "3.0.0", + "info": { + "title": "Numbers API", + "version": "1.0.0", + "description": "API for fetching interesting number facts", + }, + "servers": [{"url": "http://numbersapi.com"}], + "paths": { + "/{number}": { + "get": { + "description": "Get fact about a number", + "operationId": "get_number_fact", + "parameters": [ + { + "name": "number", + "in": "path", + "required": True, + "description": "The number to get the fact for", + "schema": {"type": "integer"}, + } + ], + "responses": { + "200": { + "description": "A fact about the number", + "content": {"text/plain": {"schema": {"type": "string"}}}, + } + }, + } + } + }, + } + } + + +@pytest.fixture +def create_workspace_with_authentication(): + return { + "openapi_schema": { + "openapi": "3.0.0", + "info": {"title": "Create New Workspace", "version": "1.0.0"}, + "servers": [{"url": "https://tx.c.csvfx.com/api"}], + "paths": { + "/tx/v1/workspaces": { + "post": { + "summary": "Create a new workspace", + "description": "This endpoint creates a new workspace with the provided data.", + "operationId": "createWorkspace", + "requestBody": { + "required": True, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "name": {"type": "string", "description": "The name of the workspace"}, + "description": { + "type": "string", + "description": "The description of the workspace", + }, + "ui_settings": { + "type": "object", + "properties": { + "color": { + "type": "string", + "description": "The color of the workspace UI", + }, + "icon": { + "type": "string", + "description": "The icon of the workspace UI", + }, + }, + }, + "tenant_id": {"type": "string", "description": "The tenant ID"}, + }, + } + } + }, + }, + "responses": { + "200": { + "description": "Workspace created successfully", + "content": {"application/json": {"schema": {"type": "object", "properties": {}}}}, + }, + "401": {"description": "Unauthorized - Authentication credentials are missing or invalid"}, + "403": { + "description": "Forbidden - The authenticated user does not have permission to perform this action" + }, + "500": {"description": "Internal Server Error - Something went wrong on the server side"}, + }, + } + } + }, + }, + "authentication": { + "type": "bearer", + "secret": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJqdGkiOiI2M2RiNDlhY2RjYzhiZjdiMDk5MDhmZDYiLCJhdWQiOiI2M2RiNDlmN2RjYzhiZjdiMDk5MDkwM2MiLCJ1aWQiOiI2M2RiNDlhY2RjYzhiZjdiMDk5MDhmZDYiLCJpYXQiOjE3MDg5MTQ0MDcsImV4cCI6MTcwOTAwMDgwN30.vK6cH1qxPqjgeSdDym4b4fzwTwhW66s_L1suCvh3W98", + }, + } + + +# 测试带有action的助手 +def test_assistant_with_action_tools( + get_weather_data_valid_payload, get_number_fact_valid_payload, create_workspace_with_authentication +): + body = ActionBulkCreateRequest(**get_weather_data_valid_payload) + body.authentication = Authentication( + type=AuthenticationType.none, + ) + actions = ActionService.create_actions_sync(session=session, body=body) + [get_weather_data] = actions + # get_number_fact + body = ActionBulkCreateRequest(**get_number_fact_valid_payload) + body.authentication = Authentication( + type=AuthenticationType.none, + ) + actions = ActionService.create_actions_sync(session=session, body=body) + [get_number_fact] = actions + + client = openai.OpenAI(base_url="http://localhost:8086/api/v1", api_key="xxx") + + # 创建带有 action 的 assistant + assistant = client.beta.assistants.create( + name="Assistant Demo", + instructions="你是一个有用的助手", + tools=[ + {"type": "action", "id": get_weather_data.id}, + {"type": "action", "id": get_number_fact.id}, + ], + model="gpt-3.5-turbo-1106", + ) + print(assistant, end="\n\n") + + thread = client.beta.threads.create() + print(thread, end="\n\n") + + message = client.beta.threads.messages.create( + thread_id=thread.id, + role="user", + content="武汉的天气", + ) + print(message, end="\n\n") + + run = client.beta.threads.runs.create( + # model="gpt-3.5-turbo-1106", + thread_id=thread.id, + assistant_id=assistant.id, + instructions="", + ) + print(run, end="\n\n") + + while True: + run = client.beta.threads.runs.retrieve(thread_id=thread.id, run_id=run.id) + if run.status == "completed": + print("done!", end="\n\n") + messages = client.beta.threads.messages.list(thread_id=thread.id) + + print("messages: ") + for message in messages: + assert message.content[0].type == "text" + print(messages) + print({"role": message.role, "message": message.content[0].text.value}) + + break + else: + print("\nin progress...") + time.sleep(1)