-
Notifications
You must be signed in to change notification settings - Fork 726
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enforce pred_var
is always greater than zero on GRF
#480
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Weird that the bayesian debiasing was not ensuring that, but maybe here you were getting exact zeros, which would also be problematic.
I was surprised also when I was looking through the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this contribution! Looks good, although I added one minor suggestion.
Also, would it be possible to add a simple test where the current code is failing but this change works so that we can make sure not to regress in the future?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new changes look good, thanks for contributing! I'm approving the PR because it looks good code-wise, but before we can merge it there are two issues:
- A minor line-too-long linting problem
- We have a real test failure when running the
notebooks/Generalized Random Forests.ipynb
notebook where the assertion is being triggered. Could you check if this is just a case where we should be using a slightly looser tolerance, or if we're really getting big negative values there for some reason?
(there are also a couple of other random test failures that I suspect are sporadic and could be fixed by just rerunning)
econml/grf/_base_grf.py
Outdated
@@ -793,10 +793,13 @@ def predict_full(self, X, interval=False, alpha=0.05): | |||
""" | |||
if interval: | |||
point, pred_var = self._predict_point_and_var(X, full=True, point=True, var=True) | |||
assert np.isclose(pred_var[pred_var < 0], 0, atol=1e-8).all(), '`pred_var` should not produce large negative values' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately this line is failing our linting step because it's too long, so you'll need to break it up over two lines instead before we can merge.
Some of the failed test appears to be caused by an import/ |
@arose13 The transient test failures are gone; the remaining notebook failures are due to triggering the assert within |
This is the list of negative numbers that the GRF is generating in cell 5 of the
Let me know what you think. |
I'm confused: the cell below plots thje confidence interval and seems to have no problem. Is this triggered by some change here? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
ensure that the
pred_var
is always greater than zero.This prevents NaNs from being created for some outputted values when creating the confidence interval.
The NaNs were previously being created when the variance was converted to the sd for scipy's distribution models.
PS: I also removed duplicated code that used to appear on line 798 and 799