From 1c4a00bb4e15de083aea0c2a0ffe14ea97955c70 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Fri, 6 Sep 2024 15:16:53 +0800 Subject: [PATCH] Try to make dataset objects totally unhashable (#42054) --- airflow/datasets/__init__.py | 22 +++------------------ airflow/lineage/hook.py | 2 +- airflow/models/dag.py | 4 ++-- newsfragments/42054.significant.rst | 4 ++++ tests/datasets/test_dataset.py | 14 ++++--------- tests/lineage/test_hook.py | 4 ++-- tests/timetables/test_datasets_timetable.py | 2 +- 7 files changed, 17 insertions(+), 35 deletions(-) create mode 100644 newsfragments/42054.significant.rst diff --git a/airflow/datasets/__init__.py b/airflow/datasets/__init__.py index 55d947544c1d..d4305eeb0494 100644 --- a/airflow/datasets/__init__.py +++ b/airflow/datasets/__init__.py @@ -206,20 +206,12 @@ def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDepe raise NotImplementedError -@attr.define() +@attr.define(unsafe_hash=False) class DatasetAlias(BaseDataset): """A represeation of dataset alias which is used to create dataset during the runtime.""" name: str - def __eq__(self, other: Any) -> bool: - if isinstance(other, DatasetAlias): - return self.name == other.name - return NotImplemented - - def __hash__(self) -> int: - return hash(self.name) - def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDependency]: """ Iterate a dataset alias as dag dependency. @@ -241,7 +233,7 @@ class DatasetAliasEvent(TypedDict): dest_dataset_uri: str -@attr.define() +@attr.define(unsafe_hash=False) class Dataset(os.PathLike, BaseDataset): """A representation of data dependencies between workflows.""" @@ -249,21 +241,13 @@ class Dataset(os.PathLike, BaseDataset): converter=_sanitize_uri, validator=[attr.validators.min_len(1), attr.validators.max_len(3000)], ) - extra: dict[str, Any] | None = None + extra: dict[str, Any] = attr.field(factory=dict) __version__: ClassVar[int] = 1 def __fspath__(self) -> str: return self.uri - def __eq__(self, other: Any) -> bool: - if isinstance(other, self.__class__): - return self.uri == other.uri - return NotImplemented - - def __hash__(self) -> int: - return hash(self.uri) - @property def normalized_uri(self) -> str | None: """ diff --git a/airflow/lineage/hook.py b/airflow/lineage/hook.py index 4ff35e4d9ce8..45227c524818 100644 --- a/airflow/lineage/hook.py +++ b/airflow/lineage/hook.py @@ -112,7 +112,7 @@ def create_dataset( """ if uri: # Fallback to default factory using the provided URI - return Dataset(uri=uri, extra=dataset_extra) + return Dataset(uri=uri, extra=dataset_extra or {}) if not scheme: self.log.debug( diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 54ebce7392b3..56f7dc89d25b 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -2786,12 +2786,12 @@ def bulk_write_to_db( curr_outlet_references = curr_orm_dag and curr_orm_dag.task_outlet_dataset_references for task in dag.tasks: dataset_outlets: list[Dataset] = [] - dataset_alias_outlets: set[DatasetAlias] = set() + dataset_alias_outlets: list[DatasetAlias] = [] for outlet in task.outlets: if isinstance(outlet, Dataset): dataset_outlets.append(outlet) elif isinstance(outlet, DatasetAlias): - dataset_alias_outlets.add(outlet) + dataset_alias_outlets.append(outlet) if not dataset_outlets: if curr_outlet_references: diff --git a/newsfragments/42054.significant.rst b/newsfragments/42054.significant.rst new file mode 100644 index 000000000000..aebf70757fa0 --- /dev/null +++ b/newsfragments/42054.significant.rst @@ -0,0 +1,4 @@ +Dataset and DatasetAlias are no longer hashable + +This means they can no longer be used as dict keys or put into a set. Dataset's +equality logic is also tweaked slightly to consider the extra dict. diff --git a/tests/datasets/test_dataset.py b/tests/datasets/test_dataset.py index 017c87476349..940e445669cc 100644 --- a/tests/datasets/test_dataset.py +++ b/tests/datasets/test_dataset.py @@ -120,12 +120,6 @@ def test_not_equal_when_different_uri(): assert dataset1 != dataset2 -def test_hash(): - uri = "s3://example/dataset" - dataset = Dataset(uri=uri) - hash(dataset) - - def test_dataset_logic_operations(): result_or = dataset1 | dataset2 assert isinstance(result_or, DatasetAny) @@ -187,10 +181,10 @@ def test_datasetbooleancondition_evaluate_iter(): assert all_condition.evaluate({"s3://bucket1/data1": True, "s3://bucket2/data2": False}) is False # Testing iter_datasets indirectly through the subclasses - datasets_any = set(any_condition.iter_datasets()) - datasets_all = set(all_condition.iter_datasets()) - assert datasets_any == {("s3://bucket1/data1", dataset1), ("s3://bucket2/data2", dataset2)} - assert datasets_all == {("s3://bucket1/data1", dataset1), ("s3://bucket2/data2", dataset2)} + datasets_any = dict(any_condition.iter_datasets()) + datasets_all = dict(all_condition.iter_datasets()) + assert datasets_any == {"s3://bucket1/data1": dataset1, "s3://bucket2/data2": dataset2} + assert datasets_all == {"s3://bucket1/data1": dataset1, "s3://bucket2/data2": dataset2} @pytest.mark.parametrize( diff --git a/tests/lineage/test_hook.py b/tests/lineage/test_hook.py index 97e160e1f971..16f386b6849e 100644 --- a/tests/lineage/test_hook.py +++ b/tests/lineage/test_hook.py @@ -69,7 +69,7 @@ def test_add_input_dataset(self, mock_dataset): self.collector.add_input_dataset(hook, uri="test_uri") assert next(iter(self.collector._inputs.values())) == (dataset, hook) - mock_dataset.assert_called_once_with(uri="test_uri", extra=None) + mock_dataset.assert_called_once_with(uri="test_uri", extra={}) def test_grouping_datasets(self): hook_1 = MagicMock() @@ -96,7 +96,7 @@ def test_grouping_datasets(self): @patch("airflow.lineage.hook.ProvidersManager") def test_create_dataset(self, mock_providers_manager): def create_dataset(arg1, arg2="default", extra=None): - return Dataset(uri=f"myscheme://{arg1}/{arg2}", extra=extra) + return Dataset(uri=f"myscheme://{arg1}/{arg2}", extra=extra or {}) mock_providers_manager.return_value.dataset_factories = {"myscheme": create_dataset} assert self.collector.create_dataset( diff --git a/tests/timetables/test_datasets_timetable.py b/tests/timetables/test_datasets_timetable.py index b055f0d34dc9..b456b9bf5dc9 100644 --- a/tests/timetables/test_datasets_timetable.py +++ b/tests/timetables/test_datasets_timetable.py @@ -134,7 +134,7 @@ def test_serialization(dataset_timetable: DatasetOrTimeSchedule, monkeypatch: An "timetable": "mock_serialized_timetable", "dataset_condition": { "__type": "dataset_all", - "objects": [{"__type": "dataset", "uri": "test_dataset", "extra": None}], + "objects": [{"__type": "dataset", "uri": "test_dataset", "extra": {}}], }, }