Skip to content

Commit

Permalink
fixup! Upgrade dependencies to allow us to use SQLAlchemy v2 in Airfl…
Browse files Browse the repository at this point in the history
…ow 3.0/main
  • Loading branch information
ashb committed Oct 11, 2024
1 parent 047ae87 commit b08fac3
Show file tree
Hide file tree
Showing 8 changed files with 1,855 additions and 1,821 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@

_STRING_COLUMN_TYPE = sa.String(length=1500).with_variant(
sa.String(length=1500, collation="latin1_general_cs"),
dialect_name="mysql",
"mysql",
)


Expand Down Expand Up @@ -85,7 +85,7 @@ def downgrade():
"uri",
type_=sa.String(length=3000).with_variant(
sa.String(length=3000, collation="latin1_general_cs"),
dialect_name="mysql",
"mysql",
),
nullable=False,
)
Expand Down
16 changes: 13 additions & 3 deletions airflow/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing import TYPE_CHECKING, Any

from sqlalchemy import Column, Integer, MetaData, String, text
from sqlalchemy.orm import registry
from sqlalchemy.orm import DeclarativeBase

from airflow.configuration import conf

Expand All @@ -45,13 +45,23 @@ def _get_schema():


metadata = MetaData(schema=_get_schema(), naming_convention=naming_convention)
mapper_registry = registry(metadata=metadata)
_sentinel = object()

if TYPE_CHECKING:
Base = Any
else:
Base = mapper_registry.generate_base()

class Base(DeclarativeBase):
"""
Base class to ease transition to SQLAv2.
:meta private:
"""

metadata = metadata
# https://docs.sqlalchemy.org/en/20/changelog/migration_20.html#migration-20-step-six
__allow_unmapped__ = True


ID_LEN = 250

Expand Down
23 changes: 15 additions & 8 deletions airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,7 @@ def create_default_connections(session: Session = NEW_SESSION):
)


@contextlib.contextmanager
def _get_flask_db(sql_database_uri):
from flask import Flask
from flask_sqlalchemy import SQLAlchemy
Expand All @@ -744,7 +745,8 @@ def _get_flask_db(sql_database_uri):
flask_app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
db = SQLAlchemy(flask_app)
AirflowDatabaseSessionInterface(app=flask_app, db=db, table="session", key_prefix="")
return db
with flask_app.app_context():
yield db


def _create_db_from_orm(session):
Expand All @@ -753,8 +755,8 @@ def _create_db_from_orm(session):
from airflow.models.base import Base

def _create_flask_session_tbl(sql_database_uri):
db = _get_flask_db(sql_database_uri)
db.create_all()
with _get_flask_db(sql_database_uri) as db:
db.create_all()

with create_global_lock(session=session, lock=DBLocks.MIGRATIONS):
engine = session.get_bind().engine
Expand Down Expand Up @@ -1284,8 +1286,8 @@ def drop_airflow_models(connection):
from airflow.models.base import Base

Base.metadata.drop_all(connection)
db = _get_flask_db(connection.engine.url)
db.drop_all()
with _get_flask_db(connection.engine.url) as db:
db.drop_all()
# alembic adds significant import time, so we import it lazily
from alembic.migration import MigrationContext

Expand Down Expand Up @@ -1340,13 +1342,18 @@ def create_global_lock(
lock_timeout: int = 1800,
) -> Generator[None, None, None]:
"""Contextmanager that will create and teardown a global db lock."""
conn = session.get_bind().connect()
bind = session.get_bind()
if isinstance(bind, Engine):
conn = bind.connect()
else:
conn = bind
dialect = conn.dialect
mysql_supports_locks = dialect.name == "mysql" and dialect.server_version_info and dialect.server_version_info >= (5, 6)
try:
if dialect.name == "postgresql":
conn.execute(text("SET LOCK_TIMEOUT to :timeout"), {"timeout": lock_timeout})
conn.execute(text("SELECT pg_advisory_lock(:id)"), {"id": lock.value})
elif dialect.name == "mysql" and dialect.server_version_info >= (5, 6):
elif mysql_supports_locks:
conn.execute(text("SELECT GET_LOCK(:id, :timeout)"), {"id": str(lock), "timeout": lock_timeout})

yield
Expand All @@ -1356,7 +1363,7 @@ def create_global_lock(
(unlocked,) = conn.execute(text("SELECT pg_advisory_unlock(:id)"), {"id": lock.value}).fetchone()
if not unlocked:
raise RuntimeError("Error releasing DB lock!")
elif dialect.name == "mysql" and dialect.server_version_info >= (5, 6):
elif mysql_supports_locks:
conn.execute(text("select RELEASE_LOCK(:id)"), {"id": str(lock)})


Expand Down
2 changes: 1 addition & 1 deletion dev/breeze/src/airflow_breeze/global_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ def get_airflow_extras():
# END OF EXTRAS LIST UPDATED BY PRE COMMIT
]

CHICKEN_EGG_PROVIDERS = " ".join(["standard"])
CHICKEN_EGG_PROVIDERS = " ".join(["standard", "fab", "amazon"])


BASE_PROVIDERS_COMPATIBILITY_CHECKS: list[dict[str, str | list[str]]] = [
Expand Down
2 changes: 1 addition & 1 deletion docs/apache-airflow/img/airflow_erd.sha256
Original file line number Diff line number Diff line change
@@ -1 +1 @@
8bd129828ba299ef05d70305eee66d15b6c0c79dc6ae82f654b9657464e3682a
c4498d5a4d0f05418a13d74b267539927c3f89860610c5973fd040d2b34038e2
Loading

0 comments on commit b08fac3

Please sign in to comment.