Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dygraph qat] Use layer to calculate output scale #31861

Merged
merged 6 commits into from
Mar 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 42 additions & 27 deletions paddle/fluid/operators/fake_quantize_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -649,13 +649,18 @@ class MovingAverageAbsMaxScaleOp : public framework::OperatorWithKernel {
"MovingAverageAbsMaxScale");
OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale",
"MovingAverageAbsMaxScale");

if (ctx->HasOutput("OutState")) {
ctx->SetOutputDim("OutState", {1});
}
if (ctx->HasOutput("OutAccum")) {
ctx->SetOutputDim("OutAccum", {1});
}
ctx->SetOutputDim("OutScale", {1});
if (ctx->HasOutput("Out")) {
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->SetOutputDim("OutScale", {1});
ctx->ShareLoD("X", /*->*/ "Out");
}
}

protected:
Expand All @@ -673,6 +678,9 @@ class MovingAverageAbsMaxScaleOpMaker
AddInput("X", "(Tensor) Input is float data type.");
AddInput("InAccum", "Last accum.").AsDispensable();
AddInput("InState", "Last state.").AsDispensable();
AddOutput("Out",
"(Tensor) Output tensor is just equivalent to the input tensor.")
.AsDispensable();
AddOutput("OutScale", " Current scale");
AddOutput("OutState", "(Tensor) state buffer.").AsDispensable();
AddOutput("OutAccum", "(Tensor) accum buffer.").AsDispensable();
Expand All @@ -693,17 +701,17 @@ And it will not quantize the input tensor.
}
};

class FakeQuantDequantGradOp : public framework::OperatorWithKernel {
class StrightThroughEstimatorGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
auto out_grad_name = framework::GradVarName("Out");
auto x_grad_name = framework::GradVarName("X");
OP_INOUT_CHECK(ctx->HasInput(out_grad_name), "Input", out_grad_name,
"FakeQuantDequantGradOp");
"StrightThroughEstimatorGradOp");
OP_INOUT_CHECK(ctx->HasOutput(x_grad_name), "Output", x_grad_name,
"FakeQuantDequantGradOp");
"StrightThroughEstimatorGradOp");

ctx->SetOutputDim(x_grad_name, ctx->GetInputDim(out_grad_name));
}
Expand All @@ -717,13 +725,13 @@ class FakeQuantDequantGradOp : public framework::OperatorWithKernel {
};

template <typename T>
class FakeQuantDequantGradMaker : public framework::SingleGradOpMaker<T> {
class StrightThroughEstimatorMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

protected:
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("fake_quantize_dequantize_grad");
grad_op->SetType("stright_throuth_estimator_grad");
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs());
Expand All @@ -744,11 +752,11 @@ REGISTER_OPERATOR(
REGISTER_OP_CPU_KERNEL(fake_quantize_abs_max,
ops::FakeQuantizeAbsMaxKernel<CPU, float>);

REGISTER_OPERATOR(fake_quantize_dequantize_abs_max,
ops::FakeQuantOrWithDequantAbsMaxOp,
ops::FakeQuantOrWithDequantAbsMaxOpMaker,
ops::FakeQuantDequantGradMaker<paddle::framework::OpDesc>,
ops::FakeQuantDequantGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(
fake_quantize_dequantize_abs_max, ops::FakeQuantOrWithDequantAbsMaxOp,
ops::FakeQuantOrWithDequantAbsMaxOpMaker,
ops::StrightThroughEstimatorMaker<paddle::framework::OpDesc>,
ops::StrightThroughEstimatorMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(fake_quantize_dequantize_abs_max,
ops::FakeQuantizeDequantizeAbsMaxKernel<CPU, float>);

Expand All @@ -769,11 +777,12 @@ REGISTER_OPERATOR(
REGISTER_OP_CPU_KERNEL(fake_quantize_moving_average_abs_max,
ops::FakeQuantizeMovingAverageAbsMaxKernel<CPU, float>);

REGISTER_OPERATOR(fake_quantize_dequantize_moving_average_abs_max,
ops::FakeQuantOrWithDequantMovingAverageAbsMaxOp,
ops::FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker,
ops::FakeQuantDequantGradMaker<paddle::framework::OpDesc>,
ops::FakeQuantDequantGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(
fake_quantize_dequantize_moving_average_abs_max,
ops::FakeQuantOrWithDequantMovingAverageAbsMaxOp,
ops::FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker,
ops::StrightThroughEstimatorMaker<paddle::framework::OpDesc>,
ops::StrightThroughEstimatorMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
fake_quantize_dequantize_moving_average_abs_max,
ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel<CPU, float>);
Expand All @@ -789,20 +798,22 @@ REGISTER_OP_CPU_KERNEL(fake_channel_wise_quantize_abs_max,
REGISTER_OPERATOR(
moving_average_abs_max_scale, ops::MovingAverageAbsMaxScaleOp,
ops::MovingAverageAbsMaxScaleOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
ops::StrightThroughEstimatorMaker<paddle::framework::OpDesc>,
ops::StrightThroughEstimatorMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(moving_average_abs_max_scale,
ops::MovingAverageAbsMaxScaleKernel<CPU, float>);

REGISTER_OPERATOR(fake_quantize_dequantize_grad, ops::FakeQuantDequantGradOp);
REGISTER_OP_CPU_KERNEL(fake_quantize_dequantize_grad,
ops::FakeQuantDequantGradKernel<CPU, float>);
REGISTER_OPERATOR(stright_throuth_estimator_grad,
ops::StrightThroughEstimatorGradOp);
REGISTER_OP_CPU_KERNEL(stright_throuth_estimator_grad,
ops::StrightThroughEstimatorGradKernel<CPU, float>);

REGISTER_OPERATOR(fake_channel_wise_quantize_dequantize_abs_max,
ops::FakeChannelWiseQuantizeDequantizeAbsMaxOp,
ops::FakeChannelWiseQuantizeDequantizeAbsMaxOpMaker,
ops::FakeQuantDequantGradMaker<paddle::framework::OpDesc>,
ops::FakeQuantDequantGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(
fake_channel_wise_quantize_dequantize_abs_max,
ops::FakeChannelWiseQuantizeDequantizeAbsMaxOp,
ops::FakeChannelWiseQuantizeDequantizeAbsMaxOpMaker,
ops::StrightThroughEstimatorMaker<paddle::framework::OpDesc>,
ops::StrightThroughEstimatorMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
fake_channel_wise_quantize_dequantize_abs_max,
ops::FakeChannelWiseQuantizeDequantizeAbsMaxKernel<CPU, float>);
Expand All @@ -820,4 +831,8 @@ REGISTER_OP_VERSION(moving_average_abs_max_scale)
"Out",
"Delete output in order to make the inference model not "
"save moving_average_abs_max_scale operator. This will "
"make the quantitative model be correctly applied in inference."));
"make the quantitative model be correctly applied in inference."))
.AddCheckpoint(
R"ROC(Incompatible upgrade of output [Out])ROC",
paddle::framework::compatible::OpVersionDesc().NewOutput(
"Out", "In order to support dygraph qat, add output again."));
4 changes: 2 additions & 2 deletions paddle/fluid/operators/fake_quantize_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -543,8 +543,8 @@ REGISTER_OP_CUDA_KERNEL(moving_average_abs_max_scale,
REGISTER_OP_CUDA_KERNEL(
fake_quantize_dequantize_moving_average_abs_max,
ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel<CUDA, float>);
REGISTER_OP_CUDA_KERNEL(fake_quantize_dequantize_grad,
ops::FakeQuantDequantGradKernel<CUDA, float>);
REGISTER_OP_CUDA_KERNEL(stright_throuth_estimator_grad,
ops::StrightThroughEstimatorGradKernel<CUDA, float>);
REGISTER_OP_CUDA_KERNEL(
fake_channel_wise_quantize_dequantize_abs_max,
ops::FakeChannelWiseQuantizeDequantizeAbsMaxKernel<CUDA, float>);
16 changes: 11 additions & 5 deletions paddle/fluid/operators/fake_quantize_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,12 @@ class MovingAverageAbsMaxScaleKernel : public framework::OpKernel<T> {
auto* in = context.Input<framework::Tensor>("X");
auto& dev_ctx = context.template device_context<DeviceContext>();

if (context.HasOutput("Out")) {
auto* out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
framework::TensorCopy(*in, context.GetPlace(), dev_ctx, out);
}

bool is_test = context.Attr<bool>("is_test");
// testing
if (is_test) {
Expand Down Expand Up @@ -344,17 +350,17 @@ class MovingAverageAbsMaxScaleKernel : public framework::OpKernel<T> {
};

template <typename DeviceContext, typename T>
class FakeQuantDequantGradKernel : public framework::OpKernel<T> {
class StrightThroughEstimatorGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* d_out =
context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto x_grad_name = framework::GradVarName("X");
auto* d_x = context.Output<framework::LoDTensor>(x_grad_name);
PADDLE_ENFORCE_NOT_NULL(
d_x, platform::errors::PreconditionNotMet(
"FakeQuantDequantGradOp doesn't have the output named %s.",
x_grad_name));
PADDLE_ENFORCE_NOT_NULL(d_x, platform::errors::PreconditionNotMet(
"StrightThroughEstimatorGradKernel "
"doesn't have the output named %s.",
x_grad_name));

// Initialize dx as same as d_out
d_x->mutable_data<T>(context.GetPlace());
Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/pybind/op_function_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
{"matrix_nms", {"Out", "Index", "RoisNum"}},
{"distribute_fpn_proposals",
{"MultiFpnRois", "RestoreIndex", "MultiLevelRoIsNum"}},
{"moving_average_abs_max_scale", {"OutScale", "OutAccum", "OutState"}},
{"moving_average_abs_max_scale",
{"Out", "OutScale", "OutAccum", "OutState"}},
{"multiclass_nms3", {"Out", "NmsRoisNum"}},
{"generate_proposals_v2", {"RpnRois", "RpnRoiProbs", "RpnRoisNum"}},
{"momentum", {"ParamOut", "VelocityOut"}},
Expand Down Expand Up @@ -137,7 +138,8 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"check_finite_and_unscale", {"Out", "FoundInfinite"}},
{"update_loss_scaling",
{"Out", "LossScaling", "OutGoodSteps", "OutBadSteps"}},
{"moving_average_abs_max_scale", {"OutScale", "OutAccum", "OutState"}},
{"moving_average_abs_max_scale",
{"Out", "OutScale", "OutAccum", "OutState"}},
{"lamb",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}},
{"rnn", {"DropoutState"}},
Expand Down
Loading