Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Process outputs #911

Merged
merged 4 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 29 additions & 11 deletions python/langsmith/run_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Mapping,
Optional,
Protocol,
Sequence,
Tuple,
Type,
TypedDict,
Expand Down Expand Up @@ -242,9 +243,10 @@ def traceable(
metadata: Optional[Mapping[str, Any]] = None,
tags: Optional[List[str]] = None,
client: Optional[ls_client.Client] = None,
reduce_fn: Optional[Callable] = None,
reduce_fn: Optional[Callable[[Sequence], dict]] = None,
project_name: Optional[str] = None,
process_inputs: Optional[Callable[[dict], dict]] = None,
process_outputs: Optional[Callable[..., dict]] = None,
_invocation_params_fn: Optional[Callable[[dict], dict]] = None,
) -> Callable[[Callable[P, R]], SupportsLangsmithExtra[P, R]]: ...

Expand All @@ -270,7 +272,11 @@ def traceable(
called, and the run itself will be stuck in a pending state.
project_name: The name of the project to log the run to. Defaults to None,
which will use the default project.
process_inputs: A function to filter the inputs to the run. Defaults to None.
process_inputs: Custom serialization / processing function for inputs.
Defaults to None.
process_outputs: Custom serialization / processing function for outputs.
Defaults to None.



Returns:
Expand Down Expand Up @@ -415,6 +421,18 @@ def manual_extra_function(x):
process_inputs=kwargs.pop("process_inputs", None),
invocation_params_fn=kwargs.pop("_invocation_params_fn", None),
)
outputs_processor = kwargs.pop("process_outputs", None)

def _on_run_end(
container: _TraceableContainer,
outputs: Optional[Any] = None,
error: Optional[BaseException] = None,
) -> None:
"""Handle the end of run."""
if outputs and outputs_processor is not None:
outputs = outputs_processor(outputs)
_container_end(container, outputs=outputs, error=error)

if kwargs:
warnings.warn(
f"The following keyword arguments are not recognized and will be ignored: "
Expand Down Expand Up @@ -463,11 +481,11 @@ async def async_wrapper(
except BaseException as e:
# shield from cancellation, given we're catching all exceptions
await asyncio.shield(
aitertools.aio_to_thread(_container_end, run_container, error=e)
aitertools.aio_to_thread(_on_run_end, run_container, error=e)
)
raise e
await aitertools.aio_to_thread(
_container_end, run_container, outputs=function_result
_on_run_end, run_container, outputs=function_result
)
return function_result

Expand Down Expand Up @@ -536,7 +554,7 @@ async def async_generator_wrapper(
pass
except BaseException as e:
await asyncio.shield(
aitertools.aio_to_thread(_container_end, run_container, error=e)
aitertools.aio_to_thread(_on_run_end, run_container, error=e)
)
raise e
if results:
Expand All @@ -551,7 +569,7 @@ async def async_generator_wrapper(
else:
function_result = None
await aitertools.aio_to_thread(
_container_end, run_container, outputs=function_result
_on_run_end, run_container, outputs=function_result
)

@functools.wraps(func)
Expand All @@ -578,9 +596,9 @@ def wrapper(
kwargs.pop("config", None)
function_result = run_container["context"].run(func, *args, **kwargs)
except BaseException as e:
_container_end(run_container, error=e)
_on_run_end(run_container, error=e)
raise e
_container_end(run_container, outputs=function_result)
_on_run_end(run_container, outputs=function_result)
return function_result

@functools.wraps(func)
Expand Down Expand Up @@ -630,7 +648,7 @@ def generator_wrapper(
pass

except BaseException as e:
_container_end(run_container, error=e)
_on_run_end(run_container, error=e)
raise e
if results:
if reduce_fn:
Expand All @@ -643,7 +661,7 @@ def generator_wrapper(
function_result = results
else:
function_result = None
_container_end(run_container, outputs=function_result)
_on_run_end(run_container, outputs=function_result)

if inspect.isasyncgenfunction(func):
selected_wrapper: Callable = async_generator_wrapper
Expand Down Expand Up @@ -1131,7 +1149,7 @@ def _container_end(
container: _TraceableContainer,
outputs: Optional[Any] = None,
error: Optional[BaseException] = None,
):
) -> None:
"""End the run."""
run_tree = container.get("new_run")
if run_tree is None:
Expand Down
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langsmith"
version = "0.1.97"
version = "0.1.98"
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
authors = ["LangChain <support@langchain.dev>"]
license = "MIT"
Expand Down
6 changes: 3 additions & 3 deletions python/tests/integration_tests/test_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import uuid
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from typing import AsyncGenerator, Generator, Optional
from typing import AsyncGenerator, Generator, Optional, Sequence

import pytest # type: ignore

Expand Down Expand Up @@ -330,7 +330,7 @@ def test_sync_generator_reduce_fn(langchain_client: Client):
project_name = "__My Tracer Project - test_sync_generator_reduce_fn"
run_meta = uuid.uuid4().hex

def reduce_fn(outputs: list) -> dict:
def reduce_fn(outputs: Sequence) -> dict:
return {"my_output": " ".join(outputs)}

@traceable(run_type="chain", reduce_fn=reduce_fn)
Expand Down Expand Up @@ -411,7 +411,7 @@ async def test_async_generator_reduce_fn(langchain_client: Client):
project_name = "__My Tracer Project - test_async_generator_reduce_fn"
run_meta = uuid.uuid4().hex

def reduce_fn(outputs: list) -> dict:
def reduce_fn(outputs: Sequence) -> dict:
return {"my_output": " ".join(outputs)}

@traceable(run_type="chain", reduce_fn=reduce_fn)
Expand Down
107 changes: 107 additions & 0 deletions python/tests/unit_tests/test_run_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1341,3 +1341,110 @@ async def test_trace_respects_env_var(env_var: bool, context: Optional[bool]):
assert len(mock_calls) >= 1
else:
assert not mock_calls


async def test_process_inputs_outputs():
mock_client = _get_mock_client()
in_s = "what's life's meaning"

def process_inputs(inputs: dict) -> dict:
assert inputs == {"val": in_s, "ooblek": "nada"}
inputs["val2"] = "this is mutated"
return {"serialized_in": "what's the meaning of life?"}

def process_outputs(outputs: int) -> dict:
assert outputs == 42
return {"serialized_out": 24}

@traceable(process_inputs=process_inputs, process_outputs=process_outputs)
def my_function(val: str, **kwargs: Any) -> int:
assert not kwargs.get("val2")
return 42

with tracing_context(enabled=True):
my_function(
in_s,
ooblek="nada",
langsmith_extra={"client": mock_client},
)

def _check_client(client: Client) -> None:
mock_calls = _get_calls(client)
assert len(mock_calls) == 1
call = mock_calls[0]
assert call.args[0] == "POST"
assert call.args[1].startswith("https://api.smith.langchain.com")
body = json.loads(call.kwargs["data"])
assert body["post"]
assert body["post"][0]["inputs"] == {
"serialized_in": "what's the meaning of life?"
}
assert body["post"][0]["outputs"] == {"serialized_out": 24}

_check_client(mock_client)

@traceable(process_inputs=process_inputs, process_outputs=process_outputs)
async def amy_function(val: str, **kwargs: Any) -> int:
assert not kwargs.get("val2")
return 42

mock_client = _get_mock_client()
with tracing_context(enabled=True):
await amy_function(
in_s,
ooblek="nada",
langsmith_extra={"client": mock_client},
)

_check_client(mock_client)

# Do generator

def reducer(outputs: list) -> dict:
return {"reduced": outputs[0]}

def process_reduced_outputs(outputs: dict) -> dict:
assert outputs == {"reduced": 42}
return {"serialized_out": 24}

@traceable(
process_inputs=process_inputs,
process_outputs=process_reduced_outputs,
reduce_fn=reducer,
)
def my_gen(val: str, **kwargs: Any) -> Generator[int, None, None]:
assert not kwargs.get("val2")
yield 42

mock_client = _get_mock_client()
with tracing_context(enabled=True):
result = list(
my_gen(
in_s,
ooblek="nada",
langsmith_extra={"client": mock_client},
)
)
assert result == [42]

_check_client(mock_client)

@traceable(
process_inputs=process_inputs,
process_outputs=process_reduced_outputs,
reduce_fn=reducer,
)
async def amy_gen(val: str, **kwargs: Any) -> AsyncGenerator[int, None]:
assert not kwargs.get("val2")
yield 42

mock_client = _get_mock_client()
with tracing_context(enabled=True):
result = [
i
async for i in amy_gen(
in_s, ooblek="nada", langsmith_extra={"client": mock_client}
)
]
assert result == [42]
_check_client(mock_client)
Loading