From 0f95d0444a6e9d93dba30bd0c4211fe9000ea6bd Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Tue, 7 Nov 2023 20:25:07 -0500 Subject: [PATCH] fix: safe default zarr write mode --- xeofs/models/_base_cross_model.py | 7 ++++++- xeofs/models/_base_model.py | 7 ++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/xeofs/models/_base_cross_model.py b/xeofs/models/_base_cross_model.py index 295c8de..17c4a6f 100644 --- a/xeofs/models/_base_cross_model.py +++ b/xeofs/models/_base_cross_model.py @@ -284,6 +284,7 @@ def serialize(self, save_data: bool = False) -> DataTree: def save( self, path: str, + overwrite: bool = False, save_data: bool = False, **kwargs, ): @@ -293,6 +294,8 @@ def save( ---------- path : str Path to save the model zarr store. + overwrite: bool, default=False + Whether or not to overwrite the existing path if it already exists. save_data : str Whether or not to save the full input data along with the fitted components. **kwargs @@ -306,7 +309,9 @@ def save( if hasattr(self, "model"): dt["model"] = self.model.serialize() - dt.to_zarr(path, **kwargs) + write_mode = "w" if overwrite else "w-" + + dt.to_zarr(path, mode=write_mode, **kwargs) @classmethod def deserialize(cls, dt: DataTree) -> Self: diff --git a/xeofs/models/_base_model.py b/xeofs/models/_base_model.py index 9a2d954..1683312 100644 --- a/xeofs/models/_base_model.py +++ b/xeofs/models/_base_model.py @@ -357,6 +357,7 @@ def serialize(self, save_data: bool = False) -> DataTree: def save( self, path: str, + overwrite: bool = False, save_data: bool = False, **kwargs, ): @@ -366,6 +367,8 @@ def save( ---------- path : str Path to save the model zarr store. + overwrite: bool, default=False + Whether or not to overwrite the existing path if it already exists. save_data : str Whether or not to save the full input data along with the fitted components. **kwargs @@ -379,7 +382,9 @@ def save( if hasattr(self, "model"): dt["model"] = self.model.serialize(save_data=save_data) - dt.to_zarr(path, **kwargs) + write_mode = "w" if overwrite else "w-" + + dt.to_zarr(path, mode=write_mode, **kwargs) @classmethod def deserialize(cls, dt: DataTree) -> Self: