Skip to content

Commit

Permalink
fastapi_poe client: post_message_attachment improvements (#43)
Browse files Browse the repository at this point in the history
Changed `post_message_attachment` so it now returns the parsed response
of its upload request.

Added `is_inline` and `content_type` arguments

Removed the async lock that didn't seem to be protecting anything

Added typing
  • Loading branch information
chris00zeng authored Jan 6, 2024
1 parent accaa5f commit e8a94e5
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 24 deletions.
78 changes: 54 additions & 24 deletions src/fastapi_poe/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
import sys
import warnings
from typing import Any, AsyncIterable, Dict, Optional, Union
from typing import Any, AsyncIterable, BinaryIO, Dict, Optional, Union

import httpx
from fastapi import Depends, FastAPI, HTTPException, Request, Response
Expand All @@ -17,8 +17,10 @@
from starlette.middleware.base import BaseHTTPMiddleware

from fastapi_poe.types import (
AttachmentUploadResponse,
ContentType,
ErrorResponse,
Identifier,
MetaResponse,
PartialResponse,
QueryRequest,
Expand All @@ -34,6 +36,8 @@
class InvalidParameterError(Exception):
pass

class AttachmentUploadError(Exception):
pass

class LoggingMiddleware(BaseHTTPMiddleware):
async def set_body(self, request: Request):
Expand Down Expand Up @@ -110,25 +114,51 @@ async def on_error(self, error_request: ReportErrorRequest) -> None:

# Helpers for generating responses
def __init__(self):
self.pending_file_attachments = {}
self.file_attachment_lock = asyncio.Lock()
self._pending_file_attachment_tasks = {}

async def post_message_attachment(
self, access_key, message_id, download_url=None, file_data=None, filename=None
):
self,
access_key: str,
message_id: Identifier,
*,
download_url: Optional[str] = None,
file_data: Optional[Union[bytes, BinaryIO]] = None,
filename: Optional[str] = None,
content_type: Optional[str] = None,
is_inline: bool = False
) -> AttachmentUploadResponse:
task = asyncio.create_task(
self._make_file_attachment_request(
access_key, message_id, download_url, file_data, filename
access_key=access_key,
message_id=message_id,
download_url=download_url,
file_data=file_data,
filename=filename,
content_type=content_type,
is_inline=is_inline
)
)
async with self.file_attachment_lock:
files_for_message = self.pending_file_attachments.get(message_id, [])
files_for_message.append(task)
self.pending_file_attachments[message_id] = files_for_message
pending_tasks_for_message = self._pending_file_attachment_tasks.get(message_id, None)
if pending_tasks_for_message is None:
pending_tasks_for_message = set()
self._pending_file_attachment_tasks[message_id] = pending_tasks_for_message
pending_tasks_for_message.add(task)
try:
return await task
finally:
pending_tasks_for_message.remove(task)

async def _make_file_attachment_request(
self, access_key, message_id, download_url=None, file_data=None, filename=None
):
self,
access_key: str,
message_id: Identifier,
*,
download_url: Optional[str] = None,
file_data: Optional[Union[bytes, BinaryIO]] = None,
filename: Optional[str] = None,
content_type: Optional[str] = None,
is_inline: bool = False
) -> AttachmentUploadResponse:
url = "https://www.quora.com/poe_api/file_attachment_POST"

async with httpx.AsyncClient(timeout=120) as client:
Expand All @@ -139,11 +169,16 @@ async def _make_file_attachment_request(
raise InvalidParameterError(
"Cannot provide filename or file_data if download_url is provided."
)
data = {"message_id": message_id, "download_url": download_url}
data = {"message_id": message_id, "is_inline": is_inline, "download_url": download_url}
request = httpx.Request("POST", url, data=data, headers=headers)
elif file_data and filename:
data = {"message_id": message_id}
files = {"file": (filename, file_data)}
data = {"message_id": message_id, "is_inline": is_inline}
files = {
"file": (
(filename, file_data) if content_type is None
else (filename, file_data, content_type)
)
}
request = httpx.Request(
"POST", url, files=files, data=data, headers=headers
)
Expand All @@ -154,25 +189,20 @@ async def _make_file_attachment_request(
response = await client.send(request)

if response.status_code != 200:
logger.error(
f"Recieved {response.status_code} when attempting attach file."
)
raise AttachmentUploadError(f"{response.status_code}: {response.reason_phrase}")

return AttachmentUploadResponse(inline_ref=response.json().get("inline_ref"))

return response
except httpx.HTTPError:
logger.error("An HTTP error occurred when attempting to attach file")
raise

async def _process_pending_attachment_requests(self, message_id):
try:
await asyncio.gather(*self.pending_file_attachments.get(message_id, []))
await asyncio.gather(*self._pending_file_attachment_tasks.pop(message_id, []))
except Exception:
logger.error("Error processing pending attachment requests")
raise
finally:
# clear the pending attachments for the message
async with self.file_attachment_lock:
self.pending_file_attachments.pop(message_id, None)

@staticmethod
def text_event(text: str) -> ServerSentEvent:
Expand Down
3 changes: 3 additions & 0 deletions src/fastapi_poe/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ class SettingsResponse(BaseModel):
allow_attachments: bool = False
introduction_message: str = ""

class AttachmentUploadResponse(BaseModel):
inline_ref: Optional[str]


class PartialResponse(BaseModel):
"""Representation of a (possibly partial) response from a bot."""
Expand Down

0 comments on commit e8a94e5

Please sign in to comment.