Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to define the API for parameter initialization, regularization (L2, weight dropout, etc), maybe updater opts per-param #59

Closed
albertz opened this issue Nov 4, 2021 · 28 comments
Milestone

Comments

@albertz
Copy link
Member

albertz commented Nov 4, 2021

It is maybe not such a nice idea if every new Module will have this option and then explicitly passes it on to all submodules. Some modules might also not implement this. And maybe there are other options which are not handled yet, e.g. param noise, etc.

How should this be handled? Mostly such arguments which every layer can potentially accept. Mostly arguments about certain behavior on parameters (variables).

In TF, the natural way would be to use tf.variable_scope context scope and there have some custom getter or other custom logic.

So maybe also some context manager here?

@albertz albertz added this to the first-release milestone Nov 4, 2021
This was referenced Nov 4, 2021
@albertz
Copy link
Member Author

albertz commented Nov 5, 2021

Also Lstm with rec_weight_dropout. That should be removed, and we should have a generic weight_dropout.
Although such generic weight dropout still needs to be implemented on RETURNN side (rwth-i6/returnn#735).

@albertz
Copy link
Member Author

albertz commented Nov 5, 2021

The same also for parameter initialization options. Like forward_weights_init and bias_init.
Maybe we should also extend RETURNN in some way to make this more generic. The LayerBase.var_creation_scope could be extended by some generic mechanism.

@albertz
Copy link
Member Author

albertz commented Nov 5, 2021

LinearLayer also has filter_grad. This is applied after the matrix multiplication, before the optional activation function.
This should be more generic as well. Basically this could be applied after every layer.
There could be a RETURNN layer for this, like FilterGradLayer, which would be somewhat similar to ScaledGradientLayer.
Or this could be a generic option for LayerBase. And then applied in LayerBase.post_init.

Then this becomes also relevant here.

@albertz
Copy link
Member Author

albertz commented Nov 5, 2021

LayerBase also has trainable. Should we also handle that here?

@albertz
Copy link
Member Author

albertz commented Nov 5, 2021

LayerBase also has updater_opts. Should we also handle that here?

@albertz
Copy link
Member Author

albertz commented Nov 5, 2021

LayerBase also has param_variational_noise, which should be handled here. However, this applies to all params of a layer. Maybe we want to have more fine-grained support for each param?

albertz added a commit that referenced this issue Nov 5, 2021
We will handle that in a more generic way (#59).
@albertz
Copy link
Member Author

albertz commented Nov 10, 2021

There are some other potential scope-based or global options (all potential LayerBase options) like param_device.
PyTorch also has this, like device or dtype (e.g. to use bfloat16), like:

with torch.cuda.device(1):
  ...

or device and dtype are also arguments to most modules in torch.nn.

@albertz albertz changed the title How to handle L2 and similar options How to handle L2, parameter initialization, weight dropout, and other things Nov 12, 2021
@albertz albertz changed the title How to handle L2, parameter initialization, weight dropout, and other things How to handle parameter initialization, regularization (L2, weight dropout, etc), maybe updater opts per-param Jan 1, 2022
@albertz
Copy link
Member Author

albertz commented Jan 1, 2022

Param init, regularization (per param) in most frameworks are options to the corresponding module or layer:

  • PyTorch
  • Keras
  • RETURNN

In lower level TF, param init and regularization is via a context manager (tf.variable_scope).

  • JAX/Flax: Param init per module. Regularization is separate.
  • Sonnet: Param init per module. Regularization is separate.

@albertz
Copy link
Member Author

albertz commented Jan 2, 2022

For param int: No matter if via direct argument, context scope or some other way, there need to be some sensible default, and sometimes also explicit overwrite.

The question is, should there be an easy way to overwrite the defaults globally? This could make sense to easily try out some new kind of param initialization. This could be via a parent context scope, or maybe like the PyTorch Module.apply logic.

The problem is, such global overwrite might not make sense for every parameter. Some parameters have other custom default init logic.

How to differentiate those from the standard ones?

One possibility is maybe some clustering of types of init, e.g. for linear matrix, bias, and maybe other types. And then you could globally overwrite the default per type. Although this is tricky and maybe hard to get right and will likely never fit all cases.

Also, if there are multiple (maybe nested) settings of defaults, which one will take precedence? The most recent one or the first (most outer) one? There are reasons for priority on the most outer one to be able to globally overwrite this, otherwise such outer settings would never have an effect. But there are maybe also cases where some inner setting should take precedence. Do we need a priority system for this? But that would make it way too complicated.

@albertz
Copy link
Member Author

albertz commented Jan 3, 2022

@albertz
Copy link
Member Author

albertz commented Jan 3, 2022

PyTorch Module.apply has even an example in its documentation on using it to overwrite param init:

@torch.no_grad()
def init_weights(m):
    print(m)
    if type(m) == nn.Linear:
        m.weight.fill_(1.0)
        print(m.weight)
net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
net.apply(init_weights)

Module.apply together with a forward pre-hook is also used to implement weight normalization. See here. The forward pre-hook calculates the original weights. In TF/RETURNN, such a pre-hook would not really be needed.

So maybe such apply API is a nice and simple solution for us as well.

One question of e.g. weight norm and other reparameterization is where and how the new underlying params are being added. In PyTorch weight norm, it adds them directly to the module, with postfix _g and _v. This is maybe a bit ugly, although it probably should never cause problems. What would be a cleaner solution though? Add them to the layer ref of the replaced param? This is currently not really supported because named_parameters only searches through named_children. But maybe this can be extended for layer refs. Not sure...

Recently there has been some discussion on PyTorch that the apply API is actually not so nice for reparameterization including weight norm, and a new API for reparameterization was proposed, discussed and implemented: pytorch/pytorch#7313, pytorch/pytorch#28937, pytorch/pytorch#33344, parametrizations tutorial.

I'm not sure though that the new PyTorch parametrization API makes it really cleaner and simpler. While it looks nice from a high-level viewpoint, it hides many things.

@albertz
Copy link
Member Author

albertz commented Jan 3, 2022

Note, our parameters are now always nn.Parameter instances, which derive from Layer, and thus LayerRef.

Maybe they should have some initial attribute.
(The implementation behind that would require some extra effort. We might want to map this to the init layer option of the underlying VariableLayer. However, we can not pass other layer refs there! So this needs some extension on RETURNN side. I'm not exactly sure how difficult this would be.)

Then overwriting the initial attribute on parameters would be straight-forward via such apply API.

@albertz

This comment has been minimized.

@albertz
Copy link
Member Author

albertz commented Jan 3, 2022

L2 on weights would probably also be done via apply. And this should be straight-forward using mark_as_loss (#56). Maybe mark_as_loss needs some extension or flag that this is just for regularization, just for nicer logs.

@albertz
Copy link
Member Author

albertz commented Jan 3, 2022

Weight dropout could just be done like weight norm explained above.

@albertz albertz mentioned this issue Jan 5, 2022
@albertz
Copy link
Member Author

albertz commented Jan 12, 2022

I wonder whether we need the param init as an argument to the modules at all. With such argument, there are two ways to overwrite the default:

linear = Linear(..., weight_init=...)

Vs:

linear = Linear(...)
linear.weight.initial = ...

But the second variant would be possible in any case. So, do we really need the first way then?

@albertz
Copy link
Member Author

albertz commented Jan 12, 2022

Things like param init, weight dropout, weight norm, L2 on weights, etc can all easily be handled in such post-processing way, and recursively via such apply API. So this doesn't need to be an argument of the module.

I also don't like it when there are too many arguments. So I would prefer to keep the argument list short.

Context managers are an alternative to arguments. Context managers might be added for param init or other regularization at some point (but not really needed; and I would only do that if the behavior would be simple and straight-forward). Context managers are probably the way to go for dtype and device.

One remaining question are options for model-based regularization such as dropout (on activations). Dropout is by far the most common thing but there might be other variants, and a complex model might also have various options where dropout could be added. If this is configured via model argument, this could easily become bloated when the model becomes complex or when there are other types of regularization in the future. So I also don't like this. Context managers might be somehow tricky here as well, to get this right with straight-forward behavior. However, we can maybe again do this via post-processing. Maybe modules could provide functions like set_dropout. Or maybe nn.scoped could return an object which allows for functions like register_forward_hook or register_forward_pre_hook where dropout could be added.

@albertz
Copy link
Member Author

albertz commented Jan 12, 2022

Another model-based regularization is stochastic depth (#99, #106). This can be implemented efficiently using Cond (#24) for train_flag() & random_uniform() > 0.5 as condition or so.

But the question here for this issue is again, how would the option look like? An argument to the module? Or some module method like set_stochastic_depth? Or via register_forward_hook on individual sub modules?

@albertz albertz changed the title How to handle parameter initialization, regularization (L2, weight dropout, etc), maybe updater opts per-param How to define the API for parameter initialization, regularization (L2, weight dropout, etc), maybe updater opts per-param Jan 17, 2022
@albertz
Copy link
Member Author

albertz commented Jan 17, 2022

Note that param init also requires some further thought on the technical side: #92

@albertz
Copy link
Member Author

albertz commented Jan 18, 2022

A lot of relevant similar discussion on the API how to define the param init is here: pytorch/pytorch#18182

@albertz
Copy link
Member Author

albertz commented Jan 18, 2022

Related is also what param init defaults we should use. See #94.

@albertz
Copy link
Member Author

albertz commented Feb 1, 2022

Related is TF arg_scope which provides a context scope to overwrite the defaults for kwargs like initializer or regularizer, example from OpenAI weightnorm. But I have not really seen much adoption of this and I think it was removed in TF 2.

@albertz
Copy link
Member Author

albertz commented Feb 1, 2022

Related is also #96 on explicit training loop and stages. When we implement that, the param init would be outside of the training loop, and probably it becomes all more clear.

@albertz
Copy link
Member Author

albertz commented Feb 11, 2022

I think this is settled now. For param init, we can use Module.apply to apply it recursively for a whole model or submodel, or otherwise individual per each param. For regularization on params, just the same.

@albertz albertz closed this as completed Feb 11, 2022
@albertz
Copy link
Member Author

albertz commented Feb 28, 2022

I think we should leave this open until this is really implemented, or until we maybe have a separate issue specifically on the technical aspect.

@albertz albertz reopened this Feb 28, 2022
@albertz
Copy link
Member Author

albertz commented Mar 20, 2022

Module.apply is implemented now. This allows for custom param init and many other things mentioned here.

L2 on params is still another aspect. We could just do sum(nn.reduce(param ** 2.0, mode="sum", axis=param.shape) for param in net.parameters()).mask_as_loss(). However, we know from literature (e.g. AdamW) that it is better to have this decoupled from the loss optimization, and not use momentum for this part. Because we don't really want L2, we want weight decay. Weight decay is equivalent to L2 only with SGD. The cleanest/simplest approach would probably be sth like explicit custom updates, like #90.

@albertz
Copy link
Member Author

albertz commented Mar 21, 2022

For L2 on params, we could also use the RETURNN mechanism directly as an option to the VariableLayer. This specifically would also allow to use the RETURNN option decouple_constraints directly.

albertz added a commit that referenced this issue Mar 22, 2022
@albertz
Copy link
Member Author

albertz commented May 2, 2022

The Parameter.weight_decay property is implemented now.

I think we can close this here for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant