Skip to content

Commit

Permalink
feature/simplify consumers (#1016)
Browse files Browse the repository at this point in the history
* simplified handle_documents

* wip

* format

* wip

---------

Co-authored-by: George Burton <g.e.c.cburton@gmail.com>
  • Loading branch information
gecBurton and George Burton authored Sep 5, 2024
1 parent 826d7d5 commit 096bc35
Showing 1 changed file with 37 additions and 68 deletions.
105 changes: 37 additions & 68 deletions django_app/redbox_app/redbox_core/consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,24 +38,26 @@ async def receive(self, text_data=None, bytes_data=None):
data = json.loads(text_data or bytes_data)
logger.debug("received %s from browser", data)
user_message_text: str = data.get("message", "")
session_id: str | None = data.get("sessionId", None)
selected_file_uuids: Sequence[UUID] = [UUID(u) for u in data.get("selectedFiles", [])]
user: User = self.scope.get("user", None)

session: Chat = await self.get_session(session_id, user, user_message_text)
if session_id := data.get("sessionId"):
session = await Chat.objects.aget(id=session_id)
else:
session = await Chat.objects.acreate(name=user_message_text[0 : settings.CHAT_TITLE_LENGTH], user=user)

# save user message
selected_files = await self.get_files_by_id(selected_file_uuids, user)
selected_files = File.objects.filter(id__in=selected_file_uuids, user=user)
await self.save_message(session, user_message_text, ChatRoleEnum.user, selected_files=selected_files)

await self.llm_conversation(selected_files, session, user, user_message_text)
await self.close()

async def llm_conversation(self, selected_files: Sequence[File], session: Chat, user: User, title: str) -> None:
"""Initiate & close websocket conversation with the core-api message endpoint."""
session_messages = await self.get_messages(session)
session_messages = ChatMessage.objects.filter(chat=session).order_by("created_at")
message_history: Sequence[Mapping[str, str]] = [
{"role": message.role, "text": message.text} for message in session_messages
{"role": message.role, "text": message.text} async for message in session_messages
]
url = URL.build(scheme="ws", host=settings.CORE_API_HOST, port=settings.CORE_API_PORT) / "chat/rag"
try:
Expand All @@ -67,21 +69,18 @@ async def llm_conversation(self, selected_files: Sequence[File], session: Chat,
}
await self.send_to_server(core_websocket, message)
await self.send_to_client("session-id", session.id)
reply, citations, route, metadata = await self.receive_llm_responses(user, core_websocket)
reply, citations, route, metadata = await self.receive_llm_responses(core_websocket)
message = await self.save_message(
session, reply, ChatRoleEnum.ai, sources=citations, route=route, metadata=metadata
)
await self.send_to_client("end", {"message_id": message.id, "title": title, "session_id": session.id})

for file, _ in citations:
file.last_referenced = timezone.now()
await self.file_save(file)
except (TimeoutError, ConnectionClosedError, CancelledError) as e:
logger.exception("Error from core.", exc_info=e)
await self.send_to_client("error", error_messages.CORE_ERROR_MESSAGE)

async def receive_llm_responses(
self, user: User, core_websocket: WebSocketClientProtocol
self, core_websocket: WebSocketClientProtocol
) -> tuple[str, Sequence[tuple[File, SourceDocument]], str, MetadataDetail]:
"""Conduct websocket conversation with the core-api message endpoint."""
full_reply: MutableSequence[str] = []
Expand All @@ -94,7 +93,7 @@ async def receive_llm_responses(
if response.resource_type == "text":
full_reply.append(await self.handle_text(response))
elif response.resource_type == "documents":
citations += await self.handle_documents(response, user)
citations += await self.handle_documents(response)
elif response.resource_type == "route_name":
route = await self.handle_route(response)
elif response.resource_type == "metadata":
Expand All @@ -103,11 +102,14 @@ async def receive_llm_responses(
full_reply.append(await self.handle_error(response))
return "".join(full_reply), citations, route, metadata

async def handle_documents(self, response: ClientResponse, user: User) -> Sequence[tuple[File, SourceDocument]]:
source_files, citations = await self.get_sources_with_files(response.data, user)
for file in source_files:
async def handle_documents(self, response: ClientResponse) -> Sequence[tuple[File, SourceDocument]]:
s3_keys = [doc.s3_key for doc in response.data]
files = File.objects.filter(original_file__in=s3_keys)

async for file in files:
await self.send_to_client("source", {"url": str(file.url), "original_file_name": file.original_file_name})
return citations

return [(file, [doc for doc in response.data if doc.s3_key == file.unique_name]) for file in files]

async def handle_text(self, response: ClientResponse) -> str:
await self.send_to_client("text", response.data)
Expand Down Expand Up @@ -154,22 +156,6 @@ async def send_to_server(websocket: WebSocketClientProtocol, data: Mapping[str,
logger.debug("sending %s to core-api", data)
return await websocket.send(json.dumps(data, default=str))

@staticmethod
@database_sync_to_async
def get_session(session_id: str, user: User, user_message_text: str) -> Chat:
if session_id:
session = Chat.objects.get(id=session_id)
else:
session_name = user_message_text[0 : settings.CHAT_TITLE_LENGTH]
session = Chat(name=session_name, user=user)
session.save()
return session

@staticmethod
@database_sync_to_async
def get_messages(session: Chat) -> Sequence[ChatMessage]:
return list(ChatMessage.objects.filter(chat=session).order_by("created_at"))

@staticmethod
@database_sync_to_async
def save_message(
Expand All @@ -185,6 +171,9 @@ def save_message(
chat_message.save()
if sources:
for file, citations in sources:
file.last_referenced = timezone.now()
file.save()

for citation in citations:
Citation.objects.create(
chat_message=chat_message,
Expand All @@ -194,44 +183,24 @@ def save_message(
)
if selected_files:
chat_message.selected_files.set(selected_files)
if metadata:
if metadata.input_tokens:
for model, token_count in metadata.input_tokens.items():
ChatMessageTokenUse.objects.create(
chat_message=chat_message,
use_type=ChatMessageTokenUse.UseTypeEnum.INPUT,
model_name=model,
token_count=token_count,
)
if metadata.output_tokens:
for model, token_count in metadata.output_tokens.items():
ChatMessageTokenUse.objects.create(
chat_message=chat_message,
use_type=ChatMessageTokenUse.UseTypeEnum.OUTPUT,
model_name=model,
token_count=token_count,
)
return chat_message

@staticmethod
@database_sync_to_async
def get_files_by_id(docs: Sequence[UUID], user: User) -> Sequence[File]:
return list(File.objects.filter(id__in=docs, user=user))

@staticmethod
@database_sync_to_async
def get_sources_with_files(
docs: Sequence[SourceDocument], user: User
) -> tuple[Sequence[File], Sequence[tuple[File, Sequence[SourceDocument]]]]:
s3_keys = [doc.s3_key for doc in docs]
files = File.objects.filter(original_file__in=s3_keys, user=user)

return files, [(file, [doc for doc in docs if doc.s3_key == file.unique_name]) for file in files]

@staticmethod
@database_sync_to_async
def file_save(file):
return file.save()
if metadata and metadata.input_tokens:
for model, token_count in metadata.input_tokens.items():
ChatMessageTokenUse.objects.create(
chat_message=chat_message,
use_type=ChatMessageTokenUse.UseTypeEnum.INPUT,
model_name=model,
token_count=token_count,
)
if metadata and metadata.output_tokens:
for model, token_count in metadata.output_tokens.items():
ChatMessageTokenUse.objects.create(
chat_message=chat_message,
use_type=ChatMessageTokenUse.UseTypeEnum.OUTPUT,
model_name=model,
token_count=token_count,
)
return chat_message

@staticmethod
@database_sync_to_async
Expand Down

0 comments on commit 096bc35

Please sign in to comment.