Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow registering of custom exception handlers for potential_fn computations #3168

Merged
merged 8 commits into from
Jan 3, 2023

Conversation

Balandat
Copy link
Contributor

In some cases evaluation the potential funciton may result in numerical issues. Currently the code hard-codes the handling of a RuntimeError raised when matrices are (numerically) singular. This PR adds the ability to register custom exception handlers. This allows other code depending on pyro to register custom exception handlers without having to modify core pyro code.

There are some other places in which potential_fn is called that could benefit from being guarded by these handlers (one is HMC._find_reasonable_step_size). I'm not sure what the right thing to do there is when encountering numerical isssues, but happy to add this in as needed.

…putations

In some cases evaluation the potential funciton may result in numerical issues. Currently the code hard-codes the handling of a RuntimeError raised when matrices are (numerically) singular. This PR adds the ability to register custom exception handlers. This allows other code depending on pyro to register custom exception handlers without having to modify core pyro code.

There are some other places in which `potential_fn` is called that could benefit from being guarded by these handlers (one is `HMC._find_reasonable_step_size`). I'm not sure what the right thing to do there is when encountering numerical isssues, but happy to add this in as needed.
@Balandat
Copy link
Contributor Author

@fritzo is this (or something like this) something you'd be willing to merge in? This will be quite useful for folks who may use some more numerically challenging setups (e.g. GPs) that rely on implementations that throw some other errors.

cc @dme65, @saitcakmak

fehiepsi
fehiepsi previously approved these changes Dec 29, 2022
Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great to me. Thanks, @Balandat!

The utility will be exposed in docs around this path: https://docs.pyro.ai/en/stable/ops.html#pyro.ops.integrator.potential_grad

@fehiepsi
Copy link
Member

Could you fix the lint issue?

@Balandat
Copy link
Contributor Author

Sure, will do.

Any thoughts how to handle numerical issues of this kind in the following:

potential_energy = self.potential_fn(z)

@fehiepsi
Copy link
Member

I think you can do similarly

try:
    potential_energy = self.potential_fn(z)
except ...:
    # skip finding reasonable step size
    return step_size

Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems reasonable. Since we may also want to use this or something like it in NumPyro, could we discuss the design advantages of the proposed solution over alternative solutions?

1. A higher-level solution: In case of error, MH reject the trajectory and sample a different velocity. I'm unsure this makes sense, but it also seems unclear whether error-as-zero-grad makes sense 🤷

2. A lower-level solution: Wrap the potential function outside of MCMC to silence exceptions:

def safe_potential(fn: Callable[[dict], torch.Tensor]) -> Callable[[dict], torch.Tensor]:
    def safe_fn(z: dict) -> torch.Tensor:
        try:
            return fn(z)
        except RuntimeError as e:
            if "singular" in str(e) or "input is not positive-definite" in str(e):
                return next(iter(z.values())).new_tensor(math.nan)
            raise e
    return safe_fn

which I guess would also require a change to potential_grad()

- grads = grad(potential_energy, z_nodes)
+ grads = grad(potential_energy, z_nodes, allow_unused=True)

pyro/ops/integrator.py Outdated Show resolved Hide resolved
@fehiepsi
Copy link
Member

Currently, when errors happen, we set potential_energy to nan, which will lead to the rejection of the proposal. grad can be zeros, ones, or empty - IIUC it won't affect the logic. In NumPyro, jax won't throw errors for numerical computations - rather than that, it throws some invalid outputs like nan or inf, so probably we don't need those exceptions.

@Balandat Balandat changed the title [RFC] Allow registering exception handlers forfunction_fn computations [RFC] Allow registering exception handlers for potential_fn computations Dec 29, 2022
@Balandat
Copy link
Contributor Author

This seems reasonable. Since we may also want to use this or something like it in NumPyro, could we discuss the design advantages of the proposed solution over alternative solutions?

I'm not an expert on what the proper thing to do algorithmically is. Ideally the handling of this can happen at a higher level as part of the algorithm. If the proposal is rejected by setting things to nan and this is the right thing to do, then great.

The downside of the lower-level solution you proposed is that it will require the user to understand the details of the algorithm to implement this kind of wrapper, rather than "I want to handle these kinds of exceptions and have the algorithm deal with properly rejecting the proposal".

@Balandat Balandat changed the title [RFC] Allow registering exception handlers for potential_fn computations Allow registering of custom exception handlers for potential_fn computations Dec 29, 2022
@fritzo
Copy link
Member

fritzo commented Jan 1, 2023

@Balandat sorry linting seems to be a hassle. We usually use the make targets make format and make lint to do all the linting and formatting automatically.

@Balandat
Copy link
Contributor Author

Balandat commented Jan 3, 2023

@Balandat sorry linting seems to be a hassle. We usually use the make targets make format and make lint to do all the linting and formatting automatically.

No worries - I should have run make format. Note though that make lint currently spits out hundreds if not thousands of errors in files unrelated to this PR.

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@fritzo fritzo merged commit 1ec2c39 into pyro-ppl:dev Jan 3, 2023
@Balandat Balandat deleted the exception_handlers branch January 3, 2023 15:48
Balandat added a commit to Balandat/botorch that referenced this pull request Jan 3, 2023
This includes two major changes of interest to botorch:
1. Support for pytorch 2.0 (allowing us to not require pyro dev branch when testing against pytorch nightlies): pyro-ppl/pyro#3164
2. The ability to registering of custom exception handlers for numerical issues: pyro-ppl/pyro#3168
Balandat added a commit to Balandat/botorch that referenced this pull request Jan 4, 2023
Summary:
This includes two major changes of interest to botorch:
1. Support for pytorch 2.0 (allowing us to not require pyro dev branch when testing against pytorch nightlies): pyro-ppl/pyro#3164
2. The ability to registering of custom exception handlers for numerical issues: pyro-ppl/pyro#3168

Pull Request resolved: pytorch#1606

Reviewed By: saitcakmak

Differential Revision: D42330914

Pulled By: Balandat

fbshipit-source-id: 440d9ff00df60f4ba88acc82e323f9e675caf254
Balandat added a commit to Balandat/botorch that referenced this pull request Jan 4, 2023
Summary: Uses draft changes from pyro-ppl/pyro#3168 (part of pyro 1.8.4 pulled in via D42331876) to register handling of `torch.linalg.LinAlgError` and the `ValueError` that can be raised in the torch distribution's `__init__()`

Differential Revision: D42159791

fbshipit-source-id: ef1f53c9471ba8b1d62ca5ad347e5a471d9bb7a1
facebook-github-bot pushed a commit to pytorch/botorch that referenced this pull request Jan 4, 2023
Summary:
This includes two major changes of interest to botorch:
1. Support for pytorch 2.0 (allowing us to not require pyro dev branch when testing against pytorch nightlies): pyro-ppl/pyro#3164
2. The ability to registering of custom exception handlers for numerical issues: pyro-ppl/pyro#3168

Pull Request resolved: #1606

Reviewed By: saitcakmak

Differential Revision: D42330914

Pulled By: Balandat

fbshipit-source-id: 1117e7713b99819b4f1297de2d49aeb1bd2b9c9b
facebook-github-bot pushed a commit to pytorch/botorch that referenced this pull request Jan 4, 2023
Summary:
Pull Request resolved: #1607

Uses draft changes from pyro-ppl/pyro#3168 (part of pyro 1.8.4 pulled in via D42331876) to register handling of `torch.linalg.LinAlgError` and the `ValueError` that can be raised in the torch distribution's `__init__()`

Reviewed By: saitcakmak

Differential Revision: D42159791

fbshipit-source-id: 3bbe2433b83bd114edd277e42f0017010ac9199f
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants