Skip to content

Commit

Permalink
[Relay][Pytorch] Add aten::new_ones, aten::new_full, aten::fill_, ate…
Browse files Browse the repository at this point in the history
…n::pad, aten::reshape_as and atem::empty_like (#11896)

* add new ops

* fix pad

* fix pad

* remove pad

* fix CI

* remove doc

* fix fill_

* add tests
  • Loading branch information
Yuanjing Shi authored Jun 29, 2022
1 parent 7ef6811 commit 41c94b2
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 0 deletions.
55 changes: 55 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,21 @@ def ones_like(self, inputs, input_types):

return out

def new_ones(self, inputs, input_types):
size = inputs[1]

import torch

if not isinstance(size, (_expr.Expr, list, tuple, torch.Size, np.ndarray)):
msg = "Data type %s could not be parsed in ones op" % (type(size))
raise AssertionError(msg)

if inputs[2] is not None:
dtype = _convert_dtype_value(inputs[2])
else:
dtype = input_types[0]
return self.full_impl(size, 1, dtype)

def zeros(self, inputs, input_types):
data = inputs[0]

Expand Down Expand Up @@ -765,6 +780,28 @@ def full_like(self, inputs, input_types):

return out

def new_full(self, inputs, input_types):
data = inputs[1]
fill_value = inputs[2]
import torch

if not isinstance(data, (_expr.Expr, list, tuple, torch.Size)):
msg = "Data type %s could not be parsed in full op" % (type(data))
raise AssertionError(msg)

if inputs[3] is not None: # dtype given
dtype = _convert_dtype_value(inputs[3])
else:
# if dtype is None, use the dtype of the input tensor
dtype = self.infer_type(input[0])

return self.full_impl(data, fill_value, dtype)

def fill_(self, inputs, input_types):
data = inputs[0]
fill_value = inputs[1]
return self.full_impl(self.infer_shape(data), fill_value, input_types[0])

def linspace(self, inputs, input_types):
start = inputs[0]
stop = inputs[1]
Expand Down Expand Up @@ -1425,6 +1462,11 @@ def reshape(self, inputs, input_types):
new_shape = tmp_shape
return _op.transform.reshape(data, new_shape)

def reshape_as(self, inputs, input_types):
data = inputs[0]
new_shape = self.infer_shape(inputs[1])
return _op.transform.reshape(data, new_shape)

def pixel_shuffle(self, inputs, input_types):
data = inputs[0]
upscale_factor = inputs[1]
Expand Down Expand Up @@ -2400,6 +2442,14 @@ def empty(self, inputs, input_types):
shape = inputs[0]
return _op.zeros(shape, _convert_dtype_value(inputs[1]))

def empty_like(self, inputs, input_types):
shape = self.infer_shape(inputs[0])
if inputs[1] is not None:
dtype = _convert_dtype_value(inputs[1])
else:
dtype = input_types[0]
return _op.zeros(shape, dtype)

def bincount(self, inputs, input_types):
data = inputs[0]
weights = inputs[1]
Expand Down Expand Up @@ -3119,8 +3169,11 @@ def create_convert_map(self):
"aten::ones_like": self.ones_like,
"aten::zeros": self.zeros,
"aten::zeros_like": self.zeros_like,
"aten::new_ones": self.new_ones,
"aten::full": self.full,
"aten::full_like": self.full_like,
"aten::new_full": self.new_full,
"aten::fill_": self.fill_,
"aten::linspace": self.linspace,
"aten::reciprocal": self.reciprocal,
"aten::repeat": self.repeat,
Expand Down Expand Up @@ -3186,6 +3239,7 @@ def create_convert_map(self):
"aten::size": self.size,
"aten::view": self.view,
"aten::reshape": self.reshape,
"aten::reshape_as": self.reshape_as,
"aten::clone": self.clone,
"aten::log_softmax": self.log_softmax,
"aten::sigmoid": self.sigmoid,
Expand Down Expand Up @@ -3305,6 +3359,7 @@ def create_convert_map(self):
"aten::tensor": self.identity, # used for example in tensor(1.0)
"aten::numel": self.numel,
"aten::empty": self.empty,
"aten::empty_like": self.empty_like,
"aten::bincount": self.bincount,
"aten::scatter_add": self.scatter_add,
"aten::__not__": self.logical_not,
Expand Down
75 changes: 75 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,28 @@ def visit(op):
torch.cuda.empty_cache()


def verify_model_with_input(test_func, input_data, input_dict={}):
baseline_outputs = test_func(*input_data)
trace = torch.jit.trace(test_func, [input.clone() for input in input_data])
input_names = ["input{}".format(idx) for idx, inp in enumerate(input_data)]
input_shapes = list(zip(input_names, [inp.shape for inp in input_data]))
mod, params = relay.frontend.from_pytorch(trace, input_shapes, {})
with tvm.transform.PassContext(opt_level=3):
for target in ["llvm", "cuda"]:
if not tvm.runtime.enabled(target):
continue
dev = tvm.device(target, 0)
lib = relay.build(mod, target=target, params=params)
relay_model = graph_executor.GraphModule(lib["default"](dev))
for name, value in input_dict.items():
relay_model.set_input(name, value)
relay_model.run()

compiled_output = relay_model.get_output(0).numpy()
assert_shapes_match(baseline_outputs, compiled_output)
tvm.testing.assert_allclose(baseline_outputs, compiled_output, rtol=1e-5, atol=1e-5)


# Single operator tests
@tvm.testing.uses_gpu
def test_forward_pixel_shuffle():
Expand Down Expand Up @@ -1275,6 +1297,16 @@ def forward(self, x):
verify_model(Reshape3(), input_data=torch.randn(2, 3, 4))


@tvm.testing.uses_gpu
def test_forward_reshape_as():
def test_func(input_tensor, other_tensor):
return input_tensor.reshape_as(other_tensor)

input_data = [torch.rand([2, 1, 10, 1, 10]), torch.rand([2, 1, 10, 10])]

verify_model_with_input(test_func, input_data, {"input0": input_data[0]})


@tvm.testing.uses_gpu
def test_flatten():
def _test_flatten(start_dim, end_dim):
Expand Down Expand Up @@ -2961,6 +2993,17 @@ def forward(self, *args):
verify_model(OnesLike3().float().eval(), input_data=input_data)


@tvm.testing.uses_gpu
def test_forward_new_ones():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]

def test_func(input_tensor):
return input_tensor.new_ones([3, 10, 10])

verify_model_with_input(test_func, [torch.rand(input_shape).float()])


@tvm.testing.uses_gpu
def test_forward_zeros():
torch.set_grad_enabled(False)
Expand Down Expand Up @@ -3034,6 +3077,24 @@ def forward(self, *args):
verify_model(FullLike3().float().eval(), input_data=input_data)


@tvm.testing.uses_gpu
def test_forward_new_full():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]

def test_func(input_tensor):
return input_tensor.new_full([2, 3], 1)

verify_model_with_input(test_func, [torch.rand(input_shape).float()])


def test_forward_fill_():
def test_func(x):
return x.fill_(3)

verify_model_with_input(test_func, [torch.rand([1, 3, 10, 10]).float()])


@tvm.testing.uses_gpu
def test_forward_linspace():
torch.set_grad_enabled(False)
Expand Down Expand Up @@ -3752,6 +3813,20 @@ def forward(self, data):
verify_script_model(Numel(), [(3, 5, 8)], targets)


def test_empty():
def test_func():
return torch.empty([1, 3, 10, 10])

verify_model_with_input(test_func, [])


def test_empty_like():
def test_func(data):
return torch.empty_like(data)

verify_model_with_input(test_func, [torch.rand([1, 3, 10, 10]).float()])


def test_forward_pretrained_bert_base_uncased():
######################################################################
# This is an example how to run BERT models using TVM
Expand Down

0 comments on commit 41c94b2

Please sign in to comment.