diff --git a/pyproject.toml b/pyproject.toml index 1e92a57..87e98d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "fastapi_poe" -version = "0.0.25" +version = "0.0.26" authors = [ { name="Lida Li", email="lli@quora.com" }, { name="Jelle Zijlstra", email="jelle@quora.com" }, diff --git a/src/fastapi_poe/base.py b/src/fastapi_poe/base.py index b3a690f..52a1ac4 100644 --- a/src/fastapi_poe/base.py +++ b/src/fastapi_poe/base.py @@ -1,4 +1,5 @@ import argparse +import asyncio import copy import json import logging @@ -7,6 +8,7 @@ import warnings from typing import Any, AsyncIterable, Dict, Optional, Union +import httpx from fastapi import Depends, FastAPI, HTTPException, Request, Response from fastapi.exceptions import RequestValidationError from fastapi.responses import HTMLResponse, JSONResponse @@ -104,6 +106,57 @@ async def on_error(self, error_request: ReportErrorRequest) -> None: logger.error(f"Error from Poe server: {error_request}") # Helpers for generating responses + def __init__(self): + self.pending_file_attachments = {} + self.file_attachment_lock = asyncio.Lock() + + async def post_message_attachment( + self, message_id, file_data, filename, access_key + ): + task = asyncio.create_task( + self._make_file_attachment_request( + message_id, file_data, filename, access_key + ) + ) + async with self.file_attachment_lock: + files_for_message = self.pending_file_attachments.get(message_id, []) + files_for_message.append(task) + self.pending_file_attachments[message_id] = files_for_message + + async def _make_file_attachment_request( + self, message_id, file_data, filename, access_key + ): + url = "https://www.quora.com/poe_api/file_attachment_POST" + + async with httpx.AsyncClient(timeout=120) as client: + try: + files = {"file": (filename, file_data)} + data = {"message_id": message_id} + headers = {"Authorization": f"{access_key}"} + + request = httpx.Request( + "POST", url, files=files, data=data, headers=headers + ) + response = await client.send(request) + + if response.status_code != 200: + logger.error( + f"Recieved {response.status_code} when attempting attach file." + ) + + return response + except httpx.HTTPError: + logger.error("An HTTP error occurred when attempting to attach file") + + async def _process_pending_attachment_requests(self, message_id): + try: + await asyncio.gather(*self.pending_file_attachments.get(message_id, [])) + except Exception: + logger.error("Error processing pending attachment requests") + # clear the pending attachments for the message + async with self.file_attachment_lock: + self.pending_file_attachments.pop(message_id, None) + @staticmethod def text_event(text: str) -> ServerSentEvent: return ServerSentEvent(data=json.dumps({"text": text}), event="text") @@ -203,6 +256,7 @@ async def handle_query( except Exception as e: logger.exception("Error responding to query") yield self.error_event(repr(e), allow_retry=False) + await self._process_pending_attachment_requests(request.message_id) yield self.done_event()