Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot obtain inference for DMLCateEstimator #308

Open
yqc924 opened this issue Nov 9, 2020 · 1 comment
Open

Cannot obtain inference for DMLCateEstimator #308

yqc924 opened this issue Nov 9, 2020 · 1 comment

Comments

@yqc924
Copy link

yqc924 commented Nov 9, 2020

Hi,

I was using DMLCateEstimator (X=None, assuming homogeneous treatment effect), but couldn't obtain the inference for the estimator with const_marginal_effect_inference(). Could anyone help me figure out the problem? Thanks a lot!

1. When I set n_splits as default
est1.const_marginal_effect() and const_marginal_effect_interval() work well; but const_marginal_effect_inference() incurred an error.

est1 = DMLCateEstimator(model_y=GradientBoostingRegressor(n_estimators=500), #n_estimators=30, min_samples_leaf=30
                     model_t=GradientBoostingRegressor(n_estimators=500), #n_estimators=30, min_samples_leaf=30
                     model_final=ElasticNetCV(max_iter=500))
est1.fit(Y=Y, T=T1, W=X1, inference='bootstrap') 
print(est1.const_marginal_effect())
print(est1.const_marginal_effect_interval())
est1.const_marginal_effect_inference()
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-29-b224b377fa84> in <module>
----> 1 est1.const_marginal_effect_inference()

~/anaconda3/lib/python3.7/site-packages/econml/_ortho_learner.py in const_marginal_effect_inference(self, X)
    593     def const_marginal_effect_inference(self, X=None):
    594         self._check_fitted_dims(X)
--> 595         return super().const_marginal_effect_inference(X)
    596     const_marginal_effect_inference.__doc__ = LinearCateEstimator.const_marginal_effect_inference.__doc__
    597 

~/anaconda3/lib/python3.7/site-packages/econml/cate_estimator.py in call(self, *args, **kwargs)
    168             name = m.__name__
    169             if self._inference is not None:
--> 170                 return getattr(self._inference, name)(*args, **kwargs)
    171             else:
    172                 raise AttributeError("Can't call '%s' because 'inference' is None" % name)

~/anaconda3/lib/python3.7/site-packages/econml/inference.py in __getattr__(self, name)
     61             raise AttributeError()
     62 
---> 63         m = getattr(self._est, name)
     64 
     65         def wrapped(*args, alpha=0.1, **kwargs):

~/anaconda3/lib/python3.7/site-packages/econml/bootstrap.py in __getattr__(self, name)
    139                 return get_mean()
    140 
--> 141         raise (caught if caught else AttributeError(name))

AttributeError: const_marginal_effect_inference

2. When I customized my own splitter:
n_splits=[(np.arange(train_count),np.arange(train_count, train_count+val_count))]
Similar error happens when I ran const_marginal_effect_inference()
What's worse, the point estimation by est1.const_marginal_effect() is not the mid-point of const_marginal_effect_interval()

3. When I customize my own splitter:
n_splits=StratifiedKFold(n_splits=10).split(X=np.zeros(len(df['stratifier']), y=np.asarray(df['stratifier']))
The est1.fit() method ran but report the error: can't pickle generator objects

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-16-45cc22f76006> in <module>
      8                      discrete_treatment=False, #default
      9                      random_state=None) #default
---> 10 est1.fit(Y=Y, T=T1, W=X1, inference='bootstrap')
     11 #n_splits=StratifiedKFold(n_splits=10).split(X=np.zeros(train_count+val_count), y=np.asarray(df['zipcode']))
     12 #n_splits=[(np.arange(train_count),np.arange(train_count, train_count+val_count))]

~/anaconda3/lib/python3.7/site-packages/econml/_rlearner.py in fit(self, Y, T, X, W, sample_weight, sample_var, inference)
    294         """
    295         # Replacing fit from _OrthoLearner, to enforce Z=None and improve the docstring
--> 296         return super().fit(Y, T, X=X, W=W, sample_weight=sample_weight, sample_var=sample_var, inference=inference)
    297 
    298     def score(self, Y, T, X=None, W=None):

~/anaconda3/lib/python3.7/site-packages/econml/cate_estimator.py in call(self, Y, T, inference, *args, **kwargs)
     90             if inference is not None:
     91                 # NOTE: we call inference fit *after* calling the main fit method
---> 92                 inference.fit(self, Y, T, *args, **kwargs)
     93             self._inference = inference
     94             return self

~/anaconda3/lib/python3.7/site-packages/econml/inference.py in fit(self, estimator, *args, **kwargs)
     53 
     54     def fit(self, estimator, *args, **kwargs):
---> 55         est = BootstrapEstimator(estimator, self._n_bootstrap_samples, self._n_jobs, compute_means=False)
     56         est.fit(*args, **kwargs)
     57         self._est = est

~/anaconda3/lib/python3.7/site-packages/econml/bootstrap.py in __init__(self, wrapped, n_bootstrap_samples, n_jobs, compute_means, prefer_wrapped)
     48 
     49     def __init__(self, wrapped, n_bootstrap_samples=1000, n_jobs=None, compute_means=True, prefer_wrapped=False):
---> 50         self._instances = [clone(wrapped, safe=False) for _ in range(n_bootstrap_samples)]
     51         self._n_bootstrap_samples = n_bootstrap_samples
     52         self._n_jobs = n_jobs

~/anaconda3/lib/python3.7/site-packages/econml/bootstrap.py in <listcomp>(.0)
     48 
     49     def __init__(self, wrapped, n_bootstrap_samples=1000, n_jobs=None, compute_means=True, prefer_wrapped=False):
---> 50         self._instances = [clone(wrapped, safe=False) for _ in range(n_bootstrap_samples)]
     51         self._n_bootstrap_samples = n_bootstrap_samples
     52         self._n_jobs = n_jobs

~/anaconda3/lib/python3.7/site-packages/sklearn/base.py in clone(estimator, safe)
     53     elif not hasattr(estimator, 'get_params') or isinstance(estimator, type):
     54         if not safe:
---> 55             return copy.deepcopy(estimator)
     56         else:
     57             raise TypeError("Cannot clone object '%s' (type %s): "

~/anaconda3/lib/python3.7/copy.py in deepcopy(x, memo, _nil)
    178                     y = x
    179                 else:
--> 180                     y = _reconstruct(x, memo, *rv)
    181 
    182     # If is its own copy, don't memoize.

~/anaconda3/lib/python3.7/copy.py in _reconstruct(x, memo, func, args, state, listiter, dictiter, deepcopy)
    278     if state is not None:
    279         if deep:
--> 280             state = deepcopy(state, memo)
    281         if hasattr(y, '__setstate__'):
    282             y.__setstate__(state)

~/anaconda3/lib/python3.7/copy.py in deepcopy(x, memo, _nil)
    148     copier = _deepcopy_dispatch.get(cls)
    149     if copier:
--> 150         y = copier(x, memo)
    151     else:
    152         try:

~/anaconda3/lib/python3.7/copy.py in _deepcopy_dict(x, memo, deepcopy)
    238     memo[id(x)] = y
    239     for key, value in x.items():
--> 240         y[deepcopy(key, memo)] = deepcopy(value, memo)
    241     return y
    242 d[dict] = _deepcopy_dict

~/anaconda3/lib/python3.7/copy.py in deepcopy(x, memo, _nil)
    167                     reductor = getattr(x, "__reduce_ex__", None)
    168                     if reductor:
--> 169                         rv = reductor(4)
    170                     else:
    171                         reductor = getattr(x, "__reduce__", None)

TypeError: can't pickle generator objects

The est1.const_marginal_effect() method works, but est1.const_marginal_effect_interval() and est1.const_marginal_effect_inference() do not work,

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-17-9816c0c44639> in <module>
      6 
      7 #shape(obs, 2): obtain 95% CI of marginal treatment effect for each obs
----> 8 print(np.shape(est1.const_marginal_effect_interval(alpha=.05)))
      9 print(est1.const_marginal_effect_interval(alpha=.05))
     10 

~/anaconda3/lib/python3.7/site-packages/econml/_ortho_learner.py in const_marginal_effect_interval(self, X, alpha)
    588     def const_marginal_effect_interval(self, X=None, *, alpha=0.1):
    589         self._check_fitted_dims(X)
--> 590         return super().const_marginal_effect_interval(X, alpha=alpha)
    591     const_marginal_effect_interval.__doc__ = LinearCateEstimator.const_marginal_effect_interval.__doc__
    592 

~/anaconda3/lib/python3.7/site-packages/econml/cate_estimator.py in call(self, *args, **kwargs)
    167         def call(self, *args, **kwargs):
    168             name = m.__name__
--> 169             if self._inference is not None:
    170                 return getattr(self._inference, name)(*args, **kwargs)
    171             else:

AttributeError: 'DMLCateEstimator' object has no attribute '_inference'
@kbattocchi
Copy link
Collaborator

Unfortunately bootstrap inference doesn't currently support inference (but this has already been addressed in this repo with fixes that haven't made it into any release yet by #236 and #299). The fact that intervals aren't centered around the point estimate is not necessarily surprising, depending on the nature of the data - it could be that the distribution of bootstrap estimates is highly non-symmetric.

Unless there's a specific reason you need an elastic net for your final stage, you might be better served by using LinearDMLCateEstimator with statsmodels inference (which uses an unregularized final model) or SparseLinearDMLCateEstimator with debiasedLasso inference (which uses a cross-validated debiased lasso final model).

I'll take a closer look at your 3rd point because that looks different from issues that I've seen before. Thanks for the report.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants