Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mehei/otherinferences #203

Merged
merged 28 commits into from
Feb 15, 2020
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
613947b
add other analytical inferences for const_marignal_effect
heimengqi Dec 3, 2019
1811f3b
add effect inference and population summary of inference
heimengqi Dec 17, 2019
20cf21b
add docstring
heimengqi Dec 17, 2019
4dfc5d2
Update setup.cfg to avoid Linux segfault issue
kbattocchi Dec 18, 2019
dfa0b3d
linting error
heimengqi Dec 18, 2019
4a6db8f
Merge branch 'mehei/otherinferences' of https://github.com/microsoft/…
heimengqi Dec 18, 2019
9be90a1
Panel() function has been deprecated, change to alternatives
heimengqi Dec 19, 2019
506aa91
fix debiased lasso prediction shape when y is a vector
heimengqi Dec 19, 2019
dec5424
improve population summary, update notebook and test
heimengqi Dec 26, 2019
3075c57
Merge branch 'master' into mehei/otherinferences
heimengqi Dec 26, 2019
01f5be5
linting error
heimengqi Dec 27, 2019
eee32da
Merge branch 'mehei/otherinferences' of https://github.com/microsoft/…
heimengqi Dec 27, 2019
e8da90d
syntax error
heimengqi Dec 27, 2019
7bd9bbb
update docstring
heimengqi Dec 27, 2019
6040702
support inferences for drlearner, update notebook and add test, chang…
heimengqi Jan 2, 2020
359eeba
Add coef__inference and intercept__inference, improvement on output, …
heimengqi Jan 7, 2020
6baa4b2
solve review comment
heimengqi Jan 9, 2020
ed81d96
Merge branch 'master' into mehei/otherinferences
heimengqi Jan 9, 2020
b7bb95a
add test to check whether the CI from inference class equals to CI fr…
heimengqi Jan 9, 2020
253c255
delete the test notebook
heimengqi Feb 5, 2020
40eba63
check fit_cate_intercept for drlearner and throw an error for effect_…
heimengqi Feb 5, 2020
3b22021
Merge branch 'mehei/otherinferences' of https://github.com/microsoft/…
heimengqi Feb 5, 2020
a794c12
Merge branch 'master' into mehei/otherinferences
heimengqi Feb 6, 2020
83552c9
fix linting error and drlearner inference support 2d response array
heimengqi Feb 6, 2020
c034308
add summary function for coef and intercept, fix shape inconsistence …
heimengqi Feb 11, 2020
dde15ce
solve conflict
heimengqi Feb 11, 2020
f3eb706
solve review comment
heimengqi Feb 14, 2020
101140c
Merge branch 'master' into mehei/otherinferences
heimengqi Feb 14, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9,395 changes: 9,395 additions & 0 deletions Other inferences test.ipynb

Large diffs are not rendered by default.

10 changes: 10 additions & 0 deletions econml/_ortho_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,11 +543,21 @@ def const_marginal_effect_interval(self, X=None, *, alpha=0.1):
return super().const_marginal_effect_interval(X, alpha=alpha)
const_marginal_effect_interval.__doc__ = LinearCateEstimator.const_marginal_effect_interval.__doc__

def const_marginal_effect_inference(self, X=None):
self._check_fitted_dims(X)
return super().const_marginal_effect_inference(X)
const_marginal_effect_inference.__doc__ = LinearCateEstimator.const_marginal_effect_inference.__doc__

def effect_interval(self, X=None, *, T0=0, T1=1, alpha=0.1):
self._check_fitted_dims(X)
return super().effect_interval(X, T0=T0, T1=T1, alpha=alpha)
effect_interval.__doc__ = LinearCateEstimator.effect_interval.__doc__

def effect_inference(self, X=None, *, T0=0, T1=1):
self._check_fitted_dims(X)
return super().effect_inference(X, T0=T0, T1=T1)
effect_inference.__doc__ = LinearCateEstimator.effect_inference.__doc__

def score(self, Y, T, X=None, W=None, Z=None):
"""
Score the fitted CATE model on a new data set. Generates nuisance parameters
Expand Down
139 changes: 138 additions & 1 deletion econml/cate_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .inference import BootstrapInference
from .utilities import tensordot, ndim, reshape, shape, parse_final_model_params, inverse_onehot
from .inference import StatsModelsInference, StatsModelsInferenceDiscrete, LinearModelFinalInference,\
LinearModelFinalInferenceDiscrete
LinearModelFinalInferenceDiscrete, InferenceResults


class BaseCateEstimator(metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -197,6 +197,30 @@ def effect_interval(self, X=None, *, T0=0, T1=1, alpha=0.1):
"""
pass

@_defer_to_inference
def effect_inference(self, X=None, *, T0=0, T1=1):
""" Inference results for the quantities :math:`\\tau(X, T0, T1)` produced
by the model. Available only when ``inference`` is not ``None``, when
calling the fit method.

Parameters
----------
X: optional (m, d_x) matrix
Features for each sample
T0: optional (m, d_t) matrix or vector of length m (Default=0)
Base treatments for each sample
T1: optional (m, d_t) matrix or vector of length m (Default=1)
Target treatments for each sample

Returns
-------
InferenceResults: object
The inference results instance contains prediction and prediction standard error and
can on demand calculate confidence interval, z statistic and p value. It can also output
a dataframe summary of these inference results.
"""
pass

@_defer_to_inference
def marginal_effect_interval(self, T, X=None, *, alpha=0.1):
""" Confidence intervals for the quantities :math:`\\partial \\tau(T, X)` produced
Expand All @@ -221,6 +245,28 @@ def marginal_effect_interval(self, T, X=None, *, alpha=0.1):
"""
pass

@_defer_to_inference
def marginal_effect_inference(self, T, X=None):
""" Inference results for the quantities :math:`\\partial \\tau(T, X)` produced
by the model. Available only when ``inference`` is not ``None``, when
calling the fit method.

Parameters
----------
T: (m, d_t) matrix
Base treatments for each sample
X: optional (m, d_x) matrix or None (Default=None)
Features for each sample

Returns
-------
InferenceResults: object
The inference results instance contains prediction and prediction standard error and
can on demand calculate confidence interval, z statistic and p value. It can also output
a dataframe summary of these inference results.
"""
pass


class LinearCateEstimator(BaseCateEstimator):
"""Base class for all CATE estimators with linear treatment effects in this package."""
Expand Down Expand Up @@ -324,6 +370,18 @@ def marginal_effect_interval(self, T, X=None, *, alpha=0.1):
for eff in effs)
marginal_effect_interval.__doc__ = BaseCateEstimator.marginal_effect_interval.__doc__

def marginal_effect_inference(self, T, X=None):
X, T = self._expand_treatments(X, T)
cme_inf = self.const_marginal_effect_inference(X=X)
pred = cme_inf.point_estimate
pred_stderr = cme_inf.stderr
if X is None:
pred = np.repeat(pred, shape(T)[0], axis=0)
pred_stderr = np.repeat(pred_stderr, shape(T)[0], axis=0)
return InferenceResults(d_t=cme_inf.d_t, d_y=cme_inf.d_y, pred=pred,
pred_stderr=pred_stderr, inf_type='effect', pred_dist=None, fn_transformer=None)
marginal_effect_inference.__doc__ = BaseCateEstimator.marginal_effect_inference.__doc__

@BaseCateEstimator._defer_to_inference
def const_marginal_effect_interval(self, X=None, *, alpha=0.1):
""" Confidence intervals for the quantities :math:`\\theta(X)` produced
Expand All @@ -346,6 +404,26 @@ def const_marginal_effect_interval(self, X=None, *, alpha=0.1):
"""
pass

@BaseCateEstimator._defer_to_inference
def const_marginal_effect_inference(self, X=None):
""" Inference results for the quantities :math:`\\theta(X)` produced
by the model. Available only when ``inference`` is not ``None``, when
calling the fit method.

Parameters
----------
X: optional (m, d_x) matrix or None (Default=None)
Features for each sample

Returns
-------
InferenceResults: object
The inference results instance contains prediction and prediction standard error and
can on demand calculate confidence interval, z statistic and p value. It can also output
a dataframe summary of these inference results.
"""
pass


class TreatmentExpansionMixin(BaseCateEstimator):
"""Mixin which automatically handles promotions of scalar treatments to the appropriate shape."""
Expand Down Expand Up @@ -454,6 +532,18 @@ def coef__interval(self, *, alpha=0.1):
"""
pass

@BaseCateEstimator._defer_to_inference
def coef__inference(self):
""" The inference of coefficients in the linear model of the constant marginal treatment
effect.

Returns
-------
InferenceResults: object
The inference of the coefficients in the final linear model
"""
pass

@BaseCateEstimator._defer_to_inference
def intercept__interval(self, *, alpha=0.1):
""" The intercept in the linear model of the constant marginal treatment
Expand All @@ -472,6 +562,18 @@ def intercept__interval(self, *, alpha=0.1):
"""
pass

@BaseCateEstimator._defer_to_inference
def intercept__inference(self):
""" The inference of intercept in the linear model of the constant marginal treatment
effect.

Returns
-------
InferenceResults: object
The inference of the intercept in the final linear model
"""
pass


class StatsModelsCateEstimatorMixin(LinearModelFinalCateEstimatorMixin):
"""
Expand Down Expand Up @@ -568,6 +670,23 @@ def coef__interval(self, T, *, alpha=0.1):
"""
pass

@BaseCateEstimator._defer_to_inference
def coef__inference(self, T):
""" The inference for the coefficients in the linear model of the
constant marginal treatment effect associated with treatment T.

Parameters
----------
T: alphanumeric
The input treatment for which we want the coefficients.

Returns
-------
InferenceResults: object
The inference of the coefficients in the final linear model
"""
pass

@BaseCateEstimator._defer_to_inference
def intercept__interval(self, T, *, alpha=0.1):
""" The intercept in the linear model of the constant marginal treatment
Expand All @@ -588,6 +707,24 @@ def intercept__interval(self, T, *, alpha=0.1):
"""
pass

@BaseCateEstimator._defer_to_inference
def intercept__inference(self, T):
""" The inference of the intercept in the linear model of the constant marginal treatment
effect associated with treatment T.

Parameters
----------
T: alphanumeric
The input treatment for which we want the coefficients.

Returns
-------
InferenceResults: object
The inference of the intercept in the final linear model

"""
pass


class StatsModelsCateEstimatorDiscreteMixin(LinearModelFinalCateEstimatorDiscreteMixin):
"""
Expand Down
Loading