Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Harrison/guarded output parser #1804

Merged
merged 17 commits into from
Mar 22, 2023
Merged
343 changes: 321 additions & 22 deletions docs/modules/prompts/examples/output_parsers.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion langchain/agents/conversational_chat/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
)
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains import LLMChain
from langchain.output_parsers.base import BaseOutputParser
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.chat import (
ChatPromptTemplate,
Expand All @@ -26,6 +25,7 @@
AIMessage,
BaseLanguageModel,
BaseMessage,
BaseOutputParser,
HumanMessage,
)
from langchain.tools.base import BaseTool
Expand Down
7 changes: 5 additions & 2 deletions langchain/output_parsers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from langchain.output_parsers.base import BaseOutputParser
from langchain.output_parsers.fix import OutputFixingParser
from langchain.output_parsers.list import (
CommaSeparatedListOutputParser,
ListOutputParser,
Expand All @@ -7,16 +7,19 @@
from langchain.output_parsers.rail_parser import GuardrailsOutputParser
from langchain.output_parsers.regex import RegexParser
from langchain.output_parsers.regex_dict import RegexDictParser
from langchain.output_parsers.retry import RetryOutputParser, RetryWithErrorOutputParser
from langchain.output_parsers.structured import ResponseSchema, StructuredOutputParser

__all__ = [
"RegexParser",
"RegexDictParser",
"ListOutputParser",
"CommaSeparatedListOutputParser",
"BaseOutputParser",
"StructuredOutputParser",
"ResponseSchema",
"GuardrailsOutputParser",
"PydanticOutputParser",
"RetryOutputParser",
"RetryWithErrorOutputParser",
"OutputFixingParser",
]
28 changes: 0 additions & 28 deletions langchain/output_parsers/base.py

This file was deleted.

41 changes: 41 additions & 0 deletions langchain/output_parsers/fix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from __future__ import annotations

from typing import Any

from langchain.chains.llm import LLMChain
from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel, BaseOutputParser, OutputParserException


class OutputFixingParser(BaseOutputParser):
"""Wraps a parser and tries to fix parsing errors."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the any benefits to trying multiple times?


parser: BaseOutputParser
retry_chain: LLMChain

@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
parser: BaseOutputParser,
prompt: BasePromptTemplate = NAIVE_FIX_PROMPT,
) -> OutputFixingParser:
chain = LLMChain(llm=llm, prompt=prompt)
return cls(parser=parser, retry_chain=chain)

def parse(self, completion: str) -> Any:
try:
parsed_completion = self.parser.parse(completion)
except OutputParserException as e:
new_completion = self.retry_chain.run(
instructions=self.parser.get_format_instructions(),
completion=completion,
error=repr(e),
)
parsed_completion = self.parser.parse(new_completion)

return parsed_completion

def get_format_instructions(self) -> str:
return self.parser.get_format_instructions()
5 changes: 4 additions & 1 deletion langchain/output_parsers/format_instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
}}
```"""

PYDANTIC_FORMAT_INSTRUCTIONS = """The output should be formatted as a JSON instance that conforms to the JSON schema below. For example, the object {{"foo": ["bar", "baz"]}} conforms to the schema {{"foo": {{"description": "a list of strings field", "type": "string"}}}}.
PYDANTIC_FORMAT_INSTRUCTIONS = """The output should be formatted as a JSON instance that conforms to the JSON schema below.

As an example, for the schema {{"properties": {{"foo": {{"title": "Foo", "description": "a list of strings", "type": "array", "items": {{"type": "string"}}}}}}, "required": ["foo"]}}}}
the object {{"foo": ["bar", "baz"]}} is a well-formatted instance of the schema. The object {{"properties": {{"foo": ["bar", "baz"]}}}} is not well-formatted.

Here is the output schema:
```
Expand Down
2 changes: 1 addition & 1 deletion langchain/output_parsers/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from abc import abstractmethod
from typing import List

from langchain.output_parsers.base import BaseOutputParser
from langchain.schema import BaseOutputParser


class ListOutputParser(BaseOutputParser):
Expand Down
22 changes: 22 additions & 0 deletions langchain/output_parsers/prompts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# flake8: noqa
from langchain.prompts.prompt import PromptTemplate

NAIVE_FIX = """Instructions:
--------------
{instructions}
--------------
Completion:
--------------
{completion}
--------------

Above, the Completion did not satisfy the constraints given in the Instructions.
Error:
--------------
{error}
--------------

Please try again. Please only respond with an answer that satisfies the constraints laid out in the Instructions:"""


NAIVE_FIX_PROMPT = PromptTemplate.from_template(NAIVE_FIX)
17 changes: 10 additions & 7 deletions langchain/output_parsers/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

from pydantic import BaseModel, ValidationError

from langchain.output_parsers.base import BaseOutputParser
from langchain.output_parsers.format_instructions import PYDANTIC_FORMAT_INSTRUCTIONS
from langchain.schema import BaseOutputParser, OutputParserException


class PydanticOutputParser(BaseOutputParser):
Expand All @@ -14,7 +14,9 @@ class PydanticOutputParser(BaseOutputParser):
def parse(self, text: str) -> BaseModel:
try:
# Greedy search for 1st json candidate.
match = re.search("\{.*\}", text.strip())
match = re.search(
"\{.*\}", text.strip(), re.MULTILINE | re.IGNORECASE | re.DOTALL
)
json_str = ""
if match:
json_str = match.group()
Expand All @@ -24,16 +26,17 @@ def parse(self, text: str) -> BaseModel:
except (json.JSONDecodeError, ValidationError) as e:
name = self.pydantic_object.__name__
msg = f"Failed to parse {name} from completion {text}. Got: {e}"
raise ValueError(msg)
raise OutputParserException(msg)

def get_format_instructions(self) -> str:
schema = self.pydantic_object.schema()

# Remove extraneous fields.
Copy link
Contributor

@jerwelborn jerwelborn Mar 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noting @hwchase17 : added this back to support more complex types, like an enum.. I tweaked the in-context example as a result, including a positive + negative case. Seems to be working reasonably well. I'd like to add some tests, but doesn't have to be in this diff

reduced_schema = {
prop: {"description": data["description"], "type": data["type"]}
for prop, data in schema["properties"].items()
}
reduced_schema = schema
if "title" in reduced_schema:
del reduced_schema["title"]
if "type" in reduced_schema:
del reduced_schema["type"]
# Ensure json in context is well-formed with double quotes.
schema = json.dumps(reduced_schema)

Expand Down
2 changes: 1 addition & 1 deletion langchain/output_parsers/rail_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Any, Dict

from langchain.output_parsers.base import BaseOutputParser
from langchain.schema import BaseOutputParser


class GuardrailsOutputParser(BaseOutputParser):
Expand Down
2 changes: 1 addition & 1 deletion langchain/output_parsers/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from pydantic import BaseModel

from langchain.output_parsers.base import BaseOutputParser
from langchain.schema import BaseOutputParser


class RegexParser(BaseOutputParser, BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion langchain/output_parsers/regex_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from pydantic import BaseModel

from langchain.output_parsers.base import BaseOutputParser
from langchain.schema import BaseOutputParser


class RegexDictParser(BaseOutputParser, BaseModel):
Expand Down
118 changes: 118 additions & 0 deletions langchain/output_parsers/retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from __future__ import annotations

from typing import Any

from langchain.chains.llm import LLMChain
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import (
BaseLanguageModel,
BaseOutputParser,
OutputParserException,
PromptValue,
)

NAIVE_COMPLETION_RETRY = """Prompt:
{prompt}
Completion:
{completion}

Above, the Completion did not satisfy the constraints given in the Prompt.
Please try again:"""

NAIVE_COMPLETION_RETRY_WITH_ERROR = """Prompt:
{prompt}
Completion:
{completion}

Above, the Completion did not satisfy the constraints given in the Prompt.
Details: {error}
Please try again:"""

NAIVE_RETRY_PROMPT = PromptTemplate.from_template(NAIVE_COMPLETION_RETRY)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about inlining the template instead of using a proxy variable? We don't re-use the proxy variable anywhere right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nope, but imo it makes it more readable to have it separate

NAIVE_RETRY_WITH_ERROR_PROMPT = PromptTemplate.from_template(
NAIVE_COMPLETION_RETRY_WITH_ERROR
)


class RetryOutputParser(BaseOutputParser):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the difference between this parser and the fix parser above?

What do you think about expanding the doc-string significantly to explain use case and how the retry works? e.g., 2-5 lines of documentation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup will good, good call

"""Wraps a parser and tries to fix parsing errors.

Does this by passing the original prompt and the completion to another
LLM, and telling it the completion did not satisfy criteria in the prompt.
"""

parser: BaseOutputParser
retry_chain: LLMChain
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type system doesn't seem to help us much here -- since it looks like the run api of the retry_chain interface is defined by the prompt (i.e., ability to accept prompt and completion)

have you thought of a way to surface that type information? I assume the challenge is maintaining things serializable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure what you mean, would love to discuss


@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
parser: BaseOutputParser,
prompt: BasePromptTemplate = NAIVE_RETRY_PROMPT,
) -> RetryOutputParser:
chain = LLMChain(llm=llm, prompt=prompt)
return cls(parser=parser, retry_chain=chain)

def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> Any:
try:
parsed_completion = self.parser.parse(completion)
except OutputParserException:
new_completion = self.retry_chain.run(
prompt=prompt_value.to_string(), completion=completion
)
parsed_completion = self.parser.parse(new_completion)

return parsed_completion

def parse(self, completion: str) -> Any:
raise NotImplementedError(
"This OutputParser can only be called by the `parse_with_prompt` method."
)

def get_format_instructions(self) -> str:
return self.parser.get_format_instructions()


class RetryWithErrorOutputParser(BaseOutputParser):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not combine this parser with the one above and add a variable to control behavior on error?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prompt inputs differ... its def possible but youd need to make sure the variable for whether to pass error in is aligned with the prompt which just feels like too many levers to make in sync

"""Wraps a parser and tries to fix parsing errors.

Does this by passing the original prompt, the completion, AND the error
that was raised to another language and telling it that the completion
did not work, and raised the given error. Differs from RetryOutputParser
in that this implementation provides the error that was raised back to the
LLM, which in theory should give it more information on how to fix it.
"""

parser: BaseOutputParser
retry_chain: LLMChain

@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
parser: BaseOutputParser,
prompt: BasePromptTemplate = NAIVE_RETRY_WITH_ERROR_PROMPT,
) -> RetryWithErrorOutputParser:
chain = LLMChain(llm=llm, prompt=prompt)
return cls(parser=parser, retry_chain=chain)

def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> Any:
try:
parsed_completion = self.parser.parse(completion)
except OutputParserException as e:
new_completion = self.retry_chain.run(
prompt=prompt_value.to_string(), completion=completion, error=repr(e)
)
parsed_completion = self.parser.parse(new_completion)

return parsed_completion

def parse(self, completion: str) -> Any:
raise NotImplementedError(
"This OutputParser can only be called by the `parse_with_prompt` method."
)

def get_format_instructions(self) -> str:
return self.parser.get_format_instructions()
4 changes: 2 additions & 2 deletions langchain/output_parsers/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

from pydantic import BaseModel

from langchain.output_parsers.base import BaseOutputParser
from langchain.output_parsers.format_instructions import STRUCTURED_FORMAT_INSTRUCTIONS
from langchain.schema import BaseOutputParser, OutputParserException

line_template = '\t"{name}": {type} // {description}'

Expand Down Expand Up @@ -42,7 +42,7 @@ def parse(self, text: str) -> BaseModel:
json_obj = json.loads(json_string)
for schema in self.response_schemas:
if schema.name not in json_obj:
raise ValueError(
raise OutputParserException(
f"Got invalid return object. Expected key `{schema.name}` "
f"to be present, but got {json_obj}"
)
Expand Down
8 changes: 1 addition & 7 deletions langchain/prompts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,7 @@
from pydantic import BaseModel, Extra, Field, root_validator

from langchain.formatting import formatter
from langchain.output_parsers.base import BaseOutputParser
from langchain.output_parsers.list import ( # noqa: F401
CommaSeparatedListOutputParser,
ListOutputParser,
)
from langchain.output_parsers.regex import RegexParser # noqa: F401
from langchain.schema import BaseMessage, HumanMessage, PromptValue
from langchain.schema import BaseMessage, BaseOutputParser, HumanMessage, PromptValue


def jinja2_formatter(template: str, **kwargs: Any) -> str:
Expand Down
Loading