From 9054d992fec483681a2918a3ffcd35f769351bc4 Mon Sep 17 00:00:00 2001 From: Deepyaman Datta Date: Tue, 17 Dec 2024 06:00:16 -0700 Subject: [PATCH] fix(datasets): make `GBQTableDataset` serializable (#961) Signed-off-by: Deepyaman Datta --- kedro-datasets/RELEASE.md | 2 + .../kedro_datasets/pandas/gbq_dataset.py | 47 +++++++++++-------- .../tests/pandas/test_gbq_dataset.py | 7 +-- 3 files changed, 31 insertions(+), 25 deletions(-) diff --git a/kedro-datasets/RELEASE.md b/kedro-datasets/RELEASE.md index d7ba58b7e..518bbdbd8 100755 --- a/kedro-datasets/RELEASE.md +++ b/kedro-datasets/RELEASE.md @@ -15,6 +15,8 @@ ## Bug fixes and other changes +- Delayed backend connection for `pandas.GBQTableDataset`. In practice, this means that a dataset's connection details aren't used (or validated) until the dataset is accessed. On the plus side, the cost of connection isn't incurred regardless of when or whether the dataset is used. Furthermore, this makes the dataset object serializable (e.g. for use with `ParallelRunner`), because the unserializable client isn't part of it. +- Removed the unused BigQuery client created in `pandas.GBQQueryDataset`. This makes the dataset object serializable (e.g. for use with `ParallelRunner`) by removing the unserializable object. - Implemented Snowflake's [local testing framework](https://docs.snowflake.com/en/developer-guide/snowpark/python/testing-locally) for testing purposes. - Improved the dependency management for Spark-based datasets by refactoring the Spark and Databricks utility functions used across the datasets. - Added deprecation warning for `tracking.MetricsDataset` and `tracking.JSONDataset`. diff --git a/kedro-datasets/kedro_datasets/pandas/gbq_dataset.py b/kedro-datasets/kedro_datasets/pandas/gbq_dataset.py index a38b3b82c..8b21102d3 100644 --- a/kedro-datasets/kedro_datasets/pandas/gbq_dataset.py +++ b/kedro-datasets/kedro_datasets/pandas/gbq_dataset.py @@ -6,7 +6,7 @@ import copy from pathlib import PurePosixPath -from typing import Any, NoReturn +from typing import Any, ClassVar, NoReturn import fsspec import pandas as pd @@ -22,8 +22,10 @@ validate_on_forbidden_chars, ) +from kedro_datasets._utils import ConnectionMixin -class GBQTableDataset(AbstractDataset[None, pd.DataFrame]): + +class GBQTableDataset(ConnectionMixin, AbstractDataset[None, pd.DataFrame]): """``GBQTableDataset`` loads and saves data from/to Google BigQuery. It uses pandas-gbq to read and write from/to BigQuery table. @@ -68,6 +70,8 @@ class GBQTableDataset(AbstractDataset[None, pd.DataFrame]): DEFAULT_LOAD_ARGS: dict[str, Any] = {} DEFAULT_SAVE_ARGS: dict[str, Any] = {"progress_bar": False} + _CONNECTION_GROUP: ClassVar[str] = "bigquery" + def __init__( # noqa: PLR0913 self, *, @@ -114,18 +118,14 @@ def __init__( # noqa: PLR0913 self._validate_location() validate_on_forbidden_chars(dataset=dataset, table_name=table_name) - if isinstance(credentials, dict): - credentials = Credentials(**credentials) - self._dataset = dataset self._table_name = table_name self._project_id = project - self._credentials = credentials - self._client = bigquery.Client( - project=self._project_id, - credentials=self._credentials, - location=self._save_args.get("location"), - ) + self._connection_config = { + "project": self._project_id, + "credentials": credentials, + "location": self._save_args.get("location"), + } self.metadata = metadata @@ -137,12 +137,24 @@ def _describe(self) -> dict[str, Any]: "save_args": self._save_args, } + def _connect(self) -> bigquery.Client: + credentials = self._connection_config["credentials"] + if isinstance(credentials, dict): + # Only create `Credentials` object once for consistent hash. + credentials = Credentials(**credentials) + + return bigquery.Client( + project=self._connection_config["project"], + credentials=credentials, + location=self._connection_config["location"], + ) + def load(self) -> pd.DataFrame: sql = f"select * from {self._dataset}.{self._table_name}" # nosec self._load_args.setdefault("query_or_table", sql) return pd_gbq.read_gbq( project_id=self._project_id, - credentials=self._credentials, + credentials=self._connection._credentials, **self._load_args, ) @@ -151,14 +163,14 @@ def save(self, data: pd.DataFrame) -> None: dataframe=data, destination_table=f"{self._dataset}.{self._table_name}", project_id=self._project_id, - credentials=self._credentials, + credentials=self._connection._credentials, **self._save_args, ) def _exists(self) -> bool: - table_ref = self._client.dataset(self._dataset).table(self._table_name) + table_ref = self._connection.dataset(self._dataset).table(self._table_name) try: - self._client.get_table(table_ref) + self._connection.get_table(table_ref) return True except NotFound: return False @@ -268,11 +280,6 @@ def __init__( # noqa: PLR0913 credentials = Credentials(**credentials) self._credentials = credentials - self._client = bigquery.Client( - project=self._project_id, - credentials=self._credentials, - location=self._load_args.get("location"), - ) # load sql query from arg or from file if sql: diff --git a/kedro-datasets/tests/pandas/test_gbq_dataset.py b/kedro-datasets/tests/pandas/test_gbq_dataset.py index 19767f15b..03f7f5fab 100644 --- a/kedro-datasets/tests/pandas/test_gbq_dataset.py +++ b/kedro-datasets/tests/pandas/test_gbq_dataset.py @@ -141,6 +141,7 @@ def test_save_load_data(self, gbq_dataset, dummy_dataframe, mocker): ) mocked_read_gbq.return_value = dummy_dataframe mocked_df = mocker.Mock() + gbq_dataset._connection._credentials = None gbq_dataset.save(mocked_df) loaded_data = gbq_dataset.load() @@ -205,8 +206,8 @@ def test_credentials_propagation(self, mocker): credentials=credentials, project=PROJECT, ) + dataset.exists() # Do something to trigger the client creation. - assert dataset._credentials == credentials_obj mocked_credentials.assert_called_once_with(**credentials) mocked_bigquery.Client.assert_called_once_with( project=PROJECT, credentials=credentials_obj, location=None @@ -238,7 +239,6 @@ def test_credentials_propagation(self, mocker): "kedro_datasets.pandas.gbq_dataset.Credentials", return_value=credentials_obj, ) - mocked_bigquery = mocker.patch("kedro_datasets.pandas.gbq_dataset.bigquery") dataset = GBQQueryDataset( sql=SQL_QUERY, @@ -248,9 +248,6 @@ def test_credentials_propagation(self, mocker): assert dataset._credentials == credentials_obj mocked_credentials.assert_called_once_with(**credentials) - mocked_bigquery.Client.assert_called_once_with( - project=PROJECT, credentials=credentials_obj, location=None - ) def test_load(self, mocker, gbq_sql_dataset, dummy_dataframe): """Test `load` method invocation"""