Skip to content

Commit

Permalink
fix(datasets): Refactor TensorFlowModelDataset to DataSet (kedro-org#186
Browse files Browse the repository at this point in the history
)

* refactor TensorFlowModelDataset to Set

matching consistency of all other kedro-datasets, DataSet should be camelcase. will be reverted in 0.19.0

Signed-off-by: BrianCechmanek <brian@hazy.com>

* Introdcuing .gitpod.yml to kedro-plugins (kedro-org#185)

Currently opening gitpod will installed a Python 3.11 which breaks everything because we don't support it set. This PR introduce a simple .gitpod.yml to get it started.

Signed-off-by: BrianCechmanek <brian@hazy.com>

* sync APIDataSet  from kedro's `develop` (kedro-org#184)

* Update APIDataSet

Signed-off-by: Nok Chan <nok.lam.chan@quantumblack.com>

* Sync ParquetDataSet

Signed-off-by: Nok Chan <nok.lam.chan@quantumblack.com>

* Sync Test

Signed-off-by: Nok Chan <nok.lam.chan@quantumblack.com>

* Linting

Signed-off-by: Nok Chan <nok.lam.chan@quantumblack.com>

* Revert Unnecessary ParquetDataSet Changes

Signed-off-by: Nok Chan <nok.lam.chan@quantumblack.com>

* Sync release notes

Signed-off-by: Nok Chan <nok.lam.chan@quantumblack.com>

---------

Signed-off-by: Nok Chan <nok.lam.chan@quantumblack.com>
Signed-off-by: BrianCechmanek <brian@hazy.com>

* [kedro-datasets] Bump version of `tables` in `test_requirements.txt`  (kedro-org#182)

* bump tables version and remove step in workflow

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>

* revert version for linux

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>

* change version to 3.7

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>

* remove extra line

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>

---------

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>
Signed-off-by: BrianCechmanek <brian@hazy.com>

* refactor tensorflowModelDataset casing in datasets setup.py

Signed-off-by: BrianCechmanek <brian@hazy.com>

* add tensorflowmodeldataset bugfix to release.md

Signed-off-by: BrianCechmanek <brian@hazy.com>

* Update all the doc reference with TensorFlowModelDataSet

Signed-off-by: Nok <nok.lam.chan@quantumblack.com>

---------

Signed-off-by: BrianCechmanek <brian@hazy.com>
Signed-off-by: Nok Chan <nok.lam.chan@quantumblack.com>
Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>
Signed-off-by: Nok <nok.lam.chan@quantumblack.com>
Co-authored-by: Nok Lam Chan <mediumnok@gmail.com>
Co-authored-by: Ankita Katiyar <110245118+ankatiyar@users.noreply.github.com>
Co-authored-by: Nok <nok.lam.chan@quantumblack.com>
Signed-off-by: Danny Farah <danny_farah@mckinsey.com>
  • Loading branch information
4 people authored and dannyrfar committed May 3, 2023
1 parent fdd205c commit c0dd796
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 29 deletions.
7 changes: 7 additions & 0 deletions kedro-datasets/RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@
## Bug fixes and other changes
* Relaxed `delta-spark` upper bound to allow compatibility with Spark 3.1.x and 3.2.x.

# Release 1.2.1:

## Major features and improvements:

## Bug fixes and other changes
* Renamed `TensorFlowModelDataset` to `TensorFlowModelDataSet` to be consistent with all other plugins in kedro-datasets.

# Release 1.2.0:

## Major features and improvements:
Expand Down
8 changes: 4 additions & 4 deletions kedro-datasets/kedro_datasets/tensorflow/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# TensorFlowModelDataset
# TensorFlowModelDataSet

``TensorflowModelDataset`` loads and saves TensorFlow models.
The underlying functionality is supported by, and passes input arguments to TensorFlow 2.X load_model and save_model methods. Only TF2 is currently supported for saving and loading, V1 requires HDF5 and serialises differently.
Expand All @@ -8,9 +8,9 @@ The underlying functionality is supported by, and passes input arguments to Tens
import numpy as np
import tensorflow as tf

from kedro_datasets.tensorflow import TensorFlowModelDataset
from kedro_datasets.tensorflow import TensorFlowModelDataSet

data_set = TensorFlowModelDataset("tf_model_dirname")
data_set = TensorFlowModelDataSet("tf_model_dirname")

model = tf.keras.Model()
predictions = model.predict([...])
Expand All @@ -25,7 +25,7 @@ np.testing.assert_allclose(predictions, new_predictions, rtol=1e-6, atol=1e-6)
#### Example catalog.yml:
```yaml
example_tensorflow_data:
type: tensorflow.TensorFlowModelDataset
type: tensorflow.TensorFlowModelDataSet
filepath: data/08_reporting/tf_model_dirname
load_args:
tf_device: "/CPU:0" # optional
Expand Down
4 changes: 2 additions & 2 deletions kedro-datasets/kedro_datasets/tensorflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Provides I/O for TensorFlow Models."""

__all__ = ["TensorFlowModelDataset"]
__all__ = ["TensorFlowModelDataSet"]

from contextlib import suppress

with suppress(ImportError):
from .tensorflow_model_dataset import TensorFlowModelDataset
from .tensorflow_model_dataset import TensorFlowModelDataSet
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""``TensorflowModelDataset`` is a data set implementation which can save and load
"""``TensorFlowModelDataSet`` is a data set implementation which can save and load
TensorFlow models.
"""
import copy
Expand All @@ -19,8 +19,8 @@
TEMPORARY_H5_FILE = "tmp_tensorflow_model.h5"


class TensorFlowModelDataset(AbstractVersionedDataSet[tf.keras.Model, tf.keras.Model]):
"""``TensorflowModelDataset`` loads and saves TensorFlow models.
class TensorFlowModelDataSet(AbstractVersionedDataSet[tf.keras.Model, tf.keras.Model]):
"""``TensorFlowModelDataSet`` loads and saves TensorFlow models.
The underlying functionality is supported by, and passes input arguments through to,
TensorFlow 2.X load_model and save_model methods.
Expand All @@ -31,7 +31,7 @@ class TensorFlowModelDataset(AbstractVersionedDataSet[tf.keras.Model, tf.keras.M
.. code-block:: yaml
tensorflow_model:
type: tensorflow.TensorFlowModelDataset
type: tensorflow.TensorFlowModelDataSet
filepath: data/06_models/tensorflow_model.h5
load_args:
compile: False
Expand All @@ -45,11 +45,11 @@ class TensorFlowModelDataset(AbstractVersionedDataSet[tf.keras.Model, tf.keras.M
data_catalog.html#use-the-data-catalog-with-the-code-api>`_:
::
>>> from kedro_datasets.tensorflow import TensorFlowModelDataset
>>> from kedro_datasets.tensorflow import TensorFlowModelDataSet
>>> import tensorflow as tf
>>> import numpy as np
>>>
>>> data_set = TensorFlowModelDataset("data/06_models/tensorflow_model.h5")
>>> data_set = TensorFlowModelDataSet("data/06_models/tensorflow_model.h5")
>>> model = tf.keras.Model()
>>> predictions = model.predict([...])
>>>
Expand All @@ -73,7 +73,7 @@ def __init__(
credentials: Dict[str, Any] = None,
fs_args: Dict[str, Any] = None,
) -> None:
"""Creates a new instance of ``TensorFlowModelDataset``.
"""Creates a new instance of ``TensorFlowModelDataSet``.
Args:
filepath: Filepath in POSIX format to a TensorFlow model directory prefixed with a
Expand Down
2 changes: 1 addition & 1 deletion kedro-datasets/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def _collect_requirements(requires):
}
svmlight_require = {"svmlight.SVMLightDataSet": ["scikit-learn~=1.0.2", "scipy~=1.7.3"]}
tensorflow_require = {
"tensorflow.TensorflowModelDataset": [
"tensorflow.TensorFlowModelDataSet": [
# currently only TensorFlow V2 supported for saving and loading.
# V1 requires HDF5 and serialises differently
"tensorflow~=2.0; platform_system != 'Darwin' or platform_machine != 'arm64'",
Expand Down
30 changes: 15 additions & 15 deletions kedro-datasets/tests/tensorflow/test_tensorflow_model_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from s3fs import S3FileSystem


# In this test module, we wrap tensorflow and TensorFlowModelDataset imports into a module-scoped
# In this test module, we wrap tensorflow and TensorFlowModelDataSet imports into a module-scoped
# fixtures to avoid them being evaluated immediately when a new test process is spawned.
# Specifically:
# - ParallelRunner spawns a new subprocess.
Expand All @@ -34,9 +34,9 @@ def tf():

@pytest.fixture(scope="module")
def tensorflow_model_dataset():
from kedro_datasets.tensorflow import TensorFlowModelDataset
from kedro_datasets.tensorflow import TensorFlowModelDataSet

return TensorFlowModelDataset
return TensorFlowModelDataSet


@pytest.fixture
Expand Down Expand Up @@ -134,7 +134,7 @@ def call(self, inputs, training=None, mask=None): # pragma: no cover
return model


class TestTensorFlowModelDataset:
class TestTensorFlowModelDataSet:
"""No versioning passed to creator"""

def test_save_and_load(self, tf_model_dataset, dummy_tf_base_model, dummy_x_test):
Expand All @@ -152,7 +152,7 @@ def test_save_and_load(self, tf_model_dataset, dummy_tf_base_model, dummy_x_test
def test_load_missing_model(self, tf_model_dataset):
"""Test error message when trying to load missing model."""
pattern = (
r"Failed while loading data from data set TensorFlowModelDataset\(.*\)"
r"Failed while loading data from data set TensorFlowModelDataSet\(.*\)"
)
with pytest.raises(DataSetError, match=pattern):
tf_model_dataset.load()
Expand All @@ -166,7 +166,7 @@ def test_exists(self, tf_model_dataset, dummy_tf_base_model):
def test_hdf5_save_format(
self, dummy_tf_base_model, dummy_x_test, filepath, tensorflow_model_dataset
):
"""Test TensorflowModelDataset can save TF graph models in HDF5 format"""
"""Test TensorFlowModelDataSet can save TF graph models in HDF5 format"""
hdf5_dataset = tensorflow_model_dataset(
filepath=filepath, save_args={"save_format": "h5"}
)
Expand All @@ -187,7 +187,7 @@ def test_unused_subclass_model_hdf5_save_format(
filepath,
tensorflow_model_dataset,
):
"""Test TensorflowModelDataset cannot save subclassed user models in HDF5 format
"""Test TensorFlowModelDataSet cannot save subclassed user models in HDF5 format
Subclassed model
Expand Down Expand Up @@ -277,8 +277,8 @@ def test_save_and_overwrite_existing_model(
assert len(dummy_tf_base_model_new.layers) == len(reloaded.layers)


class TestTensorFlowModelDatasetVersioned:
"""Test suite with versioning argument passed into TensorFlowModelDataset creator"""
class TestTensorFlowModelDataSetVersioned:
"""Test suite with versioning argument passed into TensorFlowModelDataSet creator"""

@pytest.mark.parametrize(
"load_version,save_version",
Expand Down Expand Up @@ -320,7 +320,7 @@ def test_hdf5_save_format(
load_version,
save_version,
):
"""Test versioned TensorflowModelDataset can save TF graph models in
"""Test versioned TensorFlowModelDataSet can save TF graph models in
HDF5 format"""
hdf5_dataset = tensorflow_model_dataset(
filepath=filepath,
Expand All @@ -340,7 +340,7 @@ def test_prevent_overwrite(self, dummy_tf_base_model, versioned_tf_model_dataset
corresponding file for a given save version already exists."""
versioned_tf_model_dataset.save(dummy_tf_base_model)
pattern = (
r"Save path \'.+\' for TensorFlowModelDataset\(.+\) must "
r"Save path \'.+\' for TensorFlowModelDataSet\(.+\) must "
r"not exist if versioning is enabled\."
)
with pytest.raises(DataSetError, match=pattern):
Expand All @@ -362,7 +362,7 @@ def test_save_version_warning(
the subsequent load path."""
pattern = (
rf"Save version '{save_version}' did not match load version '{load_version}' "
rf"for TensorFlowModelDataset\(.+\)"
rf"for TensorFlowModelDataSet\(.+\)"
)
with pytest.warns(UserWarning, match=pattern):
versioned_tf_model_dataset.save(dummy_tf_base_model)
Expand All @@ -383,7 +383,7 @@ def test_exists(self, versioned_tf_model_dataset, dummy_tf_base_model):

def test_no_versions(self, versioned_tf_model_dataset):
"""Check the error if no versions are available for load."""
pattern = r"Did not find any versions for TensorFlowModelDataset\(.+\)"
pattern = r"Did not find any versions for TensorFlowModelDataSet\(.+\)"
with pytest.raises(DataSetError, match=pattern):
versioned_tf_model_dataset.load()

Expand All @@ -408,7 +408,7 @@ def test_versioning_existing_dataset(
self, tf_model_dataset, versioned_tf_model_dataset, dummy_tf_base_model
):
"""Check behavior when attempting to save a versioned dataset on top of an
already existing (non-versioned) dataset. Note: because TensorFlowModelDataset
already existing (non-versioned) dataset. Note: because TensorFlowModelDataSet
saves to a directory even if non-versioned, an error is not expected."""
tf_model_dataset.save(dummy_tf_base_model)
assert tf_model_dataset.exists()
Expand All @@ -425,7 +425,7 @@ def test_save_and_load_with_device(
load_version,
save_version,
):
"""Test versioned TensorflowModelDataset can load models using an explicit tf_device"""
"""Test versioned TensorFlowModelDataSet can load models using an explicit tf_device"""
hdf5_dataset = tensorflow_model_dataset(
filepath=filepath,
load_args={"tf_device": "/CPU:0"},
Expand Down

0 comments on commit c0dd796

Please sign in to comment.