Skip to content

Commit

Permalink
Enable stratification in bootstrap
Browse files Browse the repository at this point in the history
  • Loading branch information
kbattocchi committed Mar 26, 2020
1 parent 0f5ddfe commit 7b8054c
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 11 deletions.
43 changes: 39 additions & 4 deletions econml/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,32 +44,62 @@ class BootstrapEstimator:
In case a method ending in '_interval' exists on the wrapped object, whether
that should be preferred (meaning this wrapper will compute the mean of it).
This option only affects behavior if `compute_means` is set to ``True``.
stratify_treatment: bool, default False
Whether to stratify by treatment when calling fit; this will ensure that each stratum of treatment
is subsampled independently, so that each resample will have the same number of entries with each
treatment as the original sample did.
"""

def __init__(self, wrapped, n_bootstrap_samples=1000, n_jobs=None, compute_means=True, prefer_wrapped=False):
def __init__(self, wrapped, n_bootstrap_samples=1000, n_jobs=None,
compute_means=True, prefer_wrapped=False, stratify_treatment=False):
self._instances = [clone(wrapped, safe=False) for _ in range(n_bootstrap_samples)]
self._n_bootstrap_samples = n_bootstrap_samples
self._n_jobs = n_jobs
self._compute_means = compute_means
self._prefer_wrapped = prefer_wrapped
self._stratify_treatment = stratify_treatment

# TODO: Add a __dir__ implementation?

def _stratified_indices(self, Y, T, *args, **kwargs):
assert 1 <= np.ndim(T) <= 2
unique = np.unique(T, axis=0)
indices = []
for el in unique:
ind, = np.where(np.all(T == el, axis=1) if np.ndim(T) == 2 else T == el)
indices.append(ind)
return indices

def fit(self, *args, **named_args):
"""
Fit the model.
The full signature of this method is the same as that of the wrapped object's `fit` method.
"""
n_samples = np.shape(args[0] if args else named_args[(*named_args,)[0]])[0]
indices = np.random.choice(n_samples, size=(self._n_bootstrap_samples, n_samples), replace=True)

if self._stratify_treatment:
index_chunks = self._stratified_indices(*args, **named_args)
else:
n_samples = np.shape(args[0] if args else named_args[(*named_args,)[0]])[0]
index_chunks = [np.arange(n_samples)] # one chunk with all indices

indices = []
for chunk in index_chunks:
n_samples = len(chunk)
indices.append(chunk[np.random.choice(n_samples,
size=(self._n_bootstrap_samples, n_samples),
replace=True)])

indices = np.hstack(indices)

def fit(x, *args, **kwargs):
x.fit(*args, **kwargs)
return x # Explicitly return x in case fit fails to return its target

def convertArg(arg, inds):
return arg[inds] if arg is not None else None
return np.asarray(arg)[inds] if arg is not None else None

self._instances = Parallel(n_jobs=self._n_jobs, prefer='threads', verbose=3)(
delayed(fit)(obj,
*[convertArg(arg, inds) for arg in args],
Expand All @@ -84,6 +114,11 @@ def __getattr__(self, name):
Additionally, the suffix "_interval" is supported for getting an interval instead of a point estimate.
"""

# don't proxy special methods
if name.startswith('__'):
raise AttributeError(name)

def proxy(make_call, name, summary):
def summarize_with(f):
return summary(np.array(Parallel(n_jobs=self._n_jobs, prefer='threads', verbose=3)(
Expand Down
4 changes: 3 additions & 1 deletion econml/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def __init__(self, n_bootstrap_samples=100, n_jobs=-1):
self._n_jobs = n_jobs

def fit(self, estimator, *args, **kwargs):
est = BootstrapEstimator(estimator, self._n_bootstrap_samples, self._n_jobs, compute_means=False)
discrete_treatment = estimator._discrete_treatment if hasattr(estimator, '_discrete_treatment') else False
est = BootstrapEstimator(estimator, self._n_bootstrap_samples, self._n_jobs, compute_means=False,
stratify_treatment=discrete_treatment)
est.fit(*args, **kwargs)
self._est = est

Expand Down
17 changes: 16 additions & 1 deletion econml/tests/test_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from econml.inference import BootstrapInference
from econml.dml import LinearDMLCateEstimator
from econml.two_stage_least_squares import NonparametricTwoStageLeastSquares
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.preprocessing import PolynomialFeatures
import numpy as np
import unittest
Expand Down Expand Up @@ -265,3 +265,18 @@ def test_internal_options(self):

# TODO: test that the estimated effect is usually within the bounds
# and that the true effect is also usually within the bounds

def test_stratify(self):
"""Test that we can properly stratify by treatment"""
T = [1, 0, 1, 2, 0, 2]
Y = [1, 2, 3, 4, 5, 6]
X = np.array([1, 1, 2, 2, 1, 2]).reshape(-1, 1)
est = LinearDMLCateEstimator(model_y=LinearRegression(), model_t=LogisticRegression(), discrete_treatment=True)
est.fit(Y, T, inference='bootstrap')
est.const_marginal_effect_interval()

est.fit(Y, T, X=X, inference='bootstrap')
est.const_marginal_effect_interval(X)

est.fit(Y, np.asarray(T).reshape(-1, 1), inference='bootstrap') # test stratifying 2D treatment
est.const_marginal_effect_interval()
6 changes: 1 addition & 5 deletions econml/tests/test_dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,7 @@ def make_random(is_discrete, d):

model_t = LogisticRegression() if is_discrete else Lasso()

# TODO: add stratification to bootstrap so that we can use it
# even with discrete treatments
all_infs = [None, 'statsmodels']
if not is_discrete:
all_infs.append(BootstrapInference(1))
all_infs = [None, 'statsmodels', BootstrapInference(1)]

for est, multi, infs in\
[(LinearDMLCateEstimator(model_y=Lasso(),
Expand Down

0 comments on commit 7b8054c

Please sign in to comment.