diff --git a/libs/core/langchain_core/callbacks/stdout.py b/libs/core/langchain_core/callbacks/stdout.py index bcdb5317ff950..aadc1cc8ebb9a 100644 --- a/libs/core/langchain_core/callbacks/stdout.py +++ b/libs/core/langchain_core/callbacks/stdout.py @@ -32,8 +32,14 @@ def on_chain_start( inputs (Dict[str, Any]): The inputs to the chain. **kwargs (Any): Additional keyword arguments. """ - class_name = serialized.get("name", serialized.get("id", [""])[-1]) - print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m") # noqa: T201 + if "name" in kwargs: + name = kwargs["name"] + else: + if serialized: + name = serialized.get("name", serialized.get("id", [""])[-1]) + else: + name = "" + print(f"\n\n\033[1m> Entering new {name} chain...\033[0m") # noqa: T201 def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None: """Print out that we finished a chain. diff --git a/libs/langchain/tests/unit_tests/callbacks/test_stdout.py b/libs/langchain/tests/unit_tests/callbacks/test_stdout.py new file mode 100644 index 0000000000000..f983da718d90a --- /dev/null +++ b/libs/langchain/tests/unit_tests/callbacks/test_stdout.py @@ -0,0 +1,44 @@ +from typing import Any, Dict, List, Optional + +import pytest + +from langchain.callbacks import StdOutCallbackHandler +from langchain.chains.base import CallbackManagerForChainRun, Chain + + +class FakeChain(Chain): + """Fake chain class for testing purposes.""" + + be_correct: bool = True + the_input_keys: List[str] = ["foo"] + the_output_keys: List[str] = ["bar"] + + @property + def input_keys(self) -> List[str]: + """Input keys.""" + return self.the_input_keys + + @property + def output_keys(self) -> List[str]: + """Output key of bar.""" + return self.the_output_keys + + def _call( + self, + inputs: Dict[str, str], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + return {"bar": "bar"} + + +def test_stdoutcallback(capsys: pytest.CaptureFixture) -> Any: + """Test the stdout callback handler.""" + chain_test = FakeChain(callbacks=[StdOutCallbackHandler(color="red")]) + chain_test.invoke({"foo": "bar"}) + # Capture the output + captured = capsys.readouterr() + # Assert the output is as expected + assert captured.out == ( + "\n\n\x1b[1m> Entering new FakeChain " + "chain...\x1b[0m\n\n\x1b[1m> Finished chain.\x1b[0m\n" + )