From e8224d88fe4ed47dd199dbb2b5ad7ac56db578ff Mon Sep 17 00:00:00 2001 From: stijnvanhoey Date: Mon, 8 Jun 2020 22:45:08 +0200 Subject: [PATCH 1/5] Add refactored sse function --- .../optimization/objective_fcns.py | 77 ++++++++++++++++++- 1 file changed, 74 insertions(+), 3 deletions(-) diff --git a/src/covid19model/optimization/objective_fcns.py b/src/covid19model/optimization/objective_fcns.py index 16b286d98..fa7a61d47 100644 --- a/src/covid19model/optimization/objective_fcns.py +++ b/src/covid19model/optimization/objective_fcns.py @@ -1,5 +1,76 @@ import numpy as np + +def sse(theta, model, data, model_parameters, lag_time=None): + """Calculate the sum of squared errors from data and model instance + + The function assumes the number of values in the theta array corresponds + to the number of defined parameters, but can be extended with a lag_time + parameters to make this adjustable as well by an optimization algorithm. + + + Parameters + ---------- + theta : np.array + Array with N (if lag_time is defined) or N+1 (if lag_time is not + defined) values. + model : covidmodel.model instance + Covid model instance + data : xarray.DataArray + Xarray DataArray with time as a dimension and the name corresponding + to an existing model output state. + model_parameters : list of str + Parameter names of the model parameters to adjust in order to get the + SSE. + lag_time : int + Warming up period before comparing the data with the model output. + e.g. if 40; comparison of data only starts after 40 days. + + + Notes + ----- + Assumes daily time step(!) # TODO: need to generalize this + + Examples + -------- + >>> data_to_fit = xr.DataArray([ 54, 79, 100, 131, 165, 228, 290], + coords={"time": range(1, 8)}, name="ICU", + dims=['time']) + >>> sse([1.6, 0.025, 42], sir_model, data_to_fit, ['sigma', 'beta'], lag_time=None) + >>> # previous is equivalent to + >>> sse([1.6, 0.025], sir_model, data_to_fit, ['sigma', 'beta'], lag_time=42) + >>> # but the latter will use a fixed lag_time, + >>> # whereas the first one can be used inside optimization + """ + + if data_to_fit.name not in sir_model.state_names: + raise Exception("Data variable to fit is not available as model output") + + # define new parameters in model + for i, parameter in enumerate(model_parameters): + model.parameters.update({parameter: theta[i]}) + + # extract additional parameter # TODO - check alternatives for generalisation here + if not lag_time: + lag_time = int(round(theta[-1])) + else: + lag_time = int(round(lag_time)) + + # run model + time = len(data.time) + lag_time # at least this length + output = model.sim(time) + + # extract the variable of interest + subset_output = output.sum(dim="stratification")[data.name] + # adjust to fix lag time to extract start of comparison + output_to_compare = subset_output.sel(time=slice(lag_time, lag_time - 1 + len(data))) + + # calculate sse -> we could maybe enable more function options on this level? + sse = np.sum((data.values - output_to_compare.values)**2) + + return sse + + def SSE(thetas,BaseModel,data,states,parNames,weights,checkpoints=None): """ @@ -30,7 +101,7 @@ def SSE(thetas,BaseModel,data,states,parNames,weights,checkpoints=None): ----------- SSE = SSE(model,thetas,data,parNames,positions,weights) """ - + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # assign estimates to correct variable # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -46,7 +117,7 @@ def SSE(thetas,BaseModel,data,states,parNames,weights,checkpoints=None): # Run simulation # ~~~~~~~~~~~~~~ # number of dataseries - n = len(data) + n = len(data) # Compute simulation time data_length =[] for i in range(n): @@ -129,7 +200,7 @@ def MLE(thetas,BaseModel,data,states,parNames,checkpoints=None): # Run simulation # ~~~~~~~~~~~~~~ # number of dataseries - n = len(data) + n = len(data) # Compute simulation time data_length =[] for i in range(n): From 7d6fde227f0345ccccc8e2ac796e1eacf14c569b Mon Sep 17 00:00:00 2001 From: stijnvanhoey Date: Mon, 8 Jun 2020 23:03:56 +0200 Subject: [PATCH 2/5] Switch name and add refactored fit function --- .../optimization/{MCMC.py => optimize.py} | 57 ++++++++++++++++++- 1 file changed, 55 insertions(+), 2 deletions(-) rename src/covid19model/optimization/{MCMC.py => optimize.py} (50%) diff --git a/src/covid19model/optimization/MCMC.py b/src/covid19model/optimization/optimize.py similarity index 50% rename from src/covid19model/optimization/MCMC.py rename to src/covid19model/optimization/optimize.py index c11c2f8f9..c98f2eee3 100644 --- a/src/covid19model/optimization/MCMC.py +++ b/src/covid19model/optimization/optimize.py @@ -3,6 +3,59 @@ from covid19model.optimization import objective_fcns from covid19model.optimization import pso + +def fit(func, model, data, model_parameters, bounds, + disp=True, maxiter=30, popsize=10): + """Fit a model using a defined SSE function + + Parameters + ---------- + func : objective function + The arguments ``model``, ``data`` + and ``model_parameters`` should be the + other arguments of the optimization func + model : covidmodel.model instance + Covid19model instance + data : xarray.DataArray + Xarray DataArray with time as a dimension and the name corresponding + to an existing model output state. + model_parameters : list of str + Parameter names of the model parameters to adjust in order to get the + SSE. + bounds : list of (min, max) + For each parameter (model parameter or others), define the min and + max boundaries as tuples. If func has non-model parameters to + optimize as well, len(bounds) > len(model_parameters). + disp: boolean + display the pso output stream + maxiter: float or int + maximum number of pso iterations + popsize: float or int + population size of particle swarm + increasing this variable lowers the chance of finding local minima but slows down calculations + """ + # TODO - add checks on inputs,... + + original_parameters = model.parameters.copy() + + (theta_hat, obj_fun_val, pars_final_swarm, + obj_fun_val_final_swarm) = pso.optim(func, + bounds, + args=(model, data, model_parameters), + swarmsize=popsize, + maxiter=maxiter, + processes=mp.cpu_count()-1, + minfunc=1e-9, + minstep=1e-9, + debug=True, + particle_output=True) + + # reset the model parameters to original values + model.parameters = original_parameters + + return theta_hat + + def fit_pso(BaseModel,data,parNames,states,bounds,checkpoints=None,disp=True,maxiter=30,popsize=10): """ A function to compute the mimimum of the absolute value of the maximum likelihood estimator using a particle swarm optimization @@ -12,7 +65,7 @@ def fit_pso(BaseModel,data,parNames,states,bounds,checkpoints=None,disp=True,max BaseModel: model object correctly initialised model to be fitted to the dataset data: array - list containing dataseries + list containing dataseries parNames: array list containing the names of the parameters to be fitted states: array @@ -38,7 +91,7 @@ def fit_pso(BaseModel,data,parNames,states,bounds,checkpoints=None,disp=True,max Notes ----------- - Use all available cores minus one by default (optimal number of processors for 2-,4- or 6-core PC's with an OS). + Use all available cores minus one by default (optimal number of processors for 2-,4- or 6-core PC's with an OS). Example use ----------- From a75168061eebabd277b7122e0eb1c638c0454f66 Mon Sep 17 00:00:00 2001 From: stijnvanhoey Date: Mon, 8 Jun 2020 23:20:51 +0200 Subject: [PATCH 3/5] Fix path handling data functions --- src/covid19model/data/parameters.py | 11 +++++++---- src/covid19model/data/polymod.py | 28 +++++++++++++++------------- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/src/covid19model/data/parameters.py b/src/covid19model/data/parameters.py index f1997b5e9..8dbb4c15c 100644 --- a/src/covid19model/data/parameters.py +++ b/src/covid19model/data/parameters.py @@ -54,27 +54,30 @@ def get_COVID19_SEIRD_parameters(): ----------- parameters = get_COVID19_SEIRD_parameters() """ + abs_dir = os.path.dirname(__file__) + rel_dir = os.path.join(abs_dir, "../../../data/raw/model_parameters") + # Initialize parameters dictionary parameters = {} # Assign Nc_total from the Polymod study to the parameters dictionary initN, Nc_home, Nc_work, Nc_schools, Nc_transport, Nc_leisure, Nc_others, Nc_total = polymod.get_interaction_matrices() parameters['Nc'] = Nc_total - + # Verity_etal - df = pd.read_csv("../../data/raw/model_parameters/verity_etal.csv", sep=',',header='infer') + df = pd.read_csv(os.path.join(rel_dir, "verity_etal.csv"), sep=',',header='infer') parameters['h'] = np.array(df.loc[:,'symptomatic_hospitalized'].astype(float).tolist())/100 parameters['icu'] = np.array(df.loc[:,'hospitalized_ICU'].astype(float).tolist())/100 parameters['c'] = 1-parameters['icu'] parameters['m0'] = np.array(df.loc[:,'CFR'].astype(float).tolist())/100/parameters['h']/parameters['icu'] # Abrams_etal - df_asymp = pd.read_csv("../../data/raw/model_parameters/abrams_etal.csv", sep=',',header='infer') + df_asymp = pd.read_csv(os.path.join(rel_dir, "wu_etal.csv"), sep=',',header='infer') parameters['a'] = np.array(df_asymp.loc[:,'fraction asymptomatic'].astype(float).tolist()) parameters['m'] = 1-parameters['a'] # Other parameters - df_other_pars = pd.read_csv("../../data/raw/model_parameters/others.csv", sep=',',header='infer') + df_other_pars = pd.read_csv(os.path.join(rel_dir, "others.csv"), sep=',',header='infer') parameters.update(df_other_pars.T.to_dict()[0]) # Fitted parameters diff --git a/src/covid19model/data/polymod.py b/src/covid19model/data/polymod.py index 6bac67058..a6fe7dae9 100644 --- a/src/covid19model/data/polymod.py +++ b/src/covid19model/data/polymod.py @@ -4,7 +4,7 @@ import numpy as np def get_interaction_matrices(): - """Extract interaction matrices and demographic data from `data/raw/polymod` folder + """Extract interaction matrices and demographic data from `data/raw/polymod` folder This function returns the total number of individuals in ten year age bins in the Belgian population and the interaction matrices Nc at home, at work, in schools, on public transport, during leisure activities and during other activities. @@ -15,15 +15,15 @@ def get_interaction_matrices(): Nc_home : np.array (9x9) number of daily contacts at home of individuals in age group X with individuals in age group Y Nc_work : np.array (9x9) - number of daily contacts in the workplace of individuals in age group X with individuals in age group Y + number of daily contacts in the workplace of individuals in age group X with individuals in age group Y Nc_schools : np.array (9x9) - number of daily contacts in schools of individuals in age group X with individuals in age group Y + number of daily contacts in schools of individuals in age group X with individuals in age group Y Nc_transport : np.array (9x9) - number of daily contacts on public transport of individuals in age group X with individuals in age group Y + number of daily contacts on public transport of individuals in age group X with individuals in age group Y Nc_leisure : np.array (9x9) number of daily contacts during leisure activities of individuals in age group X with individuals in age group Y Nc_others : np.array (9x9) - number of daily contacts in other places of individuals in age group X with individuals in age group Y + number of daily contacts in other places of individuals in age group X with individuals in age group Y Nc_total : np.array (9x9) total number of daily contacts of individuals in age group X with individuals in age group Y, calculated as the sum of all the above interaction @@ -36,15 +36,17 @@ def get_interaction_matrices(): ----------- initN, Nc_home, Nc_work, Nc_schools, Nc_transport, Nc_leisure, Nc_others, Nc_total = get_interaction_matrices() """ + abs_dir = os.path.dirname(__file__) + rel_dir = os.path.join(abs_dir, "../../../data/raw/polymod/interaction_matrices/Belgium") # Data source - Nc_home = np.loadtxt("../../data/raw/polymod/interaction_matrices/Belgium/BELhome.txt", dtype='f', delimiter='\t') - Nc_work = np.loadtxt("../../data/raw/polymod/interaction_matrices/Belgium/BELwork.txt", dtype='f', delimiter='\t') - Nc_schools = np.loadtxt("../../data/raw/polymod/interaction_matrices/Belgium/BELschools.txt", dtype='f', delimiter='\t') - Nc_transport = np.loadtxt("../../data/raw/polymod/interaction_matrices/Belgium/BELtransport.txt", dtype='f', delimiter='\t') - Nc_leisure = np.loadtxt("../../data/raw/polymod/interaction_matrices/Belgium/BELleisure.txt", dtype='f', delimiter='\t') - Nc_others = np.loadtxt("../../data/raw/polymod/interaction_matrices/Belgium/BELothers.txt", dtype='f', delimiter='\t') - Nc_total = np.loadtxt("../../data/raw/polymod/interaction_matrices/Belgium/BELtotal.txt", dtype='f', delimiter='\t') - initN = np.loadtxt("../../data/raw/polymod/demographic/BELagedist_10year.txt", dtype='f', delimiter='\t') + Nc_home = np.loadtxt(os.path.join(rel_dir, "BELhome.txt"), dtype='f', delimiter='\t') + Nc_work = np.loadtxt(os.path.join(rel_dir, "BELwork.txt"), dtype='f', delimiter='\t') + Nc_schools = np.loadtxt(os.path.join(rel_dir, "BELschools.txt"), dtype='f', delimiter='\t') + Nc_transport = np.loadtxt(os.path.join(rel_dir, "BELtransport.txt"), dtype='f', delimiter='\t') + Nc_leisure = np.loadtxt(os.path.join(rel_dir, "BELleisure.txt"), dtype='f', delimiter='\t') + Nc_others = np.loadtxt(os.path.join(rel_dir, "BELothers.txt"), dtype='f', delimiter='\t') + Nc_total = np.loadtxt(os.path.join(rel_dir, "BELtotal.txt"), dtype='f', delimiter='\t') + initN = np.loadtxt(os.path.join(abs_dir, "../../../data/raw/polymod/demographic/BELagedist_10year.txt"), dtype='f', delimiter='\t') return initN, Nc_home, Nc_work, Nc_schools, Nc_transport, Nc_leisure, Nc_others, Nc_total From de6b74cf1eb429edec403f5ddd974ac4361e0a3f Mon Sep 17 00:00:00 2001 From: stijnvanhoey Date: Mon, 8 Jun 2020 23:21:18 +0200 Subject: [PATCH 4/5] Fix object names --- src/covid19model/optimization/objective_fcns.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/covid19model/optimization/objective_fcns.py b/src/covid19model/optimization/objective_fcns.py index fa7a61d47..03ab4aab7 100644 --- a/src/covid19model/optimization/objective_fcns.py +++ b/src/covid19model/optimization/objective_fcns.py @@ -43,7 +43,7 @@ def sse(theta, model, data, model_parameters, lag_time=None): >>> # whereas the first one can be used inside optimization """ - if data_to_fit.name not in sir_model.state_names: + if data.name not in model.state_names: raise Exception("Data variable to fit is not available as model output") # define new parameters in model From b6866f9d86f465838e079cf457af8a2f37c37dff Mon Sep 17 00:00:00 2001 From: stijnvanhoey Date: Mon, 8 Jun 2020 23:21:54 +0200 Subject: [PATCH 5/5] Add example of refactored optimization --- ...2-stijnvanhoey-refactor-optimization.ipynb | 1033 +++++++++++++++++ 1 file changed, 1033 insertions(+) create mode 100644 notebooks/0.2-stijnvanhoey-refactor-optimization.ipynb diff --git a/notebooks/0.2-stijnvanhoey-refactor-optimization.ipynb b/notebooks/0.2-stijnvanhoey-refactor-optimization.ipynb new file mode 100644 index 000000000..44fdc25d1 --- /dev/null +++ b/notebooks/0.2-stijnvanhoey-refactor-optimization.ipynb @@ -0,0 +1,1033 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "ExecuteTime": { + "end_time": "2020-06-03T07:21:18.467181Z", + "start_time": "2020-06-03T07:21:17.408492Z" + } + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import xarray as xr\n", + "from scipy.integrate import solve_ivp\n", + "import matplotlib.pyplot as plt\n", + "\n", + "import itertools" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# OPTIONAL: Load the \"autoreload\" extension so that package code can change\n", + "%load_ext autoreload\n", + "# OPTIONAL: always reload modules so that as you change code in src, it gets loaded\n", + "%autoreload 2" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Age based SEIRS" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Import model" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "ExecuteTime": { + "end_time": "2020-06-03T07:21:35.683658Z", + "start_time": "2020-06-03T07:21:33.312826Z" + } + }, + "outputs": [], + "source": [ + "from covid19model.models import models\n", + "from covid19model.data import google\n", + "from covid19model.data import sciensano\n", + "from covid19model.data import polymod\n", + "from covid19model.data import parameters\n", + "from covid19model.visualization.output import population_status, infected" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "heading_collapsed": true + }, + "source": [ + "### Define model locally" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Test model" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "initN, Nc_home, Nc_work, Nc_schools, Nc_transport, Nc_leisure, Nc_others, Nc_total = polymod.get_interaction_matrices()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "ExecuteTime": { + "end_time": "2020-06-03T07:21:30.050406Z", + "start_time": "2020-06-03T07:21:30.041794Z" + } + }, + "outputs": [], + "source": [ + "h = np.array([0.0205,0.0205,0.1755,0.1755,0.2115,0.2503,0.3066,0.4033,0.4770])\n", + "icu = np.array([0,0,0.0310,0.0310,0.055,0.077,0.107,0.1685,0.1895])\n", + "r = icu/h" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "ExecuteTime": { + "end_time": "2020-06-03T07:21:30.550918Z", + "start_time": "2020-06-03T07:21:30.544110Z" + } + }, + "outputs": [], + "source": [ + "# ... parameters and initial conditions\n", + "levels = initN.size\n", + "nc = Nc_total\n", + "params = parameters.get_COVID19_SEIRD_parameters()\n", + "\n", + "initial_states = {'S': initN, 'E': np.ones(levels)}" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "ExecuteTime": { + "end_time": "2020-06-03T07:21:36.603310Z", + "start_time": "2020-06-03T07:21:36.480161Z" + } + }, + "outputs": [], + "source": [ + "# -> user initiates the model\n", + "sir_model = models.COVID19_SEIRD(initial_states, params)\n", + "\n", + "# -> user runs a simulation for a defined time period\n", + "time = [0, 200]\n", + "output = sir_model.sim(time)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "ExecuteTime": { + "end_time": "2020-06-03T07:25:36.445570Z", + "start_time": "2020-06-03T07:25:36.190335Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[,\n", + " ,\n", + " ]" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "infected(output)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "ExecuteTime": { + "end_time": "2020-06-03T07:21:53.877975Z", + "start_time": "2020-06-03T07:21:53.658738Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[,\n", + " ,\n", + " ,\n", + " ]" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "population_status(output)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "ExecuteTime": { + "end_time": "2020-06-03T07:26:37.433173Z", + "start_time": "2020-06-03T07:26:37.213467Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plt.plot(output.coords['time'], output['S'].sum(dim=\"stratification\"), label=\"S\")\n", + "plt.plot(output.coords['time'], output['E'].sum(dim=\"stratification\"), label=\"E\")\n", + "\n", + "plt.legend()" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "from covid19model.data import sciensano" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
H_totICU_totH_inH_outH_tot_cumsum
DATE
2020-03-1526654711853
2020-03-16370799014129
2020-03-1749710012331221
2020-03-1865013118348356
2020-03-1984416521249519
\n", + "
" + ], + "text/plain": [ + " H_tot ICU_tot H_in H_out H_tot_cumsum\n", + "DATE \n", + "2020-03-15 266 54 71 18 53\n", + "2020-03-16 370 79 90 14 129\n", + "2020-03-17 497 100 123 31 221\n", + "2020-03-18 650 131 183 48 356\n", + "2020-03-19 844 165 212 49 519" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_sciensano = sciensano.get_sciensano_COVID19_data(update=False)\n", + "df_sciensano.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Convert the data to a xarray DataArray (can be later extended towards xarray Dataset if multiple states are compared" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "Show/Hide data repr\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Show/Hide attributes\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
xarray.DataArray
'ICU'
  • time: 7
  • 54 79 100 131 165 228 290
    array([ 54,  79, 100, 131, 165, 228, 290])
    • time
      (time)
      int64
      1 2 3 4 5 6 7
      array([1, 2, 3, 4, 5, 6, 7])
" + ], + "text/plain": [ + "\n", + "array([ 54, 79, 100, 131, 165, 228, 290])\n", + "Coordinates:\n", + " * time (time) int64 1 2 3 4 5 6 7" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data_to_fit = xr.DataArray([ 54, 79, 100, 131, 165, 228, 290], \n", + " coords={\"time\": range(1, 8)}, name=\"ICU\",\n", + " dims=['time'])\n", + "data_to_fit" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ 54, 79, 100, 131, 165, 228, 290])" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data_to_fit.values" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "from covid19model.optimization.objective_fcns import sse" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Try out the sse function:" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "122202.00118319865" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sse([0.3, 0.04, 12], sir_model, data_to_fit, ['sigma', 'beta'], lag_time=None)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "6437.4700475273885" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sse([1.6, 0.025, 42], sir_model, data_to_fit, ['sigma', 'beta'], lag_time=None)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Due to the setup of the function, this is similar to" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "6437.4700475273885" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sse([1.6, 0.025], sir_model, data_to_fit, ['sigma', 'beta'], lag_time=42)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Apply with fixed lag time:" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "13328.779260868365" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sse([1.6, 0.025], sir_model, data_to_fit, ['sigma', 'beta'], lag_time=40)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Apply the optimization pso functionality:" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "import multiprocessing as mp\n", + "from covid19model.optimization import pso" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "from covid19model.optimization.optimize import fit" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "bounds=[(1, 100), (0.02, 0.06), (20, 80)]" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "No constraints given.\n", + "New best for swarm at iteration 1: [ 1. 0.03304964 24.70556363] 10807.146214163115\n", + "Best after iteration 1: [ 1. 0.03304964 24.70556363] 10807.146214163115\n", + "Best after iteration 2: [ 1. 0.03304964 24.70556363] 10807.146214163115\n", + "Best after iteration 3: [ 1. 0.03304964 24.70556363] 10807.146214163115\n", + "Best after iteration 4: [ 1. 0.03304964 24.70556363] 10807.146214163115\n", + "New best for swarm at iteration 5: [ 1. 0.03285257 25.5069565 ] 1016.6487860669702\n", + "Best after iteration 5: [ 1. 0.03285257 25.5069565 ] 1016.6487860669702\n", + "Best after iteration 6: [ 1. 0.03285257 25.5069565 ] 1016.6487860669702\n", + "Best after iteration 7: [ 1. 0.03285257 25.5069565 ] 1016.6487860669702\n", + "Best after iteration 8: [ 1. 0.03285257 25.5069565 ] 1016.6487860669702\n", + "Best after iteration 9: [ 1. 0.03285257 25.5069565 ] 1016.6487860669702\n", + "Best after iteration 10: [ 1. 0.03285257 25.5069565 ] 1016.6487860669702\n", + "Best after iteration 11: [ 1. 0.03285257 25.5069565 ] 1016.6487860669702\n", + "New best for swarm at iteration 12: [ 1. 0.03281443 25.61753133] 968.2384358003329\n", + "Best after iteration 12: [ 1. 0.03281443 25.61753133] 968.2384358003329\n", + "Best after iteration 13: [ 1. 0.03281443 25.61753133] 968.2384358003329\n", + "New best for swarm at iteration 14: [ 1. 0.0328097 25.73390775] 965.4411400238192\n", + "Best after iteration 14: [ 1. 0.0328097 25.73390775] 965.4411400238192\n", + "Best after iteration 15: [ 1. 0.0328097 25.73390775] 965.4411400238192\n", + "New best for swarm at iteration 16: [ 1. 0.03278046 25.80294168] 963.6008804391676\n", + "Best after iteration 16: [ 1. 0.03278046 25.80294168] 963.6008804391676\n", + "New best for swarm at iteration 17: [ 1. 0.03278706 25.73100244] 961.7118511730044\n", + "Best after iteration 17: [ 1. 0.03278706 25.73100244] 961.7118511730044\n", + "New best for swarm at iteration 18: [ 1. 0.03279043 25.73246628] 961.263391696539\n", + "Best after iteration 18: [ 1. 0.03279043 25.73246628] 961.263391696539\n", + "New best for swarm at iteration 19: [ 1. 0.03279278 25.79734207] 961.1574566149003\n", + "Best after iteration 19: [ 1. 0.03279278 25.79734207] 961.1574566149003\n", + "New best for swarm at iteration 20: [ 1. 0.03279307 25.75317367] 961.1561728429244\n", + "Best after iteration 20: [ 1. 0.03279307 25.75317367] 961.1561728429244\n", + "Best after iteration 21: [ 1. 0.03279307 25.75317367] 961.1561728429244\n", + "Best after iteration 22: [ 1. 0.03279307 25.75317367] 961.1561728429244\n", + "Best after iteration 23: [ 1. 0.03279307 25.75317367] 961.1561728429244\n", + "Best after iteration 24: [ 1. 0.03279307 25.75317367] 961.1561728429244\n", + "Best after iteration 25: [ 1. 0.03279307 25.75317367] 961.1561728429244\n", + "Best after iteration 26: [ 1. 0.03279307 25.75317367] 961.1561728429244\n", + "New best for swarm at iteration 27: [ 1. 0.03279307 25.75214352] 961.1561725663297\n", + "Best after iteration 27: [ 1. 0.03279307 25.75214352] 961.1561725663297\n", + "Best after iteration 28: [ 1. 0.03279307 25.75214352] 961.1561725663297\n", + "Best after iteration 29: [ 1. 0.03279307 25.75214352] 961.1561725663297\n", + "Best after iteration 30: [ 1. 0.03279307 25.75214352] 961.1561725663297\n", + "Stopping search: maximum iterations reached --> 30\n" + ] + }, + { + "data": { + "text/plain": [ + "array([ 1. , 0.03279307, 25.75214352])" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "fit(sse, sir_model, data_to_fit, ['sigma', 'beta'], bounds)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python [conda env:COVID_MODEL]", + "language": "python", + "name": "conda-env-COVID_MODEL-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": true + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}