Skip to content

Commit

Permalink
Add nearest_interp/v2 int8 and uint8 support (#37985)
Browse files Browse the repository at this point in the history
  • Loading branch information
wozna authored Dec 22, 2021
1 parent abb07f3 commit 56e2a6a
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 10 deletions.
8 changes: 6 additions & 2 deletions paddle/fluid/operators/mkldnn/interpolate_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,15 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;

REGISTER_OP_KERNEL(nearest_interp, MKLDNN, ::paddle::platform::CPUPlace,
ops::InterpolateMKLDNNKernel<float>);
ops::InterpolateMKLDNNKernel<float>,
ops::InterpolateMKLDNNKernel<int8_t>,
ops::InterpolateMKLDNNKernel<uint8_t>);
REGISTER_OP_KERNEL(bilinear_interp, MKLDNN, ::paddle::platform::CPUPlace,
ops::InterpolateMKLDNNKernel<float>);

REGISTER_OP_KERNEL(nearest_interp_v2, MKLDNN, ::paddle::platform::CPUPlace,
ops::InterpolateMKLDNNKernel<float>);
ops::InterpolateMKLDNNKernel<float>,
ops::InterpolateMKLDNNKernel<int8_t>,
ops::InterpolateMKLDNNKernel<uint8_t>);
REGISTER_OP_KERNEL(bilinear_interp_v2, MKLDNN, ::paddle::platform::CPUPlace,
ops::InterpolateMKLDNNKernel<float>);
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@

import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid.tests.unittests.op_test import OpTest
from paddle.fluid.tests.unittests.op_test import skip_check_grad_ci

Expand Down Expand Up @@ -66,13 +63,17 @@ class TestNearestInterpMKLDNNOp(OpTest):
def init_test_case(self):
pass

def init_data_type(self):
pass

def setUp(self):
self.op_type = "nearest_interp"
self.interp_method = 'nearest'
self._cpu_only = True
self.use_mkldnn = True
self.input_shape = [1, 1, 2, 2]
self.data_layout = 'NCHW'
self.dtype = np.float32
# priority: actual_shape > out_size > scale > out_h & out_w
self.out_h = 1
self.out_w = 1
Expand All @@ -81,8 +82,15 @@ def setUp(self):
self.actual_shape = None

self.init_test_case()
self.init_data_type()

if self.dtype == np.float32:
input_np = np.random.random(self.input_shape).astype(self.dtype)
else:
init_low, init_high = (-5, 5) if self.dtype == np.int8 else (0, 10)
input_np = np.random.randint(init_low, init_high,
self.input_shape).astype(self.dtype)

input_np = np.random.random(self.input_shape).astype("float32")
if self.data_layout == "NCHW":
in_h = self.input_shape[2]
in_w = self.input_shape[3]
Expand Down Expand Up @@ -162,6 +170,35 @@ def init_test_case(self):
self.scale = 0.


def create_test_class(parent):
class TestFp32Case(parent):
def init_data_type(self):
self.dtype = np.float32

class TestInt8Case(parent):
def init_data_type(self):
self.dtype = np.int8

class TestUint8Case(parent):
def init_data_type(self):
self.dtype = np.uint8

TestFp32Case.__name__ = parent.__name__
TestInt8Case.__name__ = parent.__name__
TestUint8Case.__name__ = parent.__name__
globals()[parent.__name__] = TestFp32Case
globals()[parent.__name__] = TestInt8Case
globals()[parent.__name__] = TestUint8Case


create_test_class(TestNearestInterpMKLDNNOp)
create_test_class(TestNearestInterpOpMKLDNNNHWC)
create_test_class(TestNearestNeighborInterpMKLDNNCase2)
create_test_class(TestNearestNeighborInterpCase3)
create_test_class(TestNearestNeighborInterpCase4)
create_test_class(TestNearestInterpOpMKLDNNNHWC)
create_test_class(TestNearestNeighborInterpSame)

if __name__ == "__main__":
from paddle import enable_static
enable_static()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@

import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid.tests.unittests.op_test import OpTest
from paddle.fluid.tests.unittests.op_test import skip_check_grad_ci

Expand Down Expand Up @@ -66,13 +63,17 @@ class TestNearestInterpV2MKLDNNOp(OpTest):
def init_test_case(self):
pass

def init_data_type(self):
pass

def setUp(self):
self.op_type = "nearest_interp_v2"
self.interp_method = 'nearest'
self._cpu_only = True
self.use_mkldnn = True
self.input_shape = [1, 1, 2, 2]
self.data_layout = 'NCHW'
self.dtype = np.float32
# priority: actual_shape > out_size > scale > out_h & out_w
self.out_h = 1
self.out_w = 1
Expand All @@ -81,8 +82,15 @@ def setUp(self):
self.actual_shape = None

self.init_test_case()
self.init_data_type()

if self.dtype == np.float32:
input_np = np.random.random(self.input_shape).astype(self.dtype)
else:
init_low, init_high = (-5, 5) if self.dtype == np.int8 else (0, 10)
input_np = np.random.randint(init_low, init_high,
self.input_shape).astype(self.dtype)

input_np = np.random.random(self.input_shape).astype("float32")
if self.data_layout == "NCHW":
in_h = self.input_shape[2]
in_w = self.input_shape[3]
Expand Down Expand Up @@ -178,6 +186,34 @@ def init_test_case(self):
self.out_size = np.array([65, 129]).astype("int32")


def create_test_class(parent):
class TestFp32Case(parent):
def init_data_type(self):
self.dtype = np.float32

class TestInt8Case(parent):
def init_data_type(self):
self.dtype = np.int8

class TestUint8Case(parent):
def init_data_type(self):
self.dtype = np.uint8

TestFp32Case.__name__ = parent.__name__
TestInt8Case.__name__ = parent.__name__
TestUint8Case.__name__ = parent.__name__
globals()[parent.__name__] = TestFp32Case
globals()[parent.__name__] = TestInt8Case
globals()[parent.__name__] = TestUint8Case


create_test_class(TestNearestInterpV2MKLDNNOp)
create_test_class(TestNearestInterpOpV2MKLDNNNHWC)
create_test_class(TestNearestNeighborInterpV2MKLDNNCase2)
create_test_class(TestNearestNeighborInterpV2MKLDNNCase3)
create_test_class(TestNearestNeighborInterpV2MKLDNNCase4)
create_test_class(TestNearestNeighborInterpV2MKLDNNSame)

if __name__ == "__main__":
from paddle import enable_static
enable_static()
Expand Down

0 comments on commit 56e2a6a

Please sign in to comment.