diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index eb1e436807a77..64059475275eb 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1422,7 +1422,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope, enable_cache_runtime_context_ = true; if (this->Type() == "fused_multi_transformer_int8" || this->Type() == "fused_multi_transformer_moe_int8" || - this->Type() == "fused_multi_transformer_moe_weight_only") + this->Type() == "fused_multi_transformer_moe_weight_only" || + this->Type() == "fused_multi_transformer_weight_only") enable_cache_runtime_context_ = true; if (!all_kernels_must_compute_runtime_shape_ && HasAttr(kAllKernelsMustComputeRuntimeShape)) diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt index 249473d44f698..1a907ef884d44 100755 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -26,6 +26,7 @@ register_operators( fused_multi_transformer_int8_op fused_multi_transformer_moe_op fused_multi_transformer_moe_weight_only_op + fused_multi_transformer_weight_only_op fused_multi_transformer_moe_int8_op fused_bias_dropout_residual_layer_norm_op resnet_unit_op @@ -126,6 +127,7 @@ if(WITH_GPU OR WITH_ROCM) op_library(fused_multi_transformer_int8_op) op_library(fused_multi_transformer_moe_op) op_library(fused_multi_transformer_moe_weight_only_op) + op_library(fused_multi_transformer_weight_only_op) op_library(fused_multi_transformer_moe_int8_op) op_library(fused_bias_dropout_residual_layer_norm_op) endif() diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_weight_only_op.cc b/paddle/fluid/operators/fused/fused_multi_transformer_weight_only_op.cc new file mode 100644 index 0000000000000..4a4fb2c17b1c4 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_multi_transformer_weight_only_op.cc @@ -0,0 +1,311 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace operators { + +class FusedMultiTransformerWeightOnlyOp : public framework::OperatorWithKernel { + private: + static constexpr const char *OpName = "FusedMultiTransformerWeightOnlyOp"; + + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { +#define CHECK_INPUT(name) \ + OP_INOUT_CHECK(ctx->HasInput(#name), "Input", #name, OpName) +#define CHECK_INPUTS(name) \ + OP_INOUT_CHECK(ctx->HasInputs(#name), "Input", #name, OpName) +#define CHECK_OUTPUT(name) \ + OP_INOUT_CHECK(ctx->HasOutput(#name), "Output", #name, OpName) +#define CHECK_OUTPUTS(name) \ + OP_INOUT_CHECK(ctx->HasOutputs(#name), "Output", #name, OpName) + + CHECK_INPUT(X); + + // attention + CHECK_INPUTS(QKVW); + CHECK_INPUTS(OutLinearW); + + if (ctx->HasInput("TimeStep")) { + CHECK_INPUTS(CacheKV); + } + + if (ctx->HasInputs("CacheKV")) { + CHECK_OUTPUTS(CacheKVOut); + } + + // ffn + CHECK_INPUTS(FFN1Weight); + CHECK_INPUTS(FFN2Weight); + + CHECK_OUTPUT(Out); + + // x: qkv's input [batch_size, seq_len, dim_embed] + // y: qkv's weight: [3, num_head, dim_head, dim_embed] + auto x_dim = ctx->GetInputDim("X"); + auto y_dim = ctx->GetInputsDim("QKVW")[0]; + //bool trans_qkvw = ctx->Attrs().Get("trans_qkvw"); + PADDLE_ENFORCE_EQ( + x_dim.size(), + 3, + platform::errors::InvalidArgument("The dimensions of x must be 3" + "(batch_size, seq_len, dim_embed)," + "but received dimensions of" + "Input is [%d]", + x_dim.size())); + PADDLE_ENFORCE_EQ(y_dim.size(), + 4, + platform::errors::InvalidArgument( + "The dimensions of qkv_weight must be 4" + "(3, num_head, dim_head, dim_embed)," + "but received dimensions of" + "Input is [%d]", + y_dim.size())); + PADDLE_ENFORCE_EQ( + x_dim[2], + y_dim[3], + platform::errors::InvalidArgument( + "ShapeError: the dimension of x_dim[2] and y_dim[3]" + "must be equal. But received: the shape " + "of input x = [%s], and the shape of " + "input qkv_weight = [%s]", + x_dim, + y_dim)); + + if (ctx->HasInputs("CacheKV")) { + // [2, batch_size, num_head, max_seq_len, head_size] + const auto &c_dims = ctx->GetInputsDim("CacheKV"); + const auto &c_dim = c_dims[0]; + + PADDLE_ENFORCE_EQ( + c_dim.size(), + 5, + paddle::platform::errors::InvalidArgument( + "The CacheKV must be 5 dims, but got %d", c_dim.size())); + PADDLE_ENFORCE_EQ(c_dim[0], + 2, + paddle::platform::errors::InvalidArgument( + "The first dim of CacheKV must be 2, but got %d", + c_dim[0])); // 2 + PADDLE_ENFORCE_EQ(c_dim[2], + y_dim[1], + paddle::platform::errors::InvalidArgument( + "The third dim of CacheKV must be equal with num " + "head %d, but got %d", + y_dim[1], + c_dim[2])); // num_head + PADDLE_ENFORCE_EQ(c_dim[4], + y_dim[2], + paddle::platform::errors::InvalidArgument( + "The fifth dim of CacheKV must be equal with head " + "size %d, but got %d", + y_dim[2], + c_dim[4])); // head_size + } + + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + //auto *input_x = ctx.Input("X"); + //VLOG(0) << "input x type: " << (*input_x).dtype(); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } + + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, + const phi::DenseTensor &tensor, + const framework::OpKernelType &expected_kernel_type) const override { + if (var_name == "TimeStep") { + VLOG(10) << "var_name:" << var_name << " need not to transform"; + return expected_kernel_type; + } + return framework::OpKernelType( + expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + } +}; + +class FusedMultiTransformerWeightOnlyOpMaker + : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "The input tensor."); + AddInput("LnScale", + "Scale is a 1-dimensional tensor of size " + "H. Here, H represents the last dimension of its input tensor.") + .AsDuplicable(); + AddInput("LnBias", + "Bias is a 1-dimensional tensor of size " + "H. Here, H represents the last dimension of its input tensor.") + .AsDuplicable(); + AddInput("QKVW", "The qkv weight tensor.").AsDuplicable(); + AddInput("QKVWScale", "The qkv weight scale tensor.").AsDuplicable(); + AddInput("QKVBias", "The qkv bias tensor.").AsDispensable().AsDuplicable(); + AddInput("CacheKV", "(optional) The cached KV for generation inference.") + .AsDispensable() + .AsDuplicable(); + AddInput("PreCaches", + "(optional) The prefix caches for generation inference.") + .AsDispensable() + .AsDuplicable(); + AddInput("RotaryPosEmb", + "(optional) The RoPE embeddings for generation inference.") + .AsDispensable(); + AddInput("BeamCacheOffset", + "(optional) The offset of CacheKV when using BeamSearch.") + .AsDispensable(); + AddInput("TimeStep", + "(optional, int) The time step for generation inference.") + .AsDispensable(); + AddInput("SeqLengths", "(optional) The sequence length tensor of inputs.") + .AsDispensable(); + AddInput("SrcMask", "(optional) The attention mask tensor in fmha.") + .AsDispensable(); + AddInput("OutLinearW", "The out_linear weight tensor.") + .AsDuplicable(); + AddInput("OutLinearWScale", "The out_linear weight scale tensor.") + .AsDuplicable(); + AddInput("OutLinearBias", "The out_linear bias tensor.") + .AsDispensable() + .AsDuplicable(); + AddInput("FFNLnScale", "The layer_norm scale of FusedFeedForward op") + .AsDuplicable(); + AddInput("FFNLnBias", "The layer_norm bias of FusedFeedForward op") + .AsDuplicable(); + AddInput("FFN1Weight", "The linear1 weight of FusedFeedForward op") + .AsDuplicable(); + AddInput("FFN1WeightScale", "The ffn1 weight scale tensor.") + .AsDuplicable(); + AddInput("FFN1Bias", "The linear1 bias of FusedFeedForward op") + .AsDispensable() + .AsDuplicable(); + AddInput("FFN2Weight", "The linear2 weight of FusedFeedForward op") + .AsDuplicable(); + AddInput("FFN2WeightScale", "The ffn2 weight scale tensor.") + .AsDuplicable(); + AddInput("FFN2Bias", "The linear2 bias input of FusedFeedForward op") + .AsDispensable() + .AsDuplicable(); + AddOutput("CacheKVOut", "The updated cache KV. Inplace with CacheKV") + .AsDispensable() + .AsDuplicable(); + AddOutput("Out", "Result after multi ."); + AddAttr("pre_layer_norm", + "if true, the attention op uses pre_layer_norm architecure, " + "else, uses post_layer_norm architecuture. " + "[default true].") + .SetDefault(true); + AddAttr("rotary_emb_dims", + "the Attr(dims) for RotaryPosEmb's Computation [default 0].") + .SetDefault(0) + .AddCustomChecker([](const int &rotary_emb_dims) { + PADDLE_ENFORCE_EQ( + rotary_emb_dims >= 0 && rotary_emb_dims <= 2, + true, + platform::errors::InvalidArgument( + "'rotary_emb_dims' in Op(Rotray) should be between" + "0 and 2, But received [%s].", + rotary_emb_dims)); + }); + AddAttr("epsilon", + "Constant for numerical stability [default 1e-5].") + .SetDefault(1e-5) + .AddCustomChecker([](const float &epsilon) { + PADDLE_ENFORCE_EQ(epsilon >= 0.0f && epsilon <= 0.001f, + true, + platform::errors::InvalidArgument( + "'epsilon' in Op(LayerNorm) should be between" + "0.0 and 0.001, But received [%s].", + epsilon)); + }); + + AddAttr("dropout_rate", "Probability of setting units to zero.") + .SetDefault(.5f) + .AddCustomChecker([](const float &drop_p) { + PADDLE_ENFORCE_EQ(drop_p >= 0.0f && drop_p <= 1.0f, + true, + platform::errors::InvalidArgument( + "'dropout_rate' must be between 0.0 and 1.0.")); + }); + + AddAttr("is_test", + "(bool, default false) Set to true for inference only, false " + "for training. Some layers may run faster when this is true.") + .SetDefault(false); + + AddAttr( + "dropout_implementation", + "[\"downgrade_in_infer\"|\"upscale_in_train\"]" + "The meaning is the same as 'attn_dropout_implementation'.") + .SetDefault("downgrade_in_infer") + .AddCustomChecker([](const std::string &type) { + PADDLE_ENFORCE_EQ( + type == "downgrade_in_infer" || type == "upscale_in_train", + true, + platform::errors::InvalidArgument( + "dropout_implementation can only be downgrade_in_infer or " + "upscale_in_train")); + }); + + AddAttr("act_method", "act_method") + .SetDefault("gelu") + .AddCustomChecker([](const std::string &act_type) { + PADDLE_ENFORCE_EQ( + act_type == "gelu" || act_type == "geglu" || act_type == "relu" || act_type == "none", + true, + platform::errors::InvalidArgument( + "Only support `gelu`, `geglu`, `relu`, `none` activation in " + "FusedMultiTransformer. ")); + }); + + AddAttr("weight_dtype", "weight_dtype") + .SetDefault("int8") + .AddCustomChecker([](const std::string &weight_dtype) { + PADDLE_ENFORCE_EQ(weight_dtype == "int8" || weight_dtype == "int4", + true, + platform::errors::InvalidArgument( + "Only support `int8`, `int4` weight dtype in " + "FusedMultiTransformer. ")); + }); + + AddAttr( + "ring_id", + "ring id for tensor model parallel. distributed training and inference") + .SetDefault(-1); + + AddComment(R"DOC(fused multi transformer quant weight only layers op)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR( + fused_multi_transformer_weight_only, + ops::FusedMultiTransformerWeightOnlyOp, + ops::FusedMultiTransformerWeightOnlyOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); + diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_weight_only_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_weight_only_op.cu new file mode 100644 index 0000000000000..52659b60b0d81 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_multi_transformer_weight_only_op.cu @@ -0,0 +1,845 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/fused/fused_multi_transformer_op.h" + +namespace paddle { +namespace operators { + +template +static void PrintMatrix(const T* mat_d, int num, std::string name, int i) { + std::vector tmp(num); + cudaMemcpy(tmp.data(), mat_d, sizeof(T) * num, cudaMemcpyDeviceToHost); + + std::ofstream outfile; + outfile.open(name+".txt", std::ios::app); + std::stringstream ss; + + ss << "begin print " << i << " th layer:" << std::endl; + for (int i = 0; i < num; ++i) { + ss << tmp[i] << " "; + } + ss << std::endl; + outfile << ss.str(); + outfile.close(); +} + + + +template +class FusedMultiTransformerWeightOnlyOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + using U = LayerNormParamType; + auto &dev_ctx = ctx.cuda_device_context(); + + auto *time_step = ctx.Input("TimeStep"); + // 0. input + auto *input_x = ctx.Input("X"); + const auto input_x_dims = input_x->dims(); + int bsz = input_x_dims[0]; + int seq_len = input_x_dims[1]; + int dim_embed = input_x_dims[2]; + int bsz_seq = bsz * seq_len; + LOG(INFO) << "intput X: bsz: " << bsz << ", seq_len: " << seq_len << ", dim_embed: " << dim_embed; + const std::string act_method = ctx.Attr("act_method"); + const std::string none_act = "none"; + bool use_glu = (act_method == "geglu"); + bool remove_padding = false; + auto *sequence_lengths = ctx.Input("SeqLengths"); + if (sequence_lengths) { + remove_padding = true; + } + auto *beam_cache_offset = ctx.Input("BeamCacheOffset"); + int beam_size = 1; + if (beam_cache_offset) { + beam_size = beam_cache_offset->dims()[1]; + } + // LOG(INFO) << "beam_size: " << beam_size; + phi::DenseTensor d_token_tensor; + phi::DenseTensor padding_offset_tensor; + phi::DenseTensor x_remove_padding; + bool encoder_remove_padding = (remove_padding && !time_step); + LOG(INFO) << "remove padding: " << encoder_remove_padding; + int token_num = 0; + + auto *out = ctx.Output("Out"); + auto *from_data = dev_ctx.Alloc(out, out->numel() * sizeof(T)); + + // Init out + if (encoder_remove_padding) { + InitValue(dev_ctx, from_data, out->numel(), static_cast(0.)); + } + + // remove padding in encoder + if (encoder_remove_padding) { + // just for encoder + d_token_tensor.Resize({{1}}); + auto *d_token_num = dev_ctx.Alloc( + &d_token_tensor, d_token_tensor.numel() * sizeof(int)); + // alloc the max size of padding_offset_tensor + padding_offset_tensor.Resize({{bsz_seq}}); + dev_ctx.Alloc(&padding_offset_tensor, + padding_offset_tensor.numel() * sizeof(int)); + InvokeGetPaddingOffset(dev_ctx, + &token_num, + d_token_num, + padding_offset_tensor.data(), + sequence_lengths->data(), + bsz, + seq_len); + padding_offset_tensor.Resize({{token_num}}); + x_remove_padding.Resize({{token_num, dim_embed}}); + dev_ctx.Alloc(&x_remove_padding, x_remove_padding.numel() * sizeof(T)); + InvokeRemovePadding(dev_ctx, + x_remove_padding.data(), + input_x->data(), + padding_offset_tensor.data(), + token_num, + dim_embed); + } else { + token_num = bsz_seq; + } + + if (token_num == 0) { + return; + } + + auto *padding_offset_data = + encoder_remove_padding ? padding_offset_tensor.data() : nullptr; + // whether do weight only quant + + // 1. layer norm + const auto pre_layer_norm = ctx.Attr("pre_layer_norm"); + const float epsilon = ctx.Attr("epsilon"); + auto ln_scales = ctx.MultiInput("LnScale"); + auto ln_biases = ctx.MultiInput("LnBias"); + + auto ln_compute = AttnLayerNorm(dev_ctx, epsilon, token_num, dim_embed); + phi::DenseTensor ln_mean, ln_var; + ln_mean.Resize({{token_num}}); + auto *ln_mean_data = + dev_ctx.Alloc(&ln_mean, ln_mean.numel() * sizeof(U)); + ln_var.Resize({{token_num}}); + auto *ln_var_data = dev_ctx.Alloc(&ln_var, ln_var.numel() * sizeof(U)); + + // 2. qkv + // x: qkv's input [batch_size, seq_len, dim_embed] + // y: qkv's weight: [3, num_head, dim_head, dim_embed] + auto qkv_weights = ctx.MultiInput("QKVW"); + auto qkv_scales = ctx.MultiInput("QKVWScale"); + auto qkv_biases = ctx.MultiInput("QKVBias"); + const std::string weight_dtype = ctx.Attr("weight_dtype"); + //const bool trans_qkvw = ctx.Attr("trans_qkvw"); + const auto qkv_w_dims = qkv_weights[0]->dims(); + //int num_head = trans_qkvw ? qkv_w_dims[1] : qkv_w_dims[2]; + //int dim_head = trans_qkvw ? qkv_w_dims[2] : qkv_w_dims[3]; + int num_head = qkv_w_dims[1]; + int dim_head = qkv_w_dims[2]; + int hidden_size = num_head * dim_head; + LOG(INFO) << "num head: " << num_head << ", dim head: " << dim_head << ", hidden size:" << hidden_size; + int output_size = 3 * hidden_size; + int qkv_output_size = 3 * hidden_size; + int input_size = dim_embed; + //weight only gemm + auto weight_only_gemm = + AttnMatMulWeightOnly(dev_ctx, (weight_dtype == "int4")); + int default_act = weight_only_gemm.GetActivation("none"); + int ffn_act = weight_only_gemm.GetActivation(act_method); + + bool compute_bias = qkv_biases.size() > 0 && time_step == nullptr; + // (transA, transB, compute_bias) = (false, trans_qkvw, false) + // Since we fused QKVBias into QKVBiasAddTransposeSplit kernel, here we + // set compute_bias as false. + const bool trans_qkvw = true; + auto qkv_compute = AttnMatMul(dev_ctx, + false, + trans_qkvw, + token_num, + output_size, + input_size, + /*compute_bias=*/false); + phi::DenseTensor qkv_out; + qkv_out.Resize({{token_num, 3, num_head, dim_head}}); + auto *qkv_out_data = + dev_ctx.Alloc(&qkv_out, qkv_out.numel() * sizeof(T)); + + // 2.1 rotary + auto *rotary_tensor = ctx.Input("RotaryPosEmb"); + const int rotary_emb_dims = ctx.Attr("rotary_emb_dims"); + + // 3. fmha + AttnDropoutParam attn_param( + true, "upscale_in_train", 0.0, true, true, 0, nullptr); + auto fmha_compute = + FMHARef(dev_ctx, bsz, seq_len, num_head, dim_head, attn_param); + auto *src_mask = ctx.Input("SrcMask"); + auto cache_kvs = ctx.MultiInput("CacheKV"); + auto cache_kv_outs = ctx.MultiOutput("CacheKVOut"); + + int cache_offset = 0; + + int time_step_cpu = 0; + if (time_step) { + // VLOG(0) << "time_step: " << *time_step; + time_step_cpu = src_mask->dims()[3] - 1; + // VLOG(0) << "time_step_cpu: " << time_step_cpu; + } + + auto out_seq_len = seq_len; + if (time_step) { + PADDLE_ENFORCE_GT(time_step_cpu, + 0, + platform::errors::PreconditionNotMet( + "The value of time_step must > 0, but now is %d", + time_step_cpu)); + PADDLE_ENFORCE_EQ( + seq_len, + 1, + platform::errors::PreconditionNotMet( + "In decode stage, the seq_len of input must be 1, but now is %d", + seq_len)); + out_seq_len += time_step_cpu; + } else { + out_seq_len += cache_offset; + } + + phi::DenseTensor q_transpose_out, kv_transpose_out, qk_out; + q_transpose_out.Resize({{bsz, num_head, seq_len, dim_head}}); + auto *q_transpose_out_data = + dev_ctx.Alloc(&q_transpose_out, q_transpose_out.numel() * sizeof(T)); + + kv_transpose_out.Resize({{2, bsz, num_head, seq_len, dim_head}}); + auto *kv_transpose_out_data = dev_ctx.Alloc( + &kv_transpose_out, kv_transpose_out.numel() * sizeof(T)); + + if (encoder_remove_padding) { + InitValue(dev_ctx, + q_transpose_out_data, + q_transpose_out.numel(), + static_cast(0.)); + InitValue(dev_ctx, + kv_transpose_out_data, + kv_transpose_out.numel(), + static_cast(0.)); + } + + qk_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); + auto *qk_out_data = dev_ctx.Alloc(&qk_out, qk_out.numel() * sizeof(T)); + + phi::DenseTensor src_mask_out; + + // [2, bs, num_head, cache_seq_len + seq_len, head_dim] + phi::DenseTensor pre_cache_kv_out; + + phi::DenseTensor softmax_out; + phi::DenseTensor attn_dropout_mask_out, attn_dropout_out; + phi::DenseTensor qktv_out, fmha_out; + softmax_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); + auto *softmax_out_data = + dev_ctx.Alloc(&softmax_out, softmax_out.numel() * sizeof(T)); + + qktv_out.Resize({{bsz, num_head, seq_len, dim_head}}); + auto *qktv_out_data = + dev_ctx.Alloc(&qktv_out, qktv_out.numel() * sizeof(T)); + fmha_out.Resize({{bsz, seq_len, num_head, dim_head}}); + auto *fmha_out_data = + dev_ctx.Alloc(&fmha_out, fmha_out.numel() * sizeof(T)); + + // 4. out_linear + auto out_linear_weights = ctx.MultiInput("OutLinearW"); + auto out_linear_scales = ctx.MultiInput("OutLinearWScale"); + auto out_linear_biases = ctx.MultiInput("OutLinearBias"); + int ring_id = ctx.Attr("ring_id"); + // (transA, transB, compute_bias) = (false, false, false) + auto out_linear_compute = AttnMatMul( + dev_ctx, false, false, token_num, dim_embed, hidden_size, false); + + // 5. ln(residual + bias) + DropoutParam dropout_param2(true, 0, true, true, 0.0, nullptr, 0); + FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( + dev_ctx, token_num, dim_embed, dropout_param2, epsilon); + auto ffn_ln_scales = ctx.MultiInput("FFNLnScale"); + auto ffn_ln_biases = ctx.MultiInput("FFNLnBias"); + phi::DenseTensor bias_dropout_residual_out, dropout_mask_out; + T *bias_dropout_residual_out_data = nullptr; + if (pre_layer_norm) { + bias_dropout_residual_out.Resize({{token_num, dim_embed}}); + bias_dropout_residual_out_data = + dev_ctx.Alloc(&bias_dropout_residual_out, + bias_dropout_residual_out.numel() * sizeof(T)); + } + uint8_t *dropout_mask_out_data = nullptr; + + // 6. ffn matmul1 + auto ffn1_weights = ctx.MultiInput("FFN1Weight"); + auto ffn1_weights_scales = + ctx.MultiInput("FFN1WeightScale"); + auto ffn1_biases = ctx.MultiInput("FFN1Bias"); + auto ffn1_weight_dim = ffn1_weights[0]->dims(); + + int dim_ffn = ffn1_weight_dim[0]; + //int dim_ffn = ffn1_weight_dim[1]; + FFNGluHelper ffn1_glu_helper( + dev_ctx, act_method, token_num, dim_ffn / 2, dim_ffn, dim_embed); + auto ffn1_linear_compute = AttnMatMul( + dev_ctx, false, false, token_num, dim_ffn, dim_embed, false); + phi::DenseTensor ffn1_out; + ffn1_out.Resize({{token_num, dim_ffn}}); + auto *ffn1_out_data = + dev_ctx.Alloc(&ffn1_out, ffn1_out.numel() * sizeof(T)); + + // 7. ffn act + bias + DropoutParam ffn1_dropout_param(true, 0, true, true, 0.0, nullptr, 0); + FusedDropoutHelper fused_act_dropout_helper( + dev_ctx, token_num, dim_ffn, ffn1_dropout_param); + phi::DenseTensor ffn1_dropout_out, ffn1_dropout_mask; + int tmp_dim_ffn = dim_ffn; + if (use_glu) tmp_dim_ffn /= 2; + int8_t *ffn1_dropout_mask_data = nullptr; + ffn1_dropout_out.Resize({{token_num, tmp_dim_ffn}}); + auto *ffn1_dropout_out_data = dev_ctx.Alloc( + &ffn1_dropout_out, ffn1_dropout_out.numel() * sizeof(T)); + + // 8. ffn2 matmul + auto ffn2_weights = ctx.MultiInput("FFN2Weight"); + auto ffn2_weights_scales = ctx.MultiInput("FFN2WeightScale"); + auto ffn2_biases = ctx.MultiInput("FFN2Bias"); + auto ffn2_linear_compute = AttnMatMul( + dev_ctx, false, false, token_num, dim_embed, tmp_dim_ffn, false); + + // 9. ffn2 residual bias + DropoutParam ffn2_dropout_param(true, 0, true, true, 0.0, nullptr, 0); + FusedDropoutLayerNormHelper ffn2_fused_dropout_helper( + dev_ctx, token_num, dim_embed, ffn2_dropout_param, epsilon); + + phi::DenseTensor tmp_out, tmp_out_rm_padding; + tmp_out.Resize({{token_num, dim_embed}}); + //dev_ctx.Alloc(&tmp_out, tmp_out.numel() * sizeof(T)); + if (encoder_remove_padding) { + tmp_out_rm_padding.Resize({{token_num, dim_embed}}); + auto *tmp_out_rm_padding_data = dev_ctx.Alloc( + &tmp_out_rm_padding, tmp_out_rm_padding.numel() * sizeof(T)); + } + auto *tmp_out_data = + dev_ctx.Alloc(&tmp_out, tmp_out.numel() * sizeof(T)); + + const T *x_data; + if (encoder_remove_padding) { + x_data = x_remove_padding.data(); + } else { + x_data = input_x->data(); + } + phi::DenseTensor *buf0 = nullptr; + phi::DenseTensor *buf1 = nullptr; + + // step0: x --> buf1 + // step1: buf1 --> buf0 + // step2: buf0 --> buf1 + int layers = qkv_weights.size(); + if (encoder_remove_padding) { + // In the case of variable lengths, the padding needs to be rebuilt + // eventually. So buf0 and buf1 do not need to be changed according to the + // pre_layer_norm and the number of layers. + buf0 = &tmp_out; + buf1 = &tmp_out_rm_padding; + } else { + if (pre_layer_norm) { + //buf1 = &tmp_out; + //buf0 = out; + //buf0->Resize({{token_num, dim_embed}}); + if (layers & 1) { + // odd, set buf1 as out + buf0 = &tmp_out; + buf1 = out; + } else { + // even, set buf0 as out + buf0 = out; + buf1 = &tmp_out; + } + } else { + buf0 = &tmp_out; + buf1 = out; + } + } + + for (int i = 0; i < layers; ++i) { + // step1. layer_norm + if (i == 0 && pre_layer_norm) { + auto *ln_scale_data = ln_scales[i]->data(); + auto *ln_bias_data = ln_biases[i]->data(); + // TODO(wangxi): can remove mean var in inference + ln_compute.ComputeForward(x_data, + ln_scale_data, + ln_bias_data, + buf1->data(), + ln_mean_data, + ln_var_data); + } + + // step2. qkv + const phi::DenseTensor *qkv_bias = + qkv_biases.size() > 0 ? qkv_biases[i] : nullptr; + // NOTE: in decoder stage, bias is fused in fmha + const phi::DenseTensor *bias = time_step ? nullptr : qkv_bias; + if (!pre_layer_norm && i == 0) { + const phi::DenseTensor *tmp_input_x = + (encoder_remove_padding) ? &x_remove_padding : input_x; + weight_only_gemm.Linear( + *tmp_input_x, + *qkv_weights[i], + bias, + *qkv_scales[i], + token_num, + qkv_output_size, + dim_embed, + default_act, + &qkv_out); + //qkv_compute.ComputeForward( + // qkv_weights[i], tmp_input_x, bias, &qkv_out, &qkv_out); + } else { + //qkv_compute.ComputeForward( + // qkv_weights[i], buf1, bias, &qkv_out, &qkv_out); + VLOG(0) << "layer id=" << i << ", qkv input=" << buf1->dims() + << ", weight=" << qkv_weights[i]->dims() + << ", scale=" << qkv_scales[i]->dims() + << ", output=" << qkv_out.dims(); + VLOG(0) << "token num=" << token_num << ", output size=" << qkv_output_size + << ", dim_embed=" << dim_embed; + weight_only_gemm.Linear( + *buf1, + *qkv_weights[i], + bias, + *qkv_scales[i], + token_num, + qkv_output_size, + dim_embed, + default_act, + &qkv_out); + } + + // step3. fmha + const phi::DenseTensor *cache_kv = + cache_kvs.size() > 0 ? cache_kvs[i] : nullptr; + phi::DenseTensor *cache_kv_out = cache_kv ? cache_kv_outs[i] : nullptr; + + if (time_step) { // generation decoder stage + // [2, batch_size, num_head, max_seq_len, head_size] + int max_seq_len = cache_kv->dims()[3]; + fmha(dev_ctx, + qkv_out, + *qkv_bias, + *src_mask, + sequence_lengths, + rotary_tensor, + beam_cache_offset, + cache_kv_out, + &fmha_out, + bsz, + beam_size, + max_seq_len, + num_head, + dim_head, + time_step_cpu, + rotary_emb_dims, + 1. / sqrt(dim_head)); + } else if (cache_kv_out) { // generation context stage + const phi::DenseTensor *pre_cache_kv_tensor = nullptr; + phi::DenseTensor *pre_cache_kv_out_tmp = nullptr; + phi::DenseTensor *src_mask_tmp = nullptr; + const int *sequence_lengths_data = + encoder_remove_padding ? sequence_lengths->data() : nullptr; + qkv_bias_add_transpose_split(dev_ctx, + q_transpose_out_data, + kv_transpose_out_data, + qkv_out_data, + qkv_bias->data(), + padding_offset_data, + token_num, + bsz, + num_head, + seq_len, + dim_head, + compute_bias); + + // q_transpose_out_data [bs, head_num, seq_len, dim_head] + // kv_transpose_out_data [2, bs, head_num, seq_len, dim_head] + if (rotary_emb_dims != 0) { + auto *rotary_emb_data = rotary_tensor->data(); + rotary_qk(dev_ctx, + q_transpose_out_data, + kv_transpose_out_data, + q_transpose_out_data, + kv_transpose_out_data, + rotary_emb_data, + sequence_lengths_data, + rotary_emb_dims, + bsz, + num_head, + seq_len, + dim_head); + } + + phi::DenseTensor *tmp_padding_offset_tensor = + encoder_remove_padding ? &padding_offset_tensor : nullptr; + fmha_compute.ComputeForwardWithoutTranspose(pre_cache_kv_tensor, + src_mask, + tmp_padding_offset_tensor, + &q_transpose_out, + &kv_transpose_out, + pre_cache_kv_out_tmp, + &qk_out, + src_mask_tmp, + &softmax_out, + &attn_dropout_mask_out, + &attn_dropout_out, + &qktv_out, + &fmha_out, + token_num); + + const T *k_ptr = nullptr; + const T *v_ptr = nullptr; + + if (cache_offset > 0) { + // [2, bsz, num_head, cache_offset + seq_len, head_dim] + const T *kv_data = pre_cache_kv_out.data(); + k_ptr = kv_data; + int64_t k_size = bsz * num_head * (seq_len + cache_offset) * dim_head; + v_ptr = k_ptr + k_size; + } else { + // [3, bsz, num_head, seq_len, head_dim] + int64_t k_size = bsz * seq_len * num_head * dim_head; + const T *q_ptr = q_transpose_out_data; + k_ptr = kv_transpose_out_data; + v_ptr = k_ptr + k_size; + } + + // [2, bsz, num_head, max_seq_len, head_dim] + int max_seq_len = cache_kv_out->dims()[3]; + T *cache_kv_data = cache_kv_out->data(); + int64_t cache_k_size = bsz * num_head * max_seq_len * dim_head; + + T *cache_k_ptr = cache_kv_data; + T *cache_v_ptr = cache_kv_data + cache_k_size; + + const int seq_len_tmp = seq_len + cache_offset; + write_cache_kv(dev_ctx, + cache_k_ptr, + cache_v_ptr, + k_ptr, + v_ptr, + sequence_lengths_data, + bsz, + num_head, + seq_len_tmp, + max_seq_len, + dim_head); + } else { // not generation + // TODO(wangxi): can remove dropout in inference + qkv_bias_add_transpose_split(dev_ctx, + q_transpose_out_data, + kv_transpose_out_data, + qkv_out_data, + qkv_bias->data(), + padding_offset_data, + token_num, + bsz, + num_head, + seq_len, + dim_head, + compute_bias); + + // q_transpose_out_data [bs, head_num, seq_len, dim_head] + // kv_transpose_out_data [2, bs, head_num, seq_len, dim_head] + if (rotary_emb_dims != 0) { + auto *rotary_emb_data = rotary_tensor->data(); + const int *sequence_lengths_data = + encoder_remove_padding ? sequence_lengths->data() : nullptr; + rotary_qk(dev_ctx, + q_transpose_out_data, + kv_transpose_out_data, + q_transpose_out_data, + kv_transpose_out_data, + rotary_emb_data, + sequence_lengths_data, + rotary_emb_dims, + bsz, + num_head, + seq_len, + dim_head); + } + phi::DenseTensor *tmp_padding_offset_tensor = + encoder_remove_padding ? &padding_offset_tensor : nullptr; + fmha_compute.ComputeForwardWithoutTranspose(cache_kv, + src_mask, + tmp_padding_offset_tensor, + &q_transpose_out, + &kv_transpose_out, + cache_kv_out, + &qk_out, + nullptr, + &softmax_out, + &attn_dropout_mask_out, + &attn_dropout_out, + &qktv_out, + &fmha_out, + token_num); + } +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step3"; +#endif + VLOG(0) << "layer id=" << i << ", out linear input=" << fmha_out.dims() + << ", weight=" << out_linear_weights[i]->dims() + << ", scale=" << out_linear_scales[i]->dims() + << ", out linear out: " << buf1->dims(); + VLOG(0) << "token num=" << token_num << ", dim embed=" << dim_embed + << ", hidden size=" << hidden_size; + //PrintMatrix(fmha_out_data, bsz*seq_len*num_head*dim_head, "fmha_out", i); + if (pre_layer_norm) { + //out_linear_compute.ComputeForward( + // out_linear_weights[i], &fmha_out, nullptr, buf1, nullptr); + weight_only_gemm.Linear(fmha_out, + *out_linear_weights[i], + nullptr, + *out_linear_scales[i], + token_num, + dim_embed, + hidden_size, + default_act, + buf1); + //PrintMatrix(buf1->data(), token_num * dim_embed, "out_linear_output", i); + AllReduce(*buf1, ring_id, buf1->numel(), dev_ctx); + } else { + //out_linear_compute.ComputeForward( + // out_linear_weights[i], &fmha_out, nullptr, buf0, nullptr); + weight_only_gemm.Linear(fmha_out, + *out_linear_weights[i], + nullptr, + *out_linear_scales[i], + token_num, + dim_embed, + hidden_size, + default_act, + buf0); + AllReduce(*buf0, ring_id, buf0->numel(), dev_ctx); + } +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step4"; +#endif + + // step5. ln(residual + dropout(input + bias)) + if (pre_layer_norm) { + auto *ln_scale_data = ffn_ln_scales[i]->data(); + auto *ln_bias_data = ffn_ln_biases[i]->data(); + auto *out_linear_bias_data = out_linear_biases[i]->data(); + + // inplace + fused_dropout_layernorm_helper.LayernormResidualDropoutBias( + dev_ctx, + buf1->data(), + x_data, + out_linear_bias_data, + ln_scale_data, + ln_bias_data, + bias_dropout_residual_out_data, + dropout_mask_out_data, + buf1->data(), + ln_mean_data, + ln_var_data); + } else { + auto *ln_scale_data = ln_scales[i]->data(); + auto *ln_bias_data = ln_biases[i]->data(); + auto *out_linear_bias_data = out_linear_biases[i]->data(); + auto *residual_data = (i == 0 ? x_data : buf1->data()); + fused_dropout_layernorm_helper.LayernormResidualDropoutBias( + dev_ctx, + buf0->data(), + residual_data, + out_linear_bias_data, + ln_scale_data, + ln_bias_data, + buf0->data(), + dropout_mask_out_data, + buf1->data(), + ln_mean_data, + ln_var_data); + } +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step5"; +#endif + + // step6. ffn matmul1 + /** + if (use_glu) { + ffn1_glu_helper.Compute(buf1, + ffn1_weights[i], + ffn1_biases[i], + &ffn1_out, + &ffn1_dropout_out); + } else { + ffn1_linear_compute.ComputeForward( + ffn1_weights[i], buf1, nullptr, &ffn1_out, nullptr); + } + **/ + + VLOG(0) << "layer id=" << i << ", ffn1 input=" << buf1->dims() + << ", weight=" << ffn1_weights[i]->dims() + << ", scale=" << ffn1_weights_scales[i]->dims() + << ", ffn1 out: " << (ffn1_out).dims(); + VLOG(0) << "token num=" << token_num << ", dim ffn=" << dim_ffn + << ", dim_embed=" << dim_embed; + weight_only_gemm.Linear(*buf1, + *ffn1_weights[i], + nullptr, + *ffn1_weights_scales[i], + token_num, + dim_ffn, + dim_embed, + default_act, + &ffn1_out); +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step6"; +#endif + + // step7. act bias + // TODO(wangxi): remove dropout mask in inference + if (!use_glu) { + fused_act_dropout_helper.DropoutActBias(dev_ctx, + ffn1_out_data, + ffn1_biases[i]->data(), + act_method, + ffn1_dropout_out_data, + ffn1_dropout_mask_data); + } + // step8. ffn2 matmul + if (pre_layer_norm) { + //ffn2_linear_compute.ComputeForward( + // ffn2_weights[i], &ffn1_dropout_out, nullptr, buf1, nullptr); + VLOG(0) << "layer id=" << i << ", ffn2 input=" << ffn1_dropout_out.dims() + << ", weight=" << ffn2_weights[i]->dims() + << ", scale=" << ffn2_weights_scales[i]->dims() + << ", ffn2 out: " << buf1->dims(); + VLOG(0) << "token num=" << token_num << ", dim embed=" << dim_embed + << ", dim_ffn=" << dim_ffn; + weight_only_gemm.Linear(ffn1_dropout_out, + *ffn2_weights[i], + nullptr, + *ffn2_weights_scales[i], + token_num, + dim_embed, + dim_ffn, + default_act, + buf1); + } else { + //ffn2_linear_compute.ComputeForward( + // ffn2_weights[i], &ffn1_dropout_out, nullptr, buf0, nullptr); + weight_only_gemm.Linear(ffn1_dropout_out, + *ffn2_weights[i], + nullptr, + *ffn2_weights_scales[i], + token_num, + dim_embed, + dim_ffn, + default_act, + buf0); + } +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step8.0"; +#endif + + if (pre_layer_norm) { + AllReduce(*buf1, ring_id, buf1->numel(), dev_ctx); + } else { + AllReduce(*buf0, ring_id, buf0->numel(), dev_ctx); + } +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step8.1"; +#endif + + // step9. residual bias + if (pre_layer_norm) { + // TODO(wangxi): remove dropout mask in inference + if (i < layers - 1) { + auto *ln_scale_data = ln_scales[i + 1]->data(); + auto *ln_bias_data = ln_biases[i + 1]->data(); + ffn2_fused_dropout_helper.LayernormResidualDropoutBias( + dev_ctx, + buf1->data(), + bias_dropout_residual_out_data, + ffn2_biases[i]->data(), + ln_scale_data, + ln_bias_data, + buf1->data(), + dropout_mask_out_data, + buf0->data(), + ln_mean_data, + ln_var_data); + } else { + ffn2_fused_dropout_helper.ResidualDropoutBias( + dev_ctx, + buf1->data(), + bias_dropout_residual_out_data, + ffn2_biases[i]->data(), + buf1->data(), + dropout_mask_out_data); + } + } else { + auto *ln_scale_data = ffn_ln_scales[i]->data(); + auto *ln_bias_data = ffn_ln_biases[i]->data(); + ffn2_fused_dropout_helper.LayernormResidualDropoutBias( + dev_ctx, + buf0->data(), + buf1->data(), + ffn2_biases[i]->data(), + ln_scale_data, + ln_bias_data, + buf0->data(), + dropout_mask_out_data, + buf1->data(), + ln_mean_data, + ln_var_data); + } +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step9"; +#endif + if (pre_layer_norm) { + x_data = buf1->data(); + std::swap(buf0, buf1); + } + } + + if (encoder_remove_padding) { + if (pre_layer_norm) { + InvokeRebuildPadding(dev_ctx, + from_data, + buf0->data(), + padding_offset_data, + token_num, + dim_embed); + } else { + InvokeRebuildPadding(dev_ctx, + from_data, + buf1->data(), + padding_offset_data, + token_num, + dim_embed); + } + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_CUDA_KERNEL(fused_multi_transformer_weight_only, + ops::FusedMultiTransformerWeightOnlyOpKernel); diff --git a/paddle/fluid/pybind/op_function_generator.h b/paddle/fluid/pybind/op_function_generator.h index 4334c5ad5544a..8e3b1848d9fb4 100644 --- a/paddle/fluid/pybind/op_function_generator.h +++ b/paddle/fluid/pybind/op_function_generator.h @@ -73,6 +73,29 @@ std::map> op_ins_map = { "FFN1Bias", "FFN2Weight", "FFN2Bias"}}, + {"fused_multi_transformer_weight_only", + {"X", + "LnScale", + "LnBias", + "QKVW", + "QKVWScale", + "QKVBias", + "CacheKV", + "BeamCacheOffset", + "TimeStep", + "SeqLengths", + "SrcMask", + "OutLinearW", + "OutLinearWScale", + "OutLinearBias", + "FFNLnScale", + "FFNLnBias", + "FFN1Weight", + "FFN1WeightScale", + "FFN1Bias", + "FFN2Weight", + "FFN2WeightScale", + "FFN2Bias"}}, {"fused_multi_transformer_moe", {"X", "LnScale", @@ -425,6 +448,7 @@ std::map> op_outs_map = { "Beta2PowOut", "MasterParamOut"}}, {"fused_multi_transformer", {"CacheKVOut", "Out"}}, + {"fused_multi_transformer_weight_only", {"CacheKVOut", "Out"}}, {"fused_multi_transformer_moe", {"CacheKVOut", "Out"}}, {"fused_multi_transformer_moe_weight_only", {"CacheKVOut", "Out"}}, {"fused_multi_transformer_int8", {"CacheKVOut", "Out"}}, @@ -533,6 +557,7 @@ std::map> op_passing_outs_map = { {"split", {"Out"}}, {"concat", {"Out"}}, {"fused_multi_transformer", {"CacheKVOut"}}, + {"fused_multi_transformer_weight_only", {"CacheKVOut"}}, {"fused_multi_transformer_moe", {"CacheKVOut"}}, {"fused_multi_transformer_moe_weight_only", {"CacheKVOut"}}, {"fused_multi_transformer_int8", {"CacheKVOut"}}, diff --git a/python/paddle/incubate/nn/__init__.py b/python/paddle/incubate/nn/__init__.py index 83ce42ccc5bfc..0cad08e11ccad 100644 --- a/python/paddle/incubate/nn/__init__.py +++ b/python/paddle/incubate/nn/__init__.py @@ -23,6 +23,7 @@ from .layer.fused_transformer import FusedMultiTransformerMoe # noqa: F401 from .layer.fused_transformer import FusedMultiTransformerMoeINT8 # noqa: F401 from .layer.fused_transformer import FusedMultiTransformerMoeWeightOnly +from .layer.fused_transformer import FusedMultiTransformerWeightOnly # noqa: F401 __all__ = [ #noqa 'FusedMultiHeadAttention', @@ -36,4 +37,5 @@ 'FusedBiasDropoutResidualLayerNorm', 'FusedMoELayer', 'FusedMultiTransformerMoeWeightOnly', + 'FusedMultiTransformerWeightOnly', ] diff --git a/python/paddle/incubate/nn/layer/fused_transformer.py b/python/paddle/incubate/nn/layer/fused_transformer.py index b4ccb531a6474..8878c4e29b4de 100644 --- a/python/paddle/incubate/nn/layer/fused_transformer.py +++ b/python/paddle/incubate/nn/layer/fused_transformer.py @@ -1220,7 +1220,8 @@ def __init__( ) self.normalize_before = normalize_before - self._dtype = self._helper.get_default_dtype() + self._dtype = "float16" + #self._dtype = self._helper.get_default_dtype() self._epsilon = epsilon self._trans_qkvw = trans_qkvw self._ring_id = ring_id @@ -1461,6 +1462,343 @@ def trans_to_fp16(l): trans_to_fp16(self.ffn2_biases) self._dtype = dtype +class FusedMultiTransformerWeightOnly(Layer): + """ + FusedMultiTransfor on weight quant + """ + def __init__( + self, + embed_dim, + num_heads, + dim_feedforward, + dropout_rate=0.0, + activation="gelu", + weight_dtype="int8", + normalize_before=True, + ln_scale_attrs=None, + ln_bias_attrs=None, + qkv_weight_attrs=None, + qkv_scale_attrs=None, + qkv_bias_attrs=None, + linear_weight_attrs=None, + linear_scale_attrs=None, + linear_bias_attrs=None, + ffn_ln_scale_attrs=None, + ffn_ln_bias_attrs=None, + ffn1_weight_attrs=None, + ffn1_scale_attrs=None, + ffn1_bias_attrs=None, + ffn2_weight_attrs=None, + ffn2_scale_attrs=None, + ffn2_bias_attrs=None, + epsilon=1e-5, + num_layers=-1, + nranks=1, + ring_id=-1, + name=None, + dy_to_st=False, + ): + super(FusedMultiTransformerWeightOnly, self).__init__() + + assert embed_dim > 0, ( + "Expected embed_dim to be greater than 0, " + "but received {}".format(embed_dim) + ) + assert ( + num_heads > 0 + ), "Expected nhead to be greater than 0, " "but received {}".format( + num_heads + ) + assert ( + dim_feedforward > 0 + ), "Expected dim_feedforward to be greater than 0, but received {}".format( + dim_feedforward + ) + + self.normalize_before = normalize_before + #self._dtype = self._helper.get_default_dtype() + self._dtype = "float16" + self._epsilon = epsilon + self._ring_id = ring_id + + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == embed_dim + ), "embed_dim must be divisible by num_heads" + + # tensor model parallel + if nranks > 1: + assert ring_id != -1 + assert num_heads % nranks == 0 + assert dim_feedforward % nranks == 0 + num_heads = num_heads // nranks + #dim_feedforward = dim_feedforward // nranks + self._dim_feedforward = dim_feedforward + self._weight_dtype = weight_dtype + + if isinstance(qkv_weight_attrs, (list, tuple, ParameterList)): + num_layers = len(qkv_weight_attrs) + assert num_layers > 0 + + self.ln_scales, self.ln_biases = ParameterList(), ParameterList() + self.qkv_weights, self.qkv_scales, self.qkv_biases = ParameterList(), ParameterList(), ParameterList() + #self.qkv_weights, self.qkv_biases = ParameterList(), ParameterList() + self.linear_weights, self.linear_scales, self.linear_biases = ParameterList(), ParameterList(), ParameterList() + #self.linear_weights, self.linear_biases = ParameterList(), ParameterList() + self.ffn_ln_scales, self.ffn_ln_biases = ParameterList(), ParameterList() + self.ffn1_weights, self.ffn1_scales, self.ffn1_biases = ParameterList(), ParameterList(), ParameterList() + #self.ffn1_weights, self.ffn1_biases = ParameterList(), ParameterList() + self.ffn2_weights, self.ffn2_scales, self.ffn2_biases = ParameterList(), ParameterList(), ParameterList() + #self.ffn2_weights, self.ffn2_biases = ParameterList(), ParameterList() + def get_attr(attrs, idx): + if isinstance(attrs, (list, tuple, ParameterList)): + assert len(attrs) == num_layers + return attrs[idx] + return attrs + weight_int8 = False if self._weight_dtype == "int4" else True + print(f"_weight_dtype: {self._weight_dtype}, weight_int8: {weight_int8}") + + for i in range(num_layers): + ln_scale_attr = get_attr(ln_scale_attrs, i) + ln_bias_attr = get_attr(ln_bias_attrs, i) + qkv_weight_attr = get_attr(qkv_weight_attrs, i) + qkv_scale_attr = get_attr(qkv_scale_attrs, i) + qkv_bias_attr = get_attr(qkv_bias_attrs, i) + linear_weight_attr = get_attr(linear_weight_attrs, i) + linear_scale_attr = get_attr(linear_scale_attrs, i) + linear_bias_attr = get_attr(linear_bias_attrs, i) + + ffn_ln_scale_attr = get_attr(ffn_ln_scale_attrs, i) + ffn_ln_bias_attr = get_attr(ffn_ln_bias_attrs, i) + ffn1_weight_attr = get_attr(ffn1_weight_attrs, i) + ffn1_scale_attr = get_attr(ffn1_scale_attrs, i) + ffn1_bias_attr = get_attr(ffn1_bias_attrs, i) + ffn2_weight_attr = get_attr(ffn2_weight_attrs, i) + ffn2_scale_attr = get_attr(ffn2_scale_attrs, i) + ffn2_bias_attr = get_attr(ffn2_bias_attrs, i) + + ln_scale = self.create_parameter( + attr=ln_scale_attr, + shape=[embed_dim], + default_initializer=Constant(value=1.0), + dtype="float32", + ) + ln_bias = self.create_parameter( + attr=ln_bias_attr, shape=[embed_dim], is_bias=True, dtype="float32" + ) + qkv_weight = self.create_parameter( + shape=[3, num_heads, self.head_dim, embed_dim], + attr=qkv_weight_attr, + dtype=self._dtype, + is_bias=False, + ) + qkv_scale = self.create_parameter( + shape=[int(3 * num_heads * self.head_dim)], + attr=qkv_scale_attr, + dtype=self._dtype, + is_bias=False, + default_initializer=Constant(value=1.0), + ) + qkv_bias = self.create_parameter( + shape=[3, num_heads, self.head_dim], + attr=qkv_bias_attr, + dtype=self._dtype, + is_bias=True, + ) + ''' + linear_weight = self.create_parameter( + shape=[int(num_heads * self.head_dim), embed_dim], + attr=linear_weight_attr, + dtype=self._dtype, + is_bias=False, + ) + ''' + linear_weight = self.create_parameter( + shape=[embed_dim if weight_int8 else int(embed_dim / 2), + int(num_heads * self.head_dim)], + attr=linear_weight_attr, + dtype=self._dtype, + is_bias=False, + ) + linear_scale = self.create_parameter( + shape=[embed_dim], + attr=linear_scale_attr, + dtype=self._dtype, + is_bias=False, + default_initializer=Constant(1.0), + ) + linear_bias = self.create_parameter( + shape=[embed_dim], + attr=linear_bias_attr, + dtype=self._dtype, + is_bias=True, + ) + ffn_ln_scale = self.create_parameter( + shape=[embed_dim], + attr=ffn_ln_scale_attr, + is_bias=False, + default_initializer=Constant(value=1.0), + dtype="float32", + ) + ffn_ln_bias = self.create_parameter( + shape=[embed_dim], attr=ffn_ln_bias_attr, is_bias=True, dtype="float32" + ) + ''' + ffn1_weight = self.create_parameter( + shape=[embed_dim, dim_feedforward], + attr=ffn1_weight_attr, + dtype=self._dtype, + is_bias=False, + ) + ''' + ffn1_weight = self.create_parameter( + shape=[dim_feedforward, embed_dim], + attr=ffn1_weight_attr, + dtype=self._dtype, + is_bias=False, + ) + ffn1_scale = self.create_parameter( + shape=[dim_feedforward], + attr=ffn1_scale_attr, + dtype=self._dtype, + is_bias=False, + default_initializer=Constant(value=1.0), + ) + ffn1_bias = self.create_parameter( + shape=[dim_feedforward], + attr=ffn1_bias_attr, + dtype=self._dtype, + is_bias=True, + ) + ''' + ffn2_weight = self.create_parameter( + shape=[dim_feedforward, embed_dim], + attr=ffn2_weight_attr, + dtype=self._dtype, + is_bias=False, + ) + ''' + ffn2_weight = self.create_parameter( + shape=[embed_dim, dim_feedforward], + attr=ffn2_weight_attr, + dtype=self._dtype, + is_bias=False, + ) + ffn2_scale = self.create_parameter( + shape=[embed_dim], + attr=ffn2_scale_attr, + dtype=self._dtype, + is_bias=False, + default_initializer=Constant(value=1.0), + ) + ffn2_bias = self.create_parameter( + shape=[embed_dim], + attr=ffn2_bias_attr, + dtype=self._dtype, + is_bias=True, + ) + + # tensor model parallel + if nranks > 1: + # column parallel + _set_var_distributed(qkv_weight) + _set_var_distributed(qkv_bias) + _set_var_distributed(ffn1_weight) + _set_var_distributed(ffn1_bias) + # row parallel + _set_var_distributed(linear_weight) + _set_var_distributed(ffn2_weight) + + self.ln_scales.append(ln_scale) + self.ln_biases.append(ln_bias) + self.qkv_weights.append(qkv_weight) + self.qkv_scales.append(qkv_scale) + self.qkv_biases.append(qkv_bias) + self.linear_weights.append(linear_weight) + self.linear_scales.append(linear_scale) + self.linear_biases.append(linear_bias) + + self.ffn_ln_scales.append(ffn_ln_scale) + self.ffn_ln_biases.append(ffn_ln_bias) + self.ffn1_weights.append(ffn1_weight) + self.ffn1_scales.append(ffn1_scale) + self.ffn1_biases.append(ffn1_bias) + self.ffn2_weights.append(ffn2_weight) + self.ffn2_scales.append(ffn2_scale) + self.ffn2_biases.append(ffn2_bias) + + self.dropout_rate = dropout_rate + self.activation = activation + self.name = name + #trans weight to int8 + self._int8_decorate() + + + def forward(self, src, attn_mask=None, caches=None, seq_lens=None, beam_offset=None, time_step=None): + """ + Applies multi transformer weight only layers on the input. + """ + cache_kv_out, final_out = _legacy_C_ops.fused_multi_transformer_weight_only( + src, + list(self.ln_scales), + list(self.ln_biases), + list(self.qkv_weights), + list(self.qkv_scales), + list(self.qkv_biases), + caches, + beam_offset, + time_step, + seq_lens, + attn_mask, + list(self.linear_weights), + list(self.linear_scales), + list(self.linear_biases), + list(self.ffn_ln_scales), + list(self.ffn_ln_biases), + list(self.ffn1_weights), + list(self.ffn1_scales), + list(self.ffn1_biases), + list(self.ffn2_weights), + list(self.ffn2_scales), + list(self.ffn2_biases), + caches, + 'pre_layer_norm', + self.normalize_before, + 'epsilon', + self._epsilon, + 'dropout_rate', + self.dropout_rate, + 'is_test', + not self.training, + 'dropout_implementation', + 'upscale_in_train', + 'act_method', + self.activation, + 'weight_dtype', + self._weight_dtype, + 'ring_id', + self._ring_id + ) + if caches is not None: + return final_out, cache_kv_out + return final_out + + def _int8_decorate(self): + # tmp fix for amp.decorator(O2) + def trans_to_int8(l): + for param in l: + if param is not None: + with no_grad(): + param_applied = _to_dtype(param, "int8") + trans_to_int8(self.qkv_weights) + trans_to_int8(self.linear_weights) + trans_to_int8(self.ffn1_weights) + trans_to_int8(self.ffn2_weights) + self._dtype = "int8" + + class FusedMultiTransformerINT8(Layer): def __init__(self, @@ -1655,7 +1993,7 @@ def get_attr(attrs, idx): shape=[embed_dim], attr=ffn2_out_scales_attr, dtype="float32", - is_bias=False) + is_bias=False) # tensor model parallel if nranks > 1: @@ -1698,51 +2036,51 @@ def forward(self, src, attn_mask=None, caches=None, seq_lens=None, beam_offset=N forward """ cache_kv_out, final_out = _legacy_C_ops.fused_multi_transformer_int8( - src, - list(self.ln_scales), - list(self.ln_biases), - list(self.qkv_weights), - list(self.qkv_biases), + src, + list(self.ln_scales), + list(self.ln_biases), + list(self.qkv_weights), + list(self.qkv_biases), caches, beam_offset, - time_step, - seq_lens, - attn_mask, - list(self.linear_weights), - list(self.linear_biases), + time_step, + seq_lens, + attn_mask, + list(self.linear_weights), + list(self.linear_biases), list(self.ffn_ln_scales), - list(self.ffn_ln_biases), - list(self.ffn1_weights), - list(self.ffn1_biases), - list(self.ffn2_weights), + list(self.ffn_ln_biases), + list(self.ffn1_weights), + list(self.ffn1_biases), + list(self.ffn2_weights), list(self.ffn2_biases), - list(self.qkv_out_scales), - list(self.out_linear_out_scales), + list(self.qkv_out_scales), + list(self.out_linear_out_scales), list(self.ffn1_out_scales), - list(self.ffn2_out_scales), - caches, - 'qkv_in_scale', + list(self.ffn2_out_scales), + caches, + 'qkv_in_scale', self.qkv_in_scale, - 'out_linear_in_scale', - self.out_linear_in_scale, + 'out_linear_in_scale', + self.out_linear_in_scale, 'ffn1_in_scale', - self.ffn1_in_scale, + self.ffn1_in_scale, 'ffn2_in_scale', - self.ffn2_in_scale, + self.ffn2_in_scale, 'pre_layer_norm', - self.normalize_before, - 'epsilon', - self._epsilon, - 'dropout_rate', + self.normalize_before, + 'epsilon', + self._epsilon, + 'dropout_rate', self.dropout_rate, - 'is_test', - not self.training, - 'dropout_implementation', + 'is_test', + not self.training, + 'dropout_implementation', 'upscale_in_train', - 'act_method', - self.activation, - 'trans_qkvw', - self._trans_qkvw, + 'act_method', + self.activation, + 'trans_qkvw', + self._trans_qkvw, 'ring_id', self._ring_id) @@ -1784,7 +2122,7 @@ class FusedMoELayer(Layer): # dim_feedforward = 128 fused_moe_layer = FusedMoELayer(128, 128, 4, 2) output = fused_moe_layer(input) # [2, 4, 128] - + """ def __init__(self, @@ -1837,8 +2175,8 @@ def __init__(self, is_bias=False ) self.gate_bias = self.create_parameter( - shape=[num_expert * self.world_size], - attr=None, + shape=[num_expert * self.world_size], + attr=None, dtype=self._dtype, is_bias=True ) @@ -1890,7 +2228,7 @@ def get_attr(attrs, idx): self.linear2_weights[i].name = "expert_" + self.linear2_weights[i].name self.linear1_biases[i].name = "expert_" + self.linear1_biases[i].name self.linear2_biases[i].name = "expert_" + self.linear2_biases[i].name - + def forward(self, inp): bsz = inp.shape[0] seq_len = inp.shape[1] @@ -1915,7 +2253,7 @@ def forward(self, inp): self.approximate, ) return out - + def _amp_decorate(self, dtype): # tmp fix for amp.decorator(O2) def trans_to_fp16(l): @@ -2108,8 +2446,8 @@ def get_attr(attrs, idx): is_bias=False ) gate_bias = self.create_parameter( - shape=[num_expert * self.world_size], - attr=gate_bias_attr, + shape=[num_expert * self.world_size], + attr=gate_bias_attr, dtype=self._dtype, is_bias=True ) @@ -2139,7 +2477,7 @@ def get_attr(attrs, idx): expert_bias1_attr = get_attr(expert_bias1_attrs, i * num_expert + j) expert_weight2_attr = get_attr(expert_weight2_attrs, i * num_expert + j) expert_bias2_attr = get_attr(expert_bias2_attrs, i * num_expert + j) - + expert_weight1 = self.create_parameter( shape=[d_model, dim_feedforward] if not gemm_cutlass else [dim_feedforward, d_model], attr=expert_weight1_attr, @@ -2517,8 +2855,8 @@ def get_attr(attrs, idx): is_bias=False ) gate_bias = self.create_parameter( - shape=[num_expert * self.world_size], - attr=gate_bias_attr, + shape=[num_expert * self.world_size], + attr=gate_bias_attr, dtype=self._dtype, is_bias=True ) @@ -2550,7 +2888,7 @@ def get_attr(attrs, idx): expert_bias1_attr = get_attr(expert_bias1_attrs, i * num_expert + j) expert_weight2_attr = get_attr(expert_weight2_attrs, i * num_expert + j) expert_bias2_attr = get_attr(expert_bias2_attrs, i * num_expert + j) - + expert_weight1 = self.create_parameter( # shape=[d_model, dim_feedforward], shape=[dim_feedforward, d_model], @@ -2694,8 +3032,8 @@ def trans_to_int8(l): trans_to_int8(self.expert_weights1) trans_to_int8(self.expert_weights2) self._dtype = "int8" - - + + class FusedMultiTransformerMoeWeightOnly(Layer): """ FusedMultiTransformerMoe @@ -2776,7 +3114,7 @@ def __init__( self._epsilon = epsilon self._ring_id = ring_id self._weight_dtype = weight_dtype - + self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads @@ -2806,7 +3144,7 @@ def get_attr(attrs, idx): assert len(attrs) == num_layers return attrs[idx] return attrs - + weight_int8 = False if self._weight_dtype == "int4" else True print(f"_weight_dtype: {self._weight_dtype}, weight_int8: {weight_int8}") for i in range(num_layers): @@ -2834,9 +3172,9 @@ def get_attr(attrs, idx): attr=ln_bias_attr, shape=[embed_dim], is_bias=True, dtype="float32" ) qkv_weight = self.create_parameter( - shape=[3, - num_heads, - self.head_dim if weight_int8 else int(self.head_dim / 2), + shape=[3, + num_heads, + self.head_dim if weight_int8 else int(self.head_dim / 2), embed_dim], attr=qkv_weight_attr, dtype="uint8", @@ -2857,7 +3195,7 @@ def get_attr(attrs, idx): is_bias=True, ) linear_weight = self.create_parameter( - shape=[embed_dim if weight_int8 else int(embed_dim / 2), + shape=[embed_dim if weight_int8 else int(embed_dim / 2), int(num_heads * self.head_dim)], attr=linear_weight_attr, dtype="uint8", @@ -2895,8 +3233,8 @@ def get_attr(attrs, idx): is_bias=False ) gate_bias = self.create_parameter( - shape=[num_expert * self.world_size], - attr=gate_bias_attr, + shape=[num_expert * self.world_size], + attr=gate_bias_attr, dtype=self._dtype, is_bias=True ) @@ -2930,9 +3268,9 @@ def get_attr(attrs, idx): expert_weight2_attr = get_attr(expert_weight2_attrs, i * num_expert + j) expert_scale2_attr = get_attr(expert_scale2_attrs, i * num_expert + j) expert_bias2_attr = get_attr(expert_bias2_attrs, i * num_expert + j) - + expert_weight1 = self.create_parameter( - shape=[dim_feedforward if weight_int8 else int(dim_feedforward / 2), + shape=[dim_feedforward if weight_int8 else int(dim_feedforward / 2), d_model], attr=expert_weight1_attr, dtype="uint8", @@ -2954,7 +3292,7 @@ def get_attr(attrs, idx): default_initializer=nn.initializer.Constant(value=0.0) ) expert_weight2 = self.create_parameter( - shape=[d_model if weight_int8 else int(d_model / 2), + shape=[d_model if weight_int8 else int(d_model / 2), dim_feedforward], attr=expert_weight2_attr, dtype="uint8", @@ -3058,7 +3396,7 @@ def forward(self, src, attn_mask=None, caches=None, seq_lens=None, beam_offset=N if caches is not None: return final_out, cache_kv_out return final_out - + def _int8_decorate(self): # tmp fix for INT8 def trans_to_int8(l): @@ -3133,6 +3471,4 @@ def shard_tensor(dst_tensor, parent_tensor, pos): self.shared_weights2.append(shared_weight2) self.shared_scales2.append(shared_scale2) - self.shared_biases2.append(shared_bias2) - - + self.shared_biases2.append(shared_bias2) \ No newline at end of file