Skip to content

Commit

Permalink
Add ruff rules for flake8-simplify (SIM)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Sep 25, 2024
1 parent 51c4393 commit 7d980bb
Show file tree
Hide file tree
Showing 46 changed files with 282 additions and 389 deletions.
5 changes: 2 additions & 3 deletions libs/core/langchain_core/_api/beta_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,9 @@ async def awarning_emitting_wrapper(*args: Any, **kwargs: Any) -> Any:

def finalize(wrapper: Callable[..., Any], new_doc: str) -> T:
"""Finalize the annotation of a class."""
try:
# Can't set new_doc on some extension objects.
with contextlib.suppress(AttributeError):
obj.__doc__ = new_doc
except AttributeError: # Can't set on some extension objects.
pass

def warn_if_direct_instance(
self: Any, *args: Any, **kwargs: Any
Expand Down
5 changes: 2 additions & 3 deletions libs/core/langchain_core/_api/deprecation.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,9 @@ async def awarning_emitting_wrapper(*args: Any, **kwargs: Any) -> Any:

def finalize(wrapper: Callable[..., Any], new_doc: str) -> T:
"""Finalize the deprecation of a class."""
try:
# Can't set new_doc on some extension objects.
with contextlib.suppress(AttributeError):
obj.__doc__ = new_doc
except AttributeError: # Can't set on some extension objects.
pass

def warn_if_direct_instance(
self: Any, *args: Any, **kwargs: Any
Expand Down
2 changes: 1 addition & 1 deletion libs/core/langchain_core/callbacks/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(
mode: The mode to open the file in. Defaults to "a".
color: The color to use for the text. Defaults to None.
"""
self.file = cast(TextIO, open(filename, mode, encoding="utf-8"))
self.file = cast(TextIO, open(filename, mode, encoding="utf-8")) # noqa: SIM115
self.color = color

def __del__(self) -> None:
Expand Down
10 changes: 5 additions & 5 deletions libs/core/langchain_core/callbacks/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2252,14 +2252,14 @@ def _configure(
else:
parent_run_id_ = inheritable_callbacks.parent_run_id
# Break ties between the external tracing context and inherited context
if parent_run_id is not None:
if parent_run_id_ is None:
parent_run_id_ = parent_run_id
if parent_run_id is not None and (
parent_run_id_ is None
# If the LC parent has already been reflected
# in the run tree, we know the run_tree is either the
# same parent or a child of the parent.
elif run_tree and str(parent_run_id_) in run_tree.dotted_order:
parent_run_id_ = parent_run_id
or (run_tree and str(parent_run_id_) in run_tree.dotted_order)
):
parent_run_id_ = parent_run_id
# Otherwise, we assume the LC context has progressed
# beyond the run tree and we should not inherit the parent.
callback_manager = callback_manager_cls(
Expand Down
11 changes: 5 additions & 6 deletions libs/core/langchain_core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,11 @@ def __init__(
send_to_llm: bool = False,
):
super().__init__(error)
if send_to_llm:
if observation is None or llm_output is None:
raise ValueError(
"Arguments 'observation' & 'llm_output'"
" are required if 'send_to_llm' is True"
)
if send_to_llm and (observation is None or llm_output is None):
raise ValueError(
"Arguments 'observation' & 'llm_output'"
" are required if 'send_to_llm' is True"
)
self.observation = observation
self.llm_output = llm_output
self.send_to_llm = send_to_llm
2 changes: 1 addition & 1 deletion libs/core/langchain_core/indexing/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def calculate_hashes(cls, values: dict[str, Any]) -> Any:
values["metadata_hash"] = metadata_hash
values["hash_"] = str(_hash_string_to_uuid(content_hash + metadata_hash))

_uid = values.get("uid", None)
_uid = values.get("uid")

if _uid is None:
values["uid"] = values["hash_"]
Expand Down
39 changes: 12 additions & 27 deletions libs/core/langchain_core/language_models/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,10 +802,7 @@ def _generate_with_cache(
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if isinstance(self.cache, BaseCache):
llm_cache = self.cache
else:
llm_cache = get_llm_cache()
llm_cache = self.cache if isinstance(self.cache, BaseCache) else get_llm_cache()
# We should check the cache unless it's explicitly set to False
# A None cache means we should use the default global cache
# if it's configured.
Expand Down Expand Up @@ -879,10 +876,7 @@ async def _agenerate_with_cache(
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if isinstance(self.cache, BaseCache):
llm_cache = self.cache
else:
llm_cache = get_llm_cache()
llm_cache = self.cache if isinstance(self.cache, BaseCache) else get_llm_cache()
# We should check the cache unless it's explicitly set to False
# A None cache means we should use the default global cache
# if it's configured.
Expand Down Expand Up @@ -1054,10 +1048,7 @@ def call_as_llm(
def predict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
if stop is None:
_stop = None
else:
_stop = list(stop)
_stop = None if stop is None else list(stop)
result = self([HumanMessage(content=text)], stop=_stop, **kwargs)
if isinstance(result.content, str):
return result.content
Expand All @@ -1072,20 +1063,14 @@ def predict_messages(
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> BaseMessage:
if stop is None:
_stop = None
else:
_stop = list(stop)
_stop = None if stop is None else list(stop)
return self(messages, stop=_stop, **kwargs)

@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
async def apredict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
if stop is None:
_stop = None
else:
_stop = list(stop)
_stop = None if stop is None else list(stop)
result = await self._call_async(
[HumanMessage(content=text)], stop=_stop, **kwargs
)
Expand All @@ -1102,10 +1087,7 @@ async def apredict_messages(
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> BaseMessage:
if stop is None:
_stop = None
else:
_stop = list(stop)
_stop = None if stop is None else list(stop)
return await self._call_async(messages, stop=_stop, **kwargs)

@property
Expand Down Expand Up @@ -1333,9 +1315,12 @@ def _cleanup_llm_representation(serialized: Any, depth: int) -> None:
if not isinstance(serialized, dict):
return

if "type" in serialized and serialized["type"] == "not_implemented":
if "repr" in serialized:
del serialized["repr"]
if (
"type" in serialized
and serialized["type"] == "not_implemented"
and "repr" in serialized
):
del serialized["repr"]

if "graph" in serialized:
del serialized["graph"]
Expand Down
5 changes: 1 addition & 4 deletions libs/core/langchain_core/language_models/fake_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,7 @@ def _generate(
) -> ChatResult:
"""Top Level call"""
message = next(self.messages)
if isinstance(message, str):
message_ = AIMessage(content=message)
else:
message_ = message
message_ = AIMessage(content=message) if isinstance(message, str) else message
generation = ChatGeneration(message=message_)
return ChatResult(generations=[generation])

Expand Down
25 changes: 5 additions & 20 deletions libs/core/langchain_core/language_models/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1305,10 +1305,7 @@ async def _call_async(
def predict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
if stop is None:
_stop = None
else:
_stop = list(stop)
_stop = None if stop is None else list(stop)
return self(text, stop=_stop, **kwargs)

@deprecated("0.1.7", alternative="invoke", removal="1.0")
Expand All @@ -1320,21 +1317,15 @@ def predict_messages(
**kwargs: Any,
) -> BaseMessage:
text = get_buffer_string(messages)
if stop is None:
_stop = None
else:
_stop = list(stop)
_stop = None if stop is None else list(stop)
content = self(text, stop=_stop, **kwargs)
return AIMessage(content=content)

@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
async def apredict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
if stop is None:
_stop = None
else:
_stop = list(stop)
_stop = None if stop is None else list(stop)
return await self._call_async(text, stop=_stop, **kwargs)

@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
Expand All @@ -1346,10 +1337,7 @@ async def apredict_messages(
**kwargs: Any,
) -> BaseMessage:
text = get_buffer_string(messages)
if stop is None:
_stop = None
else:
_stop = list(stop)
_stop = None if stop is None else list(stop)
content = await self._call_async(text, stop=_stop, **kwargs)
return AIMessage(content=content)

Expand Down Expand Up @@ -1384,10 +1372,7 @@ def save(self, file_path: Union[Path, str]) -> None:
llm.save(file_path="path/llm.yaml")
"""
# Convert file to Path object.
if isinstance(file_path, str):
save_path = Path(file_path)
else:
save_path = file_path
save_path = Path(file_path) if isinstance(file_path, str) else file_path

directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)
Expand Down
27 changes: 14 additions & 13 deletions libs/core/langchain_core/load/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ def __init__(

def __call__(self, value: dict[str, Any]) -> Any:
if (
value.get("lc", None) == 1
and value.get("type", None) == "secret"
and value.get("id", None) is not None
value.get("lc") == 1
and value.get("type") == "secret"
and value.get("id") is not None
):
[key] = value["id"]
if key in self.secrets_map:
Expand All @@ -99,27 +99,28 @@ def __call__(self, value: dict[str, Any]) -> Any:
raise KeyError(f'Missing key "{key}" in load(secrets_map)')

if (
value.get("lc", None) == 1
and value.get("type", None) == "not_implemented"
and value.get("id", None) is not None
value.get("lc") == 1
and value.get("type") == "not_implemented"
and value.get("id") is not None
):
raise NotImplementedError(
"Trying to load an object that doesn't implement "
f"serialization: {value}"
)

if (
value.get("lc", None) == 1
and value.get("type", None) == "constructor"
and value.get("id", None) is not None
value.get("lc") == 1
and value.get("type") == "constructor"
and value.get("id") is not None
):
[*namespace, name] = value["id"]
mapping_key = tuple(value["id"])

if namespace[0] not in self.valid_namespaces:
raise ValueError(f"Invalid namespace: {value}")
# The root namespace ["langchain"] is not a valid identifier.
elif namespace == ["langchain"]:
if (
namespace[0] not in self.valid_namespaces
# The root namespace ["langchain"] is not a valid identifier.
or namespace == ["langchain"]
):
raise ValueError(f"Invalid namespace: {value}")
# Has explicit import path.
elif mapping_key in self.import_mappings:
Expand Down
7 changes: 3 additions & 4 deletions libs/core/langchain_core/load/serializable.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
from abc import ABC
from typing import (
Any,
Expand Down Expand Up @@ -238,7 +239,7 @@ def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]:

# include all secrets, even if not specified in kwargs
# as these secrets may be passed as an environment variable instead
for key in secrets.keys():
for key in secrets:
secret_value = getattr(self, key, None) or lc_kwargs.get(key)
if secret_value is not None:
lc_kwargs.update({key: secret_value})
Expand Down Expand Up @@ -357,8 +358,6 @@ def to_json_not_implemented(obj: object) -> SerializedNotImplemented:
"id": _id,
"repr": None,
}
try:
with contextlib.suppress(Exception):
result["repr"] = repr(obj)
except Exception:
pass
return result
28 changes: 12 additions & 16 deletions libs/core/langchain_core/messages/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,23 +435,22 @@ def filter_messages(
messages = convert_to_messages(messages)
filtered: list[BaseMessage] = []
for msg in messages:
if exclude_names and msg.name in exclude_names:
continue
elif exclude_types and _is_message_type(msg, exclude_types):
continue
elif exclude_ids and msg.id in exclude_ids:
if (
(exclude_names and msg.name in exclude_names)
or (exclude_types and _is_message_type(msg, exclude_types))
or (exclude_ids and msg.id in exclude_ids)
):
continue
else:
pass

# default to inclusion when no inclusion criteria given.
if not (include_types or include_ids or include_names):
filtered.append(msg)
elif include_names and msg.name in include_names:
filtered.append(msg)
elif include_types and _is_message_type(msg, include_types):
filtered.append(msg)
elif include_ids and msg.id in include_ids:
if (
not (include_types or include_ids or include_names)
or (include_names and msg.name in include_names)
or (include_types and _is_message_type(msg, include_types))
or (include_ids and msg.id in include_ids)
):
filtered.append(msg)
else:
pass
Expand Down Expand Up @@ -960,10 +959,7 @@ def _last_max_tokens(
while messages and not _is_message_type(messages[-1], end_on):
messages.pop()
swapped_system = include_system and isinstance(messages[0], SystemMessage)
if swapped_system:
reversed_ = messages[:1] + messages[1:][::-1]
else:
reversed_ = messages[::-1]
reversed_ = messages[:1] + messages[1:][::-1] if swapped_system else messages[::-1]

reversed_ = _first_max_tokens(
reversed_,
Expand Down
5 changes: 2 additions & 3 deletions libs/core/langchain_core/output_parsers/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import contextlib
from abc import ABC, abstractmethod
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -311,8 +312,6 @@ def _type(self) -> str:
def dict(self, **kwargs: Any) -> dict:
"""Return dictionary representation of output parser."""
output_parser_dict = super().dict(**kwargs)
try:
with contextlib.suppress(NotImplementedError):
output_parser_dict["_type"] = self._type
except NotImplementedError:
pass
return output_parser_dict
10 changes: 4 additions & 6 deletions libs/core/langchain_core/output_parsers/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,10 @@ def _diff(self, prev: Optional[Any], next: Any) -> Any:
return jsonpatch.make_patch(prev, next).patch

def _get_schema(self, pydantic_object: type[TBaseModel]) -> dict[str, Any]:
if PYDANTIC_MAJOR_VERSION == 2:
if issubclass(pydantic_object, pydantic.BaseModel):
return pydantic_object.model_json_schema()
elif issubclass(pydantic_object, pydantic.v1.BaseModel):
return pydantic_object.model_json_schema()
return pydantic_object.model_json_schema()
if issubclass(pydantic_object, pydantic.BaseModel):
return pydantic_object.model_json_schema()
elif issubclass(pydantic_object, pydantic.v1.BaseModel):
return pydantic_object.schema()

def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a JSON object.
Expand Down
Loading

0 comments on commit 7d980bb

Please sign in to comment.