Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Databricks SQL operator serialization #31780

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions airflow/providers/databricks/operators/databricks_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import json
from typing import TYPE_CHECKING, Any, Sequence

from databricks.sql.types import Row
from databricks.sql.utils import ParamEscaper

from airflow.exceptions import AirflowException
Expand All @@ -33,6 +34,10 @@
from airflow.utils.context import Context


def make_serializable(val: Row):
return tuple(val)


class DatabricksSqlOperator(SQLExecuteQueryOperator):
"""
Executes SQL code in a Databricks SQL endpoint or a Databricks cluster.
Expand Down Expand Up @@ -125,7 +130,7 @@ def _should_run_output_processing(self) -> bool:

def _process_output(self, results: list[Any], descriptions: list[Sequence[Sequence] | None]) -> list[Any]:
if not self._output_path:
return list(zip(descriptions, results))
return list(zip(descriptions, [[make_serializable(row) for row in res] for res in results]))
if not self._output_format:
raise AirflowException("Output format should be specified!")
# Output to a file only the result of last query
Expand Down Expand Up @@ -158,7 +163,7 @@ def _process_output(self, results: list[Any], descriptions: list[Sequence[Sequen
file.write("\n")
else:
raise AirflowException(f"Unsupported output format: '{self._output_format}'")
return list(zip(descriptions, results))
return list(zip(descriptions, [[make_serializable(row) for row in res] for res in results]))


COPY_INTO_APPROVED_FORMATS = ["CSV", "JSON", "AVRO", "ORC", "PARQUET", "TEXT", "BINARYFILE"]
Expand Down
23 changes: 21 additions & 2 deletions tests/providers/databricks/operators/test_databricks_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
from unittest.mock import patch

import pytest
from databricks.sql.types import Row
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change this import?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry. have not seen it before merging. It's a left-over from some other trials done - not a big harm, it's the same Row just imported from databricks_sql - so it actually might even make more sense to be imported from there.


from airflow.providers.common.sql.hooks.sql import fetch_all_handler
from airflow.providers.databricks.operators.databricks_sql import DatabricksSqlOperator
from airflow.providers.databricks.operators.databricks_sql import DatabricksSqlOperator, Row
from airflow.serialization.serde import serialize

DATE = "2017-04-20"
TASK_ID = "databricks-sql-operator"
Expand Down Expand Up @@ -151,6 +151,25 @@ def test_exec_success(sql, return_last, split_statement, hook_results, hook_desc
)


def test_return_value_serialization():
hook_descriptions = [[("id",), ("value",)]]
hook_results = [Row(id=1, value="value1"), Row(id=2, value="value2")]

with patch("airflow.providers.databricks.operators.databricks_sql.DatabricksSqlHook") as db_mock_class:
op = DatabricksSqlOperator(
task_id=TASK_ID,
sql="select * from dummy2",
do_xcom_push=True,
return_last=True,
)
db_mock = db_mock_class.return_value
db_mock.run.return_value = hook_results
db_mock.descriptions = hook_descriptions
result = op.execute({})
serialized_result = serialize(result)
assert serialized_result == serialize(([("id",), ("value",)], [(1, "value1"), (2, "value2")]))


@pytest.mark.parametrize(
"return_last, split_statements, sql, descriptions, hook_results, do_xcom_push",
[
Expand Down