Skip to content

Commit

Permalink
Merge pull request #185 from Jacob-Stevens-Haas/simplify-fitting
Browse files Browse the repository at this point in the history
Simplify derivatives in pysindy.py
  • Loading branch information
Jacob-Stevens-Haas authored Jul 3, 2022
2 parents a7f94bf + f9ebbbd commit 632585a
Show file tree
Hide file tree
Showing 38 changed files with 2,842 additions and 1,875 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
build
dist
*.egg-info
commit-msg
todo

# Environments
.env
Expand Down
63 changes: 31 additions & 32 deletions examples/10_PDEFIND_examples.ipynb

Large diffs are not rendered by default.

231 changes: 45 additions & 186 deletions examples/12_weakform_SINDy_examples.ipynb

Large diffs are not rendered by default.

149 changes: 91 additions & 58 deletions examples/13_ensembling.ipynb

Large diffs are not rendered by default.

158 changes: 92 additions & 66 deletions examples/15_pysindy_lectures.ipynb

Large diffs are not rendered by default.

216 changes: 176 additions & 40 deletions examples/1_feature_overview.ipynb

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions examples/7_plasma_examples.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@
" \n",
" if feature_names is None:\n",
" feature_names = [f\"f{k}\" for k in range(coefficients.shape[1])]\n",
"\n",
" max_ind = len(feature_names)\n",
" \n",
" with sns.axes_style(style=\"white\", rc={\"axes.facecolor\": (0, 0, 0, 0)}): \n",
" if ax is None:\n",
Expand All @@ -269,7 +269,7 @@
" heatmap_args.update(**heatmap_kws)\n",
"\n",
" sns.heatmap(\n",
" coefficients.T,\n",
" coefficients[:, :max_ind].T,\n",
" **heatmap_args\n",
" )\n",
"\n",
Expand Down Expand Up @@ -1609,7 +1609,7 @@
"fig.show()\n",
"\n",
"X, Y = np.meshgrid(\n",
" np.linspace(r_scan[0], r_scan[-1] + 1, r_scan[-1] - 1) - 0.5,\n",
" np.linspace(r_scan[0], r_scan[-1] + 1, r_length) - 0.5,\n",
" thresholds,\n",
" indexing=\"ij\",\n",
")\n",
Expand Down Expand Up @@ -1853,7 +1853,7 @@
"\n",
"fig, axs = plt.subplots(2, 1, figsize=(14, 10))\n",
"X, Y = np.meshgrid(\n",
" np.linspace(r_scan[0], r_scan[-1] + 1, r_scan[-1] - 1) - 0.5,\n",
" np.linspace(r_scan[0], r_scan[-1] + 1, r_length) - 0.5,\n",
" thresholds,\n",
" indexing=\"ij\",\n",
")\n",
Expand Down
1,033 changes: 738 additions & 295 deletions examples/9_sindypi_with_sympy.ipynb

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@ exclude = '''
\.git
| \.mypy_cache
| \.venv
| .vscode
| version.py
| build
| dist
| env
)/
'''

Expand Down
4 changes: 3 additions & 1 deletion pysindy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from . import deeptime
from . import utils
from .pysindy import SINDy
from .pysindy import AxesArray
from .differentiation import BaseDifferentiation
from .differentiation import FiniteDifference
from .differentiation import SpectralDerivative
Expand All @@ -36,9 +37,10 @@
from .optimizers import STLSQ
from .optimizers import SINDyPI
from .optimizers import TrappingSR3
from .optimizers import EnsembleOptimizer


__all__ = ["SINDy"]
__all__ = ["SINDy", "AxesArray"]
__all__.extend(differentiation.__all__)
__all__.extend(feature_library.__all__)
__all__.extend(optimizers.__all__)
Expand Down
14 changes: 10 additions & 4 deletions pysindy/deeptime/deeptime.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from sklearn.utils.validation import check_is_fitted

from ..pysindy import SINDy
from ..utils import SampleConcatter


class SINDyEstimator(SINDy):
Expand Down Expand Up @@ -100,7 +101,7 @@ def fit(self, x, **kwargs):
super(SINDyEstimator, self).fit(x, **kwargs)
self._model = SINDyModel(
feature_library=self.model.steps[0][1],
optimizer=self.model.steps[1][1],
optimizer=self.model.steps[-1][1],
feature_names=self.feature_names,
t_default=self.t_default,
discrete_time=self.discrete_time,
Expand Down Expand Up @@ -134,10 +135,11 @@ class SINDyModel(SINDy):
Implementation of SINDy conforming to the API of a Deeptime
`Model <https://deeptime-ml.github.io/api/generated/deeptime.base.Model.html>`_.
The model is represented as a Scikit-learn pipeline object with two steps:
The model is represented as a Scikit-learn pipeline object with three steps:
1. Map the raw input data to nonlinear features according to the selected
``feature_library``
2. Multiply the nonlinear features with a coefficient matrix encapuslated
2. Reshape the data from input shape to an optimization problem
3. Multiply the nonlinear features with a coefficient matrix encapuslated
in ``optimizer``.
This class expects the feature library and optimizer to already be fit
Expand Down Expand Up @@ -208,7 +210,11 @@ def __init__(
check_is_fitted(feature_library)
check_is_fitted(optimizer)

steps = [("features", feature_library), ("model", optimizer)]
steps = [
("features", feature_library),
("shaping", SampleConcatter()),
("model", optimizer),
]
self.model = Pipeline(steps)

if float(__version__[:3]) >= 1.0:
Expand Down
3 changes: 0 additions & 3 deletions pysindy/differentiation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

from sklearn.base import BaseEstimator

from ..utils.base import validate_input


class BaseDifferentiation(BaseEstimator):
"""
Expand Down Expand Up @@ -48,5 +46,4 @@ def _differentiate(self, x, t=1):
raise NotImplementedError

def __call__(self, x, t=1):
x = validate_input(x)
return self._differentiate(x, t)
9 changes: 2 additions & 7 deletions pysindy/differentiation/sindy_derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
"""
from derivative import dxdt
from numpy import arange
from sklearn.base import BaseEstimator

from ..utils.base import validate_input
from .base import BaseDifferentiation


class SINDyDerivative(BaseEstimator):
class SINDyDerivative(BaseDifferentiation):
"""
Wrapper class for differentiation classes from the :doc:`derivative:index` package.
This class is meant to provide all the same functionality as the
Expand Down Expand Up @@ -72,7 +71,3 @@ def _differentiate(self, x, t=1):
t = arange(x.shape[0]) * t

return dxdt(x, t, axis=0, **self.kwargs)

def __call__(self, x, t=1):
x = validate_input(x, t=t)
return self._differentiate(x, t)
Loading

0 comments on commit 632585a

Please sign in to comment.