Skip to content

Commit

Permalink
Add various ruff rules
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Sep 27, 2024
1 parent 9eb26c5 commit 88f1e9f
Show file tree
Hide file tree
Showing 27 changed files with 103 additions and 52 deletions.
6 changes: 3 additions & 3 deletions libs/core/langchain_core/embeddings/fake.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class FakeEmbeddings(Embeddings, BaseModel):
def _get_embedding(self) -> list[float]:
import numpy as np # type: ignore[import-not-found, import-untyped]

return list(np.random.normal(size=self.size))
return list(np.random.default_rng().normal(size=self.size))

def embed_documents(self, texts: list[str]) -> list[list[float]]:
return [self._get_embedding() for _ in texts]
Expand Down Expand Up @@ -109,8 +109,8 @@ def _get_embedding(self, seed: int) -> list[float]:
import numpy as np # type: ignore[import-not-found, import-untyped]

# set the seed for the random generator
np.random.seed(seed)
return list(np.random.normal(size=self.size))
rng = np.random.default_rng(seed)
return list(rng.normal(size=self.size))

def _get_seed(self, text: str) -> int:
"""Get a seed for the random generator, using the hash of the text."""
Expand Down
2 changes: 1 addition & 1 deletion libs/core/langchain_core/language_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def with_structured_output(
"""Not implemented on this class."""
# Implement this on child class if there is a way of steering the model to
# generate responses that match a given schema.
raise NotImplementedError()
raise NotImplementedError

@deprecated("0.1.7", alternative="invoke", removal="1.0")
@abstractmethod
Expand Down
4 changes: 2 additions & 2 deletions libs/core/langchain_core/language_models/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,7 +983,7 @@ def _stream(
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
raise NotImplementedError()
raise NotImplementedError

async def _astream(
self,
Expand Down Expand Up @@ -1124,7 +1124,7 @@ def bind_tools(
tools: Sequence[Union[typing.Dict[str, Any], type, Callable, BaseTool]], # noqa: UP006
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
raise NotImplementedError()
raise NotImplementedError

def with_structured_output(
self,
Expand Down
2 changes: 1 addition & 1 deletion libs/core/langchain_core/language_models/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ def _stream(
Returns:
An iterator of GenerationChunks.
"""
raise NotImplementedError()
raise NotImplementedError

async def _astream(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def parse(self, text: str) -> Any:
Returns:
The parsed JSON object.
"""
raise NotImplementedError()
raise NotImplementedError


class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
Expand Down
2 changes: 1 addition & 1 deletion libs/core/langchain_core/output_parsers/openai_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def parse(self, text: str) -> Any:
Returns:
The parsed tool calls.
"""
raise NotImplementedError()
raise NotImplementedError


class JsonOutputKeyToolsParser(JsonOutputToolsParser):
Expand Down
2 changes: 1 addition & 1 deletion libs/core/langchain_core/output_parsers/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _diff(self, prev: Optional[T], next: T) -> T:
Returns:
The diff between the previous and current parsed output.
"""
raise NotImplementedError()
raise NotImplementedError

def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:
prev_parsed = None
Expand Down
2 changes: 1 addition & 1 deletion libs/core/langchain_core/prompts/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -1328,7 +1328,7 @@ def save(self, file_path: Union[Path, str]) -> None:
Args:
file_path: path to file.
"""
raise NotImplementedError()
raise NotImplementedError

def pretty_repr(self, html: bool = False) -> str:
"""Human-readable representation.
Expand Down
2 changes: 1 addition & 1 deletion libs/core/langchain_core/prompts/few_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,4 +467,4 @@ def pretty_repr(self, html: bool = False) -> str:
Returns:
A pretty representation of the prompt template.
"""
raise NotImplementedError()
raise NotImplementedError
2 changes: 1 addition & 1 deletion libs/core/langchain_core/prompts/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,4 +128,4 @@ def pretty_repr(self, html: bool = False) -> str:
Returns:
A pretty representation of the prompt.
"""
raise NotImplementedError()
raise NotImplementedError
10 changes: 8 additions & 2 deletions libs/core/langchain_core/runnables/graph_mermaid.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import base64
import re
from dataclasses import asdict
Expand Down Expand Up @@ -291,9 +292,14 @@ async def _render_mermaid_using_pyppeteer(
img_bytes = await page.screenshot({"fullPage": False})
await browser.close()

def write_to_file(path: str, bytes: bytes) -> None:
with open(path, "wb") as file:
file.write(bytes)

if output_file_path is not None:
with open(output_file_path, "wb") as file:
file.write(img_bytes)
await asyncio.get_event_loop().run_in_executor(
None, write_to_file, output_file_path, img_bytes
)

return img_bytes

Expand Down
2 changes: 1 addition & 1 deletion libs/core/langchain_core/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ def get_secret_from_env() -> Optional[SecretStr]:
return SecretStr(os.environ[key])
if isinstance(default, str):
return SecretStr(default)
elif isinstance(default, type(None)):
elif default is None:
return None
else:
if error_message:
Expand Down
34 changes: 32 additions & 2 deletions libs/core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,38 @@ python = ">=3.12.4"
[tool.poetry.extras]

[tool.ruff.lint]
select = [ "B", "C4", "E", "F", "I", "N", "PIE", "T201", "UP",]
ignore = [ "UP007",]
select = [
"ASYNC",
"B",
"C4",
"COM",
"DJ",
"E",
"EXE",
"F",
"FLY",
"FURB",
"I",
"ICN",
"INT",
"LOG",
"N",
"NPY",
"PD",
"PIE",
"Q",
"RSE",
"SLOT",
"T10",
"T201",
"TID",
"UP",
"YTT"
]
ignore = [
"COM812", # Messes with the formatter
"UP007", # Incompatible with pydantic + Python 3.9
]

[tool.coverage.run]
omit = [ "tests/*",]
Expand Down
4 changes: 2 additions & 2 deletions libs/core/tests/unit_tests/chat_history/test_chat_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def add_message(self, message: BaseMessage) -> None:

def clear(self) -> None:
"""Clear the store."""
raise NotImplementedError()
raise NotImplementedError

store: list[BaseMessage] = []
chat_history = SampleChatHistory(store=store)
Expand Down Expand Up @@ -50,7 +50,7 @@ def add_messages(self, message: Sequence[BaseMessage]) -> None:

def clear(self) -> None:
"""Clear the store."""
raise NotImplementedError()
raise NotImplementedError

chat_history = BulkAddHistory(store=store)
chat_history.add_message(HumanMessage(content="Hello"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def _generate(
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
raise NotImplementedError()
raise NotImplementedError

def _stream(
self,
Expand Down Expand Up @@ -210,7 +210,7 @@ def _generate(
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
raise NotImplementedError()
raise NotImplementedError

async def _astream( # type: ignore
self,
Expand Down
4 changes: 2 additions & 2 deletions libs/core/tests/unit_tests/language_models/llms/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def _generate(
**kwargs: Any,
) -> LLMResult:
"""Top Level call"""
raise NotImplementedError()
raise NotImplementedError

def _stream(
self,
Expand Down Expand Up @@ -198,7 +198,7 @@ def _generate(
**kwargs: Any,
) -> LLMResult:
"""Top Level call"""
raise NotImplementedError()
raise NotImplementedError

async def _astream(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class StrInvertCase(BaseTransformOutputParser[str]):

def parse(self, text: str) -> str:
"""Parse a single string into a specific format."""
raise NotImplementedError()
raise NotImplementedError

def parse_result(
self, result: list[Generation], *, partial: bool = False
Expand Down
16 changes: 8 additions & 8 deletions libs/core/tests/unit_tests/runnables/test_fallbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,13 @@ def chain() -> Runnable:


def _raise_error(inputs: dict) -> str:
raise ValueError()
raise ValueError


def _dont_raise_error(inputs: dict) -> str:
if "exception" in inputs:
return "bar"
raise ValueError()
raise ValueError


@pytest.fixture()
Expand Down Expand Up @@ -99,11 +99,11 @@ def _runnable(inputs: dict) -> str:
if inputs["text"] == "foo":
return "first"
if "exception" not in inputs:
raise ValueError()
raise ValueError
if inputs["text"] == "bar":
return "second"
if isinstance(inputs["exception"], ValueError):
raise RuntimeError()
raise RuntimeError
return "third"


Expand Down Expand Up @@ -251,13 +251,13 @@ def _generate(input: Iterator) -> Iterator[str]:


def _generate_immediate_error(input: Iterator) -> Iterator[str]:
raise ValueError()
raise ValueError
yield ""


def _generate_delayed_error(input: Iterator) -> Iterator[str]:
yield ""
raise ValueError()
raise ValueError


def test_fallbacks_stream() -> None:
Expand All @@ -279,13 +279,13 @@ async def _agenerate(input: AsyncIterator) -> AsyncIterator[str]:


async def _agenerate_immediate_error(input: AsyncIterator) -> AsyncIterator[str]:
raise ValueError()
raise ValueError
yield ""


async def _agenerate_delayed_error(input: AsyncIterator) -> AsyncIterator[str]:
yield ""
raise ValueError()
raise ValueError


async def test_fallbacks_astream() -> None:
Expand Down
4 changes: 2 additions & 2 deletions libs/core/tests/unit_tests/runnables/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ class InvalidInputTypeRunnable(Runnable[int, int]):
@property
@override
def InputType(self) -> type:
raise TypeError()
raise TypeError

@override
def invoke(
Expand All @@ -380,7 +380,7 @@ class InvalidOutputTypeRunnable(Runnable[int, int]):
@property
@override
def OutputType(self) -> type:
raise TypeError()
raise TypeError

@override
def invoke(
Expand Down
23 changes: 18 additions & 5 deletions libs/core/tests/unit_tests/runnables/test_runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@ def test_with_types_with_type_generics() -> None:

def foo(x: int) -> None:
"""Add one to the input."""
raise NotImplementedError()
raise NotImplementedError

# Try specifying some
RunnableLambda(foo).with_types(
Expand Down Expand Up @@ -3967,7 +3967,7 @@ def __init__(self, fail_starts_with: str) -> None:
self.fail_starts_with = fail_starts_with

def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any:
raise NotImplementedError()
raise NotImplementedError

def _batch(
self,
Expand Down Expand Up @@ -4086,7 +4086,7 @@ def __init__(self, fail_starts_with: str) -> None:
self.fail_starts_with = fail_starts_with

def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any:
raise NotImplementedError()
raise NotImplementedError

async def _abatch(
self,
Expand Down Expand Up @@ -5335,7 +5335,7 @@ def on_end(run: Run) -> None:
assert value2 in shared_state.values(), "Value not found in the dictionary."


async def test_closing_iterator_doesnt_raise_error() -> None:
def test_closing_iterator_doesnt_raise_error() -> None:
"""Test that closing an iterator calls on_chain_end rather than on_chain_error."""
import time

Expand All @@ -5344,9 +5344,10 @@ async def test_closing_iterator_doesnt_raise_error() -> None:
from langchain_core.output_parsers import StrOutputParser

on_chain_error_triggered = False
on_chain_end_triggered = False

class MyHandler(BaseCallbackHandler):
async def on_chain_error(
def on_chain_error(
self,
error: BaseException,
*,
Expand All @@ -5359,6 +5360,17 @@ async def on_chain_error(
nonlocal on_chain_error_triggered
on_chain_error_triggered = True

def on_chain_end(
self,
outputs: dict[str, Any],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
nonlocal on_chain_end_triggered
on_chain_end_triggered = True

llm = GenericFakeChatModel(messages=iter(["hi there"]))
chain = llm | StrOutputParser()
chain_ = chain.with_config({"callbacks": [MyHandler()]})
Expand All @@ -5369,6 +5381,7 @@ async def on_chain_error(
# Wait for a bit to make sure that the callback is called.
time.sleep(0.05)
assert on_chain_error_triggered is False
assert on_chain_end_triggered is True


def test_pydantic_protected_namespaces() -> None:
Expand Down
Loading

0 comments on commit 88f1e9f

Please sign in to comment.