Skip to content

Commit

Permalink
fix for mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
curiousgeorgios committed Jul 21, 2023
1 parent c0e1bac commit 9977631
Showing 1 changed file with 35 additions and 17 deletions.
52 changes: 35 additions & 17 deletions langchain/chains/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ class Chain(Serializable, ABC):
starting with on_chain_start, ending with on_chain_end or on_chain_error.
Each custom chain can optionally call additional callback methods, see Callback docs
for full details."""
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
callback_manager: Optional[BaseCallbackManager] = Field(
default=None, exclude=True
)
"""Deprecated, use `callbacks` instead."""
verbose: bool = Field(default_factory=_get_verbosity)
"""Whether or not run in verbose mode. In verbose mode, some intermediate logs
Expand All @@ -92,7 +94,7 @@ class Config:
@property
def _chain_type(self) -> str:
warnings.warn("Saving not supported for this chain type.", UserWarning)
return 'NotImplemented'
return "NotImplemented"

@root_validator()
def raise_callback_manager_deprecation(cls, values: Dict) -> Dict:
Expand Down Expand Up @@ -181,7 +183,9 @@ async def _acall(
A dict of named outputs. Should contain all outputs specified in
`Chain.output_keys`.
"""
raise NotImplementedError("Async call not supported for this chain type.")
raise NotImplementedError(
"Async call not supported for this chain type."
)

def __call__(
self,
Expand Down Expand Up @@ -228,7 +232,9 @@ def __call__(
metadata,
self.metadata,
)
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
new_arg_supported = inspect.signature(self._call).parameters.get(
"run_manager"
)
run_manager = callback_manager.on_chain_start(
dumpd(self),
inputs,
Expand Down Expand Up @@ -295,7 +301,9 @@ async def acall(
metadata,
self.metadata,
)
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
new_arg_supported = inspect.signature(self._acall).parameters.get(
"run_manager"
)
run_manager = await callback_manager.on_chain_start(
dumpd(self),
inputs,
Expand Down Expand Up @@ -343,7 +351,9 @@ def prep_outputs(
else:
return {**inputs, **outputs}

def prep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]:
def prep_inputs(
self, inputs: Union[Dict[str, Any], Any]
) -> Dict[str, str]:
"""Validate and prepare chain inputs, including adding inputs from memory.
Args:
Expand All @@ -360,7 +370,9 @@ def prep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]:
if self.memory is not None:
# If there are multiple input keys, but some get set by memory so that
# only one is not set, we can still figure out which key it is.
_input_keys = _input_keys.difference(self.memory.memory_variables)
_input_keys = _input_keys.difference(
self.memory.memory_variables
)
if len(_input_keys) != 1:
raise ValueError(
f"A single string input was passed in, but this chain expects "
Expand Down Expand Up @@ -437,15 +449,17 @@ def run(

if args and not kwargs:
if len(args) != 1:
raise ValueError("`run` supports only one positional argument.")
return self(args[0], callbacks=callbacks, tags=tags, metadata=metadata)[
_output_key
]
raise ValueError(
"`run` supports only one positional argument."
)
return self(
args[0], callbacks=callbacks, tags=tags, metadata=metadata
)[_output_key]

if kwargs and not args:
return self(kwargs, callbacks=callbacks, tags=tags, metadata=metadata)[
_output_key
]
return self(
kwargs, callbacks=callbacks, tags=tags, metadata=metadata
)[_output_key]

if not kwargs and not args:
raise ValueError(
Expand Down Expand Up @@ -513,7 +527,9 @@ async def arun(
)
elif args and not kwargs:
if len(args) != 1:
raise ValueError("`run` supports only one positional argument.")
raise ValueError(
"`run` supports only one positional argument."
)
return (
await self.acall(
args[0], callbacks=callbacks, tags=tags, metadata=metadata
Expand Down Expand Up @@ -552,8 +568,10 @@ def dict(self, **kwargs: Any) -> Dict:
# -> {"_type": "foo", "verbose": False, ...}
"""
if self.memory is not None:
warnings.warn("Saving not supported for this chain type.", UserWarning)
return 'NotImplemented'
warnings.warn(
"Saving not supported for this chain type.", UserWarning
)
return {}

_dict = super().dict(**kwargs)
_dict["_type"] = self._chain_type
Expand Down

0 comments on commit 9977631

Please sign in to comment.