Skip to content

Commit

Permalink
some tests passing
Browse files Browse the repository at this point in the history
  • Loading branch information
George Burton committed Sep 5, 2024
1 parent 9f0d60a commit 8f4ff86
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 117 deletions.
7 changes: 4 additions & 3 deletions django_app/redbox_app/redbox_core/consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,10 @@ def get_ai_settings(user: User) -> AISettings:
fields=[field.name for field in user.ai_settings._meta.fields if field.name != "label"], # noqa: SLF001
)

async def handle_text(self, response: ClientResponse) -> str:
await self.send_to_client("text", response.data)
self.full_reply.append(response.data)
async def handle_text(self, text: str) -> str:
if text:
await self.send_to_client("text", text)
self.full_reply.append(text)

async def handle_route(self, response: ClientResponse) -> str:
await self.send_to_client("route", response.data)
Expand Down
183 changes: 78 additions & 105 deletions django_app/tests/test_consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from websockets import WebSocketClientProtocol
from websockets.legacy.client import Connect

from redbox.chains.runnables import CannedChatLLM
from redbox_app.redbox_core import error_messages
from redbox_app.redbox_core.consumers import ChatConsumer
from redbox_app.redbox_core.models import Chat, ChatMessage, ChatMessageTokenUse, ChatRoleEnum, File, User
Expand All @@ -22,6 +23,12 @@
logger = logging.getLogger(__name__)


class TestCannedChatLLM(CannedChatLLM):
def _convert_input(self, x):
x = x["request"].question
return super()._convert_input(x)


@database_sync_to_async
def get_token_use_model(use_type: str) -> str:
return ChatMessageTokenUse.objects.filter(use_type=use_type).latest("created_at").model_name
Expand All @@ -34,11 +41,14 @@ def get_token_use_count(use_type: str) -> int:

@pytest.mark.django_db(transaction=True)
@pytest.mark.asyncio()
async def test_chat_consumer_with_new_session(alice: User, uploaded_file: File, mocked_connect: Connect):
async def test_chat_consumer_with_new_session(alice: User, uploaded_file: File):
# Given
runnable = TestCannedChatLLM(text=["Good afternoon, ", "Mr. Amor.", "cow", "horse"]).with_config(
tags=["response_flag"]
)

# When
with patch("redbox_app.redbox_core.consumers.connect", new=mocked_connect):
with patch("redbox_app.redbox_core.consumers.ChatConsumer.redbox.graph", new=runnable):
communicator = WebsocketCommunicator(ChatConsumer.as_asgi(), "/ws/chat/")
communicator.scope["user"] = alice
connected, _ = await communicator.connect()
Expand Down Expand Up @@ -81,11 +91,16 @@ async def test_chat_consumer_with_new_session(alice: User, uploaded_file: File,

@pytest.mark.django_db(transaction=True)
@pytest.mark.asyncio()
async def test_chat_consumer_staff_user(staff_user: User, mocked_connect: Connect):
async def test_chat_consumer_staff_user(staff_user: User):
# Given

# Given
runnable = TestCannedChatLLM(text=["Good afternoon, ", "Mr. Amor.", "cow", "horse"]).with_config(
tags=["response_flag"]
)

# When
with patch("redbox_app.redbox_core.consumers.connect", new=mocked_connect):
with patch("redbox_app.redbox_core.consumers.ChatConsumer.redbox.graph", new=runnable):
communicator = WebsocketCommunicator(ChatConsumer.as_asgi(), "/ws/chat/")
communicator.scope["user"] = staff_user
connected, _ = await communicator.connect()
Expand Down Expand Up @@ -114,11 +129,16 @@ async def test_chat_consumer_staff_user(staff_user: User, mocked_connect: Connec

@pytest.mark.django_db(transaction=True)
@pytest.mark.asyncio()
async def test_chat_consumer_with_existing_session(alice: User, chat: Chat, mocked_connect: Connect):
async def test_chat_consumer_with_existing_session(alice: User, chat: Chat):
# Given

# Given
runnable = TestCannedChatLLM(text=["Good afternoon, ", "Mr. Amor.", "cow", "horse"]).with_config(
tags=["response_flag"]
)

# When
with patch("redbox_app.redbox_core.consumers.connect", new=mocked_connect):
with patch("redbox_app.redbox_core.consumers.ChatConsumer.redbox.graph", new=runnable):
communicator = WebsocketCommunicator(ChatConsumer.as_asgi(), "/ws/chat/")
communicator.scope["user"] = alice
connected, _ = await communicator.connect()
Expand All @@ -140,11 +160,14 @@ async def test_chat_consumer_with_existing_session(alice: User, chat: Chat, mock

@pytest.mark.django_db(transaction=True)
@pytest.mark.asyncio()
async def test_chat_consumer_with_naughty_question(alice: User, uploaded_file: File, mocked_connect: Connect):
async def test_chat_consumer_with_naughty_question(alice: User, uploaded_file: File):
# Given
runnable = TestCannedChatLLM(text=["Good afternoon, ", "Mr. Amor.", "cow", "horse"]).with_config(
tags=["response_flag"]
)

# When
with patch("redbox_app.redbox_core.consumers.connect", new=mocked_connect):
with patch("redbox_app.redbox_core.consumers.ChatConsumer.redbox.graph", new=runnable):
communicator = WebsocketCommunicator(ChatConsumer.as_asgi(), "/ws/chat/")
communicator.scope["user"] = alice
connected, _ = await communicator.connect()
Expand Down Expand Up @@ -180,12 +203,16 @@ async def test_chat_consumer_with_naughty_question(alice: User, uploaded_file: F
@pytest.mark.django_db(transaction=True)
@pytest.mark.asyncio()
async def test_chat_consumer_with_naughty_citation(
alice: User, uploaded_file: File, mocked_connect_with_naughty_citation: Connect
alice: User,
uploaded_file: File, # mocked_connect_with_naughty_citation: Connect
):
# Given
runnable = TestCannedChatLLM(text=["Good afternoon, ", "Mr. Amor.", "cow", "horse"]).with_config(
tags=["response_flag"]
)

# When
with patch("redbox_app.redbox_core.consumers.connect", new=mocked_connect_with_naughty_citation):
with patch("redbox_app.redbox_core.consumers.ChatConsumer.redbox.graph", new=runnable):
communicator = WebsocketCommunicator(ChatConsumer.as_asgi(), "/ws/chat/")
communicator.scope["user"] = alice
connected, _ = await communicator.connect()
Expand Down Expand Up @@ -242,13 +269,16 @@ async def test_chat_consumer_with_selected_files(
alice: User,
several_files: Sequence[File],
chat_with_files: Chat,
mocked_connect_with_several_files: Connect,
# mocked_connect_with_several_files: Connect,
):
# Given
selected_files: Sequence[File] = several_files[2:]
runnable = TestCannedChatLLM(text=["Good afternoon, ", "Mr. Amor.", "cow", "horse"]).with_config(
tags=["response_flag"]
)

# When
with patch("redbox_app.redbox_core.consumers.connect", new=mocked_connect_with_several_files):
with patch("redbox_app.redbox_core.consumers.ChatConsumer.redbox.graph", new=runnable):
communicator = WebsocketCommunicator(ChatConsumer.as_asgi(), "/ws/chat/")
communicator.scope["user"] = alice
connected, _ = await communicator.connect()
Expand Down Expand Up @@ -302,42 +332,49 @@ async def test_chat_consumer_with_selected_files(

@pytest.mark.django_db(transaction=True)
@pytest.mark.asyncio()
async def test_chat_consumer_with_connection_error(alice: User, mocked_breaking_connect: Connect):
async def test_chat_consumer_with_connection_error(
alice: User,
# mocked_breaking_connect: Connect,
):
# Given

runnable = TestCannedChatLLM(text=[{"error": ""}]).with_config(tags=["response_flag"])

# When
with patch("redbox_app.redbox_core.consumers.connect", new=mocked_breaking_connect):
with patch("redbox_app.redbox_core.consumers.ChatConsumer.redbox.graph", new=runnable):
communicator = WebsocketCommunicator(ChatConsumer.as_asgi(), "/ws/chat/")
communicator.scope["user"] = alice
connected, _ = await communicator.connect()
assert connected

await communicator.send_json_to({"message": "Hello Hal."})
await communicator.receive_json_from(timeout=5)
response2 = await communicator.receive_json_from(timeout=5)
response = await communicator.receive_json_from(timeout=5)

# Then
assert response2["type"] == "error"
assert response["type"] == "error"


@pytest.mark.django_db(transaction=True)
@pytest.mark.asyncio()
async def test_chat_consumer_with_explicit_unhandled_error(
alice: User, mocked_connect_with_explicit_unhandled_error: Connect
):
async def test_chat_consumer_with_explicit_unhandled_error(alice: User):
# Given

runnable = TestCannedChatLLM(text=["Good afternoon, ", error_messages.CORE_ERROR_MESSAGE]).with_config(
tags=["response_flag"]
)

# When
with patch("redbox_app.redbox_core.consumers.connect", new=mocked_connect_with_explicit_unhandled_error):
with patch("redbox_app.redbox_core.consumers.ChatConsumer.redbox.graph", new=runnable):
communicator = WebsocketCommunicator(ChatConsumer.as_asgi(), "/ws/chat/")
communicator.scope["user"] = alice
connected, _ = await communicator.connect()
assert connected

await communicator.send_json_to({"message": "Hello Hal."})
response1 = await communicator.receive_json_from(timeout=5)
response2 = await communicator.receive_json_from(timeout=5)
response3 = await communicator.receive_json_from(timeout=5)
response1 = await communicator.receive_json_from(timeout=555)
response2 = await communicator.receive_json_from(timeout=555)
response3 = await communicator.receive_json_from(timeout=555)

# Then
assert response1["type"] == "session-id"
Expand All @@ -351,11 +388,15 @@ async def test_chat_consumer_with_explicit_unhandled_error(

@pytest.mark.django_db(transaction=True)
@pytest.mark.asyncio()
async def test_chat_consumer_with_rate_limited_error(alice: User, mocked_connect_with_rate_limited_error: Connect):
async def test_chat_consumer_with_rate_limited_error(alice: User):
# Given

runnable = TestCannedChatLLM(text=["Good afternoon, ", error_messages.RATE_LIMITED]).with_config(
tags=["response_flag"]
)

# When
with patch("redbox_app.redbox_core.consumers.connect", new=mocked_connect_with_rate_limited_error):
with patch("redbox_app.redbox_core.consumers.ChatConsumer.redbox.graph", new=runnable):
communicator = WebsocketCommunicator(ChatConsumer.as_asgi(), "/ws/chat/")
communicator.scope["user"] = alice
connected, _ = await communicator.connect()
Expand All @@ -378,13 +419,14 @@ async def test_chat_consumer_with_rate_limited_error(alice: User, mocked_connect

@pytest.mark.django_db(transaction=True)
@pytest.mark.asyncio()
async def test_chat_consumer_with_explicit_no_document_selected_error(
alice: User, mocked_connect_with_explicit_no_document_selected_error: Connect
):
async def test_chat_consumer_with_explicit_no_document_selected_error(alice: User):
# Given
runnable = TestCannedChatLLM(text=["Please select a document.", error_messages.RATE_LIMITED]).with_config(
tags=["response_flag"]
)

# When
with patch("redbox_app.redbox_core.consumers.connect", new=mocked_connect_with_explicit_no_document_selected_error):
with patch("redbox_app.redbox_core.consumers.ChatConsumer.redbox.graph", new=runnable):
communicator = WebsocketCommunicator(ChatConsumer.as_asgi(), "/ws/chat/")
communicator.scope["user"] = alice
connected, _ = await communicator.connect()
Expand All @@ -404,10 +446,12 @@ async def test_chat_consumer_with_explicit_no_document_selected_error(

@pytest.mark.django_db()
@pytest.mark.asyncio()
async def test_chat_consumer_get_ai_settings(
alice: User, mocked_connect_with_explicit_no_document_selected_error: Connect
):
with patch("redbox_app.redbox_core.consumers.connect", new=mocked_connect_with_explicit_no_document_selected_error):
async def test_chat_consumer_get_ai_settings(alice: User):
# Given
runnable = TestCannedChatLLM(text=[]).with_config(tags=["response_flag"])

# When
with patch("redbox_app.redbox_core.consumers.ChatConsumer.redbox.graph", new=runnable):
communicator = WebsocketCommunicator(ChatConsumer.as_asgi(), "/ws/chat/")
communicator.scope["user"] = alice
connected, _ = await communicator.connect()
Expand All @@ -434,45 +478,6 @@ def get_chat_messages(user: User) -> Sequence[ChatMessage]:
)


@pytest.fixture()
def mocked_connect(uploaded_file: File) -> Connect:
mocked_websocket = AsyncMock(spec=WebSocketClientProtocol, name="mocked_websocket")
mocked_connect = MagicMock(spec=Connect, name="mocked_connect")
mocked_connect.return_value.__aenter__.return_value = mocked_websocket
mocked_websocket.__aiter__.return_value = [
json.dumps({"resource_type": "text", "data": "Good afternoon, "}),
json.dumps({"resource_type": "text", "data": "Mr. Amor."}),
json.dumps({"resource_type": "route_name", "data": "gratitude"}),
json.dumps(
{
"resource_type": "documents",
"data": [{"s3_key": uploaded_file.unique_name, "page_content": "Good afternoon Mr Amor"}],
}
),
json.dumps(
{
"resource_type": "documents",
"data": [
{"s3_key": uploaded_file.unique_name, "page_content": "Good afternoon Mr Amor"},
{
"s3_key": uploaded_file.unique_name,
"page_content": "Good afternoon Mr Amor",
"page_numbers": [34, 35],
},
],
}
),
json.dumps(
{
"resource_type": "metadata",
"data": {"input_tokens": {"gpt-4o": 123}, "output_tokens": {"gpt-4o": 1000}},
}
),
json.dumps({"resource_type": "end"}),
]
return mocked_connect


@pytest.fixture()
def mocked_connect_with_naughty_citation(uploaded_file: File) -> Connect:
mocked_websocket = AsyncMock(spec=WebSocketClientProtocol, name="mocked_websocket")
Expand Down Expand Up @@ -506,39 +511,7 @@ def mocked_breaking_connect() -> Connect:

@pytest.fixture()
def mocked_connect_with_explicit_unhandled_error() -> Connect:
mocked_websocket = AsyncMock(spec=WebSocketClientProtocol, name="mocked_websocket")
mocked_connect = MagicMock(spec=Connect, name="mocked_connect")
mocked_connect.return_value.__aenter__.return_value = mocked_websocket
mocked_websocket.__aiter__.return_value = [
json.dumps({"resource_type": "text", "data": "Good afternoon, "}),
json.dumps({"resource_type": "error", "data": {"code": "unknown", "message": "Oh dear."}}),
]
return mocked_connect


@pytest.fixture()
def mocked_connect_with_rate_limited_error() -> Connect:
mocked_websocket = AsyncMock(spec=WebSocketClientProtocol, name="mocked_websocket")
mocked_connect = MagicMock(spec=Connect, name="mocked_connect")
mocked_connect.return_value.__aenter__.return_value = mocked_websocket
mocked_websocket.__aiter__.return_value = [
json.dumps({"resource_type": "text", "data": "Good afternoon, "}),
json.dumps(
{"resource_type": "error", "data": {"code": "rate-limit", "message": "HTTP/1.1 429 Too Many Requests"}}
),
]
return mocked_connect


@pytest.fixture()
def mocked_connect_with_explicit_no_document_selected_error() -> Connect:
mocked_websocket = AsyncMock(spec=WebSocketClientProtocol, name="mocked_websocket")
mocked_connect = MagicMock(spec=Connect, name="mocked_connect")
mocked_connect.return_value.__aenter__.return_value = mocked_websocket
mocked_websocket.__aiter__.return_value = [
json.dumps({"resource_type": "error", "data": {"code": "no-document-selected", "message": "whatever"}}),
]
return mocked_connect
return 3


@pytest.fixture()
Expand Down
7 changes: 3 additions & 4 deletions redbox-core/redbox/chains/runnables.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
from typing import Any, Iterator
import re
from operator import itemgetter

from langchain_core.callbacks.manager import CallbackManagerForLLMRun
Expand All @@ -21,7 +20,6 @@


log = logging.getLogger()
re_string_pattern = re.compile(r"(\S+)")


def build_chat_prompt_from_messages_runnable(prompt_set: PromptSet, tokeniser: Encoding = None) -> Runnable:
Expand Down Expand Up @@ -92,7 +90,7 @@ class CannedChatLLM(BaseChatModel):
Based on https://python.langchain.com/v0.2/docs/how_to/custom_chat_model/
"""

text: str
text: list[str]

def _generate(
self,
Expand Down Expand Up @@ -137,7 +135,8 @@ def _stream(
downstream and understand why generation stopped.
run_manager: A run manager with callbacks for the LLM.
"""
for token in re_string_pattern.split(self.text):

for token in self.text:
chunk = ChatGenerationChunk(message=AIMessageChunk(content=token))

if run_manager:
Expand Down
6 changes: 3 additions & 3 deletions redbox-core/redbox/graph/nodes/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,15 @@ def _passthrough(state: RedboxState) -> dict[str, Any]:
return _passthrough


def build_set_text_pattern(text: str, final_response_chain: bool = False):
def build_set_text_pattern(texts: list[str], final_response_chain: bool = False):
"""Returns a function that can arbitrarily set state["text"] to a value."""
llm = CannedChatLLM(text=text)
llm = CannedChatLLM(text=texts)
_llm = llm.with_config(tags=["response_flag"]) if final_response_chain else llm

def _set_text(state: RedboxState) -> dict[str, Any]:
set_text_chain = _llm | StrOutputParser()

return {"text": set_text_chain.invoke(text)}
return {"text": set_text_chain.invoke(texts)}

return _set_text

Expand Down
Loading

0 comments on commit 8f4ff86

Please sign in to comment.