Skip to content

Commit

Permalink
#944 add automatic discretisation
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinsulzer committed Apr 7, 2020
1 parent 96e7811 commit 316f75a
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 5 deletions.
3 changes: 3 additions & 0 deletions pybamm/discretisations/discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,9 @@ def process_model(self, model, inplace=True, check_model=True):

pybamm.logger.info("Finish discretising {}".format(model.name))

# Record that the model has been discretised
model_disc.is_discretised = True

return model_disc

def set_variable_slices(self, variables):
Expand Down
3 changes: 3 additions & 0 deletions pybamm/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ def __init__(self, name="Unnamed model"):
self.use_simplify = True
self.convert_to_format = "casadi"

# Model is not initially discretised
self.is_discretised = False

# Default timescale is 1 second
self.timescale = pybamm.Scalar(1)

Expand Down
15 changes: 10 additions & 5 deletions pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,15 @@ def set_up(self, model, inputs=None):
raise pybamm.SolverError(
"""Cannot use algebraic solver to solve model with time derivatives"""
)
# Discretise model if it isn't already discretised
# This only works with purely 0D models, as otherwise the mesh and spatial
# method should be specified by the user
if model.is_discretised is False:
disc = pybamm.Discretisation()
disc.process_model(model)
# try:
# except error as e:
# raise ValueError(e)

inputs = inputs or {}
y0 = model.concatenated_initial_conditions.evaluate(0, None, inputs=inputs)
Expand Down Expand Up @@ -564,11 +573,7 @@ def solve(self, model, t_eval=None, external_variables=None, inputs=None):
]

# remove any discontinuities after end of t_eval
discontinuities = [
v
for v in discontinuities
if v < t_eval_dimensionless[-1]
]
discontinuities = [v for v in discontinuities if v < t_eval_dimensionless[-1]]

if len(discontinuities) > 0:
pybamm.logger.info(
Expand Down

0 comments on commit 316f75a

Please sign in to comment.