Skip to content

Commit

Permalink
Reuse db_engine instead of recreating (#201)
Browse files Browse the repository at this point in the history
* Reuse db_engine instead of recreating

* Reuse test_db_engine too
  • Loading branch information
ricardogsilva authored Aug 23, 2024
1 parent 5358dc5 commit ea23967
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 16 deletions.
32 changes: 28 additions & 4 deletions arpav_ppcv/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,36 @@

logger = logging.getLogger(__name__)

_DB_ENGINE = None
_TEST_DB_ENGINE = None


def get_engine(settings: config.ArpavPpcvSettings, use_test_db: Optional[bool] = False):
db_dsn = settings.test_db_dsn if use_test_db else settings.db_dsn
return sqlmodel.create_engine(
db_dsn.unicode_string(), echo=True if settings.verbose_db_logs else False
)
# This function implements caching of the sqlalchemy engine, relying on the
# value of the module global `_DB_ENGINE` variable. This is done in order to
# - reuse the same database engine throughout the lifecycle of the application
# - provide an opportunity to clear the cache when needed (e.g.: in the fastapi
# lifespan function)
#
# Note: this function cannot use the `functools.cache` decorator because
# the `settings` parameter is not hashable
if use_test_db:
global _TEST_DB_ENGINE
if _TEST_DB_ENGINE is None:
_TEST_DB_ENGINE = sqlmodel.create_engine(
settings.test_db_dsn.unicode_string(),
echo=True if settings.verbose_db_logs else False,
)
result = _TEST_DB_ENGINE
else:
global _DB_ENGINE
if _DB_ENGINE is None:
_DB_ENGINE = sqlmodel.create_engine(
settings.db_dsn.unicode_string(),
echo=True if settings.verbose_db_logs else False,
)
result = _DB_ENGINE
return result


def create_variable(
Expand Down
13 changes: 5 additions & 8 deletions arpav_ppcv/prefect/flows/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# this is a module global because we need to configure the prefect flow and
# task with values from it
settings = get_settings()
db_engine = database.get_engine(settings)


@prefect.task(
Expand Down Expand Up @@ -92,9 +93,8 @@ def refresh_stations(
refresh_stations_with_seasonal_data: bool = True,
refresh_stations_with_yearly_data: bool = True,
):
settings = get_settings()
client = httpx.Client()
with sqlmodel.Session(database.get_engine(settings)) as db_session:
with sqlmodel.Session(db_engine) as db_session:
db_variables = _get_variables(db_session, variable_name)
if len(db_variables) > 0:
to_filter_for_new_stations = set()
Expand Down Expand Up @@ -211,10 +211,9 @@ def refresh_monthly_measurements(
variable_name: str | None = None,
month: int | None = None,
):
settings = get_settings()
client = httpx.Client()
all_created = []
with sqlmodel.Session(database.get_engine(settings)) as db_session:
with sqlmodel.Session(db_engine) as db_session:
if len(db_variables := _get_variables(db_session, variable_name)) > 0:
if len(db_stations := _get_stations(db_session, station_code)) > 0:
for db_station in db_stations:
Expand Down Expand Up @@ -317,10 +316,9 @@ def refresh_seasonal_measurements(
variable_name: str | None = None,
season_name: str | None = None,
):
settings = get_settings()
client = httpx.Client()
all_created = []
with sqlmodel.Session(database.get_engine(settings)) as db_session:
with sqlmodel.Session(db_engine) as db_session:
if len(db_variables := _get_variables(db_session, variable_name)) > 0:
if len(db_stations := _get_stations(db_session, station_code)) > 0:
for db_station in db_stations:
Expand Down Expand Up @@ -410,10 +408,9 @@ def refresh_yearly_measurements(
station_code: str | None = None,
variable_name: str | None = None,
):
settings = get_settings()
client = httpx.Client()
all_created = []
with sqlmodel.Session(database.get_engine(settings)) as db_session:
with sqlmodel.Session(db_engine) as db_session:
if len(db_variables := _get_variables(db_session, variable_name)) > 0:
if len(db_stations := _get_stations(db_session, station_code)) > 0:
for db_station in db_stations:
Expand Down
21 changes: 17 additions & 4 deletions arpav_ppcv/webapp/app.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,31 @@
import fastapi
import contextlib

from starlette.applications import Starlette
from starlette.staticfiles import StaticFiles
from starlette.templating import Jinja2Templates

from .. import config
from .. import (
config,
database,
)
from .api_v2.app import create_app as create_v2_app
from .admin.app import create_admin
from .routes import routes


def create_app_from_settings(settings: config.ArpavPpcvSettings) -> fastapi.FastAPI:
@contextlib.asynccontextmanager
async def lifespan(app: Starlette):
yield
# ensure the database engine is properly disposed of, closing any connections
database._DB_ENGINE.dispose() # noqa
database._DB_ENGINE = None


def create_app_from_settings(settings: config.ArpavPpcvSettings) -> Starlette:
app = Starlette(
debug=settings.debug,
routes=routes,
lifespan=lifespan,
)
settings.static_dir.mkdir(parents=True, exist_ok=True)
app.mount("/static", StaticFiles(directory=settings.static_dir), name="static")
Expand All @@ -28,6 +41,6 @@ def create_app_from_settings(settings: config.ArpavPpcvSettings) -> fastapi.Fast
return app


def create_app() -> fastapi.FastAPI:
def create_app() -> Starlette:
settings = config.get_settings()
return create_app_from_settings(settings)

0 comments on commit ea23967

Please sign in to comment.