Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable stratification in bootstrap #236

Merged
merged 5 commits into from
Nov 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 21 additions & 10 deletions econml/_ortho_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,22 @@ def _filter_none_kwargs(self, **kwargs):
non_none_kwargs[key] = value
return non_none_kwargs

def _strata(self, Y, T, X=None, W=None, Z=None, sample_weight=None, sample_var=None):
kbattocchi marked this conversation as resolved.
Show resolved Hide resolved
if self._discrete_instrument:
Z = LabelEncoder().fit_transform(np.ravel(Z))

if self._discrete_treatment:
enc = LabelEncoder()
T = enc.fit_transform(np.ravel(T))
if self._discrete_instrument:
return T + Z * len(enc.classes_)
else:
return T
elif self._discrete_instrument:
return Z
else:
return None

@BaseCateEstimator._wrap_fit
def fit(self, Y, T, X=None, W=None, Z=None, sample_weight=None, sample_var=None, *, inference=None):
"""
Expand Down Expand Up @@ -540,27 +556,22 @@ def fit(self, Y, T, X=None, W=None, Z=None, sample_weight=None, sample_var=None,
def _fit_nuisances(self, Y, T, X=None, W=None, Z=None, sample_weight=None):
# use a binary array to get stratified split in case of discrete treatment
stratify = self._discrete_treatment or self._discrete_instrument
strata = self._strata(Y, T, X=X, W=W, Z=Z, sample_weight=sample_weight)
kbattocchi marked this conversation as resolved.
Show resolved Hide resolved
if strata is None:
strata = T # always safe to pass T as second arg to split even if we're not actually stratifying

if self._discrete_treatment:
T = self._one_hot_encoder.fit_transform(reshape(T, (-1, 1)))

if self._discrete_instrument:
z_enc = LabelEncoder()
Z = z_enc.fit_transform(Z.ravel())

if self._discrete_treatment: # need to stratify on combination of Z and T
to_split = inverse_onehot(T) + Z * len(self._one_hot_encoder.categories_[0])
else:
to_split = Z # just stratify on Z

z_ohe = OneHotEncoder(categories='auto', sparse=False, drop='first')
Z = z_ohe.fit_transform(reshape(Z, (-1, 1)))
self.z_transformer = FunctionTransformer(
func=_EncoderWrapper(z_ohe, z_enc).encode,
validate=False)
else:
# stratify on T if discrete, and fine to pass T as second arg to KFold.split even when not
to_split = inverse_onehot(T) if self._discrete_treatment else T
self.z_transformer = None

if self._n_splits == 1: # special case, no cross validation
Expand All @@ -575,9 +586,9 @@ def _fit_nuisances(self, Y, T, X=None, W=None, Z=None, sample_weight=None):
all_vars = [var if np.ndim(var) == 2 else var.reshape(-1, 1) for var in [Z, W, X] if var is not None]
if all_vars:
all_vars = np.hstack(all_vars)
folds = splitter.split(all_vars, to_split)
folds = splitter.split(all_vars, strata)
else:
folds = splitter.split(np.ones((T.shape[0], 1)), to_split)
folds = splitter.split(np.ones((T.shape[0], 1)), strata)

if self._discrete_treatment:
self._d_t = shape(T)[1:]
Expand Down
144 changes: 130 additions & 14 deletions econml/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import numpy as np
from joblib import Parallel, delayed
from sklearn.base import clone
from scipy.stats import norm
from collections import OrderedDict
import pandas as pd


class BootstrapEstimator:
Expand Down Expand Up @@ -44,32 +47,69 @@ class BootstrapEstimator:
In case a method ending in '_interval' exists on the wrapped object, whether
that should be preferred (meaning this wrapper will compute the mean of it).
This option only affects behavior if `compute_means` is set to ``True``.

bootstrap_type: 'percentile', 'pivot', or 'normal', default 'pivot'
Bootstrap method used to compute results. 'percentile' will result in using the empiracal CDF of
the replicated computations of the statistics. 'pivot' will also use the replicates but create a pivot
interval that also relies on the estimate over the entire dataset. 'normal' will instead compute an interval
assuming the replicates are normally distributed.
"""

def __init__(self, wrapped, n_bootstrap_samples=1000, n_jobs=None, compute_means=True, prefer_wrapped=False):
def __init__(self, wrapped, n_bootstrap_samples=1000, n_jobs=None, compute_means=True, prefer_wrapped=False,
bootstrap_type='pivot'):
self._instances = [clone(wrapped, safe=False) for _ in range(n_bootstrap_samples)]
self._n_bootstrap_samples = n_bootstrap_samples
self._n_jobs = n_jobs
self._compute_means = compute_means
self._prefer_wrapped = prefer_wrapped
self._bootstrap_type = bootstrap_type
self._wrapped = wrapped

# TODO: Add a __dir__ implementation?

@staticmethod
def __stratified_indices(arr):
assert 1 <= np.ndim(arr) <= 2
unique = np.unique(arr, axis=0)
indices = []
for el in unique:
ind, = np.where(np.all(arr == el, axis=1) if np.ndim(arr) == 2 else arr == el)
indices.append(ind)
return indices

def fit(self, *args, **named_args):
"""
Fit the model.

The full signature of this method is the same as that of the wrapped object's `fit` method.
"""
n_samples = np.shape(args[0] if args else named_args[(*named_args,)[0]])[0]
indices = np.random.choice(n_samples, size=(self._n_bootstrap_samples, n_samples), replace=True)
from .cate_estimator import BaseCateEstimator # need to nest this here to avoid circular import

index_chunks = None
if isinstance(self._instances[0], BaseCateEstimator):
index_chunks = self._instances[0]._strata(*args, **named_args)
if index_chunks is not None:
index_chunks = self.__stratified_indices(index_chunks)
if index_chunks is None:
n_samples = np.shape(args[0] if args else named_args[(*named_args,)[0]])[0]
index_chunks = [np.arange(n_samples)] # one chunk with all indices

indices = []
for chunk in index_chunks:
n_samples = len(chunk)
indices.append(chunk[np.random.choice(n_samples,
size=(self._n_bootstrap_samples, n_samples),
replace=True)])

indices = np.hstack(indices)

def fit(x, *args, **kwargs):
x.fit(*args, **kwargs)
return x # Explicitly return x in case fit fails to return its target

def convertArg(arg, inds):
return arg[inds] if arg is not None else None
return np.asarray(arg)[inds] if arg is not None else None

self._instances = Parallel(n_jobs=self._n_jobs, prefer='threads', verbose=3)(
delayed(fit)(obj,
*[convertArg(arg, inds) for arg in args],
Expand All @@ -84,10 +124,16 @@ def __getattr__(self, name):

Additionally, the suffix "_interval" is supported for getting an interval instead of a point estimate.
"""

# don't proxy special methods
if name.startswith('__'):
raise AttributeError(name)

def proxy(make_call, name, summary):
def summarize_with(f):
return summary(np.array(Parallel(n_jobs=self._n_jobs, prefer='threads', verbose=3)(
(f, (obj, name), {}) for obj in self._instances)))
results = np.array(Parallel(n_jobs=self._n_jobs, prefer='threads', verbose=3)(
(f, (obj, name), {}) for obj in self._instances)), f(self._wrapped, name)
return summary(*results)
if make_call:
def call(*args, **kwargs):
return summarize_with(lambda obj, name: getattr(obj, name)(*args, **kwargs))
Expand All @@ -97,16 +143,36 @@ def call(*args, **kwargs):

def get_mean():
# for attributes that exist on the wrapped object, just compute the mean of the wrapped calls
return proxy(callable(getattr(self._instances[0], name)), name, lambda arr: np.mean(arr, axis=0))
return proxy(callable(getattr(self._instances[0], name)), name, lambda arr, _: np.mean(arr, axis=0))

def get_std():
prefix = name[: - len('_std')]
return proxy(callable(getattr(self._instances[0], prefix)), prefix,
lambda arr, _: np.std(arr, axis=0))

def get_interval():
# if the attribute exists on the wrapped object once we remove the suffix,
# then we should be computing a confidence interval for the wrapped calls
prefix = name[: - len("_interval")]

def call_with_bounds(can_call, lower, upper):
return proxy(can_call, prefix,
lambda arr: (np.percentile(arr, lower, axis=0), np.percentile(arr, upper, axis=0)))
def percentile_bootstrap(arr, _):
return np.percentile(arr, lower, axis=0), np.percentile(arr, upper, axis=0)

def pivot_bootstrap(arr, est):
return 2 * est - np.percentile(arr, upper, axis=0), 2 * est - np.percentile(arr, lower, axis=0)

def normal_bootstrap(arr, est):
std = np.std(arr, axis=0)
return est - norm.ppf(upper / 100) * std, est - norm.ppf(lower / 100) * std

# TODO: studentized bootstrap? this would be more accurate in most cases but can we avoid
# second level bootstrap which would be prohibitive computationally?

fn = {'percentile': percentile_bootstrap,
'normal': normal_bootstrap,
'pivot': pivot_bootstrap}[self._bootstrap_type]
return proxy(can_call, prefix, fn)

can_call = callable(getattr(self._instances[0], prefix))
if can_call:
Expand All @@ -120,19 +186,69 @@ def call(lower=5, upper=95):
return call_with_bounds(can_call, lower, upper)
return call

def get_inference():
# can't import from econml.inference at top level without creating mutual dependencies
from .inference import EmpiricalInferenceResults

prefix = name[: - len("_inference")]
if prefix in ['const_marginal_effect', 'marginal_effect', 'effect']:
inf_type = 'effect'
elif prefix == 'coef_':
inf_type = 'coefficient'
elif prefix == 'intercept_':
inf_type = 'intercept'
else:
raise AttributeError("Unsupported inference: " + name)

d_t = self._wrapped._d_t[0] if self._wrapped._d_t else 1
d_t = 1 if prefix == 'effect' else d_t
d_y = self._wrapped._d_y[0] if self._wrapped._d_y else 1

def get_inference_nonparametric(kind):
def get_dist(est, arr):
if kind == 'percentile':
return arr
elif kind == 'pivot':
return 2 * est - arr
else:
raise ValueError("Invalid kind, must be either 'percentile' or 'pivot'")
return proxy(callable(getattr(self._instances[0], prefix)), prefix,
lambda arr, est: EmpiricalInferenceResults(d_t=d_t, d_y=d_y,
pred=est, pred_dist=get_dist(est, arr),
inf_type=inf_type, fname_transformer=None))

def get_inference_parametric():
pred = getattr(self._wrapped, prefix)
stderr = getattr(self, prefix + '_std')
return NormalInferenceResults(d_t=d_t, d_y=d_y, pred=pred,
pred_stderr=stderr, inf_type=inf_type,
pred_dist=None, fname_transformer=None)

return {'normal': get_inference_parametric,
'percentile': lambda: get_inference_nonparametric('percentile'),
'pivot': lambda: get_inference_nonparametric('pivot')}[self._bootstrap_type]

caught = None
m = None
if name.endswith("_interval"):
m = get_interval
elif name.endswith("_std"):
m = get_std
elif name.endswith("_inference"):
m = get_inference
if self._compute_means and self._prefer_wrapped:
try:
return get_mean()
except AttributeError as err:
caught = err
if name.endswith("_interval"):
return get_interval()
if m is not None:
m()
else:
# try to get interval first if appropriate, since we don't prefer a wrapped method with this name
if name.endswith("_interval"):
# try to get interval/std first if appropriate,
# since we don't prefer a wrapped method with this name
if m is not None:
try:
return get_interval()
return m()
except AttributeError as err:
caught = err
if self._compute_means:
Expand Down
31 changes: 21 additions & 10 deletions econml/cate_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@
from functools import wraps
from copy import deepcopy
from warnings import warn
from .bootstrap import BootstrapEstimator
from .inference import BootstrapInference
from .utilities import tensordot, ndim, reshape, shape, parse_final_model_params, inverse_onehot
from .inference import StatsModelsInference, StatsModelsInferenceDiscrete, LinearModelFinalInference,\
LinearModelFinalInferenceDiscrete, InferenceResults
LinearModelFinalInferenceDiscrete, NormalInferenceResults


class BaseCateEstimator(metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -42,6 +41,21 @@ def _get_inference(self, inference):
# because inf now stores state from fitting est2
return deepcopy(inference)

def _strata(self, Y, T, *args, **kwargs):
"""
Get an array of values representing strata that should be preserved by bootstrapping. For example,
if treatment is discrete, then each bootstrapped estimator needs to be given at least one instance
with each treatment type. For estimators like DRIV, then the same is true of the combination of
treatment and instrument. The arguments to this method will match those to fit.

Returns
-------
strata : array or None
A vector with the same number of rows as the inputs, where the unique values represent
the strata that need to be preserved by bootstrapping, or None if no preservation is necessary.
"""
return None

def _prefit(self, Y, T, *args, **kwargs):
self._d_y = np.shape(Y)[1:]
self._d_t = np.shape(T)[1:]
Expand Down Expand Up @@ -366,20 +380,17 @@ def marginal_effect(self, T, X=None):
def marginal_effect_interval(self, T, X=None, *, alpha=0.1):
X, T = self._expand_treatments(X, T)
effs = self.const_marginal_effect_interval(X=X, alpha=alpha)
return tuple(np.repeat(eff, shape(T)[0], axis=0) if X is None else eff
for eff in effs)
if X is None: # need to repeat by the number of rows of T to ensure the right shape
effs = tuple(np.repeat(eff, shape(T)[0], axis=0) for eff in effs)
return effs
marginal_effect_interval.__doc__ = BaseCateEstimator.marginal_effect_interval.__doc__

def marginal_effect_inference(self, T, X=None):
X, T = self._expand_treatments(X, T)
cme_inf = self.const_marginal_effect_inference(X=X)
pred = cme_inf.point_estimate
pred_stderr = cme_inf.stderr
if X is None:
pred = np.repeat(pred, shape(T)[0], axis=0)
pred_stderr = np.repeat(pred_stderr, shape(T)[0], axis=0)
return InferenceResults(d_t=cme_inf.d_t, d_y=cme_inf.d_y, pred=pred,
pred_stderr=pred_stderr, inf_type='effect', pred_dist=None, fname_transformer=None)
cme_inf = cme_inf._expand_outputs(shape(T)[0])
return cme_inf
marginal_effect_inference.__doc__ = BaseCateEstimator.marginal_effect_inference.__doc__

@BaseCateEstimator._defer_to_inference
Expand Down
1 change: 1 addition & 0 deletions econml/cate_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,7 @@ def interpret(self, cate_estimator, X, sample_treatment_costs=None, treatment_na
splitter=self.splitter,
max_depth=self.max_depth,
min_samples_split=self.min_samples_split,
min_samples_leaf=self.min_samples_leaf,
min_weight_fraction_leaf=self.min_weight_fraction_leaf,
max_features=self.max_features,
random_state=self.random_state,
Expand Down
Loading