diff --git a/src/fastapi_poe/client.py b/src/fastapi_poe/client.py index c2b8696..cdde64c 100644 --- a/src/fastapi_poe/client.py +++ b/src/fastapi_poe/client.py @@ -19,6 +19,7 @@ Identifier, MetaResponse as MetaMessage, PartialResponse as BotMessage, + ProtocolMessage, QueryRequest, SettingsResponse, ) @@ -55,16 +56,16 @@ def _safe_ellipsis(obj: object, limit: int) -> str: @dataclass class _BotContext: endpoint: str - access_key: str = field(repr=False) session: httpx.AsyncClient = field(repr=False) + api_key: Optional[str] = field(default=None, repr=False) on_error: Optional[ErrorHandler] = field(default=None, repr=False) @property def headers(self) -> Dict[str, str]: - return { - "Accept": "application/json", - "Authorization": f"Bearer {self.access_key}", - } + headers = {"Accept": "application/json"} + if self.api_key is not None: + headers["Authorization"] = f"Bearer {self.api_key}" + return headers async def report_error( self, message: str, metadata: Optional[Dict[str, Any]] = None @@ -275,10 +276,10 @@ def _default_error_handler(e: Exception, msg: str) -> None: async def stream_request( request: QueryRequest, bot_name: str, - access_key: str = "", - *, api_key: str = "", - api_key_deprecation_warning_stacklevel: int = 2, + *, + access_key: str = "", + access_key_deprecation_warning_stacklevel: int = 2, session: Optional[httpx.AsyncClient] = None, on_error: ErrorHandler = _default_error_handler, num_tries: int = 2, @@ -286,21 +287,19 @@ async def stream_request( base_url: str = "https://api.poe.com/bot/", ) -> AsyncGenerator[BotMessage, None]: """Streams BotMessages from a Poe bot.""" - if api_key != "": + if access_key != "": warnings.warn( - "the api_key param is deprecated, pass your key using access_key instead", + "the access_key param is no longer necessary when using this function.", DeprecationWarning, - stacklevel=api_key_deprecation_warning_stacklevel, + stacklevel=access_key_deprecation_warning_stacklevel, ) - if access_key == "": - access_key = api_key async with contextlib.AsyncExitStack() as stack: if session is None: session = await stack.enter_async_context(httpx.AsyncClient()) url = f"{base_url}{bot_name}" ctx = _BotContext( - endpoint=url, access_key=access_key, session=session, on_error=on_error + endpoint=url, api_key=api_key, session=session, on_error=on_error ) got_response = False for i in range(num_tries): @@ -324,12 +323,48 @@ async def stream_request( await asyncio.sleep(retry_sleep_time) +async def get_bot_response( + messages: List[ProtocolMessage], + bot_name: str, + api_key: str, + *, + temperature: Optional[float] = None, + skip_system_prompt: Optional[bool] = None, + logit_bias: Optional[Dict[str, float]] = None, + stop_sequences: Optional[List[str]] = None, + base_url: str = "https://api.poe.com/bot/", +) -> AsyncGenerator[BotMessage, None]: + additional_params = {} + # This is so that we don't have to redefine the default values for these params. + if temperature is not None: + additional_params["temperature"] = temperature + if skip_system_prompt is not None: + additional_params["skip_system_prompt"] = skip_system_prompt + if logit_bias is not None: + additional_params["logit_bias"] = logit_bias + if stop_sequences is not None: + additional_params["stop_sequences"] = stop_sequences + + query = QueryRequest( + query=messages, + user_id="", + conversation_id="", + message_id="", + version=PROTOCOL_VERSION, + type="query", + **additional_params, + ) + return stream_request( + request=query, bot_name=bot_name, api_key=api_key, base_url=base_url + ) + + async def get_final_response( request: QueryRequest, bot_name: str, - access_key: str = "", - *, api_key: str = "", + *, + access_key: str = "", session: Optional[httpx.AsyncClient] = None, on_error: ErrorHandler = _default_error_handler, num_tries: int = 2, @@ -341,9 +376,9 @@ async def get_final_response( async for message in stream_request( request, bot_name, - access_key, - api_key=api_key, - api_key_deprecation_warning_stacklevel=3, + api_key, + access_key=access_key, + access_key_deprecation_warning_stacklevel=3, session=session, on_error=on_error, num_tries=num_tries,