Skip to content

Commit

Permalink
add helper function for people to use bot response api from their ter…
Browse files Browse the repository at this point in the history
…mianl (#26)

* add helper function for people to use bot response api from their terminal

* add base_url to helper func and use empty string instead of missing

---------

Co-authored-by: Anmol Singh <asingh@quora.com>
  • Loading branch information
anmolsingh95 and Anmol Singh authored Oct 4, 2023
1 parent 3f8e66e commit 8f11eda
Showing 1 changed file with 54 additions and 19 deletions.
73 changes: 54 additions & 19 deletions src/fastapi_poe/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Identifier,
MetaResponse as MetaMessage,
PartialResponse as BotMessage,
ProtocolMessage,
QueryRequest,
SettingsResponse,
)
Expand Down Expand Up @@ -55,16 +56,16 @@ def _safe_ellipsis(obj: object, limit: int) -> str:
@dataclass
class _BotContext:
endpoint: str
access_key: str = field(repr=False)
session: httpx.AsyncClient = field(repr=False)
api_key: Optional[str] = field(default=None, repr=False)
on_error: Optional[ErrorHandler] = field(default=None, repr=False)

@property
def headers(self) -> Dict[str, str]:
return {
"Accept": "application/json",
"Authorization": f"Bearer {self.access_key}",
}
headers = {"Accept": "application/json"}
if self.api_key is not None:
headers["Authorization"] = f"Bearer {self.api_key}"
return headers

async def report_error(
self, message: str, metadata: Optional[Dict[str, Any]] = None
Expand Down Expand Up @@ -275,32 +276,30 @@ def _default_error_handler(e: Exception, msg: str) -> None:
async def stream_request(
request: QueryRequest,
bot_name: str,
access_key: str = "",
*,
api_key: str = "",
api_key_deprecation_warning_stacklevel: int = 2,
*,
access_key: str = "",
access_key_deprecation_warning_stacklevel: int = 2,
session: Optional[httpx.AsyncClient] = None,
on_error: ErrorHandler = _default_error_handler,
num_tries: int = 2,
retry_sleep_time: float = 0.5,
base_url: str = "https://api.poe.com/bot/",
) -> AsyncGenerator[BotMessage, None]:
"""Streams BotMessages from a Poe bot."""
if api_key != "":
if access_key != "":
warnings.warn(
"the api_key param is deprecated, pass your key using access_key instead",
"the access_key param is no longer necessary when using this function.",
DeprecationWarning,
stacklevel=api_key_deprecation_warning_stacklevel,
stacklevel=access_key_deprecation_warning_stacklevel,
)
if access_key == "":
access_key = api_key

async with contextlib.AsyncExitStack() as stack:
if session is None:
session = await stack.enter_async_context(httpx.AsyncClient())
url = f"{base_url}{bot_name}"
ctx = _BotContext(
endpoint=url, access_key=access_key, session=session, on_error=on_error
endpoint=url, api_key=api_key, session=session, on_error=on_error
)
got_response = False
for i in range(num_tries):
Expand All @@ -324,12 +323,48 @@ async def stream_request(
await asyncio.sleep(retry_sleep_time)


async def get_bot_response(
messages: List[ProtocolMessage],
bot_name: str,
api_key: str,
*,
temperature: Optional[float] = None,
skip_system_prompt: Optional[bool] = None,
logit_bias: Optional[Dict[str, float]] = None,
stop_sequences: Optional[List[str]] = None,
base_url: str = "https://api.poe.com/bot/",
) -> AsyncGenerator[BotMessage, None]:
additional_params = {}
# This is so that we don't have to redefine the default values for these params.
if temperature is not None:
additional_params["temperature"] = temperature
if skip_system_prompt is not None:
additional_params["skip_system_prompt"] = skip_system_prompt
if logit_bias is not None:
additional_params["logit_bias"] = logit_bias
if stop_sequences is not None:
additional_params["stop_sequences"] = stop_sequences

query = QueryRequest(
query=messages,
user_id="",
conversation_id="",
message_id="",
version=PROTOCOL_VERSION,
type="query",
**additional_params,
)
return stream_request(
request=query, bot_name=bot_name, api_key=api_key, base_url=base_url
)


async def get_final_response(
request: QueryRequest,
bot_name: str,
access_key: str = "",
*,
api_key: str = "",
*,
access_key: str = "",
session: Optional[httpx.AsyncClient] = None,
on_error: ErrorHandler = _default_error_handler,
num_tries: int = 2,
Expand All @@ -341,9 +376,9 @@ async def get_final_response(
async for message in stream_request(
request,
bot_name,
access_key,
api_key=api_key,
api_key_deprecation_warning_stacklevel=3,
api_key,
access_key=access_key,
access_key_deprecation_warning_stacklevel=3,
session=session,
on_error=on_error,
num_tries=num_tries,
Expand Down

0 comments on commit 8f11eda

Please sign in to comment.