-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Design/Handling of dimension tags #17
Comments
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
We should clarify this because this is very fundamental to how all our code will look like. The main question is to decide on these two options:
I think when we require them, we should also require them to be unique. I tend to prefer having them mandatory. This might make some parts maybe a bit more verbose from the first glance. But this should resolve any possible ambiguity on axes, and make the code very explicit and clear. We might however introduce some shorter aliases, or maybe restructure things a bit. For example:
So in your config, it could look like:
(Just a draft.) |
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
Would this imply that "F" or "T" are also not allowed anymore? |
This comment has been minimized.
This comment has been minimized.
Sorry, maybe I was confusing here. This is a design decision we should take. We can either still allow Or maybe we just should change the behavior (on RETURNN side or here in returnn-common): Whenever
|
Also, we would need to think about how |
Ah yes, good point. Maybe for the There could be an Or, when we do #47, the user could explicitly do it afterwards like:
Maybe that is cleaner? |
That would be even better, yes. Then every module checking shapes (which ideally, every module should do I think), would have something like this on top:
This wouldn't work for intermediate states right? (Also, when nesting modules, the outer module would also need to know the common dims to assert the output of the inner module). |
The However, if you are the author of We should be careful to not make it too verbose or too complicated. When the code looks complex from a first glance, this is also bad. Another aspect: Probably many modules will be lazy w.r.t. the input dim, just like it is standard in RETURNN and other frameworks. E.g. the user would specify |
So, I currently have this class Linear(nn.Module):
"""
Linear transformation.
"""
def __init__(self, out_dim: nn.Dim, *, in_dim: Optional[nn.Dim] = None):
super().__init__()
self.out_dim = out_dim
self.out_dim_inner = out_dim
self.in_dim = in_dim
self.weight = None # type: Optional[nn.Parameter]
self.bias = None # type: Optional[nn.Parameter]
if in_dim:
self._lazy_init(in_dim)
def _lazy_init(self, in_dim: nn.Dim):
if self.in_dim:
assert self.in_dim == in_dim
else:
self.in_dim = in_dim
if in_dim == self.out_dim:
self.out_dim_inner = self.out_dim.copy(same_as_self=False, description=f"{self}:out-dim-inner")
self.weight = nn.Parameter((self.in_dim, self.out_dim_inner))
self.bias = nn.Parameter((self.out_dim_inner,))
@nn.scoped
def __call__(self, source: nn.LayerRef) -> nn.Layer:
self._lazy_init(source.dim)
out = nn.dot(source, self.weight, reduce=self.in_dim) + self.bias
if self.out_dim_inner != self.out_dim:
out = nn.reinterpret_data(out, set_dim_tags={self.out_dim_inner: self.out_dim})
return out You see that I needed to introduce this I'm not really sure about the best way to solve this. I'm not really happy with this solution now. @Zettelkasten any suggestions or ideas? |
While further thinking about this, some thoughts and options:
|
Via: rwth-i6/returnn_common#17 (comment) Also: Data.get_axes_from_description, only unique matching for dim tags
Another thing: I wonder whether passing
This is for dynamic sizes. For static dims, it is even more complicated, as you need to calculate the dim by hand (although maybe you could argue, it is a good thing to have it explicit). Maybe the module should instead just create such a dim itself, and then return it? One question is, how should it be returned? Just For example, it would then look like:
I think this is nicer. The case where it is known beforehand (or expected to match some other existing dim) is maybe valid, but the user could easily still do that by calling
Or consider the @nn.scoped
def __call__(self, inp: nn.LayerRef, *, in_spatial_dim: nn.Dim, out_spatial_dim: nn.Dim) -> nn.LayerRef:
"""forward"""
in_spatial_dims = [in_spatial_dim, inp.feature_dim]
in_dim = nn.FeatureDim("dummy-input-feature-dim", 1)
x = nn.expand_dim(inp, dim=in_dim)
for i, conv_layer in enumerate(self.conv_layers):
out_spatial_dims = [nn.SpatialDim(f"conv-{i}-1"), nn.SpatialDim(f"conv-{i}-2")]
x = conv_layer(x, in_dim=in_dim, in_spatial_dims=in_spatial_dims, out_spatial_dims=out_spatial_dims)
in_spatial_dims = out_spatial_dims
in_dim = conv_layer.out_dim
x = self.activation(x)
if self.pool_sizes and i < len(self.pool_sizes):
x = nn.pool(
x, in_dim=in_dim, in_spatial_dims=in_spatial_dims,
pool_size=self.pool_sizes[i], padding='same', mode='max')
if self.dropout:
x = nn.dropout(x, axis=in_dim, dropout=self.dropout)
out = nn.merge_dims(x, axes=in_spatial_dims, out_dim=out_spatial_dim)
return out With the suggestion, it would look like: @nn.scoped
def __call__(self, inp: nn.LayerRef, *, in_spatial_dim: nn.Dim) -> nn.LayerRef:
"""forward"""
in_spatial_dims = [in_spatial_dim, inp.feature_dim]
in_dim = nn.FeatureDim("dummy-input-feature-dim", 1)
x = nn.expand_dim(inp, dim=in_dim)
for i, conv_layer in enumerate(self.conv_layers):
x, in_spatial_dims = conv_layer(x, in_dim=in_dim, in_spatial_dims=in_spatial_dims)
in_dim = conv_layer.out_dim
x = self.activation(x)
if self.pool_sizes and i < len(self.pool_sizes):
x = nn.pool(
x, in_dim=in_dim, in_spatial_dims=in_spatial_dims,
pool_size=self.pool_sizes[i], padding='same', mode='max')
if self.dropout:
x = nn.dropout(x, axis=in_dim, dropout=self.dropout)
out, out_spatial_dim = nn.merge_dims(x, axes=in_spatial_dims)
return out, out_spatial_dim @Zettelkasten opinions? Edit This is implemented like this now. |
This was wrapping reinterpret_data with set_axes. There should be no layer in returnn-common which would depend on this, so this should never be needed. If there are cases where this is needed, we probably could and should fix them such that it is not needed. #17
The last suggestion was implemented now, i.e. namely that functions and modules return a potentially new dim tag, and the argument (e.g. |
Remaining is the issue with |
Another external data point is Mesh TensorFlow, and the gpt-neo implementation. Examples:
Note that a Dimension = collections.namedtuple("Dimension", ["name", "size"]) where MTF also requires unique dimensions in a shape (see here). |
This comment was marked as duplicate.
This comment was marked as duplicate.
Solution for ambiguous dim tags, e.g. in VariableLayer for square matrix. Via: rwth-i6/returnn_common#17 (comment)
Now we need to adopt the code here accordingly (get rid of |
This is all implemented now. |
Like batch dim, spatial dims (with dynamic lengths), or static dims (named, or also unnamed).
Dim
(earlierDimensionTag
) in RETURNN. Directly use that, or wrap it somehow?Should this (the batch dim) include beam information, or be separate from this?
Relevant for all layers which define some shape or dimension (e.g.
Const
,Variable
).Should this be enforced, i.e. no simple
int
allowed inn_out
or so but always aDim
object?And maybe better use
out_dim
instead ofn_out
(consistent with rwth-i6/returnn#597).Edit: It was decided to make
nn.Dim
mandatory, and useout_dim
instead ofn_out
.Very related is this issue on RETURNN side on explicit dim tags: rwth-i6/returnn#597
Related is also whether we want unique dim tags? (#48, rwth-i6/returnn#632)
This issues covers multiple aspects:
Dim
. We inherit all its logic on equality etc. We also haveFeatureDim
,SpatialDim
from RETURNN.Dim
instances) are mandatory for any shape or sizeTensor.verify_out_shape
for easy verificationin_dim == out_dim
, square matrices, Design/Handling of dimension tags #17 (comment), Solution for ambiguous dim tags returnn#871out_spatial_dim
argument, the a new spatial dim gets returned. Seepool1d
for example.The text was updated successfully, but these errors were encountered: