Skip to content

Commit

Permalink
feature/use functions not operators (#1171)
Browse files Browse the repository at this point in the history
  • Loading branch information
gecBurton authored Nov 7, 2024
1 parent 2fd0dac commit 55c9629
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 88 deletions.
82 changes: 15 additions & 67 deletions redbox-core/redbox/chains/runnables.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
import re
from operator import attrgetter, itemgetter
from typing import Any, Callable, Iterable, Iterator

from langchain_core.callbacks.manager import (
Expand Down Expand Up @@ -29,37 +28,14 @@
from redbox.models.graph import RedboxEventType
from redbox.transform import (
flatten_document_state,
to_request_metadata,
tool_calls_to_toolstate,
get_all_metadata,
)

log = logging.getLogger()
re_string_pattern = re.compile(r"(\S+)")


def combine_getters(*getters: Callable[[Any], Any]) -> Callable[[Any], Any]:
"""Permits chaining of *getter functions in LangChain."""

def _combined(obj):
for getter in getters:
obj = getter(obj)
return obj

return _combined


def itemgetter_with_default(field: str, default_getter: Callable[[Any], Any]):
getter = itemgetter(field)

def _impl(obj):
try:
return getter(obj)
except Exception:
return default_getter(obj)

return _impl


def build_chat_prompt_from_messages_runnable(
prompt_set: PromptSet,
tokeniser: Encoding = None,
Expand Down Expand Up @@ -131,50 +107,22 @@ def build_llm_chain(
_llm = llm.with_config(tags=["response_flag"]) if final_response_chain else llm
_output_parser = output_parser if output_parser else StrOutputParser()

_llm_text_and_tools = _llm | {
"raw_response": RunnablePassthrough(),
"parsed_response": _output_parser,
"tool_calls": tool_calls_to_toolstate,
}

text_and_tools = {
"text_and_tools": _llm_text_and_tools,
"prompt": RunnableLambda(lambda prompt: prompt.to_string()),
"model": lambda _: model_name,
}

return (
build_chat_prompt_from_messages_runnable(prompt_set, partial_variables={"format_arg": format_instructions})
| {
"text_and_tools": (
_llm
| {
"raw_response": RunnablePassthrough(),
"parsed_response": _output_parser,
"tool_calls": (RunnableLambda(lambda r: r.tool_calls) | tool_calls_to_toolstate),
}
),
"prompt": RunnableLambda(lambda prompt: prompt.to_string()),
}
| {
"text": RunnableLambda(
combine_getters(
itemgetter("text_and_tools"),
itemgetter_with_default(
"parsed_response", combine_getters(itemgetter("raw_response"), attrgetter("content"))
),
)
)
| (lambda r: r if isinstance(r, str) else r.answer),
"tool_calls": combine_getters(itemgetter("text_and_tools"), itemgetter("tool_calls")),
"citations": RunnableLambda(
combine_getters(
itemgetter("text_and_tools"),
itemgetter_with_default(
"parsed_response", combine_getters(itemgetter("raw_response"), attrgetter("content"))
),
)
)
| (lambda r: [] if isinstance(r, str) else r.citations),
"metadata": (
{
"prompt": itemgetter("prompt"),
"response": combine_getters(
itemgetter("text_and_tools"), itemgetter("raw_response"), attrgetter("content")
),
"model": lambda _: model_name,
}
| to_request_metadata
),
}
| text_and_tools
| get_all_metadata
| RunnablePassthrough.assign(
_log=RunnableLambda(
lambda _: log_activity(f"Generating response with {model_name}...") if final_response_chain else None
Expand Down
38 changes: 30 additions & 8 deletions redbox-core/redbox/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import tiktoken
from langchain_core.callbacks.manager import dispatch_custom_event
from langchain_core.documents import Document
from langchain_core.messages import ToolCall
from langchain_core.messages import ToolCall, AnyMessage
from langchain_core.runnables import RunnableLambda

from redbox.models.chain import (
Expand Down Expand Up @@ -130,21 +130,23 @@ def get_document_token_count(state: RedboxState) -> int:
return sum(d.metadata["token_count"] for d in flatten_document_state(state.get("documents", [])))


@RunnableLambda
def to_request_metadata(prompt_response_model: dict):
def to_request_metadata(obj: dict):
"""Takes a dictionary with keys 'prompt', 'response' and 'model' and creates metadata.
Will also emit events for metadata updates.
"""
model = prompt_response_model["model"]

prompt = obj["prompt"]
response = obj["text_and_tools"]["raw_response"].content
model = obj["model"]

try:
tokeniser = tiktoken.encoding_for_model(model)
except KeyError:
tokeniser = tiktoken.get_encoding("cl100k_base")

input_tokens = len(tokeniser.encode(prompt_response_model["prompt"]))
output_tokens = len(tokeniser.encode(prompt_response_model["response"]))
input_tokens = len(tokeniser.encode(prompt))
output_tokens = len(tokeniser.encode(response))

metadata_event = RequestMetadata(
llm_calls=[LLMCallMetadata(llm_model_name=model, input_tokens=input_tokens, output_tokens=output_tokens)]
Expand All @@ -155,6 +157,26 @@ def to_request_metadata(prompt_response_model: dict):
return metadata_event


@RunnableLambda
def get_all_metadata(obj: dict):
text_and_tools = obj["text_and_tools"]

if parsed_response := text_and_tools.get("parsed_response"):
text = getattr(parsed_response, "answer", parsed_response)
citations = getattr(parsed_response, "citations", [])
else:
text = text_and_tools["raw_response"].content
citations = []

out = {
"tool_calls": text_and_tools["tool_calls"],
"metadata": to_request_metadata(obj),
"text": text,
"citations": citations,
}
return out


def merge_documents(initial: list[Document], adjacent: list[Document]) -> list[Document]:
"""Merges a list of adjacent documents with an initial list.
Expand Down Expand Up @@ -249,9 +271,9 @@ def process_group(group: list[Document]) -> list[list[Document]]:
return list(itertools.chain.from_iterable(all_sorted_blocks_by_max_score))


def tool_calls_to_toolstate(tool_calls: list[ToolCall], called: bool | None = False) -> ToolState:
def tool_calls_to_toolstate(message: AnyMessage, called: bool | None = False) -> ToolState:
"""Takes a list of tool calls and shapes them into a valid ToolState.
Sets all tool calls to a called state. Assumes this state is False.
"""
return {t["id"]: {"tool": ToolCall(**t), "called": called} for t in tool_calls}
return ToolState({t["id"]: {"tool": ToolCall(**t), "called": called} for t in message.tool_calls})
5 changes: 3 additions & 2 deletions redbox-core/tests/graph/test_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def test_build_llm_chain(test_case: RedboxChatTestCase):
final_state = llm_chain.invoke(state)

test_case_content = test_case.test_data.llm_responses[-1].content
test_case_tool_calls = tool_calls_to_toolstate(test_case.test_data.llm_responses[-1].tool_calls)
test_case_tool_calls = tool_calls_to_toolstate(test_case.test_data.llm_responses[-1])

assert (
final_state["text"] == test_case_content
Expand Down Expand Up @@ -371,12 +371,13 @@ def test_build_tool_pattern(tools: list[StructuredTool], expected: dict[str, str
"""Tests some basic tools update the state correctly."""
tool = build_tool_pattern(tools=tools)
tool_calls = [{"name": tool.name, "args": {}, "id": tool.name} for tool in tools]
message = AIMessage(content="hello", tool_calls=tool_calls)

state = RedboxState(
request=RedboxQuery(
question="What is AI?", s3_keys=[], user_uuid=uuid4(), chat_history=[], permitted_s3_keys=[]
),
tool_calls=tool_calls_to_toolstate(tool_calls=tool_calls, called=False),
tool_calls=tool_calls_to_toolstate(message=message, called=False),
)

response = tool.invoke(state)
Expand Down
32 changes: 21 additions & 11 deletions redbox-core/tests/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import pytest
from langchain_core.documents.base import Document
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableLambda

from redbox.models.chain import LLMCallMetadata, RequestMetadata
from redbox.retriever.retrievers import filter_by_elbow
Expand Down Expand Up @@ -162,24 +164,32 @@ def test_elbow_filter(scores: list[float], target_len: int):
(
{
"prompt": "Lorem ipsum dolor sit amet.",
"response": (
"Lorem ipsum dolor sit amet, consectetur adipiscing elit, "
"sed do eiusmod tempor incididunt ut labore et dolore magna "
"aliqua. "
),
"model": "gpt-4o",
"text_and_tools": {
"raw_response": AIMessage(
content=(
"Lorem ipsum dolor sit amet, consectetur adipiscing elit, "
"sed do eiusmod tempor incididunt ut labore et dolore magna "
"aliqua. "
)
)
},
},
RequestMetadata(llm_calls={LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=6, output_tokens=23)}),
),
(
{
"prompt": "Lorem ipsum dolor sit amet.",
"response": (
"Lorem ipsum dolor sit amet, consectetur adipiscing elit, "
"sed do eiusmod tempor incididunt ut labore et dolore magna "
"aliqua. "
),
"model": "unknown-model",
"text_and_tools": {
"raw_response": AIMessage(
content=(
"Lorem ipsum dolor sit amet, consectetur adipiscing elit, "
"sed do eiusmod tempor incididunt ut labore et dolore magna "
"aliqua. "
)
)
},
},
RequestMetadata(
llm_calls={LLMCallMetadata(llm_model_name="unknown-model", input_tokens=6, output_tokens=23)}
Expand All @@ -188,7 +198,7 @@ def test_elbow_filter(scores: list[float], target_len: int):
],
)
def test_to_request_metadata(output: dict, expected: RequestMetadata):
result = to_request_metadata.invoke(output)
result = RunnableLambda(to_request_metadata).invoke(output)
# We assert on token counts here as the id generation causes the LLMCallMetadata objects not to match
assert (
result.input_tokens == expected.input_tokens
Expand Down

0 comments on commit 55c9629

Please sign in to comment.