diff --git a/airflow/lineage/__init__.py b/airflow/lineage/__init__.py index 6500d8008ea4..332a04e7250b 100644 --- a/airflow/lineage/__init__.py +++ b/airflow/lineage/__init__.py @@ -130,10 +130,10 @@ def wrapper(self, context, *args, **kwargs): # Remove auto and task_ids self.inlets = [i for i in self.inlets if not isinstance(i, str)] - # We manually create a session here since xcom_pull returns a LazyXComAccess iterator. - # If we do not pass a session a new session will be created, however that session will not be - # properly closed and will remain open. After we are done iterating we can safely close this - # session. + # We manually create a session here since xcom_pull returns a + # LazySelectSequence proxy. If we do not pass a session, a new one + # will be created, but that session will not be properly closed. + # After we are done iterating, we can safely close this session. with create_session() as session: _inlets = self.xcom_pull( context, task_ids=task_ids, dag_id=self.dag_id, key=PIPELINE_OUTLETS, session=session diff --git a/airflow/models/base.py b/airflow/models/base.py index 1cde7d75613b..e9f86f8d7e67 100644 --- a/airflow/models/base.py +++ b/airflow/models/base.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING, Any from sqlalchemy import Column, Integer, MetaData, String, text from sqlalchemy.orm import registry @@ -48,7 +48,10 @@ def _get_schema(): mapper_registry = registry(metadata=metadata) _sentinel = object() -Base: Any = mapper_registry.generate_base() +if TYPE_CHECKING: + Base = Any +else: + Base = mapper_registry.generate_base() ID_LEN = 250 diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index ef2a41ac91ab..1a9d1e00362e 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -97,7 +97,7 @@ from airflow.models.taskinstancekey import TaskInstanceKey from airflow.models.taskmap import TaskMap from airflow.models.taskreschedule import TaskReschedule -from airflow.models.xcom import LazyXComAccess, XCom +from airflow.models.xcom import LazyXComSelectSequence, XCom from airflow.plugins_manager import integrate_macros_plugins from airflow.sentry import Sentry from airflow.settings import task_instance_mutation_hook @@ -3358,34 +3358,37 @@ def xcom_pull( return default if map_indexes is not None or first.map_index < 0: return XCom.deserialize_value(first) - query = query.order_by(None).order_by(XCom.map_index.asc()) - return LazyXComAccess.build_from_xcom_query(query) + return LazyXComSelectSequence.from_select( + query.with_entities(XCom.value).order_by(None).statement, + order_by=[XCom.map_index], + session=session, + ) # At this point either task_ids or map_indexes is explicitly multi-value. # Order return values to match task_ids and map_indexes ordering. - query = query.order_by(None) + ordering = [] if task_ids is None or isinstance(task_ids, str): - query = query.order_by(XCom.task_id) + ordering.append(XCom.task_id) + elif task_id_whens := {tid: i for i, tid in enumerate(task_ids)}: + ordering.append(case(task_id_whens, value=XCom.task_id)) else: - task_id_whens = {tid: i for i, tid in enumerate(task_ids)} - if task_id_whens: - query = query.order_by(case(task_id_whens, value=XCom.task_id)) - else: - query = query.order_by(XCom.task_id) + ordering.append(XCom.task_id) if map_indexes is None or isinstance(map_indexes, int): - query = query.order_by(XCom.map_index) + ordering.append(XCom.map_index) elif isinstance(map_indexes, range): order = XCom.map_index if map_indexes.step < 0: order = order.desc() - query = query.order_by(order) + ordering.append(order) + elif map_index_whens := {map_index: i for i, map_index in enumerate(map_indexes)}: + ordering.append(case(map_index_whens, value=XCom.map_index)) else: - map_index_whens = {map_index: i for i, map_index in enumerate(map_indexes)} - if map_index_whens: - query = query.order_by(case(map_index_whens, value=XCom.map_index)) - else: - query = query.order_by(XCom.map_index) - return LazyXComAccess.build_from_xcom_query(query) + ordering.append(XCom.map_index) + return LazyXComSelectSequence.from_select( + query.with_entities(XCom.value).order_by(None).statement, + order_by=ordering, + session=session, + ) @provide_session def get_num_running_task_instances(self, session: Session, same_dagrun: bool = False) -> int: diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py index 7a6695c1b30f..fe1ebadc2e59 100644 --- a/airflow/models/xcom.py +++ b/airflow/models/xcom.py @@ -17,18 +17,14 @@ # under the License. from __future__ import annotations -import collections.abc -import contextlib import inspect -import itertools import json import logging import pickle import warnings -from functools import cached_property, wraps -from typing import TYPE_CHECKING, Any, Generator, Iterable, cast, overload +from functools import wraps +from typing import TYPE_CHECKING, Any, Iterable, cast, overload -import attr from sqlalchemy import ( Column, ForeignKeyConstraint, @@ -38,6 +34,7 @@ PrimaryKeyConstraint, String, delete, + select, text, ) from sqlalchemy.dialects.mysql import LONGBLOB @@ -45,12 +42,12 @@ from sqlalchemy.orm import Query, reconstructor, relationship from sqlalchemy.orm.exc import NoResultFound -from airflow import settings from airflow.api_internal.internal_api_call import internal_api_call from airflow.configuration import conf from airflow.exceptions import RemovedInAirflow3Warning from airflow.models.base import COLLATION_ARGS, ID_LEN, TaskInstanceDependencies from airflow.utils import timezone +from airflow.utils.db import LazySelectSequence from airflow.utils.helpers import exactly_one, is_container from airflow.utils.json import XComDecoder, XComEncoder from airflow.utils.log.logging_mixin import LoggingMixin @@ -70,7 +67,9 @@ import datetime import pendulum + from sqlalchemy.engine import Row from sqlalchemy.orm import Session + from sqlalchemy.sql.expression import Select, TextClause from airflow.models.taskinstancekey import TaskInstanceKey @@ -222,11 +221,11 @@ def set( if dag_run_id is None: raise ValueError(f"DAG run not found on DAG {dag_id!r} with ID {run_id!r}") - # Seamlessly resolve LazyXComAccess to a list. This is intended to work + # Seamlessly resolve LazySelectSequence to a list. This intends to work # as a "lazy list" to avoid pulling a ton of XComs unnecessarily, but if # it's pushed into XCom, the user should be aware of the performance # implications, and this avoids leaking the implementation detail. - if isinstance(value, LazyXComAccess): + if isinstance(value, LazySelectSequence): warning_message = ( "Coercing mapped lazy proxy %s from task %s (DAG %s, run %s) " "to list, which may degrade performance. Review resource " @@ -716,111 +715,19 @@ def orm_deserialize_value(self) -> Any: return BaseXCom._deserialize_value(self, True) -class _LazyXComAccessIterator(collections.abc.Iterator): - def __init__(self, cm: contextlib.AbstractContextManager[Query]) -> None: - self._cm = cm - self._entered = False - - def __del__(self) -> None: - if self._entered: - self._cm.__exit__(None, None, None) - - def __iter__(self) -> collections.abc.Iterator: - return self - - def __next__(self) -> Any: - return XCom.deserialize_value(next(self._it)) - - @cached_property - def _it(self) -> collections.abc.Iterator: - self._entered = True - return iter(self._cm.__enter__()) - - -@attr.define(slots=True) -class LazyXComAccess(collections.abc.Sequence): - """Wrapper to lazily pull XCom with a sequence-like interface. - - Note that since the session bound to the parent query may have died when we - actually access the sequence's content, we must create a new session - for every function call with ``with_session()``. +class LazyXComSelectSequence(LazySelectSequence[Any]): + """List-like interface to lazily access XCom values. :meta private: """ - _query: Query - _len: int | None = attr.ib(init=False, default=None) - - @classmethod - def build_from_xcom_query(cls, query: Query) -> LazyXComAccess: - return cls(query=query.with_entities(XCom.value)) - - def __repr__(self) -> str: - return f"LazyXComAccess([{len(self)} items])" - - def __str__(self) -> str: - return str(list(self)) - - def __eq__(self, other: Any) -> bool: - if isinstance(other, (list, LazyXComAccess)): - z = itertools.zip_longest(iter(self), iter(other), fillvalue=object()) - return all(x == y for x, y in z) - return NotImplemented - - def __getstate__(self) -> Any: - # We don't want to go to the trouble of serializing the entire Query - # object, including its filters, hints, etc. (plus SQLAlchemy does not - # provide a public API to inspect a query's contents). Converting the - # query into a SQL string is the best we can get. Theoratically we can - # do the same for count(), but I think it should be performant enough to - # calculate only that eagerly. - with self._get_bound_query() as query: - statement = query.statement.compile( - query.session.get_bind(), - # This inlines all the values into the SQL string to simplify - # cross-process commuinication as much as possible. - compile_kwargs={"literal_binds": True}, - ) - return (str(statement), query.count()) - - def __setstate__(self, state: Any) -> None: - statement, self._len = state - self._query = Query(XCom.value).from_statement(text(statement)) - - def __len__(self): - if self._len is None: - with self._get_bound_query() as query: - self._len = query.count() - return self._len - - def __iter__(self): - return _LazyXComAccessIterator(self._get_bound_query()) + @staticmethod + def _rebuild_select(stmt: TextClause) -> Select: + return select(XCom.value).from_statement(stmt) - def __getitem__(self, key): - if not isinstance(key, int): - raise ValueError("only support index access for now") - try: - with self._get_bound_query() as query: - r = query.offset(key).limit(1).one() - except NoResultFound: - raise IndexError(key) from None - return XCom.deserialize_value(r) - - @contextlib.contextmanager - def _get_bound_query(self) -> Generator[Query, None, None]: - # Do we have a valid session already? - if self._query.session and self._query.session.is_active: - yield self._query - return - - Session = getattr(settings, "Session", None) - if Session is None: - raise RuntimeError("Session must be set before!") - session = Session() - try: - yield self._query.with_session(session) - finally: - session.close() + @staticmethod + def _process_row(row: Row) -> Any: + return XCom.deserialize_value(row) def _patch_outdated_serializer(clazz: type[BaseXCom], params: Iterable[str]) -> None: diff --git a/airflow/typing_compat.py b/airflow/typing_compat.py index 5ae2d236abc0..ba96c92d77f0 100644 --- a/airflow/typing_compat.py +++ b/airflow/typing_compat.py @@ -23,6 +23,7 @@ "Literal", "ParamSpec", "Protocol", + "Self", "TypedDict", "TypeGuard", "runtime_checkable", @@ -45,3 +46,8 @@ from typing import ParamSpec, TypeGuard else: from typing_extensions import ParamSpec, TypeGuard + +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self diff --git a/airflow/utils/context.py b/airflow/utils/context.py index 2b7302098d15..58e688f2d995 100644 --- a/airflow/utils/context.py +++ b/airflow/utils/context.py @@ -32,25 +32,26 @@ KeysView, Mapping, MutableMapping, - Sequence, SupportsIndex, ValuesView, - overload, ) import attrs import lazy_object_proxy +from sqlalchemy import select from airflow.datasets import Dataset, coerce_to_uri from airflow.exceptions import RemovedInAirflow3Warning +from airflow.models.dataset import DatasetEvent, DatasetModel +from airflow.utils.db import LazySelectSequence from airflow.utils.types import NOTSET if TYPE_CHECKING: + from sqlalchemy.engine import Row from sqlalchemy.orm import Session - from sqlalchemy.sql.expression import Select + from sqlalchemy.sql.expression import Select, TextClause from airflow.models.baseoperator import BaseOperator - from airflow.models.dataset import DatasetEvent # NOTE: Please keep this in sync with the following: # * Context in airflow/utils/context.pyi. @@ -187,57 +188,23 @@ def __getitem__(self, key: str | Dataset) -> OutletEventAccessor: return self._dict[uri] -@attrs.define() -class InletEventsAccessor(Sequence["DatasetEvent"]): - """Lazy sequence to access inlet dataset events. +class LazyDatasetEventSelectSequence(LazySelectSequence[DatasetEvent]): + """List-like interface to lazily access DatasetEvent rows. :meta private: """ - _uri: str - _session: Session - - def _get_select_stmt(self, *, reverse: bool = False) -> Select: - from sqlalchemy import select - - from airflow.models.dataset import DatasetEvent, DatasetModel - - stmt = select(DatasetEvent).join(DatasetEvent.dataset).where(DatasetModel.uri == self._uri) - if reverse: - return stmt.order_by(DatasetEvent.timestamp.desc()) - return stmt.order_by(DatasetEvent.timestamp.asc()) - - def __reversed__(self) -> Iterator[DatasetEvent]: - return iter(self._session.scalar(self._get_select_stmt(reverse=True))) - - def __iter__(self) -> Iterator[DatasetEvent]: - return iter(self._session.scalar(self._get_select_stmt())) - - @overload - def __getitem__(self, key: int) -> DatasetEvent: ... - - @overload - def __getitem__(self, key: slice) -> Sequence[DatasetEvent]: ... - - def __getitem__(self, key: int | slice) -> DatasetEvent | Sequence[DatasetEvent]: - if not isinstance(key, int): - raise ValueError("non-index access is not supported") - if key >= 0: - stmt = self._get_select_stmt().offset(key) - else: - stmt = self._get_select_stmt(reverse=True).offset(-1 - key) - if (event := self._session.scalar(stmt.limit(1))) is not None: - return event - raise IndexError(key) - - def __len__(self) -> int: - from sqlalchemy import func, select + @staticmethod + def _rebuild_select(stmt: TextClause) -> Select: + return select(DatasetEvent).from_statement(stmt) - return self._session.scalar(select(func.count()).select_from(self._get_select_stmt())) + @staticmethod + def _process_row(row: Row) -> DatasetEvent: + return row[0] @attrs.define(init=False) -class InletEventsAccessors(Mapping[str, InletEventsAccessor]): +class InletEventsAccessors(Mapping[str, LazyDatasetEventSelectSequence]): """Lazy mapping for inlet dataset events accessors. :meta private: @@ -258,14 +225,18 @@ def __iter__(self) -> Iterator[str]: def __len__(self) -> int: return len(self._inlets) - def __getitem__(self, key: int | str | Dataset) -> InletEventsAccessor: + def __getitem__(self, key: int | str | Dataset) -> LazyDatasetEventSelectSequence: if isinstance(key, int): # Support index access; it's easier for trivial cases. dataset = self._inlets[key] if not isinstance(dataset, Dataset): raise IndexError(key) else: dataset = self._datasets[coerce_to_uri(key)] - return InletEventsAccessor(dataset.uri, session=self._session) + return LazyDatasetEventSelectSequence.from_select( + select(DatasetEvent).join(DatasetEvent.dataset).where(DatasetModel.uri == dataset.uri), + order_by=[DatasetEvent.timestamp], + session=self._session, + ) class AirflowContextDeprecationWarning(RemovedInAirflow3Warning): diff --git a/airflow/utils/db.py b/airflow/utils/db.py index e7afb04014e2..fb6416966b33 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -17,8 +17,10 @@ # under the License. from __future__ import annotations +import collections.abc import contextlib import enum +import itertools import json import logging import os @@ -27,8 +29,20 @@ import warnings from dataclasses import dataclass from tempfile import gettempdir -from typing import TYPE_CHECKING, Callable, Generator, Iterable +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generator, + Iterable, + Iterator, + Protocol, + Sequence, + TypeVar, + overload, +) +import attrs from sqlalchemy import ( Table, and_, @@ -54,16 +68,28 @@ # TODO: remove create_session once we decide to break backward compatibility from airflow.utils.session import NEW_SESSION, create_session, provide_session # noqa: F401 +from airflow.utils.task_instance_session import get_current_task_instance_session if TYPE_CHECKING: from alembic.runtime.environment import EnvironmentContext from alembic.script import ScriptDirectory + from sqlalchemy.engine import Row from sqlalchemy.orm import Query, Session - from sqlalchemy.sql.elements import ClauseElement + from sqlalchemy.sql.elements import ClauseElement, TextClause from sqlalchemy.sql.selectable import Select - from airflow.models.base import Base from airflow.models.connection import Connection + from airflow.typing_compat import Self + + # TODO: Import this from sqlalchemy.orm instead when switching to SQLA 2. + # https://docs.sqlalchemy.org/en/20/orm/mapping_api.html#sqlalchemy.orm.MappedClassProtocol + class MappedClassProtocol(Protocol): + """Protocol for SQLALchemy model base.""" + + __tablename__: str + + +T = TypeVar("T") log = logging.getLogger(__name__) @@ -1028,7 +1054,7 @@ def check_username_duplicates(session: Session) -> Iterable[str]: ) -def reflect_tables(tables: list[Base | str] | None, session): +def reflect_tables(tables: list[MappedClassProtocol | str] | None, session): """ When running checks prior to upgrades, we use reflection to determine current state of the database. @@ -1416,7 +1442,7 @@ class BadReferenceConfig: ref_table="task_instance", ) - models_list: list[tuple[Base, str, BadReferenceConfig]] = [ + models_list: list[tuple[MappedClassProtocol, str, BadReferenceConfig]] = [ (TaskInstance, "2.2", missing_dag_run_config), (TaskReschedule, "2.2", missing_ti_config), (RenderedTaskInstanceFields, "2.3", missing_ti_config), @@ -1875,7 +1901,7 @@ def get_sqla_model_classes(): def get_query_count(query_stmt: Select, *, session: Session) -> int: - """Get count of query. + """Get count of a query. A SELECT COUNT() FROM is issued against the subquery built from the given statement. The ORDER BY clause is stripped from the statement @@ -1888,8 +1914,21 @@ def get_query_count(query_stmt: Select, *, session: Session) -> int: return session.scalar(count_stmt) +def check_query_exists(query_stmt: Select, *, session: Session) -> bool: + """Check whether there is at least one row matching a query. + + A SELECT 1 FROM is issued against the subquery built from the given + statement. The ORDER BY clause is stripped from the statement since it's + unnecessary, and can impact query planning and degrade performance. + + :meta private: + """ + count_stmt = select(literal(True)).select_from(query_stmt.order_by(None).subquery()) + return session.scalar(count_stmt) + + def exists_query(*where: ClauseElement, session: Session) -> bool: - """Check whether there is at least one row matching given clause. + """Check whether there is at least one row matching given clauses. This does a SELECT 1 WHERE ... LIMIT 1 and check the result. @@ -1897,3 +1936,122 @@ def exists_query(*where: ClauseElement, session: Session) -> bool: """ stmt = select(literal(True)).where(*where).limit(1) return session.scalar(stmt) is not None + + +@attrs.define(slots=True) +class LazySelectSequence(Sequence[T]): + """List-like interface to lazily access a database model query. + + The intended use case is inside a task execution context, where we manage an + active SQLAlchemy session in the background. + + This is an abstract base class. Each use case should subclass, and implement + the following static methods: + + * ``_rebuild_select`` is called when a lazy sequence is unpickled. Since it + is not easy to pickle SQLAlchemy constructs, this class serializes the + SELECT statements into plain text to storage. This method is called on + deserialization to convert the textual clause back into an ORM SELECT. + * ``_process_row`` is called when an item is accessed. The lazy sequence + uses ``session.execute()`` to fetch rows from the database, and this + method should know how to process each row into a value. + + :meta private: + """ + + _select_asc: ClauseElement + _select_desc: ClauseElement + _session: Session = attrs.field(kw_only=True, factory=get_current_task_instance_session) + _len: int | None = attrs.field(init=False, default=None) + + @classmethod + def from_select( + cls, + select: Select, + *, + order_by: Sequence[ClauseElement], + session: Session | None = None, + ) -> Self: + s1 = select + for col in order_by: + s1 = s1.order_by(col.asc()) + s2 = select + for col in order_by: + s2 = s2.order_by(col.desc()) + return cls(s1, s2, session=session or get_current_task_instance_session()) + + @staticmethod + def _rebuild_select(stmt: TextClause) -> Select: + """Rebuild a textual statement into an ORM-configured SELECT statement. + + This should do something like ``select(field).from_statement(stmt)`` to + reconfigure ORM information to the textual SQL statement. + """ + raise NotImplementedError + + @staticmethod + def _process_row(row: Row) -> T: + """Process a SELECT-ed row into the end value.""" + raise NotImplementedError + + def __repr__(self) -> str: + counter = "item" if (length := len(self)) == 1 else "items" + return f"LazySelectSequence([{length} {counter}])" + + def __str__(self) -> str: + counter = "item" if (length := len(self)) == 1 else "items" + return f"LazySelectSequence([{length} {counter}])" + + def __getstate__(self) -> Any: + # We don't want to go to the trouble of serializing SQLAlchemy objects. + # Converting the statement into a SQL string is the best we can get. + # The literal_binds compile argument inlines all the values into the SQL + # string to simplify cross-process commuinication as much as possible. + # Theoratically we can do the same for count(), but I think it should be + # performant enough to calculate only that eagerly. + s1 = str(self._select_asc.compile(self._session.get_bind(), compile_kwargs={"literal_binds": True})) + s2 = str(self._select_desc.compile(self._session.get_bind(), compile_kwargs={"literal_binds": True})) + return (s1, s2, len(self)) + + def __setstate__(self, state: Any) -> None: + s1, s2, self._len = state + self._select_asc = self._rebuild_select(text(s1)) + self._select_desc = self._rebuild_select(text(s2)) + self._session = get_current_task_instance_session() + + def __bool__(self) -> bool: + return check_query_exists(self._select_asc, session=self._session) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, collections.abc.Sequence): + return NotImplemented + z = itertools.zip_longest(iter(self), iter(other), fillvalue=object()) + return all(x == y for x, y in z) + + def __reversed__(self) -> Iterator[T]: + return iter(self._process_row(r) for r in self._session.execute(self._select_desc)) + + def __iter__(self) -> Iterator[T]: + return iter(self._process_row(r) for r in self._session.execute(self._select_asc)) + + def __len__(self) -> int: + if self._len is None: + self._len = get_query_count(self._select_asc, session=self._session) + return self._len + + @overload + def __getitem__(self, key: int) -> T: ... + + @overload + def __getitem__(self, key: slice) -> Self: ... + + def __getitem__(self, key: int | slice) -> T | Self: + if not isinstance(key, int): + raise ValueError("non-index access is not supported") + if key >= 0: + stmt = self._select_asc.offset(key) + else: + stmt = self._select_desc.offset(-1 - key) + if (row := self._session.execute(stmt.limit(1)).one_or_none()) is None: + raise IndexError(key) + return self._process_row(row) diff --git a/airflow/utils/task_instance_session.py b/airflow/utils/task_instance_session.py index 9d4dd958347c..bb9741bf5256 100644 --- a/airflow/utils/task_instance_session.py +++ b/airflow/utils/task_instance_session.py @@ -22,7 +22,7 @@ import traceback from typing import TYPE_CHECKING -from airflow.utils.session import create_session +from airflow import settings if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -41,7 +41,7 @@ def get_current_task_instance_session() -> Session: log.warning('File: "%s", %s , in %s', filename, line_number, name) if line: log.warning(" %s", line.strip()) - __current_task_instance_session = create_session() + __current_task_instance_session = settings.Session() return __current_task_instance_session diff --git a/docs/apache-airflow/authoring-and-scheduling/dynamic-task-mapping.rst b/docs/apache-airflow/authoring-and-scheduling/dynamic-task-mapping.rst index 739bc6fee956..dd6f42bc852c 100644 --- a/docs/apache-airflow/authoring-and-scheduling/dynamic-task-mapping.rst +++ b/docs/apache-airflow/authoring-and-scheduling/dynamic-task-mapping.rst @@ -53,7 +53,7 @@ The grid view also provides visibility into your mapped tasks in the details pan In the above example, ``values`` received by ``sum_it`` is an aggregation of all values returned by each mapped instance of ``add_one``. However, since it is impossible to know how many instances of ``add_one`` we will have in advance, ``values`` is not a normal list, but a "lazy sequence" that retrieves each individual value only when asked. Therefore, if you run ``print(values)`` directly, you would get something like this:: - LazyXComAccess(dag_id='simple_mapping', run_id='test_run', task_id='add_one') + LazySelectSequence([15 items]) You can use normal sequence syntax on this object (e.g. ``values[0]``), or iterate through it normally with a ``for`` loop. ``list(values)`` will give you a "real" ``list``, but since this would eagerly load values from *all* of the referenced upstream mapped tasks, you must be aware of the potential performance implications if the mapped number is large. diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 8afc1abaa179..11d833a21c50 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -73,7 +73,7 @@ from airflow.models.taskmap import TaskMap from airflow.models.taskreschedule import TaskReschedule from airflow.models.variable import Variable -from airflow.models.xcom import LazyXComAccess, XCom +from airflow.models.xcom import LazyXComSelectSequence, XCom from airflow.operators.bash import BashOperator from airflow.operators.empty import EmptyOperator from airflow.operators.python import PythonOperator @@ -93,6 +93,7 @@ from airflow.utils.session import create_session, provide_session from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.task_group import TaskGroup +from airflow.utils.task_instance_session import set_current_task_instance_session from airflow.utils.types import DagRunType from airflow.utils.xcom import XCOM_RETURN_KEY from tests.models import DEFAULT_DATE, TEST_DAGS_FOLDER @@ -4355,20 +4356,22 @@ def test_lazy_xcom_access_does_not_pickle_session(dag_maker, session): run: DagRun = dag_maker.create_dagrun() run.get_task_instance("t", session=session).xcom_push("xxx", 123, session=session) - query = session.query(XCom.value).filter_by( - dag_id=run.dag_id, - run_id=run.run_id, - task_id="t", - map_index=-1, - key="xxx", - ) - - original = LazyXComAccess.build_from_xcom_query(query) - processed = pickle.loads(pickle.dumps(original)) + with set_current_task_instance_session(session=session): + original = LazyXComSelectSequence.from_select( + select(XCom.value).filter_by( + dag_id=run.dag_id, + run_id=run.run_id, + task_id="t", + map_index=-1, + key="xxx", + ), + order_by=(), + ) + processed = pickle.loads(pickle.dumps(original)) # After the object went through pickling, the underlying ORM query should be # replaced by one backed by a literal SQL string with all variables binded. - sql_lines = [line.strip() for line in str(processed._query.statement.compile(None)).splitlines()] + sql_lines = [line.strip() for line in str(processed._select_asc.compile(None)).splitlines()] assert sql_lines == _get_lazy_xcom_access_expected_sql_lines() assert len(processed) == 1 @@ -4398,7 +4401,7 @@ def test_ti_xcom_pull_on_mapped_operator_return_lazy_iterable(mock_deserialize_v # Simply pulling the joined XCom value should not deserialize. joined = ti_2.xcom_pull("task_1", session=session) - assert isinstance(joined, LazyXComAccess) + assert isinstance(joined, LazyXComSelectSequence) assert mock_deserialize_value.call_count == 0 # Only when we go through the iterable does deserialization happen.