From 7fc9e99e2158e676393d91f4b347014c9589cfe2 Mon Sep 17 00:00:00 2001 From: liuhetian <91518757+liuhetian@users.noreply.github.com> Date: Sat, 14 Sep 2024 02:42:01 +0800 Subject: [PATCH] openai[patch]: get output_type when using with_structured_output (#26307) - This allows pydantic to correctly resolve annotations necessary when using openai new param `json_schema` Resolves issue: #26250 --------- Co-authored-by: Eugene Yurtsev Co-authored-by: Bagatur --- .../langchain_openai/chat_models/base.py | 14 ++++----- .../tests/unit_tests/chat_models/test_base.py | 29 +++++++++++++++++++ 2 files changed, 36 insertions(+), 7 deletions(-) diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 61d9141df41ce..58e7c702cef16 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -65,7 +65,6 @@ from langchain_core.messages.ai import UsageMetadata from langchain_core.messages.tool import tool_call_chunk from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser -from langchain_core.output_parsers.base import OutputParserLike from langchain_core.output_parsers.openai_tools import ( JsonOutputKeyToolsParser, PydanticToolsParser, @@ -1421,7 +1420,7 @@ class AnswerWithJustification(BaseModel): strict=strict, ) if is_pydantic_schema: - output_parser: OutputParserLike = PydanticToolsParser( + output_parser: Runnable = PydanticToolsParser( tools=[schema], # type: ignore[list-item] first_tool_only=True, # type: ignore[list-item] ) @@ -1445,11 +1444,12 @@ class AnswerWithJustification(BaseModel): strict = strict if strict is not None else True response_format = _convert_to_openai_response_format(schema, strict=strict) llm = self.bind(response_format=response_format) - output_parser = ( - cast(Runnable, _oai_structured_outputs_parser) - if is_pydantic_schema - else JsonOutputParser() - ) + if is_pydantic_schema: + output_parser = _oai_structured_outputs_parser.with_types( + output_type=cast(type, schema) + ) + else: + output_parser = JsonOutputParser() else: raise ValueError( f"Unrecognized method argument. Expected one of 'function_calling' or " diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py index 4e959f005990d..ae83849901c23 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py @@ -18,6 +18,7 @@ ) from langchain_core.messages.ai import UsageMetadata from langchain_core.pydantic_v1 import BaseModel +from pydantic import BaseModel as BaseModelV2 from langchain_openai import ChatOpenAI from langchain_openai.chat_models.base import ( @@ -694,3 +695,31 @@ def test_get_num_tokens_from_messages() -> None: expected = 176 actual = llm.get_num_tokens_from_messages(messages) assert expected == actual + + +class Foo(BaseModel): + bar: int + + +class FooV2(BaseModelV2): + bar: int + + +@pytest.mark.parametrize("schema", [Foo, FooV2]) +def test_schema_from_with_structured_output(schema: Type) -> None: + """Test schema from with_structured_output.""" + + llm = ChatOpenAI() + + structured_llm = llm.with_structured_output( + schema, method="json_schema", strict=True + ) + + expected = { + "properties": {"bar": {"title": "Bar", "type": "integer"}}, + "required": ["bar"], + "title": schema.__name__, + "type": "object", + } + actual = structured_llm.get_output_schema().schema() + assert actual == expected