Skip to content

Commit

Permalink
remove commits, fix cuda11.4 support int8
Browse files Browse the repository at this point in the history
  • Loading branch information
humingqing authored and humingqing committed Nov 28, 2023
1 parent 8a9d061 commit 652fcb3
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 102 deletions.
20 changes: 0 additions & 20 deletions paddle/fluid/operators/fused/fused_multi_transformer_moe_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -406,26 +406,6 @@ class FusedMultiTransformerMoeOpKernel : public framework::OpKernel<T> {
seq_len,
max_seq_len,
dim_head);
// if (dev_ctx.GetPlace().GetDeviceId() == 0) {
// VLOG(0) << "layer=" << i << ", debug flash attn dims[" <<
// bsz << ", " << num_head << "," << seq_len << "," << dim_head
//<< "]"
// << "bs=" << qkv_out.dims()[0] << ", seq_len=" <<
//qkv_out.dims()[1]
// << ", max_seq_len=" << cache_kv_out->dims()[3];
// if (i == 0 || i == 1) {
// char szname[512] = {0};
// char szinname[512] = {0};
// snprintf(szinname, sizeof(szinname), "./%d_input.txt", i);
// snprintf(szname, sizeof(szname), "./%d_output.txt", i);
// if (access(szname, 0) != 0) {
// std::ofstream fout(szinname, std::ios::binary |
//std::ofstream::out | std::ofstream::app); fout << qkv_out;
// std::ofstream fout2(szname, std::ios::binary |
//std::ofstream::out | std::ofstream::app); fout2 << fmha_out;
// }
// }
// }
} else { // not generation
VLOG(0) << "not support!";
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/backends/gpu/cuda/cuda_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ cudaDataType_t ToCudaDataType() {
} else if (std::is_same<T, phi::dtype::bfloat16>::value) {
return CUDA_R_16BF;
#endif
#if CUDA_VERSION >= 11060
#if CUDA_VERSION >= 11040
} else if (std::is_same<T, int8_t>::value) {
return CUDA_R_8I;
} else if (std::is_same<T, int32_t>::value) {
Expand Down
128 changes: 62 additions & 66 deletions paddle/phi/kernels/fused_moe_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,33 +14,33 @@

#pragma once

#include "paddle/phi/core/dense_tensor.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/collective/global_gather_op.h"
#include "paddle/fluid/operators/collective/global_scatter_op.h"
#include "paddle/fluid/operators/fused/fused_dropout_helper.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/fluid/operators/fused/fused_dropout_helper.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/top_k_kernel.h"
#include "paddle/phi/kernels/bmm_kernel.h"
#include "paddle/phi/kernels/cum_kernel.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/elementwise_add_kernel.h"
#include "paddle/phi/kernels/elementwise_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/functors.h"
#include "paddle/phi/kernels/index_select_kernel.h"
#include "paddle/phi/kernels/scatter_kernel.h"
#include "paddle/fluid/operators/collective/global_scatter_op.h"
#include "paddle/fluid/operators/collective/global_gather_op.h"
#include "paddle/phi/kernels/bmm_kernel.h"
#include "paddle/phi/kernels/elementwise_add_kernel.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/kernels/number_count_kernel.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"
#include "paddle/phi/kernels/scatter_kernel.h"
#include "paddle/phi/kernels/top_k_kernel.h"

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h"
Expand Down Expand Up @@ -77,7 +77,7 @@ static void AllToAll(Tensor& tensor, // NOLINT
} else {
auto dtype = platform::ToNCCLDataType(
framework::TransToProtoVarType(tensor.dtype()));
int64_t send_numel = tensor.numel(); // send_numel
int64_t send_numel = tensor.numel(); // send_numel
auto place = ctx.GetPlace();
auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
int nranks = comm->nranks();
Expand Down Expand Up @@ -170,8 +170,7 @@ void GlobalScatterFunctor(const phi::GPUContext& ctx,
if (platform::is_cpu_place(local_count->place())) {
cpu_local_count_data = local_count->data<int64_t>();
} else {
framework::TensorCopy(
*local_count, platform::CPUPlace(), &cpu_local_count);
framework::TensorCopy(*local_count, platform::CPUPlace(), &cpu_local_count);
cpu_local_count_data = cpu_local_count.data<int64_t>();
}
auto global_count_len = 0;
Expand Down Expand Up @@ -277,8 +276,7 @@ void GlobalScatterProcessGroupFunctor(const phi::GPUContext& ctx,
if (platform::is_cpu_place(local_count->place())) {
cpu_local_count_data = local_count->data<int64_t>();
} else {
framework::TensorCopy(
*local_count, platform::CPUPlace(), &cpu_local_count);
framework::TensorCopy(*local_count, platform::CPUPlace(), &cpu_local_count);
cpu_local_count_data = cpu_local_count.data<int64_t>();
}
auto global_count_len = 0;
Expand Down Expand Up @@ -323,15 +321,13 @@ void GlobalScatterProcessGroupFunctor(const phi::GPUContext& ctx,
if (cpu_local_count_data[idx]) {
phi::DenseTensor tmp = *x;
pg->Send_Partial(tmp,
j,
expert_ptr[idx] * in_feat,
cpu_local_count_data[idx] * in_feat);
j,
expert_ptr[idx] * in_feat,
cpu_local_count_data[idx] * in_feat);
}
if (cpu_global_count_data[idx]) {
pg->Recv_Partial(*out,
j,
recv_ptr * in_feat,
cpu_global_count_data[idx] * in_feat);
pg->Recv_Partial(
*out, j, recv_ptr * in_feat, cpu_global_count_data[idx] * in_feat);
recv_ptr += cpu_global_count_data[idx];
}
}
Expand Down Expand Up @@ -373,8 +369,7 @@ void GlobalGatherFunctor(const phi::GPUContext& ctx,
cpu_local_count_data = local_count->data<int64_t>();
local_count_len = local_count->numel();
} else {
framework::TensorCopy(
*local_count, platform::CPUPlace(), &cpu_local_count);
framework::TensorCopy(*local_count, platform::CPUPlace(), &cpu_local_count);
cpu_local_count_data = cpu_local_count.data<int64_t>();
local_count_len = cpu_local_count.numel();
}
Expand Down Expand Up @@ -483,8 +478,7 @@ void GlobalGatherProcessGroupFunctor(const phi::GPUContext& ctx,
cpu_local_count_data = local_count->data<int64_t>();
local_count_len = local_count->numel();
} else {
framework::TensorCopy(
*local_count, platform::CPUPlace(), &cpu_local_count);
framework::TensorCopy(*local_count, platform::CPUPlace(), &cpu_local_count);
cpu_local_count_data = cpu_local_count.data<int64_t>();
local_count_len = cpu_local_count.numel();
}
Expand Down Expand Up @@ -533,9 +527,9 @@ void GlobalGatherProcessGroupFunctor(const phi::GPUContext& ctx,
}
if (cpu_local_count_data[idx]) {
pg->Recv_Partial(*out,
j,
expert_ptr[idx] * in_feat,
cpu_local_count_data[idx] * in_feat);
j,
expert_ptr[idx] * in_feat,
cpu_local_count_data[idx] * in_feat);
}
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
Expand All @@ -559,27 +553,29 @@ void GlobalGatherProcessGroupFunctor(const phi::GPUContext& ctx,

template <typename T>
void MatMulAndAddGelu(const phi::GPUContext& dev_ctx,
const framework::Tensor* weight,
const framework::Tensor* input,
const framework::Tensor* bias,
bool istransA,
bool istransB,
bool compute_bias,
framework::Tensor* output) {
const framework::Tensor* weight,
const framework::Tensor* input,
const framework::Tensor* bias,
bool istransA,
bool istransB,
bool compute_bias,
framework::Tensor* output) {
#if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11040)
phi::funcs::LinearWithCublasLt<T>::Run(
dev_ctx,
input, // x
weight, // y
output, // out
((compute_bias) ? static_cast<const void*>(bias->data<T>()): nullptr), // bias
nullptr,
input->dims()[0], // M bsz_seq
weight->dims()[1], // N output_size
input->dims()[1], // K input_size
istransA,
istransB,
((compute_bias) ? phi::funcs::MatmulFusedType::kMatmulBiasGelu: phi::funcs::MatmulFusedType::kMatmulGelu));
dev_ctx,
input, // x
weight, // y
output, // out
((compute_bias) ? static_cast<const void*>(bias->data<T>())
: nullptr), // bias
nullptr,
input->dims()[0], // M bsz_seq
weight->dims()[1], // N output_size
input->dims()[1], // K input_size
istransA,
istransB,
((compute_bias) ? phi::funcs::MatmulFusedType::kMatmulBiasGelu
: phi::funcs::MatmulFusedType::kMatmulGelu));
#endif
}

Expand All @@ -596,18 +592,18 @@ void MatMulAndAdd(const phi::GPUContext& dev_ctx,
#if (CUDA_VERSION >= 11040)
if (compute_bias) {
phi::funcs::LinearWithCublasLt<T>::Run(
dev_ctx,
input, // x
weight, // y
bias_out, // out
static_cast<const void*>(bias->data<T>()), // bias
nullptr,
input->dims()[0], // M bsz_seq
weight->dims()[1], // N output_size
input->dims()[1], // K input_size
istransA,
istransB,
phi::funcs::MatmulFusedType::kMatmulBias);
dev_ctx,
input, // x
weight, // y
bias_out, // out
static_cast<const void*>(bias->data<T>()), // bias
nullptr,
input->dims()[0], // M bsz_seq
weight->dims()[1], // N output_size
input->dims()[1], // K input_size
istransA,
istransB,
phi::funcs::MatmulFusedType::kMatmulBias);
return;
}
#endif
Expand Down
30 changes: 15 additions & 15 deletions paddle/phi/kernels/gpu/fused_moe_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/gpu/fused_moe_kernel.cu.h"
#include "paddle/phi/kernels/funcs/scatter.cu.h"
#include "paddle/phi/kernels/gpu/fused_moe_kernel.cu.h"

namespace phi {
using Tensor = DenseTensor;
Expand Down Expand Up @@ -265,13 +265,13 @@ void FusedMoeKernel(const DeviceContext& dev_ctx,
// cuda 11.4
#if (CUDA_VERSION >= 11040)
MatMulAndAddGelu<T>(dev_ctx,
experts_weight1[idx],
&tmp_inp,
experts_bias1[idx],
false,
false,
false, // dont compute bias
&expert_out1);
experts_weight1[idx],
&tmp_inp,
experts_bias1[idx],
false,
false,
false, // dont compute bias
&expert_out1);
#else
// linear1 matmul
MatMulAndAdd<T>(dev_ctx,
Expand All @@ -283,10 +283,10 @@ void FusedMoeKernel(const DeviceContext& dev_ctx,
false, // dont compute bias
&expert_out1,
nullptr);

paddle::operators::FusedDropoutHelper<T, uint8_t>
fused_act_dropout_helper(
dev_ctx, cur_expert_count, dim_feedforward, dropout_param);
fused_act_dropout_helper(
dev_ctx, cur_expert_count, dim_feedforward, dropout_param);
// bias gelu
fused_act_dropout_helper.DropoutActBias(dev_ctx,
expert_out1.data<T>(),
Expand Down Expand Up @@ -390,11 +390,11 @@ void FusedMoeKernel(const DeviceContext& dev_ctx,
// layer norm
if (!pre_layer_norm) {
auto* ln_mean_data = dev_ctx.template Alloc<U>(&ln_mean);
auto* ln_variance_data = dev_ctx.template Alloc<U>(&ln_variance);
auto* ln_out_data = dev_ctx.template Alloc<T>(&ln_out);
auto* ln_variance_data = dev_ctx.template Alloc<U>(&ln_variance);
auto* ln_out_data = dev_ctx.template Alloc<T>(&ln_out);

const U* ln_scale_ptr = ln_scale.data<U>();
const U* ln_bias_ptr = ln_bias.data<U>();
const U* ln_scale_ptr = ln_scale.data<U>();
const U* ln_bias_ptr = ln_bias.data<U>();
pre_layernorm_helper.LayerNorm(dev_ctx,
out->data<T>(),
ln_scale_ptr,
Expand Down

0 comments on commit 652fcb3

Please sign in to comment.