Skip to content

Commit

Permalink
Raise exception in tests when blocking code is called in event loop
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Oct 28, 2024
1 parent 924e02f commit 6773382
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def build_component_menu_list(self, file_paths):

async def process_file_async(self, file_path):
try:
file_content = self.read_file_content(file_path)
file_content = await asyncio.to_thread(self.read_file_content, file_path)
except Exception: # noqa: BLE001
logger.exception(f"Error while reading file {file_path}")
return False, f"Could not read {file_path}"
Expand Down
4 changes: 3 additions & 1 deletion src/backend/base/langflow/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ async def dispatch(self, request: Request, call_next):


def get_lifespan(*, fix_migration=False, version=None):
telemetry_service = get_telemetry_service()

def _initialize():
initialize_services(fix_migration=fix_migration)
setup_llm_caching()
Expand All @@ -104,7 +106,7 @@ async def lifespan(_app: FastAPI):
await asyncio.to_thread(_initialize)
all_types_dict = await get_and_cache_all_types_dict(get_settings_service())
await asyncio.to_thread(create_or_update_starter_projects, all_types_dict)
get_telemetry_service().start()
telemetry_service.start()
await asyncio.to_thread(load_flows_from_directory)
yield
except Exception as exc:
Expand Down
6 changes: 4 additions & 2 deletions src/backend/base/langflow/services/telemetry/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(self, settings_service: SettingsService):
self._stopping = False

self.ot = OpenTelemetry(prometheus_enabled=settings_service.settings.prometheus_enabled)
self.architecture = None

# Check for do-not-track settings
self.do_not_track = (
Expand Down Expand Up @@ -93,15 +94,16 @@ async def _queue_event(self, payload) -> None:
async def log_package_version(self) -> None:
python_version = ".".join(platform.python_version().split(".")[:2])
version_info = get_version_info()
architecture = platform.architecture()[0]
if self.architecture is None:
self.architecture = (await asyncio.to_thread(platform.architecture))[0]
payload = VersionPayload(
package=version_info["package"].lower(),
version=version_info["version"],
platform=platform.platform(),
python=python_version,
cache_type=self.settings_service.settings.cache_type,
backend_only=self.settings_service.settings.backend_only,
arch=architecture,
arch=self.architecture,
auto_login=self.settings_service.auth_settings.AUTO_LOGIN,
)
await self._queue_event((self.send_telemetry_data, payload, None))
Expand Down
114 changes: 114 additions & 0 deletions src/backend/tests/blockbuster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import asyncio
import inspect
import io
import os
import socket
import ssl
import sys
from importlib.abc import FileLoader

import forbiddenfruit


class BlockingError(Exception): ...


def _raise_for_function(func):
if inspect.isbuiltin(func):
msg = f"Blocking call to {func.__qualname__} ({func.__self__})"
elif inspect.ismethoddescriptor(func):
msg = f"Blocking call to {func}"
else:
msg = f"Blocking call to {func.__module__}.{func.__qualname__}"
raise BlockingError(msg)


def _raise_if_blocking(func):
def wrapper(*args, **kwargs):
try:
asyncio.get_running_loop()
except RuntimeError:
return func(*args, **kwargs)
_raise_for_function(func)
return None

return wrapper


def _wrap_os_blocking(func):
def os_op(fd, *args, **kwargs):
try:
asyncio.get_running_loop()
except RuntimeError:
return func(fd, *args, **kwargs)
if os.get_blocking(fd):
_raise_for_function(func)
return func(fd, *args, **kwargs)

return os_op


def _wrap_socket_blocking(func):
def socket_op(self, *args, **kwargs):
try:
asyncio.get_running_loop()
except RuntimeError:
return func(self, *args, **kwargs)
if self.getblocking():
msg = f"Blocking call to {func.__module__}.{func.__qualname__}"
raise BlockingError(msg)
return func(self, *args, **kwargs)

return socket_op


def _wrap_file_blocking(func):
def file_op(self, *args, **kwargs):
try:
asyncio.get_running_loop()
except RuntimeError:
return func(self, *args, **kwargs)
# Ignore blocking reads in import file loader
inspect.stack()
for frame_info in inspect.stack():
if isinstance(frame_info.frame.f_locals.get("self"), FileLoader):
return func(self, *args, **kwargs)
if frame_info.filename.endswith("_pytest/assertion/rewrite.py") and frame_info.function == "_read_pyc":
return func(self, *args, **kwargs)
if self.fileno() not in [sys.stdout.fileno(), sys.stderr.fileno()]:
_raise_for_function(func)
return func(self, *args, **kwargs)

return file_op


def init():
# time.sleep = _raise_if_blocking(time.sleep)

os.read = _wrap_os_blocking(os.read)
os.write = _wrap_os_blocking(os.write)

socket.socket.send = _wrap_socket_blocking(socket.socket.send)
socket.socket.sendall = _wrap_socket_blocking(socket.socket.sendall)
socket.socket.sendto = _wrap_socket_blocking(socket.socket.sendto)
socket.socket.recv = _wrap_socket_blocking(socket.socket.recv)
socket.socket.recv_into = _wrap_socket_blocking(socket.socket.recv_into)
socket.socket.recvfrom = _wrap_socket_blocking(socket.socket.recvfrom)
socket.socket.recvfrom_into = _wrap_socket_blocking(socket.socket.recvfrom_into)
socket.socket.recvmsg = _wrap_socket_blocking(socket.socket.recvmsg)
socket.socket.recvmsg_into = _wrap_socket_blocking(socket.socket.recvmsg_into)

ssl.SSLSocket.write = _wrap_socket_blocking(ssl.SSLSocket.write)
ssl.SSLSocket.send = _wrap_socket_blocking(ssl.SSLSocket.send)
ssl.SSLSocket.read = _wrap_socket_blocking(ssl.SSLSocket.read)
ssl.SSLSocket.recv = _wrap_socket_blocking(ssl.SSLSocket.recv)

forbiddenfruit.curse(io.BufferedReader, "read", _wrap_file_blocking(io.BufferedReader.read))
forbiddenfruit.curse(io.BufferedWriter, "read", _wrap_file_blocking(io.BufferedWriter.read))
forbiddenfruit.curse(io.BufferedWriter, "write", _wrap_file_blocking(io.BufferedWriter.write))
forbiddenfruit.curse(io.BufferedRandom, "read", _wrap_file_blocking(io.BufferedRandom.read))
forbiddenfruit.curse(io.BufferedRandom, "write", _wrap_file_blocking(io.BufferedRandom.write))
forbiddenfruit.curse(io.TextIOWrapper, "read", _wrap_file_blocking(io.TextIOWrapper.read))
forbiddenfruit.curse(io.TextIOWrapper, "write", _wrap_file_blocking(io.TextIOWrapper.write))
forbiddenfruit.curse(io.FileIO, "read", _wrap_file_blocking(io.FileIO.read))
forbiddenfruit.curse(io.FileIO, "write", _wrap_file_blocking(io.FileIO.write))
40 changes: 24 additions & 16 deletions src/backend/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import json
import shutil

Expand Down Expand Up @@ -32,13 +33,15 @@
from sqlmodel.pool import StaticPool
from typer.testing import CliRunner

from tests import blockbuster
from tests.api_keys import get_openai_api_key

if TYPE_CHECKING:
from langflow.services.database.service import DatabaseService


load_dotenv()
blockbuster.init()


def pytest_configure(config):
Expand Down Expand Up @@ -286,23 +289,28 @@ async def client_fixture(
if "noclient" in request.keywords:
yield
else:
db_dir = tempfile.mkdtemp()
db_path = Path(db_dir) / "test.db"
monkeypatch.setenv("LANGFLOW_DATABASE_URL", f"sqlite:///{db_path}")
monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", "false")
if "load_flows" in request.keywords:
shutil.copyfile(
pytest.BASIC_EXAMPLE_PATH, Path(load_flows_dir) / "c54f9130-f2fa-4a3e-b22a-3856d946351b.json"
)
monkeypatch.setenv("LANGFLOW_LOAD_FLOWS_PATH", load_flows_dir)
monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", "true")

from langflow.main import create_app

app = create_app()
db_service = get_db_service()
db_service.database_url = f"sqlite:///{db_path}"
db_service.reload_engine()
def init_app():
db_dir = tempfile.mkdtemp()
db_path = Path(db_dir) / "test.db"
monkeypatch.setenv("LANGFLOW_DATABASE_URL", f"sqlite:///{db_path}")
monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", "false")
if "load_flows" in request.keywords:
shutil.copyfile(
pytest.BASIC_EXAMPLE_PATH, Path(load_flows_dir) / "c54f9130-f2fa-4a3e-b22a-3856d946351b.json"
)
monkeypatch.setenv("LANGFLOW_LOAD_FLOWS_PATH", load_flows_dir)
monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", "true")

from langflow.main import create_app

app = create_app()
db_service = get_db_service()
db_service.database_url = f"sqlite:///{db_path}"
db_service.reload_engine()
return app, db_path

app, db_path = await asyncio.to_thread(init_app)
# app.dependency_overrides[get_session] = get_session_override
async with (
LifespanManager(app, startup_timeout=None, shutdown_timeout=None) as manager,
Expand Down

0 comments on commit 6773382

Please sign in to comment.