Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update agent module #259

Merged
merged 4 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
{
"intents": {
"retrieval": "关于一些通用的信息检索,比如搜索旅游攻略、搜索美食攻略、搜索注意事项等信息。",
"agent": "实时性的信息查询,比如查询航班信息、查询高铁信息、查询天气等时效性很强的信息。"
},
"system_prompt": "你是一个旅游小助手,可以帮助用户查询指定时间从A地区到B地区的机票信息,火车票信息以及天气信息等。请严格使用输入的工具,不要虚构任何细节。",
"function_tools": [
{
Expand All @@ -25,7 +21,7 @@
"api_tools": [
{
"name": "search_flight_ticket_api",
"url": "http://127.0.0.1:8070/demo/api/flights",
"url": "http://127.0.0.1:8001/demo/api/flights",
"headers": {
"Authorization": "Bearer YOUR_ACCESS_TOKEN"
},
Expand All @@ -42,14 +38,14 @@
},
"date": {
"type": "str",
"description": "出发时间,如'2024-03-29'"
"description": "出发时间,YYYY-MM-DD格式,如'2024-03-29'"
}
},
"required": ["from_city", "to_city", "date"]
},
{
"name": "search_train_ticket_api",
"url": "http://127.0.0.1:8070/demo/api/trains",
"url": "http://127.0.0.1:8001/demo/api/trains",
"headers": {
"Authorization": "Bearer YOUR_ACCESS_TOKEN"
},
Expand All @@ -66,14 +62,14 @@
},
"date": {
"type": "str",
"description": "出发时间,如'2024-03-29'"
"description": "出发时间,YYYY-MM-DD格式,如'2024-03-29'"
}
},
"required": ["from_city", "to_city", "date"]
},
{
"name": "search_hotels_api",
"url": "http://127.0.0.1:8070/demo/api/hotels",
"url": "http://127.0.0.1:8001/demo/api/hotels",
"headers": {
"Authorization": "Bearer YOUR_ACCESS_TOKEN"
},
Expand All @@ -85,12 +81,16 @@
"type": "str",
"description": "查询的城市,如'北京'、'上海'、'南京''"
},
"date": {
"checkin_date": {
"type": "str",
"description": "入住时间,YYYY-MM-DD格式,如'2024-03-29'"
},
"checkout_date": {
"type": "str",
"description": "出发时间,如'2024-03-29'"
"description": "离店时间,YYYY-MM-DD格式,如'2024-03-31'"
}
},
"required": ["city", "date"]
"required": ["city", "checkin_date", "checkout_date"]
}
]
}
Original file line number Diff line number Diff line change
@@ -1,26 +1,34 @@
import requests
import os
import logging

logger = logging.getLogger(__name__)


def get_place_weather(city: str) -> str:
print(f"[Agent] Checking realtime weather info for {city}")
logger.info(f"[Agent] Checking realtime weather info for {city}")

"""Get city name and return city weather"""
api_key = os.environ.get("weather_api_key")

# 可以直接赋值给api_key,原始代码的config只有type类型。
base_url = "http://api.openweathermap.org/data/2.5/forecast?"
complete_url = f"{base_url}q={city}&appid={api_key}&lang=zh_cn&units=metric"
print(complete_url)
response = requests.get(complete_url)
logger.info(f"Requesting {complete_url}...")
response = requests.get(complete_url, timeout=5)
weather_data = response.json()

if weather_data["cod"] != "200":
print(f"获取天气信息失败,错误代码:{weather_data['cod']}")
return None
logger.error(
f"获取天气信息失败,错误代码:{weather_data['cod']} 错误信息:{weather_data['message']}"
)
return f"获取天气信息失败,错误代码:{weather_data['cod']} 错误信息:{weather_data['message']}"

element = weather_data["list"][0]

return str(
f"{city}的天气:\n 时间: {element['dt_txt']}\n 温度: {element['main']['temp']} °C\n 天气描述: {element['weather'][0]['description']}\n"
)
return f"""
{city}的天气:
时间: {element['dt_txt']}
温度: {element['main']['temp']} °C
天气描述: {element['weather'][0]['description']}
"""
192 changes: 192 additions & 0 deletions src/pai_rag/app/api/agent_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
from datetime import datetime
from fastapi import APIRouter
import logging

from pydantic import BaseModel

logger = logging.getLogger(__name__)

demo_router = APIRouter()


# Mock 数据
flights_data = [
{
"flight_number": "CA123",
"from": "北京",
"to": "上海",
"departure_time": "08:00",
"arrival_time": "10:00",
"price": 1200,
},
{
"flight_number": "MU456",
"from": "北京",
"to": "上海",
"departure_time": "14:00",
"arrival_time": "16:00",
"price": 1300,
},
{
"flight_number": "HU789",
"from": "北京",
"to": "上海",
"departure_time": "18:00",
"arrival_time": "20:00",
"price": 1100,
},
{
"flight_number": "CA234",
"from": "北京",
"to": "上海",
"departure_time": "06:00",
"arrival_time": "08:00",
"price": 1250,
},
{
"flight_number": "MU567",
"from": "北京",
"to": "上海",
"departure_time": "21:00",
"arrival_time": "23:00",
"price": 1350,
},
]

highspeed_trains_data = [
{
"train_number": "G1234",
"from": "北京",
"to": "上海",
"departure_time": "09:00",
"arrival_time": "11:30",
"price": 800,
},
{
"train_number": "G5678",
"from": "北京",
"to": "上海",
"departure_time": "15:00",
"arrival_time": "17:30",
"price": 850,
},
{
"train_number": "G9101",
"from": "北京",
"to": "上海",
"departure_time": "18:30",
"arrival_time": "21:00",
"price": 780,
},
{
"train_number": "G1123",
"from": "北京",
"to": "上海",
"departure_time": "07:00",
"arrival_time": "09:30",
"price": 820,
},
{
"train_number": "G4578",
"from": "北京",
"to": "上海",
"departure_time": "22:00",
"arrival_time": "00:30",
"price": 870,
},
]

hotels_data = [
{
"hotel_name": "万豪酒店",
"city": "上海",
"price_per_night": 600,
},
{
"hotel_name": "希尔顿酒店",
"city": "上海",
"price_per_night": 850,
},
{
"hotel_name": "洲际酒店",
"city": "上海",
"price_per_night": 700,
},
{
"hotel_name": "皇冠假日酒店",
"city": "上海",
"price_per_night": 750,
},
{
"hotel_name": "如家酒店",
"city": "上海",
"price_per_night": 300,
},
]


@demo_router.get("/flights")
async def get_flights(date: str, to_city: str, from_city: str):
try:
_ = datetime.strptime(date, "%Y-%m-%d")
except Exception as _:
return {
"error": f"Invalid date format '{date}'. Please provide a date in YYYY-MM-DD format."
}

raw_fights = [
flight
for flight in flights_data
if flight["from"] == from_city and flight["to"] == to_city
]

for flight in raw_fights:
flight["date"] = date

return raw_fights


@demo_router.get("/trains")
async def get_trains(date: str, to_city: str, from_city: str):
try:
_ = datetime.strptime(date, "%Y-%m-%d")
except Exception as _:
return {
"error": f"Invalid date format '{date}'. Please provide a date in YYYY-MM-DD format."
}

raw_trains = [
train
for train in highspeed_trains_data
if train["from"] == from_city and train["to"] == to_city
]

for train in raw_trains:
train["date"] = date

return raw_trains


class HotelInput(BaseModel):
checkin_date: str
checkout_date: str
city: str


@demo_router.post("/hotels")
async def get_hotels(input: HotelInput):
try:
_ = datetime.strptime(input.checkin_date, "%Y-%m-%d")
_ = datetime.strptime(input.checkout_date, "%Y-%m-%d")
except Exception as _:
return {
"error": f"Invalid date format '{input}'. Please provide a date in YYYY-MM-DD format."
}

hotels = [hotel for hotel in hotels_data if hotel["city"] == input.city]

for hotel in hotels:
hotel["checkin_date"] = input.checkin_date
hotel["checkout_date"] = input.checkout_date

return hotels
9 changes: 8 additions & 1 deletion src/pai_rag/app/api/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,14 @@ async def aquery_retrieval(query: RetrievalQuery):

@router.post("/query/agent")
async def aquery_agent(query: RagQuery):
return await rag_service.aquery_agent(query)
response = await rag_service.aquery_agent(query)
if not query.stream:
return response
else:
return StreamingResponse(
response,
media_type="text/event-stream",
)


@router.post("/config/agent")
Expand Down
8 changes: 4 additions & 4 deletions src/pai_rag/app/api/service.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from fastapi import APIRouter, FastAPI
from fastapi import FastAPI
from pai_rag.core.rag_config_manager import RagConfigManager
from pai_rag.core.rag_service import rag_service
from pai_rag.app.api import query
from pai_rag.app.api import agent_demo
from pai_rag.app.api.middleware import init_middleware
from pai_rag.app.api.error_handler import config_app_errors


def init_router(app: FastAPI):
api_router = APIRouter()
api_router.include_router(query.router, tags=["RagQuery"])
app.include_router(api_router, prefix="/service")
app.include_router(query.router, prefix="/service", tags=["RAG"])
app.include_router(agent_demo.demo_router, tags=["AgentDemo"], prefix="/demo/api")


def configure_app(app: FastAPI, rag_configuration: RagConfigManager):
Expand Down
17 changes: 12 additions & 5 deletions src/pai_rag/app/web/rag_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,9 @@ def query(
else:
full_content = ""
for chunk in r.iter_lines(chunk_size=8192, decode_unicode=True):
chunk_response = dotdict(json.loads(chunk))
if not chunk.startswith("data:"):
continue
chunk_response = dotdict(json.loads(chunk[5:]))
full_content += chunk_response.delta
chunk_response.delta = full_content
yield self._format_rag_response(
Expand All @@ -249,7 +251,9 @@ def query_search(
else:
full_content = ""
for chunk in r.iter_lines(chunk_size=8192, decode_unicode=True):
chunk_response = dotdict(json.loads(chunk))
if not chunk.startswith("data:"):
continue
chunk_response = dotdict(json.loads(chunk[5:]))
full_content += chunk_response.delta
chunk_response.delta = full_content
yield self._format_rag_response(text, chunk_response, stream=stream)
Expand All @@ -275,7 +279,9 @@ def query_data_analysis(
else:
full_content = ""
for chunk in r.iter_lines(chunk_size=8192, decode_unicode=True):
chunk_response = dotdict(json.loads(chunk))
if not chunk.startswith("data:"):
continue
chunk_response = dotdict(json.loads(chunk[5:]))
full_content += chunk_response.delta
chunk_response.delta = full_content
yield self._format_rag_response(text, chunk_response, stream=stream)
Expand Down Expand Up @@ -308,7 +314,9 @@ def query_llm(
else:
full_content = ""
for chunk in r.iter_lines(chunk_size=8192, decode_unicode=True):
chunk_response = dotdict(json.loads(chunk))
if not chunk.startswith("data:"):
continue
chunk_response = dotdict(json.loads(chunk[5:]))
full_content += chunk_response.delta
chunk_response.delta = full_content
yield self._format_rag_response(
Expand Down Expand Up @@ -448,7 +456,6 @@ def get_config(self):
r = requests.get(self.config_url, timeout=DEFAULT_CLIENT_TIME_OUT)
if r.status_code != HTTPStatus.OK:
raise RagApiError(code=r.status_code, msg=r.text)

config = RagConfig.model_validate_json(json_data=r.text)
return config

Expand Down
Loading
Loading