Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cherry pick #43088 #40664] Add float16 to fake quantize/dequantize OP #43689

Merged
merged 2 commits into from
Jun 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions paddle/fluid/operators/fake_dequantize_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@ limitations under the License. */

namespace ops = paddle::operators;
using CUDA = paddle::platform::CUDADeviceContext;
using float16 = paddle::platform::float16;
REGISTER_OP_CUDA_KERNEL(fake_dequantize_max_abs,
ops::FakeDequantizeMaxAbsKernel<CUDA, float>,
ops::FakeDequantizeMaxAbsKernel<CUDA, double>);
ops::FakeDequantizeMaxAbsKernel<CUDA, double>,
ops::FakeDequantizeMaxAbsKernel<CUDA, float16>);
REGISTER_OP_CUDA_KERNEL(
fake_channel_wise_dequantize_max_abs,
ops::FakeChannelWiseDequantizeMaxAbsKernel<CUDA, float>,
ops::FakeChannelWiseDequantizeMaxAbsKernel<CUDA, double>);
ops::FakeChannelWiseDequantizeMaxAbsKernel<CUDA, double>,
ops::FakeChannelWiseDequantizeMaxAbsKernel<CUDA, float16>);
15 changes: 10 additions & 5 deletions paddle/fluid/operators/fake_quantize_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,22 @@ namespace ops = paddle::operators;
using CUDA = paddle::platform::CUDADeviceContext;
using float16 = paddle::platform::float16;
REGISTER_OP_CUDA_KERNEL(fake_quantize_abs_max,
ops::FakeQuantizeAbsMaxKernel<CUDA, float>);
ops::FakeQuantizeAbsMaxKernel<CUDA, float>,
ops::FakeQuantizeAbsMaxKernel<CUDA, float16>);
REGISTER_OP_CUDA_KERNEL(fake_quantize_dequantize_abs_max,
ops::FakeQuantizeDequantizeAbsMaxKernel<CUDA, float>,
ops::FakeQuantizeDequantizeAbsMaxKernel<CUDA, float16>);
REGISTER_OP_CUDA_KERNEL(fake_channel_wise_quantize_abs_max,
ops::FakeChannelWiseQuantizeAbsMaxKernel<CUDA, float>);
REGISTER_OP_CUDA_KERNEL(
fake_channel_wise_quantize_abs_max,
ops::FakeChannelWiseQuantizeAbsMaxKernel<CUDA, float>,
ops::FakeChannelWiseQuantizeAbsMaxKernel<CUDA, float16>);
REGISTER_OP_CUDA_KERNEL(fake_quantize_range_abs_max,
ops::FakeQuantizeRangeAbsMaxKernel<CUDA, float>);
ops::FakeQuantizeRangeAbsMaxKernel<CUDA, float>,
ops::FakeQuantizeRangeAbsMaxKernel<CUDA, float16>);
REGISTER_OP_CUDA_KERNEL(
fake_quantize_moving_average_abs_max,
ops::FakeQuantizeMovingAverageAbsMaxKernel<CUDA, float>);
ops::FakeQuantizeMovingAverageAbsMaxKernel<CUDA, float>,
ops::FakeQuantizeMovingAverageAbsMaxKernel<CUDA, float16>);
REGISTER_OP_CUDA_KERNEL(moving_average_abs_max_scale,
ops::MovingAverageAbsMaxScaleKernel<CUDA, float>,
ops::MovingAverageAbsMaxScaleKernel<CUDA, float16>);
Expand Down
86 changes: 57 additions & 29 deletions paddle/fluid/operators/fake_quantize_op.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,24 @@ limitations under the License. */
#endif // PADDLE_FLUID_OPERATORS_FAKE_QUANTIZE_OP_CU_H_

#include <string>

#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/fake_quantize_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"

namespace paddle {
namespace operators {

template <typename T>
struct QuantizeDataType {
using type = T;
};

template <>
struct QuantizeDataType<paddle::platform::float16> {
using type = float;
};

template <typename T>
__global__ void FindAbsMaxKernel(const T* in, const int n, T* out) {
int bid = threadIdx.x + blockIdx.x * blockDim.x;
Expand Down Expand Up @@ -87,10 +98,12 @@ __global__ void FindChannelAbsMaxKernelQuantAxis0(const T* in, const int n,
int tid = threadIdx.x;
int channel_size = n / c;
const T* in_c = in + blockIdx.x * channel_size;
extern __shared__ T shared_max_data[];
extern __shared__ char* shared_max_data_tmp[];
auto shared_max_data = reinterpret_cast<T*>(shared_max_data_tmp);
T local_max_data = T(0);
for (int i = tid; i < channel_size; i += blockDim.x) {
T tmp = fabs(in_c[i]);
T tmp = static_cast<T>(
fabs(static_cast<typename QuantizeDataType<T>::type>(in_c[i])));
if (tmp > local_max_data) {
local_max_data = tmp;
}
Expand All @@ -112,7 +125,8 @@ template <typename T>
__global__ void FindChannelAbsMaxKernelQuantAxis1(const T* in, const int n,
const int cin, const int cout,
T* out) {
extern __shared__ T shared_max_data[];
extern __shared__ char* shared_max_data_tmp[];
auto shared_max_data = reinterpret_cast<T*>(shared_max_data_tmp);
int cout_wh_size = n / cin;
int wh_size = n / (cin * cout);

Expand All @@ -121,7 +135,8 @@ __global__ void FindChannelAbsMaxKernelQuantAxis1(const T* in, const int n,
const T* in_current = in + tid * cout_wh_size + bid * wh_size;
T local_max_data = T(0);
for (int i = 0; i < wh_size; i++) {
T tmp = fabs(in_current[i]);
T tmp = static_cast<T>(
fabs(static_cast<typename QuantizeDataType<T>::type>(in_current[i])));
if (tmp > local_max_data) {
local_max_data = tmp;
}
Expand Down Expand Up @@ -203,14 +218,18 @@ __global__ void ClipAndQuantKernel(const T* in, const T* scale,
int bid = threadIdx.x + blockIdx.x * blockDim.x;
int tid = threadIdx.x;

T s = scale[0];
T inv_s = inverse(s);
using ComputeDataType = typename QuantizeDataType<T>::type;

ComputeDataType s = static_cast<ComputeDataType>(scale[0]);
ComputeDataType inv_s = inverse(s);
ComputeDataType bin_cnt_t = static_cast<ComputeDataType>(bin_cnt);

for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
T x = in[i];
T v = x > s ? s : x;
ComputeDataType x = static_cast<ComputeDataType>(in[i]);
ComputeDataType v = x > s ? s : x;
v = v < -s ? -s : v;
v = bin_cnt * inv_s * v;
out[i] = round(v);
v = bin_cnt_t * inv_s * v;
out[i] = static_cast<T>(round(v));
}
}

Expand All @@ -221,17 +240,19 @@ __global__ void ClipAndQuantDequantKernel(const T* in, const T* scale,
int bid = threadIdx.x + blockIdx.x * blockDim.x;
int tid = threadIdx.x;

T s = scale[0];
T inv_s = inverse(s);
T bin_cnt_t = static_cast<T>(bin_cnt);
using ComputeDataType = typename QuantizeDataType<T>::type;

ComputeDataType s = static_cast<ComputeDataType>(scale[0]);
ComputeDataType inv_s = inverse(s);
ComputeDataType bin_cnt_t = static_cast<ComputeDataType>(bin_cnt);

for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
T x = in[i];
ComputeDataType x = static_cast<ComputeDataType>(in[i]);
x = x > s ? s : x;
x = x < -s ? -s : x;
x = bin_cnt_t * inv_s * x;
x = static_cast<T>(round(static_cast<float>(x)));
out[i] = (x * s) / bin_cnt_t;
x = round(x);
out[i] = static_cast<T>((x * s) / bin_cnt_t);
}
}

Expand Down Expand Up @@ -285,15 +306,18 @@ __global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale,
const T* in_c = in + blockIdx.x * channel_size;
T* out_c = out + blockIdx.x * channel_size;

T s = scale[blockIdx.x];
T inv_s = inverse(s);
using ComputeDataType = typename QuantizeDataType<T>::type;

ComputeDataType s = static_cast<ComputeDataType>(scale[blockIdx.x]);
ComputeDataType inv_s = inverse(s);
ComputeDataType bin_cnt_t = static_cast<ComputeDataType>(bin_cnt);

for (int64_t i = tid; i < channel_size; i += blockDim.x) {
T x = in_c[i];
T v = x > s ? s : x;
ComputeDataType x = static_cast<ComputeDataType>(in_c[i]);
ComputeDataType v = x > s ? s : x;
v = v < -s ? -s : v;
v = bin_cnt * inv_s * v;
out_c[i] = round(v);
v = bin_cnt_t * inv_s * v;
out_c[i] = static_cast<T>(round(v));
}
}

Expand All @@ -303,14 +327,17 @@ __global__ void ChannelClipAndQuantKernelQuantAxisN(
const T* in, const T* scale, const int bin_cnt, const int64_t n,
const int nScale, const int quant_stride, T* out) {
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
using ComputeDataType = typename QuantizeDataType<T>::type;
ComputeDataType bin_cnt_t = static_cast<ComputeDataType>(bin_cnt);
for (int64_t i = idx; i < n; i += blockDim.x * gridDim.x) {
T s = scale[(i / quant_stride) % nScale];
T inv_s = inverse(s);
T x = in[i];
T v = x > s ? s : x;
ComputeDataType s =
static_cast<ComputeDataType>(scale[(i / quant_stride) % nScale]);
ComputeDataType inv_s = inverse(s);
ComputeDataType x = static_cast<ComputeDataType>(in[i]);
ComputeDataType v = x > s ? s : x;
v = v < -s ? -s : v;
v = bin_cnt * inv_s * v;
out[i] = round(v);
v = bin_cnt_t * inv_s * v;
out[i] = static_cast<T>(round(v));
}
}

Expand Down Expand Up @@ -376,7 +403,8 @@ __global__ void FindRangeAbsMaxAndFillArray(const T* cur_scale,
scale_arr[idx] = cur;
T max = last_scale[0];
out_scale[0] = max < cur ? cur : max;
if (fabs(removed - max) < 1e-6) {
if (fabs(static_cast<typename QuantizeDataType<T>::type>(removed - max)) <
1e-6) {
need_find_max[0] = 1;
out_size[0] = it > window_size ? window_size : it;
} else {
Expand Down
75 changes: 58 additions & 17 deletions python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np
import math
from op_test import OpTest
import paddle.fluid.core as core


def quantize_max_abs(x, max_range):
Expand Down Expand Up @@ -76,22 +77,25 @@ def channel_wise_dequantize_max_abs(x,
class TestFakeChannelWiseDequantizeMaxAbsOpTwoScales(OpTest):
def set_args(self):
self.quant_bits = [8, 8]
self.data_type = "float32"
self.activation_scale = 0.7861

def set_dtype(self):
self.dtype = np.float32

def setUp(self):
self.set_args()
self.set_dtype()
self.op_type = "fake_channel_wise_dequantize_max_abs"
x = np.random.randn(4, 3, 64, 64).astype(self.data_type)
x = np.random.randn(4, 3, 64, 64).astype(self.dtype)
yq, scales = channel_wise_quantize_max_abs(x, self.quant_bits[0], 1)
ydq = channel_wise_dequantize_max_abs(yq, scales, self.quant_bits, 1,
self.activation_scale)

self.inputs = {
'X': yq,
'Scales': [("scales0", np.array(scales).astype(self.data_type)),
("scales1", np.array(
[self.activation_scale]).astype(self.data_type))]
'Scales':
[("scales0", np.array(scales).astype(self.dtype)),
("scales1", np.array([self.activation_scale]).astype(self.dtype))]
}
self.attrs = {'quant_bits': self.quant_bits}
self.outputs = {'Out': ydq}
Expand All @@ -100,24 +104,36 @@ def test_check_output(self):
self.check_output()


class TestFakeChannelWiseDequantizeMaxAbsOpTwoScalesFloat16(
TestFakeChannelWiseDequantizeMaxAbsOpTwoScales):
def set_dtype(self):
self.dtype = np.float16

def test_check_output(self):
self.check_output(atol=1e-2)


class TestFakeChannelWiseDequantizeMaxAbsOpOneScale(OpTest):
def set_args(self):
self.quant_bits = [8]
self.data_type = "float32"
self.quant_axis = 0

def set_dtype(self):
self.dtype = np.float32

def setUp(self):
self.set_args()
self.set_dtype()
self.op_type = "fake_channel_wise_dequantize_max_abs"
x = np.random.randn(4, 3, 64, 64).astype(self.data_type)
x = np.random.randn(4, 3, 64, 64).astype(self.dtype)
yq, scales = channel_wise_quantize_max_abs(x, self.quant_bits[0],
self.quant_axis)
ydq = channel_wise_dequantize_max_abs(yq, scales, self.quant_bits,
self.quant_axis)

self.inputs = {
'X': yq,
'Scales': [("scales0", np.array(scales).astype(self.data_type))]
'Scales': [("scales0", np.array(scales).astype(self.dtype))]
}
self.attrs = {
'quant_bits': self.quant_bits,
Expand All @@ -133,24 +149,44 @@ class TestFakeChannelWiseDequantizeMaxAbsOpOneScale1(
TestFakeChannelWiseDequantizeMaxAbsOpOneScale):
def set_args(self):
self.quant_bits = [8]
self.data_type = "float32"
self.quant_axis = 1


class TestFakeChannelWiseDequantizeMaxAbsOpOneScaleFloat16(
TestFakeChannelWiseDequantizeMaxAbsOpOneScale):
def set_dtype(self):
self.dtype = np.float16

def test_check_output(self):
self.check_output(atol=1e-2)


class TestFakeChannelWiseDequantizeMaxAbsOpOneScale1Float16(
TestFakeChannelWiseDequantizeMaxAbsOpOneScale1):
def set_dtype(self):
self.dtype = np.float16

def test_check_output(self):
self.check_output(atol=1e-2)


class TestFakeDequantizeMaxAbsOp(OpTest):
def set_args(self):
self.num_bits = 8
self.max_range = math.pow(2, self.num_bits - 1) - 1
self.data_type = "float32"

def set_dtype(self):
self.dtype = np.float32

def setUp(self):
self.set_args()
self.set_dtype()
self.op_type = "fake_dequantize_max_abs"
x = np.random.randn(31, 65).astype(self.data_type)
x = np.random.randn(31, 65).astype(self.dtype)
yq, scale = quantize_max_abs(x, self.max_range)
ydq = dequantize_max_abs(yq, scale, self.max_range)

self.inputs = {'X': yq, 'Scale': np.array(scale).astype(self.data_type)}
self.inputs = {'X': yq, 'Scale': np.array(scale).astype(self.dtype)}
self.attrs = {'max_range': self.max_range}
self.outputs = {'Out': ydq}

Expand All @@ -159,17 +195,22 @@ def test_check_output(self):


class TestFakeDequantizeMaxAbsOpDouble(TestFakeDequantizeMaxAbsOp):
def set_args(self):
self.num_bits = 8
self.max_range = math.pow(2, self.num_bits - 1) - 1
self.data_type = "float64"
def set_dtype(self):
self.dtype = np.float64


class TestFakeDequantizeMaxAbsOp5Bits(TestFakeDequantizeMaxAbsOp):
def set_args(self):
self.num_bits = 5
self.max_range = math.pow(2, self.num_bits - 1) - 1
self.data_type = "float32"


class TestFakeDequantizeMaxAbsOpFloat16(TestFakeDequantizeMaxAbsOp):
def set_dtype(self):
self.dtype = np.float16

def test_check_output(self):
self.check_output(atol=1e-2)


class TestChannelWiseDequantizeOp(OpTest):
Expand Down
Loading