Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Averate Treatment Effect methods to all estimators #365

Merged
merged 14 commits into from
Jan 14, 2021
78 changes: 60 additions & 18 deletions econml/_tree_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ def __init__(self, *args, title=None, **kwargs):
self.title = title
super().__init__(*args, **kwargs)

def export(self, decision_tree, ax=None):
def export(self, decision_tree, node_dict=None, ax=None):
if ax is None:
ax = plt.gca()
self.node_dict = node_dict
anns = super().export(decision_tree, ax=ax)
if self.title is not None:
ax.set_title(self.title)
Expand All @@ -65,6 +66,10 @@ def __init__(self, *args, title=None, **kwargs):
self.title = title
super().__init__(*args, **kwargs)

def export(self, decision_tree, node_dict=None):
self.node_dict = node_dict
return super().export(decision_tree)

def tail(self):
if self.title is not None:
self.out_file.write("labelloc=\"t\"; \n")
Expand Down Expand Up @@ -102,27 +107,64 @@ def get_fill_color(self, tree, node_id):
return self.get_color(value)

def node_replacement_text(self, tree, node_id, criterion):
if tree.n_outputs == 1:
value = tree.value[node_id][0, :]
else:
value = tree.value[node_id]

# Write node mean CATE
node_string = 'CATE mean = '
value_text = np.array2string(value[0, 0] if self.include_uncertainty else value[0], precision=self.precision)
node_string += value_text + self.characters[4]
node_info = self.node_dict[node_id]
node_string = 'CATE mean' + self.characters[4]
value_text = ""
mean = node_info['mean']
if hasattr(mean, 'shape') and (len(mean.shape) > 0):
if len(mean.shape) == 1:
for i in range(mean.shape[0]):
value_text += "{}".format(np.around(mean[i], self.precision))
if 'ci' in node_info:
value_text += " ({}, {})".format(np.around(node_info['ci'][0][i], self.precision),
np.around(node_info['ci'][1][i], self.precision))
if i != mean.shape[0] - 1:
value_text += ", "
value_text += self.characters[4]
elif len(mean.shape) == 2:
for i in range(mean.shape[0]):
for j in range(mean.shape[1]):
value_text += "{}".format(np.around(mean[i, j], self.precision))
if 'ci' in node_info:
value_text += " ({}, {})".format(np.around(node_info['ci'][0][i, j], self.precision),
np.around(node_info['ci'][1][i, j], self.precision))
if j != mean.shape[1] - 1:
value_text += ", "
value_text += self.characters[4]
else:
raise ValueError("can only handle up to 2d values")
else:
value_text += "{}".format(np.around(mean, self.precision))
if 'ci' in node_info:
value_text += " ({}, {})".format(np.around(node_info['ci'][0], self.precision),
np.around(node_info['ci'][1], self.precision)) + self.characters[4]
node_string += value_text

# Write node std of CATE
node_string += "CATE std = "
value_text = np.array2string(np.sqrt(np.clip(tree.impurity[node_id], 0, np.inf)), precision=self.precision)
node_string += value_text + self.characters[4]

# Write confidence interval information if at leaf node
if (tree.children_left[node_id] == _tree.TREE_LEAF) and self.include_uncertainty:
ci_text = "Mean Endpoints of {}% CI: ({}, {})".format(int((1 - self.uncertainty_level) * 100),
np.around(value[1, 0], self.precision),
np.around(value[2, 0], self.precision))
node_string += ci_text + self.characters[4]
node_string += "CATE std" + self.characters[4]
std = node_info['std']
value_text = ""
if hasattr(std, 'shape') and (len(std.shape) > 0):
if len(std.shape) == 1:
for i in range(std.shape[0]):
value_text += "{}".format(np.around(std[i], self.precision))
if i != std.shape[0] - 1:
value_text += ", "
elif len(std.shape) == 2:
for i in range(std.shape[0]):
for j in range(std.shape[1]):
value_text += "{}".format(np.around(std[i, j], self.precision))
if j != std.shape[1] - 1:
value_text += ", "
if i != std.shape[0] - 1:
value_text += self.characters[4]
else:
raise ValueError("can only handle up to 2d values")
else:
value_text += "{}".format(np.around(std, self.precision))
node_string += value_text

return node_string

Expand Down
228 changes: 228 additions & 0 deletions econml/cate_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,54 @@ def marginal_effect(self, T, X=None):
"""
pass

def ate(self, X=None, *, T0, T1):
"""
Calculate the average treatment effect :math:`E_X[\\tau(X, T0, T1)]`.

The effect is calculated between the two treatment points and is averaged over
the population of X variables.

Parameters
----------
T0: (m, d_t) matrix or vector of length m
Base treatments for each sample
T1: (m, d_t) matrix or vector of length m
Target treatments for each sample
X: optional (m, d_x) matrix
Features for each sample

Returns
-------
τ: float or (d_y,) array
Average treatment effects on each outcome
Note that when Y is a vector rather than a 2-dimensional array, the result will be a scalar
"""
return np.mean(self.effect(X=X, T0=T0, T1=T1), axis=0)

def marginal_ate(self, T, X=None):
"""
Calculate the average marginal effect :math:`E_{T, X}[\\partial\\tau(T, X)]`.

The marginal effect is calculated around a base treatment
point and averaged over the population of X.

Parameters
----------
T: (m, d_t) matrix
Base treatments for each sample
X: optional (m, d_x) matrix
Features for each sample

Returns
-------
grad_tau: (d_y, d_t) array
Average marginal effects on each outcome
Note that when Y or T is a vector rather than a 2-dimensional array,
the corresponding singleton dimensions in the output will be collapsed
(e.g. if both are vectors, then the output of this method will be a scalar)
"""
return np.mean(self.marginal_effect(T, X=X), axis=0)

def _expand_treatments(self, X=None, *Ts):
"""
Given a set of features and treatments, return possibly modified features and treatments.
Expand Down Expand Up @@ -304,6 +352,101 @@ def marginal_effect_inference(self, T, X=None):
"""
pass

@_defer_to_inference
def ate_interval(self, X=None, *, T0, T1, alpha=0.1):
""" Confidence intervals for the quantity :math:`E_X[\\tau(X, T0, T1)]` produced
by the model. Available only when ``inference`` is not ``None``, when
calling the fit method.

Parameters
----------
X: optional (m, d_x) matrix
Features for each sample
T0: optional (m, d_t) matrix or vector of length m (Default=0)
Base treatments for each sample
T1: optional (m, d_t) matrix or vector of length m (Default=1)
Target treatments for each sample
alpha: optional float in [0, 1] (Default=0.1)
The overall level of confidence of the reported interval.
The alpha/2, 1-alpha/2 confidence interval is reported.

Returns
-------
lower, upper : tuple(type of :meth:`ate(X, T0, T1)<ate>`, type of :meth:`ate(X, T0, T1))<ate>` )
The lower and the upper bounds of the confidence interval for each quantity.
"""
pass

@_defer_to_inference
def ate_inference(self, X=None, *, T0, T1):
""" Inference results for the quantity :math:`E_X[\\tau(X, T0, T1)]` produced
by the model. Available only when ``inference`` is not ``None``, when
calling the fit method.

Parameters
----------
X: optional (m, d_x) matrix
Features for each sample
T0: optional (m, d_t) matrix or vector of length m (Default=0)
Base treatments for each sample
T1: optional (m, d_t) matrix or vector of length m (Default=1)
Target treatments for each sample

Returns
-------
PopulationSummaryResults: object
The inference results instance contains prediction and prediction standard error and
can on demand calculate confidence interval, z statistic and p value. It can also output
a dataframe summary of these inference results.
"""
pass

@_defer_to_inference
def marginal_ate_interval(self, T, X=None, *, alpha=0.1):
""" Confidence intervals for the quantities :math:`E_{T,X}[\\partial \\tau(T, X)]` produced
by the model. Available only when ``inference`` is not ``None``, when
calling the fit method.

Parameters
----------
T: (m, d_t) matrix
Base treatments for each sample
X: optional (m, d_x) matrix or None (Default=None)
Features for each sample
alpha: optional float in [0, 1] (Default=0.1)
The overall level of confidence of the reported interval.
The alpha/2, 1-alpha/2 confidence interval is reported.

Returns
-------
lower, upper : tuple(type of :meth:`marginal_ate(T, X)<marginal_ate>`, \
type of :meth:`marginal_ate(T, X)<marginal_ate>` )
The lower and the upper bounds of the confidence interval for each quantity.
"""
pass

@_defer_to_inference
def marginal_ate_inference(self, T, X=None):
""" Inference results for the quantities :math:`E_{T,X}[\\partial \\tau(T, X)]` produced
by the model. Available only when ``inference`` is not ``None``, when
calling the fit method.

Parameters
----------
T: (m, d_t) matrix
Base treatments for each sample
X: optional (m, d_x) matrix or None (Default=None)
Features for each sample

Returns
-------
PopulationSummaryResults: object
The inference results instance contains prediction and prediction standard error and
can on demand calculate confidence interval, z statistic and p value. It can also output
a dataframe summary of these inference results.
"""
pass


class LinearCateEstimator(BaseCateEstimator):
"""Base class for all CATE estimators with linear treatment effects in this package."""
Expand Down Expand Up @@ -458,6 +601,79 @@ def const_marginal_effect_inference(self, X=None):
"""
pass

def const_marginal_ate(self, X=None):
"""
Calculate the average constant marginal CATE :math:`E_X[\\theta(X)]`.

Parameters
----------
X: optional (m, d_x) matrix or None (Default=None)
Features for each sample.

Returns
-------
theta: (d_y, d_t) matrix
Average constant marginal CATE of each treatment on each outcome.
Note that when Y or T is a vector rather than a 2-dimensional array,
the corresponding singleton dimensions in the output will be collapsed
(e.g. if both are vectors, then the output of this method will be a scalar)
"""
return np.mean(self.const_marginal_effect(X=X), axis=0)

@BaseCateEstimator._defer_to_inference
def const_marginal_ate_interval(self, X=None, *, alpha=0.1):
""" Confidence intervals for the quantities :math:`E_X[\\theta(X)]` produced
by the model. Available only when ``inference`` is not ``None``, when
calling the fit method.

Parameters
----------
X: optional (m, d_x) matrix or None (Default=None)
Features for each sample
alpha: optional float in [0, 1] (Default=0.1)
The overall level of confidence of the reported interval.
The alpha/2, 1-alpha/2 confidence interval is reported.

Returns
-------
lower, upper : tuple(type of :meth:`const_marginal_ate(X)<const_marginal_ate>` ,\
type of :meth:`const_marginal_ate(X)<const_marginal_ate>` )
The lower and the upper bounds of the confidence interval for each quantity.
"""
pass

@BaseCateEstimator._defer_to_inference
def const_marginal_ate_inference(self, X=None):
""" Inference results for the quantities :math:`E_X[\\theta(X)]` produced
by the model. Available only when ``inference`` is not ``None``, when
calling the fit method.

Parameters
----------
X: optional (m, d_x) matrix or None (Default=None)
Features for each sample

Returns
-------
PopulationSummaryResults: object
The inference results instance contains prediction and prediction standard error and
can on demand calculate confidence interval, z statistic and p value. It can also output
a dataframe summary of these inference results.
"""
pass

def marginal_ate(self, T, X=None):
return self.const_marginal_ate(X=X)
marginal_ate.__doc__ = BaseCateEstimator.marginal_ate.__doc__

def marginal_ate_interval(self, T, X=None, *, alpha=0.1):
return self.const_marginal_ate_interval(X=X, alpha=alpha)
marginal_ate_interval.__doc__ = BaseCateEstimator.marginal_ate_interval.__doc__

def marginal_ate_inference(self, T, X=None):
return self.const_marginal_ate_inference(X=X)
marginal_ate_inference.__doc__ = BaseCateEstimator.marginal_ate_inference.__doc__

def shap_values(self, X, *, feature_names=None, treatment_names=None, output_names=None):
""" Shap value for the final stage models (const_marginal_effect)

Expand Down Expand Up @@ -520,6 +736,18 @@ def effect(self, X=None, *, T0=0, T1=1):
return super().effect(X, T0=T0, T1=T1)
effect.__doc__ = BaseCateEstimator.effect.__doc__

def ate(self, X=None, *, T0=0, T1=1):
return super().ate(X=X, T0=T0, T1=T1)
ate.__doc__ = BaseCateEstimator.ate.__doc__

def ate_interval(self, X=None, *, T0=0, T1=1, alpha=0.1):
return super().ate_interval(X=X, T0=T0, T1=T1, alpha=alpha)
ate_interval.__doc__ = BaseCateEstimator.ate_interval.__doc__

def ate_inference(self, X=None, *, T0=0, T1=1):
return super().ate_inference(X=X, T0=T0, T1=T1)
ate_inference.__doc__ = BaseCateEstimator.ate_inference.__doc__


class LinearModelFinalCateEstimatorMixin(BaseCateEstimator):
"""
Expand Down
Loading