From ebb45e468c5abc84c0a19c97898bba9f07f88e7b Mon Sep 17 00:00:00 2001 From: Mark Dunne Date: Tue, 29 Oct 2024 15:33:13 +0000 Subject: [PATCH] Feature/add govuk search (#1117) --- .vscode/settings.json | 7 +- django_app/.vscode/settings.json | 1 - .../redbox_app/redbox_core/consumers.py | 6 +- django_app/redbox_app/redbox_core/models.py | 68 ++++++++++++++---- docker-compose.yml | 2 + redbox-core/redbox/app.py | 7 +- redbox-core/redbox/graph/nodes/tools.py | 70 ++++++++++++++++++- redbox-core/redbox/graph/root.py | 14 +--- redbox-core/redbox/models/file.py | 13 ++-- redbox-core/tests/graph/test_app.py | 45 +++++++++++- redbox-core/tests/test_tools.py | 39 ++++++++++- 11 files changed, 227 insertions(+), 45 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index aea20205a..4678f2ae3 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -10,5 +10,10 @@ "python.analysis.autoImportCompletions": true, "python.experiments.optOutFrom": [ "pythonTerminalEnvVarActivation" - ] + ], + "python.testing.pytestArgs": [ + "redbox-core" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true } \ No newline at end of file diff --git a/django_app/.vscode/settings.json b/django_app/.vscode/settings.json index 28bd20880..ea9112184 100644 --- a/django_app/.vscode/settings.json +++ b/django_app/.vscode/settings.json @@ -5,5 +5,4 @@ "python.testing.unittestEnabled": false, "python.testing.pytestEnabled": true, "python.testing.pytestPath": "venv/bin/python -m pytest", - } \ No newline at end of file diff --git a/django_app/redbox_app/redbox_core/consumers.py b/django_app/redbox_app/redbox_core/consumers.py index 3bb9e6184..e3d41de4d 100644 --- a/django_app/redbox_app/redbox_core/consumers.py +++ b/django_app/redbox_app/redbox_core/consumers.py @@ -41,9 +41,7 @@ File, StatusEnum, ) -from redbox_app.redbox_core.models import ( - AISettings as AISettingsModel, -) +from redbox_app.redbox_core.models import AISettings as AISettingsModel User = get_user_model() OptFileSeq = Sequence[File] | None @@ -218,7 +216,7 @@ def save_ai_message( url=citation_source.source, text=citation_source.highlighted_text_in_source, page_numbers=citation_source.page_numbers, - source=Citation.Origin(citation_source.source_type.title()), + source=Citation.Origin(citation_source.source_type), ) if self.metadata: diff --git a/django_app/redbox_app/redbox_core/models.py b/django_app/redbox_app/redbox_core/models.py index df4614dc1..2ae21274e 100644 --- a/django_app/redbox_app/redbox_core/models.py +++ b/django_app/redbox_app/redbox_core/models.py @@ -71,7 +71,10 @@ class Providers(models.TextChoices): GROQ = "groq" OLLAMA = "ollama" - name = models.CharField(max_length=128, help_text="The name of the model, e.g. “gpt-4o”, “claude-3-opus-20240229”.") + name = models.CharField( + max_length=128, + help_text="The name of the model, e.g. “gpt-4o”, “claude-3-opus-20240229”.", + ) provider = models.CharField(max_length=128, choices=Providers, help_text="The model provider") description = models.TextField(null=True, blank=True, help_text="brief description of the model") is_default = models.BooleanField(default=False, help_text="is this the default llm to use.") @@ -139,13 +142,22 @@ class AISettings(UUIDPrimaryKeyBase, TimeStampedModel, AbstractAISettings): rag_num_candidates = models.PositiveIntegerField(default=10) rag_gauss_scale_size = models.PositiveIntegerField(default=3) rag_gauss_scale_decay = models.DecimalField( - max_digits=5, decimal_places=2, default=0.5, validators=[validators.MinValueValidator(0.0)] + max_digits=5, + decimal_places=2, + default=0.5, + validators=[validators.MinValueValidator(0.0)], ) rag_gauss_scale_min = models.DecimalField( - max_digits=5, decimal_places=2, default=1.1, validators=[validators.MinValueValidator(1.0)] + max_digits=5, + decimal_places=2, + default=1.1, + validators=[validators.MinValueValidator(1.0)], ) rag_gauss_scale_max = models.DecimalField( - max_digits=5, decimal_places=2, default=2.0, validators=[validators.MinValueValidator(1.0)] + max_digits=5, + decimal_places=2, + default=2.0, + validators=[validators.MinValueValidator(1.0)], ) rag_desired_chunk_size = models.PositiveIntegerField(default=300) elbow_filter_enabled = models.BooleanField(default=False) @@ -158,7 +170,10 @@ class AISettings(UUIDPrimaryKeyBase, TimeStampedModel, AbstractAISettings): max_digits=5, decimal_places=2, default=0.7, - validators=[validators.MinValueValidator(0.0), validators.MaxValueValidator(1.0)], + validators=[ + validators.MinValueValidator(0.0), + validators.MaxValueValidator(1.0), + ], ) def __str__(self) -> str: @@ -430,7 +445,9 @@ class DurationTasks(models.TextChoices): profession = models.CharField(null=True, blank=True, max_length=4, choices=Profession) info_about_user = models.CharField(null=True, blank=True, help_text="user entered info from profile overlay") redbox_response_preferences = models.CharField( - null=True, blank=True, help_text="user entered info from profile overlay, to be used in custom prompt" + null=True, + blank=True, + help_text="user entered info from profile overlay, to be used in custom prompt", ) ai_settings = models.ForeignKey(AISettings, on_delete=models.SET_DEFAULT, default="default", to_field="label") is_developer = models.BooleanField(null=True, blank=True, default=False, help_text="is this user a developer?") @@ -555,7 +572,10 @@ class File(UUIDPrimaryKeyBase, TimeStampedModel): original_file_name = models.TextField(max_length=2048, blank=True, null=True) # delete me last_referenced = models.DateTimeField(blank=True, null=True) ingest_error = models.TextField( - max_length=2048, blank=True, null=True, help_text="error, if any, encountered during ingest" + max_length=2048, + blank=True, + null=True, + help_text="error, if any, encountered during ingest", ) def __str__(self) -> str: # pragma: no cover @@ -673,7 +693,10 @@ def get_ordered_by_citation_priority(cls, chat_message_id: uuid.UUID) -> Sequenc .annotate(min_created_at=Min("citation__created_at")) .order_by("min_created_at") .prefetch_related( - Prefetch("citation_set", queryset=Citation.objects.filter(chat_message_id=chat_message_id)) + Prefetch( + "citation_set", + queryset=Citation.objects.filter(chat_message_id=chat_message_id), + ) ) ) @@ -685,7 +708,9 @@ class Chat(UUIDPrimaryKeyBase, TimeStampedModel, AbstractAISettings): # Exit feedback - this is separate to the ratings for individual ChatMessages feedback_achieved = models.BooleanField( - null=True, blank=True, help_text="Did Redbox do what you needed it to in this chat?" + null=True, + blank=True, + help_text="Did Redbox do what you needed it to in this chat?", ) feedback_saved_time = models.BooleanField(null=True, blank=True, help_text="Did Redbox help save you time?") feedback_improved_work = models.BooleanField( @@ -743,16 +768,26 @@ class Origin(models.TextChoices): GOV_UK = "GOV.UK", _("gov.uk") file = models.ForeignKey( - File, on_delete=models.CASCADE, null=True, blank=True, help_text="file for internal citation" + File, + on_delete=models.CASCADE, + null=True, + blank=True, + help_text="file for internal citation", ) url = models.URLField(null=True, blank=True, help_text="url for external") chat_message = models.ForeignKey("ChatMessage", on_delete=models.CASCADE) text = models.TextField(null=True, blank=True) page_numbers = ArrayField( - models.PositiveIntegerField(), null=True, blank=True, help_text="location of citation in document" + models.PositiveIntegerField(), + null=True, + blank=True, + help_text="location of citation in document", ) source = models.CharField( - max_length=32, choices=Origin, help_text="source of citation", default=Origin.USER_UPLOADED_DOCUMENT + max_length=32, + choices=Origin, + help_text="source of citation", + default=Origin.USER_UPLOADED_DOCUMENT, ) text_in_answer = models.TextField(null=True, blank=True) @@ -797,7 +832,9 @@ class ChatMessage(UUIDPrimaryKeyBase, TimeStampedModel): source_files = models.ManyToManyField(File, through=Citation) rating = models.PositiveIntegerField( - blank=True, null=True, validators=[validators.MinValueValidator(1), validators.MaxValueValidator(5)] + blank=True, + null=True, + validators=[validators.MinValueValidator(1), validators.MaxValueValidator(5)], ) rating_text = models.TextField(blank=True, null=True) rating_chips = ArrayField(models.CharField(max_length=32), null=True, blank=True) @@ -835,7 +872,10 @@ class UseTypeEnum(models.TextChoices): chat_message = models.ForeignKey(ChatMessage, on_delete=models.CASCADE) use_type = models.CharField( - max_length=10, choices=UseTypeEnum, help_text="input or output tokens", default=UseTypeEnum.INPUT + max_length=10, + choices=UseTypeEnum, + help_text="input or output tokens", + default=UseTypeEnum.INPUT, ) model_name = models.CharField(max_length=50, null=True, blank=True) token_count = models.PositiveIntegerField(null=True, blank=True) diff --git a/docker-compose.yml b/docker-compose.yml index 9ff3c74a9..291792538 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -15,6 +15,8 @@ services: condition: service_healthy minio: condition: service_healthy + elasticsearch: + condition: service_healthy networks: - redbox-app-network env_file: diff --git a/redbox-core/redbox/app.py b/redbox-core/redbox/app.py index d030108ba..199d3e7db 100644 --- a/redbox-core/redbox/app.py +++ b/redbox-core/redbox/app.py @@ -8,10 +8,7 @@ get_metadata_retriever, get_parameterised_retriever, ) -from redbox.graph.nodes.tools import ( - build_search_documents_tool, - build_search_wikipedia_tool, -) +from redbox.graph.nodes.tools import build_govuk_search_tool, build_search_documents_tool, build_search_wikipedia_tool from redbox.graph.root import get_root_graph from redbox.models.chain import RedboxState from redbox.models.chat import ChatRoute @@ -60,9 +57,11 @@ def __init__( chunk_resolution=ChunkResolution.normal, ) search_wikipedia = build_search_wikipedia_tool() + search_govuk = build_govuk_search_tool() tools: dict[str, StructuredTool] = { "_search_documents": search_documents, + "_search_govuk": search_govuk, "_search_wikipedia": search_wikipedia, } diff --git a/redbox-core/redbox/graph/nodes/tools.py b/redbox-core/redbox/graph/nodes/tools.py index 2eb7f3c7b..883c0d690 100644 --- a/redbox-core/redbox/graph/nodes/tools.py +++ b/redbox-core/redbox/graph/nodes/tools.py @@ -1,5 +1,6 @@ from typing import Annotated, Any, get_args, get_origin, get_type_hints +import requests import tiktoken from elasticsearch import Elasticsearch from langchain_community.utilities import WikipediaAPIWrapper @@ -9,7 +10,7 @@ from langgraph.prebuilt import InjectedState from redbox.models.chain import RedboxState -from redbox.models.file import ChunkMetadata, ChunkResolution +from redbox.models.file import ChunkCreatorType, ChunkMetadata, ChunkResolution from redbox.retriever.queries import ( add_document_filter_scores_to_query, build_document_query, @@ -133,6 +134,71 @@ def _search_documents(query: str, state: Annotated[RedboxState, InjectedState]) return _search_documents +def build_govuk_search_tool(num_results: int = 1) -> Tool: + """Constructs a tool that searches gov.uk and sets state["documents"].""" + + tokeniser = tiktoken.encoding_for_model("gpt-4o") + + @tool + def _search_govuk(query: str, state: Annotated[dict, InjectedState]) -> dict[str, Any]: + """ + Search for documents on gov.uk based on a query string. + This endpoint is used to search for documents on gov.uk. There are many types of documents on gov.uk. + Types include: + - guidance + - policy + - legislation + - news + - travel advice + - departmental reports + - statistics + - consultations + - appeals + """ + + url_base = "https://www.gov.uk" + required_fields = [ + "format", + "title", + "description", + "indexable_content", + "link", + ] + + response = requests.get( + f"{url_base}/api/search.json", + params={ + "q": query, + "count": num_results, + "fields": required_fields, + }, + headers={"Accept": "application/json"}, + ) + response.raise_for_status() + response = response.json() + + mapped_documents = [] + for i, doc in enumerate(response["results"]): + if any(field not in doc for field in required_fields): + continue + + mapped_documents.append( + Document( + page_content=doc["indexable_content"], + metadata=ChunkMetadata( + index=i, + uri=f"{url_base}{doc['link']}", + token_count=len(tokeniser.encode(doc["indexable_content"])), + creator_type=ChunkCreatorType.gov_uk, + ).model_dump(), + ) + ) + + return {"documents": structure_documents_by_group_and_indices(mapped_documents)} + + return _search_govuk + + def build_search_wikipedia_tool(number_wikipedia_results=1, max_chars_per_wiki_page=12000) -> Tool: """Constructs a tool that searches Wikipedia""" _wikipedia_wrapper = WikipediaAPIWrapper( @@ -163,7 +229,7 @@ def _search_wikipedia(query: str, state: Annotated[RedboxState, InjectedState]) index=i, uri=doc.metadata["source"], token_count=len(tokeniser.encode(doc.page_content)), - creator_type="Wikipedia", + creator_type=ChunkCreatorType.wikipedia, ).model_dump(), ) for i, doc in enumerate(response) diff --git a/redbox-core/redbox/graph/root.py b/redbox-core/redbox/graph/root.py index 6794178a6..c59d17868 100644 --- a/redbox-core/redbox/graph/root.py +++ b/redbox-core/redbox/graph/root.py @@ -3,7 +3,6 @@ from langgraph.graph import END, START, StateGraph from langgraph.graph.graph import CompiledGraph - from redbox.chains.components import get_structured_response_with_citations_parser from redbox.chains.runnables import build_self_route_output_parser from redbox.graph.edges import ( @@ -32,18 +31,11 @@ empty_process, report_sources_process, ) -from redbox.graph.nodes.sends import ( - build_document_chunk_send, - build_document_group_send, - build_tool_send, -) +from redbox.graph.nodes.sends import build_document_chunk_send, build_document_group_send, build_tool_send from redbox.models.chain import RedboxState from redbox.models.chat import ChatRoute, ErrorRoute from redbox.models.graph import ROUTABLE_KEYWORDS, RedboxActivityEvent -from redbox.transform import ( - structure_documents_by_file_name, - structure_documents_by_group_and_indices, -) +from redbox.transform import structure_documents_by_file_name, structure_documents_by_group_and_indices # Subgraphs @@ -167,7 +159,7 @@ def get_agentic_search_graph(tools: dict[str, StructuredTool], debug: bool = Fal citations_output_parser, format_instructions = get_structured_response_with_citations_parser() builder = StateGraph(RedboxState) # Tools - agent_tool_names = ["_search_documents", "_search_wikipedia"] + agent_tool_names = ["_search_documents", "_search_wikipedia", "_search_govuk"] agent_tools: list[StructuredTool] = tuple([tools.get(tool_name) for tool_name in agent_tool_names]) # Processes diff --git a/redbox-core/redbox/models/file.py b/redbox-core/redbox/models/file.py index 90705cb7a..14ed2f7f6 100644 --- a/redbox-core/redbox/models/file.py +++ b/redbox-core/redbox/models/file.py @@ -4,8 +4,7 @@ from enum import StrEnum from uuid import UUID, uuid4 -from typing import Literal -from pydantic import BaseModel, Field, AliasChoices +from pydantic import AliasChoices, BaseModel, Field class ChunkResolution(StrEnum): @@ -16,6 +15,12 @@ class ChunkResolution(StrEnum): largest = "largest" +class ChunkCreatorType(StrEnum): + wikipedia = "Wikipedia" + user_uploaded_document = "UserUploadedDocument" + gov_uk = "GOV.UK" + + class ChunkMetadata(BaseModel): """ Worker model for document metadata for new style chunks. @@ -26,7 +31,7 @@ class ChunkMetadata(BaseModel): index: int = 0 # The order of this chunk in the original resource created_datetime: datetime.datetime = datetime.datetime.now(datetime.UTC) chunk_resolution: ChunkResolution = ChunkResolution.normal - creator_type: Literal["Wikipedia", "UserUploadedDocument"] + creator_type: ChunkCreatorType uri: str = Field(validation_alias=AliasChoices("uri", "file_name")) # URL or file name token_count: int @@ -40,4 +45,4 @@ class UploadedFileMetadata(ChunkMetadata): name: str description: str keywords: list[str] - creator_type: Literal["UserUploadedDocument"] = "UserUploadedDocument" + creator_type: ChunkCreatorType = ChunkCreatorType.user_uploaded_document diff --git a/redbox-core/tests/graph/test_app.py b/redbox-core/tests/graph/test_app.py index d4d68c4fa..dcd295266 100644 --- a/redbox-core/tests/graph/test_app.py +++ b/redbox-core/tests/graph/test_app.py @@ -13,11 +13,11 @@ from redbox.models.chain import ( AISettings, Citation, - StructuredResponseWithCitations, RedboxQuery, RedboxState, RequestMetadata, Source, + StructuredResponseWithCitations, metadata_reducer, ) from redbox.models.chat import ChatRoute, ErrorRoute @@ -462,6 +462,43 @@ def assert_number_of_events(num_of_events: int): ], test_id="No Such Keyword with docs", ), + generate_test_cases( + query=RedboxQuery( + question="@gadget Tell me about travel advice to cuba", + s3_keys=[], + user_uuid=uuid4(), + chat_history=[], + permitted_s3_keys=["s3_key"], + ), + test_data=[ + RedboxTestData( + number_of_docs=1, + tokens_in_all_docs=10000, + llm_responses=[ + AIMessage( + content="", + additional_kwargs={ + "tool_calls": [ + { + "id": "call_e4003b", + "function": { + "arguments": '{\n "query": "travel advice to cuba"\n}', + "name": "_search_govuk", + }, + "type": "function", + } + ] + }, + ), + "answer", + StructuredResponseWithCitations(answer="AI is a lie", citations=[]).model_dump_json(), + ], + expected_text="AI is a lie", + expected_route=ChatRoute.gadget, + ), + ], + test_id="Agentic govuk search", + ), ] for test_case in generated_cases ] @@ -481,7 +518,13 @@ def _search_documents(query: str) -> dict[str, Any]: """Tool to search documents.""" return {"documents": structure_documents_by_group_and_indices(test_case.docs)} + @tool + def _search_govuk(query: str) -> dict[str, Any]: + """Tool to search gov.uk for travel advice and other government information.""" + return {"documents": structure_documents_by_group_and_indices(test_case.docs)} + mocker.patch("redbox.app.build_search_documents_tool", return_value=_search_documents) + mocker.patch("redbox.app.build_govuk_search_tool", return_value=_search_govuk) mocker.patch("redbox.graph.nodes.processes.get_chat_llm", return_value=llm) # Instantiate app diff --git a/redbox-core/tests/test_tools.py b/redbox-core/tests/test_tools.py index f146aceca..d6a468d84 100644 --- a/redbox-core/tests/test_tools.py +++ b/redbox-core/tests/test_tools.py @@ -1,6 +1,6 @@ from typing import Annotated, Any -from uuid import UUID, uuid4 from urllib.parse import urlparse +from uuid import UUID, uuid4 import pytest from elasticsearch import Elasticsearch @@ -9,6 +9,7 @@ from langgraph.prebuilt import InjectedState from redbox.graph.nodes.tools import ( + build_govuk_search_tool, build_search_documents_tool, build_search_wikipedia_tool, has_injected_state, @@ -16,7 +17,7 @@ ) from redbox.models import Settings from redbox.models.chain import AISettings, RedboxQuery, RedboxState -from redbox.models.file import ChunkMetadata, ChunkResolution +from redbox.models.file import ChunkCreatorType, ChunkMetadata, ChunkResolution from redbox.test.data import RedboxChatTestCase from redbox.transform import flatten_document_state from tests.retriever.test_retriever import TEST_CHAIN_PARAMETERS @@ -153,6 +154,37 @@ def test_search_documents_tool( assert group_docs[doc.metadata["uuid"]] == doc +def test_govuk_search_tool(): + tool = build_govuk_search_tool() + + state_update = tool.invoke( + { + "query": "Cuba Travel Advice", + "state": RedboxState( + request=RedboxQuery( + question="Search gov.uk for travel advice to cuba", + s3_keys=[], + user_uuid=uuid4(), + chat_history=[], + ai_settings=AISettings(), + permitted_s3_keys=[], + ) + ), + } + ) + + documents = flatten_document_state(state_update["documents"]) + + # assert at least one document is travel advice + assert any("/foreign-travel-advice/cuba" in document.metadata["uri"] for document in documents) + + for document in documents: + assert document.page_content != "" + metadata = ChunkMetadata.model_validate(document.metadata) + assert urlparse(metadata.uri).hostname == "www.gov.uk" + assert metadata.creator_type == ChunkCreatorType.gov_uk + + def test_wikipedia_tool(): tool = build_search_wikipedia_tool() state_update = tool.invoke( @@ -170,8 +202,9 @@ def test_wikipedia_tool(): ), } ) + for document in flatten_document_state(state_update["documents"]): assert document.page_content != "" metadata = ChunkMetadata.model_validate(document.metadata) assert urlparse(metadata.uri).hostname == "en.wikipedia.org" - assert metadata.creator_type == "Wikipedia" + assert metadata.creator_type == ChunkCreatorType.wikipedia