Skip to content

Commit

Permalink
update unittest for interpolate
Browse files Browse the repository at this point in the history
  • Loading branch information
tink2123 committed Jan 20, 2023
1 parent d706fe7 commit 324fccb
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 42 deletions.
38 changes: 22 additions & 16 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1419,16 +1419,18 @@ static void Interpolate1DInferShapeCheck(
if (scale_tensor) {
auto scale_tensor_dim = scale_tensor.dims();
PADDLE_ENFORCE_EQ(
scale_tensor_dim.size(),
1,
scale_tensor_dim.size() == 1 || scale_tensor_dim.size() == 0,
true,
phi::errors::InvalidArgument(
"Scale's dimension size must be 1, but got dimension = %d .",
"Scale's dimension size must be 1 or 0, but got dimension = %d .",
scale_tensor_dim.size()));
PADDLE_ENFORCE_EQ(scale_tensor_dim[0],
1,
phi::errors::InvalidArgument(
"Scale's shape must be 1, but got shape = %d .",
scale_tensor_dim[0]));
if (scale_tensor_dim.size()) {
PADDLE_ENFORCE_EQ(scale_tensor_dim[0],
1,
phi::errors::InvalidArgument(
"Scale's shape must be 1, but got shape = %d .",
scale_tensor_dim[0]));
}
out_w_tmp = -1;
} else {
if (scale.size() > 0) {
Expand Down Expand Up @@ -1555,11 +1557,15 @@ static void Interpolate2DInferShapeCheck(
"Scale's dimension size must be 1 or 0, but got dimension = %d .",
scale_tensor_dim.size()));

PADDLE_ENFORCE_EQ(scale_tensor_dim[0] == 2 || scale_tensor_dim[0] == 1,
true,
phi::errors::InvalidArgument(
"Scale's shape must be 2 or 1, but got shape = %d .",
scale_tensor_dim[0]));
if (scale_tensor_dim.size() == 1) {
PADDLE_ENFORCE_EQ(
scale_tensor_dim[0] == 2 || scale_tensor_dim[0] == 1,
true,
phi::errors::InvalidArgument(
"Scale's shape must be 2 or 1, but got shape = %d .",
scale_tensor_dim[0]));
}

out_h_tmp = -1;
out_w_tmp = -1;
} else {
Expand Down Expand Up @@ -1692,10 +1698,10 @@ static void Interpolate3DInferShapeCheck(
if (scale_tensor) {
auto scale_tensor_dim = scale_tensor.dims();
PADDLE_ENFORCE_EQ(
scale_tensor_dim.size(),
1,
scale_tensor_dim.size() == 1 || scale_tensor_dim.size() == 0,
true,
phi::errors::InvalidArgument(
"Scale's dimension size must be 1, but got size = %d .",
"Scale's dimension size must be 1 or 0, but got size = %d .",
scale_tensor_dim.size()));
PADDLE_ENFORCE_EQ(scale_tensor_dim[0] == 3 || scale_tensor_dim[0] == 1,
true,
Expand Down
14 changes: 7 additions & 7 deletions paddle/phi/kernels/funcs/interpolate_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,13 @@ inline std::vector<int> get_new_shape(
std::vector<int> vec_new_shape;
for (size_t i = 0; i < list_new_shape_tensor.size(); ++i) {
auto tensor = list_new_shape_tensor[i];
PADDLE_ENFORCE_EQ(
tensor->dims() == phi::make_ddim({1}) || tensor->dims() == phi::make_ddim({}),
true,
errors::InvalidArgument("The shape of dimension tensor should be [1] or [],"
"but received d%.",
tensor->dims()));

PADDLE_ENFORCE_EQ(tensor->dims() == phi::make_ddim({1}) ||
tensor->dims() == phi::make_ddim({}),
true,
errors::InvalidArgument(
"The shape of dimension tensor should be [1] or [],"
"but received d%.",
tensor->dims()));

#ifdef PADDLE_WITH_XPU
if (tensor->place().GetType() == phi::AllocationType::XPU) {
Expand Down
38 changes: 23 additions & 15 deletions python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1174,12 +1174,10 @@ def test_atan2(self):
self.assertEqual(x2.grad.numpy(), 0)

def test_interpolate(self):
import numpy as np

from paddle.nn.functional import interpolate

input_data = np.random.random((2, 3, 6, 6)).astype("float32")
input_x = paddle.to_tensor(input_data)
input_x = paddle.rand([2, 3, 6, 6])
input_x.stop_gradient = False
origin_result = interpolate(
x=input_x, size=[12, 12], mode="bilinear", align_corners=False
)
Expand All @@ -1191,6 +1189,10 @@ def test_interpolate(self):
out1 = interpolate(
x=input_x, size=output_size, mode="bilinear", align_corners=False
)
out1.backward()

self.assertEqual(out1.shape, [2, 3, 12, 12])
self.assertEqual(input_x.grad.shape, [2, 3, 6, 6])

scale_1 = [paddle.full([], 2), paddle.full([], 2)]
out2 = interpolate(
Expand All @@ -1199,6 +1201,10 @@ def test_interpolate(self):
mode="bilinear",
align_corners=False,
)
out2.backward()

self.assertEqual(out2.shape, [2, 3, 12, 12])
self.assertEqual(input_x.grad.shape, [2, 3, 6, 6])

scale_2 = paddle.full([], 2)
out3 = interpolate(
Expand All @@ -1207,8 +1213,10 @@ def test_interpolate(self):
mode="bilinear",
align_corners=False,
)
out3.backward()

out1.backward()
self.assertEqual(out3.shape, [2, 3, 12, 12])
self.assertEqual(input_x.grad.shape, [2, 3, 6, 6])

np.testing.assert_allclose(
origin_result.numpy(), out1.numpy(), rtol=1e-05
Expand Down Expand Up @@ -1852,38 +1860,38 @@ def test_atan2(self):

@prog_scope()
def test_interpolate(self):
import numpy as np

from paddle.nn.functional import interpolate

input_data = np.random.random((2, 3, 6, 6)).astype("float32")
input_x = paddle.to_tensor(input_data)
input_x = paddle.rand([2, 3, 6, 6])
input_x.stop_gradient = False

output_size = [
paddle.full([], 12, dtype="int32"),
paddle.full([], 12, dtype="int32"),
]

out1 = interpolate(
x=input_x, size=output_size, mode="bilinear", align_corners=False
)

paddle.static.append_backward(out1)
paddle.static.append_backward(out1.sum())
prog = paddle.static.default_main_program()
res1 = self.exe.run(prog, feed={}, fetch_list=[out1])
res1 = self.exe.run(prog, feed={}, fetch_list=[out1, input_x.grad_name])

scale_1 = np.array(2)
scale_1 = paddle.full([], 2)
out2 = interpolate(
x=input_x,
scale_factor=scale_1,
mode="bilinear",
align_corners=False,
)
paddle.static.append_backward(out2)
paddle.static.append_backward(out2.sum())
prog = paddle.static.default_main_program()
res2 = self.exe.run(prog, feed={}, fetch_list=[out2])
res2 = self.exe.run(prog, feed={}, fetch_list=[out2, input_x.grad_name])

self.assertEqual(res1[0].shape, (2, 3, 12, 12))
self.assertEqual(res1[1].shape, (2, 3, 6, 6))
self.assertEqual(res2[0].shape, (2, 3, 12, 12))
self.assertEqual(res2[1].shape, (2, 3, 6, 6))


# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
Expand Down
7 changes: 3 additions & 4 deletions python/paddle/nn/functional/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,11 +483,10 @@ def _is_list_or_turple_(data):
out_shape = list(out_shape.numpy())
else:
out_shape = list(out_shape)

for i, dim in enumerate(out_shape):
if isinstance(dim, Variable) and dim.shape != []:
out_shape[i] = dim.numpy()[0]
elif isinstance(dim, Variable) and dim.shape == []:
out_shape[i] = dim.numpy()
if isinstance(dim, Variable):
out_shape[i] = dim.numpy().item()
if not (_is_list_or_turple_(out_shape)):
raise TypeError("size should be a list or tuple or Variable.")
# Validate the shape
Expand Down

0 comments on commit 324fccb

Please sign in to comment.