Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stabilize autoguide scale parameters via SoftplusTransform #2767

Merged
merged 25 commits into from
Feb 28, 2021

Conversation

fritzo
Copy link
Member

@fritzo fritzo commented Feb 19, 2021

Resolves #2766
Blocked by #2753
Ports pytorch/pytorch#52300

This switches to using softplus transforms for autoguide scale parameters (point 2 in pyro-ppl/numpyro#855 (comment)), and adds relevant machinery:

  • constraints.softplus_positive and SoftplusTransform
  • constraints.softplus_lower_cholesky and SoftplusLowerCholeskyTransform

This does not use softplus transforms for latent variables that are scales (point 1 in pyro-ppl/numpyro#855 (comment)). While this PR adds machinery to declare and perform those transforms, it remains to detect which positive latent variables should be softplus-transformed rather than exp-transformed.

cc @vitkl @fehiepsi

Tested

  • needed to tweak parameters in a couple inference tests (test now actually run a bit faster)

@fritzo fritzo changed the title Stable autoguide scale Stablize autoguide scale Feb 19, 2021
@fritzo fritzo changed the title Stablize autoguide scale Stabilize autoguide scale Feb 19, 2021
@fritzo fritzo changed the title Stabilize autoguide scale Stabilize autoguide scale parameters Feb 19, 2021
@fehiepsi
Copy link
Member

fehiepsi commented Feb 19, 2021

Just curious, what makes you think this will be more stable? (is it numerical or theoretical?)

@fritzo
Copy link
Member Author

fritzo commented Feb 19, 2021

Just curious, what makes you think this will be more stable?

I don't understand why this PR would be more stable 🤷 I'm just trying to implement @vitkl's point 2 pyro-ppl/numpyro#855 (comment) which he claimed was more stable. It's plausible I've misunderstood something, and also plausible that a better solution would be to use ClippedAdam or a lower learning rate.

@vitkl can you confirm (1) this PR implements roughly what you're requesting, (2) that you've shown it is more numerically stable, and (3) that you tried simpler solutions like ClippedAdam or tweaking optimizer parameters?

@fritzo fritzo changed the title Stabilize autoguide scale parameters Stabilize autoguide scale parameters via SoftplusTransform Feb 21, 2021
@fritzo
Copy link
Member Author

fritzo commented Feb 21, 2021

@fehiepsi do you think we should name these constraints softplus_positive and softplus_lower_cholesky? That seemed like mixing metaphors to me, but I suppose it's better than introducing yet another word 'stable'. I appreciate your help in choosing good names.

@fehiepsi
Copy link
Member

@fritzo I don't have a better solution. Those names softplus_foo looks good to me. In TFP wrapper, because distributions there do not have constraints, I created a class named BijectorConstraint(bijector), which is similar here (:D) Bijector <-> softplus, Constraint <-> positive.

@vitkl
Copy link
Contributor

vitkl commented Feb 21, 2021

Hi @fehiepsi @fritzo

Thanks for this.

  1. This implements exactly what I requested - however in the numpyro issue (Softplus transform as a more numerically stable way to enforce positive constraint numpyro#855 (comment)) you also discussed making that an option rather than the default:
class AutoNormal(..., use_softplus=False):

    if use_softplus:
        _deep_setattr(self.scales, name,
                                PyroParam(init_scale, constraints.stable_positive, event_dim))
    else:
        _deep_setattr(self.scales, name,
                                PyroParam(init_scale, constraints.positive, event_dim))
  1. It is more numerically stable for the cell2location model (tested in numpyro). Softplus is also used by pymc3.

  2. We see that ClippedAdam and reduced learning rate (0.002 -> 0.0002) does not help (again tested in numpyro). We did not do an exhaustive search of training hyperparameters yet.

I am planning to add a benchmark (to https://github.com/pyro-ppl/sandbox)that shows the impact of this modification of stability and accuracy for our model. To simplify the analysis, it would help to have the switch option AutoNormal(..., use_softplus=False) in both pyro and numpyro.

@fritzo fritzo removed the discussion label Feb 22, 2021
@fritzo fritzo marked this pull request as ready for review February 22, 2021 20:09
@fritzo fritzo changed the base branch from pytorch-nightly to dev February 22, 2021 20:09
@fritzo
Copy link
Member Author

fritzo commented Feb 22, 2021

Hi @vitkl, I believe this should be ready to go.

making that an option rather than the default

Hmm, I'm hesitant to add interface complexity to our so-called AutoGuides, especially since I suspect your experiments will show softplus should be the default. At the same time I'd like to make it easy for you to run experiments. WDYT of this compromise: How about we make the constraint hackable but not publicly configurable? That is, we create a class level variable

class AutoNormal(AutoGuide):
    scale_constraint = constraints.softplus_positive    # <--- hackable but not documented
    def __init__(self, model, *,
                 init_loc_fn=init_to_feasible,
                 init_scale=0.1,
                 create_plates=None):
        ...

Then in your experiments you can override this via

guide_exp = AutoNormal(model)
guide_exp.scale_constraint = constraints.positive

guide_softplus = AutoNormal(model)
guide_softplus.scale_constraint = constraints.softplus_positive

Hope I'm not overthinking this, I'd just like to avoid interface bloat 😄 @fehiepsi does this seem ok to you?

@fehiepsi
Copy link
Member

Sure! This only changes how we optimize the scale parameters. Probably this will affect the current inference code of some users but I think it is easy to fix... Btw, we should mention why we make this change in the next release notes. :)

fehiepsi
fehiepsi previously approved these changes Feb 22, 2021
Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with passing tests and exposing those new transforms in docs. Thanks for addressing this issue, @fritzo!

@vitkl
Copy link
Contributor

vitkl commented Feb 22, 2021

Thanks a lot for implementing this. Class variable works great. We are in the middle of resubmitting the paper revisions this week - but I will try to do the testing soon!

fehiepsi
fehiepsi previously approved these changes Feb 28, 2021
Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @fritzo! I'll port this to NumPyro soon. Do you want to expose those transforms in docs?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Softplus transform for AutoNormal scales [feature request]
3 participants