Skip to content

Commit

Permalink
refactor: Clean up DB sessions consistently
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreasAlbertQC committed Jul 10, 2024
1 parent 56ab2cf commit 39055c2
Show file tree
Hide file tree
Showing 17 changed files with 53 additions and 82 deletions.
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ repos:
- id: mypy
files: ^quetz/
additional_dependencies:
- sqlalchemy-stubs
- types-click
- types-Jinja2
- types-mock
Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ dependencies:
- authlib=0.15.5
- psycopg2
- httpx>=0.22.0
- sqlalchemy
- sqlalchemy >= 2, <3
- sqlalchemy-utils
- sqlite
- python-multipart
Expand Down
7 changes: 2 additions & 5 deletions init_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@
def init_test_db():
config = Config()
init_db(config.sqlalchemy_database_url)
db = get_session(config.sqlalchemy_database_url)

testUsers = []
with get_session(config) as db:
testUsers = []

try:
for index, username in enumerate(["alice", "bob", "carol", "dave"]):
user = User(id=uuid.uuid4().bytes, username=username)

Expand Down Expand Up @@ -102,8 +101,6 @@ def init_test_db():

db.add(channel_member)
db.commit()
finally:
db.close()


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions plugins/quetz_conda_suggest/quetz_conda_suggest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import quetz
from quetz.config import Config
from quetz.database import get_db_manager
from quetz.database import get_session
from quetz.db_models import PackageVersion
from quetz.utils import add_entry_for_index

Expand Down Expand Up @@ -45,7 +45,7 @@ def post_add_package_version(version, condainfo):
if command not in suggest_map:
suggest_map[command] = package

with get_db_manager() as db:
with get_session() as db:
if not version.binfiles:
metadata = db_models.CondaSuggestMetadata(
version_id=version.id, data=json.dumps(suggest_map)
Expand Down
11 changes: 4 additions & 7 deletions plugins/quetz_conda_suggest/tests/test_quetz_conda_suggest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import shutil
import tarfile
import tempfile
from contextlib import contextmanager
from unittest import mock

import pytest
Expand All @@ -27,13 +26,12 @@ def test_post_add_package_version(package_version, db, config):
target.seek(0)
condainfo = CondaInfo(target, filename)

@contextmanager
def get_db():
yield db
return db

from quetz_conda_suggest import main

with mock.patch("quetz_conda_suggest.main.get_db_manager", get_db):
with mock.patch("quetz_conda_suggest.main.get_session", get_db):
main.post_add_package_version(package_version, condainfo)

meta = db.query(db_models.CondaSuggestMetadata).first()
Expand All @@ -50,7 +48,7 @@ def get_db():
b"lib/libtpkg.so\n",
b"lib/pkgconfig/libtpkg.pc\n",
]
with mock.patch("quetz_conda_suggest.main.get_db_manager", get_db):
with mock.patch("quetz_conda_suggest.main.get_session", get_db):
main.post_add_package_version(package_version, condainfo)

meta = db.query(db_models.CondaSuggestMetadata).all()
Expand All @@ -76,7 +74,6 @@ def test_conda_suggest_endpoint_with_upload(
response = client.get("/api/dummylogin/madhurt")
filename = "test-package-0.1-0.tar.bz2"

@contextmanager
def get_db():
yield db

Expand Down Expand Up @@ -114,7 +111,7 @@ def get_db():
tar.addfile(t, io.BytesIO(b))
tar.close()

with mock.patch("quetz_conda_suggest.main.get_db_manager", get_db):
with mock.patch("quetz_conda_suggest.main.get_session", get_db):
url = f"/api/channels/{channel.name}/files/"
files = {"files": (filename, open(filename, "rb"))}
response = client.post(url, files=files)
Expand Down
8 changes: 4 additions & 4 deletions plugins/quetz_content_trust/quetz_content_trust/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from quetz import authorization
from quetz.config import Config
from quetz.database import get_db_manager
from quetz.database import get_session
from quetz.deps import get_rules

from . import db_models
Expand Down Expand Up @@ -102,7 +102,7 @@ def post_role(
):
auth.assert_channel_roles(channel, ["owner"])

with get_db_manager() as db:
with get_session() as db:
existing_role_count = (
db.query(db_models.ContentTrustRole)
.filter(
Expand Down Expand Up @@ -190,7 +190,7 @@ def get_role(
):
auth.assert_channel_roles(channel, ["owner", "maintainer", "member"])

with get_db_manager() as db:
with get_session() as db:
query = (
db.query(db_models.ContentTrustRole)
.filter(db_models.ContentTrustRole.channel == channel)
Expand All @@ -211,7 +211,7 @@ def get_new_key(secret: bool = False):
mamba_key = libmamba_api.Key.from_ed25519(key.public_key)
private_key = key.private_key

with get_db_manager() as db:
with get_session() as db:
db.add(key)
db.commit()

Expand Down
4 changes: 2 additions & 2 deletions plugins/quetz_content_trust/quetz_content_trust/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from sqlalchemy import desc

import quetz
from quetz.database import get_db_manager
from quetz.database import get_session

from . import db_models
from .api import router
Expand All @@ -21,7 +21,7 @@ def register_router():
def post_index_creation(raw_repodata: dict, channel_name, subdir):
"""Use available online keys to sign packages"""

with get_db_manager() as db:
with get_session() as db:
query = (
db.query(db_models.SigningKey)
.join(db_models.RoleDelegation.keys)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import quetz
from quetz.config import Config
from quetz.database import get_db_manager
from quetz.database import get_session
from quetz.db_models import PackageFormatEnum, PackageVersion
from quetz.utils import add_temp_static_file

Expand Down Expand Up @@ -107,7 +107,7 @@ def _load_instructions(tar, path):

@quetz.hookimpl(tryfirst=True)
def post_package_indexing(tempdir: Path, channel_name, subdirs, files, packages):
with get_db_manager() as db:
with get_session() as db:
query = (
db.query(PackageVersion)
.filter(
Expand Down
8 changes: 4 additions & 4 deletions plugins/quetz_repodata_patching/tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def test_post_package_indexing(
def get_db():
yield db

with mock.patch("quetz_repodata_patching.main.get_db_manager", get_db):
with mock.patch("quetz_repodata_patching.main.get_session", get_db):
indexing.update_indexes(dao, pkgstore, channel_name)

ext = "json.bz2" if compressed_repodata else "json"
Expand Down Expand Up @@ -378,7 +378,7 @@ def test_index_html(
def get_db():
yield db

with mock.patch("quetz_repodata_patching.main.get_db_manager", get_db):
with mock.patch("quetz_repodata_patching.main.get_session", get_db):
indexing.update_indexes(dao, pkgstore, channel_name)

index_path = os.path.join(
Expand Down Expand Up @@ -419,7 +419,7 @@ def test_patches_for_subdir(
def get_db():
yield db

with mock.patch("quetz_repodata_patching.main.get_db_manager", get_db):
with mock.patch("quetz_repodata_patching.main.get_session", get_db):
indexing.update_indexes(dao, pkgstore, channel_name)

index_path = os.path.join(
Expand Down Expand Up @@ -471,7 +471,7 @@ def test_no_repodata_patches_package(
def get_db():
yield db

with mock.patch("quetz_repodata_patching.main.get_db_manager", get_db):
with mock.patch("quetz_repodata_patching.main.get_session", get_db):
indexing.update_indexes(dao, pkgstore, channel_name)

index_path = os.path.join(
Expand Down
4 changes: 2 additions & 2 deletions plugins/quetz_runexports/quetz_runexports/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json

import quetz
from quetz.database import get_db_manager
from quetz.database import get_session

from . import db_models
from .api import router
Expand All @@ -16,7 +16,7 @@ def register_router():
def post_add_package_version(version, condainfo):
run_exports = json.dumps(condainfo.run_exports)

with get_db_manager() as db:
with get_session() as db:
if not version.runexports:
metadata = db_models.PackageVersionMetadata(
version_id=version.id, data=run_exports
Expand Down
4 changes: 2 additions & 2 deletions plugins/quetz_runexports/tests/test_quetz_runexports.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def get_db():

from quetz_runexports import main

with mock.patch("quetz_runexports.main.get_db_manager", get_db):
with mock.patch("quetz_runexports.main.get_session", get_db):
main.post_add_package_version(package_version, condainfo)

meta = db.query(db_models.PackageVersionMetadata).first()
Expand All @@ -71,7 +71,7 @@ def get_db():

# modify runexport and re-save
condainfo.run_exports = {"weak": ["somepackage < 0.3"]}
with mock.patch("quetz_runexports.main.get_db_manager", get_db):
with mock.patch("quetz_runexports.main.get_session", get_db):
main.post_add_package_version(package_version, condainfo)

meta = db.query(db_models.PackageVersionMetadata).all()
Expand Down
3 changes: 0 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,6 @@ venvPath= "."

[tool.mypy]
ignore_missing_imports = true
plugins = [
"sqlmypy"
]
disable_error_code = [
"misc"
]
Expand Down
26 changes: 12 additions & 14 deletions quetz/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,8 @@ def add_user_roles(
config = _get_config(path)

with working_directory(path):
db = get_session(config.sqlalchemy_database_url)
_set_user_roles(db, config)
with get_session(config) as db:
_set_user_roles(db, config)


@app.command()
Expand Down Expand Up @@ -500,11 +500,11 @@ def create(
deployment_folder.joinpath("channels").mkdir(exist_ok=True)

with working_directory(db_path):
db = get_session(config.sqlalchemy_database_url)
_run_migrations(config.sqlalchemy_database_url)
if dev:
_fill_test_database(db)
_set_user_roles(db, config)
with get_session(config) as db:
if dev:
_fill_test_database(db)
_set_user_roles(db, config)


def _get_config(path: Union[Path, str]) -> Config:
Expand Down Expand Up @@ -758,14 +758,12 @@ def start_supervisor_daemon(path: Path, num_procs=None):
# is set there (it only matters for sqlite database).
db_path = path if path.joinpath("config.toml").exists() else os.getcwd()
with working_directory(db_path):
db = get_session(config.sqlalchemy_database_url)
supervisor = Supervisor(db, manager)
try:
supervisor.run()
except KeyboardInterrupt:
logger.info("stopping supervisor")
finally:
db.close()
with get_session(config) as db:
supervisor = Supervisor(db, manager)
try:
supervisor.run()
except KeyboardInterrupt:
logger.info("stopping supervisor")


@app.command()
Expand Down
22 changes: 7 additions & 15 deletions quetz/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Distributed under the terms of the Modified BSD License.
import logging
import re
from contextlib import contextmanager
from typing import Callable

from sqlalchemy import create_engine, event
Expand Down Expand Up @@ -66,31 +65,24 @@ def get_session_maker(engine) -> Callable[[], Session]:
return sessionmaker(autocommit=False, autoflush=True, bind=engine)


def get_session(db_url: str, **kwargs) -> Session:
def get_session(config: Config | None) -> Session:
"""Get a database session.
Important note: this function is mocked during tests!
ea
Important note: this function is mocked during tests!
"""
return get_session_maker(get_engine(db_url, **kwargs))()

if config is None:
config = Config()

@contextmanager
def get_db_manager():
config = Config()
db = get_session(
engine = get_engine(
db_url=config.sqlalchemy_database_url,
echo=config.sqlalchemy_echo_sql,
postgres_kwargs=dict(
pool_size=config.sqlalchemy_postgres_pool_size,
max_overflow=config.sqlalchemy_postgres_max_overflow,
),
)

try:
yield db
finally:
db.close()
return get_session_maker(engine)()


def sanitize_db_url(db_url: str) -> str:
Expand Down
13 changes: 1 addition & 12 deletions quetz/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,8 @@ def get_config():


def get_db(config: Config = Depends(get_config)):
database_url = config.sqlalchemy_database_url
db = get_db_session(
database_url,
echo=config.sqlalchemy_echo_sql,
postgres_kwargs=dict(
pool_size=config.sqlalchemy_postgres_pool_size,
max_overflow=config.sqlalchemy_postgres_max_overflow,
),
)
try:
with get_db_session(config) as db:
yield db
finally:
db.close()


def get_dao(db: Session = Depends(get_db)):
Expand Down
Loading

0 comments on commit 39055c2

Please sign in to comment.