-
-
Notifications
You must be signed in to change notification settings - Fork 984
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stabilize autoguide scale parameters via SoftplusTransform #2767
Conversation
Just curious, what makes you think this will be more stable? (is it numerical or theoretical?) |
I don't understand why this PR would be more stable 🤷 I'm just trying to implement @vitkl's point 2 pyro-ppl/numpyro#855 (comment) which he claimed was more stable. It's plausible I've misunderstood something, and also plausible that a better solution would be to use @vitkl can you confirm (1) this PR implements roughly what you're requesting, (2) that you've shown it is more numerically stable, and (3) that you tried simpler solutions like |
@fehiepsi do you think we should name these constraints |
@fritzo I don't have a better solution. Those names softplus_foo looks good to me. In TFP wrapper, because distributions there do not have constraints, I created a class named |
Thanks for this.
class AutoNormal(..., use_softplus=False):
if use_softplus:
_deep_setattr(self.scales, name,
PyroParam(init_scale, constraints.stable_positive, event_dim))
else:
_deep_setattr(self.scales, name,
PyroParam(init_scale, constraints.positive, event_dim))
I am planning to add a benchmark (to https://github.com/pyro-ppl/sandbox)that shows the impact of this modification of stability and accuracy for our model. To simplify the analysis, it would help to have the switch option |
Hi @vitkl, I believe this should be ready to go.
Hmm, I'm hesitant to add interface complexity to our so-called AutoGuides, especially since I suspect your experiments will show class AutoNormal(AutoGuide):
scale_constraint = constraints.softplus_positive # <--- hackable but not documented
def __init__(self, model, *,
init_loc_fn=init_to_feasible,
init_scale=0.1,
create_plates=None):
... Then in your experiments you can override this via guide_exp = AutoNormal(model)
guide_exp.scale_constraint = constraints.positive
guide_softplus = AutoNormal(model)
guide_softplus.scale_constraint = constraints.softplus_positive Hope I'm not overthinking this, I'd just like to avoid interface bloat 😄 @fehiepsi does this seem ok to you? |
Sure! This only changes how we optimize the scale parameters. Probably this will affect the current inference code of some users but I think it is easy to fix... Btw, we should mention why we make this change in the next release notes. :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with passing tests and exposing those new transforms in docs. Thanks for addressing this issue, @fritzo!
Thanks a lot for implementing this. Class variable works great. We are in the middle of resubmitting the paper revisions this week - but I will try to do the testing soon! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Resolves #2766
Blocked by #2753Ports pytorch/pytorch#52300
This switches to using softplus transforms for autoguide scale parameters (point 2 in pyro-ppl/numpyro#855 (comment)), and adds relevant machinery:
constraints.softplus_positive
andSoftplusTransform
constraints.softplus_lower_cholesky
andSoftplusLowerCholeskyTransform
This does not use softplus transforms for latent variables that are scales (point 1 in pyro-ppl/numpyro#855 (comment)). While this PR adds machinery to declare and perform those transforms, it remains to detect which positive latent variables should be softplus-transformed rather than exp-transformed.
cc @vitkl @fehiepsi
Tested