Skip to content

Commit

Permalink
core[patch]: On Chain Start Fix for Chain Class (#26593)
Browse files Browse the repository at this point in the history
- **Issue:** #26588

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
  • Loading branch information
keenborder786 and eyurtsev authored Sep 23, 2024
1 parent bba7af9 commit 154a5ff
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 2 deletions.
10 changes: 8 additions & 2 deletions libs/core/langchain_core/callbacks/stdout.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,14 @@ def on_chain_start(
inputs (Dict[str, Any]): The inputs to the chain.
**kwargs (Any): Additional keyword arguments.
"""
class_name = serialized.get("name", serialized.get("id", ["<unknown>"])[-1])
print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m") # noqa: T201
if "name" in kwargs:
name = kwargs["name"]
else:
if serialized:
name = serialized.get("name", serialized.get("id", ["<unknown>"])[-1])
else:
name = "<unknown>"
print(f"\n\n\033[1m> Entering new {name} chain...\033[0m") # noqa: T201

def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain.
Expand Down
44 changes: 44 additions & 0 deletions libs/langchain/tests/unit_tests/callbacks/test_stdout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import Any, Dict, List, Optional

import pytest

from langchain.callbacks import StdOutCallbackHandler
from langchain.chains.base import CallbackManagerForChainRun, Chain


class FakeChain(Chain):
"""Fake chain class for testing purposes."""

be_correct: bool = True
the_input_keys: List[str] = ["foo"]
the_output_keys: List[str] = ["bar"]

@property
def input_keys(self) -> List[str]:
"""Input keys."""
return self.the_input_keys

@property
def output_keys(self) -> List[str]:
"""Output key of bar."""
return self.the_output_keys

def _call(
self,
inputs: Dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
return {"bar": "bar"}


def test_stdoutcallback(capsys: pytest.CaptureFixture) -> Any:
"""Test the stdout callback handler."""
chain_test = FakeChain(callbacks=[StdOutCallbackHandler(color="red")])
chain_test.invoke({"foo": "bar"})
# Capture the output
captured = capsys.readouterr()
# Assert the output is as expected
assert captured.out == (
"\n\n\x1b[1m> Entering new FakeChain "
"chain...\x1b[0m\n\n\x1b[1m> Finished chain.\x1b[0m\n"
)

0 comments on commit 154a5ff

Please sign in to comment.