Skip to content

Commit

Permalink
Raise exception in tests when blocking code is called in event loop
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Nov 2, 2024
1 parent df22924 commit 1ce2a6b
Show file tree
Hide file tree
Showing 13 changed files with 597 additions and 49 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ dev-dependencies = [
"asgi-lifespan>=2.1.0",
"pytest-github-actions-annotate-failures>=0.2.0",
"pytest-codspeed>=3.0.0",
"forbiddenfruit>=0.1.4",
]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def build_component_menu_list(self, file_paths):

async def process_file_async(self, file_path):
try:
file_content = self.read_file_content(file_path)
file_content = await asyncio.to_thread(self.read_file_content, file_path)
except Exception: # noqa: BLE001
logger.exception(f"Error while reading file {file_path}")
return False, f"Could not read {file_path}"
Expand Down
4 changes: 3 additions & 1 deletion src/backend/base/langflow/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ async def dispatch(self, request: Request, call_next):


def get_lifespan(*, fix_migration=False, version=None):
telemetry_service = get_telemetry_service()

def _initialize():
initialize_services(fix_migration=fix_migration)
setup_llm_caching()
Expand All @@ -104,7 +106,7 @@ async def lifespan(_app: FastAPI):
await asyncio.to_thread(_initialize)
all_types_dict = await get_and_cache_all_types_dict(get_settings_service())
await asyncio.to_thread(create_or_update_starter_projects, all_types_dict)
get_telemetry_service().start()
telemetry_service.start()
await asyncio.to_thread(load_flows_from_directory)
yield
except Exception as exc:
Expand Down
6 changes: 4 additions & 2 deletions src/backend/base/langflow/services/telemetry/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(self, settings_service: SettingsService):
self._stopping = False

self.ot = OpenTelemetry(prometheus_enabled=settings_service.settings.prometheus_enabled)
self.architecture: str | None = None

# Check for do-not-track settings
self.do_not_track = (
Expand Down Expand Up @@ -93,15 +94,16 @@ async def _queue_event(self, payload) -> None:
async def log_package_version(self) -> None:
python_version = ".".join(platform.python_version().split(".")[:2])
version_info = get_version_info()
architecture = platform.architecture()[0]
if self.architecture is None:
self.architecture = (await asyncio.to_thread(platform.architecture))[0]
payload = VersionPayload(
package=version_info["package"].lower(),
version=version_info["version"],
platform=platform.platform(),
python=python_version,
cache_type=self.settings_service.settings.cache_type,
backend_only=self.settings_service.settings.backend_only,
arch=architecture,
arch=self.architecture,
auto_login=self.settings_service.auth_settings.AUTO_LOGIN,
)
await self._queue_event((self.send_telemetry_data, payload, None))
Expand Down
6 changes: 3 additions & 3 deletions src/backend/base/langflow/services/tracing/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ def _reset_io(self) -> None:
async def initialize_tracers(self) -> None:
try:
await self.start()
self._initialize_langsmith_tracer()
self._initialize_langwatch_tracer()
self._initialize_langfuse_tracer()
await asyncio.to_thread(self._initialize_langsmith_tracer)
await asyncio.to_thread(self._initialize_langwatch_tracer)
await asyncio.to_thread(self._initialize_langfuse_tracer)
except Exception: # noqa: BLE001
logger.opt(exception=True).debug("Error initializing tracers")

Expand Down
138 changes: 138 additions & 0 deletions src/backend/tests/blockbuster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import asyncio
import inspect
import io
import os
import socket
import ssl
import sys
import time
from importlib.abc import FileLoader

import forbiddenfruit


class BlockingError(Exception): ...


def _blocking_error(func):
if inspect.isbuiltin(func):
msg = f"Blocking call to {func.__qualname__} ({func.__self__})"
elif inspect.ismethoddescriptor(func):
msg = f"Blocking call to {func}"
else:
msg = f"Blocking call to {func.__module__}.{func.__qualname__}"
return BlockingError(msg)


def _wrap_blocking(func):
def wrapper(*args, **kwargs):
try:
asyncio.get_running_loop()
except RuntimeError:
return func(*args, **kwargs)
raise _blocking_error(func)

return wrapper


def _wrap_time_blocking(func):
def wrapper(*args, **kwargs):
try:
asyncio.get_running_loop()
except RuntimeError:
return func(*args, **kwargs)
for frame_info in inspect.stack():
if frame_info.filename.endswith("pydev/pydevd.py") and frame_info.function == "_do_wait_suspend":
return func(*args, **kwargs)

raise _blocking_error(func)

return wrapper


def _wrap_os_blocking(func):
def os_op(fd, *args, **kwargs):
try:
asyncio.get_running_loop()
except RuntimeError:
return func(fd, *args, **kwargs)
if os.get_blocking(fd):
raise _blocking_error(func)
return func(fd, *args, **kwargs)

return os_op


def _wrap_socket_blocking(func):
def socket_op(self, *args, **kwargs):
try:
asyncio.get_running_loop()
except RuntimeError:
return func(self, *args, **kwargs)
if self.getblocking():
raise _blocking_error(func)
return func(self, *args, **kwargs)

return socket_op


def _wrap_file_read_blocking(func):
def file_op(self, *args, **kwargs):
try:
asyncio.get_running_loop()
except RuntimeError:
return func(self, *args, **kwargs)
for frame_info in inspect.stack():
if isinstance(frame_info.frame.f_locals.get("self"), FileLoader):
return func(self, *args, **kwargs)
if frame_info.filename.endswith("_pytest/assertion/rewrite.py") and frame_info.function in [
"_rewrite_test",
"_read_pyc",
"_write_pyc",
]:
return func(self, *args, **kwargs)
raise _blocking_error(func)

return file_op


def _wrap_file_write_blocking(func):
def file_op(self, *args, **kwargs):
try:
asyncio.get_running_loop()
except RuntimeError:
return func(self, *args, **kwargs)
if self not in [sys.stdout, sys.stderr]:
raise _blocking_error(func)
return func(self, *args, **kwargs)

return file_op


def init():
time.sleep = _wrap_time_blocking(time.sleep)

os.read = _wrap_os_blocking(os.read)
os.write = _wrap_os_blocking(os.write)

socket.socket.send = _wrap_socket_blocking(socket.socket.send)
socket.socket.sendall = _wrap_socket_blocking(socket.socket.sendall)
socket.socket.sendto = _wrap_socket_blocking(socket.socket.sendto)
socket.socket.recv = _wrap_socket_blocking(socket.socket.recv)
socket.socket.recv_into = _wrap_socket_blocking(socket.socket.recv_into)
socket.socket.recvfrom = _wrap_socket_blocking(socket.socket.recvfrom)
socket.socket.recvfrom_into = _wrap_socket_blocking(socket.socket.recvfrom_into)
socket.socket.recvmsg = _wrap_socket_blocking(socket.socket.recvmsg)
socket.socket.recvmsg_into = _wrap_socket_blocking(socket.socket.recvmsg_into)

ssl.SSLSocket.write = _wrap_socket_blocking(ssl.SSLSocket.write)
ssl.SSLSocket.send = _wrap_socket_blocking(ssl.SSLSocket.send)
ssl.SSLSocket.read = _wrap_socket_blocking(ssl.SSLSocket.read)
ssl.SSLSocket.recv = _wrap_socket_blocking(ssl.SSLSocket.recv)

forbiddenfruit.curse(io.BufferedReader, "read", _wrap_file_read_blocking(io.BufferedReader.read))
forbiddenfruit.curse(io.BufferedWriter, "write", _wrap_file_write_blocking(io.BufferedWriter.write))
forbiddenfruit.curse(io.BufferedRandom, "read", _wrap_blocking(io.BufferedRandom.read))
forbiddenfruit.curse(io.BufferedRandom, "write", _wrap_file_write_blocking(io.BufferedRandom.write))
forbiddenfruit.curse(io.TextIOWrapper, "read", _wrap_file_read_blocking(io.TextIOWrapper.read))
forbiddenfruit.curse(io.TextIOWrapper, "write", _wrap_file_write_blocking(io.TextIOWrapper.write))
40 changes: 24 additions & 16 deletions src/backend/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import json
import shutil

Expand Down Expand Up @@ -32,13 +33,15 @@
from sqlmodel.pool import StaticPool
from typer.testing import CliRunner

from tests import blockbuster
from tests.api_keys import get_openai_api_key

if TYPE_CHECKING:
from langflow.services.database.service import DatabaseService


load_dotenv()
blockbuster.init()


def pytest_configure(config):
Expand Down Expand Up @@ -286,23 +289,28 @@ async def client_fixture(
if "noclient" in request.keywords:
yield
else:
db_dir = tempfile.mkdtemp()
db_path = Path(db_dir) / "test.db"
monkeypatch.setenv("LANGFLOW_DATABASE_URL", f"sqlite:///{db_path}")
monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", "false")
if "load_flows" in request.keywords:
shutil.copyfile(
pytest.BASIC_EXAMPLE_PATH, Path(load_flows_dir) / "c54f9130-f2fa-4a3e-b22a-3856d946351b.json"
)
monkeypatch.setenv("LANGFLOW_LOAD_FLOWS_PATH", load_flows_dir)
monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", "true")

from langflow.main import create_app

app = create_app()
db_service = get_db_service()
db_service.database_url = f"sqlite:///{db_path}"
db_service.reload_engine()
def init_app():
db_dir = tempfile.mkdtemp()
db_path = Path(db_dir) / "test.db"
monkeypatch.setenv("LANGFLOW_DATABASE_URL", f"sqlite:///{db_path}")
monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", "false")
if "load_flows" in request.keywords:
shutil.copyfile(
pytest.BASIC_EXAMPLE_PATH, Path(load_flows_dir) / "c54f9130-f2fa-4a3e-b22a-3856d946351b.json"
)
monkeypatch.setenv("LANGFLOW_LOAD_FLOWS_PATH", load_flows_dir)
monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", "true")

from langflow.main import create_app

app = create_app()
db_service = get_db_service()
db_service.database_url = f"sqlite:///{db_path}"
db_service.reload_engine()
return app, db_path

app, db_path = await asyncio.to_thread(init_app)
# app.dependency_overrides[get_session] = get_session_override
async with (
LifespanManager(app, startup_timeout=None, shutdown_timeout=None) as manager,
Expand Down
11 changes: 6 additions & 5 deletions src/backend/tests/unit/graph/graph/test_base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import logging
from collections import deque

Expand All @@ -13,7 +14,7 @@
async def test_graph_not_prepared():
chat_input = ChatInput()
chat_output = ChatOutput()
graph = Graph()
graph = await asyncio.to_thread(Graph)
graph.add_component(chat_input)
graph.add_component(chat_output)
with pytest.raises(ValueError, match="Graph not prepared"):
Expand All @@ -23,7 +24,7 @@ async def test_graph_not_prepared():
async def test_graph(caplog: pytest.LogCaptureFixture):
chat_input = ChatInput()
chat_output = ChatOutput()
graph = Graph()
graph = await asyncio.to_thread(Graph)
graph.add_component(chat_input)
graph.add_component(chat_output)
caplog.clear()
Expand All @@ -35,7 +36,7 @@ async def test_graph(caplog: pytest.LogCaptureFixture):
async def test_graph_with_edge():
chat_input = ChatInput()
chat_output = ChatOutput()
graph = Graph()
graph = await asyncio.to_thread(Graph)
input_id = graph.add_component(chat_input)
output_id = graph.add_component(chat_output)
graph.add_component_edge(input_id, (chat_input.outputs[0].name, chat_input.inputs[0].name), output_id)
Expand All @@ -56,7 +57,7 @@ async def test_graph_functional():
chat_input = ChatInput(_id="chat_input")
chat_output = ChatOutput(input_value="test", _id="chat_output")
chat_output.set(sender_name=chat_input.message_response)
graph = Graph(chat_input, chat_output)
graph = await asyncio.to_thread(Graph, chat_input, chat_output)
assert graph._run_queue == deque(["chat_input"])
await graph.astep()
assert graph._run_queue == deque(["chat_output"])
Expand All @@ -71,7 +72,7 @@ async def test_graph_functional_async_start():
chat_input = ChatInput(_id="chat_input")
chat_output = ChatOutput(input_value="test", _id="chat_output")
chat_output.set(sender_name=chat_input.message_response)
graph = Graph(chat_input, chat_output)
graph = await asyncio.to_thread(Graph, chat_input, chat_output)
# Now iterate through the graph
# and check that the graph is running
# correctly
Expand Down
3 changes: 2 additions & 1 deletion src/backend/tests/unit/test_database.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import json
from typing import NamedTuple
from uuid import UUID, uuid4
Expand Down Expand Up @@ -604,7 +605,7 @@ async def test_delete_nonexistent_flow(client: AsyncClient, logged_in_headers):
@pytest.mark.usefixtures("active_user")
async def test_read_only_starter_projects(client: AsyncClient, logged_in_headers):
response = await client.get("api/v1/flows/basic_examples/", headers=logged_in_headers)
starter_projects = load_starter_projects()
starter_projects = await asyncio.to_thread(load_starter_projects)
assert response.status_code == 200
assert len(response.json()) == len(starter_projects)

Expand Down
34 changes: 20 additions & 14 deletions src/backend/tests/unit/test_files.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import re
import shutil
import tempfile
Expand Down Expand Up @@ -37,20 +38,25 @@ async def files_client_fixture(
if "noclient" in request.keywords:
yield
else:
db_dir = tempfile.mkdtemp()
db_path = Path(db_dir) / "test.db"
monkeypatch.setenv("LANGFLOW_DATABASE_URL", f"sqlite:///{db_path}")
monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", "false")
if "load_flows" in request.keywords:
shutil.copyfile(
pytest.BASIC_EXAMPLE_PATH, Path(load_flows_dir) / "c54f9130-f2fa-4a3e-b22a-3856d946351b.json"
)
monkeypatch.setenv("LANGFLOW_LOAD_FLOWS_PATH", load_flows_dir)
monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", "true")

from langflow.main import create_app

app = create_app()

def init_app():
db_dir = tempfile.mkdtemp()
db_path = Path(db_dir) / "test.db"
monkeypatch.setenv("LANGFLOW_DATABASE_URL", f"sqlite:///{db_path}")
monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", "false")
if "load_flows" in request.keywords:
shutil.copyfile(
pytest.BASIC_EXAMPLE_PATH, Path(load_flows_dir) / "c54f9130-f2fa-4a3e-b22a-3856d946351b.json"
)
monkeypatch.setenv("LANGFLOW_LOAD_FLOWS_PATH", load_flows_dir)
monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", "true")

from langflow.main import create_app

app = create_app()
return app, db_path

app, db_path = await asyncio.to_thread(init_app)

app.dependency_overrides[get_storage_service] = lambda: mock_storage_service
async with (
Expand Down
Loading

0 comments on commit 1ce2a6b

Please sign in to comment.