From 3c4601f290528ab8e6fc4874cc9f92f5db73c744 Mon Sep 17 00:00:00 2001 From: yghstill <742925032@qq.com> Date: Mon, 28 Mar 2022 11:41:00 +0000 Subject: [PATCH 01/14] add new format of quantization --- paddle/fluid/operators/CMakeLists.txt | 3 +- paddle/fluid/operators/quantize_linear_op.cc | 323 +++++ paddle/fluid/operators/quantize_linear_op.cu | 429 ++++++ paddle/fluid/operators/quantize_linear_op.h | 163 +++ paddle/phi/kernels/cpu/cast_kernel.cc | 1 + paddle/phi/kernels/gpu/cast_kernel.cu | 1 + .../slim/quantization/imperative/utils.py | 5 +- .../post_training_quantization.py | 195 +-- .../slim/quantization/quantization_pass.py | 1258 ++++++++++++----- .../fluid/contrib/slim/quantization/utils.py | 281 ++++ 10 files changed, 2182 insertions(+), 477 deletions(-) create mode 100644 paddle/fluid/operators/quantize_linear_op.cc create mode 100644 paddle/fluid/operators/quantize_linear_op.cu create mode 100644 paddle/fluid/operators/quantize_linear_op.h diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index e77be832c0cc8..bf2a00ea74a01 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -102,10 +102,11 @@ endif() set(OP_HEADER_DEPS ${OP_HEADER_DEPS} phi phi_api_utils gather_scatter_kernel) -register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op load_combine_op lstm_op run_program_op eye_op +register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op load_combine_op lstm_op run_program_op eye_op quantize_linear_op recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op spectral_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS}) op_library(run_program_op SRCS run_program_op.cc run_program_op.cu.cc DEPS executor_cache ${OP_HEADER_DEPS}) +op_library(quantize_linear_op SRCS quantize_linear_op.cc quantize_linear_op.cu DEPS cast_kernel ${OP_HEADER_DEPS}) op_library(save_combine_op DEPS string_array) op_library(load_combine_op DEPS string_array) diff --git a/paddle/fluid/operators/quantize_linear_op.cc b/paddle/fluid/operators/quantize_linear_op.cc new file mode 100644 index 0000000000000..72bb7ed61c3a4 --- /dev/null +++ b/paddle/fluid/operators/quantize_linear_op.cc @@ -0,0 +1,323 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/quantize_linear_op.h" +#include +#include +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/operators/clip_op.h" +#include "paddle/fluid/platform/transform.h" + +namespace paddle { +namespace operators { + +template +struct Compare { + public: + bool operator()(const T a, const T b) { return (std::abs(a) < std::abs(b)); } +}; + +template +struct FindAbsMaxFunctor { + void operator()(const platform::CPUDeviceContext& ctx, const T* in, + const int num, T* out) { + *out = std::abs(*(std::max_element(in + 0, in + num, Compare()))); + } +}; + +template struct FindAbsMaxFunctor; + +template +struct FindChannelAbsMaxFunctor { + void operator()(const platform::CPUDeviceContext& ctx, + const framework::Tensor& in_tensor, const int quant_axis, + T* out_abs_max) { + // At present, channelwise quantization supports conv2d, depthwise_conv2d + // conv2d_transpose and mul + PADDLE_ENFORCE_EQ( + quant_axis == 0 || quant_axis == 1, true, + platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " + "the received is %d", + quant_axis)); + auto* in_data = in_tensor.data(); + auto in_dims = in_tensor.dims(); + const int64_t channel = in_dims[quant_axis]; + if (quant_axis == 0) { + const int64_t channel_size = in_tensor.numel() / channel; + for (int64_t i = 0; i < channel; i++) { + auto* start = in_data + i * channel_size; + auto* end = in_data + (i + 1) * channel_size; + out_abs_max[i] = + std::abs(*(std::max_element(start, end, Compare()))); + } + } else if (quant_axis == 1) { + for (int64_t i = 0; i < channel; i++) { + out_abs_max[i] = 0; + } + const int64_t step_i = in_tensor.numel() / in_dims[0]; + const int64_t step_j = in_tensor.numel() / (in_dims[0] * in_dims[1]); + for (int64_t i = 0; i < in_dims[0]; i++) { + for (int64_t j = 0; j < in_dims[1]; j++) { + auto* start = in_data + i * step_i + j * step_j; + auto* end = in_data + i * step_i + (j + 1) * step_j; + T abs_max = std::abs(*(std::max_element(start, end, Compare()))); + out_abs_max[j] = std::max(out_abs_max[j], abs_max); + } + } + } + } +}; + +template struct FindChannelAbsMaxFunctor; + +template +struct ClipAndFakeQuantFunctor { + void operator()(const platform::CPUDeviceContext& ctx, + const framework::Tensor& in, const framework::Tensor& scale, + const int bin_cnt, framework::Tensor* out) { + T s = scale.data()[0]; + T inv_s = inverse(s); + platform::Transform trans; + trans(ctx, in.data(), in.data() + in.numel(), + out->mutable_data(ctx.GetPlace()), ClipFunctor(-s, s)); + auto out_e = framework::EigenVector::Flatten(*out); + out_e.device(*ctx.eigen_device()) = (bin_cnt * inv_s * out_e).round(); + } +}; + +template struct ClipAndFakeQuantFunctor; + +template +struct ChannelClipAndFakeQuantFunctor { + void operator()(const platform::CPUDeviceContext& ctx, + const framework::Tensor& in, const framework::Tensor& scale, + const int bin_cnt, const int quant_axis, + framework::Tensor* out) { + // At present, channelwise quantization supports conv2d, depthwise_conv2d + // conv2d_transpose and mul + PADDLE_ENFORCE_EQ( + quant_axis == 0 || quant_axis == 1, true, + platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " + "the received is %d", + quant_axis)); + auto* scale_data = scale.data(); + auto* in_data = in.data(); + auto* out_data = out->mutable_data(ctx.GetPlace()); + auto in_dims = in.dims(); + const int64_t channel = in_dims[quant_axis]; + platform::Transform trans; + if (quant_axis == 0) { + const int64_t channel_size = in.numel() / channel; + for (int64_t i = 0; i < channel; i++) { + T s = scale_data[i]; + auto* start = in_data + i * channel_size; + auto* end = in_data + (i + 1) * channel_size; + trans(ctx, start, end, out_data + i * channel_size, + ClipFunctor(-s, s)); + } + for (int64_t i = 0; i < channel; i++) { + T s = scale_data[i]; + T inv_s = inverse(s); + framework::Tensor one_channel_out = out->Slice(i, i + 1); + auto out_e = framework::EigenVector::Flatten(one_channel_out); + out_e.device(*ctx.eigen_device()) = (bin_cnt * inv_s * out_e).round(); + } + } else if (quant_axis == 1) { + const int64_t step_i = in.numel() / in_dims[0]; + const int64_t step_j = in.numel() / (in_dims[0] * in_dims[1]); + for (int i = 0; i < in_dims[0]; i++) { + for (int j = 0; j < in_dims[1]; j++) { + T s = scale_data[j]; + T inv_s = inverse(s); + auto* start = in_data + i * step_i + j * step_j; + auto* end = in_data + i * step_i + (j + 1) * step_j; + auto* cur_out_data = out_data + i * step_i + j * step_j; + trans(ctx, start, end, cur_out_data, ClipFunctor(-s, s)); + for (int k = 0; k < step_j; k++) { + cur_out_data[k] = std::round(bin_cnt * inv_s * cur_out_data[k]); + } + } + } + } + } +}; + +template struct ChannelClipAndFakeQuantFunctor; + +template +struct DequantizeFunctor { + void operator()(const platform::CPUDeviceContext& dev_ctx, + const framework::Tensor* in, const framework::Tensor* scale, + T max_range, framework::Tensor* out) { + auto in_e = framework::EigenVector::Flatten(*in); + const T* scale_factor = scale->data(); + auto out_e = framework::EigenVector::Flatten(*out); + + auto& dev = *dev_ctx.eigen_device(); + out_e.device(dev) = in_e * scale_factor[0] / max_range; + } +}; + +template +struct ChannelDequantizeFunctor { + void operator()(const platform::CPUDeviceContext& dev_ctx, + const framework::Tensor* in, const framework::Tensor* scale, + const int scale_num, T max_range, const int quant_axis, + const int x_num_col_dims, framework::Tensor* out) { + if (scale_num == 1) { + // Dequant op is before quantized op + // Dequantize the weight of quantized op + auto in_dims = in->dims(); + const int64_t channel = in_dims[quant_axis]; + const T* scale_factor = scale->data(); + if (quant_axis == 0) { + for (int64_t i = 0; i < channel; i++) { + T s = scale_factor[i]; + framework::Tensor one_channel_in = in->Slice(i, i + 1); + framework::Tensor one_channel_out = out->Slice(i, i + 1); + auto in_e = framework::EigenVector::Flatten(one_channel_in); + auto out_e = framework::EigenVector::Flatten(one_channel_out); + auto& dev = *dev_ctx.eigen_device(); + out_e.device(dev) = in_e * s / max_range; + } + } else if (quant_axis == 1) { + int64_t out_iter = 1; + for (int i = 0; i < quant_axis; i++) { + out_iter *= in_dims[i]; + } + int64_t step_i = in->numel() / out_iter; + int64_t step_j = in->numel() / (out_iter * channel); + auto* in_data = in->data(); + auto* out_data = out->mutable_data(dev_ctx.GetPlace()); + for (int64_t i = 0; i < out_iter; i++) { + for (int64_t j = 0; j < channel; j++) { + auto* cur_in = in_data + i * step_i + j * step_j; + auto* cur_out = out_data + i * step_i + j * step_j; + T s = scale_factor[j]; + for (int64_t k = 0; k < step_j; k++) { + *cur_out = (*cur_in) * s / max_range; + ++cur_in; + ++cur_out; + } + } + } + } + } + } +}; + +template struct DequantizeFunctor; +template struct DequantizeFunctor; +template struct ChannelDequantizeFunctor; +template struct ChannelDequantizeFunctor; + +class QuantizeLinearOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "QuantizeLinear"); + OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale", "QuantizeLinear"); + OP_INOUT_CHECK(ctx->HasInput("ZeroPoint"), "Input", "ZeroPoint", + "QuantizeLinear"); + OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "QuantizeLinear"); + ctx->SetOutputDim("Y", ctx->GetInputDim("X")); + int quant_axis = ctx->Attrs().Get("quant_axis"); + if (ctx->HasOutput("OutScale")) { + if (quant_axis < 0) { + ctx->SetOutputDim("OutScale", {1}); + } else { + ctx->SetOutputDim("OutScale", {ctx->GetInputDim("X")[quant_axis]}); + } + } + ctx->ShareLoD("X", /*->*/ "Y"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } +}; + +class QuantizeLinearOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor) Input is float data type."); + AddInput("Scale", "(Tensor) Input is float data type."); + AddInput("ZeroPoint", "(Tensor) Input is float data type."); + AddOutput("Y", + "(Tensor) Output of quantized low level tensor, " + "but also saved as float data type."); + AddOutput("OutScale", "(Tensor) Current scale").AsDispensable().AsExtra(); + AddAttr("quant_axis", + "(int, default 0) The axis for quantization. " + "For conv2d, depthwise_conv2d, conv2d_transpose " + "and mul, the quant_axis is equal to the cout axis.") + .SetDefault(0) + .AddCustomChecker([](const int& quant_axis) { + PADDLE_ENFORCE_EQ( + quant_axis == 0 || quant_axis == 1 || quant_axis == -1, true, + platform::errors::InvalidArgument( + "'quant_axis' should be 0 or 1, but " + "the received is %d", + quant_axis)); + }); + AddAttr("bit_length", "(int, default 8)") + .SetDefault(8) + .AddCustomChecker([](const int& bit_length) { + PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true, + platform::errors::InvalidArgument( + "'bit_length' should be between 1 and 16, but " + "the received is %d", + bit_length)); + }); + AddAttr("is_test", + "(bool, default false) Set to true for inference only, false " + "for training. Some layers may run faster when this is true.") + .SetDefault(true); + AddComment(R"DOC( +The scale of QuantizeLinear operator is a vector. +In detail, each channel of the input X has a scale value. +$$scale_c = max(abs(X_c))$$ +$$range = 2^{bit\_length - 1} - 1$$ +$$Out_c = round(\frac{X_c * range} {scale_c})$$ +In above three formulas, the range value of c is as follow: +$$0 \leq c \lt \ the\ channel\ number\ of\ X$$ +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +using CPU = paddle::platform::CPUDeviceContext; + +REGISTER_OPERATOR( + quantize_linear, ops::QuantizeLinearOp, ops::QuantizeLinearOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); + +REGISTER_OP_CPU_KERNEL(quantize_linear, ops::QuantizeLinearKernel); + +REGISTER_OPERATOR( + dequantize_linear, ops::QuantizeLinearOp, ops::QuantizeLinearOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); + +REGISTER_OP_CPU_KERNEL(dequantize_linear, + ops::DeQuantizeLinearKernel, + ops::DeQuantizeLinearKernel, + ops::DeQuantizeLinearKernel); diff --git a/paddle/fluid/operators/quantize_linear_op.cu b/paddle/fluid/operators/quantize_linear_op.cu new file mode 100644 index 0000000000000..42efe3a988a22 --- /dev/null +++ b/paddle/fluid/operators/quantize_linear_op.cu @@ -0,0 +1,429 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/operators/quantize_linear_op.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" + +namespace paddle { +namespace operators { + +template +__global__ void KeDequantize(const T* in, const T* scale, T max_range, int num, + T* out) { + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < num) { + out[idx] = in[idx] * scale[0] / max_range; + } +} + +template +struct DequantizeFunctor { + void operator()(const platform::CUDADeviceContext& dev_ctx, + const framework::Tensor* in, const framework::Tensor* scale, + T max_range, framework::Tensor* out) { + const T* in_data = in->data(); + const T* scale_factor = scale->data(); + T* out_data = out->mutable_data(dev_ctx.GetPlace()); + + int num = in->numel(); + int block = 512; + int grid = (num + block - 1) / block; + + KeDequantize<<>>( + in_data, scale_factor, max_range, num, out_data); + } +}; + +template +__global__ void DequantizeOneScaleQuantAxis0(const T* in, const T* scale, + T max_range, int num, int channel, + T* out) { + int tid = threadIdx.x; + int channel_size = num / channel; + const T* in_c = in + blockIdx.x * channel_size; + T* out_c = out + blockIdx.x * channel_size; + for (int i = tid; i < channel_size; i += blockDim.x) { + out_c[i] = in_c[i] * scale[blockIdx.x] / max_range; + } +} + +template +__global__ void DequantizeOneScaleQuantAxisN(const T* in, const T* scale, + const T max_range, + const int64_t num, + const int n_scales, + const int quant_stride, T* out) { + int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; + for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) { + T s = scale[(i / quant_stride) % n_scales]; + out[i] = in[i] * s / max_range; + } +} + +template +struct ChannelDequantizeFunctor { + void operator()(const platform::CUDADeviceContext& dev_ctx, + const framework::Tensor* in, const framework::Tensor* scale, + const int scale_num, T max_range, const int quant_axis, + const int x_num_col_dims, framework::Tensor* out) { + auto in_dims = in->dims(); + const T* in_data = in->data(); + T* out_data = out->mutable_data(dev_ctx.GetPlace()); + if (scale_num == 1) { + int64_t num = in->numel(); + const T* scale_factor = scale->data(); + if (quant_axis == 0) { + int grid = in_dims[0]; + int block = 1024; + DequantizeOneScaleQuantAxis0<<>>( + in_data, scale_factor, max_range, num, in_dims[0], out_data); + } else { + int quant_stride = 1; + for (int i = quant_axis + 1; i < in_dims.size(); i++) { + quant_stride *= in_dims[i]; + } + + int64_t block_size = std::min( + num, static_cast(dev_ctx.GetMaxThreadsPerBlock() / 4)); + int64_t max_threads = + dev_ctx.GetMaxPhysicalThreadCount(); // SM * block_per_SM + const int64_t max_blocks = std::max( + ((max_threads - 1) / block_size + 1), static_cast(1)); + const int64_t grid_size = + std::min(max_blocks, (num + block_size - 1) / block_size); + + DequantizeOneScaleQuantAxisN< + T><<>>( + in_data, scale_factor, max_range, num, in_dims[quant_axis], + quant_stride, out_data); + } + } + } +}; + +template struct DequantizeFunctor; +template struct DequantizeFunctor; +template struct ChannelDequantizeFunctor; +template struct ChannelDequantizeFunctor; + +template +__global__ void FindAbsMaxKernel(const T* in, const int n, T* out) { + int bid = threadIdx.x + blockIdx.x * blockDim.x; + int tid = threadIdx.x; + + extern __shared__ char* shared_max_data_tmp[]; + auto shared_max_data = reinterpret_cast(shared_max_data_tmp); + if (gridDim.x > 1) { + shared_max_data[tid] = T(0); + for (int i = bid; i < n; i += blockDim.x * gridDim.x) { + T tmp = abs(in[i]); + if (tmp > shared_max_data[tid]) { + shared_max_data[tid] = tmp; + } + } + } else { + if (bid < n) { + shared_max_data[tid] = abs(in[bid]); + } else { + shared_max_data[tid] = T(0); + } + } + __syncthreads(); + + for (int i = blockDim.x / 2; i > 0; i >>= 1) { + if (tid < i && (shared_max_data[tid] < shared_max_data[tid + i])) { + shared_max_data[tid] = shared_max_data[tid + i]; + } + __syncthreads(); + } + if (tid == 0) { + out[blockIdx.x] = shared_max_data[0]; + } +} + +template +struct FindAbsMaxFunctor { + void operator()(const platform::CUDADeviceContext& ctx, const T* in, + const int num, T* out) { + int block = 1024; + int grid = (block - 1 + num) / block; + grid = (grid > block) ? block : grid; + + framework::Tensor max; + T* max_data = max.mutable_data(phi::make_ddim({grid}), ctx.GetPlace()); + FindAbsMaxKernel<<>>( + in, num, max_data); + FindAbsMaxKernel<<<1, block, 1024 * sizeof(T), ctx.stream()>>>( + max_data, grid, out); + } +}; + +template struct FindAbsMaxFunctor; +template struct FindAbsMaxFunctor; + +template +__global__ void FindChannelAbsMaxKernelQuantAxis0(const T* in, const int n, + const int c, T* out) { + int tid = threadIdx.x; + int channel_size = n / c; + const T* in_c = in + blockIdx.x * channel_size; + extern __shared__ T shared_max_data[]; + shared_max_data[tid] = T(0); + for (int i = tid; i < channel_size; i += blockDim.x) { + T tmp = fabs(in_c[i]); + if (tmp > shared_max_data[tid]) { + shared_max_data[tid] = tmp; + } + } + __syncthreads(); + for (int i = blockDim.x / 2; i > 0; i >>= 1) { + if (tid < i && (shared_max_data[tid] < shared_max_data[tid + i])) { + shared_max_data[tid] = shared_max_data[tid + i]; + } + __syncthreads(); + } + if (tid == 0) { + out[blockIdx.x] = shared_max_data[0]; + } +} + +template +__global__ void FindChannelAbsMaxKernelQuantAxis1(const T* in, const int n, + const int cin, const int cout, + T* out) { + extern __shared__ T shared_max_data[]; + int cout_wh_size = n / cin; + int wh_size = n / (cin * cout); + + int tid = threadIdx.x; + int bid = blockIdx.x; + const T* in_current = in + tid * cout_wh_size + bid * wh_size; + shared_max_data[tid] = T(0); + for (int i = 0; i < wh_size; i++) { + T tmp = fabs(in_current[i]); + if (tmp > shared_max_data[tid]) { + shared_max_data[tid] = tmp; + } + } + __syncthreads(); + + int len = blockDim.x; + for (int i = (len + 1) / 2; i > 0; len = i, i = (i + 1) / 2) { + if (tid < i && tid + i < len && + shared_max_data[tid] < shared_max_data[tid + i]) { + shared_max_data[tid] = shared_max_data[tid + i]; + } + if (i == 1) { + i = 0; // break the loop + } + __syncthreads(); + } + if (tid == 0 && shared_max_data[0] > out[bid]) { + out[bid] = shared_max_data[0]; + } +} + +template +struct FindChannelAbsMaxFunctor { + void operator()(const platform::CUDADeviceContext& ctx, + const framework::Tensor& in_tensor, const int quant_axis, + T* out_abs_max) { + PADDLE_ENFORCE_EQ( + quant_axis == 0 || quant_axis == 1, true, + platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " + "the received is %d", + quant_axis)); + const int num = in_tensor.numel(); + auto in_dims = in_tensor.dims(); + const T* in_data = in_tensor.data(); + if (quant_axis == 0) { + int cout = in_dims[0]; + int grid = cout; + int block = 1024; + FindChannelAbsMaxKernelQuantAxis0< + T><<>>( + in_data, num, cout, out_abs_max); + } else if (quant_axis == 1) { + int cin = in_dims[0]; + int cout = in_dims[1]; + int grid = cout; + int max_threads = 1024; + +#ifdef PADDLE_WITH_HIP + hipMemset(out_abs_max, 0, sizeof(T) * cout); +#else + cudaMemset(out_abs_max, 0, sizeof(T) * cout); +#endif + + for (int i = 0; i < cin / max_threads; i++) { + int block = max_threads; + FindChannelAbsMaxKernelQuantAxis1< + T><<>>( + in_data, num, cin, cout, out_abs_max); + in_data += num / cin; + } + + int block = cin % max_threads; + if (block > 0) { + FindChannelAbsMaxKernelQuantAxis1< + T><<>>( + in_data, num, in_dims[0], in_dims[1], out_abs_max); + } + } + } +}; + +template struct FindChannelAbsMaxFunctor; + +template +__global__ void ClipAndQuantKernel(const T* in, const T* scale, + const int bin_cnt, const int n, T* out) { + int bid = threadIdx.x + blockIdx.x * blockDim.x; + int tid = threadIdx.x; + + T s = scale[0]; + T inv_s = inverse(s); + for (int i = bid; i < n; i += blockDim.x * gridDim.x) { + T x = in[i]; + T v = x > s ? s : x; + v = v < -s ? -s : v; + v = bin_cnt * inv_s * v; + out[i] = round(v); + } +} + +template +struct ClipAndFakeQuantFunctor { + void operator()(const platform::CUDADeviceContext& ctx, + const framework::Tensor& in, const framework::Tensor& scale, + const int bin_cnt, framework::Tensor* out) { + int num = in.numel(); + int block = 1024; + int grid = (block - 1 + num) / block; + + const T* in_data = in.data(); + const T* scale_data = scale.data(); + T* out_data = out->mutable_data(ctx.GetPlace()); + + ClipAndQuantKernel<<>>( + in_data, scale_data, bin_cnt, num, out_data); + } +}; + +template struct ClipAndFakeQuantFunctor; + +// ChannelClipAndQuantKernel for quant_axis is 0 +template +__global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale, + const int bin_cnt, + const int64_t n, + const int c, T* out) { + int tid = threadIdx.x; + + int64_t channel_size = n / c; + const T* in_c = in + blockIdx.x * channel_size; + T* out_c = out + blockIdx.x * channel_size; + + T s = scale[blockIdx.x]; + T inv_s = inverse(s); + + for (int64_t i = tid; i < channel_size; i += blockDim.x) { + T x = in_c[i]; + T v = x > s ? s : x; + v = v < -s ? -s : v; + v = bin_cnt * inv_s * v; + out_c[i] = round(v); + } +} + +// ChannelClipAndQuantKernel for quant_axis is N +template +__global__ void ChannelClipAndQuantKernelQuantAxisN( + const T* in, const T* scale, const int bin_cnt, const int64_t n, + const int nScale, const int quant_stride, T* out) { + int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; + for (int64_t i = idx; i < n; i += blockDim.x * gridDim.x) { + T s = scale[(i / quant_stride) % nScale]; + T inv_s = inverse(s); + T x = in[i]; + T v = x > s ? s : x; + v = v < -s ? -s : v; + v = bin_cnt * inv_s * v; + out[i] = round(v); + } +} + +template +struct ChannelClipAndFakeQuantFunctor { + void operator()(const platform::CUDADeviceContext& ctx, + const framework::Tensor& in, const framework::Tensor& scale, + const int bin_cnt, const int quant_axis, + framework::Tensor* out) { + PADDLE_ENFORCE_EQ( + quant_axis == 0 || quant_axis == 1, true, + platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " + "the received is %d", + quant_axis)); + + int64_t num = in.numel(); + auto in_dims = in.dims(); + const T* in_data = in.data(); + const T* scale_data = scale.data(); + T* out_data = out->mutable_data(ctx.GetPlace()); + + if (quant_axis == 0) { + int grid = in_dims[0]; + int block = 1024; + ChannelClipAndQuantKernelQuantAxis0<<>>( + in_data, scale_data, bin_cnt, num, in_dims[0], out_data); + } else { + int quant_stride = 1; + for (int i = quant_axis + 1; i < in_dims.size(); i++) { + quant_stride *= in_dims[i]; + } + int64_t block_size = + std::min(num, static_cast(ctx.GetMaxThreadsPerBlock() / 4)); + int64_t max_threads = + ctx.GetMaxPhysicalThreadCount(); // SM * block_per_SM + const int64_t max_blocks = std::max(((max_threads - 1) / block_size + 1), + static_cast(1)); + + const int64_t grid_size = + std::min(max_blocks, (num + block_size - 1) / block_size); + + ChannelClipAndQuantKernelQuantAxisN<<>>( + in_data, scale_data, bin_cnt, num, in_dims[quant_axis], quant_stride, + out_data); + } + } +}; + +template struct ChannelClipAndFakeQuantFunctor; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +using CUDA = paddle::platform::CUDADeviceContext; +REGISTER_OP_CUDA_KERNEL(dequantize_linear, + ops::DeQuantizeLinearKernel, + ops::DeQuantizeLinearKernel, + ops::DeQuantizeLinearKernel); + +REGISTER_OP_CUDA_KERNEL(quantize_linear, + ops::QuantizeLinearKernel); diff --git a/paddle/fluid/operators/quantize_linear_op.h b/paddle/fluid/operators/quantize_linear_op.h new file mode 100644 index 0000000000000..2f2d9ef9a5100 --- /dev/null +++ b/paddle/fluid/operators/quantize_linear_op.h @@ -0,0 +1,163 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/memory/malloc.h" +#include "paddle/fluid/platform/transform.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/ddim.h" +#include "paddle/phi/core/hostdevice.h" +#include "paddle/phi/kernels/cast_kernel.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" + +namespace paddle { +namespace operators { + +template +inline HOSTDEVICE T inverse(T s) { + T eps = static_cast(1e-6); + T one = static_cast(1.0); + return s <= static_cast(1e-30) ? one / (s + eps) : one / s; +} + +template +struct FindAbsMaxFunctor { + void operator()(const DeviceContext& ctx, const T* in, const int num, T* out); +}; + +template +struct ClipAndFakeQuantFunctor { + void operator()(const DeviceContext& ctx, const framework::Tensor& in, + const framework::Tensor& scale, const int bin_cnt, + framework::Tensor* out); +}; + +template +struct FindChannelAbsMaxFunctor { + void operator()(const DeviceContext& ctx, const framework::Tensor& in_tensor, + const int quant_axis, T* out_abs_max); +}; + +template +struct ChannelClipAndFakeQuantFunctor { + void operator()(const DeviceContext& ctx, const framework::Tensor& in, + const framework::Tensor& scale, const int bin_cnt, + const int quant_axis, framework::Tensor* out); +}; + +template +struct DequantizeFunctor { + void operator()(const DeviceContext& dev_ctx, const framework::Tensor* in, + const framework::Tensor* scale, T max_range, + framework::Tensor* out); +}; + +template +struct ChannelDequantizeFunctor { + void operator()(const DeviceContext& dev_ctx, const framework::Tensor* in, + const framework::Tensor** scales, const int scale_num, + T max_range, const int quant_axis, const int x_num_col_dims, + framework::Tensor* out); +}; + +template +class QuantizeLinearKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in = context.Input("X"); + auto* in_scale = context.Input("Scale"); + + auto* out = context.Output("Y"); + out->mutable_data(context.GetPlace()); + int bit_length = context.Attr("bit_length"); + int bin_cnt = std::pow(2, bit_length - 1) - 1; + int quant_axis = context.Attr("quant_axis"); + bool is_test = context.Attr("is_test"); + auto& dev_ctx = context.template device_context(); + + if (quant_axis < 0) { + if (!is_test) { + auto* out_scale = context.Output("OutScale"); + T* out_s = out_scale->mutable_data(context.GetPlace()); + FindAbsMaxFunctor()(dev_ctx, in->data(), + in->numel(), out_s); + ClipAndFakeQuantFunctor()(dev_ctx, *in, *out_scale, + bin_cnt, out); + } else { + ClipAndFakeQuantFunctor()(dev_ctx, *in, *in_scale, + bin_cnt, out); + } + } else { + if (!is_test) { + auto* out_scale = context.Output("OutScale"); + T* out_scale_data = out_scale->mutable_data(context.GetPlace()); + FindChannelAbsMaxFunctor()(dev_ctx, *in, quant_axis, + out_scale_data); + } + ChannelClipAndFakeQuantFunctor()( + dev_ctx, *in, *in_scale, bin_cnt, quant_axis, out); + } + } +}; + +template +class DeQuantizeLinearKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto& dev_ctx = context.template device_context(); + auto* in = context.Input("X"); + + auto in_tmp = phi::Cast( + static_cast::TYPE&>(dev_ctx), + *in, experimental::CppTypeToDataType::Type()); + + auto* scale = context.Input("Scale"); + auto* out = context.Output("Y"); + int bit_length = context.Attr("bit_length"); + auto quant_axis = context.Attr("quant_axis"); + out->mutable_data(dev_ctx.GetPlace()); + + if (quant_axis < 0) { + float max_range = (std::pow(2, bit_length - 1) - 1); + DequantizeFunctor()(dev_ctx, &in_tmp, scale, + static_cast(max_range), out); + } else { + auto x_num_col_dims = 1; + int max_range = 1; + + out->mutable_data(dev_ctx.GetPlace()); + // Now only support scale_num = 1 + int scale_num = 1; + PADDLE_ENFORCE_EQ( + scale->numel(), in_tmp.dims()[quant_axis], + platform::errors::PreconditionNotMet( + "The number of first scale values must be the same with " + "quant_axis dimension value of Input(X) when the `scale` has " + "only one element, but %ld != %ld here.", + scale->numel(), in_tmp.dims()[quant_axis])); + max_range *= (std::pow(2, bit_length - 1) - 1); + + ChannelDequantizeFunctor()( + dev_ctx, &in_tmp, scale, scale_num, static_cast(max_range), + quant_axis, x_num_col_dims, out); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/phi/kernels/cpu/cast_kernel.cc b/paddle/phi/kernels/cpu/cast_kernel.cc index 800962544c73e..b53c94eb4cae2 100644 --- a/paddle/phi/kernels/cpu/cast_kernel.cc +++ b/paddle/phi/kernels/cpu/cast_kernel.cc @@ -41,6 +41,7 @@ PD_REGISTER_KERNEL(cast, int64_t, int16_t, bool, + int8_t, uint8_t, phi::dtype::float16, phi::dtype::bfloat16, diff --git a/paddle/phi/kernels/gpu/cast_kernel.cu b/paddle/phi/kernels/gpu/cast_kernel.cu index 7c4cadbc90ac6..40a84648e4b16 100644 --- a/paddle/phi/kernels/gpu/cast_kernel.cu +++ b/paddle/phi/kernels/gpu/cast_kernel.cu @@ -41,6 +41,7 @@ void CastKernel(const Context& dev_ctx, int64_t, \ int16_t, \ bool, \ + int8_t, \ uint8_t, \ phi::dtype::float16, \ phi::dtype::complex, \ diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py b/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py index 009ce372b4f29..758928f8dafe8 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py @@ -18,10 +18,7 @@ import paddle import paddle.nn.quant.quant_layers as quant_layers -from ..quantization_pass import _get_op_input_var_names -from ..quantization_pass import _get_op_output_var_names -from ..quantization_pass import _get_output_name_index -from ..quantization_pass import _get_input_name_index +from ..utils import _get_op_input_var_names, _get_op_output_var_names, _get_output_name_index, _get_input_name_index layer_name_map = { 'Conv2DTranspose': paddle.nn.Conv2DTranspose, diff --git a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py index b1b645e85e75d..febb4170b4440 100644 --- a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py +++ b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py @@ -25,18 +25,10 @@ from ....executor import global_scope, Executor from ....framework import IrGraph from ....log_helper import get_logger -from .quantization_pass import QuantizationTransformPass -from .quantization_pass import QuantizationFreezePass -from .quantization_pass import AddQuantDequantPass -from .quantization_pass import _out_scale_op_list -from .quantization_pass import _get_op_input_var_names -from .quantization_pass import _get_op_output_var_names -from .quantization_pass import _get_output_name_index -from .quantization_pass import _get_input_name_index -from .quantization_pass import _channelwise_quant_axis1_ops +from .quantization_pass import QuantizationTransformPass, QuantizationTransformPassV2, QuantizationFreezePass, QuantWeightPass, AddQuantDequantPass, AddQuantDequantPassV2 from .cal_kl_threshold import cal_kl_threshold from .adaround import run_adaround -from .utils import load_variable_data, set_variable_data +from . import utils __all__ = ['PostTrainingQuantization', 'WeightQuantization'] @@ -131,6 +123,7 @@ def __init__(self, weight_bits=8, activation_quantize_type='range_abs_max', weight_quantize_type='channel_wise_abs_max', + onnx_format=False, optimize_model=False, is_use_cache_file=False, cache_dir=None): @@ -203,6 +196,8 @@ def __init__(self, the fake ops in saving quantized model, and we save the scale obtained by post training quantization in fake ops. Compared to 'abs_max', the model accuracy is usually higher when it is 'channel_wise_abs_max'. + onnx_format(bool): Whether to export the quantized model with format of onnx. + Default is False. optimize_model(bool, optional): If set optimize_model as True, it applies some passes to the model before quantization, and it supports `conv2d/depthwise_conv2d + bn` pass so far. Some targets require the @@ -265,8 +260,8 @@ def __init__(self, self._learning_rate = learning_rate self._dynamic_quantize_op_type = ['lstm'] self._support_quantize_op_type = \ - list(set(QuantizationTransformPass._supported_quantizable_op_type + - AddQuantDequantPass._supported_quantizable_op_type + + list(set(utils._weight_supported_quantizable_op_type + + utils._act_supported_quantizable_op_type + self._dynamic_quantize_op_type)) # Check inputs @@ -305,6 +300,7 @@ def __init__(self, self._weight_bits = weight_bits self._activation_quantize_type = activation_quantize_type self._weight_quantize_type = weight_quantize_type + self._onnx_format = onnx_format self._is_full_quantize = is_full_quantize if is_full_quantize: self._quantizable_op_type = self._support_quantize_op_type @@ -322,7 +318,7 @@ def __init__(self, self._fetch_list = None self._data_loader = data_loader - self._out_scale_op_list = _out_scale_op_list + self._out_scale_op_list = utils._out_scale_op_list self._quantized_weight_var_name = set() self._quantized_act_var_name = set() self._weight_op_pairs = {} @@ -406,7 +402,8 @@ def quantize(self): else: self._save_input_threhold() - self._save_output_threshold() + if not self._onnx_format: + self._save_output_threshold() if any(op_type in self._quantizable_op_type for op_type in self._dynamic_quantize_op_type): self._collect_dynamic_quantize_op_threshold( @@ -466,6 +463,7 @@ def save_quantized_model(self, Returns: None ''' + clip_extra = True if self._onnx_format else False io.save_inference_model( dirname=save_model_path, model_filename=model_filename, @@ -473,7 +471,8 @@ def save_quantized_model(self, feeded_var_names=self._feed_list, target_vars=self._fetch_list, executor=self._executor, - main_program=self._program) + main_program=self._program, + clip_extra=clip_extra) _logger.info("The quantized model is saved in " + save_model_path) def _load_model_data(self): @@ -551,22 +550,22 @@ def collect_var_name(var_name_list, persistable_var_names, op_type): # For quantized ops, sample inputs and outputs if op_type in self._quantizable_op_type: collect_var_name( - _get_op_input_var_names(op), persistable_var_names, - op_type) + utils._get_op_input_var_names(op), + persistable_var_names, op_type) collect_var_name( - _get_op_output_var_names(op), persistable_var_names, - op_type) + utils._get_op_output_var_names(op), + persistable_var_names, op_type) # collect quanted op output var name - for out_var_name in _get_op_output_var_names(op): - for in_var_name in _get_op_input_var_names(op): + for out_var_name in utils._get_op_output_var_names(op): + for in_var_name in utils._get_op_input_var_names(op): if in_var_name in persistable_var_names: self._quantized_op_pairs[ in_var_name] = out_var_name # For other op, only sample output scale elif op_type in self._out_scale_op_list: collect_var_name( - _get_op_output_var_names(op), persistable_var_names, - op_type) + utils._get_op_output_var_names(op), + persistable_var_names, op_type) def _set_activation_persistable(self): ''' @@ -608,13 +607,14 @@ def _sampling(self): def _sample_mse(self): if self._quantized_threshold == {}: for var_name in self._quantized_weight_var_name: - var_tensor = load_variable_data(self._scope, var_name) + var_tensor = utils.utils.load_variable_data(self._scope, + var_name) if self._weight_quantize_type == "abs_max": abs_max_value = float(np.max(np.abs(var_tensor))) elif self._weight_quantize_type == "channel_wise_abs_max": abs_max_value = [] if self._weight_op_pairs[ - var_name] in _channelwise_quant_axis1_ops: + var_name] in utils._channelwise_quant_axis1_ops: for i in range(var_tensor.shape[1]): abs_max_value.append( float(np.max(np.abs(var_tensor[:, i])))) @@ -625,7 +625,7 @@ def _sample_mse(self): self._quantized_threshold[var_name] = abs_max_value _logger.info("MSE searching stage ...") for var_name in self._quantized_act_var_name: - var_tensor = load_variable_data(self._scope, var_name) + var_tensor = utils.load_variable_data(self._scope, var_name) var_tensor = var_tensor.flatten() abs_max_value = float(np.max(np.abs(var_tensor))) abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value @@ -647,13 +647,13 @@ def _sample_mse(self): def _sample_emd(self): if self._quantized_threshold == {}: for var_name in self._quantized_weight_var_name: - var_tensor = load_variable_data(self._scope, var_name) + var_tensor = utils.load_variable_data(self._scope, var_name) if self._weight_quantize_type == "abs_max": abs_max_value = float(np.max(np.abs(var_tensor))) elif self._weight_quantize_type == "channel_wise_abs_max": abs_max_value = [] if self._weight_op_pairs[ - var_name] in _channelwise_quant_axis1_ops: + var_name] in utils._channelwise_quant_axis1_ops: for i in range(var_tensor.shape[1]): abs_max_value.append( float(np.max(np.abs(var_tensor[:, i])))) @@ -664,7 +664,7 @@ def _sample_emd(self): self._quantized_threshold[var_name] = abs_max_value _logger.info("EMD searching stage ...") for var_name in self._quantized_act_var_name: - var_tensor = load_variable_data(self._scope, var_name) + var_tensor = utils.load_variable_data(self._scope, var_name) var_tensor = var_tensor.flatten() abs_max_value = float(np.max(np.abs(var_tensor))) abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value @@ -688,13 +688,13 @@ def _sample_emd(self): def _sample_avg(self): if self._quantized_threshold == {}: for var_name in self._quantized_weight_var_name: - var_tensor = load_variable_data(self._scope, var_name) + var_tensor = utils.load_variable_data(self._scope, var_name) if self._weight_quantize_type == "abs_max": abs_max_value = float(np.max(np.abs(var_tensor))) elif self._weight_quantize_type == "channel_wise_abs_max": abs_max_value = [] if self._weight_op_pairs[ - var_name] in _channelwise_quant_axis1_ops: + var_name] in utils._channelwise_quant_axis1_ops: for i in range(var_tensor.shape[1]): abs_max_value.append( float(np.max(np.abs(var_tensor[:, i])))) @@ -705,7 +705,7 @@ def _sample_avg(self): self._quantized_threshold[var_name] = abs_max_value for var_name in self._quantized_act_var_name: - var_tensor = load_variable_data(self._scope, var_name) + var_tensor = utils.load_variable_data(self._scope, var_name) abs_max_value = float(np.max(np.abs(var_tensor))) if (var_name not in self._quantized_var_avg): self._quantized_var_avg[var_name] = [] @@ -717,13 +717,13 @@ def _sample_avg(self): def _sample_abs_max(self): if self._quantized_threshold == {}: for var_name in self._quantized_weight_var_name: - var_tensor = load_variable_data(self._scope, var_name) + var_tensor = utils.load_variable_data(self._scope, var_name) if self._weight_quantize_type == "abs_max": abs_max_value = float(np.max(np.abs(var_tensor))) elif self._weight_quantize_type == "channel_wise_abs_max": abs_max_value = [] if self._weight_op_pairs[ - var_name] in _channelwise_quant_axis1_ops: + var_name] in utils._channelwise_quant_axis1_ops: for i in range(var_tensor.shape[1]): abs_max_value.append( float(np.max(np.abs(var_tensor[:, i])))) @@ -734,7 +734,7 @@ def _sample_abs_max(self): self._quantized_threshold[var_name] = abs_max_value for var_name in self._quantized_act_var_name: - var_tensor = load_variable_data(self._scope, var_name) + var_tensor = utils.load_variable_data(self._scope, var_name) abs_max_value = float(np.max(np.abs(var_tensor))) if (var_name not in self._quantized_threshold) or \ (abs_max_value > self._quantized_threshold[var_name]): @@ -743,7 +743,7 @@ def _sample_abs_max(self): def _sample_min_max(self): if self._quantized_var_min == {} and self._quantized_var_max == {}: for var_name in self._quantized_weight_var_name: - var_tensor = load_variable_data(self._scope, var_name) + var_tensor = utils.load_variable_data(self._scope, var_name) if self._weight_quantize_type == "abs_max": min_value = float(np.min(var_tensor)) max_value = float(np.max(var_tensor)) @@ -751,7 +751,7 @@ def _sample_min_max(self): min_value = [] max_value = [] if self._weight_op_pairs[ - var_name] in _channelwise_quant_axis1_ops: + var_name] in utils._channelwise_quant_axis1_ops: for i in range(var_tensor.shape[1]): min_value.append(float(np.min(var_tensor[:, i]))) max_value.append(float(np.max(var_tensor[:, i]))) @@ -763,7 +763,7 @@ def _sample_min_max(self): self._quantized_var_max[var_name] = max_value for var_name in self._quantized_act_var_name: - var_tensor = load_variable_data(self._scope, var_name) + var_tensor = utils.load_variable_data(self._scope, var_name) min_value = float(np.min(var_tensor)) max_value = float(np.max(var_tensor)) if (var_name not in self._quantized_var_min) or \ @@ -775,7 +775,7 @@ def _sample_min_max(self): def _sample_histogram(self): for var_name in self._quantized_act_var_name: - var_tensor = load_variable_data(self._scope, var_name) + var_tensor = utils.load_variable_data(self._scope, var_name) var_tensor_abs = np.abs(var_tensor) bins = self._sampling_act_histogram[var_name][1] hist, _ = np.histogram(var_tensor_abs, bins=bins) @@ -790,7 +790,7 @@ def _save_input_threhold(self): for block_id in range(len(self._program.blocks)): for op in self._program.blocks[block_id].ops: if op.type in self._quantizable_op_type: - for var_name in _get_op_input_var_names(op): + for var_name in utils._get_op_input_var_names(op): assert var_name in self._quantized_var_min assert var_name in self._quantized_var_max op._set_attr(var_name + ".min", @@ -805,7 +805,7 @@ def _collect_activation_abs_min_max(self): get the min and max value, and then calculate the threshold. ''' for var_name in self._quantized_act_var_name: - var_tensor = load_variable_data(self._scope, var_name) + var_tensor = utils.load_variable_data(self._scope, var_name) var_tensor = np.abs(var_tensor) min_value = float(np.min(var_tensor)) max_value = float(np.max(var_tensor)) @@ -839,13 +839,13 @@ def _calculate_kl_hist_threshold(self): # Abs_max threshold for weights for var_name in self._quantized_weight_var_name: - weight_data = load_variable_data(self._scope, var_name) + weight_data = utils.load_variable_data(self._scope, var_name) if self._weight_quantize_type == "abs_max": weight_threshold = float(np.max(np.abs(weight_data))) elif self._weight_quantize_type == "channel_wise_abs_max": weight_threshold = [] if self._weight_op_pairs[ - var_name] in _channelwise_quant_axis1_ops: + var_name] in utils._channelwise_quant_axis1_ops: for i in range(weight_data.shape[1]): weight_threshold.append( float(np.max(np.abs(weight_data[:, i])))) @@ -876,17 +876,27 @@ def _update_program(self): # use QuantizationTransformPass to insert fake_quant/fake_dequantize op major_quantizable_op_types = [] - for op_type in QuantizationTransformPass._supported_quantizable_op_type: + for op_type in utils._weight_supported_quantizable_op_type: if op_type in self._quantizable_op_type: major_quantizable_op_types.append(op_type) - transform_pass = QuantizationTransformPass( - scope=self._scope, - place=self._place, - weight_bits=self._weight_bits, - activation_bits=self._activation_bits, - activation_quantize_type=self._activation_quantize_type, - weight_quantize_type=self._weight_quantize_type, - quantizable_op_type=major_quantizable_op_types) + if not self._onnx_format: + transform_pass = QuantizationTransformPass( + scope=self._scope, + place=self._place, + weight_bits=self._weight_bits, + activation_bits=self._activation_bits, + activation_quantize_type=self._activation_quantize_type, + weight_quantize_type=self._weight_quantize_type, + quantizable_op_type=major_quantizable_op_types) + else: + transform_pass = QuantizationTransformPassV2( + scope=self._scope, + place=self._place, + weight_bits=self._weight_bits, + activation_bits=self._activation_bits, + activation_quantize_type=self._activation_quantize_type, + weight_quantize_type=self._weight_quantize_type, + quantizable_op_type=major_quantizable_op_types) for sub_graph in graph.all_sub_graphs(): # Insert fake_quant/fake_dequantize op must in test graph, so @@ -896,13 +906,20 @@ def _update_program(self): # use AddQuantDequantPass to insert fake_quant_dequant op minor_quantizable_op_types = [] - for op_type in AddQuantDequantPass._supported_quantizable_op_type: + for op_type in utils._act_supported_quantizable_op_type: if op_type in self._quantizable_op_type: minor_quantizable_op_types.append(op_type) - add_quant_dequant_pass = AddQuantDequantPass( - scope=self._scope, - place=self._place, - quantizable_op_type=minor_quantizable_op_types) + if not self._onnx_format: + add_quant_dequant_pass = AddQuantDequantPass( + scope=self._scope, + place=self._place, + quantizable_op_type=minor_quantizable_op_types) + else: + add_quant_dequant_pass = AddQuantDequantPassV2( + scope=self._scope, + place=self._place, + quantizable_op_type=minor_quantizable_op_types, + is_full_quantized=self._is_full_quantize) for sub_graph in graph.all_sub_graphs(): sub_graph._for_test = True @@ -914,33 +931,39 @@ def _update_program(self): else: scale_dict = self._quantized_threshold for key, val in scale_dict.items(): - set_variable_data( + utils.set_variable_data( self._scope, self._place, key + ".scale", np.array( [val], dtype=np.float32)) - set_variable_data( + utils.set_variable_data( self._scope, self._place, key + ".quant_dequant.scale", np.array( [val], dtype=np.float32)) - # apply QuantizationFreezePass, and obtain the final quant model - freeze_pass = QuantizationFreezePass( - scope=self._scope, - place=self._place, - bias_correction=self._bias_correction, - weight_bits=self._weight_bits, - round_type=self._round_type, - activation_bits=self._activation_bits, - weight_quantize_type=self._weight_quantize_type, - quantizable_op_type=major_quantizable_op_types) - - for sub_graph in graph.all_sub_graphs(): - sub_graph._for_test = True - freeze_pass.apply(sub_graph) + if not self._onnx_format: + # apply QuantizationFreezePass, and obtain the final quant model + freeze_pass = QuantizationFreezePass( + scope=self._scope, + place=self._place, + bias_correction=self._bias_correction, + weight_bits=self._weight_bits, + round_type=self._round_type, + activation_bits=self._activation_bits, + weight_quantize_type=self._weight_quantize_type, + quantizable_op_type=major_quantizable_op_types) + + for sub_graph in graph.all_sub_graphs(): + sub_graph._for_test = True + freeze_pass.apply(sub_graph) + else: + quant_weight_pass = QuantWeightPass(self._scope, self._place) + for sub_graph in graph.all_sub_graphs(): + sub_graph._for_test = True + quant_weight_pass.apply(sub_graph) self._program = graph.to_program() @@ -960,7 +983,7 @@ def save_info(op_node, out_var_name, threshold_map, out_info_name, op._set_attr("quantization_type", quantized_type) def analysis_and_save_info(op_node, out_var_name): - argname_index = _get_output_name_index(op_node, out_var_name) + argname_index = utils._get_output_name_index(op_node, out_var_name) assert argname_index is not None, \ out_var_name + " is not the output of the op" if self._algo == "KL": @@ -997,7 +1020,7 @@ def analysis_and_save_info(op_node, out_var_name): for op in self._program.blocks[block_id].ops: if op.type in ( self._quantizable_op_type + self._out_scale_op_list): - out_var_names = _get_op_output_var_names(op) + out_var_names = utils._get_op_output_var_names(op) for var_name in out_var_names: analysis_and_save_info(op, var_name) @@ -1020,11 +1043,11 @@ def _collect_dynamic_quantize_op_threshold(self, target_ops_type): quantization_type = str("post_" + self._algo).lower() persistable_var_names = _all_persistable_var_names(self._program) for op in target_ops: - for var_name in _get_op_input_var_names(op): + for var_name in utils._get_op_input_var_names(op): if var_name in persistable_var_names: - var_data = load_variable_data(self._scope, var_name) + var_data = utils.load_variable_data(self._scope, var_name) threshold = float(np.max(np.abs(var_data))) - argname, index = _get_input_name_index(op, var_name) + argname, index = utils._get_input_name_index(op, var_name) op._set_attr(argname + str(index) + "_threshold", threshold) op._set_attr("quantization_type", quantization_type) op._set_attr("bit_length", self._weight_bits) @@ -1268,7 +1291,7 @@ def _weight_abs_max_quantization(self, scope, place, weight_bits, save_weight_dtype = np.int8 if weight_bits == 8 else np.int16 # Get quantized scale and weight data - weight_data = load_variable_data(scope, var_name) + weight_data = utils.load_variable_data(scope, var_name) if abs(threshold_rate) < 1e-10: threshold_value = np.max(np.abs(weight_data)) else: @@ -1282,11 +1305,13 @@ def _weight_abs_max_quantization(self, scope, place, weight_bits, # Set weight data if not for_test: - set_variable_data(scope, place, var_name, quantized_weight_data) + utils.set_variable_data(scope, place, var_name, + quantized_weight_data) else: dequantized_weight_data = \ (quantized_weight_data * scale).astype(np.float32) - set_variable_data(scope, place, var_name, dequantized_weight_data) + utils.set_variable_data(scope, place, var_name, + dequantized_weight_data) # Save info op._set_attr('quantization_type', 'post_weight_abs_max') @@ -1303,7 +1328,7 @@ def _weight_channel_wise_abs_max_quantization( save_weight_dtype = np.int8 if weight_bits == 8 else np.int16 # Get quantized scale and weight data - weight_data = load_variable_data(scope, var_name) + weight_data = utils.load_variable_data(scope, var_name) if op.type == "mul": scales, quantized_weight_data = \ self._mul_channel_wise_quantization(weight_data, @@ -1317,7 +1342,8 @@ def _weight_channel_wise_abs_max_quantization( # Set weight data if not for_test: - set_variable_data(scope, place, var_name, quantized_weight_data) + utils.set_variable_data(scope, place, var_name, + quantized_weight_data) else: if op.type == "mul": dequantized_weight_data = \ @@ -1328,7 +1354,8 @@ def _weight_channel_wise_abs_max_quantization( else: _logger.error(op.type + " is not supported by weight quantization") - set_variable_data(scope, place, var_name, dequantized_weight_data) + utils.set_variable_data(scope, place, var_name, + dequantized_weight_data) # Save info op._set_attr('quantization_type', 'post_weight_channel_wise_abs_max') diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index 6d7c91fddeb77..40c9698638b3a 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -26,12 +26,20 @@ from ....layers import mean from ....executor import scope_guard from ....framework import _get_paddle_place -from .utils import _channelwise_quant_axis1_ops, quant_tensor +from . import utils __all__ = [ - 'QuantizationTransformPass', 'QuantizationFreezePass', 'ConvertToInt8Pass', - 'TransformForMobilePass', 'OutScaleForTrainingPass', - 'OutScaleForInferencePass', 'AddQuantDequantPass' + 'QuantizationTransformPass', + 'QuantizationFreezePass', + 'ConvertToInt8Pass', + 'TransformForMobilePass', + 'OutScaleForTrainingPass', + 'OutScaleForInferencePass', + 'AddQuantDequantPass', + 'QuantizationTransformPassV2', + 'AddQuantDequantPassV2', + 'ReplaceFakeQuantDequantPass', + 'QuantWeightPass', ] _fake_quant_op_list = [ @@ -44,280 +52,14 @@ ] _fake_quant_dequant_op_list = [ - 'fake_quantize_dequantize_moving_average_abs_max' + 'fake_quantize_dequantize_moving_average_abs_max', + "fake_channel_wise_quantize_dequantize_abs_max", + "fake_quantize_dequantize_moving_average_abs_max" ] -_out_scale_op_list = [ - "conv2d", - "depthwise_conv2d", - "mul", - "matmul", - "matmul_v2", - "relu", - "leaky_relu", - "relu6", - "sigmoid", - "tanh", - "prelu", - "swish", - "dropout", - "softmax", - "batch_norm", - "layer_norm", - "elementwise_add", - "pool2d", - "reshape2", - "transpose2", - "concat", - "elementwise_mul", - "elementwise_pow", - "elementwise_sub", - "scale", - "slice", - "hard_swish", - "hard_sigmoid", - "conv2d_transpose", - "gru", - "bilinear_interp", - "nearest_interp", - "trilinear_interp", - "flatten", - "flatten2", - "transpose", - "pad2d", - "pad3d", - "reshape", - "split", - "flatten_contiguous_range", - "squeeze", - "squeeze2", - "nearest_interp_v2", - "fill_constant_batch_size_like", - "bilinear_interp", - "bilinear_interp_v2", - "arg_max", - "abs", - "assign", - "cast", - "clip", - "box_coder", - "crop", - "cumsum", - "equal", - "expand_v2", - "fill_any_like", - "fill_constant", - "gelu", - "instance_norm", - "lookup_table", - "lookup_table_v2", - "norm", - "p_norm", - "pow", - "reduce_mean", - "stack", - "top_k_v2", - "unsqueeze", - "unsqueeze2", - "logical_and", - "logical_not", - "meshgrid", - "roi_align", - "strided_slice", - "where", - "grid_sampler", - "tile", - "group_norm", - "reduce_sum", - "square", - "softplus", - "gather", - "shuffle_channel", -] - -# list op real input and output names, to avoid processing input such as AxisTensor. -_op_real_in_out_name = { - "conv2d": [["Input", "Filter"], ["Output"]], - "depthwise_conv2d": [["Input", "Filter"], ["Output"]], - "conv2d_transpose": [["Input", "Filter"], ["Output"]], - "mul": [["X", "Y"], ["Out"]], - "matmul": [["X", "Y"], ["Out"]], - "matmul_v2": [["X", "Y"], ["Out"]], - "pool2d": [["X"], ["Out"]], - "elementwise_add": [["X", "Y"], ["Out"]], - "concat": [["X"], ["Out"]], - "softmax": [["X"], ["Out"]], - "argmax": [["X"], ["Out"]], - "transpose": [["X"], ["Out"]], - "equal": [["X", "Y"], ["Out"]], - "gather": [["X"], ["Out"]], - "greater_equal": [["X", "Y"], ["Out"]], - "greater_than": [["X", "Y"], ["Out"]], - "less_equal": [["X", "Y"], ["Out"]], - "less_than": [["X", "Y"], ["Out"]], - "mean": [["X"], ["Out"]], - "not_equal": [["X", "Y"], ["Out"]], - "reshape": [["X"], ["Out"]], - "reshape2": [["X"], ["Out"]], - "transpose2": [["X"], ["Out"]], - "bilinear_interp": [["X"], ["Out"]], - "nearest_interp": [["X"], ["Out"]], - "trilinear_interp": [["X"], ["Out"]], - "slice": [["Input"], ["Out"]], - "squeeze": [["X"], ["Out"]], - "elementwise_sub": [["X", "Y"], ["Out"]], - "relu": [["X"], ["Out"]], - "relu6": [["X"], ["Out"]], - "leaky_relu": [["X"], ["Out"]], - "prelu": [["X", "Alpha"], ["Out"]], - "tanh": [["X"], ["Out"]], - "swish": [["X"], ["Out"]], - "dropout": [["X"], ["Out"]], - "batch_norm": [["X"], ["Y"]], - "layer_norm": [["X"], ["Y"]], - "sigmoid": [["X"], ["Out"]], - "elementwise_mul": [["X", "Y"], ["Out"]], - "elementwise_pow": [["X", "Y"], ["Out"]], - "scale": [["X"], ["Out"]], - "hard_swish": [["X"], ["Out"]], - "hard_sigmoid": [["X"], ["Out"]], - "gru": [["Input", "Weight"], ["Hidden"]], - "lstm": [["Input", "Weight"], ["Hidden"]], - "pad2d": [["X"], ["Out"]], - "pad3d": [["X"], ["Out"]], - "flatten": [["X"], ["Out"]], - "flatten2": [["X"], ["Out"]], - "unsqueeze2": [["X"], ["Out"]], - "unsqueeze2": [["X"], ["Out"]], - "flatten_contiguous_range": [["X"], ["Out"]], - "split": [["X"], ["Out"]], - "squeeze2": [["X"], ["Out"]], - "nearest_interp_v2": [["X"], ["Out"]], - "bilinear_interp": [["X"], ["Out"]], - "bilinear_interp_v2": [["X"], ["Out"]], - "fill_constant_batch_size_like": [["Input"], ["Out"]], - "arg_max": [["X"], ["Out"]], - "abs": [["X"], ["Out"]], - "assign": [["X"], ["Out"]], - "cast": [["X"], ["Out"]], - "clip": [["X"], ["Out"]], - "box_coder": [["PriorBox"], ["OutputBox"]], - "crop": [["X"], ["Out"]], - "cumsum": [["X"], ["Out"]], - "expand_v2": [["X"], ["Out"]], - "fill_any_like": [["X"], ["Out"]], - "fill_constant": [[], ["Out"]], - "gelu": [["X"], ["Out"]], - "instance_norm": [["X"], ["Out"]], - "lookup_table": [["W", "Ids"], ["Out"]], - "lookup_table_v2": [["W", "Ids"], ["Out"]], - "norm": [["X"], ["Norm"]], - "p_norm": [["X"], ["Out"]], - "pow": [["X"], ["Out"]], - "reduce_mean": [["X"], ["Out"]], - "stack": [["X"], ["Y"]], - "top_k_v2": [["X"], ["Out", "Indices"]], - "logical_and": [["X", "Y"], ["Out"]], - "logical_not": [["X"], ["Out"]], - "meshgrid": [["X"], ["Out"]], - "roi_align": [["X", "ROIs"], ["Out"]], - "strided_slice": [["Input"], ["Out"]], - "where": [["Condition", "X", "Y"], ["Out"]], - "grid_sampler": [["X", "Grid"], ["Output"]], - "tile": [["X"], ["Out"]], - "group_norm": [["X"], ["Y", "Mean", "Variance"]], - "reduce_sum": [["X"], ["Out"]], - "square": [["X"], ["Out"]], - "softplus": [["X"], ["Out"]], - "shuffle_channel": [["X"], ["Out"]], -} - _conv_ops = ['conv2d', 'depthwise_conv2d', 'conv2d_transpose'] -def _get_op_input_var_names(op): - """ - Get the input var names of the op. - Args: - op(IrNode, Operator): the input op. - Returns: - input_var_names or None. - """ - assert isinstance(op, (IrNode, Operator)), \ - "The input op should be IrNode or Operator." - var_names = [] - op_name = op.name() if isinstance(op, IrNode) \ - else op.type - if op_name not in _op_real_in_out_name: - return [] - - name_list = _op_real_in_out_name[op_name][0] - for name in name_list: - var_name = op.input(name) - if isinstance(var_name, list): - var_names.extend(var_name) - else: - var_names.append(var_name) - return var_names - - -def _get_input_name_index(op, input_var_name): - """Get the input name and index of the var_name in the op""" - assert isinstance(op, (IrNode, Operator)), \ - "The input op should be IrNode or Operator." - op_name = op.name() if isinstance(op, IrNode) \ - else op.type - if op_name not in _op_real_in_out_name: - return None - - res = None - for argname in _op_real_in_out_name[op_name][0]: - var_names = op.input(argname) - for index, name in enumerate(var_names): - if name == input_var_name: - res = (argname, index) - return res - - -def _get_op_output_var_names(op): - """ """ - assert isinstance(op, (IrNode, Operator)), \ - "The input op should be IrNode or Operator." - var_names = [] - op_name = op.name() if isinstance(op, IrNode) \ - else op.type - if op_name not in _op_real_in_out_name: - return [] - - name_list = _op_real_in_out_name[op_name][1] - for name in name_list: - var_name = op.output(name) - if isinstance(var_name, list): - var_names.extend(var_name) - else: - var_names.append(var_name) - return var_names - - -def _get_output_name_index(op, output_var_name): - """Get the output name and index of the var_name in the op""" - assert isinstance(op, (IrNode, Operator)), \ - "The input op should be IrNode or Operator." - op_name = op.name() if isinstance(op, IrNode) \ - else op.type - if op_name not in _op_real_in_out_name: - return None - - name_list = _op_real_in_out_name[op_name][1] - res = None - for name in name_list: - var_name = op.output(name) - for index, val in enumerate(var_name): - if val == output_var_name: - res = (name, index) - return res - - def _init_var_node(var_node, value, scope, place): assert isinstance(value, np.ndarray), 'The type of value should be numpy array.' @@ -334,7 +76,7 @@ def _is_input_all_not_persistable(graph, op_node): Analyse the real inputs of the op node are all not persistable. ''' is_input_all_not_persistable = True - for var_name in _get_op_input_var_names(op_node): + for var_name in utils._get_op_input_var_names(op_node): in_node = graph._find_node_by_name(op_node.inputs, var_name) is_input_all_not_persistable = (is_input_all_not_persistable and \ (not in_node.persistable())) @@ -360,10 +102,6 @@ class QuantizationTransformPass(object): Quantize the ops that have weights. Add quant and dequant ops for the quantized ops's inputs. """ - _supported_quantizable_op_type = [ - 'conv2d', 'depthwise_conv2d', 'conv2d_transpose', 'mul', 'matmul', - 'matmul_v2' - ] def __init__(self, scope=None, @@ -493,7 +231,7 @@ def __init__(self, self._quantizable_ops = quantizable_op_type for op in self._quantizable_ops: - assert op in QuantizationTransformPass._supported_quantizable_op_type, \ + assert op in utils._weight_supported_quantizable_op_type, \ op + " is not supported for quantization." self._quantizable_grad_ops = [ '%s_grad' % (op) for op in self._quantizable_ops @@ -588,7 +326,7 @@ def _transform_forward(graph, op): else self._activation_quantize_type if quant_type == 'channel_wise_abs_max': # Weight quantization quant_axis = 1 if op.name() in \ - _channelwise_quant_axis1_ops else 0 + utils._channelwise_quant_axis1_ops else 0 quant_var_node, scale_var_node = self._insert_channel_quant_op( graph, var_node, name, quant_bits, quant_axis) dequant_var_node = self._insert_channel_dequant_op( @@ -1289,13 +1027,13 @@ def apply(self, graph): if self._round_type == 'round': if any( _check_grandchild_op_node(op_node, op) - for op in _channelwise_quant_axis1_ops): + for op in utils._channelwise_quant_axis1_ops): quant_axis = 1 else: quant_axis = 0 - quantized_param_v = quant_tensor(param_v.copy(), - scale_v, quant_axis, - self._weight_bits) + quantized_param_v = utils.quant_tensor( + param_v.copy(), scale_v, quant_axis, + self._weight_bits) quantized_param_v = np.round(quantized_param_v) if self._bias_correction == True: quantized_param_v = self._bias_correction_w( @@ -1319,7 +1057,7 @@ def apply(self, graph): op_node_desc.attr("quantization_type") == "qat_with_weight": if self._weight_quantize_type == 'channel_wise_abs_max': quant_axis = 1 if op_node.name() in \ - _channelwise_quant_axis1_ops else 0 + utils._channelwise_quant_axis1_ops else 0 self._insert_post_channel_dequant_op(graph, op_node, quant_axis) else: @@ -1555,8 +1293,8 @@ def _bias_correction_w(self, x, x_quant, scale_v, quant_axis): mean_bias = np.resize(mean_bias, x.shape) x_dequant = (mean_bias + x_dequant) * std_bias - quantized_param_v = quant_tensor(x_dequant, scale_v, quant_axis, - self._weight_bits) + quantized_param_v = utils.quant_tensor(x_dequant, scale_v, quant_axis, + self._weight_bits) return quantized_param_v @@ -1707,7 +1445,7 @@ def __init__(self, scope=None, place=None, moving_rate=0.9): self._place = _get_paddle_place(place) self._moving_rate = moving_rate self._is_test = None - self._teller_set = _out_scale_op_list + self._teller_set = utils._out_scale_op_list def apply(self, graph): """ @@ -1725,7 +1463,7 @@ def apply(self, graph): if op.name() in self._teller_set: target_ops.append(op) for op in target_ops: - for output_var_name in _get_op_output_var_names(op): + for output_var_name in utils._get_op_output_var_names(op): in_node = graph._find_node_by_name(op.outputs, output_var_name) if in_node.dtype() not in \ [core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32]: @@ -1796,14 +1534,13 @@ def apply(self, graph): graph.link_to(accum_in_node, scale_op_node) graph.link_to(scale_op_node, state_out_node) graph.link_to(scale_op_node, accum_out_node) - graph.resolve_hazard() return graph def _scale_name(self, var_name): """ Return the scale name for the var named `var_name`. """ - return "%s@scale" % (var_name) + return "%s.scale" % (var_name) class OutScaleForInferencePass(object): @@ -1816,7 +1553,7 @@ def __init__(self, scope=None): scope(fluid.Scope): The scope is used to initialize these new parameters. """ self._scope = scope - self._teller_set = _out_scale_op_list + self._teller_set = utils._out_scale_op_list def apply(self, graph): """ @@ -1831,7 +1568,7 @@ def apply(self, graph): op_nodes = graph.all_op_nodes() for op_node in op_nodes: if op_node.name() in self._teller_set: - var_names = _get_op_output_var_names(op_node) + var_names = utils._get_op_output_var_names(op_node) for var_name in var_names: in_node = graph._find_node_by_name(op_node.outputs, var_name) @@ -1848,7 +1585,8 @@ def apply(self, graph): # For compatibility, we save output threshold by two methods. op_node.op()._set_attr("out_threshold", float(scale_value)) - argname_index = _get_output_name_index(op_node, var_name) + argname_index = utils._get_output_name_index(op_node, + var_name) assert argname_index is not None, \ var_name + " is not the output of the op" op_node.op()._set_attr(argname_index[0] + str(argname_index[1]) \ @@ -1861,7 +1599,7 @@ def _scale_name(self, var_name): """ Return the scale name for the var named `var_name`. """ - return "%s@scale" % (var_name) + return "%s.scale" % (var_name) class AddQuantDequantPass(object): @@ -1869,95 +1607,6 @@ class AddQuantDequantPass(object): Quantize the ops that do not have weights, and add quant_dequant op for the quantized ops's inputs. """ - _supported_quantizable_op_type = [ - "pool2d", - "elementwise_add", - "concat", - "softmax", - "argmax", - "transpose", - "equal", - "gather", - "greater_equal", - "greater_than", - "less_equal", - "less_than", - "mean", - "not_equal", - "reshape", - "reshape2", - "dropout", - "bilinear_interp", - "nearest_interp", - "trilinear_interp", - "slice", - "squeeze", - "elementwise_sub", - "mul", - "matmul", - "relu", - "relu6", - "leaky_relu", - "tanh", - "swish", - "scale", - "transpose", - "transpose2", - "sigmoid", - "pad2d", - "flatten", - "flatten2", - "batch_norm", - "layer_norm", - "matmul_v2", - "split", - "flatten_contiguous_range", - "squeeze2", - "nearest_interp_v2", - "bilinear_interp", - "bilinear_interp_v2", - "fill_constant_batch_size_like", - "arg_max", - "abs", - "assign", - "cast", - "clip", - "box_coder", - "crop", - "cumsum", - "elementwise_mul", - "elementwise_pow", - "expand_v2", - "fill_any_like", - "fill_constant", - "gelu", - "hard_sigmoid", - "hard_swish", - "instance_norm", - "lookup_table", - "lookup_table_v2", - "norm", - "p_norm", - "pad3d", - "pow", - "prelu", - "reduce_mean", - "unsqueeze", - "unsqueeze2", - "logical_and", - "logical_not", - "meshgrid", - "roi_align", - "strided_slice", - "where", - "grid_sampler", - "tile", - "group_norm", - "reduce_sum", - "square", - "softplus", - "shuffle_channel", - ] # To be compatible with PaddleSlim, not remove _activation_type for now _activation_type = ["relu", "relu6", "leaky_relu", "tanh", "swish"] @@ -2000,12 +1649,11 @@ def __init__(self, self._skip_pattern = skip_pattern if is_full_quantized: - self._quantizable_op_type = \ - AddQuantDequantPass._supported_quantizable_op_type + self._quantizable_op_type = utils._act_supported_quantizable_op_type else: self._quantizable_op_type = quantizable_op_type for op_type in quantizable_op_type: - assert op_type in AddQuantDequantPass._supported_quantizable_op_type, \ + assert op_type in utils._act_supported_quantizable_op_type, \ op_type + " is not supported for quantization." self._quantizable_grad_op_type = [ '%s_grad' % (op) for op in self._quantizable_op_type @@ -2050,7 +1698,7 @@ def apply(self, graph): "qat_without_weight") op_node.op()._set_attr("activation_bits", self._quant_bits) op_node.op()._set_attr("with_quant_attr", True) - arg_names = _get_op_input_var_names(op_node) + arg_names = utils._get_op_input_var_names(op_node) for arg_name in arg_names: in_node = graph._find_node_by_name(op_node.inputs, arg_name) if arg_name in dequantized_vars_map: @@ -2162,3 +1810,837 @@ def _inser_quant_dequant_moving_average_abs_max_op(self, graph, var_node, graph.link_to(quant_op_node, accum_out_node) return quant_var_node, scale_out_node + + +class InsertQuantizeLinear(object): + """ + Insert quantize_linear and dequantize_linear op before ops. + """ + + def __init__(self, + place, + scope, + quant_bits=8, + quant_axis=-1, + channel_wise=False, + is_test=True): + self._place = place + self._scope = scope + self.quant_bits = quant_bits + self.quant_axis = quant_axis + self.channel_wise = channel_wise + self._is_test = is_test + + def insert_quant_op(self, graph, var_node): + assert var_node.is_var(), '{} is not a var'.format(var_node.name()) + + quant_var_node = graph.create_var_node( + name=self._quantized_var_name(var_node.name()), + var_type=var_node.type(), + shape=var_node.shape(), + var_dtype=var_node.dtype()) + data_type = 'float64' if var_node.dtype( + ) == core.VarDesc.VarType.FP64 else 'float32' + if self.channel_wise: + scale_var_shape = var_node.shape()[self.quant_axis] + scale_var_type = core.VarDesc.VarType.LOD_TENSOR + init_scale_value = np.zeros(scale_var_shape, dtype=data_type) + else: + scale_var_shape = 1 + scale_var_type = var_node.type() + init_scale_value = np.array([0.001], dtype=data_type) + scale_var_node = graph.create_persistable_node( + name=self._quantized_scale_name(var_node.name()), + var_type=scale_var_type, + shape=[scale_var_shape], + var_dtype=var_node.dtype()) + _init_var_node(scale_var_node, init_scale_value, self._scope, + self._place) + + zero_point_node = None + if zero_point_node is None: + zero_point_node = graph.create_persistable_node( + name=self._zero_point_name(quant_var_node.name()), + var_type=core.VarDesc.VarType.LOD_TENSOR, + shape=scale_var_node.shape(), + var_dtype=core.VarDesc.VarType.INT32) + _init_var_node( + zero_point_node, + np.zeros( + scale_var_node.shape(), dtype="int32"), + self._scope, + self._place) + + inputs = {"X": var_node, "Scale": scale_var_node} + if zero_point_node is not None: + inputs["ZeroPoint"] = zero_point_node + + attrs = {"quant_axis": self.quant_axis, "bit_length": self.quant_bits} + outputs = {"Y": quant_var_node} + if not self._is_test: + attrs["is_test"] = self._is_test + attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward + outputs["OutScale"] = scale_var_node + + quant_op_node = graph.create_op_node( + op_type="quantize_linear", + attrs=attrs, + inputs=inputs, + outputs=outputs) + + graph.link_to(var_node, quant_op_node) + graph.link_to(scale_var_node, quant_op_node) + if zero_point_node is not None: + graph.link_to(zero_point_node, quant_op_node) + graph.link_to(quant_op_node, quant_var_node) + return quant_var_node, scale_var_node + + def insert_dequant_op(self, graph, var_node, scale_var_node): + assert var_node.is_var(), '{} is not a var'.format(var_node.name()) + + dequant_var_node = graph.create_var_node( + name=self._dequantized_var_name(var_node.name()), + var_type=var_node.type(), + shape=var_node.shape(), + var_dtype=var_node.dtype()) + + zero_point_node = None + if zero_point_node is None: + zero_point_node = graph.create_persistable_node( + name=self._zero_point_name(dequant_var_node.name()), + var_type=core.VarDesc.VarType.LOD_TENSOR, + shape=scale_var_node.shape(), + var_dtype=core.VarDesc.VarType.INT32) + _init_var_node( + zero_point_node, + np.zeros( + scale_var_node.shape(), dtype="int32"), + self._scope, + self._place) + + inputs = {"X": var_node, "Scale": scale_var_node} + if zero_point_node is not None: + inputs["ZeroPoint"] = zero_point_node + + attrs = {"quant_axis": self.quant_axis, "bit_length": self.quant_bits} + if not self._is_test: + attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward + + quant_op_node = graph.create_op_node( + op_type="dequantize_linear", + attrs=attrs, + inputs=inputs, + outputs={"Y": dequant_var_node}) + + graph.link_to(var_node, quant_op_node) + graph.link_to(scale_var_node, quant_op_node) + if zero_point_node is not None: + graph.link_to(zero_point_node, quant_op_node) + graph.link_to(quant_op_node, dequant_var_node) + return dequant_var_node + + def _quantized_var_name(self, var_name): + """ + Return quantized variable name for the input `var_name`. + """ + return "%s.quantized" % (var_name) + + def _dequantized_var_name(self, var_name): + """ + Return dequantized variable name for the input `var_name`. + """ + return "%s.dequantized" % (var_name) + + def _quantized_scale_name(self, var_name): + """ + Return the scale name of quantized variable for the input `var_name`. + """ + return "%s.scale" % (var_name) + + def _zero_point_name(self, var_name): + """ + Return the scale name for the var named `var_name`. + """ + return "%s@zero_point" % (var_name) + + +class QuantizationTransformPassV2(object): + """ + Quantize the ops that have weights. Add quant and dequant ops for + the quantized ops's inputs. + """ + + def __init__(self, + scope=None, + place=None, + weight_bits=8, + activation_bits=8, + activation_quantize_type='abs_max', + weight_quantize_type='abs_max', + window_size=10000, + moving_rate=0.9, + skip_pattern=['skip_quant'], + quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul'], + weight_quantize_func=None, + act_quantize_func=None, + weight_preprocess_func=None, + act_preprocess_func=None, + optimizer_func=None, + executor=None): + + self._scope = scope + self._place = _get_paddle_place(place) + self._weight_bits = weight_bits + self._activation_bits = activation_bits + self._skip_pattern = skip_pattern + self._weight_quantize_func = weight_quantize_func + self._act_quantize_func = act_quantize_func + self._weight_preprocess_func = weight_preprocess_func + self._act_preprocess_func = act_preprocess_func + self._optimizer = optimizer_func + self._exe = executor + quant_type = [ + 'abs_max', 'channel_wise_abs_max', 'range_abs_max', + 'moving_average_abs_max' + ] + assert activation_quantize_type != 'channel_wise_abs_max', \ + "The activation quantization type does not support 'channel_wise_abs_max'." + if activation_quantize_type not in quant_type: + raise ValueError( + "Unknown activation_quantize_type : '%s'. It can only be " + "'abs_max' or 'range_abs_max' or 'moving_average_abs_max'." % + (str(activation_quantize_type))) + if weight_quantize_type not in quant_type: + raise ValueError( + "Unknown weight_quantize_type: '%s'. It can only be " + "'abs_max' or 'channel_wise_abs_max' or 'range_abs_max' " + "or 'moving_average_abs_max'." % (str(weight_quantize_type))) + + self._activation_quantize_type = activation_quantize_type + self._weight_quantize_type = weight_quantize_type + self._window_size = window_size + self._moving_rate = moving_rate + + self._quantizable_ops = quantizable_op_type + for op in self._quantizable_ops: + assert op in utils._weight_supported_quantizable_op_type, \ + op + " is not supported for quantization." + self._quantizable_grad_ops = [ + '%s_grad' % (op) for op in self._quantizable_ops + ] + self._is_test = None + self._global_step = None + + self.create_var_map = {} + self.create_op_map = {} + + # marked the variable which has been dequantized. + self.dequantized_vars = collections.OrderedDict() + self.persistable_vars = [] + self.processed_vars = [] + + def _quant_preprocess(self, op_node): + user_skipped = False + if isinstance(self._skip_pattern, list): + user_skipped = op_node.op().has_attr("op_namescope") and \ + any(pattern in op_node.op().attr("op_namescope") \ + for pattern in self._skip_pattern) + elif isinstance(self._skip_pattern, str): + user_skipped = op_node.op().has_attr("op_namescope") and \ + op_node.op().attr("op_namescope").find( + self._skip_pattern) != -1 + + if user_skipped: + op_node.op()._set_attr("skip_quant", True) + op_node.op()._set_attr("with_quant_attr", True) + + def _transform_forward(self, graph, op): + op.op()._set_attr("quantization_type", "qat_with_weight") + #op.op()._set_attr("with_quant_attr", True) + inputs = op.inputs + for var_node in inputs: + if var_node.name() not in op.input_arg_names(): + continue + if var_node.name() in self.dequantized_vars: + dequant_var_node = self.dequantized_vars[var_node.name()] + else: + name = var_node.name() + if name in self.processed_vars: + continue + is_weight = True if var_node.name() in self.persistable_vars \ + else False + + # if var node is weight and weight_preprocess_func is not None, + # will insert weight preprocess func + # to preorocess weight before quantization + # if var node is activation and act_preprocess_func is not None, + # will insert activation preprocess func + # to preorocess activation before quantization + if is_weight and self._weight_preprocess_func is not None: + var_node = self._insert_func( + graph, self._weight_preprocess_func, var_node, op) + elif not is_weight and self._act_preprocess_func is not None: + var_node = self._insert_func( + graph, self._act_preprocess_func, var_node, op) + + # if var node is weight and weight_quantize_func is not None, + # will insert weight quantize func to quantize and dequantize weight + # if var node is activation and act_quantize_func is not None, + # will insert act quantize func to quantize and dequantize activation + if is_weight and self._weight_quantize_func is not None: + target_out_node = self._insert_func( + graph, self._weight_quantize_func, var_node, op) + processed_vars.append(name) + continue + elif not is_weight and self._act_quantize_func is not None: + target_out_node = self._insert_func( + graph, self._act_quantize_func, var_node, op) + processed_vars.append(name) + continue + + quant_bits = self._weight_bits if var_node.name() in self.persistable_vars \ + else self._activation_bits + quant_type = self._weight_quantize_type if is_weight \ + else self._activation_quantize_type + quant_axis = -1 + channel_wise = False + if quant_type == 'channel_wise_abs_max': # Weight quantization + channel_wise = True + quant_axis = 1 if op.name() in \ + utils._channelwise_quant_axis1_ops else 0 + insert_quant_pass = InsertQuantizeLinear( + self._place, + self._scope, + quant_bits=quant_bits, + quant_axis=quant_axis, + channel_wise=channel_wise, + is_test=self._is_test) + quant_var_node, scale_var_node = insert_quant_pass.insert_quant_op( + graph, var_node) + dequant_var_node = insert_quant_pass.insert_dequant_op( + graph, quant_var_node, scale_var_node) + + self.dequantized_vars[name] = dequant_var_node + graph.update_input_link(var_node, dequant_var_node, op) + + def _transform_backward(self, graph, op): + for var_node in op.inputs: + if var_node.name() not in op.input_arg_names(): + continue + if var_node.name() in self.dequantized_vars: + dequant_var_node = self.dequantized_vars[var_node.name()] + graph.update_input_link(var_node, dequant_var_node, op) + + def _has_weight(self, op): + has_weight = False + for var_node in op.inputs: + if var_node.name() not in op.input_arg_names(): + continue + name = var_node.name() + if var_node.name() in self.persistable_vars: + has_weight = True + return has_weight + + def _create_global_step(self, graph): + if self._weight_quantize_type == 'range_abs_max' or \ + self._activation_quantize_type == 'range_abs_max': + counter_name = cpt.to_text('@STEP_COUNTER@') + for node in graph.all_var_nodes(): + if node.name() == counter_name: + self._global_step = node + if self._global_step is None: + global_step_in = graph.create_persistable_node( + name=counter_name, + var_type=core.VarDesc.VarType.LOD_TENSOR, + shape=[1], + var_dtype=core.VarDesc.VarType.INT64) + _init_var_node( + global_step_in, + np.zeros( + [1], dtype='int64'), + self._scope, + self._place) + global_step_out = graph.create_var_node_from_desc( + global_step_in.var()) + # The attribute of `op_role` is needed by ParallelExecutor. + increment_op = graph.create_op_node( + op_type='increment', + attrs={ + 'step': 1.0, + 'op_role': + core.op_proto_and_checker_maker.OpRole.Forward + }, + inputs={'X': global_step_in}, + outputs={'Out': global_step_out}) + graph.link_to(global_step_in, increment_op) + graph.link_to(increment_op, global_step_out) + self._global_step = global_step_out + + def _is_skip_quant(self, graph, op_node): + """ + Analyse whether the op node skips quantization. + """ + is_skip = False + if op_node.op().has_attr("skip_quant") and \ + op_node.op().attr("skip_quant"): + is_skip = True + # if the inputs of mul and matmul are not all persistable, use + # AddQuantDequantPass to quantize them. + if op_node.name() in ["mul", "matmul"] and \ + _is_input_all_not_persistable(graph, op_node): + is_skip = True + if op_node.op().has_attr("quantization_type") and \ + op_node.op().attr("quantization_type") == "qat_without_weight": + is_skip = True + return is_skip + + def apply(self, graph): + """ + Quantize the graph for training process. According to weight and + activation quantization type, the graph will be added some fake + quantize operators and fake dequantize operators. + + Args: + graph(IrGraph): the applied graph. + Returns: + None + """ + assert isinstance(graph, + IrGraph), 'graph must be the instance of IrGraph.' + self._is_test = graph.is_test() + + self.persistable_vars = [ + p.name() for p in graph.all_persistable_nodes() + ] + + if not self._is_test: + self._create_global_step(graph) + ops = graph.all_op_nodes() + # Do the preproccess of quantization, such as skipping some ops + # for not being quantized. + for op in ops: + if op.name() in self._quantizable_ops or \ + op.name() in self._quantizable_grad_ops: + self._quant_preprocess(op) + # Insert mapping table to solve the problem in saving inference model. + graph.out_node_mapping_table = dict() + # The process of _transform_forward and _transform_backward is needed in two for loops. + # The loop for transforming the forward graph: + for op in ops: + if op.name() in self._quantizable_ops: + if not self._is_skip_quant(graph, op) and self._has_weight(op): + self._transform_forward(graph, op) + # The loop for renaming the inputs of backward op. + for op in ops: + if op.name() in self._quantizable_grad_ops and self._has_weight(op): + self._transform_backward(graph, op) + #graph.resolve_hazard() + return graph + + +class AddQuantDequantPassV2(object): + """ + Quantize the ops that do not have weights, and add quant_linear and dequant_linear + op for the quantized ops's inputs. + """ + + # To be compatible with PaddleSlim, not remove _activation_type for now + _activation_type = ["relu", "relu6", "leaky_relu", "tanh", "swish"] + + def __init__(self, + scope=None, + place=None, + moving_rate=0.9, + quant_bits=8, + skip_pattern=["skip_quant"], + quantizable_op_type=["elementwise_add", "pool2d"], + is_full_quantized=False): + """ + Constructor. + + Args: + scope(fluid.Scope): The scope is used to initialize these new parameters. + place(fluid.CPUPlace|fluid.CUDAPlace|str): place is used to initialize new + parameters described above. If ``place`` is string, it can be It can be ``cpu`` + or ``gpu:x``, where ``x`` is the index of the GPUs. + moving_rate(float, optional): the param for 'quant_dequant_moving_average_abs_max' + quantization. Default is 0.9. + quant_bits(int, optional): quantization bit number for activation. Default is 8. + skip_pattern(str, optional): The user-defined quantization skip pattern, which + will be presented in the name scope of an op. When the skip pattern is + detected in an op's name scope, the corresponding op will not be quantized. + Default is 'skip_quant'. + quantizable_op_type(list[str], optional): List the type of ops that will be + quantized. Default is ["elementwise_add", "pool2d"]. + is_full_quantized(bool, optional): If set is_full_quantized as True, apply + quantization to all supported quantizable op type. If set is_full_quantized + as False, only apply quantization to the op type according to the input + quantizable_op_type. + """ + self._scope = scope + self._place = _get_paddle_place(place) + self._moving_rate = moving_rate + self._quant_bits = quant_bits + self._is_test = None + self._skip_pattern = skip_pattern + + if is_full_quantized: + self._quantizable_op_type = utils._act_supported_quantizable_op_type + else: + self._quantizable_op_type = quantizable_op_type + for op_type in quantizable_op_type: + assert op_type in utils._act_supported_quantizable_op_type, \ + op_type + " is not supported for quantization." + self._quantizable_grad_op_type = [ + '%s_grad' % (op) for op in self._quantizable_op_type + ] + + assert self._scope != None, "scope must not be None." + assert self._place != None, "place must not be None." + self.persistable_vars = [] + + def apply(self, graph): + """ + Add quant_dequant before some ops, such as the 'elementwise_add' and + 'pool2d' op. + + Args: + graph(IrGraph): the target graph. + Returns: + None + """ + assert isinstance(graph, + IrGraph), 'graph must be the instance of IrGraph.' + self._is_test = graph.is_test() + dequantized_vars_map = collections.OrderedDict() + + self.persistable_vars = [ + p.name() for p in graph.all_persistable_nodes() + ] + + # Forward stage, insert quant_dequant op + all_op_nodes = graph.all_op_nodes() + for op_node in all_op_nodes: + if op_node.name() in self._quantizable_op_type: + is_skip = False + if isinstance(self._skip_pattern, list): + is_skip = op_node.op().has_attr("op_namescope") and \ + any(pattern in op_node.op().attr("op_namescope") for pattern in self._skip_pattern) + elif isinstance(self._skip_pattern, str): + is_skip = op_node.op().has_attr("op_namescope") and \ + op_node.op().attr("op_namescope").find(self._skip_pattern) != -1 + is_quantized = op_node.op().has_attr("quantization_type") and \ + op_node.op().attr("quantization_type") == "qat_with_weight" + if is_skip or is_quantized: + continue + + op_node.op()._set_attr("quantization_type", + "qat_without_weight") + arg_names = utils._get_op_input_var_names(op_node) + for arg_name in arg_names: + in_node = graph._find_node_by_name(op_node.inputs, arg_name) + if in_node.persistable(): + continue + if arg_name in dequantized_vars_map: + dequant_var_node = dequantized_vars_map[arg_name] + else: + insert_quant_pass = InsertQuantizeLinear( + self._place, + self._scope, + quant_bits=self._quant_bits, + quant_axis=-1, + channel_wise=False, + is_test=self._is_test) + quant_var_node, scale_var_node = insert_quant_pass.insert_quant_op( + graph, in_node) + dequant_var_node = insert_quant_pass.insert_dequant_op( + graph, quant_var_node, scale_var_node) + dequantized_vars_map[arg_name] = dequant_var_node + graph.update_input_link(in_node, dequant_var_node, op_node) + + # Backward stage, update input link + for op_node in all_op_nodes: + if op_node.name() in self._quantizable_grad_op_type: + for input_name in op_node.input_arg_names(): + if input_name in dequantized_vars_map: + in_node = graph._find_node_by_name(op_node.inputs, + input_name) + dequant_var_node = dequantized_vars_map[input_name] + graph.update_input_link(in_node, dequant_var_node, + op_node) + + return graph + + +class ReplaceFakeQuantDequantPass(object): + def __init__(self, scope, place): + self._place = _get_paddle_place(place) + self._scope = scope + assert self._scope != None, "scope must not be None." + assert self._place != None, "place must not be None." + + def apply(self, graph): + assert isinstance(graph, + IrGraph), 'graph must be the instance of IrGraph.' + fake_quant_ops = [] + fake_dequant_ops = [] + fake_quant_dequant_ops = [] + + for op in graph.all_op_nodes(): + if op.name() in _fake_quant_op_list: + fake_quant_ops.append(op) + elif op.name() in _fake_dequant_op_list: + fake_dequant_ops.append(op) + elif op.name() in _fake_quant_dequant_op_list: + fake_quant_dequant_ops.append(op) + + for _op in fake_quant_ops: + print(_op.name()) + self._replace_op(graph, _op, "quantize_linear") + graph.safe_remove_nodes(_op) + + for _op in fake_dequant_ops: + self._replace_op(graph, _op, "dequantize_linear") + graph.safe_remove_nodes(_op) + + for _op in fake_quant_dequant_ops: + self._replace_op(graph, _op, "quantize_dequantize") + graph.safe_remove_nodes(_op) + + graph.resolve_hazard() + return graph + + def _replace_op(self, graph, op, target_op_name): + assert target_op_name in [ + "quantize_linear", "dequantize_linear", "quantize_dequantize" + ] + x_node = graph._find_node_by_name(op.inputs, op.input("X")[0]) + out_node = graph._find_node_by_name(op.outputs, op.output("Out")[0]) + if target_op_name == "quantize_linear" or target_op_name == "quantize_dequantize": + scale_node = graph._find_node_by_name(op.outputs, + op.output("OutScale")[0]) + else: + scale_name = "Scales" if op.op().has_attr("quant_axis") else "Scale" + scale_node = graph._find_node_by_name(op.inputs, + op.input(scale_name)[0]) + + quant_axis = op.op().attr("quant_axis") if op.op().has_attr( + "quant_axis") else -1 + bit_length = op.op().attr("bit_length") if op.op().has_attr( + "bit_length") else 8 + + zero_point_node = None + quanted_node = out_node if target_op_name == "quantize_linear" else x_node + if zero_point_node is None: + zero_point_node = graph.create_persistable_node( + name=self._zero_point_name(quanted_node.name()), + var_type=core.VarDesc.VarType.LOD_TENSOR, + shape=scale_node.shape(), + var_dtype=core.VarDesc.VarType.INT32) + _init_var_node( + zero_point_node, + np.zeros( + scale_node.shape(), dtype="int32"), + self._scope, + self._place) + + if target_op_name != "quantize_dequantize": + inputs = {"X": x_node, "Scale": scale_node} + if zero_point_node is not None: + inputs["ZeroPoint"] = zero_point_node + quant_op_node = graph.create_op_node( + op_type=target_op_name, + attrs={"quant_axis": quant_axis, + "bit_length": bit_length}, + inputs=inputs, + outputs={"Y": out_node}) + + graph.link_to(x_node, quant_op_node) + graph.link_to(scale_node, quant_op_node) + if zero_point_node is not None: + graph.link_to(zero_point_node, quant_op_node) + graph.link_to(quant_op_node, out_node) + else: + quant_var_node = graph.create_var_node( + name=self._quantized_var_name(x_node.name()), + var_type=x_node.type(), + shape=x_node.shape(), + var_dtype=x_node.dtype()) + quant_op_node = graph.create_op_node( + op_type="quantize_linear", + attrs={"quant_axis": quant_axis, + "bit_length": bit_length}, + inputs={ + "X": x_node, + "Scale": scale_node, + "ZeroPoint": zero_point_node + }, + outputs={"Y": quant_var_node}) + graph.link_to(x_node, quant_op_node) + graph.link_to(scale_node, quant_op_node) + if zero_point_node is not None: + graph.link_to(zero_point_node, quant_op_node) + graph.link_to(quant_op_node, quant_var_node) + dequant_op_node = graph.create_op_node( + op_type="dequantize_linear", + attrs={"quant_axis": quant_axis, + "bit_length": bit_length}, + inputs={ + "X": quant_var_node, + "Scale": scale_node, + "ZeroPoint": zero_point_node + }, + outputs={"Y": out_node}) + graph.link_to(quant_var_node, dequant_op_node) + graph.link_to(scale_node, dequant_op_node) + if zero_point_node is not None: + graph.link_to(zero_point_node, dequant_op_node) + graph.link_to(dequant_op_node, out_node) + + def _quantized_var_name(self, var_name): + """ + Return quantized variable name for the input `var_name`. + """ + return "%s.quantized" % (var_name) + + def _zero_point_name(self, var_name): + """ + Return the scale name for the var named `var_name`. + """ + return "%s@zero_point" % (var_name) + + +class QuantWeightPass(object): + """ + quant weights and remove weights input quantize_linear node. for example: + `weight -> quant -> dequant -> conv2d` will be frozen into `weight -> dequant -> conv2d`, + and weight will be scaled offline. + + Args: + scope(fluid.Scope): scope is used to get the weight tensor values. + place(fluid.CPUPlace|fluid.CUDAPlace|str): place is used to restore the weight tensors. + If it's string, It can be ``cpu``, and ``gpu:x``, where ``x`` is the index of the GPUs. + bias_correction(bool): whether use bias correction for post-training quantization. + https://arxiv.org/abs/1810.05723. + """ + + def __init__(self, scope, place, bias_correction=False, quant_bits=8): + self._place = _get_paddle_place(place) + self._scope = scope + self._bias_correction = bias_correction + self._quant_bits = quant_bits + #self.scale_dict = scale_dict + assert self._scope != None, "scope must not be None." + assert self._place != None, "place must not be None." + + def apply(self, graph): + assert isinstance(graph, + IrGraph), 'graph must be the instance of IrGraph.' + fake_quant_ops_for_weight = [] + + fake_quant_ops = [ + op for op in graph.all_op_nodes() if op.name() == "quantize_linear" + ] + for _op in fake_quant_ops: + x_node = graph._find_node_by_name(_op.inputs, _op.input("X")[0]) + if x_node.persistable(): + scale_node = graph._find_node_by_name(_op.inputs, + _op.input("Scale")[0]) + zero_point_node = graph._find_node_by_name( + _op.inputs, _op.input("ZeroPoint")[0]) + out_node = graph._find_node_by_name(_op.outputs, + _op.output("Y")[0]) + + scale_v = self._load_var(scale_node.name()) + assert scale_v.ndim in [1, 2 + ], "the dim of scale_v should be 1 or 2" + if scale_v.ndim == 2: + scale_v = scale_v[0] + if scale_v.size == 1 and _op.name() == 'abs_max': + scale_v = scale_v[0] + else: + scale_v = scale_v.tolist() + param_v = self._load_var(x_node.name()) + quant_axis = _op.op().attr("quant_axis") + bits_length = _op.op().attr("bit_length") + quantized_param_v = utils.quant_tensor(param_v.copy(), scale_v, + quant_axis, bits_length) + if self._bias_correction == True: + quantized_param_v = self._bias_correction_w( + param_v, quantized_param_v, scale_v, quant_axis) + if self._quant_bits == 8: + save_weight_dtype = np.int8 + elif self._quant_bits == 4: + save_weight_dtype = np.int4 + # cast weight type + quantized_param_v = quantized_param_v.astype(save_weight_dtype) + self._restore_var(x_node.name(), quantized_param_v) + + for next_op_node in out_node.outputs: + graph.update_input_link(out_node, x_node, next_op_node) + graph.safe_remove_nodes(out_node) + self._remove_unused_var_nodes(graph) + + def _remove_unused_var_nodes(self, graph): + all_used_vars = set() + ops = graph.all_op_nodes() + for op_node in ops: + for input_node in op_node.inputs: + all_used_vars.add(input_node) + for output_node in op_node.outputs: + all_used_vars.add(output_node) + + all_used_vars = {n.node for n in all_used_vars} + all_unused_vars = { + n + for n in filter(lambda node: node.node not in all_used_vars, + graph.all_var_nodes()) + } + graph.safe_remove_nodes(all_unused_vars) + + def _load_var(self, name): + return np.array(self._scope.find_var(name).get_tensor()) + + def _restore_var(self, name, array): + tensor = self._scope.find_var(name).get_tensor() + tensor.set(array, self._place) + + def _bias_correction_w(self, x, x_quant, scale_v, quant_axis): + ''' + Bias correction for weight + ''' + eps = 1e-8 + bnt = (1 << (self._weight_bits - 1)) - 1 + x_dequant = x_quant.copy() + if isinstance(scale_v, list): + if quant_axis == 0: + for i, s in enumerate(scale_v): + x_dequant[i] = x_dequant[i] * s / bnt + quant_bias = x - x_dequant + mean_bias = quant_bias.reshape(quant_bias.shape[0], -1).mean(-1) + std_orig = x.reshape(x.shape[0], -1).std(-1) + std_quant = x_dequant.reshape(x_dequant.shape[0], -1).std(-1) + std_bias = std_orig / (std_quant + eps) + else: + for i, s in enumerate(scale_v): + x_dequant[:, i] = x_quant[:, i] * s / bnt + quant_bias = x - x_dequant + mean_bias = np.array([ + quant_bias[:, i].mean() for i in range(quant_bias.shape[1]) + ]) + std_orig = np.array([x[:, i].std() for i in range(x.shape[1])]) + std_quant = np.array( + [x_dequant[:, i].std() for i in range(x_dequant.shape[1])]) + std_bias = std_orig / (std_quant + eps) + else: + x_dequant = x_quant * scale_v / bnt + mean_bias = (x - x_dequant).mean() + std_bias = x.std() / (x_dequant.std() + eps) + if mean_bias.ndim == 1: + std_bias = np.resize(std_bias, x.shape) + mean_bias = np.resize(mean_bias, x.shape) + + x_dequant = (mean_bias + x_dequant) * std_bias + quantized_param_v = self._quant(x_dequant, scale_v, self._weight_bits, + quant_axis) + return quantized_param_v diff --git a/python/paddle/fluid/contrib/slim/quantization/utils.py b/python/paddle/fluid/contrib/slim/quantization/utils.py index 43f33f33c3138..3bea6dd5954e9 100644 --- a/python/paddle/fluid/contrib/slim/quantization/utils.py +++ b/python/paddle/fluid/contrib/slim/quantization/utils.py @@ -13,11 +13,292 @@ # limitations under the License. import numpy as np +from ....framework import IrNode +from ....framework import Operator + +_weight_supported_quantizable_op_type = [ + 'conv2d', 'depthwise_conv2d', 'conv2d_transpose', 'mul', 'matmul', + 'matmul_v2' +] + +_act_supported_quantizable_op_type = [ + "pool2d", + "elementwise_add", + "concat", + "softmax", + "argmax", + "transpose", + "equal", + "gather", + "greater_equal", + "greater_than", + "less_equal", + "less_than", + "mean", + "not_equal", + "reshape", + "reshape2", + "dropout", + "bilinear_interp", + "nearest_interp", + "trilinear_interp", + "slice", + "squeeze", + "elementwise_sub", + "mul", + "matmul", + "relu", + "relu6", + "leaky_relu", + "tanh", + "swish", + "scale", + "transpose", + "transpose2", + "sigmoid", + "pad2d", + "flatten", + "flatten2", + "batch_norm", + "layer_norm", + "matmul_v2", + "split", + "flatten_contiguous_range", + "squeeze2", + "nearest_interp_v2", + "bilinear_interp", + "bilinear_interp_v2", + "fill_constant_batch_size_like", + "arg_max", + "abs", + "assign", + "cast", + "clip", + "box_coder", + "crop", + "cumsum", + "elementwise_mul", + "elementwise_pow", + "expand_v2", + "fill_any_like", + "fill_constant", + "gelu", + "hard_sigmoid", + "hard_swish", + "instance_norm", + "lookup_table", + "lookup_table_v2", + "norm", + "p_norm", + "pad3d", + "pow", + "prelu", + "reduce_mean", + "unsqueeze", + "unsqueeze2", + "logical_and", + "logical_not", + "meshgrid", + "roi_align", + "strided_slice", + "where", + "grid_sampler", + "tile", + "group_norm", + "reduce_sum", + "square", + "softplus", + "shuffle_channel", +] + +_out_scale_op_list = list( + set(_weight_supported_quantizable_op_type + + _act_supported_quantizable_op_type)) _channelwise_quant_axis1_ops = [ 'conv2d_transpose', 'mul', 'matmul', 'matmul_v2' ] +# list op real input and output names, to avoid processing input such as AxisTensor. +_op_real_in_out_name = { + "conv2d": [["Input", "Filter"], ["Output"]], + "depthwise_conv2d": [["Input", "Filter"], ["Output"]], + "conv2d_transpose": [["Input", "Filter"], ["Output"]], + "mul": [["X", "Y"], ["Out"]], + "matmul": [["X", "Y"], ["Out"]], + "matmul_v2": [["X", "Y"], ["Out"]], + "pool2d": [["X"], ["Out"]], + "elementwise_add": [["X", "Y"], ["Out"]], + "concat": [["X"], ["Out"]], + "softmax": [["X"], ["Out"]], + "argmax": [["X"], ["Out"]], + "transpose": [["X"], ["Out"]], + "equal": [["X", "Y"], ["Out"]], + "gather": [["X"], ["Out"]], + "greater_equal": [["X", "Y"], ["Out"]], + "greater_than": [["X", "Y"], ["Out"]], + "less_equal": [["X", "Y"], ["Out"]], + "less_than": [["X", "Y"], ["Out"]], + "mean": [["X"], ["Out"]], + "not_equal": [["X", "Y"], ["Out"]], + "reshape": [["X"], ["Out"]], + "reshape2": [["X"], ["Out"]], + "transpose2": [["X"], ["Out"]], + "bilinear_interp": [["X"], ["Out"]], + "nearest_interp": [["X"], ["Out"]], + "trilinear_interp": [["X"], ["Out"]], + "slice": [["Input"], ["Out"]], + "squeeze": [["X"], ["Out"]], + "elementwise_sub": [["X", "Y"], ["Out"]], + "relu": [["X"], ["Out"]], + "relu6": [["X"], ["Out"]], + "leaky_relu": [["X"], ["Out"]], + "prelu": [["X", "Alpha"], ["Out"]], + "tanh": [["X"], ["Out"]], + "swish": [["X"], ["Out"]], + "dropout": [["X"], ["Out"]], + "batch_norm": [["X"], ["Y"]], + "layer_norm": [["X"], ["Y"]], + "sigmoid": [["X"], ["Out"]], + "elementwise_mul": [["X", "Y"], ["Out"]], + "elementwise_pow": [["X", "Y"], ["Out"]], + "scale": [["X"], ["Out"]], + "hard_swish": [["X"], ["Out"]], + "hard_sigmoid": [["X"], ["Out"]], + "gru": [["Input", "Weight"], ["Hidden"]], + "lstm": [["Input", "Weight"], ["Hidden"]], + "pad2d": [["X"], ["Out"]], + "pad3d": [["X"], ["Out"]], + "flatten": [["X"], ["Out"]], + "flatten2": [["X"], ["Out"]], + "unsqueeze2": [["X"], ["Out"]], + "unsqueeze2": [["X"], ["Out"]], + "flatten_contiguous_range": [["X"], ["Out"]], + "split": [["X"], ["Out"]], + "squeeze2": [["X"], ["Out"]], + "nearest_interp_v2": [["X"], ["Out"]], + "bilinear_interp": [["X"], ["Out"]], + "bilinear_interp_v2": [["X"], ["Out"]], + "fill_constant_batch_size_like": [["Input"], ["Out"]], + "arg_max": [["X"], ["Out"]], + "abs": [["X"], ["Out"]], + "assign": [["X"], ["Out"]], + "cast": [["X"], ["Out"]], + "clip": [["X"], ["Out"]], + "box_coder": [["PriorBox"], ["OutputBox"]], + "crop": [["X"], ["Out"]], + "cumsum": [["X"], ["Out"]], + "expand_v2": [["X"], ["Out"]], + "fill_any_like": [["X"], ["Out"]], + "fill_constant": [[], ["Out"]], + "gelu": [["X"], ["Out"]], + "instance_norm": [["X"], ["Out"]], + "lookup_table": [["W", "Ids"], ["Out"]], + "lookup_table_v2": [["W", "Ids"], ["Out"]], + "norm": [["X"], ["Norm"]], + "p_norm": [["X"], ["Out"]], + "pow": [["X"], ["Out"]], + "reduce_mean": [["X"], ["Out"]], + "stack": [["X"], ["Y"]], + "top_k_v2": [["X"], ["Out", "Indices"]], + "logical_and": [["X", "Y"], ["Out"]], + "logical_not": [["X"], ["Out"]], + "meshgrid": [["X"], ["Out"]], + "roi_align": [["X", "ROIs"], ["Out"]], + "strided_slice": [["Input"], ["Out"]], + "where": [["Condition", "X", "Y"], ["Out"]], + "grid_sampler": [["X", "Grid"], ["Output"]], + "tile": [["X"], ["Out"]], + "group_norm": [["X"], ["Y", "Mean", "Variance"]], + "reduce_sum": [["X"], ["Out"]], + "square": [["X"], ["Out"]], + "softplus": [["X"], ["Out"]], + "shuffle_channel": [["X"], ["Out"]], +} + + +def _get_op_input_var_names(op): + """ + Get the input var names of the op. + Args: + op(IrNode, Operator): the input op. + Returns: + input_var_names or None. + """ + assert isinstance(op, (IrNode, Operator)), \ + "The input op should be IrNode or Operator." + var_names = [] + op_name = op.name() if isinstance(op, IrNode) \ + else op.type + if op_name not in _op_real_in_out_name: + return [] + + name_list = _op_real_in_out_name[op_name][0] + for name in name_list: + var_name = op.input(name) + if isinstance(var_name, list): + var_names.extend(var_name) + else: + var_names.append(var_name) + return var_names + + +def _get_op_output_var_names(op): + """ """ + assert isinstance(op, (IrNode, Operator)), \ + "The input op should be IrNode or Operator." + var_names = [] + op_name = op.name() if isinstance(op, IrNode) \ + else op.type + if op_name not in _op_real_in_out_name: + return [] + + name_list = _op_real_in_out_name[op_name][1] + for name in name_list: + var_name = op.output(name) + if isinstance(var_name, list): + var_names.extend(var_name) + else: + var_names.append(var_name) + return var_names + + +def _get_input_name_index(op, input_var_name): + """Get the input name and index of the var_name in the op""" + assert isinstance(op, (IrNode, Operator)), \ + "The input op should be IrNode or Operator." + op_name = op.name() if isinstance(op, IrNode) \ + else op.type + if op_name not in _op_real_in_out_name: + return None + + res = None + for argname in _op_real_in_out_name[op_name][0]: + var_names = op.input(argname) + for index, name in enumerate(var_names): + if name == input_var_name: + res = (argname, index) + return res + + +def _get_output_name_index(op, output_var_name): + """Get the output name and index of the var_name in the op""" + assert isinstance(op, (IrNode, Operator)), \ + "The input op should be IrNode or Operator." + op_name = op.name() if isinstance(op, IrNode) \ + else op.type + if op_name not in _op_real_in_out_name: + return None + + name_list = _op_real_in_out_name[op_name][1] + res = None + for name in name_list: + var_name = op.output(name) + for index, val in enumerate(var_name): + if val == output_var_name: + res = (name, index) + return res + def load_variable_data(scope, var_name): ''' From c4d378aeac151605a9182f02ca7847fdb4453fa3 Mon Sep 17 00:00:00 2001 From: yghstill <742925032@qq.com> Date: Tue, 29 Mar 2022 08:57:18 +0000 Subject: [PATCH 02/14] add unittest --- paddle/fluid/operators/CMakeLists.txt | 2 +- paddle/fluid/operators/quantize_linear_op.cc | 148 ------------------ paddle/fluid/operators/quantize_linear_op.h | 49 +----- .../slim/quantization/imperative/qat.py | 22 ++- .../post_training_quantization.py | 3 +- .../slim/quantization/quantization_pass.py | 96 ++---------- .../fluid/contrib/slim/quantization/utils.py | 40 +++++ .../contrib/slim/tests/test_imperative_qat.py | 10 +- ..._post_training_quantization_mobilenetv1.py | 59 ++++++- 9 files changed, 136 insertions(+), 293 deletions(-) diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index bf2a00ea74a01..c539504047acc 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -106,7 +106,7 @@ register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op load_combin recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op spectral_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS}) op_library(run_program_op SRCS run_program_op.cc run_program_op.cu.cc DEPS executor_cache ${OP_HEADER_DEPS}) -op_library(quantize_linear_op SRCS quantize_linear_op.cc quantize_linear_op.cu DEPS cast_kernel ${OP_HEADER_DEPS}) +op_library(quantize_linear_op DEPS cast_kernel) op_library(save_combine_op DEPS string_array) op_library(load_combine_op DEPS string_array) diff --git a/paddle/fluid/operators/quantize_linear_op.cc b/paddle/fluid/operators/quantize_linear_op.cc index 72bb7ed61c3a4..04bddb7f772af 100644 --- a/paddle/fluid/operators/quantize_linear_op.cc +++ b/paddle/fluid/operators/quantize_linear_op.cc @@ -21,154 +21,6 @@ limitations under the License. */ namespace paddle { namespace operators { -template -struct Compare { - public: - bool operator()(const T a, const T b) { return (std::abs(a) < std::abs(b)); } -}; - -template -struct FindAbsMaxFunctor { - void operator()(const platform::CPUDeviceContext& ctx, const T* in, - const int num, T* out) { - *out = std::abs(*(std::max_element(in + 0, in + num, Compare()))); - } -}; - -template struct FindAbsMaxFunctor; - -template -struct FindChannelAbsMaxFunctor { - void operator()(const platform::CPUDeviceContext& ctx, - const framework::Tensor& in_tensor, const int quant_axis, - T* out_abs_max) { - // At present, channelwise quantization supports conv2d, depthwise_conv2d - // conv2d_transpose and mul - PADDLE_ENFORCE_EQ( - quant_axis == 0 || quant_axis == 1, true, - platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " - "the received is %d", - quant_axis)); - auto* in_data = in_tensor.data(); - auto in_dims = in_tensor.dims(); - const int64_t channel = in_dims[quant_axis]; - if (quant_axis == 0) { - const int64_t channel_size = in_tensor.numel() / channel; - for (int64_t i = 0; i < channel; i++) { - auto* start = in_data + i * channel_size; - auto* end = in_data + (i + 1) * channel_size; - out_abs_max[i] = - std::abs(*(std::max_element(start, end, Compare()))); - } - } else if (quant_axis == 1) { - for (int64_t i = 0; i < channel; i++) { - out_abs_max[i] = 0; - } - const int64_t step_i = in_tensor.numel() / in_dims[0]; - const int64_t step_j = in_tensor.numel() / (in_dims[0] * in_dims[1]); - for (int64_t i = 0; i < in_dims[0]; i++) { - for (int64_t j = 0; j < in_dims[1]; j++) { - auto* start = in_data + i * step_i + j * step_j; - auto* end = in_data + i * step_i + (j + 1) * step_j; - T abs_max = std::abs(*(std::max_element(start, end, Compare()))); - out_abs_max[j] = std::max(out_abs_max[j], abs_max); - } - } - } - } -}; - -template struct FindChannelAbsMaxFunctor; - -template -struct ClipAndFakeQuantFunctor { - void operator()(const platform::CPUDeviceContext& ctx, - const framework::Tensor& in, const framework::Tensor& scale, - const int bin_cnt, framework::Tensor* out) { - T s = scale.data()[0]; - T inv_s = inverse(s); - platform::Transform trans; - trans(ctx, in.data(), in.data() + in.numel(), - out->mutable_data(ctx.GetPlace()), ClipFunctor(-s, s)); - auto out_e = framework::EigenVector::Flatten(*out); - out_e.device(*ctx.eigen_device()) = (bin_cnt * inv_s * out_e).round(); - } -}; - -template struct ClipAndFakeQuantFunctor; - -template -struct ChannelClipAndFakeQuantFunctor { - void operator()(const platform::CPUDeviceContext& ctx, - const framework::Tensor& in, const framework::Tensor& scale, - const int bin_cnt, const int quant_axis, - framework::Tensor* out) { - // At present, channelwise quantization supports conv2d, depthwise_conv2d - // conv2d_transpose and mul - PADDLE_ENFORCE_EQ( - quant_axis == 0 || quant_axis == 1, true, - platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " - "the received is %d", - quant_axis)); - auto* scale_data = scale.data(); - auto* in_data = in.data(); - auto* out_data = out->mutable_data(ctx.GetPlace()); - auto in_dims = in.dims(); - const int64_t channel = in_dims[quant_axis]; - platform::Transform trans; - if (quant_axis == 0) { - const int64_t channel_size = in.numel() / channel; - for (int64_t i = 0; i < channel; i++) { - T s = scale_data[i]; - auto* start = in_data + i * channel_size; - auto* end = in_data + (i + 1) * channel_size; - trans(ctx, start, end, out_data + i * channel_size, - ClipFunctor(-s, s)); - } - for (int64_t i = 0; i < channel; i++) { - T s = scale_data[i]; - T inv_s = inverse(s); - framework::Tensor one_channel_out = out->Slice(i, i + 1); - auto out_e = framework::EigenVector::Flatten(one_channel_out); - out_e.device(*ctx.eigen_device()) = (bin_cnt * inv_s * out_e).round(); - } - } else if (quant_axis == 1) { - const int64_t step_i = in.numel() / in_dims[0]; - const int64_t step_j = in.numel() / (in_dims[0] * in_dims[1]); - for (int i = 0; i < in_dims[0]; i++) { - for (int j = 0; j < in_dims[1]; j++) { - T s = scale_data[j]; - T inv_s = inverse(s); - auto* start = in_data + i * step_i + j * step_j; - auto* end = in_data + i * step_i + (j + 1) * step_j; - auto* cur_out_data = out_data + i * step_i + j * step_j; - trans(ctx, start, end, cur_out_data, ClipFunctor(-s, s)); - for (int k = 0; k < step_j; k++) { - cur_out_data[k] = std::round(bin_cnt * inv_s * cur_out_data[k]); - } - } - } - } - } -}; - -template struct ChannelClipAndFakeQuantFunctor; - -template -struct DequantizeFunctor { - void operator()(const platform::CPUDeviceContext& dev_ctx, - const framework::Tensor* in, const framework::Tensor* scale, - T max_range, framework::Tensor* out) { - auto in_e = framework::EigenVector::Flatten(*in); - const T* scale_factor = scale->data(); - auto out_e = framework::EigenVector::Flatten(*out); - - auto& dev = *dev_ctx.eigen_device(); - out_e.device(dev) = in_e * scale_factor[0] / max_range; - } -}; - template struct ChannelDequantizeFunctor { void operator()(const platform::CPUDeviceContext& dev_ctx, diff --git a/paddle/fluid/operators/quantize_linear_op.h b/paddle/fluid/operators/quantize_linear_op.h index 2f2d9ef9a5100..3776ce95bf644 100644 --- a/paddle/fluid/operators/quantize_linear_op.h +++ b/paddle/fluid/operators/quantize_linear_op.h @@ -17,6 +17,8 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/memory/malloc.h" +#include "paddle/fluid/operators/fake_dequantize_op.h" +#include "paddle/fluid/operators/fake_quantize_op.h" #include "paddle/fluid/platform/transform.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/core/ddim.h" @@ -27,53 +29,6 @@ limitations under the License. */ namespace paddle { namespace operators { -template -inline HOSTDEVICE T inverse(T s) { - T eps = static_cast(1e-6); - T one = static_cast(1.0); - return s <= static_cast(1e-30) ? one / (s + eps) : one / s; -} - -template -struct FindAbsMaxFunctor { - void operator()(const DeviceContext& ctx, const T* in, const int num, T* out); -}; - -template -struct ClipAndFakeQuantFunctor { - void operator()(const DeviceContext& ctx, const framework::Tensor& in, - const framework::Tensor& scale, const int bin_cnt, - framework::Tensor* out); -}; - -template -struct FindChannelAbsMaxFunctor { - void operator()(const DeviceContext& ctx, const framework::Tensor& in_tensor, - const int quant_axis, T* out_abs_max); -}; - -template -struct ChannelClipAndFakeQuantFunctor { - void operator()(const DeviceContext& ctx, const framework::Tensor& in, - const framework::Tensor& scale, const int bin_cnt, - const int quant_axis, framework::Tensor* out); -}; - -template -struct DequantizeFunctor { - void operator()(const DeviceContext& dev_ctx, const framework::Tensor* in, - const framework::Tensor* scale, T max_range, - framework::Tensor* out); -}; - -template -struct ChannelDequantizeFunctor { - void operator()(const DeviceContext& dev_ctx, const framework::Tensor* in, - const framework::Tensor** scales, const int scale_num, - T max_range, const int quant_axis, const int x_num_col_dims, - framework::Tensor* out); -}; - template class QuantizeLinearKernel : public framework::OpKernel { public: diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py index a3fdca5e40669..5a9f192d15499 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py @@ -28,6 +28,7 @@ from paddle.fluid.initializer import Constant from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX from paddle.fluid.io import load_inference_model, save_inference_model +from ..quantization_pass import ReplaceFakeQuantDequantPass, QuantWeightPass from paddle.fluid.log_helper import get_logger from .. import quantization_pass from . import utils @@ -431,7 +432,12 @@ def apply(self, model): setattr(parent_layer, sub_name, cur_quant_layer) - def save_quantized_model(self, model, path, input_spec=None, **config): + def save_quantized_model(self, + model, + path, + input_spec=None, + onnx_format=False, + **config): """ Save the quantized model for the inference. @@ -498,6 +504,18 @@ def save_quantized_model(self, model, path, input_spec=None, **config): self._set_skip_quant_attr(infer_program) + clip_extra = False + if onnx_format: + graph = IrGraph(core.Graph(infer_program.desc), for_test=False) + transform_pass = ReplaceFakeQuantDequantPass(scope, place) + transform_pass.apply(graph) + + quant_weight_pass = QuantWeightPass(scope, place) + quant_weight_pass.apply(graph) + infer_program = graph.to_program() + + clip_extra = True + save_inference_model( dirname=dirname, feeded_var_names=feed_target_names, @@ -506,7 +524,7 @@ def save_quantized_model(self, model, path, input_spec=None, **config): main_program=infer_program.clone(), model_filename=model_filename, params_filename=params_filename, - clip_extra=False) + clip_extra=clip_extra) if is_dynamic_mode: paddle.disable_static() diff --git a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py index febb4170b4440..c528d35a7ae8c 100644 --- a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py +++ b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py @@ -607,8 +607,7 @@ def _sampling(self): def _sample_mse(self): if self._quantized_threshold == {}: for var_name in self._quantized_weight_var_name: - var_tensor = utils.utils.load_variable_data(self._scope, - var_name) + var_tensor = utils.load_variable_data(self._scope, var_name) if self._weight_quantize_type == "abs_max": abs_max_value = float(np.max(np.abs(var_tensor))) elif self._weight_quantize_type == "channel_wise_abs_max": diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index 40c9698638b3a..b6adc822e1329 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -1036,8 +1036,12 @@ def apply(self, graph): self._weight_bits) quantized_param_v = np.round(quantized_param_v) if self._bias_correction == True: - quantized_param_v = self._bias_correction_w( - param_v, quantized_param_v, scale_v, quant_axis) + quantized_param_v = utils.bias_correction_w( + param_v, + quantized_param_v, + scale_v, + quant_axis, + weight_bits=self._weight_bits) quantized_param_v = np.round(quantized_param_v) self._restore_var(input_arg_name, quantized_param_v) self._remove_fake_quant_and_dequant_op(graph, op_node) @@ -1257,46 +1261,6 @@ def _is_float(self, v): return isinstance(v, float) or isinstance(v, np.float32) \ or isinstance(v, np.float64) - def _bias_correction_w(self, x, x_quant, scale_v, quant_axis): - ''' - Bias correction for weight - ''' - eps = 1e-8 - bnt = (1 << (self._weight_bits - 1)) - 1 - x_dequant = x_quant.copy() - if isinstance(scale_v, list): - if quant_axis == 0: - for i, s in enumerate(scale_v): - x_dequant[i] = x_dequant[i] * s / bnt - quant_bias = x - x_dequant - mean_bias = quant_bias.reshape(quant_bias.shape[0], -1).mean(-1) - std_orig = x.reshape(x.shape[0], -1).std(-1) - std_quant = x_dequant.reshape(x_dequant.shape[0], -1).std(-1) - std_bias = std_orig / (std_quant + eps) - else: - for i, s in enumerate(scale_v): - x_dequant[:, i] = x_quant[:, i] * s / bnt - quant_bias = x - x_dequant - mean_bias = np.array([ - quant_bias[:, i].mean() for i in range(quant_bias.shape[1]) - ]) - std_orig = np.array([x[:, i].std() for i in range(x.shape[1])]) - std_quant = np.array( - [x_dequant[:, i].std() for i in range(x_dequant.shape[1])]) - std_bias = std_orig / (std_quant + eps) - else: - x_dequant = x_quant * scale_v / bnt - mean_bias = (x - x_dequant).mean() - std_bias = x.std() / (x_dequant.std() + eps) - if mean_bias.ndim == 1: - std_bias = np.resize(std_bias, x.shape) - mean_bias = np.resize(mean_bias, x.shape) - - x_dequant = (mean_bias + x_dequant) * std_bias - quantized_param_v = utils.quant_tensor(x_dequant, scale_v, quant_axis, - self._weight_bits) - return quantized_param_v - class ConvertToInt8Pass(object): def __init__(self, scope, place, quantizable_op_type=None): @@ -2566,8 +2530,12 @@ def apply(self, graph): quantized_param_v = utils.quant_tensor(param_v.copy(), scale_v, quant_axis, bits_length) if self._bias_correction == True: - quantized_param_v = self._bias_correction_w( - param_v, quantized_param_v, scale_v, quant_axis) + quantized_param_v = utils.bias_correction_w( + param_v, + quantized_param_v, + scale_v, + quant_axis, + weight_bits=bits_length) if self._quant_bits == 8: save_weight_dtype = np.int8 elif self._quant_bits == 4: @@ -2604,43 +2572,3 @@ def _load_var(self, name): def _restore_var(self, name, array): tensor = self._scope.find_var(name).get_tensor() tensor.set(array, self._place) - - def _bias_correction_w(self, x, x_quant, scale_v, quant_axis): - ''' - Bias correction for weight - ''' - eps = 1e-8 - bnt = (1 << (self._weight_bits - 1)) - 1 - x_dequant = x_quant.copy() - if isinstance(scale_v, list): - if quant_axis == 0: - for i, s in enumerate(scale_v): - x_dequant[i] = x_dequant[i] * s / bnt - quant_bias = x - x_dequant - mean_bias = quant_bias.reshape(quant_bias.shape[0], -1).mean(-1) - std_orig = x.reshape(x.shape[0], -1).std(-1) - std_quant = x_dequant.reshape(x_dequant.shape[0], -1).std(-1) - std_bias = std_orig / (std_quant + eps) - else: - for i, s in enumerate(scale_v): - x_dequant[:, i] = x_quant[:, i] * s / bnt - quant_bias = x - x_dequant - mean_bias = np.array([ - quant_bias[:, i].mean() for i in range(quant_bias.shape[1]) - ]) - std_orig = np.array([x[:, i].std() for i in range(x.shape[1])]) - std_quant = np.array( - [x_dequant[:, i].std() for i in range(x_dequant.shape[1])]) - std_bias = std_orig / (std_quant + eps) - else: - x_dequant = x_quant * scale_v / bnt - mean_bias = (x - x_dequant).mean() - std_bias = x.std() / (x_dequant.std() + eps) - if mean_bias.ndim == 1: - std_bias = np.resize(std_bias, x.shape) - mean_bias = np.resize(mean_bias, x.shape) - - x_dequant = (mean_bias + x_dequant) * std_bias - quantized_param_v = self._quant(x_dequant, scale_v, self._weight_bits, - quant_axis) - return quantized_param_v diff --git a/python/paddle/fluid/contrib/slim/quantization/utils.py b/python/paddle/fluid/contrib/slim/quantization/utils.py index 3bea6dd5954e9..9881e6ec02b54 100644 --- a/python/paddle/fluid/contrib/slim/quantization/utils.py +++ b/python/paddle/fluid/contrib/slim/quantization/utils.py @@ -365,6 +365,46 @@ def dequant_tensor(x, scale, quant_axis=0, weight_bits=8): return x +def bias_correction_w(self, x, x_quant, scale_v, quant_axis, weight_bits=8): + ''' + Bias correction for weight + ''' + eps = 1e-8 + bnt = (1 << (weight_bits - 1)) - 1 + x_dequant = x_quant.copy() + if isinstance(scale_v, list): + if quant_axis == 0: + for i, s in enumerate(scale_v): + x_dequant[i] = x_dequant[i] * s / bnt + quant_bias = x - x_dequant + mean_bias = quant_bias.reshape(quant_bias.shape[0], -1).mean(-1) + std_orig = x.reshape(x.shape[0], -1).std(-1) + std_quant = x_dequant.reshape(x_dequant.shape[0], -1).std(-1) + std_bias = std_orig / (std_quant + eps) + else: + for i, s in enumerate(scale_v): + x_dequant[:, i] = x_quant[:, i] * s / bnt + quant_bias = x - x_dequant + mean_bias = np.array( + [quant_bias[:, i].mean() for i in range(quant_bias.shape[1])]) + std_orig = np.array([x[:, i].std() for i in range(x.shape[1])]) + std_quant = np.array( + [x_dequant[:, i].std() for i in range(x_dequant.shape[1])]) + std_bias = std_orig / (std_quant + eps) + else: + x_dequant = x_quant * scale_v / bnt + mean_bias = (x - x_dequant).mean() + std_bias = x.std() / (x_dequant.std() + eps) + if mean_bias.ndim == 1: + std_bias = np.resize(std_bias, x.shape) + mean_bias = np.resize(mean_bias, x.shape) + + x_dequant = (mean_bias + x_dequant) * std_bias + quantized_param_v = quant_tensor(x_dequant, scale_v, quant_axis, + weight_bits) + return quantized_param_v + + def stable_sigmoid(x): sig = np.where(x < 0, np.exp(x) / (1 + np.exp(x)), 1 / (1 + np.exp(-x))) return sig diff --git a/python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py index 5db720b028ffe..bbe13bde5ad87 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py +++ b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py @@ -53,7 +53,7 @@ class TestImperativeQat(unittest.TestCase): def set_vars(self): self.weight_quantize_type = 'abs_max' self.activation_quantize_type = 'moving_average_abs_max' - print('weight_quantize_type', self.weight_quantize_type) + self.onnx_format = False def func_qat(self): self.set_vars() @@ -171,7 +171,8 @@ def func_qat(self): input_spec=[ paddle.static.InputSpec( shape=[None, 1, 28, 28], dtype='float32') - ]) + ], + onnx_format=self.onnx_format) print('Quantized model saved in %s' % tmpdir) if core.is_compiled_with_cuda(): @@ -199,5 +200,10 @@ def test_qat(self): self.func_qat() +class TestImperativeQatONNXFormat(unittest.TestCase): + def set_vars(self): + self.onnx_format = True + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py index 312a0c9e4b40e..448c96d8be341 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py +++ b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py @@ -243,7 +243,8 @@ def generate_quantized_model(self, round_type="round", is_full_quantize=False, is_use_cache_file=False, - is_optimize_model=False): + is_optimize_model=False, + onnx_format=False): try: os.system("mkdir " + self.int8_model) except Exception as e: @@ -265,13 +266,23 @@ def generate_quantized_model(self, round_type=round_type, is_full_quantize=is_full_quantize, optimize_model=is_optimize_model, + onnx_format=onnx_format, is_use_cache_file=is_use_cache_file) ptq.quantize() ptq.save_quantized_model(self.int8_model) - def run_test(self, model, algo, round_type, data_urls, data_md5s, - quantizable_op_type, is_full_quantize, is_use_cache_file, - is_optimize_model, diff_threshold): + def run_test(self, + model, + algo, + round_type, + data_urls, + data_md5s, + quantizable_op_type, + is_full_quantize, + is_use_cache_file, + is_optimize_model, + diff_threshold, + onnx_format=False): infer_iterations = self.infer_iterations batch_size = self.batch_size sample_iterations = self.sample_iterations @@ -285,9 +296,10 @@ def run_test(self, model, algo, round_type, data_urls, data_md5s, print("Start INT8 post training quantization for {0} on {1} images ...". format(model, sample_iterations * batch_size)) - self.generate_quantized_model( - model_cache_folder + "/model", quantizable_op_type, algo, - round_type, is_full_quantize, is_use_cache_file, is_optimize_model) + self.generate_quantized_model(model_cache_folder + "/model", + quantizable_op_type, algo, round_type, + is_full_quantize, is_use_cache_file, + is_optimize_model, onnx_format) print("Start INT8 inference for {0} on {1} images ...".format( model, infer_iterations * batch_size)) @@ -517,5 +529,38 @@ def test_post_training_avg_mobilenetv1(self): is_optimize_model, diff_threshold) +class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization): + def test_post_training_onnx_format_mobilenetv1(self): + model = "MobileNet-V1" + algo = "avg" + round_type = "round" + data_urls = [ + 'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz' + ] + data_md5s = ['13892b0716d26443a8cdea15b3c6438b'] + quantizable_op_type = [ + "conv2d", + "depthwise_conv2d", + "mul", + ] + is_full_quantize = False + is_use_cache_file = False + is_optimize_model = True + onnx_format = True + diff_threshold = 0.025 + self.run_test( + model, + algo, + round_type, + data_urls, + data_md5s, + quantizable_op_type, + is_full_quantize, + is_use_cache_file, + is_optimize_model, + diff_threshold, + onnx_format=onnx_format) + + if __name__ == '__main__': unittest.main() From b81f3e2f6ef3be222d45ec0dcc9f643ee2f27572 Mon Sep 17 00:00:00 2001 From: yghstill <742925032@qq.com> Date: Wed, 30 Mar 2022 07:26:19 +0000 Subject: [PATCH 03/14] fix unittest --- .../fluid/contrib/slim/tests/test_imperative_qat.py | 2 ++ .../slim/tests/test_imperative_qat_channelwise.py | 9 +++++++++ 2 files changed, 11 insertions(+) diff --git a/python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py index bbe13bde5ad87..df1a1df709978 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py +++ b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py @@ -202,6 +202,8 @@ def test_qat(self): class TestImperativeQatONNXFormat(unittest.TestCase): def set_vars(self): + self.weight_quantize_type = 'abs_max' + self.activation_quantize_type = 'moving_average_abs_max' self.onnx_format = True diff --git a/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_channelwise.py b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_channelwise.py index 1a6c9c41638db..f396b1c1093bb 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_channelwise.py +++ b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_channelwise.py @@ -41,6 +41,15 @@ class TestImperativeQatChannelWise(TestImperativeQat): def set_vars(self): self.weight_quantize_type = 'channel_wise_abs_max' self.activation_quantize_type = 'moving_average_abs_max' + self.onnx_format = False + print('weight_quantize_type', self.weight_quantize_type) + + +class TestImperativeQatChannelWiseONNXFormat(TestImperativeQat): + def set_vars(self): + self.weight_quantize_type = 'channel_wise_abs_max' + self.activation_quantize_type = 'moving_average_abs_max' + self.onnx_format = True print('weight_quantize_type', self.weight_quantize_type) From 641d549f0f58c56b0f05a43321af21ce1f1c54e4 Mon Sep 17 00:00:00 2001 From: yghstill <742925032@qq.com> Date: Fri, 1 Apr 2022 07:03:50 +0000 Subject: [PATCH 04/14] fix some interface --- paddle/fluid/operators/fake_dequantize_op.cu | 66 ++++++----- paddle/fluid/operators/quantize_linear_op.cc | 78 ++++++------- paddle/fluid/operators/quantize_linear_op.cu | 108 ++++++++---------- paddle/fluid/operators/quantize_linear_op.h | 29 ++--- .../slim/quantization/imperative/qat.py | 2 + .../post_training_quantization.py | 2 +- .../slim/quantization/quantization_pass.py | 47 +++++--- .../contrib/slim/tests/test_imperative_qat.py | 25 ++-- .../tests/test_imperative_qat_channelwise.py | 2 + ..._post_training_quantization_mobilenetv1.py | 2 +- ...est_post_training_quantization_resnet50.py | 29 +++++ 11 files changed, 213 insertions(+), 177 deletions(-) diff --git a/paddle/fluid/operators/fake_dequantize_op.cu b/paddle/fluid/operators/fake_dequantize_op.cu index c0ec44909a5f3..56b4cc792ce10 100644 --- a/paddle/fluid/operators/fake_dequantize_op.cu +++ b/paddle/fluid/operators/fake_dequantize_op.cu @@ -18,11 +18,11 @@ namespace paddle { namespace operators { template -__global__ void KeDequantize(const T* in, const T* scale, T max_range, int num, - T* out) { - const int idx = threadIdx.x + blockIdx.x * blockDim.x; - if (idx < num) { - out[idx] = in[idx] * scale[0] / max_range; +__global__ void KeDequantize(const T* in, const T* scale, T max_range, + int64_t num, T* out) { + int64_t idx = threadIdx.x + blockIdx.x * blockDim.x; + for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) { + out[i] = in[i] * scale[0] / max_range; } } @@ -35,11 +35,16 @@ struct DequantizeFunctor { const T* scale_factor = scale->data(); T* out_data = out->mutable_data(dev_ctx.GetPlace()); - int num = in->numel(); - int block = 512; - int grid = (num + block - 1) / block; - - KeDequantize<<>>( + int64_t num = in->numel(); + int64_t block_size = std::min( + num, static_cast(dev_ctx.GetMaxThreadsPerBlock() / 4)); + int64_t max_threads = + dev_ctx.GetMaxPhysicalThreadCount(); // SM * block_per_SM + const int64_t max_blocks = + std::max(((max_threads - 1) / block_size + 1), static_cast(1)); + const int64_t grid_size = + std::min(max_blocks, (num + block_size - 1) / block_size); + KeDequantize<<>>( in_data, scale_factor, max_range, num, out_data); } }; @@ -96,31 +101,24 @@ struct ChannelDequantizeFunctor { if (scale_num == 1) { int64_t num = in->numel(); const T* scale_factor = scales[0]->data(); - if (quant_axis == 0) { - int grid = in_dims[0]; - int block = 1024; - DequantizeOneScaleQuantAxis0<<>>( - in_data, scale_factor, max_range, num, in_dims[0], out_data); - } else { - int quant_stride = 1; - for (int i = quant_axis + 1; i < in_dims.size(); i++) { - quant_stride *= in_dims[i]; - } - - int64_t block_size = std::min( - num, static_cast(dev_ctx.GetMaxThreadsPerBlock() / 4)); - int64_t max_threads = - dev_ctx.GetMaxPhysicalThreadCount(); // SM * block_per_SM - const int64_t max_blocks = std::max( - ((max_threads - 1) / block_size + 1), static_cast(1)); - const int64_t grid_size = - std::min(max_blocks, (num + block_size - 1) / block_size); - - DequantizeOneScaleQuantAxisN< - T><<>>( - in_data, scale_factor, max_range, num, in_dims[quant_axis], - quant_stride, out_data); + int64_t block_size = std::min( + num, static_cast(dev_ctx.GetMaxThreadsPerBlock() / 4)); + int64_t max_threads = + dev_ctx.GetMaxPhysicalThreadCount(); // SM * block_per_SM + const int64_t max_blocks = std::max(((max_threads - 1) / block_size + 1), + static_cast(1)); + const int64_t grid_size = + std::min(max_blocks, (num + block_size - 1) / block_size); + + int quant_stride = 1; + for (int i = quant_axis + 1; i < in_dims.size(); i++) { + quant_stride *= in_dims[i]; } + + DequantizeOneScaleQuantAxisN< + T><<>>( + in_data, scale_factor, max_range, num, in_dims[quant_axis], + quant_stride, out_data); } else if (scale_num == 2) { // Not need to consider quant_axis int num = in->numel(); diff --git a/paddle/fluid/operators/quantize_linear_op.cc b/paddle/fluid/operators/quantize_linear_op.cc index 04bddb7f772af..0085740500689 100644 --- a/paddle/fluid/operators/quantize_linear_op.cc +++ b/paddle/fluid/operators/quantize_linear_op.cc @@ -17,51 +17,49 @@ limitations under the License. */ #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/operators/clip_op.h" #include "paddle/fluid/platform/transform.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" namespace paddle { namespace operators { template -struct ChannelDequantizeFunctor { +struct ChannelDequantizeFunctorV2 { void operator()(const platform::CPUDeviceContext& dev_ctx, const framework::Tensor* in, const framework::Tensor* scale, - const int scale_num, T max_range, const int quant_axis, - const int x_num_col_dims, framework::Tensor* out) { - if (scale_num == 1) { - // Dequant op is before quantized op - // Dequantize the weight of quantized op - auto in_dims = in->dims(); - const int64_t channel = in_dims[quant_axis]; - const T* scale_factor = scale->data(); - if (quant_axis == 0) { - for (int64_t i = 0; i < channel; i++) { - T s = scale_factor[i]; - framework::Tensor one_channel_in = in->Slice(i, i + 1); - framework::Tensor one_channel_out = out->Slice(i, i + 1); - auto in_e = framework::EigenVector::Flatten(one_channel_in); - auto out_e = framework::EigenVector::Flatten(one_channel_out); - auto& dev = *dev_ctx.eigen_device(); - out_e.device(dev) = in_e * s / max_range; - } - } else if (quant_axis == 1) { - int64_t out_iter = 1; - for (int i = 0; i < quant_axis; i++) { - out_iter *= in_dims[i]; - } - int64_t step_i = in->numel() / out_iter; - int64_t step_j = in->numel() / (out_iter * channel); - auto* in_data = in->data(); - auto* out_data = out->mutable_data(dev_ctx.GetPlace()); - for (int64_t i = 0; i < out_iter; i++) { - for (int64_t j = 0; j < channel; j++) { - auto* cur_in = in_data + i * step_i + j * step_j; - auto* cur_out = out_data + i * step_i + j * step_j; - T s = scale_factor[j]; - for (int64_t k = 0; k < step_j; k++) { - *cur_out = (*cur_in) * s / max_range; - ++cur_in; - ++cur_out; - } + T max_range, const int quant_axis, framework::Tensor* out) { + // Dequant op is before quantized op + // Dequantize the weight of quantized op + auto in_dims = in->dims(); + const int64_t channel = in_dims[quant_axis]; + const T* scale_factor = scale->data(); + if (quant_axis == 0) { + for (int64_t i = 0; i < channel; i++) { + T s = scale_factor[i]; + framework::Tensor one_channel_in = in->Slice(i, i + 1); + framework::Tensor one_channel_out = out->Slice(i, i + 1); + auto in_e = framework::EigenVector::Flatten(one_channel_in); + auto out_e = framework::EigenVector::Flatten(one_channel_out); + auto& dev = *dev_ctx.eigen_device(); + out_e.device(dev) = in_e * s / max_range; + } + } else if (quant_axis == 1) { + int64_t out_iter = 1; + for (int i = 0; i < quant_axis; i++) { + out_iter *= in_dims[i]; + } + int64_t step_i = in->numel() / out_iter; + int64_t step_j = in->numel() / (out_iter * channel); + auto* in_data = in->data(); + auto* out_data = out->mutable_data(dev_ctx.GetPlace()); + for (int64_t i = 0; i < out_iter; i++) { + for (int64_t j = 0; j < channel; j++) { + auto* cur_in = in_data + i * step_i + j * step_j; + auto* cur_out = out_data + i * step_i + j * step_j; + T s = scale_factor[j]; + for (int64_t k = 0; k < step_j; k++) { + *cur_out = (*cur_in) * s / max_range; + ++cur_in; + ++cur_out; } } } @@ -71,8 +69,8 @@ struct ChannelDequantizeFunctor { template struct DequantizeFunctor; template struct DequantizeFunctor; -template struct ChannelDequantizeFunctor; -template struct ChannelDequantizeFunctor; +template struct ChannelDequantizeFunctorV2; +template struct ChannelDequantizeFunctorV2; class QuantizeLinearOp : public framework::OperatorWithKernel { public: diff --git a/paddle/fluid/operators/quantize_linear_op.cu b/paddle/fluid/operators/quantize_linear_op.cu index 42efe3a988a22..bce67a3d4a18b 100644 --- a/paddle/fluid/operators/quantize_linear_op.cu +++ b/paddle/fluid/operators/quantize_linear_op.cu @@ -21,11 +21,11 @@ namespace paddle { namespace operators { template -__global__ void KeDequantize(const T* in, const T* scale, T max_range, int num, - T* out) { - const int idx = threadIdx.x + blockIdx.x * blockDim.x; - if (idx < num) { - out[idx] = in[idx] * scale[0] / max_range; +__global__ void KeDequantize(const T* in, const T* scale, T max_range, + int64_t num, T* out) { + int64_t idx = threadIdx.x + blockIdx.x * blockDim.x; + for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) { + out[i] = in[i] * scale[0] / max_range; } } @@ -38,28 +38,20 @@ struct DequantizeFunctor { const T* scale_factor = scale->data(); T* out_data = out->mutable_data(dev_ctx.GetPlace()); - int num = in->numel(); - int block = 512; - int grid = (num + block - 1) / block; - - KeDequantize<<>>( + int64_t num = in->numel(); + int64_t block_size = std::min( + num, static_cast(dev_ctx.GetMaxThreadsPerBlock() / 4)); + int64_t max_threads = + dev_ctx.GetMaxPhysicalThreadCount(); // SM * block_per_SM + const int64_t max_blocks = + std::max(((max_threads - 1) / block_size + 1), static_cast(1)); + const int64_t grid_size = + std::min(max_blocks, (num + block_size - 1) / block_size); + KeDequantize<<>>( in_data, scale_factor, max_range, num, out_data); } }; -template -__global__ void DequantizeOneScaleQuantAxis0(const T* in, const T* scale, - T max_range, int num, int channel, - T* out) { - int tid = threadIdx.x; - int channel_size = num / channel; - const T* in_c = in + blockIdx.x * channel_size; - T* out_c = out + blockIdx.x * channel_size; - for (int i = tid; i < channel_size; i += blockDim.x) { - out_c[i] = in_c[i] * scale[blockIdx.x] / max_range; - } -} - template __global__ void DequantizeOneScaleQuantAxisN(const T* in, const T* scale, const T max_range, @@ -74,50 +66,40 @@ __global__ void DequantizeOneScaleQuantAxisN(const T* in, const T* scale, } template -struct ChannelDequantizeFunctor { +struct ChannelDequantizeFunctorV2 { void operator()(const platform::CUDADeviceContext& dev_ctx, const framework::Tensor* in, const framework::Tensor* scale, - const int scale_num, T max_range, const int quant_axis, - const int x_num_col_dims, framework::Tensor* out) { + T max_range, const int quant_axis, framework::Tensor* out) { auto in_dims = in->dims(); const T* in_data = in->data(); T* out_data = out->mutable_data(dev_ctx.GetPlace()); - if (scale_num == 1) { - int64_t num = in->numel(); - const T* scale_factor = scale->data(); - if (quant_axis == 0) { - int grid = in_dims[0]; - int block = 1024; - DequantizeOneScaleQuantAxis0<<>>( - in_data, scale_factor, max_range, num, in_dims[0], out_data); - } else { - int quant_stride = 1; - for (int i = quant_axis + 1; i < in_dims.size(); i++) { - quant_stride *= in_dims[i]; - } - - int64_t block_size = std::min( - num, static_cast(dev_ctx.GetMaxThreadsPerBlock() / 4)); - int64_t max_threads = - dev_ctx.GetMaxPhysicalThreadCount(); // SM * block_per_SM - const int64_t max_blocks = std::max( - ((max_threads - 1) / block_size + 1), static_cast(1)); - const int64_t grid_size = - std::min(max_blocks, (num + block_size - 1) / block_size); - - DequantizeOneScaleQuantAxisN< - T><<>>( - in_data, scale_factor, max_range, num, in_dims[quant_axis], - quant_stride, out_data); - } + int64_t num = in->numel(); + const T* scale_factor = scale->data(); + int64_t block_size = std::min( + num, static_cast(dev_ctx.GetMaxThreadsPerBlock() / 4)); + int64_t max_threads = + dev_ctx.GetMaxPhysicalThreadCount(); // SM * block_per_SM + const int64_t max_blocks = + std::max(((max_threads - 1) / block_size + 1), static_cast(1)); + const int64_t grid_size = + std::min(max_blocks, (num + block_size - 1) / block_size); + + int quant_stride = 1; + for (int i = quant_axis + 1; i < in_dims.size(); i++) { + quant_stride *= in_dims[i]; } + + DequantizeOneScaleQuantAxisN< + T><<>>( + in_data, scale_factor, max_range, num, in_dims[quant_axis], + quant_stride, out_data); } }; template struct DequantizeFunctor; template struct DequantizeFunctor; -template struct ChannelDequantizeFunctor; -template struct ChannelDequantizeFunctor; +template struct ChannelDequantizeFunctorV2; +template struct ChannelDequantizeFunctorV2; template __global__ void FindAbsMaxKernel(const T* in, const int n, T* out) { @@ -127,13 +109,14 @@ __global__ void FindAbsMaxKernel(const T* in, const int n, T* out) { extern __shared__ char* shared_max_data_tmp[]; auto shared_max_data = reinterpret_cast(shared_max_data_tmp); if (gridDim.x > 1) { - shared_max_data[tid] = T(0); + T local_max_data = T(0); for (int i = bid; i < n; i += blockDim.x * gridDim.x) { T tmp = abs(in[i]); - if (tmp > shared_max_data[tid]) { - shared_max_data[tid] = tmp; + if (tmp > local_max_data) { + local_max_data = tmp; } } + shared_max_data[tid] = local_max_data; } else { if (bid < n) { shared_max_data[tid] = abs(in[bid]); @@ -212,13 +195,14 @@ __global__ void FindChannelAbsMaxKernelQuantAxis1(const T* in, const int n, int tid = threadIdx.x; int bid = blockIdx.x; const T* in_current = in + tid * cout_wh_size + bid * wh_size; - shared_max_data[tid] = T(0); + T local_max_data = T(0); for (int i = 0; i < wh_size; i++) { T tmp = fabs(in_current[i]); - if (tmp > shared_max_data[tid]) { - shared_max_data[tid] = tmp; + if (tmp > local_max_data) { + local_max_data = tmp; } } + shared_max_data[tid] = local_max_data; __syncthreads(); int len = blockDim.x; diff --git a/paddle/fluid/operators/quantize_linear_op.h b/paddle/fluid/operators/quantize_linear_op.h index 3776ce95bf644..e20b99e85f0b3 100644 --- a/paddle/fluid/operators/quantize_linear_op.h +++ b/paddle/fluid/operators/quantize_linear_op.h @@ -13,7 +13,6 @@ limitations under the License. */ #include #include -#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/memory/malloc.h" @@ -24,11 +23,17 @@ limitations under the License. */ #include "paddle/phi/core/ddim.h" #include "paddle/phi/core/hostdevice.h" #include "paddle/phi/kernels/cast_kernel.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" namespace paddle { namespace operators { +template +struct ChannelDequantizeFunctorV2 { + void operator()(const DeviceContext& dev_ctx, const framework::Tensor* in, + const framework::Tensor** scales, const int scale_num, + T max_range, const int quant_axis, framework::Tensor* out); +}; + template class QuantizeLinearKernel : public framework::OpKernel { public: @@ -62,9 +67,12 @@ class QuantizeLinearKernel : public framework::OpKernel { T* out_scale_data = out_scale->mutable_data(context.GetPlace()); FindChannelAbsMaxFunctor()(dev_ctx, *in, quant_axis, out_scale_data); + ChannelClipAndFakeQuantFunctor()( + dev_ctx, *in, *out_scale, bin_cnt, quant_axis, out); + } else { + ChannelClipAndFakeQuantFunctor()( + dev_ctx, *in, *in_scale, bin_cnt, quant_axis, out); } - ChannelClipAndFakeQuantFunctor()( - dev_ctx, *in, *in_scale, bin_cnt, quant_axis, out); } } }; @@ -92,12 +100,6 @@ class DeQuantizeLinearKernel : public framework::OpKernel { DequantizeFunctor()(dev_ctx, &in_tmp, scale, static_cast(max_range), out); } else { - auto x_num_col_dims = 1; - int max_range = 1; - - out->mutable_data(dev_ctx.GetPlace()); - // Now only support scale_num = 1 - int scale_num = 1; PADDLE_ENFORCE_EQ( scale->numel(), in_tmp.dims()[quant_axis], platform::errors::PreconditionNotMet( @@ -105,11 +107,10 @@ class DeQuantizeLinearKernel : public framework::OpKernel { "quant_axis dimension value of Input(X) when the `scale` has " "only one element, but %ld != %ld here.", scale->numel(), in_tmp.dims()[quant_axis])); - max_range *= (std::pow(2, bit_length - 1) - 1); + int max_range = (std::pow(2, bit_length - 1) - 1); - ChannelDequantizeFunctor()( - dev_ctx, &in_tmp, scale, scale_num, static_cast(max_range), - quant_axis, x_num_col_dims, out); + ChannelDequantizeFunctorV2()( + dev_ctx, &in_tmp, scale, static_cast(max_range), quant_axis, out); } } }; diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py index 5a9f192d15499..059cb7b0dd1bf 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py @@ -450,6 +450,8 @@ def save_quantized_model(self, InputSpec or example Tensor. If None, all input variables of the original Layer's forward method would be the inputs of the saved model. Default None. + onnx_format (bool, optional): Whether to export the quantized model + with format of ONNX. Default is False. **configs (dict, optional): Other save configuration options for compatibility. We do not recommend using these configurations, they may be removed in the future. If not necessary, DO NOT use diff --git a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py index c528d35a7ae8c..a7b3dd5792a02 100644 --- a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py +++ b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py @@ -196,7 +196,7 @@ def __init__(self, the fake ops in saving quantized model, and we save the scale obtained by post training quantization in fake ops. Compared to 'abs_max', the model accuracy is usually higher when it is 'channel_wise_abs_max'. - onnx_format(bool): Whether to export the quantized model with format of onnx. + onnx_format(bool): Whether to export the quantized model with format of ONNX. Default is False. optimize_model(bool, optional): If set optimize_model as True, it applies some passes to the model before quantization, and it supports diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index b6adc822e1329..d4247f2c2a0ee 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -54,11 +54,12 @@ _fake_quant_dequant_op_list = [ 'fake_quantize_dequantize_moving_average_abs_max', "fake_channel_wise_quantize_dequantize_abs_max", - "fake_quantize_dequantize_moving_average_abs_max" ] _conv_ops = ['conv2d', 'depthwise_conv2d', 'conv2d_transpose'] +_SCALE_DEFAULT_VALUE = 0.001 + def _init_var_node(var_node, value, scope, place): assert isinstance(value, @@ -491,7 +492,7 @@ def _insert_quant_range_abs_max_op(self, graph, var_node, name, quant_bits): _init_var_node( scale_in_node, np.array( - [0.001], dtype=data_type), + [_SCALE_DEFAULT_VALUE], dtype=data_type), self._scope, self._place) @@ -559,7 +560,7 @@ def _insert_quant_moving_average_abs_max_op(self, graph, var_node, name, _init_var_node( scale_in_node, np.array( - [0.001], dtype=data_type), + [_SCALE_DEFAULT_VALUE], dtype=data_type), self._scope, self._place) @@ -1707,7 +1708,7 @@ def _inser_quant_dequant_moving_average_abs_max_op(self, graph, var_node, _init_var_node( scale_in_node, np.array( - [0.001], dtype=data_type), + [_SCALE_DEFAULT_VALUE], dtype=data_type), self._scope, self._place) @@ -1812,7 +1813,7 @@ def insert_quant_op(self, graph, var_node): else: scale_var_shape = 1 scale_var_type = var_node.type() - init_scale_value = np.array([0.001], dtype=data_type) + init_scale_value = np.array([_SCALE_DEFAULT_VALUE], dtype=data_type) scale_var_node = graph.create_persistable_node( name=self._quantized_scale_name(var_node.name()), var_type=scale_var_type, @@ -1844,7 +1845,9 @@ def insert_quant_op(self, graph, var_node): if not self._is_test: attrs["is_test"] = self._is_test attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward - outputs["OutScale"] = scale_var_node + scale_out_node = graph.create_var_node_from_desc(scale_var_node.var( + )) + outputs["OutScale"] = scale_out_node quant_op_node = graph.create_op_node( op_type="quantize_linear", @@ -1857,6 +1860,8 @@ def insert_quant_op(self, graph, var_node): if zero_point_node is not None: graph.link_to(zero_point_node, quant_op_node) graph.link_to(quant_op_node, quant_var_node) + if not self._is_test: + graph.link_to(quant_op_node, scale_out_node) return quant_var_node, scale_var_node def insert_dequant_op(self, graph, var_node, scale_var_node): @@ -2020,7 +2025,6 @@ def _quant_preprocess(self, op_node): def _transform_forward(self, graph, op): op.op()._set_attr("quantization_type", "qat_with_weight") - #op.op()._set_attr("with_quant_attr", True) inputs = op.inputs for var_node in inputs: if var_node.name() not in op.input_arg_names(): @@ -2149,8 +2153,8 @@ def _is_skip_quant(self, graph, op_node): op_node.op().attr("skip_quant"): is_skip = True # if the inputs of mul and matmul are not all persistable, use - # AddQuantDequantPass to quantize them. - if op_node.name() in ["mul", "matmul"] and \ + # AddQuantDequantPassV2 to quantize them. + if op_node.name() in ["mul", "matmul", "matmul_v2"] and \ _is_input_all_not_persistable(graph, op_node): is_skip = True if op_node.op().has_attr("quantization_type") and \ @@ -2486,14 +2490,21 @@ class QuantWeightPass(object): If it's string, It can be ``cpu``, and ``gpu:x``, where ``x`` is the index of the GPUs. bias_correction(bool): whether use bias correction for post-training quantization. https://arxiv.org/abs/1810.05723. + quant_bits(int, optional): quantization bit number for weight. Default is 8. + save_int_weight(bool, optional): Whether the type saving the weight is int. Default is True. """ - def __init__(self, scope, place, bias_correction=False, quant_bits=8): + def __init__(self, + scope, + place, + bias_correction=False, + quant_bits=8, + save_int_weight=True): self._place = _get_paddle_place(place) self._scope = scope self._bias_correction = bias_correction self._quant_bits = quant_bits - #self.scale_dict = scale_dict + self._save_int_weight = save_int_weight assert self._scope != None, "scope must not be None." assert self._place != None, "place must not be None." @@ -2536,12 +2547,14 @@ def apply(self, graph): scale_v, quant_axis, weight_bits=bits_length) - if self._quant_bits == 8: - save_weight_dtype = np.int8 - elif self._quant_bits == 4: - save_weight_dtype = np.int4 - # cast weight type - quantized_param_v = quantized_param_v.astype(save_weight_dtype) + if self._save_int_weight: + # cast weight type to int + if self._quant_bits == 8: + save_weight_dtype = np.int8 + elif self._quant_bits == 4: + save_weight_dtype = np.int4 + quantized_param_v = quantized_param_v.astype( + save_weight_dtype) self._restore_var(x_node.name(), quantized_param_v) for next_op_node in out_node.outputs: diff --git a/python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py index df1a1df709978..015ecb3d4a4e9 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py +++ b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py @@ -54,6 +54,8 @@ def set_vars(self): self.weight_quantize_type = 'abs_max' self.activation_quantize_type = 'moving_average_abs_max' self.onnx_format = False + self.check_export_model_accuracy = True + self.diff_threshold = 0.01 def func_qat(self): self.set_vars() @@ -159,9 +161,13 @@ def func_qat(self): data = next(test_reader()) test_data = np.array([x[0].reshape(1, 28, 28) for x in data]).astype('float32') + y_data = np.array( + [x[1] for x in data]).astype('int64').reshape(-1, 1) test_img = fluid.dygraph.to_variable(test_data) + label = fluid.dygraph.to_variable(y_data) lenet.eval() - before_save = lenet(test_img) + fp32_out = lenet(test_img) + fp32_acc = fluid.layers.accuracy(fp32_out, label).numpy() with tempfile.TemporaryDirectory(prefix="qat_save_path_") as tmpdir: # save inference quantized model @@ -186,13 +192,15 @@ def func_qat(self): executor=exe, model_filename="lenet" + INFER_MODEL_SUFFIX, params_filename="lenet" + INFER_PARAMS_SUFFIX) - after_save, = exe.run(inference_program, - feed={feed_target_names[0]: test_data}, - fetch_list=fetch_targets) - # check - self.assertTrue( - np.allclose(after_save, before_save.numpy()), - msg='Failed to save the inference quantized model.') + quant_out, = exe.run(inference_program, + feed={feed_target_names[0]: test_data}, + fetch_list=fetch_targets) + paddle.disable_static() + quant_out = fluid.dygraph.to_variable(quant_out) + quant_acc = fluid.layers.accuracy(quant_out, label).numpy() + paddle.enable_static() + delta_value = fp32_acc - quant_acc + self.assertLess(delta_value, self.diff_threshold) def test_qat(self): with _test_eager_guard(): @@ -205,6 +213,7 @@ def set_vars(self): self.weight_quantize_type = 'abs_max' self.activation_quantize_type = 'moving_average_abs_max' self.onnx_format = True + self.diff_threshold = 0.025 if __name__ == '__main__': diff --git a/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_channelwise.py b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_channelwise.py index f396b1c1093bb..ff40b170345a8 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_channelwise.py +++ b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_channelwise.py @@ -41,6 +41,7 @@ class TestImperativeQatChannelWise(TestImperativeQat): def set_vars(self): self.weight_quantize_type = 'channel_wise_abs_max' self.activation_quantize_type = 'moving_average_abs_max' + self.diff_threshold = 0.01 self.onnx_format = False print('weight_quantize_type', self.weight_quantize_type) @@ -50,6 +51,7 @@ def set_vars(self): self.weight_quantize_type = 'channel_wise_abs_max' self.activation_quantize_type = 'moving_average_abs_max' self.onnx_format = True + self.diff_threshold = 0.025 print('weight_quantize_type', self.weight_quantize_type) diff --git a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py index 448c96d8be341..498a1ec46cacd 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py +++ b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py @@ -547,7 +547,7 @@ def test_post_training_onnx_format_mobilenetv1(self): is_use_cache_file = False is_optimize_model = True onnx_format = True - diff_threshold = 0.025 + diff_threshold = 0.05 self.run_test( model, algo, diff --git a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_resnet50.py b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_resnet50.py index a26dcb51c724a..dc12026a21ab1 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_resnet50.py +++ b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_resnet50.py @@ -39,5 +39,34 @@ def test_post_training_resnet50(self): is_optimize_model, diff_threshold) +class TestPostTrainingForResnet50ONNXFormat(TestPostTrainingQuantization): + def test_post_training_resnet50(self): + model = "ResNet-50" + algo = "min_max" + round_type = "round" + data_urls = [ + 'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz' + ] + data_md5s = ['4a5194524823d9b76da6e738e1367881'] + quantizable_op_type = ["conv2d", "mul"] + is_full_quantize = False + is_use_cache_file = False + is_optimize_model = False + diff_threshold = 0.025 + onnx_format = True + self.run_test( + model, + algo, + round_type, + data_urls, + data_md5s, + quantizable_op_type, + is_full_quantize, + is_use_cache_file, + is_optimize_model, + diff_threshold, + onnx_format=onnx_format) + + if __name__ == '__main__': unittest.main() From 80910a47c54f9cf1856af9963a39cab46854c6cf Mon Sep 17 00:00:00 2001 From: yghstill <742925032@qq.com> Date: Fri, 1 Apr 2022 09:28:42 +0000 Subject: [PATCH 05/14] fix cuda kernel --- paddle/fluid/operators/fake_dequantize_op.cu | 133 +---- .../fluid/operators/fake_dequantize_op.cu.h | 151 +++++ paddle/fluid/operators/fake_quantize_op.cu | 525 +---------------- paddle/fluid/operators/fake_quantize_op.cu.h | 543 ++++++++++++++++++ paddle/fluid/operators/quantize_linear_op.cu | 347 +---------- 5 files changed, 698 insertions(+), 1001 deletions(-) create mode 100644 paddle/fluid/operators/fake_dequantize_op.cu.h create mode 100644 paddle/fluid/operators/fake_quantize_op.cu.h diff --git a/paddle/fluid/operators/fake_dequantize_op.cu b/paddle/fluid/operators/fake_dequantize_op.cu index 56b4cc792ce10..582f0627b2044 100644 --- a/paddle/fluid/operators/fake_dequantize_op.cu +++ b/paddle/fluid/operators/fake_dequantize_op.cu @@ -12,140 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/operators/fake_dequantize_op.cu.h" #include "paddle/fluid/operators/fake_dequantize_op.h" -namespace paddle { -namespace operators { - -template -__global__ void KeDequantize(const T* in, const T* scale, T max_range, - int64_t num, T* out) { - int64_t idx = threadIdx.x + blockIdx.x * blockDim.x; - for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) { - out[i] = in[i] * scale[0] / max_range; - } -} - -template -struct DequantizeFunctor { - void operator()(const platform::CUDADeviceContext& dev_ctx, - const framework::Tensor* in, const framework::Tensor* scale, - T max_range, framework::Tensor* out) { - const T* in_data = in->data(); - const T* scale_factor = scale->data(); - T* out_data = out->mutable_data(dev_ctx.GetPlace()); - - int64_t num = in->numel(); - int64_t block_size = std::min( - num, static_cast(dev_ctx.GetMaxThreadsPerBlock() / 4)); - int64_t max_threads = - dev_ctx.GetMaxPhysicalThreadCount(); // SM * block_per_SM - const int64_t max_blocks = - std::max(((max_threads - 1) / block_size + 1), static_cast(1)); - const int64_t grid_size = - std::min(max_blocks, (num + block_size - 1) / block_size); - KeDequantize<<>>( - in_data, scale_factor, max_range, num, out_data); - } -}; - -template -__global__ void DequantizeOneScaleQuantAxis0(const T* in, const T* scale, - T max_range, int num, int channel, - T* out) { - int tid = threadIdx.x; - int channel_size = num / channel; - const T* in_c = in + blockIdx.x * channel_size; - T* out_c = out + blockIdx.x * channel_size; - for (int i = tid; i < channel_size; i += blockDim.x) { - out_c[i] = in_c[i] * scale[blockIdx.x] / max_range; - } -} - -template -__global__ void DequantizeOneScaleQuantAxisN(const T* in, const T* scale, - const T max_range, - const int64_t num, - const int n_scales, - const int quant_stride, T* out) { - int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; - for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) { - T s = scale[(i / quant_stride) % n_scales]; - out[i] = in[i] * s / max_range; - } -} - -template -__global__ void DequantizeTwoScale(const T* in, const T* scale_one, - const T* scale_two, T max_range, int num, - int iter_size, int channel, T* out) { - int tid = threadIdx.x; - int channel_size = num / (iter_size * channel); - int scale_index = blockIdx.x % channel; - const T* in_c = in + blockIdx.x * channel_size; - T* out_c = out + blockIdx.x * channel_size; - for (int i = tid; i < channel_size; i += blockDim.x) { - out_c[i] = in_c[i] * scale_one[scale_index] * scale_two[0] / max_range; - } -} - -template -struct ChannelDequantizeFunctor { - void operator()(const platform::CUDADeviceContext& dev_ctx, - const framework::Tensor* in, const framework::Tensor** scales, - const int scale_num, T max_range, const int quant_axis, - const int x_num_col_dims, framework::Tensor* out) { - auto in_dims = in->dims(); - const T* in_data = in->data(); - T* out_data = out->mutable_data(dev_ctx.GetPlace()); - if (scale_num == 1) { - int64_t num = in->numel(); - const T* scale_factor = scales[0]->data(); - int64_t block_size = std::min( - num, static_cast(dev_ctx.GetMaxThreadsPerBlock() / 4)); - int64_t max_threads = - dev_ctx.GetMaxPhysicalThreadCount(); // SM * block_per_SM - const int64_t max_blocks = std::max(((max_threads - 1) / block_size + 1), - static_cast(1)); - const int64_t grid_size = - std::min(max_blocks, (num + block_size - 1) / block_size); - - int quant_stride = 1; - for (int i = quant_axis + 1; i < in_dims.size(); i++) { - quant_stride *= in_dims[i]; - } - - DequantizeOneScaleQuantAxisN< - T><<>>( - in_data, scale_factor, max_range, num, in_dims[quant_axis], - quant_stride, out_data); - } else if (scale_num == 2) { - // Not need to consider quant_axis - int num = in->numel(); - int iter_size = 1; - for (int i = 0; i < x_num_col_dims; i++) { - iter_size *= in->dims()[i]; - } - int channel = in->dims()[x_num_col_dims]; - const T* scale_one = scales[0]->data(); - const T* scale_two = scales[1]->data(); - int block = 1024; - int grid = iter_size * channel; - DequantizeTwoScale<<>>( - in_data, scale_one, scale_two, max_range, num, iter_size, channel, - out_data); - } - } -}; - -template struct DequantizeFunctor; -template struct DequantizeFunctor; -template struct ChannelDequantizeFunctor; -template struct ChannelDequantizeFunctor; - -} // namespace operators -} // namespace paddle - namespace ops = paddle::operators; using CUDA = paddle::platform::CUDADeviceContext; REGISTER_OP_CUDA_KERNEL(fake_dequantize_max_abs, diff --git a/paddle/fluid/operators/fake_dequantize_op.cu.h b/paddle/fluid/operators/fake_dequantize_op.cu.h new file mode 100644 index 0000000000000..9859dd4607c15 --- /dev/null +++ b/paddle/fluid/operators/fake_dequantize_op.cu.h @@ -0,0 +1,151 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef PADDLE_FLUID_OPERATORS_FAKE_DEQUANTIZE_OP_CU_H_ +#define PADDLE_FLUID_OPERATORS_FAKE_DEQUANTIZE_OP_CU_H_ +#endif // PADDLE_FLUID_OPERATORS_FAKE_DEQUANTIZE_OP_CU_H_ + +#include "paddle/fluid/operators/fake_dequantize_op.h" + +namespace paddle { +namespace operators { + +template +__global__ void KeDequantize(const T* in, const T* scale, T max_range, + int64_t num, T* out) { + int64_t idx = threadIdx.x + blockIdx.x * blockDim.x; + for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) { + out[i] = in[i] * scale[0] / max_range; + } +} + +template +struct DequantizeFunctor { + void operator()(const platform::CUDADeviceContext& dev_ctx, + const framework::Tensor* in, const framework::Tensor* scale, + T max_range, framework::Tensor* out) { + const T* in_data = in->data(); + const T* scale_factor = scale->data(); + T* out_data = out->mutable_data(dev_ctx.GetPlace()); + + int64_t num = in->numel(); + int64_t block_size = std::min( + num, static_cast(dev_ctx.GetMaxThreadsPerBlock() / 4)); + int64_t max_threads = + dev_ctx.GetMaxPhysicalThreadCount(); // SM * block_per_SM + const int64_t max_blocks = + std::max(((max_threads - 1) / block_size + 1), static_cast(1)); + const int64_t grid_size = + std::min(max_blocks, (num + block_size - 1) / block_size); + KeDequantize<<>>( + in_data, scale_factor, max_range, num, out_data); + } +}; + +template +__global__ void DequantizeOneScaleQuantAxis0(const T* in, const T* scale, + T max_range, int num, int channel, + T* out) { + int tid = threadIdx.x; + int channel_size = num / channel; + const T* in_c = in + blockIdx.x * channel_size; + T* out_c = out + blockIdx.x * channel_size; + for (int i = tid; i < channel_size; i += blockDim.x) { + out_c[i] = in_c[i] * scale[blockIdx.x] / max_range; + } +} + +template +__global__ void DequantizeOneScaleQuantAxisN(const T* in, const T* scale, + const T max_range, + const int64_t num, + const int n_scales, + const int quant_stride, T* out) { + int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; + for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) { + T s = scale[(i / quant_stride) % n_scales]; + out[i] = in[i] * s / max_range; + } +} + +template +__global__ void DequantizeTwoScale(const T* in, const T* scale_one, + const T* scale_two, T max_range, int num, + int iter_size, int channel, T* out) { + int tid = threadIdx.x; + int channel_size = num / (iter_size * channel); + int scale_index = blockIdx.x % channel; + const T* in_c = in + blockIdx.x * channel_size; + T* out_c = out + blockIdx.x * channel_size; + for (int i = tid; i < channel_size; i += blockDim.x) { + out_c[i] = in_c[i] * scale_one[scale_index] * scale_two[0] / max_range; + } +} + +template +struct ChannelDequantizeFunctor { + void operator()(const platform::CUDADeviceContext& dev_ctx, + const framework::Tensor* in, const framework::Tensor** scales, + const int scale_num, T max_range, const int quant_axis, + const int x_num_col_dims, framework::Tensor* out) { + auto in_dims = in->dims(); + const T* in_data = in->data(); + T* out_data = out->mutable_data(dev_ctx.GetPlace()); + if (scale_num == 1) { + int64_t num = in->numel(); + const T* scale_factor = scales[0]->data(); + int64_t block_size = std::min( + num, static_cast(dev_ctx.GetMaxThreadsPerBlock() / 4)); + int64_t max_threads = + dev_ctx.GetMaxPhysicalThreadCount(); // SM * block_per_SM + const int64_t max_blocks = std::max(((max_threads - 1) / block_size + 1), + static_cast(1)); + const int64_t grid_size = + std::min(max_blocks, (num + block_size - 1) / block_size); + + int quant_stride = 1; + for (int i = quant_axis + 1; i < in_dims.size(); i++) { + quant_stride *= in_dims[i]; + } + + DequantizeOneScaleQuantAxisN< + T><<>>( + in_data, scale_factor, max_range, num, in_dims[quant_axis], + quant_stride, out_data); + } else if (scale_num == 2) { + // Not need to consider quant_axis + int num = in->numel(); + int iter_size = 1; + for (int i = 0; i < x_num_col_dims; i++) { + iter_size *= in->dims()[i]; + } + int channel = in->dims()[x_num_col_dims]; + const T* scale_one = scales[0]->data(); + const T* scale_two = scales[1]->data(); + int block = 1024; + int grid = iter_size * channel; + DequantizeTwoScale<<>>( + in_data, scale_one, scale_two, max_range, num, iter_size, channel, + out_data); + } + } +}; + +template struct DequantizeFunctor; +template struct DequantizeFunctor; +template struct ChannelDequantizeFunctor; +template struct ChannelDequantizeFunctor; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/fake_quantize_op.cu b/paddle/fluid/operators/fake_quantize_op.cu index 01384a6cafef9..5416ae11c2b56 100644 --- a/paddle/fluid/operators/fake_quantize_op.cu +++ b/paddle/fluid/operators/fake_quantize_op.cu @@ -12,531 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include -#include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/operators/fake_quantize_op.cu.h" #include "paddle/fluid/operators/fake_quantize_op.h" -#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" - -namespace paddle { -namespace operators { - -template -__global__ void FindAbsMaxKernel(const T* in, const int n, T* out) { - int bid = threadIdx.x + blockIdx.x * blockDim.x; - int tid = threadIdx.x; - - extern __shared__ char* shared_max_data_tmp[]; - auto shared_max_data = reinterpret_cast(shared_max_data_tmp); - if (gridDim.x > 1) { - T local_max_data = T(0); - for (int i = bid; i < n; i += blockDim.x * gridDim.x) { - T tmp = abs(in[i]); - if (tmp > local_max_data) { - local_max_data = tmp; - } - } - shared_max_data[tid] = local_max_data; - } else { - if (bid < n) { - shared_max_data[tid] = abs(in[bid]); - } else { - shared_max_data[tid] = T(0); - } - } - __syncthreads(); - - for (int i = blockDim.x / 2; i > 0; i >>= 1) { - if (tid < i && (shared_max_data[tid] < shared_max_data[tid + i])) { - shared_max_data[tid] = shared_max_data[tid + i]; - } - __syncthreads(); - } - if (tid == 0) { - out[blockIdx.x] = shared_max_data[0]; - } -} - -template -struct FindAbsMaxFunctor { - void operator()(const platform::CUDADeviceContext& ctx, const T* in, - const int num, T* out) { - int block = 1024; - int grid = (block - 1 + num) / block; - grid = (grid > block) ? block : grid; - - framework::Tensor max; - T* max_data = max.mutable_data(phi::make_ddim({grid}), ctx.GetPlace()); - FindAbsMaxKernel<<>>( - in, num, max_data); - FindAbsMaxKernel<<<1, block, 1024 * sizeof(T), ctx.stream()>>>( - max_data, grid, out); - } -}; - -template struct FindAbsMaxFunctor; -template struct FindAbsMaxFunctor; - -template -__global__ void FindChannelAbsMaxKernelQuantAxis0(const T* in, const int n, - const int c, T* out) { - int tid = threadIdx.x; - int channel_size = n / c; - const T* in_c = in + blockIdx.x * channel_size; - extern __shared__ T shared_max_data[]; - T local_max_data = T(0); - for (int i = tid; i < channel_size; i += blockDim.x) { - T tmp = fabs(in_c[i]); - if (tmp > local_max_data) { - local_max_data = tmp; - } - } - shared_max_data[tid] = local_max_data; - __syncthreads(); - for (int i = blockDim.x / 2; i > 0; i >>= 1) { - if (tid < i && (shared_max_data[tid] < shared_max_data[tid + i])) { - shared_max_data[tid] = shared_max_data[tid + i]; - } - __syncthreads(); - } - if (tid == 0) { - out[blockIdx.x] = shared_max_data[0]; - } -} - -template -__global__ void FindChannelAbsMaxKernelQuantAxis1(const T* in, const int n, - const int cin, const int cout, - T* out) { - extern __shared__ T shared_max_data[]; - int cout_wh_size = n / cin; - int wh_size = n / (cin * cout); - - int tid = threadIdx.x; - int bid = blockIdx.x; - const T* in_current = in + tid * cout_wh_size + bid * wh_size; - T local_max_data = T(0); - for (int i = 0; i < wh_size; i++) { - T tmp = fabs(in_current[i]); - if (tmp > local_max_data) { - local_max_data = tmp; - } - } - shared_max_data[tid] = local_max_data; - __syncthreads(); - - int len = blockDim.x; - for (int i = (len + 1) / 2; i > 0; len = i, i = (i + 1) / 2) { - if (tid < i && tid + i < len && - shared_max_data[tid] < shared_max_data[tid + i]) { - shared_max_data[tid] = shared_max_data[tid + i]; - } - if (i == 1) { - i = 0; // break the loop - } - __syncthreads(); - } - if (tid == 0 && shared_max_data[0] > out[bid]) { - out[bid] = shared_max_data[0]; - } -} - -template -struct FindChannelAbsMaxFunctor { - void operator()(const platform::CUDADeviceContext& ctx, - const framework::Tensor& in_tensor, const int quant_axis, - T* out_abs_max) { - PADDLE_ENFORCE_EQ( - quant_axis == 0 || quant_axis == 1, true, - platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " - "the received is %d", - quant_axis)); - const int num = in_tensor.numel(); - auto in_dims = in_tensor.dims(); - const T* in_data = in_tensor.data(); - if (quant_axis == 0) { - int cout = in_dims[0]; - int grid = cout; - int block = 1024; - FindChannelAbsMaxKernelQuantAxis0< - T><<>>( - in_data, num, cout, out_abs_max); - } else if (quant_axis == 1) { - int cin = in_dims[0]; - int cout = in_dims[1]; - int grid = cout; - int max_threads = 1024; - -#ifdef PADDLE_WITH_HIP - hipMemset(out_abs_max, 0, sizeof(T) * cout); -#else - cudaMemset(out_abs_max, 0, sizeof(T) * cout); -#endif - - for (int i = 0; i < cin / max_threads; i++) { - int block = max_threads; - FindChannelAbsMaxKernelQuantAxis1< - T><<>>( - in_data, num, cin, cout, out_abs_max); - in_data += num / cin; - } - - int block = cin % max_threads; - if (block > 0) { - FindChannelAbsMaxKernelQuantAxis1< - T><<>>( - in_data, num, in_dims[0], in_dims[1], out_abs_max); - } - } - } -}; - -template struct FindChannelAbsMaxFunctor; - -template -__global__ void ClipAndQuantKernel(const T* in, const T* scale, - const int bin_cnt, const int n, T* out) { - int bid = threadIdx.x + blockIdx.x * blockDim.x; - int tid = threadIdx.x; - - T s = scale[0]; - T inv_s = inverse(s); - for (int i = bid; i < n; i += blockDim.x * gridDim.x) { - T x = in[i]; - T v = x > s ? s : x; - v = v < -s ? -s : v; - v = bin_cnt * inv_s * v; - out[i] = round(v); - } -} - -template -__global__ void ClipAndQuantDequantKernel(const T* in, const T* scale, - const int bin_cnt, const int n, - T* out) { - int bid = threadIdx.x + blockIdx.x * blockDim.x; - int tid = threadIdx.x; - - T s = scale[0]; - T inv_s = inverse(s); - T bin_cnt_t = static_cast(bin_cnt); - - for (int i = bid; i < n; i += blockDim.x * gridDim.x) { - T x = in[i]; - x = x > s ? s : x; - x = x < -s ? -s : x; - x = bin_cnt_t * inv_s * x; - x = static_cast(round(static_cast(x))); - out[i] = (x * s) / bin_cnt_t; - } -} - -template -struct ClipAndFakeQuantFunctor { - void operator()(const platform::CUDADeviceContext& ctx, - const framework::Tensor& in, const framework::Tensor& scale, - const int bin_cnt, framework::Tensor* out) { - int num = in.numel(); - int block = 1024; - int grid = (block - 1 + num) / block; - - const T* in_data = in.data(); - const T* scale_data = scale.data(); - T* out_data = out->mutable_data(ctx.GetPlace()); - - ClipAndQuantKernel<<>>( - in_data, scale_data, bin_cnt, num, out_data); - } -}; - -template struct ClipAndFakeQuantFunctor; - -template -struct ClipAndFakeQuantDequantFunctor { - void operator()(const platform::CUDADeviceContext& ctx, - const framework::Tensor& in, const framework::Tensor& scale, - const int bin_cnt, framework::Tensor* out) { - int num = in.numel(); - int block = 1024; - int grid = (block - 1 + num) / block; - - const T* in_data = in.data(); - const T* scale_data = scale.data(); - T* out_data = out->mutable_data(ctx.GetPlace()); - - ClipAndQuantDequantKernel<<>>( - in_data, scale_data, bin_cnt, num, out_data); - } -}; - -// ChannelClipAndQuantKernel for quant_axis is 0 -template -__global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale, - const int bin_cnt, - const int64_t n, - const int c, T* out) { - int tid = threadIdx.x; - - int64_t channel_size = n / c; - const T* in_c = in + blockIdx.x * channel_size; - T* out_c = out + blockIdx.x * channel_size; - - T s = scale[blockIdx.x]; - T inv_s = inverse(s); - - for (int64_t i = tid; i < channel_size; i += blockDim.x) { - T x = in_c[i]; - T v = x > s ? s : x; - v = v < -s ? -s : v; - v = bin_cnt * inv_s * v; - out_c[i] = round(v); - } -} - -// ChannelClipAndQuantKernel for quant_axis is N -template -__global__ void ChannelClipAndQuantKernelQuantAxisN( - const T* in, const T* scale, const int bin_cnt, const int64_t n, - const int nScale, const int quant_stride, T* out) { - int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; - for (int64_t i = idx; i < n; i += blockDim.x * gridDim.x) { - T s = scale[(i / quant_stride) % nScale]; - T inv_s = 1.0 / s; - T x = in[i]; - T v = x > s ? s : x; - v = v < -s ? -s : v; - v = bin_cnt * inv_s * v; - out[i] = round(v); - } -} - -template -struct ChannelClipAndFakeQuantFunctor { - void operator()(const platform::CUDADeviceContext& ctx, - const framework::Tensor& in, const framework::Tensor& scale, - const int bin_cnt, const int quant_axis, - framework::Tensor* out) { - PADDLE_ENFORCE_EQ( - quant_axis == 0 || quant_axis == 1, true, - platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " - "the received is %d", - quant_axis)); - - int64_t num = in.numel(); - auto in_dims = in.dims(); - const T* in_data = in.data(); - const T* scale_data = scale.data(); - T* out_data = out->mutable_data(ctx.GetPlace()); - - if (quant_axis == 0) { - int grid = in_dims[0]; - int block = 1024; - ChannelClipAndQuantKernelQuantAxis0<<>>( - in_data, scale_data, bin_cnt, num, in_dims[0], out_data); - } else { - int quant_stride = 1; - for (int i = quant_axis + 1; i < in_dims.size(); i++) { - quant_stride *= in_dims[i]; - } - int64_t block_size = - std::min(num, static_cast(ctx.GetMaxThreadsPerBlock() / 4)); - int64_t max_threads = - ctx.GetMaxPhysicalThreadCount(); // SM * block_per_SM - const int64_t max_blocks = std::max(((max_threads - 1) / block_size + 1), - static_cast(1)); - - const int64_t grid_size = - std::min(max_blocks, (num + block_size - 1) / block_size); - - ChannelClipAndQuantKernelQuantAxisN<<>>( - in_data, scale_data, bin_cnt, num, in_dims[quant_axis], quant_stride, - out_data); - } - } -}; - -template struct ChannelClipAndFakeQuantFunctor; - -template -__global__ void FindRangeAbsMaxAndFillArray(const T* cur_scale, - const T* last_scale, - const int64_t* iter, - const int window_size, T* scale_arr, - T* out_scale, int* need_find_max, - int* out_size) { - int it = iter[0]; - int idx = it % window_size; - T removed = scale_arr[idx]; - T cur = cur_scale[0]; - scale_arr[idx] = cur; - T max = last_scale[0]; - out_scale[0] = max < cur ? cur : max; - if (fabs(removed - max) < 1e-6) { - need_find_max[0] = 1; - out_size[0] = it > window_size ? window_size : it; - } else { - need_find_max[0] = 0; - } -} - -template -struct FindRangeAbsMaxFunctor { - void operator()(const platform::CUDADeviceContext& ctx, - const framework::Tensor& cur_scale, - const framework::Tensor& last_scale, - const framework::Tensor& iter, const int window_size, - framework::Tensor* scales_arr, framework::Tensor* out_scale) { - const auto gpu_place = ctx.GetPlace(); - - T* scale_arr = scales_arr->mutable_data(gpu_place); - T* out_scale_data = out_scale->mutable_data(gpu_place); - - framework::Tensor need_find_max, out_size; - int* find_max = need_find_max.mutable_data({1}, gpu_place); - int* out_size_data = out_size.mutable_data({1}, gpu_place); - - FindRangeAbsMaxAndFillArray<<<1, 1, 0, ctx.stream()>>>( - cur_scale.data(), last_scale.data(), iter.data(), - window_size, scale_arr, out_scale_data, find_max, out_size_data); - - int g_find_max; - memory::Copy(platform::CPUPlace(), &g_find_max, gpu_place, find_max, - sizeof(int), ctx.stream()); - ctx.Wait(); - if (g_find_max) { - int len; - memory::Copy(platform::CPUPlace(), &len, gpu_place, out_size_data, - sizeof(int), ctx.stream()); - ctx.Wait(); - FindAbsMaxFunctor()(ctx, scale_arr, len, - out_scale_data); - } - } -}; - -template -__global__ void FindMovingAverageAbsMaxKernel(const T* in_state, - const T* in_accum, - const T* cur_scale, const T rate, - T* out_state, T* out_accum, - T* out_scale) { - T state = rate * (*in_state) + T(1.0f); - T accum = rate * (*in_accum) + (*cur_scale); - *out_state = state; - *out_accum = accum; - *out_scale = accum / state; -} - -template struct FindRangeAbsMaxFunctor; - -template -struct FindMovingAverageAbsMaxFunctor { - void operator()(const platform::CUDADeviceContext& ctx, - const framework::Tensor& in_accum, - const framework::Tensor& in_state, const T* cur_scale, - const float rate, framework::Tensor* out_state, - framework::Tensor* out_accum, framework::Tensor* out_scale) { - const auto gpu_place = ctx.GetPlace(); - - T rate_t = static_cast(rate); - T* out_state_data = out_state->mutable_data(gpu_place); - T* out_accum_data = out_accum->mutable_data(gpu_place); - T* out_scale_data = out_scale->mutable_data(gpu_place); - - FindMovingAverageAbsMaxKernel<<<1, 1, 0, ctx.stream()>>>( - in_state.data(), in_accum.data(), cur_scale, rate_t, - out_state_data, out_accum_data, out_scale_data); - } -}; - -// ChannelClipAndQuantDequantKernel for quant_axis is 0 -template -__global__ void ChannelClipAndQuantDequantKernelQuantAxis0( - const T* in, const T* scale, const int bin_cnt, const int n, const int c, - T* out) { - int tid = threadIdx.x; - - int channel_size = n / c; - const T* in_c = in + blockIdx.x * channel_size; - T* out_c = out + blockIdx.x * channel_size; - - T s = scale[blockIdx.x]; - T inv_s = inverse(s); - - for (int i = tid; i < channel_size; i += blockDim.x) { - T x = in_c[i]; - T v = x > s ? s : x; - v = v < -s ? -s : v; - v = bin_cnt * inv_s * v; - out_c[i] = round(v) * s / bin_cnt; - } -} - -// ChannelClipAndQuantDequantKernel for quant_axis is 1 -template -__global__ void ChannelClipAndQuantDequantKernelQuantAxis1( - const T* in, const T* scale, const int bin_cnt, const int n, const int cin, - const int cout, T* out) { - T s = scale[blockIdx.x % cout]; - T inv_s = inverse(s); - - int wh_size = n / (cin * cout); - const T* in_c = in + blockIdx.x * wh_size; - T* out_c = out + blockIdx.x * wh_size; - - for (int i = threadIdx.x; i < wh_size; i += blockDim.x) { - T x = in_c[i]; - T v = x > s ? s : x; - v = v < -s ? -s : v; - v = bin_cnt * inv_s * v; - out_c[i] = round(v) * s / bin_cnt; - } -} - -template -struct ChannelClipFakeQuantDequantFunctor { - void operator()(const platform::CUDADeviceContext& ctx, - const framework::Tensor& in, const framework::Tensor& scale, - const int bin_cnt, const int quant_axis, - framework::Tensor* out) { - // At present, channelwise quantization supports conv2d, depthwise_conv2d - // conv2d_transpose and mul - PADDLE_ENFORCE_EQ( - quant_axis == 0 || quant_axis == 1, true, - platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " - "the received is %d", - quant_axis)); - - int num = in.numel(); - auto in_dims = in.dims(); - - const T* in_data = in.data(); - const T* scale_data = scale.data(); - T* out_data = out->mutable_data(ctx.GetPlace()); - - if (quant_axis == 0) { - int grid = in_dims[0]; - int block = 1024; - ChannelClipAndQuantDequantKernelQuantAxis0< - T><<>>(in_data, scale_data, bin_cnt, - num, in_dims[0], out_data); - } else if (quant_axis == 1) { - int grid = in_dims[0] * in_dims[1]; - int block = 1024; - - ChannelClipAndQuantDequantKernelQuantAxis1< - T><<>>( - in_data, scale_data, bin_cnt, num, in_dims[0], in_dims[1], out_data); - } - } -}; - -template struct ChannelClipFakeQuantDequantFunctor; - -} // namespace operators -} // namespace paddle namespace ops = paddle::operators; using CUDA = paddle::platform::CUDADeviceContext; diff --git a/paddle/fluid/operators/fake_quantize_op.cu.h b/paddle/fluid/operators/fake_quantize_op.cu.h new file mode 100644 index 0000000000000..d85d47f546131 --- /dev/null +++ b/paddle/fluid/operators/fake_quantize_op.cu.h @@ -0,0 +1,543 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef PADDLE_FLUID_OPERATORS_FAKE_QUANTIZE_OP_CU_H_ +#define PADDLE_FLUID_OPERATORS_FAKE_QUANTIZE_OP_CU_H_ +#endif // PADDLE_FLUID_OPERATORS_FAKE_QUANTIZE_OP_CU_H_ + +#include +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/operators/fake_quantize_op.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" + +namespace paddle { +namespace operators { + +template +__global__ void FindAbsMaxKernel(const T* in, const int n, T* out) { + int bid = threadIdx.x + blockIdx.x * blockDim.x; + int tid = threadIdx.x; + + extern __shared__ char* shared_max_data_tmp[]; + auto shared_max_data = reinterpret_cast(shared_max_data_tmp); + if (gridDim.x > 1) { + T local_max_data = T(0); + for (int i = bid; i < n; i += blockDim.x * gridDim.x) { + T tmp = abs(in[i]); + if (tmp > local_max_data) { + local_max_data = tmp; + } + } + shared_max_data[tid] = local_max_data; + } else { + if (bid < n) { + shared_max_data[tid] = abs(in[bid]); + } else { + shared_max_data[tid] = T(0); + } + } + __syncthreads(); + + for (int i = blockDim.x / 2; i > 0; i >>= 1) { + if (tid < i && (shared_max_data[tid] < shared_max_data[tid + i])) { + shared_max_data[tid] = shared_max_data[tid + i]; + } + __syncthreads(); + } + if (tid == 0) { + out[blockIdx.x] = shared_max_data[0]; + } +} + +template +struct FindAbsMaxFunctor { + void operator()(const platform::CUDADeviceContext& ctx, const T* in, + const int num, T* out) { + int block = 1024; + int grid = (block - 1 + num) / block; + grid = (grid > block) ? block : grid; + + framework::Tensor max; + T* max_data = max.mutable_data(phi::make_ddim({grid}), ctx.GetPlace()); + FindAbsMaxKernel<<>>( + in, num, max_data); + FindAbsMaxKernel<<<1, block, 1024 * sizeof(T), ctx.stream()>>>( + max_data, grid, out); + } +}; + +template struct FindAbsMaxFunctor; +template struct FindAbsMaxFunctor; + +template +__global__ void FindChannelAbsMaxKernelQuantAxis0(const T* in, const int n, + const int c, T* out) { + int tid = threadIdx.x; + int channel_size = n / c; + const T* in_c = in + blockIdx.x * channel_size; + extern __shared__ T shared_max_data[]; + T local_max_data = T(0); + for (int i = tid; i < channel_size; i += blockDim.x) { + T tmp = fabs(in_c[i]); + if (tmp > local_max_data) { + local_max_data = tmp; + } + } + shared_max_data[tid] = local_max_data; + __syncthreads(); + for (int i = blockDim.x / 2; i > 0; i >>= 1) { + if (tid < i && (shared_max_data[tid] < shared_max_data[tid + i])) { + shared_max_data[tid] = shared_max_data[tid + i]; + } + __syncthreads(); + } + if (tid == 0) { + out[blockIdx.x] = shared_max_data[0]; + } +} + +template +__global__ void FindChannelAbsMaxKernelQuantAxis1(const T* in, const int n, + const int cin, const int cout, + T* out) { + extern __shared__ T shared_max_data[]; + int cout_wh_size = n / cin; + int wh_size = n / (cin * cout); + + int tid = threadIdx.x; + int bid = blockIdx.x; + const T* in_current = in + tid * cout_wh_size + bid * wh_size; + T local_max_data = T(0); + for (int i = 0; i < wh_size; i++) { + T tmp = fabs(in_current[i]); + if (tmp > local_max_data) { + local_max_data = tmp; + } + } + shared_max_data[tid] = local_max_data; + __syncthreads(); + + int len = blockDim.x; + for (int i = (len + 1) / 2; i > 0; len = i, i = (i + 1) / 2) { + if (tid < i && tid + i < len && + shared_max_data[tid] < shared_max_data[tid + i]) { + shared_max_data[tid] = shared_max_data[tid + i]; + } + if (i == 1) { + i = 0; // break the loop + } + __syncthreads(); + } + if (tid == 0 && shared_max_data[0] > out[bid]) { + out[bid] = shared_max_data[0]; + } +} + +template +struct FindChannelAbsMaxFunctor { + void operator()(const platform::CUDADeviceContext& ctx, + const framework::Tensor& in_tensor, const int quant_axis, + T* out_abs_max) { + PADDLE_ENFORCE_EQ( + quant_axis == 0 || quant_axis == 1, true, + platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " + "the received is %d", + quant_axis)); + const int num = in_tensor.numel(); + auto in_dims = in_tensor.dims(); + const T* in_data = in_tensor.data(); + if (quant_axis == 0) { + int cout = in_dims[0]; + int grid = cout; + int block = 1024; + FindChannelAbsMaxKernelQuantAxis0< + T><<>>( + in_data, num, cout, out_abs_max); + } else if (quant_axis == 1) { + int cin = in_dims[0]; + int cout = in_dims[1]; + int grid = cout; + int max_threads = 1024; + +#ifdef PADDLE_WITH_HIP + hipMemset(out_abs_max, 0, sizeof(T) * cout); +#else + cudaMemset(out_abs_max, 0, sizeof(T) * cout); +#endif // PADDLE_FLUID_OPERATORS_FAKE_QUANTIZE_OP_CU_H_ + + for (int i = 0; i < cin / max_threads; i++) { + int block = max_threads; + FindChannelAbsMaxKernelQuantAxis1< + T><<>>( + in_data, num, cin, cout, out_abs_max); + in_data += num / cin; + } + + int block = cin % max_threads; + if (block > 0) { + FindChannelAbsMaxKernelQuantAxis1< + T><<>>( + in_data, num, in_dims[0], in_dims[1], out_abs_max); + } + } + } +}; + +template struct FindChannelAbsMaxFunctor; + +template +__global__ void ClipAndQuantKernel(const T* in, const T* scale, + const int bin_cnt, const int n, T* out) { + int bid = threadIdx.x + blockIdx.x * blockDim.x; + int tid = threadIdx.x; + + T s = scale[0]; + T inv_s = inverse(s); + for (int i = bid; i < n; i += blockDim.x * gridDim.x) { + T x = in[i]; + T v = x > s ? s : x; + v = v < -s ? -s : v; + v = bin_cnt * inv_s * v; + out[i] = round(v); + } +} + +template +__global__ void ClipAndQuantDequantKernel(const T* in, const T* scale, + const int bin_cnt, const int n, + T* out) { + int bid = threadIdx.x + blockIdx.x * blockDim.x; + int tid = threadIdx.x; + + T s = scale[0]; + T inv_s = inverse(s); + T bin_cnt_t = static_cast(bin_cnt); + + for (int i = bid; i < n; i += blockDim.x * gridDim.x) { + T x = in[i]; + x = x > s ? s : x; + x = x < -s ? -s : x; + x = bin_cnt_t * inv_s * x; + x = static_cast(round(static_cast(x))); + out[i] = (x * s) / bin_cnt_t; + } +} + +template +struct ClipAndFakeQuantFunctor { + void operator()(const platform::CUDADeviceContext& ctx, + const framework::Tensor& in, const framework::Tensor& scale, + const int bin_cnt, framework::Tensor* out) { + int num = in.numel(); + int block = 1024; + int grid = (block - 1 + num) / block; + + const T* in_data = in.data(); + const T* scale_data = scale.data(); + T* out_data = out->mutable_data(ctx.GetPlace()); + + ClipAndQuantKernel<<>>( + in_data, scale_data, bin_cnt, num, out_data); + } +}; + +template struct ClipAndFakeQuantFunctor; + +template +struct ClipAndFakeQuantDequantFunctor { + void operator()(const platform::CUDADeviceContext& ctx, + const framework::Tensor& in, const framework::Tensor& scale, + const int bin_cnt, framework::Tensor* out) { + int num = in.numel(); + int block = 1024; + int grid = (block - 1 + num) / block; + + const T* in_data = in.data(); + const T* scale_data = scale.data(); + T* out_data = out->mutable_data(ctx.GetPlace()); + + ClipAndQuantDequantKernel<<>>( + in_data, scale_data, bin_cnt, num, out_data); + } +}; + +// ChannelClipAndQuantKernel for quant_axis is 0 +template +__global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale, + const int bin_cnt, + const int64_t n, + const int c, T* out) { + int tid = threadIdx.x; + + int64_t channel_size = n / c; + const T* in_c = in + blockIdx.x * channel_size; + T* out_c = out + blockIdx.x * channel_size; + + T s = scale[blockIdx.x]; + T inv_s = inverse(s); + + for (int64_t i = tid; i < channel_size; i += blockDim.x) { + T x = in_c[i]; + T v = x > s ? s : x; + v = v < -s ? -s : v; + v = bin_cnt * inv_s * v; + out_c[i] = round(v); + } +} + +// ChannelClipAndQuantKernel for quant_axis is N +template +__global__ void ChannelClipAndQuantKernelQuantAxisN( + const T* in, const T* scale, const int bin_cnt, const int64_t n, + const int nScale, const int quant_stride, T* out) { + int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; + for (int64_t i = idx; i < n; i += blockDim.x * gridDim.x) { + T s = scale[(i / quant_stride) % nScale]; + T inv_s = 1.0 / s; + T x = in[i]; + T v = x > s ? s : x; + v = v < -s ? -s : v; + v = bin_cnt * inv_s * v; + out[i] = round(v); + } +} + +template +struct ChannelClipAndFakeQuantFunctor { + void operator()(const platform::CUDADeviceContext& ctx, + const framework::Tensor& in, const framework::Tensor& scale, + const int bin_cnt, const int quant_axis, + framework::Tensor* out) { + PADDLE_ENFORCE_EQ( + quant_axis == 0 || quant_axis == 1, true, + platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " + "the received is %d", + quant_axis)); + + int64_t num = in.numel(); + auto in_dims = in.dims(); + const T* in_data = in.data(); + const T* scale_data = scale.data(); + T* out_data = out->mutable_data(ctx.GetPlace()); + + if (quant_axis == 0) { + int grid = in_dims[0]; + int block = 1024; + ChannelClipAndQuantKernelQuantAxis0<<>>( + in_data, scale_data, bin_cnt, num, in_dims[0], out_data); + } else { + int quant_stride = 1; + for (int i = quant_axis + 1; i < in_dims.size(); i++) { + quant_stride *= in_dims[i]; + } + int64_t block_size = + std::min(num, static_cast(ctx.GetMaxThreadsPerBlock() / 4)); + int64_t max_threads = + ctx.GetMaxPhysicalThreadCount(); // SM * block_per_SM + const int64_t max_blocks = std::max(((max_threads - 1) / block_size + 1), + static_cast(1)); + + const int64_t grid_size = + std::min(max_blocks, (num + block_size - 1) / block_size); + + ChannelClipAndQuantKernelQuantAxisN<<>>( + in_data, scale_data, bin_cnt, num, in_dims[quant_axis], quant_stride, + out_data); + } + } +}; + +template struct ChannelClipAndFakeQuantFunctor; + +template +__global__ void FindRangeAbsMaxAndFillArray(const T* cur_scale, + const T* last_scale, + const int64_t* iter, + const int window_size, T* scale_arr, + T* out_scale, int* need_find_max, + int* out_size) { + int it = iter[0]; + int idx = it % window_size; + T removed = scale_arr[idx]; + T cur = cur_scale[0]; + scale_arr[idx] = cur; + T max = last_scale[0]; + out_scale[0] = max < cur ? cur : max; + if (fabs(removed - max) < 1e-6) { + need_find_max[0] = 1; + out_size[0] = it > window_size ? window_size : it; + } else { + need_find_max[0] = 0; + } +} + +template +struct FindRangeAbsMaxFunctor { + void operator()(const platform::CUDADeviceContext& ctx, + const framework::Tensor& cur_scale, + const framework::Tensor& last_scale, + const framework::Tensor& iter, const int window_size, + framework::Tensor* scales_arr, framework::Tensor* out_scale) { + const auto gpu_place = ctx.GetPlace(); + + T* scale_arr = scales_arr->mutable_data(gpu_place); + T* out_scale_data = out_scale->mutable_data(gpu_place); + + framework::Tensor need_find_max, out_size; + int* find_max = need_find_max.mutable_data({1}, gpu_place); + int* out_size_data = out_size.mutable_data({1}, gpu_place); + + FindRangeAbsMaxAndFillArray<<<1, 1, 0, ctx.stream()>>>( + cur_scale.data(), last_scale.data(), iter.data(), + window_size, scale_arr, out_scale_data, find_max, out_size_data); + + int g_find_max; + memory::Copy(platform::CPUPlace(), &g_find_max, gpu_place, find_max, + sizeof(int), ctx.stream()); + ctx.Wait(); + if (g_find_max) { + int len; + memory::Copy(platform::CPUPlace(), &len, gpu_place, out_size_data, + sizeof(int), ctx.stream()); + ctx.Wait(); + FindAbsMaxFunctor()(ctx, scale_arr, len, + out_scale_data); + } + } +}; + +template +__global__ void FindMovingAverageAbsMaxKernel(const T* in_state, + const T* in_accum, + const T* cur_scale, const T rate, + T* out_state, T* out_accum, + T* out_scale) { + T state = rate * (*in_state) + T(1.0f); + T accum = rate * (*in_accum) + (*cur_scale); + *out_state = state; + *out_accum = accum; + *out_scale = accum / state; +} + +template struct FindRangeAbsMaxFunctor; + +template +struct FindMovingAverageAbsMaxFunctor { + void operator()(const platform::CUDADeviceContext& ctx, + const framework::Tensor& in_accum, + const framework::Tensor& in_state, const T* cur_scale, + const float rate, framework::Tensor* out_state, + framework::Tensor* out_accum, framework::Tensor* out_scale) { + const auto gpu_place = ctx.GetPlace(); + + T rate_t = static_cast(rate); + T* out_state_data = out_state->mutable_data(gpu_place); + T* out_accum_data = out_accum->mutable_data(gpu_place); + T* out_scale_data = out_scale->mutable_data(gpu_place); + + FindMovingAverageAbsMaxKernel<<<1, 1, 0, ctx.stream()>>>( + in_state.data(), in_accum.data(), cur_scale, rate_t, + out_state_data, out_accum_data, out_scale_data); + } +}; + +// ChannelClipAndQuantDequantKernel for quant_axis is 0 +template +__global__ void ChannelClipAndQuantDequantKernelQuantAxis0( + const T* in, const T* scale, const int bin_cnt, const int n, const int c, + T* out) { + int tid = threadIdx.x; + + int channel_size = n / c; + const T* in_c = in + blockIdx.x * channel_size; + T* out_c = out + blockIdx.x * channel_size; + + T s = scale[blockIdx.x]; + T inv_s = inverse(s); + + for (int i = tid; i < channel_size; i += blockDim.x) { + T x = in_c[i]; + T v = x > s ? s : x; + v = v < -s ? -s : v; + v = bin_cnt * inv_s * v; + out_c[i] = round(v) * s / bin_cnt; + } +} + +// ChannelClipAndQuantDequantKernel for quant_axis is 1 +template +__global__ void ChannelClipAndQuantDequantKernelQuantAxis1( + const T* in, const T* scale, const int bin_cnt, const int n, const int cin, + const int cout, T* out) { + T s = scale[blockIdx.x % cout]; + T inv_s = inverse(s); + + int wh_size = n / (cin * cout); + const T* in_c = in + blockIdx.x * wh_size; + T* out_c = out + blockIdx.x * wh_size; + + for (int i = threadIdx.x; i < wh_size; i += blockDim.x) { + T x = in_c[i]; + T v = x > s ? s : x; + v = v < -s ? -s : v; + v = bin_cnt * inv_s * v; + out_c[i] = round(v) * s / bin_cnt; + } +} + +template +struct ChannelClipFakeQuantDequantFunctor { + void operator()(const platform::CUDADeviceContext& ctx, + const framework::Tensor& in, const framework::Tensor& scale, + const int bin_cnt, const int quant_axis, + framework::Tensor* out) { + // At present, channelwise quantization supports conv2d, depthwise_conv2d + // conv2d_transpose and mul + PADDLE_ENFORCE_EQ( + quant_axis == 0 || quant_axis == 1, true, + platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " + "the received is %d", + quant_axis)); + + int num = in.numel(); + auto in_dims = in.dims(); + + const T* in_data = in.data(); + const T* scale_data = scale.data(); + T* out_data = out->mutable_data(ctx.GetPlace()); + + if (quant_axis == 0) { + int grid = in_dims[0]; + int block = 1024; + ChannelClipAndQuantDequantKernelQuantAxis0< + T><<>>(in_data, scale_data, bin_cnt, + num, in_dims[0], out_data); + } else if (quant_axis == 1) { + int grid = in_dims[0] * in_dims[1]; + int block = 1024; + + ChannelClipAndQuantDequantKernelQuantAxis1< + T><<>>( + in_data, scale_data, bin_cnt, num, in_dims[0], in_dims[1], out_data); + } + } +}; + +template struct ChannelClipFakeQuantDequantFunctor; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/quantize_linear_op.cu b/paddle/fluid/operators/quantize_linear_op.cu index bce67a3d4a18b..6c7e430f51126 100644 --- a/paddle/fluid/operators/quantize_linear_op.cu +++ b/paddle/fluid/operators/quantize_linear_op.cu @@ -14,57 +14,14 @@ limitations under the License. */ #include #include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/operators/fake_dequantize_op.cu.h" +#include "paddle/fluid/operators/fake_quantize_op.cu.h" #include "paddle/fluid/operators/quantize_linear_op.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" namespace paddle { namespace operators { -template -__global__ void KeDequantize(const T* in, const T* scale, T max_range, - int64_t num, T* out) { - int64_t idx = threadIdx.x + blockIdx.x * blockDim.x; - for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) { - out[i] = in[i] * scale[0] / max_range; - } -} - -template -struct DequantizeFunctor { - void operator()(const platform::CUDADeviceContext& dev_ctx, - const framework::Tensor* in, const framework::Tensor* scale, - T max_range, framework::Tensor* out) { - const T* in_data = in->data(); - const T* scale_factor = scale->data(); - T* out_data = out->mutable_data(dev_ctx.GetPlace()); - - int64_t num = in->numel(); - int64_t block_size = std::min( - num, static_cast(dev_ctx.GetMaxThreadsPerBlock() / 4)); - int64_t max_threads = - dev_ctx.GetMaxPhysicalThreadCount(); // SM * block_per_SM - const int64_t max_blocks = - std::max(((max_threads - 1) / block_size + 1), static_cast(1)); - const int64_t grid_size = - std::min(max_blocks, (num + block_size - 1) / block_size); - KeDequantize<<>>( - in_data, scale_factor, max_range, num, out_data); - } -}; - -template -__global__ void DequantizeOneScaleQuantAxisN(const T* in, const T* scale, - const T max_range, - const int64_t num, - const int n_scales, - const int quant_stride, T* out) { - int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; - for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) { - T s = scale[(i / quant_stride) % n_scales]; - out[i] = in[i] * s / max_range; - } -} - template struct ChannelDequantizeFunctorV2 { void operator()(const platform::CUDADeviceContext& dev_ctx, @@ -96,309 +53,9 @@ struct ChannelDequantizeFunctorV2 { } }; -template struct DequantizeFunctor; -template struct DequantizeFunctor; template struct ChannelDequantizeFunctorV2; template struct ChannelDequantizeFunctorV2; -template -__global__ void FindAbsMaxKernel(const T* in, const int n, T* out) { - int bid = threadIdx.x + blockIdx.x * blockDim.x; - int tid = threadIdx.x; - - extern __shared__ char* shared_max_data_tmp[]; - auto shared_max_data = reinterpret_cast(shared_max_data_tmp); - if (gridDim.x > 1) { - T local_max_data = T(0); - for (int i = bid; i < n; i += blockDim.x * gridDim.x) { - T tmp = abs(in[i]); - if (tmp > local_max_data) { - local_max_data = tmp; - } - } - shared_max_data[tid] = local_max_data; - } else { - if (bid < n) { - shared_max_data[tid] = abs(in[bid]); - } else { - shared_max_data[tid] = T(0); - } - } - __syncthreads(); - - for (int i = blockDim.x / 2; i > 0; i >>= 1) { - if (tid < i && (shared_max_data[tid] < shared_max_data[tid + i])) { - shared_max_data[tid] = shared_max_data[tid + i]; - } - __syncthreads(); - } - if (tid == 0) { - out[blockIdx.x] = shared_max_data[0]; - } -} - -template -struct FindAbsMaxFunctor { - void operator()(const platform::CUDADeviceContext& ctx, const T* in, - const int num, T* out) { - int block = 1024; - int grid = (block - 1 + num) / block; - grid = (grid > block) ? block : grid; - - framework::Tensor max; - T* max_data = max.mutable_data(phi::make_ddim({grid}), ctx.GetPlace()); - FindAbsMaxKernel<<>>( - in, num, max_data); - FindAbsMaxKernel<<<1, block, 1024 * sizeof(T), ctx.stream()>>>( - max_data, grid, out); - } -}; - -template struct FindAbsMaxFunctor; -template struct FindAbsMaxFunctor; - -template -__global__ void FindChannelAbsMaxKernelQuantAxis0(const T* in, const int n, - const int c, T* out) { - int tid = threadIdx.x; - int channel_size = n / c; - const T* in_c = in + blockIdx.x * channel_size; - extern __shared__ T shared_max_data[]; - shared_max_data[tid] = T(0); - for (int i = tid; i < channel_size; i += blockDim.x) { - T tmp = fabs(in_c[i]); - if (tmp > shared_max_data[tid]) { - shared_max_data[tid] = tmp; - } - } - __syncthreads(); - for (int i = blockDim.x / 2; i > 0; i >>= 1) { - if (tid < i && (shared_max_data[tid] < shared_max_data[tid + i])) { - shared_max_data[tid] = shared_max_data[tid + i]; - } - __syncthreads(); - } - if (tid == 0) { - out[blockIdx.x] = shared_max_data[0]; - } -} - -template -__global__ void FindChannelAbsMaxKernelQuantAxis1(const T* in, const int n, - const int cin, const int cout, - T* out) { - extern __shared__ T shared_max_data[]; - int cout_wh_size = n / cin; - int wh_size = n / (cin * cout); - - int tid = threadIdx.x; - int bid = blockIdx.x; - const T* in_current = in + tid * cout_wh_size + bid * wh_size; - T local_max_data = T(0); - for (int i = 0; i < wh_size; i++) { - T tmp = fabs(in_current[i]); - if (tmp > local_max_data) { - local_max_data = tmp; - } - } - shared_max_data[tid] = local_max_data; - __syncthreads(); - - int len = blockDim.x; - for (int i = (len + 1) / 2; i > 0; len = i, i = (i + 1) / 2) { - if (tid < i && tid + i < len && - shared_max_data[tid] < shared_max_data[tid + i]) { - shared_max_data[tid] = shared_max_data[tid + i]; - } - if (i == 1) { - i = 0; // break the loop - } - __syncthreads(); - } - if (tid == 0 && shared_max_data[0] > out[bid]) { - out[bid] = shared_max_data[0]; - } -} - -template -struct FindChannelAbsMaxFunctor { - void operator()(const platform::CUDADeviceContext& ctx, - const framework::Tensor& in_tensor, const int quant_axis, - T* out_abs_max) { - PADDLE_ENFORCE_EQ( - quant_axis == 0 || quant_axis == 1, true, - platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " - "the received is %d", - quant_axis)); - const int num = in_tensor.numel(); - auto in_dims = in_tensor.dims(); - const T* in_data = in_tensor.data(); - if (quant_axis == 0) { - int cout = in_dims[0]; - int grid = cout; - int block = 1024; - FindChannelAbsMaxKernelQuantAxis0< - T><<>>( - in_data, num, cout, out_abs_max); - } else if (quant_axis == 1) { - int cin = in_dims[0]; - int cout = in_dims[1]; - int grid = cout; - int max_threads = 1024; - -#ifdef PADDLE_WITH_HIP - hipMemset(out_abs_max, 0, sizeof(T) * cout); -#else - cudaMemset(out_abs_max, 0, sizeof(T) * cout); -#endif - - for (int i = 0; i < cin / max_threads; i++) { - int block = max_threads; - FindChannelAbsMaxKernelQuantAxis1< - T><<>>( - in_data, num, cin, cout, out_abs_max); - in_data += num / cin; - } - - int block = cin % max_threads; - if (block > 0) { - FindChannelAbsMaxKernelQuantAxis1< - T><<>>( - in_data, num, in_dims[0], in_dims[1], out_abs_max); - } - } - } -}; - -template struct FindChannelAbsMaxFunctor; - -template -__global__ void ClipAndQuantKernel(const T* in, const T* scale, - const int bin_cnt, const int n, T* out) { - int bid = threadIdx.x + blockIdx.x * blockDim.x; - int tid = threadIdx.x; - - T s = scale[0]; - T inv_s = inverse(s); - for (int i = bid; i < n; i += blockDim.x * gridDim.x) { - T x = in[i]; - T v = x > s ? s : x; - v = v < -s ? -s : v; - v = bin_cnt * inv_s * v; - out[i] = round(v); - } -} - -template -struct ClipAndFakeQuantFunctor { - void operator()(const platform::CUDADeviceContext& ctx, - const framework::Tensor& in, const framework::Tensor& scale, - const int bin_cnt, framework::Tensor* out) { - int num = in.numel(); - int block = 1024; - int grid = (block - 1 + num) / block; - - const T* in_data = in.data(); - const T* scale_data = scale.data(); - T* out_data = out->mutable_data(ctx.GetPlace()); - - ClipAndQuantKernel<<>>( - in_data, scale_data, bin_cnt, num, out_data); - } -}; - -template struct ClipAndFakeQuantFunctor; - -// ChannelClipAndQuantKernel for quant_axis is 0 -template -__global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale, - const int bin_cnt, - const int64_t n, - const int c, T* out) { - int tid = threadIdx.x; - - int64_t channel_size = n / c; - const T* in_c = in + blockIdx.x * channel_size; - T* out_c = out + blockIdx.x * channel_size; - - T s = scale[blockIdx.x]; - T inv_s = inverse(s); - - for (int64_t i = tid; i < channel_size; i += blockDim.x) { - T x = in_c[i]; - T v = x > s ? s : x; - v = v < -s ? -s : v; - v = bin_cnt * inv_s * v; - out_c[i] = round(v); - } -} - -// ChannelClipAndQuantKernel for quant_axis is N -template -__global__ void ChannelClipAndQuantKernelQuantAxisN( - const T* in, const T* scale, const int bin_cnt, const int64_t n, - const int nScale, const int quant_stride, T* out) { - int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; - for (int64_t i = idx; i < n; i += blockDim.x * gridDim.x) { - T s = scale[(i / quant_stride) % nScale]; - T inv_s = inverse(s); - T x = in[i]; - T v = x > s ? s : x; - v = v < -s ? -s : v; - v = bin_cnt * inv_s * v; - out[i] = round(v); - } -} - -template -struct ChannelClipAndFakeQuantFunctor { - void operator()(const platform::CUDADeviceContext& ctx, - const framework::Tensor& in, const framework::Tensor& scale, - const int bin_cnt, const int quant_axis, - framework::Tensor* out) { - PADDLE_ENFORCE_EQ( - quant_axis == 0 || quant_axis == 1, true, - platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " - "the received is %d", - quant_axis)); - - int64_t num = in.numel(); - auto in_dims = in.dims(); - const T* in_data = in.data(); - const T* scale_data = scale.data(); - T* out_data = out->mutable_data(ctx.GetPlace()); - - if (quant_axis == 0) { - int grid = in_dims[0]; - int block = 1024; - ChannelClipAndQuantKernelQuantAxis0<<>>( - in_data, scale_data, bin_cnt, num, in_dims[0], out_data); - } else { - int quant_stride = 1; - for (int i = quant_axis + 1; i < in_dims.size(); i++) { - quant_stride *= in_dims[i]; - } - int64_t block_size = - std::min(num, static_cast(ctx.GetMaxThreadsPerBlock() / 4)); - int64_t max_threads = - ctx.GetMaxPhysicalThreadCount(); // SM * block_per_SM - const int64_t max_blocks = std::max(((max_threads - 1) / block_size + 1), - static_cast(1)); - - const int64_t grid_size = - std::min(max_blocks, (num + block_size - 1) / block_size); - - ChannelClipAndQuantKernelQuantAxisN<<>>( - in_data, scale_data, bin_cnt, num, in_dims[quant_axis], quant_stride, - out_data); - } - } -}; - -template struct ChannelClipAndFakeQuantFunctor; - } // namespace operators } // namespace paddle From da5ddf780461dacf800d283a4b75ba274eadf5ff Mon Sep 17 00:00:00 2001 From: yghstill <742925032@qq.com> Date: Sat, 2 Apr 2022 01:59:00 +0000 Subject: [PATCH 06/14] fix unittest --- python/paddle/fluid/contrib/slim/quantization/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/utils.py b/python/paddle/fluid/contrib/slim/quantization/utils.py index 9881e6ec02b54..608844dd55da7 100644 --- a/python/paddle/fluid/contrib/slim/quantization/utils.py +++ b/python/paddle/fluid/contrib/slim/quantization/utils.py @@ -365,10 +365,10 @@ def dequant_tensor(x, scale, quant_axis=0, weight_bits=8): return x -def bias_correction_w(self, x, x_quant, scale_v, quant_axis, weight_bits=8): +def bias_correction_w(x, x_quant, scale_v, quant_axis, weight_bits=8): + ''' + Bias correction for weight ''' - Bias correction for weight - ''' eps = 1e-8 bnt = (1 << (weight_bits - 1)) - 1 x_dequant = x_quant.copy() From 82f2b71fb6e4a3cfe70b57901c1d2e350a8ed443 Mon Sep 17 00:00:00 2001 From: yghstill <742925032@qq.com> Date: Sat, 2 Apr 2022 13:03:29 +0000 Subject: [PATCH 07/14] fix coverage --- .../post_training_quantization.py | 19 +- ...t_post_training_quantization_lstm_model.py | 70 +++++- .../test_post_training_quantization_mnist.py | 76 +++++- .../unittests/test_quantize_linear_op.py | 230 ++++++++++++++++++ 4 files changed, 376 insertions(+), 19 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_quantize_linear_op.py diff --git a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py index a7b3dd5792a02..a4c7a2a2bf8df 100644 --- a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py +++ b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py @@ -387,23 +387,27 @@ def quantize(self): break _logger.info("Finish sampling stage, all batch: " + str(batch_id)) - if self._round_type == 'adaround': - self._adaround_apply() - - self._reset_activation_persistable() if self._algo == 'avg': for var_name in self._quantized_act_var_name: self._quantized_threshold[var_name] = \ np.array(self._quantized_var_avg[var_name]).mean() if self._algo in ["KL", "hist"]: self._calculate_kl_hist_threshold() - if self._algo in ["KL", "abs_max", "hist", "avg", "mse", "emd"]: - self._update_program() - else: + + if self._round_type == 'adaround': + self._adaround_apply() + + self._reset_activation_persistable() + + if self._algo is 'min_max': self._save_input_threhold() + else: + self._update_program() + # save out_threshold for quantized ops. if not self._onnx_format: self._save_output_threshold() + if any(op_type in self._quantizable_op_type for op_type in self._dynamic_quantize_op_type): self._collect_dynamic_quantize_op_threshold( @@ -428,6 +432,7 @@ def quantize(self): return self._program def _adaround_apply(self): + assert self._algo != "min_max", "The algo should not be min_max." if self._algo in ["KL", "hist"]: scale_dict = self._quantized_var_threshold else: diff --git a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_lstm_model.py b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_lstm_model.py index 58a430eb96406..85cabb6b5e9b7 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_lstm_model.py +++ b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_lstm_model.py @@ -173,7 +173,8 @@ def generate_quantized_model(self, is_use_cache_file=False, is_optimize_model=False, batch_size=10, - batch_nums=10): + batch_nums=10, + onnx_format=False): place = fluid.CPUPlace() exe = fluid.Executor(place) @@ -190,14 +191,28 @@ def generate_quantized_model(self, round_type=round_type, is_full_quantize=is_full_quantize, optimize_model=is_optimize_model, + onnx_format=onnx_format, is_use_cache_file=is_use_cache_file) ptq.quantize() ptq.save_quantized_model(self.int8_model_path) - def run_test(self, model_name, model_url, model_md5, data_name, data_url, - data_md5, algo, round_type, quantizable_op_type, - is_full_quantize, is_use_cache_file, is_optimize_model, - diff_threshold, infer_iterations, quant_iterations): + def run_test(self, + model_name, + model_url, + model_md5, + data_name, + data_url, + data_md5, + algo, + round_type, + quantizable_op_type, + is_full_quantize, + is_use_cache_file, + is_optimize_model, + diff_threshold, + infer_iterations, + quant_iterations, + onnx_format=False): fp32_model_path = self.download_model(model_url, model_md5, model_name) fp32_model_path = os.path.join(fp32_model_path, model_name) @@ -211,10 +226,10 @@ def run_test(self, model_name, model_url, model_md5, data_name, data_url, print("Start post training quantization for {0} on {1} samples ...". format(model_name, quant_iterations)) - self.generate_quantized_model(fp32_model_path, data_path, algo, - round_type, quantizable_op_type, - is_full_quantize, is_use_cache_file, - is_optimize_model, quant_iterations) + self.generate_quantized_model( + fp32_model_path, data_path, algo, round_type, quantizable_op_type, + is_full_quantize, is_use_cache_file, is_optimize_model, + quant_iterations, onnx_format) print("Start INT8 inference for {0} on {1} samples ...".format( model_name, infer_iterations)) @@ -278,5 +293,42 @@ def test_post_training_kl(self): diff_threshold, infer_iterations, quant_iterations) +class TestPostTrainingKLForMnistONNXFormat(TestPostTrainingQuantization): + def test_post_training_kl_onnx_format(self): + model_name = "nlp_lstm_fp32_model" + model_url = "https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/nlp_lstm_fp32_model.tar.gz" + model_md5 = "519b8eeac756e7b4b7bcb2868e880452" + data_name = "quant_lstm_input_data" + data_url = "https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/quant_lstm_input_data.tar.gz" + data_md5 = "add84c754e9b792fea1fbd728d134ab7" + algo = "KL" + round_type = "round" + quantizable_op_type = ["mul", "lstm"] + is_full_quantize = False + is_use_cache_file = False + is_optimize_model = False + diff_threshold = 0.01 + infer_iterations = 100 + quant_iterations = 10 + onnx_format = True + self.run_test( + model_name, + model_url, + model_md5, + data_name, + data_url, + data_md5, + algo, + round_type, + quantizable_op_type, + is_full_quantize, + is_use_cache_file, + is_optimize_model, + diff_threshold, + infer_iterations, + quant_iterations, + onnx_format=onnx_format) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mnist.py b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mnist.py index 74198da11fb2c..c219d2fbf89a9 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mnist.py +++ b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mnist.py @@ -116,7 +116,8 @@ def generate_quantized_model(self, is_use_cache_file=False, is_optimize_model=False, batch_size=10, - batch_nums=10): + batch_nums=10, + onnx_format=False): place = fluid.CPUPlace() exe = fluid.Executor(place) @@ -134,6 +135,7 @@ def generate_quantized_model(self, round_type=round_type, is_full_quantize=is_full_quantize, optimize_model=is_optimize_model, + onnx_format=onnx_format, is_use_cache_file=is_use_cache_file) ptq.quantize() ptq.save_quantized_model(self.int8_model_path) @@ -151,7 +153,8 @@ def run_test(self, diff_threshold, batch_size=10, infer_iterations=10, - quant_iterations=5): + quant_iterations=5, + onnx_format=False): origin_model_path = self.download_model(data_url, data_md5, model_name) origin_model_path = os.path.join(origin_model_path, model_name) @@ -166,7 +169,7 @@ def run_test(self, self.generate_quantized_model(origin_model_path, algo, round_type, quantizable_op_type, is_full_quantize, is_use_cache_file, is_optimize_model, - batch_size, quant_iterations) + batch_size, quant_iterations, onnx_format) print("Start INT8 inference for {0} on {1} images ...".format( model_name, infer_iterations * batch_size)) @@ -335,5 +338,72 @@ def test_post_training_mse(self): infer_iterations, quant_iterations) +class TestPostTrainingmseForMnistONNXFormat(TestPostTrainingQuantization): + def test_post_training_mse_onnx_format(self): + model_name = "mnist_model" + data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" + data_md5 = "be71d3997ec35ac2a65ae8a145e2887c" + algo = "mse" + round_type = "round" + quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"] + is_full_quantize = False + is_use_cache_file = False + is_optimize_model = True + onnx_format = True + diff_threshold = 0.01 + batch_size = 10 + infer_iterations = 50 + quant_iterations = 5 + self.run_test( + model_name, + data_url, + data_md5, + algo, + round_type, + quantizable_op_type, + is_full_quantize, + is_use_cache_file, + is_optimize_model, + diff_threshold, + batch_size, + infer_iterations, + quant_iterations, + onnx_format=onnx_format) + + +class TestPostTrainingmseForMnistONNXFormatFullQuant( + TestPostTrainingQuantization): + def test_post_training_mse_onnx_format_full_quant(self): + model_name = "mnist_model" + data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" + data_md5 = "be71d3997ec35ac2a65ae8a145e2887c" + algo = "mse" + round_type = "round" + quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"] + is_full_quantize = True + is_use_cache_file = False + is_optimize_model = False + onnx_format = True + diff_threshold = 0.01 + batch_size = 10 + infer_iterations = 50 + quant_iterations = 5 + self.run_test( + model_name, + data_url, + data_md5, + algo, + round_type, + quantizable_op_type, + is_full_quantize, + is_use_cache_file, + is_optimize_model, + diff_threshold, + batch_size, + infer_iterations, + quant_iterations, + onnx_format=onnx_format) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_quantize_linear_op.py b/python/paddle/fluid/tests/unittests/test_quantize_linear_op.py new file mode 100644 index 0000000000000..99b00bc0c6897 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_quantize_linear_op.py @@ -0,0 +1,230 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import math +from op_test import OpTest + + +def quantize_max_abs(x, max_range): + scale = np.max(np.abs(x).flatten()) + y = np.round(x / scale * max_range) + return y, scale + + +def dequantize_max_abs(x, scale, max_range): + y = (scale / max_range) * x + return y + + +def channel_wise_quantize_max_abs(x, quant_bit=8, quant_axis=0): + assert quant_axis in [0, 1], "The quant_axis should be 0 or 1." + scales = [] + y = x.copy() + max_range = math.pow(2, quant_bit - 1) - 1 + if quant_axis == 0: + for i in range(x.shape[0]): + scale = np.max(np.abs(x[i])).astype("float32") + scales.append(scale) + y[i] = np.round(x[i] * max_range / scale) + elif quant_axis == 1: + for i in range(x.shape[1]): + scale = np.max(np.abs(x[:, i])).astype("float32") + scales.append(scale) + y[:, i] = np.round(x[:, i] * max_range / scale) + return y, scales + + +def channel_wise_dequantize_max_abs(x, + scales, + quant_bits, + quant_axis, + activation_scale=None): + assert quant_axis in [0, 1], "The quant_axis should be 0 or 1." + + if isinstance(quant_bits, list): + max_range = math.pow(2, quant_bits[0] - 1) - 1 + else: + max_range = math.pow(2, quant_bits - 1) - 1 + y = x.copy() + if quant_axis == 0: + for i in range(x.shape[0]): + y[i] = x[i] * scales[i] / max_range + elif quant_axis == 1: + for i in range(x.shape[1]): + y[:, i] = x[:, i] * scales[i] / max_range + + if activation_scale is not None: + y = y * activation_scale / (math.pow(2, quant_bits[1] - 1) - 1) + return y + + +class TestChannelWiseDequantizeOp(OpTest): + def set_args(self): + self.bit_length = 8 + self.data_type = "float32" + self.quant_axis = 0 + self.zero_point = 0. + + def setUp(self): + self.set_args() + self.op_type = "dequantize_linear" + x = np.random.randn(4, 3, 64, 64).astype(self.data_type) + yq, scale = channel_wise_quantize_max_abs(x, self.quant_bits[0], + self.quant_axis) + ydq = channel_wise_dequantize_max_abs(yq, scales, self.quant_bits, + self.quant_axis) + + self.inputs = { + 'X': yq, + 'Scale': np.array(scale).astype(self.data_type), + 'ZeroPoint': self.zero_point + } + self.attrs = { + 'bit_length': self.bit_length, + 'quant_axis': self.quant_axis + } + self.outputs = {'Y': ydq} + + def test_check_output(self): + self.check_output() + + +class TestChannelWiseDequantizeOp1(TestChannelWiseDequantizeOp): + def set_args(self): + self.bit_length = 8 + self.data_type = "float32" + self.quant_axis = 1 + self.zero_point = 0. + + +class TestDequantizeOp(OpTest): + def set_args(self): + self.num_bits = 8 + self.quant_axis = -1 + self.max_range = math.pow(2, self.num_bits - 1) - 1 + self.data_type = "float32" + self.zero_point = 0. + + def setUp(self): + self.set_args() + self.op_type = "dequantize_linear" + x = np.random.randn(31, 65).astype(self.data_type) + yq, scale = quantize_max_abs(x, self.max_range) + ydq = dequantize_max_abs(yq, scale, self.max_range) + + self.inputs = { + 'X': yq, + 'Scale': np.array(scale).astype(self.data_type), + 'ZeroPoint': self.zero_point + } + self.attrs = { + 'bit_length': self.bit_length, + 'quant_axis': self.quant_axis + } + self.outputs = {'Y': ydq} + + def test_check_output(self): + self.check_output() + + +class TestDequantizeOpDouble(TestDequantizeOp): + def set_args(self): + self.num_bits = 8 + self.max_range = math.pow(2, self.num_bits - 1) - 1 + self.data_type = "float64" + self.zero_point = 0. + self.quant_axis = -1 + + +class TestFakeDequantizeMaxAbsOp5Bits(TestDequantizeOp): + def set_args(self): + self.num_bits = 5 + self.max_range = math.pow(2, self.num_bits - 1) - 1 + self.data_type = "float32" + self.zero_point = 0. + self.quant_axis = -1 + + +class TestChannelWisequantizeOp(OpTest): + def set_args(self): + self.bit_length = 8 + self.data_type = "float32" + self.quant_axis = 0 + self.zero_point = 0. + + def setUp(self): + self.set_args() + self.op_type = "quantize_linear" + x = np.random.randn(4, 3, 64, 64).astype(self.data_type) + yq, scale = channel_wise_quantize_max_abs(x, self.quant_bits[0], + self.quant_axis) + + self.inputs = { + 'X': x, + 'Scale': np.array(scale).astype(self.data_type), + 'ZeroPoint': self.zero_point + } + self.attrs = { + 'bit_length': self.bit_length, + 'quant_axis': self.quant_axis + } + self.outputs = {'Y': yq} + + def test_check_output(self): + self.check_output() + + +class TestChannelWisequantizeOp1(TestChannelWisequantizeOp): + def set_args(self): + self.bit_length = 8 + self.data_type = "float32" + self.quant_axis = 1 + self.zero_point = 0. + + +class TestquantizeOp(OpTest): + def set_args(self): + self.num_bits = 8 + self.quant_axis = -1 + self.max_range = math.pow(2, self.num_bits - 1) - 1 + self.data_type = "float32" + self.zero_point = 0. + + def setUp(self): + self.set_args() + self.op_type = "dequantize_linear" + x = np.random.randn(31, 65).astype(self.data_type) + yq, scale = quantize_max_abs(x, self.max_range) + + self.inputs = { + 'X': x, + 'Scale': np.array(scale).astype(self.data_type), + 'ZeroPoint': self.zero_point + } + self.attrs = { + 'bit_length': self.bit_length, + 'quant_axis': self.quant_axis + } + self.outputs = {'Y': yq} + + def test_check_output(self): + self.check_output() + + +if __name__ == "__main__": + unittest.main() From de75349981579d760b49591f26eef9f16ad118b7 Mon Sep 17 00:00:00 2001 From: yghstill <742925032@qq.com> Date: Sat, 2 Apr 2022 14:31:01 +0000 Subject: [PATCH 08/14] fix clip include --- paddle/fluid/operators/quantize_linear_op.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/operators/quantize_linear_op.cc b/paddle/fluid/operators/quantize_linear_op.cc index 0085740500689..4039f0e9d07e1 100644 --- a/paddle/fluid/operators/quantize_linear_op.cc +++ b/paddle/fluid/operators/quantize_linear_op.cc @@ -15,9 +15,9 @@ limitations under the License. */ #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/fluid/operators/clip_op.h" #include "paddle/fluid/platform/transform.h" #include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/impl/clip_kernel_impl.h" namespace paddle { namespace operators { From 7bd3abb70782a46ffa9e481fd3e89a40b6c5409b Mon Sep 17 00:00:00 2001 From: yghstill <742925032@qq.com> Date: Sat, 2 Apr 2022 15:25:57 +0000 Subject: [PATCH 09/14] fix test_quantize_linear_op --- .../unittests/test_quantize_linear_op.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_quantize_linear_op.py b/python/paddle/fluid/tests/unittests/test_quantize_linear_op.py index 99b00bc0c6897..b09a6a3506869 100644 --- a/python/paddle/fluid/tests/unittests/test_quantize_linear_op.py +++ b/python/paddle/fluid/tests/unittests/test_quantize_linear_op.py @@ -84,9 +84,9 @@ def setUp(self): self.set_args() self.op_type = "dequantize_linear" x = np.random.randn(4, 3, 64, 64).astype(self.data_type) - yq, scale = channel_wise_quantize_max_abs(x, self.quant_bits[0], + yq, scale = channel_wise_quantize_max_abs(x, self.bit_length, self.quant_axis) - ydq = channel_wise_dequantize_max_abs(yq, scales, self.quant_bits, + ydq = channel_wise_dequantize_max_abs(yq, scale, self.bit_length, self.quant_axis) self.inputs = { @@ -114,9 +114,9 @@ def set_args(self): class TestDequantizeOp(OpTest): def set_args(self): - self.num_bits = 8 + self.bit_length = 8 self.quant_axis = -1 - self.max_range = math.pow(2, self.num_bits - 1) - 1 + self.max_range = math.pow(2, self.bit_length - 1) - 1 self.data_type = "float32" self.zero_point = 0. @@ -144,8 +144,8 @@ def test_check_output(self): class TestDequantizeOpDouble(TestDequantizeOp): def set_args(self): - self.num_bits = 8 - self.max_range = math.pow(2, self.num_bits - 1) - 1 + self.bit_length = 8 + self.max_range = math.pow(2, self.bit_length - 1) - 1 self.data_type = "float64" self.zero_point = 0. self.quant_axis = -1 @@ -153,8 +153,8 @@ def set_args(self): class TestFakeDequantizeMaxAbsOp5Bits(TestDequantizeOp): def set_args(self): - self.num_bits = 5 - self.max_range = math.pow(2, self.num_bits - 1) - 1 + self.bit_length = 5 + self.max_range = math.pow(2, self.bit_length - 1) - 1 self.data_type = "float32" self.zero_point = 0. self.quant_axis = -1 @@ -171,7 +171,7 @@ def setUp(self): self.set_args() self.op_type = "quantize_linear" x = np.random.randn(4, 3, 64, 64).astype(self.data_type) - yq, scale = channel_wise_quantize_max_abs(x, self.quant_bits[0], + yq, scale = channel_wise_quantize_max_abs(x, self.bit_length, self.quant_axis) self.inputs = { @@ -199,9 +199,9 @@ def set_args(self): class TestquantizeOp(OpTest): def set_args(self): - self.num_bits = 8 + self.bit_length = 8 self.quant_axis = -1 - self.max_range = math.pow(2, self.num_bits - 1) - 1 + self.max_range = math.pow(2, self.bit_length - 1) - 1 self.data_type = "float32" self.zero_point = 0. From 94e7d0e3cf1415b662725a9821dd68e4812f5e7b Mon Sep 17 00:00:00 2001 From: yghstill <742925032@qq.com> Date: Sun, 3 Apr 2022 02:32:28 +0000 Subject: [PATCH 10/14] fix unittest --- .../unittests/test_fake_dequantize_op.py | 82 +++++++ .../tests/unittests/test_fake_quantize_op.py | 140 +++++++++++ .../unittests/test_quantize_linear_op.py | 230 ------------------ 3 files changed, 222 insertions(+), 230 deletions(-) delete mode 100644 python/paddle/fluid/tests/unittests/test_quantize_linear_op.py diff --git a/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py b/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py index b30e0a6775ea9..ee40c10fed3e5 100644 --- a/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py +++ b/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py @@ -172,5 +172,87 @@ def set_args(self): self.data_type = "float32" +class TestChannelWiseDequantizeOp(OpTest): + def set_args(self): + self.bit_length = 8 + self.data_type = "float32" + self.quant_axis = 0 + + def setUp(self): + self.set_args() + self.op_type = "dequantize_linear" + x = np.random.randn(4, 3, 64, 64).astype(self.data_type) + yq, scale = channel_wise_quantize_max_abs(x, self.bit_length, + self.quant_axis) + ydq = channel_wise_dequantize_max_abs(yq, scale, self.bit_length, + self.quant_axis) + scale = np.array(scale).astype(self.data_type) + zero_point = np.zeros(scale.shape, dtype="int32") + print('TestChannelWiseDequantizeOp:') + self.inputs = {'X': yq, 'Scale': scale, 'ZeroPoint': zero_point} + self.attrs = { + 'bit_length': self.bit_length, + 'quant_axis': self.quant_axis + } + self.outputs = {'Y': ydq} + + def test_check_output(self): + self.check_output() + + +class TestChannelWiseDequantizeOp1(TestChannelWiseDequantizeOp): + def set_args(self): + self.bit_length = 8 + self.data_type = "float32" + self.quant_axis = 1 + print('TestChannelWiseDequantizeOp1:') + + +class TestDequantizeOp(OpTest): + def set_args(self): + self.bit_length = 8 + self.quant_axis = -1 + self.max_range = math.pow(2, self.bit_length - 1) - 1 + self.data_type = "float32" + print('TestDequantizeOp:') + + def setUp(self): + self.set_args() + self.op_type = "dequantize_linear" + x = np.random.randn(31, 65).astype(self.data_type) + yq, scale = quantize_max_abs(x, self.max_range) + ydq = dequantize_max_abs(yq, scale, self.max_range) + scale = np.array(scale).astype(self.data_type) + zero_point = np.zeros(scale.shape, dtype="int32") + + self.inputs = {'X': yq, 'Scale': scale, 'ZeroPoint': zero_point} + self.attrs = { + 'bit_length': self.bit_length, + 'quant_axis': self.quant_axis + } + self.outputs = {'Y': ydq} + + def test_check_output(self): + self.check_output() + + +class TestDequantizeOpDouble(TestDequantizeOp): + def set_args(self): + self.bit_length = 8 + self.max_range = math.pow(2, self.bit_length - 1) - 1 + self.data_type = "float64" + self.quant_axis = -1 + print('TestDequantizeOpDouble:') + + +class TestFakeDequantizeMaxAbsOp5Bits(TestDequantizeOp): + def set_args(self): + self.bit_length = 5 + self.max_range = math.pow(2, self.bit_length - 1) - 1 + self.data_type = "float32" + self.quant_axis = -1 + print('TestFakeDequantizeMaxAbsOp5Bits:') + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py b/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py index 1d7bfc9f6963c..2be61d1218560 100644 --- a/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py +++ b/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py @@ -16,6 +16,7 @@ import unittest import numpy as np +import math from op_test import OpTest import paddle.fluid.core as core @@ -374,5 +375,144 @@ def set_arg(self): self.inputs = {'X': np.random.random((30, 15)).astype("float32"), } +def quantize_max_abs(x, max_range): + scale = np.max(np.abs(x).flatten()) + y = np.round(x / scale * max_range) + return y, scale + + +def channel_wise_quantize_max_abs(x, quant_bit=8, quant_axis=0): + assert quant_axis in [0, 1], "The quant_axis should be 0 or 1." + scales = [] + y = x.copy() + max_range = math.pow(2, quant_bit - 1) - 1 + if quant_axis == 0: + for i in range(x.shape[0]): + scale = np.max(np.abs(x[i])).astype("float32") + scales.append(scale) + y[i] = np.round(x[i] * max_range / scale) + elif quant_axis == 1: + for i in range(x.shape[1]): + scale = np.max(np.abs(x[:, i])).astype("float32") + scales.append(scale) + y[:, i] = np.round(x[:, i] * max_range / scale) + return y, scales + + +class TestChannelWiseQuantizeOp(OpTest): + def set_args(self): + self.bit_length = 8 + self.data_type = "float32" + self.quant_axis = 0 + + def setUp(self): + self.set_args() + self.op_type = "quantize_linear" + x = np.random.randn(4, 3, 64, 64).astype(self.data_type) + yq, scale = channel_wise_quantize_max_abs(x, self.bit_length, + self.quant_axis) + scale = np.array(scale).astype(self.data_type) + zero_point = np.zeros(scale.shape, dtype="int32") + + self.inputs = {'X': x, 'Scale': scale, 'ZeroPoint': zero_point} + self.attrs = { + 'bit_length': self.bit_length, + 'quant_axis': self.quant_axis + } + self.outputs = {'Y': yq} + + def test_check_output(self): + self.check_output() + + +class TestChannelWiseQuantizeOp1(TestChannelWiseQuantizeOp): + def set_args(self): + self.bit_length = 8 + self.data_type = "float32" + self.quant_axis = 1 + + +class TestChannelWiseQuantizeOpTrain(OpTest): + def set_args(self): + self.bit_length = 8 + self.data_type = "float32" + self.quant_axis = 0 + self.is_test = False + + def setUp(self): + self.set_args() + self.op_type = "quantize_linear" + x = np.random.randn(4, 3, 64, 64).astype(self.data_type) + yq, scale = channel_wise_quantize_max_abs(x, self.bit_length, + self.quant_axis) + scale = np.array(scale).astype(self.data_type) + zero_point = np.zeros(scale.shape, dtype="int32") + + self.inputs = {'X': x, 'Scale': scale, 'ZeroPoint': zero_point} + self.attrs = { + 'bit_length': self.bit_length, + 'quant_axis': self.quant_axis, + 'is_test': self.is_test + } + self.outputs = {'Y': yq, 'OutScale': scale} + + def test_check_output(self): + self.check_output() + + +class TestquantizeOp(OpTest): + def set_args(self): + self.bit_length = 8 + self.quant_axis = -1 + self.max_range = math.pow(2, self.bit_length - 1) - 1 + self.data_type = "float32" + + def setUp(self): + self.set_args() + self.op_type = "quantize_linear" + x = np.random.randn(31, 65).astype(self.data_type) + yq, scale = quantize_max_abs(x, self.max_range) + scale = np.array(scale).astype(self.data_type) + zero_point = np.zeros(scale.shape, dtype="int32") + + self.inputs = {'X': x, 'Scale': scale, 'ZeroPoint': zero_point} + self.attrs = { + 'bit_length': self.bit_length, + 'quant_axis': self.quant_axis, + } + self.outputs = {'Y': yq} + + def test_check_output(self): + self.check_output() + + +class TestquantizeOpTrain(TestquantizeOp): + def set_args(self): + self.bit_length = 8 + self.quant_axis = -1 + self.max_range = math.pow(2, self.bit_length - 1) - 1 + self.data_type = "float32" + self.is_test = False + + def setUp(self): + self.set_args() + self.op_type = "quantize_linear" + x = np.random.randn(31, 65).astype(self.data_type) + yq, scale = quantize_max_abs(x, self.max_range) + scale = np.array(scale).astype(self.data_type) + zero_point = np.zeros(scale.shape, dtype="int32") + + self.inputs = {'X': x, 'Scale': scale, 'ZeroPoint': zero_point} + self.attrs = { + 'bit_length': self.bit_length, + 'quant_axis': self.quant_axis, + 'is_test': self.is_test + } + self.outputs = {'Y': yq, 'OutScale': scale} + + def test_check_output(self): + self.check_output() + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_quantize_linear_op.py b/python/paddle/fluid/tests/unittests/test_quantize_linear_op.py deleted file mode 100644 index b09a6a3506869..0000000000000 --- a/python/paddle/fluid/tests/unittests/test_quantize_linear_op.py +++ /dev/null @@ -1,230 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import print_function - -import unittest -import numpy as np -import math -from op_test import OpTest - - -def quantize_max_abs(x, max_range): - scale = np.max(np.abs(x).flatten()) - y = np.round(x / scale * max_range) - return y, scale - - -def dequantize_max_abs(x, scale, max_range): - y = (scale / max_range) * x - return y - - -def channel_wise_quantize_max_abs(x, quant_bit=8, quant_axis=0): - assert quant_axis in [0, 1], "The quant_axis should be 0 or 1." - scales = [] - y = x.copy() - max_range = math.pow(2, quant_bit - 1) - 1 - if quant_axis == 0: - for i in range(x.shape[0]): - scale = np.max(np.abs(x[i])).astype("float32") - scales.append(scale) - y[i] = np.round(x[i] * max_range / scale) - elif quant_axis == 1: - for i in range(x.shape[1]): - scale = np.max(np.abs(x[:, i])).astype("float32") - scales.append(scale) - y[:, i] = np.round(x[:, i] * max_range / scale) - return y, scales - - -def channel_wise_dequantize_max_abs(x, - scales, - quant_bits, - quant_axis, - activation_scale=None): - assert quant_axis in [0, 1], "The quant_axis should be 0 or 1." - - if isinstance(quant_bits, list): - max_range = math.pow(2, quant_bits[0] - 1) - 1 - else: - max_range = math.pow(2, quant_bits - 1) - 1 - y = x.copy() - if quant_axis == 0: - for i in range(x.shape[0]): - y[i] = x[i] * scales[i] / max_range - elif quant_axis == 1: - for i in range(x.shape[1]): - y[:, i] = x[:, i] * scales[i] / max_range - - if activation_scale is not None: - y = y * activation_scale / (math.pow(2, quant_bits[1] - 1) - 1) - return y - - -class TestChannelWiseDequantizeOp(OpTest): - def set_args(self): - self.bit_length = 8 - self.data_type = "float32" - self.quant_axis = 0 - self.zero_point = 0. - - def setUp(self): - self.set_args() - self.op_type = "dequantize_linear" - x = np.random.randn(4, 3, 64, 64).astype(self.data_type) - yq, scale = channel_wise_quantize_max_abs(x, self.bit_length, - self.quant_axis) - ydq = channel_wise_dequantize_max_abs(yq, scale, self.bit_length, - self.quant_axis) - - self.inputs = { - 'X': yq, - 'Scale': np.array(scale).astype(self.data_type), - 'ZeroPoint': self.zero_point - } - self.attrs = { - 'bit_length': self.bit_length, - 'quant_axis': self.quant_axis - } - self.outputs = {'Y': ydq} - - def test_check_output(self): - self.check_output() - - -class TestChannelWiseDequantizeOp1(TestChannelWiseDequantizeOp): - def set_args(self): - self.bit_length = 8 - self.data_type = "float32" - self.quant_axis = 1 - self.zero_point = 0. - - -class TestDequantizeOp(OpTest): - def set_args(self): - self.bit_length = 8 - self.quant_axis = -1 - self.max_range = math.pow(2, self.bit_length - 1) - 1 - self.data_type = "float32" - self.zero_point = 0. - - def setUp(self): - self.set_args() - self.op_type = "dequantize_linear" - x = np.random.randn(31, 65).astype(self.data_type) - yq, scale = quantize_max_abs(x, self.max_range) - ydq = dequantize_max_abs(yq, scale, self.max_range) - - self.inputs = { - 'X': yq, - 'Scale': np.array(scale).astype(self.data_type), - 'ZeroPoint': self.zero_point - } - self.attrs = { - 'bit_length': self.bit_length, - 'quant_axis': self.quant_axis - } - self.outputs = {'Y': ydq} - - def test_check_output(self): - self.check_output() - - -class TestDequantizeOpDouble(TestDequantizeOp): - def set_args(self): - self.bit_length = 8 - self.max_range = math.pow(2, self.bit_length - 1) - 1 - self.data_type = "float64" - self.zero_point = 0. - self.quant_axis = -1 - - -class TestFakeDequantizeMaxAbsOp5Bits(TestDequantizeOp): - def set_args(self): - self.bit_length = 5 - self.max_range = math.pow(2, self.bit_length - 1) - 1 - self.data_type = "float32" - self.zero_point = 0. - self.quant_axis = -1 - - -class TestChannelWisequantizeOp(OpTest): - def set_args(self): - self.bit_length = 8 - self.data_type = "float32" - self.quant_axis = 0 - self.zero_point = 0. - - def setUp(self): - self.set_args() - self.op_type = "quantize_linear" - x = np.random.randn(4, 3, 64, 64).astype(self.data_type) - yq, scale = channel_wise_quantize_max_abs(x, self.bit_length, - self.quant_axis) - - self.inputs = { - 'X': x, - 'Scale': np.array(scale).astype(self.data_type), - 'ZeroPoint': self.zero_point - } - self.attrs = { - 'bit_length': self.bit_length, - 'quant_axis': self.quant_axis - } - self.outputs = {'Y': yq} - - def test_check_output(self): - self.check_output() - - -class TestChannelWisequantizeOp1(TestChannelWisequantizeOp): - def set_args(self): - self.bit_length = 8 - self.data_type = "float32" - self.quant_axis = 1 - self.zero_point = 0. - - -class TestquantizeOp(OpTest): - def set_args(self): - self.bit_length = 8 - self.quant_axis = -1 - self.max_range = math.pow(2, self.bit_length - 1) - 1 - self.data_type = "float32" - self.zero_point = 0. - - def setUp(self): - self.set_args() - self.op_type = "dequantize_linear" - x = np.random.randn(31, 65).astype(self.data_type) - yq, scale = quantize_max_abs(x, self.max_range) - - self.inputs = { - 'X': x, - 'Scale': np.array(scale).astype(self.data_type), - 'ZeroPoint': self.zero_point - } - self.attrs = { - 'bit_length': self.bit_length, - 'quant_axis': self.quant_axis - } - self.outputs = {'Y': yq} - - def test_check_output(self): - self.check_output() - - -if __name__ == "__main__": - unittest.main() From fd29c7242c5aa971a83d95b6315aab3ba3c63daf Mon Sep 17 00:00:00 2001 From: yghstill <742925032@qq.com> Date: Sun, 3 Apr 2022 07:44:23 +0000 Subject: [PATCH 11/14] fix coverage --- .../slim/quantization/quantization_pass.py | 162 +++++------------- 1 file changed, 41 insertions(+), 121 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index d4247f2c2a0ee..9b589d1549806 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -2109,41 +2109,6 @@ def _has_weight(self, op): has_weight = True return has_weight - def _create_global_step(self, graph): - if self._weight_quantize_type == 'range_abs_max' or \ - self._activation_quantize_type == 'range_abs_max': - counter_name = cpt.to_text('@STEP_COUNTER@') - for node in graph.all_var_nodes(): - if node.name() == counter_name: - self._global_step = node - if self._global_step is None: - global_step_in = graph.create_persistable_node( - name=counter_name, - var_type=core.VarDesc.VarType.LOD_TENSOR, - shape=[1], - var_dtype=core.VarDesc.VarType.INT64) - _init_var_node( - global_step_in, - np.zeros( - [1], dtype='int64'), - self._scope, - self._place) - global_step_out = graph.create_var_node_from_desc( - global_step_in.var()) - # The attribute of `op_role` is needed by ParallelExecutor. - increment_op = graph.create_op_node( - op_type='increment', - attrs={ - 'step': 1.0, - 'op_role': - core.op_proto_and_checker_maker.OpRole.Forward - }, - inputs={'X': global_step_in}, - outputs={'Out': global_step_out}) - graph.link_to(global_step_in, increment_op) - graph.link_to(increment_op, global_step_out) - self._global_step = global_step_out - def _is_skip_quant(self, graph, op_node): """ Analyse whether the op node skips quantization. @@ -2181,8 +2146,6 @@ def apply(self, graph): p.name() for p in graph.all_persistable_nodes() ] - if not self._is_test: - self._create_global_step(graph) ops = graph.all_op_nodes() # Do the preproccess of quantization, such as skipping some ops # for not being quantized. @@ -2202,7 +2165,6 @@ def apply(self, graph): for op in ops: if op.name() in self._quantizable_grad_ops and self._has_weight(op): self._transform_backward(graph, op) - #graph.resolve_hazard() return graph @@ -2350,47 +2312,24 @@ def __init__(self, scope, place): def apply(self, graph): assert isinstance(graph, IrGraph), 'graph must be the instance of IrGraph.' - fake_quant_ops = [] - fake_dequant_ops = [] fake_quant_dequant_ops = [] for op in graph.all_op_nodes(): - if op.name() in _fake_quant_op_list: - fake_quant_ops.append(op) - elif op.name() in _fake_dequant_op_list: - fake_dequant_ops.append(op) - elif op.name() in _fake_quant_dequant_op_list: + if op.name() in _fake_quant_dequant_op_list: fake_quant_dequant_ops.append(op) - for _op in fake_quant_ops: - print(_op.name()) - self._replace_op(graph, _op, "quantize_linear") - graph.safe_remove_nodes(_op) - - for _op in fake_dequant_ops: - self._replace_op(graph, _op, "dequantize_linear") - graph.safe_remove_nodes(_op) - for _op in fake_quant_dequant_ops: - self._replace_op(graph, _op, "quantize_dequantize") + self._replace_op(graph, _op) graph.safe_remove_nodes(_op) graph.resolve_hazard() return graph - def _replace_op(self, graph, op, target_op_name): - assert target_op_name in [ - "quantize_linear", "dequantize_linear", "quantize_dequantize" - ] + def _replace_op(self, graph, op): x_node = graph._find_node_by_name(op.inputs, op.input("X")[0]) out_node = graph._find_node_by_name(op.outputs, op.output("Out")[0]) - if target_op_name == "quantize_linear" or target_op_name == "quantize_dequantize": - scale_node = graph._find_node_by_name(op.outputs, - op.output("OutScale")[0]) - else: - scale_name = "Scales" if op.op().has_attr("quant_axis") else "Scale" - scale_node = graph._find_node_by_name(op.inputs, - op.input(scale_name)[0]) + scale_node = graph._find_node_by_name(op.outputs, + op.output("OutScale")[0]) quant_axis = op.op().attr("quant_axis") if op.op().has_attr( "quant_axis") else -1 @@ -2398,7 +2337,7 @@ def _replace_op(self, graph, op, target_op_name): "bit_length") else 8 zero_point_node = None - quanted_node = out_node if target_op_name == "quantize_linear" else x_node + quanted_node = x_node if zero_point_node is None: zero_point_node = graph.create_persistable_node( name=self._zero_point_name(quanted_node.name()), @@ -2412,58 +2351,41 @@ def _replace_op(self, graph, op, target_op_name): self._scope, self._place) - if target_op_name != "quantize_dequantize": - inputs = {"X": x_node, "Scale": scale_node} - if zero_point_node is not None: - inputs["ZeroPoint"] = zero_point_node - quant_op_node = graph.create_op_node( - op_type=target_op_name, - attrs={"quant_axis": quant_axis, - "bit_length": bit_length}, - inputs=inputs, - outputs={"Y": out_node}) - - graph.link_to(x_node, quant_op_node) - graph.link_to(scale_node, quant_op_node) - if zero_point_node is not None: - graph.link_to(zero_point_node, quant_op_node) - graph.link_to(quant_op_node, out_node) - else: - quant_var_node = graph.create_var_node( - name=self._quantized_var_name(x_node.name()), - var_type=x_node.type(), - shape=x_node.shape(), - var_dtype=x_node.dtype()) - quant_op_node = graph.create_op_node( - op_type="quantize_linear", - attrs={"quant_axis": quant_axis, - "bit_length": bit_length}, - inputs={ - "X": x_node, - "Scale": scale_node, - "ZeroPoint": zero_point_node - }, - outputs={"Y": quant_var_node}) - graph.link_to(x_node, quant_op_node) - graph.link_to(scale_node, quant_op_node) - if zero_point_node is not None: - graph.link_to(zero_point_node, quant_op_node) - graph.link_to(quant_op_node, quant_var_node) - dequant_op_node = graph.create_op_node( - op_type="dequantize_linear", - attrs={"quant_axis": quant_axis, - "bit_length": bit_length}, - inputs={ - "X": quant_var_node, - "Scale": scale_node, - "ZeroPoint": zero_point_node - }, - outputs={"Y": out_node}) - graph.link_to(quant_var_node, dequant_op_node) - graph.link_to(scale_node, dequant_op_node) - if zero_point_node is not None: - graph.link_to(zero_point_node, dequant_op_node) - graph.link_to(dequant_op_node, out_node) + quant_var_node = graph.create_var_node( + name=self._quantized_var_name(x_node.name()), + var_type=x_node.type(), + shape=x_node.shape(), + var_dtype=x_node.dtype()) + quant_op_node = graph.create_op_node( + op_type="quantize_linear", + attrs={"quant_axis": quant_axis, + "bit_length": bit_length}, + inputs={ + "X": x_node, + "Scale": scale_node, + "ZeroPoint": zero_point_node + }, + outputs={"Y": quant_var_node}) + graph.link_to(x_node, quant_op_node) + graph.link_to(scale_node, quant_op_node) + if zero_point_node is not None: + graph.link_to(zero_point_node, quant_op_node) + graph.link_to(quant_op_node, quant_var_node) + dequant_op_node = graph.create_op_node( + op_type="dequantize_linear", + attrs={"quant_axis": quant_axis, + "bit_length": bit_length}, + inputs={ + "X": quant_var_node, + "Scale": scale_node, + "ZeroPoint": zero_point_node + }, + outputs={"Y": out_node}) + graph.link_to(quant_var_node, dequant_op_node) + graph.link_to(scale_node, dequant_op_node) + if zero_point_node is not None: + graph.link_to(zero_point_node, dequant_op_node) + graph.link_to(dequant_op_node, out_node) def _quantized_var_name(self, var_name): """ @@ -2551,8 +2473,6 @@ def apply(self, graph): # cast weight type to int if self._quant_bits == 8: save_weight_dtype = np.int8 - elif self._quant_bits == 4: - save_weight_dtype = np.int4 quantized_param_v = quantized_param_v.astype( save_weight_dtype) self._restore_var(x_node.name(), quantized_param_v) From 536838d3491eb92947d2403ba5dfb88f329f5d68 Mon Sep 17 00:00:00 2001 From: yghstill <742925032@qq.com> Date: Sun, 3 Apr 2022 09:31:56 +0000 Subject: [PATCH 12/14] fix coverage --- .../slim/tests/test_quantization_pass.py | 125 ++++++++++++++++++ 1 file changed, 125 insertions(+) diff --git a/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py b/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py index ccd23485c3d9a..fe261237f1227 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py @@ -21,6 +21,7 @@ import paddle from paddle.fluid.framework import IrGraph from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass +from paddle.fluid.contrib.slim.quantization import QuantizationTransformPassV2 from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass from paddle.fluid.contrib.slim.quantization import TransformForMobilePass @@ -686,5 +687,129 @@ def test_residual_block_skip_pattern_1(self): for_ci=True) +class TestQuantizationTransformPassV2(unittest.TestCase): + def setUp(self): + self.quantizable_op_and_inputs = { + 'conv2d': ['Input', 'Filter'], + 'depthwise_conv2d': ['Input', 'Filter'], + 'mul': ['X', 'Y'] + } + self.quantizable_grad_op_inputs = { + 'conv2d_grad': ['Input', 'Filter'], + 'depthwise_conv2d_grad': ['Input', 'Filter'], + 'mul_grad': ['X', 'Y'] + } + + def check_program(self, program): + quantized_ops = set() + for block in program.blocks: + for op in block.ops: + # check forward + if op.type in self.quantizable_op_and_inputs: + for arg_name in op.input_arg_names: + self.assertTrue( + arg_name.endswith('.quantized.dequantized')) + quantized_ops.add(arg_name) + + for op in block.ops: + # check backward + if op.type in self.quantizable_grad_op_inputs: + for pname in self.quantizable_grad_op_inputs[op.type]: + arg_name = op.input(pname)[0] + self.assertTrue( + arg_name.endswith('.quantized.dequantized')) + self.assertTrue(arg_name in quantized_ops) + + def linear_fc_quant(self, + activation_quant_type, + weight_quantize_type, + for_ci=True): + main = fluid.Program() + startup = fluid.Program() + with fluid.program_guard(main, startup): + loss = linear_fc(3) + opt = fluid.optimizer.Adam(learning_rate=0.001) + opt.minimize(loss) + place = fluid.CPUPlace() + graph = IrGraph(core.Graph(main.desc), for_test=False) + transform_pass = QuantizationTransformPassV2( + scope=fluid.global_scope(), + place=place, + activation_quantize_type=activation_quant_type, + weight_quantize_type=weight_quantize_type) + transform_pass.apply(graph) + if not for_ci: + marked_nodes = set() + for op in graph.all_op_nodes(): + if op.name().find('quantize') > -1: + marked_nodes.add(op) + graph.draw('.', 'quantize_fc_' + activation_quant_type, + marked_nodes) + program = graph.to_program() + self.check_program(program) + val_graph = IrGraph(core.Graph(program.desc), for_test=False) + if not for_ci: + val_marked_nodes = set() + for op in val_graph.all_op_nodes(): + if op.name().find('quantize') > -1: + val_marked_nodes.add(op) + val_graph.draw('.', 'val_fc_' + activation_quant_type, + val_marked_nodes) + + def test_linear_fc_quant_abs_max(self): + self.linear_fc_quant('abs_max', 'abs_max', for_ci=True) + + def test_linear_fc_quant_channel_wise_abs_max(self): + self.linear_fc_quant('abs_max', 'channel_wise_abs_max', for_ci=True) + + def residual_block_quant(self, + activation_quant_type, + weight_quantize_type, + quantizable_op_type, + for_ci=True): + main = fluid.Program() + startup = fluid.Program() + with fluid.program_guard(main, startup): + loss = residual_block(2) + opt = fluid.optimizer.Adam(learning_rate=0.001) + opt.minimize(loss) + place = fluid.CPUPlace() + graph = IrGraph(core.Graph(main.desc), for_test=False) + transform_pass = QuantizationTransformPass( + scope=fluid.global_scope(), + place=place, + activation_quantize_type=activation_quant_type, + weight_quantize_type=weight_quantize_type, + quantizable_op_type=quantizable_op_type) + transform_pass.apply(graph) + if not for_ci: + marked_nodes = set() + for op in graph.all_op_nodes(): + if op.name().find('quantize') > -1: + marked_nodes.add(op) + graph.draw('.', 'quantize_residual_' + activation_quant_type, + marked_nodes) + program = graph.to_program() + self.check_program(program) + val_graph = IrGraph(core.Graph(program.desc), for_test=False) + if not for_ci: + val_marked_nodes = set() + for op in val_graph.all_op_nodes(): + if op.name().find('quantize') > -1: + val_marked_nodes.add(op) + val_graph.draw('.', 'val_residual_' + activation_quant_type, + val_marked_nodes) + + def test_residual_block_abs_max(self): + quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul', 'matmul'] + self.residual_block_quant( + 'abs_max', 'abs_max', quantizable_op_type, for_ci=True) + + def test_residual_block_channel_wise_abs_max(self): + quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul', 'matmul'] + self.residual_block_quant( + 'abs_max', 'channel_wise_abs_max', quantizable_op_type, for_ci=True) + + if __name__ == '__main__': unittest.main() From 7c2bec35e718a9001e6788824184119d146f8e59 Mon Sep 17 00:00:00 2001 From: yghstill <742925032@qq.com> Date: Mon, 4 Apr 2022 01:37:45 +0000 Subject: [PATCH 13/14] fix CI-iScan-Python --- .../paddle/fluid/tests/unittests/test_fake_dequantize_op.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py b/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py index ee40c10fed3e5..728e178845c9b 100644 --- a/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py +++ b/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py @@ -205,7 +205,6 @@ def set_args(self): self.bit_length = 8 self.data_type = "float32" self.quant_axis = 1 - print('TestChannelWiseDequantizeOp1:') class TestDequantizeOp(OpTest): @@ -214,7 +213,6 @@ def set_args(self): self.quant_axis = -1 self.max_range = math.pow(2, self.bit_length - 1) - 1 self.data_type = "float32" - print('TestDequantizeOp:') def setUp(self): self.set_args() @@ -242,16 +240,14 @@ def set_args(self): self.max_range = math.pow(2, self.bit_length - 1) - 1 self.data_type = "float64" self.quant_axis = -1 - print('TestDequantizeOpDouble:') -class TestFakeDequantizeMaxAbsOp5Bits(TestDequantizeOp): +class TestDequantizeOp5Bits(TestDequantizeOp): def set_args(self): self.bit_length = 5 self.max_range = math.pow(2, self.bit_length - 1) - 1 self.data_type = "float32" self.quant_axis = -1 - print('TestFakeDequantizeMaxAbsOp5Bits:') if __name__ == "__main__": From 324e04c639d2359734a221e81905bb64a992100e Mon Sep 17 00:00:00 2001 From: yghstill <742925032@qq.com> Date: Mon, 4 Apr 2022 04:01:48 +0000 Subject: [PATCH 14/14] add code comments --- .../slim/quantization/quantization_pass.py | 149 +++++++++++++++++- 1 file changed, 143 insertions(+), 6 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index 9b589d1549806..17ddedd9d300a 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -1780,6 +1780,17 @@ def _inser_quant_dequant_moving_average_abs_max_op(self, graph, var_node, class InsertQuantizeLinear(object): """ Insert quantize_linear and dequantize_linear op before ops. + + Args: + place(paddle.CPUPlace|paddle.CUDAPlace|str): place is used to restore the weight tensors. + If it's string, It can be ``cpu``, and ``gpu:x``, where ``x`` is the index of the GPUs. + scope(paddle.Scope): scope is used to get the weight tensor values. + quant_bits(int, optional): quantization bit number for weight. Default is 8. + quant_axis(int, optional): quantization dimension of channels. When it is greater than or + equal to 0, it will quantization with per channel, else quantization with per layer. + Default is -1. + channel_wise(bool, optional): Whether quantization with per channel or not. Default is False. + is_test(bool, optional): Whether quantization with training or not. Default is True. """ def __init__(self, @@ -1956,7 +1967,79 @@ def __init__(self, act_preprocess_func=None, optimizer_func=None, executor=None): + r""" + Args: + scope(paddle.Scope): When activation use 'range_abs_max' as the quantize + type, this pass will create some new parameters. The scope is used to + initialize these new parameters. + place(paddle.CPUPlace|paddle.CUDAPlace|str): place is used to initialize new + parameters described above. If it's string, It can be ``cpu``, and ``gpu:x``, + where ``x`` is the index of the GPUs. + weight_bits(int): quantization bit number for weights, + the bias is not quantized. + activation_bits(int): quantization bit number for activation. + activation_quantize_type(str): quantization type for activation, + now support 'abs_max', 'range_abs_max' and 'moving_average_abs_max'. + If use 'abs_max' mode, the quantization scale will be calculated + dynamically each step in both training and testing period. If use + 'range_abs_max', a static quantization scale will be calculated + during training and used in inference. + weight_quantize_type(str): quantization type for weights, + support 'abs_max' and 'channel_wise_abs_max'. The 'range_abs_max' + usually is not used for weight, since weights are fixed once the + model is well trained. + window_size(int): the window size for 'range_abs_max' quantization. + moving_rate(float): the param for 'moving_average_abs_max' quantization. + skip_pattern(str or str list): The user-defined quantization skip pattern, which + will be presented in the name scope of an op. When the skip pattern is + detected in an op's name scope, the corresponding op will not be quantized. + quantizable_op_type(list[str]): List the type of ops that will be quantized. + Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in + QuantizationFreezePass and ConvertToInt8Pass must be the same as this. + weight_quantize_func(function): Function that defines how to quantize weight. + Using this can quickly test if user's quantization method works or not. + In this function, user should both define quantization function and + dequantization function, that is, the function's input is non-quantized + weight and function returns dequantized weight. If None, will use + quantization op defined by 'weight_quantize_type'. Default is None. + act_quantize_func(function): Function that defines how to quantize activation. + Using this can quickly test if user's quantization method works or not. + In this function, user should both define quantization and dequantization + process, that is, the function's input is non-quantized activation and + function returns dequantized activation. If None, will use quantization + op defined by 'activation_quantize_type'. Default is None. + weight_preprocess_func(function): Function that defines how to preprocess + weight before quantization. Using this can quickly test if user's preprocess + method works or not. The function's input is non-quantized weight and + function returns processed weight to be quantized. If None, the weight will + be quantized directly. Default is None. + act_preprocess_func(function): Function that defines how to preprocess + activation before quantization. Using this can quickly test if user's + preprocess method works or not. The function's input is non-quantized + activation and function returns processed activation to be quantized. + If None, the activation will be quantized directly. Default is None. + optimizer_func(function): Fuction return a optimizer. When 'is_test' is + False and user want to use self-defined quantization function and + preprocess function, this function must be set. Default is None. + executor(paddle.Executor): If user want to use self-defined quantization + function and preprocess function, executor must be set for initialization. + Default is None. + + Examples: + .. code-block:: python + # The original graph will be rewrite. + import paddle + from paddle.fluid.contrib.slim.quantization \ + import QuantizationTransformPassV2 + from paddle.fluid.contrib.slim.graph import IrGraph + from paddle.fluid import core + graph = IrGraph(core.Graph(program.desc), for_test=False) + place = paddle.CPUPlace() + scope = paddle.static.global_scope() + transform_pass = QuantizationTransformPassV2(scope, place) + transform_pass.apply(graph) + """ self._scope = scope self._place = _get_paddle_place(place) self._weight_bits = weight_bits @@ -2186,11 +2269,9 @@ def __init__(self, quantizable_op_type=["elementwise_add", "pool2d"], is_full_quantized=False): """ - Constructor. - Args: - scope(fluid.Scope): The scope is used to initialize these new parameters. - place(fluid.CPUPlace|fluid.CUDAPlace|str): place is used to initialize new + scope(paddle.Scope): The scope is used to initialize these new parameters. + place(paddle.CPUPlace|paddle.CUDAPlace|str): place is used to initialize new parameters described above. If ``place`` is string, it can be It can be ``cpu`` or ``gpu:x``, where ``x`` is the index of the GPUs. moving_rate(float, optional): the param for 'quant_dequant_moving_average_abs_max' @@ -2206,6 +2287,21 @@ def __init__(self, quantization to all supported quantizable op type. If set is_full_quantized as False, only apply quantization to the op type according to the input quantizable_op_type. + + Examples: + .. code-block:: python + # The original graph will be rewrite. + import paddle + from paddle.fluid.contrib.slim.quantization \ + import AddQuantDequantPassV2 + from paddle.fluid.contrib.slim.graph import IrGraph + from paddle.fluid import core + + graph = IrGraph(core.Graph(program.desc), for_test=False) + place = paddle.CPUPlace() + scope = paddle.static.global_scope() + add_quant_dequant_pass = AddQuantDequantPassV2(scope, place) + add_quant_dequant_pass.apply(graph) """ self._scope = scope self._place = _get_paddle_place(place) @@ -2303,7 +2399,33 @@ def apply(self, graph): class ReplaceFakeQuantDequantPass(object): + """ + replace quant-dequant ops with quantize_linear and dequantize_linear ops. + """ + def __init__(self, scope, place): + r""" + Args: + scope(paddle.Scope): The scope is used to initialize these new parameters. + place(paddle.CPUPlace|paddle.CUDAPlace|str): place is used to initialize new + parameters described above. If ``place`` is string, it can be It can be ``cpu`` + or ``gpu:x``, where ``x`` is the index of the GPUs. + + Examples: + .. code-block:: python + # The original graph will be rewrite. + import paddle + from paddle.fluid.contrib.slim.quantization \ + import ReplaceFakeQuantDequantPass + from paddle.fluid.contrib.slim.graph import IrGraph + from paddle.fluid import core + + graph = IrGraph(core.Graph(program.desc), for_test=False) + place = paddle.CPUPlace() + scope = paddle.static.global_scope() + replace_pass = ReplaceFakeQuantDequantPass(scope, place) + replace_pass.apply(graph) + """ self._place = _get_paddle_place(place) self._scope = scope assert self._scope != None, "scope must not be None." @@ -2407,13 +2529,28 @@ class QuantWeightPass(object): and weight will be scaled offline. Args: - scope(fluid.Scope): scope is used to get the weight tensor values. - place(fluid.CPUPlace|fluid.CUDAPlace|str): place is used to restore the weight tensors. + scope(paddle.Scope): scope is used to get the weight tensor values. + place(paddle.CPUPlace|paddle.CUDAPlace|str): place is used to restore the weight tensors. If it's string, It can be ``cpu``, and ``gpu:x``, where ``x`` is the index of the GPUs. bias_correction(bool): whether use bias correction for post-training quantization. https://arxiv.org/abs/1810.05723. quant_bits(int, optional): quantization bit number for weight. Default is 8. save_int_weight(bool, optional): Whether the type saving the weight is int. Default is True. + + Examples: + .. code-block:: python + # The original graph will be rewrite. + import paddle + from paddle.fluid.contrib.slim.quantization \ + import QuantWeightPass + from paddle.fluid.contrib.slim.graph import IrGraph + from paddle.fluid import core + + graph = IrGraph(core.Graph(program.desc), for_test=False) + place = paddle.CPUPlace() + scope = paddle.static.global_scope() + quant_weight_pass = QuantWeightPass(scope, place) + quant_weight_pass.apply(graph) """ def __init__(self,