Skip to content

Commit

Permalink
CLN: Remove unused and unnecessary functions.
Browse files Browse the repository at this point in the history
pysindy.GeneralizedLibrary.has_type - not used
pysindy.utils.ax_time_to_ax_spatial - limited use in previous refactor.
  • Loading branch information
Jacob-Stevens-Haas committed Jul 3, 2022
1 parent 40cb00a commit f9ebbbd
Show file tree
Hide file tree
Showing 9 changed files with 9 additions and 63 deletions.
10 changes: 2 additions & 8 deletions pysindy/feature_library/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from ..utils import comprehend_axes
from ..utils import validate_no_reshape
from ..utils import wrap_axes
from ..utils.axes import ax_time_to_ax_sample


class BaseFeatureLibrary(TransformerMixin):
Expand Down Expand Up @@ -188,16 +187,11 @@ def func(self, x, *args, **kwargs):
else:
if not sparse.issparse(x):
x = AxesArray(x, comprehend_axes(x))
x = ax_time_to_ax_sample(x)
reconstructor = np.array
else: # sparse arrays
reconstructor = type(x)
axes = comprehend_axes(x)
wrap_axes(axes, x)
# Can't use x = ax_time_to_ax_sample(x) b/c that creates
# an AxesArray
x.ax_sample = x.ax_time
x.ax_time = None
result = wrapped_func(self, [x], *args, **kwargs)
if isinstance(result, Sequence): # e.g. transform() returns x
return reconstructor(result[0])
Expand Down Expand Up @@ -332,7 +326,7 @@ def transform(self, x_full):
xp[..., start_feature_index:end_feature_index] = lib.transform([x])[0]

current_feat += lib_n_output_features
xp = ax_time_to_ax_sample(AxesArray(xp, comprehend_axes(xp)))
xp = AxesArray(xp, comprehend_axes(xp))
xp_full.append(xp)
if self.library_ensemble:
xp_full = self._ensemble(xp_full)
Expand Down Expand Up @@ -565,7 +559,7 @@ def transform(self, x_full):

current_feat += lib_i_n_output_features * lib_j_n_output_features

xp = ax_time_to_ax_sample(AxesArray(xp, comprehend_axes(xp)))
xp = AxesArray(xp, comprehend_axes(xp))
xp_full.append(xp)
if self.library_ensemble:
xp_full = self._ensemble(xp_full)
Expand Down
3 changes: 1 addition & 2 deletions pysindy/feature_library/custom_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from sklearn import __version__
from sklearn.utils.validation import check_is_fitted

from ..utils import ax_time_to_ax_sample
from ..utils import AxesArray
from ..utils import comprehend_axes
from .base import BaseFeatureLibrary
Expand Down Expand Up @@ -231,7 +230,7 @@ def transform(self, x_full):
xp[..., library_idx] = f(*[x[..., j] for j in c])
library_idx += 1

xp = ax_time_to_ax_sample(AxesArray(xp, comprehend_axes(xp)))
xp = AxesArray(xp, comprehend_axes(xp))
xp_full.append(xp)
if self.library_ensemble:
xp_full = self._ensemble(xp_full)
Expand Down
3 changes: 1 addition & 2 deletions pysindy/feature_library/fourier_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from ..utils import comprehend_axes
from .base import BaseFeatureLibrary
from .base import x_sequence_or_item
from pysindy.utils.axes import ax_time_to_ax_sample


class FourierLibrary(BaseFeatureLibrary):
Expand Down Expand Up @@ -175,7 +174,7 @@ def transform(self, x_full):
if self.include_cos:
xp[..., idx] = np.cos((i + 1) * x[..., j])
idx += 1
xp = ax_time_to_ax_sample(AxesArray(xp, comprehend_axes(xp)))
xp = AxesArray(xp, comprehend_axes(xp))
xp_full.append(xp)
if self.library_ensemble:
xp_full = self._ensemble(xp_full)
Expand Down
19 changes: 0 additions & 19 deletions pysindy/feature_library/generalized_library.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Type

import numpy as np
from sklearn import __version__
from sklearn.utils.validation import check_is_fitted
Expand Down Expand Up @@ -244,23 +242,6 @@ def fit(self, x_full, y=None):

return self

def has_type(self, libtype: Type, exclusively=False) -> bool:
"""Checks whether this library has a specific library type.
Parameters
----------
libtype : A type of feature library
exclusively: whether to check all libraries
Returns
-------
Bool indicating whether specific library type is present
"""
has_inst = map(lambda lib: isinstance(lib, libtype), self.libraries_)
if exclusively:
return all(has_inst)
return any(has_inst)

@x_sequence_or_item
def transform(self, x_full):
"""Transform data with libs provided below.
Expand Down
3 changes: 1 addition & 2 deletions pysindy/feature_library/pde_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from sklearn import __version__
from sklearn.utils.validation import check_is_fitted

from ..utils import ax_time_to_ax_sample
from ..utils import AxesArray
from ..utils import comprehend_axes
from .base import BaseFeatureLibrary
Expand Down Expand Up @@ -465,7 +464,7 @@ def transform(self, x_full):
shape,
)
library_idx += n_library_terms * self.num_derivatives * n_features
xp = ax_time_to_ax_sample(AxesArray(xp, comprehend_axes(xp)))
xp = AxesArray(xp, comprehend_axes(xp))
xp_full.append(xp)
if self.library_ensemble:
xp_full = self._ensemble(xp_full)
Expand Down
6 changes: 1 addition & 5 deletions pysindy/feature_library/polynomial_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,12 +233,8 @@ def transform(self, x_full):
axes = comprehend_axes(x)
x = x.asformat("csr")
wrap_axes(axes, x)
# Can't use x = ax_time_to_ax_sample(x) b/c that creates
# an AxesArray
x.ax_sample = x.ax_time
x.ax_time = None

n_samples = x.shape[x.ax_sample]
n_samples = x.shape[x.ax_time]
n_features = x.shape[x.ax_coord]
if float(__version__[:3]) >= 1.0:
if n_features != self.n_features_in_:
Expand Down
14 changes: 3 additions & 11 deletions pysindy/pysindy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from .optimizers import SINDyOptimizer
from .optimizers import SINDyPI
from .optimizers import STLSQ
from .utils import ax_time_to_ax_sample
from .utils import AxesArray
from .utils import comprehend_axes
from .utils import concat_sample_axis
Expand Down Expand Up @@ -345,20 +344,17 @@ def fit(
if u is not None:
x = [np.concatenate((xi, ui), axis=xi.ax_coord) for xi, ui in zip(x, u)]

x = [ax_time_to_ax_sample(xi) for xi in x]
x_dot = [ax_time_to_ax_sample(xdoti) for xdoti in x_dot]

if hasattr(self.optimizer, "unbias"):
unbias = self.optimizer.unbias

# backwards compatibility for ensemble options
if ensemble and n_subset is None:
n_subset = x[0].shape[x[0].ax_sample]
n_subset = x[0].shape[x[0].ax_time]
if library_ensemble:
self.feature_library.library_ensemble = False
if ensemble and not library_ensemble:
if n_subset is None:
n_sample_tot = np.sum([xi.shape[xi.ax_sample] for xi in x])
n_sample_tot = np.sum([xi.shape[xi.ax_time] for xi in x])
n_subset = int(0.6 * n_sample_tot)
optimizer = SINDyOptimizer(
EnsembleOptimizer(
Expand All @@ -382,7 +378,7 @@ def fit(
self.coef_list = optimizer.optimizer.coef_list
elif ensemble and library_ensemble:
if n_subset is None:
n_sample_tot = np.sum([xi.shape[xi.ax_sample] for xi in x])
n_sample_tot = np.sum([xi.shape[xi.ax_time] for xi in x])
n_subset = int(0.6 * n_sample_tot)
optimizer = SINDyOptimizer(
EnsembleOptimizer(
Expand Down Expand Up @@ -472,7 +468,6 @@ def predict(self, x, u=None, multiple_trajectories=False):
if u is not None:
u = validate_control_variables(x, u)
x = [np.concatenate((xi, ui), axis=xi.ax_coord) for xi, ui in zip(x, u)]
x = [ax_time_to_ax_sample(xi) for xi in x]
result = [self.model.predict([xi]) for xi in x]
result = [
self.feature_library.reshape_samples_to_spatial_grid(pred)
Expand Down Expand Up @@ -616,9 +611,6 @@ def score(
x_dot_predict = [xd[:-1] for xd in x_dot_predict]

x, x_dot = self._process_multiple_trajectories(x, t, x_dot)
x = [ax_time_to_ax_sample(xi) for xi in x]
x_dot = [ax_time_to_ax_sample(xdoti) for xdoti in x_dot]
x_dot_predict = [ax_time_to_ax_sample(xdip) for xdip in x_dot_predict]

x_dot = concat_sample_axis(x_dot)
x_dot_predict = concat_sample_axis(x_dot_predict)
Expand Down
2 changes: 0 additions & 2 deletions pysindy/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .axes import ax_time_to_ax_sample
from .axes import AxesArray
from .axes import comprehend_axes
from .axes import concat_sample_axis
Expand Down Expand Up @@ -58,7 +57,6 @@
__all__ = [
"AxesArray",
"SampleConcatter",
"ax_time_to_ax_sample",
"concat_sample_axis",
"wrap_axes",
"comprehend_axes",
Expand Down
12 changes: 0 additions & 12 deletions pysindy/utils/axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,18 +134,6 @@ def comprehend_axes(x):
return axes


def ax_time_to_ax_sample(x: AxesArray) -> AxesArray:
"""Relabel the time axis as a sample axis"""
if x.ax_sample is not None:
return x # idempotence: f(x) = f(f(x))
if x.ax_time is None:
raise TypeError("Cannot reassign time axis if it doesn't exist.")
new_axes = x.__dict__
new_axes["ax_sample"] = new_axes["ax_time"]
new_axes["ax_time"] = None
return AxesArray(np.asarray(x), new_axes)


class SampleConcatter(TransformerMixin):
def __init__(self):
pass
Expand Down

0 comments on commit f9ebbbd

Please sign in to comment.