Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 1477 sensitivities for solvers #1552

Merged
merged 93 commits into from
Aug 19, 2021
Merged
Show file tree
Hide file tree
Changes from 68 commits
Commits
Show all changes
93 commits
Select commit Hold shift + click to select a range
a59abb5
#1100 starting to add SDAEs (odes only for now)
valentinsulzer Jul 8, 2020
f97c952
#1100 SODEs working with scipy
valentinsulzer Jul 9, 2020
54d45f9
#1100 get SODEs working for casadi solver
valentinsulzer Jul 9, 2020
d438609
#1100 flake8
valentinsulzer Jul 9, 2020
4c6bf59
#1100 starting to get SDAEs working with casadi
valentinsulzer Jul 10, 2020
b67a457
#1100 working on examples
valentinsulzer Jul 13, 2020
27475b8
#1100 merge #1082
valentinsulzer Jul 21, 2020
0cfdf6f
#1100 reformatted sensitivity API
valentinsulzer Jul 22, 2020
26447b3
#1100 remove ProcessedSymbolicVariable
valentinsulzer Jul 22, 2020
b887515
#1100 working on casadi solver
valentinsulzer Jul 22, 2020
c79f0ca
#1100 working on casadi solver sensitivities
valentinsulzer Jul 23, 2020
067c504
#1100 explicit foward sensitivity working
valentinsulzer Jul 23, 2020
94314fb
#1100 flake8
valentinsulzer Jul 23, 2020
dc69d8e
#1100 merge develop
valentinsulzer Jul 23, 2020
5b4f8e8
Merge branch 'develop' into issue-1100-sdaes
valentinsulzer Jul 27, 2020
03cc0eb
#1100 reformat Solution syntax
valentinsulzer Jul 28, 2020
1872abf
Merge branch 'develop' into issue-1100-sdaes
valentinsulzer Jul 30, 2020
df8bcea
#1100 merge develop
valentinsulzer Aug 4, 2020
789a85f
#1100 merge develop
valentinsulzer Nov 4, 2020
a6ccf32
#1100 merge develop
valentinsulzer Nov 30, 2020
4c1a7af
Merge branch 'develop' into issue-1100-sdaes
valentinsulzer Dec 28, 2020
ddb34c6
#1100 merge 1221
valentinsulzer Dec 28, 2020
2854ce5
#1100 fixing tests
valentinsulzer Dec 28, 2020
be94c59
Merge branch 'issue-1221-convert-to-casadi' into issue-1100-sdaes
valentinsulzer Dec 28, 2020
e70e057
#1100 fixed some solver tests
valentinsulzer Dec 28, 2020
41ded61
#1100 merge
valentinsulzer Dec 29, 2020
2407617
#1100 merge
valentinsulzer Dec 30, 2020
5ebed5a
#1100 merge and fix flake8
valentinsulzer Jan 20, 2021
f794956
#1100 working on tests
valentinsulzer Jan 21, 2021
595601e
Merge branch 'develop' into issue-1100-sdaes
valentinsulzer Jan 27, 2021
af295f3
#1100 merge develop
valentinsulzer Feb 5, 2021
cccefb8
#1100 merge develop
valentinsulzer Mar 27, 2021
c7ddbf5
#1477 draft out a test for idaklu and changes to base solver for sens…
martinjrobins May 8, 2021
41565da
#1477 python sensitivities seem ok, working on casadi
martinjrobins May 22, 2021
7d97c45
#1477 evaluating sensitivities ok for all convert_tos
martinjrobins May 24, 2021
7c39f3f
#1477 update test_sensitivities to use a dae
martinjrobins May 24, 2021
f10fdfc
#1477 sensitivivity calc in idas-klu is running, now need to extract …
martinjrobins Jun 10, 2021
1f3bc9d
#1477 idaklu sensitivities works and tested for python, casadi and jax
martinjrobins Jun 14, 2021
47c5345
#1477 flake8
martinjrobins Jun 14, 2021
b3a3091
#1477 only call IDAGetSens if calculating sensitivities
martinjrobins Jun 14, 2021
836e57f
#1477 merge in #1100
martinjrobins Jun 28, 2021
0cbe0a5
#1477 fix some bugs after merge
martinjrobins Jun 28, 2021
5214994
#1477 generalising 'explicit forward' option so any solver can use it
martinjrobins Jul 5, 2021
840f073
#1477 going to take out sensitivity=casadi option
martinjrobins Jul 16, 2021
3ebec30
#1477 took out sensitivity=casadi option
martinjrobins Jul 16, 2021
ac94921
#1477 took out sensitivity=casadi option, take 2
martinjrobins Jul 16, 2021
f5699c4
#1477 sorting out processed variable
martinjrobins Jul 16, 2021
72560c5
#1477 got some more casadi tests working
martinjrobins Jul 16, 2021
d9ff546
#1477 fix algebraic solver
martinjrobins Jul 16, 2021
6e91335
#1477 unit tests pass
martinjrobins Jul 16, 2021
9897494
#1477 fix for casadi manual stepper
martinjrobins Jul 16, 2021
187c8d5
#1477 fix flake8
martinjrobins Jul 16, 2021
933fdf9
#1477 flake8
martinjrobins Jul 16, 2021
44af729
#1477 merge in develop
martinjrobins Jul 17, 2021
5b0ef01
#1477 flake8
martinjrobins Jul 17, 2021
2b50282
#1477 some minor fixes
martinjrobins Jul 19, 2021
3296ec3
#1477 make sure that model is set up again if calculate sensitivities…
martinjrobins Jul 20, 2021
0b56826
#1477 fix problem related to casadi solver caching integrators
martinjrobins Jul 21, 2021
29905b7
#1477 add standard output tests to sensitivity soln to make sure it h…
martinjrobins Jul 21, 2021
e67880f
#1477 fixes after running integration tests
martinjrobins Jul 21, 2021
5fdeaa9
#1477 flake8
martinjrobins Jul 21, 2021
fc72078
#1477 update changelog
martinjrobins Jul 21, 2021
3b8a902
#1477 update changelog
martinjrobins Jul 21, 2021
6ce1d5d
#1477 fix bugs in scikits odes tests
martinjrobins Jul 22, 2021
f7a3a14
#1477 remove old notebook
martinjrobins Jul 22, 2021
c7543b1
#1477 restore tox.ini
martinjrobins Jul 22, 2021
8d361fe
#1477 fix bug in idaklu
martinjrobins Jul 22, 2021
4c7bbe5
#1477 skip sens test in base_solver if klu not installed
martinjrobins Jul 22, 2021
03528da
#1477 add some tests and remove uncovered lines not neccessary
martinjrobins Aug 2, 2021
06984c4
#1477 check sensitivities with fd in integration tests
martinjrobins Aug 2, 2021
a1cf26e
#1477 fix codacity errors
martinjrobins Aug 2, 2021
5726251
#1477 fix bug in base_solver
martinjrobins Aug 2, 2021
f9daa06
#1477 make fix better
martinjrobins Aug 2, 2021
241af9a
#1477 fix mass matrix inv bug
martinjrobins Aug 2, 2021
073eb58
#1477 integration tests work ok, pretty poor accuracy (perhaps on fd?)
martinjrobins Aug 3, 2021
d18cf81
#1477 fix idaklu unit test
martinjrobins Aug 3, 2021
313c9c3
#1477 fix bug in solution
martinjrobins Aug 3, 2021
bf59b7d
#1477 fix integration idaklu test
martinjrobins Aug 4, 2021
88ecb3f
#1477 do sensitivity integration tests using a processed variable
martinjrobins Aug 4, 2021
1b32660
#1477 fix codacity
martinjrobins Aug 4, 2021
f8bc091
#1477 improve coverage
martinjrobins Aug 5, 2021
6ca02be
#1477 fix some remaining bugs with algebraic solver bounds
martinjrobins Aug 5, 2021
9291bb9
#1477 flake8
martinjrobins Aug 5, 2021
ea59d9f
#1477 fix bug with bounds
martinjrobins Aug 5, 2021
af94506
#1477 increase coverage
martinjrobins Aug 5, 2021
11f056d
#1477 merge in develop
martinjrobins Aug 18, 2021
bec1698
#1477 update changelog
martinjrobins Aug 18, 2021
608fa1f
#1477 remove test files
martinjrobins Aug 18, 2021
1daa4ba
#1477 remove sensitivities notebook
martinjrobins Aug 18, 2021
df0ff95
#1477 swap to using current function for sens tests
martinjrobins Aug 18, 2021
2cb99ad
#1477 fix bug in jax evaluate
martinjrobins Aug 18, 2021
5857a3a
#1477 put back algebraic solver sens tests
martinjrobins Aug 18, 2021
e898c65
Merge branch 'develop' into issue-1477-idaklu-send
martinjrobins Aug 18, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Features

- `pybamm.base_solver.solve` function can take a list of input parameters to calculate the sensitivities of the solution with respect to. Alternatively, it can be set to `True` to calculate the sensitivities for all input parameters ([#1552](https://github.com/pybamm-team/PyBaMM/pull/1552))
- Added fitted expressions for OCPs for the Chen2020 parameter set ([#1526](https://github.com/pybamm-team/PyBaMM/pull/1497))
- Added `initial_soc` argument to `Simualtion.solve` for specifying the initial SOC when solving a model ([#1512](https://github.com/pybamm-team/PyBaMM/pull/1512))
- Added `print_name` to some symbols ([#1495](https://github.com/pybamm-team/PyBaMM/pull/1495), [#1497](https://github.com/pybamm-team/PyBaMM/pull/1497))
Expand Down Expand Up @@ -174,12 +175,13 @@ This release adds new operators for more complex models, some basic sensitivity

## Breaking changes

- Changed sensitivity API. Removed `ProcessedSymbolicVariable`, all sensitivity now handled within the solvers and `ProcessedVariable` ()
- Renamed `quick_plot_vars` to `output_variables` in `Simulation` to be consistent with `QuickPlot`. Passing `quick_plot_vars` to `Simulation.plot()` has been deprecated and `output_variables` should be passed instead ([#1099](https://github.com/pybamm-team/PyBaMM/pull/1099))
martinjrobins marked this conversation as resolved.
Show resolved Hide resolved
- The "fast diffusion" particle option has been renamed "uniform profile" ([#1130](https://github.com/pybamm-team/PyBaMM/pull/1130))
- The modules containing standard parameters are now classes so they can take options
(e.g. `standard_parameters_lithium_ion` is now `LithiumIonParameters`) ([#1120](https://github.com/pybamm-team/PyBaMM/pull/1120))
- Renamed `quick_plot_vars` to `output_variables` in `Simulation` to be consistent with `QuickPlot`. Passing `quick_plot_vars` to `Simulation.plot()` has been deprecated and `output_variables` should be passed instead ([#1099](https://github.com/pybamm-team/PyBaMM/pull/1099))


# [v0.2.3](https://github.com/pybamm-team/PyBaMM/tree/v0.2.3) - 2020-07-01

This release enables the use of [Google Colab](https://colab.research.google.com/github/pybamm-team/PyBaMM/blob/main/) for running example notebooks, and adds some small new features and bug fixes.
Expand Down
4 changes: 2 additions & 2 deletions FindSUNDIALS.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
# find the SUNDIALS include directories
find_path(SUNDIALS_INCLUDE_DIR
NAMES
ida/ida.h
idas/idas.h
sundials/sundials_math.h
sundials/sundials_types.h
sunlinsol/sunlinsol_klu.h
Expand All @@ -39,7 +39,7 @@ find_path(SUNDIALS_INCLUDE_DIR
)

set(SUNDIALS_WANT_COMPONENTS
sundials_ida
sundials_idas
sundials_sunlinsolklu
sundials_sunmatrixsparse
sundials_nvecserial
Expand Down
3 changes: 0 additions & 3 deletions docs/source/solvers/processed_variable.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,3 @@ Post-Process Variables

.. autoclass:: pybamm.ProcessedVariable
:members:

.. autoclass:: pybamm.ProcessedSymbolicVariable
:members:
5 changes: 5 additions & 0 deletions pybamm/discretisations/discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,11 @@ def process_model(self, model, inplace=True, check_model=True):
model_disc.rhs, model_disc.concatenated_rhs = rhs, concat_rhs
model_disc.algebraic, model_disc.concatenated_algebraic = alg, concat_alg

# Save length of rhs and algebraic
model_disc.len_rhs = model_disc.concatenated_rhs.size
model_disc.len_alg = model_disc.concatenated_algebraic.size
model_disc.len_rhs_and_alg = model_disc.len_rhs + model_disc.len_alg

# Process events
processed_events = []
pybamm.logger.verbose("Discretise events for {}".format(model.name))
Expand Down
12 changes: 12 additions & 0 deletions pybamm/expression_tree/concatenations.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,18 @@ def __str__(self):
out = out[:-2] + ")"
return out

def _diff(self, variable):
""" See :meth:`pybamm.Symbol._diff()`. """
children_diffs = [
child.diff(variable) for child in self.cached_children
]
if len(children_diffs) == 1:
diff = children_diffs[0]
else:
diff = self.__class__(*children_diffs)

return diff

def get_children_domains(self, children):
# combine domains from children
domain = []
Expand Down
60 changes: 50 additions & 10 deletions pybamm/expression_tree/operations/evaluate_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ def __init__(self, symbol):
constants[symbol_id] = jax.device_put(constants[symbol_id])

# get a list of constant arguments to input to the function
arg_list = [
self._arg_list = [
id_to_python_variable(symbol_id, True) for symbol_id in constants.keys()
]

Expand All @@ -580,9 +580,11 @@ def __init__(self, symbol):

# add function def to first line
args = "t=None, y=None, y_dot=None, inputs=None, known_evals=None"
if arg_list:
args = ",".join(arg_list) + ", " + args
python_str = "def evaluate_jax({}):\n".format(args) + python_str
if self._arg_list:
args = ",".join(self._arg_list) + ", " + args
python_str = (
"def evaluate_jax({}):\n".format(args) + python_str
)

# calculate the final variable that will output the result of calling `evaluate`
# on `symbol`
Expand All @@ -606,17 +608,32 @@ def __init__(self, symbol):
compiled_function = compile(python_str, result_var, "exec")
exec(compiled_function)

n = len(arg_list)
static_argnums = tuple(static_argnums)
self._jit_evaluate = jax.jit(self._evaluate_jax, static_argnums=static_argnums)
self._static_argnums = tuple(static_argnums)
self._jit_evaluate = jax.jit(self._evaluate_jax,
static_argnums=self._static_argnums)

def get_jacobian(self):
n = len(self._arg_list)

# store a jit version of evaluate_jax's jacobian
# forward mode autodiff wrt y, which is argument 1 after arg_list
jacobian_evaluate = jax.jacfwd(self._evaluate_jax, argnums=1 + n)
self._jac_evaluate = jax.jit(jacobian_evaluate, static_argnums=static_argnums)

def get_jacobian(self):
self._jac_evaluate = jax.jit(jacobian_evaluate,
static_argnums=self._static_argnums)

return EvaluatorJaxJacobian(self._jac_evaluate, self._constants)

def get_sensitivities(self):
n = len(self._arg_list)

# forward mode autodiff wrt inputs, which is argument 3 after arg_list
jacobian_evaluate = jax.jacfwd(self._evaluate_jax, argnums=3 + n)

self._sens_evaluate = jax.jit(jacobian_evaluate,
static_argnums=self._static_argnums)

return EvaluatorJaxSensitivities(self._sens_evaluate, self._constants)

def debug(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
# generated code assumes y is a column vector
if y is not None and y.ndim == 1:
Expand Down Expand Up @@ -673,3 +690,26 @@ def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
return result, known_evals
else:
return result


class EvaluatorJaxSensitivities:
def __init__(self, jac_evaluate, constants):
self._jac_evaluate = jac_evaluate
self._constants = constants

def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
"""
Acts as a drop-in replacement for :func:`pybamm.Symbol.evaluate`
"""
# generated code assumes y is a column vector
if y is not None and y.ndim == 1:
y = y.reshape(-1, 1)

# execute code
result = self._jac_evaluate(*self._constants, t, y, y_dot, inputs, known_evals)

# don't need known_evals, but need to reproduce Symbol.evaluate signature
if known_evals is not None:
return result, known_evals
else:
return result
13 changes: 13 additions & 0 deletions pybamm/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,19 @@ def timescale(self, value):
"""Set the timescale"""
self._timescale = value

@property
def length_scales(self):
"Length scales of model"
return self._length_scale

@length_scales.setter
def length_scales(self, values):
"Set the length scale, converting any numbers to pybamm.Scalar"
for domain, scale in values.items():
if isinstance(scale, numbers.Number):
values[domain] = pybamm.Scalar(scale)
self._length_scale = values

@property
def parameters(self):
"""Returns all the parameters in the model"""
Expand Down
Loading