Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add AsyncSession support for non-blocking db operations #4408

Merged
merged 2 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ dependencies = [
"huggingface-hub[inference]>=0.23.2",
"networkx>=3.1",
"fake-useragent>=1.5.0",
"psycopg2-binary>=2.9.6",
"pyarrow>=14.0.0",
"wikipedia>=1.4.0",
"qdrant-client~=1.9.2",
Expand All @@ -49,7 +48,6 @@ dependencies = [
"pymongo>=4.6.0",
"supabase~=2.6.0",
"certifi>=2023.11.17,<2025.0.0",
"psycopg>=3.1.9",
"fastavro>=1.8.0",
"redis>=5.0.1",
"metaphor-python>=0.1.11",
Expand Down Expand Up @@ -111,6 +109,7 @@ dependencies = [
"langchain-elasticsearch>=0.2.0",
"langchain-ollama>=0.2.0",
"pymupdf~=1.24.13",
"sqlalchemy[aiosqlite,postgresql_psycopg2binary,postgresql_psycopgbinary]>=2.0.36"
]

[project.urls]
Expand Down
50 changes: 29 additions & 21 deletions src/backend/base/langflow/__main__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import inspect
import platform
import socket
Expand Down Expand Up @@ -27,7 +28,7 @@
create_default_folder_if_it_doesnt_exist,
)
from langflow.services.database.utils import session_getter
from langflow.services.deps import get_db_service, get_settings_service, session_scope
from langflow.services.deps import async_session_scope, get_db_service, get_settings_service
from langflow.services.settings.constants import DEFAULT_SUPERUSER
from langflow.services.utils import initialize_services
from langflow.utils.version import fetch_latest_version, get_version_info
Expand Down Expand Up @@ -486,28 +487,35 @@ def api_key(
if not auth_settings.AUTO_LOGIN:
typer.echo("Auto login is disabled. API keys cannot be created through the CLI.")
return
with session_scope() as session:
from langflow.services.database.models.user.model import User

superuser = session.exec(select(User).where(User.username == DEFAULT_SUPERUSER)).first()
if not superuser:
typer.echo("Default superuser not found. This command requires a superuser and AUTO_LOGIN to be enabled.")
return
from langflow.services.database.models.api_key import ApiKey, ApiKeyCreate
from langflow.services.database.models.api_key.crud import (
create_api_key,
delete_api_key,
)

api_key = session.exec(select(ApiKey).where(ApiKey.user_id == superuser.id)).first()
if api_key:
delete_api_key(session, api_key.id)
async def aapi_key():
async with async_session_scope() as session:
from langflow.services.database.models.user.model import User

api_key_create = ApiKeyCreate(name="CLI")
unmasked_api_key = create_api_key(session, api_key_create, user_id=superuser.id)
session.commit()
# Create a banner to display the API key and tell the user it won't be shown again
api_key_banner(unmasked_api_key)
superuser = (await session.exec(select(User).where(User.username == DEFAULT_SUPERUSER))).first()
if not superuser:
typer.echo(
"Default superuser not found. This command requires a superuser and AUTO_LOGIN to be enabled."
)
return None
from langflow.services.database.models.api_key import ApiKey, ApiKeyCreate
from langflow.services.database.models.api_key.crud import (
create_api_key,
delete_api_key,
)

api_key = (await session.exec(select(ApiKey).where(ApiKey.user_id == superuser.id))).first()
if api_key:
await delete_api_key(session, api_key.id)

api_key_create = ApiKeyCreate(name="CLI")
unmasked_api_key = await create_api_key(session, api_key_create, user_id=superuser.id)
await session.commit()
return unmasked_api_key

unmasked_api_key = asyncio.run(aapi_key())
Copy link
Collaborator Author

@cbornet cbornet Nov 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since create_api_key and delete_api_key are now coroutines, we run the command in an event loop.
I think it's OK as it's a CLI and so a kind of entry point.
But please comment if you think there could be issues like conflicts with other event loops.
The other solution would be to duplicate create_api_key and delete_api_key to have both a sync and an async version.

# Create a banner to display the API key and tell the user it won't be shown again
api_key_banner(unmasked_api_key)


def api_key_banner(unmasked_api_key) -> None:
Expand Down
4 changes: 3 additions & 1 deletion src/backend/base/langflow/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
from loguru import logger
from sqlalchemy import delete
from sqlmodel import Session
from sqlmodel.ext.asyncio.session import AsyncSession

from langflow.graph.graph.base import Graph
from langflow.services.auth.utils import get_current_active_user
from langflow.services.database.models import User
from langflow.services.database.models.flow import Flow
from langflow.services.database.models.transactions.model import TransactionTable
from langflow.services.database.models.vertex_builds.model import VertexBuildTable
from langflow.services.deps import get_session
from langflow.services.deps import get_async_session, get_session
from langflow.services.store.utils import get_lf_version_from_pypi

if TYPE_CHECKING:
Expand All @@ -31,6 +32,7 @@

CurrentActiveUser = Annotated[User, Depends(get_current_active_user)]
DbSession = Annotated[Session, Depends(get_session)]
AsyncDbSession = Annotated[AsyncSession, Depends(get_async_session)]


def has_api_terms(word: str):
Expand Down
14 changes: 7 additions & 7 deletions src/backend/base/langflow/api/v1/api_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from fastapi import APIRouter, Depends, HTTPException, Response

from langflow.api.utils import CurrentActiveUser, DbSession
from langflow.api.utils import AsyncDbSession, CurrentActiveUser, DbSession
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: we can't replace atm the DbSession by AsyncDbSession in save_store_api_key as it gets the current user with the sync session and so we can't modify it with the async session.

from langflow.api.v1.schemas import ApiKeyCreateRequest, ApiKeysResponse
from langflow.services.auth import utils as auth_utils

Expand All @@ -20,12 +20,12 @@

@router.get("/")
async def get_api_keys_route(
db: DbSession,
db: AsyncDbSession,
current_user: CurrentActiveUser,
) -> ApiKeysResponse:
try:
user_id = current_user.id
keys = get_api_keys(db, user_id)
keys = await get_api_keys(db, user_id)

return ApiKeysResponse(total_count=len(keys), user_id=user_id, api_keys=keys)
except Exception as exc:
Expand All @@ -36,22 +36,22 @@ async def get_api_keys_route(
async def create_api_key_route(
req: ApiKeyCreate,
current_user: CurrentActiveUser,
db: DbSession,
db: AsyncDbSession,
) -> UnmaskedApiKeyRead:
try:
user_id = current_user.id
return create_api_key(db, req, user_id=user_id)
return await create_api_key(db, req, user_id=user_id)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e)) from e


@router.delete("/{api_key_id}", dependencies=[Depends(auth_utils.get_current_active_user)])
async def delete_api_key_route(
api_key_id: UUID,
db: DbSession,
db: AsyncDbSession,
):
try:
delete_api_key(db, api_key_id)
await delete_api_key(db, api_key_id)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e)) from e
return {"detail": "API Key deleted"}
Expand Down
19 changes: 10 additions & 9 deletions src/backend/base/langflow/services/database/models/api_key/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,21 @@
from uuid import UUID

from sqlmodel import Session, select
from sqlmodel.ext.asyncio.session import AsyncSession

from langflow.services.database.models.api_key import ApiKey, ApiKeyCreate, ApiKeyRead, UnmaskedApiKeyRead

if TYPE_CHECKING:
from sqlmodel.sql.expression import SelectOfScalar


def get_api_keys(session: Session, user_id: UUID) -> list[ApiKeyRead]:
async def get_api_keys(session: AsyncSession, user_id: UUID) -> list[ApiKeyRead]:
query: SelectOfScalar = select(ApiKey).where(ApiKey.user_id == user_id)
api_keys = session.exec(query).all()
api_keys = (await session.exec(query)).all()
return [ApiKeyRead.model_validate(api_key) for api_key in api_keys]


def create_api_key(session: Session, api_key_create: ApiKeyCreate, user_id: UUID) -> UnmaskedApiKeyRead:
async def create_api_key(session: AsyncSession, api_key_create: ApiKeyCreate, user_id: UUID) -> UnmaskedApiKeyRead:
# Generate a random API key with 32 bytes of randomness
generated_api_key = f"sk-{secrets.token_urlsafe(32)}"

Expand All @@ -30,20 +31,20 @@ def create_api_key(session: Session, api_key_create: ApiKeyCreate, user_id: UUID
)

session.add(api_key)
session.commit()
session.refresh(api_key)
await session.commit()
await session.refresh(api_key)
unmasked = UnmaskedApiKeyRead.model_validate(api_key, from_attributes=True)
unmasked.api_key = generated_api_key
return unmasked


def delete_api_key(session: Session, api_key_id: UUID) -> None:
api_key = session.get(ApiKey, api_key_id)
async def delete_api_key(session: AsyncSession, api_key_id: UUID) -> None:
api_key = await session.get(ApiKey, api_key_id)
if api_key is None:
msg = "API Key not found"
raise ValueError(msg)
session.delete(api_key)
session.commit()
await session.delete(api_key)
await session.commit()


def check_key(session: Session, api_key: str) -> ApiKey | None:
Expand Down
82 changes: 56 additions & 26 deletions src/backend/base/langflow/services/database/service.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

import asyncio
import sqlite3
import time
from contextlib import contextmanager
from contextlib import asynccontextmanager, contextmanager
from datetime import datetime, timezone
from pathlib import Path
from typing import TYPE_CHECKING
Expand All @@ -14,7 +15,9 @@
from sqlalchemy import event, inspect
from sqlalchemy.engine import Engine
from sqlalchemy.exc import OperationalError
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from sqlmodel import Session, SQLModel, create_engine, select, text
from sqlmodel.ext.asyncio.session import AsyncSession

from langflow.services.base import Service
from langflow.services.database import models
Expand All @@ -39,12 +42,17 @@ def __init__(self, settings_service: SettingsService):
msg = "No database URL provided"
raise ValueError(msg)
self.database_url: str = settings_service.settings.database_url
self._sanitize_database_url()
# This file is in langflow.services.database.manager.py
# the ini is in langflow
langflow_dir = Path(__file__).parent.parent.parent
self.script_location = langflow_dir / "alembic"
self.alembic_cfg_path = langflow_dir / "alembic.ini"
# register the event listener for sqlite as part of this class.
# Using decorator will make the method not able to use self
event.listen(Engine, "connect", self.on_connection)
self.engine = self._create_engine()
self.async_engine = self._create_async_engine()
alembic_log_file = self.settings_service.settings.alembic_log_file

# Check if the provided path is absolute, cross-platform.
Expand All @@ -56,10 +64,47 @@ def __init__(self, settings_service: SettingsService):
self.alembic_log_path = Path(langflow_dir) / alembic_log_file

def reload_engine(self) -> None:
self._sanitize_database_url()
self.engine = self._create_engine()
self.async_engine = self._create_async_engine()

def _sanitize_database_url(self):
if self.database_url.startswith("postgres://"):
self.database_url = self.database_url.replace("postgres://", "postgresql://")
logger.warning(
"Fixed postgres dialect in database URL. Replacing postgres:// with postgresql://. "
"To avoid this warning, update the database URL."
)

def _create_engine(self) -> Engine:
"""Create the engine for the database."""
return create_engine(
self.database_url,
connect_args=self._get_connect_args(),
pool_size=self.settings_service.settings.pool_size,
max_overflow=self.settings_service.settings.max_overflow,
)

def _create_async_engine(self) -> AsyncEngine:
"""Create the engine for the database."""
url_components = self.database_url.split("://", maxsplit=1)
if url_components[0].startswith("sqlite"):
database_url = "sqlite+aiosqlite://"
kwargs = {}
Copy link
Collaborator Author

@cbornet cbornet Nov 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aiosqlite doesn't have a connection pool.
I don't think that's a big issue but please comment if you think it is.

else:
kwargs = {
"pool_size": self.settings_service.settings.pool_size,
"max_overflow": self.settings_service.settings.max_overflow,
}
database_url = "postgresql+psycopg://" if url_components[0].startswith("postgresql") else url_components[0]
database_url += url_components[1]
return create_async_engine(
database_url,
connect_args=self._get_connect_args(),
**kwargs,
)

def _get_connect_args(self):
if self.settings_service.settings.database_url and self.settings_service.settings.database_url.startswith(
"sqlite"
):
Expand All @@ -69,33 +114,12 @@ def _create_engine(self) -> Engine:
}
else:
connect_args = {}
try:
# register the event listener for sqlite as part of this class.
# Using decorator will make the method not able to use self
event.listen(Engine, "connect", self.on_connection)

return create_engine(
self.database_url,
connect_args=connect_args,
pool_size=self.settings_service.settings.pool_size,
max_overflow=self.settings_service.settings.max_overflow,
)
except sa.exc.NoSuchModuleError as exc:
if "postgres" in str(exc) and not self.database_url.startswith("postgresql"):
# https://stackoverflow.com/questions/62688256/sqlalchemy-exc-nosuchmoduleerror-cant-load-plugin-sqlalchemy-dialectspostgre
self.database_url = self.database_url.replace("postgres://", "postgresql://")
logger.warning(
"Fixed postgres dialect in database URL. Replacing postgres:// with postgresql://. "
"To avoid this warning, update the database URL."
)
return self._create_engine()
msg = "Error creating database engine"
raise RuntimeError(msg) from exc
return connect_args

def on_connection(self, dbapi_connection, _connection_record) -> None:
from sqlite3 import Connection as sqliteConnection

if isinstance(dbapi_connection, sqliteConnection):
if isinstance(
dbapi_connection, sqlite3.Connection | sa.dialects.sqlite.aiosqlite.AsyncAdapt_aiosqlite_connection
):
pragmas: dict = self.settings_service.settings.sqlite_pragmas or {}
pragmas_list = []
for key, val in pragmas.items():
Expand All @@ -117,6 +141,11 @@ def with_session(self):
with Session(self.engine) as session:
yield session

@asynccontextmanager
async def with_async_session(self):
async with AsyncSession(self.async_engine) as session:
yield session

def migrate_flows_if_auto_login(self) -> None:
# if auto_login is enabled, we need to migrate the flows
# to the default superuser if they don't have a user id
Expand Down Expand Up @@ -334,3 +363,4 @@ def _teardown(self) -> None:

async def teardown(self) -> None:
await asyncio.to_thread(self._teardown)
await self.async_engine.dispose()
Loading
Loading