Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Averate Treatment Effect methods to all estimators #365

Merged
merged 14 commits into from
Jan 14, 2021
228 changes: 228 additions & 0 deletions econml/_cate_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,54 @@ def marginal_effect(self, T, X=None):
"""
pass

def ate(self, X=None, *, T0, T1):
"""
Calculate the average treatment effect :math:`E_X[\\tau(X, T0, T1)]`.

The effect is calculated between the two treatment points and is averaged over
the population of X variables.

Parameters
----------
T0: (m, d_t) matrix or vector of length m
Base treatments for each sample
T1: (m, d_t) matrix or vector of length m
Target treatments for each sample
X: optional (m, d_x) matrix
Features for each sample

Returns
-------
τ: float or (d_y,) array
Average treatment effects on each outcome
Note that when Y is a vector rather than a 2-dimensional array, the result will be a scalar
"""
return np.mean(self.effect(X=X, T0=T0, T1=T1), axis=0)

def marginal_ate(self, T, X=None):
"""
Calculate the average marginal effect :math:`E_{T, X}[\\partial\\tau(T, X)]`.

The marginal effect is calculated around a base treatment
point and averaged over the population of X.

Parameters
----------
T: (m, d_t) matrix
Base treatments for each sample
X: optional (m, d_x) matrix
Features for each sample

Returns
-------
grad_tau: (d_y, d_t) array
Average marginal effects on each outcome
Note that when Y or T is a vector rather than a 2-dimensional array,
the corresponding singleton dimensions in the output will be collapsed
(e.g. if both are vectors, then the output of this method will be a scalar)
"""
return np.mean(self.marginal_effect(T, X=X), axis=0)

def _expand_treatments(self, X=None, *Ts):
"""
Given a set of features and treatments, return possibly modified features and treatments.
Expand Down Expand Up @@ -303,6 +351,101 @@ def marginal_effect_inference(self, T, X=None):
"""
pass

@_defer_to_inference
def ate_interval(self, X=None, *, T0, T1, alpha=0.1):
""" Confidence intervals for the quantity :math:`E_X[\\tau(X, T0, T1)]` produced
by the model. Available only when ``inference`` is not ``None``, when
calling the fit method.

Parameters
----------
X: optional (m, d_x) matrix
Features for each sample
T0: optional (m, d_t) matrix or vector of length m (Default=0)
Base treatments for each sample
T1: optional (m, d_t) matrix or vector of length m (Default=1)
Target treatments for each sample
alpha: optional float in [0, 1] (Default=0.1)
The overall level of confidence of the reported interval.
The alpha/2, 1-alpha/2 confidence interval is reported.

Returns
-------
lower, upper : tuple(type of :meth:`ate(X, T0, T1)<ate>`, type of :meth:`ate(X, T0, T1))<ate>` )
The lower and the upper bounds of the confidence interval for each quantity.
"""
pass

@_defer_to_inference
def ate_inference(self, X=None, *, T0, T1):
""" Inference results for the quantity :math:`E_X[\\tau(X, T0, T1)]` produced
by the model. Available only when ``inference`` is not ``None``, when
calling the fit method.

Parameters
----------
X: optional (m, d_x) matrix
Features for each sample
T0: optional (m, d_t) matrix or vector of length m (Default=0)
Base treatments for each sample
T1: optional (m, d_t) matrix or vector of length m (Default=1)
Target treatments for each sample

Returns
-------
PopulationSummaryResults: object
The inference results instance contains prediction and prediction standard error and
can on demand calculate confidence interval, z statistic and p value. It can also output
a dataframe summary of these inference results.
"""
pass

@_defer_to_inference
def marginal_ate_interval(self, T, X=None, *, alpha=0.1):
""" Confidence intervals for the quantities :math:`E_{T,X}[\\partial \\tau(T, X)]` produced
by the model. Available only when ``inference`` is not ``None``, when
calling the fit method.

Parameters
----------
T: (m, d_t) matrix
Base treatments for each sample
X: optional (m, d_x) matrix or None (Default=None)
Features for each sample
alpha: optional float in [0, 1] (Default=0.1)
The overall level of confidence of the reported interval.
The alpha/2, 1-alpha/2 confidence interval is reported.

Returns
-------
lower, upper : tuple(type of :meth:`marginal_ate(T, X)<marginal_ate>`, \
type of :meth:`marginal_ate(T, X)<marginal_ate>` )
The lower and the upper bounds of the confidence interval for each quantity.
"""
pass

@_defer_to_inference
def marginal_ate_inference(self, T, X=None):
""" Inference results for the quantities :math:`E_{T,X}[\\partial \\tau(T, X)]` produced
by the model. Available only when ``inference`` is not ``None``, when
calling the fit method.

Parameters
----------
T: (m, d_t) matrix
Base treatments for each sample
X: optional (m, d_x) matrix or None (Default=None)
Features for each sample

Returns
-------
PopulationSummaryResults: object
The inference results instance contains prediction and prediction standard error and
can on demand calculate confidence interval, z statistic and p value. It can also output
a dataframe summary of these inference results.
"""
pass


class LinearCateEstimator(BaseCateEstimator):
"""Base class for all CATE estimators with linear treatment effects in this package."""
Expand Down Expand Up @@ -457,6 +600,79 @@ def const_marginal_effect_inference(self, X=None):
"""
pass

def const_marginal_ate(self, X=None):
"""
Calculate the average constant marginal CATE :math:`E_X[\\theta(X)]`.

Parameters
----------
X: optional (m, d_x) matrix or None (Default=None)
Features for each sample.

Returns
-------
theta: (d_y, d_t) matrix
Average constant marginal CATE of each treatment on each outcome.
Note that when Y or T is a vector rather than a 2-dimensional array,
the corresponding singleton dimensions in the output will be collapsed
(e.g. if both are vectors, then the output of this method will be a scalar)
"""
return np.mean(self.const_marginal_effect(X=X), axis=0)

@BaseCateEstimator._defer_to_inference
def const_marginal_ate_interval(self, X=None, *, alpha=0.1):
""" Confidence intervals for the quantities :math:`E_X[\\theta(X)]` produced
by the model. Available only when ``inference`` is not ``None``, when
calling the fit method.

Parameters
----------
X: optional (m, d_x) matrix or None (Default=None)
Features for each sample
alpha: optional float in [0, 1] (Default=0.1)
The overall level of confidence of the reported interval.
The alpha/2, 1-alpha/2 confidence interval is reported.

Returns
-------
lower, upper : tuple(type of :meth:`const_marginal_ate(X)<const_marginal_ate>` ,\
type of :meth:`const_marginal_ate(X)<const_marginal_ate>` )
The lower and the upper bounds of the confidence interval for each quantity.
"""
pass

@BaseCateEstimator._defer_to_inference
def const_marginal_ate_inference(self, X=None):
""" Inference results for the quantities :math:`E_X[\\theta(X)]` produced
by the model. Available only when ``inference`` is not ``None``, when
calling the fit method.

Parameters
----------
X: optional (m, d_x) matrix or None (Default=None)
Features for each sample

Returns
-------
PopulationSummaryResults: object
The inference results instance contains prediction and prediction standard error and
can on demand calculate confidence interval, z statistic and p value. It can also output
a dataframe summary of these inference results.
"""
pass

def marginal_ate(self, T, X=None):
return self.const_marginal_ate(X=X)
marginal_ate.__doc__ = BaseCateEstimator.marginal_ate.__doc__

def marginal_ate_interval(self, T, X=None, *, alpha=0.1):
return self.const_marginal_ate_interval(X=X, alpha=alpha)
marginal_ate_interval.__doc__ = BaseCateEstimator.marginal_ate_interval.__doc__

def marginal_ate_inference(self, T, X=None):
return self.const_marginal_ate_inference(X=X)
marginal_ate_inference.__doc__ = BaseCateEstimator.marginal_ate_inference.__doc__

def shap_values(self, X, *, feature_names=None, treatment_names=None, output_names=None, background_samples=100):
""" Shap value for the final stage models (const_marginal_effect)

Expand Down Expand Up @@ -524,6 +740,18 @@ def effect(self, X=None, *, T0=0, T1=1):
return super().effect(X, T0=T0, T1=T1)
effect.__doc__ = BaseCateEstimator.effect.__doc__

def ate(self, X=None, *, T0=0, T1=1):
return super().ate(X=X, T0=T0, T1=T1)
ate.__doc__ = BaseCateEstimator.ate.__doc__

def ate_interval(self, X=None, *, T0=0, T1=1, alpha=0.1):
return super().ate_interval(X=X, T0=T0, T1=T1, alpha=alpha)
ate_interval.__doc__ = BaseCateEstimator.ate_interval.__doc__

def ate_inference(self, X=None, *, T0=0, T1=1):
return super().ate_inference(X=X, T0=T0, T1=T1)
ate_inference.__doc__ = BaseCateEstimator.ate_inference.__doc__


class LinearModelFinalCateEstimatorMixin(BaseCateEstimator):
"""
Expand Down
7 changes: 7 additions & 0 deletions econml/cate_interpreter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

from ._interpreters import SingleTreeCateInterpreter, SingleTreePolicyInterpreter

__all__ = ["SingleTreeCateInterpreter",
"SingleTreePolicyInterpreter"]
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
class _SingleTreeInterpreter(metaclass=abc.ABCMeta):

tree_model = None
node_dict = None

@abc.abstractmethod
def interpret(self, cate_estimator, X):
Expand Down Expand Up @@ -156,7 +157,7 @@ def export_graphviz(self, out_file=None, feature_names=None,
exporter = self._make_dot_exporter(out_file=out_file, feature_names=feature_names, filled=filled,
leaves_parallel=leaves_parallel, rotate=rotate, rounded=rounded,
special_characters=special_characters, precision=precision)
exporter.export(self.tree_model)
exporter.export(self.tree_model, node_dict=self.node_dict)

if return_string:
return out_file.getvalue()
Expand Down Expand Up @@ -249,7 +250,7 @@ def plot(self, ax=None, title=None, feature_names=None,
check_is_fitted(self.tree_model, 'tree_')
exporter = self._make_mpl_exporter(title=title, feature_names=feature_names, filled=filled,
rounded=rounded, precision=precision, fontsize=fontsize)
exporter.export(self.tree_model, ax=ax)
exporter.export(self.tree_model, node_dict=self.node_dict, ax=ax)


class SingleTreeCateInterpreter(_SingleTreeInterpreter):
Expand All @@ -261,7 +262,7 @@ class SingleTreeCateInterpreter(_SingleTreeInterpreter):
include_uncertainty : bool, optional, default False
Whether to include confidence interval information when building a
simplified model of the cate model. If set to True, then
cate estimator needs to support the `effect_interval` method.
cate estimator needs to support the `const_marginal_ate_inference` method.

uncertainty_level : double, optional, default .05
The uncertainty level for the confidence intervals to be constructed
Expand All @@ -270,6 +271,11 @@ class SingleTreeCateInterpreter(_SingleTreeInterpreter):
in a leaf have similar target prediction but also similar alpha
confidence intervals.

uncertainty_only_on_leaves : bool, optional, default True
Whether uncertainty information should be displayed only on leaf nodes.
If False, then interpretation can be slightly slower, especially for cate
models that have a computationally expensive inference method.

splitter : string, optional, default "best"
The strategy used to choose the split at each node. Supported
strategies are "best" to choose the best split and "random" to choose
Expand Down Expand Up @@ -335,6 +341,7 @@ class SingleTreeCateInterpreter(_SingleTreeInterpreter):
def __init__(self,
include_model_uncertainty=False,
uncertainty_level=.1,
uncertainty_only_on_leaves=True,
splitter="best",
max_depth=None,
min_samples_split=2,
Expand All @@ -346,6 +353,7 @@ def __init__(self,
min_impurity_decrease=0.):
self.include_uncertainty = include_model_uncertainty
self.uncertainty_level = uncertainty_level
self.uncertainty_only_on_leaves = uncertainty_only_on_leaves
self.criterion = "mse"
self.splitter = splitter
self.max_depth = max_depth
Expand All @@ -370,20 +378,23 @@ def interpret(self, cate_estimator, X):
min_impurity_decrease=self.min_impurity_decrease)
y_pred = cate_estimator.const_marginal_effect(X)

assert all(d == 1 for d in y_pred.shape[1:]), ("Interpretation is only available for "
"single-dimensional treatments and outcomes")

if y_pred.ndim != 2:
y_pred = y_pred.reshape(-1, 1)

if self.include_uncertainty:
y_lower, y_upper = cate_estimator.const_marginal_effect_interval(X, alpha=self.uncertainty_level)
if y_lower.ndim != 2:
y_lower = y_lower.reshape(-1, 1)
y_upper = y_upper.reshape(-1, 1)
y_pred = np.hstack([y_pred, y_lower, y_upper])
self.tree_model.fit(X, y_pred)

self.tree_model.fit(X, y_pred.reshape((y_pred.shape[0], -1)))
paths = self.tree_model.decision_path(X)
node_dict = {}
for node_id in range(paths.shape[1]):
mask = paths.getcol(node_id).toarray().flatten().astype(bool)
Xsub = X[mask]
if (self.include_uncertainty and
((not self.uncertainty_only_on_leaves) or (self.tree_model.tree_.children_left[node_id] < 0))):
res = cate_estimator.const_marginal_ate_inference(Xsub)
node_dict[node_id] = {'mean': res.mean_point,
'std': res.std_point,
'ci': res.conf_int_mean(alpha=self.uncertainty_level)}
else:
cate_node = y_pred[mask]
node_dict[node_id] = {'mean': np.mean(cate_node, axis=0),
'std': np.std(cate_node, axis=0)}
self.node_dict = node_dict
return self

def _make_dot_exporter(self, *, out_file, feature_names, filled,
Expand Down
Loading