-
Notifications
You must be signed in to change notification settings - Fork 15.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Harrison/guarded output parser #1804
Changes from all commits
44d2492
1af560c
6898d83
fa2d98c
bfa858b
325825d
a0cde05
3ee7558
32a8507
ccc1897
5f41f07
86085bc
68e9b7f
97b8724
cbab13c
1de0790
e8f2ed4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any | ||
|
||
from langchain.chains.llm import LLMChain | ||
from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT | ||
from langchain.prompts.base import BasePromptTemplate | ||
from langchain.schema import BaseLanguageModel, BaseOutputParser, OutputParserException | ||
|
||
|
||
class OutputFixingParser(BaseOutputParser): | ||
"""Wraps a parser and tries to fix parsing errors.""" | ||
|
||
parser: BaseOutputParser | ||
retry_chain: LLMChain | ||
|
||
@classmethod | ||
def from_llm( | ||
cls, | ||
llm: BaseLanguageModel, | ||
parser: BaseOutputParser, | ||
prompt: BasePromptTemplate = NAIVE_FIX_PROMPT, | ||
) -> OutputFixingParser: | ||
chain = LLMChain(llm=llm, prompt=prompt) | ||
return cls(parser=parser, retry_chain=chain) | ||
|
||
def parse(self, completion: str) -> Any: | ||
try: | ||
parsed_completion = self.parser.parse(completion) | ||
except OutputParserException as e: | ||
new_completion = self.retry_chain.run( | ||
instructions=self.parser.get_format_instructions(), | ||
completion=completion, | ||
error=repr(e), | ||
) | ||
parsed_completion = self.parser.parse(new_completion) | ||
|
||
return parsed_completion | ||
|
||
def get_format_instructions(self) -> str: | ||
return self.parser.get_format_instructions() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# flake8: noqa | ||
from langchain.prompts.prompt import PromptTemplate | ||
|
||
NAIVE_FIX = """Instructions: | ||
-------------- | ||
{instructions} | ||
-------------- | ||
Completion: | ||
-------------- | ||
{completion} | ||
-------------- | ||
|
||
Above, the Completion did not satisfy the constraints given in the Instructions. | ||
Error: | ||
-------------- | ||
{error} | ||
-------------- | ||
|
||
Please try again. Please only respond with an answer that satisfies the constraints laid out in the Instructions:""" | ||
|
||
|
||
NAIVE_FIX_PROMPT = PromptTemplate.from_template(NAIVE_FIX) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,8 +4,8 @@ | |
|
||
from pydantic import BaseModel, ValidationError | ||
|
||
from langchain.output_parsers.base import BaseOutputParser | ||
from langchain.output_parsers.format_instructions import PYDANTIC_FORMAT_INSTRUCTIONS | ||
from langchain.schema import BaseOutputParser, OutputParserException | ||
|
||
|
||
class PydanticOutputParser(BaseOutputParser): | ||
|
@@ -14,7 +14,9 @@ class PydanticOutputParser(BaseOutputParser): | |
def parse(self, text: str) -> BaseModel: | ||
try: | ||
# Greedy search for 1st json candidate. | ||
match = re.search("\{.*\}", text.strip()) | ||
match = re.search( | ||
"\{.*\}", text.strip(), re.MULTILINE | re.IGNORECASE | re.DOTALL | ||
) | ||
json_str = "" | ||
if match: | ||
json_str = match.group() | ||
|
@@ -24,16 +26,17 @@ def parse(self, text: str) -> BaseModel: | |
except (json.JSONDecodeError, ValidationError) as e: | ||
name = self.pydantic_object.__name__ | ||
msg = f"Failed to parse {name} from completion {text}. Got: {e}" | ||
raise ValueError(msg) | ||
raise OutputParserException(msg) | ||
|
||
def get_format_instructions(self) -> str: | ||
schema = self.pydantic_object.schema() | ||
|
||
# Remove extraneous fields. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. noting @hwchase17 : added this back to support more complex types, like an enum.. I tweaked the in-context example as a result, including a positive + negative case. Seems to be working reasonably well. I'd like to add some tests, but doesn't have to be in this diff |
||
reduced_schema = { | ||
prop: {"description": data["description"], "type": data["type"]} | ||
for prop, data in schema["properties"].items() | ||
} | ||
reduced_schema = schema | ||
if "title" in reduced_schema: | ||
del reduced_schema["title"] | ||
if "type" in reduced_schema: | ||
del reduced_schema["type"] | ||
# Ensure json in context is well-formed with double quotes. | ||
schema = json.dumps(reduced_schema) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any | ||
|
||
from langchain.chains.llm import LLMChain | ||
from langchain.prompts.base import BasePromptTemplate | ||
from langchain.prompts.prompt import PromptTemplate | ||
from langchain.schema import ( | ||
BaseLanguageModel, | ||
BaseOutputParser, | ||
OutputParserException, | ||
PromptValue, | ||
) | ||
|
||
NAIVE_COMPLETION_RETRY = """Prompt: | ||
{prompt} | ||
Completion: | ||
{completion} | ||
|
||
Above, the Completion did not satisfy the constraints given in the Prompt. | ||
Please try again:""" | ||
|
||
NAIVE_COMPLETION_RETRY_WITH_ERROR = """Prompt: | ||
{prompt} | ||
Completion: | ||
{completion} | ||
|
||
Above, the Completion did not satisfy the constraints given in the Prompt. | ||
Details: {error} | ||
Please try again:""" | ||
|
||
NAIVE_RETRY_PROMPT = PromptTemplate.from_template(NAIVE_COMPLETION_RETRY) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about inlining the template instead of using a proxy variable? We don't re-use the proxy variable anywhere right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nope, but imo it makes it more readable to have it separate |
||
NAIVE_RETRY_WITH_ERROR_PROMPT = PromptTemplate.from_template( | ||
NAIVE_COMPLETION_RETRY_WITH_ERROR | ||
) | ||
|
||
|
||
class RetryOutputParser(BaseOutputParser): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the difference between this parser and the fix parser above? What do you think about expanding the doc-string significantly to explain use case and how the retry works? e.g., 2-5 lines of documentation. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yup will good, good call |
||
"""Wraps a parser and tries to fix parsing errors. | ||
|
||
Does this by passing the original prompt and the completion to another | ||
LLM, and telling it the completion did not satisfy criteria in the prompt. | ||
""" | ||
|
||
parser: BaseOutputParser | ||
retry_chain: LLMChain | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The type system doesn't seem to help us much here -- since it looks like the have you thought of a way to surface that type information? I assume the challenge is maintaining things serializable There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not sure what you mean, would love to discuss |
||
|
||
@classmethod | ||
def from_llm( | ||
cls, | ||
llm: BaseLanguageModel, | ||
parser: BaseOutputParser, | ||
prompt: BasePromptTemplate = NAIVE_RETRY_PROMPT, | ||
) -> RetryOutputParser: | ||
chain = LLMChain(llm=llm, prompt=prompt) | ||
return cls(parser=parser, retry_chain=chain) | ||
|
||
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> Any: | ||
try: | ||
parsed_completion = self.parser.parse(completion) | ||
except OutputParserException: | ||
new_completion = self.retry_chain.run( | ||
prompt=prompt_value.to_string(), completion=completion | ||
) | ||
parsed_completion = self.parser.parse(new_completion) | ||
|
||
return parsed_completion | ||
|
||
def parse(self, completion: str) -> Any: | ||
raise NotImplementedError( | ||
"This OutputParser can only be called by the `parse_with_prompt` method." | ||
) | ||
|
||
def get_format_instructions(self) -> str: | ||
return self.parser.get_format_instructions() | ||
|
||
|
||
class RetryWithErrorOutputParser(BaseOutputParser): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not combine this parser with the one above and add a variable to control behavior on error? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. prompt inputs differ... its def possible but youd need to make sure the variable for whether to pass error in is aligned with the prompt which just feels like too many levers to make in sync |
||
"""Wraps a parser and tries to fix parsing errors. | ||
|
||
Does this by passing the original prompt, the completion, AND the error | ||
that was raised to another language and telling it that the completion | ||
did not work, and raised the given error. Differs from RetryOutputParser | ||
in that this implementation provides the error that was raised back to the | ||
LLM, which in theory should give it more information on how to fix it. | ||
""" | ||
|
||
parser: BaseOutputParser | ||
retry_chain: LLMChain | ||
|
||
@classmethod | ||
def from_llm( | ||
cls, | ||
llm: BaseLanguageModel, | ||
parser: BaseOutputParser, | ||
prompt: BasePromptTemplate = NAIVE_RETRY_WITH_ERROR_PROMPT, | ||
) -> RetryWithErrorOutputParser: | ||
chain = LLMChain(llm=llm, prompt=prompt) | ||
return cls(parser=parser, retry_chain=chain) | ||
|
||
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> Any: | ||
try: | ||
parsed_completion = self.parser.parse(completion) | ||
except OutputParserException as e: | ||
new_completion = self.retry_chain.run( | ||
prompt=prompt_value.to_string(), completion=completion, error=repr(e) | ||
) | ||
parsed_completion = self.parser.parse(new_completion) | ||
|
||
return parsed_completion | ||
|
||
def parse(self, completion: str) -> Any: | ||
raise NotImplementedError( | ||
"This OutputParser can only be called by the `parse_with_prompt` method." | ||
) | ||
|
||
def get_format_instructions(self) -> str: | ||
return self.parser.get_format_instructions() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are the any benefits to trying multiple times?