Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: serialization methods #103

Merged
merged 10 commits into from
Nov 7, 2023
Merged

Conversation

slevang
Copy link
Contributor

@slevang slevang commented Nov 7, 2023

Closes #95.

This ended up being a lot harder than I expected, mostly due to the nested structure of the preprocessor and its attributes, plus the mixture of basic data types and DataArrays. The only workable solution I could come up was to add xarray-datatree and zarr as optional dependencies, and build the whole structure as a tree. This was my first foray into DataTree which is super flexible and worked great for this task. I bailed on netcdf only because there are serious limitations with representing any sort of nested structures in attributes, i.e. dicts.

save()/load() are currently implemented and tested for EOF, EOFRotator, MCA, and MCARotator. I didn't bother with any of the models that don't yet implement a transform() method, because there would be less use for a saved model in this case. Should be pretty easy to extend in the future though if desired.

slevang and others added 4 commits November 3, 2023 16:15
Previously, the Preprocessor was fully initialized within the fit method
,as the decision between using a DataArrayStacker or a DataSetStacker
depended on the data being fitted. This commit consolidates both
stackers into a unified Stacker class, enabling the complete
initialization of the Preprocessor at the start of the model
initialization process.
@nicrie
Copy link
Contributor

nicrie commented Nov 7, 2023

Hey, thanks Sam, that's really awesome! I've only skimmed through it, but it looks super clean. I'll try to get the CI up and running, and then I'll take a closer look at the code this evening. I haven't used `xarray-datatree before either, but it sounds like the perfect application.

Just as a side note, I'll change the target branch of the PR to develop so we can merge both this and #102 into main at the same time. Otherwise, we'd have two releases back-to-back.

@nicrie nicrie changed the base branch from main to develop November 7, 2023 12:51
@slevang
Copy link
Contributor Author

slevang commented Nov 7, 2023

Perfect, thanks, didn't notice the main/develop gitflow.

@slevang
Copy link
Contributor Author

slevang commented Nov 7, 2023

For reference, here's the DataTree for a serialized EOF model.

At most of the branches, we could probably get by with less deep nesting. But there were a few cases where I think the only safe route was to make each serialized attribute, or even individual values (usually coords) within a dict-like attribute, be isolated datasets/nodes.

The rotator classes are even heavier since we have to serialize the original model as well, but that was quite easy to add with DataTree.

DataTree('EOF', parent=None)
│   Dimensions:             (feature: 20, mode: 2, sample: 25)
│   Coordinates:
│     * feature             (feature) int64 0 1 2 3 4 5 6 7 ... 13 14 15 16 17 18 19
│     * mode                (mode) int64 1 2
│     * sample              (sample) datetime64[ns] 2001-01-01 ... 2025-01-01
│   Data variables:
│       input_data          float64 nan
│       components          (feature, mode) float64 0.1595 -0.02274 ... -0.1336
│       scores              (sample, mode) float64 -1.006 -5.303 ... -2.691 11.94
│       norms               (mode) float64 33.44 24.17
│       explained_variance  (mode) float64 46.58 24.35
│       total_variance      float64 195.8
│   Attributes:
│       params:   {'n_modes': 2, 'center': True, 'standardize': False, 'use_cosla...
└── DataTree('preprocessor')
    │   Dimensions:  ()
    │   Data variables:
    │       *empty*
    │   Attributes:
    │       params:   {'feature_name': 'feature', 'return_list': False, 'sample_name'...
    │       n_data:   1
    ├── DataTree('scaler')
    │   └── DataTree('0')
    │       │   Dimensions:  ()
    │       │   Data variables:
    │       │       *empty*
    │       │   Attributes:
    │       │       params:           {'with_center': True, 'with_coslat': False, 'with_std':...
    │       │       mean_:            _is_node
    │       │       std_:             _is_node
    │       │       coslat_weights_:  _is_node
    │       │       weights_:         _is_node
    │       ├── DataTree('mean_')
    │       │       Dimensions:  (lat: 5, lon: 4)
    │       │       Coordinates:
    │       │         * lat      (lat) float64 20.0 30.0 40.0 50.0 60.0
    │       │         * lon      (lon) float64 -10.0 0.0 10.0 20.0
    │       │       Data variables:
    │       │           t2m      (lat, lon) float64 4.711 4.587 5.164 4.597 ... 4.126 4.879 4.056
    │       │       Attributes:
    │       │           multiindexes:  {}
    │       │           name_map:      {'mean_': 't2m'}
    │       ├── DataTree('std_')
    │       │       Dimensions:  ()
    │       │       Data variables:
    │       │           std_     float64 nan
    │       │       Attributes:
    │       │           multiindexes:  {}
    │       │           name_map:      {'std_': 'std_'}
    │       ├── DataTree('coslat_weights_')
    │       │       Dimensions:          ()
    │       │       Data variables:
    │       │           coslat_weights_  float64 nan
    │       │       Attributes:
    │       │           multiindexes:  {}
    │       │           name_map:      {'coslat_weights_': 'coslat_weights_'}
    │       └── DataTree('weights_')
    │               Dimensions:   (lon: 4, lat: 5)
    │               Coordinates:
    │                 * lon       (lon) float64 -10.0 0.0 10.0 20.0
    │                 * lat       (lat) float64 20.0 30.0 40.0 50.0 60.0
    │               Data variables:
    │                   weights_  (lon, lat) float64 1.0 1.0 1.0 1.0 1.0 1.0 ... 1.0 1.0 1.0 1.0 1.0
    │               Attributes:
    │                   multiindexes:  {}
    │                   name_map:      {'weights_': 'weights_'}
    ├── DataTree('renamer')
    │   └── DataTree('0')
    │           Dimensions:  ()
    │           Data variables:
    │               *empty*
    │           Attributes:
    │               params:       {'base': 'dim', 'start': 0}
    │               dim_mapping:  {'time': 'dim0', 'lat': 'dim1', 'lon': 'dim2'}
    ├── DataTree('preconverter')
    │   └── DataTree('0')
    │           Dimensions:  ()
    │           Data variables:
    │               *empty*
    │           Attributes:
    │               params:                 {}
    │               modified_dimensions:    []
    │               coords_from_fit:        {}
    │               coords_from_transform:  {}
    ├── DataTree('stacker')
    │   └── DataTree('0')
    │       │   Dimensions:  ()
    │       │   Data variables:
    │       │       *empty*
    │       │   Attributes:
    │       │       params:        {'feature_name': 'feature', 'sample_name': 'sample'}
    │       │       dims_in:       ('dim0', 'dim1', 'dim2')
    │       │       dims_out:      ('sample', 'feature')
    │       │       dims_mapping:  {'sample': ('dim0',), 'feature': ('dim1', 'dim2')}
    │       │       coords_in:     _is_tree
    │       │       coords_out:    _is_tree
    │       ├── DataTree('coords_in')
    │       │   ├── DataTree('dim0')
    │       │   │       Dimensions:  (dim0: 25)
    │       │   │       Coordinates:
    │       │   │         * dim0     (dim0) datetime64[ns] 2001-01-01 2002-01-01 ... 2025-01-01
    │       │   │       Data variables:
    │       │   │           *empty*
    │       │   │       Attributes:
    │       │   │           multiindexes:  {}
    │       │   │           name_map:      {'dim0': 'dim0'}
    │       │   ├── DataTree('dim1')
    │       │   │       Dimensions:  (dim1: 5)
    │       │   │       Coordinates:
    │       │   │         * dim1     (dim1) float64 20.0 30.0 40.0 50.0 60.0
    │       │   │       Data variables:
    │       │   │           *empty*
    │       │   │       Attributes:
    │       │   │           multiindexes:  {}
    │       │   │           name_map:      {'dim1': 'dim1'}
    │       │   └── DataTree('dim2')
    │       │           Dimensions:  (dim2: 4)
    │       │           Coordinates:
    │       │             * dim2     (dim2) float64 -10.0 0.0 10.0 20.0
    │       │           Data variables:
    │       │               *empty*
    │       │           Attributes:
    │       │               multiindexes:  {}
    │       │               name_map:      {'dim2': 'dim2'}
    │       └── DataTree('coords_out')
    │           ├── DataTree('sample')
    │           │       Dimensions:  (sample: 25)
    │           │       Coordinates:
    │           │         * sample   (sample) datetime64[ns] 2001-01-01 2002-01-01 ... 2025-01-01
    │           │       Data variables:
    │           │           *empty*
    │           │       Attributes:
    │           │           multiindexes:  {}
    │           │           name_map:      {'sample': 'sample'}
    │           └── DataTree('feature')
    │                   Dimensions:  (feature: 20)
    │                   Coordinates:
    │                       dim1     (feature) float64 20.0 20.0 20.0 20.0 30.0 ... 60.0 60.0 60.0 60.0
    │                       dim2     (feature) float64 -10.0 0.0 10.0 20.0 -10.0 ... -10.0 0.0 10.0 20.0
    │                   Dimensions without coordinates: feature
    │                   Data variables:
    │                       *empty*
    │                   Attributes:
    │                       multiindexes:  {'feature': ['dim1', 'dim2']}
    │                       name_map:      {'feature': 'feature'}
    ├── DataTree('postconverter')
    │   └── DataTree('0')
    │       │   Dimensions:  ()
    │       │   Data variables:
    │       │       *empty*
    │       │   Attributes:
    │       │       params:                 {}
    │       │       modified_dimensions:    ['feature']
    │       │       coords_from_fit:        _is_tree
    │       │       coords_from_transform:  _is_tree
    │       ├── DataTree('coords_from_fit')
    │       │   └── DataTree('feature')
    │       │           Dimensions:  (feature: 20)
    │       │           Coordinates:
    │       │               dim1     (feature) float64 20.0 20.0 20.0 20.0 30.0 ... 60.0 60.0 60.0 60.0
    │       │               dim2     (feature) float64 -10.0 0.0 10.0 20.0 -10.0 ... -10.0 0.0 10.0 20.0
    │       │           Dimensions without coordinates: feature
    │       │           Data variables:
    │       │               *empty*
    │       │           Attributes:
    │       │               multiindexes:  {'feature': ['dim1', 'dim2']}
    │       │               name_map:      {'feature': 'feature'}
    │       └── DataTree('coords_from_transform')
    │           └── DataTree('feature')
    │                   Dimensions:  (feature: 20)
    │                   Coordinates:
    │                       dim1     (feature) float64 20.0 20.0 20.0 20.0 30.0 ... 60.0 60.0 60.0 60.0
    │                       dim2     (feature) float64 -10.0 0.0 10.0 20.0 -10.0 ... -10.0 0.0 10.0 20.0
    │                   Dimensions without coordinates: feature
    │                   Data variables:
    │                       *empty*
    │                   Attributes:
    │                       multiindexes:  {'feature': ['dim1', 'dim2']}
    │                       name_map:      {'feature': 'feature'}
    ├── DataTree('sanitizer')
    │   └── DataTree('0')
    │       │   Dimensions:  ()
    │       │   Data variables:
    │       │       *empty*
    │       │   Attributes:
    │       │       params:            {'feature_name': 'feature', 'sample_name': 'sample'}
    │       │       feature_coords:    _is_node
    │       │       sample_coords:     _is_node
    │       │       is_valid_feature:  _is_node
    │       ├── DataTree('feature_coords')
    │       │       Dimensions:  (feature: 20)
    │       │       Coordinates:
    │       │         * feature  (feature) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
    │       │       Data variables:
    │       │           *empty*
    │       │       Attributes:
    │       │           multiindexes:  {}
    │       │           name_map:      {'feature_coords': 'feature'}
    │       ├── DataTree('sample_coords')
    │       │       Dimensions:  (sample: 25)
    │       │       Coordinates:
    │       │         * sample   (sample) datetime64[ns] 2001-01-01 2002-01-01 ... 2025-01-01
    │       │       Data variables:
    │       │           *empty*
    │       │       Attributes:
    │       │           multiindexes:  {}
    │       │           name_map:      {'sample_coords': 'sample'}
    │       └── DataTree('is_valid_feature')
    │               Dimensions:           (feature: 20)
    │               Coordinates:
    │                 * feature           (feature) int64 0 1 2 3 4 5 6 7 ... 13 14 15 16 17 18 19
    │               Data variables:
    │                   is_valid_feature  (feature) bool True True True True ... True True True True
    │               Attributes:
    │                   multiindexes:  {}
    │                   name_map:      {'is_valid_feature': 'is_valid_feature'}
    └── DataTree('concatenator')
        │   Dimensions:  ()
        │   Data variables:
        │       *empty*
        │   Attributes:
        │       params:      {'feature_name': 'feature', 'sample_name': 'sample'}
        │       n_data:      1
        │       n_features:  [20]
        │       coords_in:   _is_tree
        └── DataTree('coords_in')
            └── DataTree('0')
                    Dimensions:  (feature: 20)
                    Coordinates:
                      * feature  (feature) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
                    Data variables:
                        *empty*
                    Attributes:
                        multiindexes:  {}
                        name_map:      {'0': 'feature'}

@codecov-commenter
Copy link

codecov-commenter commented Nov 7, 2023

Codecov Report

Attention: 28 lines in your changes are missing coverage. Please review.

Comparison is base (8e199ef) 91.46% compared to head (e08b408) 94.17%.
Report is 170 commits behind head on develop.

Additional details and impacted files
@@             Coverage Diff             @@
##           develop     #103      +/-   ##
===========================================
+ Coverage    91.46%   94.17%   +2.71%     
===========================================
  Files           39       71      +32     
  Lines         2600     5807    +3207     
===========================================
+ Hits          2378     5469    +3091     
- Misses         222      338     +116     
Files Coverage Δ
tests/models/test_cca.py 100.00% <100.00%> (ø)
tests/models/test_complex_eof.py 100.00% <100.00%> (ø)
tests/models/test_complex_mca.py 100.00% <100.00%> (ø)
tests/models/test_eeof.py 100.00% <100.00%> (ø)
tests/models/test_eof.py 100.00% <100.00%> (ø)
tests/models/test_eof_rotator.py 100.00% <100.00%> (+2.05%) ⬆️
tests/models/test_gwpca.py 100.00% <100.00%> (ø)
tests/models/test_mca.py 100.00% <ø> (ø)
tests/models/test_mca_rotator.py 100.00% <ø> (ø)
tests/models/test_opa.py 100.00% <ø> (ø)
... and 37 more

... and 25 files with indirect coverage changes

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor Author

@slevang slevang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some additional notes on various details of the implementation.

@@ -1,12 +1,26 @@
from __future__ import annotations
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to add this and the TYPE_CHECKING conditional to get tests to work with typing DataTree as an optional dependency. Not sure exactly what the best practice here. We could also consider adding zarr and xarray-datatree as non-optional dependencies since they are pretty lightweight libs.

xeofs/models/_base_model.py Show resolved Hide resolved
reindexed_data_list.append(reindexed)

self._dummy_feature_coords = dummy_feature_coords
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This _dummy_feature_coords was a little troublesome to serialize as dict[np.array], so I tried to instead recreate it in the inverse_transform pathway so we don't need to save it. This passes tests but I don't fully understand this class so let me know if this isn't right.

X = self.sanitizer.transform(X)
return self.concatenator.transform(X) # type: ignore
X_t = X.copy()
for transformer in self.get_transformers():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reduced these lines by defining a single source of truth for the order of transformer ops and then looping through them. Not sure if the copy is necessary in all these cases but didn't want to accidentally modify the original object if we don't want to. We could go a step further with this and try to use the transformer_types to instantiate the transformers as well, but we have to pass some unique kwargs to each there so I didn't bother.

if isinstance(transformer_obj, GenericListTransformer):
# Loop through list transformer objects and assign a dummy key
for i, transformer in enumerate(transformer_obj.transformers):
dt_transformer[str(i)] = transformer.serialize()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, lists containing other complex objects are not very conducive to serialization with a dict-like model such as DataTree, so I've added dummy keys like this in a few places.

ds = xr.Dataset(data_vars={data.name: data})

# Drop multiindexes and record for later
ds = ds.reset_index(list(multiindexes.keys()))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MultiIndexes are not serializable to zarr or netcdf currently, but was able to get around this by dropping and rebuilding on deserialization.

# Drop multiindexes and record for later
ds = ds.reset_index(list(multiindexes.keys()))
ds.attrs["multiindexes"] = multiindexes
ds.attrs["name_map"] = {key: data.name}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In some cases we have an attribute like feature_coords=xr.DataArray(name="feature"), where it is important to hold on to the original name, so this allows us to map attribute names to xarray names.

xeofs/preprocessing/transformer.py Show resolved Hide resolved
xeofs/preprocessing/transformer.py Show resolved Hide resolved
@slevang slevang marked this pull request as ready for review November 7, 2023 15:09
Copy link
Contributor

@nicrie nicrie left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

truly brilliant work! Yeah, the Rotator classes are pretty bloated but cool that xarray-datatree takes over so much of the grunt work. Learning a few things myself while reviewing the code :)

I like how cleanly the data structure is split up. It's probably the more robust solution for the future. Plus from a user perspective it doesn't really matter.

@nicrie nicrie self-requested a review November 7, 2023 15:10
Copy link
Contributor

@nicrie nicrie left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to add this and the TYPE_CHECKING conditional to get tests to work with typing DataTree as an optional dependency. Not sure exactly what the best practice here. We could also consider adding zarr and xarray-datatree as non-optional dependencies since they are pretty lightweight libs.

I'm in favour of just including them as non-optional

@nicrie
Copy link
Contributor

nicrie commented Nov 7, 2023

This _dummy_feature_coords was a little troublesome to serialize as dict[np.array], so I tried to instead recreate it in the inverse_transform pathway so we don't need to save it. This passes tests but I don't fully understand this class so let me know if this isn't right.

Everything looks good! With your approach, we indeed no longer need to save the dummy feature coordinate separately.

The sole purpose of the class is to concatenate a list of 2D DataArrays along the feature dimension. If the original feature dimension is only one-dimensional, the coordinates are preserved (the MultiIndexConverter ignores these). This can result in an attempt to concatenate multiple DataArrays with potentially incompatible coordinates at this point. In the case of already converted coordinates (which are always given by range(len(coords))), there's a risk of combining non-unique coordinates. Although xarray allows this in principle, we might run into trouble when performing mathematical operations along the coordinates. I can't remember exactly which mathematical operation causes the issue (it might even be .dot), but it seems to me that going into the individual models with streamlined coordinates is definitely the safer route.

@nicrie
Copy link
Contributor

nicrie commented Nov 7, 2023

I reduced these lines by defining a single source of truth for the order of transformer ops and then looping through them. Not sure if the copy is necessary in all these cases but didn't want to accidentally modify the original object if we don't want to. We could go a step further with this and try to use the transformer_types to instantiate the transformers as well, but we have to pass some unique kwargs to each there so I didn't bother.

That's beautiful!

@nicrie
Copy link
Contributor

nicrie commented Nov 7, 2023

Do you have a use case in mind where the user would require access to the model's serialize/deserialize methods? Or would providing just the save/load methods be sufficient?

@slevang
Copy link
Contributor Author

slevang commented Nov 7, 2023

Do you have a use case in mind where the user would require access to the model's serialize/deserialize methods? Or would providing just the save/load methods be sufficient?

Nope, nothing in mind, I only abstracted those out because it allowed me to reuse code for the non-rotated model here. They could definitely be private methods if you prefer.

@nicrie
Copy link
Contributor

nicrie commented Nov 7, 2023

Do you have a use case in mind where the user would require access to the model's serialize/deserialize methods? Or would providing just the save/load methods be sufficient?

Nope, nothing in mind, I only abstracted those out because it allowed me to reuse code for the non-rotated model here. They could definitely be private methods if you prefer.

Alright, that makes sense; we can keep it then. I think we're ready to go once all the checks have passed. :)

@nicrie nicrie merged commit 69df4d9 into xarray-contrib:develop Nov 7, 2023
6 checks passed
@slevang
Copy link
Contributor Author

slevang commented Nov 7, 2023

Nope, nothing in mind, I only abstracted those out because it allowed me to reuse code for the non-rotated model here. They could definitely be private methods if you prefer.

Actually I take this back, already thought of a potential use case for myself.

Say you want to fit a separate set of EOFs for each month in a dataset (seasonally varying). You end up with 12 separate models, and for tidy storage we could, instead of calling save() on each one, call serialize(), then assemble the DataTree for each month into a node on a parent tree and save the whole thing with to_zarr(). Then to reassemble the models, load the DataTree manually, separate out each node, and call deserialize().

You could of course save 12 separate zarr stores too, but I can imagine applications where this might be preferred.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add serialization methods
3 participants