diff --git a/redbox-core/redbox/chains/runnables.py b/redbox-core/redbox/chains/runnables.py index 62ceedd0d..082311a78 100644 --- a/redbox-core/redbox/chains/runnables.py +++ b/redbox-core/redbox/chains/runnables.py @@ -1,6 +1,5 @@ import logging import re -from operator import attrgetter, itemgetter from typing import Any, Callable, Iterable, Iterator from langchain_core.callbacks.manager import ( @@ -29,37 +28,14 @@ from redbox.models.graph import RedboxEventType from redbox.transform import ( flatten_document_state, - to_request_metadata, tool_calls_to_toolstate, + get_all_metadata, ) log = logging.getLogger() re_string_pattern = re.compile(r"(\S+)") -def combine_getters(*getters: Callable[[Any], Any]) -> Callable[[Any], Any]: - """Permits chaining of *getter functions in LangChain.""" - - def _combined(obj): - for getter in getters: - obj = getter(obj) - return obj - - return _combined - - -def itemgetter_with_default(field: str, default_getter: Callable[[Any], Any]): - getter = itemgetter(field) - - def _impl(obj): - try: - return getter(obj) - except Exception: - return default_getter(obj) - - return _impl - - def build_chat_prompt_from_messages_runnable( prompt_set: PromptSet, tokeniser: Encoding = None, @@ -131,50 +107,22 @@ def build_llm_chain( _llm = llm.with_config(tags=["response_flag"]) if final_response_chain else llm _output_parser = output_parser if output_parser else StrOutputParser() + _llm_text_and_tools = _llm | { + "raw_response": RunnablePassthrough(), + "parsed_response": _output_parser, + "tool_calls": tool_calls_to_toolstate, + } + + text_and_tools = { + "text_and_tools": _llm_text_and_tools, + "prompt": RunnableLambda(lambda prompt: prompt.to_string()), + "model": lambda _: model_name, + } + return ( build_chat_prompt_from_messages_runnable(prompt_set, partial_variables={"format_arg": format_instructions}) - | { - "text_and_tools": ( - _llm - | { - "raw_response": RunnablePassthrough(), - "parsed_response": _output_parser, - "tool_calls": (RunnableLambda(lambda r: r.tool_calls) | tool_calls_to_toolstate), - } - ), - "prompt": RunnableLambda(lambda prompt: prompt.to_string()), - } - | { - "text": RunnableLambda( - combine_getters( - itemgetter("text_and_tools"), - itemgetter_with_default( - "parsed_response", combine_getters(itemgetter("raw_response"), attrgetter("content")) - ), - ) - ) - | (lambda r: r if isinstance(r, str) else r.answer), - "tool_calls": combine_getters(itemgetter("text_and_tools"), itemgetter("tool_calls")), - "citations": RunnableLambda( - combine_getters( - itemgetter("text_and_tools"), - itemgetter_with_default( - "parsed_response", combine_getters(itemgetter("raw_response"), attrgetter("content")) - ), - ) - ) - | (lambda r: [] if isinstance(r, str) else r.citations), - "metadata": ( - { - "prompt": itemgetter("prompt"), - "response": combine_getters( - itemgetter("text_and_tools"), itemgetter("raw_response"), attrgetter("content") - ), - "model": lambda _: model_name, - } - | to_request_metadata - ), - } + | text_and_tools + | get_all_metadata | RunnablePassthrough.assign( _log=RunnableLambda( lambda _: log_activity(f"Generating response with {model_name}...") if final_response_chain else None diff --git a/redbox-core/redbox/transform.py b/redbox-core/redbox/transform.py index b6b2e56b6..12fc16b88 100644 --- a/redbox-core/redbox/transform.py +++ b/redbox-core/redbox/transform.py @@ -4,7 +4,7 @@ import tiktoken from langchain_core.callbacks.manager import dispatch_custom_event from langchain_core.documents import Document -from langchain_core.messages import ToolCall +from langchain_core.messages import ToolCall, AnyMessage from langchain_core.runnables import RunnableLambda from redbox.models.chain import ( @@ -130,21 +130,23 @@ def get_document_token_count(state: RedboxState) -> int: return sum(d.metadata["token_count"] for d in flatten_document_state(state.get("documents", []))) -@RunnableLambda -def to_request_metadata(prompt_response_model: dict): +def to_request_metadata(obj: dict): """Takes a dictionary with keys 'prompt', 'response' and 'model' and creates metadata. Will also emit events for metadata updates. """ - model = prompt_response_model["model"] + + prompt = obj["prompt"] + response = obj["text_and_tools"]["raw_response"].content + model = obj["model"] try: tokeniser = tiktoken.encoding_for_model(model) except KeyError: tokeniser = tiktoken.get_encoding("cl100k_base") - input_tokens = len(tokeniser.encode(prompt_response_model["prompt"])) - output_tokens = len(tokeniser.encode(prompt_response_model["response"])) + input_tokens = len(tokeniser.encode(prompt)) + output_tokens = len(tokeniser.encode(response)) metadata_event = RequestMetadata( llm_calls=[LLMCallMetadata(llm_model_name=model, input_tokens=input_tokens, output_tokens=output_tokens)] @@ -155,6 +157,26 @@ def to_request_metadata(prompt_response_model: dict): return metadata_event +@RunnableLambda +def get_all_metadata(obj: dict): + text_and_tools = obj["text_and_tools"] + + if parsed_response := text_and_tools.get("parsed_response"): + text = getattr(parsed_response, "answer", parsed_response) + citations = getattr(parsed_response, "citations", []) + else: + text = text_and_tools["raw_response"].content + citations = [] + + out = { + "tool_calls": text_and_tools["tool_calls"], + "metadata": to_request_metadata(obj), + "text": text, + "citations": citations, + } + return out + + def merge_documents(initial: list[Document], adjacent: list[Document]) -> list[Document]: """Merges a list of adjacent documents with an initial list. @@ -249,9 +271,9 @@ def process_group(group: list[Document]) -> list[list[Document]]: return list(itertools.chain.from_iterable(all_sorted_blocks_by_max_score)) -def tool_calls_to_toolstate(tool_calls: list[ToolCall], called: bool | None = False) -> ToolState: +def tool_calls_to_toolstate(message: AnyMessage, called: bool | None = False) -> ToolState: """Takes a list of tool calls and shapes them into a valid ToolState. Sets all tool calls to a called state. Assumes this state is False. """ - return {t["id"]: {"tool": ToolCall(**t), "called": called} for t in tool_calls} + return ToolState({t["id"]: {"tool": ToolCall(**t), "called": called} for t in message.tool_calls}) diff --git a/redbox-core/tests/graph/test_patterns.py b/redbox-core/tests/graph/test_patterns.py index 65243c223..93d6579ba 100644 --- a/redbox-core/tests/graph/test_patterns.py +++ b/redbox-core/tests/graph/test_patterns.py @@ -104,7 +104,7 @@ def test_build_llm_chain(test_case: RedboxChatTestCase): final_state = llm_chain.invoke(state) test_case_content = test_case.test_data.llm_responses[-1].content - test_case_tool_calls = tool_calls_to_toolstate(test_case.test_data.llm_responses[-1].tool_calls) + test_case_tool_calls = tool_calls_to_toolstate(test_case.test_data.llm_responses[-1]) assert ( final_state["text"] == test_case_content @@ -371,12 +371,13 @@ def test_build_tool_pattern(tools: list[StructuredTool], expected: dict[str, str """Tests some basic tools update the state correctly.""" tool = build_tool_pattern(tools=tools) tool_calls = [{"name": tool.name, "args": {}, "id": tool.name} for tool in tools] + message = AIMessage(content="hello", tool_calls=tool_calls) state = RedboxState( request=RedboxQuery( question="What is AI?", s3_keys=[], user_uuid=uuid4(), chat_history=[], permitted_s3_keys=[] ), - tool_calls=tool_calls_to_toolstate(tool_calls=tool_calls, called=False), + tool_calls=tool_calls_to_toolstate(message=message, called=False), ) response = tool.invoke(state) diff --git a/redbox-core/tests/test_transform.py b/redbox-core/tests/test_transform.py index 39f9924ae..860fbe222 100644 --- a/redbox-core/tests/test_transform.py +++ b/redbox-core/tests/test_transform.py @@ -5,6 +5,8 @@ import pytest from langchain_core.documents.base import Document +from langchain_core.messages import AIMessage +from langchain_core.runnables import RunnableLambda from redbox.models.chain import LLMCallMetadata, RequestMetadata from redbox.retriever.retrievers import filter_by_elbow @@ -162,24 +164,32 @@ def test_elbow_filter(scores: list[float], target_len: int): ( { "prompt": "Lorem ipsum dolor sit amet.", - "response": ( - "Lorem ipsum dolor sit amet, consectetur adipiscing elit, " - "sed do eiusmod tempor incididunt ut labore et dolore magna " - "aliqua. " - ), "model": "gpt-4o", + "text_and_tools": { + "raw_response": AIMessage( + content=( + "Lorem ipsum dolor sit amet, consectetur adipiscing elit, " + "sed do eiusmod tempor incididunt ut labore et dolore magna " + "aliqua. " + ) + ) + }, }, RequestMetadata(llm_calls={LLMCallMetadata(llm_model_name="gpt-4o", input_tokens=6, output_tokens=23)}), ), ( { "prompt": "Lorem ipsum dolor sit amet.", - "response": ( - "Lorem ipsum dolor sit amet, consectetur adipiscing elit, " - "sed do eiusmod tempor incididunt ut labore et dolore magna " - "aliqua. " - ), "model": "unknown-model", + "text_and_tools": { + "raw_response": AIMessage( + content=( + "Lorem ipsum dolor sit amet, consectetur adipiscing elit, " + "sed do eiusmod tempor incididunt ut labore et dolore magna " + "aliqua. " + ) + ) + }, }, RequestMetadata( llm_calls={LLMCallMetadata(llm_model_name="unknown-model", input_tokens=6, output_tokens=23)} @@ -188,7 +198,7 @@ def test_elbow_filter(scores: list[float], target_len: int): ], ) def test_to_request_metadata(output: dict, expected: RequestMetadata): - result = to_request_metadata.invoke(output) + result = RunnableLambda(to_request_metadata).invoke(output) # We assert on token counts here as the id generation causes the LLMCallMetadata objects not to match assert ( result.input_tokens == expected.input_tokens