diff --git a/docs/overview_2_features.rst b/docs/overview_2_features.rst index 3c51eac..6959a37 100644 --- a/docs/overview_2_features.rst +++ b/docs/overview_2_features.rst @@ -86,6 +86,24 @@ Refer to the :doc:`api` section to discover the methods currently available. Please note that ``xeofs`` is in its developmental phase. If there's a specific method you'd like to see included, we encourage you to open an issue on `GitHub`_. +Model Serialization +------------------- + +``xeofs`` models offer convenient ``save()`` and ``load()`` methods for serializing +fitted models to a portable format. + +.. code-block:: python + + from xeofs.models import EOF + + model = EOF() + model.fit(data, dim="time") + model.save("my_model.zarr") + + # Later, you can load the model + loaded_model = EOF.load("my_model.zarr") + + Input Data Compatibility ------------------------ diff --git a/poetry.lock b/poetry.lock index 0732ab3..e7a9222 100644 --- a/poetry.lock +++ b/poetry.lock @@ -25,6 +25,16 @@ files = [ {file = "alabaster-0.7.13.tar.gz", hash = "sha256:a27a4a084d5e690e16e01e03ad2b2e552c61a65469419b907243193de1a84ae2"}, ] +[[package]] +name = "asciitree" +version = "0.3.3" +description = "Draws ASCII trees." +optional = false +python-versions = "*" +files = [ + {file = "asciitree-0.3.3.tar.gz", hash = "sha256:4aa4b9b649f85e3fcb343363d97564aa1fb62e249677f2e18a96765145cc0f6e"}, +] + [[package]] name = "attrs" version = "23.1.0" @@ -719,6 +729,17 @@ files = [ [package.extras] test = ["pytest (>=6)"] +[[package]] +name = "fasteners" +version = "0.19" +description = "A python package that provides useful locks" +optional = false +python-versions = ">=3.6" +files = [ + {file = "fasteners-0.19-py3-none-any.whl", hash = "sha256:758819cb5d94cdedf4e836988b74de396ceacb8e2794d21f82d131fd9ee77237"}, + {file = "fasteners-0.19.tar.gz", hash = "sha256:b4f37c3ac52d8a445af3a66bce57b33b5e90b97c696b7b984f530cf8f0ded09c"}, +] + [[package]] name = "fastjsonschema" version = "2.18.1" @@ -1436,6 +1457,46 @@ files = [ llvmlite = "==0.40.*" numpy = ">=1.21,<1.25" +[[package]] +name = "numcodecs" +version = "0.12.1" +description = "A Python package providing buffer compression and transformation codecs for use in data storage and communication applications." +optional = false +python-versions = ">=3.8" +files = [ + {file = "numcodecs-0.12.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d37f628fe92b3699e65831d5733feca74d2e33b50ef29118ffd41c13c677210e"}, + {file = "numcodecs-0.12.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:941b7446b68cf79f089bcfe92edaa3b154533dcbcd82474f994b28f2eedb1c60"}, + {file = "numcodecs-0.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0e79bf9d1d37199ac00a60ff3adb64757523291d19d03116832e600cac391c51"}, + {file = "numcodecs-0.12.1-cp310-cp310-win_amd64.whl", hash = "sha256:82d7107f80f9307235cb7e74719292d101c7ea1e393fe628817f0d635b7384f5"}, + {file = "numcodecs-0.12.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:eeaf42768910f1c6eebf6c1bb00160728e62c9343df9e2e315dc9fe12e3f6071"}, + {file = "numcodecs-0.12.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:135b2d47563f7b9dc5ee6ce3d1b81b0f1397f69309e909f1a35bb0f7c553d45e"}, + {file = "numcodecs-0.12.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a191a8e347ecd016e5c357f2bf41fbcb026f6ffe78fff50c77ab12e96701d155"}, + {file = "numcodecs-0.12.1-cp311-cp311-win_amd64.whl", hash = "sha256:21d8267bd4313f4d16f5b6287731d4c8ebdab236038f29ad1b0e93c9b2ca64ee"}, + {file = "numcodecs-0.12.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:2f84df6b8693206365a5b37c005bfa9d1be486122bde683a7b6446af4b75d862"}, + {file = "numcodecs-0.12.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:760627780a8b6afdb7f942f2a0ddaf4e31d3d7eea1d8498cf0fd3204a33c4618"}, + {file = "numcodecs-0.12.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c258bd1d3dfa75a9b708540d23b2da43d63607f9df76dfa0309a7597d1de3b73"}, + {file = "numcodecs-0.12.1-cp312-cp312-win_amd64.whl", hash = "sha256:e04649ea504aff858dbe294631f098fbfd671baf58bfc04fc48d746554c05d67"}, + {file = "numcodecs-0.12.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:caf1a1e6678aab9c1e29d2109b299f7a467bd4d4c34235b1f0e082167846b88f"}, + {file = "numcodecs-0.12.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c17687b1fd1fef68af616bc83f896035d24e40e04e91e7e6dae56379eb59fe33"}, + {file = "numcodecs-0.12.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:29dfb195f835a55c4d490fb097aac8c1bcb96c54cf1b037d9218492c95e9d8c5"}, + {file = "numcodecs-0.12.1-cp38-cp38-win_amd64.whl", hash = "sha256:2f1ba2f4af3fd3ba65b1bcffb717fe65efe101a50a91c368f79f3101dbb1e243"}, + {file = "numcodecs-0.12.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:2fbb12a6a1abe95926f25c65e283762d63a9bf9e43c0de2c6a1a798347dfcb40"}, + {file = "numcodecs-0.12.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f2207871868b2464dc11c513965fd99b958a9d7cde2629be7b2dc84fdaab013b"}, + {file = "numcodecs-0.12.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:abff3554a6892a89aacf7b642a044e4535499edf07aeae2f2e6e8fc08c9ba07f"}, + {file = "numcodecs-0.12.1-cp39-cp39-win_amd64.whl", hash = "sha256:ef964d4860d3e6b38df0633caf3e51dc850a6293fd8e93240473642681d95136"}, + {file = "numcodecs-0.12.1.tar.gz", hash = "sha256:05d91a433733e7eef268d7e80ec226a0232da244289614a8f3826901aec1098e"}, +] + +[package.dependencies] +numpy = ">=1.7" + +[package.extras] +docs = ["mock", "numpydoc", "sphinx (<7.0.0)", "sphinx-issues"] +msgpack = ["msgpack"] +test = ["coverage", "flake8", "pytest", "pytest-cov"] +test-extras = ["importlib-metadata"] +zfpy = ["zfpy (>=1.0.0)"] + [[package]] name = "numpy" version = "1.24.4" @@ -2961,6 +3022,21 @@ io = ["cftime", "fsspec", "h5netcdf", "netCDF4", "pooch", "pydap", "scipy", "zar parallel = ["dask[complete]"] viz = ["matplotlib", "nc-time-axis", "seaborn"] +[[package]] +name = "xarray-datatree" +version = "0.0.13" +description = "Hierarchical tree-like data structures for xarray" +optional = false +python-versions = ">=3.9" +files = [ + {file = "xarray-datatree-0.0.13.tar.gz", hash = "sha256:f42bd519cab8754eb8a98749464846893b59560318520c45212e85c46af692c9"}, + {file = "xarray_datatree-0.0.13-py3-none-any.whl", hash = "sha256:b5c92339339e58f029107fd3c50478adb1dfd1316eaa628d1e0e2e8a3e7a079a"}, +] + +[package.dependencies] +packaging = "*" +xarray = ">=2022.6.0" + [[package]] name = "xattr" version = "0.10.1" @@ -3045,6 +3121,27 @@ files = [ [package.dependencies] cffi = ">=1.0" +[[package]] +name = "zarr" +version = "2.16.1" +description = "An implementation of chunked, compressed, N-dimensional arrays for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "zarr-2.16.1-py3-none-any.whl", hash = "sha256:de4882433ccb5b42cc1ec9872b95e64ca3a13581424666b28ed265ad76c7056f"}, + {file = "zarr-2.16.1.tar.gz", hash = "sha256:4276cf4b4a653431042cd53ff2282bc4d292a6842411e88529964504fb073286"}, +] + +[package.dependencies] +asciitree = "*" +fasteners = "*" +numcodecs = ">=0.10.0" +numpy = ">=1.20,<1.21.0 || >1.21.0" + +[package.extras] +docs = ["numcodecs[msgpack]", "numpydoc", "pydata-sphinx-theme", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx-issues", "sphinx-rtd-theme"] +jupyter = ["ipytree (>=0.2.2)", "ipywidgets (>=8.0.0)", "notebook"] + [[package]] name = "zipp" version = "3.17.0" @@ -3063,4 +3160,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.12" -content-hash = "498459979f01cb42a137694dc9e99ec56fe8cc644569416dab73f21b3956c164" +content-hash = "d6a8f52cdfedd044cf44cd017d5cf8d802c180f9355b9e3fd1007e2b06f7abdc" diff --git a/pyproject.toml b/pyproject.toml index 42462a9..73101b8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,14 @@ netCDF4 = "^1.5.7" numba = "^0.57" typing-extensions = "^4.8.0" + +[tool.poetry.group.io] +optional = false + +[tool.poetry.group.io.dependencies] +zarr = ">=2.0.0" +xarray-datatree = ">=0.0.5" + [tool.poetry.group.dev] optional = true diff --git a/tests/models/test_eof.py b/tests/models/test_eof.py index 37c49f6..b9734e1 100644 --- a/tests/models/test_eof.py +++ b/tests/models/test_eof.py @@ -442,3 +442,51 @@ def test_inverse_transform(dim, mock_data_array): # Check that the reconstructed data has the same dimensions as the original data assert set(reconstructed_data.dims) == set(mock_data_array.dims) + + +@pytest.mark.parametrize( + "dim", + [ + (("time",)), + (("lat", "lon")), + (("lon", "lat")), + ], +) +def test_save_load(dim, mock_data_array, tmp_path): + """Test save/load methods in EOF class, ensuring that we can + roundtrip the model and get the same results when transforming + data.""" + original = EOF() + original.fit(mock_data_array, dim) + + # Save the EOF model + original.save(tmp_path / "eof.zarr") + + # Check that the EOF model has been saved + assert (tmp_path / "eof.zarr").exists() + + # Recreate the model from saved file + loaded = EOF.load(tmp_path / "eof.zarr") + + # Check that the params and DataContainer objects match + assert original.get_params() == loaded.get_params() + assert all([key in loaded.data for key in original.data]) + assert all( + [ + loaded.data._allow_compute[key] == original.data._allow_compute[key] + for key in original.data + ] + ) + + # Test that the recreated model can be used to transform new data + assert np.allclose( + original.scores(), loaded.transform(mock_data_array), rtol=1e-3, atol=1e-3 + ) + + # Enhancement: the loaded model should also be able to inverse_transform new data + # assert np.allclose( + # original.inverse_transform(original.scores()), + # loaded.inverse_transform(loaded.scores()), + # rtol=1e-3, + # atol=1e-3, + # ) diff --git a/tests/models/test_eof_rotator.py b/tests/models/test_eof_rotator.py index 3fb26eb..e262e00 100644 --- a/tests/models/test_eof_rotator.py +++ b/tests/models/test_eof_rotator.py @@ -192,3 +192,53 @@ def test_compute(eof_model_delayed, compute): assert data_is_dask(eof_rotator.data["explained_variance"]) assert data_is_dask(eof_rotator.data["components"]) assert data_is_dask(eof_rotator.data["rotation_matrix"]) + + +@pytest.mark.parametrize( + "dim", + [ + (("time",)), + (("lat", "lon")), + (("lon", "lat")), + ], +) +def test_save_load(dim, mock_data_array, tmp_path): + """Test save/load methods in EOF class, ensuring that we can + roundtrip the model and get the same results when transforming + data.""" + original_unrotated = EOF() + original_unrotated.fit(mock_data_array, dim) + + original = EOFRotator() + original.fit(original_unrotated) + + # Save the EOF model + original.save(tmp_path / "eof.zarr") + + # Check that the EOF model has been saved + assert (tmp_path / "eof.zarr").exists() + + # Recreate the model from saved file + loaded = EOFRotator.load(tmp_path / "eof.zarr") + + # Check that the params and DataContainer objects match + assert original.get_params() == loaded.get_params() + assert all([key in loaded.data for key in original.data]) + assert all( + [ + loaded.data._allow_compute[key] == original.data._allow_compute[key] + for key in original.data + ] + ) + + # Test that the recreated model can be used to transform new data + assert np.allclose( + original.scores(), loaded.transform(mock_data_array), rtol=1e-3, atol=1e-3 + ) + + # Enhancement: the loaded model should also be able to inverse_transform new data + # assert np.allclose( + # original.inverse_transform(original.scores()), + # loaded.inverse_transform(loaded.scores()), + # rtol=1e-2, + # ) diff --git a/tests/models/test_mca.py b/tests/models/test_mca.py index afa4c51..82ccc13 100644 --- a/tests/models/test_mca.py +++ b/tests/models/test_mca.py @@ -381,3 +381,54 @@ def test_compute(mock_dask_data_array, dim, compute): assert data_is_dask(mca_model.data["squared_covariance"]) assert data_is_dask(mca_model.data["components1"]) assert data_is_dask(mca_model.data["components2"]) + + +@pytest.mark.parametrize( + "dim", + [ + (("time",)), + (("lat", "lon")), + (("lon", "lat")), + ], +) +def test_save_load(dim, mock_data_array, tmp_path): + """Test save/load methods in MCA class, ensuring that we can + roundtrip the model and get the same results when transforming + data.""" + original = MCA() + original.fit(mock_data_array, mock_data_array, dim) + + # Save the EOF model + original.save(tmp_path / "mca.zarr") + + # Check that the EOF model has been saved + assert (tmp_path / "mca.zarr").exists() + + # Recreate the model from saved file + loaded = MCA.load(tmp_path / "mca.zarr") + + # Check that the params and DataContainer objects match + assert original.get_params() == loaded.get_params() + assert all([key in loaded.data for key in original.data]) + assert all( + [ + loaded.data._allow_compute[key] == original.data._allow_compute[key] + for key in original.data + ] + ) + + # Test that the recreated model can be used to transform new data + assert np.allclose( + original.scores(), + loaded.transform(mock_data_array, mock_data_array), + rtol=1e-3, + atol=1e-3, + ) + + # Enhancement: the loaded model should also be able to inverse_transform new data + # assert np.allclose( + # original.inverse_transform(original.scores()), + # loaded.inverse_transform(loaded.scores()), + # rtol=1e-3, + # atol=1e-3, + # ) diff --git a/tests/models/test_mca_rotator.py b/tests/models/test_mca_rotator.py index 0cd9c59..b0bd1b9 100644 --- a/tests/models/test_mca_rotator.py +++ b/tests/models/test_mca_rotator.py @@ -249,3 +249,57 @@ def test_compute(mca_model_delayed, compute): assert data_is_dask(mca_rotator.data["norm1"]) assert data_is_dask(mca_rotator.data["norm2"]) assert data_is_dask(mca_rotator.data["modes_sign"]) + + +@pytest.mark.parametrize( + "dim", + [ + (("time",)), + (("lat", "lon")), + (("lon", "lat")), + ], +) +def test_save_load(dim, mock_data_array, tmp_path): + """Test save/load methods in MCA class, ensuring that we can + roundtrip the model and get the same results when transforming + data.""" + original_unrotated = MCA() + original_unrotated.fit(mock_data_array, mock_data_array, dim) + + original = MCARotator() + original.fit(original_unrotated) + + # Save the EOF model + original.save(tmp_path / "mca.zarr") + + # Check that the EOF model has been saved + assert (tmp_path / "mca.zarr").exists() + + # Recreate the model from saved file + loaded = MCARotator.load(tmp_path / "mca.zarr") + + # Check that the params and DataContainer objects match + assert original.get_params() == loaded.get_params() + assert all([key in loaded.data for key in original.data]) + assert all( + [ + loaded.data._allow_compute[key] == original.data._allow_compute[key] + for key in original.data + ] + ) + + # Test that the recreated model can be used to transform new data + assert np.allclose( + original.scores(), + loaded.transform(data1=mock_data_array, data2=mock_data_array), + rtol=1e-3, + atol=1e-3, + ) + + # Enhancement: the loaded model should also be able to inverse_transform new data + # assert np.allclose( + # original.inverse_transform(original.scores()), + # loaded.inverse_transform(loaded.scores()), + # rtol=1e-3, + # atol=1e-3, + # ) diff --git a/tests/preprocessing/test_dataarray_stacker.py b/tests/preprocessing/test_dataarray_stacker.py index 5882f13..1eb531e 100644 --- a/tests/preprocessing/test_dataarray_stacker.py +++ b/tests/preprocessing/test_dataarray_stacker.py @@ -2,7 +2,7 @@ import numpy as np import xarray as xr -from xeofs.preprocessing.stacker import DataArrayStacker +from xeofs.preprocessing import Stacker from xeofs.utils.data_types import DataArray from ..conftest import generate_synthetic_dataarray from ..utilities import ( @@ -50,7 +50,7 @@ def test_fit_valid_dimension_names(sample_name, feature_name, data_params): data = generate_synthetic_dataarray(*data_params) all_dims, sample_dims, feature_dims = get_dims_from_data(data) - stacker = DataArrayStacker(sample_name=sample_name, feature_name=feature_name) + stacker = Stacker(sample_name=sample_name, feature_name=feature_name) stacker.fit(data, sample_dims, feature_dims) stacked_data = stacker.transform(data) reconstructed_data = stacker.inverse_transform_data(stacked_data) @@ -72,7 +72,7 @@ def test_fit_invalid_dimension_names(sample_name, feature_name, data_params): data = generate_synthetic_dataarray(*data_params) all_dims, sample_dims, feature_dims = get_dims_from_data(data) - stacker = DataArrayStacker(sample_name=sample_name, feature_name=feature_name) + stacker = Stacker(sample_name=sample_name, feature_name=feature_name) with pytest.raises(ValueError): stacker.fit(data, sample_dims, feature_dims) @@ -87,7 +87,7 @@ def test_fit(synthetic_dataarray): data = synthetic_dataarray all_dims, sample_dims, feature_dims = get_dims_from_data(data) - stacker = DataArrayStacker() + stacker = Stacker() stacker.fit(data, sample_dims, feature_dims) @@ -100,7 +100,7 @@ def test_transform(synthetic_dataarray): data = synthetic_dataarray all_dims, sample_dims, feature_dims = get_dims_from_data(data) - stacker = DataArrayStacker() + stacker = Stacker() stacker.fit(data, sample_dims, feature_dims) transformed_data = stacker.transform(data) transformed_data2 = stacker.transform(data) @@ -124,7 +124,7 @@ def test_transform_invalid(synthetic_dataarray): data = synthetic_dataarray all_dims, sample_dims, feature_dims = get_dims_from_data(data) - stacker = DataArrayStacker() + stacker = Stacker() stacker.fit(data, sample_dims, feature_dims) with pytest.raises(ValueError): stacker.transform(data.isel(feature0=slice(0, 2))) @@ -139,7 +139,7 @@ def test_fit_transform(synthetic_dataarray): data = synthetic_dataarray all_dims, sample_dims, feature_dims = get_dims_from_data(data) - stacker = DataArrayStacker() + stacker = Stacker() transformed_data = stacker.fit_transform(data, sample_dims, feature_dims) is_dask_before = data_is_dask(data) @@ -160,7 +160,7 @@ def test_invserse_transform_data(synthetic_dataarray): data = synthetic_dataarray all_dims, sample_dims, feature_dims = get_dims_from_data(data) - stacker = DataArrayStacker() + stacker = Stacker() stacker.fit(data, sample_dims, feature_dims) stacked_data = stacker.transform(data) unstacked_data = stacker.inverse_transform_data(stacked_data) @@ -185,7 +185,7 @@ def test_invserse_transform_components(synthetic_dataarray): data: DataArray = synthetic_dataarray all_dims, sample_dims, feature_dims = get_dims_from_data(data) - stacker = DataArrayStacker() + stacker = Stacker() stacker.fit(data, sample_dims, feature_dims) stacked_data = stacker.transform(data) @@ -212,7 +212,7 @@ def test_invserse_transform_scores(synthetic_dataarray): data: DataArray = synthetic_dataarray all_dims, sample_dims, feature_dims = get_dims_from_data(data) - stacker = DataArrayStacker() + stacker = Stacker() stacker.fit(data, sample_dims, feature_dims) stacked_data = stacker.transform(data) diff --git a/tests/preprocessing/test_dataset_stacker.py b/tests/preprocessing/test_dataset_stacker.py index e35d0e1..5eed6c2 100644 --- a/tests/preprocessing/test_dataset_stacker.py +++ b/tests/preprocessing/test_dataset_stacker.py @@ -2,7 +2,7 @@ import xarray as xr import numpy as np -from xeofs.preprocessing.stacker import DataSetStacker +from xeofs.preprocessing import Stacker from xeofs.utils.data_types import DataSet, DataArray from ..conftest import generate_synthetic_dataset from ..utilities import ( @@ -56,7 +56,7 @@ def test_fit_valid_dimension_names(sample_name, feature_name, data_params): data = generate_synthetic_dataset(*data_params) all_dims, sample_dims, feature_dims = get_dims_from_data(data) - stacker = DataSetStacker(sample_name=sample_name, feature_name=feature_name) + stacker = Stacker(sample_name=sample_name, feature_name=feature_name) stacker.fit(data, sample_dims, feature_dims) stacked_data = stacker.transform(data) reconstructed_data = stacker.inverse_transform_data(stacked_data) @@ -81,7 +81,7 @@ def test_fit_invalid_dimension_names(sample_name, feature_name, data_params): data = generate_synthetic_dataset(*data_params) all_dims, sample_dims, feature_dims = get_dims_from_data(data) - stacker = DataSetStacker(sample_name=sample_name, feature_name=feature_name) + stacker = Stacker(sample_name=sample_name, feature_name=feature_name) with pytest.raises(ValueError): stacker.fit(data, sample_dims, feature_dims) @@ -96,7 +96,7 @@ def test_fit(synthetic_dataset): data = synthetic_dataset all_dims, sample_dims, feature_dims = get_dims_from_data(data) - stacker = DataSetStacker() + stacker = Stacker() stacker.fit(data, sample_dims, feature_dims) @@ -109,7 +109,7 @@ def test_transform(synthetic_dataset): data = synthetic_dataset all_dims, sample_dims, feature_dims = get_dims_from_data(data) - stacker = DataSetStacker() + stacker = Stacker() stacker.fit(data, sample_dims, feature_dims) transformed_data = stacker.transform(data) transformed_data2 = stacker.transform(data) @@ -133,7 +133,7 @@ def test_transform_invalid(synthetic_dataset): data = synthetic_dataset all_dims, sample_dims, feature_dims = get_dims_from_data(data) - stacker = DataSetStacker() + stacker = Stacker() stacker.fit(data, sample_dims, feature_dims) with pytest.raises(ValueError): stacker.transform(data.isel(feature0=slice(0, 2))) @@ -148,7 +148,7 @@ def test_fit_transform(synthetic_dataset): data = synthetic_dataset all_dims, sample_dims, feature_dims = get_dims_from_data(data) - stacker = DataSetStacker() + stacker = Stacker() transformed_data = stacker.fit_transform(data, sample_dims, feature_dims) is_dask_before = data_is_dask(data) @@ -169,7 +169,7 @@ def test_invserse_transform_data(synthetic_dataset): data = synthetic_dataset all_dims, sample_dims, feature_dims = get_dims_from_data(data) - stacker = DataSetStacker() + stacker = Stacker() stacker.fit(data, sample_dims, feature_dims) stacked_data = stacker.transform(data) unstacked_data = stacker.inverse_transform_data(stacked_data) @@ -194,7 +194,7 @@ def test_invserse_transform_components(synthetic_dataset): data = synthetic_dataset all_dims, sample_dims, feature_dims = get_dims_from_data(data) - stacker = DataSetStacker() + stacker = Stacker() stacker.fit(data, sample_dims, feature_dims) stacked_data = stacker.transform(data) @@ -222,7 +222,7 @@ def test_invserse_transform_scores(synthetic_dataset): data = synthetic_dataset all_dims, sample_dims, feature_dims = get_dims_from_data(data) - stacker = DataSetStacker() + stacker = Stacker() stacker.fit(data, sample_dims, feature_dims) stacked_data = stacker.transform(data) diff --git a/xeofs/models/_base_cross_model.py b/xeofs/models/_base_cross_model.py index 5075e6c..8a0b678 100644 --- a/xeofs/models/_base_cross_model.py +++ b/xeofs/models/_base_cross_model.py @@ -3,10 +3,13 @@ from abc import ABC, abstractmethod from datetime import datetime +import xarray as xr +from datatree import DataTree, open_datatree + from .eof import EOF from ..preprocessing.preprocessor import Preprocessor from ..data_container import DataContainer -from ..utils.data_types import DataObject, DataArray +from ..utils.data_types import DataObject, DataArray, DataSet from ..utils.xarray_utils import convert_to_dim_type from ..utils.sanity_checks import validate_input_type from .._version import __version__ @@ -248,3 +251,109 @@ def compute(self, verbose: bool = False): def get_params(self) -> Dict: """Get the model parameters.""" return self._params + + def serialize(self, save_data: bool = False) -> DataTree: + """Serialize a complete model with its preprocessors.""" + data = {} + for key, x in self.data.items(): + if self.data._allow_compute[key] or save_data: + data[key] = x.assign_attrs( + {"allow_compute": self.data._allow_compute[key]} + ) + else: + # create an empty placeholder array + data[key] = xr.DataArray().assign_attrs( + {"allow_compute": False, "placeholder": True} + ) + + # Store the DataContainer items as data_vars, and the model parameters as global attrs + ds_model = xr.Dataset(data, attrs=dict(params=self.get_params())) + # Set as the root node of the tree + dt = DataTree(data=ds_model, name=type(self).__name__) + + # Retrieve the tree representation of the preprocessor + dt["preprocessor1"] = self.preprocessor1.serialize_all() + dt["preprocessor2"] = self.preprocessor2.serialize_all() + dt.preprocessor1.parent = dt + dt.preprocessor2.parent = dt + + return dt + + def save( + self, + path: str, + save_data: bool = False, + **kwargs, + ): + """Save the model to zarr. + + Parameters + ---------- + path : str + Path to save the model zarr store. + save_data : str + Whether or not to save the full input data along with the fitted components. + **kwargs + Additional keyword arguments to pass to `DataTree.to_zarr()`. + + """ + dt = self.serialize(save_data=save_data) + + # Handle rotator models by separately serializing the original model + # and attaching as a child of the rotator model + if hasattr(self, "model"): + dt["model"] = self.model.serialize() + + dt.to_zarr(path, **kwargs) + + @classmethod + def deserialize(cls, dt: DataTree) -> Self: + """Deserialize the model and its preprocessors from a DataTree.""" + # Recreate the model with parameters set by root level attrs + model = cls(**dt.attrs["params"]) + + # Recreate the Preprocessors from their trees + model.preprocessor1 = Preprocessor.deserialize_all(dt.preprocessor1) + model.preprocessor2 = Preprocessor.deserialize_all(dt.preprocessor2) + + # Create the model's DataContainer from the root level data_vars + model.data = DataContainer({k: dt[k] for k in dt.data_vars}) + + return model + + @classmethod + def load(cls, path: str, **kwargs) -> Self: + """Load a saved model from zarr. + + Parameters + ---------- + path : str + Path to the saved model zarr store. + **kwargs + Additional keyword arguments to pass to `open_datatree()`. + + Returns + ------- + model : _BaseCrossModel + The loaded model. + + """ + dt = open_datatree(path, engine="zarr", **kwargs) + + model = cls.deserialize(dt) + + # Rebuild any attached model + if dt.get("model") is not None: + # Recreate the original model from its tree, assuming we should + # use the first base class of the current model + model.model = cls.__bases__[0].deserialize(dt.model) + + for key in model.data.keys(): + model.data._allow_compute[key] = model.data[key].attrs["allow_compute"] + model._validate_loaded_data(model.data[key]) + + return model + + def _validate_loaded_data(self, data: DataArray): + """Optionally check the loaded data for placeholders.""" + pass diff --git a/xeofs/models/_base_model.py b/xeofs/models/_base_model.py index e932d10..147072f 100644 --- a/xeofs/models/_base_model.py +++ b/xeofs/models/_base_model.py @@ -1,11 +1,21 @@ import warnings -from typing import Optional, Sequence, Hashable, Dict, Any, List, TypeVar, Tuple +from typing import ( + Optional, + Sequence, + Hashable, + Dict, + Any, + List, + TypeVar, + Tuple, +) from typing_extensions import Self from abc import ABC, abstractmethod from datetime import datetime import numpy as np import xarray as xr +from datatree import DataTree, open_datatree from ..preprocessing.preprocessor import Preprocessor from ..data_container import DataContainer @@ -221,7 +231,7 @@ def fit_transform( data: List[Data] | Data, dim: Sequence[Hashable] | Hashable, weights: Optional[List[Data] | Data] = None, - **kwargs + **kwargs, ) -> DataArray: """Fit the model to the input data and project the data onto the components. @@ -322,3 +332,106 @@ def compute(self, verbose: bool = False): def get_params(self) -> Dict[str, Any]: """Get the model parameters.""" return self._params + + def serialize(self, save_data: bool = False) -> DataTree: + """Serialize a complete model with its preprocessor.""" + data = {} + for key, x in self.data.items(): + if self.data._allow_compute[key] or save_data: + data[key] = x.assign_attrs( + {"allow_compute": self.data._allow_compute[key]} + ) + else: + # create an empty placeholder array + data[key] = xr.DataArray().assign_attrs( + {"allow_compute": False, "placeholder": True} + ) + + # Store the DataContainer items as data_vars, and the model parameters as global attrs + ds_model = xr.Dataset(data, attrs=dict(params=self.get_params())) + # Set as the root node of the tree + dt = DataTree(data=ds_model, name=type(self).__name__) + + # Retrieve the tree representation of the preprocessor + dt["preprocessor"] = self.preprocessor.serialize_all() + dt.preprocessor.parent = dt + + return dt + + def save( + self, + path: str, + save_data: bool = False, + **kwargs, + ): + """Save the model to zarr. + + Parameters + ---------- + path : str + Path to save the model zarr store. + save_data : str + Whether or not to save the full input data along with the fitted components. + **kwargs + Additional keyword arguments to pass to `DataTree.to_zarr()`. + + """ + dt = self.serialize(save_data=save_data) + + # Handle rotator models by separately serializing the original model + # and attaching as a child of the rotator model + if hasattr(self, "model"): + dt["model"] = self.model.serialize(save_data=save_data) + + dt.to_zarr(path, **kwargs) + + @classmethod + def deserialize(cls, dt: DataTree) -> Self: + """Deserialize the model and its preprocessors from a DataTree.""" + # Recreate the model with parameters set by root level attrs + model = cls(**dt.attrs["params"]) + + # Recreate the Preprocessor from its tree + model.preprocessor = Preprocessor.deserialize_all(dt.preprocessor) + + # Create the model's DataContainer from the root level data_vars + model.data = DataContainer({k: dt[k] for k in dt.data_vars}) + + return model + + @classmethod + def load(cls, path: str, **kwargs) -> Self: + """Load a saved model from zarr. + + Parameters + ---------- + path : str + Path to the saved model zarr store. + **kwargs + Additional keyword arguments to pass to `open_datatree()`. + + Returns + ------- + model : _BaseModel + The loaded model. + + """ + dt = open_datatree(path, engine="zarr", **kwargs) + + model = cls.deserialize(dt) + + # Rebuild any attached model + if dt.get("model") is not None: + # Recreate the original model from its tree, assuming we should + # use the first base class of the current model + model.model = cls.__bases__[0].deserialize(dt.model) + + for key in model.data.keys(): + model.data._allow_compute[key] = model.data[key].attrs["allow_compute"] + model._validate_loaded_data(model.data[key]) + + return model + + def _validate_loaded_data(self, data: DataArray): + """Optionally check the loaded data for placeholders.""" + pass diff --git a/xeofs/models/mca.py b/xeofs/models/mca.py index c2ebebc..7ac0f49 100644 --- a/xeofs/models/mca.py +++ b/xeofs/models/mca.py @@ -1,3 +1,4 @@ +import warnings from typing import Tuple, Optional, Sequence, Dict from typing_extensions import Self @@ -544,6 +545,17 @@ def heterogeneous_patterns(self, correction=None, alpha=0.05): return (patterns1, patterns2), (pvals1, pvals2) + def _validate_loaded_data(self, data: xr.DataArray): + if data.attrs.get("placeholder"): + warnings.warn( + f"The input data field '{data.name}' was not saved, which will produce" + " empty results when calling `homogeneous_patterns()` or " + "`heterogeneous_patterns()`. To avoid this warning, you can save the" + " model with `save_data=True`, or add the data manually by running" + " it through the model's `preprocessor.transform()` method and then" + " attaching it with `data.add()`." + ) + class ComplexMCA(MCA): """Complex MCA. diff --git a/xeofs/preprocessing/__init__.py b/xeofs/preprocessing/__init__.py index 54b31f3..ac6babf 100644 --- a/xeofs/preprocessing/__init__.py +++ b/xeofs/preprocessing/__init__.py @@ -1,12 +1,15 @@ from .scaler import Scaler from .sanitizer import Sanitizer from .multi_index_converter import MultiIndexConverter -from .stacker import DataArrayStacker, DataSetStacker +from .stacker import Stacker +from .concatenator import Concatenator +from .dimension_renamer import DimensionRenamer __all__ = [ "Scaler", "Sanitizer", "MultiIndexConverter", - "DataArrayStacker", - "DataSetStacker", + "Stacker", + "Concatenator", + "DimensionRenamer", ] diff --git a/xeofs/preprocessing/concatenator.py b/xeofs/preprocessing/concatenator.py index 7110de2..0c6f514 100644 --- a/xeofs/preprocessing/concatenator.py +++ b/xeofs/preprocessing/concatenator.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List, Optional, Dict from typing_extensions import Self import pandas as pd @@ -26,7 +26,16 @@ class Concatenator(Transformer): def __init__(self, sample_name: str = "sample", feature_name: str = "feature"): super().__init__(sample_name, feature_name) - self.stackers = [] + self.n_data = None + self.n_features = [] + self.coords_in = {} + + def get_serialization_attrs(self) -> Dict: + return dict( + n_data=self.n_data, + n_features=self.n_features, + coords_in=self.coords_in, + ) def fit( self, @@ -48,9 +57,11 @@ def fit( self.n_data = len(X) - # Set input feature coordinates - self.coords_in = [data.coords[self.feature_name] for data in X] - self.n_features = [coord.size for coord in self.coords_in] + # Set input feature coordinates, using dict for easier serialization + self.coords_in = { + str(i): data.coords[self.feature_name] for i, data in enumerate(X) + } + self.n_features = [coord.size for coord in self.coords_in.values()] return self @@ -62,23 +73,18 @@ def transform(self, X: List[DataArray]) -> DataArray: ) reindexed_data_list: List[DataArray] = [] - dummy_feature_coords = [] idx_range = np.cumsum([0] + self.n_features) for i, data in enumerate(X): # Create dummy feature coordinates for DataArray new_coords = np.arange(idx_range[i], idx_range[i + 1]) - # Replace original feature coordiantes with dummy coordinates + # Replace original feature coordinates with dummy coordinates data = data.drop_vars(self.feature_name) reindexed = data.assign_coords({self.feature_name: new_coords}) - # Store dummy feature coordinates - dummy_feature_coords.append(new_coords) reindexed_data_list.append(reindexed) - self._dummy_feature_coords = dummy_feature_coords - X_concat: DataArray = xr.concat(reindexed_data_list, dim=self.feature_name) self.coords_out = X_concat.coords[self.feature_name] @@ -96,7 +102,10 @@ def _split_dataarray_into_list(self, data: DataArray) -> List[DataArray]: feature_name = self.feature_name data_list: List[DataArray] = [] - for coords, features in zip(self.coords_in, self._dummy_feature_coords): + idx_range = np.cumsum([0] + self.n_features) + for i, coords in enumerate(self.coords_in.values()): + # Create dummy feature coordinates for DataArray + features = np.arange(idx_range[i], idx_range[i + 1]) # Select the features corresponding to the current DataArray sub_selection = data.sel({feature_name: features}) # Replace dummy feature coordinates with original feature coordinates diff --git a/xeofs/preprocessing/dimension_renamer.py b/xeofs/preprocessing/dimension_renamer.py index a032034..7ef729d 100644 --- a/xeofs/preprocessing/dimension_renamer.py +++ b/xeofs/preprocessing/dimension_renamer.py @@ -1,3 +1,4 @@ +from typing import Dict from typing_extensions import Self from .transformer import Transformer @@ -22,6 +23,11 @@ def __init__(self, base="dim", start=0): self.start = start self.dim_mapping = {} + def get_serialization_attrs(self) -> Dict: + return dict( + dim_mapping=self.dim_mapping, + ) + def fit(self, X: Data, sample_dims: Dims, feature_dims: Dims, **kwargs) -> Self: self.sample_dims_before = sample_dims self.feature_dims_before = feature_dims diff --git a/xeofs/preprocessing/factory.py b/xeofs/preprocessing/factory.py deleted file mode 100644 index 19c9a0c..0000000 --- a/xeofs/preprocessing/factory.py +++ /dev/null @@ -1,57 +0,0 @@ -# import xarray as xr - -# from .scaler import DataArrayScaler, DataSetScaler, DataListScaler -# from .stacker import DataArrayStacker, DataSetStacker, DataListStacker -# from .multi_index_converter import ( -# DataArrayMultiIndexConverter, -# DataSetMultiIndexConverter, -# DataListMultiIndexConverter, -# ) -# from ..utils.data_types import DataObject - - -# class ScalerFactory: -# @staticmethod -# def create_scaler(data: DataObject, **kwargs): -# if isinstance(data, xr.DataArray): -# return DataArrayScaler(**kwargs) -# elif isinstance(data, xr.Dataset): -# return DataSetScaler(**kwargs) -# elif isinstance(data, list) and all( -# isinstance(da, xr.DataArray) for da in data -# ): -# return DataListScaler(**kwargs) -# else: -# raise ValueError("Invalid data type") - - -# class MultiIndexConverterFactory: -# @staticmethod -# def create_converter( -# data: DataObject, **kwargs -# ) -> DataArrayMultiIndexConverter | DataListMultiIndexConverter: -# if isinstance(data, xr.DataArray): -# return DataArrayMultiIndexConverter(**kwargs) -# elif isinstance(data, xr.Dataset): -# return DataSetMultiIndexConverter(**kwargs) -# elif isinstance(data, list) and all( -# isinstance(da, xr.DataArray) for da in data -# ): -# return DataListMultiIndexConverter(**kwargs) -# else: -# raise ValueError("Invalid data type") - - -# class StackerFactory: -# @staticmethod -# def create_stacker(data: DataObject, **kwargs): -# if isinstance(data, xr.DataArray): -# return DataArrayStacker(**kwargs) -# elif isinstance(data, xr.Dataset): -# return DataSetStacker(**kwargs) -# elif isinstance(data, list) and all( -# isinstance(da, xr.DataArray) for da in data -# ): -# return DataListStacker(**kwargs) -# else: -# raise ValueError("Invalid data type") diff --git a/xeofs/preprocessing/multi_index_converter.py b/xeofs/preprocessing/multi_index_converter.py index 6989a44..b27817a 100644 --- a/xeofs/preprocessing/multi_index_converter.py +++ b/xeofs/preprocessing/multi_index_converter.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List, Optional, Dict from typing_extensions import Self import pandas as pd @@ -15,6 +15,13 @@ def __init__(self): self.coords_from_fit = {} self.coords_from_transform = {} + def get_serialization_attrs(self) -> Dict: + return dict( + modified_dimensions=self.modified_dimensions, + coords_from_fit=self.coords_from_fit, + coords_from_transform=self.coords_from_transform, + ) + def fit( self, X: Data, diff --git a/xeofs/preprocessing/preprocessor.py b/xeofs/preprocessing/preprocessor.py index 704d655..70a86c6 100644 --- a/xeofs/preprocessing/preprocessor.py +++ b/xeofs/preprocessing/preprocessor.py @@ -1,14 +1,18 @@ -from typing import Optional, Sequence, Hashable, List, Tuple, Any, Type +from typing import Optional, List, Tuple, Dict +from typing_extensions import Self import numpy as np +import xarray as xr +from datatree import DataTree from .list_processor import GenericListTransformer from .dimension_renamer import DimensionRenamer from .scaler import Scaler -from .stacker import StackerFactory, Stacker +from .stacker import Stacker from .multi_index_converter import MultiIndexConverter from .sanitizer import Sanitizer from .concatenator import Concatenator +from .transformer import Transformer from ..utils.xarray_utils import ( get_dims, unwrap_singleton_list, @@ -19,9 +23,6 @@ from ..utils.data_types import ( DataArray, Data, - DataVar, - DataVarBound, - DataList, Dims, DimsList, ) @@ -52,14 +53,17 @@ def extract_new_dim_names(X: List[DimensionRenamer]) -> Tuple[Dims, DimsList]: return new_sample_dims, new_feature_dims -class Preprocessor: - """Scale and stack the data along sample dimensions. +class Preprocessor(Transformer): + """Preprocess xarray objects (DataArray, Dataset). - Scaling includes (i) removing the mean and, optionally, (ii) dividing by the standard deviation, - (iii) multiplying by the square root of cosine of latitude weights (area weighting; coslat weighting), - and (iv) multiplying by additional user-defined weights. - - Stacking includes (i) stacking the data along the sample dimensions and (ii) stacking the data along the feature dimensions. + Preprocessing includes + (i) Feature-wise scaling (e.g. removing mean, dividing by standard deviation, applying (latitude) weights + (ii) Renaming dimensions (to avoid conflicts with sample and feature dimensions) + (iii) Converting MultiIndexes to regular Indexes (MultiIndexes cannot be stacked) + (iv) Stacking the data into 2D DataArray + (v) Converting MultiIndexes introduced by stacking into regular Indexes + (vi) Removing NaNs + (vii) Concatenating the 2D DataArrays into one 2D DataArray Parameters ---------- @@ -76,7 +80,8 @@ class Preprocessor: with_weights : bool, default=False If True, the data is multiplied by additional user-defined weights. return_list : bool, default=True - If True, the output is returned as a list of DataArrays. If False, the output is returned as a single DataArray if possible. + If True, inverse_transform methods returns always a list of DataArray(s). + If False, the output is returned as a single DataArray if possible. """ @@ -97,76 +102,127 @@ def __init__( self.with_coslat = with_coslat self.return_list = return_list + self.n_data = None + + dim_names_as_kwargs = { + "sample_name": self.sample_name, + "feature_name": self.feature_name, + } + + # Initialize transformers + # 1 | Center, scale and weigh the data + scaler_kwargs = { + "with_center": self.with_center, + "with_std": self.with_std, + "with_coslat": self.with_coslat, + } + self.scaler = GenericListTransformer(Scaler, **scaler_kwargs) + # 2 | Rename dimensions + self.renamer = GenericListTransformer(DimensionRenamer) + # 3 | Convert MultiIndexes (before stacking) + self.preconverter = GenericListTransformer(MultiIndexConverter) + # 4 | Stack the data to 2D DataArray + self.stacker = GenericListTransformer(Stacker, **dim_names_as_kwargs) + # 5 | Convert MultiIndexes (after stacking) + self.postconverter = GenericListTransformer(MultiIndexConverter) + # 6 | Remove NaNs + self.sanitizer = GenericListTransformer(Sanitizer, **dim_names_as_kwargs) + # 7 | Concatenate into one 2D DataArray + self.concatenator = Concatenator(**dim_names_as_kwargs) + + def get_serialization_attrs(self) -> Dict: + return dict(n_data=self.n_data) + + def transformer_types(self): + """Ordered list of transformer operations.""" + return dict( + scaler=Scaler, + renamer=DimensionRenamer, + preconverter=MultiIndexConverter, + stacker=Stacker, + postconverter=MultiIndexConverter, + sanitizer=Sanitizer, + concatenator=Concatenator, + ) + + def get_transformers(self, inverse: bool = False): + transformers = [getattr(self, t) for t in self.transformer_types().keys()] + if inverse: + transformers = transformers[::-1] + return transformers + def fit( self, X: List[Data] | Data, sample_dims: Dims, weights: Optional[List[Data] | Data] = None, - ): + ) -> Self: + """Fit the preprocessor to the data. + + Parameters + ---------- + X : xarray objects or list of xarray objects + Input data. + sample_dims : tuple of str + Sample dimensions. + weights : xr.DataArray or list of xr.DataArray, optional + Weights to be applied to the data. + + Returns + ------- + self : Preprocessor + The fitted preprocessor. + + """ self._set_return_list(X) X = convert_to_list(X) self.n_data = len(X) sample_dims, feature_dims = get_dims(X, sample_dims) - # Set sample and feature dimensions - self.dims = { - self.sample_name: sample_dims, - self.feature_name: feature_dims, - } - - # However, for each DataArray a list of feature dimensions must be provided + # For each DataArray a list of feature dimensions must be provided _check_parameter_number("feature_dims", feature_dims, self.n_data) # Ensure that weights are provided as a list weights = process_parameter("weights", weights, None, self.n_data) # 1 | Center, scale and weigh the data - scaler_kwargs = { - "with_center": self.with_center, - "with_std": self.with_std, - "with_coslat": self.with_coslat, - } - scaler_ikwargs = { - "weights": weights, - } - self.scaler = GenericListTransformer(Scaler, **scaler_kwargs) - X = self.scaler.fit_transform(X, sample_dims, feature_dims, scaler_ikwargs) - + scaler_iterkwargs = {"weights": weights} + X = self.scaler.fit_transform( + X=X, + sample_dims=sample_dims, + feature_dims=feature_dims, + iter_kwargs=scaler_iterkwargs, + ) # 2 | Rename dimensions - self.renamer = GenericListTransformer(DimensionRenamer) X = self.renamer.fit_transform(X, sample_dims, feature_dims) sample_dims, feature_dims = extract_new_dim_names(self.renamer.transformers) - # 3 | Convert MultiIndexes (before stacking) - self.preconverter = GenericListTransformer(MultiIndexConverter) X = self.preconverter.fit_transform(X, sample_dims, feature_dims) - # 4 | Stack the data to 2D DataArray - stacker_kwargs = { - "sample_name": self.sample_name, - "feature_name": self.feature_name, - } - stack_type: Type[Stacker] = StackerFactory.create(X[0]) - self.stacker = GenericListTransformer(stack_type, **stacker_kwargs) X = self.stacker.fit_transform(X, sample_dims, feature_dims) # 5 | Convert MultiIndexes (after stacking) - self.postconverter = GenericListTransformer(MultiIndexConverter) X = self.postconverter.fit_transform(X, sample_dims, feature_dims) # 6 | Remove NaNs - sanitizer_kwargs = { - "sample_name": self.sample_name, - "feature_name": self.feature_name, - } - self.sanitizer = GenericListTransformer(Sanitizer, **sanitizer_kwargs) X = self.sanitizer.fit_transform(X, sample_dims, feature_dims) - # 7 | Concatenate into one 2D DataArray - self.concatenator = Concatenator(self.sample_name, self.feature_name) self.concatenator.fit(X) # type: ignore return self def transform(self, X: List[Data] | Data) -> DataArray: + """Transform the data. + + Parameters + ---------- + X : xarray objects or list of xarray objects + Input data. + + Returns + ------- + xr.DataArray + The transformed data. + + """ X = convert_to_list(X) if len(X) != self.n_data: @@ -176,13 +232,11 @@ def transform(self, X: List[Data] | Data) -> DataArray: f"len(data objects used for fitting)={self.n_data}" ) - X = self.scaler.transform(X) - X = self.renamer.transform(X) - X = self.preconverter.transform(X) - X = self.stacker.transform(X) - X = self.postconverter.transform(X) - X = self.sanitizer.transform(X) - return self.concatenator.transform(X) # type: ignore + X_t = X.copy() + for transformer in self.get_transformers(): + X_t = transformer.transform(X_t) # type: ignore + + return X_t def fit_transform( self, @@ -206,14 +260,11 @@ def inverse_transform_data(self, X: DataArray) -> List[Data] | Data: The inverse transformed data. """ - X_list = self.concatenator.inverse_transform_data(X) - X_list = self.sanitizer.inverse_transform_data(X_list) # type: ignore - X_list = self.postconverter.inverse_transform_data(X_list) - X_list_ND = self.stacker.inverse_transform_data(X_list) - X_list_ND = self.preconverter.inverse_transform_data(X_list_ND) - X_list_ND = self.renamer.inverse_transform_data(X_list_ND) - X_list_ND = self.scaler.inverse_transform_data(X_list_ND) - return self._process_output(X_list_ND) + X_it = X.copy() + for transformer in self.get_transformers(inverse=True): + X_it = transformer.inverse_transform_data(X_it) + + return self._process_output(X_it) def inverse_transform_components(self, X: DataArray) -> List[Data] | Data: """Inverse transform the components. @@ -229,14 +280,11 @@ def inverse_transform_components(self, X: DataArray) -> List[Data] | Data: The inverse transformed components. """ - X_list = self.concatenator.inverse_transform_components(X) - X_list = self.sanitizer.inverse_transform_components(X_list) # type: ignore - X_list = self.postconverter.inverse_transform_components(X_list) - X_list_ND = self.stacker.inverse_transform_components(X_list) - X_list_ND = self.preconverter.inverse_transform_components(X_list_ND) - X_list_ND = self.renamer.inverse_transform_components(X_list_ND) - X_list_ND = self.scaler.inverse_transform_components(X_list_ND) - return self._process_output(X_list_ND) + X_it = X.copy() + for transformer in self.get_transformers(inverse=True): + X_it = transformer.inverse_transform_components(X_it) + + return self._process_output(X_it) def inverse_transform_scores(self, X: DataArray) -> DataArray: """Inverse transform the scores. @@ -254,14 +302,11 @@ def inverse_transform_scores(self, X: DataArray) -> DataArray: The inverse transformed scores. """ - X = self.concatenator.inverse_transform_scores(X) - X = self.sanitizer.inverse_transform_scores(X) - X = self.postconverter.inverse_transform_scores(X) - X = self.stacker.inverse_transform_scores(X) - X = self.preconverter.inverse_transform_scores(X) - X = self.renamer.inverse_transform_scores(X) - X = self.scaler.inverse_transform_scores(X) - return X + X_it = X.copy() + for transformer in self.get_transformers(inverse=True): + X_it = transformer.inverse_transform_scores(X_it) + + return X_it def inverse_transform_scores_unseen(self, X: DataArray) -> DataArray: """Inverse transform the scores. @@ -279,14 +324,11 @@ def inverse_transform_scores_unseen(self, X: DataArray) -> DataArray: The inverse transformed scores. """ - X = self.concatenator.inverse_transform_scores_unseen(X) - X = self.sanitizer.inverse_transform_scores_unseen(X) - X = self.postconverter.inverse_transform_scores_unseen(X) - X = self.stacker.inverse_transform_scores_unseen(X) - X = self.preconverter.inverse_transform_scores_unseen(X) - X = self.renamer.inverse_transform_scores_unseen(X) - X = self.scaler.inverse_transform_scores_unseen(X) - return X + X_it = X.copy() + for transformer in self.get_transformers(inverse=True): + X_it = transformer.inverse_transform_scores_unseen(X_it) + + return X_it def _process_output(self, X: List[Data]) -> List[Data] | Data: if self.return_list: @@ -299,3 +341,56 @@ def _set_return_list(self, X): self.return_list = True else: self.return_list = False + + def serialize_all(self) -> DataTree: + """Serialize the necessary attributes of the fitted pre-processor + and all transformers to a Dataset.""" + # Serialize the preprocessor as the root node + dt = self.serialize() + dt.name = "preprocessor" + + # Serialize all transformers + names = list(self.transformer_types().keys()) + transformers = self.get_transformers() + + for name, transformer_obj in zip(names, transformers): + dt_transformer = DataTree() + if isinstance(transformer_obj, GenericListTransformer): + # Loop through list transformer objects and assign a dummy key + for i, transformer in enumerate(transformer_obj.transformers): + dt_transformer[str(i)] = transformer.serialize() + else: + dt_transformer = transformer_obj.serialize() + # Place the serialized transformer in the tree + dt[name] = dt_transformer + dt[name].parent = dt + + return dt + + @classmethod + def deserialize_all(cls, dt: DataTree) -> Self: + """Deserialize from a DataTree representation of the preprocessor + and all attached Transformers.""" + # Create the parent preprocessor + preprocessor = cls.deserialize(dt) + + # Loop through all transformers and deserialize + names = list(preprocessor.transformer_types().keys()) + transformers = preprocessor.get_transformers() + + for name, transformer_obj in zip(names, transformers): + if isinstance(transformer_obj, GenericListTransformer): + # Recreate list transformers sequentially + for transformer in dt[name].values(): + deserialized = preprocessor.transformer_types()[name].deserialize( + transformer + ) + transformer_obj.transformers.append(deserialized) + else: + # Recreate single transformer + deserialized = preprocessor.transformer_types()[name].deserialize( + dt[name] + ) + setattr(preprocessor, name, deserialized) + + return preprocessor diff --git a/xeofs/preprocessing/sanitizer.py b/xeofs/preprocessing/sanitizer.py index c0a2181..7cebe56 100644 --- a/xeofs/preprocessing/sanitizer.py +++ b/xeofs/preprocessing/sanitizer.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Dict from typing_extensions import Self import xarray as xr @@ -15,6 +15,17 @@ class Sanitizer(Transformer): def __init__(self, sample_name="sample", feature_name="feature"): super().__init__(sample_name=sample_name, feature_name=feature_name) + self.feature_coords = xr.DataArray() + self.sample_coords = xr.DataArray() + self.is_valid_feature = xr.DataArray() + + def get_serialization_attrs(self) -> Dict: + return dict( + feature_coords=self.feature_coords, + sample_coords=self.sample_coords, + is_valid_feature=self.is_valid_feature, + ) + def _check_input_type(self, X) -> None: if not isinstance(X, xr.DataArray): raise ValueError("Input must be an xarray DataArray") diff --git a/xeofs/preprocessing/scaler.py b/xeofs/preprocessing/scaler.py index be4fed1..a7411f6 100644 --- a/xeofs/preprocessing/scaler.py +++ b/xeofs/preprocessing/scaler.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Dict from typing_extensions import Self import numpy as np @@ -40,6 +40,19 @@ def __init__( self.with_std = with_std self.with_coslat = with_coslat + self.mean_ = xr.DataArray(name="mean_") + self.std_ = xr.DataArray(name="std_") + self.coslat_weights_ = xr.DataArray(name="coslat_weights_") + self.weights_ = xr.DataArray(name="weights_") + + def get_serialization_attrs(self) -> Dict: + return dict( + mean_=self.mean_, + std_=self.std_, + coslat_weights_=self.coslat_weights_, + weights_=self.weights_, + ) + def _verify_input(self, X, name: str): if not isinstance(X, (xr.DataArray, xr.Dataset)): raise TypeError(f"{name} must be an xarray DataArray or Dataset") @@ -173,163 +186,3 @@ def inverse_transform_scores(self, X: DataArray) -> DataArray: def inverse_transform_scores_unseen(self, X: DataArray) -> DataArray: return X - - -# class DataListScaler(Scaler): -# """Scale a list of xr.DataArray along sample dimensions. - -# Scaling includes (i) removing the mean and, optionally, (ii) dividing by the standard deviation, -# (iii) multiplying by the square root of cosine of latitude weights (area weighting; coslat weighting), -# and (iv) multiplying by additional user-defined weights. - -# Parameters -# ---------- -# with_std : bool, default=True -# If True, the data is divided by the standard deviation. -# with_coslat : bool, default=False -# If True, the data is multiplied by the square root of cosine of latitude weights. -# with_weights : bool, default=False -# If True, the data is multiplied by additional user-defined weights. - -# """ - -# def __init__(self, with_std=False, with_coslat=False): -# super().__init__(with_std=with_std, with_coslat=with_coslat) -# self.scalers = [] - -# def _verify_input(self, data, name: str): -# """Verify that the input data is a list of DataArrays. - -# Parameters -# ---------- -# data : list of xarray.DataArray -# Data to be checked. - -# """ -# if not isinstance(data, list): -# raise TypeError(f"{name} must be a list of xarray DataArrays or Datasets") -# if not all(isinstance(da, (xr.DataArray, xr.Dataset)) for da in data): -# raise TypeError(f"{name} must be a list of xarray DataArrays or Datasets") - -# def fit( -# self, -# data: List[Data], -# sample_dims: Dims, -# feature_dims_list: DimsList, -# weights: Optional[List[Data] | Data] = None, -# ) -> Self: -# """Fit the scaler to the data. - -# Parameters -# ---------- -# data : list of xarray.DataArray -# Data to be scaled. -# sample_dims : hashable or sequence of hashable -# Dimensions along which the data is considered to be a sample. -# feature_dims_list : list of hashable or list of sequence of hashable -# List of dimensions along which the data is considered to be a feature. -# weights : list of xarray.DataArray, optional -# List of weights to be applied to the data. Must have the same dimensions as the data. - -# """ -# self._verify_input(data, "data") - -# # Check input -# if not isinstance(feature_dims_list, list): -# err_message = "feature dims must be a list of the feature dimensions of each DataArray, " -# err_message += 'e.g. [("lon", "lat"), ("lon")]' -# raise TypeError(err_message) - -# # Sample dimensions are the same for all data arrays -# # Feature dimensions may be different for each data array -# self.dims = {"sample": sample_dims, "feature": feature_dims_list} - -# # However, for each DataArray a list of feature dimensions must be provided -# _check_parameter_number("feature_dims", feature_dims_list, len(data)) - -# # If no weights are provided, create a list of None -# self.weights = process_parameter("weights", weights, None, len(data)) - -# params = self.get_params() - -# for da, wghts, fdims in zip(data, self.weights, feature_dims_list): -# # Create Scaler object for each data array -# scaler = Scaler(**params) -# scaler.fit(da, sample_dims=sample_dims, feature_dims=fdims, weights=wghts) -# self.scalers.append(scaler) - -# return self - -# def transform(self, da_list: List[Data]) -> List[Data]: -# """Scale the data. - -# Parameters -# ---------- -# da_list : list of xarray.DataArray -# Data to be scaled. - -# Returns -# ------- -# list of xarray.DataArray -# Scaled data. - -# """ -# self._verify_input(da_list, "da_list") - -# da_list_transformed = [] -# for scaler, da in zip(self.scalers, da_list): -# da_list_transformed.append(scaler.transform(da)) -# return da_list_transformed - -# def fit_transform( -# self, -# data: List[Data], -# sample_dims: Dims, -# feature_dims_list: DimsList, -# weights: Optional[List[Data] | Data] = None, -# ) -> List[Data]: -# """Fit the scaler to the data and scale it. - -# Parameters -# ---------- -# data : list of xr.DataArray -# Data to be scaled. -# sample_dims : hashable or sequence of hashable -# Dimensions along which the data is considered to be a sample. -# feature_dims_list : list of hashable or list of sequence of hashable -# List of dimensions along which the data is considered to be a feature. -# weights : list of xr.DataArray, optional -# List of weights to be applied to the data. Must have the same dimensions as the data. - -# Returns -# ------- -# list of xarray.DataArray -# Scaled data. - -# """ -# self.fit(data, sample_dims, feature_dims_list, weights) -# return self.transform(data) - -# def inverse_transform_data(self, da_list: List[Data]) -> List[Data]: -# """Unscale the data. - -# Parameters -# ---------- -# da_list : list of xarray.DataArray -# Data to be scaled. - -# Returns -# ------- -# list of xarray.DataArray -# Scaled data. - -# """ -# self._verify_input(da_list, "da_list") - -# da_list_transformed = [] -# for scaler, da in zip(self.scalers, da_list): -# da_list_transformed.append(scaler.inverse_transform_data(da)) -# return da_list_transformed - -# def inverse_transform_components(self, da_list: List[Data]) -> List[Data]: -# return da_list diff --git a/xeofs/preprocessing/stacker.py b/xeofs/preprocessing/stacker.py index 31ea22e..6268006 100644 --- a/xeofs/preprocessing/stacker.py +++ b/xeofs/preprocessing/stacker.py @@ -1,18 +1,18 @@ -from abc import abstractmethod -from typing import List, Optional, Type +from typing import Dict from typing_extensions import Self -import numpy as np import pandas as pd import xarray as xr from .transformer import Transformer from ..utils.data_types import Dims, DataArray, DataSet, Data, DataVar, DataVarBound -from ..utils.sanity_checks import convert_to_dim_type class Stacker(Transformer): - """Converts a DataArray of any dimensionality into a 2D structure. + """Converts an xarray DataArray or Dataset of any dimensionality into a 2D DataArray. + + The new DataArray will have two dimensions: `sample` and `feature`. + The dimensions of the original data will be stacked along these two dimensions. Attributes ---------- @@ -21,9 +21,9 @@ class Stacker(Transformer): feature_dims : Sequence[Hashable] The dimensions of the data that will be stacked along the `feature` dimension. sample_name : str - The name of the sample dimension. + The name of the sample dimension (dim=0). feature_name : str - The name of the feature dimension. + The name of the feature dimension (dim=1). dims_in : Tuple[str] The dimensions of the input data. dims_out : Tuple[str] @@ -47,11 +47,82 @@ def __init__( self.dims_out = tuple((sample_name, feature_name)) self.dims_mapping = {} self.dims_mapping.update({d: tuple() for d in self.dims_out}) - self.coords_in = {} self.coords_out = {} - def _validate_matching_dimensions(self, X: Data): + def get_serialization_attrs(self) -> Dict: + return dict( + dims_in=self.dims_in, + dims_out=self.dims_out, + dims_mapping=self.dims_mapping, + coords_in=self.coords_in, + coords_out=self.coords_out, + ) + + def _validate_data_type(self, X: Data): + """Check that the data type is either DataArray or Dataset.""" + if not isinstance(X, (xr.DataArray, xr.Dataset)): + raise TypeError(f"Invalid data type {type(X)}.") + + def _validate_dimension_names(self, X, sample_dims, feature_dims): + """Check that the names of the sample and feature dimensions are not already present in the data.""" + sample_name_in_data = self.sample_name in X.dims + feature_name_in_data = self.feature_name in X.dims + + has_invalid_sample_name = ( + True if (len(sample_dims) > 1) and sample_name_in_data else False + ) + + match X: + case xr.DataArray(): + has_invalid_feature_name = ( + True if (len(feature_dims) > 1) and feature_name_in_data else False + ) + case xr.Dataset(): + has_invalid_feature_name = True if feature_name_in_data else False + case _: + raise TypeError(f"Invalid data type {type(X)}.") + + if has_invalid_sample_name: + err_msg = f"Name of sample dimension ({self.sample_name}) is already present in data. Please use another name." + raise ValueError(err_msg) + + if has_invalid_feature_name: + err_msg = f"Name of feature dimension ({self.feature_name}) is already present in data. Please use another name." + raise ValueError(err_msg) + + def _validate_dims(self, X: Data, sample_dims, feature_dims): + invalid_sample_dims = True if len(sample_dims) < 1 else False + invalid_feature_dims = True if len(feature_dims) < 1 else False + + if invalid_sample_dims: + raise ValueError(f"Sample dimension must not be empty.") + if invalid_feature_dims: + match X: + case xr.DataArray(): + raise ValueError(f"Feature dimension must not be empty.") + case xr.Dataset(): + err_msg = f"Dataset without feature dimension is currently not supported. Please convert your Dataset to a DataArray first, e.g. by using `to_array()`." + raise ValueError(err_msg) + case _: + raise TypeError(f"Invalid data type {type(X)}.") + + def _validate_indices(self, X: Data): + """Check that the indices of the data are no MultiIndex""" + if any([isinstance(index, pd.MultiIndex) for index in X.indexes.values()]): + raise ValueError(f"Cannot stack data containing a MultiIndex.") + + def _sanity_check(self, X: Data, sample_dims, feature_dims): + self._validate_dims(X, sample_dims, feature_dims) + self._validate_dimension_names(X, sample_dims, feature_dims) + self._validate_indices(X) + + def _validate_transform_data_type(self, X: Data): + """Check that the data type is either DataArray or Dataset.""" + if not isinstance(X, self.data_type): + raise TypeError(f"Expected data type {self.data_type}, got {type(X)}.") + + def _validate_transform_dimensions(self, X: Data): """Verify that the dimensions of the data are consistent with the dimensions used to fit the stacker.""" # Test whether sample and feature dimensions are present in data array expected_sample_dims = set(self.dims_mapping[self.sample_name]) @@ -63,7 +134,7 @@ def _validate_matching_dimensions(self, X: Data): f"One or more dimensions in {expected_dims} are not present in data." ) - def _validate_matching_feature_coords(self, X: Data): + def _validate_transform_feature_coords(self, X: Data): """Verify that the feature coordinates of the data are consistent with the feature coordinates used to fit the stacker.""" feature_dims = self.dims_mapping[self.feature_name] coords_are_equal = [ @@ -74,34 +145,21 @@ def _validate_matching_feature_coords(self, X: Data): "Data to be transformed has different coordinates than the data used to fit." ) - def _validate_dimension_names(self, sample_dims, feature_dims): - if len(sample_dims) > 1: - if self.sample_name in sample_dims: - raise ValueError( - f"Name of sample dimension ({self.sample_name}) is already present in data. Please use another name." - ) - if len(feature_dims) > 1: - if self.feature_name in feature_dims: - raise ValueError( - f"Name of feature dimension ({self.feature_name}) is already present in data. Please use another name." - ) - - def _validate_indices(self, X: Data): - """Check that the indices of the data are no MultiIndex""" - if any([isinstance(index, pd.MultiIndex) for index in X.indexes.values()]): - raise ValueError(f"Cannot stack data containing a MultiIndex.") - - def _sanity_check(self, X: Data, sample_dims, feature_dims): - self._validate_dimension_names(sample_dims, feature_dims) - self._validate_indices(X) + def _reorder_dims(self, X: DataVarBound) -> DataVarBound: + """Reorder dimensions to original order; catch ('mode') dimensions via ellipsis""" + order_input_dims = [ + valid_dim for valid_dim in self.dims_in if valid_dim in X.dims + ] + if order_input_dims != X.dims: + X = X.transpose(..., *order_input_dims) + return X - @abstractmethod def _stack(self, X: Data, sample_dims: Dims, feature_dims: Dims) -> DataArray: """Stack data to 2D. Parameters ---------- - data : DataArray + X : xr.DataArray | xr.Dataset The data to be reshaped. sample_dims : Hashable or Sequence[Hashable] The dimensions of the data that will be stacked along the `sample` dimension. @@ -110,19 +168,112 @@ def _stack(self, X: Data, sample_dims: Dims, feature_dims: Dims) -> DataArray: Returns ------- - data_stacked : DataArray + DataArray The reshaped 2d-data. """ + sample_name = self.sample_name + feature_name = self.feature_name + + # Stack SAMPLE dimension + if len(sample_dims) > 1: + X = X.stack({sample_name: sample_dims}) + elif len(sample_dims) == 1: + if sample_dims[0] != sample_name: + X = X.rename({sample_dims[0]: sample_name}) + else: + # There's only one sample dimension and it's already named correctly + pass + else: + raise ValueError(f"Sample dimension must not be empty.") + + # Stack FEATURE dimension + match X: + case xr.DataArray(): + if len(feature_dims) > 1: + X = X.stack({feature_name: feature_dims}) + elif len(feature_dims) == 1: + if feature_dims[0] != feature_name: + X = X.rename({feature_dims[0]: feature_name}) + else: + # There's only one feature dimension and it's already named correctly + pass + else: + raise ValueError(f"Feature dimension must not be empty.") + + case xr.Dataset(): + X = X.to_stacked_array( + new_dim=feature_name, sample_dims=(self.sample_name,) + ) + case _: + raise TypeError(f"Invalid data type {type(X)}.") + + # Reorder dimensions to be always (sample, feature) + if X.dims == (feature_name, sample_name): + X = X.transpose(sample_name, feature_name) - def _reorder_dims(self, X: DataVarBound) -> DataVarBound: - """Reorder dimensions to original order; catch ('mode') dimensions via ellipsis""" - order_input_dims = [ - valid_dim for valid_dim in self.dims_in if valid_dim in X.dims - ] - if order_input_dims != X.dims: - X = X.transpose(..., *order_input_dims) return X + def _unstack_to_dataarray(self, X: DataArray) -> DataArray: + """Unstack 2D DataArray to its original dimensions. + + Parameters + ---------- + X : DataArray + The data to be unstacked. + + Returns + ------- + DataArray + The unstacked data. + """ + sample_name = self.sample_name + feature_name = self.feature_name + + has_only_one_sample_dim = len(self.dims_mapping[sample_name]) == 1 + has_only_one_feature_dim = len(self.dims_mapping[feature_name]) == 1 + + if sample_name in X.dims: + # If sample dimensions is one dimensional, rename is sufficient, otherwise unstack + if has_only_one_sample_dim: + if self.dims_mapping[sample_name][0] != sample_name: + X = X.rename({sample_name: self.dims_mapping[sample_name][0]}) + else: + X = X.unstack(sample_name) + + # pass if feature/sample dimensions do not exist in data + if feature_name in X.dims: + # If sample dimensions is one dimensional, rename is sufficient, otherwise unstack + if has_only_one_feature_dim: + if self.dims_mapping[feature_name][0] != feature_name: + X = X.rename({feature_name: self.dims_mapping[feature_name][0]}) + else: + X = X.unstack(feature_name) + + else: + pass + + X = self._reorder_dims(X) + return X + + def _unstack_to_dataset_data(self, X: DataArray) -> DataSet: + """Unstack `sample` and `feature` dimension of an DataArray to its original dimensions.""" + sample_name = self.sample_name + feature_name = self.feature_name + has_only_one_sample_dim = len(self.dims_mapping[sample_name]) == 1 + + if has_only_one_sample_dim: + X = X.rename({sample_name: self.dims_mapping[sample_name][0]}) + + ds: DataSet = X.to_unstacked_dataset(feature_name, "variable").unstack() + ds = self._reorder_dims(ds) + return ds + + def _unstack_to_dataset_components(self, data: DataArray) -> DataSet: + feature_name = self.feature_name + ds: DataSet = data.to_unstacked_dataset(feature_name, "variable").unstack() + ds = self._reorder_dims(ds) + return ds + def fit(self, X: Data, sample_dims: Dims, feature_dims: Dims) -> Self: """Fit the stacker. @@ -137,6 +288,9 @@ def fit(self, X: Data, sample_dims: Dims, feature_dims: Dims) -> Self: The fitted stacker. """ + self._sanity_check(X, sample_dims, feature_dims) + + self.data_type = type(X) self.sample_dims = sample_dims self.feature_dims = feature_dims self.dims_mapping.update( @@ -145,7 +299,6 @@ def fit(self, X: Data, sample_dims: Dims, feature_dims: Dims) -> Self: self.feature_name: feature_dims, } ) - self._sanity_check(X, sample_dims, feature_dims) # Set dimensions and coordinates self.dims_in = X.dims @@ -175,10 +328,10 @@ def transform(self, X: Data) -> DataArray: """ # Test whether sample and feature dimensions are present in data array - self._validate_matching_dimensions(X) + self._validate_transform_dimensions(X) # Check if data to be transformed has the same feature coordinates as the data used to fit the stacker - self._validate_matching_feature_coords(X) + self._validate_transform_feature_coords(X) # Stack data sample_dims = self.dims_mapping[self.sample_name] @@ -204,7 +357,6 @@ def fit_transform( ) -> DataArray: return self.fit(X, sample_dims, feature_dims).transform(X) - @abstractmethod def inverse_transform_data(self, X: DataArray) -> Data: """Reshape the 2D data (sample x feature) back into its original dimensions. @@ -219,14 +371,20 @@ def inverse_transform_data(self, X: DataArray) -> Data: The reshaped data. """ + match self.data_type: + case xr.DataArray: + return self._unstack_to_dataarray(X) + case xr.Dataset: + return self._unstack_to_dataset_data(X) + case _: + raise TypeError(f"Invalid data type {type(X)}.") - @abstractmethod def inverse_transform_components(self, X: DataArray) -> Data: - """Reshape the 2D components (sample x feature) back into its original dimensions. + """Reshape the 2D components (feature x mode) back into its original dimensions. Parameters ---------- - data : DataArray + X : DataArray The data to be reshaped. Returns @@ -235,14 +393,22 @@ def inverse_transform_components(self, X: DataArray) -> Data: The reshaped data. """ + match self.data_type: + case xr.DataArray: + return self._unstack_to_dataarray(X) + case xr.Dataset: + return self._unstack_to_dataset_components(X) + case _: + raise TypeError(f"Invalid data type {type(X)}.") - @abstractmethod - def inverse_transform_scores(self, data: DataArray) -> DataArray: - """Reshape the 2D scores (sample x feature) back into its original dimensions. + def inverse_transform_scores(self, X: DataArray) -> DataArray: + """Reshape the 2D scores (sample x mode) back into its original dimensions. + + Use this for fitted scores. Parameters ---------- - data : DataArray + X : DataArray The data to be reshaped. Returns @@ -251,266 +417,22 @@ def inverse_transform_scores(self, data: DataArray) -> DataArray: The reshaped data. """ + return self._unstack_to_dataarray(X) - @abstractmethod - def inverse_transform_scores_unseen(self, data: DataArray) -> DataArray: - pass - + def inverse_transform_scores_unseen(self, X: DataArray) -> DataArray: + """Reshape the 2D scores (sample x mode) back into its original dimensions. -class DataArrayStacker(Stacker): - def _stack( - self, data: DataArray, sample_dims: Dims, feature_dims: Dims - ) -> DataArray: - """Reshape a DataArray to 2D. + Use this for new, unseen scores. Parameters ---------- - data : DataArray + X : DataArray The data to be reshaped. - sample_dims : Hashable or Sequence[Hashable] - The dimensions of the data that will be stacked along the `sample` dimension. - feature_dims : Hashable or Sequence[Hashable] - The dimensions of the data that will be stacked along the `feature` dimension. Returns ------- - data_stacked : DataArray - The reshaped 2d-data. - """ - sample_name = self.sample_name - feature_name = self.feature_name - - # 3 cases: - # 1. uni-dimensional with correct feature/sample name ==> do nothing - # 2. uni-dimensional with name different from feature/sample ==> rename - # 3. multi-dimensinoal with names different from feature/sample ==> stack - - # - SAMPLE - - if len(sample_dims) == 1: - # Case 1 - if sample_dims[0] == sample_name: - pass - # Case 2 - else: - data = data.rename({sample_dims[0]: sample_name}) - # Case 3 - else: - data = data.stack({sample_name: sample_dims}) - - # - FEATURE - - if len(feature_dims) == 1: - # Case 1 - if feature_dims[0] == feature_name: - pass - # Case 2 - else: - data = data.rename({feature_dims[0]: feature_name}) - # Case 3 - else: - data = data.stack({feature_name: feature_dims}) - - # Reorder dimensions to be always (sample, feature) - if data.dims == (feature_name, sample_name): - data = data.transpose(sample_name, feature_name) - - return data - - def _unstack(self, data: DataArray) -> DataArray: - """Unstack 2D DataArray to its original dimensions. - - Parameters - ---------- - data : DataArray - The data to be unstacked. - - Returns - ------- - data_unstacked : DataArray - The unstacked data. - """ - sample_name = self.sample_name - feature_name = self.feature_name - - # pass if feature/sample dimensions do not exist in data - if feature_name in data.dims: - # If sample dimensions is one dimensional, rename is sufficient, otherwise unstack - if len(self.dims_mapping[feature_name]) == 1: - if self.dims_mapping[feature_name][0] != feature_name: - data = data.rename( - {feature_name: self.dims_mapping[feature_name][0]} - ) - else: - data = data.unstack(feature_name) - - if sample_name in data.dims: - # If sample dimensions is one dimensional, rename is sufficient, otherwise unstack - if len(self.dims_mapping[sample_name]) == 1: - if self.dims_mapping[sample_name][0] != sample_name: - data = data.rename({sample_name: self.dims_mapping[sample_name][0]}) - else: - data = data.unstack(sample_name) - - else: - pass - - return data - - def inverse_transform_data(self, X: DataArray) -> Data: - Xnd = self._unstack(X) - Xnd = self._reorder_dims(Xnd) - return Xnd - - def inverse_transform_components(self, X: DataArray) -> Data: - Xnd = self._unstack(X) - Xnd = self._reorder_dims(Xnd) - return Xnd - - def inverse_transform_scores(self, data: DataArray) -> DataArray: - data = self._unstack(data) # type: ignore - data = self._reorder_dims(data) - return data - - def inverse_transform_scores_unseen(self, data: DataArray) -> DataArray: - return self.inverse_transform_scores(data) - - -class DataSetStacker(Stacker): - """Converts a Dataset of any dimensionality into a 2D structure.""" - - def _validate_dimension_names(self, sample_dims, feature_dims): - if len(sample_dims) > 1: - if self.sample_name in sample_dims: - raise ValueError( - f"Name of sample dimension ({self.sample_name}) is already present in data. Please use another name." - ) - if len(feature_dims) >= 1: - if self.feature_name in feature_dims: - raise ValueError( - f"Name of feature dimension ({self.feature_name}) is already present in data. Please use another name." - ) - else: - raise ValueError( - f"Datasets without feature dimension are currently not supported. Please convert your Dataset to a DataArray first, e.g. by using `to_array()`." - ) - - def _stack(self, data: DataSet, sample_dims, feature_dims) -> DataArray: - """Reshape a Dataset to 2D. - - Parameters - ---------- - data : Dataset - The data to be reshaped. - sample_dims : Hashable or Sequence[Hashable] - The dimensions of the data that will be stacked along the `sample` dimension. - feature_dims : Hashable or Sequence[Hashable] - The dimensions of the data that will be stacked along the `feature` dimension. + DataArray + The reshaped data. - Returns - ------- - data_stacked : DataArray - The reshaped 2d-data. """ - sample_name = self.sample_name - feature_name = self.feature_name - - # 3 cases: - # 1. uni-dimensional with correct feature/sample name ==> do nothing - # 2. uni-dimensional with name different from feature/sample ==> rename - # 3. multi-dimensinoal with names different from feature/sample ==> stack - - # - SAMPLE - - if len(sample_dims) == 1: - # Case 1 - if sample_dims[0] == sample_name: - pass - # Case 2 - else: - data = data.rename({sample_dims[0]: sample_name}) - # Case 3 - else: - data = data.stack({sample_name: sample_dims}) - - # - FEATURE - - # Convert Dataset -> DataArray, stacking all non-sample dimensions to feature dimension, including data variables - err_msg = f"Feature dimension {feature_dims[0]} already exists in data. Please choose another feature dimension name." - # Case 2 & 3 - if (len(feature_dims) == 1) & (feature_dims[0] == feature_name): - raise ValueError(err_msg) - else: - try: - da = data.to_stacked_array( - new_dim=feature_name, sample_dims=(self.sample_name,) - ) - except ValueError: - raise ValueError(err_msg) - - # Reorder dimensions to be always (sample, feature) - if da.dims == (feature_name, sample_name): - da = da.transpose(sample_name, feature_name) - - return da - - def _unstack_data(self, data: DataArray) -> DataSet: - """Unstack `sample` and `feature` dimension of an DataArray to its original dimensions.""" - sample_name = self.sample_name - feature_name = self.feature_name - has_only_one_sample_dim = len(self.dims_mapping[sample_name]) == 1 - - if has_only_one_sample_dim: - data = data.rename({sample_name: self.dims_mapping[sample_name][0]}) - - ds: DataSet = data.to_unstacked_dataset(feature_name, "variable").unstack() - ds = self._reorder_dims(ds) - return ds - - def _unstack_components(self, data: DataArray) -> DataSet: - feature_name = self.feature_name - ds: DataSet = data.to_unstacked_dataset(feature_name, "variable").unstack() - ds = self._reorder_dims(ds) - return ds - - def _unstack_scores(self, data: DataArray) -> DataArray: - sample_name = self.sample_name - has_only_one_sample_dim = len(self.dims_mapping[sample_name]) == 1 - - if has_only_one_sample_dim: - data = data.rename({sample_name: self.dims_mapping[sample_name][0]}) - - data = data.unstack() - data = self._reorder_dims(data) - return data - - def inverse_transform_data(self, X: DataArray) -> DataSet: - """Reshape the 2D data (sample x feature) back into its original shape.""" - X_ds: DataSet = self._unstack_data(X) - return X_ds - - def inverse_transform_components(self, X: DataArray) -> DataSet: - """Reshape the 2D components (sample x feature) back into its original shape.""" - X_ds: DataSet = self._unstack_components(X) - return X_ds - - def inverse_transform_scores(self, X: DataArray) -> DataArray: - """Reshape the 2D scores (sample x feature) back into its original shape.""" - X = self._unstack_scores(X) - return X - - def inverse_transform_scores_unseen(self, X: DataArray) -> DataArray: return self.inverse_transform_scores(X) - - -class StackerFactory: - """Factory class for creating stackers.""" - - def __init__(self): - pass - - @staticmethod - def create(data: Data) -> Type[DataArrayStacker] | Type[DataSetStacker]: - """Create a stacker for the given data.""" - if isinstance(data, xr.DataArray): - return DataArrayStacker - elif isinstance(data, xr.Dataset): - return DataSetStacker - else: - raise TypeError(f"Invalid data type {type(data)}.") diff --git a/xeofs/preprocessing/transformer.py b/xeofs/preprocessing/transformer.py index 5bd9f1a..731f108 100644 --- a/xeofs/preprocessing/transformer.py +++ b/xeofs/preprocessing/transformer.py @@ -1,8 +1,11 @@ from abc import ABC -from typing import Optional +from typing import Optional, Dict from typing_extensions import Self from abc import abstractmethod +import pandas as pd +import xarray as xr +from datatree import DataTree from sklearn.base import BaseEstimator, TransformerMixin from ..utils.data_types import Dims, DataVar, DataArray, DataSet, Data, DataVarBound @@ -22,13 +25,26 @@ def __init__( self.sample_name = sample_name self.feature_name = feature_name + @abstractmethod + def get_serialization_attrs(self) -> Dict: + """Return a dictionary containing the attributes that need to be serialized + as part of a saved transformer. + + There are limitations on the types of attributes that can be serialized. + Most simple types (e.g. int, float, str, bool, None) can be, as well as + DataArrays and dicts of DataArrays. Other nested types (e.g. lists of + DataArrays) will likely fail. + + """ + return dict() + @abstractmethod def fit( self, X: Data, sample_dims: Optional[Dims] = None, feature_dims: Optional[Dims] = None, - **kwargs + **kwargs, ) -> Self: """Fit transformer to data. @@ -52,7 +68,7 @@ def fit_transform( X: Data, sample_dims: Optional[Dims] = None, feature_dims: Optional[Dims] = None, - **kwargs + **kwargs, ) -> Data: return self.fit(X, sample_dims, feature_dims, **kwargs).transform(X) @@ -71,3 +87,87 @@ def inverse_transform_scores(self, X: DataArray) -> DataArray: @abstractmethod def inverse_transform_scores_unseen(self, X: DataArray) -> DataArray: return X + + def _serialize_data(self, key: str, data: DataArray) -> DataSet: + # Make sure the DataArray has some name so we can create a string mapping + if data.name is None: + data.name = key + + multiindexes = {} + if data.name in data.coords: + # Create coords-based datasets and note multiindexes + if isinstance(data.to_index(), pd.MultiIndex): + multiindexes[data.name] = [n for n in data.to_index().names] + ds = xr.Dataset(coords={data.name: data}) + else: + # Create data-based datasets + ds = xr.Dataset(data_vars={data.name: data}) + + # Drop multiindexes and record for later + ds = ds.reset_index(list(multiindexes.keys())) + ds.attrs["multiindexes"] = multiindexes + ds.attrs["name_map"] = {key: data.name} + + return ds + + def serialize(self) -> DataTree: + """Serialize a transformer to a DataTree.""" + dt = DataTree() + params = self.get_params() + attrs = self.get_serialization_attrs() + + # Set initialization params as tree level attrs + dt.attrs["params"] = params + + # Serialize each transformer attribute + for key, attr in attrs.items(): + if isinstance(attr, xr.DataArray): + # attach data to data_vars or coords + ds = self._serialize_data(key, attr) + dt[key] = DataTree(name=key, data=ds) + dt.attrs[key] = "_is_node" + elif isinstance(attr, dict) and any( + [isinstance(val, xr.DataArray) for val in attr.values()] + ): + # attach dict of data as branching tree + dt_attr = DataTree() + for k, v in attr.items(): + ds = self._serialize_data(k, v) + dt_attr[k] = DataTree(name=k, data=ds) + dt[key] = dt_attr + dt.attrs[key] = "_is_tree" + else: + # attach simple types as dataset attrs + dt.attrs[key] = attr + + return dt + + def _deserialize_data_node(self, key: str, ds: xr.Dataset) -> DataArray: + # Rebuild multiindexes + ds = ds.set_index(ds.attrs.get("multiindexes", {})) + # Extract the DataArray or coord from the Dataset + data_key = ds.attrs["name_map"][key] + data = ds[data_key] + return data + + @classmethod + def deserialize(cls, dt: DataTree) -> Self: + """Deserialize a saved transformer from a DataTree.""" + # Create the object from params + params = dt.attrs.pop("params") + transformer = cls(**params) + + # Set attributes + for key, attr in dt.attrs.items(): + if attr == "_is_node": + data = transformer._deserialize_data_node(key, dt[key]) + setattr(transformer, key, data) + elif attr == "_is_tree": + data = {} + for k, v in dt[key].items(): + data[k] = transformer._deserialize_data_node(k, dt[key][k]) + setattr(transformer, key, data) + else: + setattr(transformer, key, attr) + + return transformer