Skip to content

Commit

Permalink
Feature/add govuk search (#1117)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkDunne authored Oct 29, 2024
1 parent 6391639 commit ebb45e4
Show file tree
Hide file tree
Showing 11 changed files with 227 additions and 45 deletions.
7 changes: 6 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,10 @@
"python.analysis.autoImportCompletions": true,
"python.experiments.optOutFrom": [
"pythonTerminalEnvVarActivation"
]
],
"python.testing.pytestArgs": [
"redbox-core"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}
1 change: 0 additions & 1 deletion django_app/.vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,4 @@
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,
"python.testing.pytestPath": "venv/bin/python -m pytest",

}
6 changes: 2 additions & 4 deletions django_app/redbox_app/redbox_core/consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@
File,
StatusEnum,
)
from redbox_app.redbox_core.models import (
AISettings as AISettingsModel,
)
from redbox_app.redbox_core.models import AISettings as AISettingsModel

User = get_user_model()
OptFileSeq = Sequence[File] | None
Expand Down Expand Up @@ -218,7 +216,7 @@ def save_ai_message(
url=citation_source.source,
text=citation_source.highlighted_text_in_source,
page_numbers=citation_source.page_numbers,
source=Citation.Origin(citation_source.source_type.title()),
source=Citation.Origin(citation_source.source_type),
)

if self.metadata:
Expand Down
68 changes: 54 additions & 14 deletions django_app/redbox_app/redbox_core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@ class Providers(models.TextChoices):
GROQ = "groq"
OLLAMA = "ollama"

name = models.CharField(max_length=128, help_text="The name of the model, e.g. “gpt-4o”, “claude-3-opus-20240229”.")
name = models.CharField(
max_length=128,
help_text="The name of the model, e.g. “gpt-4o”, “claude-3-opus-20240229”.",
)
provider = models.CharField(max_length=128, choices=Providers, help_text="The model provider")
description = models.TextField(null=True, blank=True, help_text="brief description of the model")
is_default = models.BooleanField(default=False, help_text="is this the default llm to use.")
Expand Down Expand Up @@ -139,13 +142,22 @@ class AISettings(UUIDPrimaryKeyBase, TimeStampedModel, AbstractAISettings):
rag_num_candidates = models.PositiveIntegerField(default=10)
rag_gauss_scale_size = models.PositiveIntegerField(default=3)
rag_gauss_scale_decay = models.DecimalField(
max_digits=5, decimal_places=2, default=0.5, validators=[validators.MinValueValidator(0.0)]
max_digits=5,
decimal_places=2,
default=0.5,
validators=[validators.MinValueValidator(0.0)],
)
rag_gauss_scale_min = models.DecimalField(
max_digits=5, decimal_places=2, default=1.1, validators=[validators.MinValueValidator(1.0)]
max_digits=5,
decimal_places=2,
default=1.1,
validators=[validators.MinValueValidator(1.0)],
)
rag_gauss_scale_max = models.DecimalField(
max_digits=5, decimal_places=2, default=2.0, validators=[validators.MinValueValidator(1.0)]
max_digits=5,
decimal_places=2,
default=2.0,
validators=[validators.MinValueValidator(1.0)],
)
rag_desired_chunk_size = models.PositiveIntegerField(default=300)
elbow_filter_enabled = models.BooleanField(default=False)
Expand All @@ -158,7 +170,10 @@ class AISettings(UUIDPrimaryKeyBase, TimeStampedModel, AbstractAISettings):
max_digits=5,
decimal_places=2,
default=0.7,
validators=[validators.MinValueValidator(0.0), validators.MaxValueValidator(1.0)],
validators=[
validators.MinValueValidator(0.0),
validators.MaxValueValidator(1.0),
],
)

def __str__(self) -> str:
Expand Down Expand Up @@ -430,7 +445,9 @@ class DurationTasks(models.TextChoices):
profession = models.CharField(null=True, blank=True, max_length=4, choices=Profession)
info_about_user = models.CharField(null=True, blank=True, help_text="user entered info from profile overlay")
redbox_response_preferences = models.CharField(
null=True, blank=True, help_text="user entered info from profile overlay, to be used in custom prompt"
null=True,
blank=True,
help_text="user entered info from profile overlay, to be used in custom prompt",
)
ai_settings = models.ForeignKey(AISettings, on_delete=models.SET_DEFAULT, default="default", to_field="label")
is_developer = models.BooleanField(null=True, blank=True, default=False, help_text="is this user a developer?")
Expand Down Expand Up @@ -555,7 +572,10 @@ class File(UUIDPrimaryKeyBase, TimeStampedModel):
original_file_name = models.TextField(max_length=2048, blank=True, null=True) # delete me
last_referenced = models.DateTimeField(blank=True, null=True)
ingest_error = models.TextField(
max_length=2048, blank=True, null=True, help_text="error, if any, encountered during ingest"
max_length=2048,
blank=True,
null=True,
help_text="error, if any, encountered during ingest",
)

def __str__(self) -> str: # pragma: no cover
Expand Down Expand Up @@ -673,7 +693,10 @@ def get_ordered_by_citation_priority(cls, chat_message_id: uuid.UUID) -> Sequenc
.annotate(min_created_at=Min("citation__created_at"))
.order_by("min_created_at")
.prefetch_related(
Prefetch("citation_set", queryset=Citation.objects.filter(chat_message_id=chat_message_id))
Prefetch(
"citation_set",
queryset=Citation.objects.filter(chat_message_id=chat_message_id),
)
)
)

Expand All @@ -685,7 +708,9 @@ class Chat(UUIDPrimaryKeyBase, TimeStampedModel, AbstractAISettings):

# Exit feedback - this is separate to the ratings for individual ChatMessages
feedback_achieved = models.BooleanField(
null=True, blank=True, help_text="Did Redbox do what you needed it to in this chat?"
null=True,
blank=True,
help_text="Did Redbox do what you needed it to in this chat?",
)
feedback_saved_time = models.BooleanField(null=True, blank=True, help_text="Did Redbox help save you time?")
feedback_improved_work = models.BooleanField(
Expand Down Expand Up @@ -743,16 +768,26 @@ class Origin(models.TextChoices):
GOV_UK = "GOV.UK", _("gov.uk")

file = models.ForeignKey(
File, on_delete=models.CASCADE, null=True, blank=True, help_text="file for internal citation"
File,
on_delete=models.CASCADE,
null=True,
blank=True,
help_text="file for internal citation",
)
url = models.URLField(null=True, blank=True, help_text="url for external")
chat_message = models.ForeignKey("ChatMessage", on_delete=models.CASCADE)
text = models.TextField(null=True, blank=True)
page_numbers = ArrayField(
models.PositiveIntegerField(), null=True, blank=True, help_text="location of citation in document"
models.PositiveIntegerField(),
null=True,
blank=True,
help_text="location of citation in document",
)
source = models.CharField(
max_length=32, choices=Origin, help_text="source of citation", default=Origin.USER_UPLOADED_DOCUMENT
max_length=32,
choices=Origin,
help_text="source of citation",
default=Origin.USER_UPLOADED_DOCUMENT,
)
text_in_answer = models.TextField(null=True, blank=True)

Expand Down Expand Up @@ -797,7 +832,9 @@ class ChatMessage(UUIDPrimaryKeyBase, TimeStampedModel):
source_files = models.ManyToManyField(File, through=Citation)

rating = models.PositiveIntegerField(
blank=True, null=True, validators=[validators.MinValueValidator(1), validators.MaxValueValidator(5)]
blank=True,
null=True,
validators=[validators.MinValueValidator(1), validators.MaxValueValidator(5)],
)
rating_text = models.TextField(blank=True, null=True)
rating_chips = ArrayField(models.CharField(max_length=32), null=True, blank=True)
Expand Down Expand Up @@ -835,7 +872,10 @@ class UseTypeEnum(models.TextChoices):

chat_message = models.ForeignKey(ChatMessage, on_delete=models.CASCADE)
use_type = models.CharField(
max_length=10, choices=UseTypeEnum, help_text="input or output tokens", default=UseTypeEnum.INPUT
max_length=10,
choices=UseTypeEnum,
help_text="input or output tokens",
default=UseTypeEnum.INPUT,
)
model_name = models.CharField(max_length=50, null=True, blank=True)
token_count = models.PositiveIntegerField(null=True, blank=True)
Expand Down
2 changes: 2 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ services:
condition: service_healthy
minio:
condition: service_healthy
elasticsearch:
condition: service_healthy
networks:
- redbox-app-network
env_file:
Expand Down
7 changes: 3 additions & 4 deletions redbox-core/redbox/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@
get_metadata_retriever,
get_parameterised_retriever,
)
from redbox.graph.nodes.tools import (
build_search_documents_tool,
build_search_wikipedia_tool,
)
from redbox.graph.nodes.tools import build_govuk_search_tool, build_search_documents_tool, build_search_wikipedia_tool
from redbox.graph.root import get_root_graph
from redbox.models.chain import RedboxState
from redbox.models.chat import ChatRoute
Expand Down Expand Up @@ -60,9 +57,11 @@ def __init__(
chunk_resolution=ChunkResolution.normal,
)
search_wikipedia = build_search_wikipedia_tool()
search_govuk = build_govuk_search_tool()

tools: dict[str, StructuredTool] = {
"_search_documents": search_documents,
"_search_govuk": search_govuk,
"_search_wikipedia": search_wikipedia,
}

Expand Down
70 changes: 68 additions & 2 deletions redbox-core/redbox/graph/nodes/tools.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Annotated, Any, get_args, get_origin, get_type_hints

import requests
import tiktoken
from elasticsearch import Elasticsearch
from langchain_community.utilities import WikipediaAPIWrapper
Expand All @@ -9,7 +10,7 @@
from langgraph.prebuilt import InjectedState

from redbox.models.chain import RedboxState
from redbox.models.file import ChunkMetadata, ChunkResolution
from redbox.models.file import ChunkCreatorType, ChunkMetadata, ChunkResolution
from redbox.retriever.queries import (
add_document_filter_scores_to_query,
build_document_query,
Expand Down Expand Up @@ -133,6 +134,71 @@ def _search_documents(query: str, state: Annotated[RedboxState, InjectedState])
return _search_documents


def build_govuk_search_tool(num_results: int = 1) -> Tool:
"""Constructs a tool that searches gov.uk and sets state["documents"]."""

tokeniser = tiktoken.encoding_for_model("gpt-4o")

@tool
def _search_govuk(query: str, state: Annotated[dict, InjectedState]) -> dict[str, Any]:
"""
Search for documents on gov.uk based on a query string.
This endpoint is used to search for documents on gov.uk. There are many types of documents on gov.uk.
Types include:
- guidance
- policy
- legislation
- news
- travel advice
- departmental reports
- statistics
- consultations
- appeals
"""

url_base = "https://www.gov.uk"
required_fields = [
"format",
"title",
"description",
"indexable_content",
"link",
]

response = requests.get(
f"{url_base}/api/search.json",
params={
"q": query,
"count": num_results,
"fields": required_fields,
},
headers={"Accept": "application/json"},
)
response.raise_for_status()
response = response.json()

mapped_documents = []
for i, doc in enumerate(response["results"]):
if any(field not in doc for field in required_fields):
continue

mapped_documents.append(
Document(
page_content=doc["indexable_content"],
metadata=ChunkMetadata(
index=i,
uri=f"{url_base}{doc['link']}",
token_count=len(tokeniser.encode(doc["indexable_content"])),
creator_type=ChunkCreatorType.gov_uk,
).model_dump(),
)
)

return {"documents": structure_documents_by_group_and_indices(mapped_documents)}

return _search_govuk


def build_search_wikipedia_tool(number_wikipedia_results=1, max_chars_per_wiki_page=12000) -> Tool:
"""Constructs a tool that searches Wikipedia"""
_wikipedia_wrapper = WikipediaAPIWrapper(
Expand Down Expand Up @@ -163,7 +229,7 @@ def _search_wikipedia(query: str, state: Annotated[RedboxState, InjectedState])
index=i,
uri=doc.metadata["source"],
token_count=len(tokeniser.encode(doc.page_content)),
creator_type="Wikipedia",
creator_type=ChunkCreatorType.wikipedia,
).model_dump(),
)
for i, doc in enumerate(response)
Expand Down
14 changes: 3 additions & 11 deletions redbox-core/redbox/graph/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from langgraph.graph import END, START, StateGraph
from langgraph.graph.graph import CompiledGraph


from redbox.chains.components import get_structured_response_with_citations_parser
from redbox.chains.runnables import build_self_route_output_parser
from redbox.graph.edges import (
Expand Down Expand Up @@ -32,18 +31,11 @@
empty_process,
report_sources_process,
)
from redbox.graph.nodes.sends import (
build_document_chunk_send,
build_document_group_send,
build_tool_send,
)
from redbox.graph.nodes.sends import build_document_chunk_send, build_document_group_send, build_tool_send
from redbox.models.chain import RedboxState
from redbox.models.chat import ChatRoute, ErrorRoute
from redbox.models.graph import ROUTABLE_KEYWORDS, RedboxActivityEvent
from redbox.transform import (
structure_documents_by_file_name,
structure_documents_by_group_and_indices,
)
from redbox.transform import structure_documents_by_file_name, structure_documents_by_group_and_indices

# Subgraphs

Expand Down Expand Up @@ -167,7 +159,7 @@ def get_agentic_search_graph(tools: dict[str, StructuredTool], debug: bool = Fal
citations_output_parser, format_instructions = get_structured_response_with_citations_parser()
builder = StateGraph(RedboxState)
# Tools
agent_tool_names = ["_search_documents", "_search_wikipedia"]
agent_tool_names = ["_search_documents", "_search_wikipedia", "_search_govuk"]
agent_tools: list[StructuredTool] = tuple([tools.get(tool_name) for tool_name in agent_tool_names])

# Processes
Expand Down
13 changes: 9 additions & 4 deletions redbox-core/redbox/models/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
from enum import StrEnum
from uuid import UUID, uuid4

from typing import Literal
from pydantic import BaseModel, Field, AliasChoices
from pydantic import AliasChoices, BaseModel, Field


class ChunkResolution(StrEnum):
Expand All @@ -16,6 +15,12 @@ class ChunkResolution(StrEnum):
largest = "largest"


class ChunkCreatorType(StrEnum):
wikipedia = "Wikipedia"
user_uploaded_document = "UserUploadedDocument"
gov_uk = "GOV.UK"


class ChunkMetadata(BaseModel):
"""
Worker model for document metadata for new style chunks.
Expand All @@ -26,7 +31,7 @@ class ChunkMetadata(BaseModel):
index: int = 0 # The order of this chunk in the original resource
created_datetime: datetime.datetime = datetime.datetime.now(datetime.UTC)
chunk_resolution: ChunkResolution = ChunkResolution.normal
creator_type: Literal["Wikipedia", "UserUploadedDocument"]
creator_type: ChunkCreatorType
uri: str = Field(validation_alias=AliasChoices("uri", "file_name")) # URL or file name
token_count: int

Expand All @@ -40,4 +45,4 @@ class UploadedFileMetadata(ChunkMetadata):
name: str
description: str
keywords: list[str]
creator_type: Literal["UserUploadedDocument"] = "UserUploadedDocument"
creator_type: ChunkCreatorType = ChunkCreatorType.user_uploaded_document
Loading

0 comments on commit ebb45e4

Please sign in to comment.