diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index ba0d025026f9..3e0bf64e4c1c 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1971,6 +1971,13 @@ def expand_as(self, inputs, input_types): target = _op.cast(target, t0) return _op.broadcast_to_like(inputs[0], target) + def broadcast_tensors(self, inputs, input_types): + tensor_list = inputs[0] + import torch + + res_shape = list(torch.broadcast_shapes(*[self.infer_shape(t) for t in tensor_list])) + return [_op.broadcast_to(tensor, res_shape) for tensor in tensor_list] + def Bool(self, inputs, input_types): assert len(inputs) == 1 return inputs[0] @@ -3189,6 +3196,7 @@ def create_convert_map(self): "aten::upsample_trilinear3d": self.make_upsample3d("linear"), "aten::upsample_nearest3d": self.make_upsample3d("nearest_neighbor"), "aten::expand_as": self.expand_as, + "aten::broadcast_tensors": self.broadcast_tensors, "aten::lt": self.make_elemwise("less"), "aten::gt": self.make_elemwise("greater"), "aten::le": self.make_elemwise("less_equal"), diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 9609008c9969..e4cb6354c017 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1791,6 +1791,28 @@ def forward(self, *args): verify_model(Expand2().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu +def test_forward_broadcast_tensors(): + torch.set_grad_enabled(False) + + class BroadCastTensors1(Module): + def forward(self, x, y): + return torch.broadcast_tensors(x, y) + + x = torch.arange(3).view(1, 1, 3) + y = torch.arange(2).view(1, 2, 1) + verify_model(BroadCastTensors1().float().eval(), input_data=[x, y]) + + class BroadCastTensors2(Module): + def forward(self, x, y, z): + return torch.broadcast_tensors(x, y, z) + + x = torch.arange(3).view(1, 1, 3) + y = torch.arange(2).view(1, 2, 1) + z = torch.arange(4).view(4, 1, 1) + verify_model(BroadCastTensors2().float().eval(), input_data=[x, y, z]) + + @tvm.testing.uses_gpu def test_forward_pow(): torch.set_grad_enabled(False)