-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimization #92
base: master
Are you sure you want to change the base?
Optimization #92
Changes from all commits
e8224d8
7d6fde2
a751680
de6b74c
b6866f9
542f100
f17c0c6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,76 @@ | ||
import numpy as np | ||
|
||
|
||
def sse(theta, model, data, model_parameters, lag_time=None): | ||
"""Calculate the sum of squared errors from data and model instance | ||
|
||
The function assumes the number of values in the theta array corresponds | ||
to the number of defined parameters, but can be extended with a lag_time | ||
parameters to make this adjustable as well by an optimization algorithm. | ||
|
||
|
||
Parameters | ||
---------- | ||
theta : np.array | ||
Array with N (if lag_time is defined) or N+1 (if lag_time is not | ||
defined) values. | ||
model : covidmodel.model instance | ||
Covid model instance | ||
data : xarray.DataArray | ||
Xarray DataArray with time as a dimension and the name corresponding | ||
to an existing model output state. | ||
model_parameters : list of str | ||
Parameter names of the model parameters to adjust in order to get the | ||
SSE. | ||
lag_time : int | ||
Warming up period before comparing the data with the model output. | ||
e.g. if 40; comparison of data only starts after 40 days. | ||
|
||
|
||
Notes | ||
----- | ||
Assumes daily time step(!) # TODO: need to generalize this | ||
|
||
Examples | ||
-------- | ||
>>> data_to_fit = xr.DataArray([ 54, 79, 100, 131, 165, 228, 290], | ||
coords={"time": range(1, 8)}, name="ICU", | ||
dims=['time']) | ||
>>> sse([1.6, 0.025, 42], sir_model, data_to_fit, ['sigma', 'beta'], lag_time=None) | ||
>>> # previous is equivalent to | ||
>>> sse([1.6, 0.025], sir_model, data_to_fit, ['sigma', 'beta'], lag_time=42) | ||
>>> # but the latter will use a fixed lag_time, | ||
>>> # whereas the first one can be used inside optimization | ||
""" | ||
|
||
if data.name not in model.state_names: | ||
raise Exception("Data variable to fit is not available as model output") | ||
|
||
# define new parameters in model | ||
for i, parameter in enumerate(model_parameters): | ||
model.parameters.update({parameter: theta[i]}) | ||
|
||
# extract additional parameter # TODO - check alternatives for generalisation here | ||
if not lag_time: | ||
lag_time = int(round(theta[-1])) | ||
else: | ||
lag_time = int(round(lag_time)) | ||
|
||
# run model | ||
time = len(data.time) + lag_time # at least this length | ||
output = model.sim(time) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could also run with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The default will be to only align those values where the time values match exactly, but you can indeed make this more flexible (eg http://xarray.pydata.org/en/stable/generated/xarray.DataArray.reindex.html also allows a simple "nearest" or a certain tolerance that should be allowed) |
||
|
||
# extract the variable of interest | ||
subset_output = output.sum(dim="stratification")[data.name] | ||
# adjust to fix lag time to extract start of comparison | ||
output_to_compare = subset_output.sel(time=slice(lag_time, lag_time - 1 + len(data))) | ||
|
||
# calculate sse -> we could maybe enable more function options on this level? | ||
sse = np.sum((data.values - output_to_compare.values)**2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to check that the "time" actually matched? (with the provided data) So if you adjust the time of the model output with lag_time, this calculation could be done with the xarray object (so without |
||
|
||
return sse | ||
|
||
|
||
def SSE(thetas,BaseModel,data,states,parNames,weights,checkpoints=None): | ||
|
||
""" | ||
|
@@ -30,7 +101,7 @@ def SSE(thetas,BaseModel,data,states,parNames,weights,checkpoints=None): | |
----------- | ||
SSE = SSE(model,thetas,data,parNames,positions,weights) | ||
""" | ||
|
||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
# assign estimates to correct variable | ||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
@@ -46,7 +117,7 @@ def SSE(thetas,BaseModel,data,states,parNames,weights,checkpoints=None): | |
# Run simulation | ||
# ~~~~~~~~~~~~~~ | ||
# number of dataseries | ||
n = len(data) | ||
n = len(data) | ||
# Compute simulation time | ||
data_length =[] | ||
for i in range(n): | ||
|
@@ -124,7 +195,7 @@ def MLE(thetas,BaseModel,data,states,parNames,checkpoints=None): | |
# Run simulation | ||
# ~~~~~~~~~~~~~~ | ||
# number of dataseries | ||
n = len(data) | ||
n = len(data) | ||
# Compute simulation time | ||
data_length =[] | ||
for i in range(n): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am wondering if we should introduce a concept of "cloning" a model, to avoid changing the original one when setting parameters like this for an optimization run
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I propose to define this in a separate issue/PR to provide a more generic solution.