Skip to content

Commit

Permalink
refactor(typing): use built-in types for annotations (#208)
Browse files Browse the repository at this point in the history
  • Loading branch information
nicrie authored Sep 1, 2024
1 parent 278d2ee commit 7505350
Show file tree
Hide file tree
Showing 22 changed files with 168 additions and 239 deletions.
6 changes: 2 additions & 4 deletions xeofs/data_container/data_container.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Dict

import dask
from typing_extensions import Self

Expand Down Expand Up @@ -66,7 +64,7 @@ def _validate_attrs_values(self, value):
else:
return value

def _validate_attrs(self, attrs: Dict) -> Dict:
def _validate_attrs(self, attrs: dict) -> dict:
"""Convert any boolean and None values to strings"""
for key, value in attrs.items():
if isinstance(value, bool):
Expand All @@ -78,7 +76,7 @@ def _validate_attrs(self, attrs: Dict) -> Dict:

return attrs

def set_attrs(self, attrs: Dict):
def set_attrs(self, attrs: dict):
attrs = self._validate_attrs(attrs)
for key in self.keys():
self[key].attrs = attrs
24 changes: 12 additions & 12 deletions xeofs/models/cca.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from abc import abstractmethod
from datetime import datetime
from typing import Hashable, List, Sequence
from typing import Hashable, Sequence

import dask.array as da
import numpy as np
Expand Down Expand Up @@ -143,7 +143,7 @@ def fit(
Preprocessor(with_coslat=self.use_coslat[i], **self._preprocessor_kwargs)
for i in range(self.n_views_)
]
views2D: List[DataArray] = [
views2D: list[DataArray] = [
preprocessor.fit_transform(data, dim)
for preprocessor, data in zip(self.preprocessors, views)
]
Expand Down Expand Up @@ -215,7 +215,7 @@ def _apply_pca(self, views: DataList):
return view_transformed

@abstractmethod
def _fit_algorithm(self, views: List[DataArray]) -> Self:
def _fit_algorithm(self, views: list[DataArray]) -> Self:
raise NotImplementedError


Expand Down Expand Up @@ -304,7 +304,7 @@ def __init__(
self.c = c
self.eps = eps

def _fit_algorithm(self, views: List[DataArray]) -> Self:
def _fit_algorithm(self, views: list[DataArray]) -> Self:
# Check input data
[assert_not_complex(view) for view in views]

Expand Down Expand Up @@ -620,26 +620,26 @@ def _apply_smallest_eigval(self, D, dims):
def _smallest_eigval(self, D):
return min(0, np.linalg.eigvalsh(D).min())

def weights(self) -> List[DataObject]:
def weights(self) -> list[DataObject]:
weights = [
prep.inverse_transform_components(wghts)
for prep, wghts in zip(self.preprocessors, self.data["weights"])
]
return weights

def _transform(self, views: Sequence[DataArray]) -> List[DataArray]:
def _transform(self, views: Sequence[DataArray]) -> list[DataArray]:
transformed_views = []
for i, view in enumerate(views):
transformed_view = xr.dot(view, self.data["weights"][i], dims="feature")
transformed_views.append(transformed_view)
return transformed_views

def transform(self, views: Sequence[DataObject]) -> List[DataArray]:
def transform(self, views: Sequence[DataObject]) -> list[DataArray]:
"""Transform the input data into the canonical space.
Parameters
----------
views : List[DataArray | Dataset]
views : list[DataArray | Dataset]
Input data to transform
"""
Expand All @@ -655,7 +655,7 @@ def transform(self, views: Sequence[DataObject]) -> List[DataArray]:
unstacked_transformed_views.append(unstacked_view)
return unstacked_transformed_views

def components(self, normalize: bool = True) -> List[DataObject]:
def components(self, normalize: bool = True) -> list[DataObject]:
"""Get the canonical loadings for each view."""
can_loads = self.data["canonical_loadings"]
input_data = self.data["input_data"]
Expand All @@ -681,19 +681,19 @@ def components(self, normalize: bool = True) -> List[DataObject]:
]
return loadings

def scores(self) -> List[DataArray]:
def scores(self) -> list[DataArray]:
"""Get the canonical variates for each view."""
variates = []
for i, view in enumerate(self.data["variates"]):
vari = self.preprocessors[i].inverse_transform_scores(view)
variates.append(vari)
return variates

def explained_variance(self) -> List[DataArray]:
def explained_variance(self) -> list[DataArray]:
"""Get the explained variance for each view."""
return self.data["explained_variance"]

def explained_variance_ratio(self) -> List[DataArray]:
def explained_variance_ratio(self) -> list[DataArray]:
"""Get the explained variance ratio for each view."""
return self.data["explained_variance_ratio"]

Expand Down
14 changes: 7 additions & 7 deletions xeofs/models/cpcca.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Dict, Optional, Sequence, Tuple
from typing import Sequence

import numpy as np
import xarray as xr
Expand Down Expand Up @@ -225,10 +225,10 @@ def _fit_algorithm(

def _transform_algorithm(
self,
X: Optional[DataArray] = None,
Y: Optional[DataArray] = None,
X: DataArray | None = None,
Y: DataArray | None = None,
normalized=False,
) -> Dict[str, DataArray]:
) -> dict[str, DataArray]:
results = {}
if X is not None:
# Project data onto singular vectors
Expand Down Expand Up @@ -1159,7 +1159,7 @@ def _fit_algorithm(self, X: DataArray, Y: DataArray) -> Self:

return super()._fit_algorithm(X, Y)

def components_amplitude(self, normalized=True) -> Tuple[DataObject, DataObject]:
def components_amplitude(self, normalized=True) -> tuple[DataObject, DataObject]:
"""Get the amplitude of the components.
The amplitudes of the components are defined as
Expand Down Expand Up @@ -1229,7 +1229,7 @@ def components_phase(self, normalized=True) -> tuple[DataObject, DataObject]:

return Px, Py

def scores_amplitude(self, normalized=False) -> Tuple[DataArray, DataArray]:
def scores_amplitude(self, normalized=False) -> tuple[DataArray, DataArray]:
"""Get the amplitude of the scores.
The amplitudes of the scores are defined as
Expand Down Expand Up @@ -1264,7 +1264,7 @@ def scores_amplitude(self, normalized=False) -> Tuple[DataArray, DataArray]:

return Rx, Ry

def scores_phase(self, normalized=False) -> Tuple[DataArray, DataArray]:
def scores_phase(self, normalized=False) -> tuple[DataArray, DataArray]:
"""Get the phase of the scores.
The phases of the scores are defined as
Expand Down
6 changes: 3 additions & 3 deletions xeofs/models/cpcca_rotator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Sequence
from typing import Sequence

import numpy as np
import xarray as xr
Expand Down Expand Up @@ -103,7 +103,7 @@ def __init__(

self.sorted = False

def get_serialization_attrs(self) -> Dict:
def get_serialization_attrs(self) -> dict:
return dict(
data=self.data,
preprocessor1=self.preprocessor1,
Expand Down Expand Up @@ -309,7 +309,7 @@ def transform(
X: DataObject | None = None,
Y: DataObject | None = None,
normalized: bool = False,
) -> DataArray | List[DataArray]:
) -> DataArray | list[DataArray]:
"""Transform the data.
Parameters
Expand Down
14 changes: 6 additions & 8 deletions xeofs/models/eeof.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from typing import Optional
from typing_extensions import Self

import numpy as np
import xarray as xr
from typing_extensions import Self

from .eof import EOF
from ..utils.data_types import DataArray
from ..data_container import DataContainer
from ..utils.data_types import DataArray
from .eof import EOF


class ExtendedEOF(EOF):
Expand All @@ -29,7 +27,7 @@ class ExtendedEOF(EOF):
Embedding dimension is the number of dimensions in the delay-coordinate space used to represent
the dynamics of the system. It determines the number of delayed copies
of the time series that are used to construct the delay-coordinate space.
n_pca_modes : Optional[int]
n_pca_modes : int, optional
If provided, the input data is first preprocessed using PCA with the
specified number of modes. The EEOF analysis is then performed on the
resulting PCA scores. This approach can lead to important computational
Expand Down Expand Up @@ -64,7 +62,7 @@ def __init__(
n_modes: int,
tau: int,
embedding: int,
n_pca_modes: Optional[int] = None,
n_pca_modes: int | None = None,
center: bool = True,
standardize: bool = False,
use_coslat: bool = False,
Expand All @@ -73,7 +71,7 @@ def __init__(
feature_name: str = "feature",
compute: bool = True,
solver: str = "auto",
random_state: Optional[int] = None,
random_state: int | None = None,
solver_kwargs: dict = {},
**kwargs,
):
Expand Down
33 changes: 16 additions & 17 deletions xeofs/models/eof.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import warnings
from typing import Dict, Optional

import numpy as np
import xarray as xr
Expand Down Expand Up @@ -37,7 +36,7 @@ class EOF(_BaseModelSingleSet):
If True, four pieces of the fit will be computed sequentially: 1) the
preprocessor scaler, 2) optional NaN checks, 3) SVD decomposition, 4) scores
and components.
random_state : Optional[int], default=None
random_state : int, optional
Seed for the random number generator.
solver: {"auto", "full", "randomized"}, default="auto"
Solver to use for the SVD computation.
Expand All @@ -62,9 +61,9 @@ def __init__(
sample_name: str = "sample",
feature_name: str = "feature",
compute: bool = True,
random_state: Optional[int] = None,
random_state: int | None = None,
solver: str = "auto",
solver_kwargs: Dict = {},
solver_kwargs: dict = {},
**kwargs,
):
super().__init__(
Expand Down Expand Up @@ -144,7 +143,7 @@ def _inverse_transform_algorithm(self, scores: DataArray) -> DataArray:
Returns
-------
data: DataArray | Dataset | List[DataArray]
data: DataArray | Dataset | list[DataArray]
Reconstructed data.
"""
Expand All @@ -163,7 +162,7 @@ def components(self, normalized: bool = True) -> DataObject:
Returns
-------
components: DataArray | Dataset | List[DataArray]
components: DataArray | Dataset | list[DataArray]
Components of the fitted model.
"""
Expand All @@ -183,7 +182,7 @@ def scores(self, normalized: bool = False) -> DataArray:
Returns
-------
components: DataArray | Dataset | List[DataArray]
components: DataArray | Dataset | list[DataArray]
Scores of the fitted model.
"""
Expand Down Expand Up @@ -273,7 +272,7 @@ class ComplexEOF(EOF):
computation. If True, four pieces of the fit will be computed
sequentially: 1) the preprocessor scaler, 2) optional NaN checks, 3) SVD
decomposition, 4) scores and components.
random_state : Optional[int], default=None
random_state : int, optional
Seed for the random number generator.
solver: {"auto", "full", "randomized"}, default="auto"
Solver to use for the SVD computation.
Expand Down Expand Up @@ -311,9 +310,9 @@ def __init__(
sample_name: str = "sample",
feature_name: str = "feature",
compute: bool = True,
random_state: Optional[int] = None,
random_state: int | None = None,
solver: str = "auto",
solver_kwargs: Dict = {},
solver_kwargs: dict = {},
**kwargs,
):
super().__init__(
Expand Down Expand Up @@ -358,7 +357,7 @@ def components_amplitude(self, normalized=True) -> DataObject:
Returns
-------
components_amplitude: DataArray | Dataset | List[DataArray]
components_amplitude: DataArray | Dataset | list[DataArray]
Amplitude of the components of the fitted model.
"""
Expand All @@ -383,7 +382,7 @@ def components_phase(self) -> DataObject:
Returns
-------
components_phase: DataArray | Dataset | List[DataArray]
components_phase: DataArray | Dataset | list[DataArray]
Phase of the components of the fitted model.
"""
Expand All @@ -410,7 +409,7 @@ def scores_amplitude(self, normalized=True) -> DataArray:
Returns
-------
scores_amplitude: DataArray | Dataset | List[DataArray]
scores_amplitude: DataArray | Dataset | list[DataArray]
Amplitude of the scores of the fitted model.
"""
Expand All @@ -435,7 +434,7 @@ def scores_phase(self) -> DataArray:
Returns
-------
scores_phase: DataArray | Dataset | List[DataArray]
scores_phase: DataArray | Dataset | list[DataArray]
Phase of the scores of the fitted model.
"""
Expand Down Expand Up @@ -485,7 +484,7 @@ class HilbertEOF(ComplexEOF):
If True, four pieces of the fit will be computed sequentially: 1) the
preprocessor scaler, 2) optional NaN checks, 3) SVD decomposition, 4) scores
and components.
random_state : Optional[int], default=None
random_state : int, optional
Seed for the random number generator.
solver: {"auto", "full", "randomized"}, default="auto"
Solver to use for the SVD computation.
Expand Down Expand Up @@ -520,9 +519,9 @@ def __init__(
sample_name: str = "sample",
feature_name: str = "feature",
compute: bool = True,
random_state: Optional[int] = None,
random_state: int | None = None,
solver: str = "auto",
solver_kwargs: Dict = {},
solver_kwargs: dict = {},
**kwargs,
):
super().__init__(
Expand Down
3 changes: 1 addition & 2 deletions xeofs/models/eof_rotator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from datetime import datetime
from typing import Dict

import numpy as np
import xarray as xr
Expand Down Expand Up @@ -93,7 +92,7 @@ def __init__(

self.sorted = False

def get_serialization_attrs(self) -> Dict:
def get_serialization_attrs(self) -> dict:
return dict(
data=self.data,
preprocessor=self.preprocessor,
Expand Down
Loading

0 comments on commit 7505350

Please sign in to comment.