From e8a94e5ddef81ce1f68e80183258fe4bbb3cbe0b Mon Sep 17 00:00:00 2001 From: chris00zeng Date: Fri, 5 Jan 2024 17:07:43 -0800 Subject: [PATCH] fastapi_poe client: `post_message_attachment` improvements (#43) Changed `post_message_attachment` so it now returns the parsed response of its upload request. Added `is_inline` and `content_type` arguments Removed the async lock that didn't seem to be protecting anything Added typing --- src/fastapi_poe/base.py | 78 +++++++++++++++++++++++++++------------- src/fastapi_poe/types.py | 3 ++ 2 files changed, 57 insertions(+), 24 deletions(-) diff --git a/src/fastapi_poe/base.py b/src/fastapi_poe/base.py index d1ab536..299a518 100644 --- a/src/fastapi_poe/base.py +++ b/src/fastapi_poe/base.py @@ -6,7 +6,7 @@ import os import sys import warnings -from typing import Any, AsyncIterable, Dict, Optional, Union +from typing import Any, AsyncIterable, BinaryIO, Dict, Optional, Union import httpx from fastapi import Depends, FastAPI, HTTPException, Request, Response @@ -17,8 +17,10 @@ from starlette.middleware.base import BaseHTTPMiddleware from fastapi_poe.types import ( + AttachmentUploadResponse, ContentType, ErrorResponse, + Identifier, MetaResponse, PartialResponse, QueryRequest, @@ -34,6 +36,8 @@ class InvalidParameterError(Exception): pass +class AttachmentUploadError(Exception): + pass class LoggingMiddleware(BaseHTTPMiddleware): async def set_body(self, request: Request): @@ -110,25 +114,51 @@ async def on_error(self, error_request: ReportErrorRequest) -> None: # Helpers for generating responses def __init__(self): - self.pending_file_attachments = {} - self.file_attachment_lock = asyncio.Lock() + self._pending_file_attachment_tasks = {} async def post_message_attachment( - self, access_key, message_id, download_url=None, file_data=None, filename=None - ): + self, + access_key: str, + message_id: Identifier, + *, + download_url: Optional[str] = None, + file_data: Optional[Union[bytes, BinaryIO]] = None, + filename: Optional[str] = None, + content_type: Optional[str] = None, + is_inline: bool = False + ) -> AttachmentUploadResponse: task = asyncio.create_task( self._make_file_attachment_request( - access_key, message_id, download_url, file_data, filename + access_key=access_key, + message_id=message_id, + download_url=download_url, + file_data=file_data, + filename=filename, + content_type=content_type, + is_inline=is_inline ) ) - async with self.file_attachment_lock: - files_for_message = self.pending_file_attachments.get(message_id, []) - files_for_message.append(task) - self.pending_file_attachments[message_id] = files_for_message + pending_tasks_for_message = self._pending_file_attachment_tasks.get(message_id, None) + if pending_tasks_for_message is None: + pending_tasks_for_message = set() + self._pending_file_attachment_tasks[message_id] = pending_tasks_for_message + pending_tasks_for_message.add(task) + try: + return await task + finally: + pending_tasks_for_message.remove(task) async def _make_file_attachment_request( - self, access_key, message_id, download_url=None, file_data=None, filename=None - ): + self, + access_key: str, + message_id: Identifier, + *, + download_url: Optional[str] = None, + file_data: Optional[Union[bytes, BinaryIO]] = None, + filename: Optional[str] = None, + content_type: Optional[str] = None, + is_inline: bool = False + ) -> AttachmentUploadResponse: url = "https://www.quora.com/poe_api/file_attachment_POST" async with httpx.AsyncClient(timeout=120) as client: @@ -139,11 +169,16 @@ async def _make_file_attachment_request( raise InvalidParameterError( "Cannot provide filename or file_data if download_url is provided." ) - data = {"message_id": message_id, "download_url": download_url} + data = {"message_id": message_id, "is_inline": is_inline, "download_url": download_url} request = httpx.Request("POST", url, data=data, headers=headers) elif file_data and filename: - data = {"message_id": message_id} - files = {"file": (filename, file_data)} + data = {"message_id": message_id, "is_inline": is_inline} + files = { + "file": ( + (filename, file_data) if content_type is None + else (filename, file_data, content_type) + ) + } request = httpx.Request( "POST", url, files=files, data=data, headers=headers ) @@ -154,25 +189,20 @@ async def _make_file_attachment_request( response = await client.send(request) if response.status_code != 200: - logger.error( - f"Recieved {response.status_code} when attempting attach file." - ) + raise AttachmentUploadError(f"{response.status_code}: {response.reason_phrase}") + + return AttachmentUploadResponse(inline_ref=response.json().get("inline_ref")) - return response except httpx.HTTPError: logger.error("An HTTP error occurred when attempting to attach file") raise async def _process_pending_attachment_requests(self, message_id): try: - await asyncio.gather(*self.pending_file_attachments.get(message_id, [])) + await asyncio.gather(*self._pending_file_attachment_tasks.pop(message_id, [])) except Exception: logger.error("Error processing pending attachment requests") raise - finally: - # clear the pending attachments for the message - async with self.file_attachment_lock: - self.pending_file_attachments.pop(message_id, None) @staticmethod def text_event(text: str) -> ServerSentEvent: diff --git a/src/fastapi_poe/types.py b/src/fastapi_poe/types.py index b936570..0630f6e 100644 --- a/src/fastapi_poe/types.py +++ b/src/fastapi_poe/types.py @@ -86,6 +86,9 @@ class SettingsResponse(BaseModel): allow_attachments: bool = False introduction_message: str = "" +class AttachmentUploadResponse(BaseModel): + inline_ref: Optional[str] + class PartialResponse(BaseModel): """Representation of a (possibly partial) response from a bot."""