-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay] [PyTorch] Add aten::broadcast_tensors #11863
Conversation
import torch | ||
|
||
res_shape = list(torch.broadcast_shapes(*[self.infer_shape(t) for t in tensor_list])) | ||
return [_op.broadcast_to(tensor, res_shape) for tensor in tensor_list] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh interesting. Have you verified that this is what broadcast_tensors
does? It's interesting that we are actually invoking pytorch function for conversion purpose.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review! This is definition of broadcast_shapes
in pytorch's doc:
This is equivalent to torch.broadcast_tensors(*map(torch.empty, shapes))[0].shape but avoids the need create to intermediate tensors. This is useful for broadcasting tensors of common batch shape but different rightmost shape, e.g. to broadcast mean vectors with covariance matrices.
This workaround avoid the need to create intermediate tensors on our end. broadcast_tensors
works like the following:
a.shape = [1,1,2]
b.shape = [1,3,1]
c.shape = [4,1,1]
#x, y, z = broadcast_tensors(a,b,c)
x = a.broadcast_to(shape = [4,3,2])
y = b.broadcast_to(shape = [4,3,2])
y = b.broadcast_to(shape = [4,3,2])
meaning the resulting tensors of broadcast_tensors
will all have the same shape [4,3,2]
. This is a bit different from what we discussed yesterday but seems to be the right broadcasting semantics.
reference: https://numpy.org/doc/stable/user/basics.broadcasting.html
5f2b03a
to
12d8be3
Compare
* add aten::broadcast_tensors * add entry * fix test
* add aten::broadcast_tensors * add entry * fix test
* add aten::broadcast_tensors * add entry * fix test
This PR intends to add
aten::broadcast_tensors
ops in the PyTorch frontend.@masahi @junrushao1994 @zxybazh