Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[proposal] [discussion] Refactor pruning/weight_norm using new Reparametrization functionality + actually deprecate old impl of SpectralNorm #7313

Open
vadimkantorov opened this issue May 4, 2018 · 70 comments
Labels
module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@vadimkantorov
Copy link
Contributor

vadimkantorov commented May 4, 2018

Currently the weight_norm and spectral_norm are patching a passed module + implement special functions for adding/removing these from a module.

Some ideas for refactoring to make it less tricky:

  • provide a stable signature for getting weight, then they can be cleanly used with methods such as torch.matmul and F.conv2d
  • if module patching (adding some new buffers as parameters and registering a hook) is needed and is a reasonable pattern, provide a user-facing stable abstraction for it (especially adding and removal of parameters). It seems we have a chain of decorators-hooks, and each of them may have some extra buffers, and currently they are all patched into the passed module object.

cc @jerryzh168 @jianyuh @dzhulgakov

@fmassa
Copy link
Member

fmassa commented May 5, 2018

Could you illustrate your proposal with some code snippets? It will be easier to understand

@vadimkantorov
Copy link
Contributor Author

vadimkantorov commented May 5, 2018

For first point, spectral norm has compute_weight method, but it’s defined in an internal structure. I think a global method like constrain_weight_with_spectral_norm would be more discoverable.

For second point, to come up with a good abstraction, one needs to know why the current module object patching was chosen vs defining a new wrapping module. I don’t know the answer.

@t-vi may have some ideas, he wondered about a better abstraction also in #7298 (comment)

@zou3519
Copy link
Contributor

zou3519 commented May 7, 2018

We're trying to figure out what this would look like -- code snippets would definitely help

@t-vi
Copy link
Collaborator

t-vi commented May 7, 2018

So what it looks like to me as a user is that I would want to have "computed parameters".

  • They have raw parameters as inputs.
  • They are updated when the inputs change.
  • They are cached after they are updated and can be used just like other tensors.

I would propose to

  • Have a ParameterModel class that actually is a Module subclass and is initialized with the name(s) of the parameter(?) that is computes. It can have parameters just like a model. It's forward computes the parameters. It should optionally also have a way to set parameters (i.e. doing the inverse transform), e.g. for initialization.
  • Have a "parameters were updated hook" that is called by the optimizer which invalidates the computed parameters, e.g. by setting them to None or so.
  • Have a thing in Module.getitem that calls the ParameterModel instance when a computed parameter is None and is accessed.

This does not solve how to call the hook when someone else overwrites the parameter, but hey. I'm not set on any of the specifics. I can whip up a prototype if that helps.

So the main functional difference to how spectral normalization is currently implemented is that it replaces the forward_pre hook with an update hook. The advantage to the user is that she can see this as a modification of parameters rather than the modules having the parameters.

@t-vi
Copy link
Collaborator

t-vi commented May 11, 2018

So to put some code here (not remotely tested enough):
https://github.com/t-vi/pytorch/tree/calculated_param
has the adaptation of Module to introduce CalculatedParameter (and Parameter, optim.SGD, optim.Adam are adapted).
The biggest weakness (in my opinion) of the implementation is that it "needs help" to detect parameter updates. The typical case will be the optimizer doing it, so the optimizer needs to call param.updated().

Note how the spectral norm class now looks like a regular module - to the point where you would not need to have the utility spectral_norm and remove_spectral_norm functions (the latter could just be replaced m.weight = nn.Parameter(m.weight.data) if you don't need to programmatically control the name and check whether weight was a parameter).

While it does add some trickery in module.py, I think the user of calculated parameters gets a much better interface.

@t-vi
Copy link
Collaborator

t-vi commented May 11, 2018

An alternative to trying to chase parameter changes might be to use the pre_forward_hook within module. It would seem to depend on the application whether you want something giving a new result at every invocation even when the inputs didn't change or whether that is wasteful. This could be mitigated by having the update only be done in training mode or some such. One could also simplify the logic between the module and the CalculatedParam by keeping a calculation counter and check for None or the calculation counter to indicate out of date instead of having a hook that sets the parameter to None.

@ngimel
Copy link
Collaborator

ngimel commented May 13, 2018

Incidentally, we have also been working on abstracting reparameterization https://github.com/NVIDIA/apex/tree/master/apex/reparameterization. Current implementation puts boilerplate code to reparameterization.py, with weight norm defining just particular weight recomputation functions. The recomputation of weight is controlled by self.evaluated which is set to true at evaluation and toggled to false in backward (which may be crude enough, but worked not to recompute weight at each time step of rnn).
It is still a hook-based implementation, and conceptually I agree that all the same things can be achieved by module wrappers.

@t-vi
Copy link
Collaborator

t-vi commented May 13, 2018

@ngimel Nice, that looks like a clean variant of how weight_norm is currently done in master.

One thing I am wondering (and I think that is the main difference to my proposed implementation strategy): Is this something that works at module or at parameter level. I see it more naturally at the parameter level, while the other solutions seem to look more at the module level.

One thing that isn't easily achieved with my demo implementation is having multiple parameters that need to be calculated jointly. This could be achieved by returning a dict and have getattr check for that.

In terms of usability ("less boilerplate code") I think that allowing the parameter computation to be in a nn.Module subclass is a more natural solution than wrappers.

@ngimel
Copy link
Collaborator

ngimel commented May 13, 2018

@t-vi I've asked the author of that implementation to come here to give his perspective. CalculatedParameters in nn.Module that you have look nice, but they do require modifications to the core and optimizers, whereas apex approach is more self-contained. As you note, joint recalculation of multiple parameters does not come naturally, neither to apex nor to CalculatedParameters. Also, I think all current approaches break parameter sharing, and this is something that has to be figured out going forward.

@t-vi
Copy link
Collaborator

t-vi commented May 13, 2018

Actually, I could move the caching to the __call__ of CalculatedParameter. That would solve reduce the changes to Module and enable shared parameters.

(Edit: I did now, and also provided a tentative interface for multiple parameters by returning a dict from forward and the caching matching the suffix - not 100% clean, but it might be a start.)

@vadimkantorov
Copy link
Contributor Author

Do we want parameter calculation to be able to provide gradients by itself?

@t-vi
Copy link
Collaborator

t-vi commented May 13, 2018

If you want to manipulate gradients, you probably want to do that using Function or a backward hook.

@raulpuric
Copy link

Chiming in on this discussion. I'm the author of the reparameterization @ngimel shared earlier.
I've actually thought of 3 of these solutions in the past, and I'll be providing my thoughts on them in addition to the thoughts on my implementation.

My/hooking implementation:
Pros:

  • Abstract interface for reparameterization
  • wrapper to add hooks to all parameters in a module recursively
  • handles some degree of parameter sharing.
  • Works behind the scenes. No modifications to module aside from adding new parameters. The module can be used in the exact same way as before.

Cons:

  • can be hard/hacky to use with parameter sharing.
  • saving/reloading model is not user-friendly. One can't load the weight_norm (g/v) parameters into a model without weight norm.

Optimizer handles reparameterization:
Pros:

  • Saving/reloading model is user friendly as it does not modify anything in the model
    • works behind the scenes disentangled from the module
  • Offloading responsibility for doing reparameterizations makes sense in the context of weight norm, as one of the goals is to change how the parameters are optimized.

Cons:

  • Unusual optimizer usage is hard to allow (skipping optimization if gradients/loss blowup)

Calculated/Lazy Parameter:
Personally I'm a fan of this option, and had thought about it in the past, but didn't have time to implement it in a good fashion. If someone can implement a general class for this in a very safe PyTorch-y way, that also makes saving/reloading parameters easy, I would be in support of this option.
Pros:

  • Works at the parameter level and does not add new parameters to the module
    • Easy to swap param with a non-reparameterized version of the parameter (helps with save/load)
    • should be able to handle weird optimization strategies so long as .data and .grad are managed well

Cons:

  • Works at the parameter level so managing these might be hard.
    • Dictionary is a good first step, but needs something more. A hook might be a good solution here, since no new parameters/functionality are getting added to the module itself

Module wrapper:
Pros:

  • It can handle everything that all the other options can (fprop, bprop, param updates, etc)

Cons:

  • Module wrappers can complicate code. In order to access attributes/functionality of the wrapped module you have to add an additional .module everywhere

The ideal sort of setup I envision is either one module wrapper class that handles everything for the user, or a combination of CalculatedParameters and hooks/module wrapper. (Preferably the latter)

I think though in general PyTorch needs some sort of module wrapper nn.Module subclass that absorbs most if not all the functionality/attributes of the wrapped module instance. This would alleviate a lot of code considerations that have to be done when using things like (Distributed)DataParallel which forces us to add model.module all over our code.

@t-vi
Copy link
Collaborator

t-vi commented May 15, 2018

@raulpuric Thank you for commenting. If you have a moment to look at it, what do you think are the key things missing from the calculated param prototype ( https://github.com/t-vi/pytorch/tree/calculated_param ) relative to your vision?

@vadimkantorov
Copy link
Contributor Author

vadimkantorov commented May 16, 2018

Imagine we have a parameter precalc with for an op (pseudo-code):

weight, bias, other_param = ..., ..., ....
weight, bias = apply_spectral_norm_or_something_else(weight, other_param)
y = F.conv2d(x, weight, bias)
  1. Are we supporting this functional scenario easily? (I think that we should)
  2. Shoud we support differentiating to other_param? (I think the user should have a choice. The hooks-based solution doesn't allow it, does it? And the CalculatedParameter?)
  3. Do we want to support this kind of joint param pre-calculation?

For the question of wrapper vs patching, I think wrapper is still cleaner (even if the user has to type .module to get to original module and may abstruct model inspection a bit), because for the case of chained parameter-precalc (imagine applying weight-norm and spectral-norm together, if it made sense) the order of the precalc ops is clear to user.

@t-vi
Copy link
Collaborator

t-vi commented May 16, 2018

For the calculated parameter, you just have a module that uses other_param, so I think it covers that easily. The multiple outputs require cooperation between the calculated params and the module in naming the parameters, but that seems unavoidable regardless of the solution.

I'm not sure I quite understand the wrapper vs. patching argument. I see CalculatedParameter to do neither (you just assign the parameter). If you wanted to stack things, you would assign the one CalculatedParameter's parameters with those of another (needs a look at propagating the calling updated, but doesn't seem too hard).

In a different area: I looked a bit into what to do about calling the Parameter.updated hook and think it can be pushed to TensorBase in C and have it called automatically in inplace operations. That only works when using torch.without_grad(): p.add_(...) or whatever instead of using p.data.add_ as is currently done, but I think that would be good enough. When working with a getter/setter for definition of the updated function, the performance impact on other Tensor's should be irrelevant.

@vadimkantorov
Copy link
Contributor Author

@t-vi For the module wrapper vs patching, I think proposed CalculatedParameter stands somewhere in the middle, and I agree it does stacking more clear than hooks. My argument was to address @raulpuric's Cons section for the module wrapper.

@vadimkantorov
Copy link
Contributor Author

@apaszke @soumith Do you have opinions regarding questions in #7313 (comment)? (especially about ease of use of params precalc together with functional interface).

@t-vi
Copy link
Collaborator

t-vi commented May 29, 2018

So to give an update from my side: I implemented inplace hooks in Variables and am updating CalculatedParameter to use that. This would enable the idiom

with torch.no_grad():
    p.add_(-group['lr'], d_p)

(in SGD) to correctly invalidate the CalculatedParameter.
It would then be a bug to use .data or .detach() in such cases.

From my POV this removes the last major obstacle to using CalculatedParameter to cleanly implement cached calculated parameters.

@t-vi
Copy link
Collaborator

t-vi commented May 31, 2018

So I pushed an implementation using inplace hooks on tensor base to
https://github.com/t-vi/pytorch/tree/calculated_param
I think this addresses most things.

For multiple parameters, I the main alternative to passing the parameter to the CalculatedParam call would be to specify the name when assigning. So it would be

cp = SubClassOfCalculatedParamWithMultipleReturns(somestuff)
model.weight = cp.get_output("weight")
model.bias = cp.get_output("bias")

but you could equally do

cp = SubClassOfCalculatedParamWithMultipleReturns(somestuff)
model.weight = cp.get_output("wouldyouknow")
model.bias = cp.get_output("totallydifferent")

if your forward returns those in a dict.

Regarding the functional interface, I think it is a bit of a red herring, but without a Module subclass instance and its getattr method, one would have to do

F.linear(input, mycalculatedparam())

i.e. add the "()" to make it a call. Given that it gives you an error message rather than silently fail, I would think it is OK.

As discussed above, it would need deprecation of using p.data.do_() instead of with nograd(): p.do_(). In theory, one could try to move the hooks to the storage level, but why would we - it adds complications for not really that much gain.

I look forward to your comments.

@vadimkantorov
Copy link
Contributor Author

@t-vi kudos for this prototype supporting functional interface and multiple parameter pre-calc!

@t-vi
Copy link
Collaborator

t-vi commented Jun 7, 2018

Thanks. So in your opinion, would this satisfy the requirements for such the desired abstraction?

@t-vi
Copy link
Collaborator

t-vi commented Jun 12, 2018

@apaszke Any chance that we get some sort of input from you or one of the other devs on this?
My impression is that this could have a rather far-reaching impact, both also in terms of applications but also w.r.t. inplace hooks vs. .data vs. with torch.no_grad():.

@vadimkantorov
Copy link
Contributor Author

@t-vi what does usage look like with a regular Linear module? sth like self.fc8 = CalculatedParamSubClass(self.fc8)?

Also what's your opinion on replacing get_output with just __getattr__ ? i.e:

cp = SubClassOfCalculatedParamWithMultipleReturns(somestuff)
model.weight = cp.weight
model.bias = cp.bias

@t-vi
Copy link
Collaborator

t-vi commented Jun 28, 2018

My reason for suggesting get_output is the "explicit is better than implicit" Python-Zen.
Personally, I think the expected time spent debugging when having dynamical output accessors as attributes is higher than the time saved by not making it explicit. But other people might get this type of stuff right all the time and I'm just imagining the problem.

@t-vi
Copy link
Collaborator

t-vi commented Jun 28, 2018

given the effort in #8806, maybe moving the hook to the storage, is a good idea, too.

@fritzo
Copy link
Collaborator

fritzo commented Feb 18, 2021

This could also standardize the names for flow/distribution kind of models

@vadimkantorov @lezcano this proposal looks a lot like Pyro's abstraction for transformed parameters, which mostly relies on machinery in torch.distributions.transforms and torch.distributions.constraint_registry.

Summary of torch Transforms for constrained parameters

Our pattern is to use a Transform object that implements ._call() for the forward transform and ._inverse() for the backward transform. These are wrapped in public .__call__() and .inv.__call__() methods. We allow non-bijective transforms (similar to your proposal) and require the weak pseudoinverse relationships

t(t.inv(t(x))) == t(x)
t.inv(t(t.inv(y))) == t.inv(y)

Note that PyTorch's Transform objects do not currently distinguish between left- and right-inverse; they only distinguish between bijective and non-bijective pairs via an .is_bijective class attribute on Transform subclasses.

The transform machinery also plays well with Constraint machinery for constrained parameters (e.g. learnable correlation matrices). The usual syntax uses transform_to(-)(-) as in

y = transform_to(constraints.lower_cholesky)(torch.randn(4,4))

We've put a lot of work into tracking shapes for static shape computations in transforms and constraints, since shapes are needed by TransformedDistribution in PyTorch, Pyro, and flowtorch (cc @stefanwebb, @feynmanliang).

On top of PyTorch's Transform and Constraint machinery, Pyro uses a PyroParam object to accomplish initialization of constrained parameters. As in your proposal, setting a PyroParam calls some transform's ._call() method under the hood to transform from a constrained initializer to an unconstrained parameter space; then later reading that param calls the same transform's ._inverse() method to transform back to constrained space. These transformations are implemented inside .__getitem__() and .__setitem__() methods of a dict-like ParamStoreDict object. That ParamStoreDict object seems very similar to your proposal (if you squint and ignore Pyro's legacy cruft 😆).

@lezcano
Copy link
Collaborator

lezcano commented Feb 18, 2021

Thank you for the very nice summary @fritzo !

So, the PyroParam solution, although very neat, is what we were trying to avoid at a PyTorch level for the reasons that I outlined above, that's the reason why we went with properties rather than __getitem__ / __setitem__.

When it comes to Transform, it looks like it provides a superset of the functionality that this proposal offers. In particular, one of the parts of the design of this proposal was for it to be able to use plain nn.Modules. What I find weird is that Transform does not inherit from nn.Module. Why is this?

If it did, then it would be direct either to adapt its API directly in their definition or on a function so that it could be used as a parametrisation for register_parametrization. Otherwise, it would be possible to have an adapter class that wraps a Transform and makes it into an nn.Module that can be fed into register_parametrization. But yeah, the fact that Transforms are not nn.Modules really puzzles me, why is it so?

Having this, it looks like it could be reasonable to have a hierarchy of the form Module > InvertibleModule > Transform, where they implement: forward > forward + inverse > forward + inverse + log_det +...

@vadimkantorov
Copy link
Contributor Author

Hehe, a recent request for function/module inverses: #52553 :)

@fritzo
Copy link
Collaborator

fritzo commented Feb 21, 2021

@vadimkantorov Transforms are not nn.Modules for the same reasons that Distributions are not nn.Modules. @stefanwebb has found it useful to derive from both classes in a TransformModule class.

@lezcano
Copy link
Collaborator

lezcano commented Feb 22, 2021

Interestingly enough, it looks like parametrisations would allow to solve issues 1-3 pointed out in the distributions issue, wouldn't they? I think that that would be quite a good synergy really, and a direct application of constrained optimisation via parametrisations, which is what #33344 brings to the table. I think it'd be worth looking into that in the future.

@fritzo
Copy link
Collaborator

fritzo commented Feb 22, 2021

it looks like parametrisations would allow to solve issues 1-3 pointed out in the distributions issue

@lezcano Yes, I can see how point 1 of the distribution issue might be solved by parametrisations. Could you elaborate on your proposed strategies for addressing issue 2 (cache flushing on parameter mutation) and issue 3 (optional constructor args) in the distributions issue? Issue 2 on caching seems especially complex and warrants detailed discussion.

@vadimkantorov
Copy link
Contributor Author

vadimkantorov commented Feb 22, 2021

Potential implementation clash with attribute caching or whatever's going in #50431

@lezcano
Copy link
Collaborator

lezcano commented Feb 23, 2021

@fritzo Issue 2: the current parametrisation PR offers an opt-in caching system. We opted to go for explicit rather than implicit when it comes to activating it. At the beginning, I implemented an automatic one of sorts, that would work in most situations, but, of course, making it work in all situations is remarkably tricky with the current API. As such, you can activate this caching system via a context manager as:

with torch.cached():  # it won't be in the torch namespace, but still
    outputs = model(inputs)

or you could go with a finer grain and simply wrap the loop of your RNN, if your RNN happens to be parametrised:

with torch.cached():
    for x in xs:
        out_rnn = self.rnn_cell(x, out_rnn)

Issue 3 (how to deal with the optional constructor args). One solution would be to go as described in section "High Level API" on the summary in #7313 (comment). This would be exactly the same idea as having functions that do the dispatching of several classes, as you described in the issue that you linked.

Another idea that expresses better the intent would be to implement the relation between these constraints directly. When there is a diffeomorphic relation between two spaces, as it is in the case of logits vs probabilities via the exponential, we can express these isomorphisms explicitly. In these cases, it may be reasonable to have two parametrisations parametrising the same unconstrained argument. As such, an object distr of class Bernoulli would have two differentiable properties, distr.logits and distr.probs that point to the same unconstrained parameter. We would then implement one in terms of the other via the diffeomorphism that connects them, so that assigning to one is equivalent to assigning to the other in the natural way. Which one to put in terms of the other could be selected in the basis of performance / numerical stability. Putting two different parametrisations pointing to the same Parameter is not supported at the moment, but it is something that is direct to implement by adding an optional parameter parametrized_name to register_parametrization. An implementation of Bernoulli would then go along the lines of:

class Exp(nn.Module):
    def forward(self, x):
        return x.exp()

    def right_inverse(self, x):
        return x.log()

class Bernoulli(nn.Module):
    def __init__(self, prob=None, logits=None):
        if (prob is None) == (logit is None):
            raise ValueError("Provide one and only one!")
        # We omit the checks that prob / logits are in the correct domains
        t = prob if prob is not None else logits
        # Assume that we have decided to implement prob in terms of logit
        if isinstance(t, nn.Parameter):
            self.register_parameter("logits", t)
        else:
            self.register_buffer("logits", t)
        # Parametrise logits
        register_parametrization(self, "logits", LogitParametrization())
        # We now put a parametrisation on top of the first parametrisation
        # parametrized_name is not currently supported, but it is direct to implement
        register_parametrization(self, "logits", Exp(), parametrized_name="prob")
        # Initialise using the parameter that we have
        with torch.no_grad():
            if prob is not None:
                self.prob = prob
            else:
                self.logits = logits

distr = Bernoulli(prob=...)
distr2 = Bernoulli(prob2=...)
# From here on, working with `distr.prob` or `distr.logits` is equivalent
# If we wrap the forward pass in a `with cached()`, both the results are cached the first time that they are needed, and reused throughout all the forward pass
# Assigning to `distr.prob` or `distr.logits` also works as expected

If there are distribution classes that admit two very different implementations, it would still be reasonable to have this solution together with the first one via the factory pattern.

This last case also happens more generally in the setting of constraints as there might be some constraints that do not have a parametrisation that's better than the others. For example, there are plenty of ways to perform optimisation with orthogonal constraints (consider any differentiable map that maps an unconstrained matrix to an orthogonal matrix, e.g. the matrix exponential, Cayley map, Householder reflections...). In this case, there is no "one size fits all" parametrisation, so we will implement several of these and we will dispatch them with an orthogonal(module, tensor, args*) interface. Regardless of the one that the user chooses, we will expose the same API via the right_inverse. This way, the user can interact with the layer in an uniform way (get their value and initialise it to an orthogonal matrix) irrespective of the unconstrained representation being used.

@Chillee
Copy link
Contributor

Chillee commented Feb 24, 2021

@vadimkantorov the solution we're probably going to go with (#52576) won't affect this, so no need to worry about this :P

facebook-github-bot pushed a commit that referenced this issue Mar 4, 2021
Summary:
Provides the implementation for feature request issue #28937.

Adds the `Parametrization` functionality and implements `Pruning` on top of it.
It adds the `auto` mode, on which the parametrization is just computed once per forwards pass. The previous implementation computed the pruning on every forward, which is not optimal when pruning RNNs for example.

It implements a caching mechanism for parameters. This is implemented through the mechanism proposed at the end of the discussion #7313. In particular, it assumes that the user will not manually change the updated parameters between the call to `backwards()` and the `optimizer.step()`. If they do so, they would need to manually call the `.invalidate()` function provided in the implementation. This could be made into a function that gets a model and invalidates all the parameters in it. It might be the case that this function has to be called in the `.cuda()` and `.to` and related functions.

As described in #7313, this could be used, to implement in a cleaner way the `weight_norm` and `spectral_norm` functions. It also allows, as described in #28937, for the implementation of constrained optimization on manifolds (i.e. orthogonal constraints, positive definite matrices, invertible matrices, weights on the sphere or the hyperbolic space...)

TODO (when implementation is validated):
- More thorough test
- Documentation

Resolves  #28937

albanD

Pull Request resolved: #33344

Reviewed By: zhangguanheng66

Differential Revision: D26816708

Pulled By: albanD

fbshipit-source-id: 07c8f0da661f74e919767eae31335a9c60d9e8fe
@vadimkantorov
Copy link
Contributor Author

@lezcano given that #33344 is merged in, should I rename this issue to just "rewrite weight_norm / spectral_norm using new Reparametrization functionality"?

@vadimkantorov vadimkantorov changed the title [proposal] [discussion] An abstraction to facilitate building weight_norm/spectral_norm-like things aka Reparametrization [proposal] [discussion] Refeactor weight_norm/spectral_norm-like using new Reparametrization functionality Mar 7, 2021
@vadimkantorov vadimkantorov changed the title [proposal] [discussion] Refeactor weight_norm/spectral_norm-like using new Reparametrization functionality [proposal] [discussion] Refactor weight_norm/spectral_norm-like using new Reparametrization functionality Mar 7, 2021
@vadimkantorov vadimkantorov changed the title [proposal] [discussion] Refactor weight_norm/spectral_norm-like using new Reparametrization functionality [proposal] [discussion] Refactor weight_norm/spectral_norm using new Reparametrization functionality Mar 7, 2021
@lezcano
Copy link
Collaborator

lezcano commented Mar 8, 2021

The only detail is that the name is "parametrizations". We went for this one over "reparametrizations" as it is a less verbose version of an already very long word.

aocsa pushed a commit to Quansight/pytorch that referenced this issue Mar 15, 2021
Summary:
Provides the implementation for feature request issue pytorch#28937.

Adds the `Parametrization` functionality and implements `Pruning` on top of it.
It adds the `auto` mode, on which the parametrization is just computed once per forwards pass. The previous implementation computed the pruning on every forward, which is not optimal when pruning RNNs for example.

It implements a caching mechanism for parameters. This is implemented through the mechanism proposed at the end of the discussion pytorch#7313. In particular, it assumes that the user will not manually change the updated parameters between the call to `backwards()` and the `optimizer.step()`. If they do so, they would need to manually call the `.invalidate()` function provided in the implementation. This could be made into a function that gets a model and invalidates all the parameters in it. It might be the case that this function has to be called in the `.cuda()` and `.to` and related functions.

As described in pytorch#7313, this could be used, to implement in a cleaner way the `weight_norm` and `spectral_norm` functions. It also allows, as described in pytorch#28937, for the implementation of constrained optimization on manifolds (i.e. orthogonal constraints, positive definite matrices, invertible matrices, weights on the sphere or the hyperbolic space...)

TODO (when implementation is validated):
- More thorough test
- Documentation

Resolves  pytorch#28937

albanD

Pull Request resolved: pytorch#33344

Reviewed By: zhangguanheng66

Differential Revision: D26816708

Pulled By: albanD

fbshipit-source-id: 07c8f0da661f74e919767eae31335a9c60d9e8fe
xsacha pushed a commit to xsacha/pytorch that referenced this issue Mar 31, 2021
Summary:
Provides the implementation for feature request issue pytorch#28937.

Adds the `Parametrization` functionality and implements `Pruning` on top of it.
It adds the `auto` mode, on which the parametrization is just computed once per forwards pass. The previous implementation computed the pruning on every forward, which is not optimal when pruning RNNs for example.

It implements a caching mechanism for parameters. This is implemented through the mechanism proposed at the end of the discussion pytorch#7313. In particular, it assumes that the user will not manually change the updated parameters between the call to `backwards()` and the `optimizer.step()`. If they do so, they would need to manually call the `.invalidate()` function provided in the implementation. This could be made into a function that gets a model and invalidates all the parameters in it. It might be the case that this function has to be called in the `.cuda()` and `.to` and related functions.

As described in pytorch#7313, this could be used, to implement in a cleaner way the `weight_norm` and `spectral_norm` functions. It also allows, as described in pytorch#28937, for the implementation of constrained optimization on manifolds (i.e. orthogonal constraints, positive definite matrices, invertible matrices, weights on the sphere or the hyperbolic space...)

TODO (when implementation is validated):
- More thorough test
- Documentation

Resolves  pytorch#28937

albanD

Pull Request resolved: pytorch#33344

Reviewed By: zhangguanheng66

Differential Revision: D26816708

Pulled By: albanD

fbshipit-source-id: 07c8f0da661f74e919767eae31335a9c60d9e8fe
@vadimkantorov
Copy link
Contributor Author

Related: #55368

@vadimkantorov vadimkantorov changed the title [proposal] [discussion] Refactor weight_norm/spectral_norm using new Reparametrization functionality [proposal] [discussion] Refactor pruning/weight_norm/spectral_norm using new Reparametrization functionality May 4, 2021
@vadimkantorov vadimkantorov changed the title [proposal] [discussion] Refactor pruning/weight_norm/spectral_norm using new Reparametrization functionality [proposal] [discussion] Refactor pruning/weight_norm using new Reparametrization functionality Mar 27, 2023
@vadimkantorov vadimkantorov changed the title [proposal] [discussion] Refactor pruning/weight_norm using new Reparametrization functionality [proposal] [discussion] Refactor pruning/weight_norm using new Reparametrization functionality + actually deprecate old impl of SpectralNorm Mar 27, 2023
@ezyang
Copy link
Contributor

ezyang commented Jun 5, 2023

#103001 for weight_norm

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests