Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Implement support for fetchone() in the ODBCHook and the Databricks SQL Hook #36161

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions airflow/providers/databricks/hooks/databricks_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, TypeVar, overload

from databricks import sql # type: ignore[attr-defined]
from databricks.sql.types import Row

from airflow.exceptions import AirflowException
from airflow.providers.common.sql.hooks.sql import DbApiHook, return_single_query_results
Expand Down Expand Up @@ -242,9 +243,11 @@ def run(

@staticmethod
def _make_serializable(result):
"""Transform the databricks Row objects into a JSON-serializable list of rows."""
if result is not None:
"""Transform the databricks Row objects into JSON-serializable lists."""
if isinstance(result, list):
Copy link
Contributor

@utkarsharma2 utkarsharma2 Dec 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

statement = f"DESCRIBE TABLE {table.name};"
hook.run(statement, parameters=parameters, handler=lambda x: x.fetchall())

with this code hook.run() will return [[val1, val2, val3],[val1, val2, val3], [val1, val2, val3]] instead of [row(), row(), row()] isn't this a breaking change??

Copy link
Contributor

@utkarsharma2 utkarsharma2 Dec 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One possible solution is to inherit from row class and introduce deserialize and serialize method that way we can deal with this in a similar way as dataframes.

return [list(row) for row in result]
elif isinstance(result, Row):
return list(result)
return result

def bulk_dump(self, table, tmp_file):
Expand Down
14 changes: 9 additions & 5 deletions airflow/providers/odbc/hooks/odbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,16 @@ def get_sqlalchemy_connection(
return cnx

@staticmethod
def _make_serializable(result: list[pyodbc.Row] | None) -> list[NamedTuple] | None:
def _make_serializable(result: list[pyodbc.Row] | pyodbc.Row | None) -> list[NamedTuple] | None:
"""Transform the pyodbc.Row objects returned from an SQL command into JSON-serializable NamedTuple."""
if result is not None:
columns: list[tuple[str, type]] = [col[:2] for col in result[0].cursor_description]
# Below line respects NamedTuple docstring, but mypy do not support dynamically
# instantiated Namedtuple, and will never do: https://github.com/python/mypy/issues/848
# Below ignored lines respect NamedTuple docstring, but mypy do not support dynamically
# instantiated Namedtuple, and will never do: https://github.com/python/mypy/issues/848
columns: list[tuple[str, type]] | None = None
if isinstance(result, list):
columns = [col[:2] for col in result[0].cursor_description]
row_object = NamedTuple("Row", columns) # type: ignore[misc]
return [row_object(*row) for row in result]
elif isinstance(result, pyodbc.Row):
columns = [col[:2] for col in result.cursor_description]
return NamedTuple("Row", columns)(*result) # type: ignore[misc, operator]
return result
11 changes: 11 additions & 0 deletions tests/providers/databricks/hooks/test_databricks_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,17 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]:
[[[1, 2], [11, 12]], [[3, 4], [13, 14]]],
id="The return_last not set on multiple queries not set",
),
pytest.param(
True,
False,
"select * from test.test",
["select * from test.test"],
[["id", "value"]],
(Row(id=1, value=2),),
[[("id",), ("value",)]],
[1, 2],
id="The return_last set and no split statements set on single query in string",
),
],
)
def test_query(
Expand Down
56 changes: 48 additions & 8 deletions tests/providers/odbc/hooks/test_odbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,12 @@


@pytest.fixture
def mock_row():
"""
Mock a pyodbc.Row object - This is a C object that can only be created from C API of pyodbc.
def pyodbc_row_mock():
"""Mock a pyodbc.Row instantiated object.

This object is used in the tests to replace the real pyodbc.Row object.
pyodbc.Row is a C object that can only be created from C API of pyodbc.

This mock implements the two features used by the hook:
- cursor_description: which return column names and type
- __iter__: which allows exploding a row instance (*row)
Expand All @@ -59,6 +62,20 @@ def cursor_description(self):
return Row


@pytest.fixture
def pyodbc_instancecheck():
"""Mock a pyodbc.Row class which returns True to any isinstance() checks."""

class PyodbcRowMeta(type):
def __instancecheck__(self, instance):
return True

class PyodbcRow(metaclass=PyodbcRowMeta):
pass

return PyodbcRow


class TestOdbcHook:
def get_hook(self=None, hook_params=None, conn_params=None):
hook_params = hook_params or {}
Expand Down Expand Up @@ -282,14 +299,18 @@ def test_sqlalchemy_scheme_extra(self):
def test_pyodbc_mock(self):
"""Ensure that pyodbc.Row object has a `cursor_description` method.

In subsequent tests, pyodbc.Row is replaced by pure Python mock object, which implements the above
method. We want to detect any breaking change in the pyodbc object. If it fails, the 'mock_row'
needs to be updated.
In subsequent tests, pyodbc.Row is replaced by the 'pyodbc_row_mock' fixture, which implements the
`cursor_description` method. We want to detect any breaking change in the pyodbc object. If this test
fails, the 'pyodbc_row_mock' fixture needs to be updated.
"""
assert hasattr(pyodbc.Row, "cursor_description")

def test_query_return_serializable_result(self, mock_row):
pyodbc_result = [mock_row(key=1, column="value1"), mock_row(key=2, column="value2")]
def test_query_return_serializable_result_with_fetchall(self, pyodbc_row_mock):
"""
Simulate a cursor.fetchall which returns an iterable of pyodbc.Row object, and check if this iterable
get converted into a list of tuples.
"""
pyodbc_result = [pyodbc_row_mock(key=1, column="value1"), pyodbc_row_mock(key=2, column="value2")]
hook_result = [(1, "value1"), (2, "value2")]

def mock_handler(*_):
Expand All @@ -299,6 +320,25 @@ def mock_handler(*_):
result = hook.run("SQL", handler=mock_handler)
assert hook_result == result

def test_query_return_serializable_result_with_fetchone(
self, pyodbc_row_mock, monkeypatch, pyodbc_instancecheck
):
"""
Simulate a cursor.fetchone which returns one single pyodbc.Row object, and check if this object gets
converted into a tuple.
"""
pyodbc_result = pyodbc_row_mock(key=1, column="value1")
hook_result = (1, "value1")

def mock_handler(*_):
return pyodbc_result

hook = self.get_hook()
with monkeypatch.context() as patcher:
patcher.setattr("pyodbc.Row", pyodbc_instancecheck)
result = hook.run("SQL", handler=mock_handler)
assert hook_result == result

def test_query_no_handler_return_none(self):
hook = self.get_hook()
result = hook.run("SQL")
Expand Down
Loading