diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d364db7a..3a0c58db 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,7 +15,6 @@ repos: - id: mypy files: ^quetz/ additional_dependencies: - - sqlalchemy-stubs - types-click - types-Jinja2 - types-mock diff --git a/environment.yml b/environment.yml index 7c2e78c7..a7bb9189 100644 --- a/environment.yml +++ b/environment.yml @@ -9,7 +9,7 @@ dependencies: - authlib=0.15.5 - psycopg2 - httpx>=0.22.0 - - sqlalchemy + - sqlalchemy >= 2, <3 - sqlalchemy-utils - sqlite - python-multipart diff --git a/init_db.py b/init_db.py index e24dd8f6..b01c419f 100644 --- a/init_db.py +++ b/init_db.py @@ -21,11 +21,10 @@ def init_test_db(): config = Config() init_db(config.sqlalchemy_database_url) - db = get_session(config.sqlalchemy_database_url) - testUsers = [] + with get_session(config) as db: + testUsers = [] - try: for index, username in enumerate(["alice", "bob", "carol", "dave"]): user = User(id=uuid.uuid4().bytes, username=username) @@ -102,8 +101,6 @@ def init_test_db(): db.add(channel_member) db.commit() - finally: - db.close() if __name__ == "__main__": diff --git a/plugins/quetz_conda_suggest/quetz_conda_suggest/main.py b/plugins/quetz_conda_suggest/quetz_conda_suggest/main.py index 7f21c095..d6095f21 100644 --- a/plugins/quetz_conda_suggest/quetz_conda_suggest/main.py +++ b/plugins/quetz_conda_suggest/quetz_conda_suggest/main.py @@ -4,7 +4,7 @@ import quetz from quetz.config import Config -from quetz.database import get_db_manager +from quetz.database import get_session from quetz.db_models import PackageVersion from quetz.utils import add_entry_for_index @@ -45,7 +45,7 @@ def post_add_package_version(version, condainfo): if command not in suggest_map: suggest_map[command] = package - with get_db_manager() as db: + with get_session() as db: if not version.binfiles: metadata = db_models.CondaSuggestMetadata( version_id=version.id, data=json.dumps(suggest_map) diff --git a/plugins/quetz_conda_suggest/tests/test_quetz_conda_suggest.py b/plugins/quetz_conda_suggest/tests/test_quetz_conda_suggest.py index 89746340..d22c9388 100644 --- a/plugins/quetz_conda_suggest/tests/test_quetz_conda_suggest.py +++ b/plugins/quetz_conda_suggest/tests/test_quetz_conda_suggest.py @@ -2,7 +2,6 @@ import shutil import tarfile import tempfile -from contextlib import contextmanager from unittest import mock import pytest @@ -27,13 +26,12 @@ def test_post_add_package_version(package_version, db, config): target.seek(0) condainfo = CondaInfo(target, filename) - @contextmanager def get_db(): - yield db + return db from quetz_conda_suggest import main - with mock.patch("quetz_conda_suggest.main.get_db_manager", get_db): + with mock.patch("quetz_conda_suggest.main.get_session", get_db): main.post_add_package_version(package_version, condainfo) meta = db.query(db_models.CondaSuggestMetadata).first() @@ -50,7 +48,7 @@ def get_db(): b"lib/libtpkg.so\n", b"lib/pkgconfig/libtpkg.pc\n", ] - with mock.patch("quetz_conda_suggest.main.get_db_manager", get_db): + with mock.patch("quetz_conda_suggest.main.get_session", get_db): main.post_add_package_version(package_version, condainfo) meta = db.query(db_models.CondaSuggestMetadata).all() @@ -76,7 +74,6 @@ def test_conda_suggest_endpoint_with_upload( response = client.get("/api/dummylogin/madhurt") filename = "test-package-0.1-0.tar.bz2" - @contextmanager def get_db(): yield db @@ -114,7 +111,7 @@ def get_db(): tar.addfile(t, io.BytesIO(b)) tar.close() - with mock.patch("quetz_conda_suggest.main.get_db_manager", get_db): + with mock.patch("quetz_conda_suggest.main.get_session", get_db): url = f"/api/channels/{channel.name}/files/" files = {"files": (filename, open(filename, "rb"))} response = client.post(url, files=files) diff --git a/plugins/quetz_content_trust/quetz_content_trust/api.py b/plugins/quetz_content_trust/quetz_content_trust/api.py index 5f55dbef..14a05e0f 100644 --- a/plugins/quetz_content_trust/quetz_content_trust/api.py +++ b/plugins/quetz_content_trust/quetz_content_trust/api.py @@ -8,7 +8,7 @@ from quetz import authorization from quetz.config import Config -from quetz.database import get_db_manager +from quetz.database import get_session from quetz.deps import get_rules from . import db_models @@ -102,7 +102,7 @@ def post_role( ): auth.assert_channel_roles(channel, ["owner"]) - with get_db_manager() as db: + with get_session() as db: existing_role_count = ( db.query(db_models.ContentTrustRole) .filter( @@ -190,7 +190,7 @@ def get_role( ): auth.assert_channel_roles(channel, ["owner", "maintainer", "member"]) - with get_db_manager() as db: + with get_session() as db: query = ( db.query(db_models.ContentTrustRole) .filter(db_models.ContentTrustRole.channel == channel) @@ -211,7 +211,7 @@ def get_new_key(secret: bool = False): mamba_key = libmamba_api.Key.from_ed25519(key.public_key) private_key = key.private_key - with get_db_manager() as db: + with get_session() as db: db.add(key) db.commit() diff --git a/plugins/quetz_content_trust/quetz_content_trust/main.py b/plugins/quetz_content_trust/quetz_content_trust/main.py index 5f3723b9..f2f4a996 100644 --- a/plugins/quetz_content_trust/quetz_content_trust/main.py +++ b/plugins/quetz_content_trust/quetz_content_trust/main.py @@ -4,7 +4,7 @@ from sqlalchemy import desc import quetz -from quetz.database import get_db_manager +from quetz.database import get_session from . import db_models from .api import router @@ -21,7 +21,7 @@ def register_router(): def post_index_creation(raw_repodata: dict, channel_name, subdir): """Use available online keys to sign packages""" - with get_db_manager() as db: + with get_session() as db: query = ( db.query(db_models.SigningKey) .join(db_models.RoleDelegation.keys) diff --git a/plugins/quetz_repodata_patching/quetz_repodata_patching/main.py b/plugins/quetz_repodata_patching/quetz_repodata_patching/main.py index 4ed2e43f..57ecca1b 100644 --- a/plugins/quetz_repodata_patching/quetz_repodata_patching/main.py +++ b/plugins/quetz_repodata_patching/quetz_repodata_patching/main.py @@ -9,7 +9,7 @@ import quetz from quetz.config import Config -from quetz.database import get_db_manager +from quetz.database import get_session from quetz.db_models import PackageFormatEnum, PackageVersion from quetz.utils import add_temp_static_file @@ -107,7 +107,7 @@ def _load_instructions(tar, path): @quetz.hookimpl(tryfirst=True) def post_package_indexing(tempdir: Path, channel_name, subdirs, files, packages): - with get_db_manager() as db: + with get_session() as db: query = ( db.query(PackageVersion) .filter( diff --git a/plugins/quetz_repodata_patching/tests/test_main.py b/plugins/quetz_repodata_patching/tests/test_main.py index defb3261..dce0428d 100644 --- a/plugins/quetz_repodata_patching/tests/test_main.py +++ b/plugins/quetz_repodata_patching/tests/test_main.py @@ -300,7 +300,7 @@ def test_post_package_indexing( def get_db(): yield db - with mock.patch("quetz_repodata_patching.main.get_db_manager", get_db): + with mock.patch("quetz_repodata_patching.main.get_session", get_db): indexing.update_indexes(dao, pkgstore, channel_name) ext = "json.bz2" if compressed_repodata else "json" @@ -378,7 +378,7 @@ def test_index_html( def get_db(): yield db - with mock.patch("quetz_repodata_patching.main.get_db_manager", get_db): + with mock.patch("quetz_repodata_patching.main.get_session", get_db): indexing.update_indexes(dao, pkgstore, channel_name) index_path = os.path.join( @@ -419,7 +419,7 @@ def test_patches_for_subdir( def get_db(): yield db - with mock.patch("quetz_repodata_patching.main.get_db_manager", get_db): + with mock.patch("quetz_repodata_patching.main.get_session", get_db): indexing.update_indexes(dao, pkgstore, channel_name) index_path = os.path.join( @@ -471,7 +471,7 @@ def test_no_repodata_patches_package( def get_db(): yield db - with mock.patch("quetz_repodata_patching.main.get_db_manager", get_db): + with mock.patch("quetz_repodata_patching.main.get_session", get_db): indexing.update_indexes(dao, pkgstore, channel_name) index_path = os.path.join( diff --git a/plugins/quetz_runexports/quetz_runexports/main.py b/plugins/quetz_runexports/quetz_runexports/main.py index 4a580b81..ed72fa8e 100644 --- a/plugins/quetz_runexports/quetz_runexports/main.py +++ b/plugins/quetz_runexports/quetz_runexports/main.py @@ -1,7 +1,7 @@ import json import quetz -from quetz.database import get_db_manager +from quetz.database import get_session from . import db_models from .api import router @@ -16,7 +16,7 @@ def register_router(): def post_add_package_version(version, condainfo): run_exports = json.dumps(condainfo.run_exports) - with get_db_manager() as db: + with get_session() as db: if not version.runexports: metadata = db_models.PackageVersionMetadata( version_id=version.id, data=run_exports diff --git a/plugins/quetz_runexports/tests/test_quetz_runexports.py b/plugins/quetz_runexports/tests/test_quetz_runexports.py index 7a02d94f..178eb191 100644 --- a/plugins/quetz_runexports/tests/test_quetz_runexports.py +++ b/plugins/quetz_runexports/tests/test_quetz_runexports.py @@ -62,7 +62,7 @@ def get_db(): from quetz_runexports import main - with mock.patch("quetz_runexports.main.get_db_manager", get_db): + with mock.patch("quetz_runexports.main.get_session", get_db): main.post_add_package_version(package_version, condainfo) meta = db.query(db_models.PackageVersionMetadata).first() @@ -71,7 +71,7 @@ def get_db(): # modify runexport and re-save condainfo.run_exports = {"weak": ["somepackage < 0.3"]} - with mock.patch("quetz_runexports.main.get_db_manager", get_db): + with mock.patch("quetz_runexports.main.get_session", get_db): main.post_add_package_version(package_version, condainfo) meta = db.query(db_models.PackageVersionMetadata).all() diff --git a/pyproject.toml b/pyproject.toml index 2e64b5b7..dd674db8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,9 +59,6 @@ venvPath= "." [tool.mypy] ignore_missing_imports = true -plugins = [ - "sqlmypy" -] disable_error_code = [ "misc" ] diff --git a/quetz/cli.py b/quetz/cli.py index bff53c68..34e313cf 100644 --- a/quetz/cli.py +++ b/quetz/cli.py @@ -359,8 +359,8 @@ def add_user_roles( config = _get_config(path) with working_directory(path): - db = get_session(config.sqlalchemy_database_url) - _set_user_roles(db, config) + with get_session(config) as db: + _set_user_roles(db, config) @app.command() @@ -500,11 +500,11 @@ def create( deployment_folder.joinpath("channels").mkdir(exist_ok=True) with working_directory(db_path): - db = get_session(config.sqlalchemy_database_url) _run_migrations(config.sqlalchemy_database_url) - if dev: - _fill_test_database(db) - _set_user_roles(db, config) + with get_session(config) as db: + if dev: + _fill_test_database(db) + _set_user_roles(db, config) def _get_config(path: Union[Path, str]) -> Config: @@ -758,14 +758,12 @@ def start_supervisor_daemon(path: Path, num_procs=None): # is set there (it only matters for sqlite database). db_path = path if path.joinpath("config.toml").exists() else os.getcwd() with working_directory(db_path): - db = get_session(config.sqlalchemy_database_url) - supervisor = Supervisor(db, manager) - try: - supervisor.run() - except KeyboardInterrupt: - logger.info("stopping supervisor") - finally: - db.close() + with get_session(config) as db: + supervisor = Supervisor(db, manager) + try: + supervisor.run() + except KeyboardInterrupt: + logger.info("stopping supervisor") @app.command() diff --git a/quetz/database.py b/quetz/database.py index 9e77db65..1a9cb97b 100644 --- a/quetz/database.py +++ b/quetz/database.py @@ -2,7 +2,6 @@ # Distributed under the terms of the Modified BSD License. import logging import re -from contextlib import contextmanager from typing import Callable from sqlalchemy import create_engine, event @@ -66,19 +65,16 @@ def get_session_maker(engine) -> Callable[[], Session]: return sessionmaker(autocommit=False, autoflush=True, bind=engine) -def get_session(db_url: str, **kwargs) -> Session: +def get_session(config: Config | None) -> Session: """Get a database session. - - Important note: this function is mocked during tests! + ea + Important note: this function is mocked during tests! """ - return get_session_maker(get_engine(db_url, **kwargs))() - + if config is None: + config = Config() -@contextmanager -def get_db_manager(): - config = Config() - db = get_session( + engine = get_engine( db_url=config.sqlalchemy_database_url, echo=config.sqlalchemy_echo_sql, postgres_kwargs=dict( @@ -86,11 +82,7 @@ def get_db_manager(): max_overflow=config.sqlalchemy_postgres_max_overflow, ), ) - - try: - yield db - finally: - db.close() + return get_session_maker(engine)() def sanitize_db_url(db_url: str) -> str: diff --git a/quetz/deps.py b/quetz/deps.py index 3ef97665..495333fb 100644 --- a/quetz/deps.py +++ b/quetz/deps.py @@ -43,19 +43,8 @@ def get_config(): def get_db(config: Config = Depends(get_config)): - database_url = config.sqlalchemy_database_url - db = get_db_session( - database_url, - echo=config.sqlalchemy_echo_sql, - postgres_kwargs=dict( - pool_size=config.sqlalchemy_postgres_pool_size, - max_overflow=config.sqlalchemy_postgres_max_overflow, - ), - ) - try: + with get_db_session(config) as db: yield db - finally: - db.close() def get_dao(db: Session = Depends(get_db)): diff --git a/quetz/testing/fixtures.py b/quetz/testing/fixtures.py index 80671b39..32a4f4f5 100644 --- a/quetz/testing/fixtures.py +++ b/quetz/testing/fixtures.py @@ -1,7 +1,7 @@ import os import shutil import tempfile -from typing import List +from typing import List, Iterator import pytest from alembic.command import upgrade as alembic_upgrade @@ -13,6 +13,7 @@ from quetz.dao import Dao from quetz.database import get_engine, get_session_maker from quetz.db_models import Base +from sqlalchemy.orm import Session def pytest_configure(config): @@ -118,7 +119,7 @@ def auto_rollback(): @pytest.fixture -def session_maker(sql_connection, create_tables, auto_rollback): +def session_maker(sql_connection, create_tables, auto_rollback) -> Iterator[Session]: # run the tests with a separate external DB transaction # so that we can easily rollback all db changes (even if committed) # done by the test client diff --git a/quetz/tests/test_dao.py b/quetz/tests/test_dao.py index 0d1591ae..5ccaef7c 100644 --- a/quetz/tests/test_dao.py +++ b/quetz/tests/test_dao.py @@ -7,7 +7,7 @@ from quetz import errors, rest_models from quetz.dao import Dao -from quetz.database import get_session +from quetz.database import get_engine, get_session_maker from quetz.db_models import Channel, Package, PackageVersion from quetz.metrics.db_models import IntervalType, PackageVersionMetric, round_timestamp @@ -406,10 +406,11 @@ def db_extra(database_url): Use only for tests that require two sessions concurrently. For most cases you will want to use the db fixture (from quetz.testing.fixtures)""" - session = get_session(database_url) - + engine = get_engine( + db_url=database_url, + ) + session = get_session_maker(engine)() yield session - session.close()