Skip to content

Commit

Permalink
Allow specifying n_rows in const_marginal_effect
Browse files Browse the repository at this point in the history
  • Loading branch information
kbattocchi committed Oct 29, 2020
1 parent 48613ea commit 8084c0a
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 36 deletions.
15 changes: 9 additions & 6 deletions econml/_ortho_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,22 +613,25 @@ def _fit_final(self, Y, T, X=None, W=None, Z=None, nuisances=None, sample_weight
sample_weight=sample_weight,
sample_var=sample_var))

def const_marginal_effect(self, X=None):
def const_marginal_effect(self, X=None, *, n_rows=None):
self._check_fitted_dims(X)
if X is None:
return self._model_final.predict()
pred = self._model_final.predict()
return pred if n_rows is None else np.repeat(pred, n_rows, axis=0)
else:
if n_rows is not None:
assert shape(X)[0] == n_rows
return self._model_final.predict(X)
const_marginal_effect.__doc__ = LinearCateEstimator.const_marginal_effect.__doc__

def const_marginal_effect_interval(self, X=None, *, alpha=0.1):
def const_marginal_effect_interval(self, X=None, *, alpha=0.1, n_rows=None):
self._check_fitted_dims(X)
return super().const_marginal_effect_interval(X, alpha=alpha)
return super().const_marginal_effect_interval(X, alpha=alpha, n_rows=n_rows)
const_marginal_effect_interval.__doc__ = LinearCateEstimator.const_marginal_effect_interval.__doc__

def const_marginal_effect_inference(self, X=None):
def const_marginal_effect_inference(self, X=None, *, n_rows=None):
self._check_fitted_dims(X)
return super().const_marginal_effect_inference(X)
return super().const_marginal_effect_inference(X, n_rows=n_rows)
const_marginal_effect_inference.__doc__ = LinearCateEstimator.const_marginal_effect_inference.__doc__

def effect_interval(self, X=None, *, T0=0, T1=1, alpha=0.1):
Expand Down
40 changes: 23 additions & 17 deletions econml/cate_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ class LinearCateEstimator(BaseCateEstimator):
"""Base class for all CATE estimators with linear treatment effects in this package."""

@abc.abstractmethod
def const_marginal_effect(self, X=None):
def const_marginal_effect(self, X=None, *, n_rows=None):
"""
Calculate the constant marginal CATE :math:`\\theta(·)`.
Expand All @@ -297,6 +297,10 @@ def const_marginal_effect(self, X=None):
----------
X: optional (m, d_x) matrix or None (Default=None)
Features for each sample.
n_rows: optional int
Number of rows to return if X is None; if no number of rows is specified and X is None
then 1 row will be returned. If X is not None and a value is provided for the number of rows,
that value must agree with the number of rows in X.
Returns
-------
Expand Down Expand Up @@ -379,28 +383,22 @@ def marginal_effect(self, T, X=None):

def marginal_effect_interval(self, T, X=None, *, alpha=0.1):
X, T = self._expand_treatments(X, T)
effs = self.const_marginal_effect_interval(X=X, alpha=alpha)
return tuple(np.repeat(eff, shape(T)[0], axis=0) if X is None else eff
for eff in effs)
if X is not None:
return self.const_marginal_effect_interval(X=X, alpha=alpha)
else: # need to pass the number of rows of T to ensure the right shape
return self.const_marginal_effect_interval(X=X, alpha=alpha, n_rows=shape(T)[0])
marginal_effect_interval.__doc__ = BaseCateEstimator.marginal_effect_interval.__doc__

def marginal_effect_inference(self, T, X=None):
X, T = self._expand_treatments(X, T)
cme_inf = self.const_marginal_effect_inference(X=X)
pred = cme_inf.point_estimate
pred_stderr = cme_inf.stderr
if X is None:
pred = np.repeat(pred, shape(T)[0], axis=0)
pred_stderr = np.repeat(pred_stderr, shape(T)[0], axis=0)
# TODO: It seems wrong to return inference results based on a normal approximation
# even in the case where const_marginal_effect_inference was actually generated
# using bootstrap
return NormalInferenceResults(d_t=cme_inf.d_t, d_y=cme_inf.d_y, pred=pred,
pred_stderr=pred_stderr, inf_type='effect', fname_transformer=None)
if X is not None:
return self.const_marginal_effect_inference(X=X)
else: # need to pass the number of rows of T to ensure the right shape
return self.const_marginal_effect_inference(X=X, n_rows=shape(T)[0])
marginal_effect_inference.__doc__ = BaseCateEstimator.marginal_effect_inference.__doc__

@BaseCateEstimator._defer_to_inference
def const_marginal_effect_interval(self, X=None, *, alpha=0.1):
def const_marginal_effect_interval(self, X=None, *, alpha=0.1, n_rows=None):
""" Confidence intervals for the quantities :math:`\\theta(X)` produced
by the model. Available only when ``inference`` is not ``None``, when
calling the fit method.
Expand All @@ -412,6 +410,10 @@ def const_marginal_effect_interval(self, X=None, *, alpha=0.1):
alpha: optional float in [0, 1] (Default=0.1)
The overall level of confidence of the reported interval.
The alpha/2, 1-alpha/2 confidence interval is reported.
n_rows: optional int
Number of rows to return if X is None; if no number of rows is specified and X is None
then 1 row will be returned. If X is not None and a value is provided for the number of rows,
that value must agree with the number of rows in X.
Returns
-------
Expand All @@ -422,7 +424,7 @@ def const_marginal_effect_interval(self, X=None, *, alpha=0.1):
pass

@BaseCateEstimator._defer_to_inference
def const_marginal_effect_inference(self, X=None):
def const_marginal_effect_inference(self, X=None, *, n_rows=None):
""" Inference results for the quantities :math:`\\theta(X)` produced
by the model. Available only when ``inference`` is not ``None``, when
calling the fit method.
Expand All @@ -431,6 +433,10 @@ def const_marginal_effect_inference(self, X=None):
----------
X: optional (m, d_x) matrix or None (Default=None)
Features for each sample
n_rows: optional int
Number of rows to return if X is None; if no number of rows is specified and X is None
then 1 row will be returned. If X is not None and a value is provided for the number of rows,
that value must agree with the number of rows in X.
Returns
-------
Expand Down
46 changes: 36 additions & 10 deletions econml/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,17 +133,26 @@ def fit(self, estimator, *args, **kwargs):
self.d_t = self._d_t[0] if self._d_t else 1
self.d_y = self._d_y[0] if self._d_y else 1

def const_marginal_effect_interval(self, X, *, alpha=0.1):
def const_marginal_effect_interval(self, X, *, alpha=0.1, n_rows=None):
assert X is None or n_rows is None or n_rows == shape(X)[0]
repeat_X = X is None and n_rows is not None
if X is None:
X = np.ones((1, 1))
elif self.featurizer is not None:
X = self.featurizer.transform(X)
X, T = broadcast_unit_treatments(X, self.d_t)
preds = self._predict_interval(cross_product(X, T), alpha=alpha)
return tuple(reshape_treatmentwise_effects(pred, self._d_t, self._d_y)
for pred in preds)
preds = tuple(reshape_treatmentwise_effects(pred, self._d_t, self._d_y)
for pred in preds)
if repeat_X:
preds = tuple(np.repeat(pred, n_rows, axis=0)
for pred in preds)

def const_marginal_effect_inference(self, X):
return preds

def const_marginal_effect_inference(self, X, *, n_rows=None):
assert X is None or n_rows is None or n_rows == shape(X)[0]
repeat_X = X is None and n_rows is not None
if X is None:
X = np.ones((1, 1))
elif self.featurizer is not None:
Expand All @@ -154,6 +163,9 @@ def const_marginal_effect_inference(self, X):
raise AttributeError("Final model doesn't support prediction standard eror, "
"please call const_marginal_effect_interval to get confidence interval.")
pred_stderr = reshape_treatmentwise_effects(self._prediction_stderr(cross_product(X, T)), self._d_t, self._d_y)
if repeat_X:
pred = np.repeat(pred, n_rows, axis=0)
pred_stderr = np.repeat(pred_stderr, n_rows, axis=0)
return NormalInferenceResults(d_t=self.d_t, d_y=self.d_y, pred=pred,
pred_stderr=pred_stderr, inf_type='effect', fname_transformer=None)

Expand Down Expand Up @@ -370,23 +382,37 @@ def fit(self, estimator, *args, **kwargs):
if hasattr(estimator, 'fit_cate_intercept'):
self.fit_cate_intercept = estimator.fit_cate_intercept

def const_marginal_effect_interval(self, X, *, alpha=0.1):
def const_marginal_effect_interval(self, X, *, alpha=0.1, n_rows=None):
assert X is None or n_rows is None or n_rows == shape(X)[0]
repeat_X = X is None and n_rows is not None

if (X is not None) and (self.featurizer is not None):
X = self.featurizer.transform(X)
preds = np.array([mdl.predict_interval(X, alpha=alpha) for mdl in self.fitted_models_final])
return tuple(np.moveaxis(preds, [0, 1], [-1, 0])) # send treatment to the end, pull bounds to the front
preds = tuple(np.moveaxis(preds, [0, 1], [-1, 0])) # send treatment to the end, pull bounds to the front
if repeat_X:
preds = tuple(np.repeat(pred, n_rows, axis=0) for pred in preds)
return preds

def const_marginal_effect_inference(self, X, *, n_rows=None):
assert X is None or n_rows is None or n_rows == shape(X)[0]
repeat_X = X is None and n_rows is not None

def const_marginal_effect_inference(self, X):
if (X is not None) and (self.featurizer is not None):
X = self.featurizer.transform(X)
pred = np.array([mdl.predict(X) for mdl in self.fitted_models_final])
pred = np.moveaxis(pred, 0, -1) # send treatment to the end, pull bounds to the front
if not hasattr(self.fitted_models_final[0], 'prediction_stderr'):
raise AttributeError("Final model doesn't support prediction standard eror, "
"please call const_marginal_effect_interval to get confidence interval.")
pred_stderr = np.array([mdl.prediction_stderr(X) for mdl in self.fitted_models_final])
return NormalInferenceResults(d_t=self.d_t, d_y=self.d_y, pred=np.moveaxis(pred, 0, -1),
# send treatment to the end, pull bounds to the front
pred_stderr=np.moveaxis(pred_stderr, 0, -1), inf_type='effect',
pred_stderr = np.moveaxis(pred_stderr, 0, -1) # send treatment to the end, pull bounds to the front

if repeat_X:
pred = np.repeat(pred, n_rows, axis=0)
pred_stderr = np.repeat(pred_stderr, n_rows, axis=0)
return NormalInferenceResults(d_t=self.d_t, d_y=self.d_y, pred=pred,
pred_stderr=pred_stderr, inf_type='effect',
fname_transformer=None)

def effect_interval(self, X, *, T0, T1, alpha=0.1):
Expand Down
7 changes: 6 additions & 1 deletion econml/metalearners.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def fit(self, Y, T, X=None, *, inference=None):
feat_arr = np.concatenate((X, T), axis=1)
self.overall_model.fit(feat_arr, Y)

def const_marginal_effect(self, X=None):
def const_marginal_effect(self, X=None, *, n_rows=None):
"""Calculate the constant marginal treatment effect on a vector of features for each sample.
Parameters
Expand All @@ -180,6 +180,9 @@ def const_marginal_effect(self, X=None):
Note that when Y is a vector rather than a 2-dimensional array,
the corresponding singleton dimensions in the output will be collapsed
"""
assert X is None or n_rows is None or n_rows == shape(X)[0]
repeat_X = X is None and n_rows is not None

# Check inputs
if X is None:
X = np.zeros((1, 1))
Expand All @@ -192,6 +195,8 @@ def const_marginal_effect(self, X=None):
taus = (prediction - np.repeat(prediction[:, :, 0], self._d_t[0] + 1).reshape(prediction.shape))[:, :, 1:]
else:
taus = (prediction - np.repeat(prediction[:, 0], self._d_t[0] + 1).reshape(prediction.shape))[:, 1:]
if repeat_X:
taus = np.repeat(taus, n_rows, axis=0)
return taus


Expand Down
11 changes: 9 additions & 2 deletions econml/ortho_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,7 +969,7 @@ def fit(self, estimator, *args, **kwargs):
self._T_vec = (T0.ndim == 1)
return self

def const_marginal_effect_interval(self, X=None, *, alpha=0.1):
def const_marginal_effect_interval(self, X=None, *, alpha=0.1, n_rows=None):
""" Confidence intervals for the quantities :math:`\\theta(X)` produced
by the model. Available only when ``inference`` is ``blb``, when
calling the fit method.
Expand All @@ -989,6 +989,9 @@ def const_marginal_effect_interval(self, X=None, *, alpha=0.1):
type of :meth:`const_marginal_effect(X)<const_marginal_effect>` )
The lower and the upper bounds of the confidence interval for each quantity.
"""
assert X is None or n_rows is None or n_rows == shape(X)[0]
repeat_X = X is None and n_rows is not None

params_and_cov = self._predict_wrapper(X)
# Calculate confidence intervals for the parameter (marginal effect)
lower = alpha / 2
Expand All @@ -1000,7 +1003,11 @@ def const_marginal_effect_interval(self, X=None, *, alpha=0.1):
param_lower, param_upper = np.asarray(param_lower), np.asarray(param_upper)
if self._T_vec:
# If T is a vector, preserve shape of the effect interval
return param_lower.flatten(), param_upper.flatten()
param_lower = param_lower.flatten()
param_upper = param_upper.flatten()
if repeat_X:
param_lower = np.repeat(param_lower, n_rows, axis=0)
param_upper = np.repeat(param_upper, n_rows, axis=0)
return param_lower, param_upper

def effect_interval(self, X=None, *, T0=0, T1=1, alpha=0.1):
Expand Down

0 comments on commit 8084c0a

Please sign in to comment.