Skip to content

Commit

Permalink
Averate Treatment Effect methods to all estimators (#365)
Browse files Browse the repository at this point in the history
* added ate inference methods
  • Loading branch information
vsyrgkanis authored Jan 14, 2021
1 parent f0b0e5b commit 35c5418
Show file tree
Hide file tree
Showing 9 changed files with 2,427 additions and 1,904 deletions.
228 changes: 228 additions & 0 deletions econml/_cate_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,54 @@ def marginal_effect(self, T, X=None):
"""
pass

def ate(self, X=None, *, T0, T1):
"""
Calculate the average treatment effect :math:`E_X[\\tau(X, T0, T1)]`.
The effect is calculated between the two treatment points and is averaged over
the population of X variables.
Parameters
----------
T0: (m, d_t) matrix or vector of length m
Base treatments for each sample
T1: (m, d_t) matrix or vector of length m
Target treatments for each sample
X: optional (m, d_x) matrix
Features for each sample
Returns
-------
τ: float or (d_y,) array
Average treatment effects on each outcome
Note that when Y is a vector rather than a 2-dimensional array, the result will be a scalar
"""
return np.mean(self.effect(X=X, T0=T0, T1=T1), axis=0)

def marginal_ate(self, T, X=None):
"""
Calculate the average marginal effect :math:`E_{T, X}[\\partial\\tau(T, X)]`.
The marginal effect is calculated around a base treatment
point and averaged over the population of X.
Parameters
----------
T: (m, d_t) matrix
Base treatments for each sample
X: optional (m, d_x) matrix
Features for each sample
Returns
-------
grad_tau: (d_y, d_t) array
Average marginal effects on each outcome
Note that when Y or T is a vector rather than a 2-dimensional array,
the corresponding singleton dimensions in the output will be collapsed
(e.g. if both are vectors, then the output of this method will be a scalar)
"""
return np.mean(self.marginal_effect(T, X=X), axis=0)

def _expand_treatments(self, X=None, *Ts):
"""
Given a set of features and treatments, return possibly modified features and treatments.
Expand Down Expand Up @@ -303,6 +351,101 @@ def marginal_effect_inference(self, T, X=None):
"""
pass

@_defer_to_inference
def ate_interval(self, X=None, *, T0, T1, alpha=0.1):
""" Confidence intervals for the quantity :math:`E_X[\\tau(X, T0, T1)]` produced
by the model. Available only when ``inference`` is not ``None``, when
calling the fit method.
Parameters
----------
X: optional (m, d_x) matrix
Features for each sample
T0: optional (m, d_t) matrix or vector of length m (Default=0)
Base treatments for each sample
T1: optional (m, d_t) matrix or vector of length m (Default=1)
Target treatments for each sample
alpha: optional float in [0, 1] (Default=0.1)
The overall level of confidence of the reported interval.
The alpha/2, 1-alpha/2 confidence interval is reported.
Returns
-------
lower, upper : tuple(type of :meth:`ate(X, T0, T1)<ate>`, type of :meth:`ate(X, T0, T1))<ate>` )
The lower and the upper bounds of the confidence interval for each quantity.
"""
pass

@_defer_to_inference
def ate_inference(self, X=None, *, T0, T1):
""" Inference results for the quantity :math:`E_X[\\tau(X, T0, T1)]` produced
by the model. Available only when ``inference`` is not ``None``, when
calling the fit method.
Parameters
----------
X: optional (m, d_x) matrix
Features for each sample
T0: optional (m, d_t) matrix or vector of length m (Default=0)
Base treatments for each sample
T1: optional (m, d_t) matrix or vector of length m (Default=1)
Target treatments for each sample
Returns
-------
PopulationSummaryResults: object
The inference results instance contains prediction and prediction standard error and
can on demand calculate confidence interval, z statistic and p value. It can also output
a dataframe summary of these inference results.
"""
pass

@_defer_to_inference
def marginal_ate_interval(self, T, X=None, *, alpha=0.1):
""" Confidence intervals for the quantities :math:`E_{T,X}[\\partial \\tau(T, X)]` produced
by the model. Available only when ``inference`` is not ``None``, when
calling the fit method.
Parameters
----------
T: (m, d_t) matrix
Base treatments for each sample
X: optional (m, d_x) matrix or None (Default=None)
Features for each sample
alpha: optional float in [0, 1] (Default=0.1)
The overall level of confidence of the reported interval.
The alpha/2, 1-alpha/2 confidence interval is reported.
Returns
-------
lower, upper : tuple(type of :meth:`marginal_ate(T, X)<marginal_ate>`, \
type of :meth:`marginal_ate(T, X)<marginal_ate>` )
The lower and the upper bounds of the confidence interval for each quantity.
"""
pass

@_defer_to_inference
def marginal_ate_inference(self, T, X=None):
""" Inference results for the quantities :math:`E_{T,X}[\\partial \\tau(T, X)]` produced
by the model. Available only when ``inference`` is not ``None``, when
calling the fit method.
Parameters
----------
T: (m, d_t) matrix
Base treatments for each sample
X: optional (m, d_x) matrix or None (Default=None)
Features for each sample
Returns
-------
PopulationSummaryResults: object
The inference results instance contains prediction and prediction standard error and
can on demand calculate confidence interval, z statistic and p value. It can also output
a dataframe summary of these inference results.
"""
pass


class LinearCateEstimator(BaseCateEstimator):
"""Base class for all CATE estimators with linear treatment effects in this package."""
Expand Down Expand Up @@ -457,6 +600,79 @@ def const_marginal_effect_inference(self, X=None):
"""
pass

def const_marginal_ate(self, X=None):
"""
Calculate the average constant marginal CATE :math:`E_X[\\theta(X)]`.
Parameters
----------
X: optional (m, d_x) matrix or None (Default=None)
Features for each sample.
Returns
-------
theta: (d_y, d_t) matrix
Average constant marginal CATE of each treatment on each outcome.
Note that when Y or T is a vector rather than a 2-dimensional array,
the corresponding singleton dimensions in the output will be collapsed
(e.g. if both are vectors, then the output of this method will be a scalar)
"""
return np.mean(self.const_marginal_effect(X=X), axis=0)

@BaseCateEstimator._defer_to_inference
def const_marginal_ate_interval(self, X=None, *, alpha=0.1):
""" Confidence intervals for the quantities :math:`E_X[\\theta(X)]` produced
by the model. Available only when ``inference`` is not ``None``, when
calling the fit method.
Parameters
----------
X: optional (m, d_x) matrix or None (Default=None)
Features for each sample
alpha: optional float in [0, 1] (Default=0.1)
The overall level of confidence of the reported interval.
The alpha/2, 1-alpha/2 confidence interval is reported.
Returns
-------
lower, upper : tuple(type of :meth:`const_marginal_ate(X)<const_marginal_ate>` ,\
type of :meth:`const_marginal_ate(X)<const_marginal_ate>` )
The lower and the upper bounds of the confidence interval for each quantity.
"""
pass

@BaseCateEstimator._defer_to_inference
def const_marginal_ate_inference(self, X=None):
""" Inference results for the quantities :math:`E_X[\\theta(X)]` produced
by the model. Available only when ``inference`` is not ``None``, when
calling the fit method.
Parameters
----------
X: optional (m, d_x) matrix or None (Default=None)
Features for each sample
Returns
-------
PopulationSummaryResults: object
The inference results instance contains prediction and prediction standard error and
can on demand calculate confidence interval, z statistic and p value. It can also output
a dataframe summary of these inference results.
"""
pass

def marginal_ate(self, T, X=None):
return self.const_marginal_ate(X=X)
marginal_ate.__doc__ = BaseCateEstimator.marginal_ate.__doc__

def marginal_ate_interval(self, T, X=None, *, alpha=0.1):
return self.const_marginal_ate_interval(X=X, alpha=alpha)
marginal_ate_interval.__doc__ = BaseCateEstimator.marginal_ate_interval.__doc__

def marginal_ate_inference(self, T, X=None):
return self.const_marginal_ate_inference(X=X)
marginal_ate_inference.__doc__ = BaseCateEstimator.marginal_ate_inference.__doc__

def shap_values(self, X, *, feature_names=None, treatment_names=None, output_names=None, background_samples=100):
""" Shap value for the final stage models (const_marginal_effect)
Expand Down Expand Up @@ -524,6 +740,18 @@ def effect(self, X=None, *, T0=0, T1=1):
return super().effect(X, T0=T0, T1=T1)
effect.__doc__ = BaseCateEstimator.effect.__doc__

def ate(self, X=None, *, T0=0, T1=1):
return super().ate(X=X, T0=T0, T1=T1)
ate.__doc__ = BaseCateEstimator.ate.__doc__

def ate_interval(self, X=None, *, T0=0, T1=1, alpha=0.1):
return super().ate_interval(X=X, T0=T0, T1=T1, alpha=alpha)
ate_interval.__doc__ = BaseCateEstimator.ate_interval.__doc__

def ate_inference(self, X=None, *, T0=0, T1=1):
return super().ate_inference(X=X, T0=T0, T1=T1)
ate_inference.__doc__ = BaseCateEstimator.ate_inference.__doc__


class LinearModelFinalCateEstimatorMixin(BaseCateEstimator):
"""
Expand Down
7 changes: 7 additions & 0 deletions econml/cate_interpreter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

from ._interpreters import SingleTreeCateInterpreter, SingleTreePolicyInterpreter

__all__ = ["SingleTreeCateInterpreter",
"SingleTreePolicyInterpreter"]
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
class _SingleTreeInterpreter(metaclass=abc.ABCMeta):

tree_model = None
node_dict = None

@abc.abstractmethod
def interpret(self, cate_estimator, X):
Expand Down Expand Up @@ -156,7 +157,7 @@ def export_graphviz(self, out_file=None, feature_names=None,
exporter = self._make_dot_exporter(out_file=out_file, feature_names=feature_names, filled=filled,
leaves_parallel=leaves_parallel, rotate=rotate, rounded=rounded,
special_characters=special_characters, precision=precision)
exporter.export(self.tree_model)
exporter.export(self.tree_model, node_dict=self.node_dict)

if return_string:
return out_file.getvalue()
Expand Down Expand Up @@ -249,7 +250,7 @@ def plot(self, ax=None, title=None, feature_names=None,
check_is_fitted(self.tree_model, 'tree_')
exporter = self._make_mpl_exporter(title=title, feature_names=feature_names, filled=filled,
rounded=rounded, precision=precision, fontsize=fontsize)
exporter.export(self.tree_model, ax=ax)
exporter.export(self.tree_model, node_dict=self.node_dict, ax=ax)


class SingleTreeCateInterpreter(_SingleTreeInterpreter):
Expand All @@ -261,7 +262,7 @@ class SingleTreeCateInterpreter(_SingleTreeInterpreter):
include_uncertainty : bool, optional, default False
Whether to include confidence interval information when building a
simplified model of the cate model. If set to True, then
cate estimator needs to support the `effect_interval` method.
cate estimator needs to support the `const_marginal_ate_inference` method.
uncertainty_level : double, optional, default .05
The uncertainty level for the confidence intervals to be constructed
Expand All @@ -270,6 +271,11 @@ class SingleTreeCateInterpreter(_SingleTreeInterpreter):
in a leaf have similar target prediction but also similar alpha
confidence intervals.
uncertainty_only_on_leaves : bool, optional, default True
Whether uncertainty information should be displayed only on leaf nodes.
If False, then interpretation can be slightly slower, especially for cate
models that have a computationally expensive inference method.
splitter : string, optional, default "best"
The strategy used to choose the split at each node. Supported
strategies are "best" to choose the best split and "random" to choose
Expand Down Expand Up @@ -335,6 +341,7 @@ class SingleTreeCateInterpreter(_SingleTreeInterpreter):
def __init__(self,
include_model_uncertainty=False,
uncertainty_level=.1,
uncertainty_only_on_leaves=True,
splitter="best",
max_depth=None,
min_samples_split=2,
Expand All @@ -346,6 +353,7 @@ def __init__(self,
min_impurity_decrease=0.):
self.include_uncertainty = include_model_uncertainty
self.uncertainty_level = uncertainty_level
self.uncertainty_only_on_leaves = uncertainty_only_on_leaves
self.criterion = "mse"
self.splitter = splitter
self.max_depth = max_depth
Expand All @@ -370,20 +378,23 @@ def interpret(self, cate_estimator, X):
min_impurity_decrease=self.min_impurity_decrease)
y_pred = cate_estimator.const_marginal_effect(X)

assert all(d == 1 for d in y_pred.shape[1:]), ("Interpretation is only available for "
"single-dimensional treatments and outcomes")

if y_pred.ndim != 2:
y_pred = y_pred.reshape(-1, 1)

if self.include_uncertainty:
y_lower, y_upper = cate_estimator.const_marginal_effect_interval(X, alpha=self.uncertainty_level)
if y_lower.ndim != 2:
y_lower = y_lower.reshape(-1, 1)
y_upper = y_upper.reshape(-1, 1)
y_pred = np.hstack([y_pred, y_lower, y_upper])
self.tree_model.fit(X, y_pred)

self.tree_model.fit(X, y_pred.reshape((y_pred.shape[0], -1)))
paths = self.tree_model.decision_path(X)
node_dict = {}
for node_id in range(paths.shape[1]):
mask = paths.getcol(node_id).toarray().flatten().astype(bool)
Xsub = X[mask]
if (self.include_uncertainty and
((not self.uncertainty_only_on_leaves) or (self.tree_model.tree_.children_left[node_id] < 0))):
res = cate_estimator.const_marginal_ate_inference(Xsub)
node_dict[node_id] = {'mean': res.mean_point,
'std': res.std_point,
'ci': res.conf_int_mean(alpha=self.uncertainty_level)}
else:
cate_node = y_pred[mask]
node_dict[node_id] = {'mean': np.mean(cate_node, axis=0),
'std': np.std(cate_node, axis=0)}
self.node_dict = node_dict
return self

def _make_dot_exporter(self, *, out_file, feature_names, filled,
Expand Down
Loading

0 comments on commit 35c5418

Please sign in to comment.