From 00e1336816360399e99c1d5cefb01f42bee5d1d4 Mon Sep 17 00:00:00 2001 From: Jakub Dardzinski Date: Thu, 9 May 2024 20:54:09 +0200 Subject: [PATCH 1/7] Migrate OpenLineage proivder to V2 facets. Signed-off-by: Jakub Dardzinski --- .../providers/openlineage/extractors/base.py | 20 +- .../providers/openlineage/extractors/bash.py | 6 +- .../openlineage/extractors/manager.py | 34 +-- .../openlineage/extractors/python.py | 6 +- .../providers/openlineage/plugins/adapter.py | 78 ++++--- .../providers/openlineage/plugins/facets.py | 16 +- airflow/providers/openlineage/sqlparser.py | 31 ++- airflow/providers/openlineage/utils/sql.py | 10 +- .../openlineage/extractors/test_base.py | 20 +- .../openlineage/extractors/test_bash.py | 4 +- .../openlineage/extractors/test_manager.py | 44 ++-- .../openlineage/extractors/test_python.py | 4 +- .../openlineage/plugins/test_adapter.py | 200 ++++++++++-------- tests/providers/openlineage/test_sqlparser.py | 44 ++-- tests/providers/openlineage/utils/test_sql.py | 17 +- 15 files changed, 276 insertions(+), 258 deletions(-) diff --git a/airflow/providers/openlineage/extractors/base.py b/airflow/providers/openlineage/extractors/base.py index 352c319ca194..05aa659d7558 100644 --- a/airflow/providers/openlineage/extractors/base.py +++ b/airflow/providers/openlineage/extractors/base.py @@ -18,26 +18,28 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING +from typing import Generic, TypeVar, Union from attrs import Factory, define +from openlineage.client.event_v2 import Dataset as OLDataset +from openlineage.client.facet import BaseFacet as BaseFacet_V1 +from openlineage.client.facet_v2 import JobFacet, RunFacet from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.state import TaskInstanceState -if TYPE_CHECKING: - from openlineage.client.facet import BaseFacet - from openlineage.client.run import Dataset +DatasetSubclass = TypeVar("DatasetSubclass", bound=OLDataset) +BaseFacetSubclass = TypeVar("BaseFacetSubclass", bound=Union[BaseFacet_V1, RunFacet, JobFacet]) @define -class OperatorLineage: +class OperatorLineage(Generic[DatasetSubclass, BaseFacetSubclass]): """Structure returned from lineage extraction.""" - inputs: list[Dataset] = Factory(list) - outputs: list[Dataset] = Factory(list) - run_facets: dict[str, BaseFacet] = Factory(dict) - job_facets: dict[str, BaseFacet] = Factory(dict) + inputs: list[DatasetSubclass] = Factory(list) + outputs: list[DatasetSubclass] = Factory(list) + run_facets: dict[str, BaseFacetSubclass] = Factory(dict) + job_facets: dict[str, BaseFacetSubclass] = Factory(dict) class BaseExtractor(ABC, LoggingMixin): diff --git a/airflow/providers/openlineage/extractors/bash.py b/airflow/providers/openlineage/extractors/bash.py index 39c3c10781f4..6e1b3f28eefe 100644 --- a/airflow/providers/openlineage/extractors/bash.py +++ b/airflow/providers/openlineage/extractors/bash.py @@ -17,7 +17,7 @@ from __future__ import annotations -from openlineage.client.facet import SourceCodeJobFacet +from openlineage.client.facet_v2 import source_code_job from airflow.providers.openlineage import conf from airflow.providers.openlineage.extractors.base import BaseExtractor, OperatorLineage @@ -47,10 +47,10 @@ def _execute_extraction(self) -> OperatorLineage | None: job_facets: dict = {} if conf.is_source_enabled(): job_facets = { - "sourceCode": SourceCodeJobFacet( + "sourceCode": source_code_job.SourceCodeJobFacet( language="bash", # We're on worker and should have access to DAG files - source=self.operator.bash_command, + sourceCode=self.operator.bash_command, ) } else: diff --git a/airflow/providers/openlineage/extractors/manager.py b/airflow/providers/openlineage/extractors/manager.py index 4f28f7dafa80..5b9ad6ac1b59 100644 --- a/airflow/providers/openlineage/extractors/manager.py +++ b/airflow/providers/openlineage/extractors/manager.py @@ -30,7 +30,7 @@ from airflow.utils.log.logging_mixin import LoggingMixin if TYPE_CHECKING: - from openlineage.client.run import Dataset + from openlineage.client.event_v2 import Dataset from airflow.lineage.entities import Table from airflow.models import Operator @@ -172,7 +172,7 @@ def extract_inlets_and_outlets( def convert_to_ol_dataset_from_object_storage_uri(uri: str) -> Dataset | None: from urllib.parse import urlparse - from openlineage.client.run import Dataset + from openlineage.client.event_v2 import Dataset if "/" not in uri: return None @@ -196,21 +196,19 @@ def convert_to_ol_dataset_from_object_storage_uri(uri: str) -> Dataset | None: @staticmethod def convert_to_ol_dataset_from_table(table: Table) -> Dataset: - from openlineage.client.facet import ( - BaseFacet, - DocumentationDatasetFacet, - OwnershipDatasetFacet, - OwnershipDatasetFacetOwners, - SchemaDatasetFacet, - SchemaField, + from openlineage.client.event_v2 import Dataset + from openlineage.client.facet_v2 import ( + DatasetFacet, + documentation_dataset, + ownership_dataset, + schema_dataset, ) - from openlineage.client.run import Dataset - facets: dict[str, BaseFacet] = {} + facets: dict[str, DatasetFacet] = {} if table.columns: - facets["schema"] = SchemaDatasetFacet( + facets["schema"] = schema_dataset.SchemaDatasetFacet( fields=[ - SchemaField( + schema_dataset.SchemaDatasetFacetFields( name=column.name, type=column.data_type, description=column.description, @@ -219,9 +217,9 @@ def convert_to_ol_dataset_from_table(table: Table) -> Dataset: ] ) if table.owners: - facets["ownership"] = OwnershipDatasetFacet( + facets["ownership"] = ownership_dataset.OwnershipDatasetFacet( owners=[ - OwnershipDatasetFacetOwners( + ownership_dataset.Owner( # f.e. "user:John Doe " or just "user:" name=f"user:" f"{user.first_name + ' ' if user.first_name else ''}" @@ -233,7 +231,9 @@ def convert_to_ol_dataset_from_table(table: Table) -> Dataset: ] ) if table.description: - facets["documentation"] = DocumentationDatasetFacet(description=table.description) + facets["documentation"] = documentation_dataset.DocumentationDatasetFacet( + description=table.description + ) return Dataset( namespace=f"{table.cluster}", name=f"{table.database}.{table.name}", @@ -242,7 +242,7 @@ def convert_to_ol_dataset_from_table(table: Table) -> Dataset: @staticmethod def convert_to_ol_dataset(obj) -> Dataset | None: - from openlineage.client.run import Dataset + from openlineage.client.event_v2 import Dataset from airflow.lineage.entities import File, Table diff --git a/airflow/providers/openlineage/extractors/python.py b/airflow/providers/openlineage/extractors/python.py index 8f7efad0937f..c716e28b4d8f 100644 --- a/airflow/providers/openlineage/extractors/python.py +++ b/airflow/providers/openlineage/extractors/python.py @@ -20,7 +20,7 @@ import inspect from typing import Callable -from openlineage.client.facet import SourceCodeJobFacet +from openlineage.client.facet_v2 import source_code_job from airflow.providers.openlineage import conf from airflow.providers.openlineage.extractors.base import BaseExtractor, OperatorLineage @@ -51,10 +51,10 @@ def _execute_extraction(self) -> OperatorLineage | None: job_facet: dict = {} if conf.is_source_enabled() and source_code: job_facet = { - "sourceCode": SourceCodeJobFacet( + "sourceCode": source_code_job.SourceCodeJobFacet( language="python", # We're on worker and should have access to DAG files - source=source_code, + sourceCode=source_code, ) } else: diff --git a/airflow/providers/openlineage/plugins/adapter.py b/airflow/providers/openlineage/plugins/adapter.py index 8e1d924bb979..6f62e772a8be 100644 --- a/airflow/providers/openlineage/plugins/adapter.py +++ b/airflow/providers/openlineage/plugins/adapter.py @@ -22,19 +22,19 @@ import yaml from openlineage.client import OpenLineageClient, set_producer -from openlineage.client.facet import ( - BaseFacet, - DocumentationJobFacet, - ErrorMessageRunFacet, - JobTypeJobFacet, - NominalTimeRunFacet, - OwnershipJobFacet, - OwnershipJobFacetOwners, - ParentRunFacet, - ProcessingEngineRunFacet, - SourceCodeLocationJobFacet, +from openlineage.client.event_v2 import Job, Run, RunEvent, RunState +from openlineage.client.facet_v2 import ( + JobFacet, + RunFacet, + documentation_job, + error_message_run, + job_type_job, + nominal_time_run, + ownership_job, + parent_run, + processing_engine_run, + source_code_location_job, ) -from openlineage.client.run import Job, Run, RunEvent, RunState from openlineage.client.uuid import generate_static_uuid from airflow.providers.openlineage import __version__ as OPENLINEAGE_PROVIDER_VERSION, conf @@ -60,8 +60,8 @@ # https://openlineage.io/docs/spec/facets/job-facets/job-type # They must be set after the `set_producer(_PRODUCER)` # otherwise the `JobTypeJobFacet._producer` will be set with the default value -_JOB_TYPE_DAG = JobTypeJobFacet(jobType="DAG", integration="AIRFLOW", processingType="BATCH") -_JOB_TYPE_TASK = JobTypeJobFacet(jobType="TASK", integration="AIRFLOW", processingType="BATCH") +_JOB_TYPE_DAG = job_type_job.JobTypeJobFacet(jobType="DAG", integration="AIRFLOW", processingType="BATCH") +_JOB_TYPE_TASK = job_type_job.JobTypeJobFacet(jobType="TASK", integration="AIRFLOW", processingType="BATCH") class OpenLineageAdapter(LoggingMixin): @@ -149,7 +149,7 @@ def emit(self, event: RunEvent): if not self._client: self._client = self.get_or_create_openlineage_client() redacted_event: RunEvent = self._redacter.redact(event, max_depth=20) # type: ignore[assignment] - event_type = event.eventType.value.lower() + event_type = event.eventType.value.lower() if event.eventType else "" transport_type = f"{self._client.transport.kind}".lower() try: @@ -178,7 +178,7 @@ def start_task( nominal_end_time: str | None, owners: list[str], task: OperatorLineage | None, - run_facets: dict[str, BaseFacet] | None = None, # Custom run facets + run_facets: dict[str, RunFacet] | None = None, # Custom run facets ) -> RunEvent: """ Emit openlineage event of type START. @@ -199,14 +199,13 @@ def start_task( """ from airflow.version import version as AIRFLOW_VERSION - processing_engine_version_facet = ProcessingEngineRunFacet( + processing_engine_version_facet = processing_engine_run.ProcessingEngineRunFacet( version=AIRFLOW_VERSION, name="Airflow", openlineageAdapterVersion=OPENLINEAGE_PROVIDER_VERSION, ) - if not run_facets: - run_facets = {} + run_facets = run_facets or {} if task: run_facets = {**task.run_facets, **run_facets} run_facets["processing_engine"] = processing_engine_version_facet # type: ignore @@ -302,10 +301,9 @@ def fail_task( import traceback stack_trace = "\\n".join(traceback.format_exception(type(error), error, error.__traceback__)) - error_facet = { - "errorMessage": ErrorMessageRunFacet( - message=str(error), programmingLanguage="python", stackTrace=stack_trace + "errorMessage": error_message_run.ErrorMessageRunFacet( + message=error, programmingLanguage="python", stackTrace=stack_trace ) } @@ -403,7 +401,9 @@ def dag_failed(self, dag_run: DagRun, msg: str): execution_date=dag_run.execution_date, ), facets={ - "errorMessage": ErrorMessageRunFacet(message=msg, programmingLanguage="python"), + "errorMessage": error_message_run.ErrorMessageRunFacet( + message=msg, programmingLanguage="python" + ), **get_airflow_state_run_facet(dag_run), }, ), @@ -426,13 +426,15 @@ def _build_run( parent_run_id: str | None = None, nominal_start_time: str | None = None, nominal_end_time: str | None = None, - run_facets: dict[str, BaseFacet] | None = None, + run_facets: dict[str, RunFacet] | None = None, ) -> Run: - facets: dict[str, BaseFacet] = {} + facets: dict[str, RunFacet] = {} if nominal_start_time: - facets.update({"nominalTime": NominalTimeRunFacet(nominal_start_time, nominal_end_time)}) + facets.update( + {"nominalTime": nominal_time_run.NominalTimeRunFacet(nominal_start_time, nominal_end_time)} + ) if parent_run_id: - parent_run_facet = ParentRunFacet.create( + parent_run_facet = parent_run.ParentRunFacet.create( runId=parent_run_id, namespace=conf.namespace(), name=parent_job_name or job_name, @@ -447,23 +449,31 @@ def _build_run( @staticmethod def _build_job( job_name: str, - job_type: JobTypeJobFacet, + job_type: job_type_job.JobTypeJobFacet, job_description: str | None = None, code_location: str | None = None, owners: list[str] | None = None, - job_facets: dict[str, BaseFacet] | None = None, + job_facets: dict[str, JobFacet] | None = None, ): - facets: dict[str, BaseFacet] = {} + facets: dict[str, JobFacet] = {} if job_description: - facets.update({"documentation": DocumentationJobFacet(description=job_description)}) + facets.update( + {"documentation": documentation_job.DocumentationJobFacet(description=job_description)} + ) if code_location: - facets.update({"sourceCodeLocation": SourceCodeLocationJobFacet("", url=code_location)}) + facets.update( + { + "sourceCodeLocation": source_code_location_job.SourceCodeLocationJobFacet( + "", url=code_location + ) + } + ) if owners: facets.update( { - "ownership": OwnershipJobFacet( - owners=[OwnershipJobFacetOwners(name=owner) for owner in owners] + "ownership": ownership_job.OwnershipJobFacet( + owners=[ownership_job.Owner(name=owner) for owner in owners] ) } ) diff --git a/airflow/providers/openlineage/plugins/facets.py b/airflow/providers/openlineage/plugins/facets.py index d282c72ac813..24f411477acc 100644 --- a/airflow/providers/openlineage/plugins/facets.py +++ b/airflow/providers/openlineage/plugins/facets.py @@ -18,7 +18,7 @@ from attrs import define from deprecated import deprecated -from openlineage.client.facet import BaseFacet +from openlineage.client.facet_v2 import JobFacet, RunFacet from openlineage.client.utils import RedactMixin from airflow.exceptions import AirflowProviderDeprecationWarning @@ -29,7 +29,7 @@ category=AirflowProviderDeprecationWarning, ) @define(slots=False) -class AirflowMappedTaskRunFacet(BaseFacet): +class AirflowMappedTaskRunFacet(RunFacet): """Run facet containing information about mapped tasks.""" mapIndex: int @@ -48,7 +48,7 @@ def from_task_instance(cls, task_instance): @define(slots=False) -class AirflowJobFacet(BaseFacet): +class AirflowJobFacet(JobFacet): """ Composite Airflow job facet. @@ -71,7 +71,7 @@ class AirflowJobFacet(BaseFacet): @define(slots=False) -class AirflowStateRunFacet(BaseFacet): +class AirflowStateRunFacet(RunFacet): """ Airflow facet providing state information. @@ -90,8 +90,8 @@ class AirflowStateRunFacet(BaseFacet): @define(slots=False) -class AirflowRunFacet(BaseFacet): - """Composite Airflow task run facet.""" +class AirflowRunFacet(RunFacet): + """Composite Airflow run facet.""" dag: dict dagRun: dict @@ -101,7 +101,7 @@ class AirflowRunFacet(BaseFacet): @define(slots=False) -class AirflowDagRunFacet(BaseFacet): +class AirflowDagRunFacet(RunFacet): """Composite Airflow DAG run facet.""" dag: dict @@ -128,7 +128,7 @@ class UnknownOperatorInstance(RedactMixin): category=AirflowProviderDeprecationWarning, ) @define(slots=False) -class UnknownOperatorAttributeRunFacet(BaseFacet): +class UnknownOperatorAttributeRunFacet(RunFacet): """RunFacet that describes unknown operators in an Airflow DAG.""" unknownItems: list[UnknownOperatorInstance] diff --git a/airflow/providers/openlineage/sqlparser.py b/airflow/providers/openlineage/sqlparser.py index 9906f3db3cda..900bbf277269 100644 --- a/airflow/providers/openlineage/sqlparser.py +++ b/airflow/providers/openlineage/sqlparser.py @@ -20,16 +20,8 @@ import sqlparse from attrs import define -from openlineage.client.facet import ( - BaseFacet, - ColumnLineageDatasetFacet, - ColumnLineageDatasetFacetFieldsAdditional, - ColumnLineageDatasetFacetFieldsAdditionalInputFields, - ExtractionError, - ExtractionErrorRunFacet, - SqlJobFacet, -) -from openlineage.client.run import Dataset +from openlineage.client.event_v2 import Dataset +from openlineage.client.facet_v2 import column_lineage_dataset, extraction_error_run, sql_job from openlineage.common.sql import DbTableMeta, SqlMeta, parse from airflow.providers.openlineage.extractors.base import OperatorLineage @@ -42,6 +34,7 @@ from airflow.utils.log.logging_mixin import LoggingMixin if TYPE_CHECKING: + from openlineage.client.facet_v2 import JobFacet, RunFacet from sqlalchemy.engine import Engine from airflow.hooks.base import BaseHook @@ -206,11 +199,13 @@ def attach_column_lineage( if not len(parse_result.column_lineage): return for dataset in datasets: - dataset.facets["columnLineage"] = ColumnLineageDatasetFacet( + if not dataset.facets: + continue + dataset.facets["columnLineage"] = column_lineage_dataset.ColumnLineageDatasetFacet( fields={ - column_lineage.descendant.name: ColumnLineageDatasetFacetFieldsAdditional( + column_lineage.descendant.name: column_lineage_dataset.Fields( inputFields=[ - ColumnLineageDatasetFacetFieldsAdditionalInputFields( + column_lineage_dataset.InputField( namespace=dataset.namespace, name=".".join( filter( @@ -260,18 +255,18 @@ def generate_openlineage_metadata_from_sql( :param database: when passed it takes precedence over parsed database name :param sqlalchemy_engine: when passed, engine's dialect is used to compile SQL queries """ - job_facets: dict[str, BaseFacet] = {"sql": SqlJobFacet(query=self.normalize_sql(sql))} - parse_result = self.parse(self.split_sql_string(sql)) + job_facets: dict[str, JobFacet] = {"sql": sql_job.SQLJobFacet(query=self.normalize_sql(sql))} + parse_result = self.parse(sql=self.split_sql_string(sql)) if not parse_result: return OperatorLineage(job_facets=job_facets) - run_facets: dict[str, BaseFacet] = {} + run_facets: dict[str, RunFacet] = {} if parse_result.errors: - run_facets["extractionError"] = ExtractionErrorRunFacet( + run_facets["extractionError"] = extraction_error_run.ExtractionErrorRunFacet( totalTasks=len(sql) if isinstance(sql, list) else 1, failedTasks=len(parse_result.errors), errors=[ - ExtractionError( + extraction_error_run.Error( errorMessage=error.message, stackTrace=None, task=error.origin_statement, diff --git a/airflow/providers/openlineage/utils/sql.py b/airflow/providers/openlineage/utils/sql.py index a4ebe44740ca..7cbd9531945b 100644 --- a/airflow/providers/openlineage/utils/sql.py +++ b/airflow/providers/openlineage/utils/sql.py @@ -23,8 +23,8 @@ from typing import TYPE_CHECKING, Dict, List, Optional from attrs import define -from openlineage.client.facet import SchemaDatasetFacet, SchemaField -from openlineage.client.run import Dataset +from openlineage.client.event_v2 import Dataset +from openlineage.client.facet_v2 import schema_dataset from sqlalchemy import Column, MetaData, Table, and_, or_, union_all if TYPE_CHECKING: @@ -60,7 +60,7 @@ class TableSchema: table: str schema: str | None database: str | None - fields: list[SchemaField] + fields: list[schema_dataset.SchemaDatasetFacetFields] def to_dataset(self, namespace: str, database: str | None = None, schema: str | None = None) -> Dataset: # Prefix the table name with database and schema name using @@ -73,7 +73,7 @@ def to_dataset(self, namespace: str, database: str | None = None, schema: str | return Dataset( namespace=namespace, name=name, - facets={"schema": SchemaDatasetFacet(fields=self.fields)} if self.fields else {}, + facets={"schema": schema_dataset.SchemaDatasetFacet(fields=self.fields)} if self.fields else {}, ) @@ -122,7 +122,7 @@ def parse_query_result(cursor) -> list[TableSchema]: for row in cursor.fetchall(): table_schema_name: str = row[ColumnIndex.SCHEMA] table_name: str = row[ColumnIndex.TABLE_NAME] - table_column: SchemaField = SchemaField( + table_column = schema_dataset.SchemaDatasetFacetFields( name=row[ColumnIndex.COLUMN_NAME], type=row[ColumnIndex.UDT_NAME], description=None, diff --git a/tests/providers/openlineage/extractors/test_base.py b/tests/providers/openlineage/extractors/test_base.py index d81210605167..20ceba45dd87 100644 --- a/tests/providers/openlineage/extractors/test_base.py +++ b/tests/providers/openlineage/extractors/test_base.py @@ -16,13 +16,13 @@ # under the License. from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING, Any from unittest import mock import pytest from attrs import Factory, define, field -from openlineage.client.facet import BaseFacet, ParentRunFacet, SqlJobFacet -from openlineage.client.run import Dataset +from openlineage.client.event_v2 import Dataset +from openlineage.client.facet_v2 import BaseFacet, JobFacet, parent_run, sql_job from airflow.models.baseoperator import BaseOperator from airflow.operators.python import PythonOperator @@ -34,23 +34,27 @@ from airflow.providers.openlineage.extractors.manager import ExtractorManager from airflow.providers.openlineage.extractors.python import PythonExtractor +if TYPE_CHECKING: + from openlineage.client.facet_v2 import RunFacet pytestmark = pytest.mark.db_test INPUTS = [Dataset(namespace="database://host:port", name="inputtable")] OUTPUTS = [Dataset(namespace="database://host:port", name="inputtable")] -RUN_FACETS: dict[str, BaseFacet] = { - "parent": ParentRunFacet.create("3bb703d1-09c1-4a42-8da5-35a0b3216072", "namespace", "parentjob") +RUN_FACETS: dict[str, RunFacet] = { + "parent": parent_run.ParentRunFacet.create( + "3bb703d1-09c1-4a42-8da5-35a0b3216072", "namespace", "parentjob" + ) } -JOB_FACETS: dict[str, BaseFacet] = {"sql": SqlJobFacet(query="SELECT * FROM inputtable")} +JOB_FACETS: dict[str, JobFacet] = {"sql": sql_job.SQLJobFacet(query="SELECT * FROM inputtable")} @define -class CompleteRunFacet(BaseFacet): +class CompleteRunFacet(JobFacet): finished: bool = field(default=False) -FINISHED_FACETS: dict[str, BaseFacet] = {"complete": CompleteRunFacet(True)} +FINISHED_FACETS: dict[str, JobFacet] = {"complete": CompleteRunFacet(True)} class ExampleExtractor(BaseExtractor): diff --git a/tests/providers/openlineage/extractors/test_bash.py b/tests/providers/openlineage/extractors/test_bash.py index 8f33af535d35..de65a1d176d8 100644 --- a/tests/providers/openlineage/extractors/test_bash.py +++ b/tests/providers/openlineage/extractors/test_bash.py @@ -22,7 +22,7 @@ from unittest.mock import patch import pytest -from openlineage.client.facet import SourceCodeJobFacet +from openlineage.client.facet_v2 import source_code_job from airflow import DAG from airflow.exceptions import AirflowProviderDeprecationWarning @@ -67,7 +67,7 @@ def test_extract_operator_bash_command_enabled(mocked_source_enabled): with warnings.catch_warnings(): warnings.simplefilter("ignore", AirflowProviderDeprecationWarning) result = BashExtractor(operator).extract() - assert result.job_facets["sourceCode"] == SourceCodeJobFacet("bash", "exit 0;") + assert result.job_facets["sourceCode"] == source_code_job.SourceCodeJobFacet("bash", "exit 0;") assert "unknownSourceAttribute" in result.run_facets unknown_items = result.run_facets["unknownSourceAttribute"]["unknownItems"] assert len(unknown_items) == 1 diff --git a/tests/providers/openlineage/extractors/test_manager.py b/tests/providers/openlineage/extractors/test_manager.py index 10f04eb342cc..ccfd04d5e2b9 100644 --- a/tests/providers/openlineage/extractors/test_manager.py +++ b/tests/providers/openlineage/extractors/test_manager.py @@ -18,14 +18,8 @@ from __future__ import annotations import pytest -from openlineage.client.facet import ( - DocumentationDatasetFacet, - OwnershipDatasetFacet, - OwnershipDatasetFacetOwners, - SchemaDatasetFacet, - SchemaField, -) -from openlineage.client.run import Dataset +from openlineage.client.event_v2 import Dataset +from openlineage.client.facet_v2 import documentation_dataset, ownership_dataset, schema_dataset from airflow.lineage.entities import Column, File, Table, User from airflow.providers.openlineage.extractors.manager import ExtractorManager @@ -98,29 +92,29 @@ def test_convert_to_ol_dataset_from_table_with_columns_and_owners(): description="test description", ) expected_facets = { - "schema": SchemaDatasetFacet( + "schema": schema_dataset.SchemaDatasetFacet( fields=[ - SchemaField( + schema_dataset.SchemaDatasetFacetFields( name="col1", type="type1", description="desc1", ), - SchemaField( + schema_dataset.SchemaDatasetFacetFields( name="col2", type="type2", description="desc2", ), ] ), - "ownership": OwnershipDatasetFacet( + "ownership": ownership_dataset.OwnershipDatasetFacet( owners=[ - OwnershipDatasetFacetOwners(name="user:Mike Smith ", type=""), - OwnershipDatasetFacetOwners(name="user:Theo ", type=""), - OwnershipDatasetFacetOwners(name="user:Smith ", type=""), - OwnershipDatasetFacetOwners(name="user:", type=""), + ownership_dataset.Owner(name="user:Mike Smith ", type=""), + ownership_dataset.Owner(name="user:Theo ", type=""), + ownership_dataset.Owner(name="user:Smith ", type=""), + ownership_dataset.Owner(name="user:", type=""), ] ), - "documentation": DocumentationDatasetFacet(description="test description"), + "documentation": documentation_dataset.DocumentationDatasetFacet(description="test description"), } result = ExtractorManager.convert_to_ol_dataset_from_table(table) assert result.namespace == "c1" @@ -145,26 +139,26 @@ def test_convert_to_ol_dataset_table(): ], ) expected_facets = { - "schema": SchemaDatasetFacet( + "schema": schema_dataset.SchemaDatasetFacet( fields=[ - SchemaField( + schema_dataset.SchemaDatasetFacetFields( name="col1", type="type1", description="desc1", ), - SchemaField( + schema_dataset.SchemaDatasetFacetFields( name="col2", type="type2", description="desc2", ), ] ), - "ownership": OwnershipDatasetFacet( + "ownership": ownership_dataset.OwnershipDatasetFacet( owners=[ - OwnershipDatasetFacetOwners(name="user:Mike Smith ", type=""), - OwnershipDatasetFacetOwners(name="user:Theo ", type=""), - OwnershipDatasetFacetOwners(name="user:Smith ", type=""), - OwnershipDatasetFacetOwners(name="user:", type=""), + ownership_dataset.Owner(name="user:Mike Smith ", type=""), + ownership_dataset.Owner(name="user:Theo ", type=""), + ownership_dataset.Owner(name="user:Smith ", type=""), + ownership_dataset.Owner(name="user:", type=""), ] ), } diff --git a/tests/providers/openlineage/extractors/test_python.py b/tests/providers/openlineage/extractors/test_python.py index 7d47b9ebc6f8..81284383d864 100644 --- a/tests/providers/openlineage/extractors/test_python.py +++ b/tests/providers/openlineage/extractors/test_python.py @@ -24,7 +24,7 @@ from unittest.mock import patch import pytest -from openlineage.client.facet import SourceCodeJobFacet +from openlineage.client.facet_v2 import source_code_job from airflow import DAG from airflow.exceptions import AirflowProviderDeprecationWarning @@ -88,7 +88,7 @@ def test_extract_operator_code_enabled(mocked_source_enabled): with warnings.catch_warnings(): warnings.simplefilter("ignore", AirflowProviderDeprecationWarning) result = PythonExtractor(operator).extract() - assert result.job_facets["sourceCode"] == SourceCodeJobFacet("python", CODE) + assert result.job_facets["sourceCode"] == source_code_job.SourceCodeJobFacet("python", CODE) assert "unknownSourceAttribute" in result.run_facets unknown_items = result.run_facets["unknownSourceAttribute"]["unknownItems"] assert len(unknown_items) == 1 diff --git a/tests/providers/openlineage/plugins/test_adapter.py b/tests/providers/openlineage/plugins/test_adapter.py index fb60b5cc8c50..93b49e150f6e 100644 --- a/tests/providers/openlineage/plugins/test_adapter.py +++ b/tests/providers/openlineage/plugins/test_adapter.py @@ -24,19 +24,18 @@ from unittest.mock import ANY, MagicMock, call, patch import pytest -from openlineage.client.facet import ( - DocumentationJobFacet, - ErrorMessageRunFacet, - ExternalQueryRunFacet, - JobTypeJobFacet, - NominalTimeRunFacet, - OwnershipJobFacet, - OwnershipJobFacetOwners, - ParentRunFacet, - ProcessingEngineRunFacet, - SqlJobFacet, +from openlineage.client.event_v2 import Dataset, Job, Run, RunEvent, RunState +from openlineage.client.facet_v2 import ( + documentation_job, + error_message_run, + external_query_run, + job_type_job, + nominal_time_run, + ownership_job, + parent_run, + processing_engine_run, + sql_job, ) -from openlineage.client.run import Dataset, Job, Run, RunEvent, RunState from airflow import DAG from airflow.models.dagrun import DagRun, DagRunState @@ -161,11 +160,11 @@ def test_emit_start_event(mock_stats_incr, mock_stats_timer): run=Run( runId=run_id, facets={ - "nominalTime": NominalTimeRunFacet( + "nominalTime": nominal_time_run.NominalTimeRunFacet( nominalStartTime="2022-01-01T00:00:00", nominalEndTime="2022-01-01T00:00:00", ), - "processing_engine": ProcessingEngineRunFacet( + "processing_engine": processing_engine_run.ProcessingEngineRunFacet( version=ANY, name="Airflow", openlineageAdapterVersion=ANY ), }, @@ -174,8 +173,8 @@ def test_emit_start_event(mock_stats_incr, mock_stats_timer): namespace=namespace(), name="job", facets={ - "documentation": DocumentationJobFacet(description="description"), - "jobType": JobTypeJobFacet( + "documentation": documentation_job.DocumentationJobFacet(description="description"), + "jobType": job_type_job.JobTypeJobFacet( processingType="BATCH", integration="AIRFLOW", jobType="TASK" ), }, @@ -199,6 +198,7 @@ def test_emit_start_event_with_additional_information(mock_stats_incr, mock_stat adapter = OpenLineageAdapter(client) run_id = str(uuid.uuid4()) + parent_run_id = str(uuid.uuid4()) event_time = datetime.datetime.now().isoformat() adapter.start_task( run_id=run_id, @@ -206,7 +206,7 @@ def test_emit_start_event_with_additional_information(mock_stats_incr, mock_stat job_description="description", event_time=event_time, parent_job_name="parent_job_name", - parent_run_id="parent_run_id", + parent_run_id=parent_run_id, code_location=None, nominal_start_time=datetime.datetime(2022, 1, 1).isoformat(), nominal_end_time=datetime.datetime(2022, 1, 1).isoformat(), @@ -214,10 +214,16 @@ def test_emit_start_event_with_additional_information(mock_stats_incr, mock_stat task=OperatorLineage( inputs=[Dataset(namespace="bigquery", name="a.b.c"), Dataset(namespace="bigquery", name="x.y.z")], outputs=[Dataset(namespace="gs://bucket", name="exported_folder")], - job_facets={"sql": SqlJobFacet(query="SELECT 1;")}, - run_facets={"externalQuery1": ExternalQueryRunFacet(externalQueryId="123", source="source")}, + job_facets={"sql": sql_job.SQLJobFacet(query="SELECT 1;")}, + run_facets={ + "externalQuery1": external_query_run.ExternalQueryRunFacet( + externalQueryId="123", source="source" + ) + }, ), - run_facets={"externalQuery2": ExternalQueryRunFacet(externalQueryId="999", source="source")}, + run_facets={ + "externalQuery2": external_query_run.ExternalQueryRunFacet(externalQueryId="999", source="source") + }, ) assert ( @@ -228,34 +234,38 @@ def test_emit_start_event_with_additional_information(mock_stats_incr, mock_stat run=Run( runId=run_id, facets={ - "nominalTime": NominalTimeRunFacet( + "nominalTime": nominal_time_run.NominalTimeRunFacet( nominalStartTime="2022-01-01T00:00:00", nominalEndTime="2022-01-01T00:00:00", ), - "processing_engine": ProcessingEngineRunFacet( + "processing_engine": processing_engine_run.ProcessingEngineRunFacet( version=ANY, name="Airflow", openlineageAdapterVersion=ANY ), - "parent": ParentRunFacet( - run={"runId": "parent_run_id"}, - job={"namespace": namespace(), "name": "parent_job_name"}, + "parent": parent_run.ParentRunFacet( + run=parent_run.Run(runId=parent_run_id), + job=parent_run.Job(namespace=namespace(), name="parent_job_name"), + ), + "externalQuery1": external_query_run.ExternalQueryRunFacet( + externalQueryId="123", source="source" + ), + "externalQuery2": external_query_run.ExternalQueryRunFacet( + externalQueryId="999", source="source" ), - "externalQuery1": ExternalQueryRunFacet(externalQueryId="123", source="source"), - "externalQuery2": ExternalQueryRunFacet(externalQueryId="999", source="source"), }, ), job=Job( namespace=namespace(), name="job", facets={ - "documentation": DocumentationJobFacet(description="description"), - "ownership": OwnershipJobFacet( + "documentation": documentation_job.DocumentationJobFacet(description="description"), + "ownership": ownership_job.OwnershipJobFacet( owners=[ - OwnershipJobFacetOwners(name="owner1", type=None), - OwnershipJobFacetOwners(name="owner2", type=None), + ownership_job.Owner(name="owner1", type=None), + ownership_job.Owner(name="owner2", type=None), ] ), - "sql": SqlJobFacet(query="SELECT 1;"), - "jobType": JobTypeJobFacet( + "sql": sql_job.SQLJobFacet(query="SELECT 1;"), + "jobType": job_type_job.JobTypeJobFacet( processingType="BATCH", integration="AIRFLOW", jobType="TASK" ), }, @@ -302,7 +312,7 @@ def test_emit_complete_event(mock_stats_incr, mock_stats_timer): namespace=namespace(), name="job", facets={ - "jobType": JobTypeJobFacet( + "jobType": job_type_job.JobTypeJobFacet( processingType="BATCH", integration="AIRFLOW", jobType="TASK" ) }, @@ -326,18 +336,23 @@ def test_emit_complete_event_with_additional_information(mock_stats_incr, mock_s adapter = OpenLineageAdapter(client) run_id = str(uuid.uuid4()) + parent_run_id = str(uuid.uuid4()) event_time = datetime.datetime.now().isoformat() adapter.complete_task( run_id=run_id, end_time=event_time, parent_job_name="parent_job_name", - parent_run_id="parent_run_id", + parent_run_id=parent_run_id, job_name="job", task=OperatorLineage( inputs=[Dataset(namespace="bigquery", name="a.b.c"), Dataset(namespace="bigquery", name="x.y.z")], outputs=[Dataset(namespace="gs://bucket", name="exported_folder")], - job_facets={"sql": SqlJobFacet(query="SELECT 1;")}, - run_facets={"externalQuery": ExternalQueryRunFacet(externalQueryId="123", source="source")}, + job_facets={"sql": sql_job.SQLJobFacet(query="SELECT 1;")}, + run_facets={ + "externalQuery": external_query_run.ExternalQueryRunFacet( + externalQueryId="123", source="source" + ) + }, ), ) @@ -349,19 +364,21 @@ def test_emit_complete_event_with_additional_information(mock_stats_incr, mock_s run=Run( runId=run_id, facets={ - "parent": ParentRunFacet( - run={"runId": "parent_run_id"}, - job={"namespace": namespace(), "name": "parent_job_name"}, + "parent": parent_run.ParentRunFacet( + run=parent_run.Run(runId=parent_run_id), + job=parent_run.Job(namespace=namespace(), name="parent_job_name"), + ), + "externalQuery": external_query_run.ExternalQueryRunFacet( + externalQueryId="123", source="source" ), - "externalQuery": ExternalQueryRunFacet(externalQueryId="123", source="source"), }, ), job=Job( namespace="default", name="job", facets={ - "sql": SqlJobFacet(query="SELECT 1;"), - "jobType": JobTypeJobFacet( + "sql": sql_job.SQLJobFacet(query="SELECT 1;"), + "jobType": job_type_job.JobTypeJobFacet( processingType="BATCH", integration="AIRFLOW", jobType="TASK" ), }, @@ -408,7 +425,7 @@ def test_emit_failed_event(mock_stats_incr, mock_stats_timer): namespace=namespace(), name="job", facets={ - "jobType": JobTypeJobFacet( + "jobType": job_type_job.JobTypeJobFacet( processingType="BATCH", integration="AIRFLOW", jobType="TASK" ) }, @@ -432,59 +449,60 @@ def test_emit_failed_event_with_additional_information(mock_stats_incr, mock_sta adapter = OpenLineageAdapter(client) run_id = str(uuid.uuid4()) + parent_run_id = str(uuid.uuid4()) event_time = datetime.datetime.now().isoformat() adapter.fail_task( run_id=run_id, end_time=event_time, parent_job_name="parent_job_name", - parent_run_id="parent_run_id", + parent_run_id=parent_run_id, job_name="job", task=OperatorLineage( inputs=[Dataset(namespace="bigquery", name="a.b.c"), Dataset(namespace="bigquery", name="x.y.z")], outputs=[Dataset(namespace="gs://bucket", name="exported_folder")], - run_facets={"externalQuery": ExternalQueryRunFacet(externalQueryId="123", source="source")}, - job_facets={"sql": SqlJobFacet(query="SELECT 1;")}, + run_facets={ + "externalQuery": external_query_run.ExternalQueryRunFacet( + externalQueryId="123", source="source" + ) + }, + job_facets={"sql": sql_job.SQLJobFacet(query="SELECT 1;")}, ), error=ValueError("Error message"), ) - assert ( - call( - RunEvent( - eventType=RunState.FAIL, - eventTime=event_time, - run=Run( - runId=run_id, - facets={ - "parent": ParentRunFacet( - run={"runId": "parent_run_id"}, - job={"namespace": namespace(), "name": "parent_job_name"}, - ), - "externalQuery": ExternalQueryRunFacet(externalQueryId="123", source="source"), - "errorMessage": ErrorMessageRunFacet( - message="Error message", programmingLanguage="python", stackTrace=None - ), - }, - ), - job=Job( - namespace="default", - name="job", - facets={ - "sql": SqlJobFacet(query="SELECT 1;"), - "jobType": JobTypeJobFacet( - processingType="BATCH", integration="AIRFLOW", jobType="TASK" - ), - }, - ), - producer=_PRODUCER, - inputs=[ - Dataset(namespace="bigquery", name="a.b.c"), - Dataset(namespace="bigquery", name="x.y.z"), - ], - outputs=[Dataset(namespace="gs://bucket", name="exported_folder")], - ) + assert client.emit.mock_calls[0] == call( + RunEvent( + eventType=RunState.FAIL, + eventTime=event_time, + run=Run( + runId=run_id, + facets={ + "parent": parent_run.ParentRunFacet( + run=parent_run.Run(runId=parent_run_id), + job=parent_run.Job(namespace=namespace(), name="parent_job_name"), + ), + "externalQuery": external_query_run.ExternalQueryRunFacet( + externalQueryId="123", source="source" + ), + }, + ), + job=Job( + namespace="default", + name="job", + facets={ + "sql": sql_job.SQLJobFacet(query="SELECT 1;"), + "jobType": job_type_job.JobTypeJobFacet( + processingType="BATCH", integration="AIRFLOW", jobType="TASK" + ), + }, + ), + producer=_PRODUCER, + inputs=[ + Dataset(namespace="bigquery", name="a.b.c"), + Dataset(namespace="bigquery", name="x.y.z"), + ], + outputs=[Dataset(namespace="gs://bucket", name="exported_folder")], ) - in client.emit.mock_calls ) mock_stats_incr.assert_not_called() @@ -538,7 +556,7 @@ def test_emit_dag_started_event(mock_stats_incr, mock_stats_timer, generate_stat run=Run( runId=random_uuid, facets={ - "nominalTime": NominalTimeRunFacet( + "nominalTime": nominal_time_run.NominalTimeRunFacet( nominalStartTime=event_time.isoformat(), nominalEndTime=event_time.isoformat(), ), @@ -569,14 +587,14 @@ def test_emit_dag_started_event(mock_stats_incr, mock_stats_timer, generate_stat namespace=namespace(), name="dag_id", facets={ - "documentation": DocumentationJobFacet(description="dag desc"), - "ownership": OwnershipJobFacet( + "documentation": documentation_job.DocumentationJobFacet(description="dag desc"), + "ownership": ownership_job.OwnershipJobFacet( owners=[ - OwnershipJobFacetOwners(name="airflow", type=None), + ownership_job.Owner(name="airflow", type=None), ] ), **job_facets, - "jobType": JobTypeJobFacet( + "jobType": job_type_job.JobTypeJobFacet( processingType="BATCH", integration="AIRFLOW", jobType="DAG" ), }, @@ -654,7 +672,7 @@ def test_emit_dag_complete_event(mock_stats_incr, mock_stats_timer, generate_sta namespace=namespace(), name=dag_id, facets={ - "jobType": JobTypeJobFacet( + "jobType": job_type_job.JobTypeJobFacet( processingType="BATCH", integration="AIRFLOW", jobType="DAG" ) }, @@ -716,7 +734,7 @@ def test_emit_dag_failed_event(mock_stats_incr, mock_stats_timer, generate_stati run=Run( runId=random_uuid, facets={ - "errorMessage": ErrorMessageRunFacet( + "errorMessage": error_message_run.ErrorMessageRunFacet( message="error msg", programmingLanguage="python" ), "airflowState": AirflowStateRunFacet( @@ -733,7 +751,7 @@ def test_emit_dag_failed_event(mock_stats_incr, mock_stats_timer, generate_stati namespace=namespace(), name=dag_id, facets={ - "jobType": JobTypeJobFacet( + "jobType": job_type_job.JobTypeJobFacet( processingType="BATCH", integration="AIRFLOW", jobType="DAG" ) }, diff --git a/tests/providers/openlineage/test_sqlparser.py b/tests/providers/openlineage/test_sqlparser.py index 020d8384c6bc..edde376fad3a 100644 --- a/tests/providers/openlineage/test_sqlparser.py +++ b/tests/providers/openlineage/test_sqlparser.py @@ -20,14 +20,8 @@ from unittest.mock import MagicMock import pytest -from openlineage.client.facet import ( - ColumnLineageDatasetFacet, - ColumnLineageDatasetFacetFieldsAdditional, - ColumnLineageDatasetFacetFieldsAdditionalInputFields, - SchemaDatasetFacet, - SchemaField, -) -from openlineage.client.run import Dataset +from openlineage.client.event_v2 import Dataset +from openlineage.client.facet_v2 import column_lineage_dataset, schema_dataset, sql_job from openlineage.common.sql import DbTableMeta from airflow.providers.openlineage.sqlparser import DatabaseInfo, GetTableSchemasParams, SQLParser @@ -199,13 +193,13 @@ def rows(name): rows("popular_orders_day_of_week"), ] - expected_schema_facet = SchemaDatasetFacet( + expected_schema_facet = schema_dataset.SchemaDatasetFacet( fields=[ - SchemaField(name="ID", type="int4"), - SchemaField(name="AMOUNT_OFF", type="int4"), - SchemaField(name="CUSTOMER_EMAIL", type="varchar"), - SchemaField(name="STARTS_ON", type="timestamp"), - SchemaField(name="ENDS_ON", type="timestamp"), + schema_dataset.SchemaDatasetFacetFields(name="ID", type="int4"), + schema_dataset.SchemaDatasetFacetFields(name="AMOUNT_OFF", type="int4"), + schema_dataset.SchemaDatasetFacetFields(name="CUSTOMER_EMAIL", type="varchar"), + schema_dataset.SchemaDatasetFacetFields(name="STARTS_ON", type="timestamp"), + schema_dataset.SchemaDatasetFacetFields(name="ENDS_ON", type="timestamp"), ] ) @@ -324,11 +318,11 @@ def test_generate_openlineage_metadata_from_sql(self, mock_parse, parser_returns namespace="myscheme://host:port", name=f"{expected_schema}.top_delivery_times", facets={ - "schema": SchemaDatasetFacet( + "schema": schema_dataset.SchemaDatasetFacet( fields=[ - SchemaField(name="order_id", type="int4"), - SchemaField(name="order_placed_on", type="timestamp"), - SchemaField(name="customer_email", type="varchar"), + schema_dataset.SchemaDatasetFacetFields(name="order_id", type="int4"), + schema_dataset.SchemaDatasetFacetFields(name="order_placed_on", type="timestamp"), + schema_dataset.SchemaDatasetFacetFields(name="customer_email", type="varchar"), ] ) }, @@ -337,18 +331,18 @@ def test_generate_openlineage_metadata_from_sql(self, mock_parse, parser_returns assert len(metadata.outputs) == 1 assert metadata.outputs[0].namespace == "myscheme://host:port" assert metadata.outputs[0].name == f"{expected_schema}.popular_orders_day_of_week" - assert metadata.outputs[0].facets["schema"] == SchemaDatasetFacet( + assert metadata.outputs[0].facets["schema"] == schema_dataset.SchemaDatasetFacet( fields=[ - SchemaField(name="order_day_of_week", type="varchar"), - SchemaField(name="order_placed_on", type="timestamp"), - SchemaField(name="orders_placed", type="int4"), + schema_dataset.SchemaDatasetFacetFields(name="order_day_of_week", type="varchar"), + schema_dataset.SchemaDatasetFacetFields(name="order_placed_on", type="timestamp"), + schema_dataset.SchemaDatasetFacetFields(name="orders_placed", type="int4"), ] ) - assert metadata.outputs[0].facets["columnLineage"] == ColumnLineageDatasetFacet( + assert metadata.outputs[0].facets["columnLineage"] == column_lineage_dataset.ColumnLineageDatasetFacet( fields={ - "order_day_of_week": ColumnLineageDatasetFacetFieldsAdditional( + "order_day_of_week": column_lineage_dataset.Fields( inputFields=[ - ColumnLineageDatasetFacetFieldsAdditionalInputFields( + column_lineage_dataset.InputField( namespace="myscheme://host:port", name=f"{expected_schema}.top_delivery_times", field="order_placed_on", diff --git a/tests/providers/openlineage/utils/test_sql.py b/tests/providers/openlineage/utils/test_sql.py index f094fdaf1f7c..2e2f36ca74d3 100644 --- a/tests/providers/openlineage/utils/test_sql.py +++ b/tests/providers/openlineage/utils/test_sql.py @@ -19,8 +19,9 @@ from unittest.mock import MagicMock import pytest -from openlineage.client.facet import SchemaDatasetFacet, SchemaField, set_producer -from openlineage.client.run import Dataset +from openlineage.client import set_producer +from openlineage.client.event_v2 import Dataset +from openlineage.client.facet_v2 import schema_dataset from openlineage.common.sql import DbTableMeta from sqlalchemy import Column, MetaData, Table @@ -38,13 +39,13 @@ DB_SCHEMA_NAME = "PUBLIC" DB_TABLE_NAME = DbTableMeta("DISCOUNTS") -SCHEMA_FACET = SchemaDatasetFacet( +SCHEMA_FACET = schema_dataset.SchemaDatasetFacet( fields=[ - SchemaField(name="ID", type="int4"), - SchemaField(name="AMOUNT_OFF", type="int4"), - SchemaField(name="CUSTOMER_EMAIL", type="varchar"), - SchemaField(name="STARTS_ON", type="timestamp"), - SchemaField(name="ENDS_ON", type="timestamp"), + schema_dataset.SchemaDatasetFacetFields(name="ID", type="int4"), + schema_dataset.SchemaDatasetFacetFields(name="AMOUNT_OFF", type="int4"), + schema_dataset.SchemaDatasetFacetFields(name="CUSTOMER_EMAIL", type="varchar"), + schema_dataset.SchemaDatasetFacetFields(name="STARTS_ON", type="timestamp"), + schema_dataset.SchemaDatasetFacetFields(name="ENDS_ON", type="timestamp"), ] ) From 36a759ff9c97a72ac724ce670ebd77bafe5cc2e0 Mon Sep 17 00:00:00 2001 From: Jakub Dardzinski Date: Fri, 10 May 2024 10:03:33 +0200 Subject: [PATCH 2/7] Migrate amazon to v2 facets. Signed-off-by: Jakub Dardzinski --- .../providers/amazon/aws/operators/athena.py | 42 +++--- airflow/providers/amazon/aws/operators/s3.py | 20 ++- .../amazon/aws/operators/sagemaker.py | 4 +- .../amazon/aws/operators/test_athena.py | 43 +++--- .../amazon/aws/operators/test_redshift_sql.py | 134 +++++++++--------- .../providers/amazon/aws/operators/test_s3.py | 26 ++-- .../operators/test_sagemaker_processing.py | 2 +- .../aws/operators/test_sagemaker_training.py | 2 +- .../aws/operators/test_sagemaker_transform.py | 2 +- 9 files changed, 127 insertions(+), 148 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/athena.py b/airflow/providers/amazon/aws/operators/athena.py index 0178d60a12c9..ea7bc063145e 100644 --- a/airflow/providers/amazon/aws/operators/athena.py +++ b/airflow/providers/amazon/aws/operators/athena.py @@ -30,8 +30,8 @@ from airflow.providers.amazon.aws.utils.mixins import aws_template_fields if TYPE_CHECKING: - from openlineage.client.facet import BaseFacet - from openlineage.client.run import Dataset + from openlineage.client.event_v2 import Dataset + from openlineage.client.facet_v2 import BaseFacet from airflow.providers.openlineage.extractors.base import OperatorLineage from airflow.utils.context import Context @@ -217,20 +217,17 @@ def get_openlineage_facets_on_complete(self, _) -> OperatorLineage: path where the results are saved (user's prefix + some UUID), we are creating a dataset with the user-provided path only. This should make it easier to match this dataset across different processes. """ - from openlineage.client.facet import ( - ExternalQueryRunFacet, - ExtractionError, - ExtractionErrorRunFacet, - SqlJobFacet, - ) - from openlineage.client.run import Dataset + from openlineage.client.event_v2 import Dataset + from openlineage.client.facet_v2 import extraction_error_run, external_query_run, sql_job from airflow.providers.openlineage.extractors.base import OperatorLineage from airflow.providers.openlineage.sqlparser import SQLParser sql_parser = SQLParser(dialect="generic") - job_facets: dict[str, BaseFacet] = {"sql": SqlJobFacet(query=sql_parser.normalize_sql(self.query))} + job_facets: dict[str, BaseFacet] = { + "sql": sql_job.SQLJobFacet(query=sql_parser.normalize_sql(self.query)) + } parse_result = sql_parser.parse(sql=self.query) if not parse_result: @@ -238,11 +235,11 @@ def get_openlineage_facets_on_complete(self, _) -> OperatorLineage: run_facets: dict[str, BaseFacet] = {} if parse_result.errors: - run_facets["extractionError"] = ExtractionErrorRunFacet( + run_facets["extractionError"] = extraction_error_run.ExtractionErrorRunFacet( totalTasks=len(self.query) if isinstance(self.query, list) else 1, failedTasks=len(parse_result.errors), errors=[ - ExtractionError( + extraction_error_run.Error( errorMessage=error.message, stackTrace=None, task=error.origin_statement, @@ -273,7 +270,7 @@ def get_openlineage_facets_on_complete(self, _) -> OperatorLineage: ) if self.query_execution_id: - run_facets["externalQuery"] = ExternalQueryRunFacet( + run_facets["externalQuery"] = external_query_run.ExternalQueryRunFacet( externalQueryId=self.query_execution_id, source="awsathena" ) @@ -284,13 +281,8 @@ def get_openlineage_facets_on_complete(self, _) -> OperatorLineage: return OperatorLineage(job_facets=job_facets, run_facets=run_facets, inputs=inputs, outputs=outputs) def get_openlineage_dataset(self, database, table) -> Dataset | None: - from openlineage.client.facet import ( - SchemaDatasetFacet, - SchemaField, - SymlinksDatasetFacet, - SymlinksDatasetFacetIdentifiers, - ) - from openlineage.client.run import Dataset + from openlineage.client.event_v2 import Dataset + from openlineage.client.facet_v2 import schema_dataset, symlinks_dataset client = self.hook.get_conn() try: @@ -302,9 +294,9 @@ def get_openlineage_dataset(self, database, table) -> Dataset | None: s3_location = table_metadata["TableMetadata"]["Parameters"]["location"] parsed_path = urlparse(s3_location) facets: dict[str, BaseFacet] = { - "symlinks": SymlinksDatasetFacet( + "symlinks": symlinks_dataset.SymlinksDatasetFacet( identifiers=[ - SymlinksDatasetFacetIdentifiers( + symlinks_dataset.Identifier( namespace=f"{parsed_path.scheme}://{parsed_path.netloc}", name=str(parsed_path.path), type="TABLE", @@ -313,11 +305,13 @@ def get_openlineage_dataset(self, database, table) -> Dataset | None: ) } fields = [ - SchemaField(name=column["Name"], type=column["Type"], description=column.get("Comment")) + schema_dataset.SchemaDatasetFacetFields( + name=column["Name"], type=column["Type"], description=column["Comment"] + ) for column in table_metadata["TableMetadata"]["Columns"] ] if fields: - facets["schema"] = SchemaDatasetFacet(fields=fields) + facets["schema"] = schema_dataset.SchemaDatasetFacet(fields=fields) return Dataset( namespace=f"awsathena://athena.{self.hook.region_name}.amazonaws.com", name=".".join(filter(None, (self.catalog, database, table))), diff --git a/airflow/providers/amazon/aws/operators/s3.py b/airflow/providers/amazon/aws/operators/s3.py index f2733495efc0..a1464cc32816 100644 --- a/airflow/providers/amazon/aws/operators/s3.py +++ b/airflow/providers/amazon/aws/operators/s3.py @@ -324,7 +324,7 @@ def execute(self, context: Context): ) def get_openlineage_facets_on_start(self): - from openlineage.client.run import Dataset + from openlineage.client.event_v2 import Dataset from airflow.providers.openlineage.extractors import OperatorLineage @@ -439,7 +439,7 @@ def execute(self, context: Context): s3_hook.load_bytes(self.data, s3_key, s3_bucket, self.replace, self.encrypt, self.acl_policy) def get_openlineage_facets_on_start(self): - from openlineage.client.run import Dataset + from openlineage.client.event_v2 import Dataset from airflow.providers.openlineage.extractors import OperatorLineage @@ -546,12 +546,8 @@ def execute(self, context: Context): def get_openlineage_facets_on_complete(self, task_instance): """Implement _on_complete because object keys are resolved in execute().""" - from openlineage.client.facet import ( - LifecycleStateChange, - LifecycleStateChangeDatasetFacet, - LifecycleStateChangeDatasetFacetPreviousIdentifier, - ) - from openlineage.client.run import Dataset + from openlineage.client.event_v2 import Dataset + from openlineage.client.facet_v2 import lifecycle_state_change_dataset from airflow.providers.openlineage.extractors import OperatorLineage @@ -568,9 +564,9 @@ def get_openlineage_facets_on_complete(self, task_instance): namespace=bucket_url, name=key, facets={ - "lifecycleStateChange": LifecycleStateChangeDatasetFacet( - lifecycleStateChange=LifecycleStateChange.DROP.value, - previousIdentifier=LifecycleStateChangeDatasetFacetPreviousIdentifier( + "lifecycleStateChange": lifecycle_state_change_dataset.LifecycleStateChangeDatasetFacet( + lifecycleStateChange=lifecycle_state_change_dataset.LifecycleStateChange.DROP.value, + previousIdentifier=lifecycle_state_change_dataset.PreviousIdentifier( namespace=bucket_url, name=key, ), @@ -725,7 +721,7 @@ def execute(self, context: Context): self.log.info("Upload successful") def get_openlineage_facets_on_start(self): - from openlineage.client.run import Dataset + from openlineage.client.event_v2 import Dataset from airflow.providers.openlineage.extractors import OperatorLineage diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py b/airflow/providers/amazon/aws/operators/sagemaker.py index 63560c66a21c..0ac5ad700996 100644 --- a/airflow/providers/amazon/aws/operators/sagemaker.py +++ b/airflow/providers/amazon/aws/operators/sagemaker.py @@ -46,7 +46,7 @@ from airflow.utils.json import AirflowJsonEncoder if TYPE_CHECKING: - from openlineage.client.run import Dataset + from openlineage.client.event_v2 import Dataset from airflow.providers.openlineage.extractors.base import OperatorLineage from airflow.utils.context import Context @@ -208,7 +208,7 @@ def hook(self): @staticmethod def path_to_s3_dataset(path) -> Dataset: - from openlineage.client.run import Dataset + from openlineage.client.event_v2 import Dataset path = path.replace("s3://", "") split_path = path.split("/") diff --git a/tests/providers/amazon/aws/operators/test_athena.py b/tests/providers/amazon/aws/operators/test_athena.py index 5d5a6b88c35f..51fef31c68d4 100644 --- a/tests/providers/amazon/aws/operators/test_athena.py +++ b/tests/providers/amazon/aws/operators/test_athena.py @@ -20,15 +20,8 @@ from unittest import mock import pytest -from openlineage.client.facet import ( - ExternalQueryRunFacet, - SchemaDatasetFacet, - SchemaField, - SqlJobFacet, - SymlinksDatasetFacet, - SymlinksDatasetFacetIdentifiers, -) -from openlineage.client.run import Dataset +from openlineage.client.event_v2 import Dataset +from openlineage.client.facet_v2 import external_query_run, schema_dataset, sql_job, symlinks_dataset from airflow.exceptions import AirflowException, TaskDeferred from airflow.models import DAG, DagRun, TaskInstance @@ -312,38 +305,38 @@ def mock_get_table_metadata(CatalogName, DatabaseName, TableName): namespace="awsathena://athena.eu-west-1.amazonaws.com", name="AwsDataCatalog.TEST_DATABASE.DISCOUNTS", facets={ - "symlinks": SymlinksDatasetFacet( + "symlinks": symlinks_dataset.SymlinksDatasetFacet( identifiers=[ - SymlinksDatasetFacetIdentifiers( + symlinks_dataset.Identifier( namespace="s3://bucket", name="/discount/data/path/", type="TABLE", ) ], ), - "schema": SchemaDatasetFacet( + "schema": schema_dataset.SchemaDatasetFacet( fields=[ - SchemaField( + schema_dataset.SchemaDatasetFacetFields( name="ID", type="int", description="from deserializer", ), - SchemaField( + schema_dataset.SchemaDatasetFacetFields( name="AMOUNT_OFF", type="int", description="from deserializer", ), - SchemaField( + schema_dataset.SchemaDatasetFacetFields( name="CUSTOMER_EMAIL", type="varchar", description="from deserializer", ), - SchemaField( + schema_dataset.SchemaDatasetFacetFields( name="STARTS_ON", type="timestamp", description="from deserializer", ), - SchemaField( + schema_dataset.SchemaDatasetFacetFields( name="ENDS_ON", type="timestamp", description="from deserializer", @@ -358,18 +351,18 @@ def mock_get_table_metadata(CatalogName, DatabaseName, TableName): namespace="awsathena://athena.eu-west-1.amazonaws.com", name="AwsDataCatalog.TEST_DATABASE.TEST_TABLE", facets={ - "symlinks": SymlinksDatasetFacet( + "symlinks": symlinks_dataset.SymlinksDatasetFacet( identifiers=[ - SymlinksDatasetFacetIdentifiers( + symlinks_dataset.Identifier( namespace="s3://bucket", name="/data/test_table/data/path", type="TABLE", ) ], ), - "schema": SchemaDatasetFacet( + "schema": schema_dataset.SchemaDatasetFacet( fields=[ - SchemaField( + schema_dataset.SchemaDatasetFacetFields( name="column", type="string", description="from deserializer", @@ -381,10 +374,14 @@ def mock_get_table_metadata(CatalogName, DatabaseName, TableName): Dataset(namespace="s3://test_s3_bucket", name="/"), ], job_facets={ - "sql": SqlJobFacet( + "sql": sql_job.SQLJobFacet( query="INSERT INTO TEST_TABLE SELECT CUSTOMER_EMAIL FROM DISCOUNTS", ) }, - run_facets={"externalQuery": ExternalQueryRunFacet(externalQueryId="12345", source="awsathena")}, + run_facets={ + "externalQuery": external_query_run.ExternalQueryRunFacet( + externalQueryId="12345", source="awsathena" + ) + }, ) assert op.get_openlineage_facets_on_complete(None) == expected_lineage diff --git a/tests/providers/amazon/aws/operators/test_redshift_sql.py b/tests/providers/amazon/aws/operators/test_redshift_sql.py index 003c40e615b7..6df87ac1a443 100644 --- a/tests/providers/amazon/aws/operators/test_redshift_sql.py +++ b/tests/providers/amazon/aws/operators/test_redshift_sql.py @@ -20,15 +20,8 @@ from unittest.mock import MagicMock, PropertyMock, call, patch import pytest -from openlineage.client.facet import ( - ColumnLineageDatasetFacet, - ColumnLineageDatasetFacetFieldsAdditional, - ColumnLineageDatasetFacetFieldsAdditionalInputFields, - SchemaDatasetFacet, - SchemaField, - SqlJobFacet, -) -from openlineage.client.run import Dataset +from openlineage.client.event_v2 import Dataset +from openlineage.client.facet_v2 import column_lineage_dataset, schema_dataset, sql_job from airflow.models.connection import Connection from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook as OriginalRedshiftSQLHook @@ -208,66 +201,69 @@ def get_db_hook(self): assert dbapi_hook.get_conn.return_value.cursor.return_value.execute.mock_calls == [] expected_namespace = f"redshift://{expected_identity}:5439" - if is_over_210: - assert lineage.inputs == [ - Dataset( - namespace=expected_namespace, - name=f"{ANOTHER_DB_NAME}.{ANOTHER_DB_SCHEMA}.popular_orders_day_of_week", - facets={ - "schema": SchemaDatasetFacet( - fields=[ - SchemaField(name="order_day_of_week", type="varchar"), - SchemaField(name="order_placed_on", type="timestamp"), - SchemaField(name="orders_placed", type="int4"), - ] - ) - }, - ), - Dataset( - namespace=expected_namespace, - name=f"{DB_NAME}.{DB_SCHEMA_NAME}.little_table", - facets={ - "schema": SchemaDatasetFacet( - fields=[ - SchemaField(name="order_day_of_week", type="varchar"), - SchemaField(name="additional_constant", type="varchar"), - ] - ) - }, - ), - ] - assert lineage.outputs == [ - Dataset( - namespace=expected_namespace, - name=f"{DB_NAME}.{DB_SCHEMA_NAME}.test_table", - facets={ - "schema": SchemaDatasetFacet( - fields=[ - SchemaField(name="order_day_of_week", type="varchar"), - SchemaField(name="order_placed_on", type="timestamp"), - SchemaField(name="orders_placed", type="int4"), - SchemaField(name="additional_constant", type="varchar"), - ] - ), - "columnLineage": ColumnLineageDatasetFacet( - fields={ - "additional_constant": ColumnLineageDatasetFacetFieldsAdditional( - inputFields=[ - ColumnLineageDatasetFacetFieldsAdditionalInputFields( - namespace=expected_namespace, - name="database.public.little_table", - field="additional_constant", - ) - ], - transformationDescription="", - transformationType="", - ) - } - ), - }, - ) - ] + assert lineage.inputs == [ + Dataset( + namespace=expected_namespace, + name=f"{ANOTHER_DB_NAME}.{ANOTHER_DB_SCHEMA}.popular_orders_day_of_week", + facets={ + "schema": schema_dataset.SchemaDatasetFacet( + fields=[ + schema_dataset.SchemaDatasetFacetFields(name="order_day_of_week", type="varchar"), + schema_dataset.SchemaDatasetFacetFields(name="order_placed_on", type="timestamp"), + schema_dataset.SchemaDatasetFacetFields(name="orders_placed", type="int4"), + ] + ) + }, + ), + Dataset( + namespace=expected_namespace, + name=f"{DB_NAME}.{DB_SCHEMA_NAME}.little_table", + facets={ + "schema": schema_dataset.SchemaDatasetFacet( + fields=[ + schema_dataset.SchemaDatasetFacetFields(name="order_day_of_week", type="varchar"), + schema_dataset.SchemaDatasetFacetFields( + name="additional_constant", type="varchar" + ), + ] + ) + }, + ), + ] + assert lineage.outputs == [ + Dataset( + namespace=expected_namespace, + name=f"{DB_NAME}.{DB_SCHEMA_NAME}.test_table", + facets={ + "schema": schema_dataset.SchemaDatasetFacet( + fields=[ + schema_dataset.SchemaDatasetFacetFields(name="order_day_of_week", type="varchar"), + schema_dataset.SchemaDatasetFacetFields(name="order_placed_on", type="timestamp"), + schema_dataset.SchemaDatasetFacetFields(name="orders_placed", type="int4"), + schema_dataset.SchemaDatasetFacetFields( + name="additional_constant", type="varchar" + ), + ] + ), + "columnLineage": column_lineage_dataset.ColumnLineageDatasetFacet( + fields={ + "additional_constant": column_lineage_dataset.Fields( + inputFields=[ + column_lineage_dataset.InputField( + namespace=expected_namespace, + name="database.public.little_table", + field="additional_constant", + ) + ], + transformationDescription="", + transformationType="", + ) + } + ), + }, + ) + ] - assert lineage.job_facets == {"sql": SqlJobFacet(query=sql)} + assert lineage.job_facets == {"sql": sql_job.SQLJobFacet(query=sql)} assert lineage.run_facets["extractionError"].failedTasks == 1 diff --git a/tests/providers/amazon/aws/operators/test_s3.py b/tests/providers/amazon/aws/operators/test_s3.py index 5e4bbffbd07d..639b8a87932d 100644 --- a/tests/providers/amazon/aws/operators/test_s3.py +++ b/tests/providers/amazon/aws/operators/test_s3.py @@ -28,12 +28,8 @@ import boto3 import pytest from moto import mock_aws -from openlineage.client.facet import ( - LifecycleStateChange, - LifecycleStateChangeDatasetFacet, - LifecycleStateChangeDatasetFacetPreviousIdentifier, -) -from openlineage.client.run import Dataset +from openlineage.client.event_v2 import Dataset +from openlineage.client.facet_v2 import lifecycle_state_change_dataset from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.s3 import S3Hook @@ -771,9 +767,9 @@ def test_get_openlineage_facets_on_complete_single_object(self, mock_hook, keys) namespace=f"s3://{bucket}", name="path/data.txt", facets={ - "lifecycleStateChange": LifecycleStateChangeDatasetFacet( - lifecycleStateChange=LifecycleStateChange.DROP.value, - previousIdentifier=LifecycleStateChangeDatasetFacetPreviousIdentifier( + "lifecycleStateChange": lifecycle_state_change_dataset.LifecycleStateChangeDatasetFacet( + lifecycleStateChange=lifecycle_state_change_dataset.LifecycleStateChange.DROP.value, + previousIdentifier=lifecycle_state_change_dataset.PreviousIdentifier( namespace=f"s3://{bucket}", name="path/data.txt", ), @@ -797,9 +793,9 @@ def test_get_openlineage_facets_on_complete_multiple_objects(self, mock_hook): namespace=f"s3://{bucket}", name="path/data1.txt", facets={ - "lifecycleStateChange": LifecycleStateChangeDatasetFacet( - lifecycleStateChange=LifecycleStateChange.DROP.value, - previousIdentifier=LifecycleStateChangeDatasetFacetPreviousIdentifier( + "lifecycleStateChange": lifecycle_state_change_dataset.LifecycleStateChangeDatasetFacet( + lifecycleStateChange=lifecycle_state_change_dataset.LifecycleStateChange.DROP.value, + previousIdentifier=lifecycle_state_change_dataset.PreviousIdentifier( namespace=f"s3://{bucket}", name="path/data1.txt", ), @@ -810,9 +806,9 @@ def test_get_openlineage_facets_on_complete_multiple_objects(self, mock_hook): namespace=f"s3://{bucket}", name="path/data2.txt", facets={ - "lifecycleStateChange": LifecycleStateChangeDatasetFacet( - lifecycleStateChange=LifecycleStateChange.DROP.value, - previousIdentifier=LifecycleStateChangeDatasetFacetPreviousIdentifier( + "lifecycleStateChange": lifecycle_state_change_dataset.LifecycleStateChangeDatasetFacet( + lifecycleStateChange=lifecycle_state_change_dataset.LifecycleStateChange.DROP.value, + previousIdentifier=lifecycle_state_change_dataset.PreviousIdentifier( namespace=f"s3://{bucket}", name="path/data2.txt", ), diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py index 3a9c9c21f1aa..b08a4eff5e6d 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py @@ -20,7 +20,7 @@ import pytest from botocore.exceptions import ClientError -from openlineage.client.run import Dataset +from openlineage.client.event_v2 import Dataset from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_training.py b/tests/providers/amazon/aws/operators/test_sagemaker_training.py index 9d3ad5aee2ab..9316347de57e 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_training.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_training.py @@ -21,7 +21,7 @@ import pytest from botocore.exceptions import ClientError -from openlineage.client.run import Dataset +from openlineage.client.event_v2 import Dataset from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.amazon.aws.hooks.sagemaker import LogState, SageMakerHook diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py index 2634ac72bf36..4c9a2c3a6201 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py @@ -22,7 +22,7 @@ import pytest from botocore.exceptions import ClientError -from openlineage.client.run import Dataset +from openlineage.client.event_v2 import Dataset from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook From 37b11aa23bfaf72ff41971b5b72eca37711b6392 Mon Sep 17 00:00:00 2001 From: Jakub Dardzinski Date: Fri, 10 May 2024 10:21:05 +0200 Subject: [PATCH 3/7] Migrate OL-dependent providers to v2 facets. Signed-off-by: Jakub Dardzinski --- .../providers/amazon/aws/operators/athena.py | 4 +- .../common/io/operators/file_transfer.py | 2 +- airflow/providers/ftp/operators/ftp.py | 2 +- .../google/cloud/openlineage/utils.py | 243 +++++++++++++++++- .../google/cloud/operators/bigquery.py | 2 +- .../providers/google/cloud/operators/gcs.py | 18 +- .../google/cloud/transfers/bigquery_to_gcs.py | 18 +- .../google/cloud/transfers/gcs_to_bigquery.py | 16 +- .../google/cloud/transfers/gcs_to_gcs.py | 2 +- .../providers/openlineage/extractors/base.py | 7 +- airflow/providers/openlineage/sqlparser.py | 3 +- airflow/providers/sftp/operators/sftp.py | 2 +- .../providers/snowflake/hooks/snowflake.py | 4 +- .../transfers/copy_into_snowflake.py | 17 +- .../guides/developer.rst | 12 +- .../common/io/operators/test_file_transfer.py | 2 +- .../common/sql/operators/test_sql_execute.py | 14 +- tests/providers/ftp/operators/test_ftp.py | 2 +- .../google/cloud/openlineage/test_utils.py | 45 ++-- .../google/cloud/operators/test_bigquery.py | 5 +- .../google/cloud/operators/test_gcs.py | 23 +- .../cloud/transfers/test_bigquery_to_gcs.py | 76 +++--- .../cloud/transfers/test_gcs_to_bigquery.py | 120 ++++----- .../google/cloud/transfers/test_gcs_to_gcs.py | 8 +- tests/providers/mysql/operators/test_mysql.py | 14 +- tests/providers/sftp/operators/test_sftp.py | 2 +- .../snowflake/operators/test_snowflake_sql.py | 17 +- .../transfers/test_copy_into_snowflake.py | 27 +- tests/providers/trino/operators/test_trino.py | 20 +- 29 files changed, 467 insertions(+), 260 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/athena.py b/airflow/providers/amazon/aws/operators/athena.py index ea7bc063145e..08da5ef0e4df 100644 --- a/airflow/providers/amazon/aws/operators/athena.py +++ b/airflow/providers/amazon/aws/operators/athena.py @@ -31,7 +31,7 @@ if TYPE_CHECKING: from openlineage.client.event_v2 import Dataset - from openlineage.client.facet_v2 import BaseFacet + from openlineage.client.facet_v2 import BaseFacet, DatasetFacet from airflow.providers.openlineage.extractors.base import OperatorLineage from airflow.utils.context import Context @@ -293,7 +293,7 @@ def get_openlineage_dataset(self, database, table) -> Dataset | None: # Dataset has also its' physical location which we can add in symlink facet. s3_location = table_metadata["TableMetadata"]["Parameters"]["location"] parsed_path = urlparse(s3_location) - facets: dict[str, BaseFacet] = { + facets: dict[str, DatasetFacet] = { "symlinks": symlinks_dataset.SymlinksDatasetFacet( identifiers=[ symlinks_dataset.Identifier( diff --git a/airflow/providers/common/io/operators/file_transfer.py b/airflow/providers/common/io/operators/file_transfer.py index 9a396c86e490..273984e94136 100644 --- a/airflow/providers/common/io/operators/file_transfer.py +++ b/airflow/providers/common/io/operators/file_transfer.py @@ -75,7 +75,7 @@ def execute(self, context: Context) -> None: src.copy(dst) def get_openlineage_facets_on_start(self) -> OperatorLineage: - from openlineage.client.run import Dataset + from openlineage.client.event_v2 import Dataset from airflow.providers.openlineage.extractors import OperatorLineage diff --git a/airflow/providers/ftp/operators/ftp.py b/airflow/providers/ftp/operators/ftp.py index 691df6824d34..24f0a3d35130 100644 --- a/airflow/providers/ftp/operators/ftp.py +++ b/airflow/providers/ftp/operators/ftp.py @@ -146,7 +146,7 @@ def get_openlineage_facets_on_start(self): input: file://hostname/path output file://:/path. """ - from openlineage.client.run import Dataset + from openlineage.client.event_v2 import Dataset from airflow.providers.openlineage.extractors import OperatorLineage diff --git a/airflow/providers/google/cloud/openlineage/utils.py b/airflow/providers/google/cloud/openlineage/utils.py index 06a56ee5ab79..c4c4c72b5b85 100644 --- a/airflow/providers/google/cloud/openlineage/utils.py +++ b/airflow/providers/google/cloud/openlineage/utils.py @@ -17,17 +17,20 @@ # under the License. from __future__ import annotations +import copy +import json +from msilib import schema +import traceback from typing import TYPE_CHECKING, Any from attr import define, field -from openlineage.client.facet import ( +from openlineage.client.facet_v2 import ( BaseFacet, - ColumnLineageDatasetFacet, - ColumnLineageDatasetFacetFieldsAdditional, - ColumnLineageDatasetFacetFieldsAdditionalInputFields, - DocumentationDatasetFacet, - SchemaDatasetFacet, - SchemaField, + column_lineage_dataset, + documentation_dataset, + error_message_run,external_query_run, + output_statistics_output_dataset, + schema_dataset ) from airflow.providers.google import __version__ as provider_version @@ -44,13 +47,13 @@ def get_facets_from_bq_table(table: Table) -> dict[Any, Any]: """Get facets from BigQuery table object.""" facets = { - "schema": SchemaDatasetFacet( + "schema": schema_dataset.SchemaDatasetFacet( fields=[ - SchemaField(name=field.name, type=field.field_type, description=field.description) + schema_dataset.SchemaDatasetFacetFields(name=field.name, type=field.field_type, description=field.description) for field in table.schema ] ), - "documentation": DocumentationDatasetFacet(description=table.description or ""), + "documentation": documentation_dataset.DocumentationDatasetFacet(description=table.description or ""), } return facets @@ -59,7 +62,7 @@ def get_facets_from_bq_table(table: Table) -> dict[Any, Any]: def get_identity_column_lineage_facet( field_names: list[str], input_datasets: list[Dataset], -) -> ColumnLineageDatasetFacet: +) -> column_lineage_dataset.ColumnLineageDatasetFacet: """ Get column lineage facet. @@ -69,11 +72,11 @@ def get_identity_column_lineage_facet( if field_names and not input_datasets: raise ValueError("When providing `field_names` You must provide at least one `input_dataset`.") - column_lineage_facet = ColumnLineageDatasetFacet( + column_lineage_facet = column_lineage_dataset.ColumnLineageDatasetFacet( fields={ - field: ColumnLineageDatasetFacetFieldsAdditional( + field: column_lineage_dataset.Fields( inputFields=[ - ColumnLineageDatasetFacetFieldsAdditionalInputFields( + column_lineage_dataset.InputField( namespace=dataset.namespace, name=dataset.name, field=field ) for dataset in input_datasets @@ -175,3 +178,215 @@ def get_from_nullable_chain(source: Any, chain: list[str]) -> Any | None: return source except AttributeError: return None + + +class _BigQueryOpenLineageMixin: + def get_openlineage_facets_on_complete(self, _): + """ + Retrieve OpenLineage data for a COMPLETE BigQuery job. + + This method retrieves statistics for the specified job_ids using the BigQueryDatasetsProvider. + It calls BigQuery API, retrieving input and output dataset info from it, as well as run-level + usage statistics. + + Run facets should contain: + - ExternalQueryRunFacet + - BigQueryJobRunFacet + + Run facets may contain: + - ErrorMessageRunFacet + + Job facets should contain: + - SqlJobFacet if operator has self.sql + + Input datasets should contain facets: + - DataSourceDatasetFacet + - SchemaDatasetFacet + + Output datasets should contain facets: + - DataSourceDatasetFacet + - SchemaDatasetFacet + - OutputStatisticsOutputDatasetFacet + """ + from openlineage.client.facet import ExternalQueryRunFacet, SqlJobFacet + + from airflow.providers.openlineage.extractors import OperatorLineage + from airflow.providers.openlineage.sqlparser import SQLParser + + if not self.job_id: + return OperatorLineage() + + run_facets: dict[str, BaseFacet] = { + "externalQuery": external_query_run.ExternalQueryRunFacet(externalQueryId=self.job_id, source="bigquery") + } + + job_facets = {"sql": SqlJobFacet(query=SQLParser.normalize_sql(self.sql))} + + self.client = self.hook.get_client(project_id=self.hook.project_id) + job_ids = self.job_id + if isinstance(self.job_id, str): + job_ids = [self.job_id] + inputs, outputs = [], [] + for job_id in job_ids: + inner_inputs, inner_outputs, inner_run_facets = self.get_facets(job_id=job_id) + inputs.extend(inner_inputs) + outputs.extend(inner_outputs) + run_facets.update(inner_run_facets) + + return OperatorLineage( + inputs=inputs, + outputs=outputs, + run_facets=run_facets, + job_facets=job_facets, + ) + + def get_facets(self, job_id: str): + inputs = [] + outputs = [] + run_facets: dict[str, BaseFacet] = {} + if hasattr(self, "log"): + self.log.debug("Extracting data from bigquery job: `%s`", job_id) + try: + job = self.client.get_job(job_id=job_id) # type: ignore + props = job._properties + + if get_from_nullable_chain(props, ["status", "state"]) != "DONE": + raise ValueError(f"Trying to extract data from running bigquery job: `{job_id}`") + + # TODO: remove bigQuery_job in next release + run_facets["bigQuery_job"] = run_facets["bigQueryJob"] = self._get_bigquery_job_run_facet(props) + + if get_from_nullable_chain(props, ["statistics", "numChildJobs"]): + if hasattr(self, "log"): + self.log.debug("Found SCRIPT job. Extracting lineage from child jobs instead.") + # SCRIPT job type has no input / output information but spawns child jobs that have one + # https://cloud.google.com/bigquery/docs/information-schema-jobs#multi-statement_query_job + for child_job_id in self.client.list_jobs(parent_job=job_id): + child_job = self.client.get_job(job_id=child_job_id) # type: ignore + child_inputs, child_output = self._get_inputs_outputs_from_job(child_job._properties) + inputs.extend(child_inputs) + outputs.append(child_output) + else: + inputs, _output = self._get_inputs_outputs_from_job(props) + outputs.append(_output) + except Exception as e: + if hasattr(self, "log"): + self.log.warning("Cannot retrieve job details from BigQuery.Client. %s", e, exc_info=True) + exception_msg = traceback.format_exc() + # TODO: remove BigQueryErrorRunFacet in next release + run_facets.update( + { + "errorMessage": error_message_run.ErrorMessageRunFacet( + message=f"{e}: {exception_msg}", + programmingLanguage="python", + ), + "bigQuery_error": BigQueryErrorRunFacet( + clientError=f"{e}: {exception_msg}", + ), + } + ) + deduplicated_outputs = self._deduplicate_outputs(outputs) + return inputs, deduplicated_outputs, run_facets + + def _deduplicate_outputs(self, outputs: list[Dataset | None]) -> list[Dataset]: + # Sources are the same so we can compare only names + final_outputs = {} + for single_output in outputs: + if not single_output: + continue + key = single_output.name + if key not in final_outputs: + final_outputs[key] = single_output + continue + + # No OutputStatisticsOutputDatasetFacet is added to duplicated outputs as we can not determine + # if the rowCount or size can be summed together. + single_output.facets.pop("outputStatistics", None) + final_outputs[key] = single_output + + return list(final_outputs.values()) + + def _get_inputs_outputs_from_job(self, properties: dict) -> tuple[list[Dataset], Dataset | None]: + input_tables = get_from_nullable_chain(properties, ["statistics", "query", "referencedTables"]) or [] + output_table = get_from_nullable_chain(properties, ["configuration", "query", "destinationTable"]) + inputs = [self._get_dataset(input_table) for input_table in input_tables] + if output_table: + output = self._get_dataset(output_table) + dataset_stat_facet = self._get_statistics_dataset_facet(properties) + if dataset_stat_facet: + output.facets.update({"outputStatistics": dataset_stat_facet}) + + return inputs, output + + @staticmethod + def _get_bigquery_job_run_facet(properties: dict) -> BigQueryJobRunFacet: + if get_from_nullable_chain(properties, ["configuration", "query", "query"]): + # Exclude the query to avoid event size issues and duplicating SqlJobFacet information. + properties = copy.deepcopy(properties) + properties["configuration"]["query"].pop("query") + cache_hit = get_from_nullable_chain(properties, ["statistics", "query", "cacheHit"]) + billed_bytes = get_from_nullable_chain(properties, ["statistics", "query", "totalBytesBilled"]) + return BigQueryJobRunFacet( + cached=str(cache_hit).lower() == "true", + billedBytes=int(billed_bytes) if billed_bytes else None, + properties=json.dumps(properties), + ) + + @staticmethod + def _get_statistics_dataset_facet(properties) -> output_statistics_output_dataset.OutputStatisticsOutputDatasetFacet | None: + query_plan = get_from_nullable_chain(properties, chain=["statistics", "query", "queryPlan"]) + if not query_plan: + return None + + out_stage = query_plan[-1] + out_rows = out_stage.get("recordsWritten", None) + out_bytes = out_stage.get("shuffleOutputBytes", None) + if out_bytes and out_rows: + return output_statistics_output_dataset.OutputStatisticsOutputDatasetFacet(rowCount=int(out_rows), size=int(out_bytes)) + return None + + def _get_dataset(self, table: dict) -> Dataset: + project = table.get("projectId") + dataset = table.get("datasetId") + table_name = table.get("tableId") + dataset_name = f"{project}.{dataset}.{table_name}" + + dataset_schema = self._get_table_schema_safely(dataset_name) + return Dataset( + namespace=BIGQUERY_NAMESPACE, + name=dataset_name, + facets={ + "schema": dataset_schema, + } + if dataset_schema + else {}, + ) + + def _get_table_schema_safely(self, table_name: str) -> schema_dataset.SchemaDatasetFacet | None: + try: + return self._get_table_schema(table_name) + except Exception as e: + if hasattr(self, "log"): + self.log.warning("Could not extract output schema from bigquery. %s", e) + return None + + def _get_table_schema(self, table: str) -> schema_dataset.SchemaDatasetFacet | None: + bq_table = self.client.get_table(table) + + if not bq_table._properties: + return None + + fields = get_from_nullable_chain(bq_table._properties, ["schema", "fields"]) + if not fields: + return None + + return schema_dataset.SchemaDatasetFacet( + fields=[ + schema_dataset.SchemaDatasetFacetFields( + name=field.get("name"), + type=field.get("type"), + description=field.get("description"), + ) + for field in fields + ] + ) diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py index d55651d06b43..2e5a38d90a77 100644 --- a/airflow/providers/google/cloud/operators/bigquery.py +++ b/airflow/providers/google/cloud/operators/bigquery.py @@ -3066,4 +3066,4 @@ def on_kill(self) -> None: job_id=self.job_id, project_id=self.project_id, location=self.location ) else: - self.log.info("Skipping to cancel job: %s:%s.%s", self.project_id, self.location, self.job_id) + self.log.info("Skipping to cancel job: %s:%s.%s", self.project_id, self.location, self.job_id) \ No newline at end of file diff --git a/airflow/providers/google/cloud/operators/gcs.py b/airflow/providers/google/cloud/operators/gcs.py index 2871cb6e7aa3..8027e3d27dd6 100644 --- a/airflow/providers/google/cloud/operators/gcs.py +++ b/airflow/providers/google/cloud/operators/gcs.py @@ -336,12 +336,8 @@ def execute(self, context: Context) -> None: hook.delete(bucket_name=self.bucket_name, object_name=object_name) def get_openlineage_facets_on_start(self): - from openlineage.client.facet import ( - LifecycleStateChange, - LifecycleStateChangeDatasetFacet, - LifecycleStateChangeDatasetFacetPreviousIdentifier, - ) - from openlineage.client.run import Dataset + from openlineage.client.event_v2 import Dataset + from openlineage.client.facet_v2 import lifecycle_state_change_dataset from airflow.providers.openlineage.extractors import OperatorLineage @@ -363,9 +359,9 @@ def get_openlineage_facets_on_start(self): namespace=bucket_url, name=object_name, facets={ - "lifecycleStateChange": LifecycleStateChangeDatasetFacet( - lifecycleStateChange=LifecycleStateChange.DROP.value, - previousIdentifier=LifecycleStateChangeDatasetFacetPreviousIdentifier( + "lifecycleStateChange": lifecycle_state_change_dataset.LifecycleStateChangeDatasetFacet( + lifecycleStateChange=lifecycle_state_change_dataset.LifecycleStateChange.DROP.value, + previousIdentifier=lifecycle_state_change_dataset.PreviousIdentifier( namespace=bucket_url, name=object_name, ), @@ -645,7 +641,7 @@ def execute(self, context: Context) -> None: ) def get_openlineage_facets_on_start(self): - from openlineage.client.run import Dataset + from openlineage.client.event_v2 import Dataset from airflow.providers.openlineage.extractors import OperatorLineage @@ -921,7 +917,7 @@ def execute(self, context: Context) -> list[str]: def get_openlineage_facets_on_complete(self, task_instance): """Implement on_complete as execute() resolves object prefixes.""" - from openlineage.client.run import Dataset + from openlineage.client.event_v2 import Dataset from airflow.providers.openlineage.extractors import OperatorLineage diff --git a/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py b/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py index 3ca33e384986..800ed952b193 100644 --- a/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py @@ -289,12 +289,8 @@ def get_openlineage_facets_on_complete(self, task_instance): """Implement on_complete as we will include final BQ job id.""" from pathlib import Path - from openlineage.client.facet import ( - ExternalQueryRunFacet, - SymlinksDatasetFacet, - SymlinksDatasetFacetIdentifiers, - ) - from openlineage.client.run import Dataset + from openlineage.client.event_v2 import Dataset + from openlineage.client.facet_v2 import external_query_run, symlinks_dataset from airflow.providers.google.cloud.hooks.gcs import _parse_gcs_url from airflow.providers.google.cloud.openlineage.utils import ( @@ -334,11 +330,9 @@ def get_openlineage_facets_on_complete(self, task_instance): # If wildcard ("*") is used in gcs path, we want the name of dataset to be directory name, # but we create a symlink to the full object path with wildcard. additional_facets = { - "symlink": SymlinksDatasetFacet( + "symlink": symlinks_dataset.SymlinksDatasetFacet( identifiers=[ - SymlinksDatasetFacetIdentifiers( - namespace=f"gs://{bucket}", name=blob, type="file" - ) + symlinks_dataset.Identifier(namespace=f"gs://{bucket}", name=blob, type="file") ] ), } @@ -357,7 +351,9 @@ def get_openlineage_facets_on_complete(self, task_instance): run_facets = {} if self.job_id: run_facets = { - "externalQuery": ExternalQueryRunFacet(externalQueryId=self.job_id, source="bigquery"), + "externalQuery": external_query_run.ExternalQueryRunFacet( + externalQueryId=self.job_id, source="bigquery" + ), } return OperatorLineage(inputs=[input_dataset], outputs=output_datasets, run_facets=run_facets) diff --git a/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py b/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py index 03741b7adc45..22ecf93d4d01 100644 --- a/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +++ b/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py @@ -746,12 +746,8 @@ def get_openlineage_facets_on_complete(self, task_instance): """Implement on_complete as we will include final BQ job id.""" from pathlib import Path - from openlineage.client.facet import ( - ExternalQueryRunFacet, - SymlinksDatasetFacet, - SymlinksDatasetFacetIdentifiers, - ) - from openlineage.client.run import Dataset + from openlineage.client.event_v2 import Dataset + from openlineage.client.facet_v2 import external_query_run, symlinks_dataset from airflow.providers.google.cloud.openlineage.utils import ( get_facets_from_bq_table, @@ -785,9 +781,9 @@ def get_openlineage_facets_on_complete(self, task_instance): # If wildcard ("*") is used in gcs path, we want the name of dataset to be directory name, # but we create a symlink to the full object path with wildcard. additional_facets = { - "symlink": SymlinksDatasetFacet( + "symlink": symlinks_dataset.SymlinksDatasetFacet( identifiers=[ - SymlinksDatasetFacetIdentifiers( + symlinks_dataset.Identifier( namespace=f"gs://{self.bucket}", name=blob, type="file" ) ] @@ -818,7 +814,9 @@ def get_openlineage_facets_on_complete(self, task_instance): run_facets = {} if self.job_id: run_facets = { - "externalQuery": ExternalQueryRunFacet(externalQueryId=self.job_id, source="bigquery"), + "externalQuery": external_query_run.ExternalQueryRunFacet( + externalQueryId=self.job_id, source="bigquery" + ), } return OperatorLineage(inputs=input_datasets, outputs=[output_dataset], run_facets=run_facets) diff --git a/airflow/providers/google/cloud/transfers/gcs_to_gcs.py b/airflow/providers/google/cloud/transfers/gcs_to_gcs.py index 0b3d330b65f9..b9730e70c86e 100644 --- a/airflow/providers/google/cloud/transfers/gcs_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/gcs_to_gcs.py @@ -552,7 +552,7 @@ def get_openlineage_facets_on_complete(self, task_instance): """ from pathlib import Path - from openlineage.client.run import Dataset + from openlineage.client.event_v2 import Dataset from airflow.providers.openlineage.extractors import OperatorLineage diff --git a/airflow/providers/openlineage/extractors/base.py b/airflow/providers/openlineage/extractors/base.py index 05aa659d7558..1fe60bd77373 100644 --- a/airflow/providers/openlineage/extractors/base.py +++ b/airflow/providers/openlineage/extractors/base.py @@ -17,17 +17,22 @@ from __future__ import annotations +import warnings from abc import ABC, abstractmethod from typing import Generic, TypeVar, Union from attrs import Factory, define from openlineage.client.event_v2 import Dataset as OLDataset -from openlineage.client.facet import BaseFacet as BaseFacet_V1 + +with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + from openlineage.client.facet import BaseFacet as BaseFacet_V1 from openlineage.client.facet_v2 import JobFacet, RunFacet from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.state import TaskInstanceState +# this is not to break static checks compatibility with v1 OpenLineage facet classes DatasetSubclass = TypeVar("DatasetSubclass", bound=OLDataset) BaseFacetSubclass = TypeVar("BaseFacetSubclass", bound=Union[BaseFacet_V1, RunFacet, JobFacet]) diff --git a/airflow/providers/openlineage/sqlparser.py b/airflow/providers/openlineage/sqlparser.py index 900bbf277269..323ed8a11b8f 100644 --- a/airflow/providers/openlineage/sqlparser.py +++ b/airflow/providers/openlineage/sqlparser.py @@ -199,8 +199,7 @@ def attach_column_lineage( if not len(parse_result.column_lineage): return for dataset in datasets: - if not dataset.facets: - continue + dataset.facets = dataset.facets or {} dataset.facets["columnLineage"] = column_lineage_dataset.ColumnLineageDatasetFacet( fields={ column_lineage.descendant.name: column_lineage_dataset.Fields( diff --git a/airflow/providers/sftp/operators/sftp.py b/airflow/providers/sftp/operators/sftp.py index 13a12979040e..68fc87350b4d 100644 --- a/airflow/providers/sftp/operators/sftp.py +++ b/airflow/providers/sftp/operators/sftp.py @@ -201,7 +201,7 @@ def get_openlineage_facets_on_start(self): input: file:///path output: file://:/path. """ - from openlineage.client.run import Dataset + from openlineage.client.event_v2 import Dataset from airflow.providers.openlineage.extractors import OperatorLineage diff --git a/airflow/providers/snowflake/hooks/snowflake.py b/airflow/providers/snowflake/hooks/snowflake.py index ff34461cafab..844cbaf9e298 100644 --- a/airflow/providers/snowflake/hooks/snowflake.py +++ b/airflow/providers/snowflake/hooks/snowflake.py @@ -480,7 +480,7 @@ def _get_openlineage_authority(self, _) -> str | None: return urlparse(uri).hostname def get_openlineage_database_specific_lineage(self, _) -> OperatorLineage | None: - from openlineage.client.facet import ExternalQueryRunFacet + from openlineage.client.facet_v2 import external_query_run from airflow.providers.openlineage.extractors import OperatorLineage from airflow.providers.openlineage.sqlparser import SQLParser @@ -491,7 +491,7 @@ def get_openlineage_database_specific_lineage(self, _) -> OperatorLineage | None namespace = SQLParser.create_namespace(self.get_openlineage_database_info(connection)) return OperatorLineage( run_facets={ - "externalQuery": ExternalQueryRunFacet( + "externalQuery": external_query_run.ExternalQueryRunFacet( externalQueryId=self.query_ids[0], source=namespace ) } diff --git a/airflow/providers/snowflake/transfers/copy_into_snowflake.py b/airflow/providers/snowflake/transfers/copy_into_snowflake.py index 3606eba12c49..8624a22aae1e 100644 --- a/airflow/providers/snowflake/transfers/copy_into_snowflake.py +++ b/airflow/providers/snowflake/transfers/copy_into_snowflake.py @@ -228,13 +228,8 @@ def get_openlineage_facets_on_complete(self, task_instance): """Implement _on_complete because we rely on return value of a query.""" import re - from openlineage.client.facet import ( - ExternalQueryRunFacet, - ExtractionError, - ExtractionErrorRunFacet, - SqlJobFacet, - ) - from openlineage.client.run import Dataset + from openlineage.client.event_v2 import Dataset + from openlineage.client.facet_v2 import external_query_run, extraction_error_run, sql_job from airflow.providers.openlineage.extractors import OperatorLineage from airflow.providers.openlineage.sqlparser import SQLParser @@ -257,11 +252,11 @@ def get_openlineage_facets_on_complete(self, task_instance): "Unable to extract Dataset namespace and name for the following files: `%s`.", extraction_error_files, ) - run_facets["extractionError"] = ExtractionErrorRunFacet( + run_facets["extractionError"] = extraction_error_run.ExtractionErrorRunFacet( totalTasks=len(query_results), failedTasks=len(extraction_error_files), errors=[ - ExtractionError( + extraction_error_run.Error( errorMessage="Unable to extract Dataset namespace and name.", stackTrace=None, task=file_uri, @@ -286,13 +281,13 @@ def get_openlineage_facets_on_complete(self, task_instance): query = SQLParser.normalize_sql(self._sql) query = re.sub(r"\n+", "\n", re.sub(r" +", " ", query)) - run_facets["externalQuery"] = ExternalQueryRunFacet( + run_facets["externalQuery"] = external_query_run.ExternalQueryRunFacet( externalQueryId=self.hook.query_ids[0], source=snowflake_namespace ) return OperatorLineage( inputs=input_datasets, outputs=[Dataset(namespace=snowflake_namespace, name=dest_name)], - job_facets={"sql": SqlJobFacet(query=query)}, + job_facets={"sql": sql_job.SQLJobFacet(query=query)}, run_facets=run_facets, ) diff --git a/docs/apache-airflow-providers-openlineage/guides/developer.rst b/docs/apache-airflow-providers-openlineage/guides/developer.rst index 86f57ac3e113..5d69a1e0bac7 100644 --- a/docs/apache-airflow-providers-openlineage/guides/developer.rst +++ b/docs/apache-airflow-providers-openlineage/guides/developer.rst @@ -152,7 +152,7 @@ As there is some processing made in ``execute`` method, and there is no relevant This means we won't have to normalize self.source_object and self.source_objects, destination bucket and so on. """ - from openlineage.client.run import Dataset + from openlineage.client.event_v2 import Dataset from airflow.providers.openlineage.extractors import OperatorLineage return OperatorLineage( @@ -303,8 +303,8 @@ like extracting column level lineage and inputs/outputs from SQL query with SQL .. code-block:: python - from openlineage.client.facet import BaseFacet, ExternalQueryRunFacet, SqlJobFacet - from openlineage.client.run import Dataset + from openlineage.client.facet_v2 import BaseFacet, external_query_run, sql_job + from openlineage.client.event_v2 import Dataset from airflow.models.baseoperator import BaseOperator from airflow.providers.openlineage.extractors.base import BaseExtractor @@ -333,7 +333,7 @@ like extracting column level lineage and inputs/outputs from SQL query with SQL inputs=[Dataset(namespace="bigquery", name=self.bq_table_reference)], outputs=[Dataset(namespace=self.s3_path, name=self.s3_file_name)], job_facets={ - "sql": SqlJobFacet( + "sql": sql_job.SQLJobFacet( query="EXPORT INTO ... OPTIONS(FORMAT=csv, SEP=';' ...) AS SELECT * FROM ... " ) }, @@ -343,7 +343,9 @@ like extracting column level lineage and inputs/outputs from SQL query with SQL """Add what we received after Operator's extract call.""" lineage_metadata = self.extract() lineage_metadata.run_facets = { - "parent": ExternalQueryRunFacet(externalQueryId=self._job_id, source="bigquery") + "parent": external_query_run.ExternalQueryRunFacet( + externalQueryId=self._job_id, source="bigquery" + ) } return lineage_metadata diff --git a/tests/providers/common/io/operators/test_file_transfer.py b/tests/providers/common/io/operators/test_file_transfer.py index 3f50d379eb1f..b8b9f12d3cdb 100644 --- a/tests/providers/common/io/operators/test_file_transfer.py +++ b/tests/providers/common/io/operators/test_file_transfer.py @@ -19,7 +19,7 @@ from unittest import mock -from openlineage.client.run import Dataset +from openlineage.client.event_v2 import Dataset from tests.test_utils.compat import ignore_provider_compatibility_error diff --git a/tests/providers/common/sql/operators/test_sql_execute.py b/tests/providers/common/sql/operators/test_sql_execute.py index 0ba52abba97a..c76140065ae1 100644 --- a/tests/providers/common/sql/operators/test_sql_execute.py +++ b/tests/providers/common/sql/operators/test_sql_execute.py @@ -22,8 +22,8 @@ from unittest.mock import MagicMock import pytest -from openlineage.client.facet import SchemaDatasetFacet, SchemaField, SqlJobFacet -from openlineage.client.run import Dataset +from openlineage.client.event_v2 import Dataset +from openlineage.client.facet_v2 import schema_dataset, sql_job from airflow.models import Connection from airflow.providers.common.sql.hooks.sql import DbApiHook, fetch_all_handler @@ -338,18 +338,18 @@ def get_db_hook(self): namespace=f"sqlscheme://host:{expected_port}", name="PUBLIC.popular_orders_day_of_week", facets={ - "schema": SchemaDatasetFacet( + "schema": schema_dataset.SchemaDatasetFacet( fields=[ - SchemaField(name="order_day_of_week", type="varchar"), - SchemaField(name="order_placed_on", type="timestamp"), - SchemaField(name="orders_placed", type="int4"), + schema_dataset.SchemaDatasetFacetFields(name="order_day_of_week", type="varchar"), + schema_dataset.SchemaDatasetFacetFields(name="order_placed_on", type="timestamp"), + schema_dataset.SchemaDatasetFacetFields(name="orders_placed", type="int4"), ] ) }, ) ] - assert lineage.job_facets == {"sql": SqlJobFacet(query=sql)} + assert lineage.job_facets == {"sql": sql_job.SQLJobFacet(query=sql)} assert lineage.run_facets["extractionError"].failedTasks == 1 diff --git a/tests/providers/ftp/operators/test_ftp.py b/tests/providers/ftp/operators/test_ftp.py index 3e2930743ebc..12e1d4363e9a 100644 --- a/tests/providers/ftp/operators/test_ftp.py +++ b/tests/providers/ftp/operators/test_ftp.py @@ -21,7 +21,7 @@ from unittest import mock import pytest -from openlineage.client.run import Dataset +from openlineage.client.event_v2 import Dataset from airflow.models import DAG, Connection from airflow.providers.ftp.operators.ftp import ( diff --git a/tests/providers/google/cloud/openlineage/test_utils.py b/tests/providers/google/cloud/openlineage/test_utils.py index 19c4b1d0bee0..bfa63943a145 100644 --- a/tests/providers/google/cloud/openlineage/test_utils.py +++ b/tests/providers/google/cloud/openlineage/test_utils.py @@ -21,15 +21,8 @@ import pytest from google.cloud.bigquery.table import Table -from openlineage.client.facet import ( - ColumnLineageDatasetFacet, - ColumnLineageDatasetFacetFieldsAdditional, - ColumnLineageDatasetFacetFieldsAdditionalInputFields, - DocumentationDatasetFacet, - SchemaDatasetFacet, - SchemaField, -) -from openlineage.client.run import Dataset +from openlineage.client.event_v2 import Dataset +from openlineage.client.facet_v2 import column_lineage_dataset, documentation_dataset, schema_dataset from airflow.providers.google.cloud.openlineage.utils import ( get_facets_from_bq_table, @@ -76,13 +69,15 @@ def _properties(self): def test_get_facets_from_bq_table(): expected_facets = { - "schema": SchemaDatasetFacet( + "schema": schema_dataset.SchemaDatasetFacet( fields=[ - SchemaField(name="field1", type="STRING", description="field1 description"), - SchemaField(name="field2", type="INTEGER"), + schema_dataset.SchemaDatasetFacetFields( + name="field1", type="STRING", description="field1 description" + ), + schema_dataset.SchemaDatasetFacetFields(name="field2", type="INTEGER"), ] ), - "documentation": DocumentationDatasetFacet(description="Table description."), + "documentation": documentation_dataset.DocumentationDatasetFacet(description="Table description."), } result = get_facets_from_bq_table(TEST_TABLE) assert result == expected_facets @@ -90,8 +85,8 @@ def test_get_facets_from_bq_table(): def test_get_facets_from_empty_bq_table(): expected_facets = { - "schema": SchemaDatasetFacet(fields=[]), - "documentation": DocumentationDatasetFacet(description=""), + "schema": schema_dataset.SchemaDatasetFacet(fields=[]), + "documentation": documentation_dataset.DocumentationDatasetFacet(description=""), } result = get_facets_from_bq_table(TEST_EMPTY_TABLE) assert result == expected_facets @@ -103,16 +98,16 @@ def test_get_identity_column_lineage_facet_multiple_input_datasets(): Dataset(namespace="gs://first_bucket", name="dir1"), Dataset(namespace="gs://second_bucket", name="dir2"), ] - expected_facet = ColumnLineageDatasetFacet( + expected_facet = column_lineage_dataset.ColumnLineageDatasetFacet( fields={ - "field1": ColumnLineageDatasetFacetFieldsAdditional( + "field1": column_lineage_dataset.Fields( inputFields=[ - ColumnLineageDatasetFacetFieldsAdditionalInputFields( + column_lineage_dataset.InputField( namespace="gs://first_bucket", name="dir1", field="field1", ), - ColumnLineageDatasetFacetFieldsAdditionalInputFields( + column_lineage_dataset.InputField( namespace="gs://second_bucket", name="dir2", field="field1", @@ -121,14 +116,14 @@ def test_get_identity_column_lineage_facet_multiple_input_datasets(): transformationType="IDENTITY", transformationDescription="identical", ), - "field2": ColumnLineageDatasetFacetFieldsAdditional( + "field2": column_lineage_dataset.Fields( inputFields=[ - ColumnLineageDatasetFacetFieldsAdditionalInputFields( + column_lineage_dataset.InputField( namespace="gs://first_bucket", name="dir1", field="field2", ), - ColumnLineageDatasetFacetFieldsAdditionalInputFields( + column_lineage_dataset.InputField( namespace="gs://second_bucket", name="dir2", field="field2", @@ -149,8 +144,10 @@ def test_get_identity_column_lineage_facet_no_field_names(): Dataset(namespace="gs://first_bucket", name="dir1"), Dataset(namespace="gs://second_bucket", name="dir2"), ] - expected_facet = ColumnLineageDatasetFacet(fields={}) - result = get_identity_column_lineage_facet(field_names=field_names, input_datasets=input_datasets) + expected_facet = column_lineage_dataset.ColumnLineageDatasetFacet(fields={}) + result = get_identity_column_lineage_facet( + field_names=field_names, input_datasets=input_datasets + ) assert result == expected_facet diff --git a/tests/providers/google/cloud/operators/test_bigquery.py b/tests/providers/google/cloud/operators/test_bigquery.py index 346c50382d93..342c138f9bd2 100644 --- a/tests/providers/google/cloud/operators/test_bigquery.py +++ b/tests/providers/google/cloud/operators/test_bigquery.py @@ -26,7 +26,8 @@ import pytest from google.cloud.bigquery import DEFAULT_RETRY, ScalarQueryParameter from google.cloud.exceptions import Conflict -from openlineage.client.facet import ErrorMessageRunFacet, ExternalQueryRunFacet, SqlJobFacet +from openlineage.client.facet import DataSourceDatasetFacet, ExternalQueryRunFacet +from openlineage.client.facet_v2 import sql_job from openlineage.client.run import Dataset from airflow.exceptions import ( @@ -1852,7 +1853,7 @@ def test_execute_openlineage_events(self, mock_hook): "bigQueryJob": mock.ANY, "externalQuery": ExternalQueryRunFacet(externalQueryId=mock.ANY, source="bigquery"), } - assert lineage.job_facets == {"sql": SqlJobFacet(query="SELECT * FROM test_table")} + assert lineage.job_facets == {"sql": sql_job.SQLJobFacet(query="SELECT * FROM test_table")} @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") def test_execute_fails_openlineage_events(self, mock_hook): diff --git a/tests/providers/google/cloud/operators/test_gcs.py b/tests/providers/google/cloud/operators/test_gcs.py index 0024ad6407ab..6299b043598f 100644 --- a/tests/providers/google/cloud/operators/test_gcs.py +++ b/tests/providers/google/cloud/operators/test_gcs.py @@ -22,12 +22,8 @@ from unittest import mock import pytest -from openlineage.client.facet import ( - LifecycleStateChange, - LifecycleStateChangeDatasetFacet, - LifecycleStateChangeDatasetFacetPreviousIdentifier, -) -from openlineage.client.run import Dataset +from openlineage.client.event_v2 import Dataset +from openlineage.client.facet_v2 import lifecycle_state_change_dataset from airflow.providers.google.cloud.operators.gcs import ( GCSBucketCreateAclEntryOperator, @@ -206,9 +202,9 @@ def test_get_openlineage_facets_on_start(self, objects, prefix, inputs): namespace=bucket_url, name=name, facets={ - "lifecycleStateChange": LifecycleStateChangeDatasetFacet( - lifecycleStateChange=LifecycleStateChange.DROP.value, - previousIdentifier=LifecycleStateChangeDatasetFacetPreviousIdentifier( + "lifecycleStateChange": lifecycle_state_change_dataset.LifecycleStateChangeDatasetFacet( + lifecycleStateChange=lifecycle_state_change_dataset.LifecycleStateChange.DROP.value, + previousIdentifier=lifecycle_state_change_dataset.PreviousIdentifier( namespace=bucket_url, name=name, ), @@ -224,7 +220,8 @@ def test_get_openlineage_facets_on_start(self, objects, prefix, inputs): lineage = operator.get_openlineage_facets_on_start() assert len(lineage.inputs) == len(inputs) assert len(lineage.outputs) == 0 - assert sorted(lineage.inputs) == sorted(expected_inputs) + assert all(element in lineage.inputs for element in expected_inputs) + assert all(element in expected_inputs for element in lineage.inputs) class TestGoogleCloudStorageListOperator: @@ -619,8 +616,10 @@ def test_get_openlineage_facets_on_complete( lineage = op.get_openlineage_facets_on_complete(None) assert len(lineage.inputs) == len(inputs) assert len(lineage.outputs) == len(outputs) - assert sorted(lineage.inputs) == sorted(inputs) - assert sorted(lineage.outputs) == sorted(outputs) + assert all(element in lineage.inputs for element in inputs) + assert all(element in inputs for element in lineage.inputs) + assert all(element in lineage.outputs for element in outputs) + assert all(element in outputs for element in lineage.outputs) class TestGCSDeleteBucketOperator: diff --git a/tests/providers/google/cloud/transfers/test_bigquery_to_gcs.py b/tests/providers/google/cloud/transfers/test_bigquery_to_gcs.py index 14ed8e40b5fb..8e77d505c3ae 100644 --- a/tests/providers/google/cloud/transfers/test_bigquery_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_bigquery_to_gcs.py @@ -23,18 +23,14 @@ import pytest from google.cloud.bigquery.retry import DEFAULT_RETRY from google.cloud.bigquery.table import Table -from openlineage.client.facet import ( - ColumnLineageDatasetFacet, - ColumnLineageDatasetFacetFieldsAdditional, - ColumnLineageDatasetFacetFieldsAdditionalInputFields, - DocumentationDatasetFacet, - ExternalQueryRunFacet, - SchemaDatasetFacet, - SchemaField, - SymlinksDatasetFacet, - SymlinksDatasetFacetIdentifiers, +from openlineage.client.event_v2 import Dataset +from openlineage.client.facet_v2 import ( + column_lineage_dataset, + documentation_dataset, + external_query_run, + schema_dataset, + symlinks_dataset, ) -from openlineage.client.run import Dataset from airflow.exceptions import TaskDeferred from airflow.providers.google.cloud.transfers.bigquery_to_gcs import BigQueryToGCSOperator @@ -267,13 +263,17 @@ def test_get_openlineage_facets_on_complete_bq_dataset(self, mock_hook): source_project_dataset_table = f"{PROJECT_ID}.{TEST_DATASET}.{TEST_TABLE_ID}" expected_input_dataset_facets = { - "schema": SchemaDatasetFacet( + "schema": schema_dataset.SchemaDatasetFacet( fields=[ - SchemaField(name="field1", type="STRING", description="field1 description"), - SchemaField(name="field2", type="INTEGER"), + schema_dataset.SchemaDatasetFacetFields( + name="field1", type="STRING", description="field1 description" + ), + schema_dataset.SchemaDatasetFacetFields(name="field2", type="INTEGER"), ] ), - "documentation": DocumentationDatasetFacet(description="Table description."), + "documentation": documentation_dataset.DocumentationDatasetFacet( + description="Table description." + ), } mock_hook.return_value.split_tablename.return_value = (PROJECT_ID, TEST_DATASET, TEST_TABLE_ID) @@ -300,8 +300,8 @@ def test_get_openlineage_facets_on_complete_bq_dataset_empty_table(self, mock_ho source_project_dataset_table = f"{PROJECT_ID}.{TEST_DATASET}.{TEST_TABLE_ID}" expected_input_dataset_facets = { - "schema": SchemaDatasetFacet(fields=[]), - "documentation": DocumentationDatasetFacet(description=""), + "schema": schema_dataset.SchemaDatasetFacet(fields=[]), + "documentation": documentation_dataset.DocumentationDatasetFacet(description=""), } mock_hook.return_value.split_tablename.return_value = (PROJECT_ID, TEST_DATASET, TEST_TABLE_ID) @@ -331,13 +331,13 @@ def test_get_openlineage_facets_on_complete_gcs_no_wildcard_empty_table(self, mo bq_namespace = "bigquery" expected_input_facets = { - "schema": SchemaDatasetFacet(fields=[]), - "documentation": DocumentationDatasetFacet(description=""), + "schema": schema_dataset.SchemaDatasetFacet(fields=[]), + "documentation": documentation_dataset.DocumentationDatasetFacet(description=""), } expected_output_facets = { - "schema": SchemaDatasetFacet(fields=[]), - "columnLineage": ColumnLineageDatasetFacet(fields={}), + "schema": schema_dataset.SchemaDatasetFacet(fields=[]), + "columnLineage": column_lineage_dataset.ColumnLineageDatasetFacet(fields={}), } mock_hook.return_value.split_tablename.return_value = (PROJECT_ID, TEST_DATASET, TEST_TABLE_ID) @@ -365,7 +365,9 @@ def test_get_openlineage_facets_on_complete_gcs_no_wildcard_empty_table(self, mo facets=expected_output_facets, ) assert lineage.run_facets == { - "externalQuery": ExternalQueryRunFacet(externalQueryId=real_job_id, source=bq_namespace) + "externalQuery": external_query_run.ExternalQueryRunFacet( + externalQueryId=real_job_id, source=bq_namespace + ) } assert lineage.job_facets == {} @@ -376,33 +378,37 @@ def test_get_openlineage_facets_on_complete_gcs_wildcard_full_table(self, mock_h real_job_id = "123456_hash" bq_namespace = "bigquery" - schema_facet = SchemaDatasetFacet( + schema_facet = schema_dataset.SchemaDatasetFacet( fields=[ - SchemaField(name="field1", type="STRING", description="field1 description"), - SchemaField(name="field2", type="INTEGER"), + schema_dataset.SchemaDatasetFacetFields( + name="field1", type="STRING", description="field1 description" + ), + schema_dataset.SchemaDatasetFacetFields(name="field2", type="INTEGER"), ] ) expected_input_facets = { "schema": schema_facet, - "documentation": DocumentationDatasetFacet(description="Table description."), + "documentation": documentation_dataset.DocumentationDatasetFacet( + description="Table description." + ), } expected_output_facets = { "schema": schema_facet, - "columnLineage": ColumnLineageDatasetFacet( + "columnLineage": column_lineage_dataset.ColumnLineageDatasetFacet( fields={ - "field1": ColumnLineageDatasetFacetFieldsAdditional( + "field1": column_lineage_dataset.Fields( inputFields=[ - ColumnLineageDatasetFacetFieldsAdditionalInputFields( + column_lineage_dataset.InputField( namespace=bq_namespace, name=source_project_dataset_table, field="field1" ) ], transformationType="IDENTITY", transformationDescription="identical", ), - "field2": ColumnLineageDatasetFacetFieldsAdditional( + "field2": column_lineage_dataset.Fields( inputFields=[ - ColumnLineageDatasetFacetFieldsAdditionalInputFields( + column_lineage_dataset.InputField( namespace=bq_namespace, name=source_project_dataset_table, field="field2" ) ], @@ -411,9 +417,9 @@ def test_get_openlineage_facets_on_complete_gcs_wildcard_full_table(self, mock_h ), } ), - "symlink": SymlinksDatasetFacet( + "symlink": symlinks_dataset.SymlinksDatasetFacet( identifiers=[ - SymlinksDatasetFacetIdentifiers( + symlinks_dataset.Identifier( namespace=f"gs://{TEST_BUCKET}", name=f"{TEST_FOLDER}/{TEST_OBJECT_WILDCARD}", type="file", @@ -445,6 +451,8 @@ def test_get_openlineage_facets_on_complete_gcs_wildcard_full_table(self, mock_h namespace=f"gs://{TEST_BUCKET}", name=TEST_FOLDER, facets=expected_output_facets ) assert lineage.run_facets == { - "externalQuery": ExternalQueryRunFacet(externalQueryId=real_job_id, source=bq_namespace) + "externalQuery": external_query_run.ExternalQueryRunFacet( + externalQueryId=real_job_id, source=bq_namespace + ) } assert lineage.job_facets == {} diff --git a/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py b/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py index 3d465744f2c8..99ad21dbb1ee 100644 --- a/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py +++ b/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py @@ -24,18 +24,14 @@ import pytest from google.cloud.bigquery import DEFAULT_RETRY, Table from google.cloud.exceptions import Conflict -from openlineage.client.facet import ( - ColumnLineageDatasetFacet, - ColumnLineageDatasetFacetFieldsAdditional, - ColumnLineageDatasetFacetFieldsAdditionalInputFields, - DocumentationDatasetFacet, - ExternalQueryRunFacet, - SchemaDatasetFacet, - SchemaField, - SymlinksDatasetFacet, - SymlinksDatasetFacetIdentifiers, +from openlineage.client.event_v2 import Dataset +from openlineage.client.facet_v2 import ( + column_lineage_dataset, + documentation_dataset, + external_query_run, + schema_dataset, + symlinks_dataset, ) -from openlineage.client.run import Dataset from airflow.exceptions import AirflowException, TaskDeferred from airflow.models import DAG @@ -1257,9 +1253,9 @@ def test_get_openlineage_facets_on_complete_gcs_dataset_name( destination_project_dataset_table=TEST_EXPLICIT_DEST, ) - expected_symlink = SymlinksDatasetFacet( + expected_symlink = symlinks_dataset.SymlinksDatasetFacet( identifiers=[ - SymlinksDatasetFacetIdentifiers( + symlinks_dataset.Identifier( namespace=f"gs://{TEST_BUCKET}", name=source_object, type="file", @@ -1298,9 +1294,9 @@ def test_get_openlineage_facets_on_complete_gcs_multiple_uris(self, hook): assert len(lineage.inputs) == 4 assert lineage.inputs[0].name == TEST_OBJECT_NO_WILDCARD assert lineage.inputs[1].name == "/" - assert lineage.inputs[1].facets.get("symlink") == SymlinksDatasetFacet( + assert lineage.inputs[1].facets.get("symlink") == symlinks_dataset.SymlinksDatasetFacet( identifiers=[ - SymlinksDatasetFacetIdentifiers( + symlinks_dataset.Identifier( namespace=f"gs://{TEST_BUCKET}", name=TEST_OBJECT_WILDCARD, type="file", @@ -1309,9 +1305,9 @@ def test_get_openlineage_facets_on_complete_gcs_multiple_uris(self, hook): ) assert lineage.inputs[2].name == f"{TEST_FOLDER}1/{TEST_OBJECT_NO_WILDCARD}" assert lineage.inputs[3].name == f"{TEST_FOLDER}2" - assert lineage.inputs[3].facets.get("symlink") == SymlinksDatasetFacet( + assert lineage.inputs[3].facets.get("symlink") == symlinks_dataset.SymlinksDatasetFacet( identifiers=[ - SymlinksDatasetFacetIdentifiers( + symlinks_dataset.Identifier( namespace=f"gs://{TEST_BUCKET}", name=f"{TEST_FOLDER}2/{TEST_OBJECT_WILDCARD}", type="file", @@ -1326,27 +1322,29 @@ def test_get_openlineage_facets_on_complete_bq_dataset(self, hook): hook.return_value.get_client.return_value.get_table.return_value = TEST_TABLE expected_output_dataset_facets = { - "schema": SchemaDatasetFacet( + "schema": schema_dataset.SchemaDatasetFacet( fields=[ - SchemaField(name="field1", type="STRING", description="field1 description"), - SchemaField(name="field2", type="INTEGER"), + schema_dataset.SchemaDatasetFacetFields( + name="field1", type="STRING", description="field1 description" + ), + schema_dataset.SchemaDatasetFacetFields(name="field2", type="INTEGER"), ] ), - "documentation": DocumentationDatasetFacet(description="Test Description"), - "columnLineage": ColumnLineageDatasetFacet( + "documentation": documentation_dataset.DocumentationDatasetFacet(description="Test Description"), + "columnLineage": column_lineage_dataset.ColumnLineageDatasetFacet( fields={ - "field1": ColumnLineageDatasetFacetFieldsAdditional( + "field1": column_lineage_dataset.Fields( inputFields=[ - ColumnLineageDatasetFacetFieldsAdditionalInputFields( + column_lineage_dataset.InputField( namespace=f"gs://{TEST_BUCKET}", name=TEST_OBJECT_NO_WILDCARD, field="field1" ) ], transformationType="IDENTITY", transformationDescription="identical", ), - "field2": ColumnLineageDatasetFacetFieldsAdditional( + "field2": column_lineage_dataset.Fields( inputFields=[ - ColumnLineageDatasetFacetFieldsAdditionalInputFields( + column_lineage_dataset.InputField( namespace=f"gs://{TEST_BUCKET}", name=TEST_OBJECT_NO_WILDCARD, field="field2" ) ], @@ -1382,33 +1380,35 @@ def test_get_openlineage_facets_on_complete_bq_dataset_multiple_gcs_uris(self, h hook.return_value.get_client.return_value.get_table.return_value = TEST_TABLE expected_output_dataset_facets = { - "schema": SchemaDatasetFacet( + "schema": schema_dataset.SchemaDatasetFacet( fields=[ - SchemaField(name="field1", type="STRING", description="field1 description"), - SchemaField(name="field2", type="INTEGER"), + schema_dataset.SchemaDatasetFacetFields( + name="field1", type="STRING", description="field1 description" + ), + schema_dataset.SchemaDatasetFacetFields(name="field2", type="INTEGER"), ] ), - "documentation": DocumentationDatasetFacet(description="Test Description"), - "columnLineage": ColumnLineageDatasetFacet( + "documentation": documentation_dataset.DocumentationDatasetFacet(description="Test Description"), + "columnLineage": column_lineage_dataset.ColumnLineageDatasetFacet( fields={ - "field1": ColumnLineageDatasetFacetFieldsAdditional( + "field1": column_lineage_dataset.Fields( inputFields=[ - ColumnLineageDatasetFacetFieldsAdditionalInputFields( + column_lineage_dataset.InputField( namespace=f"gs://{TEST_BUCKET}", name=TEST_OBJECT_NO_WILDCARD, field="field1" ), - ColumnLineageDatasetFacetFieldsAdditionalInputFields( + column_lineage_dataset.InputField( namespace=f"gs://{TEST_BUCKET}", name="/", field="field1" ), ], transformationType="IDENTITY", transformationDescription="identical", ), - "field2": ColumnLineageDatasetFacetFieldsAdditional( + "field2": column_lineage_dataset.Fields( inputFields=[ - ColumnLineageDatasetFacetFieldsAdditionalInputFields( + column_lineage_dataset.InputField( namespace=f"gs://{TEST_BUCKET}", name=TEST_OBJECT_NO_WILDCARD, field="field2" ), - ColumnLineageDatasetFacetFieldsAdditionalInputFields( + column_lineage_dataset.InputField( namespace=f"gs://{TEST_BUCKET}", name="/", field="field2" ), ], @@ -1444,9 +1444,9 @@ def test_get_openlineage_facets_on_complete_empty_table(self, hook): hook.return_value.get_client.return_value.get_table.return_value = TEST_EMPTY_TABLE expected_output_dataset_facets = { - "schema": SchemaDatasetFacet(fields=[]), - "documentation": DocumentationDatasetFacet(description=""), - "columnLineage": ColumnLineageDatasetFacet(fields={}), + "schema": schema_dataset.SchemaDatasetFacet(fields=[]), + "documentation": documentation_dataset.DocumentationDatasetFacet(description=""), + "columnLineage": column_lineage_dataset.ColumnLineageDatasetFacet(fields={}), } operator = GCSToBigQueryOperator( @@ -1470,16 +1470,16 @@ def test_get_openlineage_facets_on_complete_empty_table(self, hook): assert lineage.inputs[0] == Dataset( namespace=f"gs://{TEST_BUCKET}", name=TEST_OBJECT_NO_WILDCARD, - facets={"schema": SchemaDatasetFacet(fields=[])}, + facets={"schema": schema_dataset.SchemaDatasetFacet(fields=[])}, ) assert lineage.inputs[1] == Dataset( namespace=f"gs://{TEST_BUCKET}", name="/", facets={ - "schema": SchemaDatasetFacet(fields=[]), - "symlink": SymlinksDatasetFacet( + "schema": schema_dataset.SchemaDatasetFacet(fields=[]), + "symlink": symlinks_dataset.SymlinksDatasetFacet( identifiers=[ - SymlinksDatasetFacetIdentifiers( + symlinks_dataset.Identifier( namespace=f"gs://{TEST_BUCKET}", name=TEST_OBJECT_WILDCARD, type="file", @@ -1496,18 +1496,20 @@ def test_get_openlineage_facets_on_complete_full_table_multiple_gcs_uris(self, h hook.return_value.get_client.return_value.get_table.return_value = TEST_TABLE hook.return_value.generate_job_id.return_value = REAL_JOB_ID - schema_facet = SchemaDatasetFacet( + schema_facet = schema_dataset.SchemaDatasetFacet( fields=[ - SchemaField(name="field1", type="STRING", description="field1 description"), - SchemaField(name="field2", type="INTEGER"), + schema_dataset.SchemaDatasetFacetFields( + name="field1", type="STRING", description="field1 description" + ), + schema_dataset.SchemaDatasetFacetFields(name="field2", type="INTEGER"), ] ) expected_input_wildcard_dataset_facets = { "schema": schema_facet, - "symlink": SymlinksDatasetFacet( + "symlink": symlinks_dataset.SymlinksDatasetFacet( identifiers=[ - SymlinksDatasetFacetIdentifiers( + symlinks_dataset.Identifier( namespace=f"gs://{TEST_BUCKET}", name=TEST_OBJECT_WILDCARD, type="file", @@ -1519,27 +1521,27 @@ def test_get_openlineage_facets_on_complete_full_table_multiple_gcs_uris(self, h expected_output_dataset_facets = { "schema": schema_facet, - "documentation": DocumentationDatasetFacet(description="Test Description"), - "columnLineage": ColumnLineageDatasetFacet( + "documentation": documentation_dataset.DocumentationDatasetFacet(description="Test Description"), + "columnLineage": column_lineage_dataset.ColumnLineageDatasetFacet( fields={ - "field1": ColumnLineageDatasetFacetFieldsAdditional( + "field1": column_lineage_dataset.Fields( inputFields=[ - ColumnLineageDatasetFacetFieldsAdditionalInputFields( + column_lineage_dataset.InputField( namespace=f"gs://{TEST_BUCKET}", name=TEST_OBJECT_NO_WILDCARD, field="field1" ), - ColumnLineageDatasetFacetFieldsAdditionalInputFields( + column_lineage_dataset.InputField( namespace=f"gs://{TEST_BUCKET}", name="/", field="field1" ), ], transformationType="IDENTITY", transformationDescription="identical", ), - "field2": ColumnLineageDatasetFacetFieldsAdditional( + "field2": column_lineage_dataset.Fields( inputFields=[ - ColumnLineageDatasetFacetFieldsAdditionalInputFields( + column_lineage_dataset.InputField( namespace=f"gs://{TEST_BUCKET}", name=TEST_OBJECT_NO_WILDCARD, field="field2" ), - ColumnLineageDatasetFacetFieldsAdditionalInputFields( + column_lineage_dataset.InputField( namespace=f"gs://{TEST_BUCKET}", name="/", field="field2" ), ], @@ -1576,7 +1578,9 @@ def test_get_openlineage_facets_on_complete_full_table_multiple_gcs_uris(self, h namespace=f"gs://{TEST_BUCKET}", name="/", facets=expected_input_wildcard_dataset_facets ) assert lineage.run_facets == { - "externalQuery": ExternalQueryRunFacet(externalQueryId=REAL_JOB_ID, source="bigquery") + "externalQuery": external_query_run.ExternalQueryRunFacet( + externalQueryId=REAL_JOB_ID, source="bigquery" + ) } assert lineage.job_facets == {} diff --git a/tests/providers/google/cloud/transfers/test_gcs_to_gcs.py b/tests/providers/google/cloud/transfers/test_gcs_to_gcs.py index 59c0b09eb076..f8fda2eb0a49 100644 --- a/tests/providers/google/cloud/transfers/test_gcs_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_gcs_to_gcs.py @@ -21,7 +21,7 @@ from unittest import mock import pytest -from openlineage.client.run import Dataset +from openlineage.client.event_v2 import Dataset from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.google.cloud.transfers.gcs_to_gcs import WILDCARD, GCSToGCSOperator @@ -989,5 +989,7 @@ def test_get_openlineage_facets_on_complete( lineage = operator.get_openlineage_facets_on_complete(None) assert len(lineage.inputs) == len(inputs) assert len(lineage.outputs) == len(outputs) - assert sorted(lineage.inputs) == sorted(inputs) - assert sorted(lineage.outputs) == sorted(outputs) + assert all(element in lineage.inputs for element in inputs) + assert all(element in inputs for element in lineage.inputs) + assert all(element in lineage.outputs for element in outputs) + assert all(element in outputs for element in lineage.outputs) diff --git a/tests/providers/mysql/operators/test_mysql.py b/tests/providers/mysql/operators/test_mysql.py index 91e909700fea..a3715decfb2e 100644 --- a/tests/providers/mysql/operators/test_mysql.py +++ b/tests/providers/mysql/operators/test_mysql.py @@ -22,8 +22,8 @@ from unittest.mock import MagicMock import pytest -from openlineage.client.facet import SchemaDatasetFacet, SchemaField, SqlJobFacet -from openlineage.client.run import Dataset +from openlineage.client.event_v2 import Dataset +from openlineage.client.facet_v2 import schema_dataset, sql_job from airflow.models.connection import Connection from airflow.models.dag import DAG @@ -167,17 +167,17 @@ class MySqlHookForTests(MySqlHook): namespace=f"mysql://host:{connection_port or 3306}", name="PUBLIC.popular_orders_day_of_week", facets={ - "schema": SchemaDatasetFacet( + "schema": schema_dataset.SchemaDatasetFacet( fields=[ - SchemaField(name="order_day_of_week", type="varchar"), - SchemaField(name="order_placed_on", type="timestamp"), - SchemaField(name="orders_placed", type="int4"), + schema_dataset.SchemaDatasetFacetFields(name="order_day_of_week", type="varchar"), + schema_dataset.SchemaDatasetFacetFields(name="order_placed_on", type="timestamp"), + schema_dataset.SchemaDatasetFacetFields(name="orders_placed", type="int4"), ] ) }, ) ] - assert lineage.job_facets == {"sql": SqlJobFacet(query=sql)} + assert lineage.job_facets == {"sql": sql_job.SQLJobFacet(query=sql)} assert lineage.run_facets["extractionError"].failedTasks == 1 diff --git a/tests/providers/sftp/operators/test_sftp.py b/tests/providers/sftp/operators/test_sftp.py index ca562515f2dc..f8c2c1d21e56 100644 --- a/tests/providers/sftp/operators/test_sftp.py +++ b/tests/providers/sftp/operators/test_sftp.py @@ -25,7 +25,7 @@ import paramiko import pytest -from openlineage.client.run import Dataset +from openlineage.client.event_v2 import Dataset from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.models import DAG, Connection diff --git a/tests/providers/snowflake/operators/test_snowflake_sql.py b/tests/providers/snowflake/operators/test_snowflake_sql.py index e3955d454047..3a3562002c57 100644 --- a/tests/providers/snowflake/operators/test_snowflake_sql.py +++ b/tests/providers/snowflake/operators/test_snowflake_sql.py @@ -36,13 +36,8 @@ def Row(*args, **kwargs): return MagicMock() -from openlineage.client.facet import ( - ColumnLineageDatasetFacet, - ColumnLineageDatasetFacetFieldsAdditional, - ColumnLineageDatasetFacetFieldsAdditionalInputFields, - SqlJobFacet, -) -from openlineage.client.run import Dataset +from openlineage.client.event_v2 import Dataset +from openlineage.client.facet_v2 import column_lineage_dataset, sql_job from airflow.models.connection import Connection from airflow.providers.common.sql.hooks.sql import fetch_all_handler @@ -252,11 +247,11 @@ def get_db_hook(self): namespace="snowflake://test_account.us-east.aws", name=f"{DB_NAME}.{DB_SCHEMA_NAME}.TEST_TABLE", facets={ - "columnLineage": ColumnLineageDatasetFacet( + "columnLineage": column_lineage_dataset.ColumnLineageDatasetFacet( fields={ - "additional_constant": ColumnLineageDatasetFacetFieldsAdditional( + "additional_constant": column_lineage_dataset.Fields( inputFields=[ - ColumnLineageDatasetFacetFieldsAdditionalInputFields( + column_lineage_dataset.InputField( namespace="snowflake://test_account.us-east.aws", name="DATABASE.PUBLIC.little_table", field="additional_constant", @@ -271,6 +266,6 @@ def get_db_hook(self): ) ] - assert lineage.job_facets == {"sql": SqlJobFacet(query=sql)} + assert lineage.job_facets == {"sql": sql_job.SQLJobFacet(query=sql)} assert lineage.run_facets["extractionError"].failedTasks == 1 diff --git a/tests/providers/snowflake/transfers/test_copy_into_snowflake.py b/tests/providers/snowflake/transfers/test_copy_into_snowflake.py index 220de18c2b7f..17b0d1297752 100644 --- a/tests/providers/snowflake/transfers/test_copy_into_snowflake.py +++ b/tests/providers/snowflake/transfers/test_copy_into_snowflake.py @@ -20,13 +20,8 @@ from unittest import mock import pytest -from openlineage.client.facet import ( - ExternalQueryRunFacet, - ExtractionError, - ExtractionErrorRunFacet, - SqlJobFacet, -) -from openlineage.client.run import Dataset +from openlineage.client.event_v2 import Dataset +from openlineage.client.facet_v2 import external_query_run, extraction_error_run, sql_job from airflow.providers.openlineage.extractors import OperatorLineage from airflow.providers.openlineage.sqlparser import DatabaseInfo @@ -124,11 +119,11 @@ def test_get_openlineage_facets_on_complete(self, mock_hook): inputs=expected_inputs, outputs=expected_outputs, run_facets={ - "externalQuery": ExternalQueryRunFacet( + "externalQuery": external_query_run.ExternalQueryRunFacet( externalQueryId="query_id_123", source="snowflake_scheme://authority" ) }, - job_facets={"sql": SqlJobFacet(query=expected_sql)}, + job_facets={"sql": sql_job.SQLJobFacet(query=expected_sql)}, ) @pytest.mark.parametrize("rows", (None, [])) @@ -160,11 +155,11 @@ def test_get_openlineage_facets_on_complete_with_empty_inputs(self, mock_hook, r inputs=[], outputs=expected_outputs, run_facets={ - "externalQuery": ExternalQueryRunFacet( + "externalQuery": external_query_run.ExternalQueryRunFacet( externalQueryId="query_id_123", source="snowflake_scheme://authority" ) }, - job_facets={"sql": SqlJobFacet(query=expected_sql)}, + job_facets={"sql": sql_job.SQLJobFacet(query=expected_sql)}, ) @mock.patch("airflow.providers.snowflake.transfers.copy_into_snowflake.SnowflakeHook") @@ -190,17 +185,17 @@ def test_get_openlineage_facets_on_complete_unsupported_azure_uri(self, mock_hoo ] expected_sql = """COPY INTO schema.table\n FROM @stage/\n FILE_FORMAT=CSV""" expected_run_facets = { - "extractionError": ExtractionErrorRunFacet( + "extractionError": extraction_error_run.ExtractionErrorRunFacet( totalTasks=4, failedTasks=2, errors=[ - ExtractionError( + extraction_error_run.Error( errorMessage="Unable to extract Dataset namespace and name.", stackTrace=None, task="azure://my_account.another_weird-url.net/con/file.csv", taskNumber=None, ), - ExtractionError( + extraction_error_run.Error( errorMessage="Unable to extract Dataset namespace and name.", stackTrace=None, task="azure://my_account.weird-url.net/azure_container/dir3/file.csv", @@ -208,7 +203,7 @@ def test_get_openlineage_facets_on_complete_unsupported_azure_uri(self, mock_hoo ), ], ), - "externalQuery": ExternalQueryRunFacet( + "externalQuery": external_query_run.ExternalQueryRunFacet( externalQueryId="query_id_123", source="snowflake_scheme://authority" ), } @@ -227,5 +222,5 @@ def test_get_openlineage_facets_on_complete_unsupported_azure_uri(self, mock_hoo inputs=expected_inputs, outputs=expected_outputs, run_facets=expected_run_facets, - job_facets={"sql": SqlJobFacet(query=expected_sql)}, + job_facets={"sql": sql_job.SQLJobFacet(query=expected_sql)}, ) diff --git a/tests/providers/trino/operators/test_trino.py b/tests/providers/trino/operators/test_trino.py index 4bc235e5a19a..126b3db6335c 100644 --- a/tests/providers/trino/operators/test_trino.py +++ b/tests/providers/trino/operators/test_trino.py @@ -19,8 +19,8 @@ from unittest import mock -from openlineage.client.facet import SchemaDatasetFacet, SchemaField, SqlJobFacet -from openlineage.client.run import Dataset +from openlineage.client.event_v2 import Dataset +from openlineage.client.facet_v2 import schema_dataset, sql_job from airflow.models.connection import Connection from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator @@ -91,14 +91,14 @@ def get_first(self, *_): namespace="trino://trino:8080", name=f"{DB_NAME}.{DB_SCHEMA_NAME}.customer", facets={ - "schema": SchemaDatasetFacet( + "schema": schema_dataset.SchemaDatasetFacet( fields=[ - SchemaField(name="custkey", type="bigint"), - SchemaField(name="name", type="varchar(25)"), - SchemaField(name="address", type="varchar(40)"), - SchemaField(name="nationkey", type="bigint"), - SchemaField(name="phone", type="varchar(15)"), - SchemaField(name="acctbal", type="double"), + schema_dataset.SchemaDatasetFacetFields(name="custkey", type="bigint"), + schema_dataset.SchemaDatasetFacetFields(name="name", type="varchar(25)"), + schema_dataset.SchemaDatasetFacetFields(name="address", type="varchar(40)"), + schema_dataset.SchemaDatasetFacetFields(name="nationkey", type="bigint"), + schema_dataset.SchemaDatasetFacetFields(name="phone", type="varchar(15)"), + schema_dataset.SchemaDatasetFacetFields(name="acctbal", type="double"), ] ) }, @@ -107,4 +107,4 @@ def get_first(self, *_): assert len(lineage.outputs) == 0 - assert lineage.job_facets == {"sql": SqlJobFacet(query=sql)} + assert lineage.job_facets == {"sql": sql_job.SQLJobFacet(query=sql)} From 9c60d952eda2d3e75a488922f1a5cb7da436c9d1 Mon Sep 17 00:00:00 2001 From: Jakub Dardzinski Date: Wed, 15 May 2024 20:49:37 +0200 Subject: [PATCH 4/7] Migrate Google provider to V2 facets. Signed-off-by: Jakub Dardzinski --- .../providers/amazon/aws/operators/athena.py | 68 ++++++++++++--- .../google/cloud/openlineage/utils.py | 86 +++++++++++++------ .../google/cloud/operators/bigquery.py | 2 +- .../providers/openlineage/plugins/adapter.py | 7 +- .../providers/snowflake/hooks/snowflake.py | 10 ++- .../amazon/aws/operators/test_athena.py | 51 ++++++----- .../amazon/aws/operators/test_redshift_sql.py | 58 ++++++++----- .../google/cloud/openlineage/test_utils.py | 10 ++- .../google/cloud/operators/test_bigquery.py | 13 +-- .../openlineage/extractors/test_base.py | 5 +- 10 files changed, 211 insertions(+), 99 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/athena.py b/airflow/providers/amazon/aws/operators/athena.py index 08da5ef0e4df..b27ddf33ebb1 100644 --- a/airflow/providers/amazon/aws/operators/athena.py +++ b/airflow/providers/amazon/aws/operators/athena.py @@ -217,17 +217,38 @@ def get_openlineage_facets_on_complete(self, _) -> OperatorLineage: path where the results are saved (user's prefix + some UUID), we are creating a dataset with the user-provided path only. This should make it easier to match this dataset across different processes. """ - from openlineage.client.event_v2 import Dataset - from openlineage.client.facet_v2 import extraction_error_run, external_query_run, sql_job + if TYPE_CHECKING: + from openlineage.client.event_v2 import Dataset + from openlineage.client.generated.external_query_run import ExternalQueryRunFacet + from openlineage.client.generated.extraction_error_run import ( + Error, + ExtractionErrorRunFacet, + ) + from openlineage.client.generated.sql_job import SQLJobFacet as SqlJobFacet + else: + try: + from openlineage.client.event_v2 import Dataset + from openlineage.client.generated.external_query_run import ExternalQueryRunFacet + from openlineage.client.generated.extraction_error_run import ( + Error, + ExtractionErrorRunFacet, + ) + from openlineage.client.generated.sql_job import SQLJobFacet as SqlJobFacet + except ImportError: + from openlineage.client.facet import ( + ExternalQueryRunFacet, + ExtractionError as Error, + ExtractionErrorRunFacet, + SqlJobFacet, + ) + from openlineage.client.run import Dataset from airflow.providers.openlineage.extractors.base import OperatorLineage from airflow.providers.openlineage.sqlparser import SQLParser sql_parser = SQLParser(dialect="generic") - job_facets: dict[str, BaseFacet] = { - "sql": sql_job.SQLJobFacet(query=sql_parser.normalize_sql(self.query)) - } + job_facets: dict[str, BaseFacet] = {"sql": SqlJobFacet(query=sql_parser.normalize_sql(self.query))} parse_result = sql_parser.parse(sql=self.query) if not parse_result: @@ -235,11 +256,11 @@ def get_openlineage_facets_on_complete(self, _) -> OperatorLineage: run_facets: dict[str, BaseFacet] = {} if parse_result.errors: - run_facets["extractionError"] = extraction_error_run.ExtractionErrorRunFacet( + run_facets["extractionError"] = ExtractionErrorRunFacet( totalTasks=len(self.query) if isinstance(self.query, list) else 1, failedTasks=len(parse_result.errors), errors=[ - extraction_error_run.Error( + Error( errorMessage=error.message, stackTrace=None, task=error.origin_statement, @@ -270,7 +291,7 @@ def get_openlineage_facets_on_complete(self, _) -> OperatorLineage: ) if self.query_execution_id: - run_facets["externalQuery"] = external_query_run.ExternalQueryRunFacet( + run_facets["externalQuery"] = ExternalQueryRunFacet( externalQueryId=self.query_execution_id, source="awsathena" ) @@ -281,8 +302,27 @@ def get_openlineage_facets_on_complete(self, _) -> OperatorLineage: return OperatorLineage(job_facets=job_facets, run_facets=run_facets, inputs=inputs, outputs=outputs) def get_openlineage_dataset(self, database, table) -> Dataset | None: - from openlineage.client.event_v2 import Dataset - from openlineage.client.facet_v2 import schema_dataset, symlinks_dataset + if TYPE_CHECKING: + from openlineage.client.event_v2 import Dataset + from openlineage.client.generated.schema_dataset import ( + SchemaDatasetFacet, + SchemaDatasetFacetFields, + ) + from openlineage.client.generated.symlinks_dataset import Identifier, SymlinksDatasetFacet + else: + try: + from openlineage.client.event_v2 import Dataset + from openlineage.client.generated.schema_dataset import ( + SchemaDatasetFacet, + SchemaDatasetFacetFields, + ) + from openlineage.client.generated.symlinks_dataset import Identifier, SymlinksDatasetFacet + except ImportError: + from openlineage.client.facet import ( + SchemaDatasetFacet, + SchemaField as SchemaDatasetFacetFields, + SymlinksDatasetFacet, + ) client = self.hook.get_conn() try: @@ -294,9 +334,9 @@ def get_openlineage_dataset(self, database, table) -> Dataset | None: s3_location = table_metadata["TableMetadata"]["Parameters"]["location"] parsed_path = urlparse(s3_location) facets: dict[str, DatasetFacet] = { - "symlinks": symlinks_dataset.SymlinksDatasetFacet( + "symlinks": SymlinksDatasetFacet( identifiers=[ - symlinks_dataset.Identifier( + Identifier( namespace=f"{parsed_path.scheme}://{parsed_path.netloc}", name=str(parsed_path.path), type="TABLE", @@ -305,13 +345,13 @@ def get_openlineage_dataset(self, database, table) -> Dataset | None: ) } fields = [ - schema_dataset.SchemaDatasetFacetFields( + SchemaDatasetFacetFields( name=column["Name"], type=column["Type"], description=column["Comment"] ) for column in table_metadata["TableMetadata"]["Columns"] ] if fields: - facets["schema"] = schema_dataset.SchemaDatasetFacet(fields=fields) + facets["schema"] = SchemaDatasetFacet(fields=fields) return Dataset( namespace=f"awsathena://athena.{self.hook.region_name}.amazonaws.com", name=".".join(filter(None, (self.catalog, database, table))), diff --git a/airflow/providers/google/cloud/openlineage/utils.py b/airflow/providers/google/cloud/openlineage/utils.py index c4c4c72b5b85..f7ebe3641dfb 100644 --- a/airflow/providers/google/cloud/openlineage/utils.py +++ b/airflow/providers/google/cloud/openlineage/utils.py @@ -19,18 +19,19 @@ import copy import json -from msilib import schema import traceback -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from attr import define, field +from openlineage.client.event_v2 import Dataset, InputDataset, OutputDataset from openlineage.client.facet_v2 import ( BaseFacet, column_lineage_dataset, documentation_dataset, - error_message_run,external_query_run, + error_message_run, + external_query_run, output_statistics_output_dataset, - schema_dataset + schema_dataset, ) from airflow.providers.google import __version__ as provider_version @@ -49,7 +50,9 @@ def get_facets_from_bq_table(table: Table) -> dict[Any, Any]: facets = { "schema": schema_dataset.SchemaDatasetFacet( fields=[ - schema_dataset.SchemaDatasetFacetFields(name=field.name, type=field.field_type, description=field.description) + schema_dataset.SchemaDatasetFacetFields( + name=field.name, type=field.field_type, description=field.description + ) for field in table.schema ] ), @@ -208,7 +211,7 @@ def get_openlineage_facets_on_complete(self, _): - SchemaDatasetFacet - OutputStatisticsOutputDatasetFacet """ - from openlineage.client.facet import ExternalQueryRunFacet, SqlJobFacet + from openlineage.client.facet_v2 import sql_job from airflow.providers.openlineage.extractors import OperatorLineage from airflow.providers.openlineage.sqlparser import SQLParser @@ -217,10 +220,12 @@ def get_openlineage_facets_on_complete(self, _): return OperatorLineage() run_facets: dict[str, BaseFacet] = { - "externalQuery": external_query_run.ExternalQueryRunFacet(externalQueryId=self.job_id, source="bigquery") + "externalQuery": external_query_run.ExternalQueryRunFacet( + externalQueryId=self.job_id, source="bigquery" + ) } - job_facets = {"sql": SqlJobFacet(query=SQLParser.normalize_sql(self.sql))} + job_facets = {"sql": sql_job.SQLJobFacet(query=SQLParser.normalize_sql(self.sql))} self.client = self.hook.get_client(project_id=self.hook.project_id) job_ids = self.job_id @@ -288,7 +293,7 @@ def get_facets(self, job_id: str): deduplicated_outputs = self._deduplicate_outputs(outputs) return inputs, deduplicated_outputs, run_facets - def _deduplicate_outputs(self, outputs: list[Dataset | None]) -> list[Dataset]: + def _deduplicate_outputs(self, outputs: list[OutputDataset | None]) -> list[OutputDataset]: # Sources are the same so we can compare only names final_outputs = {} for single_output in outputs: @@ -301,20 +306,24 @@ def _deduplicate_outputs(self, outputs: list[Dataset | None]) -> list[Dataset]: # No OutputStatisticsOutputDatasetFacet is added to duplicated outputs as we can not determine # if the rowCount or size can be summed together. - single_output.facets.pop("outputStatistics", None) + if single_output.outputFacets: + single_output.outputFacets.pop("outputStatistics", None) final_outputs[key] = single_output return list(final_outputs.values()) - def _get_inputs_outputs_from_job(self, properties: dict) -> tuple[list[Dataset], Dataset | None]: + def _get_inputs_outputs_from_job( + self, properties: dict + ) -> tuple[list[InputDataset], OutputDataset | None]: input_tables = get_from_nullable_chain(properties, ["statistics", "query", "referencedTables"]) or [] output_table = get_from_nullable_chain(properties, ["configuration", "query", "destinationTable"]) - inputs = [self._get_dataset(input_table) for input_table in input_tables] + inputs = [(self._get_input_dataset(input_table)) for input_table in input_tables] if output_table: - output = self._get_dataset(output_table) + output = self._get_output_dataset(output_table) dataset_stat_facet = self._get_statistics_dataset_facet(properties) + output.outputFacets = output.outputFacets or {} if dataset_stat_facet: - output.facets.update({"outputStatistics": dataset_stat_facet}) + output.outputFacets["outputStatistics"] = dataset_stat_facet return inputs, output @@ -333,7 +342,9 @@ def _get_bigquery_job_run_facet(properties: dict) -> BigQueryJobRunFacet: ) @staticmethod - def _get_statistics_dataset_facet(properties) -> output_statistics_output_dataset.OutputStatisticsOutputDatasetFacet | None: + def _get_statistics_dataset_facet( + properties, + ) -> output_statistics_output_dataset.OutputStatisticsOutputDatasetFacet | None: query_plan = get_from_nullable_chain(properties, chain=["statistics", "query", "queryPlan"]) if not query_plan: return None @@ -342,25 +353,48 @@ def _get_statistics_dataset_facet(properties) -> output_statistics_output_datase out_rows = out_stage.get("recordsWritten", None) out_bytes = out_stage.get("shuffleOutputBytes", None) if out_bytes and out_rows: - return output_statistics_output_dataset.OutputStatisticsOutputDatasetFacet(rowCount=int(out_rows), size=int(out_bytes)) + return output_statistics_output_dataset.OutputStatisticsOutputDatasetFacet( + rowCount=int(out_rows), size=int(out_bytes) + ) return None - def _get_dataset(self, table: dict) -> Dataset: + def _get_input_dataset(self, table: dict) -> InputDataset: + return cast(InputDataset, self._get_dataset(table, "input")) + + def _get_output_dataset(self, table: dict) -> OutputDataset: + return cast(OutputDataset, self._get_dataset(table, "output")) + + def _get_dataset(self, table: dict, dataset_type: str) -> Dataset: project = table.get("projectId") dataset = table.get("datasetId") table_name = table.get("tableId") dataset_name = f"{project}.{dataset}.{table_name}" dataset_schema = self._get_table_schema_safely(dataset_name) - return Dataset( - namespace=BIGQUERY_NAMESPACE, - name=dataset_name, - facets={ - "schema": dataset_schema, - } - if dataset_schema - else {}, - ) + if dataset_type == "input": + # Logic specific to creating InputDataset (if needed) + return InputDataset( + namespace=BIGQUERY_NAMESPACE, + name=dataset_name, + facets={ + "schema": dataset_schema, + } + if dataset_schema + else {}, + ) + elif dataset_type == "output": + # Logic specific to creating OutputDataset (if needed) + return OutputDataset( + namespace=BIGQUERY_NAMESPACE, + name=dataset_name, + facets={ + "schema": dataset_schema, + } + if dataset_schema + else {}, + ) + else: + raise ValueError("Invalid dataset_type. Must be 'input' or 'output'") def _get_table_schema_safely(self, table_name: str) -> schema_dataset.SchemaDatasetFacet | None: try: diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py index 2e5a38d90a77..d55651d06b43 100644 --- a/airflow/providers/google/cloud/operators/bigquery.py +++ b/airflow/providers/google/cloud/operators/bigquery.py @@ -3066,4 +3066,4 @@ def on_kill(self) -> None: job_id=self.job_id, project_id=self.project_id, location=self.location ) else: - self.log.info("Skipping to cancel job: %s:%s.%s", self.project_id, self.location, self.job_id) \ No newline at end of file + self.log.info("Skipping to cancel job: %s:%s.%s", self.project_id, self.location, self.job_id) diff --git a/airflow/providers/openlineage/plugins/adapter.py b/airflow/providers/openlineage/plugins/adapter.py index 6f62e772a8be..bdd34e8abac4 100644 --- a/airflow/providers/openlineage/plugins/adapter.py +++ b/airflow/providers/openlineage/plugins/adapter.py @@ -434,10 +434,9 @@ def _build_run( {"nominalTime": nominal_time_run.NominalTimeRunFacet(nominal_start_time, nominal_end_time)} ) if parent_run_id: - parent_run_facet = parent_run.ParentRunFacet.create( - runId=parent_run_id, - namespace=conf.namespace(), - name=parent_job_name or job_name, + parent_run_facet = parent_run.ParentRunFacet( + run=parent_run.Run(runId=parent_run_id), + job=parent_run.Job(namespace=conf.namespace(), name=parent_job_name or job_name), ) facets.update({"parent": parent_run_facet}) diff --git a/airflow/providers/snowflake/hooks/snowflake.py b/airflow/providers/snowflake/hooks/snowflake.py index 844cbaf9e298..bc6462efdaa9 100644 --- a/airflow/providers/snowflake/hooks/snowflake.py +++ b/airflow/providers/snowflake/hooks/snowflake.py @@ -480,7 +480,13 @@ def _get_openlineage_authority(self, _) -> str | None: return urlparse(uri).hostname def get_openlineage_database_specific_lineage(self, _) -> OperatorLineage | None: - from openlineage.client.facet_v2 import external_query_run + if TYPE_CHECKING: + from openlineage.client.generated.external_query_run import ExternalQueryRunFacet + else: + try: + from openlineage.client.generated.external_query_run import ExternalQueryRunFacet + except ImportError: + from openlineage.client.facet import ExternalQueryRunFacet from airflow.providers.openlineage.extractors import OperatorLineage from airflow.providers.openlineage.sqlparser import SQLParser @@ -491,7 +497,7 @@ def get_openlineage_database_specific_lineage(self, _) -> OperatorLineage | None namespace = SQLParser.create_namespace(self.get_openlineage_database_info(connection)) return OperatorLineage( run_facets={ - "externalQuery": external_query_run.ExternalQueryRunFacet( + "externalQuery": ExternalQueryRunFacet( externalQueryId=self.query_ids[0], source=namespace ) } diff --git a/tests/providers/amazon/aws/operators/test_athena.py b/tests/providers/amazon/aws/operators/test_athena.py index 51fef31c68d4..d603d7c5c1aa 100644 --- a/tests/providers/amazon/aws/operators/test_athena.py +++ b/tests/providers/amazon/aws/operators/test_athena.py @@ -20,8 +20,23 @@ from unittest import mock import pytest -from openlineage.client.event_v2 import Dataset -from openlineage.client.facet_v2 import external_query_run, schema_dataset, sql_job, symlinks_dataset + +try: + from openlineage.client.event_v2 import Dataset + from openlineage.client.generated.external_query_run import ExternalQueryRunFacet + from openlineage.client.generated.schema_dataset import SchemaDatasetFacet, SchemaDatasetFacetFields + from openlineage.client.generated.sql_job import SQLJobFacet + from openlineage.client.generated.symlinks_dataset import Identifier, SymlinksDatasetFacet +except ImportError: + from openlineage.client.facet import ( + ExternalQueryRunFacet, + SchemaDatasetFacet, + SchemaField as SchemaDatasetFacetFields, + SqlJobFacet as SQLJobFacet, + SymlinksDatasetFacet, + SymlinksDatasetFacetIdentifiers as Identifier, + ) + from openlineage.client.run import Dataset from airflow.exceptions import AirflowException, TaskDeferred from airflow.models import DAG, DagRun, TaskInstance @@ -305,38 +320,38 @@ def mock_get_table_metadata(CatalogName, DatabaseName, TableName): namespace="awsathena://athena.eu-west-1.amazonaws.com", name="AwsDataCatalog.TEST_DATABASE.DISCOUNTS", facets={ - "symlinks": symlinks_dataset.SymlinksDatasetFacet( + "symlinks": SymlinksDatasetFacet( identifiers=[ - symlinks_dataset.Identifier( + Identifier( namespace="s3://bucket", name="/discount/data/path/", type="TABLE", ) ], ), - "schema": schema_dataset.SchemaDatasetFacet( + "schema": SchemaDatasetFacet( fields=[ - schema_dataset.SchemaDatasetFacetFields( + SchemaDatasetFacetFields( name="ID", type="int", description="from deserializer", ), - schema_dataset.SchemaDatasetFacetFields( + SchemaDatasetFacetFields( name="AMOUNT_OFF", type="int", description="from deserializer", ), - schema_dataset.SchemaDatasetFacetFields( + SchemaDatasetFacetFields( name="CUSTOMER_EMAIL", type="varchar", description="from deserializer", ), - schema_dataset.SchemaDatasetFacetFields( + SchemaDatasetFacetFields( name="STARTS_ON", type="timestamp", description="from deserializer", ), - schema_dataset.SchemaDatasetFacetFields( + SchemaDatasetFacetFields( name="ENDS_ON", type="timestamp", description="from deserializer", @@ -351,18 +366,18 @@ def mock_get_table_metadata(CatalogName, DatabaseName, TableName): namespace="awsathena://athena.eu-west-1.amazonaws.com", name="AwsDataCatalog.TEST_DATABASE.TEST_TABLE", facets={ - "symlinks": symlinks_dataset.SymlinksDatasetFacet( + "symlinks": SymlinksDatasetFacet( identifiers=[ - symlinks_dataset.Identifier( + Identifier( namespace="s3://bucket", name="/data/test_table/data/path", type="TABLE", ) ], ), - "schema": schema_dataset.SchemaDatasetFacet( + "schema": SchemaDatasetFacet( fields=[ - schema_dataset.SchemaDatasetFacetFields( + SchemaDatasetFacetFields( name="column", type="string", description="from deserializer", @@ -374,14 +389,10 @@ def mock_get_table_metadata(CatalogName, DatabaseName, TableName): Dataset(namespace="s3://test_s3_bucket", name="/"), ], job_facets={ - "sql": sql_job.SQLJobFacet( + "sql": SQLJobFacet( query="INSERT INTO TEST_TABLE SELECT CUSTOMER_EMAIL FROM DISCOUNTS", ) }, - run_facets={ - "externalQuery": external_query_run.ExternalQueryRunFacet( - externalQueryId="12345", source="awsathena" - ) - }, + run_facets={"externalQuery": ExternalQueryRunFacet(externalQueryId="12345", source="awsathena")}, ) assert op.get_openlineage_facets_on_complete(None) == expected_lineage diff --git a/tests/providers/amazon/aws/operators/test_redshift_sql.py b/tests/providers/amazon/aws/operators/test_redshift_sql.py index 6df87ac1a443..94e2b5b2f397 100644 --- a/tests/providers/amazon/aws/operators/test_redshift_sql.py +++ b/tests/providers/amazon/aws/operators/test_redshift_sql.py @@ -20,8 +20,26 @@ from unittest.mock import MagicMock, PropertyMock, call, patch import pytest -from openlineage.client.event_v2 import Dataset -from openlineage.client.facet_v2 import column_lineage_dataset, schema_dataset, sql_job + +try: + from openlineage.client.event_v2 import Dataset + from openlineage.client.generated.column_lineage_dataset import ( + ColumnLineageDatasetFacet, + Fields, + InputField, + ) + from openlineage.client.generated.schema_dataset import SchemaDatasetFacet, SchemaDatasetFacetFields + from openlineage.client.generated.sql_job import SQLJobFacet +except ImportError: + from openlineage.client.facet import ( + ColumnLineageDatasetFacet, + ColumnLineageDatasetFacetFieldsAdditional as Fields, + ColumnLineageDatasetFacetFieldsAdditionalInputFields as InputField, + SchemaDatasetFacet, + SchemaField as SchemaDatasetFacetFields, + SqlJobFacet as SQLJobFacet, + ) + from openlineage.client.run import Dataset from airflow.models.connection import Connection from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook as OriginalRedshiftSQLHook @@ -206,11 +224,11 @@ def get_db_hook(self): namespace=expected_namespace, name=f"{ANOTHER_DB_NAME}.{ANOTHER_DB_SCHEMA}.popular_orders_day_of_week", facets={ - "schema": schema_dataset.SchemaDatasetFacet( + "schema": SchemaDatasetFacet( fields=[ - schema_dataset.SchemaDatasetFacetFields(name="order_day_of_week", type="varchar"), - schema_dataset.SchemaDatasetFacetFields(name="order_placed_on", type="timestamp"), - schema_dataset.SchemaDatasetFacetFields(name="orders_placed", type="int4"), + SchemaDatasetFacetFields(name="order_day_of_week", type="varchar"), + SchemaDatasetFacetFields(name="order_placed_on", type="timestamp"), + SchemaDatasetFacetFields(name="orders_placed", type="int4"), ] ) }, @@ -219,12 +237,10 @@ def get_db_hook(self): namespace=expected_namespace, name=f"{DB_NAME}.{DB_SCHEMA_NAME}.little_table", facets={ - "schema": schema_dataset.SchemaDatasetFacet( + "schema": SchemaDatasetFacet( fields=[ - schema_dataset.SchemaDatasetFacetFields(name="order_day_of_week", type="varchar"), - schema_dataset.SchemaDatasetFacetFields( - name="additional_constant", type="varchar" - ), + SchemaDatasetFacetFields(name="order_day_of_week", type="varchar"), + SchemaDatasetFacetFields(name="additional_constant", type="varchar"), ] ) }, @@ -235,21 +251,19 @@ def get_db_hook(self): namespace=expected_namespace, name=f"{DB_NAME}.{DB_SCHEMA_NAME}.test_table", facets={ - "schema": schema_dataset.SchemaDatasetFacet( + "schema": SchemaDatasetFacet( fields=[ - schema_dataset.SchemaDatasetFacetFields(name="order_day_of_week", type="varchar"), - schema_dataset.SchemaDatasetFacetFields(name="order_placed_on", type="timestamp"), - schema_dataset.SchemaDatasetFacetFields(name="orders_placed", type="int4"), - schema_dataset.SchemaDatasetFacetFields( - name="additional_constant", type="varchar" - ), + SchemaDatasetFacetFields(name="order_day_of_week", type="varchar"), + SchemaDatasetFacetFields(name="order_placed_on", type="timestamp"), + SchemaDatasetFacetFields(name="orders_placed", type="int4"), + SchemaDatasetFacetFields(name="additional_constant", type="varchar"), ] ), - "columnLineage": column_lineage_dataset.ColumnLineageDatasetFacet( + "columnLineage": ColumnLineageDatasetFacet( fields={ - "additional_constant": column_lineage_dataset.Fields( + "additional_constant": Fields( inputFields=[ - column_lineage_dataset.InputField( + InputField( namespace=expected_namespace, name="database.public.little_table", field="additional_constant", @@ -264,6 +278,6 @@ def get_db_hook(self): ) ] - assert lineage.job_facets == {"sql": sql_job.SQLJobFacet(query=sql)} + assert lineage.job_facets == {"sql": SQLJobFacet(query=sql)} assert lineage.run_facets["extractionError"].failedTasks == 1 diff --git a/tests/providers/google/cloud/openlineage/test_utils.py b/tests/providers/google/cloud/openlineage/test_utils.py index bfa63943a145..972337f92aac 100644 --- a/tests/providers/google/cloud/openlineage/test_utils.py +++ b/tests/providers/google/cloud/openlineage/test_utils.py @@ -21,8 +21,14 @@ import pytest from google.cloud.bigquery.table import Table -from openlineage.client.event_v2 import Dataset -from openlineage.client.facet_v2 import column_lineage_dataset, documentation_dataset, schema_dataset +from openlineage.client.event_v2 import Dataset, InputDataset, OutputDataset +from openlineage.client.facet_v2 import ( + column_lineage_dataset, + documentation_dataset, + external_query_run, + output_statistics_output_dataset, + schema_dataset, +) from airflow.providers.google.cloud.openlineage.utils import ( get_facets_from_bq_table, diff --git a/tests/providers/google/cloud/operators/test_bigquery.py b/tests/providers/google/cloud/operators/test_bigquery.py index 342c138f9bd2..13e550b16ff3 100644 --- a/tests/providers/google/cloud/operators/test_bigquery.py +++ b/tests/providers/google/cloud/operators/test_bigquery.py @@ -26,9 +26,8 @@ import pytest from google.cloud.bigquery import DEFAULT_RETRY, ScalarQueryParameter from google.cloud.exceptions import Conflict -from openlineage.client.facet import DataSourceDatasetFacet, ExternalQueryRunFacet -from openlineage.client.facet_v2 import sql_job -from openlineage.client.run import Dataset +from openlineage.client.event_v2 import InputDataset +from openlineage.client.facet_v2 import error_message_run, external_query_run, sql_job from airflow.exceptions import ( AirflowException, @@ -1845,13 +1844,15 @@ def test_execute_openlineage_events(self, mock_hook): lineage = op.get_openlineage_facets_on_complete(None) assert lineage.inputs == [ - Dataset(namespace="bigquery", name="airflow-openlineage.new_dataset.test_table") + InputDataset(namespace="bigquery", name="airflow-openlineage.new_dataset.test_table") ] assert lineage.run_facets == { "bigQuery_job": mock.ANY, "bigQueryJob": mock.ANY, - "externalQuery": ExternalQueryRunFacet(externalQueryId=mock.ANY, source="bigquery"), + "externalQuery": external_query_run.ExternalQueryRunFacet( + externalQueryId=mock.ANY, source="bigquery" + ), } assert lineage.job_facets == {"sql": sql_job.SQLJobFacet(query="SELECT * FROM test_table")} @@ -1880,7 +1881,7 @@ def test_execute_fails_openlineage_events(self, mock_hook): operator.execute(MagicMock()) lineage = operator.get_openlineage_facets_on_complete(None) - assert isinstance(lineage.run_facets["errorMessage"], ErrorMessageRunFacet) + assert isinstance(lineage.run_facets["errorMessage"], error_message_run.ErrorMessageRunFacet) @pytest.mark.db_test @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") diff --git a/tests/providers/openlineage/extractors/test_base.py b/tests/providers/openlineage/extractors/test_base.py index 20ceba45dd87..2f847140070a 100644 --- a/tests/providers/openlineage/extractors/test_base.py +++ b/tests/providers/openlineage/extractors/test_base.py @@ -42,8 +42,9 @@ INPUTS = [Dataset(namespace="database://host:port", name="inputtable")] OUTPUTS = [Dataset(namespace="database://host:port", name="inputtable")] RUN_FACETS: dict[str, RunFacet] = { - "parent": parent_run.ParentRunFacet.create( - "3bb703d1-09c1-4a42-8da5-35a0b3216072", "namespace", "parentjob" + "parent": parent_run.ParentRunFacet( + run=parent_run.Run(runId="3bb703d1-09c1-4a42-8da5-35a0b3216072"), + job=parent_run.Job(namespace="namespace", name="parentjob"), ) } JOB_FACETS: dict[str, JobFacet] = {"sql": sql_job.SQLJobFacet(query="SELECT * FROM inputtable")} From 0483741cbcb8c1aea9e69fb7025f0b7401f2bc3a Mon Sep 17 00:00:00 2001 From: Jakub Dardzinski Date: Mon, 3 Jun 2024 17:54:14 +0200 Subject: [PATCH 5/7] Make migration backwards-compatible with previous OL provider versions. Signed-off-by: Jakub Dardzinski --- .../providers/amazon/aws/operators/athena.py | 2 + airflow/providers/amazon/aws/operators/s3.py | 45 ++- .../amazon/aws/operators/sagemaker.py | 6 +- .../common/io/operators/file_transfer.py | 8 +- airflow/providers/ftp/operators/ftp.py | 10 +- .../google/cloud/openlineage/mixins.py | 150 ++++++--- .../google/cloud/openlineage/utils.py | 311 +++--------------- .../providers/google/cloud/operators/gcs.py | 44 ++- .../google/cloud/transfers/bigquery_to_gcs.py | 29 +- .../google/cloud/transfers/gcs_to_bigquery.py | 30 +- .../google/cloud/transfers/gcs_to_gcs.py | 8 +- .../providers/openlineage/plugins/adapter.py | 2 +- airflow/providers/openlineage/utils/utils.py | 7 +- airflow/providers/sftp/operators/sftp.py | 10 +- .../transfers/copy_into_snowflake.py | 31 +- .../amazon/aws/operators/test_athena.py | 29 +- .../amazon/aws/operators/test_redshift_sql.py | 33 +- .../providers/amazon/aws/operators/test_s3.py | 44 ++- .../operators/test_sagemaker_processing.py | 10 +- .../aws/operators/test_sagemaker_training.py | 10 +- .../aws/operators/test_sagemaker_transform.py | 10 +- .../common/io/operators/test_file_transfer.py | 9 +- .../common/sql/operators/test_sql_execute.py | 31 +- .../dbt/cloud/utils/test_openlineage.py | 18 +- tests/providers/ftp/operators/test_ftp.py | 10 +- .../google/cloud/openlineage/test_mixins.py | 89 +++-- .../google/cloud/openlineage/test_utils.py | 73 ++-- .../google/cloud/operators/test_bigquery.py | 30 +- .../google/cloud/operators/test_gcs.py | 32 +- .../cloud/transfers/test_bigquery_to_gcs.py | 102 +++--- .../cloud/transfers/test_gcs_to_bigquery.py | 154 +++++---- .../google/cloud/transfers/test_gcs_to_gcs.py | 10 +- tests/providers/mysql/operators/test_mysql.py | 30 +- tests/providers/sftp/operators/test_sftp.py | 10 +- .../snowflake/operators/test_snowflake_sql.py | 36 +- .../transfers/test_copy_into_snowflake.py | 43 ++- tests/providers/trino/operators/test_trino.py | 35 +- 37 files changed, 901 insertions(+), 640 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/athena.py b/airflow/providers/amazon/aws/operators/athena.py index b27ddf33ebb1..255819b6bfc3 100644 --- a/airflow/providers/amazon/aws/operators/athena.py +++ b/airflow/providers/amazon/aws/operators/athena.py @@ -322,7 +322,9 @@ def get_openlineage_dataset(self, database, table) -> Dataset | None: SchemaDatasetFacet, SchemaField as SchemaDatasetFacetFields, SymlinksDatasetFacet, + SymlinksDatasetFacetIdentifiers as Identifier, ) + from openlineage.client.run import Dataset client = self.hook.get_conn() try: diff --git a/airflow/providers/amazon/aws/operators/s3.py b/airflow/providers/amazon/aws/operators/s3.py index a1464cc32816..f9c4b8808fe2 100644 --- a/airflow/providers/amazon/aws/operators/s3.py +++ b/airflow/providers/amazon/aws/operators/s3.py @@ -324,7 +324,10 @@ def execute(self, context: Context): ) def get_openlineage_facets_on_start(self): - from openlineage.client.event_v2 import Dataset + try: + from openlineage.client.event_v2 import Dataset + except ImportError: + from openlineage.client.run import Dataset from airflow.providers.openlineage.extractors import OperatorLineage @@ -439,7 +442,10 @@ def execute(self, context: Context): s3_hook.load_bytes(self.data, s3_key, s3_bucket, self.replace, self.encrypt, self.acl_policy) def get_openlineage_facets_on_start(self): - from openlineage.client.event_v2 import Dataset + try: + from openlineage.client.event_v2 import Dataset + except ImportError: + from openlineage.client.run import Dataset from airflow.providers.openlineage.extractors import OperatorLineage @@ -546,8 +552,28 @@ def execute(self, context: Context): def get_openlineage_facets_on_complete(self, task_instance): """Implement _on_complete because object keys are resolved in execute().""" - from openlineage.client.event_v2 import Dataset - from openlineage.client.facet_v2 import lifecycle_state_change_dataset + if TYPE_CHECKING: + from openlineage.client.event_v2 import Dataset + from openlineage.client.generated.lifecycle_state_change_dataset import ( + LifecycleStateChange, + LifecycleStateChangeDatasetFacet, + PreviousIdentifier, + ) + else: + try: + from openlineage.client.event_v2 import Dataset + from openlineage.client.generated.lifecycle_state_change_dataset import ( + LifecycleStateChange, + LifecycleStateChangeDatasetFacet, + PreviousIdentifier, + ) + except ImportError: + from openlineage.client.facet import ( + LifecycleStateChange, + LifecycleStateChangeDatasetFacet, + LifecycleStateChangeDatasetFacetPreviousIdentifier as PreviousIdentifier, + ) + from openlineage.client.run import Dataset from airflow.providers.openlineage.extractors import OperatorLineage @@ -564,9 +590,9 @@ def get_openlineage_facets_on_complete(self, task_instance): namespace=bucket_url, name=key, facets={ - "lifecycleStateChange": lifecycle_state_change_dataset.LifecycleStateChangeDatasetFacet( - lifecycleStateChange=lifecycle_state_change_dataset.LifecycleStateChange.DROP.value, - previousIdentifier=lifecycle_state_change_dataset.PreviousIdentifier( + "lifecycleStateChange": LifecycleStateChangeDatasetFacet( + lifecycleStateChange=LifecycleStateChange.DROP.value, + previousIdentifier=PreviousIdentifier( namespace=bucket_url, name=key, ), @@ -721,7 +747,10 @@ def execute(self, context: Context): self.log.info("Upload successful") def get_openlineage_facets_on_start(self): - from openlineage.client.event_v2 import Dataset + try: + from openlineage.client.event_v2 import Dataset + except ImportError: + from openlineage.client.run import Dataset from airflow.providers.openlineage.extractors import OperatorLineage diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py b/airflow/providers/amazon/aws/operators/sagemaker.py index 0ac5ad700996..5e7a1dfbb641 100644 --- a/airflow/providers/amazon/aws/operators/sagemaker.py +++ b/airflow/providers/amazon/aws/operators/sagemaker.py @@ -208,7 +208,11 @@ def hook(self): @staticmethod def path_to_s3_dataset(path) -> Dataset: - from openlineage.client.event_v2 import Dataset + if not TYPE_CHECKING: + try: + from openlineage.client.event_v2 import Dataset + except ImportError: + from openlineage.client.run import Dataset path = path.replace("s3://", "") split_path = path.split("/") diff --git a/airflow/providers/common/io/operators/file_transfer.py b/airflow/providers/common/io/operators/file_transfer.py index 273984e94136..25f5d7169f04 100644 --- a/airflow/providers/common/io/operators/file_transfer.py +++ b/airflow/providers/common/io/operators/file_transfer.py @@ -75,7 +75,13 @@ def execute(self, context: Context) -> None: src.copy(dst) def get_openlineage_facets_on_start(self) -> OperatorLineage: - from openlineage.client.event_v2 import Dataset + if TYPE_CHECKING: + from openlineage.client.event_v2 import Dataset + else: + try: + from openlineage.client.event_v2 import Dataset + except ImportError: + from openlineage.client.run import Dataset from airflow.providers.openlineage.extractors import OperatorLineage diff --git a/airflow/providers/ftp/operators/ftp.py b/airflow/providers/ftp/operators/ftp.py index 24f0a3d35130..856d70dcd533 100644 --- a/airflow/providers/ftp/operators/ftp.py +++ b/airflow/providers/ftp/operators/ftp.py @@ -26,6 +26,8 @@ from pathlib import Path from typing import Any, Sequence +from git import TYPE_CHECKING + from airflow.models import BaseOperator from airflow.providers.ftp.hooks.ftp import FTPHook, FTPSHook @@ -146,7 +148,13 @@ def get_openlineage_facets_on_start(self): input: file://hostname/path output file://:/path. """ - from openlineage.client.event_v2 import Dataset + if TYPE_CHECKING: + from openlineage.client.event_v2 import Dataset + else: + try: + from openlineage.client.event_v2 import Dataset + except ImportError: + from openlineage.client.run import Dataset from airflow.providers.openlineage.extractors import OperatorLineage diff --git a/airflow/providers/google/cloud/openlineage/mixins.py b/airflow/providers/google/cloud/openlineage/mixins.py index 48ff695c72eb..71c41273c107 100644 --- a/airflow/providers/google/cloud/openlineage/mixins.py +++ b/airflow/providers/google/cloud/openlineage/mixins.py @@ -20,19 +20,22 @@ import copy import json import traceback -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast if TYPE_CHECKING: - from openlineage.client.facet import ( - BaseFacet, + from openlineage.client.event_v2 import Dataset, InputDataset, OutputDataset + from openlineage.client.generated.base import RunFacet + from openlineage.client.generated.output_statistics_output_dataset import ( OutputStatisticsOutputDatasetFacet, - SchemaDatasetFacet, ) - from openlineage.client.run import Dataset + from openlineage.client.generated.schema_dataset import SchemaDatasetFacet from airflow.providers.google.cloud.openlineage.utils import BigQueryJobRunFacet +BIGQUERY_NAMESPACE = "bigquery" + + class _BigQueryOpenLineageMixin: def get_openlineage_facets_on_complete(self, _): """ @@ -61,7 +64,15 @@ def get_openlineage_facets_on_complete(self, _): - SchemaDatasetFacet - OutputStatisticsOutputDatasetFacet """ - from openlineage.client.facet import ExternalQueryRunFacet, SqlJobFacet + if TYPE_CHECKING: + from openlineage.client.generated.external_query_run import ExternalQueryRunFacet + from openlineage.client.generated.sql_job import SQLJobFacet + else: + try: + from openlineage.client.generated.external_query_run import ExternalQueryRunFacet + from openlineage.client.generated.sql_job import SQLJobFacet + except ImportError: + from openlineage.client.facet import ExternalQueryRunFacet, SqlJobFacet as SQLJobFacet from airflow.providers.openlineage.extractors import OperatorLineage from airflow.providers.openlineage.sqlparser import SQLParser @@ -79,11 +90,11 @@ def get_openlineage_facets_on_complete(self, _): impersonation_chain=self.impersonation_chain, ) - run_facets: dict[str, BaseFacet] = { + run_facets: dict[str, RunFacet] = { "externalQuery": ExternalQueryRunFacet(externalQueryId=self.job_id, source="bigquery") } - job_facets = {"sql": SqlJobFacet(query=SQLParser.normalize_sql(self.sql))} + job_facets = {"sql": SQLJobFacet(query=SQLParser.normalize_sql(self.sql))} self.client = self.hook.get_client(project_id=self.hook.project_id) job_ids = self.job_id @@ -104,16 +115,22 @@ def get_openlineage_facets_on_complete(self, _): ) def get_facets(self, job_id: str): - from openlineage.client.facet import ErrorMessageRunFacet - from airflow.providers.google.cloud.openlineage.utils import ( BigQueryErrorRunFacet, get_from_nullable_chain, ) + if TYPE_CHECKING: + from openlineage.client.generated.error_message_run import ErrorMessageRunFacet + else: + try: + from openlineage.client.generated.error_message_run import ErrorMessageRunFacet + except ImportError: + from openlineage.client.facet import ErrorMessageRunFacet + inputs = [] outputs = [] - run_facets: dict[str, BaseFacet] = {} + run_facets: dict[str, RunFacet] = {} if hasattr(self, "log"): self.log.debug("Extracting data from bigquery job: `%s`", job_id) try: @@ -158,7 +175,7 @@ def get_facets(self, job_id: str): deduplicated_outputs = self._deduplicate_outputs(outputs) return inputs, deduplicated_outputs, run_facets - def _deduplicate_outputs(self, outputs: list[Dataset | None]) -> list[Dataset]: + def _deduplicate_outputs(self, outputs: list[OutputDataset | None]) -> list[OutputDataset]: # Sources are the same so we can compare only names final_outputs = {} for single_output in outputs: @@ -171,22 +188,26 @@ def _deduplicate_outputs(self, outputs: list[Dataset | None]) -> list[Dataset]: # No OutputStatisticsOutputDatasetFacet is added to duplicated outputs as we can not determine # if the rowCount or size can be summed together. - single_output.facets.pop("outputStatistics", None) + if single_output.outputFacets: + single_output.outputFacets.pop("outputStatistics", None) final_outputs[key] = single_output return list(final_outputs.values()) - def _get_inputs_outputs_from_job(self, properties: dict) -> tuple[list[Dataset], Dataset | None]: + def _get_inputs_outputs_from_job( + self, properties: dict + ) -> tuple[list[InputDataset], OutputDataset | None]: from airflow.providers.google.cloud.openlineage.utils import get_from_nullable_chain input_tables = get_from_nullable_chain(properties, ["statistics", "query", "referencedTables"]) or [] output_table = get_from_nullable_chain(properties, ["configuration", "query", "destinationTable"]) - inputs = [self._get_dataset(input_table) for input_table in input_tables] + inputs = [(self._get_input_dataset(input_table)) for input_table in input_tables] if output_table: - output = self._get_dataset(output_table) + output = self._get_output_dataset(output_table) dataset_stat_facet = self._get_statistics_dataset_facet(properties) + output.outputFacets = output.outputFacets or {} if dataset_stat_facet: - output.facets.update({"outputStatistics": dataset_stat_facet}) + output.outputFacets["outputStatistics"] = dataset_stat_facet return inputs, output @@ -210,11 +231,19 @@ def _get_bigquery_job_run_facet(properties: dict) -> BigQueryJobRunFacet: ) @staticmethod - def _get_statistics_dataset_facet(properties) -> OutputStatisticsOutputDatasetFacet | None: - from openlineage.client.facet import OutputStatisticsOutputDatasetFacet - + def _get_statistics_dataset_facet( + properties, + ) -> OutputStatisticsOutputDatasetFacet | None: from airflow.providers.google.cloud.openlineage.utils import get_from_nullable_chain + if not TYPE_CHECKING: + try: + from openlineage.client.generated.output_statistics_output_dataset import ( + OutputStatisticsOutputDatasetFacet, + ) + except ImportError: + from openlineage.client.facet import OutputStatisticsOutputDatasetFacet + query_plan = get_from_nullable_chain(properties, chain=["statistics", "query", "queryPlan"]) if not query_plan: return None @@ -226,26 +255,58 @@ def _get_statistics_dataset_facet(properties) -> OutputStatisticsOutputDatasetFa return OutputStatisticsOutputDatasetFacet(rowCount=int(out_rows), size=int(out_bytes)) return None - def _get_dataset(self, table: dict) -> Dataset: - from openlineage.client.run import Dataset - - BIGQUERY_NAMESPACE = "bigquery" - + def _get_input_dataset(self, table: dict) -> InputDataset: + if not TYPE_CHECKING: + try: + from openlineage.client.generated.base import InputDataset + except ImportError: + from openlineage.client.run import InputDataset + return cast(InputDataset, self._get_dataset(table, "input")) + + def _get_output_dataset(self, table: dict) -> OutputDataset: + if not TYPE_CHECKING: + try: + from openlineage.client.generated.base import OutputDataset + except ImportError: + from openlineage.client.run import OutputDataset + return cast(OutputDataset, self._get_dataset(table, "output")) + + def _get_dataset(self, table: dict, dataset_type: str) -> Dataset: + if not TYPE_CHECKING: + try: + from openlineage.client.generated.base import InputDataset, OutputDataset + except ImportError: + from openlineage.client.run import InputDataset, OutputDataset project = table.get("projectId") dataset = table.get("datasetId") table_name = table.get("tableId") dataset_name = f"{project}.{dataset}.{table_name}" dataset_schema = self._get_table_schema_safely(dataset_name) - return Dataset( - namespace=BIGQUERY_NAMESPACE, - name=dataset_name, - facets={ - "schema": dataset_schema, - } - if dataset_schema - else {}, - ) + if dataset_type == "input": + # Logic specific to creating InputDataset (if needed) + return InputDataset( + namespace=BIGQUERY_NAMESPACE, + name=dataset_name, + facets={ + "schema": dataset_schema, + } + if dataset_schema + else {}, + ) + elif dataset_type == "output": + # Logic specific to creating OutputDataset (if needed) + return OutputDataset( + namespace=BIGQUERY_NAMESPACE, + name=dataset_name, + facets={ + "schema": dataset_schema, + } + if dataset_schema + else {}, + ) + else: + raise ValueError("Invalid dataset_type. Must be 'input' or 'output'") def _get_table_schema_safely(self, table_name: str) -> SchemaDatasetFacet | None: try: @@ -256,10 +317,25 @@ def _get_table_schema_safely(self, table_name: str) -> SchemaDatasetFacet | None return None def _get_table_schema(self, table: str) -> SchemaDatasetFacet | None: - from openlineage.client.facet import SchemaDatasetFacet, SchemaField - from airflow.providers.google.cloud.openlineage.utils import get_from_nullable_chain + if TYPE_CHECKING: + from openlineage.client.generated.schema_dataset import ( + SchemaDatasetFacet, + SchemaDatasetFacetFields, + ) + else: + try: + from openlineage.client.generated.schema_dataset import ( + SchemaDatasetFacet, + SchemaDatasetFacetFields, + ) + except ImportError: + from openlineage.client.facet import ( + SchemaDatasetFacet, + SchemaField as SchemaDatasetFacetFields, + ) + bq_table = self.client.get_table(table) if not bq_table._properties: @@ -271,7 +347,7 @@ def _get_table_schema(self, table: str) -> SchemaDatasetFacet | None: return SchemaDatasetFacet( fields=[ - SchemaField( + SchemaDatasetFacetFields( name=field.get("name"), type=field.get("type"), description=field.get("description"), diff --git a/airflow/providers/google/cloud/openlineage/utils.py b/airflow/providers/google/cloud/openlineage/utils.py index f7ebe3641dfb..1c0624592217 100644 --- a/airflow/providers/google/cloud/openlineage/utils.py +++ b/airflow/providers/google/cloud/openlineage/utils.py @@ -17,29 +17,43 @@ # under the License. from __future__ import annotations -import copy -import json -import traceback -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any from attr import define, field -from openlineage.client.event_v2 import Dataset, InputDataset, OutputDataset -from openlineage.client.facet_v2 import ( - BaseFacet, - column_lineage_dataset, - documentation_dataset, - error_message_run, - external_query_run, - output_statistics_output_dataset, - schema_dataset, -) - -from airflow.providers.google import __version__ as provider_version if TYPE_CHECKING: from google.cloud.bigquery.table import Table - from openlineage.client.run import Dataset + from openlineage.client.event_v2 import Dataset + from openlineage.client.generated.base import RunFacet + from openlineage.client.generated.column_lineage_dataset import ( + ColumnLineageDatasetFacet, + Fields, + InputField, + ) + from openlineage.client.generated.documentation_dataset import DocumentationDatasetFacet + from openlineage.client.generated.schema_dataset import SchemaDatasetFacet, SchemaDatasetFacetFields +else: + try: + from openlineage.client.generated.base import RunFacet + from openlineage.client.generated.column_lineage_dataset import ( + ColumnLineageDatasetFacet, + Fields, + InputField, + ) + from openlineage.client.generated.documentation_dataset import DocumentationDatasetFacet + from openlineage.client.generated.schema_dataset import SchemaDatasetFacet, SchemaDatasetFacetFields + except ImportError: + from openlineage.client.facet import ( + BaseFacet as RunFacet, + ColumnLineageDatasetFacet, + ColumnLineageDatasetFacetFieldsAdditional as Fields, + ColumnLineageDatasetFacetFieldsAdditionalInputFields as InputField, + DocumentationDatasetFacet, + SchemaDatasetFacet, + SchemaField as SchemaDatasetFacetFields, + ) +from airflow.providers.google import __version__ as provider_version BIGQUERY_NAMESPACE = "bigquery" BIGQUERY_URI = "bigquery" @@ -48,15 +62,15 @@ def get_facets_from_bq_table(table: Table) -> dict[Any, Any]: """Get facets from BigQuery table object.""" facets = { - "schema": schema_dataset.SchemaDatasetFacet( + "schema": SchemaDatasetFacet( fields=[ - schema_dataset.SchemaDatasetFacetFields( + SchemaDatasetFacetFields( name=field.name, type=field.field_type, description=field.description ) for field in table.schema ] ), - "documentation": documentation_dataset.DocumentationDatasetFacet(description=table.description or ""), + "documentation": DocumentationDatasetFacet(description=table.description or ""), } return facets @@ -65,7 +79,7 @@ def get_facets_from_bq_table(table: Table) -> dict[Any, Any]: def get_identity_column_lineage_facet( field_names: list[str], input_datasets: list[Dataset], -) -> column_lineage_dataset.ColumnLineageDatasetFacet: +) -> ColumnLineageDatasetFacet: """ Get column lineage facet. @@ -75,13 +89,11 @@ def get_identity_column_lineage_facet( if field_names and not input_datasets: raise ValueError("When providing `field_names` You must provide at least one `input_dataset`.") - column_lineage_facet = column_lineage_dataset.ColumnLineageDatasetFacet( + column_lineage_facet = ColumnLineageDatasetFacet( fields={ - field: column_lineage_dataset.Fields( + field: Fields( inputFields=[ - column_lineage_dataset.InputField( - namespace=dataset.namespace, name=dataset.name, field=field - ) + InputField(namespace=dataset.namespace, name=dataset.name, field=field) for dataset in input_datasets ], transformationType="IDENTITY", @@ -94,7 +106,7 @@ def get_identity_column_lineage_facet( @define -class BigQueryJobRunFacet(BaseFacet): +class BigQueryJobRunFacet(RunFacet): """ Facet that represents relevant statistics of bigquery run. @@ -120,7 +132,7 @@ def _get_schema() -> str: # TODO: remove BigQueryErrorRunFacet in next release @define -class BigQueryErrorRunFacet(BaseFacet): +class BigQueryErrorRunFacet(RunFacet): """ Represents errors that can happen during execution of BigqueryExtractor. @@ -181,246 +193,3 @@ def get_from_nullable_chain(source: Any, chain: list[str]) -> Any | None: return source except AttributeError: return None - - -class _BigQueryOpenLineageMixin: - def get_openlineage_facets_on_complete(self, _): - """ - Retrieve OpenLineage data for a COMPLETE BigQuery job. - - This method retrieves statistics for the specified job_ids using the BigQueryDatasetsProvider. - It calls BigQuery API, retrieving input and output dataset info from it, as well as run-level - usage statistics. - - Run facets should contain: - - ExternalQueryRunFacet - - BigQueryJobRunFacet - - Run facets may contain: - - ErrorMessageRunFacet - - Job facets should contain: - - SqlJobFacet if operator has self.sql - - Input datasets should contain facets: - - DataSourceDatasetFacet - - SchemaDatasetFacet - - Output datasets should contain facets: - - DataSourceDatasetFacet - - SchemaDatasetFacet - - OutputStatisticsOutputDatasetFacet - """ - from openlineage.client.facet_v2 import sql_job - - from airflow.providers.openlineage.extractors import OperatorLineage - from airflow.providers.openlineage.sqlparser import SQLParser - - if not self.job_id: - return OperatorLineage() - - run_facets: dict[str, BaseFacet] = { - "externalQuery": external_query_run.ExternalQueryRunFacet( - externalQueryId=self.job_id, source="bigquery" - ) - } - - job_facets = {"sql": sql_job.SQLJobFacet(query=SQLParser.normalize_sql(self.sql))} - - self.client = self.hook.get_client(project_id=self.hook.project_id) - job_ids = self.job_id - if isinstance(self.job_id, str): - job_ids = [self.job_id] - inputs, outputs = [], [] - for job_id in job_ids: - inner_inputs, inner_outputs, inner_run_facets = self.get_facets(job_id=job_id) - inputs.extend(inner_inputs) - outputs.extend(inner_outputs) - run_facets.update(inner_run_facets) - - return OperatorLineage( - inputs=inputs, - outputs=outputs, - run_facets=run_facets, - job_facets=job_facets, - ) - - def get_facets(self, job_id: str): - inputs = [] - outputs = [] - run_facets: dict[str, BaseFacet] = {} - if hasattr(self, "log"): - self.log.debug("Extracting data from bigquery job: `%s`", job_id) - try: - job = self.client.get_job(job_id=job_id) # type: ignore - props = job._properties - - if get_from_nullable_chain(props, ["status", "state"]) != "DONE": - raise ValueError(f"Trying to extract data from running bigquery job: `{job_id}`") - - # TODO: remove bigQuery_job in next release - run_facets["bigQuery_job"] = run_facets["bigQueryJob"] = self._get_bigquery_job_run_facet(props) - - if get_from_nullable_chain(props, ["statistics", "numChildJobs"]): - if hasattr(self, "log"): - self.log.debug("Found SCRIPT job. Extracting lineage from child jobs instead.") - # SCRIPT job type has no input / output information but spawns child jobs that have one - # https://cloud.google.com/bigquery/docs/information-schema-jobs#multi-statement_query_job - for child_job_id in self.client.list_jobs(parent_job=job_id): - child_job = self.client.get_job(job_id=child_job_id) # type: ignore - child_inputs, child_output = self._get_inputs_outputs_from_job(child_job._properties) - inputs.extend(child_inputs) - outputs.append(child_output) - else: - inputs, _output = self._get_inputs_outputs_from_job(props) - outputs.append(_output) - except Exception as e: - if hasattr(self, "log"): - self.log.warning("Cannot retrieve job details from BigQuery.Client. %s", e, exc_info=True) - exception_msg = traceback.format_exc() - # TODO: remove BigQueryErrorRunFacet in next release - run_facets.update( - { - "errorMessage": error_message_run.ErrorMessageRunFacet( - message=f"{e}: {exception_msg}", - programmingLanguage="python", - ), - "bigQuery_error": BigQueryErrorRunFacet( - clientError=f"{e}: {exception_msg}", - ), - } - ) - deduplicated_outputs = self._deduplicate_outputs(outputs) - return inputs, deduplicated_outputs, run_facets - - def _deduplicate_outputs(self, outputs: list[OutputDataset | None]) -> list[OutputDataset]: - # Sources are the same so we can compare only names - final_outputs = {} - for single_output in outputs: - if not single_output: - continue - key = single_output.name - if key not in final_outputs: - final_outputs[key] = single_output - continue - - # No OutputStatisticsOutputDatasetFacet is added to duplicated outputs as we can not determine - # if the rowCount or size can be summed together. - if single_output.outputFacets: - single_output.outputFacets.pop("outputStatistics", None) - final_outputs[key] = single_output - - return list(final_outputs.values()) - - def _get_inputs_outputs_from_job( - self, properties: dict - ) -> tuple[list[InputDataset], OutputDataset | None]: - input_tables = get_from_nullable_chain(properties, ["statistics", "query", "referencedTables"]) or [] - output_table = get_from_nullable_chain(properties, ["configuration", "query", "destinationTable"]) - inputs = [(self._get_input_dataset(input_table)) for input_table in input_tables] - if output_table: - output = self._get_output_dataset(output_table) - dataset_stat_facet = self._get_statistics_dataset_facet(properties) - output.outputFacets = output.outputFacets or {} - if dataset_stat_facet: - output.outputFacets["outputStatistics"] = dataset_stat_facet - - return inputs, output - - @staticmethod - def _get_bigquery_job_run_facet(properties: dict) -> BigQueryJobRunFacet: - if get_from_nullable_chain(properties, ["configuration", "query", "query"]): - # Exclude the query to avoid event size issues and duplicating SqlJobFacet information. - properties = copy.deepcopy(properties) - properties["configuration"]["query"].pop("query") - cache_hit = get_from_nullable_chain(properties, ["statistics", "query", "cacheHit"]) - billed_bytes = get_from_nullable_chain(properties, ["statistics", "query", "totalBytesBilled"]) - return BigQueryJobRunFacet( - cached=str(cache_hit).lower() == "true", - billedBytes=int(billed_bytes) if billed_bytes else None, - properties=json.dumps(properties), - ) - - @staticmethod - def _get_statistics_dataset_facet( - properties, - ) -> output_statistics_output_dataset.OutputStatisticsOutputDatasetFacet | None: - query_plan = get_from_nullable_chain(properties, chain=["statistics", "query", "queryPlan"]) - if not query_plan: - return None - - out_stage = query_plan[-1] - out_rows = out_stage.get("recordsWritten", None) - out_bytes = out_stage.get("shuffleOutputBytes", None) - if out_bytes and out_rows: - return output_statistics_output_dataset.OutputStatisticsOutputDatasetFacet( - rowCount=int(out_rows), size=int(out_bytes) - ) - return None - - def _get_input_dataset(self, table: dict) -> InputDataset: - return cast(InputDataset, self._get_dataset(table, "input")) - - def _get_output_dataset(self, table: dict) -> OutputDataset: - return cast(OutputDataset, self._get_dataset(table, "output")) - - def _get_dataset(self, table: dict, dataset_type: str) -> Dataset: - project = table.get("projectId") - dataset = table.get("datasetId") - table_name = table.get("tableId") - dataset_name = f"{project}.{dataset}.{table_name}" - - dataset_schema = self._get_table_schema_safely(dataset_name) - if dataset_type == "input": - # Logic specific to creating InputDataset (if needed) - return InputDataset( - namespace=BIGQUERY_NAMESPACE, - name=dataset_name, - facets={ - "schema": dataset_schema, - } - if dataset_schema - else {}, - ) - elif dataset_type == "output": - # Logic specific to creating OutputDataset (if needed) - return OutputDataset( - namespace=BIGQUERY_NAMESPACE, - name=dataset_name, - facets={ - "schema": dataset_schema, - } - if dataset_schema - else {}, - ) - else: - raise ValueError("Invalid dataset_type. Must be 'input' or 'output'") - - def _get_table_schema_safely(self, table_name: str) -> schema_dataset.SchemaDatasetFacet | None: - try: - return self._get_table_schema(table_name) - except Exception as e: - if hasattr(self, "log"): - self.log.warning("Could not extract output schema from bigquery. %s", e) - return None - - def _get_table_schema(self, table: str) -> schema_dataset.SchemaDatasetFacet | None: - bq_table = self.client.get_table(table) - - if not bq_table._properties: - return None - - fields = get_from_nullable_chain(bq_table._properties, ["schema", "fields"]) - if not fields: - return None - - return schema_dataset.SchemaDatasetFacet( - fields=[ - schema_dataset.SchemaDatasetFacetFields( - name=field.get("name"), - type=field.get("type"), - description=field.get("description"), - ) - for field in fields - ] - ) diff --git a/airflow/providers/google/cloud/operators/gcs.py b/airflow/providers/google/cloud/operators/gcs.py index 8027e3d27dd6..f18e49cbb5e7 100644 --- a/airflow/providers/google/cloud/operators/gcs.py +++ b/airflow/providers/google/cloud/operators/gcs.py @@ -336,8 +336,28 @@ def execute(self, context: Context) -> None: hook.delete(bucket_name=self.bucket_name, object_name=object_name) def get_openlineage_facets_on_start(self): - from openlineage.client.event_v2 import Dataset - from openlineage.client.facet_v2 import lifecycle_state_change_dataset + if TYPE_CHECKING: + from openlineage.client.event_v2 import Dataset + from openlineage.client.generated.lifecycle_state_change_dataset import ( + LifecycleStateChange, + LifecycleStateChangeDatasetFacet, + PreviousIdentifier, + ) + else: + try: + from openlineage.client.event_v2 import Dataset + from openlineage.client.generated.lifecycle_state_change_dataset import ( + LifecycleStateChange, + LifecycleStateChangeDatasetFacet, + PreviousIdentifier, + ) + except ImportError: + from openlineage.client.facet import ( + LifecycleStateChange, + LifecycleStateChangeDatasetFacet, + LifecycleStateChangeDatasetFacetPreviousIdentifier as PreviousIdentifier, + ) + from openlineage.client.run import Dataset from airflow.providers.openlineage.extractors import OperatorLineage @@ -359,9 +379,9 @@ def get_openlineage_facets_on_start(self): namespace=bucket_url, name=object_name, facets={ - "lifecycleStateChange": lifecycle_state_change_dataset.LifecycleStateChangeDatasetFacet( - lifecycleStateChange=lifecycle_state_change_dataset.LifecycleStateChange.DROP.value, - previousIdentifier=lifecycle_state_change_dataset.PreviousIdentifier( + "lifecycleStateChange": LifecycleStateChangeDatasetFacet( + lifecycleStateChange=LifecycleStateChange.DROP.value, + previousIdentifier=PreviousIdentifier( namespace=bucket_url, name=object_name, ), @@ -641,7 +661,13 @@ def execute(self, context: Context) -> None: ) def get_openlineage_facets_on_start(self): - from openlineage.client.event_v2 import Dataset + if TYPE_CHECKING: + from openlineage.client.event_v2 import Dataset + else: + try: + from openlineage.client.event_v2 import Dataset + except ImportError: + from openlineage.client.run import Dataset from airflow.providers.openlineage.extractors import OperatorLineage @@ -917,7 +943,11 @@ def execute(self, context: Context) -> list[str]: def get_openlineage_facets_on_complete(self, task_instance): """Implement on_complete as execute() resolves object prefixes.""" - from openlineage.client.event_v2 import Dataset + if not TYPE_CHECKING: + try: + from openlineage.client.event_v2 import Dataset + except ImportError: + from openlineage.client.run import Dataset from airflow.providers.openlineage.extractors import OperatorLineage diff --git a/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py b/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py index 800ed952b193..5c9df82c1488 100644 --- a/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py @@ -289,9 +289,22 @@ def get_openlineage_facets_on_complete(self, task_instance): """Implement on_complete as we will include final BQ job id.""" from pathlib import Path - from openlineage.client.event_v2 import Dataset - from openlineage.client.facet_v2 import external_query_run, symlinks_dataset - + if TYPE_CHECKING: + from openlineage.client.event_v2 import Dataset + from openlineage.client.generated.external_query_run import ExternalQueryRunFacet + from openlineage.client.generated.symlinks_dataset import Identifier, SymlinksDatasetFacet + else: + try: + from openlineage.client.event_v2 import Dataset + from openlineage.client.generated.external_query_run import ExternalQueryRunFacet + from openlineage.client.generated.symlinks_dataset import Identifier, SymlinksDatasetFacet + except ImportError: + from openlineage.client.facet import ( + ExternalQueryRunFacet, + SymlinksDatasetFacet, + SymlinksDatasetFacetIdentifiers as Identifier, + ) + from openlineage.client.run import Dataset from airflow.providers.google.cloud.hooks.gcs import _parse_gcs_url from airflow.providers.google.cloud.openlineage.utils import ( get_facets_from_bq_table, @@ -330,10 +343,8 @@ def get_openlineage_facets_on_complete(self, task_instance): # If wildcard ("*") is used in gcs path, we want the name of dataset to be directory name, # but we create a symlink to the full object path with wildcard. additional_facets = { - "symlink": symlinks_dataset.SymlinksDatasetFacet( - identifiers=[ - symlinks_dataset.Identifier(namespace=f"gs://{bucket}", name=blob, type="file") - ] + "symlink": SymlinksDatasetFacet( + identifiers=[Identifier(namespace=f"gs://{bucket}", name=blob, type="file")] ), } blob = Path(blob).parent.as_posix() @@ -351,9 +362,7 @@ def get_openlineage_facets_on_complete(self, task_instance): run_facets = {} if self.job_id: run_facets = { - "externalQuery": external_query_run.ExternalQueryRunFacet( - externalQueryId=self.job_id, source="bigquery" - ), + "externalQuery": ExternalQueryRunFacet(externalQueryId=self.job_id, source="bigquery"), } return OperatorLineage(inputs=[input_dataset], outputs=output_datasets, run_facets=run_facets) diff --git a/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py b/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py index 22ecf93d4d01..4fbb042df0a4 100644 --- a/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +++ b/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py @@ -746,8 +746,22 @@ def get_openlineage_facets_on_complete(self, task_instance): """Implement on_complete as we will include final BQ job id.""" from pathlib import Path - from openlineage.client.event_v2 import Dataset - from openlineage.client.facet_v2 import external_query_run, symlinks_dataset + if TYPE_CHECKING: + from openlineage.client.event_v2 import Dataset + from openlineage.client.generated.external_query_run import ExternalQueryRunFacet + from openlineage.client.generated.symlinks_dataset import Identifier, SymlinksDatasetFacet + else: + try: + from openlineage.client.event_v2 import Dataset + from openlineage.client.generated.external_query_run import ExternalQueryRunFacet + from openlineage.client.generated.symlinks_dataset import Identifier, SymlinksDatasetFacet + except ImportError: + from openlineage.client.facet import ( + ExternalQueryRunFacet, + SymlinksDatasetFacet, + SymlinksDatasetFacetIdentifiers as Identifier, + ) + from openlineage.client.run import Dataset from airflow.providers.google.cloud.openlineage.utils import ( get_facets_from_bq_table, @@ -781,12 +795,8 @@ def get_openlineage_facets_on_complete(self, task_instance): # If wildcard ("*") is used in gcs path, we want the name of dataset to be directory name, # but we create a symlink to the full object path with wildcard. additional_facets = { - "symlink": symlinks_dataset.SymlinksDatasetFacet( - identifiers=[ - symlinks_dataset.Identifier( - namespace=f"gs://{self.bucket}", name=blob, type="file" - ) - ] + "symlink": SymlinksDatasetFacet( + identifiers=[Identifier(namespace=f"gs://{self.bucket}", name=blob, type="file")] ), } blob = Path(blob).parent.as_posix() @@ -814,9 +824,7 @@ def get_openlineage_facets_on_complete(self, task_instance): run_facets = {} if self.job_id: run_facets = { - "externalQuery": external_query_run.ExternalQueryRunFacet( - externalQueryId=self.job_id, source="bigquery" - ), + "externalQuery": ExternalQueryRunFacet(externalQueryId=self.job_id, source="bigquery"), } return OperatorLineage(inputs=input_datasets, outputs=[output_dataset], run_facets=run_facets) diff --git a/airflow/providers/google/cloud/transfers/gcs_to_gcs.py b/airflow/providers/google/cloud/transfers/gcs_to_gcs.py index b9730e70c86e..c788fb71d1ac 100644 --- a/airflow/providers/google/cloud/transfers/gcs_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/gcs_to_gcs.py @@ -552,7 +552,13 @@ def get_openlineage_facets_on_complete(self, task_instance): """ from pathlib import Path - from openlineage.client.event_v2 import Dataset + if TYPE_CHECKING: + from openlineage.client.event_v2 import Dataset + else: + try: + from openlineage.client.event_v2 import Dataset + except ImportError: + from openlineage.client.run import Dataset from airflow.providers.openlineage.extractors import OperatorLineage diff --git a/airflow/providers/openlineage/plugins/adapter.py b/airflow/providers/openlineage/plugins/adapter.py index bdd34e8abac4..398ef5a8f6c5 100644 --- a/airflow/providers/openlineage/plugins/adapter.py +++ b/airflow/providers/openlineage/plugins/adapter.py @@ -330,7 +330,7 @@ def dag_started( msg: str, nominal_start_time: str, nominal_end_time: str, - job_facets: dict[str, BaseFacet] | None = None, # Custom job facets + job_facets: dict[str, JobFacet] | None = None, # Custom job facets ): try: owner = [x.strip() for x in dag_run.dag.owner.split(",")] if dag_run.dag else None diff --git a/airflow/providers/openlineage/utils/utils.py b/airflow/providers/openlineage/utils/utils.py index 171f35a77588..de51a315ca61 100644 --- a/airflow/providers/openlineage/utils/utils.py +++ b/airflow/providers/openlineage/utils/utils.py @@ -42,7 +42,6 @@ AirflowMappedTaskRunFacet, AirflowRunFacet, AirflowStateRunFacet, - BaseFacet, UnknownOperatorAttributeRunFacet, UnknownOperatorInstance, ) @@ -363,7 +362,7 @@ def get_airflow_run_facet( task_instance: TaskInstance, task: BaseOperator, task_uuid: str, -) -> dict[str, BaseFacet]: +) -> dict[str, AirflowRunFacet]: return { "airflow": AirflowRunFacet( dag=DagInfo(dag), @@ -375,7 +374,7 @@ def get_airflow_run_facet( } -def get_airflow_job_facet(dag_run: DagRun) -> dict[str, BaseFacet]: +def get_airflow_job_facet(dag_run: DagRun) -> dict[str, AirflowJobFacet]: if not dag_run.dag: return {} return { @@ -387,7 +386,7 @@ def get_airflow_job_facet(dag_run: DagRun) -> dict[str, BaseFacet]: } -def get_airflow_state_run_facet(dag_run: DagRun) -> dict[str, BaseFacet]: +def get_airflow_state_run_facet(dag_run: DagRun) -> dict[str, AirflowStateRunFacet]: return { "airflowState": AirflowStateRunFacet( dagRunState=dag_run.get_state(), diff --git a/airflow/providers/sftp/operators/sftp.py b/airflow/providers/sftp/operators/sftp.py index 68fc87350b4d..28cb42092dbd 100644 --- a/airflow/providers/sftp/operators/sftp.py +++ b/airflow/providers/sftp/operators/sftp.py @@ -23,7 +23,7 @@ import socket import warnings from pathlib import Path -from typing import Any, Sequence +from typing import TYPE_CHECKING, Any, Sequence import paramiko @@ -201,7 +201,13 @@ def get_openlineage_facets_on_start(self): input: file:///path output: file://:/path. """ - from openlineage.client.event_v2 import Dataset + if TYPE_CHECKING: + from openlineage.client.event_v2 import Dataset + else: + try: + from openlineage.client.event_v2 import Dataset + except ImportError: + from openlineage.client.run import Dataset from airflow.providers.openlineage.extractors import OperatorLineage diff --git a/airflow/providers/snowflake/transfers/copy_into_snowflake.py b/airflow/providers/snowflake/transfers/copy_into_snowflake.py index 8624a22aae1e..661d98b3a7f1 100644 --- a/airflow/providers/snowflake/transfers/copy_into_snowflake.py +++ b/airflow/providers/snowflake/transfers/copy_into_snowflake.py @@ -19,7 +19,7 @@ from __future__ import annotations -from typing import Any, Sequence +from typing import TYPE_CHECKING, Any, Sequence from airflow.models import BaseOperator from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook @@ -228,8 +228,25 @@ def get_openlineage_facets_on_complete(self, task_instance): """Implement _on_complete because we rely on return value of a query.""" import re - from openlineage.client.event_v2 import Dataset - from openlineage.client.facet_v2 import external_query_run, extraction_error_run, sql_job + if TYPE_CHECKING: + from openlineage.client.event_v2 import Dataset + from openlineage.client.generated.external_query_run import ExternalQueryRunFacet + from openlineage.client.generated.extraction_error_run import Error, ExtractionErrorRunFacet + from openlineage.client.generated.sql_job import SQLJobFacet + else: + try: + from openlineage.client.event_v2 import Dataset + from openlineage.client.generated.external_query_run import ExternalQueryRunFacet + from openlineage.client.generated.extraction_error_run import Error, ExtractionErrorRunFacet + from openlineage.client.generated.sql_job import SQLJobFacet + except ImportError: + from openlineage.client.facet import ( + ExternalQueryRunFacet, + ExtractionError as Error, + ExtractionErrorRunFacet, + SqlJobFacet as SQLJobFacet, + ) + from openlineage.client.run import Dataset from airflow.providers.openlineage.extractors import OperatorLineage from airflow.providers.openlineage.sqlparser import SQLParser @@ -252,11 +269,11 @@ def get_openlineage_facets_on_complete(self, task_instance): "Unable to extract Dataset namespace and name for the following files: `%s`.", extraction_error_files, ) - run_facets["extractionError"] = extraction_error_run.ExtractionErrorRunFacet( + run_facets["extractionError"] = ExtractionErrorRunFacet( totalTasks=len(query_results), failedTasks=len(extraction_error_files), errors=[ - extraction_error_run.Error( + Error( errorMessage="Unable to extract Dataset namespace and name.", stackTrace=None, task=file_uri, @@ -281,13 +298,13 @@ def get_openlineage_facets_on_complete(self, task_instance): query = SQLParser.normalize_sql(self._sql) query = re.sub(r"\n+", "\n", re.sub(r" +", " ", query)) - run_facets["externalQuery"] = external_query_run.ExternalQueryRunFacet( + run_facets["externalQuery"] = ExternalQueryRunFacet( externalQueryId=self.hook.query_ids[0], source=snowflake_namespace ) return OperatorLineage( inputs=input_datasets, outputs=[Dataset(namespace=snowflake_namespace, name=dest_name)], - job_facets={"sql": sql_job.SQLJobFacet(query=query)}, + job_facets={"sql": SQLJobFacet(query=query)}, run_facets=run_facets, ) diff --git a/tests/providers/amazon/aws/operators/test_athena.py b/tests/providers/amazon/aws/operators/test_athena.py index d603d7c5c1aa..976791851b06 100644 --- a/tests/providers/amazon/aws/operators/test_athena.py +++ b/tests/providers/amazon/aws/operators/test_athena.py @@ -17,26 +17,33 @@ from __future__ import annotations import json +from typing import TYPE_CHECKING from unittest import mock import pytest -try: +if TYPE_CHECKING: from openlineage.client.event_v2 import Dataset from openlineage.client.generated.external_query_run import ExternalQueryRunFacet from openlineage.client.generated.schema_dataset import SchemaDatasetFacet, SchemaDatasetFacetFields from openlineage.client.generated.sql_job import SQLJobFacet from openlineage.client.generated.symlinks_dataset import Identifier, SymlinksDatasetFacet -except ImportError: - from openlineage.client.facet import ( - ExternalQueryRunFacet, - SchemaDatasetFacet, - SchemaField as SchemaDatasetFacetFields, - SqlJobFacet as SQLJobFacet, - SymlinksDatasetFacet, - SymlinksDatasetFacetIdentifiers as Identifier, - ) - from openlineage.client.run import Dataset +else: + try: + from openlineage.client.event_v2 import Dataset + from openlineage.client.generated.external_query_run import ExternalQueryRunFacet + from openlineage.client.generated.schema_dataset import SchemaDatasetFacet, SchemaDatasetFacetFields + from openlineage.client.generated.sql_job import SQLJobFacet + from openlineage.client.generated.symlinks_dataset import Identifier, SymlinksDatasetFacet + except ImportError: + from openlineage.client.facet import ( + SchemaDatasetFacet, + SchemaField as SchemaDatasetFacetFields, + SqlJobFacet as SQLJobFacet, + SymlinksDatasetFacet, + SymlinksDatasetFacetIdentifiers as Identifier, + ) + from openlineage.client.run import Dataset from airflow.exceptions import AirflowException, TaskDeferred from airflow.models import DAG, DagRun, TaskInstance diff --git a/tests/providers/amazon/aws/operators/test_redshift_sql.py b/tests/providers/amazon/aws/operators/test_redshift_sql.py index 94e2b5b2f397..010d807286b4 100644 --- a/tests/providers/amazon/aws/operators/test_redshift_sql.py +++ b/tests/providers/amazon/aws/operators/test_redshift_sql.py @@ -17,11 +17,12 @@ # under the License. from __future__ import annotations +from typing import TYPE_CHECKING from unittest.mock import MagicMock, PropertyMock, call, patch import pytest -try: +if TYPE_CHECKING: from openlineage.client.event_v2 import Dataset from openlineage.client.generated.column_lineage_dataset import ( ColumnLineageDatasetFacet, @@ -30,16 +31,26 @@ ) from openlineage.client.generated.schema_dataset import SchemaDatasetFacet, SchemaDatasetFacetFields from openlineage.client.generated.sql_job import SQLJobFacet -except ImportError: - from openlineage.client.facet import ( - ColumnLineageDatasetFacet, - ColumnLineageDatasetFacetFieldsAdditional as Fields, - ColumnLineageDatasetFacetFieldsAdditionalInputFields as InputField, - SchemaDatasetFacet, - SchemaField as SchemaDatasetFacetFields, - SqlJobFacet as SQLJobFacet, - ) - from openlineage.client.run import Dataset +else: + try: + from openlineage.client.event_v2 import Dataset + from openlineage.client.generated.column_lineage_dataset import ( + ColumnLineageDatasetFacet, + Fields, + InputField, + ) + from openlineage.client.generated.schema_dataset import SchemaDatasetFacet, SchemaDatasetFacetFields + from openlineage.client.generated.sql_job import SQLJobFacet + except ImportError: + from openlineage.client.facet import ( + ColumnLineageDatasetFacet, + ColumnLineageDatasetFacetFieldsAdditional as Fields, + ColumnLineageDatasetFacetFieldsAdditionalInputFields as InputField, + SchemaDatasetFacet, + SchemaField as SchemaDatasetFacetFields, + SqlJobFacet as SQLJobFacet, + ) + from openlineage.client.run import Dataset from airflow.models.connection import Connection from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook as OriginalRedshiftSQLHook diff --git a/tests/providers/amazon/aws/operators/test_s3.py b/tests/providers/amazon/aws/operators/test_s3.py index 639b8a87932d..d4d4f016dd09 100644 --- a/tests/providers/amazon/aws/operators/test_s3.py +++ b/tests/providers/amazon/aws/operators/test_s3.py @@ -23,13 +23,35 @@ import sys from io import BytesIO from tempfile import mkdtemp +from typing import TYPE_CHECKING from unittest import mock import boto3 import pytest from moto import mock_aws -from openlineage.client.event_v2 import Dataset -from openlineage.client.facet_v2 import lifecycle_state_change_dataset + +if TYPE_CHECKING: + from openlineage.client.event_v2 import Dataset + from openlineage.client.generated.lifecycle_state_change_dataset import ( + LifecycleStateChange, + LifecycleStateChangeDatasetFacet, + PreviousIdentifier, + ) +else: + try: + from openlineage.client.event_v2 import Dataset + from openlineage.client.generated.lifecycle_state_change_dataset import ( + LifecycleStateChange, + LifecycleStateChangeDatasetFacet, + PreviousIdentifier, + ) + except ImportError: + from openlineage.client.facet import ( + LifecycleStateChange, + LifecycleStateChangeDatasetFacet, + LifecycleStateChangeDatasetFacetPreviousIdentifier as PreviousIdentifier, + ) + from openlineage.client.run import Dataset from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.s3 import S3Hook @@ -767,9 +789,9 @@ def test_get_openlineage_facets_on_complete_single_object(self, mock_hook, keys) namespace=f"s3://{bucket}", name="path/data.txt", facets={ - "lifecycleStateChange": lifecycle_state_change_dataset.LifecycleStateChangeDatasetFacet( - lifecycleStateChange=lifecycle_state_change_dataset.LifecycleStateChange.DROP.value, - previousIdentifier=lifecycle_state_change_dataset.PreviousIdentifier( + "lifecycleStateChange": LifecycleStateChangeDatasetFacet( + lifecycleStateChange=LifecycleStateChange.DROP.value, + previousIdentifier=PreviousIdentifier( namespace=f"s3://{bucket}", name="path/data.txt", ), @@ -793,9 +815,9 @@ def test_get_openlineage_facets_on_complete_multiple_objects(self, mock_hook): namespace=f"s3://{bucket}", name="path/data1.txt", facets={ - "lifecycleStateChange": lifecycle_state_change_dataset.LifecycleStateChangeDatasetFacet( - lifecycleStateChange=lifecycle_state_change_dataset.LifecycleStateChange.DROP.value, - previousIdentifier=lifecycle_state_change_dataset.PreviousIdentifier( + "lifecycleStateChange": LifecycleStateChangeDatasetFacet( + lifecycleStateChange=LifecycleStateChange.DROP.value, + previousIdentifier=PreviousIdentifier( namespace=f"s3://{bucket}", name="path/data1.txt", ), @@ -806,9 +828,9 @@ def test_get_openlineage_facets_on_complete_multiple_objects(self, mock_hook): namespace=f"s3://{bucket}", name="path/data2.txt", facets={ - "lifecycleStateChange": lifecycle_state_change_dataset.LifecycleStateChangeDatasetFacet( - lifecycleStateChange=lifecycle_state_change_dataset.LifecycleStateChange.DROP.value, - previousIdentifier=lifecycle_state_change_dataset.PreviousIdentifier( + "lifecycleStateChange": LifecycleStateChangeDatasetFacet( + lifecycleStateChange=LifecycleStateChange.DROP.value, + previousIdentifier=PreviousIdentifier( namespace=f"s3://{bucket}", name="path/data2.txt", ), diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py index b08a4eff5e6d..e4b9ee16f084 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py @@ -16,11 +16,19 @@ # under the License. from __future__ import annotations +from typing import TYPE_CHECKING from unittest import mock import pytest from botocore.exceptions import ClientError -from openlineage.client.event_v2 import Dataset + +if TYPE_CHECKING: + from openlineage.client.event_v2 import Dataset +else: + try: + from openlineage.client.event_v2 import Dataset + except ImportError: + from openlineage.client.run import Dataset from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_training.py b/tests/providers/amazon/aws/operators/test_sagemaker_training.py index 9316347de57e..448c8e967e7c 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_training.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_training.py @@ -17,11 +17,19 @@ from __future__ import annotations from datetime import datetime +from typing import TYPE_CHECKING from unittest import mock import pytest from botocore.exceptions import ClientError -from openlineage.client.event_v2 import Dataset + +if TYPE_CHECKING: + from openlineage.client.event_v2 import Dataset +else: + try: + from openlineage.client.event_v2 import Dataset + except ImportError: + from openlineage.client.run import Dataset from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.amazon.aws.hooks.sagemaker import LogState, SageMakerHook diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py index 4c9a2c3a6201..dd364208a523 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py @@ -18,11 +18,19 @@ from __future__ import annotations import copy +from typing import TYPE_CHECKING from unittest import mock import pytest from botocore.exceptions import ClientError -from openlineage.client.event_v2 import Dataset + +if TYPE_CHECKING: + from openlineage.client.event_v2 import Dataset +else: + try: + from openlineage.client.event_v2 import Dataset + except ImportError: + from openlineage.client.run import Dataset from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook diff --git a/tests/providers/common/io/operators/test_file_transfer.py b/tests/providers/common/io/operators/test_file_transfer.py index b8b9f12d3cdb..de5e7a93fc6a 100644 --- a/tests/providers/common/io/operators/test_file_transfer.py +++ b/tests/providers/common/io/operators/test_file_transfer.py @@ -17,9 +17,16 @@ from __future__ import annotations +from typing import TYPE_CHECKING from unittest import mock -from openlineage.client.event_v2 import Dataset +if TYPE_CHECKING: + from openlineage.client.event_v2 import Dataset +else: + try: + from openlineage.client.event_v2 import Dataset + except ImportError: + from openlineage.client.run import Dataset from tests.test_utils.compat import ignore_provider_compatibility_error diff --git a/tests/providers/common/sql/operators/test_sql_execute.py b/tests/providers/common/sql/operators/test_sql_execute.py index c76140065ae1..1b51f0322dd1 100644 --- a/tests/providers/common/sql/operators/test_sql_execute.py +++ b/tests/providers/common/sql/operators/test_sql_execute.py @@ -17,13 +17,28 @@ # under the License. from __future__ import annotations -from typing import Any, NamedTuple, Sequence +from typing import TYPE_CHECKING, Any, NamedTuple, Sequence from unittest import mock from unittest.mock import MagicMock import pytest -from openlineage.client.event_v2 import Dataset -from openlineage.client.facet_v2 import schema_dataset, sql_job + +if TYPE_CHECKING: + from openlineage.client.event_v2 import Dataset + from openlineage.client.generated.schema_dataset import SchemaDatasetFacet, SchemaDatasetFacetFields + from openlineage.client.generated.sql_job import SQLJobFacet +else: + try: + from openlineage.client.event_v2 import Dataset + from openlineage.client.generated.schema_dataset import SchemaDatasetFacet, SchemaDatasetFacetFields + from openlineage.client.generated.sql_job import SQLJobFacet + except ImportError: + from openlineage.client.facet import ( + SchemaDatasetFacet, + SchemaField as SchemaDatasetFacetFields, + SqlJobFacet as SQLJobFacet, + ) + from openlineage.client.run import Dataset from airflow.models import Connection from airflow.providers.common.sql.hooks.sql import DbApiHook, fetch_all_handler @@ -338,18 +353,18 @@ def get_db_hook(self): namespace=f"sqlscheme://host:{expected_port}", name="PUBLIC.popular_orders_day_of_week", facets={ - "schema": schema_dataset.SchemaDatasetFacet( + "schema": SchemaDatasetFacet( fields=[ - schema_dataset.SchemaDatasetFacetFields(name="order_day_of_week", type="varchar"), - schema_dataset.SchemaDatasetFacetFields(name="order_placed_on", type="timestamp"), - schema_dataset.SchemaDatasetFacetFields(name="orders_placed", type="int4"), + SchemaDatasetFacetFields(name="order_day_of_week", type="varchar"), + SchemaDatasetFacetFields(name="order_placed_on", type="timestamp"), + SchemaDatasetFacetFields(name="orders_placed", type="int4"), ] ) }, ) ] - assert lineage.job_facets == {"sql": sql_job.SQLJobFacet(query=sql)} + assert lineage.job_facets == {"sql": SQLJobFacet(query=sql)} assert lineage.run_facets["extractionError"].failedTasks == 1 diff --git a/tests/providers/dbt/cloud/utils/test_openlineage.py b/tests/providers/dbt/cloud/utils/test_openlineage.py index be1e578731ef..c42c6e55ab25 100644 --- a/tests/providers/dbt/cloud/utils/test_openlineage.py +++ b/tests/providers/dbt/cloud/utils/test_openlineage.py @@ -21,6 +21,8 @@ from unittest.mock import MagicMock, patch import pytest +from openlineage.common import __version__ +from packaging.version import parse from airflow.exceptions import AirflowOptionalProviderFeatureException from airflow.providers.dbt.cloud.hooks.dbt import DbtCloudHook @@ -42,18 +44,12 @@ def json(self): def emit_event(event): - run_id = TASK_UUID - name = f"{DAG_ID}.{TASK_ID}" - run_obj = event.run.facets["parent"].run - job_obj = event.run.facets["parent"].job - if isinstance(run_obj, dict): - assert run_obj["runId"] == run_id + if parse(__version__) >= parse("1.15.0"): + assert event.run.facets["parent"].run.runId == TASK_UUID + assert event.run.facets["parent"].job.name == f"{DAG_ID}.{TASK_ID}" else: - assert run_obj.runId == run_id - if isinstance(job_obj, dict): - assert job_obj["name"] == name - else: - assert job_obj.name == name + assert event.run.facets["parent"].run["runId"] == TASK_UUID + assert event.run.facets["parent"].job["name"] == f"{DAG_ID}.{TASK_ID}" assert event.job.namespace == "default" assert event.job.name.startswith("SANDBOX.TEST_SCHEMA.test_project") diff --git a/tests/providers/ftp/operators/test_ftp.py b/tests/providers/ftp/operators/test_ftp.py index 12e1d4363e9a..e443484be931 100644 --- a/tests/providers/ftp/operators/test_ftp.py +++ b/tests/providers/ftp/operators/test_ftp.py @@ -18,10 +18,18 @@ from __future__ import annotations import socket +from typing import TYPE_CHECKING from unittest import mock import pytest -from openlineage.client.event_v2 import Dataset + +if TYPE_CHECKING: + from openlineage.client.event_v2 import Dataset +else: + try: + from openlineage.client.event_v2 import Dataset + except ImportError: + from openlineage.client.run import Dataset from airflow.models import DAG, Connection from airflow.providers.ftp.operators.ftp import ( diff --git a/tests/providers/google/cloud/openlineage/test_mixins.py b/tests/providers/google/cloud/openlineage/test_mixins.py index 50e90d29a3c9..c2c803232ff6 100644 --- a/tests/providers/google/cloud/openlineage/test_mixins.py +++ b/tests/providers/google/cloud/openlineage/test_mixins.py @@ -17,16 +17,34 @@ from __future__ import annotations import json +from typing import TYPE_CHECKING from unittest.mock import MagicMock import pytest -from openlineage.client.facet import ( - ExternalQueryRunFacet, - OutputStatisticsOutputDatasetFacet, - SchemaDatasetFacet, - SchemaField, -) -from openlineage.client.run import Dataset + +if TYPE_CHECKING: + from openlineage.client.event_v2 import InputDataset, OutputDataset + from openlineage.client.generated.external_query_run import ExternalQueryRunFacet + from openlineage.client.generated.output_statistics_output_dataset import ( + OutputStatisticsOutputDatasetFacet, + ) + from openlineage.client.generated.schema_dataset import SchemaDatasetFacet, SchemaDatasetFacetFields +else: + try: + from openlineage.client.event_v2 import InputDataset, OutputDataset + from openlineage.client.generated.external_query_run import ExternalQueryRunFacet + from openlineage.client.generated.output_statistics_output_dataset import ( + OutputStatisticsOutputDatasetFacet, + ) + from openlineage.client.generated.schema_dataset import SchemaDatasetFacet, SchemaDatasetFacetFields + except ImportError: + from openlineage.client.facet import ( + ExternalQueryRunFacet, + OutputStatisticsOutputDatasetFacet, + SchemaDatasetFacet, + SchemaField as SchemaDatasetFacetFields, + ) + from openlineage.client.run import InputDataset, OutputDataset from airflow.providers.google.cloud.openlineage.mixins import _BigQueryOpenLineageMixin from airflow.providers.google.cloud.openlineage.utils import ( @@ -89,27 +107,29 @@ def test_bq_job_information(self): "externalQuery": ExternalQueryRunFacet(externalQueryId="job_id", source="bigquery"), } assert lineage.inputs == [ - Dataset( + InputDataset( namespace="bigquery", name="airflow-openlineage.new_dataset.test_table", facets={ "schema": SchemaDatasetFacet( fields=[ - SchemaField("state", "STRING", "2-digit state code"), - SchemaField("gender", "STRING", "Sex (M=male or F=female)"), - SchemaField("year", "INTEGER", "4-digit year of birth"), - SchemaField("name", "STRING", "Given name of a person at birth"), - SchemaField("number", "INTEGER", "Number of occurrences of the name"), + SchemaDatasetFacetFields("state", "STRING", "2-digit state code"), + SchemaDatasetFacetFields("gender", "STRING", "Sex (M=male or F=female)"), + SchemaDatasetFacetFields("year", "INTEGER", "4-digit year of birth"), + SchemaDatasetFacetFields("name", "STRING", "Given name of a person at birth"), + SchemaDatasetFacetFields( + "number", "INTEGER", "Number of occurrences of the name" + ), ] ) }, ) ] assert lineage.outputs == [ - Dataset( + OutputDataset( namespace="bigquery", name="airflow-openlineage.new_dataset.output_table", - facets={ + outputFacets={ "outputStatistics": OutputStatisticsOutputDatasetFacet( rowCount=20, size=321, fileCount=None ) @@ -137,27 +157,29 @@ def test_bq_script_job_information(self): "externalQuery": ExternalQueryRunFacet(externalQueryId="job_id", source="bigquery"), } assert lineage.inputs == [ - Dataset( + InputDataset( namespace="bigquery", name="airflow-openlineage.new_dataset.test_table", facets={ "schema": SchemaDatasetFacet( fields=[ - SchemaField("state", "STRING", "2-digit state code"), - SchemaField("gender", "STRING", "Sex (M=male or F=female)"), - SchemaField("year", "INTEGER", "4-digit year of birth"), - SchemaField("name", "STRING", "Given name of a person at birth"), - SchemaField("number", "INTEGER", "Number of occurrences of the name"), + SchemaDatasetFacetFields("state", "STRING", "2-digit state code"), + SchemaDatasetFacetFields("gender", "STRING", "Sex (M=male or F=female)"), + SchemaDatasetFacetFields("year", "INTEGER", "4-digit year of birth"), + SchemaDatasetFacetFields("name", "STRING", "Given name of a person at birth"), + SchemaDatasetFacetFields( + "number", "INTEGER", "Number of occurrences of the name" + ), ] ) }, ) ] assert lineage.outputs == [ - Dataset( + OutputDataset( namespace="bigquery", name="airflow-openlineage.new_dataset.output_table", - facets={ + outputFacets={ "outputStatistics": OutputStatisticsOutputDatasetFacet( rowCount=20, size=321, fileCount=None ) @@ -168,23 +190,28 @@ def test_bq_script_job_information(self): def test_deduplicate_outputs(self): outputs = [ None, - Dataset( - name="d1", namespace="", facets={"outputStatistics": OutputStatisticsOutputDatasetFacet(3, 4)} + OutputDataset( + name="d1", + namespace="", + outputFacets={"outputStatistics": OutputStatisticsOutputDatasetFacet(3, 4)}, ), - Dataset( + OutputDataset( name="d1", namespace="", - facets={"outputStatistics": OutputStatisticsOutputDatasetFacet(3, 4), "t1": "t1"}, + outputFacets={"outputStatistics": OutputStatisticsOutputDatasetFacet(3, 4)}, + facets={"t1": "t1"}, ), - Dataset( + OutputDataset( name="d2", namespace="", - facets={"outputStatistics": OutputStatisticsOutputDatasetFacet(6, 7), "t2": "t2"}, + outputFacets={"outputStatistics": OutputStatisticsOutputDatasetFacet(6, 7)}, + facets={"t2": "t2"}, ), - Dataset( + OutputDataset( name="d2", namespace="", - facets={"outputStatistics": OutputStatisticsOutputDatasetFacet(60, 70), "t20": "t20"}, + outputFacets={"outputStatistics": OutputStatisticsOutputDatasetFacet(60, 70)}, + facets={"t20": "t20"}, ), ] result = self.operator._deduplicate_outputs(outputs) diff --git a/tests/providers/google/cloud/openlineage/test_utils.py b/tests/providers/google/cloud/openlineage/test_utils.py index 972337f92aac..34f5f0b2c187 100644 --- a/tests/providers/google/cloud/openlineage/test_utils.py +++ b/tests/providers/google/cloud/openlineage/test_utils.py @@ -17,18 +17,41 @@ from __future__ import annotations import json +from typing import TYPE_CHECKING from unittest.mock import MagicMock import pytest from google.cloud.bigquery.table import Table -from openlineage.client.event_v2 import Dataset, InputDataset, OutputDataset -from openlineage.client.facet_v2 import ( - column_lineage_dataset, - documentation_dataset, - external_query_run, - output_statistics_output_dataset, - schema_dataset, -) + +if TYPE_CHECKING: + from openlineage.client.event_v2 import Dataset + from openlineage.client.generated.column_lineage_dataset import ( + ColumnLineageDatasetFacet, + Fields, + InputField, + ) + from openlineage.client.generated.documentation_dataset import DocumentationDatasetFacet + from openlineage.client.generated.schema_dataset import SchemaDatasetFacet, SchemaDatasetFacetFields +else: + try: + from openlineage.client.event_v2 import Dataset + from openlineage.client.generated.column_lineage_dataset import ( + ColumnLineageDatasetFacet, + Fields, + InputField, + ) + from openlineage.client.generated.documentation_dataset import DocumentationDatasetFacet + from openlineage.client.generated.schema_dataset import SchemaDatasetFacet, SchemaDatasetFacetFields + except ImportError: + from openlineage.client.facet import ( + ColumnLineageDatasetFacet, + ColumnLineageDatasetFacetFieldsAdditional as Fields, + ColumnLineageDatasetFacetFieldsAdditionalInputFields as InputField, + DocumentationDatasetFacet, + SchemaDatasetFacet, + SchemaField as SchemaDatasetFacetFields, + ) + from openlineage.client.run import Dataset from airflow.providers.google.cloud.openlineage.utils import ( get_facets_from_bq_table, @@ -75,15 +98,13 @@ def _properties(self): def test_get_facets_from_bq_table(): expected_facets = { - "schema": schema_dataset.SchemaDatasetFacet( + "schema": SchemaDatasetFacet( fields=[ - schema_dataset.SchemaDatasetFacetFields( - name="field1", type="STRING", description="field1 description" - ), - schema_dataset.SchemaDatasetFacetFields(name="field2", type="INTEGER"), + SchemaDatasetFacetFields(name="field1", type="STRING", description="field1 description"), + SchemaDatasetFacetFields(name="field2", type="INTEGER"), ] ), - "documentation": documentation_dataset.DocumentationDatasetFacet(description="Table description."), + "documentation": DocumentationDatasetFacet(description="Table description."), } result = get_facets_from_bq_table(TEST_TABLE) assert result == expected_facets @@ -91,8 +112,8 @@ def test_get_facets_from_bq_table(): def test_get_facets_from_empty_bq_table(): expected_facets = { - "schema": schema_dataset.SchemaDatasetFacet(fields=[]), - "documentation": documentation_dataset.DocumentationDatasetFacet(description=""), + "schema": SchemaDatasetFacet(fields=[]), + "documentation": DocumentationDatasetFacet(description=""), } result = get_facets_from_bq_table(TEST_EMPTY_TABLE) assert result == expected_facets @@ -104,16 +125,16 @@ def test_get_identity_column_lineage_facet_multiple_input_datasets(): Dataset(namespace="gs://first_bucket", name="dir1"), Dataset(namespace="gs://second_bucket", name="dir2"), ] - expected_facet = column_lineage_dataset.ColumnLineageDatasetFacet( + expected_facet = ColumnLineageDatasetFacet( fields={ - "field1": column_lineage_dataset.Fields( + "field1": Fields( inputFields=[ - column_lineage_dataset.InputField( + InputField( namespace="gs://first_bucket", name="dir1", field="field1", ), - column_lineage_dataset.InputField( + InputField( namespace="gs://second_bucket", name="dir2", field="field1", @@ -122,14 +143,14 @@ def test_get_identity_column_lineage_facet_multiple_input_datasets(): transformationType="IDENTITY", transformationDescription="identical", ), - "field2": column_lineage_dataset.Fields( + "field2": Fields( inputFields=[ - column_lineage_dataset.InputField( + InputField( namespace="gs://first_bucket", name="dir1", field="field2", ), - column_lineage_dataset.InputField( + InputField( namespace="gs://second_bucket", name="dir2", field="field2", @@ -150,10 +171,8 @@ def test_get_identity_column_lineage_facet_no_field_names(): Dataset(namespace="gs://first_bucket", name="dir1"), Dataset(namespace="gs://second_bucket", name="dir2"), ] - expected_facet = column_lineage_dataset.ColumnLineageDatasetFacet(fields={}) - result = get_identity_column_lineage_facet( - field_names=field_names, input_datasets=input_datasets - ) + expected_facet = ColumnLineageDatasetFacet(fields={}) + result = get_identity_column_lineage_facet(field_names=field_names, input_datasets=input_datasets) assert result == expected_facet diff --git a/tests/providers/google/cloud/operators/test_bigquery.py b/tests/providers/google/cloud/operators/test_bigquery.py index 13e550b16ff3..d562455c63b2 100644 --- a/tests/providers/google/cloud/operators/test_bigquery.py +++ b/tests/providers/google/cloud/operators/test_bigquery.py @@ -19,6 +19,7 @@ import json from contextlib import suppress +from typing import TYPE_CHECKING from unittest import mock from unittest.mock import ANY, MagicMock @@ -26,8 +27,25 @@ import pytest from google.cloud.bigquery import DEFAULT_RETRY, ScalarQueryParameter from google.cloud.exceptions import Conflict -from openlineage.client.event_v2 import InputDataset -from openlineage.client.facet_v2 import error_message_run, external_query_run, sql_job + +if TYPE_CHECKING: + from openlineage.client.event_v2 import InputDataset + from openlineage.client.generated.error_message_run import ErrorMessageRunFacet + from openlineage.client.generated.external_query_run import ExternalQueryRunFacet + from openlineage.client.generated.sql_job import SQLJobFacet +else: + try: + from openlineage.client.event_v2 import InputDataset + from openlineage.client.generated.error_message_run import ErrorMessageRunFacet + from openlineage.client.generated.external_query_run import ExternalQueryRunFacet + from openlineage.client.generated.sql_job import SQLJobFacet + except ImportError: + from openlineage.client.facet import ( + ErrorMessageRunFacet, + ExternalQueryRunFacet, + SqlJobFacet as SQLJobFacet, + ) + from openlineage.client.run import InputDataset from airflow.exceptions import ( AirflowException, @@ -1850,11 +1868,9 @@ def test_execute_openlineage_events(self, mock_hook): assert lineage.run_facets == { "bigQuery_job": mock.ANY, "bigQueryJob": mock.ANY, - "externalQuery": external_query_run.ExternalQueryRunFacet( - externalQueryId=mock.ANY, source="bigquery" - ), + "externalQuery": ExternalQueryRunFacet(externalQueryId=mock.ANY, source="bigquery"), } - assert lineage.job_facets == {"sql": sql_job.SQLJobFacet(query="SELECT * FROM test_table")} + assert lineage.job_facets == {"sql": SQLJobFacet(query="SELECT * FROM test_table")} @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") def test_execute_fails_openlineage_events(self, mock_hook): @@ -1881,7 +1897,7 @@ def test_execute_fails_openlineage_events(self, mock_hook): operator.execute(MagicMock()) lineage = operator.get_openlineage_facets_on_complete(None) - assert isinstance(lineage.run_facets["errorMessage"], error_message_run.ErrorMessageRunFacet) + assert isinstance(lineage.run_facets["errorMessage"], ErrorMessageRunFacet) @pytest.mark.db_test @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") diff --git a/tests/providers/google/cloud/operators/test_gcs.py b/tests/providers/google/cloud/operators/test_gcs.py index 6299b043598f..5492882b6806 100644 --- a/tests/providers/google/cloud/operators/test_gcs.py +++ b/tests/providers/google/cloud/operators/test_gcs.py @@ -19,11 +19,33 @@ from datetime import datetime, timedelta, timezone from pathlib import Path +from typing import TYPE_CHECKING from unittest import mock import pytest -from openlineage.client.event_v2 import Dataset -from openlineage.client.facet_v2 import lifecycle_state_change_dataset + +if TYPE_CHECKING: + from openlineage.client.event_v2 import Dataset + from openlineage.client.generated.lifecycle_state_change_dataset import ( + LifecycleStateChange, + LifecycleStateChangeDatasetFacet, + PreviousIdentifier, + ) +else: + try: + from openlineage.client.event_v2 import Dataset + from openlineage.client.generated.lifecycle_state_change_dataset import ( + LifecycleStateChange, + LifecycleStateChangeDatasetFacet, + PreviousIdentifier, + ) + except ImportError: + from openlineage.client.facet import ( + LifecycleStateChange, + LifecycleStateChangeDatasetFacet, + LifecycleStateChangeDatasetFacetPreviousIdentifier as PreviousIdentifier, + ) + from openlineage.client.run import Dataset from airflow.providers.google.cloud.operators.gcs import ( GCSBucketCreateAclEntryOperator, @@ -202,9 +224,9 @@ def test_get_openlineage_facets_on_start(self, objects, prefix, inputs): namespace=bucket_url, name=name, facets={ - "lifecycleStateChange": lifecycle_state_change_dataset.LifecycleStateChangeDatasetFacet( - lifecycleStateChange=lifecycle_state_change_dataset.LifecycleStateChange.DROP.value, - previousIdentifier=lifecycle_state_change_dataset.PreviousIdentifier( + "lifecycleStateChange": LifecycleStateChangeDatasetFacet( + lifecycleStateChange=LifecycleStateChange.DROP.value, + previousIdentifier=PreviousIdentifier( namespace=bucket_url, name=name, ), diff --git a/tests/providers/google/cloud/transfers/test_bigquery_to_gcs.py b/tests/providers/google/cloud/transfers/test_bigquery_to_gcs.py index 8e77d505c3ae..0af142c56296 100644 --- a/tests/providers/google/cloud/transfers/test_bigquery_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_bigquery_to_gcs.py @@ -17,20 +17,48 @@ # under the License. from __future__ import annotations +from typing import TYPE_CHECKING from unittest import mock from unittest.mock import MagicMock import pytest from google.cloud.bigquery.retry import DEFAULT_RETRY from google.cloud.bigquery.table import Table -from openlineage.client.event_v2 import Dataset -from openlineage.client.facet_v2 import ( - column_lineage_dataset, - documentation_dataset, - external_query_run, - schema_dataset, - symlinks_dataset, -) + +if TYPE_CHECKING: + from openlineage.client.event_v2 import Dataset + from openlineage.client.generated.column_lineage_dataset import ( + ColumnLineageDatasetFacet, + Fields, + InputField, + ) + from openlineage.client.generated.documentation_dataset import DocumentationDatasetFacet + from openlineage.client.generated.external_query_run import ExternalQueryRunFacet + from openlineage.client.generated.schema_dataset import SchemaDatasetFacet, SchemaDatasetFacetFields + from openlineage.client.generated.symlinks_dataset import Identifier, SymlinksDatasetFacet +else: + try: + from openlineage.client.event_v2 import Dataset + from openlineage.client.generated.column_lineage_dataset import ( + ColumnLineageDatasetFacet, + Fields, + InputField, + ) + from openlineage.client.generated.documentation_dataset import DocumentationDatasetFacet + from openlineage.client.generated.external_query_run import ExternalQueryRunFacet + from openlineage.client.generated.schema_dataset import SchemaDatasetFacet, SchemaDatasetFacetFields + from openlineage.client.generated.symlinks_dataset import Identifier, SymlinksDatasetFacet + except ImportError: + from openlineage.client.facet import ( + ColumnLineageDatasetFacet, + ColumnLineageDatasetFacetFieldsAdditional as Fields, + ColumnLineageDatasetFacetFieldsAdditionalInputFields as InputField, + DocumentationDatasetFacet, + ExternalQueryRunFacet, + SchemaDatasetFacet, + SchemaField as SchemaDatasetFacetFields, + ) + from openlineage.client.run import Dataset from airflow.exceptions import TaskDeferred from airflow.providers.google.cloud.transfers.bigquery_to_gcs import BigQueryToGCSOperator @@ -263,17 +291,13 @@ def test_get_openlineage_facets_on_complete_bq_dataset(self, mock_hook): source_project_dataset_table = f"{PROJECT_ID}.{TEST_DATASET}.{TEST_TABLE_ID}" expected_input_dataset_facets = { - "schema": schema_dataset.SchemaDatasetFacet( + "schema": SchemaDatasetFacet( fields=[ - schema_dataset.SchemaDatasetFacetFields( - name="field1", type="STRING", description="field1 description" - ), - schema_dataset.SchemaDatasetFacetFields(name="field2", type="INTEGER"), + SchemaDatasetFacetFields(name="field1", type="STRING", description="field1 description"), + SchemaDatasetFacetFields(name="field2", type="INTEGER"), ] ), - "documentation": documentation_dataset.DocumentationDatasetFacet( - description="Table description." - ), + "documentation": DocumentationDatasetFacet(description="Table description."), } mock_hook.return_value.split_tablename.return_value = (PROJECT_ID, TEST_DATASET, TEST_TABLE_ID) @@ -300,8 +324,8 @@ def test_get_openlineage_facets_on_complete_bq_dataset_empty_table(self, mock_ho source_project_dataset_table = f"{PROJECT_ID}.{TEST_DATASET}.{TEST_TABLE_ID}" expected_input_dataset_facets = { - "schema": schema_dataset.SchemaDatasetFacet(fields=[]), - "documentation": documentation_dataset.DocumentationDatasetFacet(description=""), + "schema": SchemaDatasetFacet(fields=[]), + "documentation": DocumentationDatasetFacet(description=""), } mock_hook.return_value.split_tablename.return_value = (PROJECT_ID, TEST_DATASET, TEST_TABLE_ID) @@ -331,13 +355,13 @@ def test_get_openlineage_facets_on_complete_gcs_no_wildcard_empty_table(self, mo bq_namespace = "bigquery" expected_input_facets = { - "schema": schema_dataset.SchemaDatasetFacet(fields=[]), - "documentation": documentation_dataset.DocumentationDatasetFacet(description=""), + "schema": SchemaDatasetFacet(fields=[]), + "documentation": DocumentationDatasetFacet(description=""), } expected_output_facets = { - "schema": schema_dataset.SchemaDatasetFacet(fields=[]), - "columnLineage": column_lineage_dataset.ColumnLineageDatasetFacet(fields={}), + "schema": SchemaDatasetFacet(fields=[]), + "columnLineage": ColumnLineageDatasetFacet(fields={}), } mock_hook.return_value.split_tablename.return_value = (PROJECT_ID, TEST_DATASET, TEST_TABLE_ID) @@ -365,9 +389,7 @@ def test_get_openlineage_facets_on_complete_gcs_no_wildcard_empty_table(self, mo facets=expected_output_facets, ) assert lineage.run_facets == { - "externalQuery": external_query_run.ExternalQueryRunFacet( - externalQueryId=real_job_id, source=bq_namespace - ) + "externalQuery": ExternalQueryRunFacet(externalQueryId=real_job_id, source=bq_namespace) } assert lineage.job_facets == {} @@ -378,37 +400,33 @@ def test_get_openlineage_facets_on_complete_gcs_wildcard_full_table(self, mock_h real_job_id = "123456_hash" bq_namespace = "bigquery" - schema_facet = schema_dataset.SchemaDatasetFacet( + schema_facet = SchemaDatasetFacet( fields=[ - schema_dataset.SchemaDatasetFacetFields( - name="field1", type="STRING", description="field1 description" - ), - schema_dataset.SchemaDatasetFacetFields(name="field2", type="INTEGER"), + SchemaDatasetFacetFields(name="field1", type="STRING", description="field1 description"), + SchemaDatasetFacetFields(name="field2", type="INTEGER"), ] ) expected_input_facets = { "schema": schema_facet, - "documentation": documentation_dataset.DocumentationDatasetFacet( - description="Table description." - ), + "documentation": DocumentationDatasetFacet(description="Table description."), } expected_output_facets = { "schema": schema_facet, - "columnLineage": column_lineage_dataset.ColumnLineageDatasetFacet( + "columnLineage": ColumnLineageDatasetFacet( fields={ - "field1": column_lineage_dataset.Fields( + "field1": Fields( inputFields=[ - column_lineage_dataset.InputField( + InputField( namespace=bq_namespace, name=source_project_dataset_table, field="field1" ) ], transformationType="IDENTITY", transformationDescription="identical", ), - "field2": column_lineage_dataset.Fields( + "field2": Fields( inputFields=[ - column_lineage_dataset.InputField( + InputField( namespace=bq_namespace, name=source_project_dataset_table, field="field2" ) ], @@ -417,9 +435,9 @@ def test_get_openlineage_facets_on_complete_gcs_wildcard_full_table(self, mock_h ), } ), - "symlink": symlinks_dataset.SymlinksDatasetFacet( + "symlink": SymlinksDatasetFacet( identifiers=[ - symlinks_dataset.Identifier( + Identifier( namespace=f"gs://{TEST_BUCKET}", name=f"{TEST_FOLDER}/{TEST_OBJECT_WILDCARD}", type="file", @@ -451,8 +469,6 @@ def test_get_openlineage_facets_on_complete_gcs_wildcard_full_table(self, mock_h namespace=f"gs://{TEST_BUCKET}", name=TEST_FOLDER, facets=expected_output_facets ) assert lineage.run_facets == { - "externalQuery": external_query_run.ExternalQueryRunFacet( - externalQueryId=real_job_id, source=bq_namespace - ) + "externalQuery": ExternalQueryRunFacet(externalQueryId=real_job_id, source=bq_namespace) } assert lineage.job_facets == {} diff --git a/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py b/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py index 99ad21dbb1ee..d39efde3b10e 100644 --- a/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py +++ b/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py @@ -18,20 +18,48 @@ from __future__ import annotations import json +from typing import TYPE_CHECKING from unittest import mock from unittest.mock import MagicMock, call import pytest from google.cloud.bigquery import DEFAULT_RETRY, Table from google.cloud.exceptions import Conflict -from openlineage.client.event_v2 import Dataset -from openlineage.client.facet_v2 import ( - column_lineage_dataset, - documentation_dataset, - external_query_run, - schema_dataset, - symlinks_dataset, -) + +if TYPE_CHECKING: + from openlineage.client.event_v2 import Dataset + from openlineage.client.generated.column_lineage_dataset import ( + ColumnLineageDatasetFacet, + Fields, + InputField, + ) + from openlineage.client.generated.documentation_dataset import DocumentationDatasetFacet + from openlineage.client.generated.external_query_run import ExternalQueryRunFacet + from openlineage.client.generated.schema_dataset import SchemaDatasetFacet, SchemaDatasetFacetFields + from openlineage.client.generated.symlinks_dataset import Identifier, SymlinksDatasetFacet +else: + try: + from openlineage.client.event_v2 import Dataset + from openlineage.client.generated.column_lineage_dataset import ( + ColumnLineageDatasetFacet, + Fields, + InputField, + ) + from openlineage.client.generated.documentation_dataset import DocumentationDatasetFacet + from openlineage.client.generated.external_query_run import ExternalQueryRunFacet + from openlineage.client.generated.schema_dataset import SchemaDatasetFacet, SchemaDatasetFacetFields + from openlineage.client.generated.symlinks_dataset import Identifier, SymlinksDatasetFacet + except ImportError: + from openlineage.client.facet import ( + ColumnLineageDatasetFacet, + ColumnLineageDatasetFacetFieldsAdditional as Fields, + ColumnLineageDatasetFacetFieldsAdditionalInputFields as InputField, + DocumentationDatasetFacet, + ExternalQueryRunFacet, + SchemaDatasetFacet, + SchemaField as SchemaDatasetFacetFields, + ) + from openlineage.client.run import Dataset from airflow.exceptions import AirflowException, TaskDeferred from airflow.models import DAG @@ -1253,9 +1281,9 @@ def test_get_openlineage_facets_on_complete_gcs_dataset_name( destination_project_dataset_table=TEST_EXPLICIT_DEST, ) - expected_symlink = symlinks_dataset.SymlinksDatasetFacet( + expected_symlink = SymlinksDatasetFacet( identifiers=[ - symlinks_dataset.Identifier( + Identifier( namespace=f"gs://{TEST_BUCKET}", name=source_object, type="file", @@ -1294,9 +1322,9 @@ def test_get_openlineage_facets_on_complete_gcs_multiple_uris(self, hook): assert len(lineage.inputs) == 4 assert lineage.inputs[0].name == TEST_OBJECT_NO_WILDCARD assert lineage.inputs[1].name == "/" - assert lineage.inputs[1].facets.get("symlink") == symlinks_dataset.SymlinksDatasetFacet( + assert lineage.inputs[1].facets.get("symlink") == SymlinksDatasetFacet( identifiers=[ - symlinks_dataset.Identifier( + Identifier( namespace=f"gs://{TEST_BUCKET}", name=TEST_OBJECT_WILDCARD, type="file", @@ -1305,9 +1333,9 @@ def test_get_openlineage_facets_on_complete_gcs_multiple_uris(self, hook): ) assert lineage.inputs[2].name == f"{TEST_FOLDER}1/{TEST_OBJECT_NO_WILDCARD}" assert lineage.inputs[3].name == f"{TEST_FOLDER}2" - assert lineage.inputs[3].facets.get("symlink") == symlinks_dataset.SymlinksDatasetFacet( + assert lineage.inputs[3].facets.get("symlink") == SymlinksDatasetFacet( identifiers=[ - symlinks_dataset.Identifier( + Identifier( namespace=f"gs://{TEST_BUCKET}", name=f"{TEST_FOLDER}2/{TEST_OBJECT_WILDCARD}", type="file", @@ -1322,29 +1350,27 @@ def test_get_openlineage_facets_on_complete_bq_dataset(self, hook): hook.return_value.get_client.return_value.get_table.return_value = TEST_TABLE expected_output_dataset_facets = { - "schema": schema_dataset.SchemaDatasetFacet( + "schema": SchemaDatasetFacet( fields=[ - schema_dataset.SchemaDatasetFacetFields( - name="field1", type="STRING", description="field1 description" - ), - schema_dataset.SchemaDatasetFacetFields(name="field2", type="INTEGER"), + SchemaDatasetFacetFields(name="field1", type="STRING", description="field1 description"), + SchemaDatasetFacetFields(name="field2", type="INTEGER"), ] ), - "documentation": documentation_dataset.DocumentationDatasetFacet(description="Test Description"), - "columnLineage": column_lineage_dataset.ColumnLineageDatasetFacet( + "documentation": DocumentationDatasetFacet(description="Test Description"), + "columnLineage": ColumnLineageDatasetFacet( fields={ - "field1": column_lineage_dataset.Fields( + "field1": Fields( inputFields=[ - column_lineage_dataset.InputField( + InputField( namespace=f"gs://{TEST_BUCKET}", name=TEST_OBJECT_NO_WILDCARD, field="field1" ) ], transformationType="IDENTITY", transformationDescription="identical", ), - "field2": column_lineage_dataset.Fields( + "field2": Fields( inputFields=[ - column_lineage_dataset.InputField( + InputField( namespace=f"gs://{TEST_BUCKET}", name=TEST_OBJECT_NO_WILDCARD, field="field2" ) ], @@ -1380,37 +1406,31 @@ def test_get_openlineage_facets_on_complete_bq_dataset_multiple_gcs_uris(self, h hook.return_value.get_client.return_value.get_table.return_value = TEST_TABLE expected_output_dataset_facets = { - "schema": schema_dataset.SchemaDatasetFacet( + "schema": SchemaDatasetFacet( fields=[ - schema_dataset.SchemaDatasetFacetFields( - name="field1", type="STRING", description="field1 description" - ), - schema_dataset.SchemaDatasetFacetFields(name="field2", type="INTEGER"), + SchemaDatasetFacetFields(name="field1", type="STRING", description="field1 description"), + SchemaDatasetFacetFields(name="field2", type="INTEGER"), ] ), - "documentation": documentation_dataset.DocumentationDatasetFacet(description="Test Description"), - "columnLineage": column_lineage_dataset.ColumnLineageDatasetFacet( + "documentation": DocumentationDatasetFacet(description="Test Description"), + "columnLineage": ColumnLineageDatasetFacet( fields={ - "field1": column_lineage_dataset.Fields( + "field1": Fields( inputFields=[ - column_lineage_dataset.InputField( + InputField( namespace=f"gs://{TEST_BUCKET}", name=TEST_OBJECT_NO_WILDCARD, field="field1" ), - column_lineage_dataset.InputField( - namespace=f"gs://{TEST_BUCKET}", name="/", field="field1" - ), + InputField(namespace=f"gs://{TEST_BUCKET}", name="/", field="field1"), ], transformationType="IDENTITY", transformationDescription="identical", ), - "field2": column_lineage_dataset.Fields( + "field2": Fields( inputFields=[ - column_lineage_dataset.InputField( + InputField( namespace=f"gs://{TEST_BUCKET}", name=TEST_OBJECT_NO_WILDCARD, field="field2" ), - column_lineage_dataset.InputField( - namespace=f"gs://{TEST_BUCKET}", name="/", field="field2" - ), + InputField(namespace=f"gs://{TEST_BUCKET}", name="/", field="field2"), ], transformationType="IDENTITY", transformationDescription="identical", @@ -1444,9 +1464,9 @@ def test_get_openlineage_facets_on_complete_empty_table(self, hook): hook.return_value.get_client.return_value.get_table.return_value = TEST_EMPTY_TABLE expected_output_dataset_facets = { - "schema": schema_dataset.SchemaDatasetFacet(fields=[]), - "documentation": documentation_dataset.DocumentationDatasetFacet(description=""), - "columnLineage": column_lineage_dataset.ColumnLineageDatasetFacet(fields={}), + "schema": SchemaDatasetFacet(fields=[]), + "documentation": DocumentationDatasetFacet(description=""), + "columnLineage": ColumnLineageDatasetFacet(fields={}), } operator = GCSToBigQueryOperator( @@ -1470,16 +1490,16 @@ def test_get_openlineage_facets_on_complete_empty_table(self, hook): assert lineage.inputs[0] == Dataset( namespace=f"gs://{TEST_BUCKET}", name=TEST_OBJECT_NO_WILDCARD, - facets={"schema": schema_dataset.SchemaDatasetFacet(fields=[])}, + facets={"schema": SchemaDatasetFacet(fields=[])}, ) assert lineage.inputs[1] == Dataset( namespace=f"gs://{TEST_BUCKET}", name="/", facets={ - "schema": schema_dataset.SchemaDatasetFacet(fields=[]), - "symlink": symlinks_dataset.SymlinksDatasetFacet( + "schema": SchemaDatasetFacet(fields=[]), + "symlink": SymlinksDatasetFacet( identifiers=[ - symlinks_dataset.Identifier( + Identifier( namespace=f"gs://{TEST_BUCKET}", name=TEST_OBJECT_WILDCARD, type="file", @@ -1496,20 +1516,18 @@ def test_get_openlineage_facets_on_complete_full_table_multiple_gcs_uris(self, h hook.return_value.get_client.return_value.get_table.return_value = TEST_TABLE hook.return_value.generate_job_id.return_value = REAL_JOB_ID - schema_facet = schema_dataset.SchemaDatasetFacet( + schema_facet = SchemaDatasetFacet( fields=[ - schema_dataset.SchemaDatasetFacetFields( - name="field1", type="STRING", description="field1 description" - ), - schema_dataset.SchemaDatasetFacetFields(name="field2", type="INTEGER"), + SchemaDatasetFacetFields(name="field1", type="STRING", description="field1 description"), + SchemaDatasetFacetFields(name="field2", type="INTEGER"), ] ) expected_input_wildcard_dataset_facets = { "schema": schema_facet, - "symlink": symlinks_dataset.SymlinksDatasetFacet( + "symlink": SymlinksDatasetFacet( identifiers=[ - symlinks_dataset.Identifier( + Identifier( namespace=f"gs://{TEST_BUCKET}", name=TEST_OBJECT_WILDCARD, type="file", @@ -1521,29 +1539,25 @@ def test_get_openlineage_facets_on_complete_full_table_multiple_gcs_uris(self, h expected_output_dataset_facets = { "schema": schema_facet, - "documentation": documentation_dataset.DocumentationDatasetFacet(description="Test Description"), - "columnLineage": column_lineage_dataset.ColumnLineageDatasetFacet( + "documentation": DocumentationDatasetFacet(description="Test Description"), + "columnLineage": ColumnLineageDatasetFacet( fields={ - "field1": column_lineage_dataset.Fields( + "field1": Fields( inputFields=[ - column_lineage_dataset.InputField( + InputField( namespace=f"gs://{TEST_BUCKET}", name=TEST_OBJECT_NO_WILDCARD, field="field1" ), - column_lineage_dataset.InputField( - namespace=f"gs://{TEST_BUCKET}", name="/", field="field1" - ), + InputField(namespace=f"gs://{TEST_BUCKET}", name="/", field="field1"), ], transformationType="IDENTITY", transformationDescription="identical", ), - "field2": column_lineage_dataset.Fields( + "field2": Fields( inputFields=[ - column_lineage_dataset.InputField( + InputField( namespace=f"gs://{TEST_BUCKET}", name=TEST_OBJECT_NO_WILDCARD, field="field2" ), - column_lineage_dataset.InputField( - namespace=f"gs://{TEST_BUCKET}", name="/", field="field2" - ), + InputField(namespace=f"gs://{TEST_BUCKET}", name="/", field="field2"), ], transformationType="IDENTITY", transformationDescription="identical", @@ -1578,9 +1592,7 @@ def test_get_openlineage_facets_on_complete_full_table_multiple_gcs_uris(self, h namespace=f"gs://{TEST_BUCKET}", name="/", facets=expected_input_wildcard_dataset_facets ) assert lineage.run_facets == { - "externalQuery": external_query_run.ExternalQueryRunFacet( - externalQueryId=REAL_JOB_ID, source="bigquery" - ) + "externalQuery": ExternalQueryRunFacet(externalQueryId=REAL_JOB_ID, source="bigquery") } assert lineage.job_facets == {} diff --git a/tests/providers/google/cloud/transfers/test_gcs_to_gcs.py b/tests/providers/google/cloud/transfers/test_gcs_to_gcs.py index f8fda2eb0a49..7fc33ec42cd6 100644 --- a/tests/providers/google/cloud/transfers/test_gcs_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_gcs_to_gcs.py @@ -18,10 +18,18 @@ from __future__ import annotations from datetime import datetime +from typing import TYPE_CHECKING from unittest import mock import pytest -from openlineage.client.event_v2 import Dataset + +if TYPE_CHECKING: + from openlineage.client.event_v2 import Dataset +else: + try: + from openlineage.client.event_v2 import Dataset + except ImportError: + from openlineage.client.run import Dataset from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.google.cloud.transfers.gcs_to_gcs import WILDCARD, GCSToGCSOperator diff --git a/tests/providers/mysql/operators/test_mysql.py b/tests/providers/mysql/operators/test_mysql.py index a3715decfb2e..9e66c05e80c7 100644 --- a/tests/providers/mysql/operators/test_mysql.py +++ b/tests/providers/mysql/operators/test_mysql.py @@ -19,11 +19,27 @@ import os from contextlib import closing +from typing import TYPE_CHECKING from unittest.mock import MagicMock import pytest -from openlineage.client.event_v2 import Dataset -from openlineage.client.facet_v2 import schema_dataset, sql_job + +if TYPE_CHECKING: + from openlineage.client.event_v2 import Dataset + from openlineage.client.generated.schema_dataset import SchemaDatasetFacet, SchemaDatasetFacetFields + from openlineage.client.generated.sql_job import SQLJobFacet +else: + try: + from openlineage.client.event_v2 import Dataset + from openlineage.client.generated.schema_dataset import SchemaDatasetFacet, SchemaDatasetFacetFields + from openlineage.client.generated.sql_job import SQLJobFacet + except ImportError: + from openlineage.client.facet import ( + SchemaDatasetFacet, + SchemaField as SchemaDatasetFacetFields, + SqlJobFacet as SQLJobFacet, + ) + from openlineage.client.run import Dataset from airflow.models.connection import Connection from airflow.models.dag import DAG @@ -167,17 +183,17 @@ class MySqlHookForTests(MySqlHook): namespace=f"mysql://host:{connection_port or 3306}", name="PUBLIC.popular_orders_day_of_week", facets={ - "schema": schema_dataset.SchemaDatasetFacet( + "schema": SchemaDatasetFacet( fields=[ - schema_dataset.SchemaDatasetFacetFields(name="order_day_of_week", type="varchar"), - schema_dataset.SchemaDatasetFacetFields(name="order_placed_on", type="timestamp"), - schema_dataset.SchemaDatasetFacetFields(name="orders_placed", type="int4"), + SchemaDatasetFacetFields(name="order_day_of_week", type="varchar"), + SchemaDatasetFacetFields(name="order_placed_on", type="timestamp"), + SchemaDatasetFacetFields(name="orders_placed", type="int4"), ] ) }, ) ] - assert lineage.job_facets == {"sql": sql_job.SQLJobFacet(query=sql)} + assert lineage.job_facets == {"sql": SQLJobFacet(query=sql)} assert lineage.run_facets["extractionError"].failedTasks == 1 diff --git a/tests/providers/sftp/operators/test_sftp.py b/tests/providers/sftp/operators/test_sftp.py index f8c2c1d21e56..b98b7570e8ac 100644 --- a/tests/providers/sftp/operators/test_sftp.py +++ b/tests/providers/sftp/operators/test_sftp.py @@ -21,11 +21,19 @@ import os import socket from base64 import b64encode +from typing import TYPE_CHECKING from unittest import mock import paramiko import pytest -from openlineage.client.event_v2 import Dataset + +if TYPE_CHECKING: + from openlineage.client.event_v2 import Dataset +else: + try: + from openlineage.client.event_v2 import Dataset + except ImportError: + from openlineage.client.run import Dataset from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.models import DAG, Connection diff --git a/tests/providers/snowflake/operators/test_snowflake_sql.py b/tests/providers/snowflake/operators/test_snowflake_sql.py index 3a3562002c57..c9380bf17073 100644 --- a/tests/providers/snowflake/operators/test_snowflake_sql.py +++ b/tests/providers/snowflake/operators/test_snowflake_sql.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +from typing import TYPE_CHECKING from unittest import mock from unittest.mock import MagicMock, patch @@ -36,8 +37,31 @@ def Row(*args, **kwargs): return MagicMock() -from openlineage.client.event_v2 import Dataset -from openlineage.client.facet_v2 import column_lineage_dataset, sql_job +if TYPE_CHECKING: + from openlineage.client.event_v2 import Dataset + from openlineage.client.generated.column_lineage_dataset import ( + ColumnLineageDatasetFacet, + Fields, + InputField, + ) + from openlineage.client.generated.sql_job import SQLJobFacet +else: + try: + from openlineage.client.event_v2 import Dataset + from openlineage.client.generated.column_lineage_dataset import ( + ColumnLineageDatasetFacet, + Fields, + InputField, + ) + from openlineage.client.generated.sql_job import SQLJobFacet + except ImportError: + from openlineage.client.facet import ( + ColumnLineageDatasetFacet, + ColumnLineageDatasetFacetFieldsAdditional as Fields, + ColumnLineageDatasetFacetFieldsAdditionalInputFields as InputField, + SqlJobFacet as SQLJobFacet, + ) + from openlineage.client.run import Dataset from airflow.models.connection import Connection from airflow.providers.common.sql.hooks.sql import fetch_all_handler @@ -247,11 +271,11 @@ def get_db_hook(self): namespace="snowflake://test_account.us-east.aws", name=f"{DB_NAME}.{DB_SCHEMA_NAME}.TEST_TABLE", facets={ - "columnLineage": column_lineage_dataset.ColumnLineageDatasetFacet( + "columnLineage": ColumnLineageDatasetFacet( fields={ - "additional_constant": column_lineage_dataset.Fields( + "additional_constant": Fields( inputFields=[ - column_lineage_dataset.InputField( + InputField( namespace="snowflake://test_account.us-east.aws", name="DATABASE.PUBLIC.little_table", field="additional_constant", @@ -266,6 +290,6 @@ def get_db_hook(self): ) ] - assert lineage.job_facets == {"sql": sql_job.SQLJobFacet(query=sql)} + assert lineage.job_facets == {"sql": SQLJobFacet(query=sql)} assert lineage.run_facets["extractionError"].failedTasks == 1 diff --git a/tests/providers/snowflake/transfers/test_copy_into_snowflake.py b/tests/providers/snowflake/transfers/test_copy_into_snowflake.py index 17b0d1297752..3437d85e8d7b 100644 --- a/tests/providers/snowflake/transfers/test_copy_into_snowflake.py +++ b/tests/providers/snowflake/transfers/test_copy_into_snowflake.py @@ -16,12 +16,31 @@ # under the License. from __future__ import annotations -from typing import Callable +from typing import TYPE_CHECKING, Callable from unittest import mock import pytest -from openlineage.client.event_v2 import Dataset -from openlineage.client.facet_v2 import external_query_run, extraction_error_run, sql_job + +if TYPE_CHECKING: + from openlineage.client.event_v2 import Dataset + from openlineage.client.generated.external_query_run import ExternalQueryRunFacet + from openlineage.client.generated.extraction_error_run import Error, ExtractionErrorRunFacet + from openlineage.client.generated.sql_job import SQLJobFacet +else: + try: + from openlineage.client.event_v2 import Dataset + from openlineage.client.generated.external_query_run import ExternalQueryRunFacet + from openlineage.client.generated.extraction_error_run import Error, ExtractionErrorRunFacet + from openlineage.client.generated.sql_job import SQLJobFacet + except ImportError: + from openlineage.client.facet import ( + ExternalQueryRunFacet, + ExtractionError as Error, + ExtractionErrorRunFacet, + SqlJobFacet as SQLJobFacet, + ) + from openlineage.client.run import Dataset + from airflow.providers.openlineage.extractors import OperatorLineage from airflow.providers.openlineage.sqlparser import DatabaseInfo @@ -119,11 +138,11 @@ def test_get_openlineage_facets_on_complete(self, mock_hook): inputs=expected_inputs, outputs=expected_outputs, run_facets={ - "externalQuery": external_query_run.ExternalQueryRunFacet( + "externalQuery": ExternalQueryRunFacet( externalQueryId="query_id_123", source="snowflake_scheme://authority" ) }, - job_facets={"sql": sql_job.SQLJobFacet(query=expected_sql)}, + job_facets={"sql": SQLJobFacet(query=expected_sql)}, ) @pytest.mark.parametrize("rows", (None, [])) @@ -155,11 +174,11 @@ def test_get_openlineage_facets_on_complete_with_empty_inputs(self, mock_hook, r inputs=[], outputs=expected_outputs, run_facets={ - "externalQuery": external_query_run.ExternalQueryRunFacet( + "externalQuery": ExternalQueryRunFacet( externalQueryId="query_id_123", source="snowflake_scheme://authority" ) }, - job_facets={"sql": sql_job.SQLJobFacet(query=expected_sql)}, + job_facets={"sql": SQLJobFacet(query=expected_sql)}, ) @mock.patch("airflow.providers.snowflake.transfers.copy_into_snowflake.SnowflakeHook") @@ -185,17 +204,17 @@ def test_get_openlineage_facets_on_complete_unsupported_azure_uri(self, mock_hoo ] expected_sql = """COPY INTO schema.table\n FROM @stage/\n FILE_FORMAT=CSV""" expected_run_facets = { - "extractionError": extraction_error_run.ExtractionErrorRunFacet( + "extractionError": ExtractionErrorRunFacet( totalTasks=4, failedTasks=2, errors=[ - extraction_error_run.Error( + Error( errorMessage="Unable to extract Dataset namespace and name.", stackTrace=None, task="azure://my_account.another_weird-url.net/con/file.csv", taskNumber=None, ), - extraction_error_run.Error( + Error( errorMessage="Unable to extract Dataset namespace and name.", stackTrace=None, task="azure://my_account.weird-url.net/azure_container/dir3/file.csv", @@ -203,7 +222,7 @@ def test_get_openlineage_facets_on_complete_unsupported_azure_uri(self, mock_hoo ), ], ), - "externalQuery": external_query_run.ExternalQueryRunFacet( + "externalQuery": ExternalQueryRunFacet( externalQueryId="query_id_123", source="snowflake_scheme://authority" ), } @@ -222,5 +241,5 @@ def test_get_openlineage_facets_on_complete_unsupported_azure_uri(self, mock_hoo inputs=expected_inputs, outputs=expected_outputs, run_facets=expected_run_facets, - job_facets={"sql": sql_job.SQLJobFacet(query=expected_sql)}, + job_facets={"sql": SQLJobFacet(query=expected_sql)}, ) diff --git a/tests/providers/trino/operators/test_trino.py b/tests/providers/trino/operators/test_trino.py index 126b3db6335c..d3b219b58675 100644 --- a/tests/providers/trino/operators/test_trino.py +++ b/tests/providers/trino/operators/test_trino.py @@ -17,10 +17,25 @@ # under the License. from __future__ import annotations +from typing import TYPE_CHECKING from unittest import mock -from openlineage.client.event_v2 import Dataset -from openlineage.client.facet_v2 import schema_dataset, sql_job +if TYPE_CHECKING: + from openlineage.client.event_v2 import Dataset + from openlineage.client.generated.schema_dataset import SchemaDatasetFacet, SchemaDatasetFacetFields + from openlineage.client.generated.sql_job import SQLJobFacet +else: + try: + from openlineage.client.event_v2 import Dataset + from openlineage.client.generated.schema_dataset import SchemaDatasetFacet, SchemaDatasetFacetFields + from openlineage.client.generated.sql_job import SQLJobFacet + except ImportError: + from openlineage.client.facet import ( + SchemaDatasetFacet, + SchemaField as SchemaDatasetFacetFields, + SqlJobFacet as SQLJobFacet, + ) + from openlineage.client.run import Dataset from airflow.models.connection import Connection from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator @@ -91,14 +106,14 @@ def get_first(self, *_): namespace="trino://trino:8080", name=f"{DB_NAME}.{DB_SCHEMA_NAME}.customer", facets={ - "schema": schema_dataset.SchemaDatasetFacet( + "schema": SchemaDatasetFacet( fields=[ - schema_dataset.SchemaDatasetFacetFields(name="custkey", type="bigint"), - schema_dataset.SchemaDatasetFacetFields(name="name", type="varchar(25)"), - schema_dataset.SchemaDatasetFacetFields(name="address", type="varchar(40)"), - schema_dataset.SchemaDatasetFacetFields(name="nationkey", type="bigint"), - schema_dataset.SchemaDatasetFacetFields(name="phone", type="varchar(15)"), - schema_dataset.SchemaDatasetFacetFields(name="acctbal", type="double"), + SchemaDatasetFacetFields(name="custkey", type="bigint"), + SchemaDatasetFacetFields(name="name", type="varchar(25)"), + SchemaDatasetFacetFields(name="address", type="varchar(40)"), + SchemaDatasetFacetFields(name="nationkey", type="bigint"), + SchemaDatasetFacetFields(name="phone", type="varchar(15)"), + SchemaDatasetFacetFields(name="acctbal", type="double"), ] ) }, @@ -107,4 +122,4 @@ def get_first(self, *_): assert len(lineage.outputs) == 0 - assert lineage.job_facets == {"sql": sql_job.SQLJobFacet(query=sql)} + assert lineage.job_facets == {"sql": SQLJobFacet(query=sql)} From 7827eefe5ed624a998431ea6faf810e303a9e472 Mon Sep 17 00:00:00 2001 From: Jakub Dardzinski Date: Wed, 17 Jul 2024 10:30:41 +0200 Subject: [PATCH 6/7] Move V2 facets code imports to `common.compat` provider. Signed-off-by: Jakub Dardzinski --- .../providers/amazon/aws/operators/athena.py | 69 ++------ airflow/providers/amazon/aws/operators/s3.py | 47 +----- .../amazon/aws/operators/sagemaker.py | 9 +- airflow/providers/amazon/provider.yaml | 1 + airflow/providers/common/compat/__init__.py | 2 +- .../common/compat/openlineage/__init__.py | 16 ++ .../common/compat/openlineage/facet.py | 158 ++++++++++++++++++ .../common/io/operators/file_transfer.py | 9 +- airflow/providers/ftp/operators/ftp.py | 11 +- .../google/cloud/openlineage/mixins.py | 82 +++------ .../google/cloud/openlineage/utils.py | 40 ++--- .../providers/google/cloud/operators/gcs.py | 45 +---- .../cloud/transfers/azure_blob_to_gcs.py | 3 +- .../google/cloud/transfers/bigquery_to_gcs.py | 22 +-- .../google/cloud/transfers/gcs_to_bigquery.py | 23 +-- .../google/cloud/transfers/gcs_to_gcs.py | 9 +- airflow/providers/google/provider.yaml | 1 + .../providers/openlineage/plugins/adapter.py | 2 +- airflow/providers/sftp/operators/sftp.py | 11 +- .../providers/snowflake/hooks/snowflake.py | 9 +- airflow/providers/snowflake/provider.yaml | 1 + .../transfers/copy_into_snowflake.py | 29 +--- dev/breeze/tests/test_packages.py | 3 + dev/breeze/tests/test_selective_checks.py | 22 +-- .../guides/developer.rst | 16 +- generated/provider_dependencies.json | 7 + hatch_build.py | 1 + tests/dags/test_openlineage_execution.py | 3 +- .../amazon/aws/operators/test_athena.py | 33 +--- .../amazon/aws/operators/test_redshift_sql.py | 157 ++++++++--------- .../providers/amazon/aws/operators/test_s3.py | 30 +--- .../operators/test_sagemaker_processing.py | 10 +- .../aws/operators/test_sagemaker_training.py | 10 +- .../aws/operators/test_sagemaker_transform.py | 10 +- .../common/compat/openlineage/__init__.py | 16 ++ .../common/compat/openlineage/test_facet.py | 22 +++ .../common/io/operators/test_file_transfer.py | 10 +- .../common/sql/operators/test_sql_execute.py | 25 +-- .../dbt/cloud/utils/test_openlineage.py | 1 + tests/providers/ftp/operators/test_ftp.py | 10 +- .../google/cloud/openlineage/test_mixins.py | 33 +--- .../google/cloud/openlineage/test_utils.py | 40 +---- .../google/cloud/operators/test_bigquery.py | 26 +-- .../google/cloud/operators/test_gcs.py | 30 +--- .../cloud/transfers/test_azure_blob_to_gcs.py | 2 +- .../cloud/transfers/test_bigquery_to_gcs.py | 48 ++---- .../cloud/transfers/test_gcs_to_bigquery.py | 48 ++---- .../google/cloud/transfers/test_gcs_to_gcs.py | 10 +- tests/providers/mysql/operators/test_mysql.py | 24 +-- .../openlineage/plugins/test_adapter.py | 3 + tests/providers/openlineage/test_sqlparser.py | 6 +- tests/providers/sftp/operators/test_sftp.py | 10 +- .../snowflake/operators/test_snowflake_sql.py | 34 +--- .../transfers/test_copy_into_snowflake.py | 30 +--- tests/providers/trino/operators/test_trino.py | 24 +-- 55 files changed, 527 insertions(+), 826 deletions(-) create mode 100644 airflow/providers/common/compat/openlineage/__init__.py create mode 100644 airflow/providers/common/compat/openlineage/facet.py create mode 100644 tests/providers/common/compat/openlineage/__init__.py create mode 100644 tests/providers/common/compat/openlineage/test_facet.py diff --git a/airflow/providers/amazon/aws/operators/athena.py b/airflow/providers/amazon/aws/operators/athena.py index 255819b6bfc3..d48ac751d423 100644 --- a/airflow/providers/amazon/aws/operators/athena.py +++ b/airflow/providers/amazon/aws/operators/athena.py @@ -30,9 +30,7 @@ from airflow.providers.amazon.aws.utils.mixins import aws_template_fields if TYPE_CHECKING: - from openlineage.client.event_v2 import Dataset - from openlineage.client.facet_v2 import BaseFacet, DatasetFacet - + from airflow.providers.common.compat.openlineage.facet import BaseFacet, Dataset, DatasetFacet from airflow.providers.openlineage.extractors.base import OperatorLineage from airflow.utils.context import Context @@ -217,38 +215,19 @@ def get_openlineage_facets_on_complete(self, _) -> OperatorLineage: path where the results are saved (user's prefix + some UUID), we are creating a dataset with the user-provided path only. This should make it easier to match this dataset across different processes. """ - if TYPE_CHECKING: - from openlineage.client.event_v2 import Dataset - from openlineage.client.generated.external_query_run import ExternalQueryRunFacet - from openlineage.client.generated.extraction_error_run import ( - Error, - ExtractionErrorRunFacet, - ) - from openlineage.client.generated.sql_job import SQLJobFacet as SqlJobFacet - else: - try: - from openlineage.client.event_v2 import Dataset - from openlineage.client.generated.external_query_run import ExternalQueryRunFacet - from openlineage.client.generated.extraction_error_run import ( - Error, - ExtractionErrorRunFacet, - ) - from openlineage.client.generated.sql_job import SQLJobFacet as SqlJobFacet - except ImportError: - from openlineage.client.facet import ( - ExternalQueryRunFacet, - ExtractionError as Error, - ExtractionErrorRunFacet, - SqlJobFacet, - ) - from openlineage.client.run import Dataset - + from airflow.providers.common.compat.openlineage.facet import ( + Dataset, + Error, + ExternalQueryRunFacet, + ExtractionErrorRunFacet, + SQLJobFacet, + ) from airflow.providers.openlineage.extractors.base import OperatorLineage from airflow.providers.openlineage.sqlparser import SQLParser sql_parser = SQLParser(dialect="generic") - job_facets: dict[str, BaseFacet] = {"sql": SqlJobFacet(query=sql_parser.normalize_sql(self.query))} + job_facets: dict[str, BaseFacet] = {"sql": SQLJobFacet(query=sql_parser.normalize_sql(self.query))} parse_result = sql_parser.parse(sql=self.query) if not parse_result: @@ -302,29 +281,13 @@ def get_openlineage_facets_on_complete(self, _) -> OperatorLineage: return OperatorLineage(job_facets=job_facets, run_facets=run_facets, inputs=inputs, outputs=outputs) def get_openlineage_dataset(self, database, table) -> Dataset | None: - if TYPE_CHECKING: - from openlineage.client.event_v2 import Dataset - from openlineage.client.generated.schema_dataset import ( - SchemaDatasetFacet, - SchemaDatasetFacetFields, - ) - from openlineage.client.generated.symlinks_dataset import Identifier, SymlinksDatasetFacet - else: - try: - from openlineage.client.event_v2 import Dataset - from openlineage.client.generated.schema_dataset import ( - SchemaDatasetFacet, - SchemaDatasetFacetFields, - ) - from openlineage.client.generated.symlinks_dataset import Identifier, SymlinksDatasetFacet - except ImportError: - from openlineage.client.facet import ( - SchemaDatasetFacet, - SchemaField as SchemaDatasetFacetFields, - SymlinksDatasetFacet, - SymlinksDatasetFacetIdentifiers as Identifier, - ) - from openlineage.client.run import Dataset + from airflow.providers.common.compat.openlineage.facet import ( + Dataset, + Identifier, + SchemaDatasetFacet, + SchemaDatasetFacetFields, + SymlinksDatasetFacet, + ) client = self.hook.get_conn() try: diff --git a/airflow/providers/amazon/aws/operators/s3.py b/airflow/providers/amazon/aws/operators/s3.py index f9c4b8808fe2..669a6ad25aff 100644 --- a/airflow/providers/amazon/aws/operators/s3.py +++ b/airflow/providers/amazon/aws/operators/s3.py @@ -324,11 +324,7 @@ def execute(self, context: Context): ) def get_openlineage_facets_on_start(self): - try: - from openlineage.client.event_v2 import Dataset - except ImportError: - from openlineage.client.run import Dataset - + from airflow.providers.common.compat.openlineage.facet import Dataset from airflow.providers.openlineage.extractors import OperatorLineage dest_bucket_name, dest_bucket_key = S3Hook.get_s3_bucket_key( @@ -442,11 +438,7 @@ def execute(self, context: Context): s3_hook.load_bytes(self.data, s3_key, s3_bucket, self.replace, self.encrypt, self.acl_policy) def get_openlineage_facets_on_start(self): - try: - from openlineage.client.event_v2 import Dataset - except ImportError: - from openlineage.client.run import Dataset - + from airflow.providers.common.compat.openlineage.facet import Dataset from airflow.providers.openlineage.extractors import OperatorLineage bucket, key = S3Hook.get_s3_bucket_key(self.s3_bucket, self.s3_key, "dest_bucket", "dest_key") @@ -552,29 +544,12 @@ def execute(self, context: Context): def get_openlineage_facets_on_complete(self, task_instance): """Implement _on_complete because object keys are resolved in execute().""" - if TYPE_CHECKING: - from openlineage.client.event_v2 import Dataset - from openlineage.client.generated.lifecycle_state_change_dataset import ( - LifecycleStateChange, - LifecycleStateChangeDatasetFacet, - PreviousIdentifier, - ) - else: - try: - from openlineage.client.event_v2 import Dataset - from openlineage.client.generated.lifecycle_state_change_dataset import ( - LifecycleStateChange, - LifecycleStateChangeDatasetFacet, - PreviousIdentifier, - ) - except ImportError: - from openlineage.client.facet import ( - LifecycleStateChange, - LifecycleStateChangeDatasetFacet, - LifecycleStateChangeDatasetFacetPreviousIdentifier as PreviousIdentifier, - ) - from openlineage.client.run import Dataset - + from airflow.providers.common.compat.openlineage.facet import ( + Dataset, + LifecycleStateChange, + LifecycleStateChangeDatasetFacet, + PreviousIdentifier, + ) from airflow.providers.openlineage.extractors import OperatorLineage if not self._keys: @@ -747,11 +722,7 @@ def execute(self, context: Context): self.log.info("Upload successful") def get_openlineage_facets_on_start(self): - try: - from openlineage.client.event_v2 import Dataset - except ImportError: - from openlineage.client.run import Dataset - + from airflow.providers.common.compat.openlineage.facet import Dataset from airflow.providers.openlineage.extractors import OperatorLineage dest_bucket_name, dest_bucket_key = S3Hook.get_s3_bucket_key( diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py b/airflow/providers/amazon/aws/operators/sagemaker.py index 5e7a1dfbb641..82d208816de9 100644 --- a/airflow/providers/amazon/aws/operators/sagemaker.py +++ b/airflow/providers/amazon/aws/operators/sagemaker.py @@ -46,8 +46,7 @@ from airflow.utils.json import AirflowJsonEncoder if TYPE_CHECKING: - from openlineage.client.event_v2 import Dataset - + from airflow.providers.common.compat.openlineage.facet import Dataset from airflow.providers.openlineage.extractors.base import OperatorLineage from airflow.utils.context import Context @@ -208,11 +207,7 @@ def hook(self): @staticmethod def path_to_s3_dataset(path) -> Dataset: - if not TYPE_CHECKING: - try: - from openlineage.client.event_v2 import Dataset - except ImportError: - from openlineage.client.run import Dataset + from airflow.providers.common.compat.openlineage.facet import Dataset path = path.replace("s3://", "") split_path = path.split("/") diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index 309abcc23ad2..3aa3f4005959 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -89,6 +89,7 @@ versions: dependencies: - apache-airflow>=2.7.0 + - apache-airflow-providers-common-compat>=1.1.0 - apache-airflow-providers-common-sql>=1.3.1 - apache-airflow-providers-http - apache-airflow-providers-common-compat>=1.1.0 diff --git a/airflow/providers/common/compat/__init__.py b/airflow/providers/common/compat/__init__.py index 83f7bcecf15f..449005683d75 100644 --- a/airflow/providers/common/compat/__init__.py +++ b/airflow/providers/common/compat/__init__.py @@ -29,7 +29,7 @@ __all__ = ["__version__"] -__version__ = "1.0.0" +__version__ = "1.1.0" if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse( "2.7.0" diff --git a/airflow/providers/common/compat/openlineage/__init__.py b/airflow/providers/common/compat/openlineage/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/airflow/providers/common/compat/openlineage/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/common/compat/openlineage/facet.py b/airflow/providers/common/compat/openlineage/facet.py new file mode 100644 index 000000000000..e7d4ba352ef8 --- /dev/null +++ b/airflow/providers/common/compat/openlineage/facet.py @@ -0,0 +1,158 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + + +def create_no_op(*_, **__) -> None: + """ + Create a no-op placeholder. + + This function creates and returns a None value, used as a placeholder when the OpenLineage client + library is available. It represents an action that has no effect. + """ + return None + + +if TYPE_CHECKING: + from openlineage.client.generated.base import ( + BaseFacet, + Dataset, + DatasetFacet, + InputDataset, + OutputDataset, + RunFacet, + ) + from openlineage.client.generated.column_lineage_dataset import ( + ColumnLineageDatasetFacet, + Fields, + InputField, + ) + from openlineage.client.generated.documentation_dataset import DocumentationDatasetFacet + from openlineage.client.generated.error_message_run import ErrorMessageRunFacet + from openlineage.client.generated.external_query_run import ExternalQueryRunFacet + from openlineage.client.generated.extraction_error_run import Error, ExtractionErrorRunFacet + from openlineage.client.generated.lifecycle_state_change_dataset import ( + LifecycleStateChange, + LifecycleStateChangeDatasetFacet, + PreviousIdentifier, + ) + from openlineage.client.generated.output_statistics_output_dataset import ( + OutputStatisticsOutputDatasetFacet, + ) + from openlineage.client.generated.schema_dataset import SchemaDatasetFacet, SchemaDatasetFacetFields + from openlineage.client.generated.sql_job import SQLJobFacet + from openlineage.client.generated.symlinks_dataset import Identifier, SymlinksDatasetFacet +else: + try: + try: + from openlineage.client.generated.base import ( + BaseFacet, + Dataset, + DatasetFacet, + InputDataset, + OutputDataset, + RunFacet, + ) + from openlineage.client.generated.column_lineage_dataset import ( + ColumnLineageDatasetFacet, + Fields, + InputField, + ) + from openlineage.client.generated.documentation_dataset import DocumentationDatasetFacet + from openlineage.client.generated.error_message_run import ErrorMessageRunFacet + from openlineage.client.generated.external_query_run import ExternalQueryRunFacet + from openlineage.client.generated.extraction_error_run import Error, ExtractionErrorRunFacet + from openlineage.client.generated.lifecycle_state_change_dataset import ( + LifecycleStateChange, + LifecycleStateChangeDatasetFacet, + PreviousIdentifier, + ) + from openlineage.client.generated.output_statistics_output_dataset import ( + OutputStatisticsOutputDatasetFacet, + ) + from openlineage.client.generated.schema_dataset import ( + SchemaDatasetFacet, + SchemaDatasetFacetFields, + ) + from openlineage.client.generated.sql_job import SQLJobFacet + from openlineage.client.generated.symlinks_dataset import Identifier, SymlinksDatasetFacet + except ImportError: + from openlineage.client.facet import ( + BaseFacet, + BaseFacet as DatasetFacet, + BaseFacet as RunFacet, + ColumnLineageDatasetFacet, + ColumnLineageDatasetFacetFieldsAdditional as Fields, + ColumnLineageDatasetFacetFieldsAdditionalInputFields as InputField, + DocumentationDatasetFacet, + ErrorMessageRunFacet, + ExternalQueryRunFacet, + ExtractionError as Error, + ExtractionErrorRunFacet, + LifecycleStateChange, + LifecycleStateChangeDatasetFacet, + LifecycleStateChangeDatasetFacetPreviousIdentifier as PreviousIdentifier, + OutputStatisticsOutputDatasetFacet, + SchemaDatasetFacet, + SchemaField as SchemaDatasetFacetFields, + SqlJobFacet as SQLJobFacet, + SymlinksDatasetFacet, + SymlinksDatasetFacetIdentifiers as Identifier, + ) + from openlineage.client.run import Dataset, InputDataset, OutputDataset + except ImportError: + # When no openlineage client library installed we create no-op classes. + # This allows avoiding raising ImportError when making OL imports in top-level code + # (which shouldn't be the case anyway). + BaseFacet = Dataset = DatasetFacet = InputDataset = OutputDataset = RunFacet = ( + ColumnLineageDatasetFacet + ) = Fields = InputField = DocumentationDatasetFacet = ErrorMessageRunFacet = ExternalQueryRunFacet = ( + Error + ) = ExtractionErrorRunFacet = LifecycleStateChange = LifecycleStateChangeDatasetFacet = ( + PreviousIdentifier + ) = OutputStatisticsOutputDatasetFacet = SchemaDatasetFacet = SchemaDatasetFacetFields = ( + SQLJobFacet + ) = Identifier = SymlinksDatasetFacet = create_no_op + +__all__ = [ + "BaseFacet", + "Dataset", + "DatasetFacet", + "InputDataset", + "OutputDataset", + "RunFacet", + "ColumnLineageDatasetFacet", + "Fields", + "InputField", + "DocumentationDatasetFacet", + "ErrorMessageRunFacet", + "ExternalQueryRunFacet", + "Error", + "ExtractionErrorRunFacet", + "LifecycleStateChange", + "LifecycleStateChangeDatasetFacet", + "PreviousIdentifier", + "OutputStatisticsOutputDatasetFacet", + "SchemaDatasetFacet", + "SchemaDatasetFacetFields", + "SQLJobFacet", + "Identifier", + "SymlinksDatasetFacet", +] diff --git a/airflow/providers/common/io/operators/file_transfer.py b/airflow/providers/common/io/operators/file_transfer.py index 25f5d7169f04..43957ed8aa90 100644 --- a/airflow/providers/common/io/operators/file_transfer.py +++ b/airflow/providers/common/io/operators/file_transfer.py @@ -75,14 +75,7 @@ def execute(self, context: Context) -> None: src.copy(dst) def get_openlineage_facets_on_start(self) -> OperatorLineage: - if TYPE_CHECKING: - from openlineage.client.event_v2 import Dataset - else: - try: - from openlineage.client.event_v2 import Dataset - except ImportError: - from openlineage.client.run import Dataset - + from airflow.providers.common.compat.openlineage.facet import Dataset from airflow.providers.openlineage.extractors import OperatorLineage def _prepare_ol_dataset(path: ObjectStoragePath) -> Dataset: diff --git a/airflow/providers/ftp/operators/ftp.py b/airflow/providers/ftp/operators/ftp.py index 856d70dcd533..8a8c97667106 100644 --- a/airflow/providers/ftp/operators/ftp.py +++ b/airflow/providers/ftp/operators/ftp.py @@ -26,8 +26,6 @@ from pathlib import Path from typing import Any, Sequence -from git import TYPE_CHECKING - from airflow.models import BaseOperator from airflow.providers.ftp.hooks.ftp import FTPHook, FTPSHook @@ -148,14 +146,7 @@ def get_openlineage_facets_on_start(self): input: file://hostname/path output file://:/path. """ - if TYPE_CHECKING: - from openlineage.client.event_v2 import Dataset - else: - try: - from openlineage.client.event_v2 import Dataset - except ImportError: - from openlineage.client.run import Dataset - + from airflow.providers.common.compat.openlineage.facet import Dataset from airflow.providers.openlineage.extractors import OperatorLineage scheme = "file" diff --git a/airflow/providers/google/cloud/openlineage/mixins.py b/airflow/providers/google/cloud/openlineage/mixins.py index 71c41273c107..6ea744ad8636 100644 --- a/airflow/providers/google/cloud/openlineage/mixins.py +++ b/airflow/providers/google/cloud/openlineage/mixins.py @@ -23,13 +23,14 @@ from typing import TYPE_CHECKING, cast if TYPE_CHECKING: - from openlineage.client.event_v2 import Dataset, InputDataset, OutputDataset - from openlineage.client.generated.base import RunFacet - from openlineage.client.generated.output_statistics_output_dataset import ( + from airflow.providers.common.compat.openlineage.facet import ( + Dataset, + InputDataset, + OutputDataset, OutputStatisticsOutputDatasetFacet, + RunFacet, + SchemaDatasetFacet, ) - from openlineage.client.generated.schema_dataset import SchemaDatasetFacet - from airflow.providers.google.cloud.openlineage.utils import BigQueryJobRunFacet @@ -64,16 +65,7 @@ def get_openlineage_facets_on_complete(self, _): - SchemaDatasetFacet - OutputStatisticsOutputDatasetFacet """ - if TYPE_CHECKING: - from openlineage.client.generated.external_query_run import ExternalQueryRunFacet - from openlineage.client.generated.sql_job import SQLJobFacet - else: - try: - from openlineage.client.generated.external_query_run import ExternalQueryRunFacet - from openlineage.client.generated.sql_job import SQLJobFacet - except ImportError: - from openlineage.client.facet import ExternalQueryRunFacet, SqlJobFacet as SQLJobFacet - + from airflow.providers.common.compat.openlineage.facet import ExternalQueryRunFacet, SQLJobFacet from airflow.providers.openlineage.extractors import OperatorLineage from airflow.providers.openlineage.sqlparser import SQLParser @@ -115,19 +107,12 @@ def get_openlineage_facets_on_complete(self, _): ) def get_facets(self, job_id: str): + from airflow.providers.common.compat.openlineage.facet import ErrorMessageRunFacet from airflow.providers.google.cloud.openlineage.utils import ( BigQueryErrorRunFacet, get_from_nullable_chain, ) - if TYPE_CHECKING: - from openlineage.client.generated.error_message_run import ErrorMessageRunFacet - else: - try: - from openlineage.client.generated.error_message_run import ErrorMessageRunFacet - except ImportError: - from openlineage.client.facet import ErrorMessageRunFacet - inputs = [] outputs = [] run_facets: dict[str, RunFacet] = {} @@ -234,16 +219,9 @@ def _get_bigquery_job_run_facet(properties: dict) -> BigQueryJobRunFacet: def _get_statistics_dataset_facet( properties, ) -> OutputStatisticsOutputDatasetFacet | None: + from airflow.providers.common.compat.openlineage.facet import OutputStatisticsOutputDatasetFacet from airflow.providers.google.cloud.openlineage.utils import get_from_nullable_chain - if not TYPE_CHECKING: - try: - from openlineage.client.generated.output_statistics_output_dataset import ( - OutputStatisticsOutputDatasetFacet, - ) - except ImportError: - from openlineage.client.facet import OutputStatisticsOutputDatasetFacet - query_plan = get_from_nullable_chain(properties, chain=["statistics", "query", "queryPlan"]) if not query_plan: return None @@ -256,27 +234,18 @@ def _get_statistics_dataset_facet( return None def _get_input_dataset(self, table: dict) -> InputDataset: - if not TYPE_CHECKING: - try: - from openlineage.client.generated.base import InputDataset - except ImportError: - from openlineage.client.run import InputDataset + from airflow.providers.common.compat.openlineage.facet import InputDataset + return cast(InputDataset, self._get_dataset(table, "input")) def _get_output_dataset(self, table: dict) -> OutputDataset: - if not TYPE_CHECKING: - try: - from openlineage.client.generated.base import OutputDataset - except ImportError: - from openlineage.client.run import OutputDataset + from airflow.providers.common.compat.openlineage.facet import OutputDataset + return cast(OutputDataset, self._get_dataset(table, "output")) def _get_dataset(self, table: dict, dataset_type: str) -> Dataset: - if not TYPE_CHECKING: - try: - from openlineage.client.generated.base import InputDataset, OutputDataset - except ImportError: - from openlineage.client.run import InputDataset, OutputDataset + from airflow.providers.common.compat.openlineage.facet import InputDataset, OutputDataset + project = table.get("projectId") dataset = table.get("datasetId") table_name = table.get("tableId") @@ -317,25 +286,12 @@ def _get_table_schema_safely(self, table_name: str) -> SchemaDatasetFacet | None return None def _get_table_schema(self, table: str) -> SchemaDatasetFacet | None: + from airflow.providers.common.compat.openlineage.facet import ( + SchemaDatasetFacet, + SchemaDatasetFacetFields, + ) from airflow.providers.google.cloud.openlineage.utils import get_from_nullable_chain - if TYPE_CHECKING: - from openlineage.client.generated.schema_dataset import ( - SchemaDatasetFacet, - SchemaDatasetFacetFields, - ) - else: - try: - from openlineage.client.generated.schema_dataset import ( - SchemaDatasetFacet, - SchemaDatasetFacetFields, - ) - except ImportError: - from openlineage.client.facet import ( - SchemaDatasetFacet, - SchemaField as SchemaDatasetFacetFields, - ) - bq_table = self.client.get_table(table) if not bq_table._properties: diff --git a/airflow/providers/google/cloud/openlineage/utils.py b/airflow/providers/google/cloud/openlineage/utils.py index 1c0624592217..82172d5d241c 100644 --- a/airflow/providers/google/cloud/openlineage/utils.py +++ b/airflow/providers/google/cloud/openlineage/utils.py @@ -23,36 +23,18 @@ if TYPE_CHECKING: from google.cloud.bigquery.table import Table - from openlineage.client.event_v2 import Dataset - from openlineage.client.generated.base import RunFacet - from openlineage.client.generated.column_lineage_dataset import ( - ColumnLineageDatasetFacet, - Fields, - InputField, - ) - from openlineage.client.generated.documentation_dataset import DocumentationDatasetFacet - from openlineage.client.generated.schema_dataset import SchemaDatasetFacet, SchemaDatasetFacetFields -else: - try: - from openlineage.client.generated.base import RunFacet - from openlineage.client.generated.column_lineage_dataset import ( - ColumnLineageDatasetFacet, - Fields, - InputField, - ) - from openlineage.client.generated.documentation_dataset import DocumentationDatasetFacet - from openlineage.client.generated.schema_dataset import SchemaDatasetFacet, SchemaDatasetFacetFields - except ImportError: - from openlineage.client.facet import ( - BaseFacet as RunFacet, - ColumnLineageDatasetFacet, - ColumnLineageDatasetFacetFieldsAdditional as Fields, - ColumnLineageDatasetFacetFieldsAdditionalInputFields as InputField, - DocumentationDatasetFacet, - SchemaDatasetFacet, - SchemaField as SchemaDatasetFacetFields, - ) + from airflow.providers.common.compat.openlineage.facet import Dataset + +from airflow.providers.common.compat.openlineage.facet import ( + ColumnLineageDatasetFacet, + DocumentationDatasetFacet, + Fields, + InputField, + RunFacet, + SchemaDatasetFacet, + SchemaDatasetFacetFields, +) from airflow.providers.google import __version__ as provider_version BIGQUERY_NAMESPACE = "bigquery" diff --git a/airflow/providers/google/cloud/operators/gcs.py b/airflow/providers/google/cloud/operators/gcs.py index f18e49cbb5e7..c396e173eaad 100644 --- a/airflow/providers/google/cloud/operators/gcs.py +++ b/airflow/providers/google/cloud/operators/gcs.py @@ -336,29 +336,12 @@ def execute(self, context: Context) -> None: hook.delete(bucket_name=self.bucket_name, object_name=object_name) def get_openlineage_facets_on_start(self): - if TYPE_CHECKING: - from openlineage.client.event_v2 import Dataset - from openlineage.client.generated.lifecycle_state_change_dataset import ( - LifecycleStateChange, - LifecycleStateChangeDatasetFacet, - PreviousIdentifier, - ) - else: - try: - from openlineage.client.event_v2 import Dataset - from openlineage.client.generated.lifecycle_state_change_dataset import ( - LifecycleStateChange, - LifecycleStateChangeDatasetFacet, - PreviousIdentifier, - ) - except ImportError: - from openlineage.client.facet import ( - LifecycleStateChange, - LifecycleStateChangeDatasetFacet, - LifecycleStateChangeDatasetFacetPreviousIdentifier as PreviousIdentifier, - ) - from openlineage.client.run import Dataset - + from airflow.providers.common.compat.openlineage.facet import ( + Dataset, + LifecycleStateChange, + LifecycleStateChangeDatasetFacet, + PreviousIdentifier, + ) from airflow.providers.openlineage.extractors import OperatorLineage objects = [] @@ -661,14 +644,7 @@ def execute(self, context: Context) -> None: ) def get_openlineage_facets_on_start(self): - if TYPE_CHECKING: - from openlineage.client.event_v2 import Dataset - else: - try: - from openlineage.client.event_v2 import Dataset - except ImportError: - from openlineage.client.run import Dataset - + from airflow.providers.common.compat.openlineage.facet import Dataset from airflow.providers.openlineage.extractors import OperatorLineage input_dataset = Dataset( @@ -943,12 +919,7 @@ def execute(self, context: Context) -> list[str]: def get_openlineage_facets_on_complete(self, task_instance): """Implement on_complete as execute() resolves object prefixes.""" - if not TYPE_CHECKING: - try: - from openlineage.client.event_v2 import Dataset - except ImportError: - from openlineage.client.run import Dataset - + from airflow.providers.common.compat.openlineage.facet import Dataset from airflow.providers.openlineage.extractors import OperatorLineage def _parse_prefix(pref): diff --git a/airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py b/airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py index 683bfcdda8f6..9642339ba8b1 100644 --- a/airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py @@ -124,8 +124,7 @@ def execute(self, context: Context) -> str: return f"gs://{self.bucket_name}/{self.object_name}" def get_openlineage_facets_on_start(self): - from openlineage.client.run import Dataset - + from airflow.providers.common.compat.openlineage.facet import Dataset from airflow.providers.openlineage.extractors import OperatorLineage wasb_hook = WasbHook(wasb_conn_id=self.wasb_conn_id) diff --git a/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py b/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py index 5c9df82c1488..93ef12dcbd49 100644 --- a/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py @@ -289,22 +289,12 @@ def get_openlineage_facets_on_complete(self, task_instance): """Implement on_complete as we will include final BQ job id.""" from pathlib import Path - if TYPE_CHECKING: - from openlineage.client.event_v2 import Dataset - from openlineage.client.generated.external_query_run import ExternalQueryRunFacet - from openlineage.client.generated.symlinks_dataset import Identifier, SymlinksDatasetFacet - else: - try: - from openlineage.client.event_v2 import Dataset - from openlineage.client.generated.external_query_run import ExternalQueryRunFacet - from openlineage.client.generated.symlinks_dataset import Identifier, SymlinksDatasetFacet - except ImportError: - from openlineage.client.facet import ( - ExternalQueryRunFacet, - SymlinksDatasetFacet, - SymlinksDatasetFacetIdentifiers as Identifier, - ) - from openlineage.client.run import Dataset + from airflow.providers.common.compat.openlineage.facet import ( + Dataset, + ExternalQueryRunFacet, + Identifier, + SymlinksDatasetFacet, + ) from airflow.providers.google.cloud.hooks.gcs import _parse_gcs_url from airflow.providers.google.cloud.openlineage.utils import ( get_facets_from_bq_table, diff --git a/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py b/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py index 4fbb042df0a4..451de3fa4c07 100644 --- a/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +++ b/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py @@ -746,23 +746,12 @@ def get_openlineage_facets_on_complete(self, task_instance): """Implement on_complete as we will include final BQ job id.""" from pathlib import Path - if TYPE_CHECKING: - from openlineage.client.event_v2 import Dataset - from openlineage.client.generated.external_query_run import ExternalQueryRunFacet - from openlineage.client.generated.symlinks_dataset import Identifier, SymlinksDatasetFacet - else: - try: - from openlineage.client.event_v2 import Dataset - from openlineage.client.generated.external_query_run import ExternalQueryRunFacet - from openlineage.client.generated.symlinks_dataset import Identifier, SymlinksDatasetFacet - except ImportError: - from openlineage.client.facet import ( - ExternalQueryRunFacet, - SymlinksDatasetFacet, - SymlinksDatasetFacetIdentifiers as Identifier, - ) - from openlineage.client.run import Dataset - + from airflow.providers.common.compat.openlineage.facet import ( + Dataset, + ExternalQueryRunFacet, + Identifier, + SymlinksDatasetFacet, + ) from airflow.providers.google.cloud.openlineage.utils import ( get_facets_from_bq_table, get_identity_column_lineage_facet, diff --git a/airflow/providers/google/cloud/transfers/gcs_to_gcs.py b/airflow/providers/google/cloud/transfers/gcs_to_gcs.py index c788fb71d1ac..8980436a1053 100644 --- a/airflow/providers/google/cloud/transfers/gcs_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/gcs_to_gcs.py @@ -552,14 +552,7 @@ def get_openlineage_facets_on_complete(self, task_instance): """ from pathlib import Path - if TYPE_CHECKING: - from openlineage.client.event_v2 import Dataset - else: - try: - from openlineage.client.event_v2 import Dataset - except ImportError: - from openlineage.client.run import Dataset - + from airflow.providers.common.compat.openlineage.facet import Dataset from airflow.providers.openlineage.extractors import OperatorLineage def _process_prefix(pref): diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml index 9c186ed82489..062fc071c514 100644 --- a/airflow/providers/google/provider.yaml +++ b/airflow/providers/google/provider.yaml @@ -93,6 +93,7 @@ versions: dependencies: - apache-airflow>=2.7.0 + - apache-airflow-providers-common-compat>=1.1.0 - apache-airflow-providers-common-sql>=1.7.2 - asgiref>=3.5.2 - dill>=0.2.3 diff --git a/airflow/providers/openlineage/plugins/adapter.py b/airflow/providers/openlineage/plugins/adapter.py index 398ef5a8f6c5..1d0317228b83 100644 --- a/airflow/providers/openlineage/plugins/adapter.py +++ b/airflow/providers/openlineage/plugins/adapter.py @@ -303,7 +303,7 @@ def fail_task( stack_trace = "\\n".join(traceback.format_exception(type(error), error, error.__traceback__)) error_facet = { "errorMessage": error_message_run.ErrorMessageRunFacet( - message=error, programmingLanguage="python", stackTrace=stack_trace + message=str(error), programmingLanguage="python", stackTrace=stack_trace ) } diff --git a/airflow/providers/sftp/operators/sftp.py b/airflow/providers/sftp/operators/sftp.py index 28cb42092dbd..e04f68cb0987 100644 --- a/airflow/providers/sftp/operators/sftp.py +++ b/airflow/providers/sftp/operators/sftp.py @@ -23,7 +23,7 @@ import socket import warnings from pathlib import Path -from typing import TYPE_CHECKING, Any, Sequence +from typing import Any, Sequence import paramiko @@ -201,14 +201,7 @@ def get_openlineage_facets_on_start(self): input: file:///path output: file://:/path. """ - if TYPE_CHECKING: - from openlineage.client.event_v2 import Dataset - else: - try: - from openlineage.client.event_v2 import Dataset - except ImportError: - from openlineage.client.run import Dataset - + from airflow.providers.common.compat.openlineage.facet import Dataset from airflow.providers.openlineage.extractors import OperatorLineage scheme = "file" diff --git a/airflow/providers/snowflake/hooks/snowflake.py b/airflow/providers/snowflake/hooks/snowflake.py index bc6462efdaa9..a09217f9e5ca 100644 --- a/airflow/providers/snowflake/hooks/snowflake.py +++ b/airflow/providers/snowflake/hooks/snowflake.py @@ -480,14 +480,7 @@ def _get_openlineage_authority(self, _) -> str | None: return urlparse(uri).hostname def get_openlineage_database_specific_lineage(self, _) -> OperatorLineage | None: - if TYPE_CHECKING: - from openlineage.client.generated.external_query_run import ExternalQueryRunFacet - else: - try: - from openlineage.client.generated.external_query_run import ExternalQueryRunFacet - except ImportError: - from openlineage.client.facet import ExternalQueryRunFacet - + from airflow.providers.common.compat.openlineage.facet import ExternalQueryRunFacet from airflow.providers.openlineage.extractors import OperatorLineage from airflow.providers.openlineage.sqlparser import SQLParser diff --git a/airflow/providers/snowflake/provider.yaml b/airflow/providers/snowflake/provider.yaml index 36427b6c8a46..100197ad6270 100644 --- a/airflow/providers/snowflake/provider.yaml +++ b/airflow/providers/snowflake/provider.yaml @@ -76,6 +76,7 @@ versions: dependencies: - apache-airflow>=2.7.0 + - apache-airflow-providers-common-compat>=1.1.0 - apache-airflow-providers-common-sql>=1.14.1 # In pandas 2.2 minimal version of the sqlalchemy is 2.0 # https://pandas.pydata.org/docs/whatsnew/v2.2.0.html#increased-minimum-versions-for-dependencies diff --git a/airflow/providers/snowflake/transfers/copy_into_snowflake.py b/airflow/providers/snowflake/transfers/copy_into_snowflake.py index 661d98b3a7f1..066ebf0a4df4 100644 --- a/airflow/providers/snowflake/transfers/copy_into_snowflake.py +++ b/airflow/providers/snowflake/transfers/copy_into_snowflake.py @@ -19,7 +19,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Sequence +from typing import Any, Sequence from airflow.models import BaseOperator from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook @@ -228,26 +228,13 @@ def get_openlineage_facets_on_complete(self, task_instance): """Implement _on_complete because we rely on return value of a query.""" import re - if TYPE_CHECKING: - from openlineage.client.event_v2 import Dataset - from openlineage.client.generated.external_query_run import ExternalQueryRunFacet - from openlineage.client.generated.extraction_error_run import Error, ExtractionErrorRunFacet - from openlineage.client.generated.sql_job import SQLJobFacet - else: - try: - from openlineage.client.event_v2 import Dataset - from openlineage.client.generated.external_query_run import ExternalQueryRunFacet - from openlineage.client.generated.extraction_error_run import Error, ExtractionErrorRunFacet - from openlineage.client.generated.sql_job import SQLJobFacet - except ImportError: - from openlineage.client.facet import ( - ExternalQueryRunFacet, - ExtractionError as Error, - ExtractionErrorRunFacet, - SqlJobFacet as SQLJobFacet, - ) - from openlineage.client.run import Dataset - + from airflow.providers.common.compat.openlineage.facet import ( + Dataset, + Error, + ExternalQueryRunFacet, + ExtractionErrorRunFacet, + SQLJobFacet, + ) from airflow.providers.openlineage.extractors import OperatorLineage from airflow.providers.openlineage.sqlparser import SQLParser diff --git a/dev/breeze/tests/test_packages.py b/dev/breeze/tests/test_packages.py index 8d5bc0b6791c..228a1ca0dc5e 100644 --- a/dev/breeze/tests/test_packages.py +++ b/dev/breeze/tests/test_packages.py @@ -236,6 +236,7 @@ def test_get_install_requirements(provider: str, version_suffix: str, expected: "apache.beam": ["apache-airflow-providers-apache-beam", "apache-beam[gcp]"], "apache.cassandra": ["apache-airflow-providers-apache-cassandra"], "cncf.kubernetes": ["apache-airflow-providers-cncf-kubernetes>=7.2.0"], + "common.compat": ["apache-airflow-providers-common-compat"], "common.sql": ["apache-airflow-providers-common-sql"], "facebook": ["apache-airflow-providers-facebook>=2.2.0"], "leveldb": ["plyvel"], @@ -260,6 +261,7 @@ def test_get_install_requirements(provider: str, version_suffix: str, expected: "apache.beam": ["apache-airflow-providers-apache-beam", "apache-beam[gcp]"], "apache.cassandra": ["apache-airflow-providers-apache-cassandra"], "cncf.kubernetes": ["apache-airflow-providers-cncf-kubernetes>=7.2.0.dev0"], + "common.compat": ["apache-airflow-providers-common-compat"], "common.sql": ["apache-airflow-providers-common-sql"], "facebook": ["apache-airflow-providers-facebook>=2.2.0.dev0"], "leveldb": ["plyvel"], @@ -284,6 +286,7 @@ def test_get_install_requirements(provider: str, version_suffix: str, expected: "apache.beam": ["apache-airflow-providers-apache-beam", "apache-beam[gcp]"], "apache.cassandra": ["apache-airflow-providers-apache-cassandra"], "cncf.kubernetes": ["apache-airflow-providers-cncf-kubernetes>=7.2.0b0"], + "common.compat": ["apache-airflow-providers-common-compat"], "common.sql": ["apache-airflow-providers-common-sql"], "facebook": ["apache-airflow-providers-facebook>=2.2.0b0"], "leveldb": ["plyvel"], diff --git a/dev/breeze/tests/test_selective_checks.py b/dev/breeze/tests/test_selective_checks.py index c0c40b9be92b..812d786625d4 100644 --- a/dev/breeze/tests/test_selective_checks.py +++ b/dev/breeze/tests/test_selective_checks.py @@ -649,7 +649,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str): "tests/providers/common/io/operators/test_file_transfer.py", ), { - "affected-providers-list-as-string": "common.io openlineage", + "affected-providers-list-as-string": "common.compat common.io openlineage", "all-python-versions": "['3.8']", "all-python-versions-list-as-string": "3.8", "python-versions": "['3.8']", @@ -663,11 +663,11 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str): "run-kubernetes-tests": "false", "skip-pre-commits": "identity,lint-helm-chart,mypy-airflow,mypy-dev,mypy-docs,mypy-providers,ts-compile-format-lint-www", "upgrade-to-newer-dependencies": "false", - "parallel-test-types-list-as-string": "Always Providers[common.io,openlineage]", + "parallel-test-types-list-as-string": "Always Providers[common.compat,common.io,openlineage]", "needs-mypy": "true", "mypy-folders": "['airflow', 'providers']", }, - id="Only Always and Common.IO tests should run when only common.io and tests/always changed", + id="Only Always and common providers tests should run when only common.io and tests/always changed", ), pytest.param( ("airflow/operators/bash.py",), @@ -1126,7 +1126,7 @@ def test_expected_output_full_tests_needed( ), { "affected-providers-list-as-string": "amazon apache.beam apache.cassandra cncf.kubernetes " - "common.sql facebook google hashicorp microsoft.azure microsoft.mssql " + "common.compat common.sql facebook google hashicorp microsoft.azure microsoft.mssql " "mysql openlineage oracle postgres presto salesforce samba sftp ssh trino", "all-python-versions": "['3.8']", "all-python-versions-list-as-string": "3.8", @@ -1155,7 +1155,7 @@ def test_expected_output_full_tests_needed( ), { "affected-providers-list-as-string": "amazon apache.beam apache.cassandra " - "cncf.kubernetes common.sql facebook google " + "cncf.kubernetes common.compat common.sql facebook google " "hashicorp microsoft.azure microsoft.mssql mysql openlineage oracle postgres " "presto salesforce samba sftp ssh trino", "all-python-versions": "['3.8']", @@ -1272,7 +1272,7 @@ def test_expected_output_pull_request_v2_7( ), { "affected-providers-list-as-string": "amazon apache.beam apache.cassandra " - "cncf.kubernetes common.sql " + "cncf.kubernetes common.compat common.sql " "facebook google hashicorp microsoft.azure microsoft.mssql mysql " "openlineage oracle postgres presto salesforce samba sftp ssh trino", "all-python-versions": "['3.8']", @@ -1283,7 +1283,7 @@ def test_expected_output_pull_request_v2_7( "run-tests": "true", "docs-build": "true", "docs-list-as-string": "apache-airflow helm-chart amazon apache.beam apache.cassandra " - "cncf.kubernetes common.sql facebook google hashicorp microsoft.azure " + "cncf.kubernetes common.compat common.sql facebook google hashicorp microsoft.azure " "microsoft.mssql mysql openlineage oracle postgres " "presto salesforce samba sftp ssh trino", "skip-pre-commits": "identity,mypy-airflow,mypy-dev,mypy-docs,mypy-providers,ts-compile-format-lint-www", @@ -1291,9 +1291,9 @@ def test_expected_output_pull_request_v2_7( "upgrade-to-newer-dependencies": "false", "skip-provider-tests": "false", "parallel-test-types-list-as-string": "Always CLI Providers[amazon] " - "Providers[apache.beam,apache.cassandra,cncf.kubernetes,common.sql,facebook,hashicorp," - "microsoft.azure,microsoft.mssql,mysql,openlineage,oracle,postgres,presto,salesforce," - "samba,sftp,ssh,trino] Providers[google]", + "Providers[apache.beam,apache.cassandra,cncf.kubernetes,common.compat,common.sql,facebook," + "hashicorp,microsoft.azure,microsoft.mssql,mysql,openlineage,oracle,postgres,presto," + "salesforce,samba,sftp,ssh,trino] Providers[google]", "needs-mypy": "true", "mypy-folders": "['airflow', 'providers']", }, @@ -1588,7 +1588,7 @@ def test_upgrade_to_newer_dependencies( ("docs/apache-airflow-providers-google/docs.rst",), { "docs-list-as-string": "amazon apache.beam apache.cassandra " - "cncf.kubernetes common.sql facebook google hashicorp " + "cncf.kubernetes common.compat common.sql facebook google hashicorp " "microsoft.azure microsoft.mssql mysql openlineage oracle " "postgres presto salesforce samba sftp ssh trino", }, diff --git a/docs/apache-airflow-providers-openlineage/guides/developer.rst b/docs/apache-airflow-providers-openlineage/guides/developer.rst index 5d69a1e0bac7..2582bc5fc0cc 100644 --- a/docs/apache-airflow-providers-openlineage/guides/developer.rst +++ b/docs/apache-airflow-providers-openlineage/guides/developer.rst @@ -152,7 +152,7 @@ As there is some processing made in ``execute`` method, and there is no relevant This means we won't have to normalize self.source_object and self.source_objects, destination bucket and so on. """ - from openlineage.client.event_v2 import Dataset + from airflow.providers.common.compat.openlineage.facet import Dataset from airflow.providers.openlineage.extractors import OperatorLineage return OperatorLineage( @@ -303,8 +303,12 @@ like extracting column level lineage and inputs/outputs from SQL query with SQL .. code-block:: python - from openlineage.client.facet_v2 import BaseFacet, external_query_run, sql_job - from openlineage.client.event_v2 import Dataset + from airflow.providers.common.compat.openlineage.facet import ( + BaseFacet, + Dataset, + ExternalQueryRunFacet, + SQLJobFacet, + ) from airflow.models.baseoperator import BaseOperator from airflow.providers.openlineage.extractors.base import BaseExtractor @@ -333,7 +337,7 @@ like extracting column level lineage and inputs/outputs from SQL query with SQL inputs=[Dataset(namespace="bigquery", name=self.bq_table_reference)], outputs=[Dataset(namespace=self.s3_path, name=self.s3_file_name)], job_facets={ - "sql": sql_job.SQLJobFacet( + "sql": SQLJobFacet( query="EXPORT INTO ... OPTIONS(FORMAT=csv, SEP=';' ...) AS SELECT * FROM ... " ) }, @@ -343,9 +347,7 @@ like extracting column level lineage and inputs/outputs from SQL query with SQL """Add what we received after Operator's extract call.""" lineage_metadata = self.extract() lineage_metadata.run_facets = { - "parent": external_query_run.ExternalQueryRunFacet( - externalQueryId=self._job_id, source="bigquery" - ) + "parent": ExternalQueryRunFacet(externalQueryId=self._job_id, source="bigquery") } return lineage_metadata diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 86b9b2e15b81..b8d0853bb9df 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -404,6 +404,7 @@ "devel-deps": [], "plugins": [], "cross-providers-deps": [ + "common.compat", "openlineage" ], "excluded-python-versions": [], @@ -578,6 +579,7 @@ "devel-deps": [], "plugins": [], "cross-providers-deps": [ + "common.compat", "openlineage" ], "excluded-python-versions": [], @@ -597,6 +599,7 @@ "google": { "deps": [ "PyOpenSSL>=23.0.0", + "apache-airflow-providers-common-compat>=1.1.0", "apache-airflow-providers-common-sql>=1.7.2", "apache-airflow>=2.7.0", "asgiref>=3.5.2", @@ -668,6 +671,7 @@ "apache.beam", "apache.cassandra", "cncf.kubernetes", + "common.compat", "common.sql", "facebook", "microsoft.azure", @@ -1151,6 +1155,7 @@ "devel-deps": [], "plugins": [], "cross-providers-deps": [ + "common.compat", "openlineage", "ssh" ], @@ -1194,6 +1199,7 @@ }, "snowflake": { "deps": [ + "apache-airflow-providers-common-compat>=1.1.0", "apache-airflow-providers-common-sql>=1.14.1", "apache-airflow>=2.7.0", "pandas>=1.5.3,<2.2;python_version<\"3.9\"", @@ -1205,6 +1211,7 @@ "devel-deps": [], "plugins": [], "cross-providers-deps": [ + "common.compat", "common.sql", "openlineage" ], diff --git a/hatch_build.py b/hatch_build.py index 2e233e6ba5cd..2ecd4e0aa82d 100644 --- a/hatch_build.py +++ b/hatch_build.py @@ -40,6 +40,7 @@ PROVIDER_DEPENDENCIES = json.loads(GENERATED_PROVIDERS_DEPENDENCIES_FILE.read_text()) PRE_INSTALLED_PROVIDERS = [ + "common.compat", "common.io", "common.sql", "fab>=1.0.2", diff --git a/tests/dags/test_openlineage_execution.py b/tests/dags/test_openlineage_execution.py index 29fb65cf7545..475e43ef6ac2 100644 --- a/tests/dags/test_openlineage_execution.py +++ b/tests/dags/test_openlineage_execution.py @@ -20,10 +20,9 @@ import datetime import time -from openlineage.client.generated.base import Dataset - from airflow.models.dag import DAG from airflow.models.operator import BaseOperator +from airflow.providers.common.compat.openlineage.facet import Dataset from airflow.providers.openlineage.extractors import OperatorLineage diff --git a/tests/providers/amazon/aws/operators/test_athena.py b/tests/providers/amazon/aws/operators/test_athena.py index 976791851b06..f126ec56f3e6 100644 --- a/tests/providers/amazon/aws/operators/test_athena.py +++ b/tests/providers/amazon/aws/operators/test_athena.py @@ -17,39 +17,24 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING from unittest import mock import pytest -if TYPE_CHECKING: - from openlineage.client.event_v2 import Dataset - from openlineage.client.generated.external_query_run import ExternalQueryRunFacet - from openlineage.client.generated.schema_dataset import SchemaDatasetFacet, SchemaDatasetFacetFields - from openlineage.client.generated.sql_job import SQLJobFacet - from openlineage.client.generated.symlinks_dataset import Identifier, SymlinksDatasetFacet -else: - try: - from openlineage.client.event_v2 import Dataset - from openlineage.client.generated.external_query_run import ExternalQueryRunFacet - from openlineage.client.generated.schema_dataset import SchemaDatasetFacet, SchemaDatasetFacetFields - from openlineage.client.generated.sql_job import SQLJobFacet - from openlineage.client.generated.symlinks_dataset import Identifier, SymlinksDatasetFacet - except ImportError: - from openlineage.client.facet import ( - SchemaDatasetFacet, - SchemaField as SchemaDatasetFacetFields, - SqlJobFacet as SQLJobFacet, - SymlinksDatasetFacet, - SymlinksDatasetFacetIdentifiers as Identifier, - ) - from openlineage.client.run import Dataset - from airflow.exceptions import AirflowException, TaskDeferred from airflow.models import DAG, DagRun, TaskInstance from airflow.providers.amazon.aws.hooks.athena import AthenaHook from airflow.providers.amazon.aws.operators.athena import AthenaOperator from airflow.providers.amazon.aws.triggers.athena import AthenaTrigger +from airflow.providers.common.compat.openlineage.facet import ( + Dataset, + ExternalQueryRunFacet, + Identifier, + SchemaDatasetFacet, + SchemaDatasetFacetFields, + SQLJobFacet, + SymlinksDatasetFacet, +) from airflow.providers.openlineage.extractors import OperatorLineage from airflow.utils import timezone from airflow.utils.timezone import datetime diff --git a/tests/providers/amazon/aws/operators/test_redshift_sql.py b/tests/providers/amazon/aws/operators/test_redshift_sql.py index 010d807286b4..d813b7a7e12e 100644 --- a/tests/providers/amazon/aws/operators/test_redshift_sql.py +++ b/tests/providers/amazon/aws/operators/test_redshift_sql.py @@ -17,43 +17,21 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING from unittest.mock import MagicMock, PropertyMock, call, patch import pytest -if TYPE_CHECKING: - from openlineage.client.event_v2 import Dataset - from openlineage.client.generated.column_lineage_dataset import ( - ColumnLineageDatasetFacet, - Fields, - InputField, - ) - from openlineage.client.generated.schema_dataset import SchemaDatasetFacet, SchemaDatasetFacetFields - from openlineage.client.generated.sql_job import SQLJobFacet -else: - try: - from openlineage.client.event_v2 import Dataset - from openlineage.client.generated.column_lineage_dataset import ( - ColumnLineageDatasetFacet, - Fields, - InputField, - ) - from openlineage.client.generated.schema_dataset import SchemaDatasetFacet, SchemaDatasetFacetFields - from openlineage.client.generated.sql_job import SQLJobFacet - except ImportError: - from openlineage.client.facet import ( - ColumnLineageDatasetFacet, - ColumnLineageDatasetFacetFieldsAdditional as Fields, - ColumnLineageDatasetFacetFieldsAdditionalInputFields as InputField, - SchemaDatasetFacet, - SchemaField as SchemaDatasetFacetFields, - SqlJobFacet as SQLJobFacet, - ) - from openlineage.client.run import Dataset - from airflow.models.connection import Connection from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook as OriginalRedshiftSQLHook +from airflow.providers.common.compat.openlineage.facet import ( + ColumnLineageDatasetFacet, + Dataset, + Fields, + InputField, + SchemaDatasetFacet, + SchemaDatasetFacetFields, + SQLJobFacet, +) from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator MOCK_REGION_NAME = "eu-north-1" @@ -230,64 +208,65 @@ def get_db_hook(self): assert dbapi_hook.get_conn.return_value.cursor.return_value.execute.mock_calls == [] expected_namespace = f"redshift://{expected_identity}:5439" - assert lineage.inputs == [ - Dataset( - namespace=expected_namespace, - name=f"{ANOTHER_DB_NAME}.{ANOTHER_DB_SCHEMA}.popular_orders_day_of_week", - facets={ - "schema": SchemaDatasetFacet( - fields=[ - SchemaDatasetFacetFields(name="order_day_of_week", type="varchar"), - SchemaDatasetFacetFields(name="order_placed_on", type="timestamp"), - SchemaDatasetFacetFields(name="orders_placed", type="int4"), - ] - ) - }, - ), - Dataset( - namespace=expected_namespace, - name=f"{DB_NAME}.{DB_SCHEMA_NAME}.little_table", - facets={ - "schema": SchemaDatasetFacet( - fields=[ - SchemaDatasetFacetFields(name="order_day_of_week", type="varchar"), - SchemaDatasetFacetFields(name="additional_constant", type="varchar"), - ] - ) - }, - ), - ] - assert lineage.outputs == [ - Dataset( - namespace=expected_namespace, - name=f"{DB_NAME}.{DB_SCHEMA_NAME}.test_table", - facets={ - "schema": SchemaDatasetFacet( - fields=[ - SchemaDatasetFacetFields(name="order_day_of_week", type="varchar"), - SchemaDatasetFacetFields(name="order_placed_on", type="timestamp"), - SchemaDatasetFacetFields(name="orders_placed", type="int4"), - SchemaDatasetFacetFields(name="additional_constant", type="varchar"), - ] - ), - "columnLineage": ColumnLineageDatasetFacet( - fields={ - "additional_constant": Fields( - inputFields=[ - InputField( - namespace=expected_namespace, - name="database.public.little_table", - field="additional_constant", - ) - ], - transformationDescription="", - transformationType="", - ) - } - ), - }, - ) - ] + if is_over_210: + assert lineage.inputs == [ + Dataset( + namespace=expected_namespace, + name=f"{ANOTHER_DB_NAME}.{ANOTHER_DB_SCHEMA}.popular_orders_day_of_week", + facets={ + "schema": SchemaDatasetFacet( + fields=[ + SchemaDatasetFacetFields(name="order_day_of_week", type="varchar"), + SchemaDatasetFacetFields(name="order_placed_on", type="timestamp"), + SchemaDatasetFacetFields(name="orders_placed", type="int4"), + ] + ) + }, + ), + Dataset( + namespace=expected_namespace, + name=f"{DB_NAME}.{DB_SCHEMA_NAME}.little_table", + facets={ + "schema": SchemaDatasetFacet( + fields=[ + SchemaDatasetFacetFields(name="order_day_of_week", type="varchar"), + SchemaDatasetFacetFields(name="additional_constant", type="varchar"), + ] + ) + }, + ), + ] + assert lineage.outputs == [ + Dataset( + namespace=expected_namespace, + name=f"{DB_NAME}.{DB_SCHEMA_NAME}.test_table", + facets={ + "schema": SchemaDatasetFacet( + fields=[ + SchemaDatasetFacetFields(name="order_day_of_week", type="varchar"), + SchemaDatasetFacetFields(name="order_placed_on", type="timestamp"), + SchemaDatasetFacetFields(name="orders_placed", type="int4"), + SchemaDatasetFacetFields(name="additional_constant", type="varchar"), + ] + ), + "columnLineage": ColumnLineageDatasetFacet( + fields={ + "additional_constant": Fields( + inputFields=[ + InputField( + namespace=expected_namespace, + name="database.public.little_table", + field="additional_constant", + ) + ], + transformationDescription="", + transformationType="", + ) + } + ), + }, + ) + ] assert lineage.job_facets == {"sql": SQLJobFacet(query=sql)} diff --git a/tests/providers/amazon/aws/operators/test_s3.py b/tests/providers/amazon/aws/operators/test_s3.py index d4d4f016dd09..a7a38cabe81b 100644 --- a/tests/providers/amazon/aws/operators/test_s3.py +++ b/tests/providers/amazon/aws/operators/test_s3.py @@ -23,36 +23,12 @@ import sys from io import BytesIO from tempfile import mkdtemp -from typing import TYPE_CHECKING from unittest import mock import boto3 import pytest from moto import mock_aws -if TYPE_CHECKING: - from openlineage.client.event_v2 import Dataset - from openlineage.client.generated.lifecycle_state_change_dataset import ( - LifecycleStateChange, - LifecycleStateChangeDatasetFacet, - PreviousIdentifier, - ) -else: - try: - from openlineage.client.event_v2 import Dataset - from openlineage.client.generated.lifecycle_state_change_dataset import ( - LifecycleStateChange, - LifecycleStateChangeDatasetFacet, - PreviousIdentifier, - ) - except ImportError: - from openlineage.client.facet import ( - LifecycleStateChange, - LifecycleStateChangeDatasetFacet, - LifecycleStateChangeDatasetFacetPreviousIdentifier as PreviousIdentifier, - ) - from openlineage.client.run import Dataset - from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.providers.amazon.aws.operators.s3 import ( @@ -68,6 +44,12 @@ S3ListPrefixesOperator, S3PutBucketTaggingOperator, ) +from airflow.providers.common.compat.openlineage.facet import ( + Dataset, + LifecycleStateChange, + LifecycleStateChangeDatasetFacet, + PreviousIdentifier, +) from airflow.providers.openlineage.extractors import OperatorLineage from airflow.utils.timezone import datetime, utcnow diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py index e4b9ee16f084..45e90e204a5e 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py @@ -16,20 +16,11 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING from unittest import mock import pytest from botocore.exceptions import ClientError -if TYPE_CHECKING: - from openlineage.client.event_v2 import Dataset -else: - try: - from openlineage.client.event_v2 import Dataset - except ImportError: - from openlineage.client.run import Dataset - from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook from airflow.providers.amazon.aws.operators import sagemaker @@ -38,6 +29,7 @@ SageMakerProcessingOperator, ) from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger +from airflow.providers.common.compat.openlineage.facet import Dataset from airflow.providers.openlineage.extractors import OperatorLineage CREATE_PROCESSING_PARAMS: dict = { diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_training.py b/tests/providers/amazon/aws/operators/test_sagemaker_training.py index 448c8e967e7c..4426b4f15237 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_training.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_training.py @@ -17,20 +17,11 @@ from __future__ import annotations from datetime import datetime -from typing import TYPE_CHECKING from unittest import mock import pytest from botocore.exceptions import ClientError -if TYPE_CHECKING: - from openlineage.client.event_v2 import Dataset -else: - try: - from openlineage.client.event_v2 import Dataset - except ImportError: - from openlineage.client.run import Dataset - from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.amazon.aws.hooks.sagemaker import LogState, SageMakerHook from airflow.providers.amazon.aws.operators import sagemaker @@ -38,6 +29,7 @@ from airflow.providers.amazon.aws.triggers.sagemaker import ( SageMakerTrigger, ) +from airflow.providers.common.compat.openlineage.facet import Dataset from airflow.providers.openlineage.extractors import OperatorLineage EXPECTED_INTEGER_FIELDS: list[list[str]] = [ diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py index dd364208a523..314d0ba46a52 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py @@ -18,25 +18,17 @@ from __future__ import annotations import copy -from typing import TYPE_CHECKING from unittest import mock import pytest from botocore.exceptions import ClientError -if TYPE_CHECKING: - from openlineage.client.event_v2 import Dataset -else: - try: - from openlineage.client.event_v2 import Dataset - except ImportError: - from openlineage.client.run import Dataset - from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook from airflow.providers.amazon.aws.operators import sagemaker from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTransformOperator from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger +from airflow.providers.common.compat.openlineage.facet import Dataset from airflow.providers.openlineage.extractors import OperatorLineage EXPECTED_INTEGER_FIELDS: list[list[str]] = [ diff --git a/tests/providers/common/compat/openlineage/__init__.py b/tests/providers/common/compat/openlineage/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/tests/providers/common/compat/openlineage/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/common/compat/openlineage/test_facet.py b/tests/providers/common/compat/openlineage/test_facet.py new file mode 100644 index 000000000000..6bdf6b5555fc --- /dev/null +++ b/tests/providers/common/compat/openlineage/test_facet.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + + +def test_empty(): + """Initially there were tests with import mocks involved. Now they are removed and this is placeholder.""" + pass diff --git a/tests/providers/common/io/operators/test_file_transfer.py b/tests/providers/common/io/operators/test_file_transfer.py index de5e7a93fc6a..698c33582b82 100644 --- a/tests/providers/common/io/operators/test_file_transfer.py +++ b/tests/providers/common/io/operators/test_file_transfer.py @@ -17,17 +17,9 @@ from __future__ import annotations -from typing import TYPE_CHECKING from unittest import mock -if TYPE_CHECKING: - from openlineage.client.event_v2 import Dataset -else: - try: - from openlineage.client.event_v2 import Dataset - except ImportError: - from openlineage.client.run import Dataset - +from airflow.providers.common.compat.openlineage.facet import Dataset from tests.test_utils.compat import ignore_provider_compatibility_error with ignore_provider_compatibility_error("2.8.0", __file__): diff --git a/tests/providers/common/sql/operators/test_sql_execute.py b/tests/providers/common/sql/operators/test_sql_execute.py index 1b51f0322dd1..1527f3190d5b 100644 --- a/tests/providers/common/sql/operators/test_sql_execute.py +++ b/tests/providers/common/sql/operators/test_sql_execute.py @@ -17,30 +17,19 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, NamedTuple, Sequence +from typing import Any, NamedTuple, Sequence from unittest import mock from unittest.mock import MagicMock import pytest -if TYPE_CHECKING: - from openlineage.client.event_v2 import Dataset - from openlineage.client.generated.schema_dataset import SchemaDatasetFacet, SchemaDatasetFacetFields - from openlineage.client.generated.sql_job import SQLJobFacet -else: - try: - from openlineage.client.event_v2 import Dataset - from openlineage.client.generated.schema_dataset import SchemaDatasetFacet, SchemaDatasetFacetFields - from openlineage.client.generated.sql_job import SQLJobFacet - except ImportError: - from openlineage.client.facet import ( - SchemaDatasetFacet, - SchemaField as SchemaDatasetFacetFields, - SqlJobFacet as SQLJobFacet, - ) - from openlineage.client.run import Dataset - from airflow.models import Connection +from airflow.providers.common.compat.openlineage.facet import ( + Dataset, + SchemaDatasetFacet, + SchemaDatasetFacetFields, + SQLJobFacet, +) from airflow.providers.common.sql.hooks.sql import DbApiHook, fetch_all_handler from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator from airflow.providers.openlineage.extractors.base import OperatorLineage diff --git a/tests/providers/dbt/cloud/utils/test_openlineage.py b/tests/providers/dbt/cloud/utils/test_openlineage.py index c42c6e55ab25..d7d753cc2a46 100644 --- a/tests/providers/dbt/cloud/utils/test_openlineage.py +++ b/tests/providers/dbt/cloud/utils/test_openlineage.py @@ -44,6 +44,7 @@ def json(self): def emit_event(event): + # since 1.15.0 there was v2 facets introduced if parse(__version__) >= parse("1.15.0"): assert event.run.facets["parent"].run.runId == TASK_UUID assert event.run.facets["parent"].job.name == f"{DAG_ID}.{TASK_ID}" diff --git a/tests/providers/ftp/operators/test_ftp.py b/tests/providers/ftp/operators/test_ftp.py index e443484be931..24eaa2bf4ca6 100644 --- a/tests/providers/ftp/operators/test_ftp.py +++ b/tests/providers/ftp/operators/test_ftp.py @@ -18,20 +18,12 @@ from __future__ import annotations import socket -from typing import TYPE_CHECKING from unittest import mock import pytest -if TYPE_CHECKING: - from openlineage.client.event_v2 import Dataset -else: - try: - from openlineage.client.event_v2 import Dataset - except ImportError: - from openlineage.client.run import Dataset - from airflow.models import DAG, Connection +from airflow.providers.common.compat.openlineage.facet import Dataset from airflow.providers.ftp.operators.ftp import ( FTPFileTransmitOperator, FTPOperation, diff --git a/tests/providers/google/cloud/openlineage/test_mixins.py b/tests/providers/google/cloud/openlineage/test_mixins.py index c2c803232ff6..f7feade65d36 100644 --- a/tests/providers/google/cloud/openlineage/test_mixins.py +++ b/tests/providers/google/cloud/openlineage/test_mixins.py @@ -17,35 +17,18 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING from unittest.mock import MagicMock import pytest -if TYPE_CHECKING: - from openlineage.client.event_v2 import InputDataset, OutputDataset - from openlineage.client.generated.external_query_run import ExternalQueryRunFacet - from openlineage.client.generated.output_statistics_output_dataset import ( - OutputStatisticsOutputDatasetFacet, - ) - from openlineage.client.generated.schema_dataset import SchemaDatasetFacet, SchemaDatasetFacetFields -else: - try: - from openlineage.client.event_v2 import InputDataset, OutputDataset - from openlineage.client.generated.external_query_run import ExternalQueryRunFacet - from openlineage.client.generated.output_statistics_output_dataset import ( - OutputStatisticsOutputDatasetFacet, - ) - from openlineage.client.generated.schema_dataset import SchemaDatasetFacet, SchemaDatasetFacetFields - except ImportError: - from openlineage.client.facet import ( - ExternalQueryRunFacet, - OutputStatisticsOutputDatasetFacet, - SchemaDatasetFacet, - SchemaField as SchemaDatasetFacetFields, - ) - from openlineage.client.run import InputDataset, OutputDataset - +from airflow.providers.common.compat.openlineage.facet import ( + ExternalQueryRunFacet, + InputDataset, + OutputDataset, + OutputStatisticsOutputDatasetFacet, + SchemaDatasetFacet, + SchemaDatasetFacetFields, +) from airflow.providers.google.cloud.openlineage.mixins import _BigQueryOpenLineageMixin from airflow.providers.google.cloud.openlineage.utils import ( BigQueryJobRunFacet, diff --git a/tests/providers/google/cloud/openlineage/test_utils.py b/tests/providers/google/cloud/openlineage/test_utils.py index 34f5f0b2c187..e47f14332f45 100644 --- a/tests/providers/google/cloud/openlineage/test_utils.py +++ b/tests/providers/google/cloud/openlineage/test_utils.py @@ -17,42 +17,20 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING from unittest.mock import MagicMock import pytest from google.cloud.bigquery.table import Table -if TYPE_CHECKING: - from openlineage.client.event_v2 import Dataset - from openlineage.client.generated.column_lineage_dataset import ( - ColumnLineageDatasetFacet, - Fields, - InputField, - ) - from openlineage.client.generated.documentation_dataset import DocumentationDatasetFacet - from openlineage.client.generated.schema_dataset import SchemaDatasetFacet, SchemaDatasetFacetFields -else: - try: - from openlineage.client.event_v2 import Dataset - from openlineage.client.generated.column_lineage_dataset import ( - ColumnLineageDatasetFacet, - Fields, - InputField, - ) - from openlineage.client.generated.documentation_dataset import DocumentationDatasetFacet - from openlineage.client.generated.schema_dataset import SchemaDatasetFacet, SchemaDatasetFacetFields - except ImportError: - from openlineage.client.facet import ( - ColumnLineageDatasetFacet, - ColumnLineageDatasetFacetFieldsAdditional as Fields, - ColumnLineageDatasetFacetFieldsAdditionalInputFields as InputField, - DocumentationDatasetFacet, - SchemaDatasetFacet, - SchemaField as SchemaDatasetFacetFields, - ) - from openlineage.client.run import Dataset - +from airflow.providers.common.compat.openlineage.facet import ( + ColumnLineageDatasetFacet, + Dataset, + DocumentationDatasetFacet, + Fields, + InputField, + SchemaDatasetFacet, + SchemaDatasetFacetFields, +) from airflow.providers.google.cloud.openlineage.utils import ( get_facets_from_bq_table, get_identity_column_lineage_facet, diff --git a/tests/providers/google/cloud/operators/test_bigquery.py b/tests/providers/google/cloud/operators/test_bigquery.py index d562455c63b2..ac314a1ec798 100644 --- a/tests/providers/google/cloud/operators/test_bigquery.py +++ b/tests/providers/google/cloud/operators/test_bigquery.py @@ -19,7 +19,6 @@ import json from contextlib import suppress -from typing import TYPE_CHECKING from unittest import mock from unittest.mock import ANY, MagicMock @@ -28,25 +27,6 @@ from google.cloud.bigquery import DEFAULT_RETRY, ScalarQueryParameter from google.cloud.exceptions import Conflict -if TYPE_CHECKING: - from openlineage.client.event_v2 import InputDataset - from openlineage.client.generated.error_message_run import ErrorMessageRunFacet - from openlineage.client.generated.external_query_run import ExternalQueryRunFacet - from openlineage.client.generated.sql_job import SQLJobFacet -else: - try: - from openlineage.client.event_v2 import InputDataset - from openlineage.client.generated.error_message_run import ErrorMessageRunFacet - from openlineage.client.generated.external_query_run import ExternalQueryRunFacet - from openlineage.client.generated.sql_job import SQLJobFacet - except ImportError: - from openlineage.client.facet import ( - ErrorMessageRunFacet, - ExternalQueryRunFacet, - SqlJobFacet as SQLJobFacet, - ) - from openlineage.client.run import InputDataset - from airflow.exceptions import ( AirflowException, AirflowProviderDeprecationWarning, @@ -54,6 +34,12 @@ AirflowTaskTimeout, TaskDeferred, ) +from airflow.providers.common.compat.openlineage.facet import ( + ErrorMessageRunFacet, + ExternalQueryRunFacet, + InputDataset, + SQLJobFacet, +) from airflow.providers.google.cloud.operators.bigquery import ( BigQueryCheckOperator, BigQueryColumnCheckOperator, diff --git a/tests/providers/google/cloud/operators/test_gcs.py b/tests/providers/google/cloud/operators/test_gcs.py index 5492882b6806..1a5acd0bf610 100644 --- a/tests/providers/google/cloud/operators/test_gcs.py +++ b/tests/providers/google/cloud/operators/test_gcs.py @@ -19,34 +19,16 @@ from datetime import datetime, timedelta, timezone from pathlib import Path -from typing import TYPE_CHECKING from unittest import mock import pytest -if TYPE_CHECKING: - from openlineage.client.event_v2 import Dataset - from openlineage.client.generated.lifecycle_state_change_dataset import ( - LifecycleStateChange, - LifecycleStateChangeDatasetFacet, - PreviousIdentifier, - ) -else: - try: - from openlineage.client.event_v2 import Dataset - from openlineage.client.generated.lifecycle_state_change_dataset import ( - LifecycleStateChange, - LifecycleStateChangeDatasetFacet, - PreviousIdentifier, - ) - except ImportError: - from openlineage.client.facet import ( - LifecycleStateChange, - LifecycleStateChangeDatasetFacet, - LifecycleStateChangeDatasetFacetPreviousIdentifier as PreviousIdentifier, - ) - from openlineage.client.run import Dataset - +from airflow.providers.common.compat.openlineage.facet import ( + Dataset, + LifecycleStateChange, + LifecycleStateChangeDatasetFacet, + PreviousIdentifier, +) from airflow.providers.google.cloud.operators.gcs import ( GCSBucketCreateAclEntryOperator, GCSCreateBucketOperator, diff --git a/tests/providers/google/cloud/transfers/test_azure_blob_to_gcs.py b/tests/providers/google/cloud/transfers/test_azure_blob_to_gcs.py index b71d747ebf7d..aa575eb9df40 100644 --- a/tests/providers/google/cloud/transfers/test_azure_blob_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_azure_blob_to_gcs.py @@ -94,7 +94,7 @@ def test_execute(self, mock_temp, mock_hook_gcs, mock_hook_wasb): @mock.patch("airflow.providers.google.cloud.transfers.azure_blob_to_gcs.WasbHook") def test_execute_single_file_transfer_openlineage(self, mock_hook_wasb): - from openlineage.client.run import Dataset + from airflow.providers.common.compat.openlineage.facet import Dataset MOCK_AZURE_ACCOUNT_NAME = "mock_account_name" mock_hook_wasb.return_value.get_conn.return_value.account_name = MOCK_AZURE_ACCOUNT_NAME diff --git a/tests/providers/google/cloud/transfers/test_bigquery_to_gcs.py b/tests/providers/google/cloud/transfers/test_bigquery_to_gcs.py index 0af142c56296..7c2a39825375 100644 --- a/tests/providers/google/cloud/transfers/test_bigquery_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_bigquery_to_gcs.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING from unittest import mock from unittest.mock import MagicMock @@ -25,42 +24,19 @@ from google.cloud.bigquery.retry import DEFAULT_RETRY from google.cloud.bigquery.table import Table -if TYPE_CHECKING: - from openlineage.client.event_v2 import Dataset - from openlineage.client.generated.column_lineage_dataset import ( - ColumnLineageDatasetFacet, - Fields, - InputField, - ) - from openlineage.client.generated.documentation_dataset import DocumentationDatasetFacet - from openlineage.client.generated.external_query_run import ExternalQueryRunFacet - from openlineage.client.generated.schema_dataset import SchemaDatasetFacet, SchemaDatasetFacetFields - from openlineage.client.generated.symlinks_dataset import Identifier, SymlinksDatasetFacet -else: - try: - from openlineage.client.event_v2 import Dataset - from openlineage.client.generated.column_lineage_dataset import ( - ColumnLineageDatasetFacet, - Fields, - InputField, - ) - from openlineage.client.generated.documentation_dataset import DocumentationDatasetFacet - from openlineage.client.generated.external_query_run import ExternalQueryRunFacet - from openlineage.client.generated.schema_dataset import SchemaDatasetFacet, SchemaDatasetFacetFields - from openlineage.client.generated.symlinks_dataset import Identifier, SymlinksDatasetFacet - except ImportError: - from openlineage.client.facet import ( - ColumnLineageDatasetFacet, - ColumnLineageDatasetFacetFieldsAdditional as Fields, - ColumnLineageDatasetFacetFieldsAdditionalInputFields as InputField, - DocumentationDatasetFacet, - ExternalQueryRunFacet, - SchemaDatasetFacet, - SchemaField as SchemaDatasetFacetFields, - ) - from openlineage.client.run import Dataset - from airflow.exceptions import TaskDeferred +from airflow.providers.common.compat.openlineage.facet import ( + ColumnLineageDatasetFacet, + Dataset, + DocumentationDatasetFacet, + ExternalQueryRunFacet, + Fields, + Identifier, + InputField, + SchemaDatasetFacet, + SchemaDatasetFacetFields, + SymlinksDatasetFacet, +) from airflow.providers.google.cloud.transfers.bigquery_to_gcs import BigQueryToGCSOperator from airflow.providers.google.cloud.triggers.bigquery import BigQueryInsertJobTrigger diff --git a/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py b/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py index d39efde3b10e..24ad708db697 100644 --- a/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py +++ b/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py @@ -18,7 +18,6 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING from unittest import mock from unittest.mock import MagicMock, call @@ -26,45 +25,22 @@ from google.cloud.bigquery import DEFAULT_RETRY, Table from google.cloud.exceptions import Conflict -if TYPE_CHECKING: - from openlineage.client.event_v2 import Dataset - from openlineage.client.generated.column_lineage_dataset import ( - ColumnLineageDatasetFacet, - Fields, - InputField, - ) - from openlineage.client.generated.documentation_dataset import DocumentationDatasetFacet - from openlineage.client.generated.external_query_run import ExternalQueryRunFacet - from openlineage.client.generated.schema_dataset import SchemaDatasetFacet, SchemaDatasetFacetFields - from openlineage.client.generated.symlinks_dataset import Identifier, SymlinksDatasetFacet -else: - try: - from openlineage.client.event_v2 import Dataset - from openlineage.client.generated.column_lineage_dataset import ( - ColumnLineageDatasetFacet, - Fields, - InputField, - ) - from openlineage.client.generated.documentation_dataset import DocumentationDatasetFacet - from openlineage.client.generated.external_query_run import ExternalQueryRunFacet - from openlineage.client.generated.schema_dataset import SchemaDatasetFacet, SchemaDatasetFacetFields - from openlineage.client.generated.symlinks_dataset import Identifier, SymlinksDatasetFacet - except ImportError: - from openlineage.client.facet import ( - ColumnLineageDatasetFacet, - ColumnLineageDatasetFacetFieldsAdditional as Fields, - ColumnLineageDatasetFacetFieldsAdditionalInputFields as InputField, - DocumentationDatasetFacet, - ExternalQueryRunFacet, - SchemaDatasetFacet, - SchemaField as SchemaDatasetFacetFields, - ) - from openlineage.client.run import Dataset - from airflow.exceptions import AirflowException, TaskDeferred from airflow.models import DAG from airflow.models.dagrun import DagRun from airflow.models.taskinstance import TaskInstance +from airflow.providers.common.compat.openlineage.facet import ( + ColumnLineageDatasetFacet, + Dataset, + DocumentationDatasetFacet, + ExternalQueryRunFacet, + Fields, + Identifier, + InputField, + SchemaDatasetFacet, + SchemaDatasetFacetFields, + SymlinksDatasetFacet, +) from airflow.providers.google.cloud.transfers.gcs_to_bigquery import GCSToBigQueryOperator from airflow.providers.google.cloud.triggers.bigquery import BigQueryInsertJobTrigger from airflow.utils.timezone import datetime diff --git a/tests/providers/google/cloud/transfers/test_gcs_to_gcs.py b/tests/providers/google/cloud/transfers/test_gcs_to_gcs.py index 7fc33ec42cd6..f33bdd8e8d28 100644 --- a/tests/providers/google/cloud/transfers/test_gcs_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_gcs_to_gcs.py @@ -18,20 +18,12 @@ from __future__ import annotations from datetime import datetime -from typing import TYPE_CHECKING from unittest import mock import pytest -if TYPE_CHECKING: - from openlineage.client.event_v2 import Dataset -else: - try: - from openlineage.client.event_v2 import Dataset - except ImportError: - from openlineage.client.run import Dataset - from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning +from airflow.providers.common.compat.openlineage.facet import Dataset from airflow.providers.google.cloud.transfers.gcs_to_gcs import WILDCARD, GCSToGCSOperator TASK_ID = "test-gcs-to-gcs-operator" diff --git a/tests/providers/mysql/operators/test_mysql.py b/tests/providers/mysql/operators/test_mysql.py index 9e66c05e80c7..719d37024c68 100644 --- a/tests/providers/mysql/operators/test_mysql.py +++ b/tests/providers/mysql/operators/test_mysql.py @@ -19,30 +19,18 @@ import os from contextlib import closing -from typing import TYPE_CHECKING from unittest.mock import MagicMock import pytest -if TYPE_CHECKING: - from openlineage.client.event_v2 import Dataset - from openlineage.client.generated.schema_dataset import SchemaDatasetFacet, SchemaDatasetFacetFields - from openlineage.client.generated.sql_job import SQLJobFacet -else: - try: - from openlineage.client.event_v2 import Dataset - from openlineage.client.generated.schema_dataset import SchemaDatasetFacet, SchemaDatasetFacetFields - from openlineage.client.generated.sql_job import SQLJobFacet - except ImportError: - from openlineage.client.facet import ( - SchemaDatasetFacet, - SchemaField as SchemaDatasetFacetFields, - SqlJobFacet as SQLJobFacet, - ) - from openlineage.client.run import Dataset - from airflow.models.connection import Connection from airflow.models.dag import DAG +from airflow.providers.common.compat.openlineage.facet import ( + Dataset, + SchemaDatasetFacet, + SchemaDatasetFacetFields, + SQLJobFacet, +) from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator from airflow.providers.mysql.hooks.mysql import MySqlHook from airflow.utils import timezone diff --git a/tests/providers/openlineage/plugins/test_adapter.py b/tests/providers/openlineage/plugins/test_adapter.py index 93b49e150f6e..6ba5d9d4a39e 100644 --- a/tests/providers/openlineage/plugins/test_adapter.py +++ b/tests/providers/openlineage/plugins/test_adapter.py @@ -481,6 +481,9 @@ def test_emit_failed_event_with_additional_information(mock_stats_incr, mock_sta run=parent_run.Run(runId=parent_run_id), job=parent_run.Job(namespace=namespace(), name="parent_job_name"), ), + "errorMessage": error_message_run.ErrorMessageRunFacet( + message="Error message", programmingLanguage="python", stackTrace=None + ), "externalQuery": external_query_run.ExternalQueryRunFacet( externalQueryId="123", source="source" ), diff --git a/tests/providers/openlineage/test_sqlparser.py b/tests/providers/openlineage/test_sqlparser.py index edde376fad3a..e93f90551ac8 100644 --- a/tests/providers/openlineage/test_sqlparser.py +++ b/tests/providers/openlineage/test_sqlparser.py @@ -21,7 +21,7 @@ import pytest from openlineage.client.event_v2 import Dataset -from openlineage.client.facet_v2 import column_lineage_dataset, schema_dataset, sql_job +from openlineage.client.facet_v2 import column_lineage_dataset, schema_dataset from openlineage.common.sql import DbTableMeta from airflow.providers.openlineage.sqlparser import DatabaseInfo, GetTableSchemasParams, SQLParser @@ -338,7 +338,9 @@ def test_generate_openlineage_metadata_from_sql(self, mock_parse, parser_returns schema_dataset.SchemaDatasetFacetFields(name="orders_placed", type="int4"), ] ) - assert metadata.outputs[0].facets["columnLineage"] == column_lineage_dataset.ColumnLineageDatasetFacet( + assert metadata.outputs[0].facets[ + "columnLineage" + ] == column_lineage_dataset.ColumnLineageDatasetFacet( fields={ "order_day_of_week": column_lineage_dataset.Fields( inputFields=[ diff --git a/tests/providers/sftp/operators/test_sftp.py b/tests/providers/sftp/operators/test_sftp.py index b98b7570e8ac..d87ae26b3abd 100644 --- a/tests/providers/sftp/operators/test_sftp.py +++ b/tests/providers/sftp/operators/test_sftp.py @@ -21,22 +21,14 @@ import os import socket from base64 import b64encode -from typing import TYPE_CHECKING from unittest import mock import paramiko import pytest -if TYPE_CHECKING: - from openlineage.client.event_v2 import Dataset -else: - try: - from openlineage.client.event_v2 import Dataset - except ImportError: - from openlineage.client.run import Dataset - from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.models import DAG, Connection +from airflow.providers.common.compat.openlineage.facet import Dataset from airflow.providers.sftp.hooks.sftp import SFTPHook from airflow.providers.sftp.operators.sftp import SFTPOperation, SFTPOperator from airflow.providers.ssh.hooks.ssh import SSHHook diff --git a/tests/providers/snowflake/operators/test_snowflake_sql.py b/tests/providers/snowflake/operators/test_snowflake_sql.py index c9380bf17073..fb1bcd172635 100644 --- a/tests/providers/snowflake/operators/test_snowflake_sql.py +++ b/tests/providers/snowflake/operators/test_snowflake_sql.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING from unittest import mock from unittest.mock import MagicMock, patch @@ -37,33 +36,14 @@ def Row(*args, **kwargs): return MagicMock() -if TYPE_CHECKING: - from openlineage.client.event_v2 import Dataset - from openlineage.client.generated.column_lineage_dataset import ( - ColumnLineageDatasetFacet, - Fields, - InputField, - ) - from openlineage.client.generated.sql_job import SQLJobFacet -else: - try: - from openlineage.client.event_v2 import Dataset - from openlineage.client.generated.column_lineage_dataset import ( - ColumnLineageDatasetFacet, - Fields, - InputField, - ) - from openlineage.client.generated.sql_job import SQLJobFacet - except ImportError: - from openlineage.client.facet import ( - ColumnLineageDatasetFacet, - ColumnLineageDatasetFacetFieldsAdditional as Fields, - ColumnLineageDatasetFacetFieldsAdditionalInputFields as InputField, - SqlJobFacet as SQLJobFacet, - ) - from openlineage.client.run import Dataset - from airflow.models.connection import Connection +from airflow.providers.common.compat.openlineage.facet import ( + ColumnLineageDatasetFacet, + Dataset, + Fields, + InputField, + SQLJobFacet, +) from airflow.providers.common.sql.hooks.sql import fetch_all_handler from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook diff --git a/tests/providers/snowflake/transfers/test_copy_into_snowflake.py b/tests/providers/snowflake/transfers/test_copy_into_snowflake.py index 3437d85e8d7b..56697db8c747 100644 --- a/tests/providers/snowflake/transfers/test_copy_into_snowflake.py +++ b/tests/providers/snowflake/transfers/test_copy_into_snowflake.py @@ -16,32 +16,18 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Callable +from typing import Callable from unittest import mock import pytest -if TYPE_CHECKING: - from openlineage.client.event_v2 import Dataset - from openlineage.client.generated.external_query_run import ExternalQueryRunFacet - from openlineage.client.generated.extraction_error_run import Error, ExtractionErrorRunFacet - from openlineage.client.generated.sql_job import SQLJobFacet -else: - try: - from openlineage.client.event_v2 import Dataset - from openlineage.client.generated.external_query_run import ExternalQueryRunFacet - from openlineage.client.generated.extraction_error_run import Error, ExtractionErrorRunFacet - from openlineage.client.generated.sql_job import SQLJobFacet - except ImportError: - from openlineage.client.facet import ( - ExternalQueryRunFacet, - ExtractionError as Error, - ExtractionErrorRunFacet, - SqlJobFacet as SQLJobFacet, - ) - from openlineage.client.run import Dataset - - +from airflow.providers.common.compat.openlineage.facet import ( + Dataset, + Error, + ExternalQueryRunFacet, + ExtractionErrorRunFacet, + SQLJobFacet, +) from airflow.providers.openlineage.extractors import OperatorLineage from airflow.providers.openlineage.sqlparser import DatabaseInfo from airflow.providers.snowflake.transfers.copy_into_snowflake import CopyFromExternalStageToSnowflakeOperator diff --git a/tests/providers/trino/operators/test_trino.py b/tests/providers/trino/operators/test_trino.py index d3b219b58675..24d933bf0d36 100644 --- a/tests/providers/trino/operators/test_trino.py +++ b/tests/providers/trino/operators/test_trino.py @@ -17,27 +17,15 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING from unittest import mock -if TYPE_CHECKING: - from openlineage.client.event_v2 import Dataset - from openlineage.client.generated.schema_dataset import SchemaDatasetFacet, SchemaDatasetFacetFields - from openlineage.client.generated.sql_job import SQLJobFacet -else: - try: - from openlineage.client.event_v2 import Dataset - from openlineage.client.generated.schema_dataset import SchemaDatasetFacet, SchemaDatasetFacetFields - from openlineage.client.generated.sql_job import SQLJobFacet - except ImportError: - from openlineage.client.facet import ( - SchemaDatasetFacet, - SchemaField as SchemaDatasetFacetFields, - SqlJobFacet as SQLJobFacet, - ) - from openlineage.client.run import Dataset - from airflow.models.connection import Connection +from airflow.providers.common.compat.openlineage.facet import ( + Dataset, + SchemaDatasetFacet, + SchemaDatasetFacetFields, + SQLJobFacet, +) from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator from airflow.providers.trino.hooks.trino import TrinoHook From 1bfad6076977dc5ec3386cac9c583a1f42ca2bb8 Mon Sep 17 00:00:00 2001 From: Jakub Dardzinski Date: Mon, 22 Jul 2024 23:39:52 +0200 Subject: [PATCH 7/7] Add pre-commit hook for check on `common.compat` imports over OL v1. Signed-off-by: Jakub Dardzinski --- .pre-commit-config.yaml | 8 ++ airflow/providers/amazon/aws/datasets/s3.py | 4 +- airflow/providers/common/io/datasets/file.py | 4 +- airflow/providers/openlineage/utils/utils.py | 3 +- contributing-docs/08_static_code_checks.rst | 2 + .../doc/images/output_static-checks.svg | 28 +++--- .../doc/images/output_static-checks.txt | 2 +- .../src/airflow_breeze/pre_commit_ids.py | 1 + .../guides/developer.rst | 2 +- generated/provider_dependencies.json | 1 + ...heck_common_compat_used_for_openlineage.py | 88 +++++++++++++++++++ .../providers/common/io/datasets/test_file.py | 2 +- .../openlineage/utils/custom_facet_fixture.py | 3 +- 13 files changed, 125 insertions(+), 23 deletions(-) create mode 100755 scripts/ci/pre_commit/check_common_compat_used_for_openlineage.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 87118f4b5a53..536d75574b64 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -388,6 +388,14 @@ repos: exclude: ^airflow/kubernetes/ entry: ./scripts/ci/pre_commit/check_airflow_k8s_not_used.py additional_dependencies: ['rich>=12.4.4'] + - id: check-common-compat-used-for-openlineage + name: Check common.compat is used for OL deprecated classes + language: python + files: ^airflow/.*\.py$ + require_serial: true + exclude: ^airflow/openlineage/ + entry: ./scripts/ci/pre_commit/check_common_compat_used_for_openlineage.py + additional_dependencies: ['rich>=12.4.4'] - id: check-airflow-providers-bug-report-template name: Check airflow-bug-report provider list is sorted/unique language: python diff --git a/airflow/providers/amazon/aws/datasets/s3.py b/airflow/providers/amazon/aws/datasets/s3.py index e6bed6dbe3df..c42ec2bb1cc0 100644 --- a/airflow/providers/amazon/aws/datasets/s3.py +++ b/airflow/providers/amazon/aws/datasets/s3.py @@ -24,7 +24,7 @@ if TYPE_CHECKING: from urllib.parse import SplitResult - from openlineage.client.run import Dataset as OpenLineageDataset + from airflow.providers.common.compat.openlineage.facet import Dataset as OpenLineageDataset def create_dataset(*, bucket: str, key: str, extra=None) -> Dataset: @@ -39,7 +39,7 @@ def sanitize_uri(uri: SplitResult) -> SplitResult: def convert_dataset_to_openlineage(dataset: Dataset, lineage_context) -> OpenLineageDataset: """Translate Dataset with valid AIP-60 uri to OpenLineage with assistance from the hook.""" - from openlineage.client.run import Dataset as OpenLineageDataset + from airflow.providers.common.compat.openlineage.facet import Dataset as OpenLineageDataset bucket, key = S3Hook.parse_s3_url(dataset.uri) return OpenLineageDataset(namespace=f"s3://{bucket}", name=key if key else "/") diff --git a/airflow/providers/common/io/datasets/file.py b/airflow/providers/common/io/datasets/file.py index aa7e8d98be7a..35d3b227e522 100644 --- a/airflow/providers/common/io/datasets/file.py +++ b/airflow/providers/common/io/datasets/file.py @@ -24,7 +24,7 @@ if TYPE_CHECKING: from urllib.parse import SplitResult - from openlineage.client.run import Dataset as OpenLineageDataset + from airflow.providers.common.compat.openlineage.facet import Dataset as OpenLineageDataset def create_dataset(*, path: str, extra=None) -> Dataset: @@ -44,7 +44,7 @@ def convert_dataset_to_openlineage(dataset: Dataset, lineage_context) -> OpenLin Windows paths are not standardized and can produce unexpected behaviour. """ - from openlineage.client.run import Dataset as OpenLineageDataset + from airflow.providers.common.compat.openlineage.facet import Dataset as OpenLineageDataset parsed = urllib.parse.urlsplit(dataset.uri) return OpenLineageDataset(namespace=f"file://{parsed.netloc}", name=parsed.path) diff --git a/airflow/providers/openlineage/utils/utils.py b/airflow/providers/openlineage/utils/utils.py index de51a315ca61..195d14e4e752 100644 --- a/airflow/providers/openlineage/utils/utils.py +++ b/airflow/providers/openlineage/utils/utils.py @@ -55,6 +55,7 @@ from airflow.utils.module_loading import import_string if TYPE_CHECKING: + from openlineage.client.facet_v2 import RunFacet from openlineage.client.run import Dataset as OpenLineageDataset from airflow.models import DagRun, TaskInstance @@ -345,7 +346,7 @@ class TaskGroupInfo(InfoJsonEncodable): ] -def get_airflow_dag_run_facet(dag_run: DagRun) -> dict[str, BaseFacet]: +def get_airflow_dag_run_facet(dag_run: DagRun) -> dict[str, RunFacet]: if not dag_run.dag: return {} return { diff --git a/contributing-docs/08_static_code_checks.rst b/contributing-docs/08_static_code_checks.rst index 9f2f492f0947..e1312b8a45ec 100644 --- a/contributing-docs/08_static_code_checks.rst +++ b/contributing-docs/08_static_code_checks.rst @@ -150,6 +150,8 @@ require Breeze Docker image to be built locally. +-----------------------------------------------------------+--------------------------------------------------------------+---------+ | check-code-deprecations | Check deprecations categories in decorators | | +-----------------------------------------------------------+--------------------------------------------------------------+---------+ +| check-common-compat-used-for-openlineage | Check common.compat is used for OL deprecated classes | | ++-----------------------------------------------------------+--------------------------------------------------------------+---------+ | check-compat-cache-on-methods | Check that compat cache do not use on class methods | | +-----------------------------------------------------------+--------------------------------------------------------------+---------+ | check-core-deprecation-classes | Verify usage of Airflow deprecation classes in core | | diff --git a/dev/breeze/doc/images/output_static-checks.svg b/dev/breeze/doc/images/output_static-checks.svg index d81b0c55501e..12e58b0b2da3 100644 --- a/dev/breeze/doc/images/output_static-checks.svg +++ b/dev/breeze/doc/images/output_static-checks.svg @@ -329,20 +329,20 @@ check-boring-cyborg-configuration | check-breeze-top-dependencies-limited |       check-builtin-literals | check-changelog-format |                                 check-changelog-has-no-duplicates | check-cncf-k8s-only-for-executors |           -check-code-deprecations | check-compat-cache-on-methods |                         -check-core-deprecation-classes | check-daysago-import-from-utils |                -check-decorated-operator-implements-custom-name | check-deferrable-default-value  -| check-docstring-param-types | check-example-dags-urls |                         -check-executables-have-shebangs | check-extra-packages-references |               -check-extras-order | check-fab-migrations | check-for-inclusive-language |        -check-get-lineage-collector-providers | check-google-re2-as-dependency |          -check-hatch-build-order | check-hooks-apply | check-incorrect-use-of-LoggingMixin -| check-init-decorator-arguments | check-integrations-list-consistent |           -check-lazy-logging | check-links-to-example-dags-do-not-use-hardcoded-versions |  -check-merge-conflict | check-newsfragments-are-valid |                            -check-no-airflow-deprecation-in-providers | check-no-providers-in-core-examples | -check-only-new-session-with-provide-session |                                     -check-persist-credentials-disabled-in-github-workflows |                          +check-code-deprecations | check-common-compat-used-for-openlineage |              +check-compat-cache-on-methods | check-core-deprecation-classes |                  +check-daysago-import-from-utils | check-decorated-operator-implements-custom-name +| check-deferrable-default-value | check-docstring-param-types |                  +check-example-dags-urls | check-executables-have-shebangs |                       +check-extra-packages-references | check-extras-order | check-fab-migrations |     +check-for-inclusive-language | check-get-lineage-collector-providers |            +check-google-re2-as-dependency | check-hatch-build-order | check-hooks-apply |    +check-incorrect-use-of-LoggingMixin | check-init-decorator-arguments |            +check-integrations-list-consistent | check-lazy-logging |                         +check-links-to-example-dags-do-not-use-hardcoded-versions | check-merge-conflict  +| check-newsfragments-are-valid | check-no-airflow-deprecation-in-providers |     +check-no-providers-in-core-examples | check-only-new-session-with-provide-session +| check-persist-credentials-disabled-in-github-workflows |                        check-pre-commit-information-consistent | check-provide-create-sessions-imports | check-provider-docs-valid | check-provider-yaml-valid |                           check-providers-init-file-missing | check-providers-subpackages-init-file-exist | diff --git a/dev/breeze/doc/images/output_static-checks.txt b/dev/breeze/doc/images/output_static-checks.txt index d86b3b4bd082..9453c31f4df3 100644 --- a/dev/breeze/doc/images/output_static-checks.txt +++ b/dev/breeze/doc/images/output_static-checks.txt @@ -1 +1 @@ -d4f928b6f07b32672c2dfd8fc334aff8 +85fff776e1e23ae5f9715b38c1f71825 diff --git a/dev/breeze/src/airflow_breeze/pre_commit_ids.py b/dev/breeze/src/airflow_breeze/pre_commit_ids.py index 96620385f1be..be599e1f7d68 100644 --- a/dev/breeze/src/airflow_breeze/pre_commit_ids.py +++ b/dev/breeze/src/airflow_breeze/pre_commit_ids.py @@ -40,6 +40,7 @@ "check-changelog-has-no-duplicates", "check-cncf-k8s-only-for-executors", "check-code-deprecations", + "check-common-compat-used-for-openlineage", "check-compat-cache-on-methods", "check-core-deprecation-classes", "check-daysago-import-from-utils", diff --git a/docs/apache-airflow-providers-openlineage/guides/developer.rst b/docs/apache-airflow-providers-openlineage/guides/developer.rst index 2582bc5fc0cc..8d66780190d2 100644 --- a/docs/apache-airflow-providers-openlineage/guides/developer.rst +++ b/docs/apache-airflow-providers-openlineage/guides/developer.rst @@ -485,7 +485,7 @@ Writing a custom facet function import attrs from airflow.models import TaskInstance - from openlineage.client.facet import BaseFacet + from airflow.providers.common.compat.openlineage.facet import BaseFacet @attrs.define(slots=False) diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index b8d0853bb9df..85dbd405e8f8 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -29,6 +29,7 @@ "deps": [ "PyAthena>=3.0.10", "apache-airflow-providers-common-compat>=1.1.0", + "apache-airflow-providers-common-compat>=1.1.0", "apache-airflow-providers-common-sql>=1.3.1", "apache-airflow-providers-http", "apache-airflow>=2.7.0", diff --git a/scripts/ci/pre_commit/check_common_compat_used_for_openlineage.py b/scripts/ci/pre_commit/check_common_compat_used_for_openlineage.py new file mode 100755 index 000000000000..127a2b11cd3c --- /dev/null +++ b/scripts/ci/pre_commit/check_common_compat_used_for_openlineage.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import ast +import sys +from typing import NamedTuple + +from rich.console import Console + +console = Console(color_system="standard", width=200) + + +class ImportTuple(NamedTuple): + module: list[str] + name: list[str] + alias: str + + +def get_imports(path: str): + with open(path) as fh: + root = ast.parse(fh.read(), path) + + for node in ast.iter_child_nodes(root): + if isinstance(node, ast.Import): + module: list[str] = node.names[0].name.split(".") if node.names else [] + elif isinstance(node, ast.ImportFrom) and node.module: + module = node.module.split(".") + else: + continue + + for n in node.names: # type: ignore[attr-defined] + yield ImportTuple(module=module, name=n.name.split("."), alias=n.asname) + + +errors: list[str] = [] + +EXCEPTIONS = ["airflow/providers/common/compat/openlineage/facet.py"] + + +def main() -> int: + for path in sys.argv[1:]: + import_count = 0 + local_error_count = 0 + for imp in get_imports(path): + import_count += 1 + if len(imp.module) > 2: + if imp.module[:3] == ["openlineage", "client", "facet"] or imp.module[:3] == [ + "openlineage", + "client", + "run", + ]: + if path not in EXCEPTIONS: + local_error_count += 1 + errors.append(f"{path}: ({'.'.join(imp.module)})") + console.print(f"[blue]{path}:[/] Import count: {import_count}, error_count {local_error_count}") + if errors: + console.print( + "[red]Some files imports from `openlineage.client.facet` or `openlineage.client.run`. which are deprecated.[/]\n" + "You should import from `airflow.providers.common.compat.openlineage.facet` instead." + ) + console.print("Error summary:") + for error in errors: + console.print(error) + return 1 + else: + console.print("[green]All good!") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/providers/common/io/datasets/test_file.py b/tests/providers/common/io/datasets/test_file.py index b2e4fddf986f..d8d53247a679 100644 --- a/tests/providers/common/io/datasets/test_file.py +++ b/tests/providers/common/io/datasets/test_file.py @@ -19,9 +19,9 @@ from urllib.parse import urlsplit, urlunsplit import pytest -from openlineage.client.run import Dataset as OpenLineageDataset from airflow.datasets import Dataset +from airflow.providers.common.compat.openlineage.facet import Dataset as OpenLineageDataset from airflow.providers.common.io.datasets.file import ( convert_dataset_to_openlineage, create_dataset, diff --git a/tests/providers/openlineage/utils/custom_facet_fixture.py b/tests/providers/openlineage/utils/custom_facet_fixture.py index 5a051218e2e9..f2504888b420 100644 --- a/tests/providers/openlineage/utils/custom_facet_fixture.py +++ b/tests/providers/openlineage/utils/custom_facet_fixture.py @@ -19,7 +19,8 @@ from typing import TYPE_CHECKING import attrs -from openlineage.client.facet import BaseFacet + +from airflow.providers.common.compat.openlineage.facet import BaseFacet if TYPE_CHECKING: from airflow.models import TaskInstance