Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: serialization methods #103

Merged
merged 10 commits into from
Nov 7, 2023
18 changes: 18 additions & 0 deletions docs/overview_2_features.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,24 @@ Refer to the :doc:`api` section to discover the methods currently available.
Please note that ``xeofs`` is in its developmental phase. If there's a specific method
you'd like to see included, we encourage you to open an issue on `GitHub`_.

Model Serialization
-------------------

``xeofs`` models offer convenient ``save()`` and ``load()`` methods for serializing
fitted models to a portable format.

.. code-block:: python

from xeofs.models import EOF

model = EOF()
model.fit(data, dim="time")
model.save("my_model.zarr")

# Later, you can load the model
loaded_model = EOF.load("my_model.zarr")


Input Data Compatibility
------------------------

Expand Down
99 changes: 98 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ netCDF4 = "^1.5.7"
numba = "^0.57"
typing-extensions = "^4.8.0"


[tool.poetry.group.io]
optional = false

[tool.poetry.group.io.dependencies]
zarr = ">=2.0.0"
xarray-datatree = ">=0.0.5"

[tool.poetry.group.dev]
optional = true

Expand Down
48 changes: 48 additions & 0 deletions tests/models/test_eof.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,3 +442,51 @@ def test_inverse_transform(dim, mock_data_array):

# Check that the reconstructed data has the same dimensions as the original data
assert set(reconstructed_data.dims) == set(mock_data_array.dims)


@pytest.mark.parametrize(
"dim",
[
(("time",)),
(("lat", "lon")),
(("lon", "lat")),
],
)
def test_save_load(dim, mock_data_array, tmp_path):
"""Test save/load methods in EOF class, ensuring that we can
roundtrip the model and get the same results when transforming
data."""
original = EOF()
original.fit(mock_data_array, dim)

# Save the EOF model
original.save(tmp_path / "eof.zarr")

# Check that the EOF model has been saved
assert (tmp_path / "eof.zarr").exists()

# Recreate the model from saved file
loaded = EOF.load(tmp_path / "eof.zarr")

# Check that the params and DataContainer objects match
assert original.get_params() == loaded.get_params()
assert all([key in loaded.data for key in original.data])
assert all(
[
loaded.data._allow_compute[key] == original.data._allow_compute[key]
for key in original.data
]
)

# Test that the recreated model can be used to transform new data
assert np.allclose(
original.scores(), loaded.transform(mock_data_array), rtol=1e-3, atol=1e-3
)

# Enhancement: the loaded model should also be able to inverse_transform new data
# assert np.allclose(
# original.inverse_transform(original.scores()),
# loaded.inverse_transform(loaded.scores()),
# rtol=1e-3,
# atol=1e-3,
# )
50 changes: 50 additions & 0 deletions tests/models/test_eof_rotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,53 @@ def test_compute(eof_model_delayed, compute):
assert data_is_dask(eof_rotator.data["explained_variance"])
assert data_is_dask(eof_rotator.data["components"])
assert data_is_dask(eof_rotator.data["rotation_matrix"])


@pytest.mark.parametrize(
"dim",
[
(("time",)),
(("lat", "lon")),
(("lon", "lat")),
],
)
def test_save_load(dim, mock_data_array, tmp_path):
"""Test save/load methods in EOF class, ensuring that we can
roundtrip the model and get the same results when transforming
data."""
original_unrotated = EOF()
original_unrotated.fit(mock_data_array, dim)

original = EOFRotator()
original.fit(original_unrotated)

# Save the EOF model
original.save(tmp_path / "eof.zarr")

# Check that the EOF model has been saved
assert (tmp_path / "eof.zarr").exists()

# Recreate the model from saved file
loaded = EOFRotator.load(tmp_path / "eof.zarr")

# Check that the params and DataContainer objects match
assert original.get_params() == loaded.get_params()
assert all([key in loaded.data for key in original.data])
assert all(
[
loaded.data._allow_compute[key] == original.data._allow_compute[key]
for key in original.data
]
)

# Test that the recreated model can be used to transform new data
assert np.allclose(
original.scores(), loaded.transform(mock_data_array), rtol=1e-3, atol=1e-3
)

# Enhancement: the loaded model should also be able to inverse_transform new data
# assert np.allclose(
# original.inverse_transform(original.scores()),
# loaded.inverse_transform(loaded.scores()),
# rtol=1e-2,
# )
51 changes: 51 additions & 0 deletions tests/models/test_mca.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,3 +381,54 @@ def test_compute(mock_dask_data_array, dim, compute):
assert data_is_dask(mca_model.data["squared_covariance"])
assert data_is_dask(mca_model.data["components1"])
assert data_is_dask(mca_model.data["components2"])


@pytest.mark.parametrize(
"dim",
[
(("time",)),
(("lat", "lon")),
(("lon", "lat")),
],
)
def test_save_load(dim, mock_data_array, tmp_path):
"""Test save/load methods in MCA class, ensuring that we can
roundtrip the model and get the same results when transforming
data."""
original = MCA()
original.fit(mock_data_array, mock_data_array, dim)

# Save the EOF model
original.save(tmp_path / "mca.zarr")

# Check that the EOF model has been saved
assert (tmp_path / "mca.zarr").exists()

# Recreate the model from saved file
loaded = MCA.load(tmp_path / "mca.zarr")

# Check that the params and DataContainer objects match
assert original.get_params() == loaded.get_params()
assert all([key in loaded.data for key in original.data])
assert all(
[
loaded.data._allow_compute[key] == original.data._allow_compute[key]
for key in original.data
]
)

# Test that the recreated model can be used to transform new data
assert np.allclose(
original.scores(),
loaded.transform(mock_data_array, mock_data_array),
rtol=1e-3,
atol=1e-3,
)

# Enhancement: the loaded model should also be able to inverse_transform new data
# assert np.allclose(
# original.inverse_transform(original.scores()),
# loaded.inverse_transform(loaded.scores()),
# rtol=1e-3,
# atol=1e-3,
# )
54 changes: 54 additions & 0 deletions tests/models/test_mca_rotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,57 @@ def test_compute(mca_model_delayed, compute):
assert data_is_dask(mca_rotator.data["norm1"])
assert data_is_dask(mca_rotator.data["norm2"])
assert data_is_dask(mca_rotator.data["modes_sign"])


@pytest.mark.parametrize(
"dim",
[
(("time",)),
(("lat", "lon")),
(("lon", "lat")),
],
)
def test_save_load(dim, mock_data_array, tmp_path):
"""Test save/load methods in MCA class, ensuring that we can
roundtrip the model and get the same results when transforming
data."""
original_unrotated = MCA()
original_unrotated.fit(mock_data_array, mock_data_array, dim)

original = MCARotator()
original.fit(original_unrotated)

# Save the EOF model
original.save(tmp_path / "mca.zarr")

# Check that the EOF model has been saved
assert (tmp_path / "mca.zarr").exists()

# Recreate the model from saved file
loaded = MCARotator.load(tmp_path / "mca.zarr")

# Check that the params and DataContainer objects match
assert original.get_params() == loaded.get_params()
assert all([key in loaded.data for key in original.data])
assert all(
[
loaded.data._allow_compute[key] == original.data._allow_compute[key]
for key in original.data
]
)

# Test that the recreated model can be used to transform new data
assert np.allclose(
original.scores(),
loaded.transform(data1=mock_data_array, data2=mock_data_array),
rtol=1e-3,
atol=1e-3,
)

# Enhancement: the loaded model should also be able to inverse_transform new data
# assert np.allclose(
# original.inverse_transform(original.scores()),
# loaded.inverse_transform(loaded.scores()),
# rtol=1e-3,
# atol=1e-3,
# )
Loading