diff --git a/xeofs/data_container/data_container.py b/xeofs/data_container/data_container.py index fdd64ee..9b8f948 100644 --- a/xeofs/data_container/data_container.py +++ b/xeofs/data_container/data_container.py @@ -1,5 +1,3 @@ -from typing import Dict - import dask from typing_extensions import Self @@ -66,7 +64,7 @@ def _validate_attrs_values(self, value): else: return value - def _validate_attrs(self, attrs: Dict) -> Dict: + def _validate_attrs(self, attrs: dict) -> dict: """Convert any boolean and None values to strings""" for key, value in attrs.items(): if isinstance(value, bool): @@ -78,7 +76,7 @@ def _validate_attrs(self, attrs: Dict) -> Dict: return attrs - def set_attrs(self, attrs: Dict): + def set_attrs(self, attrs: dict): attrs = self._validate_attrs(attrs) for key in self.keys(): self[key].attrs = attrs diff --git a/xeofs/models/cca.py b/xeofs/models/cca.py index 658bfbb..a378505 100644 --- a/xeofs/models/cca.py +++ b/xeofs/models/cca.py @@ -9,7 +9,7 @@ from abc import abstractmethod from datetime import datetime -from typing import Hashable, List, Sequence +from typing import Hashable, Sequence import dask.array as da import numpy as np @@ -143,7 +143,7 @@ def fit( Preprocessor(with_coslat=self.use_coslat[i], **self._preprocessor_kwargs) for i in range(self.n_views_) ] - views2D: List[DataArray] = [ + views2D: list[DataArray] = [ preprocessor.fit_transform(data, dim) for preprocessor, data in zip(self.preprocessors, views) ] @@ -215,7 +215,7 @@ def _apply_pca(self, views: DataList): return view_transformed @abstractmethod - def _fit_algorithm(self, views: List[DataArray]) -> Self: + def _fit_algorithm(self, views: list[DataArray]) -> Self: raise NotImplementedError @@ -304,7 +304,7 @@ def __init__( self.c = c self.eps = eps - def _fit_algorithm(self, views: List[DataArray]) -> Self: + def _fit_algorithm(self, views: list[DataArray]) -> Self: # Check input data [assert_not_complex(view) for view in views] @@ -620,26 +620,26 @@ def _apply_smallest_eigval(self, D, dims): def _smallest_eigval(self, D): return min(0, np.linalg.eigvalsh(D).min()) - def weights(self) -> List[DataObject]: + def weights(self) -> list[DataObject]: weights = [ prep.inverse_transform_components(wghts) for prep, wghts in zip(self.preprocessors, self.data["weights"]) ] return weights - def _transform(self, views: Sequence[DataArray]) -> List[DataArray]: + def _transform(self, views: Sequence[DataArray]) -> list[DataArray]: transformed_views = [] for i, view in enumerate(views): transformed_view = xr.dot(view, self.data["weights"][i], dims="feature") transformed_views.append(transformed_view) return transformed_views - def transform(self, views: Sequence[DataObject]) -> List[DataArray]: + def transform(self, views: Sequence[DataObject]) -> list[DataArray]: """Transform the input data into the canonical space. Parameters ---------- - views : List[DataArray | Dataset] + views : list[DataArray | Dataset] Input data to transform """ @@ -655,7 +655,7 @@ def transform(self, views: Sequence[DataObject]) -> List[DataArray]: unstacked_transformed_views.append(unstacked_view) return unstacked_transformed_views - def components(self, normalize: bool = True) -> List[DataObject]: + def components(self, normalize: bool = True) -> list[DataObject]: """Get the canonical loadings for each view.""" can_loads = self.data["canonical_loadings"] input_data = self.data["input_data"] @@ -681,7 +681,7 @@ def components(self, normalize: bool = True) -> List[DataObject]: ] return loadings - def scores(self) -> List[DataArray]: + def scores(self) -> list[DataArray]: """Get the canonical variates for each view.""" variates = [] for i, view in enumerate(self.data["variates"]): @@ -689,11 +689,11 @@ def scores(self) -> List[DataArray]: variates.append(vari) return variates - def explained_variance(self) -> List[DataArray]: + def explained_variance(self) -> list[DataArray]: """Get the explained variance for each view.""" return self.data["explained_variance"] - def explained_variance_ratio(self) -> List[DataArray]: + def explained_variance_ratio(self) -> list[DataArray]: """Get the explained variance ratio for each view.""" return self.data["explained_variance_ratio"] diff --git a/xeofs/models/cpcca.py b/xeofs/models/cpcca.py index 2c292c5..9bf767e 100644 --- a/xeofs/models/cpcca.py +++ b/xeofs/models/cpcca.py @@ -1,5 +1,5 @@ import warnings -from typing import Dict, Optional, Sequence, Tuple +from typing import Sequence import numpy as np import xarray as xr @@ -225,10 +225,10 @@ def _fit_algorithm( def _transform_algorithm( self, - X: Optional[DataArray] = None, - Y: Optional[DataArray] = None, + X: DataArray | None = None, + Y: DataArray | None = None, normalized=False, - ) -> Dict[str, DataArray]: + ) -> dict[str, DataArray]: results = {} if X is not None: # Project data onto singular vectors @@ -1159,7 +1159,7 @@ def _fit_algorithm(self, X: DataArray, Y: DataArray) -> Self: return super()._fit_algorithm(X, Y) - def components_amplitude(self, normalized=True) -> Tuple[DataObject, DataObject]: + def components_amplitude(self, normalized=True) -> tuple[DataObject, DataObject]: """Get the amplitude of the components. The amplitudes of the components are defined as @@ -1229,7 +1229,7 @@ def components_phase(self, normalized=True) -> tuple[DataObject, DataObject]: return Px, Py - def scores_amplitude(self, normalized=False) -> Tuple[DataArray, DataArray]: + def scores_amplitude(self, normalized=False) -> tuple[DataArray, DataArray]: """Get the amplitude of the scores. The amplitudes of the scores are defined as @@ -1264,7 +1264,7 @@ def scores_amplitude(self, normalized=False) -> Tuple[DataArray, DataArray]: return Rx, Ry - def scores_phase(self, normalized=False) -> Tuple[DataArray, DataArray]: + def scores_phase(self, normalized=False) -> tuple[DataArray, DataArray]: """Get the phase of the scores. The phases of the scores are defined as diff --git a/xeofs/models/cpcca_rotator.py b/xeofs/models/cpcca_rotator.py index b6853b8..c0157aa 100644 --- a/xeofs/models/cpcca_rotator.py +++ b/xeofs/models/cpcca_rotator.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Sequence +from typing import Sequence import numpy as np import xarray as xr @@ -103,7 +103,7 @@ def __init__( self.sorted = False - def get_serialization_attrs(self) -> Dict: + def get_serialization_attrs(self) -> dict: return dict( data=self.data, preprocessor1=self.preprocessor1, @@ -309,7 +309,7 @@ def transform( X: DataObject | None = None, Y: DataObject | None = None, normalized: bool = False, - ) -> DataArray | List[DataArray]: + ) -> DataArray | list[DataArray]: """Transform the data. Parameters diff --git a/xeofs/models/eeof.py b/xeofs/models/eeof.py index 77e6859..6f08122 100644 --- a/xeofs/models/eeof.py +++ b/xeofs/models/eeof.py @@ -1,12 +1,10 @@ -from typing import Optional -from typing_extensions import Self - import numpy as np import xarray as xr +from typing_extensions import Self -from .eof import EOF -from ..utils.data_types import DataArray from ..data_container import DataContainer +from ..utils.data_types import DataArray +from .eof import EOF class ExtendedEOF(EOF): @@ -29,7 +27,7 @@ class ExtendedEOF(EOF): Embedding dimension is the number of dimensions in the delay-coordinate space used to represent the dynamics of the system. It determines the number of delayed copies of the time series that are used to construct the delay-coordinate space. - n_pca_modes : Optional[int] + n_pca_modes : int, optional If provided, the input data is first preprocessed using PCA with the specified number of modes. The EEOF analysis is then performed on the resulting PCA scores. This approach can lead to important computational @@ -64,7 +62,7 @@ def __init__( n_modes: int, tau: int, embedding: int, - n_pca_modes: Optional[int] = None, + n_pca_modes: int | None = None, center: bool = True, standardize: bool = False, use_coslat: bool = False, @@ -73,7 +71,7 @@ def __init__( feature_name: str = "feature", compute: bool = True, solver: str = "auto", - random_state: Optional[int] = None, + random_state: int | None = None, solver_kwargs: dict = {}, **kwargs, ): diff --git a/xeofs/models/eof.py b/xeofs/models/eof.py index 4b234ab..7d89f82 100644 --- a/xeofs/models/eof.py +++ b/xeofs/models/eof.py @@ -1,5 +1,4 @@ import warnings -from typing import Dict, Optional import numpy as np import xarray as xr @@ -37,7 +36,7 @@ class EOF(_BaseModelSingleSet): If True, four pieces of the fit will be computed sequentially: 1) the preprocessor scaler, 2) optional NaN checks, 3) SVD decomposition, 4) scores and components. - random_state : Optional[int], default=None + random_state : int, optional Seed for the random number generator. solver: {"auto", "full", "randomized"}, default="auto" Solver to use for the SVD computation. @@ -62,9 +61,9 @@ def __init__( sample_name: str = "sample", feature_name: str = "feature", compute: bool = True, - random_state: Optional[int] = None, + random_state: int | None = None, solver: str = "auto", - solver_kwargs: Dict = {}, + solver_kwargs: dict = {}, **kwargs, ): super().__init__( @@ -144,7 +143,7 @@ def _inverse_transform_algorithm(self, scores: DataArray) -> DataArray: Returns ------- - data: DataArray | Dataset | List[DataArray] + data: DataArray | Dataset | list[DataArray] Reconstructed data. """ @@ -163,7 +162,7 @@ def components(self, normalized: bool = True) -> DataObject: Returns ------- - components: DataArray | Dataset | List[DataArray] + components: DataArray | Dataset | list[DataArray] Components of the fitted model. """ @@ -183,7 +182,7 @@ def scores(self, normalized: bool = False) -> DataArray: Returns ------- - components: DataArray | Dataset | List[DataArray] + components: DataArray | Dataset | list[DataArray] Scores of the fitted model. """ @@ -273,7 +272,7 @@ class ComplexEOF(EOF): computation. If True, four pieces of the fit will be computed sequentially: 1) the preprocessor scaler, 2) optional NaN checks, 3) SVD decomposition, 4) scores and components. - random_state : Optional[int], default=None + random_state : int, optional Seed for the random number generator. solver: {"auto", "full", "randomized"}, default="auto" Solver to use for the SVD computation. @@ -311,9 +310,9 @@ def __init__( sample_name: str = "sample", feature_name: str = "feature", compute: bool = True, - random_state: Optional[int] = None, + random_state: int | None = None, solver: str = "auto", - solver_kwargs: Dict = {}, + solver_kwargs: dict = {}, **kwargs, ): super().__init__( @@ -358,7 +357,7 @@ def components_amplitude(self, normalized=True) -> DataObject: Returns ------- - components_amplitude: DataArray | Dataset | List[DataArray] + components_amplitude: DataArray | Dataset | list[DataArray] Amplitude of the components of the fitted model. """ @@ -383,7 +382,7 @@ def components_phase(self) -> DataObject: Returns ------- - components_phase: DataArray | Dataset | List[DataArray] + components_phase: DataArray | Dataset | list[DataArray] Phase of the components of the fitted model. """ @@ -410,7 +409,7 @@ def scores_amplitude(self, normalized=True) -> DataArray: Returns ------- - scores_amplitude: DataArray | Dataset | List[DataArray] + scores_amplitude: DataArray | Dataset | list[DataArray] Amplitude of the scores of the fitted model. """ @@ -435,7 +434,7 @@ def scores_phase(self) -> DataArray: Returns ------- - scores_phase: DataArray | Dataset | List[DataArray] + scores_phase: DataArray | Dataset | list[DataArray] Phase of the scores of the fitted model. """ @@ -485,7 +484,7 @@ class HilbertEOF(ComplexEOF): If True, four pieces of the fit will be computed sequentially: 1) the preprocessor scaler, 2) optional NaN checks, 3) SVD decomposition, 4) scores and components. - random_state : Optional[int], default=None + random_state : int, optional Seed for the random number generator. solver: {"auto", "full", "randomized"}, default="auto" Solver to use for the SVD computation. @@ -520,9 +519,9 @@ def __init__( sample_name: str = "sample", feature_name: str = "feature", compute: bool = True, - random_state: Optional[int] = None, + random_state: int | None = None, solver: str = "auto", - solver_kwargs: Dict = {}, + solver_kwargs: dict = {}, **kwargs, ): super().__init__( diff --git a/xeofs/models/eof_rotator.py b/xeofs/models/eof_rotator.py index 169515f..bdfcf4e 100644 --- a/xeofs/models/eof_rotator.py +++ b/xeofs/models/eof_rotator.py @@ -1,5 +1,4 @@ from datetime import datetime -from typing import Dict import numpy as np import xarray as xr @@ -93,7 +92,7 @@ def __init__( self.sorted = False - def get_serialization_attrs(self) -> Dict: + def get_serialization_attrs(self) -> dict: return dict( data=self.data, preprocessor=self.preprocessor, diff --git a/xeofs/models/opa.py b/xeofs/models/opa.py index 7db7f45..8561196 100644 --- a/xeofs/models/opa.py +++ b/xeofs/models/opa.py @@ -1,5 +1,3 @@ -from typing import Dict, Optional - import numpy as np import xarray as xr from typing_extensions import Self @@ -80,8 +78,8 @@ def __init__( sample_name: str = "sample", feature_name: str = "feature", solver: str = "auto", - random_state: Optional[int] = None, - solver_kwargs: Dict = {}, + random_state: int | None = None, + solver_kwargs: dict = {}, ): if n_modes > n_pca_modes: raise ValueError( diff --git a/xeofs/models/sparse_pca.py b/xeofs/models/sparse_pca.py index 4a5c473..e6dbcf2 100644 --- a/xeofs/models/sparse_pca.py +++ b/xeofs/models/sparse_pca.py @@ -1,5 +1,4 @@ # %% -from typing import Dict, Optional import numpy as np import xarray as xr @@ -68,7 +67,7 @@ class SparsePCA(_BaseModelSingleSet): 2) optional NaN checks, 3) SVD decomposition, 4) scores and components. - random_state : Optional[int], default=None + random_state : int, optional Seed for the random number generator. solver : {"auto", "full", "randomized"}, default="randomized" Solver to use for the SVD computation. @@ -109,9 +108,9 @@ def __init__( feature_name: str = "feature", check_nans=True, compute: bool = True, - random_state: Optional[int] = None, + random_state: int | None = None, solver: str = "auto", - solver_kwargs: Dict = {}, + solver_kwargs: dict = {}, **kwargs, ): super().__init__( @@ -291,7 +290,7 @@ def components(self) -> DataObject: Returns ------- - components: DataArray | Dataset | List[DataArray] + components: DataArray | Dataset | list[DataArray] Components of the fitted model. """ @@ -314,7 +313,7 @@ def scores(self, normalized: bool = False) -> DataArray: Returns ------- - components: DataArray | Dataset | List[DataArray] + components: DataArray | Dataset | list[DataArray] Scores of the fitted model. """ diff --git a/xeofs/preprocessing/concatenator.py b/xeofs/preprocessing/concatenator.py index e8e25c6..94368c5 100644 --- a/xeofs/preprocessing/concatenator.py +++ b/xeofs/preprocessing/concatenator.py @@ -1,15 +1,13 @@ -from typing import List, Optional, Dict -from typing_extensions import Self - import numpy as np import xarray as xr +from typing_extensions import Self -from .transformer import Transformer from ..utils.data_types import ( + DataArray, Dims, DimsList, - DataArray, ) +from .transformer import Transformer class Concatenator(Transformer): @@ -22,7 +20,7 @@ def __init__(self, sample_name: str = "sample", feature_name: str = "feature"): self.n_features = [] self.coords_in = {} - def get_serialization_attrs(self) -> Dict: + def get_serialization_attrs(self) -> dict: return dict( n_data=self.n_data, n_features=self.n_features, @@ -31,9 +29,9 @@ def get_serialization_attrs(self) -> Dict: def fit( self, - X: List[DataArray], - sample_dims: Optional[Dims] = None, - feature_dims: Optional[DimsList] = None, + X: list[DataArray], + sample_dims: Dims | None = None, + feature_dims: DimsList | None = None, ) -> Self: # Check that all inputs are DataArrays if not all([isinstance(data, DataArray) for data in X]): @@ -57,14 +55,14 @@ def fit( return self - def transform(self, X: List[DataArray]) -> DataArray: + def transform(self, X: list[DataArray]) -> DataArray: # Test whether the input list has same length as the number of stackers if len(X) != self.n_data: raise ValueError( f"Invalid input. Number of DataArrays ({len(X)}) does not match the number of fitted DataArrays ({self.n_data})." ) - reindexed_data_list: List[DataArray] = [] + reindexed_data_list: list[DataArray] = [] idx_range = np.cumsum([0] + self.n_features) for i, data in enumerate(X): @@ -84,15 +82,15 @@ def transform(self, X: List[DataArray]) -> DataArray: def fit_transform( self, - X: List[DataArray], - sample_dims: Optional[Dims] = None, - feature_dims: Optional[DimsList] = None, + X: list[DataArray], + sample_dims: Dims | None = None, + feature_dims: DimsList | None = None, ) -> DataArray: return self.fit(X, sample_dims, feature_dims).transform(X) - def _split_dataarray_into_list(self, data: DataArray) -> List[DataArray]: + def _split_dataarray_into_list(self, data: DataArray) -> list[DataArray]: feature_name = self.feature_name - data_list: List[DataArray] = [] + data_list: list[DataArray] = [] idx_range = np.cumsum([0] + self.n_features) for i, coords in enumerate(self.coords_in.values()): @@ -106,11 +104,11 @@ def _split_dataarray_into_list(self, data: DataArray) -> List[DataArray]: return data_list - def inverse_transform_data(self, X: DataArray) -> List[DataArray]: + def inverse_transform_data(self, X: DataArray) -> list[DataArray]: """Reshape the 2D data (sample x feature) back into its original shape.""" return self._split_dataarray_into_list(X) - def inverse_transform_components(self, X: DataArray) -> List[DataArray]: + def inverse_transform_components(self, X: DataArray) -> list[DataArray]: """Reshape the 2D components (sample x feature) back into its original shape.""" return self._split_dataarray_into_list(X) diff --git a/xeofs/preprocessing/dimension_renamer.py b/xeofs/preprocessing/dimension_renamer.py index 537c581..52e4389 100644 --- a/xeofs/preprocessing/dimension_renamer.py +++ b/xeofs/preprocessing/dimension_renamer.py @@ -1,8 +1,7 @@ -from typing import Dict from typing_extensions import Self +from ..utils.data_types import Data, DataArray, DataVarBound, Dims from .transformer import Transformer -from ..utils.data_types import Dims, DataArray, Data, DataVarBound class DimensionRenamer(Transformer): @@ -23,7 +22,7 @@ def __init__(self, base="dim", start=0): self.start = start self.dim_mapping = {} - def get_serialization_attrs(self) -> Dict: + def get_serialization_attrs(self) -> dict: return dict( dim_mapping=self.dim_mapping, ) diff --git a/xeofs/preprocessing/list_processor.py b/xeofs/preprocessing/list_processor.py index 45fed8b..6dd1972 100644 --- a/xeofs/preprocessing/list_processor.py +++ b/xeofs/preprocessing/list_processor.py @@ -1,18 +1,19 @@ -from typing import List, TypeVar, Generic, Type, Dict, Any +from typing import Any, Generic, Type, TypeVar + from typing_extensions import Self -from .dimension_renamer import DimensionRenamer -from .scaler import Scaler -from .sanitizer import Sanitizer -from .multi_index_converter import MultiIndexConverter -from .stacker import Stacker from ..utils.data_types import ( Data, - DataVar, DataArray, + DataVar, Dims, DimsList, ) +from .dimension_renamer import DimensionRenamer +from .multi_index_converter import MultiIndexConverter +from .sanitizer import Sanitizer +from .scaler import Scaler +from .stacker import Stacker T = TypeVar( "T", @@ -33,27 +34,27 @@ class GenericListTransformer(Generic[T]): def __init__(self, transformer: Type[T], **kwargs): self.transformer_class = transformer - self.transformers: List[T] = [] + self.transformers: list[T] = [] self.init_kwargs = kwargs def fit( self, - X: List[DataVar], + X: list[DataVar], sample_dims: Dims, feature_dims: DimsList, - iter_kwargs: Dict[str, List[Any]] = {}, + iter_kwargs: dict[str, list[Any]] = {}, ) -> Self: """Fit transformer to each data element in the list. Parameters ---------- - X: List[Data] - List of data elements. + X: list[Data] + list of data elements. sample_dims: Dims Sample dimensions. feature_dims: DimsList Feature dimensions. - iter_kwargs: Dict[str, List[Any]] + iter_kwargs: dict[str, list[Any]] Keyword arguments for the transformer that should be iterated over. """ @@ -70,30 +71,30 @@ def fit( self.transformers.append(proc) return self - def transform(self, X: List[Data]) -> List[Data]: - X_transformed: List[Data] = [] + def transform(self, X: list[Data]) -> list[Data]: + X_transformed: list[Data] = [] for x, proc in zip(X, self.transformers): X_transformed.append(proc.transform(x)) # type: ignore return X_transformed def fit_transform( self, - X: List[Data], + X: list[Data], sample_dims: Dims, feature_dims: DimsList, - iter_kwargs: Dict[str, List[Any]] = {}, - ) -> List[Data]: + iter_kwargs: dict[str, list[Any]] = {}, + ) -> list[Data]: return self.fit(X, sample_dims, feature_dims, iter_kwargs).transform(X) # type: ignore - def inverse_transform_data(self, X: List[Data]) -> List[Data]: - X_inverse_transformed: List[Data] = [] + def inverse_transform_data(self, X: list[Data]) -> list[Data]: + X_inverse_transformed: list[Data] = [] for x, proc in zip(X, self.transformers): x_inv_trans = proc.inverse_transform_data(x) # type: ignore X_inverse_transformed.append(x_inv_trans) return X_inverse_transformed - def inverse_transform_components(self, X: List[Data]) -> List[Data]: - X_inverse_transformed: List[Data] = [] + def inverse_transform_components(self, X: list[Data]) -> list[Data]: + X_inverse_transformed: list[Data] = [] for x, proc in zip(X, self.transformers): x_inv_trans = proc.inverse_transform_components(x) # type: ignore X_inverse_transformed.append(x_inv_trans) diff --git a/xeofs/preprocessing/multi_index_converter.py b/xeofs/preprocessing/multi_index_converter.py index 65c01db..5a20cc3 100644 --- a/xeofs/preprocessing/multi_index_converter.py +++ b/xeofs/preprocessing/multi_index_converter.py @@ -1,9 +1,8 @@ -from typing import Optional, Dict -from typing_extensions import Self import pandas as pd +from typing_extensions import Self +from ..utils.data_types import Data, DataArray, DataVar, DataVarBound, Dims from .transformer import Transformer -from ..utils.data_types import Dims, DataArray, Data, DataVar, DataVarBound class MultiIndexConverter(Transformer): @@ -15,7 +14,7 @@ def __init__(self): self.coords_from_fit = {} self.coords_from_transform = {} - def get_serialization_attrs(self) -> Dict: + def get_serialization_attrs(self) -> dict: return dict( modified_dimensions=self.modified_dimensions, coords_from_fit=self.coords_from_fit, @@ -25,8 +24,8 @@ def get_serialization_attrs(self) -> Dict: def fit( self, X: Data, - sample_dims: Optional[Dims] = None, - feature_dims: Optional[Dims] = None, + sample_dims: Dims | None = None, + feature_dims: Dims | None = None, **kwargs, ) -> Self: # Store original MultiIndexes @@ -82,44 +81,3 @@ def inverse_transform_scores(self, X: DataArray) -> DataArray: def inverse_transform_scores_unseen(self, X: DataArray) -> DataArray: return self._inverse_transform(X, reference="transform") - - -# class DataListMultiIndexConverter(BaseEstimator, TransformerMixin): -# """Converts MultiIndexes to simple indexes and vice versa.""" - -# def __init__(self): -# self.converters: List[MultiIndexConverter] = [] - -# def fit(self, X: List[Data], y=None): -# for x in X: -# converter = MultiIndexConverter() -# converter.fit(x) -# self.converters.append(converter) - -# return self - -# def transform(self, X: List[Data]) -> List[Data]: -# X_transformed: List[Data] = [] -# for x, converter in zip(X, self.converters): -# X_transformed.append(converter.transform(x)) - -# return X_transformed - -# def fit_transform(self, X: List[Data], y=None) -> List[Data]: -# return self.fit(X, y).transform(X) - -# def _inverse_transform(self, X: List[Data]) -> List[Data]: -# X_inverse_transformed: List[Data] = [] -# for x, converter in zip(X, self.converters): -# X_inverse_transformed.append(converter._inverse_transform(x)) - -# return X_inverse_transformed - -# def inverse_transform_data(self, X: List[Data]) -> List[Data]: -# return self._inverse_transform(X) - -# def inverse_transform_components(self, X: List[Data]) -> List[Data]: -# return self._inverse_transform(X) - -# def inverse_transform_scores(self, X: DataArray) -> DataArray: -# return self.converters[0].inverse_transform_scores(X) diff --git a/xeofs/preprocessing/preprocessor.py b/xeofs/preprocessing/preprocessor.py index 0cbe138..d258323 100644 --- a/xeofs/preprocessing/preprocessor.py +++ b/xeofs/preprocessing/preprocessor.py @@ -1,5 +1,3 @@ -from typing import Dict, List, Optional, Tuple - import numpy as np from typing_extensions import Self @@ -31,13 +29,13 @@ from datatree import DataTree -def extract_new_dim_names(X: List[DimensionRenamer]) -> Tuple[Dims, DimsList]: +def extract_new_dim_names(X: list[DimensionRenamer]) -> tuple[Dims, DimsList]: """Extract the new dimension names from a list of DimensionRenamer objects. Parameters ---------- X : list of DimensionRenamer - List of DimensionRenamer objects. + list of DimensionRenamer objects. Returns ------- @@ -147,7 +145,7 @@ def __init__( # 7 | Concatenate into one 2D DataArray self.concatenator = Concatenator(**dim_names_as_kwargs) - def get_serialization_attrs(self) -> Dict: + def get_serialization_attrs(self) -> dict: return dict(n_data=self.n_data) def transformer_types(self): @@ -170,9 +168,9 @@ def get_transformers(self, inverse: bool = False): def fit( self, - X: List[Data] | Data, + X: list[Data] | Data, sample_dims: Dims, - weights: Optional[List[Data] | Data] = None, + weights: list[Data] | Data | None = None, ) -> Self: """Fit the preprocessor to the data. @@ -196,10 +194,10 @@ def fit( def _fit_algorithm( self, - X: List[Data] | Data, + X: list[Data] | Data, sample_dims: Dims, - weights: Optional[List[Data] | Data] = None, - ) -> Tuple[Self, Data]: + weights: list[Data] | Data | None = None, + ) -> tuple[Self, Data]: self._set_return_list(X) X = convert_to_list(X) self.n_data = len(X) @@ -235,7 +233,7 @@ def _fit_algorithm( return self, X - def transform(self, X: List[Data] | Data) -> DataArray: + def transform(self, X: list[Data] | Data) -> DataArray: """Transform the data. Parameters @@ -266,16 +264,16 @@ def transform(self, X: List[Data] | Data) -> DataArray: def fit_transform( self, - X: List[Data] | Data, + X: list[Data] | Data, sample_dims: Dims, - weights: Optional[List[Data] | Data] = None, + weights: list[Data] | Data | None = None, ) -> DataArray: # Take advantage of the fact that `.fit()` already transforms the data # to avoid duplicate computation self, X = self._fit_algorithm(X, sample_dims, weights) return X - def inverse_transform_data(self, X: DataArray) -> List[Data] | Data: + def inverse_transform_data(self, X: DataArray) -> list[Data] | Data: """Inverse transform the data. Parameters: @@ -295,7 +293,7 @@ def inverse_transform_data(self, X: DataArray) -> List[Data] | Data: return self._process_output(X_it) - def inverse_transform_components(self, X: DataArray) -> List[Data] | Data: + def inverse_transform_components(self, X: DataArray) -> list[Data] | Data: """Inverse transform the components. Parameters: @@ -359,7 +357,7 @@ def inverse_transform_scores_unseen(self, X: DataArray) -> DataArray: return X_it - def _process_output(self, X: List[Data]) -> List[Data] | Data: + def _process_output(self, X: list[Data]) -> list[Data] | Data: if self.return_list: return X else: diff --git a/xeofs/preprocessing/sanitizer.py b/xeofs/preprocessing/sanitizer.py index 1611b17..674ff9c 100644 --- a/xeofs/preprocessing/sanitizer.py +++ b/xeofs/preprocessing/sanitizer.py @@ -1,5 +1,3 @@ -from typing import Dict, Optional - import xarray as xr from dask.base import compute from typing_extensions import Self @@ -24,7 +22,7 @@ def __init__(self, sample_name="sample", feature_name="feature", check_nans=True self.sample_coords = xr.DataArray() self.is_valid_feature = xr.DataArray() - def get_serialization_attrs(self) -> Dict: + def get_serialization_attrs(self) -> dict: return dict( feature_coords=self.feature_coords, sample_coords=self.sample_coords, @@ -60,8 +58,8 @@ def _get_valid_features_per_sample(self, X: Data) -> Data: def fit( self, X: Data, - sample_dims: Optional[Dims] = None, - feature_dims: Optional[Dims] = None, + sample_dims: Dims | None = None, + feature_dims: Dims | None = None, **kwargs, ) -> Self: # Check if input is a DataArray diff --git a/xeofs/preprocessing/scaler.py b/xeofs/preprocessing/scaler.py index c2d4284..ed194bc 100644 --- a/xeofs/preprocessing/scaler.py +++ b/xeofs/preprocessing/scaler.py @@ -1,13 +1,11 @@ -from typing import Optional, Dict -from typing_extensions import Self - import dask import numpy as np import xarray as xr +from typing_extensions import Self -from .transformer import Transformer -from ..utils.data_types import Dims, DataArray, DataVar, DataVarBound +from ..utils.data_types import DataArray, DataVar, DataVarBound, Dims from ..utils.xarray_utils import compute_sqrt_cos_lat_weights, feature_ones_like +from .transformer import Transformer class Scaler(Transformer): @@ -48,7 +46,7 @@ def __init__( self.coslat_weights_ = xr.DataArray(name="coslat_weights_") self.weights_ = xr.DataArray(name="weights_") - def get_serialization_attrs(self) -> Dict: + def get_serialization_attrs(self) -> dict: return dict( mean_=self.mean_, std_=self.std_, @@ -73,7 +71,7 @@ def fit( X: DataVar, sample_dims: Dims, feature_dims: Dims, - weights: Optional[DataVar] = None, + weights: DataVar | None = None, ) -> Self: """Fit the scaler to the data. @@ -160,7 +158,7 @@ def fit_transform( X: DataVarBound, sample_dims: Dims, feature_dims: Dims, - weights: Optional[DataVarBound] = None, + weights: DataVarBound | None = None, ) -> DataVarBound: return self.fit(X, sample_dims, feature_dims, weights).transform(X) diff --git a/xeofs/preprocessing/stacker.py b/xeofs/preprocessing/stacker.py index 76838ba..7a07243 100644 --- a/xeofs/preprocessing/stacker.py +++ b/xeofs/preprocessing/stacker.py @@ -1,11 +1,9 @@ -from typing import Dict -from typing_extensions import Self - import pandas as pd import xarray as xr +from typing_extensions import Self +from ..utils.data_types import Data, DataArray, DataSet, DataVar, DataVarBound, Dims from .transformer import Transformer -from ..utils.data_types import Dims, DataArray, DataSet, Data, DataVar, DataVarBound class Stacker(Transformer): @@ -24,15 +22,15 @@ class Stacker(Transformer): The name of the sample dimension (dim=0). feature_name : str The name of the feature dimension (dim=1). - dims_in : Tuple[str] + dims_in : tuple[str] The dimensions of the input data. - dims_out : Tuple[str] + dims_out : tuple[str] The dimensions of the output data. - dims_mapping : Dict[str, Tuple[str]] + dims_mapping : dict[str, tuple[str]] The mapping between the input and output dimensions. - coords_in : Dict[str, xr.Coordinates] + coords_in : dict[str, xr.Coordinates] The coordinates of the input data. - coords_out : Dict[str, xr.Coordinates] + coords_out : dict[str, xr.Coordinates] The coordinates of the output data. """ @@ -51,7 +49,7 @@ def __init__( self.coords_out = {} self.data_type = None - def get_serialization_attrs(self) -> Dict: + def get_serialization_attrs(self) -> dict: return dict( dims_in=self.dims_in, dims_out=self.dims_out, diff --git a/xeofs/preprocessing/transformer.py b/xeofs/preprocessing/transformer.py index 5ff3982..6c7935a 100644 --- a/xeofs/preprocessing/transformer.py +++ b/xeofs/preprocessing/transformer.py @@ -1,18 +1,16 @@ -from abc import ABC -from typing import Optional, Dict -from typing_extensions import Self -from abc import abstractmethod +from abc import ABC, abstractmethod import pandas as pd import xarray as xr from sklearn.base import BaseEstimator, TransformerMixin +from typing_extensions import Self try: from xarray.core.datatree import DataTree except ImportError: from datatree import DataTree -from ..utils.data_types import Dims, DataArray, DataSet, Data +from ..utils.data_types import Data, DataArray, DataSet, Dims class Transformer(BaseEstimator, TransformerMixin, ABC): @@ -30,7 +28,7 @@ def __init__( self.feature_name = feature_name @abstractmethod - def get_serialization_attrs(self) -> Dict: + def get_serialization_attrs(self) -> dict: """Return a dictionary containing the attributes that need to be serialized as part of a saved transformer. @@ -46,8 +44,8 @@ def get_serialization_attrs(self) -> Dict: def fit( self, X: Data, - sample_dims: Optional[Dims] = None, - feature_dims: Optional[Dims] = None, + sample_dims: Dims | None = None, + feature_dims: Dims | None = None, **kwargs, ) -> Self: """Fit transformer to data. @@ -70,8 +68,8 @@ def transform(self, X: Data) -> Data: def fit_transform( self, X: Data, - sample_dims: Optional[Dims] = None, - feature_dims: Optional[Dims] = None, + sample_dims: Dims | None = None, + feature_dims: Dims | None = None, **kwargs, ) -> Data: return self.fit(X, sample_dims, feature_dims, **kwargs).transform(X) diff --git a/xeofs/preprocessing/whitener.py b/xeofs/preprocessing/whitener.py index 6a106da..952f44a 100644 --- a/xeofs/preprocessing/whitener.py +++ b/xeofs/preprocessing/whitener.py @@ -1,5 +1,4 @@ import warnings -from typing import Dict, Optional import numpy as np import xarray as xr @@ -41,7 +40,7 @@ class Whitener(Transformer): Name of the feature dimension. random_state: np.random.Generator | int | None, default=None Random seed for reproducibility. - solver_kwargs: Dict + solver_kwargs: dict Additional keyword arguments for the SVD solver. """ @@ -56,7 +55,7 @@ def __init__( sample_name: str = "sample", feature_name: str = "feature", random_state: np.random.Generator | int | None = None, - solver_kwargs: Dict = {}, + solver_kwargs: dict = {}, ): super().__init__(sample_name, feature_name) @@ -107,7 +106,7 @@ def _get_n_modes(self, X: DataArray) -> int | float: else: return self.n_modes - def get_serialization_attrs(self) -> Dict: + def get_serialization_attrs(self) -> dict: return dict( alpha=self.alpha, n_modes=self.n_modes, @@ -123,8 +122,8 @@ def get_serialization_attrs(self) -> Dict: def fit( self, X: DataArray, - sample_dims: Optional[Dims] = None, - feature_dims: Optional[DimsList] = None, + sample_dims: Dims | None = None, + feature_dims: DimsList | None = None, ) -> Self: self._sanity_check_input(X) n_samples, n_features = X.shape @@ -223,8 +222,8 @@ def transform(self, X: DataArray) -> DataArray: def fit_transform( self, X: DataArray, - sample_dims: Optional[Dims] = None, - feature_dims: Optional[DimsList] = None, + sample_dims: Dims | None = None, + feature_dims: DimsList | None = None, ) -> DataArray: return self.fit(X, sample_dims, feature_dims).transform(X) diff --git a/xeofs/utils/data_types.py b/xeofs/utils/data_types.py index 262b822..031653c 100644 --- a/xeofs/utils/data_types.py +++ b/xeofs/utils/data_types.py @@ -1,11 +1,4 @@ -from typing import ( - Hashable, - List, - Sequence, - Tuple, - TypeAlias, - TypeVar, -) +from typing import Hashable, Sequence, TypeAlias, TypeVar import dask.array as da from xarray.core import dataarray as xr_dataarray @@ -17,10 +10,10 @@ DataVar = TypeVar("DataVar", DataArray, DataSet) DataVarBound = TypeVar("DataVarBound", bound=Data) -DataArrayList: TypeAlias = List[DataArray] -DataSetList: TypeAlias = List[DataSet] -DataList: TypeAlias = List[Data] -DataVarList: TypeAlias = List[DataVar] +DataArrayList: TypeAlias = list[DataArray] +DataSetList: TypeAlias = list[DataSet] +DataList: TypeAlias = list[Data] +DataVarList: TypeAlias = list[DataVar] GenericType = TypeVar("GenericType") @@ -28,6 +21,6 @@ DataObject: TypeAlias = DataArray | DataSet | DataList Dims: TypeAlias = Sequence[Hashable] -DimsTuple: TypeAlias = Tuple[Dims, ...] -DimsList: TypeAlias = List[Dims] -DimsListTuple: TypeAlias = Tuple[DimsList, ...] +DimsTuple: TypeAlias = tuple[Dims, ...] +DimsList: TypeAlias = list[Dims] +DimsListTuple: TypeAlias = tuple[DimsList, ...] diff --git a/xeofs/utils/xarray_utils.py b/xeofs/utils/xarray_utils.py index fcbb2e1..e1775f7 100644 --- a/xeofs/utils/xarray_utils.py +++ b/xeofs/utils/xarray_utils.py @@ -1,4 +1,4 @@ -from typing import Any, Hashable, List, Sequence, Tuple, TypeVar +from typing import Any, Hashable, Sequence, TypeVar import numpy as np import xarray as xr @@ -19,7 +19,7 @@ T = TypeVar("T") -def unwrap_singleton_list(input_list: List[T]) -> T | List[T]: +def unwrap_singleton_list(input_list: list[T]) -> T | list[T]: if len(input_list) == 1: return input_list[0] else: @@ -47,7 +47,7 @@ def data_is_dask(data: DataArray | DataSet | DataList) -> bool: def process_parameter( parameter_name: str, parameter, default, n_data: int -) -> List[Any]: +) -> list[Any]: if parameter is None: return convert_to_list(default) * n_data elif isinstance(parameter, (list, tuple)): @@ -57,7 +57,7 @@ def process_parameter( return convert_to_list(parameter) * n_data -def convert_to_list(data: T | List[T] | Tuple[T]) -> List[T]: +def convert_to_list(data: T | list[T] | tuple[T]) -> list[T]: if isinstance(data, list): return data elif isinstance(data, tuple): @@ -162,21 +162,21 @@ def extract_latitude_dimension(feature_dims: Dims) -> Hashable: def get_dims( data: DataList, sample_dims: Hashable | Sequence[Hashable], -) -> Tuple[Dims, DimsList]: +) -> tuple[Dims, DimsList]: """Extracts the dimensions of a DataArray or Dataset that are not included in the sample dimensions. Parameters: ------------ data: xr.DataArray or xr.Dataset or list of xr.DataArray Input data. - sample_dims: Hashable or Sequence[Hashable] or List[Sequence[Hashable]] + sample_dims: Hashable or Sequence[Hashable] or list[Sequence[Hashable]] Sample dimensions. Returns: --------- - sample_dims: Tuple[Hashable] + sample_dims: tuple[Hashable] Sample dimensions. - feature_dims: Tuple[Hashable] + feature_dims: tuple[Hashable] Feature dimensions. """ @@ -199,12 +199,12 @@ def _get_feature_dims(data: DataArray | DataSet, sample_dims: Dims) -> Dims: ------------ data: xr.DataArray or xr.Dataset Input data. - sample_dims: Tuple[str] + sample_dims: tuple[str] Sample dimensions. Returns: --------- - feature_dims: Tuple[str] + feature_dims: tuple[str] Feature dimensions. """ diff --git a/xeofs/validation/bootstrapper.py b/xeofs/validation/bootstrapper.py index 8a1c78a..34c4acb 100644 --- a/xeofs/validation/bootstrapper.py +++ b/xeofs/validation/bootstrapper.py @@ -1,15 +1,15 @@ from abc import ABC, abstractmethod from datetime import datetime -from typing import Any, Dict +from typing import Any import numpy as np import xarray as xr from tqdm import trange -from ..models import EOF +from .._version import __version__ from ..data_container import DataContainer +from ..models import EOF from ..utils.data_types import DataArray -from .._version import __version__ class _BaseBootstrapper(ABC): @@ -22,7 +22,7 @@ def __init__(self, n_bootstraps=20, seed=None): } # Define analysis-relevant meta data - self.attrs: Dict[str, Any] = {"model": "BaseBootstrapper"} + self.attrs: dict[str, Any] = {"model": "BaseBootstrapper"} self.attrs.update(self._params) self.attrs.update( {