-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: serialization methods #103
Conversation
Previously, the Preprocessor was fully initialized within the fit method ,as the decision between using a DataArrayStacker or a DataSetStacker depended on the data being fitted. This commit consolidates both stackers into a unified Stacker class, enabling the complete initialization of the Preprocessor at the start of the model initialization process.
Hey, thanks Sam, that's really awesome! I've only skimmed through it, but it looks super clean. I'll try to get the CI up and running, and then I'll take a closer look at the code this evening. I haven't used `xarray-datatree before either, but it sounds like the perfect application. Just as a side note, I'll change the target branch of the PR to |
Perfect, thanks, didn't notice the |
For reference, here's the At most of the branches, we could probably get by with less deep nesting. But there were a few cases where I think the only safe route was to make each serialized attribute, or even individual values (usually coords) within a dict-like attribute, be isolated datasets/nodes. The rotator classes are even heavier since we have to serialize the original model as well, but that was quite easy to add with DataTree.
|
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## develop #103 +/- ##
===========================================
+ Coverage 91.46% 94.17% +2.71%
===========================================
Files 39 71 +32
Lines 2600 5807 +3207
===========================================
+ Hits 2378 5469 +3091
- Misses 222 338 +116
☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some additional notes on various details of the implementation.
xeofs/models/_base_model.py
Outdated
@@ -1,12 +1,26 @@ | |||
from __future__ import annotations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had to add this and the TYPE_CHECKING
conditional to get tests to work with typing DataTree
as an optional dependency. Not sure exactly what the best practice here. We could also consider adding zarr
and xarray-datatree
as non-optional dependencies since they are pretty lightweight libs.
reindexed_data_list.append(reindexed) | ||
|
||
self._dummy_feature_coords = dummy_feature_coords |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This _dummy_feature_coords
was a little troublesome to serialize as dict[np.array]
, so I tried to instead recreate it in the inverse_transform
pathway so we don't need to save it. This passes tests but I don't fully understand this class so let me know if this isn't right.
X = self.sanitizer.transform(X) | ||
return self.concatenator.transform(X) # type: ignore | ||
X_t = X.copy() | ||
for transformer in self.get_transformers(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I reduced these lines by defining a single source of truth for the order of transformer ops and then looping through them. Not sure if the copy
is necessary in all these cases but didn't want to accidentally modify the original object if we don't want to. We could go a step further with this and try to use the transformer_types
to instantiate the transformers as well, but we have to pass some unique kwargs to each there so I didn't bother.
if isinstance(transformer_obj, GenericListTransformer): | ||
# Loop through list transformer objects and assign a dummy key | ||
for i, transformer in enumerate(transformer_obj.transformers): | ||
dt_transformer[str(i)] = transformer.serialize() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, lists containing other complex objects are not very conducive to serialization with a dict-like model such as DataTree, so I've added dummy keys like this in a few places.
ds = xr.Dataset(data_vars={data.name: data}) | ||
|
||
# Drop multiindexes and record for later | ||
ds = ds.reset_index(list(multiindexes.keys())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MultiIndexes are not serializable to zarr or netcdf currently, but was able to get around this by dropping and rebuilding on deserialization.
# Drop multiindexes and record for later | ||
ds = ds.reset_index(list(multiindexes.keys())) | ||
ds.attrs["multiindexes"] = multiindexes | ||
ds.attrs["name_map"] = {key: data.name} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In some cases we have an attribute like feature_coords=xr.DataArray(name="feature")
, where it is important to hold on to the original name, so this allows us to map attribute names to xarray names.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
truly brilliant work! Yeah, the Rotator classes are pretty bloated but cool that xarray-datatree
takes over so much of the grunt work. Learning a few things myself while reviewing the code :)
I like how cleanly the data structure is split up. It's probably the more robust solution for the future. Plus from a user perspective it doesn't really matter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had to add this and the
TYPE_CHECKING
conditional to get tests to work with typingDataTree
as an optional dependency. Not sure exactly what the best practice here. We could also consider addingzarr
andxarray-datatree
as non-optional dependencies since they are pretty lightweight libs.
I'm in favour of just including them as non-optional
Everything looks good! With your approach, we indeed no longer need to save the dummy feature coordinate separately. The sole purpose of the class is to concatenate a list of 2D DataArrays along the feature dimension. If the original feature dimension is only one-dimensional, the coordinates are preserved (the |
That's beautiful! |
Do you have a use case in mind where the user would require access to the model's |
Nope, nothing in mind, I only abstracted those out because it allowed me to reuse code for the non-rotated model here. They could definitely be private methods if you prefer. |
Alright, that makes sense; we can keep it then. I think we're ready to go once all the checks have passed. :) |
Actually I take this back, already thought of a potential use case for myself. Say you want to fit a separate set of EOFs for each month in a dataset (seasonally varying). You end up with 12 separate models, and for tidy storage we could, instead of calling You could of course save 12 separate zarr stores too, but I can imagine applications where this might be preferred. |
Closes #95.
This ended up being a lot harder than I expected, mostly due to the nested structure of the preprocessor and its attributes, plus the mixture of basic data types and DataArrays. The only workable solution I could come up was to add
xarray-datatree
andzarr
as optional dependencies, and build the whole structure as a tree. This was my first foray into DataTree which is super flexible and worked great for this task. I bailed onnetcdf
only because there are serious limitations with representing any sort of nested structures in attributes, i.e. dicts.save()
/load()
are currently implemented and tested forEOF
,EOFRotator
,MCA
, andMCARotator
. I didn't bother with any of the models that don't yet implement atransform()
method, because there would be less use for a saved model in this case. Should be pretty easy to extend in the future though if desired.