-
-
Notifications
You must be signed in to change notification settings - Fork 986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow registering of custom exception handlers for potential_fn computations #3168
Conversation
…putations In some cases evaluation the potential funciton may result in numerical issues. Currently the code hard-codes the handling of a RuntimeError raised when matrices are (numerically) singular. This PR adds the ability to register custom exception handlers. This allows other code depending on pyro to register custom exception handlers without having to modify core pyro code. There are some other places in which `potential_fn` is called that could benefit from being guarded by these handlers (one is `HMC._find_reasonable_step_size`). I'm not sure what the right thing to do there is when encountering numerical isssues, but happy to add this in as needed.
@fritzo is this (or something like this) something you'd be willing to merge in? This will be quite useful for folks who may use some more numerically challenging setups (e.g. GPs) that rely on implementations that throw some other errors. cc @dme65, @saitcakmak |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great to me. Thanks, @Balandat!
The utility will be exposed in docs around this path: https://docs.pyro.ai/en/stable/ops.html#pyro.ops.integrator.potential_grad
Could you fix the lint issue? |
Sure, will do. Any thoughts how to handle numerical issues of this kind in the following: Line 176 in 0b1818c
|
I think you can do similarly
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems reasonable. Since we may also want to use this or something like it in NumPyro, could we discuss the design advantages of the proposed solution over alternative solutions?
1. A higher-level solution: In case of error, MH reject the trajectory and sample a different velocity. I'm unsure this makes sense, but it also seems unclear whether error-as-zero-grad makes sense 🤷
2. A lower-level solution: Wrap the potential function outside of MCMC to silence exceptions:
def safe_potential(fn: Callable[[dict], torch.Tensor]) -> Callable[[dict], torch.Tensor]:
def safe_fn(z: dict) -> torch.Tensor:
try:
return fn(z)
except RuntimeError as e:
if "singular" in str(e) or "input is not positive-definite" in str(e):
return next(iter(z.values())).new_tensor(math.nan)
raise e
return safe_fn
which I guess would also require a change to potential_grad()
- grads = grad(potential_energy, z_nodes)
+ grads = grad(potential_energy, z_nodes, allow_unused=True)
Currently, when errors happen, we set |
I'm not an expert on what the proper thing to do algorithmically is. Ideally the handling of this can happen at a higher level as part of the algorithm. If the proposal is rejected by setting things to The downside of the lower-level solution you proposed is that it will require the user to understand the details of the algorithm to implement this kind of wrapper, rather than "I want to handle these kinds of exceptions and have the algorithm deal with properly rejecting the proposal". |
@Balandat sorry linting seems to be a hassle. We usually use the make targets |
No worries - I should have run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This includes two major changes of interest to botorch: 1. Support for pytorch 2.0 (allowing us to not require pyro dev branch when testing against pytorch nightlies): pyro-ppl/pyro#3164 2. The ability to registering of custom exception handlers for numerical issues: pyro-ppl/pyro#3168
Summary: This includes two major changes of interest to botorch: 1. Support for pytorch 2.0 (allowing us to not require pyro dev branch when testing against pytorch nightlies): pyro-ppl/pyro#3164 2. The ability to registering of custom exception handlers for numerical issues: pyro-ppl/pyro#3168 Pull Request resolved: pytorch#1606 Reviewed By: saitcakmak Differential Revision: D42330914 Pulled By: Balandat fbshipit-source-id: 440d9ff00df60f4ba88acc82e323f9e675caf254
Summary: Uses draft changes from pyro-ppl/pyro#3168 (part of pyro 1.8.4 pulled in via D42331876) to register handling of `torch.linalg.LinAlgError` and the `ValueError` that can be raised in the torch distribution's `__init__()` Differential Revision: D42159791 fbshipit-source-id: ef1f53c9471ba8b1d62ca5ad347e5a471d9bb7a1
Summary: This includes two major changes of interest to botorch: 1. Support for pytorch 2.0 (allowing us to not require pyro dev branch when testing against pytorch nightlies): pyro-ppl/pyro#3164 2. The ability to registering of custom exception handlers for numerical issues: pyro-ppl/pyro#3168 Pull Request resolved: #1606 Reviewed By: saitcakmak Differential Revision: D42330914 Pulled By: Balandat fbshipit-source-id: 1117e7713b99819b4f1297de2d49aeb1bd2b9c9b
Summary: Pull Request resolved: #1607 Uses draft changes from pyro-ppl/pyro#3168 (part of pyro 1.8.4 pulled in via D42331876) to register handling of `torch.linalg.LinAlgError` and the `ValueError` that can be raised in the torch distribution's `__init__()` Reviewed By: saitcakmak Differential Revision: D42159791 fbshipit-source-id: 3bbe2433b83bd114edd277e42f0017010ac9199f
In some cases evaluation the potential funciton may result in numerical issues. Currently the code hard-codes the handling of a RuntimeError raised when matrices are (numerically) singular. This PR adds the ability to register custom exception handlers. This allows other code depending on pyro to register custom exception handlers without having to modify core pyro code.
There are some other places in which
potential_fn
is called that could benefit from being guarded by these handlers (one isHMC._find_reasonable_step_size
). I'm not sure what the right thing to do there is when encountering numerical isssues, but happy to add this in as needed.