From af1b625db1f76ab7e583cf0192ffa2ef04b94188 Mon Sep 17 00:00:00 2001 From: Jarek Potiuk Date: Wed, 7 Jun 2023 21:48:51 +0200 Subject: [PATCH] Fix Databricks SQL operator serialization The Databricks SQL operator returned Databricks Row which were not serializatble, because they were special extension of tuples that also acted as dict. In case of SQLOperator, we return a different format of output - separately descriptions of the rows and separately rows of values which are regular tuples. This PR converts the Databrick Rows into regular tuples on the flight while processing the output Fixes: #31753 Fixes: #31499 --- .../databricks/operators/databricks_sql.py | 9 ++++++-- .../operators/test_databricks_sql.py | 23 +++++++++++++++++-- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/airflow/providers/databricks/operators/databricks_sql.py b/airflow/providers/databricks/operators/databricks_sql.py index a14c083fa7326..ea2f1d0915030 100644 --- a/airflow/providers/databricks/operators/databricks_sql.py +++ b/airflow/providers/databricks/operators/databricks_sql.py @@ -22,6 +22,7 @@ import json from typing import TYPE_CHECKING, Any, Sequence +from databricks.sql.types import Row from databricks.sql.utils import ParamEscaper from airflow.exceptions import AirflowException @@ -33,6 +34,10 @@ from airflow.utils.context import Context +def make_serializable(val: Row): + return tuple(val) + + class DatabricksSqlOperator(SQLExecuteQueryOperator): """ Executes SQL code in a Databricks SQL endpoint or a Databricks cluster. @@ -125,7 +130,7 @@ def _should_run_output_processing(self) -> bool: def _process_output(self, results: list[Any], descriptions: list[Sequence[Sequence] | None]) -> list[Any]: if not self._output_path: - return list(zip(descriptions, results)) + return list(zip(descriptions, [[make_serializable(row) for row in res] for res in results])) if not self._output_format: raise AirflowException("Output format should be specified!") # Output to a file only the result of last query @@ -158,7 +163,7 @@ def _process_output(self, results: list[Any], descriptions: list[Sequence[Sequen file.write("\n") else: raise AirflowException(f"Unsupported output format: '{self._output_format}'") - return list(zip(descriptions, results)) + return list(zip(descriptions, [[make_serializable(row) for row in res] for res in results])) COPY_INTO_APPROVED_FORMATS = ["CSV", "JSON", "AVRO", "ORC", "PARQUET", "TEXT", "BINARYFILE"] diff --git a/tests/providers/databricks/operators/test_databricks_sql.py b/tests/providers/databricks/operators/test_databricks_sql.py index dd0c9b01870ea..5c5d4c92af676 100644 --- a/tests/providers/databricks/operators/test_databricks_sql.py +++ b/tests/providers/databricks/operators/test_databricks_sql.py @@ -22,10 +22,10 @@ from unittest.mock import patch import pytest -from databricks.sql.types import Row from airflow.providers.common.sql.hooks.sql import fetch_all_handler -from airflow.providers.databricks.operators.databricks_sql import DatabricksSqlOperator +from airflow.providers.databricks.operators.databricks_sql import DatabricksSqlOperator, Row +from airflow.serialization.serde import serialize DATE = "2017-04-20" TASK_ID = "databricks-sql-operator" @@ -151,6 +151,25 @@ def test_exec_success(sql, return_last, split_statement, hook_results, hook_desc ) +def test_return_value_serialization(): + hook_descriptions = [[("id",), ("value",)]] + hook_results = [Row(id=1, value="value1"), Row(id=2, value="value2")] + + with patch("airflow.providers.databricks.operators.databricks_sql.DatabricksSqlHook") as db_mock_class: + op = DatabricksSqlOperator( + task_id=TASK_ID, + sql="select * from dummy2", + do_xcom_push=True, + return_last=True, + ) + db_mock = db_mock_class.return_value + db_mock.run.return_value = hook_results + db_mock.descriptions = hook_descriptions + result = op.execute({}) + serialized_result = serialize(result) + assert serialized_result == serialize(([("id",), ("value",)], [(1, "value1"), (2, "value2")])) + + @pytest.mark.parametrize( "return_last, split_statements, sql, descriptions, hook_results, do_xcom_push", [