Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#5 from tianyan01/v2.4.2
Browse files Browse the repository at this point in the history
V2.4.2
  • Loading branch information
laipaang committed Nov 23, 2023
2 parents 5063c04 + 241c6d5 commit aed70d9
Show file tree
Hide file tree
Showing 33 changed files with 7,362 additions and 827 deletions.
2 changes: 2 additions & 0 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1420,6 +1420,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
// result of HasAttr.
if (!enable_cache_runtime_context_ && HasAttr(kEnableCacheRuntimeContext))
enable_cache_runtime_context_ = true;
if (this->Type() == "fused_multi_transformer_int8" || this->Type() == "fused_multi_transformer_moe_int8")
enable_cache_runtime_context_ = true;
if (!all_kernels_must_compute_runtime_shape_ &&
HasAttr(kAllKernelsMustComputeRuntimeShape))
all_kernels_must_compute_runtime_shape_ = true;
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/operators/fused/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ register_operators(
fused_feedforward_op
fused_multi_transformer_op
fused_multi_transformer_int8_op
fused_multi_transformer_moe_op
fused_multi_transformer_moe_int8_op
fused_bias_dropout_residual_layer_norm_op
resnet_unit_op
fused_gemm_epilogue_op
Expand Down Expand Up @@ -121,6 +123,8 @@ if(WITH_GPU OR WITH_ROCM)
op_library(fused_attention_op)
op_library(fused_multi_transformer_op)
op_library(fused_multi_transformer_int8_op)
op_library(fused_multi_transformer_moe_op)
op_library(fused_multi_transformer_moe_int8_op)
op_library(fused_bias_dropout_residual_layer_norm_op)
endif()
# resnet_unit needs cudnn 8.0 above
Expand Down
141 changes: 78 additions & 63 deletions paddle/fluid/operators/fused/attn_gemm_int8.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,37 +20,41 @@ limitations under the License. */
#include "paddle/fluid/operators/fused/quant_dequant_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using phi::backends::gpu::GpuLaunchConfig;

template <typename T>
class AttnMatmulINT8 {
public:
AttnMatmulINT8(
const phi::GPUContext& dev_ctx, int m, int n, int k, bool compute_bias)
: dev_ctx_(dev_ctx), m_(m), n_(n), k_(k), compute_bias_(compute_bias) {
auto helper = std::make_shared<CublasLtHelper>(m, k, n);
helpers_.emplace_back(helper);
cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle();
helper_ = std::make_unique<CublasLtHelper<int32_t>>(m, k, n, lt_handle);
gpu_config_ = std::make_unique<GpuLaunchConfig>(
phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, m * n, DequantKernelVecSize));
}
~AttnMatmulINT8() {}

// This function is used to execute GEMM, with input and output's types are
// both T.
void ComputeForward(const framework::Tensor* weight,
const framework::Tensor* input,
framework::Tensor* input_tmp,
const framework::Tensor* bias,
framework::Tensor* output,
framework::Tensor* output_tmp,
framework::Tensor* bias_out,
void ComputeForward(const phi::DenseTensor* weight,
const phi::DenseTensor* input,
phi::DenseTensor* input_tmp,
const phi::DenseTensor* bias,
phi::DenseTensor* output,
phi::DenseTensor* output_tmp,
phi::DenseTensor* bias_out,
const float quant_in_scale,
const framework::Tensor* dequant_out_scale,
const int quant_out_scale_offset,
const phi::DenseTensor* dequant_out_scale,
phi::DenseTensor* workspace = nullptr,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
Expand All @@ -64,24 +68,26 @@ class AttnMatmulINT8 {
quant_min_bound,
dev_ctx_.stream());

helpers_[0]->GEMM(input_tmp->data<int8_t>(),
weight->data<int8_t>(),
output_tmp->data<int32_t>(),
dev_ctx_.stream());
helper_->GEMM(input_tmp->data<int8_t>(),
weight->data<int8_t>(),
output_tmp->data<int32_t>(),
dev_ctx_.stream(),
(void*)workspace->data<int8_t>(),
workspace->numel());

dequantize_kernel_launcher<T>(output_tmp->data<int32_t>(),
output->data<T>(),
m_,
n_,
dev_ctx_.stream(),
gpu_config_.get(),
quant_in_scale,
dequant_out_scale->data<float>(),
quant_out_scale_offset);
dequant_out_scale->data<float>());

if (compute_bias_) {
// bias_out = output + bias
std::vector<const framework::Tensor*> ins = {output, bias};
std::vector<framework::Tensor*> outs = {bias_out};
std::vector<const phi::DenseTensor*> ins = {output, bias};
std::vector<phi::DenseTensor*> outs = {bias_out};
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor<T>());
PADDLE_ENFORCE_EQ(cudaGetLastError(),
Expand All @@ -95,66 +101,72 @@ class AttnMatmulINT8 {

// This function is used to execute GEMM, with input and output's types are
// both INT8.
void ComputeForwardINT8ToINT8(const framework::Tensor* weight,
framework::Tensor* input,
const framework::Tensor* bias,
framework::Tensor* output,
framework::Tensor* bias_out) {
helpers_[0]->GEMM(input->data<int8_t>(),
weight->data<int8_t>(),
output->data<int32_t>(),
dev_ctx_.stream());
void ComputeForwardINT8ToINT8(const phi::DenseTensor* weight,
phi::DenseTensor* input,
const phi::DenseTensor* bias,
phi::DenseTensor* output,
phi::DenseTensor* bias_out,
phi::DenseTensor* workspace = nullptr) {
helper_->GEMM(input->data<int8_t>(),
weight->data<int8_t>(),
output->data<int32_t>(),
dev_ctx_.stream(),
(void*)workspace->data<int8_t>(),
workspace->numel());
}

// This function is used to execute GEMM, with input and output's types are
// INT8 and T.
void ComputeForwardINT8ToT(const framework::Tensor* weight,
void ComputeForwardINT8ToT(const phi::DenseTensor* weight,
const float quant_in_scale,
framework::Tensor* input,
const framework::Tensor* bias,
framework::Tensor* output,
framework::Tensor* output_tmp,
framework::Tensor* bias_out,
const framework::Tensor* dequant_out_scale,
const int quant_out_scale_offset) {
helpers_[0]->GEMM(input->data<int8_t>(),
weight->data<int8_t>(),
output_tmp->data<int32_t>(),
dev_ctx_.stream());
phi::DenseTensor* input,
const phi::DenseTensor* bias,
phi::DenseTensor* output,
phi::DenseTensor* output_tmp,
phi::DenseTensor* bias_out,
const phi::DenseTensor* dequant_out_scale,
phi::DenseTensor* workspace = nullptr) {
helper_->GEMM(input->data<int8_t>(),
weight->data<int8_t>(),
output_tmp->data<int32_t>(),
dev_ctx_.stream(),
(void*)workspace->data<int8_t>(),
workspace->numel());

dequantize_kernel_launcher<T>(output_tmp->data<int32_t>(),
output->data<T>(),
m_,
n_,
dev_ctx_.stream(),
gpu_config_.get(),
quant_in_scale,
dequant_out_scale->data<float>(),
quant_out_scale_offset);
dequant_out_scale->data<float>());

if (compute_bias_) {
// bias_out = output + bias
std::vector<const framework::Tensor*> ins = {output, bias};
std::vector<framework::Tensor*> outs = {bias_out};
std::vector<const phi::DenseTensor*> ins = {output, bias};
std::vector<phi::DenseTensor*> outs = {bias_out};
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor<T>());
PADDLE_ENFORCE_EQ(cudaGetLastError(),
cudaSuccess,
platform::errors::Fatal(
"cuda error occured after computing bias. "
"But it does not mean this error is caused by "
"bias computing"));
// PADDLE_ENFORCE_EQ(cudaGetLastError(),
// cudaSuccess,
// platform::errors::Fatal(
// "cuda error occured after computing bias. "
// "But it does not mean this error is caused by "
// "bias computing"));
}
}

// This function is used to execute GEMM, with input and output's types are T
// and INT8.
void ComputeForwardTToINT8(const framework::Tensor* weight,
void ComputeForwardTToINT8(const phi::DenseTensor* weight,
const float quant_in_scale,
const framework::Tensor* input,
framework::Tensor* input_tmp,
const framework::Tensor* bias,
framework::Tensor* output,
framework::Tensor* bias_out,
const phi::DenseTensor* input,
phi::DenseTensor* input_tmp,
const phi::DenseTensor* bias,
phi::DenseTensor* output,
phi::DenseTensor* bias_out,
phi::DenseTensor* workspace = nullptr,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
Expand All @@ -168,10 +180,12 @@ class AttnMatmulINT8 {
quant_min_bound,
dev_ctx_.stream());

helpers_[0]->GEMM(input_tmp->data<int8_t>(),
weight->data<int8_t>(),
output->data<int32_t>(),
dev_ctx_.stream());
helper_->GEMM(input_tmp->data<int8_t>(),
weight->data<int8_t>(),
output->data<int32_t>(),
dev_ctx_.stream(),
(void*)workspace->data<int8_t>(),
workspace->numel());
}

private:
Expand All @@ -182,8 +196,9 @@ class AttnMatmulINT8 {
int k_; // k

int compute_bias_;
std::vector<std::shared_ptr<CublasLtHelper>> helpers_;
std::unique_ptr<CublasLtHelper<int32_t>> helper_;
std::unique_ptr<GpuLaunchConfig> gpu_config_;
};

} // namespace operators
} // namespace paddle
} // namespace paddle
Loading

0 comments on commit aed70d9

Please sign in to comment.