-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
using sum()
causes dtype mismatch
#135
Comments
Yea, So, obviously we cannot change how We could add an optimization on returnn-common side which automatically reduces We could extend RETURNN Not sure which way is the best or you would prefer. I think the returnn-common optimization would be nice anyway and would already solve this. |
Note that such upcasting (promotion) rules are well defined in PyTorch/JAX/Numpy. https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.dtype I think TF in its normal behavior actually does not have it. |
Okay then lets start with the optimizations like x+0 -> x; x-0 -> x; x*1 -> x and so on... I think those are a good idea anyway. |
Ok done. |
Btw, if you want to have an L2 loss, I would not recommend to do it this way. Rather, see #59. And you should also enable |
Writing e.g.
l2loss = sum(nn.reduce(param ** 2.0, axis=list(param.shape), mode="sum") for param in net.parameters())
Causes a dtype mismatch when running the network (not during construction with
returnn_common
surprisingly):The error on the Returnn side is justified, because the generated layers are:
As the value is
0
and not0.0
Returnn assumes int32 instead of float32 (float32 is the correct type of thereduce
output). The0
is added probably on the firstsum()
internal call of:I am not sure how to deal with the correct dtypes here, as when calling
sum()
we do may not exactly know which dtype we want to have.One thing I could image is that we infer the type from
self
ifother
is a zero, and extend thenn.convert_to_tensor
to have a dtype parameter.The text was updated successfully, but these errors were encountered: