Skip to content

Commit

Permalink
Enable expanding output in inference classes
Browse files Browse the repository at this point in the history
  • Loading branch information
kbattocchi committed Oct 30, 2020
1 parent 48613ea commit 11c9dac
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 13 deletions.
16 changes: 5 additions & 11 deletions econml/cate_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,23 +380,17 @@ def marginal_effect(self, T, X=None):
def marginal_effect_interval(self, T, X=None, *, alpha=0.1):
X, T = self._expand_treatments(X, T)
effs = self.const_marginal_effect_interval(X=X, alpha=alpha)
return tuple(np.repeat(eff, shape(T)[0], axis=0) if X is None else eff
for eff in effs)
if X is None: # need to repeat by the number of rows of T to ensure the right shape
effs = tuple(np.repeat(eff, shape(T)[0], axis=0) for eff in effs)
return effs
marginal_effect_interval.__doc__ = BaseCateEstimator.marginal_effect_interval.__doc__

def marginal_effect_inference(self, T, X=None):
X, T = self._expand_treatments(X, T)
cme_inf = self.const_marginal_effect_inference(X=X)
pred = cme_inf.point_estimate
pred_stderr = cme_inf.stderr
if X is None:
pred = np.repeat(pred, shape(T)[0], axis=0)
pred_stderr = np.repeat(pred_stderr, shape(T)[0], axis=0)
# TODO: It seems wrong to return inference results based on a normal approximation
# even in the case where const_marginal_effect_inference was actually generated
# using bootstrap
return NormalInferenceResults(d_t=cme_inf.d_t, d_y=cme_inf.d_y, pred=pred,
pred_stderr=pred_stderr, inf_type='effect', fname_transformer=None)
cme_inf = cme_inf._expand(shape(T)[0])
return cme_inf
marginal_effect_inference.__doc__ = BaseCateEstimator.marginal_effect_inference.__doc__

@BaseCateEstimator._defer_to_inference
Expand Down
33 changes: 31 additions & 2 deletions econml/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from statsmodels.iolib.table import SimpleTable
from .bootstrap import BootstrapEstimator
from .utilities import (cross_product, broadcast_unit_treatments, reshape_treatmentwise_effects,
ndim, inverse_onehot, parse_final_model_params, _safe_norm_ppf, Summary,
ndim, shape, inverse_onehot, parse_final_model_params, _safe_norm_ppf, Summary,
StatsModelsLinearRegression)


Expand Down Expand Up @@ -749,6 +749,24 @@ def _array_to_frame(self, d_t, d_y, arr):
df.columns = ['T' + str(i) for i in range(d_t)]
return df

@abc.abstractmethod
def _expand_outputs(self, n_rows):
"""
Expand the inference results from 1 row to n_rows identical rows. This is used internally when
we move from constant effects when X is None to a marginal effect of a different dimension.
Parameters
----------
n_rows: positive int
The number of rows to expand to
Returns
-------
results: InferenceResults
The expanded results
"""
pass


class NormalInferenceResults(InferenceResults):
"""
Expand Down Expand Up @@ -843,6 +861,12 @@ def pvalue(self, value=0):

return norm.sf(np.abs(self.zstat(value)), loc=0, scale=1) * 2

def _expand_outputs(self, n_rows):
assert shape(self.pred)[0] == shape(self.pred_stderr)[0] == 1
pred = np.repeat(self.pred, n_rows, axis=0)
pred_stderr = np.repeat(self.pred_stderr, n_rows, axis=0)
return NormalInferenceResults(self.d_t, self.d_y, pred, pred_stderr, self.inf_type, self.fname_transformer)


class EmpiricalInferenceResults(InferenceResults):
"""
Expand Down Expand Up @@ -925,9 +949,14 @@ def pvalue(self, value=0):
the corresponding singleton dimensions in the output will be collapsed
(e.g. if both are vectors, then the output of this method will also be a vector)
"""

return min((self.pred_dist < value).sum(), (self.pred_dist > value).sum()) / self.pred_dist.shape[0]

def _expand_outputs(self, n_rows):
assert shape(self.pred)[0] == shape(self.pred_dist)[1] == 1
pred = np.repeat(self.pred, n_rows, axis=0)
pred_dist = np.repeat(self.pred_dist, n_rows, axis=1)
return EmpiricalInferenceResults(self.d_t, self.d_y, pred, pred_dist, self.inf_type, self.fname_transformer)


class PopulationSummaryResults:
"""
Expand Down

0 comments on commit 11c9dac

Please sign in to comment.