Skip to content

Commit

Permalink
Unify lazy db sequence implementations (apache#39426)
Browse files Browse the repository at this point in the history
  • Loading branch information
uranusjr authored and pateash committed May 13, 2024
1 parent ec6b0f5 commit f86a46d
Show file tree
Hide file tree
Showing 10 changed files with 256 additions and 205 deletions.
8 changes: 4 additions & 4 deletions airflow/lineage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,10 @@ def wrapper(self, context, *args, **kwargs):
# Remove auto and task_ids
self.inlets = [i for i in self.inlets if not isinstance(i, str)]

# We manually create a session here since xcom_pull returns a LazyXComAccess iterator.
# If we do not pass a session a new session will be created, however that session will not be
# properly closed and will remain open. After we are done iterating we can safely close this
# session.
# We manually create a session here since xcom_pull returns a
# LazySelectSequence proxy. If we do not pass a session, a new one
# will be created, but that session will not be properly closed.
# After we are done iterating, we can safely close this session.
with create_session() as session:
_inlets = self.xcom_pull(
context, task_ids=task_ids, dag_id=self.dag_id, key=PIPELINE_OUTLETS, session=session
Expand Down
7 changes: 5 additions & 2 deletions airflow/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# under the License.
from __future__ import annotations

from typing import Any
from typing import TYPE_CHECKING, Any

from sqlalchemy import Column, Integer, MetaData, String, text
from sqlalchemy.orm import registry
Expand Down Expand Up @@ -48,7 +48,10 @@ def _get_schema():
mapper_registry = registry(metadata=metadata)
_sentinel = object()

Base: Any = mapper_registry.generate_base()
if TYPE_CHECKING:
Base = Any
else:
Base = mapper_registry.generate_base()

ID_LEN = 250

Expand Down
39 changes: 21 additions & 18 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.models.taskmap import TaskMap
from airflow.models.taskreschedule import TaskReschedule
from airflow.models.xcom import LazyXComAccess, XCom
from airflow.models.xcom import LazyXComSelectSequence, XCom
from airflow.plugins_manager import integrate_macros_plugins
from airflow.sentry import Sentry
from airflow.settings import task_instance_mutation_hook
Expand Down Expand Up @@ -3358,34 +3358,37 @@ def xcom_pull(
return default
if map_indexes is not None or first.map_index < 0:
return XCom.deserialize_value(first)
query = query.order_by(None).order_by(XCom.map_index.asc())
return LazyXComAccess.build_from_xcom_query(query)
return LazyXComSelectSequence.from_select(
query.with_entities(XCom.value).order_by(None).statement,
order_by=[XCom.map_index],
session=session,
)

# At this point either task_ids or map_indexes is explicitly multi-value.
# Order return values to match task_ids and map_indexes ordering.
query = query.order_by(None)
ordering = []
if task_ids is None or isinstance(task_ids, str):
query = query.order_by(XCom.task_id)
ordering.append(XCom.task_id)
elif task_id_whens := {tid: i for i, tid in enumerate(task_ids)}:
ordering.append(case(task_id_whens, value=XCom.task_id))
else:
task_id_whens = {tid: i for i, tid in enumerate(task_ids)}
if task_id_whens:
query = query.order_by(case(task_id_whens, value=XCom.task_id))
else:
query = query.order_by(XCom.task_id)
ordering.append(XCom.task_id)
if map_indexes is None or isinstance(map_indexes, int):
query = query.order_by(XCom.map_index)
ordering.append(XCom.map_index)
elif isinstance(map_indexes, range):
order = XCom.map_index
if map_indexes.step < 0:
order = order.desc()
query = query.order_by(order)
ordering.append(order)
elif map_index_whens := {map_index: i for i, map_index in enumerate(map_indexes)}:
ordering.append(case(map_index_whens, value=XCom.map_index))
else:
map_index_whens = {map_index: i for i, map_index in enumerate(map_indexes)}
if map_index_whens:
query = query.order_by(case(map_index_whens, value=XCom.map_index))
else:
query = query.order_by(XCom.map_index)
return LazyXComAccess.build_from_xcom_query(query)
ordering.append(XCom.map_index)
return LazyXComSelectSequence.from_select(
query.with_entities(XCom.value).order_by(None).statement,
order_by=ordering,
session=session,
)

@provide_session
def get_num_running_task_instances(self, session: Session, same_dagrun: bool = False) -> int:
Expand Down
125 changes: 16 additions & 109 deletions airflow/models/xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,14 @@
# under the License.
from __future__ import annotations

import collections.abc
import contextlib
import inspect
import itertools
import json
import logging
import pickle
import warnings
from functools import cached_property, wraps
from typing import TYPE_CHECKING, Any, Generator, Iterable, cast, overload
from functools import wraps
from typing import TYPE_CHECKING, Any, Iterable, cast, overload

import attr
from sqlalchemy import (
Column,
ForeignKeyConstraint,
Expand All @@ -38,19 +34,20 @@
PrimaryKeyConstraint,
String,
delete,
select,
text,
)
from sqlalchemy.dialects.mysql import LONGBLOB
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.orm import Query, reconstructor, relationship
from sqlalchemy.orm.exc import NoResultFound

from airflow import settings
from airflow.api_internal.internal_api_call import internal_api_call
from airflow.configuration import conf
from airflow.exceptions import RemovedInAirflow3Warning
from airflow.models.base import COLLATION_ARGS, ID_LEN, TaskInstanceDependencies
from airflow.utils import timezone
from airflow.utils.db import LazySelectSequence
from airflow.utils.helpers import exactly_one, is_container
from airflow.utils.json import XComDecoder, XComEncoder
from airflow.utils.log.logging_mixin import LoggingMixin
Expand All @@ -70,7 +67,9 @@
import datetime

import pendulum
from sqlalchemy.engine import Row
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import Select, TextClause

from airflow.models.taskinstancekey import TaskInstanceKey

Expand Down Expand Up @@ -222,11 +221,11 @@ def set(
if dag_run_id is None:
raise ValueError(f"DAG run not found on DAG {dag_id!r} with ID {run_id!r}")

# Seamlessly resolve LazyXComAccess to a list. This is intended to work
# Seamlessly resolve LazySelectSequence to a list. This intends to work
# as a "lazy list" to avoid pulling a ton of XComs unnecessarily, but if
# it's pushed into XCom, the user should be aware of the performance
# implications, and this avoids leaking the implementation detail.
if isinstance(value, LazyXComAccess):
if isinstance(value, LazySelectSequence):
warning_message = (
"Coercing mapped lazy proxy %s from task %s (DAG %s, run %s) "
"to list, which may degrade performance. Review resource "
Expand Down Expand Up @@ -716,111 +715,19 @@ def orm_deserialize_value(self) -> Any:
return BaseXCom._deserialize_value(self, True)


class _LazyXComAccessIterator(collections.abc.Iterator):
def __init__(self, cm: contextlib.AbstractContextManager[Query]) -> None:
self._cm = cm
self._entered = False

def __del__(self) -> None:
if self._entered:
self._cm.__exit__(None, None, None)

def __iter__(self) -> collections.abc.Iterator:
return self

def __next__(self) -> Any:
return XCom.deserialize_value(next(self._it))

@cached_property
def _it(self) -> collections.abc.Iterator:
self._entered = True
return iter(self._cm.__enter__())


@attr.define(slots=True)
class LazyXComAccess(collections.abc.Sequence):
"""Wrapper to lazily pull XCom with a sequence-like interface.
Note that since the session bound to the parent query may have died when we
actually access the sequence's content, we must create a new session
for every function call with ``with_session()``.
class LazyXComSelectSequence(LazySelectSequence[Any]):
"""List-like interface to lazily access XCom values.
:meta private:
"""

_query: Query
_len: int | None = attr.ib(init=False, default=None)

@classmethod
def build_from_xcom_query(cls, query: Query) -> LazyXComAccess:
return cls(query=query.with_entities(XCom.value))

def __repr__(self) -> str:
return f"LazyXComAccess([{len(self)} items])"

def __str__(self) -> str:
return str(list(self))

def __eq__(self, other: Any) -> bool:
if isinstance(other, (list, LazyXComAccess)):
z = itertools.zip_longest(iter(self), iter(other), fillvalue=object())
return all(x == y for x, y in z)
return NotImplemented

def __getstate__(self) -> Any:
# We don't want to go to the trouble of serializing the entire Query
# object, including its filters, hints, etc. (plus SQLAlchemy does not
# provide a public API to inspect a query's contents). Converting the
# query into a SQL string is the best we can get. Theoratically we can
# do the same for count(), but I think it should be performant enough to
# calculate only that eagerly.
with self._get_bound_query() as query:
statement = query.statement.compile(
query.session.get_bind(),
# This inlines all the values into the SQL string to simplify
# cross-process commuinication as much as possible.
compile_kwargs={"literal_binds": True},
)
return (str(statement), query.count())

def __setstate__(self, state: Any) -> None:
statement, self._len = state
self._query = Query(XCom.value).from_statement(text(statement))

def __len__(self):
if self._len is None:
with self._get_bound_query() as query:
self._len = query.count()
return self._len

def __iter__(self):
return _LazyXComAccessIterator(self._get_bound_query())
@staticmethod
def _rebuild_select(stmt: TextClause) -> Select:
return select(XCom.value).from_statement(stmt)

def __getitem__(self, key):
if not isinstance(key, int):
raise ValueError("only support index access for now")
try:
with self._get_bound_query() as query:
r = query.offset(key).limit(1).one()
except NoResultFound:
raise IndexError(key) from None
return XCom.deserialize_value(r)

@contextlib.contextmanager
def _get_bound_query(self) -> Generator[Query, None, None]:
# Do we have a valid session already?
if self._query.session and self._query.session.is_active:
yield self._query
return

Session = getattr(settings, "Session", None)
if Session is None:
raise RuntimeError("Session must be set before!")
session = Session()
try:
yield self._query.with_session(session)
finally:
session.close()
@staticmethod
def _process_row(row: Row) -> Any:
return XCom.deserialize_value(row)


def _patch_outdated_serializer(clazz: type[BaseXCom], params: Iterable[str]) -> None:
Expand Down
6 changes: 6 additions & 0 deletions airflow/typing_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"Literal",
"ParamSpec",
"Protocol",
"Self",
"TypedDict",
"TypeGuard",
"runtime_checkable",
Expand All @@ -45,3 +46,8 @@
from typing import ParamSpec, TypeGuard
else:
from typing_extensions import ParamSpec, TypeGuard

if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self
Loading

0 comments on commit f86a46d

Please sign in to comment.