-
Notifications
You must be signed in to change notification settings - Fork 4.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Add AsyncSession support for non-blocking db operations #4408
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,7 @@ | |
|
||
from fastapi import APIRouter, Depends, HTTPException, Response | ||
|
||
from langflow.api.utils import CurrentActiveUser, DbSession | ||
from langflow.api.utils import AsyncDbSession, CurrentActiveUser, DbSession | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: we can't replace atm the DbSession by AsyncDbSession in save_store_api_key as it gets the current user with the sync session and so we can't modify it with the async session. |
||
from langflow.api.v1.schemas import ApiKeyCreateRequest, ApiKeysResponse | ||
from langflow.services.auth import utils as auth_utils | ||
|
||
|
@@ -20,12 +20,12 @@ | |
|
||
@router.get("/") | ||
async def get_api_keys_route( | ||
db: DbSession, | ||
db: AsyncDbSession, | ||
current_user: CurrentActiveUser, | ||
) -> ApiKeysResponse: | ||
try: | ||
user_id = current_user.id | ||
keys = get_api_keys(db, user_id) | ||
keys = await get_api_keys(db, user_id) | ||
|
||
return ApiKeysResponse(total_count=len(keys), user_id=user_id, api_keys=keys) | ||
except Exception as exc: | ||
|
@@ -36,22 +36,22 @@ async def get_api_keys_route( | |
async def create_api_key_route( | ||
req: ApiKeyCreate, | ||
current_user: CurrentActiveUser, | ||
db: DbSession, | ||
db: AsyncDbSession, | ||
) -> UnmaskedApiKeyRead: | ||
try: | ||
user_id = current_user.id | ||
return create_api_key(db, req, user_id=user_id) | ||
return await create_api_key(db, req, user_id=user_id) | ||
except Exception as e: | ||
raise HTTPException(status_code=400, detail=str(e)) from e | ||
|
||
|
||
@router.delete("/{api_key_id}", dependencies=[Depends(auth_utils.get_current_active_user)]) | ||
async def delete_api_key_route( | ||
api_key_id: UUID, | ||
db: DbSession, | ||
db: AsyncDbSession, | ||
): | ||
try: | ||
delete_api_key(db, api_key_id) | ||
await delete_api_key(db, api_key_id) | ||
except Exception as e: | ||
raise HTTPException(status_code=400, detail=str(e)) from e | ||
return {"detail": "API Key deleted"} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,9 @@ | ||
from __future__ import annotations | ||
|
||
import asyncio | ||
import sqlite3 | ||
import time | ||
from contextlib import contextmanager | ||
from contextlib import asynccontextmanager, contextmanager | ||
from datetime import datetime, timezone | ||
from pathlib import Path | ||
from typing import TYPE_CHECKING | ||
|
@@ -14,7 +15,9 @@ | |
from sqlalchemy import event, inspect | ||
from sqlalchemy.engine import Engine | ||
from sqlalchemy.exc import OperationalError | ||
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine | ||
from sqlmodel import Session, SQLModel, create_engine, select, text | ||
from sqlmodel.ext.asyncio.session import AsyncSession | ||
|
||
from langflow.services.base import Service | ||
from langflow.services.database import models | ||
|
@@ -39,12 +42,17 @@ def __init__(self, settings_service: SettingsService): | |
msg = "No database URL provided" | ||
raise ValueError(msg) | ||
self.database_url: str = settings_service.settings.database_url | ||
self._sanitize_database_url() | ||
# This file is in langflow.services.database.manager.py | ||
# the ini is in langflow | ||
langflow_dir = Path(__file__).parent.parent.parent | ||
self.script_location = langflow_dir / "alembic" | ||
self.alembic_cfg_path = langflow_dir / "alembic.ini" | ||
# register the event listener for sqlite as part of this class. | ||
# Using decorator will make the method not able to use self | ||
event.listen(Engine, "connect", self.on_connection) | ||
self.engine = self._create_engine() | ||
self.async_engine = self._create_async_engine() | ||
alembic_log_file = self.settings_service.settings.alembic_log_file | ||
|
||
# Check if the provided path is absolute, cross-platform. | ||
|
@@ -56,10 +64,47 @@ def __init__(self, settings_service: SettingsService): | |
self.alembic_log_path = Path(langflow_dir) / alembic_log_file | ||
|
||
def reload_engine(self) -> None: | ||
self._sanitize_database_url() | ||
self.engine = self._create_engine() | ||
self.async_engine = self._create_async_engine() | ||
|
||
def _sanitize_database_url(self): | ||
if self.database_url.startswith("postgres://"): | ||
self.database_url = self.database_url.replace("postgres://", "postgresql://") | ||
logger.warning( | ||
"Fixed postgres dialect in database URL. Replacing postgres:// with postgresql://. " | ||
"To avoid this warning, update the database URL." | ||
) | ||
|
||
def _create_engine(self) -> Engine: | ||
"""Create the engine for the database.""" | ||
return create_engine( | ||
self.database_url, | ||
connect_args=self._get_connect_args(), | ||
pool_size=self.settings_service.settings.pool_size, | ||
max_overflow=self.settings_service.settings.max_overflow, | ||
) | ||
|
||
def _create_async_engine(self) -> AsyncEngine: | ||
"""Create the engine for the database.""" | ||
url_components = self.database_url.split("://", maxsplit=1) | ||
if url_components[0].startswith("sqlite"): | ||
database_url = "sqlite+aiosqlite://" | ||
kwargs = {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. aiosqlite doesn't have a connection pool. |
||
else: | ||
kwargs = { | ||
"pool_size": self.settings_service.settings.pool_size, | ||
"max_overflow": self.settings_service.settings.max_overflow, | ||
} | ||
database_url = "postgresql+psycopg://" if url_components[0].startswith("postgresql") else url_components[0] | ||
database_url += url_components[1] | ||
return create_async_engine( | ||
database_url, | ||
connect_args=self._get_connect_args(), | ||
**kwargs, | ||
) | ||
|
||
def _get_connect_args(self): | ||
if self.settings_service.settings.database_url and self.settings_service.settings.database_url.startswith( | ||
"sqlite" | ||
): | ||
|
@@ -69,33 +114,12 @@ def _create_engine(self) -> Engine: | |
} | ||
else: | ||
connect_args = {} | ||
try: | ||
# register the event listener for sqlite as part of this class. | ||
# Using decorator will make the method not able to use self | ||
event.listen(Engine, "connect", self.on_connection) | ||
|
||
return create_engine( | ||
self.database_url, | ||
connect_args=connect_args, | ||
pool_size=self.settings_service.settings.pool_size, | ||
max_overflow=self.settings_service.settings.max_overflow, | ||
) | ||
except sa.exc.NoSuchModuleError as exc: | ||
if "postgres" in str(exc) and not self.database_url.startswith("postgresql"): | ||
# https://stackoverflow.com/questions/62688256/sqlalchemy-exc-nosuchmoduleerror-cant-load-plugin-sqlalchemy-dialectspostgre | ||
self.database_url = self.database_url.replace("postgres://", "postgresql://") | ||
logger.warning( | ||
"Fixed postgres dialect in database URL. Replacing postgres:// with postgresql://. " | ||
"To avoid this warning, update the database URL." | ||
) | ||
return self._create_engine() | ||
msg = "Error creating database engine" | ||
raise RuntimeError(msg) from exc | ||
return connect_args | ||
|
||
def on_connection(self, dbapi_connection, _connection_record) -> None: | ||
from sqlite3 import Connection as sqliteConnection | ||
|
||
if isinstance(dbapi_connection, sqliteConnection): | ||
if isinstance( | ||
dbapi_connection, sqlite3.Connection | sa.dialects.sqlite.aiosqlite.AsyncAdapt_aiosqlite_connection | ||
): | ||
pragmas: dict = self.settings_service.settings.sqlite_pragmas or {} | ||
pragmas_list = [] | ||
for key, val in pragmas.items(): | ||
|
@@ -117,6 +141,11 @@ def with_session(self): | |
with Session(self.engine) as session: | ||
yield session | ||
|
||
@asynccontextmanager | ||
async def with_async_session(self): | ||
async with AsyncSession(self.async_engine) as session: | ||
yield session | ||
|
||
def migrate_flows_if_auto_login(self) -> None: | ||
# if auto_login is enabled, we need to migrate the flows | ||
# to the default superuser if they don't have a user id | ||
|
@@ -334,3 +363,4 @@ def _teardown(self) -> None: | |
|
||
async def teardown(self) -> None: | ||
await asyncio.to_thread(self._teardown) | ||
await self.async_engine.dispose() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since create_api_key and delete_api_key are now coroutines, we run the command in an event loop.
I think it's OK as it's a CLI and so a kind of entry point.
But please comment if you think there could be issues like conflicts with other event loops.
The other solution would be to duplicate create_api_key and delete_api_key to have both a sync and an async version.