Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vasilis/autoinference #307

Merged
merged 38 commits into from
Nov 11, 2020
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
08d5895
changed subsampledhonest forest to not alter the entries of each tree…
vasilismsr Nov 7, 2020
4f1fe19
replaced copy with empty_like
vasilismsr Nov 7, 2020
e4aac6a
added feature improtances in dr learner example notebook
vasilismsr Nov 7, 2020
7744197
added feature_importances_ to DML example notebook
vasilismsr Nov 7, 2020
ebae662
enabled feature_importances_ for forestDML and forestDRLearner as an …
vasilismsr Nov 7, 2020
7128230
fixed doctest in subsample honest forest which was producing old feat…
vasilismsr Nov 7, 2020
cdf0cf9
fixed missing .shape in new test_dml
vasilismsr Nov 7, 2020
297a2dc
Merge branch 'master' into vasilis/feature_importances
vsyrgkanis Nov 8, 2020
7567f57
Update econml/sklearn_extensions/ensemble.py
vsyrgkanis Nov 8, 2020
c9a0103
changed order of mixins
vasilismsr Nov 8, 2020
baa27c2
Merge branch 'vasilis/feature_importances' of github.com:microsoft/Ec…
vasilismsr Nov 8, 2020
5c5d262
changed feature_improtances_ implementation as normalization at the t…
vasilismsr Nov 8, 2020
5860a37
fixed docstring reference
vasilismsr Nov 8, 2020
10186d2
fixed the problem with inconsistent impurities in subsampled honest f…
vasilismsr Nov 8, 2020
5e3dd95
fixed docstring doctest
vasilismsr Nov 8, 2020
92851a8
Transformed sparse matrices to dense matrices after dot product in pa…
vasilismsr Nov 9, 2020
80ad7f4
turned inference to 'auto' for many estimators. Enabled arbitrary lin…
vasilismsr Nov 9, 2020
3f71a7a
Added StatsModelsRLM and added tests for testing the auto and the sta…
vasilismsr Nov 9, 2020
23d26fe
handle column y vector in RLM
vasilismsr Nov 9, 2020
ed17a31
Merge branch 'master' into vasilis/autoinference
vsyrgkanis Nov 9, 2020
bcee64b
fixed docstrings of RLM to allow for column y
vasilismsr Nov 9, 2020
3b96e45
lintinh
vasilismsr Nov 9, 2020
958ad1f
added explanatory text to summary(). Changed intercept to cate_interc…
vasilismsr Nov 11, 2020
17b36b5
fixed some wording in summary text
vasilismsr Nov 11, 2020
9e616eb
fixed some wording in summary text
vasilismsr Nov 11, 2020
f73dedb
fixed some wording in summary text
vasilismsr Nov 11, 2020
eaae12b
fixed some wording in summary text
vasilismsr Nov 11, 2020
741a5bf
fixed some wording in summary text
vasilismsr Nov 11, 2020
3528fed
linting
vasilismsr Nov 11, 2020
dbe6c78
Merge branch 'master' into vasilis/autoinference
vsyrgkanis Nov 11, 2020
04bef7c
updated test_inference to check for cate_intercept key not intercpet …
vasilismsr Nov 11, 2020
6cbc0fd
linting
vasilismsr Nov 11, 2020
6a13b3f
changed test_inference to check for correct zstat in degenerate cases…
vasilismsr Nov 11, 2020
d21e86f
Update econml/dml.py
vsyrgkanis Nov 11, 2020
1c6925a
Update econml/dml.py
vsyrgkanis Nov 11, 2020
74b8bbf
Update econml/dml.py
vsyrgkanis Nov 11, 2020
f98799e
dml groups
vasilismsr Nov 11, 2020
1b45f51
Merge branch 'vasilis/autoinference' of github.com:microsoft/EconML i…
vasilismsr Nov 11, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 37 additions & 4 deletions econml/cate_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,11 @@ class LinearModelFinalCateEstimatorMixin(BaseCateEstimator):
than as a separate ``intercept_``
"""

def _get_inference_options(self):
options = super()._get_inference_options()
options.update(auto=LinearModelFinalInference)
return options

bias_part_of_coef = False

@property
Expand Down Expand Up @@ -609,6 +614,14 @@ def summary(self, alpha=0.1, value=0, decimals=3, feat_name=None):
converted to various output formats.
"""
smry = Summary()
smry.add_extra_txt(["<sub>A linear parametric conditional average treatment effect (CATE) model was fitted:",
"$Y = \\Theta(X)\cdot T + g(X, W) + \\epsilon$",
"where for every outcome $i$ and treatment $j$ the CATE $\\Theta_{ij}(X)$ has the form:",
"$\\Theta_{ij}(X) = \\phi(X)' coef_{ij} + cate\\_intercept_{ij}$",
"where $\\phi(X)$ is the output of the `featurizer` or $X$ if `featurizer`=None. "
"Coefficient Results table portrays the $coef_{ij}$ parameter vector for "
"each outcome $i$ and treatment $j$. "
"Intercept Results table portrays the $cate\\_intercept_{ij}$ parameter.</sub>"])
d_t = self._d_t[0] if self._d_t else 1
d_y = self._d_y[0] if self._d_y else 1
try:
Expand All @@ -630,10 +643,10 @@ def summary(self, alpha=0.1, value=0, decimals=3, feat_name=None):
in intercept_table.columns] if d_t > 1 else intercept_table.columns.tolist()
intercept_stubs = [i + ' | ' + j for (i, j)
in intercept_table.index] if d_y > 1 else intercept_table.index.tolist()
intercept_title = 'Intercept Results'
intercept_title = 'CATE Intercept Results'
smry.add_table(intercept_array, intercept_headers, intercept_stubs, intercept_title)
except Exception as e:
print("Intercept Results: ", str(e))
print("CATE Intercept Results: ", str(e))
if len(smry.tables) > 0:
return smry

Expand All @@ -652,6 +665,7 @@ def _get_inference_options(self):
# add statsmodels to parent's options
options = super()._get_inference_options()
options.update(statsmodels=StatsModelsInference)
options.update(auto=StatsModelsInference)
return options


Expand All @@ -662,6 +676,7 @@ def _get_inference_options(self):
# add debiasedlasso to parent's options
options = super()._get_inference_options()
options.update(debiasedlasso=LinearModelFinalInference)
options.update(auto=LinearModelFinalInference)
kbattocchi marked this conversation as resolved.
Show resolved Hide resolved
return options


Expand All @@ -671,6 +686,7 @@ def _get_inference_options(self):
# add blb to parent's options
options = super()._get_inference_options()
options.update(blb=GenericSingleTreatmentModelFinalInference)
options.update(auto=GenericSingleTreatmentModelFinalInference)
return options

@property
Expand All @@ -687,6 +703,11 @@ class LinearModelFinalCateEstimatorDiscreteMixin(BaseCateEstimator):
returning an array of the fitted models for each non-control treatment
"""

def _get_inference_options(self):
options = super()._get_inference_options()
options.update(auto=LinearModelFinalInferenceDiscrete)
return options

def coef_(self, T):
""" The coefficients in the linear model of the constant marginal treatment
effect associated with treatment T.
Expand Down Expand Up @@ -826,6 +847,15 @@ def summary(self, T, *, alpha=0.1, value=0, decimals=3, feat_name=None):
converted to various output formats.
"""
smry = Summary()
smry.add_extra_txt(["<sub>A linear parametric conditional average treatment effect (CATE) model was fitted:",
"$Y = \\Theta(X)\\cdot T + g(X, W) + \\epsilon$",
"where $T$ is the one-hot-encoding of the discrete treatment and "
"for every outcome $i$ and treatment $j$ the CATE $\\Theta_{ij}(X)$ has the form:",
"$\\Theta_{ij}(X) = \\phi(X)' coef_{ij} + cate\\_intercept_{ij}$",
"where $\\phi(X)$ is the output of the `featurizer` or $X$ if `featurizer`=None. "
"Coefficient Results table portrays the $coef_{ij}$ parameter vector for "
"each outcome $i$ and the designated treatment $j$ passed to summary. "
"Intercept Results table portrays the $cate\\_intercept_{ij}$ parameter.</sub>"])
try:
coef_table = self.coef__inference(T).summary_frame(
alpha=alpha, value=value, decimals=decimals, feat_name=feat_name)
Expand All @@ -842,10 +872,10 @@ def summary(self, T, *, alpha=0.1, value=0, decimals=3, feat_name=None):
intercept_array = intercept_table.values
intercept_headers = intercept_table.columns.tolist()
intercept_stubs = intercept_table.index.tolist()
intercept_title = 'Intercept Results'
intercept_title = 'CATE Intercept Results'
smry.add_table(intercept_array, intercept_headers, intercept_stubs, intercept_title)
except Exception as e:
print("Intercept Results: ", e)
print("CATE Intercept Results: ", e)

if len(smry.tables) > 0:
return smry
Expand All @@ -866,6 +896,7 @@ def _get_inference_options(self):
# add statsmodels to parent's options
options = super()._get_inference_options()
options.update(statsmodels=StatsModelsInferenceDiscrete)
options.update(auto=StatsModelsInferenceDiscrete)
return options


Expand All @@ -876,6 +907,7 @@ def _get_inference_options(self):
# add statsmodels to parent's options
options = super()._get_inference_options()
options.update(debiasedlasso=LinearModelFinalInferenceDiscrete)
options.update(auto=LinearModelFinalInferenceDiscrete)
kbattocchi marked this conversation as resolved.
Show resolved Hide resolved
return options


Expand All @@ -885,6 +917,7 @@ def _get_inference_options(self):
# add blb to parent's options
options = super()._get_inference_options()
options.update(blb=GenericModelFinalInferenceDiscrete)
options.update(auto=GenericModelFinalInferenceDiscrete)
return options

def feature_importances_(self, T):
Expand Down
34 changes: 31 additions & 3 deletions econml/dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,34 @@ def __init__(self,
n_splits=n_splits,
random_state=random_state)

# override only so that we can update the docstring to indicate support for `StatsModelsInference`
def fit(self, Y, T, X=None, W=None, sample_weight=None, sample_var=None, inference='auto'):
vsyrgkanis marked this conversation as resolved.
Show resolved Hide resolved
"""
Estimate the counterfactual model from data, i.e. estimates functions τ(·,·,·), ∂τ(·,·).

Parameters
----------
Y: (n × d_y) matrix or vector of length n
Outcomes for each sample
T: (n × dₜ) matrix or vector of length n
Treatments for each sample
X: optional (n × dₓ) matrix
Features for each sample
W: optional (n × d_w) matrix
Controls for each sample
sample_weight: optional (n,) vector
Weights for each row
inference: string, :class:`.Inference` instance, or None
Method for performing inference. This estimator supports 'bootstrap'
(or an instance of :class:`.BootstrapInference`) and 'auto'
(or an instance of :class:`.LinearModelFinalInference`)

Returns
-------
self
"""
return super().fit(Y, T, X=X, W=W, sample_weight=sample_weight, sample_var=sample_var, inference=inference)
vsyrgkanis marked this conversation as resolved.
Show resolved Hide resolved


class LinearDMLCateEstimator(StatsModelsCateEstimatorMixin, DMLCateEstimator):
"""
Expand Down Expand Up @@ -527,7 +555,7 @@ def __init__(self,
random_state=random_state)

# override only so that we can update the docstring to indicate support for `StatsModelsInference`
def fit(self, Y, T, X=None, W=None, sample_weight=None, sample_var=None, inference=None):
def fit(self, Y, T, X=None, W=None, sample_weight=None, sample_var=None, inference='auto'):
"""
Estimate the counterfactual model from data, i.e. estimates functions τ(·,·,·), ∂τ(·,·).

Expand Down Expand Up @@ -665,7 +693,7 @@ def __init__(self,
n_splits=n_splits,
random_state=random_state)

def fit(self, Y, T, X=None, W=None, sample_weight=None, sample_var=None, inference=None):
def fit(self, Y, T, X=None, W=None, sample_weight=None, sample_var=None, inference='auto'):
"""
Estimate the counterfactual model from data, i.e. estimates functions τ(·,·,·), ∂τ(·,·).

Expand Down Expand Up @@ -1059,7 +1087,7 @@ def __init__(self,
categories=categories,
n_splits=n_crossfit_splits, random_state=random_state)

def fit(self, Y, T, X=None, W=None, sample_weight=None, sample_var=None, inference=None):
def fit(self, Y, T, X=None, W=None, sample_weight=None, sample_var=None, inference='auto'):
"""
Estimate the counterfactual model from data, i.e. estimates functions τ(·,·,·), ∂τ(·,·).

Expand Down
Loading