forked from hutuxian/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add fuse mt weight only quant (#105)
* add fuse mt weight only quant
- Loading branch information
Showing
7 changed files
with
1,587 additions
and
65 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
311 changes: 311 additions & 0 deletions
311
paddle/fluid/operators/fused/fused_multi_transformer_weight_only_op.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,311 @@ | ||
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include <memory> | ||
#include <string> | ||
|
||
#include "paddle/fluid/framework/op_registry.h" | ||
#include "paddle/fluid/framework/op_version_registry.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
class FusedMultiTransformerWeightOnlyOp : public framework::OperatorWithKernel { | ||
private: | ||
static constexpr const char *OpName = "FusedMultiTransformerWeightOnlyOp"; | ||
|
||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
void InferShape(framework::InferShapeContext *ctx) const override { | ||
#define CHECK_INPUT(name) \ | ||
OP_INOUT_CHECK(ctx->HasInput(#name), "Input", #name, OpName) | ||
#define CHECK_INPUTS(name) \ | ||
OP_INOUT_CHECK(ctx->HasInputs(#name), "Input", #name, OpName) | ||
#define CHECK_OUTPUT(name) \ | ||
OP_INOUT_CHECK(ctx->HasOutput(#name), "Output", #name, OpName) | ||
#define CHECK_OUTPUTS(name) \ | ||
OP_INOUT_CHECK(ctx->HasOutputs(#name), "Output", #name, OpName) | ||
|
||
CHECK_INPUT(X); | ||
|
||
// attention | ||
CHECK_INPUTS(QKVW); | ||
CHECK_INPUTS(OutLinearW); | ||
|
||
if (ctx->HasInput("TimeStep")) { | ||
CHECK_INPUTS(CacheKV); | ||
} | ||
|
||
if (ctx->HasInputs("CacheKV")) { | ||
CHECK_OUTPUTS(CacheKVOut); | ||
} | ||
|
||
// ffn | ||
CHECK_INPUTS(FFN1Weight); | ||
CHECK_INPUTS(FFN2Weight); | ||
|
||
CHECK_OUTPUT(Out); | ||
|
||
// x: qkv's input [batch_size, seq_len, dim_embed] | ||
// y: qkv's weight: [3, num_head, dim_head, dim_embed] | ||
auto x_dim = ctx->GetInputDim("X"); | ||
auto y_dim = ctx->GetInputsDim("QKVW")[0]; | ||
//bool trans_qkvw = ctx->Attrs().Get<bool>("trans_qkvw"); | ||
PADDLE_ENFORCE_EQ( | ||
x_dim.size(), | ||
3, | ||
platform::errors::InvalidArgument("The dimensions of x must be 3" | ||
"(batch_size, seq_len, dim_embed)," | ||
"but received dimensions of" | ||
"Input is [%d]", | ||
x_dim.size())); | ||
PADDLE_ENFORCE_EQ(y_dim.size(), | ||
4, | ||
platform::errors::InvalidArgument( | ||
"The dimensions of qkv_weight must be 4" | ||
"(3, num_head, dim_head, dim_embed)," | ||
"but received dimensions of" | ||
"Input is [%d]", | ||
y_dim.size())); | ||
PADDLE_ENFORCE_EQ( | ||
x_dim[2], | ||
y_dim[3], | ||
platform::errors::InvalidArgument( | ||
"ShapeError: the dimension of x_dim[2] and y_dim[3]" | ||
"must be equal. But received: the shape " | ||
"of input x = [%s], and the shape of " | ||
"input qkv_weight = [%s]", | ||
x_dim, | ||
y_dim)); | ||
|
||
if (ctx->HasInputs("CacheKV")) { | ||
// [2, batch_size, num_head, max_seq_len, head_size] | ||
const auto &c_dims = ctx->GetInputsDim("CacheKV"); | ||
const auto &c_dim = c_dims[0]; | ||
|
||
PADDLE_ENFORCE_EQ( | ||
c_dim.size(), | ||
5, | ||
paddle::platform::errors::InvalidArgument( | ||
"The CacheKV must be 5 dims, but got %d", c_dim.size())); | ||
PADDLE_ENFORCE_EQ(c_dim[0], | ||
2, | ||
paddle::platform::errors::InvalidArgument( | ||
"The first dim of CacheKV must be 2, but got %d", | ||
c_dim[0])); // 2 | ||
PADDLE_ENFORCE_EQ(c_dim[2], | ||
y_dim[1], | ||
paddle::platform::errors::InvalidArgument( | ||
"The third dim of CacheKV must be equal with num " | ||
"head %d, but got %d", | ||
y_dim[1], | ||
c_dim[2])); // num_head | ||
PADDLE_ENFORCE_EQ(c_dim[4], | ||
y_dim[2], | ||
paddle::platform::errors::InvalidArgument( | ||
"The fifth dim of CacheKV must be equal with head " | ||
"size %d, but got %d", | ||
y_dim[2], | ||
c_dim[4])); // head_size | ||
} | ||
|
||
ctx->SetOutputDim("Out", ctx->GetInputDim("X")); | ||
} | ||
|
||
protected: | ||
framework::OpKernelType GetExpectedKernelType( | ||
const framework::ExecutionContext &ctx) const override { | ||
//auto *input_x = ctx.Input<phi::DenseTensor>("X"); | ||
//VLOG(0) << "input x type: " << (*input_x).dtype(); | ||
return framework::OpKernelType( | ||
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); | ||
} | ||
|
||
framework::OpKernelType GetKernelTypeForVar( | ||
const std::string &var_name, | ||
const phi::DenseTensor &tensor, | ||
const framework::OpKernelType &expected_kernel_type) const override { | ||
if (var_name == "TimeStep") { | ||
VLOG(10) << "var_name:" << var_name << " need not to transform"; | ||
return expected_kernel_type; | ||
} | ||
return framework::OpKernelType( | ||
expected_kernel_type.data_type_, tensor.place(), tensor.layout()); | ||
} | ||
}; | ||
|
||
class FusedMultiTransformerWeightOnlyOpMaker | ||
: public framework::OpProtoAndCheckerMaker { | ||
public: | ||
void Make() override { | ||
AddInput("X", "The input tensor."); | ||
AddInput("LnScale", | ||
"Scale is a 1-dimensional tensor of size " | ||
"H. Here, H represents the last dimension of its input tensor.") | ||
.AsDuplicable(); | ||
AddInput("LnBias", | ||
"Bias is a 1-dimensional tensor of size " | ||
"H. Here, H represents the last dimension of its input tensor.") | ||
.AsDuplicable(); | ||
AddInput("QKVW", "The qkv weight tensor.").AsDuplicable(); | ||
AddInput("QKVWScale", "The qkv weight scale tensor.").AsDuplicable(); | ||
AddInput("QKVBias", "The qkv bias tensor.").AsDispensable().AsDuplicable(); | ||
AddInput("CacheKV", "(optional) The cached KV for generation inference.") | ||
.AsDispensable() | ||
.AsDuplicable(); | ||
AddInput("PreCaches", | ||
"(optional) The prefix caches for generation inference.") | ||
.AsDispensable() | ||
.AsDuplicable(); | ||
AddInput("RotaryPosEmb", | ||
"(optional) The RoPE embeddings for generation inference.") | ||
.AsDispensable(); | ||
AddInput("BeamCacheOffset", | ||
"(optional) The offset of CacheKV when using BeamSearch.") | ||
.AsDispensable(); | ||
AddInput("TimeStep", | ||
"(optional, int) The time step for generation inference.") | ||
.AsDispensable(); | ||
AddInput("SeqLengths", "(optional) The sequence length tensor of inputs.") | ||
.AsDispensable(); | ||
AddInput("SrcMask", "(optional) The attention mask tensor in fmha.") | ||
.AsDispensable(); | ||
AddInput("OutLinearW", "The out_linear weight tensor.") | ||
.AsDuplicable(); | ||
AddInput("OutLinearWScale", "The out_linear weight scale tensor.") | ||
.AsDuplicable(); | ||
AddInput("OutLinearBias", "The out_linear bias tensor.") | ||
.AsDispensable() | ||
.AsDuplicable(); | ||
AddInput("FFNLnScale", "The layer_norm scale of FusedFeedForward op") | ||
.AsDuplicable(); | ||
AddInput("FFNLnBias", "The layer_norm bias of FusedFeedForward op") | ||
.AsDuplicable(); | ||
AddInput("FFN1Weight", "The linear1 weight of FusedFeedForward op") | ||
.AsDuplicable(); | ||
AddInput("FFN1WeightScale", "The ffn1 weight scale tensor.") | ||
.AsDuplicable(); | ||
AddInput("FFN1Bias", "The linear1 bias of FusedFeedForward op") | ||
.AsDispensable() | ||
.AsDuplicable(); | ||
AddInput("FFN2Weight", "The linear2 weight of FusedFeedForward op") | ||
.AsDuplicable(); | ||
AddInput("FFN2WeightScale", "The ffn2 weight scale tensor.") | ||
.AsDuplicable(); | ||
AddInput("FFN2Bias", "The linear2 bias input of FusedFeedForward op") | ||
.AsDispensable() | ||
.AsDuplicable(); | ||
AddOutput("CacheKVOut", "The updated cache KV. Inplace with CacheKV") | ||
.AsDispensable() | ||
.AsDuplicable(); | ||
AddOutput("Out", "Result after multi ."); | ||
AddAttr<bool>("pre_layer_norm", | ||
"if true, the attention op uses pre_layer_norm architecure, " | ||
"else, uses post_layer_norm architecuture. " | ||
"[default true].") | ||
.SetDefault(true); | ||
AddAttr<int>("rotary_emb_dims", | ||
"the Attr(dims) for RotaryPosEmb's Computation [default 0].") | ||
.SetDefault(0) | ||
.AddCustomChecker([](const int &rotary_emb_dims) { | ||
PADDLE_ENFORCE_EQ( | ||
rotary_emb_dims >= 0 && rotary_emb_dims <= 2, | ||
true, | ||
platform::errors::InvalidArgument( | ||
"'rotary_emb_dims' in Op(Rotray) should be between" | ||
"0 and 2, But received [%s].", | ||
rotary_emb_dims)); | ||
}); | ||
AddAttr<float>("epsilon", | ||
"Constant for numerical stability [default 1e-5].") | ||
.SetDefault(1e-5) | ||
.AddCustomChecker([](const float &epsilon) { | ||
PADDLE_ENFORCE_EQ(epsilon >= 0.0f && epsilon <= 0.001f, | ||
true, | ||
platform::errors::InvalidArgument( | ||
"'epsilon' in Op(LayerNorm) should be between" | ||
"0.0 and 0.001, But received [%s].", | ||
epsilon)); | ||
}); | ||
|
||
AddAttr<float>("dropout_rate", "Probability of setting units to zero.") | ||
.SetDefault(.5f) | ||
.AddCustomChecker([](const float &drop_p) { | ||
PADDLE_ENFORCE_EQ(drop_p >= 0.0f && drop_p <= 1.0f, | ||
true, | ||
platform::errors::InvalidArgument( | ||
"'dropout_rate' must be between 0.0 and 1.0.")); | ||
}); | ||
|
||
AddAttr<bool>("is_test", | ||
"(bool, default false) Set to true for inference only, false " | ||
"for training. Some layers may run faster when this is true.") | ||
.SetDefault(false); | ||
|
||
AddAttr<std::string>( | ||
"dropout_implementation", | ||
"[\"downgrade_in_infer\"|\"upscale_in_train\"]" | ||
"The meaning is the same as 'attn_dropout_implementation'.") | ||
.SetDefault("downgrade_in_infer") | ||
.AddCustomChecker([](const std::string &type) { | ||
PADDLE_ENFORCE_EQ( | ||
type == "downgrade_in_infer" || type == "upscale_in_train", | ||
true, | ||
platform::errors::InvalidArgument( | ||
"dropout_implementation can only be downgrade_in_infer or " | ||
"upscale_in_train")); | ||
}); | ||
|
||
AddAttr<std::string>("act_method", "act_method") | ||
.SetDefault("gelu") | ||
.AddCustomChecker([](const std::string &act_type) { | ||
PADDLE_ENFORCE_EQ( | ||
act_type == "gelu" || act_type == "geglu" || act_type == "relu" || act_type == "none", | ||
true, | ||
platform::errors::InvalidArgument( | ||
"Only support `gelu`, `geglu`, `relu`, `none` activation in " | ||
"FusedMultiTransformer. ")); | ||
}); | ||
|
||
AddAttr<std::string>("weight_dtype", "weight_dtype") | ||
.SetDefault("int8") | ||
.AddCustomChecker([](const std::string &weight_dtype) { | ||
PADDLE_ENFORCE_EQ(weight_dtype == "int8" || weight_dtype == "int4", | ||
true, | ||
platform::errors::InvalidArgument( | ||
"Only support `int8`, `int4` weight dtype in " | ||
"FusedMultiTransformer. ")); | ||
}); | ||
|
||
AddAttr<int>( | ||
"ring_id", | ||
"ring id for tensor model parallel. distributed training and inference") | ||
.SetDefault(-1); | ||
|
||
AddComment(R"DOC(fused multi transformer quant weight only layers op)DOC"); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
REGISTER_OPERATOR( | ||
fused_multi_transformer_weight_only, | ||
ops::FusedMultiTransformerWeightOnlyOp, | ||
ops::FusedMultiTransformerWeightOnlyOpMaker, | ||
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, | ||
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); | ||
|
Oops, something went wrong.