Skip to content

Commit

Permalink
Refactor for the newest PMAT (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
kongzii committed Sep 5, 2024
1 parent 840104b commit aff9952
Show file tree
Hide file tree
Showing 6 changed files with 192 additions and 139 deletions.
7 changes: 0 additions & 7 deletions labs_api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,8 @@ class Config(BaseSettings):
PORT: int = 8080
WORKERS: int = 1
RELOAD: bool = True
TAVILY_API_KEY: t.Optional[SecretStr] = None
SQLALCHEMY_DB_URL: t.Optional[SecretStr] = None

@property
def tavily_api_key(self) -> SecretStr:
return check_not_none(
self.TAVILY_API_KEY, "TAVILY_API_KEY missing in the environment."
)

@property
def sqlalchemy_db_url(self) -> SecretStr:
return check_not_none(
Expand Down
38 changes: 15 additions & 23 deletions labs_api/insights.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
HexAddress,
OmenSubgraphHandler,
)
from prediction_market_agent_tooling.tools.tavily_storage.tavily_storage import (
TavilyStorage,
tavily_search,
)
from prediction_market_agent_tooling.tools.utils import utcnow
from tavily import TavilyClient

from labs_api.config import Config
from labs_api.insights_cache import MarketInsightsResponseCache
from labs_api.models import MarketInsightsResponse, TavilyResponse
from labs_api.models import MarketInsightsResponse


def market_insights_cached(
Expand All @@ -35,28 +37,18 @@ def market_insights(market_id: HexAddress) -> MarketInsightsResponse:
status_code=404, detail=f"Market with id `{market_id}` not found."
)
try:
insights = tavily_insights(market.question_title)
except Exception as e:
logger.error(f"Failed to get insights for market `{market_id}`: {e}")
insights = None
return MarketInsightsResponse.from_tavily_response(
market_id=market_id,
created_at=utcnow(),
tavily_response=insights,
)


def tavily_insights(query: str) -> TavilyResponse:
"""
Create a simple string with the top 5 search results from Tavily with a description.
"""
tavily = TavilyClient(api_key=Config().tavily_api_key.get_secret_value())
response = TavilyResponse.model_validate(
tavily.search(
query=query,
tavily_response = tavily_search(
market.question_title,
search_depth="basic",
include_answer=True,
max_results=5,
tavily_storage=TavilyStorage("market_insights"),
)
except Exception as e:
logger.error(f"Failed to get tavily_response for market `{market_id}`: {e}")
tavily_response = None
return MarketInsightsResponse.from_tavily_response(
market_id=market_id,
created_at=utcnow(),
tavily_response=tavily_response,
)
return response
18 changes: 13 additions & 5 deletions labs_api/insights_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,19 @@ def find(
item = session.exec(
query.order_by(desc(MarketInsightsResponseCacheModel.datetime_))
).first()
return (
MarketInsightsResponse.model_validate_json(item.json_dump)
if item is not None
else None
)
try:
market_insights_response = (
MarketInsightsResponse.model_validate_json(item.json_dump)
if item is not None
else None
)
except ValueError as e:
logger.error(
f"Error deserializing MarketInsightsResponse from cache for {market_id=} and {item=}: {e}"
)

market_insights_response = None
return market_insights_response

def save(
self,
Expand Down
22 changes: 6 additions & 16 deletions labs_api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
from datetime import datetime

from prediction_market_agent_tooling.gtypes import HexAddress
from prediction_market_agent_tooling.tools.tavily_storage.tavily_models import (
TavilyResponse,
TavilyResult,
)
from pydantic import BaseModel


Expand All @@ -10,7 +14,7 @@ class MarketInsightResult(BaseModel):
title: str

@staticmethod
def from_tavily_result(tavily_result: "TavilyResult") -> "MarketInsightResult":
def from_tavily_result(tavily_result: TavilyResult) -> "MarketInsightResult":
return MarketInsightResult(url=tavily_result.url, title=tavily_result.title)


Expand All @@ -28,7 +32,7 @@ def has_insights(self) -> bool:
def from_tavily_response(
market_id: HexAddress,
created_at: datetime,
tavily_response: t.Union["TavilyResponse", None],
tavily_response: t.Union[TavilyResponse, None],
) -> "MarketInsightsResponse":
return MarketInsightsResponse(
market_id=market_id,
Expand All @@ -43,17 +47,3 @@ def from_tavily_response(
else []
),
)


class TavilyResult(BaseModel):
title: str
url: str
content: str
score: float


class TavilyResponse(BaseModel):
query: str
answer: str
results: list[TavilyResult]
response_time: float
Loading

0 comments on commit aff9952

Please sign in to comment.