From 0631322374b70411a5d2dfb4ab2030435894cd45 Mon Sep 17 00:00:00 2001 From: tianyan01 Date: Mon, 14 Aug 2023 15:21:53 +0800 Subject: [PATCH 1/9] modify for dynamic zeus --- paddle/fluid/operators/fused/fmha_ref.h | 4 + paddle/phi/api/yaml/ops.yaml | 9 + paddle/phi/infermeta/multiary.cc | 51 + paddle/phi/infermeta/multiary.h | 22 + .../phi/kernels/fusion/beam_search_softmax.h | 47 + .../kernels/fusion/gpu/beam_search_softmax.cu | 925 ++++++++++++++++++ python/paddle/__init__.py | 2 + python/paddle/fluid/dygraph/amp/auto_cast.py | 7 +- .../nn/functional/fused_transformer.py | 87 +- .../incubate/nn/layer/fused_transformer.py | 59 +- python/paddle/tensor/__init__.py | 2 + python/paddle/tensor/manipulation.py | 2 +- python/paddle/tensor/search.py | 90 ++ 13 files changed, 1251 insertions(+), 56 deletions(-) create mode 100644 paddle/phi/kernels/fusion/beam_search_softmax.h create mode 100644 paddle/phi/kernels/fusion/gpu/beam_search_softmax.cu diff --git a/paddle/fluid/operators/fused/fmha_ref.h b/paddle/fluid/operators/fused/fmha_ref.h index 7de59dd9ee2e3..863ab8cba964b 100644 --- a/paddle/fluid/operators/fused/fmha_ref.h +++ b/paddle/fluid/operators/fused/fmha_ref.h @@ -115,6 +115,10 @@ class FMHARef { // out [2, bs, num_head, cache_seq_len + seq_len, head_dim] concat(dev_ctx_, {*cache_kv_tensor, kv_tensor}, 3, cache_kv_out_tensor); out_seq_len = cache_kv_out_tensor->dims()[3]; + } else { + if (cache_kv_out_tensor) { + *cache_kv_out_tensor = transpose_2_out_tensor->Slice(1, 3); + } } int64_t q_size = batch_size_ * seq_len_ * num_head_ * head_dim_; diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 10e617bd91243..98c09f81574ec 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -208,3 +208,12 @@ kernel : func : flip backward : flip_grad + +- op : beam_search_softmax + args : (Tensor logits, Tensor cum_scores, Tensor sequence_lengths, Tensor stop_flags, Tensor end_ids, Tensor step_ids, Tensor last_cache_ids, Tensor last_beam_offsets, int beam_size, int max_seq_len, int max_dec_len, bool fuse_softmax, bool early_stop) + output : Tensor(ids_this_time), Tensor(out_cum_scores), Tensor(cache_ids), Tensor(beam_offsets), Tensor(parent_idx), Tensor(stop_flags_out), Tensor(seq_lens_out), Tensor(step_ids_out) + infer_meta : + func : BeamSearchSoftmaxInferMeta + kernel : + func : beam_search_softmax + data_type : logits diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 375b88493a92b..7d8ef2a39379d 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -674,6 +674,57 @@ void BatchNormInferInferMeta(const MetaTensor& x, config); } +void BeamSearchSoftmaxInferMeta(const MetaTensor& logits, + const MetaTensor& cum_scores, + const MetaTensor& sequence_lengths, + const MetaTensor& stop_flags, + const MetaTensor& end_ids, + const MetaTensor& step_ids, + const MetaTensor& last_cache_ids, + const MetaTensor& last_beam_offsets, + int beam_size, + int max_seq_len, + int max_dec_len, + bool fuse_softmax, + bool early_stop, + MetaTensor* ids_this_time, + MetaTensor* out_cum_scores, + MetaTensor* cache_ids, + MetaTensor* beam_offsets, + MetaTensor* parent_idx, + MetaTensor* stop_flags_out, + MetaTensor* seq_lens_out, + MetaTensor* step_ids_out) { + auto logits_dims = logits.dims(); + auto logits_dtype = logits.dtype(); + int bbm = logits_dims[0]; + + ids_this_time->set_dims({bbm, 1}); + ids_this_time->share_lod(cum_scores); + ids_this_time->set_dtype(DataType::INT32); + cache_ids->set_dims({bbm, max_dec_len}); + cache_ids->share_lod(cum_scores); + cache_ids->set_dtype(DataType::INT32); + beam_offsets->set_dims({1, bbm, max_seq_len + max_dec_len}); + beam_offsets->share_lod(cum_scores); + beam_offsets->set_dtype(DataType::INT32); + parent_idx->set_dims({bbm}); + parent_idx->share_lod(cum_scores); + parent_idx->set_dtype(DataType::INT32); + out_cum_scores->set_dims({bbm}); + out_cum_scores->share_lod(cum_scores); + out_cum_scores->set_dtype(logits_dtype); + stop_flags_out->set_dims({bbm, 1}); + stop_flags_out->share_lod(cum_scores); + stop_flags_out->set_dtype(DataType::BOOL); + seq_lens_out->set_dims({bbm, 1}); + seq_lens_out->share_lod(cum_scores); + seq_lens_out->set_dtype(DataType::INT32); + step_ids_out->set_dims({bbm, 1}); + step_ids_out->share_lod(cum_scores); + step_ids_out->set_dtype(DataType::INT32); +} + void BilinearTensorProductInferMeta(const MetaTensor& x, const MetaTensor& y, const MetaTensor& weight, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 8c601182e8fc8..6673d66b0125d 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -190,6 +190,28 @@ void BatchNormInferInferMeta(const MetaTensor& x, MetaTensor* variance_out, MetaConfig config = MetaConfig()); +void BeamSearchSoftmaxInferMeta(const MetaTensor& logits, + const MetaTensor& cum_scores, + const MetaTensor& sequence_lengths, + const MetaTensor& stop_flags, + const MetaTensor& end_ids, + const MetaTensor& step_ids, + const MetaTensor& last_cache_ids, + const MetaTensor& last_beam_offsets, + int beam_size, + int max_seq_len, + int max_dec_len, + bool fuse_softmax, + bool early_stop, + MetaTensor* ids_this_time, + MetaTensor* out_cum_scores, + MetaTensor* cache_ids, + MetaTensor* beam_offsets, + MetaTensor* parent_idx, + MetaTensor* stop_flags_out, + MetaTensor* seq_lens_out, + MetaTensor* step_ids_out); + void BilinearTensorProductInferMeta(const MetaTensor& x, const MetaTensor& y, const MetaTensor& weight, diff --git a/paddle/phi/kernels/fusion/beam_search_softmax.h b/paddle/phi/kernels/fusion/beam_search_softmax.h new file mode 100644 index 0000000000000..985515a0a39f3 --- /dev/null +++ b/paddle/phi/kernels/fusion/beam_search_softmax.h @@ -0,0 +1,47 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { +namespace fusion { + +template +void BeamSearchSoftmaxKernel(const Context &dev_ctx, + const DenseTensor &logits, + const DenseTensor &cum_scores, + const DenseTensor &sequence_lengths, + const DenseTensor &stop_flags, + const DenseTensor &end_ids, + const DenseTensor &step_ids, + const DenseTensor &last_cache_ids, + const DenseTensor &last_beam_offsets, + int beam_size, + int max_seq_len, + int max_dec_len, + bool fuse_softmax, + bool early_stop, + DenseTensor *ids_this_time, + DenseTensor *out_cum_scores, + DenseTensor *cache_ids, + DenseTensor *beam_offsets, + DenseTensor *parent_idx, + DenseTensor *stop_flags_out, + DenseTensor *seq_lens_out, + DenseTensor *step_ids_out); + +} // namespace fusion +} // namespace phi diff --git a/paddle/phi/kernels/fusion/gpu/beam_search_softmax.cu b/paddle/phi/kernels/fusion/gpu/beam_search_softmax.cu new file mode 100644 index 0000000000000..5652adfd1c50b --- /dev/null +++ b/paddle/phi/kernels/fusion/gpu/beam_search_softmax.cu @@ -0,0 +1,925 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/fusion/beam_search_softmax.h" + +namespace phi { +namespace fusion { + +#define FLT_MAX 1e38 +// #define DEBUG_BEAM_SEARCH_SOFTMAX + +#define CASE_K(K) \ + case K: \ + invokeTopKSoftMaxLauncher(dev_ctx, \ + log_probs, \ + stop_flags, \ + sequence_lengths, \ + cum_log_probs, \ + step_ids, \ + last_cache_ids, \ + last_beam_offsets, \ + end_ids, \ + out_cum_log_probs, \ + stop_flags_out, \ + seq_lens_out, \ + step_ids_out, \ + ids, \ + tmp_ids, \ + tmp_vals, \ + parent_idx, \ + cache_ids, \ + beam_offsets, \ + batch_size, \ + beam_size, \ + vocab_size, \ + max_seq_len, \ + max_dec_len, \ + fuse_softmax, \ + early_stop, \ + stream); \ + break + +struct __align__(8) DySoftMaxStruct { + float logit; + float score; +}; + +__device__ __forceinline__ DySoftMaxStruct +reduce_softmax_op(DySoftMaxStruct a, DySoftMaxStruct b) { + bool a_bigger = (a.logit > b.logit); + DySoftMaxStruct bigger_m = a_bigger ? a : b; + DySoftMaxStruct smaller_m = a_bigger ? b : a; + DySoftMaxStruct res; + res.score = bigger_m.score + + smaller_m.score * exp(smaller_m.logit - bigger_m.logit); + res.logit = bigger_m.logit; + return res; +} + +template +struct TopK { + int ids[K]; + T vals[K]; + int parent_ids[K]; + + __device__ __forceinline__ void insert(T elem, int elem_id) { + if (elem > vals[K - 1] || (ids[K - 1] == -1) || + ((elem == vals[K - 1]) && (elem_id < ids[K - 1]))) { + vals[K - 1] = elem; + ids[K - 1] = elem_id; + } + + for (int k = K - 2; k >= 0; --k) { + if ((vals[k + 1] > vals[k]) || (ids[k] == -1) || + ((vals[k + 1] == vals[k]) && (ids[k + 1] < ids[k]))) { + T tmp_val = vals[k]; + int tmp_id = ids[k]; + vals[k] = vals[k + 1]; + ids[k] = ids[k + 1]; + vals[k + 1] = tmp_val; + ids[k + 1] = tmp_id; + } + } + } + + __device__ __forceinline__ void insert(T elem, int elem_id, int parent_id) { + if (elem > vals[K - 1] || (ids[K - 1] == -1) || + ((elem == vals[K - 1]) && (elem_id < ids[K - 1]))) { + vals[K - 1] = elem; + ids[K - 1] = elem_id; + parent_ids[K - 1] = parent_id; + } + + for (int k = K - 2; k >= 0; --k) { + if ((vals[k + 1] > vals[k]) || (ids[k] == -1) || + ((vals[k + 1] == vals[k]) && (ids[k + 1] < ids[k]))) { + T tmp_val = vals[k]; + int tmp_id = ids[k]; + int parent_id2 = parent_ids[k]; + vals[k] = vals[k + 1]; + ids[k] = ids[k + 1]; + parent_ids[k] = parent_ids[k + 1]; + vals[k + 1] = tmp_val; + ids[k + 1] = tmp_id; + parent_ids[k + 1] = parent_id2; + } + } + } +}; + +template +__device__ __forceinline__ TopK reduce_topk_op(const TopK &a, + const TopK &b) { + TopK res = a; + for (int i = 0; i < K; ++i) res.insert(b.vals[i], b.ids[i]); + return res; +} + +template +struct TopKSoftMax { + DySoftMaxStruct softmax_md; + TopK topk; +}; + +template +__device__ __forceinline__ TopKSoftMax reduce_topk_softmax_op( + const TopKSoftMax &a, const TopKSoftMax &b) { + TopKSoftMax res; + res.softmax_md = reduce_softmax_op(a.softmax_md, b.softmax_md); + res.topk = reduce_topk_op(a.topk, b.topk); + return res; +} + +template +__global__ void batch_topk(const int *topk_tmp_id_buf, + const T *topk_tmp_val_buf, + const int *step_ids, + const bool *stop_flags, // bs * beam_size + const int *seq_lens, + const int *end_ids, + int *id_buf, + T *val_buf, + int *parent_idx, + bool *stop_flags_out, + int *seq_lens_out, + int *step_ids_out) { + int thread_id = threadIdx.x; + int block_id = blockIdx.x; // bs + const int beam_size = K / 2; + TopK partial; + if (thread_id == 0) { + for (int i = 0; i < beam_size; ++i) { + partial.ids[i] = -1; + partial.vals[i] = -FLT_MAX; + partial.parent_ids[i] = -1; + } + + int index = block_id * beam_size * K; + if (step_ids[0] == 0) { + for (int i = 0; i < K; i++) { + partial.insert( + (T)topk_tmp_val_buf[index + i], topk_tmp_id_buf[index + i], i / K); + } + } else { + for (int i = 0; i < beam_size * K; i++) { + partial.insert( + (T)topk_tmp_val_buf[index + i], topk_tmp_id_buf[index + i], i / K); + } + } + index = block_id * beam_size; + for (int i = 0; i < beam_size; i++) { + id_buf[index + i] = partial.ids[i]; + val_buf[index + i] = partial.vals[i]; + int parent_id = partial.parent_ids[i]; + parent_idx[index + i] = parent_id; + stop_flags_out[index + i] = stop_flags[index + parent_id]; + seq_lens_out[index + i] = seq_lens[index + parent_id]; + step_ids_out[index + i] = step_ids[index + parent_id]; +#ifdef DEBUG_BEAM_SEARCH_SOFTMAX + printf("bi: %d, id: %d, val: %f, parent_id: %d\n", block_id, + id_buf[index+i], val_buf[index+i], parent_id); +#endif + } + } +} + +template +__global__ void batch_topk(const int *topk_tmp_id_buf, + const T *topk_tmp_val_buf, + const float *cum_log_probs, + const int *step_ids, + const bool *stop_flags, // bs * beam_size + const int *seq_lens, + const int *end_ids, + int *id_buf, + T *val_buf, + int *parent_idx, + bool *stop_flags_out, + int *seq_lens_out, + int *step_ids_out) { + int thread_id = threadIdx.x; + int block_id = blockIdx.x; // bs + const int beam_size = K / 2; + TopK partial; + if (thread_id == 0) { + for (int i = 0; i < beam_size; ++i) { + partial.ids[i] = -1; + partial.vals[i] = -FLT_MAX; + partial.parent_ids[i] = -1; + } + + int index = block_id * beam_size * K; + if (step_ids[0] == 0) { + for (int i = 0; i < K; i++) { + partial.insert( + (T)topk_tmp_val_buf[index + i], topk_tmp_id_buf[index + i], i / K); + } + } else { + for (int i = 0; i < beam_size * K; i++) { + if (!stop_flags[block_id * beam_size + i / K]) { + // if stop, this branch end, no longer update. + partial.insert( + (T)topk_tmp_val_buf[index + i], topk_tmp_id_buf[index + i], i / K); + } + } + } + index = block_id * beam_size; + int stop_num = 0; + for (int i = 0; i < beam_size; i++) { + if (stop_flags[index + i]) { + parent_idx[index + i] = i; + id_buf[index + i] = end_ids[0]; + val_buf[index + i] = cum_log_probs[index + i]; + stop_flags_out[index + i] = stop_flags[index + i]; + seq_lens_out[index + i] = seq_lens[index + i]; + step_ids_out[index + i] = step_ids_out[index + i]; + stop_num++; +#ifdef DEBUG_BEAM_SEARCH_SOFTMAX + printf("%d has end, bi: %d, stop_num: %d\n", index + i, block_id, stop_num); +#endif + } else { + int parent_id = partial.parent_ids[i - stop_num]; + parent_idx[index + i] = parent_id; + id_buf[index + i] = partial.ids[i - stop_num]; + val_buf[index + i] = partial.vals[i - stop_num]; + stop_flags_out[index + i] = stop_flags[index + parent_id]; + seq_lens_out[index + i] = seq_lens[index + parent_id]; + step_ids_out[index + i] = step_ids[index + parent_id]; +#ifdef DEBUG_BEAM_SEARCH_SOFTMAX + printf("bi: %d, id: %d, val: %f, parent_id: %d\n", block_id, + id_buf[index+i], val_buf[index+i], parent_id); +#endif + } + } + } +} + +template +__global__ void beam_search_softmax_topk_stage1(const T *logits, + const bool *stop_flags, + const int *end_ids, + float *tmp_buffer, + const int vocab_size, + const bool fuse_softmax) { + int thread_id = threadIdx.x; + int vector_id = blockIdx.x; // batch beam index. + + __shared__ float buf_s[PACKED_TOP_KMD_SIZE]; + + const T MAX_T_VAL = FLT_MAX; + + const int v_local = (vocab_size + gridDim.y - 1) / gridDim.y; + const int section_start = v_local * blockIdx.y; + int section_end = section_start + v_local; + section_end = (section_end > vocab_size) ? vocab_size : section_end; + + logits += vector_id * vocab_size; + if (fuse_softmax) { + typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + TopKSoftMax partial; + bool finish = stop_flags[vector_id]; + for (int i = 0; i < K; ++i) { + partial.topk.ids[i] = -1; + partial.topk.vals[i] = -MAX_T_VAL; + } + partial.softmax_md.logit = -MAX_T_VAL; + partial.softmax_md.score = 0.0F; + + if (finish) { +#pragma unroll 1 + for (int elem_id = section_start + thread_id; elem_id < section_end; + elem_id += THREADBLOCK_SIZE) { + // if is_end, set to (MAX_T_VAL, 1) + T elem = (elem_id == end_ids[0]) ? MAX_T_VAL : -MAX_T_VAL; + DySoftMaxStruct new_elem{elem, 1.0F}; + partial.softmax_md = reduce_softmax_op(partial.softmax_md, new_elem); + partial.topk.insert(elem, elem_id); + } + } else { +#pragma unroll 1 + for (int elem_id = section_start + thread_id; elem_id < section_end; + elem_id += THREADBLOCK_SIZE) { + T elem = logits[elem_id]; + DySoftMaxStruct new_elem{elem, 1.0F}; + partial.softmax_md = reduce_softmax_op(partial.softmax_md, new_elem); + partial.topk.insert(elem, elem_id); + } + } + + TopKSoftMax total = + BlockReduce(temp_storage).Reduce(partial, reduce_topk_softmax_op); + + if (thread_id == 0) { + for (int i = 0; i < K; i++) { + reinterpret_cast(buf_s)[i] = total.topk.ids[i]; + buf_s[K + i] = total.topk.vals[i]; + } + buf_s[2 * K] = total.softmax_md.score; + buf_s[2 * K + 1] = total.softmax_md.logit; + } + } else { + typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + TopK partial; + bool finish = stop_flags[vector_id]; + for (int i = 0; i < K; ++i) { + partial.ids[i] = -1; + partial.vals[i] = -MAX_T_VAL; + } + + if (finish) { +#pragma unroll 1 + for (int elem_id = section_start + thread_id; elem_id < section_end; + elem_id += THREADBLOCK_SIZE) { + // if is_end, set to (end_id, 1) + T elem = (elem_id == end_ids[0]) ? 0 : -MAX_T_VAL; + partial.insert(elem, elem_id); + } + } else { +#pragma unroll 1 + for (int elem_id = section_start + thread_id; elem_id < section_end; + elem_id += THREADBLOCK_SIZE) { + T elem = logits[elem_id]; + partial.insert(elem, elem_id); + } + } + + TopK total = + BlockReduce(temp_storage).Reduce(partial, reduce_topk_op); + + if (thread_id == 0) { + for (int i = 0; i < K; i++) { + reinterpret_cast(buf_s)[i] = total.ids[i]; + buf_s[K + i] = total.vals[i]; + } + } + } + __syncthreads(); + for (int elem_id = thread_id; elem_id < PACKED_TOP_KMD_SIZE; + elem_id += THREADBLOCK_SIZE) { + tmp_buffer[blockIdx.x * PACKED_TOP_KMD_SIZE * gridDim.y + + blockIdx.y * PACKED_TOP_KMD_SIZE + elem_id] = buf_s[elem_id]; + } +} + +template +__global__ void beam_search_softmax_topk_stage2(const float *tmp_buffer, + const float *cum_log_probs, + int *tmp_ids, + T *tmp_vals, + const int voc_parts, + const int packed_top_kmd_size, + const bool fuse_softmax) { + const int vector_id = blockIdx.x; + const int thread_id = threadIdx.x; + const int PACKED_TOP_KMD_SIZE = packed_top_kmd_size; + + const T MAX_T_VAL = FLT_MAX; + + extern __shared__ char buf_s_[]; + float *buf_s = reinterpret_cast(buf_s_); + tmp_buffer += vector_id * PACKED_TOP_KMD_SIZE * voc_parts; + + if (fuse_softmax) { + typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + TopKSoftMax partial; + for (int i = 0; i < K; ++i) { + partial.topk.ids[i] = -1; + partial.topk.vals[i] = -MAX_T_VAL; + } + partial.softmax_md.logit = -MAX_T_VAL; + partial.softmax_md.score = 0.0F; + + for (int idx = thread_id; idx < PACKED_TOP_KMD_SIZE * voc_parts; + idx += THREADBLOCK_SIZE) { + buf_s[idx] = tmp_buffer[idx]; + } + __syncthreads(); + + if (threadIdx.x < voc_parts) { + float *b_s = buf_s + thread_id * PACKED_TOP_KMD_SIZE; + for (int i = 0; i < K; i++) { + partial.topk.ids[i] = reinterpret_cast(b_s)[i]; + partial.topk.vals[i] = b_s[K + i]; + } + partial.softmax_md.score = b_s[2 * K]; + partial.softmax_md.logit = b_s[2 * K + 1]; + } + __syncthreads(); + + TopKSoftMax total = + BlockReduce(temp_storage).Reduce(partial, reduce_topk_softmax_op); + + if (thread_id == 0) { + tmp_ids += vector_id * K; + tmp_vals += vector_id * K; + cum_log_probs += vector_id; + + float d_total_log = log(total.softmax_md.score); + for (int i = 0; i < K; ++i) { + // float val = expf((float)total.topk.vals[i] - total.softmax_md.logit - d_total_log); + float val = total.topk.vals[i] - total.softmax_md.logit - d_total_log; + tmp_ids[i] = total.topk.ids[i]; + tmp_vals[i] = val + cum_log_probs[0]; +#ifdef DEBUG_BEAM_SEARCH_SOFTMAX + printf("vector_id: %d, vals: %f, logit: %f, d_total_log: %f, id: %d, val: %f, cum_log_probs: %f, res: %f\n", vector_id, total.topk.vals[i], total.softmax_md.logit, d_total_log, tmp_ids[i], val, cum_log_probs[0], tmp_vals[i]); +#endif + } + } + } else { + typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + TopK partial; + for (int i = 0; i < K; ++i) { + partial.ids[i] = -1; + partial.vals[i] = -MAX_T_VAL; + } + + for (int idx = thread_id; idx < PACKED_TOP_KMD_SIZE * voc_parts; + idx += THREADBLOCK_SIZE) { + buf_s[idx] = tmp_buffer[idx]; + } + __syncthreads(); + + if (threadIdx.x < voc_parts) { + float *b_s = buf_s + thread_id * PACKED_TOP_KMD_SIZE; + for (int i = 0; i < K; i++) { + partial.ids[i] = reinterpret_cast(b_s)[i]; + partial.vals[i] = b_s[K + i]; + } + } + __syncthreads(); + + TopK total = + BlockReduce(temp_storage).Reduce(partial, reduce_topk_op); + + if (thread_id == 0) { + tmp_ids += vector_id * K; + tmp_vals += vector_id * K; + cum_log_probs += vector_id; + + for (int i = 0; i < K; ++i) { + float val = total.vals[i]; + tmp_ids[i] = total.ids[i]; + tmp_vals[i] = val + cum_log_probs[0]; + } + } + } +} + +template +void invokeBeamSearchSoftmaxTopKStage2(const float *tmp_buffer, + const float *cum_log_probs, + int *ids, + T *vals, + const int batch_size, + const int beam_size, + const int voc_parts, + const int packed_top_kmd_size, + const bool fuse_softmax, + cudaStream_t stream) { + int smem_stage2_size = voc_parts * packed_top_kmd_size * sizeof(float); + + if (voc_parts <= 32) { + beam_search_softmax_topk_stage2 + <<>>( + tmp_buffer, cum_log_probs, ids, vals, voc_parts, packed_top_kmd_size, fuse_softmax); + return; + } + if (voc_parts <= 64) { + beam_search_softmax_topk_stage2 + <<>>( + tmp_buffer, cum_log_probs, ids, vals, voc_parts, packed_top_kmd_size, fuse_softmax); + return; + } + if (voc_parts <= 128) { + beam_search_softmax_topk_stage2 + <<>>( + tmp_buffer, cum_log_probs, ids, vals, voc_parts, packed_top_kmd_size, fuse_softmax); + return; + } +} + +__global__ void update_beam_offsets_kernel( + const int *src_indir_cache, // bs * bm * max_len + const int *beam_ids, // bs * bm + const int *sequence_lengths, // bs * bm + const bool *stop_flags, + const int *step_ids, + int *tgt_indir_cache, + const int batch_size, + const int beam_size, + const int max_seq_len, + const int max_dec_len) { + int time_step = threadIdx.x + blockIdx.x * blockDim.x; + int bb_id = blockIdx.y; + const int batch_id = bb_id / beam_size; + const int beam_id = bb_id % beam_size; + const int src_beam = beam_ids[bb_id]; + const int src_bb_id = batch_id * beam_size + src_beam; + const int seq_len = sequence_lengths[src_bb_id]; + const int max_len = max_seq_len + max_dec_len; + + + if (seq_len == 0 || time_step >= min(seq_len + 1, max_len)) { + return; + } + // if (time_step >= max_len) { + // return; + // } + + if (bb_id >= beam_size * batch_size) { + return; + } + + const uint tgt_offset = + batch_id * beam_size * max_len + beam_id * max_len + time_step; + const uint src_offset = + batch_id * beam_size * max_len + src_beam * max_len + time_step; + + tgt_indir_cache[tgt_offset] = (time_step == sequence_lengths[src_bb_id]) + ? src_beam + : src_indir_cache[src_offset]; +} + +void invokeUpdateBeamOffset(const int *src_indir_cache, + const int *beam_ids, + const int *sequence_lengths, + const bool *stop_flags, + const int *step_ids, + int *tgt_indir_cache, + const int batch_size, + const int beam_size, + const int max_seq_len, + const int max_dec_len, + cudaStream_t stream) { + const dim3 block(32); + const dim3 grid((max_seq_len + max_dec_len + block.x - 1) / block.x, + batch_size * beam_size); + update_beam_offsets_kernel<<>>(src_indir_cache, + beam_ids, + sequence_lengths, + stop_flags, + step_ids, + tgt_indir_cache, + batch_size, + beam_size, + max_seq_len, + max_dec_len); +} + +__global__ void update_cache_ids_kernel( + const int *last_cache_ids, // bs * bm * max_dec_len + const int *beam_ids, // bs * bm + const int *ids_this_time, + const int *sequence_lengths, // bs * bm + const bool *stop_flags, + const int *step_ids, + int *cache_ids, + const int batch_size, + const int beam_size, + const int max_dec_len) { + int time_step = threadIdx.x + blockIdx.x * blockDim.x; + int bb_id = blockIdx.y; + const int batch_id = bb_id / beam_size; + const int beam_id = bb_id % beam_size; + const int src_beam = beam_ids[bb_id]; + const int src_bb_id = batch_id * beam_size + src_beam; + const int step = step_ids[src_bb_id]; + + if (sequence_lengths[src_bb_id] == 0 || time_step >= min(step + 1, max_dec_len)) { + return; + } + + if (bb_id >= beam_size * batch_size) { + return; + } + + const uint tgt_offset = + batch_id * beam_size * max_dec_len + beam_id * max_dec_len + time_step; + const uint src_offset = + batch_id * beam_size * max_dec_len + src_beam * max_dec_len + time_step; + + cache_ids[tgt_offset] = + (time_step == step) ? ids_this_time[bb_id] : last_cache_ids[src_offset]; +} + +void invokeUpdateCacheIds(const int *last_cache_ids, + const int *beam_ids, + const int *sequence_lengths, + const int *ids_this_time, + const bool *stop_flags, + const int *step_ids, + int *cache_ids, + const int batch_size, + const int beam_size, + const int max_dec_len, + cudaStream_t stream) { + const dim3 block(32); + const dim3 grid((max_dec_len + block.x - 1) / block.x, + batch_size * beam_size); + update_cache_ids_kernel<<>>(last_cache_ids, + beam_ids, + ids_this_time, + sequence_lengths, + stop_flags, + step_ids, + cache_ids, + batch_size, + beam_size, + max_dec_len); +} + +template +void invokeTopKSoftMaxLauncher(const Context &dev_ctx, + const T *log_probs, + const bool *stop_flags, + const int *sequence_lengths, + const float *cum_log_probs, + const int *step_ids, + const int *last_cache_ids, + const int *last_beam_offsets, + const int *end_ids, + float *out_cum_log_probs, + bool *stop_flags_out, + int *seq_lens_out, + int *step_ids_out, + int *ids, + int *tmp_ids, + T *tmp_vals, + int *parent_idx, + int *cache_ids, + int *beam_offsets, + const int batch_size, + const int beam_size, + const int vocab_size, + const int max_seq_len, + const int max_dec_len, + const bool fuse_softmax, + const bool early_stop, + cudaStream_t stream) { + // K = 2 * beam_size + const int block_size = 128; + int voc_parts = vocab_size / 1024; + voc_parts = std::min(128, voc_parts); + int packed_top_kmd_size = 2 * K; + if (fuse_softmax) { + packed_top_kmd_size += 2; + } + const int tmp_buffer_size = + batch_size * beam_size * voc_parts * packed_top_kmd_size; + DenseTensor tmp_buffer_tensor; + tmp_buffer_tensor.Resize(phi::make_ddim({tmp_buffer_size})); + dev_ctx.template Alloc(&tmp_buffer_tensor); + float *tmp_buffer = tmp_buffer_tensor.data(); + + dim3 grid(batch_size * beam_size, voc_parts); + if (fuse_softmax) { + cudaFuncSetAttribute(beam_search_softmax_topk_stage1, + cudaFuncAttributePreferredSharedMemoryCarveout, + cudaSharedmemCarveoutMaxL1); + // (bs, bm, voc_parts, 2 * K + 2) + beam_search_softmax_topk_stage1 + <<>>( + log_probs, stop_flags, end_ids, tmp_buffer, vocab_size, fuse_softmax); + } else { + cudaFuncSetAttribute(beam_search_softmax_topk_stage1, + cudaFuncAttributePreferredSharedMemoryCarveout, + cudaSharedmemCarveoutMaxL1); + // (bs, bm, voc_parts, 2 * K) + beam_search_softmax_topk_stage1 + <<>>( + log_probs, stop_flags, end_ids, tmp_buffer, vocab_size, fuse_softmax); + } + // (bs, bm, K) + invokeBeamSearchSoftmaxTopKStage2(tmp_buffer, + cum_log_probs, + tmp_ids, + tmp_vals, + batch_size, + beam_size, + voc_parts, + packed_top_kmd_size, + fuse_softmax, + stream); + // (bs, bm) + if (early_stop) { + batch_topk<<>>( + tmp_ids, + tmp_vals, + cum_log_probs, + step_ids, + stop_flags, + sequence_lengths, + end_ids, + ids, + out_cum_log_probs, + parent_idx, + stop_flags_out, + seq_lens_out, + step_ids_out); + } else { + batch_topk<<>>( + tmp_ids, + tmp_vals, + step_ids, + stop_flags, + sequence_lengths, + end_ids, + ids, + out_cum_log_probs, + parent_idx, + stop_flags_out, + seq_lens_out, + step_ids_out); + } + invokeUpdateBeamOffset(last_beam_offsets, + parent_idx, + sequence_lengths, + stop_flags, + step_ids, + beam_offsets, + batch_size, + beam_size, + max_seq_len, + max_dec_len, + stream); + invokeUpdateCacheIds(last_cache_ids, + parent_idx, + sequence_lengths, + ids, + stop_flags, + step_ids, + cache_ids, + batch_size, + beam_size, + max_dec_len, + stream); +} + +template +void invokeTopkSoftMax(const Context &dev_ctx, + const T *log_probs, + const bool *stop_flags, + const int *sequence_lengths, + const float *cum_log_probs, + const int *step_ids, + const int *last_cache_ids, + const int *last_beam_offsets, + const int *end_ids, + float *out_cum_log_probs, + bool *stop_flags_out, + int *seq_lens_out, + int *step_ids_out, + int *ids, + int *tmp_ids, + T *tmp_vals, + int *parent_idx, + int *cache_ids, + int *beam_offsets, + const int batch_size, + const int beam_size, + const int vocab_size, + const int max_seq_len, + const int max_dec_len, + const bool fuse_softmax, + const bool early_stop, + cudaStream_t stream) { + switch (beam_size) { + CASE_K(1); + CASE_K(2); + CASE_K(3); + CASE_K(4); + CASE_K(5); + CASE_K(6); + CASE_K(7); + CASE_K(8); + CASE_K(9); + CASE_K(10); + CASE_K(11); + CASE_K(12); + CASE_K(13); + CASE_K(14); + CASE_K(15); + CASE_K(16); + default: + PADDLE_THROW(paddle::platform::errors::Unimplemented( + "beam_size = %d is unsupport!", beam_size)); + } +} + +template +void BeamSearchSoftmaxKernel(const Context &dev_ctx, + const DenseTensor &logits, + const DenseTensor &cum_scores, + const DenseTensor &sequence_lengths, + const DenseTensor &stop_flags, + const DenseTensor &end_ids, + const DenseTensor &step_ids, + const DenseTensor &last_cache_ids, + const DenseTensor &last_beam_offsets, + int beam_size, + int max_seq_len, + int max_dec_len, + bool fuse_softmax, + bool early_stop, + DenseTensor *ids_this_time, + DenseTensor *out_cum_scores, + DenseTensor *cache_ids, + DenseTensor *beam_offsets, + DenseTensor *parent_idx, + DenseTensor *stop_flags_out, + DenseTensor *seq_lens_out, + DenseTensor *step_ids_out) { + const auto &logits_dims = logits.dims(); + int bs = logits_dims[0]; + int batch_size = bs / beam_size; + int vocab_size = logits_dims[1]; + + dev_ctx.template Alloc(ids_this_time); + dev_ctx.template Alloc(cache_ids); + dev_ctx.template Alloc(beam_offsets); + dev_ctx.template Alloc(parent_idx); + dev_ctx.template Alloc(out_cum_scores); + dev_ctx.template Alloc(stop_flags_out); + dev_ctx.template Alloc(seq_lens_out); + dev_ctx.template Alloc(step_ids_out); + + phi::Copy(dev_ctx, last_cache_ids, dev_ctx.GetPlace(), false, cache_ids); + phi::Copy( + dev_ctx, last_beam_offsets, dev_ctx.GetPlace(), false, beam_offsets); + phi::Copy( + dev_ctx, stop_flags, dev_ctx.GetPlace(), false, stop_flags_out); + phi::Copy( + dev_ctx, sequence_lengths, dev_ctx.GetPlace(), false, seq_lens_out); + phi::Copy( + dev_ctx, step_ids, dev_ctx.GetPlace(), false, step_ids_out); + + const int tmp_size = batch_size * beam_size * beam_size * 2; + DenseTensor tmp_topk_id, tmp_topk_val; + tmp_topk_id.Resize(phi::make_ddim({tmp_size})); + dev_ctx.template Alloc(&tmp_topk_id); + tmp_topk_val.Resize(phi::make_ddim({tmp_size})); + dev_ctx.template Alloc(&tmp_topk_val); + + invokeTopkSoftMax(dev_ctx, + logits.data(), + stop_flags.data(), + sequence_lengths.data(), + cum_scores.data(), + step_ids.data(), + last_cache_ids.data(), + last_beam_offsets.data(), + end_ids.data(), + out_cum_scores->data(), + stop_flags_out->data(), + seq_lens_out->data(), + step_ids_out->data(), + ids_this_time->data(), + tmp_topk_id.data(), + tmp_topk_val.data(), + parent_idx->data(), + cache_ids->data(), + beam_offsets->data(), + batch_size, + beam_size, + vocab_size, + max_seq_len, + max_dec_len, + fuse_softmax, + early_stop, + dev_ctx.stream()); +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(beam_search_softmax, + GPU, + ALL_LAYOUT, + phi::fusion::BeamSearchSoftmaxKernel, + float) {} // only supports float diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index c8286c09b10fa..018a59ebc5261 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -312,6 +312,7 @@ from .tensor.search import sort # noqa: F401 from .tensor.search import kthvalue # noqa: F401 from .tensor.search import mode # noqa: F401 +from .tensor.search import beam_search_softmax # noqa: F401 from .tensor.to_string import set_printoptions # noqa: F401 @@ -667,4 +668,5 @@ 'sgn', 'triu_indices', 'take', + 'beam_search_softmax', ] diff --git a/python/paddle/fluid/dygraph/amp/auto_cast.py b/python/paddle/fluid/dygraph/amp/auto_cast.py index 87df808213656..decaf45125750 100644 --- a/python/paddle/fluid/dygraph/amp/auto_cast.py +++ b/python/paddle/fluid/dygraph/amp/auto_cast.py @@ -186,10 +186,13 @@ def pure_fp16_initialize(models): if (layer._dtype == 'float16') or isinstance( layer, (paddle.nn.BatchNorm, paddle.nn.BatchNorm1D, paddle.nn.BatchNorm2D, paddle.nn.BatchNorm3D, - paddle.nn.LayerNorm, paddle.nn.SyncBatchNorm)): + paddle.nn.LayerNorm, paddle.nn.SyncBatchNorm, + paddle.nn.ParameterList)): + # tianyan01 add paddle.nn.ParameterList, hack continue if isinstance(layer, (paddle.incubate.nn.FusedFeedForward, - paddle.incubate.nn.FusedMultiHeadAttention)): + paddle.incubate.nn.FusedMultiHeadAttention, + paddle.incubate.nn.FusedMultiTransformer)): layer._amp_decorate(dtype='float16') continue layer._to_impl(dtype='float16', diff --git a/python/paddle/incubate/nn/functional/fused_transformer.py b/python/paddle/incubate/nn/functional/fused_transformer.py index 02b844751a889..a39f1cb94c0c5 100644 --- a/python/paddle/incubate/nn/functional/fused_transformer.py +++ b/python/paddle/incubate/nn/functional/fused_transformer.py @@ -40,6 +40,7 @@ def fused_feedforward( ln2_bias=None, dropout1_rate=0.5, dropout2_rate=0.5, + seed=None, activation="relu", ln1_epsilon=1e-5, ln2_epsilon=1e-5, @@ -123,7 +124,7 @@ def fused_feedforward( _verify_dropout_rate(dropout1_rate) _verify_dropout_rate(dropout2_rate) - seed = None + # seed = None if mode not in ('downscale_in_infer', 'upscale_in_train'): raise ValueError( "mode argument should be 'downscale_in_infer' or 'upscale_in_train'" @@ -133,8 +134,8 @@ def fused_feedforward( ) # semantic transfer if _non_static_mode(): - if default_main_program().random_seed != 0: - seed = default_main_program().random_seed + # if default_main_program().random_seed != 0: + # seed = default_main_program().random_seed out, _, _, _, _, _, _, _, _, _, _ = _legacy_C_ops.fused_feedforward( x, None, @@ -221,8 +222,8 @@ def fused_feedforward( x.dtype, stop_gradient=True ) - if (seed is None or seed == 0) and helper.main_program.random_seed != 0: - seed = helper.main_program.random_seed + # if (seed is None or seed == 0) and helper.main_program.random_seed != 0: + # seed = helper.main_program.random_seed helper.append_op( type='fused_feedforward', @@ -477,6 +478,8 @@ def fused_multi_head_attention( attn_mask=None, dropout_rate=0.5, attn_dropout_rate=0.5, + dropout_seed=None, + attn_dropout_seed=None, ln_epsilon=1e-05, training=True, mode='upscale_in_train', @@ -602,7 +605,7 @@ def fused_multi_head_attention( print(output.shape) """ - seed = None + # seed = None if mode not in ('downscale_in_infer', 'upscale_in_train'): raise ValueError( "mode argument should be 'downscale_in_infer' or 'upscale_in_train'" @@ -612,8 +615,8 @@ def fused_multi_head_attention( ) # semantic transfer if _non_static_mode(): - if default_main_program().random_seed != 0: - seed = default_main_program().random_seed + # if default_main_program().random_seed != 0: + # seed = default_main_program().random_seed # pre_ln_mean, pre_ln_variance, pre_ln_out, qkv_out, qkv_bias_out, transpose_out, qk_out, # qktv_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, attn_mask_out, fmha_out, # linear_out, dropout_mask_out, ln_mean_out, ln_var_out, bias_dropout_residual_out, final_out @@ -678,13 +681,13 @@ def fused_multi_head_attention( 'is_test', not training, 'attn_dropout_fix_seed', - seed is not None, + attn_dropout_seed is not None, 'dropout_fix_seed', - seed is not None, + dropout_seed is not None, 'attn_dropout_seed', - seed if seed is not None else 0, + attn_dropout_seed if attn_dropout_seed is not None else 0, 'dropout_seed', - seed if seed is not None else 0, + dropout_seed if dropout_seed is not None else 0, 'attn_dropout_implementation', mode, 'dropout_implementation', @@ -735,8 +738,8 @@ def fused_multi_head_attention( if cache_kv: inputs['CacheKV'] = [cache_kv] - if (seed is None or seed == 0) and helper.main_program.random_seed != 0: - seed = helper.main_program.random_seed + # if (seed is None or seed == 0) and helper.main_program.random_seed != 0: + # seed = helper.main_program.random_seed # set attrs attrs = { @@ -746,10 +749,10 @@ def fused_multi_head_attention( 'dropout_rate': dropout_rate, 'attn_dropout_rate': attn_dropout_rate, 'is_test': not training, - 'attn_dropout_fix_seed': seed is not None, - 'dropout_fix_seed': seed is not None, - 'attn_dropout_seed': seed if seed is not None else 0, - 'dropout_seed': seed if seed is not None else 0, + 'attn_dropout_fix_seed': attn_dropout_seed is not None, + 'dropout_fix_seed': dropout_seed is not None, + 'attn_dropout_seed': attn_dropout_seed if attn_dropout_seed is not None else 0, + 'dropout_seed': dropout_seed if dropout_seed is not None else 0, 'attn_dropout_implementation': mode, 'dropout_implementation': mode, 'add_residual': add_residual, @@ -998,21 +1001,21 @@ def fused_multi_transformer( if _non_static_mode(): cache_kv_out, final_out = _legacy_C_ops.fused_multi_transformer( x, - ln_scales, - ln_biases, - qkv_weights, - qkv_biases, + list(ln_scales), + list(ln_biases), + list(qkv_weights), + list(qkv_biases), cache_kvs, time_step, attn_mask, - linear_weights, - linear_biases, - ffn_ln_scales, - ffn_ln_biases, - ffn1_weights, - ffn1_biases, - ffn2_weights, - ffn2_biases, + list(linear_weights), + list(linear_biases), + list(ffn_ln_scales), + list(ffn_ln_biases), + list(ffn1_weights), + list(ffn1_biases), + list(ffn2_weights), + list(ffn2_biases), cache_kvs, 'pre_layer_norm', pre_layer_norm, @@ -1048,29 +1051,29 @@ def fused_multi_transformer( # set inputs inputs = dict() inputs['X'] = [x] - inputs['LnScale'] = ln_scales - inputs['LnBias'] = ln_biases - inputs['QKVW'] = qkv_weights + inputs['LnScale'] = list(ln_scales) + inputs['LnBias'] = list(ln_biases) + inputs['QKVW'] = list(qkv_weights) if qkv_biases is not None: - inputs['QKVBias'] = qkv_biases + inputs['QKVBias'] = list(qkv_biases) if cache_kvs is not None: assert len(cache_kvs) == len(qkv_weights) inputs['CacheKV'] = cache_kvs if time_step is not None: inputs['TimeStep'] = time_step inputs['SrcMask'] = attn_mask - inputs['OutLinearW'] = linear_weights + inputs['OutLinearW'] = list(linear_weights) if linear_biases is not None: - inputs['OutLinearBias'] = linear_biases + inputs['OutLinearBias'] = list(linear_biases) - inputs['FFNLnScale'] = ffn_ln_scales - inputs['FFNLnBias'] = ffn_ln_biases - inputs['FFN1Weight'] = ffn1_weights + inputs['FFNLnScale'] = list(ffn_ln_scales) + inputs['FFNLnBias'] = list(ffn_ln_biases) + inputs['FFN1Weight'] = list(ffn1_weights) if ffn1_biases is not None: - inputs['FFN1Bias'] = ffn1_biases - inputs['FFN2Weight'] = ffn2_weights + inputs['FFN1Bias'] = list(ffn1_biases) + inputs['FFN2Weight'] = list(ffn2_weights) if ffn2_biases is not None: - inputs['FFN2Bias'] = ffn2_biases + inputs['FFN2Bias'] = list(ffn2_biases) # set attrs attrs = { diff --git a/python/paddle/incubate/nn/layer/fused_transformer.py b/python/paddle/incubate/nn/layer/fused_transformer.py index c3655c9d93a27..d71086fe7b07d 100644 --- a/python/paddle/incubate/nn/layer/fused_transformer.py +++ b/python/paddle/incubate/nn/layer/fused_transformer.py @@ -16,6 +16,7 @@ from paddle.nn import Layer from paddle.framework import ParamAttr import paddle +from paddle.nn import ParameterList from paddle.nn.layer.transformer import ( _convert_attention_mask, _convert_param_attr_to_list, @@ -269,6 +270,8 @@ def __init__( num_heads, dropout_rate=0.5, attn_dropout_rate=0.5, + dropout_seed=None, + attn_dropout_seed=None, kdim=None, vdim=None, normalize_before=False, @@ -377,6 +380,8 @@ def __init__( self.dropout_rate = dropout_rate self.attn_dropout_rate = attn_dropout_rate + self.dropout_seed = dropout_seed + self.attn_dropout_seed = attn_dropout_seed self.name = name @@ -434,6 +439,8 @@ def forward(self, query, key=None, value=None, attn_mask=None, cache=None): attn_mask=attn_mask, dropout_rate=self.dropout_rate, attn_dropout_rate=self.attn_dropout_rate, + dropout_seed=self.dropout_seed, + attn_dropout_seed=self.attn_dropout_seed, ln_epsilon=self._epsilon, training=self.training, ring_id=self._ring_id, @@ -546,6 +553,7 @@ def __init__( epsilon=1e-05, activation="relu", act_dropout_rate=None, + seed=None, normalize_before=False, linear1_weight_attr=None, linear1_bias_attr=None, @@ -582,6 +590,7 @@ def __init__( self._act_dropout_rate = ( dropout_rate if act_dropout_rate is None else act_dropout_rate ) + self._seed = seed self._act_method = activation self._normalize_before = normalize_before self._epsilon = epsilon @@ -661,6 +670,7 @@ def forward(self, src, cache=None): self._ln2_bias, dropout1_rate=self._act_dropout_rate, dropout2_rate=self._dropout_rate, + seed=self._seed, activation=self._act_method, ln1_epsilon=self._epsilon, ln2_epsilon=self._epsilon, @@ -1187,6 +1197,7 @@ def __init__( trans_qkvw=True, ring_id=-1, name=None, + dy_to_st=False, ): super(FusedMultiTransformer, self).__init__() @@ -1227,19 +1238,26 @@ def __init__( dim_feedforward = dim_feedforward // nranks self._dim_feedforward = dim_feedforward - if isinstance(qkv_weight_attrs, (list, tuple)): + if isinstance(qkv_weight_attrs, (list, tuple, ParameterList)): num_layers = len(qkv_weight_attrs) assert num_layers > 0 - self.ln_scales, self.ln_biases = [], [] - self.qkv_weights, self.qkv_biases = [], [] - self.linear_weights, self.linear_biases = [], [] - self.ffn_ln_scales, self.ffn_ln_biases = [], [] - self.ffn1_weights, self.ffn1_biases = [], [] - self.ffn2_weights, self.ffn2_biases = [], [] - + # if not dy_to_st: + # self.ln_scales, self.ln_biases = [], [] + # self.qkv_weights, self.qkv_biases = [], [] + # self.linear_weights, self.linear_biases = [], [] + # self.ffn_ln_scales, self.ffn_ln_biases = [], [] + # self.ffn1_weights, self.ffn1_biases = [], [] + # self.ffn2_weights, self.ffn2_biases = [], [] + # else: + self.ln_scales, self.ln_biases = ParameterList(), ParameterList() + self.qkv_weights, self.qkv_biases = ParameterList(), ParameterList() + self.linear_weights, self.linear_biases = ParameterList(), ParameterList() + self.ffn_ln_scales, self.ffn_ln_biases = ParameterList(), ParameterList() + self.ffn1_weights, self.ffn1_biases = ParameterList(), ParameterList() + self.ffn2_weights, self.ffn2_biases = ParameterList(), ParameterList() def get_attr(attrs, idx): - if isinstance(attrs, (list, tuple)): + if isinstance(attrs, (list, tuple, ParameterList)): assert len(attrs) == num_layers return attrs[idx] return attrs @@ -1263,9 +1281,10 @@ def get_attr(attrs, idx): attr=ln_scale_attr, shape=[embed_dim], default_initializer=Constant(value=1.0), + dtype="float32", ) ln_bias = self.create_parameter( - attr=ln_bias_attr, shape=[embed_dim], is_bias=True + attr=ln_bias_attr, shape=[embed_dim], is_bias=True, dtype="float32" ) qkv_weight = self.create_parameter( shape=[3, num_heads, self.head_dim, embed_dim] @@ -1299,9 +1318,10 @@ def get_attr(attrs, idx): attr=ffn_ln_scale_attr, is_bias=False, default_initializer=Constant(1.0), + dtype="float32", ) ffn_ln_bias = self.create_parameter( - shape=[embed_dim], attr=ffn_ln_bias_attr, is_bias=True + shape=[embed_dim], attr=ffn_ln_bias_attr, is_bias=True, dtype="float32" ) ffn1_weight = self.create_parameter( shape=[embed_dim, dim_feedforward], @@ -1418,3 +1438,20 @@ def forward(self, src, attn_mask=None, caches=None, time_step=None): name=self.name, ) return out + + def _amp_decorate(self, dtype): + # tmp fix for amp.decorator(O2) + def trans_to_fp16(l): + for param in l: + if param is not None: + with no_grad(): + param_applied = _to_dtype(param, dtype) + trans_to_fp16(self.qkv_weights) + trans_to_fp16(self.qkv_biases) + trans_to_fp16(self.linear_weights) + trans_to_fp16(self.linear_biases) + trans_to_fp16(self.ffn1_weights) + trans_to_fp16(self.ffn1_biases) + trans_to_fp16(self.ffn2_weights) + trans_to_fp16(self.ffn2_biases) + self._dtype = dtype \ No newline at end of file diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 58dfa26cfe377..4c27dae6af5e5 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -265,6 +265,7 @@ from .search import masked_select # noqa: F401 from .search import kthvalue # noqa: F401 from .search import mode # noqa: F401 +from .search import beam_search_softmax # noqa: F401 from .stat import mean # noqa: F401 from .stat import std # noqa: F401 @@ -515,6 +516,7 @@ 'take', 'bucketize', 'sgn', + 'beam_search_softmax', ] # this list used in math_op_patch.py for magic_method bind diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index f987e8b89cf25..7733f2ddd99d6 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -2721,7 +2721,7 @@ def gather(x, index, axis=None, name=None): check_variable_and_dtype( x, 'x', - ['float16', 'float32', 'float64', 'int16', 'int32', 'int64', 'uint8'], + ['bool', 'float16', 'float32', 'float64', 'int16', 'int32', 'int64', 'uint8'], 'gather', ) check_variable_and_dtype(index, 'index', ['int32', 'int64'], 'gather') diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index e4458048edc6b..17499887e7a5b 100644 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -1106,3 +1106,93 @@ def kthvalue(x, k, axis=None, keepdim=False, name=None): attrs=attrs) indices.stop_gradient = True return values, indices + + +def beam_search_softmax( + logits, + cum_scores, + sequence_lengths, + stop_flags, + end_ids, + step_ids, + last_cache_ids, + last_beam_offsets, + beam_size, + max_seq_len, + max_dec_len, + fuse_softmax, + early_stop, + name=None, +): + if in_dygraph_mode(): + return _C_ops.beam_search_softmax( + logits, + cum_scores, + sequence_lengths, + stop_flags, + end_ids, + step_ids, + last_cache_ids, + last_beam_offsets, + beam_size, + max_seq_len, + max_dec_len, + fuse_softmax, + early_stop + ) + + inputs = { + "logits": logits, + "cum_scores": cum_scores, + "sequence_lengths": sequence_lengths, + "stop_flags": stop_flags, + "end_ids": end_ids, + "step_ids": step_ids, + "last_cache_ids": last_cache_ids, + "last_beam_offsets": last_beam_offsets, + } + attrs = {} + attrs['beam_size'] = beam_size + attrs['max_seq_len'] = max_seq_len + attrs['max_dec_len'] = max_dec_len + attrs['fuse_softmax'] = fuse_softmax + attrs['early_stop'] = early_stop + + helper = LayerHelper('beam_search_softmax', **locals()) + ids_this_time = helper.create_variable_for_type_inference(dtype="int32") + cache_ids = helper.create_variable_for_type_inference(dtype="int32") + beam_offsets = helper.create_variable_for_type_inference(dtype="int32") + parent_idx = helper.create_variable_for_type_inference(dtype="int32") + out_cum_scores = helper.create_variable_for_type_inference( + dtype=logits.dtype + ) + stop_flags_out = helper.create_variable_for_type_inference( + dtype=stop_flags.dtype + ) + seq_lens_out = helper.create_variable_for_type_inference(dtype="int32") + step_ids_out = helper.create_variable_for_type_inference(dtype="int32") + helper.append_op( + type='beam_search_softmax', + inputs=inputs, + outputs={ + "ids_this_time": ids_this_time, + "out_cum_scores": out_cum_scores, + "cache_ids": cache_ids, + "beam_offsets": beam_offsets, + "parent_idx": parent_idx, + "stop_flags_out": stop_flags_out, + "seq_lens_out": seq_lens_out, + "step_ids_out": step_ids_out, + }, + attrs=attrs, + ) + return ( + ids_this_time, + out_cum_scores, + cache_ids, + beam_offsets, + parent_idx, + stop_flags_out, + seq_lens_out, + step_ids_out + ) \ No newline at end of file From 55bac3e7ff2375b0c1fe103ea8e3b091e1a04ca2 Mon Sep 17 00:00:00 2001 From: tianyan01 Date: Tue, 5 Sep 2023 14:26:56 +0800 Subject: [PATCH 2/9] add fused_moe op --- paddle/fluid/operators/fused/CMakeLists.txt | 2 + paddle/fluid/operators/fused/fused_moe_op.cc | 132 ++++ paddle/fluid/operators/fused/fused_moe_op.cu | 501 ++++++++++++ paddle/fluid/operators/fused/fused_moe_op.h | 718 ++++++++++++++++++ paddle/fluid/pybind/op_function_generator.h | 11 + .../contrib/mixed_precision/fp16_lists.py | 1 + .../contrib/mixed_precision/fp16_utils.py | 2 + python/paddle/fluid/dygraph/amp/auto_cast.py | 6 +- python/paddle/incubate/nn/__init__.py | 2 + .../incubate/nn/layer/fused_transformer.py | 168 ++++ 10 files changed, 1540 insertions(+), 3 deletions(-) create mode 100644 paddle/fluid/operators/fused/fused_moe_op.cc create mode 100644 paddle/fluid/operators/fused/fused_moe_op.cu create mode 100644 paddle/fluid/operators/fused/fused_moe_op.h diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt index fe82565bc36f3..b1676883ff39b 100755 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -23,6 +23,7 @@ register_operators( fused_transformer_op fused_feedforward_op fused_multi_transformer_op + fused_moe_op fused_multi_transformer_int8_op fused_bias_dropout_residual_layer_norm_op resnet_unit_op @@ -120,6 +121,7 @@ if(WITH_GPU OR WITH_ROCM) # fused_attention_op op_library(fused_attention_op) op_library(fused_multi_transformer_op) + op_library(fused_moe_op) op_library(fused_multi_transformer_int8_op) op_library(fused_bias_dropout_residual_layer_norm_op) endif() diff --git a/paddle/fluid/operators/fused/fused_moe_op.cc b/paddle/fluid/operators/fused/fused_moe_op.cc new file mode 100644 index 0000000000000..faaaf5d5b1938 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_moe_op.cc @@ -0,0 +1,132 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; + +class FusedMoeOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *context) const override { + // input + OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "fused_moe"); + OP_INOUT_CHECK(context->HasInput("GateWeight"), + "Input", + "GateWeight", + "fused_moe"); + OP_INOUT_CHECK(context->HasInput("GateBias"), + "Input", + "GateBias", + "fused_moe"); + OP_INOUT_CHECK(context->HasInput("LnScale"), + "Input", + "LnScale", + "fused_moe"); + OP_INOUT_CHECK(context->HasInput("LnBias"), + "Input", + "LnBias", + "fused_moe"); + OP_INOUT_CHECK(context->HasInputs("ExpertsWeight1"), + "Input", + "ExpertsWeight1", + "fused_moe"); + OP_INOUT_CHECK(context->HasInputs("ExpertsBias1"), + "Input", + "ExpertsBias1", + "fused_moe"); + OP_INOUT_CHECK(context->HasInputs("ExpertsWeight2"), + "Input", + "ExpertsWeight2", + "fused_moe"); + OP_INOUT_CHECK(context->HasInputs("ExpertsBias2"), + "Input", + "ExpertsBias2", + "fused_moe"); + // output + OP_INOUT_CHECK(context->HasOutput("Out"), + "Output", + "Out", + "fused_moe"); + auto x_dims = context->GetInputDim("X"); + context->SetOutputDim("Out", x_dims); + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } +}; + +class FusedMoeOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + // AsDispensable 可有可无 + // AsDuplicable 可复制 + // input + AddInput("X", "The input of FusedMoe op"); + AddInput("GateWeight", "The gate weight of FusedMoe op"); + AddInput("GateBias", "The gate bias of FusedMoe op"); + AddInput("LnScale", "The ln scale of FusedMoe op"); + AddInput("LnBias", "The LnBias of FusedMoe op"); + AddInput("ExpertsWeight1", "The expert linear1 weights of fused_moe op") + .AsDuplicable(); + AddInput("ExpertsBias1", "The expert linear1 biases of fused_moe op") + .AsDuplicable() + .AsDispensable(); + AddInput("ExpertsWeight2", "The expert linear2 weights of fused_moe op") + .AsDuplicable(); + AddInput("ExpertsBias2", "The expert linear2 biases of fused_moe op") + .AsDuplicable() + .AsDispensable(); + // output + AddOutput("Out", "Out"); + // attr + AddAttr("pre_layer_norm", "pre_layer_norm").SetDefault(true); + AddAttr("ln_epsilon", "ln_epsilon").SetDefault(1e-5f); + AddAttr("topk", "top k in gate").SetDefault(2); + AddAttr("mp_size", "mp_size").SetDefault(1); + AddAttr("mp_rank", "mp_rank").SetDefault(0); + AddAttr("num_expert", "num_expert").SetDefault(1); + AddAttr("world_size", "world_size").SetDefault(1); + AddAttr("moe_ring_id", "moe_ring_id").SetDefault(-1); + AddComment(R"DOC( + The fused_moe operator is the same as the following pseudo codes: + + pass + + )DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(fused_moe, + ops::FusedMoeOp, + ops::FusedMoeOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); \ No newline at end of file diff --git a/paddle/fluid/operators/fused/fused_moe_op.cu b/paddle/fluid/operators/fused/fused_moe_op.cu new file mode 100644 index 0000000000000..0478c6cf551bc --- /dev/null +++ b/paddle/fluid/operators/fused/fused_moe_op.cu @@ -0,0 +1,501 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/fused/fused_moe_op.h" + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; + +template +static void AllToAll(Tensor& tensor, // NOLINT + Tensor& out, + const int ring_id, + const phi::GPUContext& ctx) { + if (ring_id == -1) return; +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); + + if (map->has(ring_id)) { + paddle::distributed::ProcessGroup* pg = map->get(ring_id); + auto pg_nccl = static_cast(pg); + + std::vector in_tensor; + std::vector out_tensor; + in_tensor.push_back(tensor); + out_tensor.push_back(out); + auto task = pg_nccl->AllToAll(in_tensor, out_tensor, true, true); + task->Wait(); + } else { + auto dtype = platform::ToNCCLDataType( + framework::TransToProtoVarType(tensor.dtype())); + int64_t send_numel = tensor.numel(); // send_numel + auto place = ctx.GetPlace(); + auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); + int nranks = comm->nranks(); + auto stream = ctx.stream(); + + framework::DDim x_dims = tensor.dims(); + framework::DDim out_dims(x_dims); + PADDLE_ENFORCE_EQ( + x_dims[0] % nranks, + 0, + platform::errors::InvalidArgument( + "The first dimension size (%d) of the input tensor must be " + "divisible by the number of ranks (%d).", + x_dims[0], + nranks)); + auto send_buf = tensor.data(); + auto recv_buf = out.mutable_data(out_dims, place); + size_t offset = 0; + send_numel /= nranks; + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); + for (auto i = 0; i < nranks; ++i) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend( + send_buf + offset, send_numel, dtype, i, comm->comm(), stream)); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv( + recv_buf + offset, send_numel, dtype, i, comm->comm(), stream)); + offset += send_numel; + } + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); + } +#else + PADDLE_THROW(platform::errors::Unimplemented( + "PaddlePaddle should compile with NCCL or RCCL when used tensor model " + "parallel op.")); +#endif +} + +template +static void AllGather(Tensor& tensor, // NOLINT + Tensor& out, + const int ring_id, + const phi::GPUContext& ctx) { + if (ring_id == -1) return; +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); + + if (map->has(ring_id)) { + paddle::distributed::ProcessGroup* pg = map->get(ring_id); + auto pg_nccl = static_cast(pg); + + std::vector in_tensor; + std::vector out_tensor; + in_tensor.push_back(tensor); + out_tensor.push_back(out); + auto task = pg_nccl->AllGather(in_tensor, out_tensor, true, true); + task->Wait(); + } else { + auto dtype = platform::ToNCCLDataType( + framework::TransToProtoVarType(tensor.dtype())); + int64_t numel = tensor.numel(); + auto place = ctx.GetPlace(); + auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); + auto stream = ctx.stream(); + auto out_dims = tensor.dims(); + int nranks = comm->nranks(); + out_dims[0] *= nranks; + out.mutable_data(out_dims, place); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather( + tensor.data(), out.data(), numel, dtype, comm->comm(), stream)); + } +#else + PADDLE_THROW(platform::errors::Unimplemented( + "PaddlePaddle should compile with NCCL or RCCL when used tensor model " + "parallel op.")); +#endif +} + +template +class FusedMoeOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + using U = LayerNormParamType; + auto& dev_ctx = context.cuda_device_context(); + // input + auto* x = context.Input("X"); + auto* gate_weight = context.Input("GateWeight"); + auto* gate_bias = context.Input("GateBias"); + const bool pre_layer_norm = context.Attr("pre_layer_norm"); + auto* ln_scale = + pre_layer_norm ? context.Input("LnScale") : nullptr; + auto* ln_bias = + pre_layer_norm ? context.Input("LnBias") : nullptr; + // linear 1 + auto experts_weight1 = context.MultiInput("ExpertsWeight1"); + auto experts_bias1 = context.MultiInput("ExpertsBias1"); + // linear 2 + auto experts_weight2 = context.MultiInput("ExpertsWeight2"); + auto experts_bias2 = context.MultiInput("ExpertsBias2"); + + // output + auto* out = context.Output("Out"); + dev_ctx.Alloc(out, out->numel() * sizeof(T)); + + // attr + const float epsilon = context.Attr("ln_epsilon"); + const int topk = context.Attr("topk"); + const int mp_size = context.Attr("mp_size"); + const int mp_rank = context.Attr("mp_rank"); + const int num_expert = context.Attr("num_expert"); + const int world_size = context.Attr("world_size"); + const int moe_ring_id = context.Attr("moe_ring_id"); + + // dim + auto x_dim = x->dims(); + int bsz = x_dim[0]; + int seq_len = x_dim[1]; + int bsz_seq = bsz * seq_len; + int d_model = x_dim[2]; + int tot_expert = world_size * num_expert; + int dim_feedforward = experts_weight1[0]->dims()[1]; + + // pre_layer_norm + const U* ln_scale_ptr = + ln_scale == nullptr ? nullptr : ln_scale->data(); + const U* ln_bias_ptr = + ln_bias == nullptr ? nullptr : ln_bias->data(); + Tensor ln_mean, ln_variance; + ln_mean.Resize({{bsz_seq}}); + auto* ln_mean_data = + dev_ctx.Alloc(&ln_mean, ln_mean.numel() * sizeof(U)); + ln_variance.Resize({{bsz_seq}}); + auto* ln_variance_data = + dev_ctx.Alloc(&ln_variance, ln_variance.numel() * sizeof(U)); + FusedDropoutLayerNormHelper pre_layernorm_helper( + bsz_seq, d_model, epsilon); + // tmp out + Tensor ln_out; + ln_out.Resize({{bsz, seq_len, d_model}}); + auto *ln_out_data = dev_ctx.Alloc(&ln_out, ln_out.numel() * sizeof(T)); + // after slice, bsz_seq should be change + int sliced_bsz_seq = bsz_seq; + int start = 0; + int end = 0; + if (mp_size > 1) { + start = bsz_seq / world_size * mp_rank; + end = std::min(start + bsz_seq / world_size, bsz_seq); + sliced_bsz_seq = end - start; + } + int out_batch_size = sliced_bsz_seq * topk; + // slice + Tensor sliced_inp; + sliced_inp.Resize({{sliced_bsz_seq, d_model}}); + auto* sliced_inp_data = dev_ctx.Alloc(&sliced_inp, sliced_inp.numel() * sizeof(T)); + // gate linear + Tensor gate_out; + gate_out.Resize({{sliced_bsz_seq, tot_expert}}); + auto* gate_out_data = dev_ctx.Alloc(&gate_out, gate_out.numel() * sizeof(T)); + auto gate_linear_compute = AttnMatMul( + dev_ctx, false, false, sliced_bsz_seq, tot_expert, d_model, true); + // topk + Tensor topk_value, topk_idx; + topk_value.Resize({{sliced_bsz_seq, topk}}); + auto* topk_value_data = dev_ctx.Alloc(&topk_value, topk_value.numel() * sizeof(T)); + topk_idx.Resize({{sliced_bsz_seq, topk}}); + auto* topk_idx_data = dev_ctx.Alloc(&topk_idx, topk_idx.numel() * sizeof(int64_t)); + // local expert count, global expert count + Tensor local_expert_count, global_expert_count; + local_expert_count.Resize({{tot_expert}}); + global_expert_count.Resize({{tot_expert}}); + auto* local_expert_count_data = + dev_ctx.Alloc(&local_expert_count, local_expert_count.numel() * sizeof(int64_t)); + auto* global_expert_count_data = + dev_ctx.Alloc(&global_expert_count, global_expert_count.numel() * sizeof(int64_t)); + // fwd_expert_count, fwd_batch_size + Tensor fwd_expert_count, fwd_batch_size; + fwd_expert_count.Resize({{world_size, num_expert}}); + fwd_batch_size.Resize({{1}}); + auto* fwd_expert_count_data = + dev_ctx.Alloc(&fwd_expert_count, fwd_expert_count.numel() * sizeof(int64_t)); + auto* fwd_batch_size_data = + dev_ctx.Alloc(&fwd_batch_size, fwd_batch_size.numel() * sizeof(int64_t)); + // pos, temp pos + Tensor pos, temp_pos; + pos.Resize({{out_batch_size}}); + temp_pos.Resize({{out_batch_size}}); + auto* pos_data = dev_ctx.Alloc(&pos, pos.numel() * sizeof(int64_t)); + auto* temp_pos_data = dev_ctx.Alloc(&temp_pos, temp_pos.numel() * sizeof(int64_t)); + // cumsum + Tensor lec_cum; + lec_cum.Resize({{tot_expert}}); + auto* lec_cum_data = dev_ctx.Alloc(&lec_cum, lec_cum.numel() * sizeof(int64_t)); + // fused moe ffn tmp out + Tensor index_select_out; + index_select_out.Resize({{out_batch_size, d_model}}); + auto* index_select_out_data = dev_ctx.Alloc(&index_select_out, + index_select_out.numel() * sizeof(T)); + Tensor global_gather_out; + global_gather_out.Resize({{out_batch_size, d_model}}); + auto* global_gather_out_data = dev_ctx.Alloc(&global_gather_out, + global_gather_out.numel() * sizeof(T)); + Tensor moe_gather_out; + moe_gather_out.Resize({{out_batch_size, d_model}}); + auto* moe_gather_out_data = dev_ctx.Alloc(&moe_gather_out, + moe_gather_out.numel() * sizeof(T)); + Tensor bmm_out; + bmm_out.Resize({{sliced_bsz_seq, 1, d_model}}); + auto* bmm_out_data = dev_ctx.Alloc(&bmm_out, bmm_out.numel() * sizeof(T)); + Tensor all_gather_out; + all_gather_out.Resize({{bsz_seq, d_model}}); + auto* all_gather_out_data = + dev_ctx.Alloc(&all_gather_out, all_gather_out.numel() * sizeof(T)); + DropoutParam dropout_param(false, 0, true, true, 0.0, nullptr, 0); + + // step1 layer norm + if (pre_layer_norm) { + pre_layernorm_helper.LayerNorm(dev_ctx, + x->data(), + ln_scale_ptr, + ln_bias_ptr, + ln_out_data, + ln_mean_data, + ln_variance_data); + } else { + ln_out = *x; + } + // step2 resize and slice ln_out + ln_out.Resize({{bsz_seq, d_model}}); + if (mp_size > 1) { + sliced_inp = ln_out.Slice(start, end); + } else { + sliced_inp = ln_out; + } + // step3 gate & topk + gate_linear_compute.ComputeForward(gate_weight, &sliced_inp, gate_bias, &gate_out, &gate_out); + phi::TopkKernel(dev_ctx, + gate_out, + phi::Scalar(topk), + -1, + true, + false, + &topk_value, + &topk_idx); + // step4 prepare forward + // step4.1 number count + NumberCountCompute(dev_ctx, &topk_idx, tot_expert, &local_expert_count); + // step4.2 all_to_all + if (world_size > 1) { + AllToAll(local_expert_count, global_expert_count, moe_ring_id, dev_ctx); + } else { + global_expert_count = local_expert_count; + } + // global expert count resize + global_expert_count.Resize({{world_size, num_expert}}); + // fwd expert count + phi::SumKernel(dev_ctx, + global_expert_count, + phi::IntArray({0}), + global_expert_count.dtype(), + false, + &fwd_expert_count); + // fwd batch size + phi::SumKernel(dev_ctx, + fwd_expert_count, + phi::IntArray({}), // axis is None + fwd_expert_count.dtype(), + false, + &fwd_batch_size); + // step4.3 cumsum & assign pos + phi::CumsumKernel(dev_ctx, + local_expert_count, + phi::Scalar(0), + false, + false, + false, + &lec_cum); + AssignPosCompute(dev_ctx, &lec_cum, &topk_idx, &pos); + if (topk > 1) { + Tensor topk_tensor; + topk_tensor.Resize({{1}}); + auto *topk_tensor_data = dev_ctx.Alloc(&topk_tensor, topk_tensor.numel() * sizeof(int64_t)); + phi::FullKernel(dev_ctx, {1}, topk, pos.dtype(), &topk_tensor); + phi::FloorDivideKernel(dev_ctx, + pos, + topk_tensor, + &temp_pos); + } else { + temp_pos = pos; + } + Tensor fwd_expert_count_cpu; + framework::TensorCopySync(fwd_expert_count, platform::CPUPlace(), &fwd_expert_count_cpu); + Tensor fwd_batch_size_cpu; + framework::TensorCopySync(fwd_batch_size, platform::CPUPlace(), &fwd_batch_size_cpu); + int fwd_bsz = fwd_batch_size_cpu.data()[0]; + + Tensor global_scatter_out; + global_scatter_out.Resize({{fwd_bsz, d_model}}); + auto* global_scatter_out_data = dev_ctx.Alloc(&global_scatter_out, + global_scatter_out.numel() * sizeof(T)); + std::vector tmp_expert_out; + Tensor all_expert_out; + all_expert_out.Resize({{fwd_bsz, d_model}}); + auto* all_expert_out_data = dev_ctx.Alloc(&all_expert_out, + all_expert_out.numel() * sizeof(T)); + // step 5, MOEScatter + // step 5.1, index select + // suppose tmp_pos->shape != [0] + phi::IndexSelectKernel(dev_ctx, sliced_inp, temp_pos, 0, &index_select_out); + if (world_size > 1) { + auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); + // step 5.2, global_scatter + if (map->has(moe_ring_id)) { + GlobalScatterProcessGroupFunctor functor_; + functor_(dev_ctx, + &index_select_out, + &local_expert_count, + &global_expert_count, + moe_ring_id, + true, + &global_scatter_out); + } else { + GlobalScatterFunctor functor_; + functor_(dev_ctx, + &index_select_out, + &local_expert_count, + &global_expert_count, + moe_ring_id, + true, + &global_scatter_out); + } + } else { + global_scatter_out = index_select_out; + } + // step 6, Expert Computation + if (global_scatter_out.dims()[0] != 0) { + int last_index = 0; + for (int idx = 0; idx < num_expert; idx++) { + int cur_expert_count = fwd_expert_count_cpu.data()[idx]; + if (cur_expert_count <= 0) { + continue; + } + int end = cur_expert_count + last_index; + Tensor expert_out1; + expert_out1.Resize({{cur_expert_count, dim_feedforward}}); + auto *expert_out1_data = dev_ctx.Alloc(&expert_out1, + expert_out1.numel() * sizeof(T)); + Tensor act_bias_out; + act_bias_out.Resize({{cur_expert_count, dim_feedforward}}); + auto *act_bias_out_data = dev_ctx.Alloc(&act_bias_out, + act_bias_out.numel() * sizeof(T)); + Tensor expert_out2; + expert_out2.Resize({{cur_expert_count, d_model}}); + auto *expert_out2_data = dev_ctx.Alloc(&expert_out2, + expert_out2.numel() * sizeof(T)); + FusedDropoutHelper fused_act_dropout_helper( + dev_ctx, cur_expert_count, dim_feedforward, dropout_param); + + Tensor tmp_inp = global_scatter_out.Slice(last_index, end); + // linear1 matmul + MatMulAndAdd(dev_ctx, + experts_weight1[idx], + &tmp_inp, + nullptr, + false, + false, + false, // dont compute bias + &expert_out1, + nullptr); + // bias gelu + fused_act_dropout_helper.DropoutActBias(dev_ctx, + expert_out1.data(), + experts_bias1[idx]->data(), + "gelu", + act_bias_out.data(), + nullptr); + // linear2 matmul & add + MatMulAndAdd(dev_ctx, + experts_weight2[idx], + &act_bias_out, + experts_bias2[idx], + false, + false, + true, // compute bias + &expert_out2, + &expert_out2); + tmp_expert_out.emplace_back(expert_out2); + last_index = end; + } + phi::funcs::ConcatFunctor concat; + concat(dev_ctx, tmp_expert_out, 0, &all_expert_out); + } else { + all_expert_out = global_scatter_out; + } + // step7. MOEGather + if (world_size > 1) { + auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); + // step 7.1, global_gather + if (map->has(moe_ring_id)) { + GlobalGatherProcessGroupFunctor functor_; + functor_(dev_ctx, + &all_expert_out, + &local_expert_count, + &global_expert_count, + moe_ring_id, + true, + &global_gather_out); + } else { + GlobalGatherFunctor functor_; + functor_(dev_ctx, + &all_expert_out, + &local_expert_count, + &global_expert_count, + moe_ring_id, + true, + &global_gather_out); + } + } else { + global_gather_out = all_expert_out; + } + // step 7.2, local_gather or scatter + // suppose pos->shape != [0] + phi::ScatterKernel(dev_ctx, + moe_gather_out, + pos, + global_gather_out, + true, + &moe_gather_out); + // step 8, reshape & bmm + if (topk > 1) { + // moe gather out reshape + moe_gather_out.Resize({{sliced_bsz_seq, topk, d_model}}); + topk_value.Resize({{sliced_bsz_seq, 1, topk}}); + phi::BmmKernel(dev_ctx, topk_value, moe_gather_out, &bmm_out); + bmm_out.Resize({{sliced_bsz_seq, d_model}}); + } else { + bmm_out = moe_gather_out; + } + // step 9, AllGather + if (mp_size > 1) { + // all gather + AllGather(bmm_out, all_gather_out, moe_ring_id, dev_ctx); + } else { + all_gather_out = bmm_out; + } + // step 10, reshape + all_gather_out.Resize(x_dim); + // step 11, add residual + phi::AddKernel(dev_ctx, all_gather_out, *x, out); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + fused_moe, + ops::FusedMoeOpKernel, + ops::FusedMoeOpKernel, + ops::FusedMoeOpKernel); diff --git a/paddle/fluid/operators/fused/fused_moe_op.h b/paddle/fluid/operators/fused/fused_moe_op.h new file mode 100644 index 0000000000000..5ebfbff589850 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_moe_op.h @@ -0,0 +1,718 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/fluid/operators/fused/fused_dropout_helper.h" +#include "paddle/fluid/operators/layer_norm_kernel.cu.h" +#include "paddle/fluid/operators/fused/attn_gemm.h" +#include "paddle/fluid/operators/matmul_v2_op.h" +#include "paddle/phi/api/include/tensor.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/broadcast_function.h" +#include "paddle/phi/kernels/funcs/elementwise_functor.h" +#include "paddle/phi/kernels/top_k_kernel.h" +#include "paddle/phi/kernels/cum_kernel.h" +#include "paddle/phi/kernels/reduce_sum_kernel.h" +#include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/elementwise_kernel.h" +#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" +#include "paddle/phi/kernels/funcs/functors.h" +#include "paddle/phi/kernels/index_select_kernel.h" +#include "paddle/phi/kernels/scatter_kernel.h" +#include "paddle/fluid/operators/collective/global_scatter_op.h" +#include "paddle/fluid/operators/collective/global_gather_op.h" +#include "paddle/phi/kernels/bmm_kernel.h" +#include "paddle/phi/kernels/elementwise_add_kernel.h" +// #include "paddle/fluid/framework/convert_utils.h" +// #include "paddle/fluid/platform/float16.h" + +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h" +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/device/gpu/nccl_helper.h" +#endif + +namespace paddle { +namespace operators { +// number count +#define CEIL(_x_, _y_) (((_x_)-1) / (_y_) + 1) +#define PERTHREAD_EXPERTS 256 +#define WARP_SIZE 32 + +const int CUDA_NUM_THREADS = 512; +static inline int GET_BLOCKS(const int N) { + return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; +} + +template +__global__ void initialize_zero_kernel(T* data, const int length) { + CUDA_KERNEL_LOOP(idx, length) { data[idx] = static_cast(0); } +} + +template +__global__ void NumberCount(const T* numbers, + T* number_count, + int64_t batch_size, + int upper_range) { + int res_tmp[PERTHREAD_EXPERTS] = {0}; + int expert_min = blockIdx.x * PERTHREAD_EXPERTS; + int expert_max = expert_min + PERTHREAD_EXPERTS; + if (expert_max > upper_range) { + expert_max = upper_range; + } + for (int i = threadIdx.x; i < batch_size; i += blockDim.x) { + T idx = numbers[i]; + if (idx == -1) { + continue; + } + if (idx < expert_min || idx >= expert_max) { + continue; + } + res_tmp[idx - expert_min] += 1; + } + for (int i = expert_min; i < expert_max; ++i) { + int x = res_tmp[i - expert_min]; +#pragma unroll + for (int j = 1; j < WARP_SIZE; j <<= 1) { +#ifdef __HIPCC__ + x = x + __shfl_down(x, j); +#else + x = x + __shfl_down_sync(-1u, x, j); +#endif + } + if (threadIdx.x % WARP_SIZE == 0) { + platform::CudaAtomicAdd(number_count + i, x); + } + } +} + +template +void NumberCountCompute(const phi::GPUContext &dev_ctx, + framework::Tensor* numbers, + int upper_range, + framework::Tensor* out) { + int64_t batch_size = numbers->numel(); + auto place = dev_ctx.GetPlace(); + + framework::DDim out_dims = phi::make_ddim({upper_range}); + auto out_data = out->mutable_data(out_dims, place); + const T* gate_data = numbers->data(); + + initialize_zero_kernel + <<>>( + out_data, upper_range); + + NumberCount + <<>>( + gate_data, out_data, batch_size, upper_range); +} + +// assign pos +static constexpr int kNumCUDAThreads = 512; +static constexpr int kNumMaxinumNumBlocks = 4096; + +static inline int NumBlocks(const int N) { + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, + kNumMaxinumNumBlocks); +} + +template +__global__ void AssignPos(T* cum_count, + const T* numbers, + T* out, + int64_t limit) { + CUDA_KERNEL_LOOP(i, limit) { + int number_idx = numbers[i]; + if (number_idx > -1) { + int p = platform::CudaAtomicAdd(cum_count + number_idx, -1); + out[p - 1] = i; + } + } +} + +template +void AssignPosCompute(const phi::GPUContext &dev_ctx, + framework::Tensor* cum_count, // (counter number) int32 | int64 + framework::Tensor* numbers, // (batch_size * seq_len, topk) int32 + framework::Tensor* out) { + auto place = dev_ctx.GetPlace(); + auto numel = numbers->numel(); + T* cum_data = const_cast(cum_count->data()); + auto cum_size = cum_count->numel(); + + framework::Tensor cpu_cum_count; + int64_t cpu_eff_num_len_data = 0; + if (platform::is_cpu_place(cum_count->place())) { + cpu_eff_num_len_data = cum_count->data()[cum_size - 1]; + } else { + framework::TensorCopySync( + *cum_count, platform::CPUPlace(), &cpu_cum_count); + cpu_eff_num_len_data = cpu_cum_count.data()[cum_size - 1]; + } + + framework::DDim out_dims = phi::make_ddim({cpu_eff_num_len_data}); + auto out_data = out->mutable_data(out_dims, place); + + const T* num_data = numbers->data(); + + int blocks = NumBlocks(numel); + int threads = kNumCUDAThreads; + + AssignPos<<>>( + cum_data, num_data, out_data, numel); +} + +template +struct GlobalScatterFunctor { + void operator()(const phi::GPUContext& ctx, + const framework::Tensor* x, + const framework::Tensor* local_count, + const framework::Tensor* global_count, + int ring_id, + bool use_calc_stream, + framework::Tensor* out) { +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#if NCCL_VERSION_CODE >= 2703 + // auto x = ctx.Input("X"); + // auto local_count = ctx.Input("local_count"); + // auto global_count = ctx.Input("global_count"); + auto local_count_type = + framework::TransToProtoVarType(local_count->dtype()); + auto global_count_type = + framework::TransToProtoVarType(global_count->dtype()); + if (local_count_type != framework::proto::VarType::INT64) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Please use int64 type in local_count.")); + } + if (global_count_type != framework::proto::VarType::INT64) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Please use int64 type in global_count.")); + } + // auto out = ctx.Output("Out"); + const int64_t* cpu_local_count_data; + const int64_t* cpu_global_count_data; + framework::Tensor cpu_local_count; + if (platform::is_cpu_place(local_count->place())) { + cpu_local_count_data = local_count->data(); + } else { + framework::TensorCopySync( + *local_count, platform::CPUPlace(), &cpu_local_count); + cpu_local_count_data = cpu_local_count.data(); + } + auto global_count_len = 0; + framework::Tensor cpu_global_count; + if (platform::is_cpu_place(global_count->place())) { + cpu_global_count_data = global_count->data(); + global_count_len = global_count->numel(); + } else { + framework::TensorCopySync( + *global_count, platform::CPUPlace(), &cpu_global_count); + cpu_global_count_data = cpu_global_count.data(); + global_count_len = cpu_global_count.numel(); + } + + ncclDataType_t dtype = + platform::ToNCCLDataType(framework::TransToProtoVarType(x->dtype())); + + // int ring_id = ctx.Attr("ring_id"); + PADDLE_ENFORCE_GE( + ring_id, + 0, + platform::errors::InvalidArgument( + "The ring_id (%d) for global scatter op must be non-negative.", + ring_id)); + + auto place = ctx.GetPlace(); + // HARD CODE HERE! + // auto place = platform::CUDAPlace(); + auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); + gpuStream_t stream = nullptr; + if (use_calc_stream) { + // auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); + // stream = static_cast(dev_ctx)->stream(); + stream = ctx.stream(); + } else { + stream = comm->stream(); + } + int nranks = comm->nranks(); + auto in_feat = x->dims()[1]; + auto n_expert = local_count->dims()[0] / nranks; + int64_t fwd_count = 0; + + for (auto i = 0; i < global_count_len; ++i) { + fwd_count += cpu_global_count_data[i]; + } + framework::DDim out_dims = phi::make_ddim({fwd_count, in_feat}); + int64_t* expert_ptr = new int64_t[n_expert * nranks]; + expert_ptr[0] = 0; + auto tot_experts = n_expert * nranks; + for (auto i = 1; i < tot_experts; ++i) { + expert_ptr[i] = expert_ptr[i - 1] + cpu_local_count_data[i - 1]; + } + + auto recv_ptr = 0; + auto send_buf = x->data(); + auto recv_buf = out->mutable_data(out_dims, place); + + for (auto i = 0; i < n_expert; ++i) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); + for (auto j = 0; j < nranks; ++j) { + int idx = i + j * n_expert; + if (cpu_local_count_data[idx]) { + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::ncclSend(send_buf + expert_ptr[idx] * in_feat, + cpu_local_count_data[idx] * in_feat, + dtype, + j, + comm->comm(), + stream)); + } + if (cpu_global_count_data[idx]) { + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::ncclRecv(recv_buf + recv_ptr * in_feat, + cpu_global_count_data[idx] * in_feat, + dtype, + j, + comm->comm(), + stream)); + recv_ptr += cpu_global_count_data[idx]; + } + } + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); + } + +#else + PADDLE_THROW( + platform::errors::Unavailable("NCCL version >= 2.7.3 is needed.")); +#endif +#else + PADDLE_THROW( + platform::errors::Unavailable("PaddlePaddle should compile with GPU.")); +#endif + } +}; + +template +struct GlobalScatterProcessGroupFunctor { + void operator()(const phi::GPUContext& ctx, + const framework::Tensor* x, + const framework::Tensor* local_count, + const framework::Tensor* global_count, + int ring_id, + bool use_calc_stream, + framework::Tensor* out) { +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#if NCCL_VERSION_CODE >= 2703 + // auto x = ctx.Input("X"); + // auto local_count = ctx.Input("local_count"); + // auto global_count = ctx.Input("global_count"); + auto local_count_type = + framework::TransToProtoVarType(local_count->dtype()); + auto global_count_type = + framework::TransToProtoVarType(global_count->dtype()); + if (local_count_type != framework::proto::VarType::INT64) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Please use int64 type in local_count.")); + } + if (global_count_type != framework::proto::VarType::INT64) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Please use int64 type in global_count.")); + } + // auto out = ctx.Output("Out"); + const int64_t* cpu_local_count_data; + const int64_t* cpu_global_count_data; + framework::Tensor cpu_local_count; + if (platform::is_cpu_place(local_count->place())) { + cpu_local_count_data = local_count->data(); + } else { + framework::TensorCopySync( + *local_count, platform::CPUPlace(), &cpu_local_count); + cpu_local_count_data = cpu_local_count.data(); + } + auto global_count_len = 0; + framework::Tensor cpu_global_count; + if (platform::is_cpu_place(global_count->place())) { + cpu_global_count_data = global_count->data(); + global_count_len = global_count->numel(); + } else { + framework::TensorCopySync( + *global_count, platform::CPUPlace(), &cpu_global_count); + cpu_global_count_data = cpu_global_count.data(); + global_count_len = cpu_global_count.numel(); + } + + // int ring_id = ctx.Attr("ring_id"); + PADDLE_ENFORCE_GE( + ring_id, + 0, + platform::errors::InvalidArgument( + "The ring_id (%d) for global scatter op must be non-negative.", + ring_id)); + + auto place = ctx.GetPlace(); + // HARD CODE HERE! + // auto place = platform::CUDAPlace(); + + auto map = distributed::ProcessGroupMapFromGid::getInstance(); + distributed::ProcessGroup* pg = map->get(ring_id); + int nranks = pg->GetSize(); + auto in_feat = x->dims()[1]; + auto n_expert = local_count->dims()[0] / nranks; + int64_t fwd_count = 0; + + for (auto i = 0; i < global_count_len; ++i) { + fwd_count += cpu_global_count_data[i]; + } + framework::DDim out_dims = phi::make_ddim({fwd_count, in_feat}); + int64_t* expert_ptr = new int64_t[n_expert * nranks]; + expert_ptr[0] = 0; + auto tot_experts = n_expert * nranks; + for (auto i = 1; i < tot_experts; ++i) { + expert_ptr[i] = expert_ptr[i - 1] + cpu_local_count_data[i - 1]; + } + + auto recv_ptr = 0; + out->mutable_data(out_dims, place); + + for (auto i = 0; i < n_expert; ++i) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); + for (auto j = 0; j < nranks; ++j) { + int idx = i + j * n_expert; + if (cpu_local_count_data[idx]) { + phi::DenseTensor tmp = *x; + pg->Send_Partial(tmp, + j, + expert_ptr[idx] * in_feat, + cpu_local_count_data[idx] * in_feat); + } + if (cpu_global_count_data[idx]) { + pg->Recv_Partial(*out, + j, + recv_ptr * in_feat, + cpu_global_count_data[idx] * in_feat); + recv_ptr += cpu_global_count_data[idx]; + } + } + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); + } + +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); +#else + PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize()); +#endif + +#else + PADDLE_THROW( + platform::errors::Unavailable("NCCL version >= 2.7.3 is needed.")); +#endif +#else + PADDLE_THROW( + platform::errors::Unavailable("PaddlePaddle should compile with GPU.")); +#endif + } +}; + +template +struct GlobalGatherFunctor { + void operator()(const phi::GPUContext& ctx, + const framework::Tensor* x, + const framework::Tensor* local_count, + const framework::Tensor* global_count, + int ring_id, + bool use_calc_stream, + framework::Tensor* out) { +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#if NCCL_VERSION_CODE >= 2703 + // auto x = ctx.Input("X"); + // auto local_count = ctx.Input("local_count"); + // auto global_count = ctx.Input("global_count"); + auto local_count_type = + framework::TransToProtoVarType(local_count->dtype()); + auto global_count_type = + framework::TransToProtoVarType(global_count->dtype()); + if (local_count_type != framework::proto::VarType::INT64) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Please use int64 type in local_count.")); + } + if (global_count_type != framework::proto::VarType::INT64) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Please use int64 type in global_count.")); + } + // auto out = ctx.Output("Out"); + const int64_t* cpu_local_count_data; + const int64_t* cpu_global_count_data; + auto local_count_len = 0; + + framework::Tensor cpu_local_count; + if (platform::is_cpu_place(local_count->place())) { + cpu_local_count_data = local_count->data(); + local_count_len = local_count->numel(); + } else { + framework::TensorCopySync( + *local_count, platform::CPUPlace(), &cpu_local_count); + cpu_local_count_data = cpu_local_count.data(); + local_count_len = cpu_local_count.numel(); + } + + framework::Tensor cpu_global_count; + if (platform::is_cpu_place(global_count->place())) { + cpu_global_count_data = global_count->data(); + } else { + framework::TensorCopySync( + *global_count, platform::CPUPlace(), &cpu_global_count); + cpu_global_count_data = cpu_global_count.data(); + } + + ncclDataType_t dtype = + platform::ToNCCLDataType(framework::TransToProtoVarType(x->dtype())); + + // int ring_id = ctx.Attr("ring_id"); + PADDLE_ENFORCE_GE( + ring_id, + 0, + platform::errors::InvalidArgument( + "The ring_id (%d) for global gather op must be non-negative.", + ring_id)); + auto place = ctx.GetPlace(); + // auto place = platform::CUDAPlace(); + + auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); + gpuStream_t stream = nullptr; + if (use_calc_stream) { + // auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); + // stream = static_cast(dev_ctx)->stream(); + stream = ctx.stream(); + } else { + stream = comm->stream(); + } + int nranks = comm->nranks(); + auto in_feat = x->dims()[1]; + auto n_expert = local_count->dims()[0] / nranks; + + auto fwd_count = 0; + + for (auto i = 0; i < local_count_len; ++i) { + fwd_count += cpu_local_count_data[i]; + } + framework::DDim out_dims = phi::make_ddim({fwd_count, in_feat}); + int64_t* expert_ptr = new int64_t[n_expert * nranks]; + expert_ptr[0] = 0; + auto tot_experts = n_expert * nranks; + for (auto i = 1; i < tot_experts; ++i) { + expert_ptr[i] = expert_ptr[i - 1] + cpu_local_count_data[i - 1]; + } + auto send_ptr = 0; + auto send_buf = x->data(); + auto recv_buf = out->mutable_data(out_dims, place); + + for (auto i = 0; i < n_expert; ++i) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); + for (auto j = 0; j < nranks; ++j) { + int idx = i + j * n_expert; + if (cpu_global_count_data[idx]) { + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::ncclSend(send_buf + send_ptr * in_feat, + cpu_global_count_data[idx] * in_feat, + dtype, + j, + comm->comm(), + stream)); + send_ptr += cpu_global_count_data[idx]; + } + if (cpu_local_count_data[idx]) { + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::ncclRecv(recv_buf + expert_ptr[idx] * in_feat, + cpu_local_count_data[idx] * in_feat, + dtype, + j, + comm->comm(), + stream)); + } + } + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); + } +#else + PADDLE_THROW( + platform::errors::Unavailable("NCCL version >= 2.7.3 is needed.")); +#endif +#else + PADDLE_THROW( + platform::errors::Unavailable("PaddlePaddle should compile with GPU.")); +#endif + } +}; + +template +struct GlobalGatherProcessGroupFunctor { + void operator()(const phi::GPUContext& ctx, + const framework::Tensor* x, + const framework::Tensor* local_count, + const framework::Tensor* global_count, + int ring_id, + bool use_calc_stream, + framework::Tensor* out) { +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#if NCCL_VERSION_CODE >= 2703 + // auto x = ctx.Input("X"); + // auto local_count = ctx.Input("local_count"); + // auto global_count = ctx.Input("global_count"); + auto local_count_type = + framework::TransToProtoVarType(local_count->dtype()); + auto global_count_type = + framework::TransToProtoVarType(global_count->dtype()); + if (local_count_type != framework::proto::VarType::INT64) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Please use int64 type in local_count.")); + } + if (global_count_type != framework::proto::VarType::INT64) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Please use int64 type in global_count.")); + } + // auto out = ctx.Output("Out"); + const int64_t* cpu_local_count_data; + const int64_t* cpu_global_count_data; + auto local_count_len = 0; + + framework::Tensor cpu_local_count; + if (platform::is_cpu_place(local_count->place())) { + cpu_local_count_data = local_count->data(); + local_count_len = local_count->numel(); + } else { + framework::TensorCopySync( + *local_count, platform::CPUPlace(), &cpu_local_count); + cpu_local_count_data = cpu_local_count.data(); + local_count_len = cpu_local_count.numel(); + } + + framework::Tensor cpu_global_count; + if (platform::is_cpu_place(global_count->place())) { + cpu_global_count_data = global_count->data(); + } else { + framework::TensorCopySync( + *global_count, platform::CPUPlace(), &cpu_global_count); + cpu_global_count_data = cpu_global_count.data(); + } + + // int ring_id = ctx.Attr("ring_id"); + PADDLE_ENFORCE_GE( + ring_id, + 0, + platform::errors::InvalidArgument( + "The ring_id (%d) for global gather op must be non-negative.", + ring_id)); + auto place = ctx.GetPlace(); + // auto place = platform::CUDAPlace(); + + auto map = distributed::ProcessGroupMapFromGid::getInstance(); + distributed::ProcessGroup* pg = map->get(ring_id); + + int nranks = pg->GetSize(); + auto in_feat = x->dims()[1]; + auto n_expert = local_count->dims()[0] / nranks; + + auto fwd_count = 0; + + for (auto i = 0; i < local_count_len; ++i) { + fwd_count += cpu_local_count_data[i]; + } + framework::DDim out_dims = phi::make_ddim({fwd_count, in_feat}); + int64_t* expert_ptr = new int64_t[n_expert * nranks]; + expert_ptr[0] = 0; + auto tot_experts = n_expert * nranks; + for (auto i = 1; i < tot_experts; ++i) { + expert_ptr[i] = expert_ptr[i - 1] + cpu_local_count_data[i - 1]; + } + auto send_ptr = 0; + out->mutable_data(out_dims, place); + + for (auto i = 0; i < n_expert; ++i) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); + for (auto j = 0; j < nranks; ++j) { + int idx = i + j * n_expert; + if (cpu_global_count_data[idx]) { + phi::DenseTensor tmp = *x; + pg->Send_Partial( + tmp, j, send_ptr * in_feat, cpu_global_count_data[idx] * in_feat); + send_ptr += cpu_global_count_data[idx]; + } + if (cpu_local_count_data[idx]) { + pg->Recv_Partial(*out, + j, + expert_ptr[idx] * in_feat, + cpu_local_count_data[idx] * in_feat); + } + } + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); + } + +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); +#else + PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize()); +#endif + +#else + PADDLE_THROW( + platform::errors::Unavailable("NCCL version >= 2.7.3 is needed.")); +#endif +#else + PADDLE_THROW( + platform::errors::Unavailable("PaddlePaddle should compile with GPU.")); +#endif + } +}; + +template +void MatMulAndAdd(const phi::GPUContext& dev_ctx, + const framework::Tensor* weight, + const framework::Tensor* input, + const framework::Tensor* bias, + bool istransA, + bool istransB, + bool compute_bias, + framework::Tensor* output, + framework::Tensor* bias_out) { + // Note: for blas.GEMM API in Paddle, it treats all inputs as row-major. + // here: (transa, transb): nt, input * weight. + CBLAS_TRANSPOSE transA = istransA ? CblasTrans : CblasNoTrans; + CBLAS_TRANSPOSE transB = istransB ? CblasTrans : CblasNoTrans; + T alpha = static_cast(1.0); + T beta = static_cast(0.0); + + // (m, n, k) = bsz_seq, output_size, input_size, (input, weight, out) + auto blas = phi::funcs::GetBlas(dev_ctx); + blas.GEMM(transA, + transB, + input->dims()[0], + weight->dims()[1], + input->dims()[1], + alpha, + input->data(), + weight->data(), + beta, + output->data()); + if (compute_bias) { + // bias_out = output + bias + std::vector ins = {output, bias}; + std::vector outs = {bias_out}; + phi::funcs::BroadcastKernel( + dev_ctx, ins, &outs, -1, phi::funcs::AddFunctor()); + } +} + +} // namesapce operators +} //namespace paddle \ No newline at end of file diff --git a/paddle/fluid/pybind/op_function_generator.h b/paddle/fluid/pybind/op_function_generator.h index af080bd0b3431..9f0ca6b62d3cb 100644 --- a/paddle/fluid/pybind/op_function_generator.h +++ b/paddle/fluid/pybind/op_function_generator.h @@ -71,6 +71,16 @@ std::map> op_ins_map = { "FFN1Bias", "FFN2Weight", "FFN2Bias"}}, + {"fused_moe", + {"X", + "GateWeight", + "GateBias", + "LnScale", + "LnBias", + "ExpertsWeight1", + "ExpertsBias1", + "ExpertsWeight2", + "ExpertsBias2"}}, {"fused_multi_transformer_int8", {"X", "LnScale", "LnBias", "QKVW", "QKVBias", "CacheKV", "TimeStep", "SrcMask", @@ -335,6 +345,7 @@ std::map> op_outs_map = { "Beta2PowOut", "MasterParamOut"}}, {"fused_multi_transformer", {"CacheKVOut", "Out"}}, + {"fused_moe", {"Out"}}, {"fused_multi_transformer_int8", {"CacheKVOut", "Out"}}, {"resnet_basic_block", {"Y", "Conv1", "SavedMean1", "SavedInvstd1", "Mean1Out", diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py b/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py index b2767b1dd1cbf..677bbc0c2b288 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py @@ -166,6 +166,7 @@ def _update_list(self): 'concat', 'split', 'fused_feedforward', + 'fused_moe', 'fused_attention', 'fused_multi_transformer', } diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py index b23c94c7e4994..9b0f5fa1006aa 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py @@ -112,6 +112,8 @@ def _keep_fp32_input(op, in_name): } if op_type == 'fused_multi_transformer': return in_name in {'LnScale', 'LnBias', 'FFNLnScale', 'FFNLnBias'} + if op_type == 'fused_moe': + return in_name in {'LnScale', 'LnBias'} return False diff --git a/python/paddle/fluid/dygraph/amp/auto_cast.py b/python/paddle/fluid/dygraph/amp/auto_cast.py index decaf45125750..98efe71421cf4 100644 --- a/python/paddle/fluid/dygraph/amp/auto_cast.py +++ b/python/paddle/fluid/dygraph/amp/auto_cast.py @@ -186,13 +186,13 @@ def pure_fp16_initialize(models): if (layer._dtype == 'float16') or isinstance( layer, (paddle.nn.BatchNorm, paddle.nn.BatchNorm1D, paddle.nn.BatchNorm2D, paddle.nn.BatchNorm3D, - paddle.nn.LayerNorm, paddle.nn.SyncBatchNorm, - paddle.nn.ParameterList)): + paddle.nn.LayerNorm, paddle.nn.SyncBatchNorm, paddle.nn.ParameterList)): # tianyan01 add paddle.nn.ParameterList, hack continue if isinstance(layer, (paddle.incubate.nn.FusedFeedForward, paddle.incubate.nn.FusedMultiHeadAttention, - paddle.incubate.nn.FusedMultiTransformer)): + paddle.incubate.nn.FusedMultiTransformer, + paddle.incubate.nn.FusedMoELayer)): layer._amp_decorate(dtype='float16') continue layer._to_impl(dtype='float16', diff --git a/python/paddle/incubate/nn/__init__.py b/python/paddle/incubate/nn/__init__.py index cf15ee7d8ffaa..2a2def22bb3bf 100644 --- a/python/paddle/incubate/nn/__init__.py +++ b/python/paddle/incubate/nn/__init__.py @@ -18,6 +18,7 @@ from .layer.fused_transformer import FusedMultiTransformer # noqa: F401 from .layer.fused_linear import FusedLinear # noqa: F401 from .layer.fused_transformer import FusedBiasDropoutResidualLayerNorm # noqa: F401 +from .layer.fused_transformer import FusedMoELayer # tianyan01 add __all__ = [ #noqa 'FusedMultiHeadAttention', @@ -26,4 +27,5 @@ 'FusedMultiTransformer', 'FusedLinear', 'FusedBiasDropoutResidualLayerNorm', + 'FusedMoELayer', ] diff --git a/python/paddle/incubate/nn/layer/fused_transformer.py b/python/paddle/incubate/nn/layer/fused_transformer.py index d71086fe7b07d..8c2053c4bd47c 100644 --- a/python/paddle/incubate/nn/layer/fused_transformer.py +++ b/python/paddle/incubate/nn/layer/fused_transformer.py @@ -16,6 +16,8 @@ from paddle.nn import Layer from paddle.framework import ParamAttr import paddle +import paddle.nn as nn +from paddle import _legacy_C_ops from paddle.nn import ParameterList from paddle.nn.layer.transformer import ( _convert_attention_mask, @@ -1454,4 +1456,170 @@ def trans_to_fp16(l): trans_to_fp16(self.ffn1_biases) trans_to_fp16(self.ffn2_weights) trans_to_fp16(self.ffn2_biases) + self._dtype = dtype + + +class FusedMoELayer(Layer): + """FusedMoE Layer + Args: + d_model: (int) model dimention + num_expert: (int) expert count + top_k: (int) top-k number + some weights and bias... + moe_group: moe group for experts communication + mp_group: mp group for mp commutication + Examples: + .. code-block:: python + # required: gpu + import paddle + from paddle.incubate.nn import FusedMoELayer + + # input: [batch_size, src_len, d_model] + input = paddle.rand((2, 4, 128)) + # dim_feedforward = 128 + fused_moe_layer = FusedMoELayer(128, 128, 4, 2) + output = fused_moe_layer(input) # [2, 4, 128] + + """ + + def __init__(self, + d_model, + dim_feedforward, + num_expert, + top_k, + ln_scale=None, + ln_bias=None, + gate_weight=None, + gate_bias=None, + linear1_weights=None, + linear1_biases=None, + linear2_weights=None, + linear2_biases=None, + moe_group=None, + mp_group=None): + super(FusedMoELayer, self).__init__() + # only support mp/dp + self.group = moe_group + + self.world_size = 1 + if self.group is not None: + self.world_size = self.group.nranks + self.num_expert = num_expert + + self.mp_group = mp_group + self.mp_rank = 0 + self.mp_size = 1 + if mp_group is not None and mp_group.nranks > 1: + self.mp_rank = mp_group.rank + self.mp_size = mp_group.nranks + self.d_model = d_model + self.top_k = top_k + self.ln_scale = self.create_parameter( + shape=[d_model], + attr=None, + is_bias=False + ) + self.ln_bias = self.create_parameter( + shape=[d_model], attr=None, is_bias=True + ) + self.gate_weight = self.create_parameter( + shape=[d_model, num_expert * self.world_size], + attr=None, + dtype=self._dtype, + is_bias=False + ) + self.gate_bias = self.create_parameter( + shape=[num_expert * self.world_size], + attr=None, + dtype=self._dtype, + is_bias=True + ) + + self.linear1_weights = ParameterList() + self.linear2_weights = ParameterList() + self.linear1_biases = ParameterList() + self.linear2_biases = ParameterList() + def get_attr(attrs, idx): + if isinstance(attrs, (list, tuple, ParameterList)): + assert len(attrs) == num_expert + return attrs[idx] + return attrs + for i in range(num_expert): + w1 = get_attr(linear1_weights, i) + b1 = get_attr(linear1_biases, i) + w2 = get_attr(linear2_weights, i) + b2 = get_attr(linear2_biases, i) + + self.linear1_weights.append(self.create_parameter( + shape=[d_model, dim_feedforward], + attr=w1, + dtype=self._dtype, + is_bias=False, + default_initializer=nn.initializer.KaimingUniform() + )) + self.linear2_weights.append(self.create_parameter( + shape=[dim_feedforward, d_model], + attr=w2, + dtype=self._dtype, + is_bias=False, + default_initializer=nn.initializer.KaimingUniform() + )) + self.linear1_biases.append(self.create_parameter( + shape=[dim_feedforward], + attr=b1, + dtype=self._dtype, + is_bias=True, + default_initializer=nn.initializer.Constant(value=0.0) + )) + self.linear2_biases.append(self.create_parameter( + shape=[d_model], + attr=b2, + dtype=self._dtype, + is_bias=True, + default_initializer=nn.initializer.Constant(value=0.0) + )) + + def forward(self, inp): + inp = _legacy_C_ops.fused_moe( + inp, + self.gate_weight, + self.gate_bias, + self.ln_scale, + self.ln_bias, + list(self.linear1_weights), + list(self.linear1_biases), + list(self.linear2_weights), + list(self.linear2_biases), + 'pre_layer_norm', + True, + 'ln_epsilon', + 1e-5, + 'topk', + self.top_k, + 'mp_size', + self.mp_size, + 'mp_rank', + self.mp_rank, + 'num_expert', + self.num_expert, + 'world_size', + self.world_size, + 'moe_ring_id', + -1 if self.group is None else self.group.id + ) + return inp + + def _amp_decorate(self, dtype): + # tmp fix for amp.decorator(O2) + def trans_to_fp16(l): + for param in l: + if param is not None: + with paddle.no_grad(): + param_applied = _to_dtype(param, dtype) + trans_to_fp16(self.linear1_weights) + trans_to_fp16(self.linear1_biases) + trans_to_fp16(self.linear2_weights) + trans_to_fp16(self.linear2_biases) + _ = _to_dtype(self.gate_weight, dtype) + _ = _to_dtype(self.gate_bias, dtype) self._dtype = dtype \ No newline at end of file From 497b508d127acd98c37241555e42ff54b8092e5c Mon Sep 17 00:00:00 2001 From: tianyan01 Date: Wed, 13 Sep 2023 14:55:48 +0800 Subject: [PATCH 3/9] mv number_count and fused_moe op to phi, fix some fused_moe bug --- paddle/fluid/operators/fused/CMakeLists.txt | 2 - paddle/fluid/operators/fused/fmha_ref.h | 2 +- .../operators/fused/fused_dropout_helper.h | 72 +- paddle/fluid/operators/fused/fused_moe_op.cc | 132 ---- paddle/fluid/operators/fused/fused_moe_op.cu | 501 ------------ paddle/fluid/operators/fused/fused_moe_op.h | 718 ------------------ paddle/fluid/pybind/op_function_generator.h | 11 - paddle/phi/api/yaml/ops.yaml | 18 + paddle/phi/infermeta/multiary.cc | 30 + paddle/phi/infermeta/multiary.h | 23 + paddle/phi/kernels/fused_moe_kernel.h | 502 ++++++++++++ paddle/phi/kernels/gpu/fused_moe_kernel.cu | 484 ++++++++++++ paddle/phi/kernels/gpu/fused_moe_kernel.cu.h | 63 ++ paddle/phi/kernels/gpu/number_count_kernel.cu | 96 +++ paddle/phi/kernels/number_count_kernel.h | 36 + .../contrib/mixed_precision/fp16_lists.py | 1 - .../contrib/mixed_precision/fp16_utils.py | 2 - .../incubate/nn/layer/fused_transformer.py | 29 +- 18 files changed, 1316 insertions(+), 1406 deletions(-) delete mode 100644 paddle/fluid/operators/fused/fused_moe_op.cc delete mode 100644 paddle/fluid/operators/fused/fused_moe_op.cu delete mode 100644 paddle/fluid/operators/fused/fused_moe_op.h create mode 100644 paddle/phi/kernels/fused_moe_kernel.h create mode 100644 paddle/phi/kernels/gpu/fused_moe_kernel.cu create mode 100644 paddle/phi/kernels/gpu/fused_moe_kernel.cu.h create mode 100644 paddle/phi/kernels/gpu/number_count_kernel.cu create mode 100644 paddle/phi/kernels/number_count_kernel.h diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt index b1676883ff39b..fe82565bc36f3 100755 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -23,7 +23,6 @@ register_operators( fused_transformer_op fused_feedforward_op fused_multi_transformer_op - fused_moe_op fused_multi_transformer_int8_op fused_bias_dropout_residual_layer_norm_op resnet_unit_op @@ -121,7 +120,6 @@ if(WITH_GPU OR WITH_ROCM) # fused_attention_op op_library(fused_attention_op) op_library(fused_multi_transformer_op) - op_library(fused_moe_op) op_library(fused_multi_transformer_int8_op) op_library(fused_bias_dropout_residual_layer_norm_op) endif() diff --git a/paddle/fluid/operators/fused/fmha_ref.h b/paddle/fluid/operators/fused/fmha_ref.h index 863ab8cba964b..46153d980409d 100644 --- a/paddle/fluid/operators/fused/fmha_ref.h +++ b/paddle/fluid/operators/fused/fmha_ref.h @@ -117,7 +117,7 @@ class FMHARef { out_seq_len = cache_kv_out_tensor->dims()[3]; } else { if (cache_kv_out_tensor) { - *cache_kv_out_tensor = transpose_2_out_tensor->Slice(1, 3); + *cache_kv_out_tensor = transpose_2_out_tensor->Slice(1, 3); } } diff --git a/paddle/fluid/operators/fused/fused_dropout_helper.h b/paddle/fluid/operators/fused/fused_dropout_helper.h index 2d1491fefb07e..e49ea8861823e 100644 --- a/paddle/fluid/operators/fused/fused_dropout_helper.h +++ b/paddle/fluid/operators/fused/fused_dropout_helper.h @@ -212,31 +212,57 @@ class FusedDropoutHelper { const float quant_next_in_scale = 1.0, const int quant_round_type = 1, const float quant_max_bound = 127.0, - const float quant_min_bound = -127.0) { + const float quant_min_bound = -127.0, + bool approximate = false) { auto increment = GetIncrement(ctx); if (act_method == "gelu") { - GeluFunctor gelu; - LaunchDropoutActBias, InType, OutType>( - gelu, - dropout_param_.seed, - rows_, - cols_, - dropout_param_.increment, - dropout_param_.dropout_prob, - dropout_param_.is_upscale_in_train, - dropout_param_.is_test, - src, - bias, - out, - mask, - ctx, - quant_last_in_scale, - dequant_out_scale_data, - quant_out_scale_offset, - quant_next_in_scale, - quant_round_type, - quant_max_bound, - quant_min_bound); + if (approximate) { + phi::funcs::GeluFunctor gelu; + LaunchDropoutActBias, InType, OutType>( + gelu, + dropout_param_.seed, + rows_, + cols_, + dropout_param_.increment, + dropout_param_.dropout_prob, + dropout_param_.is_upscale_in_train, + dropout_param_.is_test, + src, + bias, + out, + mask, + ctx, + quant_last_in_scale, + dequant_out_scale_data, + quant_out_scale_offset, + quant_next_in_scale, + quant_round_type, + quant_max_bound, + quant_min_bound); + } else { + GeluFunctor gelu; + LaunchDropoutActBias, InType, OutType>( + gelu, + dropout_param_.seed, + rows_, + cols_, + dropout_param_.increment, + dropout_param_.dropout_prob, + dropout_param_.is_upscale_in_train, + dropout_param_.is_test, + src, + bias, + out, + mask, + ctx, + quant_last_in_scale, + dequant_out_scale_data, + quant_out_scale_offset, + quant_next_in_scale, + quant_round_type, + quant_max_bound, + quant_min_bound); + } } else if (act_method == "relu") { phi::funcs::ReluFunctor relu; LaunchDropoutActBias -#include -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" - -namespace paddle { -namespace operators { -using Tensor = framework::Tensor; - -class FusedMoeOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - void InferShape(framework::InferShapeContext *context) const override { - // input - OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "fused_moe"); - OP_INOUT_CHECK(context->HasInput("GateWeight"), - "Input", - "GateWeight", - "fused_moe"); - OP_INOUT_CHECK(context->HasInput("GateBias"), - "Input", - "GateBias", - "fused_moe"); - OP_INOUT_CHECK(context->HasInput("LnScale"), - "Input", - "LnScale", - "fused_moe"); - OP_INOUT_CHECK(context->HasInput("LnBias"), - "Input", - "LnBias", - "fused_moe"); - OP_INOUT_CHECK(context->HasInputs("ExpertsWeight1"), - "Input", - "ExpertsWeight1", - "fused_moe"); - OP_INOUT_CHECK(context->HasInputs("ExpertsBias1"), - "Input", - "ExpertsBias1", - "fused_moe"); - OP_INOUT_CHECK(context->HasInputs("ExpertsWeight2"), - "Input", - "ExpertsWeight2", - "fused_moe"); - OP_INOUT_CHECK(context->HasInputs("ExpertsBias2"), - "Input", - "ExpertsBias2", - "fused_moe"); - // output - OP_INOUT_CHECK(context->HasOutput("Out"), - "Output", - "Out", - "fused_moe"); - auto x_dims = context->GetInputDim("X"); - context->SetOutputDim("Out", x_dims); - } - - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); - } -}; - -class FusedMoeOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - // AsDispensable 可有可无 - // AsDuplicable 可复制 - // input - AddInput("X", "The input of FusedMoe op"); - AddInput("GateWeight", "The gate weight of FusedMoe op"); - AddInput("GateBias", "The gate bias of FusedMoe op"); - AddInput("LnScale", "The ln scale of FusedMoe op"); - AddInput("LnBias", "The LnBias of FusedMoe op"); - AddInput("ExpertsWeight1", "The expert linear1 weights of fused_moe op") - .AsDuplicable(); - AddInput("ExpertsBias1", "The expert linear1 biases of fused_moe op") - .AsDuplicable() - .AsDispensable(); - AddInput("ExpertsWeight2", "The expert linear2 weights of fused_moe op") - .AsDuplicable(); - AddInput("ExpertsBias2", "The expert linear2 biases of fused_moe op") - .AsDuplicable() - .AsDispensable(); - // output - AddOutput("Out", "Out"); - // attr - AddAttr("pre_layer_norm", "pre_layer_norm").SetDefault(true); - AddAttr("ln_epsilon", "ln_epsilon").SetDefault(1e-5f); - AddAttr("topk", "top k in gate").SetDefault(2); - AddAttr("mp_size", "mp_size").SetDefault(1); - AddAttr("mp_rank", "mp_rank").SetDefault(0); - AddAttr("num_expert", "num_expert").SetDefault(1); - AddAttr("world_size", "world_size").SetDefault(1); - AddAttr("moe_ring_id", "moe_ring_id").SetDefault(-1); - AddComment(R"DOC( - The fused_moe operator is the same as the following pseudo codes: - - pass - - )DOC"); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OPERATOR(fused_moe, - ops::FusedMoeOp, - ops::FusedMoeOpMaker, - paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker); \ No newline at end of file diff --git a/paddle/fluid/operators/fused/fused_moe_op.cu b/paddle/fluid/operators/fused/fused_moe_op.cu deleted file mode 100644 index 0478c6cf551bc..0000000000000 --- a/paddle/fluid/operators/fused/fused_moe_op.cu +++ /dev/null @@ -1,501 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/fused/fused_moe_op.h" - -namespace paddle { -namespace operators { -using Tensor = framework::Tensor; - -template -static void AllToAll(Tensor& tensor, // NOLINT - Tensor& out, - const int ring_id, - const phi::GPUContext& ctx) { - if (ring_id == -1) return; -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) - auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); - - if (map->has(ring_id)) { - paddle::distributed::ProcessGroup* pg = map->get(ring_id); - auto pg_nccl = static_cast(pg); - - std::vector in_tensor; - std::vector out_tensor; - in_tensor.push_back(tensor); - out_tensor.push_back(out); - auto task = pg_nccl->AllToAll(in_tensor, out_tensor, true, true); - task->Wait(); - } else { - auto dtype = platform::ToNCCLDataType( - framework::TransToProtoVarType(tensor.dtype())); - int64_t send_numel = tensor.numel(); // send_numel - auto place = ctx.GetPlace(); - auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); - int nranks = comm->nranks(); - auto stream = ctx.stream(); - - framework::DDim x_dims = tensor.dims(); - framework::DDim out_dims(x_dims); - PADDLE_ENFORCE_EQ( - x_dims[0] % nranks, - 0, - platform::errors::InvalidArgument( - "The first dimension size (%d) of the input tensor must be " - "divisible by the number of ranks (%d).", - x_dims[0], - nranks)); - auto send_buf = tensor.data(); - auto recv_buf = out.mutable_data(out_dims, place); - size_t offset = 0; - send_numel /= nranks; - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); - for (auto i = 0; i < nranks; ++i) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend( - send_buf + offset, send_numel, dtype, i, comm->comm(), stream)); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv( - recv_buf + offset, send_numel, dtype, i, comm->comm(), stream)); - offset += send_numel; - } - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); - } -#else - PADDLE_THROW(platform::errors::Unimplemented( - "PaddlePaddle should compile with NCCL or RCCL when used tensor model " - "parallel op.")); -#endif -} - -template -static void AllGather(Tensor& tensor, // NOLINT - Tensor& out, - const int ring_id, - const phi::GPUContext& ctx) { - if (ring_id == -1) return; -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) - auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); - - if (map->has(ring_id)) { - paddle::distributed::ProcessGroup* pg = map->get(ring_id); - auto pg_nccl = static_cast(pg); - - std::vector in_tensor; - std::vector out_tensor; - in_tensor.push_back(tensor); - out_tensor.push_back(out); - auto task = pg_nccl->AllGather(in_tensor, out_tensor, true, true); - task->Wait(); - } else { - auto dtype = platform::ToNCCLDataType( - framework::TransToProtoVarType(tensor.dtype())); - int64_t numel = tensor.numel(); - auto place = ctx.GetPlace(); - auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); - auto stream = ctx.stream(); - auto out_dims = tensor.dims(); - int nranks = comm->nranks(); - out_dims[0] *= nranks; - out.mutable_data(out_dims, place); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather( - tensor.data(), out.data(), numel, dtype, comm->comm(), stream)); - } -#else - PADDLE_THROW(platform::errors::Unimplemented( - "PaddlePaddle should compile with NCCL or RCCL when used tensor model " - "parallel op.")); -#endif -} - -template -class FusedMoeOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - using U = LayerNormParamType; - auto& dev_ctx = context.cuda_device_context(); - // input - auto* x = context.Input("X"); - auto* gate_weight = context.Input("GateWeight"); - auto* gate_bias = context.Input("GateBias"); - const bool pre_layer_norm = context.Attr("pre_layer_norm"); - auto* ln_scale = - pre_layer_norm ? context.Input("LnScale") : nullptr; - auto* ln_bias = - pre_layer_norm ? context.Input("LnBias") : nullptr; - // linear 1 - auto experts_weight1 = context.MultiInput("ExpertsWeight1"); - auto experts_bias1 = context.MultiInput("ExpertsBias1"); - // linear 2 - auto experts_weight2 = context.MultiInput("ExpertsWeight2"); - auto experts_bias2 = context.MultiInput("ExpertsBias2"); - - // output - auto* out = context.Output("Out"); - dev_ctx.Alloc(out, out->numel() * sizeof(T)); - - // attr - const float epsilon = context.Attr("ln_epsilon"); - const int topk = context.Attr("topk"); - const int mp_size = context.Attr("mp_size"); - const int mp_rank = context.Attr("mp_rank"); - const int num_expert = context.Attr("num_expert"); - const int world_size = context.Attr("world_size"); - const int moe_ring_id = context.Attr("moe_ring_id"); - - // dim - auto x_dim = x->dims(); - int bsz = x_dim[0]; - int seq_len = x_dim[1]; - int bsz_seq = bsz * seq_len; - int d_model = x_dim[2]; - int tot_expert = world_size * num_expert; - int dim_feedforward = experts_weight1[0]->dims()[1]; - - // pre_layer_norm - const U* ln_scale_ptr = - ln_scale == nullptr ? nullptr : ln_scale->data(); - const U* ln_bias_ptr = - ln_bias == nullptr ? nullptr : ln_bias->data(); - Tensor ln_mean, ln_variance; - ln_mean.Resize({{bsz_seq}}); - auto* ln_mean_data = - dev_ctx.Alloc(&ln_mean, ln_mean.numel() * sizeof(U)); - ln_variance.Resize({{bsz_seq}}); - auto* ln_variance_data = - dev_ctx.Alloc(&ln_variance, ln_variance.numel() * sizeof(U)); - FusedDropoutLayerNormHelper pre_layernorm_helper( - bsz_seq, d_model, epsilon); - // tmp out - Tensor ln_out; - ln_out.Resize({{bsz, seq_len, d_model}}); - auto *ln_out_data = dev_ctx.Alloc(&ln_out, ln_out.numel() * sizeof(T)); - // after slice, bsz_seq should be change - int sliced_bsz_seq = bsz_seq; - int start = 0; - int end = 0; - if (mp_size > 1) { - start = bsz_seq / world_size * mp_rank; - end = std::min(start + bsz_seq / world_size, bsz_seq); - sliced_bsz_seq = end - start; - } - int out_batch_size = sliced_bsz_seq * topk; - // slice - Tensor sliced_inp; - sliced_inp.Resize({{sliced_bsz_seq, d_model}}); - auto* sliced_inp_data = dev_ctx.Alloc(&sliced_inp, sliced_inp.numel() * sizeof(T)); - // gate linear - Tensor gate_out; - gate_out.Resize({{sliced_bsz_seq, tot_expert}}); - auto* gate_out_data = dev_ctx.Alloc(&gate_out, gate_out.numel() * sizeof(T)); - auto gate_linear_compute = AttnMatMul( - dev_ctx, false, false, sliced_bsz_seq, tot_expert, d_model, true); - // topk - Tensor topk_value, topk_idx; - topk_value.Resize({{sliced_bsz_seq, topk}}); - auto* topk_value_data = dev_ctx.Alloc(&topk_value, topk_value.numel() * sizeof(T)); - topk_idx.Resize({{sliced_bsz_seq, topk}}); - auto* topk_idx_data = dev_ctx.Alloc(&topk_idx, topk_idx.numel() * sizeof(int64_t)); - // local expert count, global expert count - Tensor local_expert_count, global_expert_count; - local_expert_count.Resize({{tot_expert}}); - global_expert_count.Resize({{tot_expert}}); - auto* local_expert_count_data = - dev_ctx.Alloc(&local_expert_count, local_expert_count.numel() * sizeof(int64_t)); - auto* global_expert_count_data = - dev_ctx.Alloc(&global_expert_count, global_expert_count.numel() * sizeof(int64_t)); - // fwd_expert_count, fwd_batch_size - Tensor fwd_expert_count, fwd_batch_size; - fwd_expert_count.Resize({{world_size, num_expert}}); - fwd_batch_size.Resize({{1}}); - auto* fwd_expert_count_data = - dev_ctx.Alloc(&fwd_expert_count, fwd_expert_count.numel() * sizeof(int64_t)); - auto* fwd_batch_size_data = - dev_ctx.Alloc(&fwd_batch_size, fwd_batch_size.numel() * sizeof(int64_t)); - // pos, temp pos - Tensor pos, temp_pos; - pos.Resize({{out_batch_size}}); - temp_pos.Resize({{out_batch_size}}); - auto* pos_data = dev_ctx.Alloc(&pos, pos.numel() * sizeof(int64_t)); - auto* temp_pos_data = dev_ctx.Alloc(&temp_pos, temp_pos.numel() * sizeof(int64_t)); - // cumsum - Tensor lec_cum; - lec_cum.Resize({{tot_expert}}); - auto* lec_cum_data = dev_ctx.Alloc(&lec_cum, lec_cum.numel() * sizeof(int64_t)); - // fused moe ffn tmp out - Tensor index_select_out; - index_select_out.Resize({{out_batch_size, d_model}}); - auto* index_select_out_data = dev_ctx.Alloc(&index_select_out, - index_select_out.numel() * sizeof(T)); - Tensor global_gather_out; - global_gather_out.Resize({{out_batch_size, d_model}}); - auto* global_gather_out_data = dev_ctx.Alloc(&global_gather_out, - global_gather_out.numel() * sizeof(T)); - Tensor moe_gather_out; - moe_gather_out.Resize({{out_batch_size, d_model}}); - auto* moe_gather_out_data = dev_ctx.Alloc(&moe_gather_out, - moe_gather_out.numel() * sizeof(T)); - Tensor bmm_out; - bmm_out.Resize({{sliced_bsz_seq, 1, d_model}}); - auto* bmm_out_data = dev_ctx.Alloc(&bmm_out, bmm_out.numel() * sizeof(T)); - Tensor all_gather_out; - all_gather_out.Resize({{bsz_seq, d_model}}); - auto* all_gather_out_data = - dev_ctx.Alloc(&all_gather_out, all_gather_out.numel() * sizeof(T)); - DropoutParam dropout_param(false, 0, true, true, 0.0, nullptr, 0); - - // step1 layer norm - if (pre_layer_norm) { - pre_layernorm_helper.LayerNorm(dev_ctx, - x->data(), - ln_scale_ptr, - ln_bias_ptr, - ln_out_data, - ln_mean_data, - ln_variance_data); - } else { - ln_out = *x; - } - // step2 resize and slice ln_out - ln_out.Resize({{bsz_seq, d_model}}); - if (mp_size > 1) { - sliced_inp = ln_out.Slice(start, end); - } else { - sliced_inp = ln_out; - } - // step3 gate & topk - gate_linear_compute.ComputeForward(gate_weight, &sliced_inp, gate_bias, &gate_out, &gate_out); - phi::TopkKernel(dev_ctx, - gate_out, - phi::Scalar(topk), - -1, - true, - false, - &topk_value, - &topk_idx); - // step4 prepare forward - // step4.1 number count - NumberCountCompute(dev_ctx, &topk_idx, tot_expert, &local_expert_count); - // step4.2 all_to_all - if (world_size > 1) { - AllToAll(local_expert_count, global_expert_count, moe_ring_id, dev_ctx); - } else { - global_expert_count = local_expert_count; - } - // global expert count resize - global_expert_count.Resize({{world_size, num_expert}}); - // fwd expert count - phi::SumKernel(dev_ctx, - global_expert_count, - phi::IntArray({0}), - global_expert_count.dtype(), - false, - &fwd_expert_count); - // fwd batch size - phi::SumKernel(dev_ctx, - fwd_expert_count, - phi::IntArray({}), // axis is None - fwd_expert_count.dtype(), - false, - &fwd_batch_size); - // step4.3 cumsum & assign pos - phi::CumsumKernel(dev_ctx, - local_expert_count, - phi::Scalar(0), - false, - false, - false, - &lec_cum); - AssignPosCompute(dev_ctx, &lec_cum, &topk_idx, &pos); - if (topk > 1) { - Tensor topk_tensor; - topk_tensor.Resize({{1}}); - auto *topk_tensor_data = dev_ctx.Alloc(&topk_tensor, topk_tensor.numel() * sizeof(int64_t)); - phi::FullKernel(dev_ctx, {1}, topk, pos.dtype(), &topk_tensor); - phi::FloorDivideKernel(dev_ctx, - pos, - topk_tensor, - &temp_pos); - } else { - temp_pos = pos; - } - Tensor fwd_expert_count_cpu; - framework::TensorCopySync(fwd_expert_count, platform::CPUPlace(), &fwd_expert_count_cpu); - Tensor fwd_batch_size_cpu; - framework::TensorCopySync(fwd_batch_size, platform::CPUPlace(), &fwd_batch_size_cpu); - int fwd_bsz = fwd_batch_size_cpu.data()[0]; - - Tensor global_scatter_out; - global_scatter_out.Resize({{fwd_bsz, d_model}}); - auto* global_scatter_out_data = dev_ctx.Alloc(&global_scatter_out, - global_scatter_out.numel() * sizeof(T)); - std::vector tmp_expert_out; - Tensor all_expert_out; - all_expert_out.Resize({{fwd_bsz, d_model}}); - auto* all_expert_out_data = dev_ctx.Alloc(&all_expert_out, - all_expert_out.numel() * sizeof(T)); - // step 5, MOEScatter - // step 5.1, index select - // suppose tmp_pos->shape != [0] - phi::IndexSelectKernel(dev_ctx, sliced_inp, temp_pos, 0, &index_select_out); - if (world_size > 1) { - auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); - // step 5.2, global_scatter - if (map->has(moe_ring_id)) { - GlobalScatterProcessGroupFunctor functor_; - functor_(dev_ctx, - &index_select_out, - &local_expert_count, - &global_expert_count, - moe_ring_id, - true, - &global_scatter_out); - } else { - GlobalScatterFunctor functor_; - functor_(dev_ctx, - &index_select_out, - &local_expert_count, - &global_expert_count, - moe_ring_id, - true, - &global_scatter_out); - } - } else { - global_scatter_out = index_select_out; - } - // step 6, Expert Computation - if (global_scatter_out.dims()[0] != 0) { - int last_index = 0; - for (int idx = 0; idx < num_expert; idx++) { - int cur_expert_count = fwd_expert_count_cpu.data()[idx]; - if (cur_expert_count <= 0) { - continue; - } - int end = cur_expert_count + last_index; - Tensor expert_out1; - expert_out1.Resize({{cur_expert_count, dim_feedforward}}); - auto *expert_out1_data = dev_ctx.Alloc(&expert_out1, - expert_out1.numel() * sizeof(T)); - Tensor act_bias_out; - act_bias_out.Resize({{cur_expert_count, dim_feedforward}}); - auto *act_bias_out_data = dev_ctx.Alloc(&act_bias_out, - act_bias_out.numel() * sizeof(T)); - Tensor expert_out2; - expert_out2.Resize({{cur_expert_count, d_model}}); - auto *expert_out2_data = dev_ctx.Alloc(&expert_out2, - expert_out2.numel() * sizeof(T)); - FusedDropoutHelper fused_act_dropout_helper( - dev_ctx, cur_expert_count, dim_feedforward, dropout_param); - - Tensor tmp_inp = global_scatter_out.Slice(last_index, end); - // linear1 matmul - MatMulAndAdd(dev_ctx, - experts_weight1[idx], - &tmp_inp, - nullptr, - false, - false, - false, // dont compute bias - &expert_out1, - nullptr); - // bias gelu - fused_act_dropout_helper.DropoutActBias(dev_ctx, - expert_out1.data(), - experts_bias1[idx]->data(), - "gelu", - act_bias_out.data(), - nullptr); - // linear2 matmul & add - MatMulAndAdd(dev_ctx, - experts_weight2[idx], - &act_bias_out, - experts_bias2[idx], - false, - false, - true, // compute bias - &expert_out2, - &expert_out2); - tmp_expert_out.emplace_back(expert_out2); - last_index = end; - } - phi::funcs::ConcatFunctor concat; - concat(dev_ctx, tmp_expert_out, 0, &all_expert_out); - } else { - all_expert_out = global_scatter_out; - } - // step7. MOEGather - if (world_size > 1) { - auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); - // step 7.1, global_gather - if (map->has(moe_ring_id)) { - GlobalGatherProcessGroupFunctor functor_; - functor_(dev_ctx, - &all_expert_out, - &local_expert_count, - &global_expert_count, - moe_ring_id, - true, - &global_gather_out); - } else { - GlobalGatherFunctor functor_; - functor_(dev_ctx, - &all_expert_out, - &local_expert_count, - &global_expert_count, - moe_ring_id, - true, - &global_gather_out); - } - } else { - global_gather_out = all_expert_out; - } - // step 7.2, local_gather or scatter - // suppose pos->shape != [0] - phi::ScatterKernel(dev_ctx, - moe_gather_out, - pos, - global_gather_out, - true, - &moe_gather_out); - // step 8, reshape & bmm - if (topk > 1) { - // moe gather out reshape - moe_gather_out.Resize({{sliced_bsz_seq, topk, d_model}}); - topk_value.Resize({{sliced_bsz_seq, 1, topk}}); - phi::BmmKernel(dev_ctx, topk_value, moe_gather_out, &bmm_out); - bmm_out.Resize({{sliced_bsz_seq, d_model}}); - } else { - bmm_out = moe_gather_out; - } - // step 9, AllGather - if (mp_size > 1) { - // all gather - AllGather(bmm_out, all_gather_out, moe_ring_id, dev_ctx); - } else { - all_gather_out = bmm_out; - } - // step 10, reshape - all_gather_out.Resize(x_dim); - // step 11, add residual - phi::AddKernel(dev_ctx, all_gather_out, *x, out); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - fused_moe, - ops::FusedMoeOpKernel, - ops::FusedMoeOpKernel, - ops::FusedMoeOpKernel); diff --git a/paddle/fluid/operators/fused/fused_moe_op.h b/paddle/fluid/operators/fused/fused_moe_op.h deleted file mode 100644 index 5ebfbff589850..0000000000000 --- a/paddle/fluid/operators/fused/fused_moe_op.h +++ /dev/null @@ -1,718 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" -#include "paddle/fluid/operators/fused/fused_dropout_helper.h" -#include "paddle/fluid/operators/layer_norm_kernel.cu.h" -#include "paddle/fluid/operators/fused/attn_gemm.h" -#include "paddle/fluid/operators/matmul_v2_op.h" -#include "paddle/phi/api/include/tensor.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" -#include "paddle/phi/kernels/funcs/broadcast_function.h" -#include "paddle/phi/kernels/funcs/elementwise_functor.h" -#include "paddle/phi/kernels/top_k_kernel.h" -#include "paddle/phi/kernels/cum_kernel.h" -#include "paddle/phi/kernels/reduce_sum_kernel.h" -#include "paddle/phi/kernels/full_kernel.h" -#include "paddle/phi/kernels/elementwise_kernel.h" -#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" -#include "paddle/phi/kernels/funcs/functors.h" -#include "paddle/phi/kernels/index_select_kernel.h" -#include "paddle/phi/kernels/scatter_kernel.h" -#include "paddle/fluid/operators/collective/global_scatter_op.h" -#include "paddle/fluid/operators/collective/global_gather_op.h" -#include "paddle/phi/kernels/bmm_kernel.h" -#include "paddle/phi/kernels/elementwise_add_kernel.h" -// #include "paddle/fluid/framework/convert_utils.h" -// #include "paddle/fluid/platform/float16.h" - -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) -#include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h" -#include "paddle/fluid/platform/collective_helper.h" -#include "paddle/fluid/platform/device/gpu/nccl_helper.h" -#endif - -namespace paddle { -namespace operators { -// number count -#define CEIL(_x_, _y_) (((_x_)-1) / (_y_) + 1) -#define PERTHREAD_EXPERTS 256 -#define WARP_SIZE 32 - -const int CUDA_NUM_THREADS = 512; -static inline int GET_BLOCKS(const int N) { - return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; -} - -template -__global__ void initialize_zero_kernel(T* data, const int length) { - CUDA_KERNEL_LOOP(idx, length) { data[idx] = static_cast(0); } -} - -template -__global__ void NumberCount(const T* numbers, - T* number_count, - int64_t batch_size, - int upper_range) { - int res_tmp[PERTHREAD_EXPERTS] = {0}; - int expert_min = blockIdx.x * PERTHREAD_EXPERTS; - int expert_max = expert_min + PERTHREAD_EXPERTS; - if (expert_max > upper_range) { - expert_max = upper_range; - } - for (int i = threadIdx.x; i < batch_size; i += blockDim.x) { - T idx = numbers[i]; - if (idx == -1) { - continue; - } - if (idx < expert_min || idx >= expert_max) { - continue; - } - res_tmp[idx - expert_min] += 1; - } - for (int i = expert_min; i < expert_max; ++i) { - int x = res_tmp[i - expert_min]; -#pragma unroll - for (int j = 1; j < WARP_SIZE; j <<= 1) { -#ifdef __HIPCC__ - x = x + __shfl_down(x, j); -#else - x = x + __shfl_down_sync(-1u, x, j); -#endif - } - if (threadIdx.x % WARP_SIZE == 0) { - platform::CudaAtomicAdd(number_count + i, x); - } - } -} - -template -void NumberCountCompute(const phi::GPUContext &dev_ctx, - framework::Tensor* numbers, - int upper_range, - framework::Tensor* out) { - int64_t batch_size = numbers->numel(); - auto place = dev_ctx.GetPlace(); - - framework::DDim out_dims = phi::make_ddim({upper_range}); - auto out_data = out->mutable_data(out_dims, place); - const T* gate_data = numbers->data(); - - initialize_zero_kernel - <<>>( - out_data, upper_range); - - NumberCount - <<>>( - gate_data, out_data, batch_size, upper_range); -} - -// assign pos -static constexpr int kNumCUDAThreads = 512; -static constexpr int kNumMaxinumNumBlocks = 4096; - -static inline int NumBlocks(const int N) { - return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, - kNumMaxinumNumBlocks); -} - -template -__global__ void AssignPos(T* cum_count, - const T* numbers, - T* out, - int64_t limit) { - CUDA_KERNEL_LOOP(i, limit) { - int number_idx = numbers[i]; - if (number_idx > -1) { - int p = platform::CudaAtomicAdd(cum_count + number_idx, -1); - out[p - 1] = i; - } - } -} - -template -void AssignPosCompute(const phi::GPUContext &dev_ctx, - framework::Tensor* cum_count, // (counter number) int32 | int64 - framework::Tensor* numbers, // (batch_size * seq_len, topk) int32 - framework::Tensor* out) { - auto place = dev_ctx.GetPlace(); - auto numel = numbers->numel(); - T* cum_data = const_cast(cum_count->data()); - auto cum_size = cum_count->numel(); - - framework::Tensor cpu_cum_count; - int64_t cpu_eff_num_len_data = 0; - if (platform::is_cpu_place(cum_count->place())) { - cpu_eff_num_len_data = cum_count->data()[cum_size - 1]; - } else { - framework::TensorCopySync( - *cum_count, platform::CPUPlace(), &cpu_cum_count); - cpu_eff_num_len_data = cpu_cum_count.data()[cum_size - 1]; - } - - framework::DDim out_dims = phi::make_ddim({cpu_eff_num_len_data}); - auto out_data = out->mutable_data(out_dims, place); - - const T* num_data = numbers->data(); - - int blocks = NumBlocks(numel); - int threads = kNumCUDAThreads; - - AssignPos<<>>( - cum_data, num_data, out_data, numel); -} - -template -struct GlobalScatterFunctor { - void operator()(const phi::GPUContext& ctx, - const framework::Tensor* x, - const framework::Tensor* local_count, - const framework::Tensor* global_count, - int ring_id, - bool use_calc_stream, - framework::Tensor* out) { -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) -#if NCCL_VERSION_CODE >= 2703 - // auto x = ctx.Input("X"); - // auto local_count = ctx.Input("local_count"); - // auto global_count = ctx.Input("global_count"); - auto local_count_type = - framework::TransToProtoVarType(local_count->dtype()); - auto global_count_type = - framework::TransToProtoVarType(global_count->dtype()); - if (local_count_type != framework::proto::VarType::INT64) { - PADDLE_THROW(platform::errors::InvalidArgument( - "Please use int64 type in local_count.")); - } - if (global_count_type != framework::proto::VarType::INT64) { - PADDLE_THROW(platform::errors::InvalidArgument( - "Please use int64 type in global_count.")); - } - // auto out = ctx.Output("Out"); - const int64_t* cpu_local_count_data; - const int64_t* cpu_global_count_data; - framework::Tensor cpu_local_count; - if (platform::is_cpu_place(local_count->place())) { - cpu_local_count_data = local_count->data(); - } else { - framework::TensorCopySync( - *local_count, platform::CPUPlace(), &cpu_local_count); - cpu_local_count_data = cpu_local_count.data(); - } - auto global_count_len = 0; - framework::Tensor cpu_global_count; - if (platform::is_cpu_place(global_count->place())) { - cpu_global_count_data = global_count->data(); - global_count_len = global_count->numel(); - } else { - framework::TensorCopySync( - *global_count, platform::CPUPlace(), &cpu_global_count); - cpu_global_count_data = cpu_global_count.data(); - global_count_len = cpu_global_count.numel(); - } - - ncclDataType_t dtype = - platform::ToNCCLDataType(framework::TransToProtoVarType(x->dtype())); - - // int ring_id = ctx.Attr("ring_id"); - PADDLE_ENFORCE_GE( - ring_id, - 0, - platform::errors::InvalidArgument( - "The ring_id (%d) for global scatter op must be non-negative.", - ring_id)); - - auto place = ctx.GetPlace(); - // HARD CODE HERE! - // auto place = platform::CUDAPlace(); - auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); - gpuStream_t stream = nullptr; - if (use_calc_stream) { - // auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - // stream = static_cast(dev_ctx)->stream(); - stream = ctx.stream(); - } else { - stream = comm->stream(); - } - int nranks = comm->nranks(); - auto in_feat = x->dims()[1]; - auto n_expert = local_count->dims()[0] / nranks; - int64_t fwd_count = 0; - - for (auto i = 0; i < global_count_len; ++i) { - fwd_count += cpu_global_count_data[i]; - } - framework::DDim out_dims = phi::make_ddim({fwd_count, in_feat}); - int64_t* expert_ptr = new int64_t[n_expert * nranks]; - expert_ptr[0] = 0; - auto tot_experts = n_expert * nranks; - for (auto i = 1; i < tot_experts; ++i) { - expert_ptr[i] = expert_ptr[i - 1] + cpu_local_count_data[i - 1]; - } - - auto recv_ptr = 0; - auto send_buf = x->data(); - auto recv_buf = out->mutable_data(out_dims, place); - - for (auto i = 0; i < n_expert; ++i) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); - for (auto j = 0; j < nranks; ++j) { - int idx = i + j * n_expert; - if (cpu_local_count_data[idx]) { - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::ncclSend(send_buf + expert_ptr[idx] * in_feat, - cpu_local_count_data[idx] * in_feat, - dtype, - j, - comm->comm(), - stream)); - } - if (cpu_global_count_data[idx]) { - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::ncclRecv(recv_buf + recv_ptr * in_feat, - cpu_global_count_data[idx] * in_feat, - dtype, - j, - comm->comm(), - stream)); - recv_ptr += cpu_global_count_data[idx]; - } - } - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); - } - -#else - PADDLE_THROW( - platform::errors::Unavailable("NCCL version >= 2.7.3 is needed.")); -#endif -#else - PADDLE_THROW( - platform::errors::Unavailable("PaddlePaddle should compile with GPU.")); -#endif - } -}; - -template -struct GlobalScatterProcessGroupFunctor { - void operator()(const phi::GPUContext& ctx, - const framework::Tensor* x, - const framework::Tensor* local_count, - const framework::Tensor* global_count, - int ring_id, - bool use_calc_stream, - framework::Tensor* out) { -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) -#if NCCL_VERSION_CODE >= 2703 - // auto x = ctx.Input("X"); - // auto local_count = ctx.Input("local_count"); - // auto global_count = ctx.Input("global_count"); - auto local_count_type = - framework::TransToProtoVarType(local_count->dtype()); - auto global_count_type = - framework::TransToProtoVarType(global_count->dtype()); - if (local_count_type != framework::proto::VarType::INT64) { - PADDLE_THROW(platform::errors::InvalidArgument( - "Please use int64 type in local_count.")); - } - if (global_count_type != framework::proto::VarType::INT64) { - PADDLE_THROW(platform::errors::InvalidArgument( - "Please use int64 type in global_count.")); - } - // auto out = ctx.Output("Out"); - const int64_t* cpu_local_count_data; - const int64_t* cpu_global_count_data; - framework::Tensor cpu_local_count; - if (platform::is_cpu_place(local_count->place())) { - cpu_local_count_data = local_count->data(); - } else { - framework::TensorCopySync( - *local_count, platform::CPUPlace(), &cpu_local_count); - cpu_local_count_data = cpu_local_count.data(); - } - auto global_count_len = 0; - framework::Tensor cpu_global_count; - if (platform::is_cpu_place(global_count->place())) { - cpu_global_count_data = global_count->data(); - global_count_len = global_count->numel(); - } else { - framework::TensorCopySync( - *global_count, platform::CPUPlace(), &cpu_global_count); - cpu_global_count_data = cpu_global_count.data(); - global_count_len = cpu_global_count.numel(); - } - - // int ring_id = ctx.Attr("ring_id"); - PADDLE_ENFORCE_GE( - ring_id, - 0, - platform::errors::InvalidArgument( - "The ring_id (%d) for global scatter op must be non-negative.", - ring_id)); - - auto place = ctx.GetPlace(); - // HARD CODE HERE! - // auto place = platform::CUDAPlace(); - - auto map = distributed::ProcessGroupMapFromGid::getInstance(); - distributed::ProcessGroup* pg = map->get(ring_id); - int nranks = pg->GetSize(); - auto in_feat = x->dims()[1]; - auto n_expert = local_count->dims()[0] / nranks; - int64_t fwd_count = 0; - - for (auto i = 0; i < global_count_len; ++i) { - fwd_count += cpu_global_count_data[i]; - } - framework::DDim out_dims = phi::make_ddim({fwd_count, in_feat}); - int64_t* expert_ptr = new int64_t[n_expert * nranks]; - expert_ptr[0] = 0; - auto tot_experts = n_expert * nranks; - for (auto i = 1; i < tot_experts; ++i) { - expert_ptr[i] = expert_ptr[i - 1] + cpu_local_count_data[i - 1]; - } - - auto recv_ptr = 0; - out->mutable_data(out_dims, place); - - for (auto i = 0; i < n_expert; ++i) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); - for (auto j = 0; j < nranks; ++j) { - int idx = i + j * n_expert; - if (cpu_local_count_data[idx]) { - phi::DenseTensor tmp = *x; - pg->Send_Partial(tmp, - j, - expert_ptr[idx] * in_feat, - cpu_local_count_data[idx] * in_feat); - } - if (cpu_global_count_data[idx]) { - pg->Recv_Partial(*out, - j, - recv_ptr * in_feat, - cpu_global_count_data[idx] * in_feat); - recv_ptr += cpu_global_count_data[idx]; - } - } - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); - } - -#ifdef PADDLE_WITH_CUDA - PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); -#else - PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize()); -#endif - -#else - PADDLE_THROW( - platform::errors::Unavailable("NCCL version >= 2.7.3 is needed.")); -#endif -#else - PADDLE_THROW( - platform::errors::Unavailable("PaddlePaddle should compile with GPU.")); -#endif - } -}; - -template -struct GlobalGatherFunctor { - void operator()(const phi::GPUContext& ctx, - const framework::Tensor* x, - const framework::Tensor* local_count, - const framework::Tensor* global_count, - int ring_id, - bool use_calc_stream, - framework::Tensor* out) { -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) -#if NCCL_VERSION_CODE >= 2703 - // auto x = ctx.Input("X"); - // auto local_count = ctx.Input("local_count"); - // auto global_count = ctx.Input("global_count"); - auto local_count_type = - framework::TransToProtoVarType(local_count->dtype()); - auto global_count_type = - framework::TransToProtoVarType(global_count->dtype()); - if (local_count_type != framework::proto::VarType::INT64) { - PADDLE_THROW(platform::errors::InvalidArgument( - "Please use int64 type in local_count.")); - } - if (global_count_type != framework::proto::VarType::INT64) { - PADDLE_THROW(platform::errors::InvalidArgument( - "Please use int64 type in global_count.")); - } - // auto out = ctx.Output("Out"); - const int64_t* cpu_local_count_data; - const int64_t* cpu_global_count_data; - auto local_count_len = 0; - - framework::Tensor cpu_local_count; - if (platform::is_cpu_place(local_count->place())) { - cpu_local_count_data = local_count->data(); - local_count_len = local_count->numel(); - } else { - framework::TensorCopySync( - *local_count, platform::CPUPlace(), &cpu_local_count); - cpu_local_count_data = cpu_local_count.data(); - local_count_len = cpu_local_count.numel(); - } - - framework::Tensor cpu_global_count; - if (platform::is_cpu_place(global_count->place())) { - cpu_global_count_data = global_count->data(); - } else { - framework::TensorCopySync( - *global_count, platform::CPUPlace(), &cpu_global_count); - cpu_global_count_data = cpu_global_count.data(); - } - - ncclDataType_t dtype = - platform::ToNCCLDataType(framework::TransToProtoVarType(x->dtype())); - - // int ring_id = ctx.Attr("ring_id"); - PADDLE_ENFORCE_GE( - ring_id, - 0, - platform::errors::InvalidArgument( - "The ring_id (%d) for global gather op must be non-negative.", - ring_id)); - auto place = ctx.GetPlace(); - // auto place = platform::CUDAPlace(); - - auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); - gpuStream_t stream = nullptr; - if (use_calc_stream) { - // auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - // stream = static_cast(dev_ctx)->stream(); - stream = ctx.stream(); - } else { - stream = comm->stream(); - } - int nranks = comm->nranks(); - auto in_feat = x->dims()[1]; - auto n_expert = local_count->dims()[0] / nranks; - - auto fwd_count = 0; - - for (auto i = 0; i < local_count_len; ++i) { - fwd_count += cpu_local_count_data[i]; - } - framework::DDim out_dims = phi::make_ddim({fwd_count, in_feat}); - int64_t* expert_ptr = new int64_t[n_expert * nranks]; - expert_ptr[0] = 0; - auto tot_experts = n_expert * nranks; - for (auto i = 1; i < tot_experts; ++i) { - expert_ptr[i] = expert_ptr[i - 1] + cpu_local_count_data[i - 1]; - } - auto send_ptr = 0; - auto send_buf = x->data(); - auto recv_buf = out->mutable_data(out_dims, place); - - for (auto i = 0; i < n_expert; ++i) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); - for (auto j = 0; j < nranks; ++j) { - int idx = i + j * n_expert; - if (cpu_global_count_data[idx]) { - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::ncclSend(send_buf + send_ptr * in_feat, - cpu_global_count_data[idx] * in_feat, - dtype, - j, - comm->comm(), - stream)); - send_ptr += cpu_global_count_data[idx]; - } - if (cpu_local_count_data[idx]) { - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::ncclRecv(recv_buf + expert_ptr[idx] * in_feat, - cpu_local_count_data[idx] * in_feat, - dtype, - j, - comm->comm(), - stream)); - } - } - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); - } -#else - PADDLE_THROW( - platform::errors::Unavailable("NCCL version >= 2.7.3 is needed.")); -#endif -#else - PADDLE_THROW( - platform::errors::Unavailable("PaddlePaddle should compile with GPU.")); -#endif - } -}; - -template -struct GlobalGatherProcessGroupFunctor { - void operator()(const phi::GPUContext& ctx, - const framework::Tensor* x, - const framework::Tensor* local_count, - const framework::Tensor* global_count, - int ring_id, - bool use_calc_stream, - framework::Tensor* out) { -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) -#if NCCL_VERSION_CODE >= 2703 - // auto x = ctx.Input("X"); - // auto local_count = ctx.Input("local_count"); - // auto global_count = ctx.Input("global_count"); - auto local_count_type = - framework::TransToProtoVarType(local_count->dtype()); - auto global_count_type = - framework::TransToProtoVarType(global_count->dtype()); - if (local_count_type != framework::proto::VarType::INT64) { - PADDLE_THROW(platform::errors::InvalidArgument( - "Please use int64 type in local_count.")); - } - if (global_count_type != framework::proto::VarType::INT64) { - PADDLE_THROW(platform::errors::InvalidArgument( - "Please use int64 type in global_count.")); - } - // auto out = ctx.Output("Out"); - const int64_t* cpu_local_count_data; - const int64_t* cpu_global_count_data; - auto local_count_len = 0; - - framework::Tensor cpu_local_count; - if (platform::is_cpu_place(local_count->place())) { - cpu_local_count_data = local_count->data(); - local_count_len = local_count->numel(); - } else { - framework::TensorCopySync( - *local_count, platform::CPUPlace(), &cpu_local_count); - cpu_local_count_data = cpu_local_count.data(); - local_count_len = cpu_local_count.numel(); - } - - framework::Tensor cpu_global_count; - if (platform::is_cpu_place(global_count->place())) { - cpu_global_count_data = global_count->data(); - } else { - framework::TensorCopySync( - *global_count, platform::CPUPlace(), &cpu_global_count); - cpu_global_count_data = cpu_global_count.data(); - } - - // int ring_id = ctx.Attr("ring_id"); - PADDLE_ENFORCE_GE( - ring_id, - 0, - platform::errors::InvalidArgument( - "The ring_id (%d) for global gather op must be non-negative.", - ring_id)); - auto place = ctx.GetPlace(); - // auto place = platform::CUDAPlace(); - - auto map = distributed::ProcessGroupMapFromGid::getInstance(); - distributed::ProcessGroup* pg = map->get(ring_id); - - int nranks = pg->GetSize(); - auto in_feat = x->dims()[1]; - auto n_expert = local_count->dims()[0] / nranks; - - auto fwd_count = 0; - - for (auto i = 0; i < local_count_len; ++i) { - fwd_count += cpu_local_count_data[i]; - } - framework::DDim out_dims = phi::make_ddim({fwd_count, in_feat}); - int64_t* expert_ptr = new int64_t[n_expert * nranks]; - expert_ptr[0] = 0; - auto tot_experts = n_expert * nranks; - for (auto i = 1; i < tot_experts; ++i) { - expert_ptr[i] = expert_ptr[i - 1] + cpu_local_count_data[i - 1]; - } - auto send_ptr = 0; - out->mutable_data(out_dims, place); - - for (auto i = 0; i < n_expert; ++i) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); - for (auto j = 0; j < nranks; ++j) { - int idx = i + j * n_expert; - if (cpu_global_count_data[idx]) { - phi::DenseTensor tmp = *x; - pg->Send_Partial( - tmp, j, send_ptr * in_feat, cpu_global_count_data[idx] * in_feat); - send_ptr += cpu_global_count_data[idx]; - } - if (cpu_local_count_data[idx]) { - pg->Recv_Partial(*out, - j, - expert_ptr[idx] * in_feat, - cpu_local_count_data[idx] * in_feat); - } - } - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); - } - -#ifdef PADDLE_WITH_CUDA - PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); -#else - PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize()); -#endif - -#else - PADDLE_THROW( - platform::errors::Unavailable("NCCL version >= 2.7.3 is needed.")); -#endif -#else - PADDLE_THROW( - platform::errors::Unavailable("PaddlePaddle should compile with GPU.")); -#endif - } -}; - -template -void MatMulAndAdd(const phi::GPUContext& dev_ctx, - const framework::Tensor* weight, - const framework::Tensor* input, - const framework::Tensor* bias, - bool istransA, - bool istransB, - bool compute_bias, - framework::Tensor* output, - framework::Tensor* bias_out) { - // Note: for blas.GEMM API in Paddle, it treats all inputs as row-major. - // here: (transa, transb): nt, input * weight. - CBLAS_TRANSPOSE transA = istransA ? CblasTrans : CblasNoTrans; - CBLAS_TRANSPOSE transB = istransB ? CblasTrans : CblasNoTrans; - T alpha = static_cast(1.0); - T beta = static_cast(0.0); - - // (m, n, k) = bsz_seq, output_size, input_size, (input, weight, out) - auto blas = phi::funcs::GetBlas(dev_ctx); - blas.GEMM(transA, - transB, - input->dims()[0], - weight->dims()[1], - input->dims()[1], - alpha, - input->data(), - weight->data(), - beta, - output->data()); - if (compute_bias) { - // bias_out = output + bias - std::vector ins = {output, bias}; - std::vector outs = {bias_out}; - phi::funcs::BroadcastKernel( - dev_ctx, ins, &outs, -1, phi::funcs::AddFunctor()); - } -} - -} // namesapce operators -} //namespace paddle \ No newline at end of file diff --git a/paddle/fluid/pybind/op_function_generator.h b/paddle/fluid/pybind/op_function_generator.h index 9f0ca6b62d3cb..af080bd0b3431 100644 --- a/paddle/fluid/pybind/op_function_generator.h +++ b/paddle/fluid/pybind/op_function_generator.h @@ -71,16 +71,6 @@ std::map> op_ins_map = { "FFN1Bias", "FFN2Weight", "FFN2Bias"}}, - {"fused_moe", - {"X", - "GateWeight", - "GateBias", - "LnScale", - "LnBias", - "ExpertsWeight1", - "ExpertsBias1", - "ExpertsWeight2", - "ExpertsBias2"}}, {"fused_multi_transformer_int8", {"X", "LnScale", "LnBias", "QKVW", "QKVBias", "CacheKV", "TimeStep", "SrcMask", @@ -345,7 +335,6 @@ std::map> op_outs_map = { "Beta2PowOut", "MasterParamOut"}}, {"fused_multi_transformer", {"CacheKVOut", "Out"}}, - {"fused_moe", {"Out"}}, {"fused_multi_transformer_int8", {"CacheKVOut", "Out"}}, {"resnet_basic_block", {"Y", "Conv1", "SavedMean1", "SavedInvstd1", "Mean1Out", diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 98c09f81574ec..d3cd00b3a541c 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -217,3 +217,21 @@ kernel : func : beam_search_softmax data_type : logits + +- op : number_count_v2 + args : (Tensor numbers, int upper_range) + output : Tensor(out) + infer_meta : + func : NumberCountInferMeta + kernel : + func : number_count_v2 + data_type : numbers + +- op : fused_moe_kernel + args : (Tensor x, Tensor gate_weight, Tensor gate_bias, Tensor ln_scale, Tensor ln_bias, Tensor[] experts_weight1, Tensor[] experts_bias1, Tensor[] experts_weight2, Tensor[] experts_bias2, bool pre_layer_norm, float ln_epsilon, int topk, int mp_size, int mp_rank, int num_expert, int world_size, int moe_ring_id, bool approximate) + output : Tensor(out) + infer_meta : + func : FusedMoeInferMeta + kernel : + func : fused_moe_kernel + data_type : x diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 7d8ef2a39379d..0a45e1fb0530b 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -2253,6 +2253,13 @@ void MultiplexInferMeta(const std::vector& ins, out->set_dtype(ins[0]->dtype()); } +void NumberCountInferMeta(const MetaTensor& numbers, + int upper_range, + MetaTensor* out) { + out->set_dims({upper_range}); + out->set_dtype(DataType::INT64); +} + void PsroiPoolInferMeta(const MetaTensor& x, const MetaTensor& rois, const MetaTensor& rois_num, @@ -2941,6 +2948,29 @@ void GraphSendUVInferMeta(const MetaTensor& x, out->set_dims(phi::make_ddim(out_dims_array)); } +void FusedMoeInferMeta(const MetaTensor& x, + const MetaTensor& gate_weight, + const MetaTensor& gate_bias, + const MetaTensor& ln_scale, + const MetaTensor& ln_bias, + const std::vector& experts_weight1, + const std::vector& experts_bias1, + const std::vector& experts_weight2, + const std::vector& experts_bias2, + bool pre_layer_norm, + float ln_epsilon, + int topk, + int mp_size, + int mp_rank, + int num_expert, + int world_size, + int moe_ring_id, + bool approximate, + MetaTensor* out) { + out->set_dims(x.dims()); + out->set_dtype(x.dtype()); +} + } // namespace phi PD_REGISTER_INFER_META_FN(batch_norm, phi::BatchNormInferMeta); diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 6673d66b0125d..948c5ca75ac4c 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -427,6 +427,10 @@ void MultiplexInferMeta(const std::vector& ins, const MetaTensor& ids, MetaTensor* out); +void NumberCountInferMeta(const MetaTensor& numbers, + int upper_range, + MetaTensor* out); + void PsroiPoolInferMeta(const MetaTensor& x, const MetaTensor& rois, const MetaTensor& rois_num, @@ -540,4 +544,23 @@ void GraphSendUVInferMeta(const MetaTensor& x, const std::string& message_op, MetaTensor* out); +void FusedMoeInferMeta(const MetaTensor& x, + const MetaTensor& gate_weight, + const MetaTensor& gate_bias, + const MetaTensor& ln_scale, + const MetaTensor& ln_bias, + const std::vector& experts_weight1, + const std::vector& experts_bias1, + const std::vector& experts_weight2, + const std::vector& experts_bias2, + bool pre_layer_norm, + float ln_epsilon, + int topk, + int mp_size, + int mp_rank, + int num_expert, + int world_size, + int moe_ring_id, + bool approximate, + MetaTensor* out); } // namespace phi diff --git a/paddle/phi/kernels/fused_moe_kernel.h b/paddle/phi/kernels/fused_moe_kernel.h new file mode 100644 index 0000000000000..dbafdf3918025 --- /dev/null +++ b/paddle/phi/kernels/fused_moe_kernel.h @@ -0,0 +1,502 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/fluid/operators/fused/fused_dropout_helper.h" +#include "paddle/fluid/operators/layer_norm_kernel.cu.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/broadcast_function.h" +#include "paddle/phi/kernels/funcs/elementwise_functor.h" +#include "paddle/phi/kernels/top_k_kernel.h" +#include "paddle/phi/kernels/cum_kernel.h" +#include "paddle/phi/kernels/reduce_sum_kernel.h" +#include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/elementwise_kernel.h" +#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" +#include "paddle/phi/kernels/funcs/functors.h" +#include "paddle/phi/kernels/index_select_kernel.h" +#include "paddle/phi/kernels/scatter_kernel.h" +#include "paddle/fluid/operators/collective/global_scatter_op.h" +#include "paddle/fluid/operators/collective/global_gather_op.h" +#include "paddle/phi/kernels/bmm_kernel.h" +#include "paddle/phi/kernels/elementwise_add_kernel.h" +#include "paddle/fluid/framework/convert_utils.h" +#include "paddle/fluid/platform/float16.h" +#include "paddle/phi/kernels/number_count_kernel.h" + +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h" +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/device/gpu/nccl_helper.h" +#endif + +namespace phi { +namespace framework = paddle::framework; +namespace platform = paddle::platform; + +template +void GlobalScatterFunctor(const phi::GPUContext& ctx, + const framework::Tensor* x, + const framework::Tensor* local_count, + const framework::Tensor* global_count, + int ring_id, + bool use_calc_stream, + framework::Tensor* out) { +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#if NCCL_VERSION_CODE >= 2703 + const int64_t* cpu_local_count_data; + const int64_t* cpu_global_count_data; + framework::Tensor cpu_local_count; + if (platform::is_cpu_place(local_count->place())) { + cpu_local_count_data = local_count->data(); + } else { + framework::TensorCopySync( + *local_count, platform::CPUPlace(), &cpu_local_count); + cpu_local_count_data = cpu_local_count.data(); + } + auto global_count_len = 0; + framework::Tensor cpu_global_count; + if (platform::is_cpu_place(global_count->place())) { + cpu_global_count_data = global_count->data(); + global_count_len = global_count->numel(); + } else { + framework::TensorCopySync( + *global_count, platform::CPUPlace(), &cpu_global_count); + cpu_global_count_data = cpu_global_count.data(); + global_count_len = cpu_global_count.numel(); + } + + ncclDataType_t dtype = + platform::ToNCCLDataType(framework::TransToProtoVarType(x->dtype())); + + PADDLE_ENFORCE_GE( + ring_id, + 0, + platform::errors::InvalidArgument( + "The ring_id (%d) for global scatter op must be non-negative.", + ring_id)); + + auto place = ctx.GetPlace(); + auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); + gpuStream_t stream = nullptr; + if (use_calc_stream) { + stream = ctx.stream(); + } else { + stream = comm->stream(); + } + int nranks = comm->nranks(); + auto in_feat = x->dims()[1]; + auto n_expert = local_count->dims()[0] / nranks; + int64_t* expert_ptr = new int64_t[n_expert * nranks]; + expert_ptr[0] = 0; + auto tot_experts = n_expert * nranks; + for (auto i = 1; i < tot_experts; ++i) { + expert_ptr[i] = expert_ptr[i - 1] + cpu_local_count_data[i - 1]; + } + + auto recv_ptr = 0; + auto send_buf = x->data(); + auto recv_buf = out->data(); + + for (auto i = 0; i < n_expert; ++i) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); + for (auto j = 0; j < nranks; ++j) { + int idx = i + j * n_expert; + if (cpu_local_count_data[idx]) { + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::ncclSend(send_buf + expert_ptr[idx] * in_feat, + cpu_local_count_data[idx] * in_feat, + dtype, + j, + comm->comm(), + stream)); + } + if (cpu_global_count_data[idx]) { + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::ncclRecv(recv_buf + recv_ptr * in_feat, + cpu_global_count_data[idx] * in_feat, + dtype, + j, + comm->comm(), + stream)); + recv_ptr += cpu_global_count_data[idx]; + } + } + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); + } + +#else + PADDLE_THROW( + platform::errors::Unavailable("NCCL version >= 2.7.3 is needed.")); +#endif +#else + PADDLE_THROW( + platform::errors::Unavailable("PaddlePaddle should compile with GPU.")); +#endif +} + +template +void GlobalScatterProcessGroupFunctor(const phi::GPUContext& ctx, + const framework::Tensor* x, + const framework::Tensor* local_count, + const framework::Tensor* global_count, + int ring_id, + bool use_calc_stream, + framework::Tensor* out) { +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#if NCCL_VERSION_CODE >= 2703 + const int64_t* cpu_local_count_data; + const int64_t* cpu_global_count_data; + framework::Tensor cpu_local_count; + if (platform::is_cpu_place(local_count->place())) { + cpu_local_count_data = local_count->data(); + } else { + framework::TensorCopySync( + *local_count, platform::CPUPlace(), &cpu_local_count); + cpu_local_count_data = cpu_local_count.data(); + } + auto global_count_len = 0; + framework::Tensor cpu_global_count; + if (platform::is_cpu_place(global_count->place())) { + cpu_global_count_data = global_count->data(); + global_count_len = global_count->numel(); + } else { + framework::TensorCopySync( + *global_count, platform::CPUPlace(), &cpu_global_count); + cpu_global_count_data = cpu_global_count.data(); + global_count_len = cpu_global_count.numel(); + } + + PADDLE_ENFORCE_GE( + ring_id, + 0, + platform::errors::InvalidArgument( + "The ring_id (%d) for global scatter op must be non-negative.", + ring_id)); + + auto place = ctx.GetPlace(); + auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); + paddle::distributed::ProcessGroup* pg = map->get(ring_id); + int nranks = pg->GetSize(); + auto in_feat = x->dims()[1]; + auto n_expert = local_count->dims()[0] / nranks; + + int64_t* expert_ptr = new int64_t[n_expert * nranks]; + expert_ptr[0] = 0; + auto tot_experts = n_expert * nranks; + for (auto i = 1; i < tot_experts; ++i) { + expert_ptr[i] = expert_ptr[i - 1] + cpu_local_count_data[i - 1]; + } + + auto recv_ptr = 0; + for (auto i = 0; i < n_expert; ++i) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); + for (auto j = 0; j < nranks; ++j) { + int idx = i + j * n_expert; + if (cpu_local_count_data[idx]) { + phi::DenseTensor tmp = *x; + pg->Send_Partial(tmp, + j, + expert_ptr[idx] * in_feat, + cpu_local_count_data[idx] * in_feat); + } + if (cpu_global_count_data[idx]) { + pg->Recv_Partial(*out, + j, + recv_ptr * in_feat, + cpu_global_count_data[idx] * in_feat); + recv_ptr += cpu_global_count_data[idx]; + } + } + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); + } + +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); +#else + PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize()); +#endif + +#else + PADDLE_THROW( + platform::errors::Unavailable("NCCL version >= 2.7.3 is needed.")); +#endif +#else + PADDLE_THROW( + platform::errors::Unavailable("PaddlePaddle should compile with GPU.")); +#endif +} + +template +void GlobalGatherFunctor(const phi::GPUContext& ctx, + const framework::Tensor* x, + const framework::Tensor* local_count, + const framework::Tensor* global_count, + int ring_id, + bool use_calc_stream, + framework::Tensor* out) { +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#if NCCL_VERSION_CODE >= 2703 + const int64_t* cpu_local_count_data; + const int64_t* cpu_global_count_data; + auto local_count_len = 0; + + framework::Tensor cpu_local_count; + if (platform::is_cpu_place(local_count->place())) { + cpu_local_count_data = local_count->data(); + local_count_len = local_count->numel(); + } else { + framework::TensorCopySync( + *local_count, platform::CPUPlace(), &cpu_local_count); + cpu_local_count_data = cpu_local_count.data(); + local_count_len = cpu_local_count.numel(); + } + + framework::Tensor cpu_global_count; + if (platform::is_cpu_place(global_count->place())) { + cpu_global_count_data = global_count->data(); + } else { + framework::TensorCopySync( + *global_count, platform::CPUPlace(), &cpu_global_count); + cpu_global_count_data = cpu_global_count.data(); + } + + ncclDataType_t dtype = + platform::ToNCCLDataType(framework::TransToProtoVarType(x->dtype())); + + PADDLE_ENFORCE_GE( + ring_id, + 0, + platform::errors::InvalidArgument( + "The ring_id (%d) for global gather op must be non-negative.", + ring_id)); + + auto place = ctx.GetPlace(); + auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); + gpuStream_t stream = nullptr; + if (use_calc_stream) { + stream = ctx.stream(); + } else { + stream = comm->stream(); + } + int nranks = comm->nranks(); + auto in_feat = x->dims()[1]; + auto n_expert = local_count->dims()[0] / nranks; + + int64_t* expert_ptr = new int64_t[n_expert * nranks]; + expert_ptr[0] = 0; + auto tot_experts = n_expert * nranks; + for (auto i = 1; i < tot_experts; ++i) { + expert_ptr[i] = expert_ptr[i - 1] + cpu_local_count_data[i - 1]; + } + auto send_ptr = 0; + auto send_buf = x->data(); + auto recv_buf = out->data(); + + for (auto i = 0; i < n_expert; ++i) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); + for (auto j = 0; j < nranks; ++j) { + int idx = i + j * n_expert; + if (cpu_global_count_data[idx]) { + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::ncclSend(send_buf + send_ptr * in_feat, + cpu_global_count_data[idx] * in_feat, + dtype, + j, + comm->comm(), + stream)); + send_ptr += cpu_global_count_data[idx]; + } + if (cpu_local_count_data[idx]) { + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::ncclRecv(recv_buf + expert_ptr[idx] * in_feat, + cpu_local_count_data[idx] * in_feat, + dtype, + j, + comm->comm(), + stream)); + } + } + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); + } +#else + PADDLE_THROW( + platform::errors::Unavailable("NCCL version >= 2.7.3 is needed.")); +#endif +#else + PADDLE_THROW( + platform::errors::Unavailable("PaddlePaddle should compile with GPU.")); +#endif +} + +template +void GlobalGatherProcessGroupFunctor(const phi::GPUContext& ctx, + const framework::Tensor* x, + const framework::Tensor* local_count, + const framework::Tensor* global_count, + int ring_id, + bool use_calc_stream, + framework::Tensor* out) { +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#if NCCL_VERSION_CODE >= 2703 + const int64_t* cpu_local_count_data; + const int64_t* cpu_global_count_data; + auto local_count_len = 0; + + framework::Tensor cpu_local_count; + if (platform::is_cpu_place(local_count->place())) { + cpu_local_count_data = local_count->data(); + local_count_len = local_count->numel(); + } else { + framework::TensorCopySync( + *local_count, platform::CPUPlace(), &cpu_local_count); + cpu_local_count_data = cpu_local_count.data(); + local_count_len = cpu_local_count.numel(); + } + + framework::Tensor cpu_global_count; + if (platform::is_cpu_place(global_count->place())) { + cpu_global_count_data = global_count->data(); + } else { + framework::TensorCopySync( + *global_count, platform::CPUPlace(), &cpu_global_count); + cpu_global_count_data = cpu_global_count.data(); + } + + PADDLE_ENFORCE_GE( + ring_id, + 0, + platform::errors::InvalidArgument( + "The ring_id (%d) for global gather op must be non-negative.", + ring_id)); + auto place = ctx.GetPlace(); + auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); + paddle::distributed::ProcessGroup* pg = map->get(ring_id); + + int nranks = pg->GetSize(); + auto in_feat = x->dims()[1]; + auto n_expert = local_count->dims()[0] / nranks; + + int64_t* expert_ptr = new int64_t[n_expert * nranks]; + expert_ptr[0] = 0; + auto tot_experts = n_expert * nranks; + for (auto i = 1; i < tot_experts; ++i) { + expert_ptr[i] = expert_ptr[i - 1] + cpu_local_count_data[i - 1]; + } + auto send_ptr = 0; + + for (auto i = 0; i < n_expert; ++i) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); + for (auto j = 0; j < nranks; ++j) { + int idx = i + j * n_expert; + if (cpu_global_count_data[idx]) { + phi::DenseTensor tmp = *x; + pg->Send_Partial( + tmp, j, send_ptr * in_feat, cpu_global_count_data[idx] * in_feat); + send_ptr += cpu_global_count_data[idx]; + } + if (cpu_local_count_data[idx]) { + pg->Recv_Partial(*out, + j, + expert_ptr[idx] * in_feat, + cpu_local_count_data[idx] * in_feat); + } + } + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); + } + +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); +#else + PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize()); +#endif + +#else + PADDLE_THROW( + platform::errors::Unavailable("NCCL version >= 2.7.3 is needed.")); +#endif +#else + PADDLE_THROW( + platform::errors::Unavailable("PaddlePaddle should compile with GPU.")); +#endif +} + +template +void MatMulAndAdd(const phi::GPUContext& dev_ctx, + const framework::Tensor* weight, + const framework::Tensor* input, + const framework::Tensor* bias, + bool istransA, + bool istransB, + bool compute_bias, + framework::Tensor* output, + framework::Tensor* bias_out) { + // Note: for blas.GEMM API in Paddle, it treats all inputs as row-major. + // here: (transa, transb): nt, input * weight. + CBLAS_TRANSPOSE transA = istransA ? CblasTrans : CblasNoTrans; + CBLAS_TRANSPOSE transB = istransB ? CblasTrans : CblasNoTrans; + T alpha = static_cast(1.0); + T beta = static_cast(0.0); + + // (m, n, k) = bsz_seq, output_size, input_size, (input, weight, out) + auto blas = phi::funcs::GetBlas(dev_ctx); + blas.GEMM(transA, + transB, + input->dims()[0], + weight->dims()[1], + input->dims()[1], + alpha, + input->data(), + weight->data(), + beta, + output->data()); + if (compute_bias) { + // bias_out = output + bias + std::vector ins = {output, bias}; + std::vector outs = {bias_out}; + phi::funcs::BroadcastKernel( + dev_ctx, ins, &outs, -1, phi::funcs::AddFunctor()); + } +} + +template +void FusedMoeKernel(const DeviceContext& context, + const DenseTensor& x, + const DenseTensor& gate_weight, + const DenseTensor& gate_bias, + const DenseTensor& ln_scale, + const DenseTensor& ln_bias, + const std::vector& experts_weight1, + const std::vector& experts_bias1, + const std::vector& experts_weight2, + const std::vector& experts_bias2, + bool pre_layer_norm, + float ln_epsilon, + int topk, + int mp_size, + int mp_rank, + int num_expert, + int world_size, + int moe_ring_id, + bool approximate, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/gpu/fused_moe_kernel.cu b/paddle/phi/kernels/gpu/fused_moe_kernel.cu new file mode 100644 index 0000000000000..0d72d7e3b058f --- /dev/null +++ b/paddle/phi/kernels/gpu/fused_moe_kernel.cu @@ -0,0 +1,484 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/gpu/fused_moe_kernel.cu.h" + +namespace phi { +using Tensor = DenseTensor; +namespace framework = paddle::framework; +namespace platform = paddle::platform; + +template +static void AllToAll(Tensor& tensor, // NOLINT + Tensor& out, + const int ring_id, + const phi::GPUContext& ctx) { + if (ring_id == -1) return; +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); + + if (map->has(ring_id)) { + paddle::distributed::ProcessGroup* pg = map->get(ring_id); + auto pg_nccl = static_cast(pg); + + std::vector in_tensor; + std::vector out_tensor; + in_tensor.push_back(tensor); + out_tensor.push_back(out); + auto task = pg_nccl->AllToAll(in_tensor, out_tensor, true, true); + task->Wait(); + } else { + auto dtype = platform::ToNCCLDataType( + framework::TransToProtoVarType(tensor.dtype())); + int64_t send_numel = tensor.numel(); // send_numel + auto place = ctx.GetPlace(); + auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); + int nranks = comm->nranks(); + auto stream = ctx.stream(); + + framework::DDim x_dims = tensor.dims(); + framework::DDim out_dims(x_dims); + PADDLE_ENFORCE_EQ( + x_dims[0] % nranks, + 0, + platform::errors::InvalidArgument( + "The first dimension size (%d) of the input tensor must be " + "divisible by the number of ranks (%d).", + x_dims[0], + nranks)); + auto send_buf = tensor.data(); + auto recv_buf = out.mutable_data(out_dims, place); + size_t offset = 0; + send_numel /= nranks; + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); + for (auto i = 0; i < nranks; ++i) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend( + send_buf + offset, send_numel, dtype, i, comm->comm(), stream)); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv( + recv_buf + offset, send_numel, dtype, i, comm->comm(), stream)); + offset += send_numel; + } + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); + } +#else + PADDLE_THROW(platform::errors::Unimplemented( + "PaddlePaddle should compile with NCCL or RCCL when used tensor model " + "parallel op.")); +#endif +} + +template +static void AllGather(Tensor& tensor, // NOLINT + Tensor& out, + const int ring_id, + const phi::GPUContext& ctx) { + if (ring_id == -1) return; +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); + + if (map->has(ring_id)) { + paddle::distributed::ProcessGroup* pg = map->get(ring_id); + auto pg_nccl = static_cast(pg); + + std::vector in_tensor; + std::vector out_tensor; + in_tensor.push_back(tensor); + out_tensor.push_back(out); + auto task = pg_nccl->AllGather(in_tensor, out_tensor, true, true); + task->Wait(); + } else { + auto dtype = platform::ToNCCLDataType( + framework::TransToProtoVarType(tensor.dtype())); + int64_t numel = tensor.numel(); + auto place = ctx.GetPlace(); + auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); + auto stream = ctx.stream(); + auto out_dims = tensor.dims(); + int nranks = comm->nranks(); + out_dims[0] *= nranks; + out.mutable_data(out_dims, place); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather( + tensor.data(), out.data(), numel, dtype, comm->comm(), stream)); + } +#else + PADDLE_THROW(platform::errors::Unimplemented( + "PaddlePaddle should compile with NCCL or RCCL when used tensor model " + "parallel op.")); +#endif +} + +template +void FusedMoeKernel(const DeviceContext& dev_ctx, + const DenseTensor& x, + const DenseTensor& gate_weight, + const DenseTensor& gate_bias, + const DenseTensor& ln_scale, + const DenseTensor& ln_bias, + const std::vector& experts_weight1, + const std::vector& experts_bias1, + const std::vector& experts_weight2, + const std::vector& experts_bias2, + bool pre_layer_norm, + float ln_epsilon, + int topk, + int mp_size, + int mp_rank, + int num_expert, + int world_size, + int moe_ring_id, + bool approximate, + DenseTensor* out) { + using U = paddle::operators::LayerNormParamType; + // output + dev_ctx.template Alloc(out); + // dim + auto x_dim = x.dims(); + int bsz = x_dim[0]; + int seq_len = x_dim[1]; + int bsz_seq = bsz * seq_len; + int d_model = x_dim[2]; + int tot_expert = world_size * num_expert; + int dim_feedforward = experts_weight1[0]->dims()[1]; + + // pre_layer_norm + const U* ln_scale_ptr = ln_scale.data(); + const U* ln_bias_ptr = ln_bias.data(); + Tensor ln_mean, ln_variance; + ln_mean.Resize({{bsz_seq}}); + auto* ln_mean_data = dev_ctx.template Alloc(&ln_mean); + ln_variance.Resize({{bsz_seq}}); + auto* ln_variance_data = dev_ctx.template Alloc(&ln_variance); + paddle::operators::FusedDropoutLayerNormHelper pre_layernorm_helper( + bsz_seq, d_model, ln_epsilon); + // tmp out + Tensor ln_out; + ln_out.Resize({{bsz, seq_len, d_model}}); + auto *ln_out_data = dev_ctx.template Alloc(&ln_out); + // after slice, bsz_seq should be change + int sliced_bsz_seq = bsz_seq; + int start = 0; + int end = 0; + if (mp_size > 1) { + start = bsz_seq / world_size * mp_rank; + end = std::min(start + bsz_seq / world_size, bsz_seq); + sliced_bsz_seq = end - start; + } + int out_batch_size = sliced_bsz_seq * topk; + // slice + Tensor sliced_inp; + sliced_inp.Resize({{sliced_bsz_seq, d_model}}); + dev_ctx.template Alloc(&sliced_inp); + // gate linear + Tensor gate_out; + gate_out.Resize({{sliced_bsz_seq, tot_expert}}); + dev_ctx.template Alloc(&gate_out); + // topk + Tensor topk_value, topk_idx; + topk_value.Resize({{sliced_bsz_seq, topk}}); + dev_ctx.template Alloc(&topk_value); + topk_idx.Resize({{sliced_bsz_seq, topk}}); + dev_ctx.template Alloc(&topk_idx); + // local expert count, global expert count + Tensor local_expert_count, global_expert_count; + local_expert_count.Resize({{tot_expert}}); + global_expert_count.Resize({{tot_expert}}); + dev_ctx.template Alloc(&local_expert_count); + dev_ctx.template Alloc(&global_expert_count); + // fwd_expert_count, fwd_batch_size + Tensor fwd_expert_count, fwd_batch_size; + fwd_expert_count.Resize({{num_expert}}); + fwd_batch_size.Resize({{1}}); + dev_ctx.template Alloc(&fwd_expert_count); + dev_ctx.template Alloc(&fwd_batch_size); + // pos, temp pos + Tensor pos, temp_pos; + pos.Resize({{out_batch_size}}); + temp_pos.Resize({{out_batch_size}}); + dev_ctx.template Alloc(&pos); + dev_ctx.template Alloc(&temp_pos); + // cumsum + Tensor lec_cum; + lec_cum.Resize({{tot_expert}}); + dev_ctx.template Alloc(&lec_cum); + // fused moe ffn tmp out + Tensor index_select_out; + index_select_out.Resize({{out_batch_size, d_model}}); + dev_ctx.template Alloc(&index_select_out); + Tensor global_gather_out; + global_gather_out.Resize({{out_batch_size, d_model}}); + dev_ctx.template Alloc(&global_gather_out); + Tensor moe_gather_out; + moe_gather_out.Resize({{out_batch_size, d_model}}); + dev_ctx.template Alloc(&moe_gather_out); + Tensor bmm_out; + bmm_out.Resize({{sliced_bsz_seq, 1, d_model}}); + dev_ctx.template Alloc(&bmm_out); + Tensor all_gather_out; + all_gather_out.Resize({{bsz_seq, d_model}}); + dev_ctx.template Alloc(&all_gather_out); + paddle::operators::DropoutParam dropout_param(false, 0, true, true, 0.0, nullptr, 0); + + // step1 layer norm + if (pre_layer_norm) { + pre_layernorm_helper.LayerNorm(dev_ctx, + x.data(), + ln_scale_ptr, + ln_bias_ptr, + ln_out_data, + ln_mean_data, + ln_variance_data); + } else { + ln_out = x; + } + // step2 resize and slice ln_out + ln_out.Resize({{bsz_seq, d_model}}); + if (mp_size > 1) { + sliced_inp = ln_out.Slice(start, end); + } else { + sliced_inp = ln_out; + } + // step3 gate & topk + MatMulAndAdd(dev_ctx, + &gate_weight, + &sliced_inp, + &gate_bias, + false, + false, + true, // compute bias + &gate_out, + &gate_out); + TopkKernel(dev_ctx, + gate_out, + Scalar(topk), + -1, + true, + false, + &topk_value, + &topk_idx); + // step4 prepare forward + // step4.1 number count + NumberCountKernel(dev_ctx, topk_idx, tot_expert, &local_expert_count); + // step4.2 all_to_all + if (world_size > 1) { + AllToAll(local_expert_count, global_expert_count, moe_ring_id, dev_ctx); + } else { + global_expert_count = local_expert_count; + } + // global expert count resize + global_expert_count.Resize({{world_size, num_expert}}); + // fwd expert count + SumKernel(dev_ctx, + global_expert_count, + IntArray({0}), + global_expert_count.dtype(), + false, + &fwd_expert_count); + // fwd batch size + SumKernel(dev_ctx, + fwd_expert_count, + IntArray({}), // axis is None + fwd_expert_count.dtype(), + false, + &fwd_batch_size); + // step4.3 cumsum & assign pos + CumsumKernel(dev_ctx, + local_expert_count, + Scalar(0), + false, + false, + false, + &lec_cum); + AssignPosCompute(dev_ctx, &lec_cum, &topk_idx, &pos, out_batch_size); + if (topk > 1) { + Tensor topk_tensor; + topk_tensor.Resize({{1}}); + dev_ctx.template Alloc(&topk_tensor); + FullKernel(dev_ctx, {1}, topk, pos.dtype(), &topk_tensor); + FloorDivideKernel(dev_ctx, + pos, + topk_tensor, + &temp_pos); + } else { + temp_pos = pos; + } + + Tensor fwd_expert_count_cpu; + framework::TensorCopySync(fwd_expert_count, platform::CPUPlace(), &fwd_expert_count_cpu); + + Tensor fwd_batch_size_cpu; + framework::TensorCopySync(fwd_batch_size, platform::CPUPlace(), &fwd_batch_size_cpu); + int fwd_bsz = fwd_batch_size_cpu.data()[0]; + + Tensor global_scatter_out; + global_scatter_out.Resize({{fwd_bsz, d_model}}); + dev_ctx.template Alloc(&global_scatter_out); + + std::vector tmp_expert_out; + Tensor all_expert_out; + all_expert_out.Resize({{fwd_bsz, d_model}}); + dev_ctx.template Alloc(&all_expert_out); + + // step 5, MOEScatter + // step 5.1, index select + // suppose tmp_pos->shape != [0] + IndexSelectKernel(dev_ctx, sliced_inp, temp_pos, 0, &index_select_out); + if (world_size > 1) { + auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); + // step 5.2, global_scatter + if (map->has(moe_ring_id)) { + GlobalScatterProcessGroupFunctor(dev_ctx, + &index_select_out, + &local_expert_count, + &global_expert_count, + moe_ring_id, + true, + &global_scatter_out); + } else { + GlobalScatterFunctor(dev_ctx, + &index_select_out, + &local_expert_count, + &global_expert_count, + moe_ring_id, + true, + &global_scatter_out); + } + } else { + global_scatter_out = index_select_out; + } + + // step 6, Expert Computation + if (fwd_bsz != 0) { + int last_index = 0; + for (int idx = 0; idx < num_expert; idx++) { + int cur_expert_count = fwd_expert_count_cpu.data()[idx]; + if (cur_expert_count <= 0) { + continue; + } + int end = cur_expert_count + last_index; + Tensor expert_out1; + expert_out1.Resize({{cur_expert_count, dim_feedforward}}); + dev_ctx.template Alloc(&expert_out1); + Tensor act_bias_out; + act_bias_out.Resize({{cur_expert_count, dim_feedforward}}); + dev_ctx.template Alloc(&act_bias_out); + Tensor expert_out2; + expert_out2.Resize({{cur_expert_count, d_model}}); + dev_ctx.template Alloc(&expert_out2); + paddle::operators::FusedDropoutHelper fused_act_dropout_helper( + dev_ctx, cur_expert_count, dim_feedforward, dropout_param); + + Tensor tmp_inp = global_scatter_out.Slice(last_index, end); + // linear1 matmul + MatMulAndAdd(dev_ctx, + experts_weight1[idx], + &tmp_inp, + nullptr, + false, + false, + false, // dont compute bias + &expert_out1, + nullptr); + // bias gelu + fused_act_dropout_helper.DropoutActBias(dev_ctx, + expert_out1.data(), + experts_bias1[idx]->data(), + "gelu", + act_bias_out.data(), + nullptr, + 1.0, + nullptr, + 0, + 1.0, + 1, + 127.0, + -127.0, + approximate); + // linear2 matmul & add + MatMulAndAdd(dev_ctx, + experts_weight2[idx], + &act_bias_out, + experts_bias2[idx], + false, + false, + true, // compute bias + &expert_out2, + &expert_out2); + tmp_expert_out.emplace_back(expert_out2); + last_index = end; + } + phi::funcs::ConcatFunctor concat; + concat(dev_ctx, tmp_expert_out, 0, &all_expert_out); + } else { + all_expert_out = global_scatter_out; + } + // step7. MOEGather + if (world_size > 1) { + auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); + // step 7.1, global_gather + if (map->has(moe_ring_id)) { + GlobalGatherProcessGroupFunctor(dev_ctx, + &all_expert_out, + &local_expert_count, + &global_expert_count, + moe_ring_id, + true, + &global_gather_out); + } else { + GlobalGatherFunctor(dev_ctx, + &all_expert_out, + &local_expert_count, + &global_expert_count, + moe_ring_id, + true, + &global_gather_out); + } + } else { + global_gather_out = all_expert_out; + } + // step 7.2, local_gather or scatter + // suppose pos->shape != [0] + ScatterKernel(dev_ctx, + moe_gather_out, + pos, + global_gather_out, + true, + &moe_gather_out); + // step 8, reshape & bmm + // moe gather out reshape + moe_gather_out.Resize({{sliced_bsz_seq, topk, d_model}}); + topk_value.Resize({{sliced_bsz_seq, 1, topk}}); + BmmKernel(dev_ctx, topk_value, moe_gather_out, &bmm_out); + bmm_out.Resize({{sliced_bsz_seq, d_model}}); + // step 9, AllGather + if (mp_size > 1) { + // all gather + AllGather(bmm_out, all_gather_out, moe_ring_id, dev_ctx); + } else { + all_gather_out = bmm_out; + } + // step 10, reshape + all_gather_out.Resize(x_dim); + // step 11, add residual + AddKernel(dev_ctx, all_gather_out, x, out); +} + +} // namespace phi + +PD_REGISTER_KERNEL(fused_moe_kernel, + GPU, + ALL_LAYOUT, + phi::FusedMoeKernel, + float, + double, + paddle::platform::float16) {} \ No newline at end of file diff --git a/paddle/phi/kernels/gpu/fused_moe_kernel.cu.h b/paddle/phi/kernels/gpu/fused_moe_kernel.cu.h new file mode 100644 index 0000000000000..6ef7eb3f81b94 --- /dev/null +++ b/paddle/phi/kernels/gpu/fused_moe_kernel.cu.h @@ -0,0 +1,63 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/fused_moe_kernel.h" + +DECLARE_bool(avoid_op_randomness); + +namespace phi { +static constexpr int kNumCUDAThreads = 512; +static constexpr int kNumMaxinumNumBlocks = 4096; + +static inline int NumBlocks(const int N) { + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, + kNumMaxinumNumBlocks); +} + +template +__global__ void AssignPos(T* cum_count, + const T* numbers, + T* out, + int64_t limit) { + CUDA_KERNEL_LOOP(i, limit) { + int number_idx = numbers[i]; + if (number_idx > -1) { + int p = platform::CudaAtomicAdd(cum_count + number_idx, -1); + out[p - 1] = i; + } + } +} + +template +void AssignPosCompute(const phi::GPUContext &dev_ctx, + framework::Tensor* cum_count, // (counter number) int32 | int64 + framework::Tensor* numbers, // (batch_size * seq_len, topk) int32 + framework::Tensor* out, + const int eff_num_len) { + auto place = dev_ctx.GetPlace(); + auto numel = numbers->numel(); + T* cum_data = const_cast(cum_count->data()); + + framework::DDim out_dims = phi::make_ddim({eff_num_len}); + auto out_data = out->mutable_data(out_dims, place); + + const T* num_data = numbers->data(); + + int blocks = NumBlocks(numel); + int threads = kNumCUDAThreads; + + AssignPos<<>>( + cum_data, num_data, out_data, numel); +} +} \ No newline at end of file diff --git a/paddle/phi/kernels/gpu/number_count_kernel.cu b/paddle/phi/kernels/gpu/number_count_kernel.cu new file mode 100644 index 0000000000000..184aaaaa138df --- /dev/null +++ b/paddle/phi/kernels/gpu/number_count_kernel.cu @@ -0,0 +1,96 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/number_count_kernel.h" + +namespace phi { +#define CEIL(_x_, _y_) (((_x_)-1) / (_y_) + 1) +#define PERTHREAD_EXPERTS 256 +#define WARP_SIZE 32 + +const int CUDA_NUM_THREADS = 512; +static inline int GET_BLOCKS(const int N) { + return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; +} + +template +__global__ void initialize_zero_kernel(T* data, const int length) { + CUDA_KERNEL_LOOP(idx, length) { data[idx] = static_cast(0); } +} + +template +__global__ void NumberCount(const T* numbers, + T* number_count, + int64_t batch_size, + int upper_range) { + int res_tmp[PERTHREAD_EXPERTS] = {0}; + int expert_min = blockIdx.x * PERTHREAD_EXPERTS; + int expert_max = expert_min + PERTHREAD_EXPERTS; + if (expert_max > upper_range) { + expert_max = upper_range; + } + for (int i = threadIdx.x; i < batch_size; i += blockDim.x) { + T idx = numbers[i]; + if (idx == -1) { + continue; + } + if (idx < expert_min || idx >= expert_max) { + continue; + } + res_tmp[idx - expert_min] += 1; + } + for (int i = expert_min; i < expert_max; ++i) { + int x = res_tmp[i - expert_min]; +#pragma unroll + for (int j = 1; j < WARP_SIZE; j <<= 1) { +#ifdef __HIPCC__ + x = x + __shfl_down(x, j); +#else + x = x + __shfl_down_sync(-1u, x, j); +#endif + } + if (threadIdx.x % WARP_SIZE == 0) { + paddle::platform::CudaAtomicAdd(number_count + i, x); + } + } +} + +template +void NumberCountKernel(const Context& dev_ctx, + const DenseTensor& numbers, + int upper_range, + DenseTensor* out){ + int64_t batch_size = numbers.numel(); + auto place = dev_ctx.GetPlace(); + + auto out_dims = make_ddim({upper_range}); + auto out_data = out->mutable_data(out_dims, place); + const T* gate_data = numbers.data(); + + initialize_zero_kernel + <<>>( + out_data, upper_range); + + NumberCount + <<>>( + gate_data, out_data, batch_size, upper_range); +} + +} // namespace phi + +PD_REGISTER_KERNEL(number_count_v2, + GPU, + ALL_LAYOUT, + phi::NumberCountKernel, + int64_t) {} \ No newline at end of file diff --git a/paddle/phi/kernels/number_count_kernel.h b/paddle/phi/kernels/number_count_kernel.h new file mode 100644 index 0000000000000..f86670868cc5e --- /dev/null +++ b/paddle/phi/kernels/number_count_kernel.h @@ -0,0 +1,36 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" + +#if defined(PADDLE_WITH_GLOO) +#include "paddle/fluid/framework/fleet/gloo_wrapper.h" +#endif + +namespace phi { + +template +void NumberCountKernel(const Context& dev_ctx, + const DenseTensor& numbers, + int upper_range, + DenseTensor* out); + +} // namespace phi diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py b/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py index 677bbc0c2b288..b2767b1dd1cbf 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py @@ -166,7 +166,6 @@ def _update_list(self): 'concat', 'split', 'fused_feedforward', - 'fused_moe', 'fused_attention', 'fused_multi_transformer', } diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py index 9b0f5fa1006aa..b23c94c7e4994 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py @@ -112,8 +112,6 @@ def _keep_fp32_input(op, in_name): } if op_type == 'fused_multi_transformer': return in_name in {'LnScale', 'LnBias', 'FFNLnScale', 'FFNLnBias'} - if op_type == 'fused_moe': - return in_name in {'LnScale', 'LnBias'} return False diff --git a/python/paddle/incubate/nn/layer/fused_transformer.py b/python/paddle/incubate/nn/layer/fused_transformer.py index 8c2053c4bd47c..b4f173fc725d3 100644 --- a/python/paddle/incubate/nn/layer/fused_transformer.py +++ b/python/paddle/incubate/nn/layer/fused_transformer.py @@ -17,7 +17,7 @@ from paddle.framework import ParamAttr import paddle import paddle.nn as nn -from paddle import _legacy_C_ops +from paddle import _legacy_C_ops, _C_ops from paddle.nn import ParameterList from paddle.nn.layer.transformer import ( _convert_attention_mask, @@ -1487,6 +1487,9 @@ def __init__(self, dim_feedforward, num_expert, top_k, + approximate, + moe_group=None, + mp_group=None, ln_scale=None, ln_bias=None, gate_weight=None, @@ -1494,9 +1497,7 @@ def __init__(self, linear1_weights=None, linear1_biases=None, linear2_weights=None, - linear2_biases=None, - moe_group=None, - mp_group=None): + linear2_biases=None): super(FusedMoELayer, self).__init__() # only support mp/dp self.group = moe_group @@ -1514,6 +1515,7 @@ def __init__(self, self.mp_size = mp_group.nranks self.d_model = d_model self.top_k = top_k + self.approximate = approximate self.ln_scale = self.create_parameter( shape=[d_model], attr=None, @@ -1578,9 +1580,13 @@ def get_attr(attrs, idx): is_bias=True, default_initializer=nn.initializer.Constant(value=0.0) )) + self.linear1_weights[i].name = "expert_" + self.linear1_weights[i].name + self.linear2_weights[i].name = "expert_" + self.linear2_weights[i].name + self.linear1_biases[i].name = "expert_" + self.linear1_biases[i].name + self.linear2_biases[i].name = "expert_" + self.linear2_biases[i].name def forward(self, inp): - inp = _legacy_C_ops.fused_moe( + inp = _C_ops.fused_moe_kernel( inp, self.gate_weight, self.gate_bias, @@ -1590,22 +1596,15 @@ def forward(self, inp): list(self.linear1_biases), list(self.linear2_weights), list(self.linear2_biases), - 'pre_layer_norm', True, - 'ln_epsilon', 1e-5, - 'topk', self.top_k, - 'mp_size', self.mp_size, - 'mp_rank', self.mp_rank, - 'num_expert', self.num_expert, - 'world_size', self.world_size, - 'moe_ring_id', - -1 if self.group is None else self.group.id + -1 if self.group is None else self.group.id, + self.approximate ) return inp @@ -1622,4 +1621,4 @@ def trans_to_fp16(l): trans_to_fp16(self.linear2_biases) _ = _to_dtype(self.gate_weight, dtype) _ = _to_dtype(self.gate_bias, dtype) - self._dtype = dtype \ No newline at end of file + self._dtype = dtype From 9df6dabfed3c6553a0778f262bfc843a063a706d Mon Sep 17 00:00:00 2001 From: tianyan01 Date: Mon, 25 Sep 2023 11:57:29 +0800 Subject: [PATCH 4/9] fix grid dim.y should less than 65535 bug --- .../fluid/operators/fused/fused_dropout_act_bias.h | 12 +++++++----- paddle/fluid/operators/fused/fused_dropout_common.h | 8 ++++++-- .../operators/fused/fused_residual_dropout_bias.h | 12 +++++++----- paddle/fluid/operators/fused/fused_softmax_mask.cu.h | 8 +++++--- 4 files changed, 25 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_dropout_act_bias.h b/paddle/fluid/operators/fused/fused_dropout_act_bias.h index e3e19d9ea6ebc..aa6c56f937524 100644 --- a/paddle/fluid/operators/fused/fused_dropout_act_bias.h +++ b/paddle/fluid/operators/fused/fused_dropout_act_bias.h @@ -86,8 +86,8 @@ __global__ void FusedDropoutActBias( const int quant_round_type = 1, const float quant_max_bound = 127.0, const float quant_min_bound = -127.0) { - int col_id = blockDim.x * blockIdx.x + threadIdx.x; - int row_id = blockIdx.y; + int col_id = threadIdx.x; + int row_id = gridDim.y * blockIdx.x + blockIdx.y; int idx = row_id * cols + col_id; curandStatePhilox4_32_10_t state; @@ -95,9 +95,11 @@ __global__ void FusedDropoutActBias( const T factor = GetFactor(dropout_prob, is_upscale_in_train, is_test); - for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) { - for (int i = col_id * VecSize; i < cols; - i += blockDim.x * gridDim.x * VecSize) { + int i = col_id * VecSize; + int r = row_id; + int stride = blockDim.x * VecSize; + for (; r < rows; r += blockDim.y * gridDim.y * gridDim.x) { + for (; i < cols; i += stride) { FusedResidualDropoutBiasOneThread(std::min( ctx.GetMaxThreadsPerBlock(), 512)))); - const auto blocks_x = + auto blocks_x = std::max(static_cast(1), (tmp_cols + threads - 1) / threads); - const auto blocks_y = std::max(static_cast(1), rows); + auto blocks_y = std::max(static_cast(1), rows); platform::GpuLaunchConfig config; + while (blocks_y > 65535) { + blocks_x *= 2; + blocks_y /= 2; + } config.block_per_grid.x = blocks_x; config.block_per_grid.y = blocks_y; config.thread_per_block.x = threads; diff --git a/paddle/fluid/operators/fused/fused_residual_dropout_bias.h b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h index f162d200abfe1..cabcbea2d2e60 100644 --- a/paddle/fluid/operators/fused/fused_residual_dropout_bias.h +++ b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h @@ -174,16 +174,18 @@ __global__ void FusedResidualDropoutBias( const float *dequant_out_scale_data = nullptr, const int quant_out_scale_offset = 0, const float quant_next_in_scale = 1.0) { - int col_id = blockDim.x * blockIdx.x + threadIdx.x; - int row_id = blockIdx.y; + int col_id = threadIdx.x; + int row_id = gridDim.y * blockIdx.x + blockIdx.y; int idx = row_id * cols + col_id; curandStatePhilox4_32_10_t state; curand_init(seed, idx, increment, &state); const T factor = GetFactor(dropout_prob, is_upscale_in_train, is_test); phi::funcs::ReluFunctor relu; - for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) { - for (int i = col_id * VecSize; i < cols; - i += blockDim.x * gridDim.x * VecSize) { + int i = col_id * VecSize; + int r = row_id; + int stride = blockDim.x * VecSize; + for (; r < rows; r += blockDim.y * gridDim.y * gridDim.x) { + for (; i < cols; i += stride) { FusedResidualDropoutBiasOneThread= seq_len) return; // ((bid*head_num + hid)*seq_len + seq_id) * seq_len - int offset = + int64_t offset = ((blockIdx.y * gridDim.z + blockIdx.z) * seq_len + seq_id) * seq_len; // (bid * seq_len + seq_id) * seq_len - int mask_offset = (blockIdx.y * seq_len + seq_id) * seq_len; + int64_t mask_offset = (blockIdx.y * seq_len + seq_id) * seq_len; src += offset; dst += offset; mask += mask_offset; From 186abc2d4a390f8710ed5df7a7736247d37eba72 Mon Sep 17 00:00:00 2001 From: tianyan01 Date: Fri, 10 Nov 2023 15:28:44 +0800 Subject: [PATCH 5/9] add new fmt & fmt_int8 & fmt_moe & fmt_moe_int8 --- paddle/fluid/framework/operator.cc | 2 + paddle/fluid/operators/fused/CMakeLists.txt | 4 + paddle/fluid/operators/fused/attn_gemm_int8.h | 141 ++- paddle/fluid/operators/fused/cublaslt.h | 849 +++++++++++-- paddle/fluid/operators/fused/fmha_ref.h | 357 +++++- .../fused/fused_multi_transformer_int8_op.cc | 83 +- .../fused/fused_multi_transformer_int8_op.cu | 713 ++++++++--- .../fused_multi_transformer_moe_int8_op.cc | 392 ++++++ .../fused_multi_transformer_moe_int8_op.cu | 782 ++++++++++++ .../fused/fused_multi_transformer_moe_op.cc | 319 +++++ .../fused/fused_multi_transformer_moe_op.cu | 844 +++++++++++++ .../fused/fused_multi_transformer_moe_op.h | 273 +++++ .../fused/fused_multi_transformer_op.cc | 95 +- .../fused/fused_multi_transformer_op.cu | 511 +++++--- .../fused/fused_multi_transformer_op.h | 869 ++++++++++++- .../operators/fused/layernorm_quant_dequant.h | 1076 +++++++++++++++++ .../operators/fused/quant_dequant_kernel.h | 134 +- paddle/fluid/platform/dynload/cublasLt.h | 32 +- paddle/fluid/pybind/op_function_generator.h | 35 +- paddle/phi/api/yaml/ops.yaml | 5 +- paddle/phi/backends/dynload/cublasLt.h | 31 +- paddle/phi/backends/gpu/gpu_resources.cc | 4 +- paddle/phi/infermeta/multiary.cc | 5 + paddle/phi/infermeta/multiary.h | 5 + paddle/phi/kernels/fused_moe_kernel.h | 123 ++ .../kernels/fusion/gpu/beam_search_softmax.cu | 1 + paddle/phi/kernels/gpu/fused_moe_kernel.cu | 153 +-- .../contrib/mixed_precision/fp16_lists.py | 1 + .../contrib/mixed_precision/fp16_utils.py | 2 +- python/paddle/fluid/dygraph/amp/auto_cast.py | 1 + python/paddle/incubate/nn/__init__.py | 4 +- .../nn/functional/fused_transformer.py | 8 + .../incubate/nn/layer/fused_transformer.py | 344 +++++- 33 files changed, 7372 insertions(+), 826 deletions(-) create mode 100644 paddle/fluid/operators/fused/fused_multi_transformer_moe_int8_op.cc create mode 100644 paddle/fluid/operators/fused/fused_multi_transformer_moe_int8_op.cu create mode 100644 paddle/fluid/operators/fused/fused_multi_transformer_moe_op.cc create mode 100644 paddle/fluid/operators/fused/fused_multi_transformer_moe_op.cu create mode 100644 paddle/fluid/operators/fused/fused_multi_transformer_moe_op.h create mode 100644 paddle/fluid/operators/fused/layernorm_quant_dequant.h diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 613cd4989276d..ad52f4513b7b5 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1420,6 +1420,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope, // result of HasAttr. if (!enable_cache_runtime_context_ && HasAttr(kEnableCacheRuntimeContext)) enable_cache_runtime_context_ = true; + if (this->Type() == "fused_multi_transformer_int8" || this->Type() == "fused_multi_transformer_moe_int8") + enable_cache_runtime_context_ = true; if (!all_kernels_must_compute_runtime_shape_ && HasAttr(kAllKernelsMustComputeRuntimeShape)) all_kernels_must_compute_runtime_shape_ = true; diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt index fe82565bc36f3..05c52b850db14 100755 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -24,6 +24,8 @@ register_operators( fused_feedforward_op fused_multi_transformer_op fused_multi_transformer_int8_op + fused_multi_transformer_moe_op + fused_multi_transformer_moe_int8_op fused_bias_dropout_residual_layer_norm_op resnet_unit_op fused_gemm_epilogue_op @@ -121,6 +123,8 @@ if(WITH_GPU OR WITH_ROCM) op_library(fused_attention_op) op_library(fused_multi_transformer_op) op_library(fused_multi_transformer_int8_op) + op_library(fused_multi_transformer_moe_op) + op_library(fused_multi_transformer_moe_int8_op) op_library(fused_bias_dropout_residual_layer_norm_op) endif() # resnet_unit needs cudnn 8.0 above diff --git a/paddle/fluid/operators/fused/attn_gemm_int8.h b/paddle/fluid/operators/fused/attn_gemm_int8.h index ba114df9085fb..ce392e98ba606 100644 --- a/paddle/fluid/operators/fused/attn_gemm_int8.h +++ b/paddle/fluid/operators/fused/attn_gemm_int8.h @@ -20,13 +20,14 @@ limitations under the License. */ #include "paddle/fluid/operators/fused/quant_dequant_kernel.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/float16.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" namespace paddle { namespace operators { -using Tensor = framework::Tensor; +using phi::backends::gpu::GpuLaunchConfig; template class AttnMatmulINT8 { @@ -34,23 +35,26 @@ class AttnMatmulINT8 { AttnMatmulINT8( const phi::GPUContext& dev_ctx, int m, int n, int k, bool compute_bias) : dev_ctx_(dev_ctx), m_(m), n_(n), k_(k), compute_bias_(compute_bias) { - auto helper = std::make_shared(m, k, n); - helpers_.emplace_back(helper); + cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle(); + helper_ = std::make_unique>(m, k, n, lt_handle); + gpu_config_ = std::make_unique( + phi::backends::gpu::GetGpuLaunchConfig1D( + dev_ctx, m * n, DequantKernelVecSize)); } ~AttnMatmulINT8() {} // This function is used to execute GEMM, with input and output's types are // both T. - void ComputeForward(const framework::Tensor* weight, - const framework::Tensor* input, - framework::Tensor* input_tmp, - const framework::Tensor* bias, - framework::Tensor* output, - framework::Tensor* output_tmp, - framework::Tensor* bias_out, + void ComputeForward(const phi::DenseTensor* weight, + const phi::DenseTensor* input, + phi::DenseTensor* input_tmp, + const phi::DenseTensor* bias, + phi::DenseTensor* output, + phi::DenseTensor* output_tmp, + phi::DenseTensor* bias_out, const float quant_in_scale, - const framework::Tensor* dequant_out_scale, - const int quant_out_scale_offset, + const phi::DenseTensor* dequant_out_scale, + phi::DenseTensor* workspace = nullptr, const int quant_round_type = 1, const float quant_max_bound = 127.0, const float quant_min_bound = -127.0) { @@ -64,24 +68,26 @@ class AttnMatmulINT8 { quant_min_bound, dev_ctx_.stream()); - helpers_[0]->GEMM(input_tmp->data(), - weight->data(), - output_tmp->data(), - dev_ctx_.stream()); + helper_->GEMM(input_tmp->data(), + weight->data(), + output_tmp->data(), + dev_ctx_.stream(), + (void*)workspace->data(), + workspace->numel()); dequantize_kernel_launcher(output_tmp->data(), output->data(), m_, n_, dev_ctx_.stream(), + gpu_config_.get(), quant_in_scale, - dequant_out_scale->data(), - quant_out_scale_offset); + dequant_out_scale->data()); if (compute_bias_) { // bias_out = output + bias - std::vector ins = {output, bias}; - std::vector outs = {bias_out}; + std::vector ins = {output, bias}; + std::vector outs = {bias_out}; phi::funcs::BroadcastKernel( dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor()); PADDLE_ENFORCE_EQ(cudaGetLastError(), @@ -95,66 +101,72 @@ class AttnMatmulINT8 { // This function is used to execute GEMM, with input and output's types are // both INT8. - void ComputeForwardINT8ToINT8(const framework::Tensor* weight, - framework::Tensor* input, - const framework::Tensor* bias, - framework::Tensor* output, - framework::Tensor* bias_out) { - helpers_[0]->GEMM(input->data(), - weight->data(), - output->data(), - dev_ctx_.stream()); + void ComputeForwardINT8ToINT8(const phi::DenseTensor* weight, + phi::DenseTensor* input, + const phi::DenseTensor* bias, + phi::DenseTensor* output, + phi::DenseTensor* bias_out, + phi::DenseTensor* workspace = nullptr) { + helper_->GEMM(input->data(), + weight->data(), + output->data(), + dev_ctx_.stream(), + (void*)workspace->data(), + workspace->numel()); } // This function is used to execute GEMM, with input and output's types are // INT8 and T. - void ComputeForwardINT8ToT(const framework::Tensor* weight, + void ComputeForwardINT8ToT(const phi::DenseTensor* weight, const float quant_in_scale, - framework::Tensor* input, - const framework::Tensor* bias, - framework::Tensor* output, - framework::Tensor* output_tmp, - framework::Tensor* bias_out, - const framework::Tensor* dequant_out_scale, - const int quant_out_scale_offset) { - helpers_[0]->GEMM(input->data(), - weight->data(), - output_tmp->data(), - dev_ctx_.stream()); + phi::DenseTensor* input, + const phi::DenseTensor* bias, + phi::DenseTensor* output, + phi::DenseTensor* output_tmp, + phi::DenseTensor* bias_out, + const phi::DenseTensor* dequant_out_scale, + phi::DenseTensor* workspace = nullptr) { + helper_->GEMM(input->data(), + weight->data(), + output_tmp->data(), + dev_ctx_.stream(), + (void*)workspace->data(), + workspace->numel()); dequantize_kernel_launcher(output_tmp->data(), output->data(), m_, n_, dev_ctx_.stream(), + gpu_config_.get(), quant_in_scale, - dequant_out_scale->data(), - quant_out_scale_offset); + dequant_out_scale->data()); if (compute_bias_) { // bias_out = output + bias - std::vector ins = {output, bias}; - std::vector outs = {bias_out}; + std::vector ins = {output, bias}; + std::vector outs = {bias_out}; phi::funcs::BroadcastKernel( dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor()); - PADDLE_ENFORCE_EQ(cudaGetLastError(), - cudaSuccess, - platform::errors::Fatal( - "cuda error occured after computing bias. " - "But it does not mean this error is caused by " - "bias computing")); + // PADDLE_ENFORCE_EQ(cudaGetLastError(), + // cudaSuccess, + // platform::errors::Fatal( + // "cuda error occured after computing bias. " + // "But it does not mean this error is caused by " + // "bias computing")); } } // This function is used to execute GEMM, with input and output's types are T // and INT8. - void ComputeForwardTToINT8(const framework::Tensor* weight, + void ComputeForwardTToINT8(const phi::DenseTensor* weight, const float quant_in_scale, - const framework::Tensor* input, - framework::Tensor* input_tmp, - const framework::Tensor* bias, - framework::Tensor* output, - framework::Tensor* bias_out, + const phi::DenseTensor* input, + phi::DenseTensor* input_tmp, + const phi::DenseTensor* bias, + phi::DenseTensor* output, + phi::DenseTensor* bias_out, + phi::DenseTensor* workspace = nullptr, const int quant_round_type = 1, const float quant_max_bound = 127.0, const float quant_min_bound = -127.0) { @@ -168,10 +180,12 @@ class AttnMatmulINT8 { quant_min_bound, dev_ctx_.stream()); - helpers_[0]->GEMM(input_tmp->data(), - weight->data(), - output->data(), - dev_ctx_.stream()); + helper_->GEMM(input_tmp->data(), + weight->data(), + output->data(), + dev_ctx_.stream(), + (void*)workspace->data(), + workspace->numel()); } private: @@ -182,8 +196,9 @@ class AttnMatmulINT8 { int k_; // k int compute_bias_; - std::vector> helpers_; + std::unique_ptr> helper_; + std::unique_ptr gpu_config_; }; } // namespace operators -} // namespace paddle +} // namespace paddle \ No newline at end of file diff --git a/paddle/fluid/operators/fused/cublaslt.h b/paddle/fluid/operators/fused/cublaslt.h index b9cc6b56f13ee..b889d3a4d219d 100644 --- a/paddle/fluid/operators/fused/cublaslt.h +++ b/paddle/fluid/operators/fused/cublaslt.h @@ -1,4 +1,5 @@ /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Copyright (c) 2022 NVIDIA Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,175 +15,796 @@ limitations under the License. */ #pragma once +#include #include #include #include #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/platform/dynload/cublasLt.h" +DECLARE_int64(cublaslt_exhaustive_search_times); + namespace dyl = paddle::platform::dynload; namespace paddle { namespace operators { + +#define PADDLE_CUBLASLT_STATUS_CHECK(name) \ + PADDLE_ENFORCE_EQ( \ + status, \ + CUBLAS_STATUS_SUCCESS, \ + platform::errors::External( \ + #name \ + "execution error" \ + "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " \ + "information")) + +const int split_k_candidates[] = {2, 3, 4, 5, 6, 8, 12, 16, 32}; + +struct CublasLtAlgoSelectorParam { + cublasLtMatmulAlgo_t algo; + int m; + int n; + int k; + int algo_id; + int swizzle; + int custom_option; + int tile; + int split_k_val; + int reduction_scheme; + int stages; + void* workspace; + size_t workspace_size; + float time; +}; + +inline bool compare_algo_time(const CublasLtAlgoSelectorParam& param_a, + const CublasLtAlgoSelectorParam& param_b) { + return (param_a.time < param_b.time); +} +#if CUDA_VERSION >= 11020 +class CublasLtAlgoCache { + public: + static CublasLtAlgoCache& Instance() { + static CublasLtAlgoCache instance(FLAGS_cublaslt_exhaustive_search_times); + return instance; + } + + template + void TestMatmulRun(cublasLtHandle_t handle, + cublasLtMatmulDesc_t matmul_desc, + cublasLtMatrixLayout_t a_desc, + cublasLtMatrixLayout_t b_desc, + cublasLtMatrixLayout_t c_desc, + void* alpha, + void* beta, + const InT* a, + const InT* b, + OutT* c, + CublasLtAlgoSelectorParam& param, // NOLINT + cudaEvent_t& start_event, // NOLINT + cudaEvent_t& stop_event, // NOLINT + cudaStream_t stream) { + cublasStatus_t status; + cublasLtMatmulHeuristicResult_t heuristic_result; + status = dyl::cublasLtMatmulAlgoCheck(handle, + matmul_desc, + a_desc, + b_desc, + c_desc, + c_desc, + ¶m.algo, + &heuristic_result); + PADDLE_CUBLASLT_STATUS_CHECK(cublasLtMatmulAlgoCheck); + if (status != CUBLAS_STATUS_SUCCESS || + heuristic_result.workspaceSize > param.workspace_size) { + // VLOG(0) << "param.workspace_size is " << param.workspace_size; + param.time = std::numeric_limits::max(); + return; + } + + PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(start_event, stream)); + int repeats = search_times_; + + for (int loop = 0; loop < repeats; loop++) { + status = dyl::cublasLtMatmul(handle, + matmul_desc, + alpha, + a, + a_desc, + b, + b_desc, + beta, + c, + c_desc, + c, + c_desc, + ¶m.algo, + param.workspace, + param.workspace_size, + stream); + if (status != CUBLAS_STATUS_SUCCESS) { + param.time = std::numeric_limits::max(); + return; + } + } + + PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(stop_event, stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); + + float time; + PADDLE_ENFORCE_GPU_SUCCESS( + cudaEventElapsedTime(&time, start_event, stop_event)); + + param.time = time / repeats; + } + + template + cublasLtMatmulAlgo_t* CublasLtAlgoSelect(cublasLtHandle_t handle, + int m, + int n, + int k, + const InT* a, + const InT* b, + OutT* c, + void* alpha, + void* beta, + cublasLtMatmulDesc_t matmul_desc, + cublasLtMatrixLayout_t a_desc, + cublasLtMatrixLayout_t b_desc, + cublasLtMatrixLayout_t c_desc, + cublasComputeType_t compute_type, + cudaDataType_t scale_type, + cudaDataType_t a_type, + cudaDataType_t b_type, + cudaDataType_t c_type, + void* workspace, + size_t workspace_size, + cudaStream_t stream) { + if (search_times_ <= 0) { + VLOG(3) << "Skip CublasLtAlgoSelect process, use default algo instead. " + "If you want to enable CublasLtAlgoSelect, " + "please set FLAGS_cublaslt_exhaustive_search_times > 0"; + return nullptr; + } + + VLOG(1) << "m n k " << m << " " << n << " " << k; + + int64_t seed = 0; + std::hash hash_fn; + + HashMatmulDesc_(matmul_desc, &seed, hash_fn); + HashMatrixLayoutDesc_(a_desc, &seed, hash_fn); + HashMatrixLayoutDesc_(b_desc, &seed, hash_fn); + HashMatrixLayoutDesc_(c_desc, &seed, hash_fn); + + cublasLtMatmulAlgo_t ret; + { + std::lock_guard lock(cache_mutex_); + auto it = map_.find(seed); + if (it != map_.end()) { + VLOG(3) << "CublasLtAlgoSelect Found in cache"; + return &(it->second); + } + } + VLOG(3) << "CublasLtAlgoSelect Not Found in cache"; + + // Get Ids + // https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmulAlgoGetIds + cublasStatus_t status = CUBLAS_STATUS_SUCCESS; + // std::vector algo_ids(requested_algo_count_); + int algo_ids[requested_algo_count_]; // NOLINT + + int num_algo_ids; + status = dyl::cublasLtMatmulAlgoGetIds(handle, + compute_type, + scale_type, + a_type, + b_type, + c_type, + c_type, + requested_algo_count_, + algo_ids, + &num_algo_ids); + PADDLE_CUBLASLT_STATUS_CHECK(cublasLtMatmulAlgoGetIds); + + // Traverse all posssible algo combinations + int step = 0; + int limit = 20000; + std::vector params; + + for (int idx = 0; idx < num_algo_ids; idx++) { + cublasLtMatmulAlgo_t algo; + + /* Initialize algo structure with given Algp ID */ + // https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmulAlgoInit + status = dyl::cublasLtMatmulAlgoInit(handle, + compute_type, + scale_type, + a_type, + b_type, + c_type, + c_type, + algo_ids[idx], + &algo); + PADDLE_CUBLASLT_STATUS_CHECK(cublasLtMatmulAlgoInit); + + // Query the tiles enums supported by that algo which is used to alloc + // enough space to store it + // https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmulAlgoCapGetAttribute + size_t attr_size = 0; + status = dyl::cublasLtMatmulAlgoCapGetAttribute( + &algo, CUBLASLT_ALGO_CAP_TILE_IDS, nullptr, 0, &attr_size); + PADDLE_CUBLASLT_STATUS_CHECK(cublasLtMatmulAlgoCapGetAttribute); + + int num_tiles = static_cast(attr_size / sizeof(int)); + std::vector tiles(num_tiles == 0 ? 1 : num_tiles); + if (num_tiles == 0) { + tiles[0] = CUBLASLT_MATMUL_TILE_UNDEFINED; + num_tiles = 1; + } else { + status = + dyl::cublasLtMatmulAlgoCapGetAttribute(&algo, + CUBLASLT_ALGO_CAP_TILE_IDS, + tiles.data(), + sizeof(int) * num_tiles, + &attr_size); + PADDLE_CUBLASLT_STATUS_CHECK(cublasLtMatmulAlgoCapGetAttribute); + } + + // Query the stages enums supported by that algo (cuda must >= 11.0) + status = dyl::cublasLtMatmulAlgoCapGetAttribute( + &algo, CUBLASLT_ALGO_CAP_STAGES_IDS, nullptr, 0, &attr_size); + PADDLE_CUBLASLT_STATUS_CHECK(cublasLtMatmulAlgoCapGetAttribute); + int num_stages = static_cast(attr_size / sizeof(int)); + std::vector stages(num_stages == 0 ? 1 : num_stages); + if (num_stages == 0) { + stages[0] = CUBLASLT_MATMUL_STAGES_UNDEFINED; + num_stages = 1; + } else { + status = + dyl::cublasLtMatmulAlgoCapGetAttribute(&algo, + CUBLASLT_ALGO_CAP_STAGES_IDS, + stages.data(), + sizeof(int) * num_stages, + &attr_size); + PADDLE_CUBLASLT_STATUS_CHECK(cublasLtMatmulAlgoCapGetAttribute); + } + + // Retrieve Other Algo Capabilities attributes + int splitk_support, red_mask, swizzling_max, custom_option_max; + status = dyl::cublasLtMatmulAlgoCapGetAttribute( + &algo, + CUBLASLT_ALGO_CAP_SPLITK_SUPPORT, + &splitk_support, + sizeof(splitk_support), + &attr_size); + status = dyl::cublasLtMatmulAlgoCapGetAttribute( + &algo, + CUBLASLT_ALGO_CAP_REDUCTION_SCHEME_MASK, + &red_mask, + sizeof(red_mask), + &attr_size); + status = dyl::cublasLtMatmulAlgoCapGetAttribute( + &algo, + CUBLASLT_ALGO_CAP_CTA_SWIZZLING_SUPPORT, + &swizzling_max, + sizeof(swizzling_max), + &attr_size); + status = dyl::cublasLtMatmulAlgoCapGetAttribute( + &algo, + CUBLASLT_ALGO_CAP_CUSTOM_OPTION_MAX, + &custom_option_max, + sizeof(custom_option_max), + &attr_size); + PADDLE_CUBLASLT_STATUS_CHECK(cublasLtMatmulAlgoCapGetAttribute); + + /* Loop over the different tiles */ + for (int tile_id = 0; tile_id < num_tiles && step < limit; tile_id++) { + /* Loop over different stages count */ + for (int stage_id = 0; stage_id < num_stages && step < limit; + stage_id++) { + /* Loop over the different custom option if any */ + for (int custom_option = 0; + custom_option <= custom_option_max && step < limit; + custom_option++) { + /* Loop over the CTAs swizzling support */ + for (int k = 0; k <= swizzling_max && step < limit; k++) { + int splir_k_trial = 0; + if (splitk_support) { + splir_k_trial += + sizeof(split_k_candidates) / sizeof(split_k_candidates[0]); + } + + for (int l = 0; (l < (1 + splir_k_trial)) && (step < limit); + l++) { + status = dyl::cublasLtMatmulAlgoConfigSetAttribute( + &algo, + CUBLASLT_ALGO_CONFIG_TILE_ID, + &tiles[tile_id], + sizeof(tiles[tile_id])); + status = dyl::cublasLtMatmulAlgoConfigSetAttribute( + &algo, + CUBLASLT_ALGO_CONFIG_STAGES_ID, + &stages[stage_id], + sizeof(stages[stage_id])); + status = dyl::cublasLtMatmulAlgoConfigSetAttribute( + &algo, + CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, + &custom_option, + sizeof(custom_option)); + status = dyl::cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &k, sizeof(k)); + int split_k_val = 0; + int reduction_scheme = CUBLASLT_REDUCTION_SCHEME_NONE; + status = dyl::cublasLtMatmulAlgoConfigSetAttribute( + &algo, + CUBLASLT_ALGO_CONFIG_SPLITK_NUM, + &split_k_val, + sizeof(split_k_val)); + status = dyl::cublasLtMatmulAlgoConfigSetAttribute( + &algo, + CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, + &reduction_scheme, + sizeof(int)); + if (l > 0) { // Split-K case + split_k_val = split_k_candidates[l - 1]; + status = dyl::cublasLtMatmulAlgoConfigSetAttribute( + &algo, + CUBLASLT_ALGO_CONFIG_SPLITK_NUM, + &split_k_candidates[l - 1], + sizeof(split_k_candidates[l - 1])); + for (reduction_scheme = 1; + reduction_scheme < + static_cast(CUBLASLT_REDUCTION_SCHEME_MASK) && + (step < limit); + reduction_scheme = reduction_scheme << 1) { + if (reduction_scheme & red_mask) { + status = dyl::cublasLtMatmulAlgoConfigSetAttribute( + &algo, + CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, + &reduction_scheme, + sizeof(reduction_scheme)); + PADDLE_CUBLASLT_STATUS_CHECK( + cublasLtMatmulAlgoConfigSetAttribute); + + cublasLtMatmulHeuristicResult_t heurResult; + status = dyl::cublasLtMatmulAlgoCheck(handle, + matmul_desc, + a_desc, + b_desc, + c_desc, + c_desc, + &algo, + &heurResult); + if (status == CUBLAS_STATUS_SUCCESS) { + CublasLtAlgoSelectorParam algo_select_params; + algo_select_params.algo = algo; + algo_select_params.m = m; + algo_select_params.n = n; + algo_select_params.k = k; + algo_select_params.algo_id = algo_ids[idx]; + algo_select_params.tile = tiles[tile_id]; + algo_select_params.swizzle = k; + algo_select_params.custom_option = custom_option; + algo_select_params.split_k_val = split_k_val; + algo_select_params.reduction_scheme = reduction_scheme; + algo_select_params.stages = stages[stage_id]; + algo_select_params.workspace_size = workspace_size; + algo_select_params.workspace = workspace; + params.emplace_back(algo_select_params); + step++; + } + } // end if + } + } else { + // Prepare algos + cublasLtMatmulHeuristicResult_t heurResult; + // https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmulAlgoCheck + status = dyl::cublasLtMatmulAlgoCheck(handle, + matmul_desc, + a_desc, + b_desc, + c_desc, + c_desc, + &algo, + &heurResult); + if (status == CUBLAS_STATUS_SUCCESS) { + CublasLtAlgoSelectorParam algo_select_params; + algo_select_params.algo = algo; + algo_select_params.m = m; + algo_select_params.n = n; + algo_select_params.k = k; + algo_select_params.algo_id = algo_ids[idx]; + algo_select_params.tile = tiles[tile_id]; + algo_select_params.swizzle = k; + algo_select_params.custom_option = custom_option; + algo_select_params.split_k_val = split_k_val; + algo_select_params.reduction_scheme = reduction_scheme; + algo_select_params.stages = stages[stage_id]; + algo_select_params.workspace_size = workspace_size; + algo_select_params.workspace = workspace; + params.emplace_back(algo_select_params); + step++; + } + } + } + } + } + } + } + } + cudaEvent_t start_event; + cudaEvent_t stop_event; + + PADDLE_ENFORCE_GPU_SUCCESS(cudaEventCreate(&start_event)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaEventCreate(&stop_event)); + + if (step == 0) { + VLOG(3) << "No algo can be used"; + return nullptr; + } + + VLOG(3) << "CublasLtAlgoSelect Start testRun " << step << " " + << params.size(); + + for (int i = 0; i < step; i++) { + TestMatmulRun(handle, + matmul_desc, + a_desc, + b_desc, + c_desc, + alpha, + beta, + a, + b, + c, + params[i], + start_event, + stop_event, + stream); + } + std::sort(params.begin(), params.end(), compare_algo_time); + + int res_id = 0; + while (params[res_id].time == 0) res_id++; + + if (res_id >= params.size()) { + VLOG(3) << "No algo can be used"; + return nullptr; + } + + VLOG(3) << "algo selected"; + + ret = params[res_id].algo; + std::lock_guard lock(cache_mutex_); + auto& algo_in_map = map_[seed]; + algo_in_map = ret; + return &algo_in_map; + } + + ~CublasLtAlgoCache() { + // Serialize map_ to cache file + std::ofstream outfile; + outfile.open(config_filename_, std::ios::out | std::ios::trunc); + outfile << dyl::cublasLtGetCudartVersion() << std::endl; + + for (const auto p : map_) { + outfile << p.first << " "; + for (int i = 0; i < 8; ++i) { + outfile << p.second.data[i] << " "; + } + outfile << std::endl; + } + outfile.close(); + } + + private: + explicit CublasLtAlgoCache(int search_times) : search_times_(search_times) { + // Init map_ from cache file + std::ifstream infile; + infile.open(config_filename_); + if (!infile.is_open()) { + VLOG(3) << "No CublasLtAlgoCache file found"; + return; + } + size_t cublaslt_version, real_cublaslt_version; + int64_t seed = 0; + uint64_t algo_data[8]; + infile >> cublaslt_version; + VLOG(1) << "cublaslt_version " << cublaslt_version; + + if (dyl::cublasLtGetCudartVersion() != cublaslt_version) { + LOG(INFO) << config_filename_ + << " is not compatible with current cublaslt_version " + << real_cublaslt_version; + return; + } + + while (!infile.eof()) { + infile >> seed >> algo_data[0] >> algo_data[1] >> algo_data[2] >> + algo_data[3] >> algo_data[4] >> algo_data[5] >> algo_data[6] >> + algo_data[7]; + + for (int i = 0; i < 8; ++i) { + map_[seed].data[i] = algo_data[i]; + } + } + infile.close(); + } + + std::string config_filename_{"/tmp/paddle_cublaslt_cache"}; + std::unordered_map map_; + int search_times_; + const int requested_algo_count_ = 100; + std::mutex cache_mutex_; + + void HashMatmulDesc_(cublasLtMatmulDesc_t desc, + int64_t* seed, + const std::hash& hash_fn) { + size_t size_to_write; + int trans_a, trans_b; + uint32_t epilogue; + + PADDLE_ENFORCE_GPU_SUCCESS( + dyl::cublasLtMatmulDescGetAttribute(desc, + CUBLASLT_MATMUL_DESC_TRANSA, + &trans_a, + sizeof(trans_a), + &size_to_write)); + HashValue_(seed, hash_fn, static_cast(trans_a)); + + PADDLE_ENFORCE_GPU_SUCCESS( + dyl::cublasLtMatmulDescGetAttribute(desc, + CUBLASLT_MATMUL_DESC_TRANSB, + &trans_b, + sizeof(trans_b), + &size_to_write)); + HashValue_(seed, hash_fn, static_cast(trans_b)); + + PADDLE_ENFORCE_GPU_SUCCESS( + dyl::cublasLtMatmulDescGetAttribute(desc, + CUBLASLT_MATMUL_DESC_EPILOGUE, + &epilogue, + sizeof(epilogue), + &size_to_write)); + HashValue_(seed, hash_fn, static_cast(epilogue)); + } + + void HashMatrixLayoutDesc_(cublasLtMatrixLayout_t desc, + int64_t* seed, + const std::hash& hash_fn) { + size_t size_to_write; + uint32_t dtype; + int32_t batch; + uint64_t row, col; + int64_t ld, batch_offset; + + PADDLE_ENFORCE_GPU_SUCCESS( + dyl::cublasLtMatrixLayoutGetAttribute(desc, + CUBLASLT_MATRIX_LAYOUT_TYPE, + &dtype, + sizeof(dtype), + &size_to_write)); + HashValue_(seed, hash_fn, static_cast(dtype)); + + PADDLE_ENFORCE_GPU_SUCCESS(dyl::cublasLtMatrixLayoutGetAttribute( + desc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batch, + sizeof(batch), + &size_to_write)); + HashValue_(seed, hash_fn, static_cast(batch)); + + PADDLE_ENFORCE_GPU_SUCCESS(dyl::cublasLtMatrixLayoutGetAttribute( + desc, CUBLASLT_MATRIX_LAYOUT_ROWS, &row, sizeof(row), &size_to_write)); + HashValue_(seed, hash_fn, static_cast(row)); + + PADDLE_ENFORCE_GPU_SUCCESS(dyl::cublasLtMatrixLayoutGetAttribute( + desc, CUBLASLT_MATRIX_LAYOUT_COLS, &col, sizeof(col), &size_to_write)); + HashValue_(seed, hash_fn, static_cast(col)); + + PADDLE_ENFORCE_GPU_SUCCESS(dyl::cublasLtMatrixLayoutGetAttribute( + desc, CUBLASLT_MATRIX_LAYOUT_LD, &ld, sizeof(ld), &size_to_write)); + HashValue_(seed, hash_fn, static_cast(ld)); + + PADDLE_ENFORCE_GPU_SUCCESS(dyl::cublasLtMatrixLayoutGetAttribute( + desc, + CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, + &batch_offset, + sizeof(batch_offset), + &size_to_write)); + HashValue_(seed, hash_fn, static_cast(batch_offset)); + } + + void HashValue_(int64_t* seed, + const std::hash& hash_fn, + int64_t value) { + *seed ^= hash_fn(value) + 0x9e3779b9 + (*seed << 6) + (*seed >> 2); + } +}; +#endif + +template class CublasLtHelper { public: - CublasLtHelper(int m, int k, int n) - : alpha_(1), beta_(0), m_(m), k_(k), n_(n) { + CublasLtHelper(int m, int k, int n, cublasLtHandle_t handle) + : alpha_(1), beta_(0), m_(m), k_(k), n_(n), handle_(handle) { cublasStatus_t status; // handle and matmul desc - status = dyl::cublasLtCreate(&handle_); + // status = dyl::cublasLtCreate(&handle_); + // PADDLE_CUBLASLT_STATUS_CHECK(cublasLtCreate); + if (std::is_same::value) { + scale_type_ = CUDA_R_16F; + a_type_ = CUDA_R_16F; + b_type_ = CUDA_R_16F; + c_type_ = CUDA_R_16F; #if CUBLAS_VER_MAJOR < 11 - cudaDataType_t cudaComputeType = CUDA_R_32I; + compute_type_ = CUDA_R_16F; #else - cublasComputeType_t cudaComputeType = CUBLAS_COMPUTE_32I; + compute_type_ = CUBLAS_COMPUTE_16F; #endif - - PADDLE_ENFORCE_EQ( - status, - CUBLAS_STATUS_SUCCESS, - platform::errors::External( - "cublasLtMatrixLayoutCreate execution error" - "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " - "information")); + } else if (std::is_same::value) { + scale_type_ = CUDA_R_32F; + a_type_ = CUDA_R_32F; + b_type_ = CUDA_R_32F; + c_type_ = CUDA_R_32F; +#if CUBLAS_VER_MAJOR < 11 + compute_type_ = CUDA_R_32F; +#else + compute_type_ = CUBLAS_COMPUTE_32F; +#endif + } else if (std::is_same::value) { + scale_type_ = CUDA_R_32I; + a_type_ = CUDA_R_8I; + b_type_ = CUDA_R_8I; + c_type_ = CUDA_R_32I; +#if CUBLAS_VER_MAJOR < 11 + compute_type_ = CUDA_R_32I; +#else + compute_type_ = CUBLAS_COMPUTE_32I; +#endif + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "CublasLtHelper just implement for FP16/FP32/INT32.")); + } #if CUBLAS_VER_MAJOR < 11 - status = dyl::cublasLtMatmulDescCreate(&matmul_desc_, cudaComputeType); + status = dyl::cublasLtMatmulDescCreate(&matmul_desc_, compute_type_); #else status = dyl::cublasLtMatmulDescCreate( - &matmul_desc_, cudaComputeType, CUDA_R_32I); + &matmul_desc_, compute_type_, scale_type_); #endif + PADDLE_CUBLASLT_STATUS_CHECK(cublasLtMatmulDescCreate); - PADDLE_ENFORCE_EQ( - status, - CUBLAS_STATUS_SUCCESS, - platform::errors::External( - "cublasLtMatmulDescCreate execution error" - "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " - "information")); + // Node: Just test for int8 cublasOperation_t op_transpose = CUBLAS_OP_T; status = dyl::cublasLtMatmulDescSetAttribute(matmul_desc_, CUBLASLT_MATMUL_DESC_TRANSA, &op_transpose, sizeof(op_transpose)); - PADDLE_ENFORCE_EQ( - status, - CUBLAS_STATUS_SUCCESS, - platform::errors::External( - "cublasLtMatmulDescSetAttribute execution error" - "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " - "information")); + PADDLE_CUBLASLT_STATUS_CHECK(cublasLtMatmulDescSetAttribute); // matrix desc - status = dyl::cublasLtMatrixLayoutCreate(&B_desc_, CUDA_R_8I, k, n, k); - PADDLE_ENFORCE_EQ( - status, - CUBLAS_STATUS_SUCCESS, - platform::errors::External( - "cublasLtMatrixLayoutCreate execution error" - "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " - "information")); + status = dyl::cublasLtMatrixLayoutCreate(&b_desc_, a_type_, k, n, k); + PADDLE_CUBLASLT_STATUS_CHECK(cublasLtMatrixLayoutCreate); - status = dyl::cublasLtMatrixLayoutCreate(&A_desc_, CUDA_R_8I, k, m, k); - PADDLE_ENFORCE_EQ( - status, - CUBLAS_STATUS_SUCCESS, - platform::errors::External( - "cublasLtMatrixLayoutCreate execution error" - "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " - "information")); + status = dyl::cublasLtMatrixLayoutCreate(&a_desc_, b_type_, k, m, k); + PADDLE_CUBLASLT_STATUS_CHECK(cublasLtMatrixLayoutCreate); - status = dyl::cublasLtMatrixLayoutCreate(&C_desc_, CUDA_R_32I, n, m, n); - PADDLE_ENFORCE_EQ( - status, - CUBLAS_STATUS_SUCCESS, - platform::errors::External( - "cublasLtMatrixLayoutCreate execution error" - "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " - "information")); + status = dyl::cublasLtMatrixLayoutCreate(&c_desc_, c_type_, n, m, n); + PADDLE_CUBLASLT_STATUS_CHECK(cublasLtMatrixLayoutCreate); } ~CublasLtHelper() { - if (handle_) dyl::cublasLtDestroy(handle_); - if (matmul_desc_) dyl::cublasLtMatmulDescDestroy(matmul_desc_); - if (A_desc_) dyl::cublasLtMatrixLayoutDestroy(A_desc_); - if (B_desc_) dyl::cublasLtMatrixLayoutDestroy(B_desc_); - if (C_desc_) dyl::cublasLtMatrixLayoutDestroy(C_desc_); + dyl::cublasLtMatmulDescDestroy(matmul_desc_); + dyl::cublasLtMatrixLayoutDestroy(a_desc_); + dyl::cublasLtMatrixLayoutDestroy(b_desc_); + dyl::cublasLtMatrixLayoutDestroy(c_desc_); } - void GEMM(int8_t* A_dev, - const int8_t* B_dev, - int32_t* C_dev, - cudaStream_t stream) { + template + void GEMM(const InT* a_dev, + const InT* b_dev, + OutT* c_dev, + cudaStream_t stream, + void* workspace = nullptr, + size_t workspace_size = 0) { cublasStatus_t status; -#if __CUDA_ARCH__ >= 800 && CUDA_VERSION >= 11020 - cublasLtMatmulAlgo_t algo; - int algoId = 21; - int swizzle = 0; - int customOption = 0; - int tile = 15; - int splitK_val = 0; - int reductionScheme = 0; -#if CUDA_VERSION >= 11000 - int stages = 23; -#endif - -#if CUBLAS_VER_MAJOR < 11 - cudaDataType_t cudaComputeType = CUDA_R_32I; -#else - cublasComputeType_t cudaComputeType = CUBLAS_COMPUTE_32I; -#endif - - dyl::cublasLtMatmulAlgoInit(handle_, - cudaComputeType, +#if CUDA_VERSION >= 11020 + cublasLtMatmulAlgo_t* algo = + CublasLtAlgoCache::Instance().CublasLtAlgoSelect(handle_, + m_, + n_, + k_, + b_dev, + a_dev, + c_dev, + &alpha_, + &beta_, + matmul_desc_, + b_desc_, + a_desc_, + c_desc_, + compute_type_, + scale_type_, + b_type_, + a_type_, + c_type_, + workspace, + workspace_size, + stream); + + cublasLtMatmulAlgo_t algo_; + if (algo == nullptr) { + int algoId = 21; + int swizzle = 0; + int customOption = 0; + int tile = 15; + int splitK_val = 0; + int reductionScheme = 0; + int stages = 23; + if (m_ >= 128) { + tile = 20; + stages = 17; + } + dyl::cublasLtMatmulAlgoInit(handle_, + compute_type_, CUDA_R_32I, CUDA_R_8I, CUDA_R_8I, CUDA_R_32I, CUDA_R_32I, algoId, - &algo); + &algo_); dyl::cublasLtMatmulAlgoConfigSetAttribute( - &algo, + &algo_, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(customOption), sizeof(customOption)); dyl::cublasLtMatmulAlgoConfigSetAttribute( - &algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &(tile), sizeof(tile)); - dyl::cublasLtMatmulAlgoConfigSetAttribute(&algo, + &algo_, CUBLASLT_ALGO_CONFIG_TILE_ID, &(tile), sizeof(tile)); + dyl::cublasLtMatmulAlgoConfigSetAttribute(&algo_, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &(splitK_val), sizeof(splitK_val)); dyl::cublasLtMatmulAlgoConfigSetAttribute( - &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(swizzle), sizeof(swizzle)); + &algo_, + CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, + &(swizzle), + sizeof(swizzle)); dyl::cublasLtMatmulAlgoConfigSetAttribute( - &algo, + &algo_, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(reductionScheme), sizeof(int)); -#if CUDA_VERSION >= 11000 dyl::cublasLtMatmulAlgoConfigSetAttribute( - &algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(stages), sizeof(stages)); -#endif + &algo_, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(stages), sizeof(stages)); + algo = &algo_; + } #endif + status = dyl::cublasLtMatmul(handle_, matmul_desc_, &alpha_, - B_dev, - B_desc_, - A_dev, - A_desc_, + b_dev, + b_desc_, + a_dev, + a_desc_, &beta_, - C_dev, - C_desc_, - C_dev, - C_desc_, -#if __CUDA_ARCH__ >= 800 && CUDA_VERSION >= 11020 - &algo, + c_dev, + c_desc_, + c_dev, + c_desc_, +#if CUDA_VERSION >= 11020 + algo, + workspace, + workspace_size, #else nullptr, -#endif nullptr, 0, +#endif stream); PADDLE_ENFORCE_EQ( status, @@ -196,11 +818,22 @@ class CublasLtHelper { private: cublasLtHandle_t handle_; cublasLtMatmulDesc_t matmul_desc_; - cublasLtMatrixLayout_t A_desc_; - cublasLtMatrixLayout_t B_desc_; - cublasLtMatrixLayout_t C_desc_; - int32_t alpha_; - int32_t beta_; + cublasLtMatrixLayout_t a_desc_; + cublasLtMatrixLayout_t b_desc_; + cublasLtMatrixLayout_t c_desc_; + + cudaDataType_t scale_type_; + cudaDataType_t a_type_; + cudaDataType_t b_type_; + cudaDataType_t c_type_; +#if CUBLAS_VER_MAJOR < 11 + cudaDataType_t compute_type_; +#else + cublasComputeType_t compute_type_; +#endif + + T alpha_; + T beta_; int m_; int k_; @@ -208,4 +841,4 @@ class CublasLtHelper { }; } // namespace operators -} // namespace paddle +} // namespace paddle \ No newline at end of file diff --git a/paddle/fluid/operators/fused/fmha_ref.h b/paddle/fluid/operators/fused/fmha_ref.h index 46153d980409d..be102980b4d78 100644 --- a/paddle/fluid/operators/fused/fmha_ref.h +++ b/paddle/fluid/operators/fused/fmha_ref.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include "paddle/fluid/operators/dropout_impl.cu.h" #include "paddle/fluid/operators/fused/fused_softmax_mask.cu.h" #include "paddle/fluid/operators/transpose_op.cu.h" @@ -27,7 +28,22 @@ limitations under the License. */ namespace paddle { namespace operators { -using Tensor = framework::Tensor; +template +class PDTraits; + +template <> +class PDTraits { +public: + typedef float DataType; + typedef float data_t; +}; + +template <> +class PDTraits { +public: + typedef half DataType; + typedef paddle::float16 data_t; +}; class AttnDropoutParam { public: @@ -46,7 +62,7 @@ class AttnDropoutParam { bool is_upscale_in_train, bool is_fix_seed, int seed_val, - const Tensor* seed) { + const phi::DenseTensor* seed) { is_test_ = is_test; dropout_implementation_ = dropout_implementation; dropout_prob_ = dropout_prob; @@ -61,9 +77,81 @@ class AttnDropoutParam { bool is_upscale_in_train_; bool is_fix_seed_; int seed_val_; - const Tensor* seed_; + const phi::DenseTensor* seed_; }; +template +__global__ void TransposeRemovingPadding(const T* input_data, + T* output_data, + const int batch_size, + const int num_head, + const int seq_len, + const int head_dim, + const int token_num, + const int elem_cnt, + const int* padding_offset) { + // transpose and remove padding + // [batch_size, num_head, seq_len, head_dim] -> [token_num, num_head, + // head_dim] + int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; + const int dim_embed = num_head * head_dim; + using LoadT = phi::AlignedVector; + LoadT src_vec; + + for (int32_t linear_index = idx * VecSize, + step = gridDim.x * blockDim.x * VecSize; + linear_index < elem_cnt; + linear_index += step) { + const int token_idx = linear_index / dim_embed; + const int ori_token_idx = + token_idx + (padding_offset == nullptr ? 0 : padding_offset[token_idx]); + const int ori_batch_id = ori_token_idx / seq_len; + const int ori_seq_id = ori_token_idx % seq_len; + const int ori_head_id = (linear_index % dim_embed) / head_dim; + const int ori_head_lane = (linear_index % dim_embed) % head_dim; + const int ori_idx = ori_batch_id * num_head * seq_len * head_dim + + ori_head_id * seq_len * head_dim + + ori_seq_id * head_dim + ori_head_lane; + phi::Load(&input_data[ori_idx], &src_vec); + phi::Store(src_vec, &output_data[linear_index]); + } +} + +template +void InvokeTransposeRemovePadding(const phi::GPUContext& dev_ctx, + const T* input_data, + T* output_data, + const int batch_size, + const int num_head, + const int seq_len, + const int head_dim, + const int token_num, + const int* padding_offset) { + // [batch_size, num_head, seq_len, head_dim] -> [token_num, num_head, + // head_dim] + constexpr int VEC_16B = 16; + const int elem_cnt = token_num * num_head * head_dim; + constexpr int PackSize = VEC_16B / sizeof(T); + PADDLE_ENFORCE_EQ( + head_dim % PackSize, + 0, + platform::errors::PreconditionNotMet( + "dim_head=%d must be divisible by vec_size=%d", head_dim, PackSize)); + const int32_t pack_num = elem_cnt / PackSize; + const int32_t block_size = 128; + int32_t grid_size = (pack_num + block_size - 1) / block_size; + TransposeRemovingPadding + <<>>(input_data, + output_data, + batch_size, + num_head, + seq_len, + head_dim, + token_num, + elem_cnt, + padding_offset); +} + template class FMHARef { public: @@ -82,18 +170,18 @@ class FMHARef { ~FMHARef() {} - void ComputeForward(const Tensor& qkv_input_tensor, - const Tensor* cache_kv_tensor, - const Tensor* src_mask_tensor, - Tensor* transpose_2_out_tensor, - Tensor* cache_kv_out_tensor, - Tensor* qk_out_tensor, - Tensor* src_mask_out_tensor, - Tensor* softmax_out_tensor, - Tensor* dropout_mask_out_tensor, - Tensor* dropout_out_tensor, - Tensor* qktv_out_tensor, - Tensor* fmha_out_tensor) { + void ComputeForward(const phi::DenseTensor& qkv_input_tensor, + const phi::DenseTensor* cache_kv_tensor, + const phi::DenseTensor* src_mask_tensor, + phi::DenseTensor* transpose_2_out_tensor, + phi::DenseTensor* cache_kv_out_tensor, + phi::DenseTensor* qk_out_tensor, + phi::DenseTensor* src_mask_out_tensor, + phi::DenseTensor* softmax_out_tensor, + phi::DenseTensor* dropout_mask_out_tensor, + phi::DenseTensor* dropout_out_tensor, + phi::DenseTensor* qktv_out_tensor, + phi::DenseTensor* fmha_out_tensor) { // input shape: [bs, seq_len, 3, num_head, head_dim] // transpose with perm [2, 0, 3, 1, 4], // output_shape: [3, bs, num_head, seq_len, head_dim] @@ -104,7 +192,6 @@ class FMHARef { T* qk_out_data = qk_out_tensor->data(); T* qktv_out_data = qktv_out_tensor->data(); T* softmax_out_data = softmax_out_tensor->data(); - T* dropout_out_data = dropout_out_tensor->data(); T* fmha_out_data = fmha_out_tensor->data(); auto out_seq_len = seq_len_; @@ -142,8 +229,8 @@ class FMHARef { float alpha = 1.0 / sqrt(head_dim_); auto q_tensor = transpose_2_out_tensor->Slice(0, 1); auto functor = phi::funcs::ScaleFunctor(alpha); - std::vector ins = {&q_tensor}; - std::vector outs = {&q_tensor}; + std::vector ins = {&q_tensor}; + std::vector outs = {&q_tensor}; phi::funcs::ElementwiseKernel(dev_ctx_, ins, &outs, functor); } @@ -183,8 +270,8 @@ class FMHARef { seq_len_, dev_ctx_.stream()); } else { - std::vector ins; - std::vector outs; + std::vector ins; + std::vector outs; ins.emplace_back(qk_out_tensor); ins.emplace_back(src_mask_tensor); outs.emplace_back(src_mask_out_tensor); @@ -220,11 +307,12 @@ class FMHARef { dropout_param_.is_upscale_in_train_, dropout_param_.is_fix_seed_, dropout_param_.seed_val_, - static_cast(*softmax_out_tensor), + static_cast(*softmax_out_tensor), dropout_param_.seed_, dropout_mask_out_tensor, dropout_out_tensor, false); + T* dropout_out_data = dropout_out_tensor->data(); blas.BatchedGEMM(transA, transB, gemm_m, @@ -262,22 +350,210 @@ class FMHARef { dev_ctx_, *qktv_out_tensor, perm_3, fmha_out_tensor); } - void ComputeBackward(const Tensor& transpose_2_out_tensor, - const Tensor* src_mask_tensor, - const Tensor& softmax_out_tensor, - const Tensor& dropout_mask_out_tensor, - const Tensor& dropout_out_tensor, - const Tensor& qk_out_tensor, - const Tensor& src_mask_out_tensor, - const Tensor& fmha_out_grad_tensor, - Tensor* qktv_out_grad_tensor, - Tensor* dropout_out_grad_tensor, - Tensor* softmax_out_grad_tensor, - Tensor* src_mask_out_grad_tensor, - Tensor* qk_out_grad_tensor, - Tensor* transpose_2_out_grad_tensor, - Tensor* src_mask_grad_tensor, - Tensor* qkv_input_grad_tensor) { + void ComputeForwardWithoutTranspose( + const phi::DenseTensor* cache_kv_tensor, + const phi::DenseTensor* src_mask_tensor, + const phi::DenseTensor* padding_offset_tensor, + phi::DenseTensor* q_transpose_out_tensor, + phi::DenseTensor* kv_transpose_out_tensor, + phi::DenseTensor* cache_kv_out_tensor, + phi::DenseTensor* qk_out_tensor, + phi::DenseTensor* src_mask_out_tensor, + phi::DenseTensor* softmax_out_tensor, + phi::DenseTensor* dropout_mask_out_tensor, + phi::DenseTensor* dropout_out_tensor, + phi::DenseTensor* qktv_out_tensor, + phi::DenseTensor* fmha_out_tensor, + const int token_num) { + // input shape: [bs, seq_len, 3, num_head, head_dim] + // transpose with perm [2, 0, 3, 1, 4], + // output_shape: [3, bs, num_head, seq_len, head_dim] + T* qk_out_data = qk_out_tensor->data(); + T* qktv_out_data = qktv_out_tensor->data(); + T* softmax_out_data = softmax_out_tensor->data(); + T* fmha_out_data = fmha_out_tensor->data(); + + auto out_seq_len = seq_len_; + if (cache_kv_tensor) { + // kv [2, bs, num_head, seq_len, head_dim] + phi::funcs::ConcatFunctor concat; + // out [2, bs, num_head, cache_seq_len + seq_len, head_dim] + concat(dev_ctx_, + {*cache_kv_tensor, *kv_transpose_out_tensor}, + 3, + cache_kv_out_tensor); + out_seq_len = cache_kv_out_tensor->dims()[3]; + } + + int64_t q_size = batch_size_ * seq_len_ * num_head_ * head_dim_; + T* q_ptr = q_transpose_out_tensor->data(); + T* k_ptr = nullptr; + T* v_ptr = nullptr; + + if (cache_kv_tensor) { + int64_t k_size = cache_kv_out_tensor->numel() / 2; + k_ptr = cache_kv_out_tensor->data(); + v_ptr = k_ptr + k_size; + } else { + int64_t k_size = q_size; + k_ptr = kv_transpose_out_tensor->data(); + v_ptr = k_ptr + k_size; + } + + { + // NOTE(wangxi): We scale Q with 1/sqrt(Dh) before QK^T, because for + // float16 calculation, INF may appear in QK^T if we do not scale before. + float alpha = 1.0 / sqrt(head_dim_); + auto functor = phi::funcs::ScaleFunctor(alpha); + std::vector ins = {q_transpose_out_tensor}; + std::vector outs = {q_transpose_out_tensor}; + phi::funcs::ElementwiseKernel(dev_ctx_, ins, &outs, functor); + } + + // q*k^t, batched_gemm + CBLAS_TRANSPOSE transA = CblasNoTrans; + CBLAS_TRANSPOSE transB = CblasTrans; + auto blas = phi::funcs::GetBlas(dev_ctx_); + int gemm_batch_size = batch_size_ * num_head_; + int gemm_m = seq_len_; + int gemm_n = out_seq_len; + int gemm_k = head_dim_; + T alpha = static_cast(1.0); + T beta = static_cast(0.0); + int64_t stride_a = gemm_m * gemm_k; + int64_t stride_b = gemm_k * gemm_n; + blas.BatchedGEMM(transA, + transB, + gemm_m, + gemm_n, + gemm_k, + alpha, + q_ptr, + k_ptr, + beta, + qk_out_data, + gemm_batch_size, + stride_a, + stride_b); + int softmax_axis = -1; + if (src_mask_tensor != nullptr) { + if (src_mask_out_tensor == nullptr && seq_len_ == out_seq_len) { + LaunchFusedSoftmaxMaskKernel(qk_out_data, + src_mask_tensor->data(), + softmax_out_data, + batch_size_, + num_head_, + seq_len_, + dev_ctx_.stream()); + } else { + std::vector ins; + std::vector outs; + ins.emplace_back(qk_out_tensor); + ins.emplace_back(src_mask_tensor); + outs.emplace_back(src_mask_out_tensor); + int elewise_add_axis = -1; + phi::funcs::BroadcastKernel( + dev_ctx_, + ins, + &outs, + elewise_add_axis, + phi::funcs::AddFunctor()); + + phi::SoftmaxForwardCUDAKernelDriver( + dev_ctx_, *src_mask_out_tensor, softmax_axis, softmax_out_tensor); + } + } else { + phi::SoftmaxForwardCUDAKernelDriver( + dev_ctx_, *qk_out_tensor, softmax_axis, softmax_out_tensor); + } + + transB = CblasNoTrans; + gemm_m = seq_len_; + gemm_n = head_dim_; + gemm_k = out_seq_len; + alpha = static_cast(1.0); + stride_a = gemm_m * gemm_k; + stride_b = gemm_k * gemm_n; + + if (dropout_param_.dropout_prob_) { + T* dropout_out_data = dropout_out_tensor->data(); + DropoutFwGPUKernelDriver( + static_cast(dev_ctx_), + dropout_param_.is_test_, + dropout_param_.dropout_prob_, + dropout_param_.is_upscale_in_train_, + dropout_param_.is_fix_seed_, + dropout_param_.seed_val_, + static_cast(*softmax_out_tensor), + dropout_param_.seed_, + dropout_mask_out_tensor, + dropout_out_tensor, + false); + blas.BatchedGEMM(transA, + transB, + gemm_m, + gemm_n, + gemm_k, + alpha, + dropout_out_data, + v_ptr, + beta, + qktv_out_data, + gemm_batch_size, + stride_a, + stride_b); + } else { + // softmax_out * v, batched_gemm + // output shape: [batch_size, num_heads, seq_len, head_dim] + blas.BatchedGEMM(transA, + transB, + gemm_m, + gemm_n, + gemm_k, + alpha, + softmax_out_data, + v_ptr, + beta, + qktv_out_data, + gemm_batch_size, + stride_a, + stride_b); + } + // transpose: [0, 2, 1, 3] + // output shape: [batch_size, seq_len, num_heads, head_dim] + if (!padding_offset_tensor) { + std::vector perm_3 = {0, 2, 1, 3}; + TransposeGPUKernelDriver( + dev_ctx_, *qktv_out_tensor, perm_3, fmha_out_tensor); + } else { + InvokeTransposeRemovePadding(dev_ctx_, + qktv_out_data, + fmha_out_data, + batch_size_, + num_head_, + seq_len_, + head_dim_, + token_num, + padding_offset_tensor->data()); + } + } + + void ComputeBackward(const phi::DenseTensor& transpose_2_out_tensor, + const phi::DenseTensor* src_mask_tensor, + const phi::DenseTensor& softmax_out_tensor, + const phi::DenseTensor& dropout_mask_out_tensor, + const phi::DenseTensor& dropout_out_tensor, + const phi::DenseTensor& qk_out_tensor, + const phi::DenseTensor& src_mask_out_tensor, + const phi::DenseTensor& fmha_out_grad_tensor, + phi::DenseTensor* qktv_out_grad_tensor, + phi::DenseTensor* dropout_out_grad_tensor, + phi::DenseTensor* softmax_out_grad_tensor, + phi::DenseTensor* src_mask_out_grad_tensor, + phi::DenseTensor* qk_out_grad_tensor, + phi::DenseTensor* transpose_2_out_grad_tensor, + phi::DenseTensor* src_mask_grad_tensor, + phi::DenseTensor* qkv_input_grad_tensor) { auto blas = phi::funcs::GetBlas(dev_ctx_); int q_size = batch_size_ * seq_len_ * num_head_ * head_dim_; int k_size = q_size; @@ -294,8 +570,6 @@ class FMHARef { const T* softmax_out_data = softmax_out_tensor.data(); T* softmax_out_grad_data = softmax_out_grad_tensor->data(); - const T* dropout_out_data = dropout_out_tensor.data(); - T* dropout_out_grad_data = dropout_out_grad_tensor->data(); T* qktv_out_grad_data = qktv_out_grad_tensor->data(); // transpose bw @@ -317,6 +591,7 @@ class FMHARef { int64_t stride_b = gemm_k * gemm_n; // bw: dy = x^t * dout if (dropout_param_.dropout_prob_) { + const T* dropout_out_data = dropout_out_tensor.data(); blas.BatchedGEMM(transA, transB, gemm_m, @@ -354,6 +629,7 @@ class FMHARef { stride_a = gemm_m * gemm_k; stride_b = gemm_k * gemm_n; if (dropout_param_.dropout_prob_) { + T* dropout_out_grad_data = dropout_out_grad_tensor->data(); blas.BatchedGEMM(transA, transB, gemm_m, @@ -389,7 +665,7 @@ class FMHARef { false, dropout_param_.dropout_prob_, dropout_param_.is_upscale_in_train_, - static_cast(*dropout_out_grad_tensor), + static_cast(*dropout_out_grad_tensor), dropout_mask_out_tensor, softmax_out_grad_tensor, false); @@ -495,3 +771,4 @@ class FMHARef { } // namespace operators } // namespace paddle + diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cc b/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cc index 9572a87aba21d..bd84667e21e0a 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cc +++ b/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cc @@ -21,7 +21,7 @@ limitations under the License. */ namespace paddle { namespace operators { -using Tensor = framework::Tensor; +using Tensor = phi::DenseTensor; class FusedMultiTransformerINT8Op : public framework::OperatorWithKernel { private: @@ -58,6 +58,12 @@ class FusedMultiTransformerINT8Op : public framework::OperatorWithKernel { CHECK_INPUTS(FFN1Weight); CHECK_INPUTS(FFN2Weight); + // scale + CHECK_INPUTS(QKVOutScale); + CHECK_INPUTS(OutLinearOutScale); + CHECK_INPUTS(FFN1OutScale); + CHECK_INPUTS(FFN2OutScale); + CHECK_OUTPUT(Out); // x: qkv's input [batch_size, seq_len, dim_embed] @@ -93,26 +99,6 @@ class FusedMultiTransformerINT8Op : public framework::OperatorWithKernel { x_dim, y_dim)); - if (ctx->Attrs().Get("ring_id") == -1) { - if (trans_qkvw) { - PADDLE_ENFORCE_EQ(y_dim[1] * y_dim[2], - y_dim[3], - platform::errors::InvalidArgument( - "The dimensions of qkv_weight must be 4" - "(3, num_head, dim_head, dim_embed)," - "and must satisfy the limitations: " - "(num_head * dim_head == dim_embed)")); - - } else { - PADDLE_ENFORCE_EQ(y_dim[2] * y_dim[3], - y_dim[0], - platform::errors::InvalidArgument( - "The dimensions of qkv_weight must be 4" - "(dim_embed, 3, num_head, dim_head)," - "and must satisfy the limitations: " - "(num_head * dim_head == dim_embed)")); - } - } if (ctx->HasInputs("CacheKV")) { // [2, batch_size, num_head, max_seq_len, head_size] @@ -129,13 +115,7 @@ class FusedMultiTransformerINT8Op : public framework::OperatorWithKernel { paddle::platform::errors::InvalidArgument( "The first dim of CacheKV must be 2, but got %d", c_dim[0])); // 2 - PADDLE_ENFORCE_EQ(c_dim[1], - x_dim[0], - paddle::platform::errors::InvalidArgument( - "The second dim of CacheKV must be equal with " - "batch size %d, but got %d", - x_dim[0], - c_dim[1])); // batch_size + PADDLE_ENFORCE_EQ(c_dim[2], trans_qkvw ? y_dim[1] : y_dim[2], paddle::platform::errors::InvalidArgument( @@ -143,12 +123,7 @@ class FusedMultiTransformerINT8Op : public framework::OperatorWithKernel { "head %d, but got %d", trans_qkvw ? y_dim[1] : y_dim[2], c_dim[2])); // num_head - PADDLE_ENFORCE_GT( - c_dim[3], - 0, - paddle::platform::errors::InvalidArgument( - "The forth dim of CacheKV must be greater than 0, but got %d", - c_dim[3])); // cache_seq_len + PADDLE_ENFORCE_EQ(c_dim[4], trans_qkvw ? y_dim[2] : y_dim[3], paddle::platform::errors::InvalidArgument( @@ -200,9 +175,21 @@ class FusedMultiTransformerINT8OpMaker AddInput("CacheKV", "(optional) The cached KV for generation inference.") .AsDispensable() .AsDuplicable(); + AddInput("PreCaches", + "(optional) The prefix caches for generation inference.") + .AsDispensable() + .AsDuplicable(); + AddInput("RotaryPosEmb", + "(optional) The RoPE embeddings for generation inference.") + .AsDispensable(); + AddInput("BeamCacheOffset", + "(optional) The offset of CacheKV when using BeamSearch.") + .AsDispensable(); AddInput("TimeStep", "(optional, int) The time step for generation inference.") .AsDispensable(); + AddInput("SeqLengths", "(optional) The sequence length tensor of inputs.") + .AsDispensable(); AddInput("SrcMask", "(optional) The attention mask tensor in fmha.") .AsDispensable(); AddInput("OutLinearW", "The out_linear weight tensor.").AsDuplicable(); @@ -232,20 +219,24 @@ class FusedMultiTransformerINT8OpMaker "In order to keep consistent with the PTQ/QAT calculation logic," "QKVOutScale should be max_bound * max_bound / max_range." "Here max_range is per-channel weight scale." - "The shape of QKVOutScale is [num_layers, num_channels]") - .AsDispensable(); + "The shape of QKVOutScale is [num_channels]") + .AsDispensable() + .AsDuplicable(); AddInput("OutLinearOutScale", "OutLinearOutScale is used to dequantize out_linear output tensor." "The definition and shape is the same as QKVOutScale") - .AsDispensable(); + .AsDispensable() + .AsDuplicable(); AddInput("FFN1OutScale", "FFN1OutScale is used to dequantize ffn1 output tensor." "The definition and shape is the same as QKVOutScale") - .AsDispensable(); + .AsDispensable() + .AsDuplicable(); AddInput("FFN2OutScale", "FFN2OutScale is used to dequantize ffn2 output tensor." "The definition and shape is the same as QKVOutScale") - .AsDispensable(); + .AsDispensable() + .AsDuplicable(); AddOutput("CacheKVOut", "The updated cache KV. Inplace with CacheKV") .AsDispensable() @@ -352,6 +343,18 @@ class FusedMultiTransformerINT8OpMaker "quant_min_bound", "(float, default -127.0) the min bound of float type to int type") .SetDefault(-127.0); + AddAttr("rotary_emb_dims", + "the Attr(dims) for RotaryPosEmb's Computation [default 0].") + .SetDefault(0) + .AddCustomChecker([](const int &rotary_emb_dims) { + PADDLE_ENFORCE_EQ( + rotary_emb_dims >= 0 && rotary_emb_dims <= 2, + true, + platform::errors::InvalidArgument( + "'rotary_emb_dims' in Op(Rotray) should be between" + "0 and 2, But received [%s].", + rotary_emb_dims)); + }); AddComment(R"DOC(fused multi transformer layers op)DOC"); } @@ -366,4 +369,4 @@ REGISTER_OPERATOR( ops::FusedMultiTransformerINT8Op, ops::FusedMultiTransformerINT8OpMaker, paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker); + paddle::framework::EmptyGradOpMaker); \ No newline at end of file diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu index 8e200275f8171..d87b3db45cb19 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu +++ b/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu @@ -14,10 +14,33 @@ limitations under the License. */ #include "paddle/fluid/operators/fused/attn_gemm_int8.h" #include "paddle/fluid/operators/fused/fused_multi_transformer_op.h" +#include "paddle/fluid/operators/fused/layernorm_quant_dequant.h" + +// DECLARE_int32(debug_layer_id); namespace paddle { namespace operators { +template +static void PrintMatrix(const T* mat_d, int num, std::string name) { + std::vector tmp(num); + cudaMemcpy(tmp.data(), mat_d, sizeof(T) * num, cudaMemcpyDeviceToHost); + + std::ofstream outfile; + outfile.open(name+".txt", std::ios::out); + std::stringstream ss; + + for (int i = 0; i < num; ++i) { + if(std::is_same::value) { + ss << static_cast(tmp[i]) << std::endl; + } else { + ss << std::setprecision(8) << tmp[i] << std::endl; + } + } + outfile << ss.str(); + outfile.close(); +} + template class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { public: @@ -25,9 +48,9 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { using U = LayerNormParamType; auto &dev_ctx = ctx.cuda_device_context(); - auto *time_step = ctx.Input("TimeStep"); + auto *time_step = ctx.Input("TimeStep"); // 0. input - auto *input_x = ctx.Input("X"); + auto *input_x = ctx.Input("X"); const auto input_x_dims = input_x->dims(); int bsz = input_x_dims[0]; int seq_len = input_x_dims[1]; @@ -48,36 +71,87 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { // dequant output scales, tensor, size = [num_layers, n], n is gemm output // size - auto *qkv_out_scale = ctx.Input("QKVOutScale"); - auto *out_linear_out_scale = ctx.Input("OutLinearOutScale"); - auto *ffn1_out_scale = ctx.Input("FFN1OutScale"); - auto *ffn2_out_scale = ctx.Input("FFN2OutScale"); + auto qkv_out_scales = ctx.MultiInput("QKVOutScale"); + auto out_linear_out_scales = + ctx.MultiInput("OutLinearOutScale"); + auto ffn1_out_scales = ctx.MultiInput("FFN1OutScale"); + auto ffn2_out_scales = ctx.MultiInput("FFN2OutScale"); + + bool remove_padding = false; + auto *sequence_lengths = ctx.Input("SeqLengths"); + if (sequence_lengths) { + remove_padding = true; + } + phi::DenseTensor d_token_tensor; + phi::DenseTensor padding_offset_tensor; + phi::DenseTensor x_remove_padding; + bool encoder_remove_padding = (remove_padding && !time_step); + int token_num = 0; + + auto *beam_cache_offset = ctx.Input("BeamCacheOffset"); + int beam_size = 1; + if (beam_cache_offset) { + beam_size = beam_cache_offset->dims()[1]; + } - int qkv_out_scale_n = qkv_out_scale->dims()[1]; - int out_linear_out_scale_n = out_linear_out_scale->dims()[1]; - int ffn1_out_scale_n = ffn1_out_scale->dims()[1]; - int ffn2_out_scale_n = ffn2_out_scale->dims()[1]; + // remove padding in encoder + if (encoder_remove_padding) { + // just for encoder + d_token_tensor.Resize({{1}}); + auto *d_token_num = dev_ctx.Alloc( + &d_token_tensor, d_token_tensor.numel() * sizeof(int)); + // alloc the max size of padding_offset_tensor + padding_offset_tensor.Resize({{bsz_seq}}); + dev_ctx.Alloc(&padding_offset_tensor, + padding_offset_tensor.numel() * sizeof(int)); + InvokeGetPaddingOffset(dev_ctx, + &token_num, + d_token_num, + padding_offset_tensor.data(), + sequence_lengths->data(), + bsz, + seq_len); + padding_offset_tensor.Resize({{token_num}}); + // VLOG(0) << "padding_offset_tensor: " << padding_offset_tensor; + x_remove_padding.Resize({{token_num, dim_embed}}); + dev_ctx.Alloc(&x_remove_padding, x_remove_padding.numel() * sizeof(T)); + InvokeRemovePadding(dev_ctx, + x_remove_padding.data(), + input_x->data(), + padding_offset_tensor.data(), + token_num, + dim_embed); + } else { + token_num = bsz_seq; + } + + if (token_num == 0) { + return; + } + + auto *padding_offset_data = + encoder_remove_padding ? padding_offset_tensor.data() : nullptr; // 1. layer norm const auto pre_layer_norm = ctx.Attr("pre_layer_norm"); const float epsilon = ctx.Attr("epsilon"); - auto ln_scales = ctx.MultiInput("LnScale"); - auto ln_biases = ctx.MultiInput("LnBias"); + auto ln_scales = ctx.MultiInput("LnScale"); + auto ln_biases = ctx.MultiInput("LnBias"); auto ln_compute = - AttnLayerNorm(dev_ctx, epsilon, bsz_seq, dim_embed); - Tensor ln_mean, ln_var; - ln_mean.Resize({{bsz_seq}}); + AttnLayerNorm(dev_ctx, epsilon, token_num, dim_embed); + phi::DenseTensor ln_mean, ln_var; + ln_mean.Resize({{token_num}}); auto *ln_mean_data = dev_ctx.Alloc(&ln_mean, ln_mean.numel() * sizeof(U)); - ln_var.Resize({{bsz_seq}}); + ln_var.Resize({{token_num}}); auto *ln_var_data = dev_ctx.Alloc(&ln_var, ln_var.numel() * sizeof(U)); // 2. qkv // x: qkv's input [batch_size, seq_len, dim_embed] // y: qkv's weight: [3, num_head, dim_head, dim_embed] - auto qkv_weights = ctx.MultiInput("QKVW"); - auto qkv_biases = ctx.MultiInput("QKVBias"); + auto qkv_weights = ctx.MultiInput("QKVW"); + auto qkv_biases = ctx.MultiInput("QKVBias"); const bool trans_qkvw = ctx.Attr("trans_qkvw"); const auto qkv_w_dims = qkv_weights[0]->dims(); int num_head = trans_qkvw ? qkv_w_dims[1] : qkv_w_dims[2]; @@ -89,21 +163,31 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { bool compute_bias = qkv_biases.size() > 0 && time_step == nullptr; // (transA, transB, compute_bias) = (false, trans_qkvw, false) AttnMatmulINT8 qkv_compute( - dev_ctx, bsz_seq, output_size, input_size, compute_bias); - Tensor qkv_out; - qkv_out.Resize({{bsz, seq_len, 3, num_head, dim_head}}); + dev_ctx, token_num, output_size, input_size, false /*compute_bias*/); + phi::DenseTensor qkv_out; + qkv_out.Resize({{token_num, 3, num_head, dim_head}}); auto *qkv_out_data = dev_ctx.Alloc(&qkv_out, qkv_out.numel() * sizeof(T)); + // 2.1 rotary + auto *rotary_tensor = ctx.Input("RotaryPosEmb"); + const int rotary_emb_dims = ctx.Attr("rotary_emb_dims"); + // 3. fmha AttnDropoutParam attn_param( true, "upscale_in_train", 0.0, true, true, 0, nullptr); auto fmha_compute = FMHARef(dev_ctx, bsz, seq_len, num_head, dim_head, attn_param); - auto *src_mask = ctx.Input("SrcMask"); - auto cache_kvs = ctx.MultiInput("CacheKV"); - auto cache_kv_outs = ctx.MultiOutput("CacheKVOut"); - // auto *time_step = ctx.Input("TimeStep"); + auto *src_mask = ctx.Input("SrcMask"); + auto cache_kvs = ctx.MultiInput("CacheKV"); + auto cache_kv_outs = ctx.MultiOutput("CacheKVOut"); + // auto *time_step = ctx.Input("TimeStep"); + + auto pre_caches = ctx.MultiInput("PreCaches"); + int cache_offset = 0; + if (pre_caches.size() > 0) { + cache_offset = pre_caches[0]->dims()[3]; + } auto out_seq_len = seq_len; if (time_step) { @@ -125,147 +209,203 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { "In decode stage, the seq_len of input must be 1, but now is %d", seq_len)); out_seq_len += time_step_value; + } else { + out_seq_len += cache_offset; + } + + phi::DenseTensor q_transpose_out, kv_transpose_out, qk_out; + q_transpose_out.Resize({{bsz, num_head, seq_len, dim_head}}); + auto *q_transpose_out_data = + dev_ctx.Alloc(&q_transpose_out, q_transpose_out.numel() * sizeof(T)); + + kv_transpose_out.Resize({{2, bsz, num_head, seq_len, dim_head}}); + auto *kv_transpose_out_data = dev_ctx.Alloc( + &kv_transpose_out, kv_transpose_out.numel() * sizeof(T)); + + if (encoder_remove_padding) { + InitValue(dev_ctx, q_transpose_out_data, q_transpose_out.numel(), static_cast(0.)); + InitValue(dev_ctx, kv_transpose_out_data, kv_transpose_out.numel(), static_cast(0.)); } - Tensor transpose_out_2, qk_out; - transpose_out_2.Resize({{3, bsz, num_head, seq_len, dim_head}}); - auto *transpose_out_2_data = - dev_ctx.Alloc(&transpose_out_2, transpose_out_2.numel() * sizeof(T)); qk_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); auto *qk_out_data = dev_ctx.Alloc(&qk_out, qk_out.numel() * sizeof(T)); - Tensor softmax_out; - Tensor attn_dropout_mask_out, attn_dropout_out; - Tensor qktv_out, fmha_out; + phi::DenseTensor src_mask_out; + if (cache_offset > 0) { + src_mask_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); + auto *src_mask_out_data = + dev_ctx.Alloc(&src_mask_out, src_mask_out.numel() * sizeof(T)); + } + + // [2, bs, num_head, cache_seq_len + seq_len, head_dim] + phi::DenseTensor pre_cache_kv_out; + if (cache_offset > 0) { + pre_cache_kv_out.Resize( + {{2, bsz, num_head, seq_len + cache_offset, dim_head}}); + auto *pre_cache_kv_out_data = dev_ctx.Alloc( + &pre_cache_kv_out, pre_cache_kv_out.numel() * sizeof(T)); + } + + phi::DenseTensor softmax_out; + phi::DenseTensor attn_dropout_mask_out, attn_dropout_out; + phi::DenseTensor qktv_out, fmha_out; softmax_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); auto *softmax_out_data = dev_ctx.Alloc(&softmax_out, softmax_out.numel() * sizeof(T)); - attn_dropout_mask_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); - auto *attn_dropout_mask_out_data = dev_ctx.Alloc( - &attn_dropout_mask_out, attn_dropout_mask_out.numel() * sizeof(T)); - attn_dropout_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); - auto *attn_dropout_data_data = dev_ctx.Alloc( - &attn_dropout_out, attn_dropout_out.numel() * sizeof(T)); + T *attn_dropout_mask_out_data = nullptr; + T *attn_dropout_data_data = nullptr; qktv_out.Resize({{bsz, num_head, seq_len, dim_head}}); auto *qktv_out_data = dev_ctx.Alloc(&qktv_out, qktv_out.numel() * sizeof(T)); - fmha_out.Resize({{bsz, seq_len, num_head, dim_head}}); + fmha_out.Resize({{token_num, num_head, dim_head}}); auto *fmha_out_data = dev_ctx.Alloc(&fmha_out, fmha_out.numel() * sizeof(T)); // 4. out_linear - auto out_linear_weights = ctx.MultiInput("OutLinearW"); - auto out_linear_biases = ctx.MultiInput("OutLinearBias"); + auto out_linear_weights = ctx.MultiInput("OutLinearW"); + auto out_linear_biases = ctx.MultiInput("OutLinearBias"); int ring_id = ctx.Attr("ring_id"); // (transA, transB, compute_bias) = (false, false, false) AttnMatmulINT8 out_linear_compute( - dev_ctx, bsz_seq, dim_embed, hidden_size, false); + dev_ctx, token_num, dim_embed, hidden_size, false); // 5. ln(residual + bias) DropoutParam dropout_param2(true, 0, true, true, 0.0, nullptr, 0); FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( - dev_ctx, bsz_seq, dim_embed, dropout_param2, epsilon); + dev_ctx, token_num, dim_embed, dropout_param2, epsilon); + FusedDropoutLayerNormHelper + fused_dropout_layernorm_helper_just_dequant( + dev_ctx, token_num, dim_embed, dropout_param2, epsilon); FusedDropoutLayerNormHelper fused_dropout_layernorm_helper_for_post_layernorm( - dev_ctx, bsz_seq, dim_embed, dropout_param2, epsilon); - auto ffn_ln_scales = ctx.MultiInput("FFNLnScale"); - auto ffn_ln_biases = ctx.MultiInput("FFNLnBias"); - Tensor bias_dropout_residual_out, dropout_mask_out; + dev_ctx, token_num, dim_embed, dropout_param2, epsilon); + + using LayerNormComputeType = float; + auto ffn_ln_scales = ctx.MultiInput("FFNLnScale"); + auto ffn_ln_biases = ctx.MultiInput("FFNLnBias"); + phi::DenseTensor bias_dropout_residual_out, dropout_mask_out; T *bias_dropout_residual_out_data = nullptr; if (pre_layer_norm) { - bias_dropout_residual_out.Resize({{bsz, seq_len, dim_embed}}); + bias_dropout_residual_out.Resize({{token_num, dim_embed}}); bias_dropout_residual_out_data = dev_ctx.Alloc(&bias_dropout_residual_out, bias_dropout_residual_out.numel() * sizeof(T)); } - dropout_mask_out.Resize({{bsz, seq_len, dim_embed}}); - auto *dropout_mask_out_data = dev_ctx.Alloc( - &dropout_mask_out, dropout_mask_out.numel() * sizeof(uint8_t)); + uint8_t *dropout_mask_out_data = nullptr; // 6. ffn matmul1 - auto ffn1_weights = ctx.MultiInput("FFN1Weight"); - auto ffn1_biases = ctx.MultiInput("FFN1Bias"); + auto ffn1_weights = ctx.MultiInput("FFN1Weight"); + auto ffn1_biases = ctx.MultiInput("FFN1Bias"); auto ffn1_weight_dim = ffn1_weights[0]->dims(); int dim_ffn = ffn1_weight_dim[0]; AttnMatmulINT8 ffn1_linear_compute( - dev_ctx, bsz_seq, dim_ffn, dim_embed, false); - Tensor ffn1_out; - ffn1_out.Resize({{bsz_seq, dim_ffn}}); + dev_ctx, token_num, dim_ffn, dim_embed, false); + phi::DenseTensor ffn1_out; + ffn1_out.Resize({{token_num, dim_ffn}}); auto *ffn1_out_data = dev_ctx.Alloc(&ffn1_out, ffn1_out.numel() * sizeof(T)); // 7. ffn act + bias DropoutParam ffn1_dropout_param(true, 0, true, true, 0.0, nullptr, 0); FusedDropoutHelper fused_act_dropout_helper( - dev_ctx, bsz_seq, dim_ffn, ffn1_dropout_param); + dev_ctx, token_num, dim_ffn, ffn1_dropout_param); FusedDropoutHelper fused_act_dropout_helper_for_post_layernorm( - dev_ctx, bsz_seq, dim_ffn, ffn1_dropout_param); - Tensor ffn1_dropout_out, ffn1_dropout_mask; - ffn1_dropout_out.Resize({{bsz_seq, dim_ffn}}); + dev_ctx, token_num, dim_ffn, ffn1_dropout_param); + phi::DenseTensor ffn1_dropout_out, ffn1_dropout_mask; + ffn1_dropout_out.Resize({{token_num, dim_ffn}}); auto *ffn1_dropout_out_data = dev_ctx.Alloc( &ffn1_dropout_out, ffn1_dropout_out.numel() * sizeof(T)); - ffn1_dropout_mask.Resize({{bsz_seq, dim_ffn}}); - auto *ffn1_dropout_mask_data = dev_ctx.Alloc( - &ffn1_dropout_mask, ffn1_dropout_mask.numel() * sizeof(uint8_t)); + uint8_t * ffn1_dropout_mask_data = nullptr; // 8. ffn2 matmul - auto ffn2_weights = ctx.MultiInput("FFN2Weight"); - auto ffn2_biases = ctx.MultiInput("FFN2Bias"); + auto ffn2_weights = ctx.MultiInput("FFN2Weight"); + auto ffn2_biases = ctx.MultiInput("FFN2Bias"); AttnMatmulINT8 ffn2_linear_compute( - dev_ctx, bsz_seq, dim_embed, dim_ffn, false); + dev_ctx, token_num, dim_embed, dim_ffn, false); // 9. ffn2 residual bias DropoutParam ffn2_dropout_param(true, 0, true, true, 0.0, nullptr, 0); FusedDropoutLayerNormHelper ffn2_fused_dropout_helper( - dev_ctx, bsz_seq, dim_embed, ffn2_dropout_param, epsilon); + dev_ctx, token_num, dim_embed, ffn2_dropout_param, epsilon); FusedDropoutLayerNormHelper ffn2_fused_dropout_dequant_helper( - dev_ctx, bsz_seq, dim_embed, ffn2_dropout_param, epsilon); + dev_ctx, token_num, dim_embed, ffn2_dropout_param, epsilon); FusedDropoutLayerNormHelper ffn2_fused_dropout_helper_for_post_layernorm( - dev_ctx, bsz_seq, dim_embed, ffn2_dropout_param, epsilon); + dev_ctx, token_num, dim_embed, ffn2_dropout_param, epsilon); // []. init workspace for cublasLt transform - Tensor input_workspace, output_workspace; + phi::DenseTensor input_workspace, output_workspace, cublaslt_workspace; // for input and output transform data is CUBLASLT_ORDER_COL32 format, - int m_max = bsz_seq, k_max = std::max(dim_embed, dim_ffn), + int m_max = token_num, k_max = std::max(dim_embed, dim_ffn), n_max = std::max({output_size, dim_embed, dim_ffn}); - input_workspace.Resize( - {{32 * ((m_max + 32 - 1) / 32), (k_max + 31) / 32 * 32}}); + input_workspace.Resize({{(m_max * k_max + 31) / 32 * 32}}); dev_ctx.Alloc(&input_workspace, input_workspace.numel() * sizeof(int8_t)); - output_workspace.Resize({{n_max * 4, (m_max + 31) / 32 * 32 * 4}}); + + output_workspace.Resize({{(n_max * m_max + 31) / 32 * 32}}); dev_ctx.Alloc(&output_workspace, output_workspace.numel() * sizeof(int32_t)); + cublaslt_workspace.Resize({{3000000}}); + dev_ctx.Alloc(&cublaslt_workspace, + cublaslt_workspace.numel() * sizeof(int8_t)); + // calc - auto *out = ctx.Output("Out"); + auto *out = ctx.Output("Out"); auto *from_data = dev_ctx.Alloc(out, out->numel() * sizeof(T)); - Tensor *from_tensor = out; - Tensor tmp_out; - tmp_out.Resize({{bsz, seq_len, dim_embed}}); + if (encoder_remove_padding) { + InitValue(dev_ctx, from_data, out->numel(), static_cast(0.)); + } + + // phi::DenseTensor *from_tensor = out; + // phi::DenseTensor tmp_out; + // tmp_out.Resize({{token_num, dim_embed}}); + + phi::DenseTensor tmp_out, tmp_out_rm_padding; + tmp_out.Resize({{token_num, dim_embed}}); + if (encoder_remove_padding) { + tmp_out_rm_padding.Resize({{token_num, dim_embed}}); + auto *tmp_out_rm_padding_data = dev_ctx.Alloc( + &tmp_out_rm_padding, tmp_out_rm_padding.numel() * sizeof(T)); + } + auto *tmp_out_data = dev_ctx.Alloc(&tmp_out, tmp_out.numel() * sizeof(T)); - auto *x_data = input_x->data(); - Tensor *buf0 = nullptr; - Tensor *buf1 = nullptr; + const T *x_data; + if (encoder_remove_padding) { + x_data = x_remove_padding.data(); + } else { + x_data = input_x->data(); + } + + phi::DenseTensor *buf0 = nullptr; + phi::DenseTensor *buf1 = nullptr; // step0: x --> buf1 // step1: buf1 --> buf0 // step2: buf0 --> buf1 int layers = qkv_weights.size(); - if (pre_layer_norm) { - buf1 = out; - } else { + if (encoder_remove_padding) { + // In the case of variable lengths, the padding needs to be rebuilt + // eventually. So buf0 and buf1 do not need to be changed according to the + // pre_layer_norm and the number of layers. buf0 = &tmp_out; - buf1 = out; + buf1 = &tmp_out_rm_padding; + } else { + if (pre_layer_norm) { + buf1 = out; + } else { + buf0 = &tmp_out; + buf1 = out; + } } for (int i = 0; i < layers; ++i) { @@ -274,6 +414,8 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { auto *ln_scale_data = ln_scales[i]->data(); auto *ln_bias_data = ln_biases[i]->data(); // TODO(wangxi): can remove mean var in inference + // if (i == FLAGS_debug_layer_id) + // VLOG(2) << "fmt in " << *input_x; ln_compute.ComputeForward(x_data, ln_scale_data, ln_bias_data, @@ -292,20 +434,23 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { #endif // step2. qkv - const Tensor *qkv_bias = qkv_biases.size() > 0 ? qkv_biases[i] : nullptr; + const phi::DenseTensor *qkv_bias = + qkv_biases.size() > 0 ? qkv_biases[i] : nullptr; // NOTE: in decoder stage, bias is fused in fmha - const Tensor *bias = time_step ? nullptr : qkv_bias; + const phi::DenseTensor *bias = time_step ? nullptr : qkv_bias; if (!pre_layer_norm && i == 0) { + const phi::DenseTensor *tmp_input_x = + (encoder_remove_padding) ? &x_remove_padding : input_x; qkv_compute.ComputeForward(qkv_weights[i], - input_x, + tmp_input_x, &input_workspace, bias, &qkv_out, &output_workspace, &qkv_out, qkv_in_scale[i], - qkv_out_scale, - i * qkv_out_scale_n, + qkv_out_scales[i], + &cublaslt_workspace, quant_round_type, quant_max_bound, quant_min_bound); @@ -318,12 +463,16 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { &output_workspace, &qkv_out, qkv_in_scale[i], - qkv_out_scale, - i * qkv_out_scale_n, + qkv_out_scales[i], + &cublaslt_workspace, quant_round_type, quant_max_bound, quant_min_bound); } else { + // if (i == FLAGS_debug_layer_id) { + // VLOG(2) << "qkv in " << input_workspace; + // VLOG(2) << "qkv weight " << *qkv_weights[i]; + // } qkv_compute.ComputeForwardINT8ToT(qkv_weights[i], qkv_in_scale[i], &input_workspace, @@ -331,53 +480,116 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { &qkv_out, &output_workspace, &qkv_out, - qkv_out_scale, - i * qkv_out_scale_n); + qkv_out_scales[i], + &cublaslt_workspace); } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step2"; #endif + // if (i == FLAGS_debug_layer_id) + // VLOG(2) << "qkv out " << qkv_out; // step3. fmha - const Tensor *cache_kv = cache_kvs.size() > 0 ? cache_kvs[i] : nullptr; - Tensor *cache_kv_out = cache_kv ? cache_kv_outs[i] : nullptr; + const phi::DenseTensor *cache_kv = + cache_kvs.size() > 0 ? cache_kvs[i] : nullptr; + phi::DenseTensor *cache_kv_out = cache_kv ? cache_kv_outs[i] : nullptr; if (time_step) { // generation decoder stage // [2, batch_size, num_head, max_seq_len, head_size] + int max_seq_len = cache_kv->dims()[3]; fmha(dev_ctx, qkv_out, *qkv_bias, *src_mask, + sequence_lengths, + rotary_tensor, + beam_cache_offset, cache_kv_out, &fmha_out, bsz, + beam_size, + // 1, max_seq_len, num_head, dim_head, - time_step->data()[0], + src_mask->dims()[3] - 1, + rotary_emb_dims, 1. / sqrt(dim_head)); } else if (cache_kv_out) { // generation context stage // TODO(wangxi): can remove dropout in inference - fmha_compute.ComputeForward(qkv_out, - nullptr, - src_mask, - &transpose_out_2, - nullptr, - &qk_out, - nullptr, - &softmax_out, - &attn_dropout_mask_out, - &attn_dropout_out, - &qktv_out, - &fmha_out); - // [3, bsz, num_head, seq_len, head_dim] - T *qkv_data = transpose_out_2_data; - int64_t q_size = bsz * seq_len * num_head * dim_head; - int64_t k_size = q_size; - const T *q_ptr = qkv_data; - const T *k_ptr = q_ptr + q_size; - const T *v_ptr = k_ptr + k_size; + const phi::DenseTensor *pre_cache_kv_tensor = + pre_caches.size() > 0 ? pre_caches[i] : nullptr; + phi::DenseTensor *pre_cache_kv_out_tmp = + cache_offset > 0 ? &pre_cache_kv_out : nullptr; + phi::DenseTensor *src_mask_tmp = + cache_offset > 0 ? &src_mask_out : nullptr; + const int *sequence_lengths_data = + encoder_remove_padding ? sequence_lengths->data() : nullptr; + qkv_bias_add_transpose_split(dev_ctx, + q_transpose_out_data, + kv_transpose_out_data, + qkv_out_data, + qkv_bias->data(), + padding_offset_data, + token_num, + bsz, + num_head, + seq_len, + dim_head, + compute_bias); + + // q_transpose_out_data [bs, head_num, seq_len, dim_head] + // kv_transpose_out_data [2, bs, head_num, seq_len, dim_head] + if (rotary_emb_dims != 0) { + auto *rotary_emb_data = rotary_tensor->data(); + rotary_qk(dev_ctx, + q_transpose_out_data, + kv_transpose_out_data, + q_transpose_out_data, + kv_transpose_out_data, + rotary_emb_data, + sequence_lengths_data, + rotary_emb_dims, + bsz, + num_head, + seq_len, + dim_head); + } + + phi::DenseTensor *tmp_padding_offset_tensor = + encoder_remove_padding ? &padding_offset_tensor : nullptr; + fmha_compute.ComputeForwardWithoutTranspose(pre_cache_kv_tensor, + src_mask, + tmp_padding_offset_tensor, + &q_transpose_out, + &kv_transpose_out, + pre_cache_kv_out_tmp, + &qk_out, + src_mask_tmp, + &softmax_out, + &attn_dropout_mask_out, + &attn_dropout_out, + &qktv_out, + &fmha_out, + token_num); + + const T *k_ptr = nullptr; + const T *v_ptr = nullptr; + + if (cache_offset > 0) { + // [2, bsz, num_head, cache_offset + seq_len, head_dim] + const T *kv_data = pre_cache_kv_out.data(); + k_ptr = kv_data; + int64_t k_size = bsz * num_head * (seq_len + cache_offset) * dim_head; + v_ptr = k_ptr + k_size; + } else { + // [3, bsz, num_head, seq_len, head_dim] + int64_t k_size = bsz * seq_len * num_head * dim_head; + const T *q_ptr = q_transpose_out_data; + k_ptr = kv_transpose_out_data; + v_ptr = k_ptr + k_size; + } // [2, bsz, num_head, max_seq_len, head_dim] int max_seq_len = cache_kv_out->dims()[3]; @@ -387,30 +599,68 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { T *cache_k_ptr = cache_kv_data; T *cache_v_ptr = cache_kv_data + cache_k_size; + const int seq_len_tmp = seq_len + cache_offset; write_cache_kv(dev_ctx, cache_k_ptr, cache_v_ptr, k_ptr, v_ptr, + sequence_lengths_data, bsz, num_head, - seq_len, + seq_len_tmp, max_seq_len, dim_head); } else { // not generation // TODO(wangxi): can remove dropout in inference - fmha_compute.ComputeForward(qkv_out, - cache_kv, - src_mask, - &transpose_out_2, - cache_kv_out, - &qk_out, - nullptr, - &softmax_out, - &attn_dropout_mask_out, - &attn_dropout_out, - &qktv_out, - &fmha_out); + qkv_bias_add_transpose_split(dev_ctx, + q_transpose_out_data, + kv_transpose_out_data, + qkv_out_data, + qkv_bias->data(), + padding_offset_data, + token_num, + bsz, + num_head, + seq_len, + dim_head, + compute_bias); + + // q_transpose_out_data [bs, head_num, seq_len, dim_head] + // kv_transpose_out_data [2, bs, head_num, seq_len, dim_head] + if (rotary_emb_dims != 0) { + auto *rotary_emb_data = rotary_tensor->data(); + const int *sequence_lengths_data = + encoder_remove_padding ? sequence_lengths->data() : nullptr; + rotary_qk(dev_ctx, + q_transpose_out_data, + kv_transpose_out_data, + q_transpose_out_data, + kv_transpose_out_data, + rotary_emb_data, + sequence_lengths_data, + rotary_emb_dims, + bsz, + num_head, + seq_len, + dim_head); + } + phi::DenseTensor *tmp_padding_offset_tensor = + encoder_remove_padding ? &padding_offset_tensor : nullptr; + fmha_compute.ComputeForwardWithoutTranspose(cache_kv, + src_mask, + tmp_padding_offset_tensor, + &q_transpose_out, + &kv_transpose_out, + cache_kv_out, + &qk_out, + nullptr, + &softmax_out, + &attn_dropout_mask_out, + &attn_dropout_out, + &qktv_out, + &fmha_out, + token_num); } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step3"; @@ -424,6 +674,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { nullptr, &output_workspace, nullptr, + &cublaslt_workspace, quant_round_type, quant_max_bound, quant_min_bound); @@ -431,6 +682,12 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { ring_id, bsz * seq_len * num_head * dim_head, dev_ctx); + // if (i == FLAGS_debug_layer_id) { + // VLOG(2) << "fmha_out " << fmha_out; + // VLOG(2) << "out_linear weight " << *out_linear_weights[i]; + // VLOG(2) << out_linear_in_scale[i]; + // VLOG(2) << "out_linear_out " << output_workspace; + // } } else { out_linear_compute.ComputeForward(out_linear_weights[i], &fmha_out, @@ -440,8 +697,8 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { &output_workspace, nullptr, out_linear_in_scale[i], - out_linear_out_scale, - i * out_linear_out_scale_n, + out_linear_out_scales[i], + &cublaslt_workspace, quant_round_type, quant_max_bound, quant_min_bound); @@ -453,31 +710,67 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { // step5. ln(residual + dropout(input + bias)) if (pre_layer_norm) { + VLOG(1) << "ffn1 in scale " << ffn1_in_scale[i]; auto *ln_scale_data = ffn_ln_scales[i]->data(); auto *ln_bias_data = ffn_ln_biases[i]->data(); auto *out_linear_bias_data = out_linear_biases[i]->data(); + + // inplace // non-inplace: buf1 -> input_workspace - fused_dropout_layernorm_helper.LayernormResidualDropoutBias( - dev_ctx, - output_workspace.data(), - x_data, - out_linear_bias_data, - ln_scale_data, - ln_bias_data, - bias_dropout_residual_out_data, - dropout_mask_out_data, - input_workspace.data(), - ln_mean_data, - ln_var_data, - out_linear_in_scale[i], - out_linear_out_scale->data(), - i * out_linear_out_scale_n, - ffn1_in_scale[i], - quant_round_type, - quant_max_bound, - quant_min_bound); + // fused_dropout_layernorm_helper.LayernormResidualDropoutBias( + // dev_ctx, + // output_workspace.data(), + // x_data, + // out_linear_bias_data, + // ln_scale_data, + // ln_bias_data, + // bias_dropout_residual_out_data, + // dropout_mask_out_data, + // input_workspace.data(), + // ln_mean_data, + // ln_var_data, + // out_linear_in_scale[i], + // out_linear_out_scales[i]->data(), + // ffn1_in_scale[i], + // quant_round_type, + // quant_max_bound, + // quant_min_bound); + + // phi::DenseTensor ffn_ln_out; + // ffn_ln_out.Resize(input_x->dims()); + // dev_ctx.Alloc(&ffn_ln_out); + + // fused_dropout_layernorm_helper_just_dequant.LayernormResidualDropoutBias( + // dev_ctx, + // output_workspace.data(), + // x_data, + // out_linear_bias_data, + // ln_scale_data, + // ln_bias_data, + // bias_dropout_residual_out_data, + // dropout_mask_out_data, + // ffn_ln_out.data(), + // ln_mean_data, + // ln_var_data, + // out_linear_in_scale[i], + // out_linear_out_scales[i]->data(), + // ffn1_in_scale[i], + // quant_round_type, + // quant_max_bound, + // quant_min_bound); + // LaunchQuantActKernel(ffn_ln_out.data(), bsz_seq, dim_embed, input_workspace.data(), ffn1_in_scale[i], quant_max_bound, quant_min_bound, dev_ctx.stream()); + + // VLOG(1) << "RIGHT out " << input_workspace; + // DequantSkipLoad load(output_workspace.data(), out_linear_bias_data, x_data, out_linear_out_scales[i]->data(), 0.0f, dim_embed); + DequantSkipLoadAndStoreResidual load(output_workspace.data(), out_linear_bias_data, x_data, + out_linear_out_scales[i]->data(), bias_dropout_residual_out_data, 0.0f, dim_embed); + AffineQuantStore store(input_workspace.data(), dim_embed, + ln_scale_data, ln_bias_data, ffn1_in_scale[i], quant_round_type, quant_max_bound, quant_min_bound); + DispatchLayerNorm(dev_ctx.stream(), load, store, token_num, dim_embed, epsilon, ln_mean_data, ln_var_data); + VLOG(1) << "WRONG out " << input_workspace; + } else { auto *ln_scale_data = ln_scales[i]->data(); auto *ln_bias_data = ln_biases[i]->data(); @@ -498,7 +791,9 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step5"; -#endif +#endif + // if (i == FLAGS_debug_layer_id) + // VLOG(2) << "ffn1_in " << input_workspace; // step6. ffn matmul1 @@ -507,7 +802,8 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { &input_workspace, nullptr, &output_workspace, - nullptr); + nullptr, + &cublaslt_workspace); } else { ffn1_linear_compute.ComputeForward(ffn1_weights[i], buf1, @@ -517,8 +813,8 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { &output_workspace, nullptr, ffn1_in_scale[i], - ffn1_out_scale, - i * ffn1_out_scale_n, + ffn1_out_scales[i], + &cublaslt_workspace, quant_round_type, quant_max_bound, quant_min_bound); @@ -526,6 +822,8 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step6"; #endif + // if (i == FLAGS_debug_layer_id) + // VLOG(2) << "ffn1 out " << output_workspace; // step7. act bias // TODO(wangxi): remove dropout mask in inference @@ -538,8 +836,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { input_workspace.data(), ffn1_dropout_mask_data, ffn1_in_scale[i], - ffn1_out_scale->data(), - i * ffn1_out_scale_n, + ffn1_out_scales[i]->data(), ffn2_in_scale[i], quant_round_type, quant_max_bound, @@ -556,6 +853,8 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step7"; #endif + // if (i == FLAGS_debug_layer_id) + // VLOG(2) << "ffn2 in " << input_workspace; // step8. ffn matmul2 if (pre_layer_norm) { @@ -563,7 +862,8 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { &input_workspace, nullptr, &output_workspace, - nullptr); + nullptr, + &cublaslt_workspace); } else { ffn2_linear_compute.ComputeForward(ffn2_weights[i], &ffn1_dropout_out, @@ -573,8 +873,8 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { &output_workspace, nullptr, ffn2_in_scale[i], - ffn2_out_scale, - i * ffn2_out_scale_n, + ffn2_out_scales[i], + &cublaslt_workspace, quant_round_type, quant_max_bound, quant_min_bound); @@ -582,6 +882,8 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step8.0"; #endif + // if (i == FLAGS_debug_layer_id) + // VLOG(2) << "ffn2 out " << output_workspace; if (pre_layer_norm) { AllReduce(output_workspace, @@ -602,25 +904,57 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { auto *ln_scale_data = ln_scales[i + 1]->data(); auto *ln_bias_data = ln_biases[i + 1]->data(); - ffn2_fused_dropout_helper.LayernormResidualDropoutBias( - dev_ctx, - output_workspace.data(), - bias_dropout_residual_out_data, - ffn2_biases[i]->data(), - ln_scale_data, - ln_bias_data, - buf1->data(), - dropout_mask_out_data, - input_workspace.data(), - ln_mean_data, - ln_var_data, - ffn2_in_scale[i], - ffn2_out_scale->data(), - i * ffn2_out_scale_n, - qkv_in_scale[i + 1], - quant_round_type, - quant_max_bound, - quant_min_bound); + // ffn2_fused_dropout_helper.LayernormResidualDropoutBias( + // dev_ctx, + // output_workspace.data(), + // bias_dropout_residual_out_data, + // ffn2_biases[i]->data(), + // ln_scale_data, + // ln_bias_data, + // buf1->data(), + // dropout_mask_out_data, + // input_workspace.data(), + // ln_mean_data, + // ln_var_data, + // ffn2_in_scale[i], + // ffn2_out_scales[i]->data(), + // qkv_in_scale[i + 1], + // quant_round_type, + // quant_max_bound, + // quant_min_bound); + + phi::DenseTensor ln_out; + ln_out.Resize(input_x->dims()); + dev_ctx.Alloc(&ln_out); + + // fused_dropout_layernorm_helper_just_dequant.LayernormResidualDropoutBias( + // dev_ctx, + // output_workspace.data(), + // bias_dropout_residual_out_data, + // ffn2_biases[i]->data(), + // ln_scale_data, + // ln_bias_data, + // buf1->data(), + // dropout_mask_out_data, + // ln_out.data(), + // ln_mean_data, + // ln_var_data, + // ffn2_in_scale[i], + // ffn2_out_scales[i]->data(), + // qkv_in_scale[i + 1], + // quant_round_type, + // quant_max_bound, + // quant_min_bound); + // LaunchQuantActKernel(ln_out.data(), bsz_seq, dim_embed, input_workspace.data(), qkv_in_scale[i + 1], quant_max_bound, quant_min_bound, dev_ctx.stream()); + // VLOG(1) << "RIGHT out " << input_workspace; + + // DequantSkipLoad load(output_workspace.data(), ffn2_biases[i]->data(), bias_dropout_residual_out_data, ffn2_out_scales[i]->data(), 0.0f, dim_embed); + DequantSkipLoadAndStoreResidual load(output_workspace.data(), ffn2_biases[i]->data(), bias_dropout_residual_out_data, + ffn2_out_scales[i]->data(), buf1->data(), 0.0f, dim_embed); + AffineQuantStore store(input_workspace.data(), dim_embed, + ln_scale_data, ln_bias_data, qkv_in_scale[i + 1], quant_round_type, quant_max_bound, quant_min_bound); + DispatchLayerNorm(dev_ctx.stream(), load, store, token_num, dim_embed, epsilon, ln_mean_data, ln_var_data); + VLOG(1) << "WRONG out " << input_workspace; } else { ffn2_fused_dropout_dequant_helper.ResidualDropoutBias( dev_ctx, @@ -630,8 +964,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { buf1->data(), dropout_mask_out_data, ffn2_in_scale[i], - ffn2_out_scale->data(), - i * ffn2_out_scale_n, + ffn2_out_scales[i]->data(), 1.0); } } else { @@ -656,6 +989,24 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { if (pre_layer_norm) { x_data = buf1->data(); } + VLOG(2) << "out layer " << i << " " << *buf1; + } + if (encoder_remove_padding) { + if (pre_layer_norm) { + InvokeRebuildPadding(dev_ctx, + from_data, + buf0->data(), + padding_offset_data, + token_num, + dim_embed); + } else { + InvokeRebuildPadding(dev_ctx, + from_data, + buf1->data(), + padding_offset_data, + token_num, + dim_embed); + } } } }; @@ -667,4 +1018,4 @@ namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL(fused_multi_transformer_int8, ops::FusedMultiTransformerINT8OpKernel, - ops::FusedMultiTransformerINT8OpKernel); + ops::FusedMultiTransformerINT8OpKernel); \ No newline at end of file diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_moe_int8_op.cc b/paddle/fluid/operators/fused/fused_multi_transformer_moe_int8_op.cc new file mode 100644 index 0000000000000..bc84a4613c56b --- /dev/null +++ b/paddle/fluid/operators/fused/fused_multi_transformer_moe_int8_op.cc @@ -0,0 +1,392 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = phi::DenseTensor; + +class FusedMultiTransformerMoeINT8Op : public framework::OperatorWithKernel { + private: + static constexpr const char *OpName = "FusedMultiTransformerMoeINT8Op"; + + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { +#define CHECK_INPUT(name) \ + OP_INOUT_CHECK(ctx->HasInput(#name), "Input", #name, OpName) +#define CHECK_INPUTS(name) \ + OP_INOUT_CHECK(ctx->HasInputs(#name), "Input", #name, OpName) +#define CHECK_OUTPUT(name) \ + OP_INOUT_CHECK(ctx->HasOutput(#name), "Output", #name, OpName) +#define CHECK_OUTPUTS(name) \ + OP_INOUT_CHECK(ctx->HasOutputs(#name), "Output", #name, OpName) + + CHECK_INPUT(X); + + // attention + CHECK_INPUTS(QKVW); + CHECK_INPUTS(OutLinearW); + + if (ctx->HasInput("TimeStep")) { + CHECK_INPUTS(CacheKV); + } + + if (ctx->HasInputs("CacheKV")) { + CHECK_OUTPUTS(CacheKVOut); + } + + // moe + CHECK_INPUTS(GateWeight); + CHECK_INPUTS(GateBias); + CHECK_INPUTS(ExpertWeight1); + CHECK_INPUTS(ExpertWeight2); + + // scale + CHECK_INPUTS(QKVOutScale); + CHECK_INPUTS(OutLinearOutScale); + CHECK_INPUTS(ExpertWeight1OutScale); + CHECK_INPUTS(ExpertWeight2OutScale); + + CHECK_OUTPUT(Out); + + // x: qkv's input [batch_size, seq_len, dim_embed] + // y: qkv's weight: [3, num_head, dim_head, dim_embed] + auto x_dim = ctx->GetInputDim("X"); + auto y_dim = ctx->GetInputsDim("QKVW")[0]; + bool trans_qkvw = ctx->Attrs().Get("trans_qkvw"); + PADDLE_ENFORCE_EQ( + x_dim.size(), + 3, + platform::errors::InvalidArgument("The dimensions of x must be 3" + "(batch_size, seq_len, dim_embed)," + "but received dimensions of" + "Input is [%d]", + x_dim.size())); + PADDLE_ENFORCE_EQ(y_dim.size(), + 4, + platform::errors::InvalidArgument( + "The dimensions of qkv_weight must be 4" + "(3, num_head, dim_head, dim_embed)," + "but received dimensions of" + "Input is [%d]", + y_dim.size())); + PADDLE_ENFORCE_EQ( + x_dim[2], + trans_qkvw ? y_dim[3] : y_dim[0], + platform::errors::InvalidArgument( + "ShapeError: the dimension of x_dim[2] and y_dim[3](trans_qkvw is " + "true) or y_dim[0](trans_qkvw is false)" + "must be equal. But received: the shape " + "of input x = [%s], and the shape of " + "input qkv_weight = [%s]", + x_dim, + y_dim)); + + + if (ctx->HasInputs("CacheKV")) { + // [2, batch_size, num_head, max_seq_len, head_size] + const auto &c_dims = ctx->GetInputsDim("CacheKV"); + const auto &c_dim = c_dims[0]; + + PADDLE_ENFORCE_EQ( + c_dim.size(), + 5, + paddle::platform::errors::InvalidArgument( + "The CacheKV must be 5 dims, but got %d", c_dim.size())); + PADDLE_ENFORCE_EQ(c_dim[0], + 2, + paddle::platform::errors::InvalidArgument( + "The first dim of CacheKV must be 2, but got %d", + c_dim[0])); // 2 + + PADDLE_ENFORCE_EQ(c_dim[2], + trans_qkvw ? y_dim[1] : y_dim[2], + paddle::platform::errors::InvalidArgument( + "The third dim of CacheKV must be equal with num " + "head %d, but got %d", + trans_qkvw ? y_dim[1] : y_dim[2], + c_dim[2])); // num_head + + PADDLE_ENFORCE_EQ(c_dim[4], + trans_qkvw ? y_dim[2] : y_dim[3], + paddle::platform::errors::InvalidArgument( + "The fifth dim of CacheKV must be equal with head " + "size %d, but got %d", + trans_qkvw ? y_dim[2] : y_dim[3], + c_dim[4])); // head_size + } + + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } + + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, + const Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const override { + if (var_name == "TimeStep") { + VLOG(10) << "var_name:" << var_name << " need not to transform"; + return expected_kernel_type; + } + return framework::OpKernelType( + expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + } +}; + +class FusedMultiTransformerMoeINT8OpMaker + : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "The input tensor."); + AddInput("LnScale", + "Scale is a 1-dimensional tensor of size " + "H. Here, H represents the last dimension of its input tensor.") + .AsDuplicable(); + AddInput("LnBias", + "Bias is a 1-dimensional tensor of size " + "H. Here, H represents the last dimension of its input tensor.") + .AsDuplicable(); + AddInput("QKVW", "The qkv weight tensor.").AsDuplicable(); + AddInput("QKVBias", "The qkv bias tensor.").AsDispensable().AsDuplicable(); + + AddInput("CacheKV", "(optional) The cached KV for generation inference.") + .AsDispensable() + .AsDuplicable(); + AddInput("PreCaches", + "(optional) The prefix caches for generation inference.") + .AsDispensable() + .AsDuplicable(); + AddInput("RotaryPosEmb", + "(optional) The RoPE embeddings for generation inference.") + .AsDispensable(); + AddInput("BeamCacheOffset", + "(optional) The offset of CacheKV when using BeamSearch.") + .AsDispensable(); + AddInput("TimeStep", + "(optional, int) The time step for generation inference.") + .AsDispensable(); + AddInput("SeqLengths", "(optional) The sequence length tensor of inputs.") + .AsDispensable(); + AddInput("SrcMask", "(optional) The attention mask tensor in fmha.") + .AsDispensable(); + AddInput("OutLinearW", "The out_linear weight tensor.").AsDuplicable(); + AddInput("OutLinearBias", "The out_linear bias tensor.") + .AsDispensable() + .AsDuplicable(); + + AddInput("GateWeight", "The gate_weights in moe") + .AsDuplicable(); + AddInput("GateBias", "The gate_biases in moe") + .AsDuplicable(); + AddInput("FFNLnScale", "The layer_norm scale of FusedFeedForward op") + .AsDuplicable(); + AddInput("FFNLnBias", "The layer_norm bias of FusedFeedForward op") + .AsDuplicable(); + AddInput("ExpertWeight1", "The expert_weights1 in moe") + .AsDuplicable(); + AddInput("ExpertBias1", "The expert_biases1 in moe") + .AsDuplicable(); + AddInput("ExpertWeight2", "The expert_weights2 in moe") + .AsDuplicable(); + AddInput("ExpertBias2", "The expert_biases2 in moe") + .AsDuplicable(); + + // out scale + AddInput("QKVOutScale", + "QKVOutScale is used to dequantize qkv output tensor." + "In order to keep consistent with the PTQ/QAT calculation logic," + "QKVOutScale should be max_bound * max_bound / max_range." + "Here max_range is per-channel weight scale." + "The shape of QKVOutScale is [num_layers]") + .AsDispensable() + .AsDuplicable(); + AddInput("OutLinearOutScale", + "OutLinearOutScale is used to dequantize out_linear output tensor." + "The definition and shape is the same as QKVOutScale") + .AsDispensable() + .AsDuplicable(); + AddInput("ExpertWeight1OutScale", + "ExpertWeight1OutScale is used to dequantize ffn1 output tensor." + "The definition and shape is num_layers * num_expert") + .AsDispensable() + .AsDuplicable(); + AddInput("ExpertWeight2OutScale", + "ExpertWeight2OutScale is used to dequantize ffn2 output tensor." + "The definition and shape is num_layers * num_expert") + .AsDispensable() + .AsDuplicable(); + + AddOutput("CacheKVOut", "The updated cache KV. Inplace with CacheKV") + .AsDispensable() + .AsDuplicable(); + AddOutput("Out", "Result after multi ."); + + AddAttr("pre_layer_norm", + "if true, the attention op uses pre_layer_norm architecure, " + "else, uses post_layer_norm architecuture. " + "[default true].") + .SetDefault(true); + AddAttr("epsilon", + "Constant for numerical stability [default 1e-5].") + .SetDefault(1e-5) + .AddCustomChecker([](const float &epsilon) { + PADDLE_ENFORCE_EQ(epsilon >= 0.0f && epsilon <= 0.001f, + true, + platform::errors::InvalidArgument( + "'epsilon' in Op(LayerNorm) should be between" + "0.0 and 0.001, But received [%s].", + epsilon)); + }); + + AddAttr("dropout_rate", "Probability of setting units to zero.") + .SetDefault(.5f) + .AddCustomChecker([](const float &drop_p) { + PADDLE_ENFORCE_EQ(drop_p >= 0.0f && drop_p <= 1.0f, + true, + platform::errors::InvalidArgument( + "'dropout_rate' must be between 0.0 and 1.0.")); + }); + + AddAttr("is_test", + "(bool, default false) Set to true for inference only, false " + "for training. Some layers may run faster when this is true.") + .SetDefault(false); + AddAttr( + "dropout_implementation", + "[\"downgrade_in_infer\"|\"upscale_in_train\"]" + "The meaning is the same as 'attn_dropout_implementation'.") + .SetDefault("downgrade_in_infer") + .AddCustomChecker([](const std::string &type) { + PADDLE_ENFORCE_EQ( + type == "downgrade_in_infer" || type == "upscale_in_train", + true, + platform::errors::InvalidArgument( + "dropout_implementation can only be downgrade_in_infer or " + "upscale_in_train")); + }); + AddAttr("act_method", "act_method").SetDefault("gelu"); + AddAttr( + "trans_qkvw", + "Whether the weights of qkv should be transposed. If true," + "the shape eights of qkv should be [3, num_head, dim_head, dim_embed]." + "Otherwise the shape of weights of qkv should be" + "[dim_embed, 3, num_head, dim_head]") + .SetDefault(true); + AddAttr( + "ring_id", + "ring id for tensor model parallel. distributed training and inference") + .SetDefault(-1); + + // for moe layer + AddAttr( + "topk", + "gate's topk im moe") + .SetDefault(2); + AddAttr( + "mp_size", + "mp size") + .SetDefault(1); + AddAttr( + "mp_rank", + "mp rank") + .SetDefault(0); + AddAttr( + "num_expert", + "experts num im moe") + .SetDefault(1); + AddAttr( + "world_size", + "world size") + .SetDefault(1); + AddAttr( + "moe_ring_id", + "experts communicate group's ring id") + .SetDefault(1); + AddAttr( + "approximate", + "approximate in expert compute gelu") + .SetDefault(true); + + // int8 add + // AddAttr("num_head", "num_head").SetDefault(0); + // AddAttr("dim_head", "dim_head").SetDefault(0); + // AddAttr("dim_ffn", "dim_ffn").SetDefault(0); + + AddAttr>( + "qkv_in_scale", + "qkv_in_scale is used to quantize qkv input tensor." + "in_scale is generated by PTQ or QAT, which represents valid max range " + "of this tensor." + "the size of qkv_in_scale should be num_layers, which is equal to " + "QKVW.dims()[0]") + .SetDefault({}); + AddAttr>( + "out_linear_in_scale", + "out_linear_in_scale is used to quantize out_linear input tensor." + "the size of out_linear_in_scale is the same as qkv_in_scale") + .SetDefault({}); + AddAttr>( + "expert_weight1_in_scale", + "expert_weight1_in_scale is used to quantize ffn1 input tensor." + "the size of expert_weight1_in_scale should be num_layers * num_expert") + .SetDefault({}); + AddAttr>( + "expert_weight2_in_scale", + "expert_weight2_in_scale is used to quantize ffn2 input tensor." + "the size of expert_weight2_in_scale should be num_layers * num_expert") + .SetDefault({}); + + AddAttr( + "quant_round_type", + "(int, default 1) The round type of fp32 to int." + "0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2" + "1: rounding to nearest ties away from zero. Eg: round(1.5)=2, " + "round(-2.5)=-3") + .SetDefault(1); + AddAttr( + "quant_max_bound", + "(float, default 127.0) the max bound of float type to int type") + .SetDefault(127.0); + AddAttr( + "quant_min_bound", + "(float, default -127.0) the min bound of float type to int type") + .SetDefault(-127.0); + AddComment(R"DOC(fused multi transformer layers op)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR( + fused_multi_transformer_moe_int8, + ops::FusedMultiTransformerMoeINT8Op, + ops::FusedMultiTransformerMoeINT8OpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); \ No newline at end of file diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_moe_int8_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_moe_int8_op.cu new file mode 100644 index 0000000000000..4869b14ca10de --- /dev/null +++ b/paddle/fluid/operators/fused/fused_multi_transformer_moe_int8_op.cu @@ -0,0 +1,782 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/fused/fused_multi_transformer_moe_op.h" +#include "paddle/fluid/operators/fused/layernorm_quant_dequant.h" + +namespace paddle { +namespace operators { + +using Tensor = phi::DenseTensor; + +template +static void PrintMatrix(const T* mat_d, int num, std::string name) { + std::vector tmp(num); + cudaMemcpy(tmp.data(), mat_d, sizeof(T) * num, cudaMemcpyDeviceToHost); + + std::ofstream outfile; + outfile.open(name+".txt", std::ios::out); + std::stringstream ss; + + for (int i = 0; i < num; ++i) { + if(std::is_same::value) { + ss << static_cast(tmp[i]) << std::endl; + } else { + ss << std::setprecision(8) << tmp[i] << std::endl; + } + } + outfile << ss.str(); + outfile.close(); +} + +template +class FusedMultiTransformerMoeINT8OpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + using U = LayerNormParamType; + auto &dev_ctx = ctx.cuda_device_context(); + + auto *time_step = ctx.Input("TimeStep"); + // 0. input + auto *input_x = ctx.Input("X"); + const auto input_x_dims = input_x->dims(); + int bsz = input_x_dims[0]; + int seq_len = input_x_dims[1]; + int dim_embed = input_x_dims[2]; + int bsz_seq = bsz * seq_len; + if (bsz_seq == 0) { + return; + } + + // quant input scales, vector, size = num_layers + auto qkv_in_scale = ctx.Attr>("qkv_in_scale"); + auto out_linear_in_scale = + ctx.Attr>("out_linear_in_scale"); + // moe expert scales, vector, size = num_expert * num_layers + auto expert_weight1_in_scale = ctx.Attr>("expert_weight1_in_scale"); + auto expert_weight2_in_scale = ctx.Attr>("expert_weight2_in_scale"); + + // quant round type and bound + auto quant_round_type = ctx.Attr("quant_round_type"); + auto quant_max_bound = ctx.Attr("quant_max_bound"); + auto quant_min_bound = ctx.Attr("quant_min_bound"); + + // dequant output scales, tensor, size = [num_layers, n], n is gemm output + // size + auto qkv_out_scales = ctx.MultiInput("QKVOutScale"); + auto out_linear_out_scales = + ctx.MultiInput("OutLinearOutScale"); + // dequant output scales, tensor, size = [num_layers * num_expert, n], n is gemm output + // size + auto expert_weight1_out_scales = ctx.MultiInput("ExpertWeight1OutScale"); + auto expert_weight2_out_scales = ctx.MultiInput("ExpertWeight2OutScale"); + + auto *sequence_lengths = ctx.Input("SeqLengths"); + auto *beam_cache_offset = ctx.Input("BeamCacheOffset"); + int beam_size = 1; + if (beam_cache_offset) { + beam_size = beam_cache_offset->dims()[1]; + } + + // 1. layer norm + const auto pre_layer_norm = ctx.Attr("pre_layer_norm"); + if (!pre_layer_norm) { + VLOG(0) << "not support post layer norm!"; + return; + } + const float epsilon = ctx.Attr("epsilon"); + auto ln_scales = ctx.MultiInput("LnScale"); + auto ln_biases = ctx.MultiInput("LnBias"); + + // in type is T, out type is int8_t + auto ln_compute = + AttnLayerNorm(dev_ctx, epsilon, bsz_seq, dim_embed); + Tensor ln_mean, ln_var; + ln_mean.Resize({{bsz_seq}}); + auto *ln_mean_data = + dev_ctx.Alloc(&ln_mean, ln_mean.numel() * sizeof(U)); + ln_var.Resize({{bsz_seq}}); + auto *ln_var_data = dev_ctx.Alloc(&ln_var, ln_var.numel() * sizeof(U)); + + // 2. qkv + // x: qkv's input [batch_size, seq_len, dim_embed] + // y: qkv's weight: [3, num_head, dim_head, dim_embed] + auto qkv_weights = ctx.MultiInput("QKVW"); + auto qkv_biases = ctx.MultiInput("QKVBias"); + const bool trans_qkvw = ctx.Attr("trans_qkvw"); + const auto qkv_w_dims = qkv_weights[0]->dims(); + int num_head = trans_qkvw ? qkv_w_dims[1] : qkv_w_dims[2]; + int dim_head = trans_qkvw ? qkv_w_dims[2] : qkv_w_dims[3]; + int hidden_size = num_head * dim_head; + int output_size = 3 * hidden_size; + int input_size = dim_embed; + + bool compute_bias = qkv_biases.size() > 0 && time_step == nullptr; + // (transA, transB, compute_bias) = (false, trans_qkvw, false) + AttnMatmulINT8 qkv_compute( + dev_ctx, bsz_seq, output_size, input_size, compute_bias); + Tensor qkv_out; + qkv_out.Resize({{bsz, seq_len, 3, num_head, dim_head}}); + auto *qkv_out_data = + dev_ctx.Alloc(&qkv_out, qkv_out.numel() * sizeof(T)); + + // 3. fmha + AttnDropoutParam attn_param( + true, "upscale_in_train", 0.0, true, true, 0, nullptr); + auto fmha_compute = + FMHARef(dev_ctx, bsz, seq_len, num_head, dim_head, attn_param); + auto *src_mask = ctx.Input("SrcMask"); + auto cache_kvs = ctx.MultiInput("CacheKV"); + auto cache_kv_outs = ctx.MultiOutput("CacheKVOut"); + + int time_step_cpu = 0; + auto out_seq_len = seq_len; + if (time_step) { + time_step_cpu = src_mask->dims()[3] - 1; + out_seq_len += time_step_cpu; + } + + Tensor transpose_out_2, qk_out; + transpose_out_2.Resize({{3, bsz, num_head, seq_len, dim_head}}); + auto *transpose_out_2_data = + dev_ctx.Alloc(&transpose_out_2, transpose_out_2.numel() * sizeof(T)); + qk_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); + auto *qk_out_data = dev_ctx.Alloc(&qk_out, qk_out.numel() * sizeof(T)); + + Tensor softmax_out; + Tensor attn_dropout_mask_out, attn_dropout_out; + Tensor qktv_out, fmha_out; + softmax_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); + auto *softmax_out_data = + dev_ctx.Alloc(&softmax_out, softmax_out.numel() * sizeof(T)); + + qktv_out.Resize({{bsz, num_head, seq_len, dim_head}}); + auto *qktv_out_data = + dev_ctx.Alloc(&qktv_out, qktv_out.numel() * sizeof(T)); + fmha_out.Resize({{bsz_seq, num_head, dim_head}}); + auto *fmha_out_data = + dev_ctx.Alloc(&fmha_out, fmha_out.numel() * sizeof(T)); + + // 4. out_linear 注意!!这个weight的维度需要transpose!! + auto out_linear_weights = ctx.MultiInput("OutLinearW"); + auto out_linear_biases = ctx.MultiInput("OutLinearBias"); + int ring_id = ctx.Attr("ring_id"); + // (transA, transB, compute_bias) = (false, false, false) + AttnMatmulINT8 out_linear_compute( + dev_ctx, bsz_seq, dim_embed, hidden_size, false); + + // 5. ln(residual + bias) + DropoutParam dropout_param(false, 0, true, true, 0.0, nullptr, 0); + + using LayerNormComputeType = float; + auto ffn_ln_scales = ctx.MultiInput("FFNLnScale"); + auto ffn_ln_biases = ctx.MultiInput("FFNLnBias"); + Tensor bias_dropout_residual_out, dropout_mask_out; + T *bias_dropout_residual_out_data = nullptr; + bias_dropout_residual_out.Resize({{bsz_seq, dim_embed}}); + bias_dropout_residual_out_data = + dev_ctx.Alloc(&bias_dropout_residual_out, + bias_dropout_residual_out.numel() * sizeof(T)); + uint8_t *dropout_mask_out_data = nullptr; + + // 6. moe layer: gate / expert_w & b / some attrs + auto gate_weights = ctx.MultiInput("GateWeight"); + auto gate_biases = ctx.MultiInput("GateBias"); + // weight的维度需要transpose!!!! + auto expert_weights1 = ctx.MultiInput("ExpertWeight1"); + auto expert_biases1 = ctx.MultiInput("ExpertBias1"); + auto expert_weights2 = ctx.MultiInput("ExpertWeight2"); + auto expert_biases2 = ctx.MultiInput("ExpertBias2"); + int dim_feedforward = expert_weights1[0]->dims()[0]; // dim is [dim_feedforward, dim_embed] + int topk = ctx.Attr("topk"); + int mp_size = ctx.Attr("mp_size"); + int mp_rank = ctx.Attr("mp_rank"); + int num_expert = ctx.Attr("num_expert"); + int world_size = ctx.Attr("world_size"); + int moe_ring_id = ctx.Attr("moe_ring_id"); + bool approximate = ctx.Attr("approximate"); + + int tot_expert = world_size * num_expert; + // after slice, bsz_seq should be change + int sliced_bsz_seq = bsz_seq; + int start = 0; + int end = 0; + if (mp_size > 1) { + start = bsz_seq / world_size * mp_rank; + end = std::min(start + bsz_seq / world_size, bsz_seq); + sliced_bsz_seq = end - start; + } + int out_batch_size = sliced_bsz_seq * topk; + // slice + Tensor sliced_inp; + sliced_inp.Resize({{sliced_bsz_seq, dim_embed}}); + dev_ctx.Alloc(&sliced_inp, sliced_inp.numel() * sizeof(T)); + // gate linear + Tensor gate_out; + gate_out.Resize({{sliced_bsz_seq, tot_expert}}); + dev_ctx.Alloc(&gate_out, gate_out.numel() * sizeof(T)); + // topk + Tensor topk_value, topk_idx; + topk_value.Resize({{sliced_bsz_seq, topk}}); + dev_ctx.Alloc(&topk_value, topk_value.numel() * sizeof(T)); + topk_idx.Resize({{sliced_bsz_seq, topk}}); + dev_ctx.Alloc(&topk_idx, topk_idx.numel() * sizeof(T)); + // local expert count, global expert count + Tensor local_expert_count, global_expert_count; + local_expert_count.Resize({{tot_expert}}); + global_expert_count.Resize({{tot_expert}}); + dev_ctx.Alloc(&local_expert_count, local_expert_count.numel() * sizeof(int64_t)); + dev_ctx.Alloc(&global_expert_count, global_expert_count.numel() * sizeof(int64_t)); + // fwd_expert_count, fwd_batch_size + Tensor fwd_expert_count, fwd_batch_size; + Tensor fwd_expert_count_cpu, fwd_batch_size_cpu; + fwd_expert_count.Resize({{num_expert}}); + fwd_batch_size.Resize({{1}}); + dev_ctx.Alloc(&fwd_expert_count, fwd_expert_count.numel() * sizeof(int64_t)); + dev_ctx.Alloc(&fwd_batch_size, fwd_batch_size.numel() * sizeof(int64_t)); + // pos, temp pos + Tensor pos, temp_pos; + pos.Resize({{out_batch_size}}); + temp_pos.Resize({{out_batch_size}}); + dev_ctx.Alloc(&pos, pos.numel() * sizeof(int64_t)); + dev_ctx.Alloc(&temp_pos, temp_pos.numel() * sizeof(int64_t)); + // cumsum + Tensor lec_cum; + lec_cum.Resize({{tot_expert}}); + dev_ctx.Alloc(&lec_cum, lec_cum.numel() * sizeof(int64_t)); + // fused moe ffn tmp out + Tensor index_select_out; + index_select_out.Resize({{out_batch_size, dim_embed}}); + dev_ctx.Alloc(&index_select_out, index_select_out.numel() * sizeof(T)); + Tensor global_gather_out; + global_gather_out.Resize({{out_batch_size, dim_embed}}); + dev_ctx.Alloc(&global_gather_out, global_gather_out.numel() * sizeof(T)); + Tensor moe_gather_out; + moe_gather_out.Resize({{out_batch_size, dim_embed}}); + dev_ctx.Alloc(&moe_gather_out, moe_gather_out.numel() * sizeof(T)); + Tensor bmm_out; + bmm_out.Resize({{sliced_bsz_seq, 1, dim_embed}}); + dev_ctx.Alloc(&bmm_out, bmm_out.numel() * sizeof(T)); + Tensor all_gather_out; + all_gather_out.Resize({{bsz_seq, dim_embed}}); + dev_ctx.Alloc(&all_gather_out, all_gather_out.numel() * sizeof(T)); + // topk tensor + Tensor topk_tensor; + topk_tensor.Resize({{1}}); + dev_ctx.Alloc(&topk_tensor, topk_tensor.numel() * sizeof(int64_t)); + phi::FullKernel(dev_ctx, {1}, topk, pos.dtype(), &topk_tensor); + + // []. init workspace for cublasLt transform + Tensor input_workspace, output_workspace, cublaslt_workspace; + // for input and output transform data is CUBLASLT_ORDER_COL32 format, + int m_max = bsz_seq, k_max = std::max({dim_embed, dim_feedforward}), + n_max = std::max({output_size, dim_embed, dim_feedforward}); + // maybe need to change the size of workspace here + + input_workspace.Resize({{(m_max * k_max + 31) / 32 * 32}}); + dev_ctx.Alloc(&input_workspace, + input_workspace.numel() * sizeof(int8_t)); + + output_workspace.Resize({{(n_max * m_max + 31) / 32 * 32}}); + dev_ctx.Alloc(&output_workspace, + output_workspace.numel() * sizeof(int32_t)); + + cublaslt_workspace.Resize({{3000000}}); + dev_ctx.Alloc(&cublaslt_workspace, + cublaslt_workspace.numel() * sizeof(int8_t)); + + // calc + auto *out = ctx.Output("Out"); + auto *from_data = dev_ctx.Alloc(out, out->numel() * sizeof(T)); + + Tensor buf0, moe_out; + buf0.Resize({{bsz_seq, dim_embed}}); + dev_ctx.Alloc(&buf0, buf0.numel() * sizeof(T)); + moe_out.Resize({{bsz_seq, dim_embed}}); + dev_ctx.Alloc(&moe_out, moe_out.numel() * sizeof(T)); + + const T *x_data; + x_data = input_x->data(); + + int layers = qkv_weights.size(); + + for (int i = 0; i < layers; ++i) { +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step1, pre layernorm"; +#endif + // step1. layer_norm + if (i == 0) { + auto *ln_scale_data = ln_scales[i]->data(); + auto *ln_bias_data = ln_biases[i]->data(); + // layer norm后,对输出做scale,因此输出是int8,在input_workspace中 + ln_compute.ComputeForward(x_data, + ln_scale_data, + ln_bias_data, + input_workspace.data(), + ln_mean_data, + ln_var_data, + nullptr, + 0, + qkv_in_scale[i], + quant_round_type, + quant_max_bound, + quant_min_bound); + } +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step2, qkv"; +#endif + // step2. qkv + const Tensor *qkv_bias = + qkv_biases.size() > 0 ? qkv_biases[i] : nullptr; + // NOTE: in decoder stage, bias is fused in fmha + const Tensor *bias = time_step ? nullptr : qkv_bias; + // 输入是int8,input workspace,输出是T,qkv_out + qkv_compute.ComputeForwardINT8ToT(qkv_weights[i], + qkv_in_scale[i], + &input_workspace, // input + bias, + &qkv_out, // out, T + &output_workspace, // out tmp, int32 + &qkv_out, // bias out, T + qkv_out_scales[i], + &cublaslt_workspace); +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step3.1 fmha"; +#endif + // step3. fmha + const Tensor *cache_kv = + cache_kvs.size() > 0 ? cache_kvs[i] : nullptr; + Tensor *cache_kv_out = cache_kv ? cache_kv_outs[i] : nullptr; + + if (time_step) { // generation decoder stage + // [2, batch_size, num_head, max_seq_len, head_size] + int max_seq_len = cache_kv->dims()[3]; + fmha(dev_ctx, + qkv_out, + *qkv_bias, + *src_mask, + sequence_lengths, + nullptr, + beam_cache_offset, + cache_kv_out, + &fmha_out, + bsz, + beam_size, + max_seq_len, + num_head, + dim_head, + time_step_cpu, + 0, + 1. / sqrt(dim_head)); + } else if (cache_kv_out) { // generation context stage + fmha_compute.ComputeForward(qkv_out, + nullptr, + src_mask, + &transpose_out_2, + nullptr, + &qk_out, + nullptr, + &softmax_out, + &attn_dropout_mask_out, + &attn_dropout_out, + &qktv_out, + &fmha_out); + // [3, bsz, num_head, seq_len, head_dim] + T *qkv_data = transpose_out_2_data; + int64_t q_size = bsz * seq_len * num_head * dim_head; + int64_t k_size = q_size; + const T *q_ptr = qkv_data; + const T *k_ptr = q_ptr + q_size; + const T *v_ptr = k_ptr + k_size; + + // [2, bsz, num_head, max_seq_len, head_dim] + int max_seq_len = cache_kv_out->dims()[3]; + T *cache_kv_data = cache_kv_out->data(); + int64_t cache_k_size = bsz * num_head * max_seq_len * dim_head; + + T *cache_k_ptr = cache_kv_data; + T *cache_v_ptr = cache_kv_data + cache_k_size; + + write_cache_kv(dev_ctx, + cache_k_ptr, + cache_v_ptr, + k_ptr, + v_ptr, + bsz, + num_head, + seq_len, + max_seq_len, + dim_head); + } else { // not generation + VLOG(0) << "not support!"; + } +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step3.2 out linear"; +#endif + // T -> int8 + out_linear_compute.ComputeForwardTToINT8(out_linear_weights[i], + out_linear_in_scale[i], + &fmha_out, + &input_workspace, // input tmp, 先将输入量化 + nullptr, + &output_workspace, // output, int32 + nullptr, + &cublaslt_workspace, + quant_round_type, + quant_max_bound, + quant_min_bound); + // 输出在output_workspace + AllReduce(output_workspace, + ring_id, + bsz * seq_len * num_head * dim_head, + dev_ctx); +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step4"; +#endif + + // step5. ln(residual + dropout(input + bias)) + auto *ln_scale_data = ffn_ln_scales[i]->data(); + auto *ln_bias_data = ffn_ln_biases[i]->data(); + auto *out_linear_bias_data = out_linear_biases[i]->data(); + // input type is int32, src is T, dst is T + DequantSkipLoadAndStoreResidual load(output_workspace.data(), out_linear_bias_data, x_data, + out_linear_out_scales[i]->data(), bias_dropout_residual_out_data, 0.0f, dim_embed); + // 改为输出先不做scale,输出是fp16,输出到buf0 + AffineQuantStore store(buf0.data(), dim_embed, ln_scale_data, ln_bias_data); + DispatchLayerNorm(dev_ctx.stream(), load, store, bsz_seq, dim_embed, epsilon, ln_mean_data, ln_var_data); + +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step5"; +#endif + // moe + // step2 resize and slice ln_out + if (mp_size > 1) { + sliced_inp = buf0.Slice(start, end); + } else { + sliced_inp = buf0; + } +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, gate & topk"; +#endif + // step3 gate & topk + // 这里不做量化 + phi::MatMulAndAdd(dev_ctx, + gate_weights[i], + &sliced_inp, + gate_biases[i], + false, + false, + true, // compute bias + &gate_out, + &gate_out); + phi::TopkKernel(dev_ctx, + gate_out, + topk, // scalar + -1, + true, + false, + &topk_value, + &topk_idx); + // step4 prepare forward + // step4.1 number count +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, number count"; +#endif + phi::NumberCountKernel(dev_ctx, topk_idx, tot_expert, &local_expert_count); + // step4.2 all_to_all +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, all_to_all"; +#endif + if (world_size > 1) { + phi::AllToAll(local_expert_count, global_expert_count, moe_ring_id, dev_ctx); + } else { + global_expert_count = local_expert_count; + } + + // global expert count resize + global_expert_count.Resize({{world_size, num_expert}}); + // fwd expert count +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, fwd expert count"; +#endif + phi::SumKernel(dev_ctx, + global_expert_count, + phi::IntArray({0}), + global_expert_count.dtype(), + false, + &fwd_expert_count); + // fwd batch size + phi::SumKernel(dev_ctx, + fwd_expert_count, + phi::IntArray({}), // axis is None + fwd_expert_count.dtype(), + false, + &fwd_batch_size); + // step4.3 cumsum & assign pos +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, cumsum"; +#endif + phi::CumsumKernel(dev_ctx, + local_expert_count, + 0, + false, + false, + false, + &lec_cum); +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, assign pos"; +#endif + phi::AssignPosCompute(dev_ctx, &lec_cum, &topk_idx, &pos, out_batch_size); +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, floor divide"; +#endif + if (topk > 1) { + phi::FloorDivideKernel(dev_ctx, + pos, + topk_tensor, + &temp_pos); + } else { + temp_pos = pos; + } +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, tensor copy"; +#endif + framework::TensorCopySync(fwd_expert_count, platform::CPUPlace(), &fwd_expert_count_cpu); + framework::TensorCopySync(fwd_batch_size, platform::CPUPlace(), &fwd_batch_size_cpu); + int fwd_bsz = fwd_batch_size_cpu.data()[0]; + + Tensor global_scatter_out; + global_scatter_out.Resize({{fwd_bsz, dim_embed}}); + dev_ctx.Alloc(&global_scatter_out, global_scatter_out.numel() * sizeof(T)); + + Tensor all_expert_out; + all_expert_out.Resize({{fwd_bsz, dim_embed}}); + dev_ctx.Alloc(&all_expert_out, all_expert_out.numel() * sizeof(T)); + + // global_scatter_out.Resize({{fwd_bsz, dim_embed}}); + // all_expert_out.Resize({{fwd_bsz, dim_embed}}); + + // step 5, MOEScatter + // step 5.1, index select + // suppose tmp_pos->shape != [0] +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, index select"; +#endif + phi::IndexSelectKernel(dev_ctx, sliced_inp, temp_pos, 0, &index_select_out); + if (world_size > 1) { + auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); + // step 5.2, global_scatter + if (map->has(moe_ring_id)) { + phi::GlobalScatterProcessGroupFunctor(dev_ctx, + &index_select_out, + &local_expert_count, + &global_expert_count, + moe_ring_id, + true, + &global_scatter_out); + } else { + phi::GlobalScatterFunctor(dev_ctx, + &index_select_out, + &local_expert_count, + &global_expert_count, + moe_ring_id, + true, + &global_scatter_out); + } + } else { + global_scatter_out = index_select_out; + } + + // step 6, Expert Computation +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, Expert Computation"; +#endif + if (fwd_bsz != 0) { + phi::funcs::ConcatFunctor concat; // fp16 + std::vector tmp_expert_out; + int last_index = 0; + for (int idx = 0; idx < num_expert; idx++) { + int cur_expert_count = fwd_expert_count_cpu.data()[idx]; + if (cur_expert_count <= 0) { + continue; + } + int end = cur_expert_count + last_index; + + Tensor expert_in_tmp; // int8_t + expert_in_tmp.Resize({{(cur_expert_count * dim_feedforward + 31) / 32 * 32 }}); + dev_ctx.Alloc(&expert_in_tmp, expert_in_tmp.numel() * sizeof(int8_t)); + + Tensor expert_out1; // int32_t + expert_out1.Resize({{(cur_expert_count * dim_feedforward + 31) / 32 * 32}}); + dev_ctx.Alloc(&expert_out1, expert_out1.numel() * sizeof(int32_t)); + + Tensor expert_out2; // T(fp16) + expert_out2.Resize({{cur_expert_count, dim_embed}}); + dev_ctx.Alloc(&expert_out2, expert_out2.numel() * sizeof(T)); + // act_bias_out.Resize({{cur_expert_count, dim_feedforward}}); maybe int8_t? + // maybe use input_workspace and output workspace? + // dev_ctx.Alloc(&act_bias_out, act_bias_out.numel() * sizeof(T)); + + // input is int32_t, output is int8_t + FusedDropoutHelper fused_act_dropout_helper( + dev_ctx, cur_expert_count, dim_feedforward, dropout_param); + + Tensor tmp_inp = global_scatter_out.Slice(last_index, end); // fp16, T + int expert_idx = i * num_expert + idx; + // T to int8_t, matmul, dont compute bias + MatMulTToINT8(dev_ctx, + expert_weights1[expert_idx], + expert_weight1_in_scale[expert_idx], + &tmp_inp, + &expert_in_tmp, + &expert_out1, + cur_expert_count, + dim_feedforward, + dim_embed, + &cublaslt_workspace, // maybe space not enough + quant_round_type, + quant_max_bound, + quant_min_bound); + // act bias, input is int32_t, output is int8_t + fused_act_dropout_helper.DropoutActBias( + dev_ctx, + expert_out1.data(), + expert_biases1[expert_idx]->data(), + "gelu", + expert_in_tmp.data(), + nullptr, + expert_weight1_in_scale[expert_idx], + expert_weight1_out_scales[expert_idx]->data(), + 0, // data offset + expert_weight2_in_scale[expert_idx], + quant_round_type, + quant_max_bound, + quant_min_bound, + approximate); + // linear2, int8_t to T + MatMulINT8ToT(dev_ctx, + expert_weights2[expert_idx], + expert_weight2_in_scale[expert_idx], + &expert_in_tmp, + expert_biases2[expert_idx], + &expert_out2, + &expert_out1, // output_tmp + &expert_out2, + expert_weight2_out_scales[expert_idx], + cur_expert_count, + dim_embed, + dim_feedforward, + true, + &cublaslt_workspace); + tmp_expert_out.emplace_back(expert_out2); + last_index = end; + } + concat(dev_ctx, tmp_expert_out, 0, &all_expert_out); + } else { + all_expert_out = global_scatter_out; + } + + // step7. MOEGather +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, MOEGather"; +#endif + if (world_size > 1) { + auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); + // step 7.1, global_gather + if (map->has(moe_ring_id)) { + phi::GlobalGatherProcessGroupFunctor(dev_ctx, + &all_expert_out, + &local_expert_count, + &global_expert_count, + moe_ring_id, + true, + &global_gather_out); + } else { + phi::GlobalGatherFunctor(dev_ctx, + &all_expert_out, + &local_expert_count, + &global_expert_count, + moe_ring_id, + true, + &global_gather_out); + } + } else { + global_gather_out = all_expert_out; + } + // step 7.2, local_gather or scatter + // suppose pos->shape != [0] +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, local_gather or scatter"; +#endif + phi::ScatterKernel(dev_ctx, + moe_gather_out, + pos, + global_gather_out, + true, + &moe_gather_out); + // step 8, reshape & bmm + // moe gather out reshape +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, reshape & bmm"; +#endif + moe_gather_out.Resize({{sliced_bsz_seq, topk, dim_embed}}); + topk_value.Resize({{sliced_bsz_seq, 1, topk}}); + phi::BmmKernel(dev_ctx, topk_value, moe_gather_out, &bmm_out); + bmm_out.Resize({{sliced_bsz_seq, dim_embed}}); + // step 9, AllGather +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, AllGather"; +#endif + if (mp_size > 1) { + // all gather + phi::AllGather(bmm_out, all_gather_out, moe_ring_id, dev_ctx); + } else { + all_gather_out = bmm_out; + } + + // step 11, add residual +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, add residual"; +#endif + if (i < layers - 1) { + // add residual & next layer norm & qkv quant + auto *ln_scale_data = ln_scales[i + 1]->data(); + auto *ln_bias_data = ln_biases[i + 1]->data(); + // input type is T, src is T, dst is T + DequantSkipLoadAndStoreResidual load(all_gather_out.data(), nullptr, bias_dropout_residual_out_data, + nullptr, moe_out.data(), 0.0f, dim_embed); + AffineQuantStore store(input_workspace.data(), dim_embed, + ln_scale_data, ln_bias_data, qkv_in_scale[i + 1], quant_round_type, quant_max_bound, quant_min_bound); + DispatchLayerNorm(dev_ctx.stream(), load, store, bsz_seq, dim_embed, epsilon, ln_mean_data, ln_var_data); + } else { + // last layer, only add residual, T + phi::AddKernel(dev_ctx, all_gather_out, bias_dropout_residual_out, &moe_out); + } + + x_data = moe_out.data(); + + } // end for layer loop + moe_out.Resize({{bsz, seq_len, dim_embed}}); + *out = moe_out; + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_CUDA_KERNEL(fused_multi_transformer_moe_int8, + ops::FusedMultiTransformerMoeINT8OpKernel, + ops::FusedMultiTransformerMoeINT8OpKernel); \ No newline at end of file diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_moe_op.cc b/paddle/fluid/operators/fused/fused_multi_transformer_moe_op.cc new file mode 100644 index 0000000000000..2132d9774eb02 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_multi_transformer_moe_op.cc @@ -0,0 +1,319 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace operators { + +class FusedMultiTransformerMoeOp : public framework::OperatorWithKernel { + private: + static constexpr const char *OpName = "FusedMultiTransformerMoeOp"; + + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { +#define CHECK_INPUT(name) \ + OP_INOUT_CHECK(ctx->HasInput(#name), "Input", #name, OpName) +#define CHECK_INPUTS(name) \ + OP_INOUT_CHECK(ctx->HasInputs(#name), "Input", #name, OpName) +#define CHECK_OUTPUT(name) \ + OP_INOUT_CHECK(ctx->HasOutput(#name), "Output", #name, OpName) +#define CHECK_OUTPUTS(name) \ + OP_INOUT_CHECK(ctx->HasOutputs(#name), "Output", #name, OpName) + + CHECK_INPUT(X); + + // attention + CHECK_INPUTS(QKVW); + CHECK_INPUTS(OutLinearW); + + if (ctx->HasInput("TimeStep")) { + CHECK_INPUTS(CacheKV); + } + + if (ctx->HasInputs("CacheKV")) { + CHECK_OUTPUTS(CacheKVOut); + } + + // moe + CHECK_INPUTS(GateWeight); + CHECK_INPUTS(GateBias); + CHECK_INPUTS(ExpertWeight1); + CHECK_INPUTS(ExpertWeight2); + + // out + CHECK_OUTPUT(Out); + + // x: qkv's input [batch_size, seq_len, dim_embed] + // y: qkv's weight: [3, num_head, dim_head, dim_embed] + auto x_dim = ctx->GetInputDim("X"); + auto y_dim = ctx->GetInputsDim("QKVW")[0]; + bool trans_qkvw = ctx->Attrs().Get("trans_qkvw"); + PADDLE_ENFORCE_EQ( + x_dim.size(), + 3, + platform::errors::InvalidArgument("The dimensions of x must be 3" + "(batch_size, seq_len, dim_embed)," + "but received dimensions of" + "Input is [%d]", + x_dim.size())); + PADDLE_ENFORCE_EQ(y_dim.size(), + 4, + platform::errors::InvalidArgument( + "The dimensions of qkv_weight must be 4" + "(3, num_head, dim_head, dim_embed)," + "but received dimensions of" + "Input is [%d]", + y_dim.size())); + PADDLE_ENFORCE_EQ( + x_dim[2], + trans_qkvw ? y_dim[3] : y_dim[0], + platform::errors::InvalidArgument( + "ShapeError: the dimension of x_dim[2] and y_dim[3](trans_qkvw is " + "true) or y_dim[0](trans_qkvw is false)" + "must be equal. But received: the shape " + "of input x = [%s], and the shape of " + "input qkv_weight = [%s]", + x_dim, + y_dim)); + + if (ctx->HasInputs("CacheKV")) { + // [2, batch_size, num_head, max_seq_len, head_size] + const auto &c_dims = ctx->GetInputsDim("CacheKV"); + const auto &c_dim = c_dims[0]; + + PADDLE_ENFORCE_EQ( + c_dim.size(), + 5, + paddle::platform::errors::InvalidArgument( + "The CacheKV must be 5 dims, but got %d", c_dim.size())); + PADDLE_ENFORCE_EQ(c_dim[0], + 2, + paddle::platform::errors::InvalidArgument( + "The first dim of CacheKV must be 2, but got %d", + c_dim[0])); // 2 + PADDLE_ENFORCE_EQ(c_dim[2], + trans_qkvw ? y_dim[1] : y_dim[2], + paddle::platform::errors::InvalidArgument( + "The third dim of CacheKV must be equal with num " + "head %d, but got %d", + trans_qkvw ? y_dim[1] : y_dim[2], + c_dim[2])); // num_head + PADDLE_ENFORCE_EQ(c_dim[4], + trans_qkvw ? y_dim[2] : y_dim[3], + paddle::platform::errors::InvalidArgument( + "The fifth dim of CacheKV must be equal with head " + "size %d, but got %d", + trans_qkvw ? y_dim[2] : y_dim[3], + c_dim[4])); // head_size + } + + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } + + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, + const phi::DenseTensor &tensor, + const framework::OpKernelType &expected_kernel_type) const override { + if (var_name == "TimeStep") { + VLOG(10) << "var_name:" << var_name << " need not to transform"; + return expected_kernel_type; + } + return framework::OpKernelType( + expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + } +}; + +class FusedMultiTransformerMoeOpOpMaker + : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "The input tensor."); + AddInput("LnScale", + "Scale is a 1-dimensional tensor of size " + "H. Here, H represents the last dimension of its input tensor.") + .AsDuplicable(); + AddInput("LnBias", + "Bias is a 1-dimensional tensor of size " + "H. Here, H represents the last dimension of its input tensor.") + .AsDuplicable(); + AddInput("QKVW", "The qkv weight tensor.").AsDuplicable(); + AddInput("QKVBias", "The qkv bias tensor.").AsDispensable().AsDuplicable(); + AddInput("CacheKV", "(optional) The cached KV for generation inference.") + .AsDispensable() + .AsDuplicable(); + AddInput("PreCaches", + "(optional) The prefix caches for generation inference.") + .AsDispensable() + .AsDuplicable(); + AddInput("RotaryPosEmb", + "(optional) The RoPE embeddings for generation inference.") + .AsDispensable(); + AddInput("BeamCacheOffset", + "(optional) The offset of CacheKV when using BeamSearch.") + .AsDispensable(); + AddInput("TimeStep", + "(optional, int) The time step for generation inference.") + .AsDispensable(); + AddInput("SeqLengths", "(optional) The sequence length tensor of inputs.") + .AsDispensable(); + AddInput("SrcMask", "(optional) The attention mask tensor in fmha.") + .AsDispensable(); + AddInput("OutLinearW", "The out_linear weight tensor.").AsDuplicable(); + AddInput("OutLinearBias", "The out_linear bias tensor.") + .AsDispensable() + .AsDuplicable(); + AddInput("GateWeight", "The gate_weights in moe") + .AsDuplicable(); + AddInput("GateBias", "The gate_biases in moe") + .AsDuplicable(); + AddInput("FFNLnScale", "The layer_norm scale of FusedFeedForward op") + .AsDuplicable(); + AddInput("FFNLnBias", "The layer_norm bias of FusedFeedForward op") + .AsDuplicable(); + AddInput("ExpertWeight1", "The expert_weights1 in moe") + .AsDuplicable(); + AddInput("ExpertBias1", "The expert_biases1 in moe") + .AsDuplicable(); + AddInput("ExpertWeight2", "The expert_weights2 in moe") + .AsDuplicable(); + AddInput("ExpertBias2", "The expert_biases2 in moe") + .AsDuplicable(); + AddOutput("CacheKVOut", "The updated cache KV. Inplace with CacheKV") + .AsDispensable() + .AsDuplicable(); + AddOutput("Out", "Result after multi ."); + AddAttr("pre_layer_norm", + "if true, the attention op uses pre_layer_norm architecure, " + "else, uses post_layer_norm architecuture. " + "[default true].") + .SetDefault(true); + AddAttr("epsilon", + "Constant for numerical stability [default 1e-5].") + .SetDefault(1e-5) + .AddCustomChecker([](const float &epsilon) { + PADDLE_ENFORCE_EQ(epsilon >= 0.0f && epsilon <= 0.001f, + true, + platform::errors::InvalidArgument( + "'epsilon' in Op(LayerNorm) should be between" + "0.0 and 0.001, But received [%s].", + epsilon)); + }); + + AddAttr("dropout_rate", "Probability of setting units to zero.") + .SetDefault(.5f) + .AddCustomChecker([](const float &drop_p) { + PADDLE_ENFORCE_EQ(drop_p >= 0.0f && drop_p <= 1.0f, + true, + platform::errors::InvalidArgument( + "'dropout_rate' must be between 0.0 and 1.0.")); + }); + + AddAttr("is_test", + "(bool, default false) Set to true for inference only, false " + "for training. Some layers may run faster when this is true.") + .SetDefault(false); + AddAttr( + "dropout_implementation", + "[\"downgrade_in_infer\"|\"upscale_in_train\"]" + "The meaning is the same as 'attn_dropout_implementation'.") + .SetDefault("downgrade_in_infer") + .AddCustomChecker([](const std::string &type) { + PADDLE_ENFORCE_EQ( + type == "downgrade_in_infer" || type == "upscale_in_train", + true, + platform::errors::InvalidArgument( + "dropout_implementation can only be downgrade_in_infer or " + "upscale_in_train")); + }); + AddAttr("act_method", "act_method") + .SetDefault("gelu") + .AddCustomChecker([](const std::string &act_type) { + PADDLE_ENFORCE_EQ( + act_type == "gelu" || act_type == "geglu" || act_type == "relu" || act_type == "none", + true, + platform::errors::InvalidArgument( + "Only support `gelu`, `geglu`, `relu`, `none` activation in " + "FusedMultiTransformer. ")); + }); + + AddAttr( + "trans_qkvw", + "Whether the weights of qkv should be transposed. If true," + "the shape eights of qkv should be [3, num_head, dim_head, dim_embed]." + "Otherwise the shape of weights of qkv should be" + "[dim_embed, 3, num_head, dim_head]") + .SetDefault(true); + + AddAttr( + "ring_id", + "ring id for tensor model parallel. distributed training and inference") + .SetDefault(-1); + // for moe layer + AddAttr( + "topk", + "gate's topk im moe") + .SetDefault(2); + AddAttr( + "mp_size", + "mp size") + .SetDefault(1); + AddAttr( + "mp_rank", + "mp rank") + .SetDefault(0); + AddAttr( + "num_expert", + "experts num im moe") + .SetDefault(1); + AddAttr( + "world_size", + "world size") + .SetDefault(1); + AddAttr( + "moe_ring_id", + "experts communicate group's ring id") + .SetDefault(1); + AddAttr( + "approximate", + "approximate in expert compute gelu") + .SetDefault(true); + AddComment(R"DOC(fused multi transformer layers op)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR( + fused_multi_transformer_moe, + ops::FusedMultiTransformerMoeOp, + ops::FusedMultiTransformerMoeOpOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); + diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_moe_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_moe_op.cu new file mode 100644 index 0000000000000..91fa104498bf4 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_multi_transformer_moe_op.cu @@ -0,0 +1,844 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/fused/fused_multi_transformer_moe_op.h" + +namespace paddle { +namespace operators { + +using Tensor = phi::DenseTensor; +// #define _DEBUG_FUSED_MULTI_TRANSFORMER + +template +class FusedMultiTransformerMoeOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + using U = LayerNormParamType; + auto &dev_ctx = ctx.cuda_device_context(); + + auto *time_step = ctx.Input("TimeStep"); + // 0. input + auto *input_x = ctx.Input("X"); + const auto input_x_dims = input_x->dims(); + int bsz = input_x_dims[0]; + int seq_len = input_x_dims[1]; + int dim_embed = input_x_dims[2]; + int bsz_seq = bsz * seq_len; + if (bsz_seq == 0) { + return; + } + // LOG(INFO) << "intput X: bsz: " << bsz << ", seq_len: " << seq_len << ", dim_embed: " << dim_embed; + const std::string act_method = ctx.Attr("act_method"); + auto *sequence_lengths = ctx.Input("SeqLengths"); // nullptr + auto *beam_cache_offset = ctx.Input("BeamCacheOffset"); + int beam_size = 1; + if (beam_cache_offset) { + beam_size = beam_cache_offset->dims()[1]; + } + // LOG(INFO) << "beam_size: " << beam_size; + + auto *out = ctx.Output("Out"); + dev_ctx.Alloc(out, out->numel() * sizeof(T)); + + // 1. layer norm + const auto pre_layer_norm = ctx.Attr("pre_layer_norm"); + if (!pre_layer_norm) { + VLOG(0) << "not support post layer norm!"; + return; + } + const float epsilon = ctx.Attr("epsilon"); + auto ln_scales = ctx.MultiInput("LnScale"); + auto ln_biases = ctx.MultiInput("LnBias"); + + auto ln_compute = AttnLayerNorm(dev_ctx, epsilon, bsz_seq, dim_embed); + Tensor ln_mean, ln_var; + ln_mean.Resize({{bsz_seq}}); + auto *ln_mean_data = + dev_ctx.Alloc(&ln_mean, ln_mean.numel() * sizeof(U)); + ln_var.Resize({{bsz_seq}}); + auto *ln_var_data = dev_ctx.Alloc(&ln_var, ln_var.numel() * sizeof(U)); + + // 2. qkv + // x: qkv's input [batch_size, seq_len, dim_embed] + // y: qkv's weight: [3, num_head, dim_head, dim_embed] + auto qkv_weights = ctx.MultiInput("QKVW"); + auto qkv_biases = ctx.MultiInput("QKVBias"); + const bool trans_qkvw = ctx.Attr("trans_qkvw"); + const auto qkv_w_dims = qkv_weights[0]->dims(); + int num_head = trans_qkvw ? qkv_w_dims[1] : qkv_w_dims[2]; + int dim_head = trans_qkvw ? qkv_w_dims[2] : qkv_w_dims[3]; + int hidden_size = num_head * dim_head; + int output_size = 3 * hidden_size; + int input_size = dim_embed; + + bool compute_bias = qkv_biases.size() > 0 && time_step == nullptr; + // (transA, transB, compute_bias) = (false, trans_qkvw, false) + auto qkv_compute = AttnMatMul(dev_ctx, + false, + trans_qkvw, + bsz_seq, + output_size, + input_size, + compute_bias); + Tensor qkv_out; + qkv_out.Resize({{bsz, seq_len, 3, num_head, dim_head}}); + auto *qkv_out_data = + dev_ctx.Alloc(&qkv_out, qkv_out.numel() * sizeof(T)); + + // 3. fmha + auto dropout_implementation = ctx.Attr("dropout_implementation"); + AttnDropoutParam attn_param( + true, dropout_implementation, 0.0, true, true, 0, nullptr); + auto fmha_compute = + FMHARef(dev_ctx, bsz, seq_len, num_head, dim_head, attn_param); + auto *src_mask = ctx.Input("SrcMask"); + auto cache_kvs = ctx.MultiInput("CacheKV"); + auto cache_kv_outs = ctx.MultiOutput("CacheKVOut"); + + int time_step_cpu = 0; + if (time_step) { + time_step_cpu = src_mask->dims()[3] - 1; + } + + auto out_seq_len = seq_len; + if (time_step) { + PADDLE_ENFORCE_GT(time_step_cpu, + 0, + platform::errors::PreconditionNotMet( + "The value of time_step must > 0, but now is %d", + time_step_cpu)); + PADDLE_ENFORCE_EQ( + seq_len, + 1, + platform::errors::PreconditionNotMet( + "In decode stage, the seq_len of input must be 1, but now is %d", + seq_len)); + out_seq_len += time_step_cpu; + } + + Tensor transpose_out_2, qk_out; + transpose_out_2.Resize({{3, bsz, num_head, seq_len, dim_head}}); + auto *transpose_out_2_data = + dev_ctx.Alloc(&transpose_out_2, transpose_out_2.numel() * sizeof(T)); + qk_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); + auto *qk_out_data = dev_ctx.Alloc(&qk_out, qk_out.numel() * sizeof(T)); + + Tensor softmax_out; + Tensor attn_dropout_mask_out, attn_dropout_out; + Tensor qktv_out, fmha_out; + softmax_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); + auto *softmax_out_data = + dev_ctx.Alloc(&softmax_out, softmax_out.numel() * sizeof(T)); + + qktv_out.Resize({{bsz, num_head, seq_len, dim_head}}); + auto *qktv_out_data = + dev_ctx.Alloc(&qktv_out, qktv_out.numel() * sizeof(T)); + fmha_out.Resize({{bsz, seq_len, num_head, dim_head}}); + auto *fmha_out_data = + dev_ctx.Alloc(&fmha_out, fmha_out.numel() * sizeof(T)); + + // 4. out_linear + auto out_linear_weights = ctx.MultiInput("OutLinearW"); + auto out_linear_biases = ctx.MultiInput("OutLinearBias"); + int ring_id = ctx.Attr("ring_id"); + // (transA, transB, compute_bias) = (false, false, false) + auto out_linear_compute = AttnMatMul( + dev_ctx, false, false, bsz_seq, dim_embed, hidden_size, false); + + // 5. ln(residual + bias), pre layernorm in ffn/moe + DropoutParam dropout_param(false, 0, true, true, 0.0, nullptr, 0); + FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( + dev_ctx, bsz_seq, dim_embed, dropout_param, epsilon); + auto ffn_ln_scales = ctx.MultiInput("FFNLnScale"); + auto ffn_ln_biases = ctx.MultiInput("FFNLnBias"); + Tensor bias_dropout_residual_out, dropout_mask_out; + T *bias_dropout_residual_out_data = nullptr; + bias_dropout_residual_out.Resize({{bsz_seq, dim_embed}}); + bias_dropout_residual_out_data = + dev_ctx.Alloc(&bias_dropout_residual_out, + bias_dropout_residual_out.numel() * sizeof(T)); + uint8_t *dropout_mask_out_data = nullptr; + + // 6. moe layer: gate / expert_w & b / some attrs + auto gate_weights = ctx.MultiInput("GateWeight"); + auto gate_biases = ctx.MultiInput("GateBias"); + auto expert_weights1 = ctx.MultiInput("ExpertWeight1"); + auto expert_biases1 = ctx.MultiInput("ExpertBias1"); + auto expert_weights2 = ctx.MultiInput("ExpertWeight2"); + auto expert_biases2 = ctx.MultiInput("ExpertBias2"); + int dim_feedforward = expert_weights1[0]->dims()[1]; + // int dim_feedforward = expert_weights1[0]->dims()[2]; // batched gemm + int topk = ctx.Attr("topk"); + int mp_size = ctx.Attr("mp_size"); + int mp_rank = ctx.Attr("mp_rank"); + int num_expert = ctx.Attr("num_expert"); + int world_size = ctx.Attr("world_size"); + int moe_ring_id = ctx.Attr("moe_ring_id"); + bool approximate = ctx.Attr("approximate"); + + int tot_expert = world_size * num_expert; + // after slice, bsz_seq should be change + int sliced_bsz_seq = bsz_seq; + int start = 0; + int end = 0; + if (mp_size > 1) { + start = bsz_seq / world_size * mp_rank; + end = std::min(start + bsz_seq / world_size, bsz_seq); + sliced_bsz_seq = end - start; + } + int out_batch_size = sliced_bsz_seq * topk; + // slice + Tensor sliced_inp; + sliced_inp.Resize({{sliced_bsz_seq, dim_embed}}); + dev_ctx.Alloc(&sliced_inp, sliced_inp.numel() * sizeof(T)); + // gate linear + Tensor gate_out; + gate_out.Resize({{sliced_bsz_seq, tot_expert}}); + dev_ctx.Alloc(&gate_out, gate_out.numel() * sizeof(T)); + // topk + Tensor topk_value, topk_idx; + topk_value.Resize({{sliced_bsz_seq, topk}}); + dev_ctx.Alloc(&topk_value, topk_value.numel() * sizeof(T)); + topk_idx.Resize({{sliced_bsz_seq, topk}}); + dev_ctx.Alloc(&topk_idx, topk_idx.numel() * sizeof(T)); + // local expert count, global expert count + Tensor local_expert_count, global_expert_count; + local_expert_count.Resize({{tot_expert}}); + global_expert_count.Resize({{tot_expert}}); + dev_ctx.Alloc(&local_expert_count, local_expert_count.numel() * sizeof(int64_t)); + dev_ctx.Alloc(&global_expert_count, global_expert_count.numel() * sizeof(int64_t)); + // fwd_expert_count, fwd_batch_size + Tensor fwd_expert_count, fwd_batch_size; + Tensor fwd_expert_count_cpu, fwd_batch_size_cpu; + fwd_expert_count.Resize({{num_expert}}); + fwd_batch_size.Resize({{1}}); + dev_ctx.Alloc(&fwd_expert_count, fwd_expert_count.numel() * sizeof(int64_t)); + dev_ctx.Alloc(&fwd_batch_size, fwd_batch_size.numel() * sizeof(int64_t)); + // pos, temp pos + Tensor pos, temp_pos; + pos.Resize({{out_batch_size}}); + temp_pos.Resize({{out_batch_size}}); + dev_ctx.Alloc(&pos, pos.numel() * sizeof(int64_t)); + dev_ctx.Alloc(&temp_pos, temp_pos.numel() * sizeof(int64_t)); + // cumsum + Tensor lec_cum; + lec_cum.Resize({{tot_expert}}); + dev_ctx.Alloc(&lec_cum, lec_cum.numel() * sizeof(int64_t)); + // fused moe ffn tmp out + Tensor index_select_out; + index_select_out.Resize({{out_batch_size, dim_embed}}); + dev_ctx.Alloc(&index_select_out, index_select_out.numel() * sizeof(T)); + Tensor global_gather_out; + global_gather_out.Resize({{out_batch_size, dim_embed}}); + dev_ctx.Alloc(&global_gather_out, global_gather_out.numel() * sizeof(T)); + Tensor moe_gather_out; + moe_gather_out.Resize({{out_batch_size, dim_embed}}); + dev_ctx.Alloc(&moe_gather_out, moe_gather_out.numel() * sizeof(T)); + Tensor bmm_out; + bmm_out.Resize({{sliced_bsz_seq, 1, dim_embed}}); + dev_ctx.Alloc(&bmm_out, bmm_out.numel() * sizeof(T)); + Tensor all_gather_out; + all_gather_out.Resize({{bsz_seq, dim_embed}}); + dev_ctx.Alloc(&all_gather_out, all_gather_out.numel() * sizeof(T)); + // topk tensor + Tensor topk_tensor; + topk_tensor.Resize({{1}}); + dev_ctx.Alloc(&topk_tensor, topk_tensor.numel() * sizeof(int64_t)); + phi::FullKernel(dev_ctx, {1}, topk, pos.dtype(), &topk_tensor); + // for nccl comm + auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); + + // // expert out, alloc max size + // Tensor global_scatter_out; + // global_scatter_out.Resize({{2 * bsz_seq, dim_embed}}); + // dev_ctx.Alloc(&global_scatter_out, global_scatter_out.numel() * sizeof(T)); + + // Tensor expert_out1, expert_out2, all_expert_out; + // expert_out1.Resize({{2 * bsz_seq, dim_feedforward}}); + // // act_bias_out.Resize({{bsz_seq, dim_feedforward}}); + // expert_out2.Resize({{2 * bsz_seq, dim_embed}}); + // all_expert_out.Resize({{2 * bsz_seq, dim_embed}}); + // dev_ctx.Alloc(&expert_out1, expert_out1.numel() * sizeof(T)); + // // dev_ctx.Alloc(&act_bias_out, act_bias_out.numel() * sizeof(T)); + // dev_ctx.Alloc(&expert_out2, expert_out2.numel() * sizeof(T)); + // dev_ctx.Alloc(&all_expert_out, all_expert_out.numel() * sizeof(T)); + + Tensor buf0, moe_out; + buf0.Resize({{bsz_seq, dim_embed}}); + dev_ctx.Alloc(&buf0, buf0.numel() * sizeof(T)); + moe_out.Resize({{bsz_seq, dim_embed}}); + dev_ctx.Alloc(&moe_out, moe_out.numel() * sizeof(T)); + + const T *x_data; + x_data = input_x->data(); + + int layers = qkv_weights.size(); + + for (int i = 0; i < layers; ++i) { +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step1, pre layernorm"; +#endif + // step1. layer_norm, only layer 0 + if (i == 0) { + auto *ln_scale_data = ln_scales[i]->data(); + auto *ln_bias_data = ln_biases[i]->data(); + // TODO(wangxi): can remove mean var in inference + ln_compute.ComputeForward(x_data, + ln_scale_data, + ln_bias_data, + buf0.data(), + ln_mean_data, + ln_var_data); + } + // auto *ln_scale_data = ln_scales[i]->data(); + // auto *ln_bias_data = ln_biases[i]->data(); + // // TODO(wangxi): can remove mean var in inference + // ln_compute.ComputeForward(x_data, + // ln_scale_data, + // ln_bias_data, + // buf0.data(), + // ln_mean_data, + // ln_var_data); + + // step2. qkv +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step2, qkv"; +#endif + const Tensor *qkv_bias = + qkv_biases.size() > 0 ? qkv_biases[i] : nullptr; + // NOTE: in decoder stage, bias is fused in fmha + const Tensor *bias = time_step ? nullptr : qkv_bias; + qkv_compute.ComputeForward( + qkv_weights[i], &buf0, bias, &qkv_out, &qkv_out); + +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step3.1 fmha"; +#endif + // step3. fmha + const Tensor *cache_kv = + cache_kvs.size() > 0 ? cache_kvs[i] : nullptr; + Tensor *cache_kv_out = cache_kv ? cache_kv_outs[i] : nullptr; + + if (time_step) { // generation decoder stage + // [2, batch_size, num_head, max_seq_len, head_size] + int max_seq_len = cache_kv->dims()[3]; + fmha(dev_ctx, + qkv_out, + *qkv_bias, + *src_mask, + sequence_lengths, + nullptr, + beam_cache_offset, + cache_kv_out, + &fmha_out, + bsz, + beam_size, + max_seq_len, + num_head, + dim_head, + time_step_cpu, + 0, + 1. / sqrt(dim_head)); + } else if (cache_kv_out) { // generation encoder stage + fmha_compute.ComputeForward(qkv_out, + nullptr, + src_mask, + &transpose_out_2, + nullptr, + &qk_out, + nullptr, + &softmax_out, + &attn_dropout_mask_out, + &attn_dropout_out, + &qktv_out, + &fmha_out); + // [3, bsz, num_head, seq_len, head_dim] + T *qkv_data = transpose_out_2_data; + int64_t q_size = bsz * seq_len * num_head * dim_head; + int64_t k_size = q_size; + const T *q_ptr = qkv_data; + const T *k_ptr = q_ptr + q_size; + const T *v_ptr = k_ptr + k_size; + + // [2, bsz, num_head, max_seq_len, head_dim] + int max_seq_len = cache_kv_out->dims()[3]; + T *cache_kv_data = cache_kv_out->data(); + int64_t cache_k_size = bsz * num_head * max_seq_len * dim_head; + + T *cache_k_ptr = cache_kv_data; + T *cache_v_ptr = cache_kv_data + cache_k_size; + + write_cache_kv(dev_ctx, + cache_k_ptr, + cache_v_ptr, + k_ptr, + v_ptr, + bsz, + num_head, + seq_len, + max_seq_len, + dim_head); + } else { // not generation + VLOG(0) << "not support!"; + } +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step3.2 out linear"; +#endif + // 输出到buf0 + out_linear_compute.ComputeForward( + out_linear_weights[i], &fmha_out, nullptr, &buf0, nullptr); + AllReduce(buf0, ring_id, buf0.numel(), dev_ctx); + +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step4"; +#endif + + // step5. ln(residual + dropout(input + bias)),在MHA里的 + auto *ln_scale_data = ffn_ln_scales[i]->data(); + auto *ln_bias_data = ffn_ln_biases[i]->data(); + auto *out_linear_bias_data = out_linear_biases[i]->data(); + + // pre layer norm : bias_dropout_residual_out is residual + fused_dropout_layernorm_helper.LayernormResidualDropoutBias( + dev_ctx, + buf0.data(), + x_data, // residual, moe out + out_linear_bias_data, + ln_scale_data, + ln_bias_data, + bias_dropout_residual_out_data, + dropout_mask_out_data, + buf0.data(), // output to buf0 + ln_mean_data, + ln_var_data); +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step5"; +#endif + // moe + // step2 resize and slice ln_out + if (mp_size > 1) { + sliced_inp = buf0.Slice(start, end); + } else { + sliced_inp = buf0; + } + VLOG(0) << "sliced_inp shape[0]: " << sliced_inp.dims()[0]; +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, gate & topk"; +#endif + // step3 gate & topk + phi::MatMulAndAdd(dev_ctx, + gate_weights[i], + &sliced_inp, + gate_biases[i], + false, + false, + true, // compute bias + &gate_out, + &gate_out); + phi::TopkKernel(dev_ctx, + gate_out, + topk, // scalar + -1, + true, + false, + &topk_value, + &topk_idx); + // step4 prepare forward + // step4.1 number count +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, number count"; +#endif + phi::NumberCountKernel(dev_ctx, topk_idx, tot_expert, &local_expert_count); + // step4.2 all_to_all +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, all_to_all"; +#endif + if (world_size > 1) { + phi::AllToAll(local_expert_count, global_expert_count, moe_ring_id, dev_ctx); + } else { + global_expert_count = local_expert_count; + } + + // global expert count resize + global_expert_count.Resize({{world_size, num_expert}}); + // fwd expert count +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, fwd expert count"; +#endif + phi::SumKernel(dev_ctx, + global_expert_count, + phi::IntArray({0}), + global_expert_count.dtype(), + false, + &fwd_expert_count); + // fwd batch size + phi::SumKernel(dev_ctx, + fwd_expert_count, + phi::IntArray({}), // axis is None + fwd_expert_count.dtype(), + false, + &fwd_batch_size); + // step4.3 cumsum & assign pos +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, cumsum"; +#endif + phi::CumsumKernel(dev_ctx, + local_expert_count, + 0, + false, + false, + false, + &lec_cum); +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, assign pos"; +#endif + phi::AssignPosCompute(dev_ctx, &lec_cum, &topk_idx, &pos, out_batch_size); +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, floor divide"; +#endif + if (topk > 1) { + phi::FloorDivideKernel(dev_ctx, + pos, + topk_tensor, + &temp_pos); + } else { + temp_pos = pos; + } +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, tensor copy"; +#endif + framework::TensorCopySync(fwd_expert_count, platform::CPUPlace(), &fwd_expert_count_cpu); + framework::TensorCopySync(fwd_batch_size, platform::CPUPlace(), &fwd_batch_size_cpu); + int fwd_bsz = fwd_batch_size_cpu.data()[0]; + + Tensor global_scatter_out; + global_scatter_out.Resize({{fwd_bsz, dim_embed}}); + dev_ctx.Alloc(&global_scatter_out, global_scatter_out.numel() * sizeof(T)); + + Tensor all_expert_out; + all_expert_out.Resize({{fwd_bsz, dim_embed}}); + dev_ctx.Alloc(&all_expert_out, all_expert_out.numel() * sizeof(T)); + + // global_scatter_out.Resize({{fwd_bsz, dim_embed}}); + // all_expert_out.Resize({{fwd_bsz, dim_embed}}); + + // step 5, MOEScatter + // step 5.1, index select + // suppose tmp_pos->shape != [0] +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, index select"; +#endif + phi::IndexSelectKernel(dev_ctx, sliced_inp, temp_pos, 0, &index_select_out); + if (world_size > 1) { + // auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); + // step 5.2, global_scatter + if (map->has(moe_ring_id)) { + phi::GlobalScatterProcessGroupFunctor(dev_ctx, + &index_select_out, + &local_expert_count, + &global_expert_count, + moe_ring_id, + true, + &global_scatter_out); + } else { + phi::GlobalScatterFunctor(dev_ctx, + &index_select_out, + &local_expert_count, + &global_expert_count, + moe_ring_id, + false, + &global_scatter_out); + } + } else { + global_scatter_out = index_select_out; + } + + // step 6, Expert Computation +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, Expert Computation"; +#endif + if (fwd_bsz != 0) { + phi::funcs::ConcatFunctor concat; + std::vector tmp_expert_out; + // if (time_step) { + // // decoder, use batched gemm + // Tensor expert_out1, expert_out2; + // expert_out1.Resize({{num_expert, fwd_bsz, dim_feedforward}}); + // expert_out2.Resize({{num_expert, fwd_bsz, dim_embed}}); + // dev_ctx.Alloc(&expert_out1, expert_out1.numel() * sizeof(T)); + // dev_ctx.Alloc(&expert_out2, expert_out2.numel() * sizeof(T)); + + // BatchedMatMulAndAdd(dev_ctx, + // expert_weights1[i], + // &global_scatter_out, + // expert_biases1[i], // bias + // false, + // false, + // true, // compute bias + // true, // is linear1 + // &expert_out1, + // &expert_out1); // bias out + // phi::GeluKernel(dev_ctx, expert_out1, approximate, &expert_out1); + // BatchedMatMulAndAdd(dev_ctx, + // expert_weights2[i], + // &expert_out1, // input + // expert_biases2[i], + // false, + // false, + // true, // compute bias + // false, // is linear1 + // &expert_out2, + // &expert_out2); + // int last_index = 0; + // for (int idx = 0; idx < num_expert; idx++) { + // int cur_expert_count = fwd_expert_count_cpu.data()[idx]; + // if (cur_expert_count <= 0) { + // continue; + // } + // int end = cur_expert_count + last_index; + // // expert_out2 slice + // Tensor tmp_sliced; + // phi::SliceCompute(dev_ctx, + // expert_out2, + // {0, 1}, + // {idx, last_index}, + // {idx + 1, end}, + // {1, 1}, + // {}, + // &tmp_sliced); + // tmp_sliced.Resize({{cur_expert_count, dim_embed}}); // maybe dont need resize + // tmp_expert_out.emplace_back(tmp_sliced); + // last_index = end; + // } + // } else { + + // encoder, use matmul + int last_index = 0; + // std::vector tmp_expert_out; + for (int idx = 0; idx < num_expert; idx++) { + int cur_expert_count = fwd_expert_count_cpu.data()[idx]; + if (cur_expert_count <= 0) { + continue; + } + int end = cur_expert_count + last_index; + + Tensor expert_out1, expert_out2, act_bias_out; + expert_out1.Resize({{cur_expert_count, dim_feedforward}}); + expert_out2.Resize({{cur_expert_count, dim_embed}}); + act_bias_out.Resize({{cur_expert_count, dim_feedforward}}); + dev_ctx.Alloc(&expert_out1, expert_out1.numel() * sizeof(T)); + dev_ctx.Alloc(&expert_out2, expert_out2.numel() * sizeof(T)); + dev_ctx.Alloc(&act_bias_out, act_bias_out.numel() * sizeof(T)); + + FusedDropoutHelper fused_act_dropout_helper( + dev_ctx, cur_expert_count, dim_feedforward, dropout_param); + + Tensor tmp_inp = global_scatter_out.Slice(last_index, end); + int expert_idx = i * num_expert + idx; + + // linear1 matmul + // VLOG(0) << "moe, Expert Computation, linear1 mul"; + phi::MatMulAndAdd(dev_ctx, + expert_weights1[expert_idx], + &tmp_inp, + nullptr, + false, + false, + false, // dont compute bias + &expert_out1, + nullptr); + // MatMulAndAdd(dev_ctx, + // expert_weights1[i]->data() + idx * dim_embed * dim_feedforward, + // tmp_inp.data(), + // nullptr, // bias + // cur_expert_count, + // dim_feedforward, + // dim_embed, + // false, + // false, + // false, // dont compute bias + // expert_out1.data(), + // nullptr); + + // bias gelu + // VLOG(0) << "moe, Expert Computation, add bias & gelu"; + // inplace + fused_act_dropout_helper.DropoutActBias(dev_ctx, + expert_out1.data(), + expert_biases1[expert_idx]->data(), + "gelu", + act_bias_out.data(), + nullptr, + 1.0, + nullptr, + 0, + 1.0, + 1, + 127.0, + -127.0, + approximate); + // fused_act_dropout_helper.DropoutActBias(dev_ctx, + // expert_out1.data(), + // expert_biases1[i]->data() + idx * dim_feedforward, + // "gelu", + // act_bias_out.data(), + // nullptr, + // 1.0, + // nullptr, + // 0, + // 1.0, + // 1, + // 127.0, + // -127.0, + // approximate); + + // linear2 matmul & add + // VLOG(0) << "moe, Expert Computation, linear2 matmul & add"; + phi::MatMulAndAdd(dev_ctx, + expert_weights2[expert_idx], + &act_bias_out, + expert_biases2[expert_idx], + false, + false, + true, // compute bias + &expert_out2, + &expert_out2); + // MatMulAndAdd(dev_ctx, + // expert_weights2[i]->data() + idx * dim_embed * dim_feedforward, + // act_bias_out.data(), + // expert_biases2[i]->data() + idx * dim_embed, + // cur_expert_count, + // dim_embed, + // dim_feedforward, + // false, + // false, + // true, // compute bias + // expert_out2.data(), + // expert_out2.data()); + // Addmm(dev_ctx, + // *expert_biases2[expert_idx], + // act_bias_out, + // *expert_weights2[expert_idx], + // 1.0, + // 1.0, + // &expert_out2); + tmp_expert_out.emplace_back(expert_out2); + last_index = end; + // } + } + // at last, concat all expert out + concat(dev_ctx, tmp_expert_out, 0, &all_expert_out); + } else { + all_expert_out = global_scatter_out; + } + + // step7. MOEGather +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, MOEGather"; +#endif + if (world_size > 1) { + // auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); + // step 7.1, global_gather + if (map->has(moe_ring_id)) { + phi::GlobalGatherProcessGroupFunctor(dev_ctx, + &all_expert_out, + &local_expert_count, + &global_expert_count, + moe_ring_id, + true, + &global_gather_out); + } else { + phi::GlobalGatherFunctor(dev_ctx, + &all_expert_out, + &local_expert_count, + &global_expert_count, + moe_ring_id, + false, + &global_gather_out); + } + } else { + global_gather_out = all_expert_out; + } + // step 7.2, local_gather or scatter + // suppose pos->shape != [0] +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, local_gather or scatter"; +#endif + phi::ScatterKernel(dev_ctx, + moe_gather_out, + pos, + global_gather_out, + true, + &moe_gather_out); + // step 8, reshape & bmm + // moe gather out reshape +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, reshape & bmm"; +#endif + moe_gather_out.Resize({{sliced_bsz_seq, topk, dim_embed}}); + topk_value.Resize({{sliced_bsz_seq, 1, topk}}); + phi::BmmKernel(dev_ctx, topk_value, moe_gather_out, &bmm_out); + bmm_out.Resize({{sliced_bsz_seq, dim_embed}}); + // step 9, AllGather +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, AllGather"; +#endif + if (mp_size > 1) { + // all gather + phi::AllGather(bmm_out, all_gather_out, moe_ring_id, dev_ctx); + } else { + all_gather_out = bmm_out; + } + + // step 11, add residual +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "moe, add residual"; +#endif + if (i < layers - 1) { + // add residual & next layer norm + auto *ln_scale_data = ln_scales[i + 1]->data(); + auto *ln_bias_data = ln_biases[i + 1]->data(); + fused_dropout_layernorm_helper.LayernormResidualDropoutBias( + dev_ctx, + all_gather_out.data(), // src + bias_dropout_residual_out_data, // residual + nullptr, // bias + ln_scale_data, + ln_bias_data, + moe_out.data(), // add out, next layer real input, for residual + dropout_mask_out_data, + buf0.data(), // out, after layernorm + ln_mean_data, + ln_var_data); + } else { + // last layer, only add residual + phi::AddKernel(dev_ctx, all_gather_out, bias_dropout_residual_out, &moe_out); + } + + // phi::AddKernel(dev_ctx, all_gather_out, bias_dropout_residual_out, &moe_out); + x_data = moe_out.data(); + VLOG(0) << "layer " << i + 1 << " end"; + + } // layers loop end + moe_out.Resize({{bsz, seq_len, dim_embed}}); + *out = moe_out; + VLOG(0) << "kernel end"; + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_CUDA_KERNEL(fused_multi_transformer_moe, + ops::FusedMultiTransformerMoeOpKernel, + ops::FusedMultiTransformerMoeOpKernel); diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_moe_op.h b/paddle/fluid/operators/fused/fused_multi_transformer_moe_op.h new file mode 100644 index 0000000000000..01a5e344ecc54 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_multi_transformer_moe_op.h @@ -0,0 +1,273 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +// This file has been adapted from FasterTransformer file: +// https://github.com/NVIDIA/FasterTransformer/blob/v4.0/fastertransformer/cuda/masked_multihead_attention.cu +// We add License in the head. + +#pragma once +// #include +#include "paddle/fluid/operators/fused/fused_multi_transformer_op.h" +#include "paddle/phi/kernels/gpu/fused_moe_kernel.cu.h" +#include "paddle/fluid/operators/fused/attn_gemm_int8.h" +// #include "paddle/phi/kernels/funcs/eigen/common.h" +// #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" +// #include "paddle/phi/kernels/impl/slice_kernel_impl.h" +// #include "paddle/phi/kernels/gelu_kernel.h" +// #include "paddle/fluid/operators/fused/attn_bias_add.cu.h" + +namespace paddle { +namespace operators { + +using Tensor = Tensor; + +// template +// void BatchedMatMulAndAdd(const phi::GPUContext& dev_ctx, +// const Tensor* weight, +// const Tensor* input, +// const Tensor* bias, +// bool istransA, +// bool istransB, +// bool compute_bias, +// bool is_linear1, +// Tensor* output, +// Tensor* bias_out) { +// // Note: for blas.BatchedGEMM API in Paddle, it treats all inputs as row-major. +// // for input [bsz_seqlen, dim_embed] * expert_weight [expert_num, dim_embed, dim_feedforward] +// CBLAS_TRANSPOSE transA = istransA ? CblasTrans : CblasNoTrans; +// CBLAS_TRANSPOSE transB = istransB ? CblasTrans : CblasNoTrans; +// T alpha = static_cast(1.0); +// T beta = static_cast(0.0); + +// // (m, n, k) = bsz_seq, output_size, input_size, (input, weight, out) +// auto blas = phi::funcs::GetBlas(dev_ctx); +// const int x_ndim = input->dims().size(); +// auto M = input->dims()[x_ndim - 2]; +// auto N = weight->dims()[2]; +// auto K = input->dims()[x_ndim - 1]; +// auto out_batch_size = weight->dims()[0]; +// int64_t strideA = is_linear1 ? 0 : M * K; +// blas.BatchedGEMM(transA, +// transB, +// M, +// N, +// K, +// alpha, +// input->data(), +// weight->data(), +// beta, +// output->data(), +// out_batch_size, +// strideA, +// K * N); +// if (compute_bias) { +// // bias_out = output + bias +// std::vector ins = {output, bias}; +// std::vector outs = {bias_out}; +// phi::funcs::BroadcastKernel( +// dev_ctx, ins, &outs, -1, phi::funcs::AddFunctor()); +// } +// } + +// template +// void MatMulAndAdd(const phi::GPUContext& dev_ctx, +// const T* weight, // input & output params is data pointer +// const T* input, +// const T* bias, +// int M, +// int N, +// int K, +// bool istransA, +// bool istransB, +// bool compute_bias, +// T* output, +// T* bias_out) { +// // Note: for blas.GEMM API in Paddle, it treats all inputs as row-major. +// // here: (transa, transb): nt, input * weight. +// CBLAS_TRANSPOSE transA = istransA ? CblasTrans : CblasNoTrans; +// CBLAS_TRANSPOSE transB = istransB ? CblasTrans : CblasNoTrans; +// T alpha = static_cast(1.0); +// T beta = static_cast(0.0); +// // input->dims()[0], // M +// // weight->dims()[1], // N +// // input->dims()[1], // K +// // (m, n, k) = bsz_seq, output_size, input_size, (input, weight, out) +// auto blas = phi::funcs::GetBlas(dev_ctx); +// blas.GEMM(transA, +// transB, +// M, +// N, +// K, +// alpha, +// input, +// weight, +// beta, +// output); +// if (compute_bias) { +// // bias_out = output + bias +// // std::vector ins = {output, bias}; +// // std::vector outs = {bias_out}; +// // phi::funcs::BroadcastKernel( +// // dev_ctx, ins, &outs, -1, phi::funcs::AddFunctor()); +// LaunchBiasAddFwKernel(dev_ctx, +// M, +// N, +// output, +// bias, +// bias_out); +// } +// } + +// template +// using PhiEigenTensor = phi::EigenTensor; + +// using Array1 = Eigen::DSizes; +// using Array2 = Eigen::DSizes; + +// template +// void Addmm(const phi::GPUContext& dev_ctx, +// const Tensor& input, // bias +// const Tensor& x, // input +// const Tensor& y, // weight +// float alpha, +// float beta, +// Tensor* out) { +// auto input_dims = input.dims(); +// auto x_dims = x.dims(); +// auto y_dims = y.dims(); + +// Tensor input_2d(input); +// if (input.dims().size() == 1) { +// input_dims = {1, input.dims()[0]}; +// input_2d.Resize(input_dims); +// } + +// // dev_ctx.template Alloc(out); +// auto blas = phi::funcs::GetBlas(dev_ctx); + +// // calc broadcast dim +// Array2 bcast_dims; +// bcast_dims[0] = x_dims[0] / input_dims[0]; +// bcast_dims[1] = y_dims[1] / input_dims[1]; +// VLOG(3) << "bcast_dims=[" << bcast_dims[0] << "," << bcast_dims[1] << "]"; +// // broadcast using eigen +// const Tensor& const_ref_input = input_2d; +// auto eigen_input = PhiEigenTensor::From(const_ref_input); +// auto eigen_out = PhiEigenTensor::From(*out); +// auto& place = *dev_ctx.eigen_device(); +// phi::funcs::EigenBroadcast, T, 2>::Eval( +// place, eigen_out, eigen_input, bcast_dims); + +// T t_alpha = static_cast(alpha); +// T t_beta = static_cast(beta); +// blas.GEMM(false, +// false, +// x_dims[0], +// y_dims[1], +// x_dims[1], +// t_alpha, +// x.data(), +// x_dims[1], +// y.data(), +// y_dims[1], +// t_beta, +// out->data(), +// y_dims[1]); +// } + +using phi::backends::gpu::GpuLaunchConfig; +// This function is used to execute GEMM, with input and output's types are T +// and INT8. +template +void MatMulTToINT8(const phi::GPUContext& dev_ctx, + const Tensor* weight, + const float quant_in_scale, + const Tensor* input, + Tensor* input_tmp, + Tensor* output, + int m, + int n, + int k, + Tensor* workspace = nullptr, + const int quant_round_type = 1, + const float quant_max_bound = 127.0, + const float quant_min_bound = -127.0) { + cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle(); + auto helper = std::make_unique>(m, k, n, lt_handle); + quantize_kernel_launcher(input->data(), + input_tmp->data(), + quant_in_scale, + m, + k, + quant_round_type, + quant_max_bound, + quant_min_bound, + dev_ctx.stream()); + + helper->GEMM(input_tmp->data(), + weight->data(), + output->data(), + dev_ctx.stream(), + (void*)workspace->data(), + workspace->numel()); +} + +template +void MatMulINT8ToT(const phi::GPUContext& dev_ctx, + const Tensor* weight, + const float quant_in_scale, + const Tensor* input, + const Tensor* bias, + Tensor* output, + Tensor* output_tmp, + Tensor* bias_out, + const Tensor* dequant_out_scale, + int m, + int n, + int k, + bool compute_bias, + Tensor* workspace = nullptr) { + cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle(); + auto helper = std::make_unique>(m, k, n, lt_handle); + auto gpu_config = std::make_unique( + phi::backends::gpu::GetGpuLaunchConfig1D( + dev_ctx, m * n, DequantKernelVecSize)); + + helper->GEMM(input->data(), + weight->data(), + output_tmp->data(), + dev_ctx.stream(), + (void*)workspace->data(), + workspace->numel()); + + dequantize_kernel_launcher(output_tmp->data(), + output->data(), + m, + n, + dev_ctx.stream(), + gpu_config.get(), + quant_in_scale, + dequant_out_scale->data()); + + if (compute_bias) { + // bias_out = output + bias + std::vector ins = {output, bias}; + std::vector outs = {bias_out}; + phi::funcs::BroadcastKernel( + dev_ctx, ins, &outs, -1, phi::funcs::AddFunctor()); + } +} + +} // operators +} // paddle \ No newline at end of file diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cc b/paddle/fluid/operators/fused/fused_multi_transformer_op.cc index 86de140b9cde8..3edb1a733e29d 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cc +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cc @@ -21,8 +21,6 @@ limitations under the License. */ namespace paddle { namespace operators { -using Tensor = framework::Tensor; - class FusedMultiTransformerOp : public framework::OperatorWithKernel { private: static constexpr const char *OpName = "FusedMultiTransformerOp"; @@ -93,27 +91,6 @@ class FusedMultiTransformerOp : public framework::OperatorWithKernel { x_dim, y_dim)); - if (ctx->Attrs().Get("ring_id") == -1) { - if (trans_qkvw) { - PADDLE_ENFORCE_EQ(y_dim[1] * y_dim[2], - y_dim[3], - platform::errors::InvalidArgument( - "The dimensions of qkv_weight must be 4" - "(3, num_head, dim_head, dim_embed)," - "and must satisfy the limitations: " - "(num_head * dim_head == dim_embed)")); - - } else { - PADDLE_ENFORCE_EQ(y_dim[2] * y_dim[3], - y_dim[0], - platform::errors::InvalidArgument( - "The dimensions of qkv_weight must be 4" - "(dim_embed, 3, num_head, dim_head)," - "and must satisfy the limitations: " - "(num_head * dim_head == dim_embed)")); - } - } - if (ctx->HasInputs("CacheKV")) { // [2, batch_size, num_head, max_seq_len, head_size] const auto &c_dims = ctx->GetInputsDim("CacheKV"); @@ -129,13 +106,6 @@ class FusedMultiTransformerOp : public framework::OperatorWithKernel { paddle::platform::errors::InvalidArgument( "The first dim of CacheKV must be 2, but got %d", c_dim[0])); // 2 - PADDLE_ENFORCE_EQ(c_dim[1], - x_dim[0], - paddle::platform::errors::InvalidArgument( - "The second dim of CacheKV must be equal with " - "batch size %d, but got %d", - x_dim[0], - c_dim[1])); // batch_size PADDLE_ENFORCE_EQ(c_dim[2], trans_qkvw ? y_dim[1] : y_dim[2], paddle::platform::errors::InvalidArgument( @@ -143,12 +113,6 @@ class FusedMultiTransformerOp : public framework::OperatorWithKernel { "head %d, but got %d", trans_qkvw ? y_dim[1] : y_dim[2], c_dim[2])); // num_head - PADDLE_ENFORCE_GT( - c_dim[3], - 0, - paddle::platform::errors::InvalidArgument( - "The forth dim of CacheKV must be greater than 0, but got %d", - c_dim[3])); // cache_seq_len PADDLE_ENFORCE_EQ(c_dim[4], trans_qkvw ? y_dim[2] : y_dim[3], paddle::platform::errors::InvalidArgument( @@ -170,7 +134,7 @@ class FusedMultiTransformerOp : public framework::OperatorWithKernel { framework::OpKernelType GetKernelTypeForVar( const std::string &var_name, - const Tensor &tensor, + const phi::DenseTensor &tensor, const framework::OpKernelType &expected_kernel_type) const override { if (var_name == "TimeStep") { VLOG(10) << "var_name:" << var_name << " need not to transform"; @@ -199,16 +163,27 @@ class FusedMultiTransformerOpOpMaker AddInput("CacheKV", "(optional) The cached KV for generation inference.") .AsDispensable() .AsDuplicable(); + AddInput("PreCaches", + "(optional) The prefix caches for generation inference.") + .AsDispensable() + .AsDuplicable(); + AddInput("RotaryPosEmb", + "(optional) The RoPE embeddings for generation inference.") + .AsDispensable(); + AddInput("BeamCacheOffset", + "(optional) The offset of CacheKV when using BeamSearch.") + .AsDispensable(); AddInput("TimeStep", "(optional, int) The time step for generation inference.") .AsDispensable(); + AddInput("SeqLengths", "(optional) The sequence length tensor of inputs.") + .AsDispensable(); AddInput("SrcMask", "(optional) The attention mask tensor in fmha.") .AsDispensable(); AddInput("OutLinearW", "The out_linear weight tensor.").AsDuplicable(); AddInput("OutLinearBias", "The out_linear bias tensor.") .AsDispensable() .AsDuplicable(); - AddInput("FFNLnScale", "The layer_norm scale of FusedFeedForward op") .AsDuplicable(); AddInput("FFNLnBias", "The layer_norm bias of FusedFeedForward op") @@ -223,17 +198,39 @@ class FusedMultiTransformerOpOpMaker AddInput("FFN2Bias", "The linear2 bias input of FusedFeedForward op") .AsDispensable() .AsDuplicable(); - + AddInput("QKVWScale", "QKVWScale") + .AsDispensable() + .AsDuplicable(); + AddInput("OutLinearWScale", "OutLinearWScale") + .AsDispensable() + .AsDuplicable(); + AddInput("FFN1WeightScale", "FFN1WeightScale") + .AsDispensable() + .AsDuplicable(); + AddInput("FFN2WeightScale", "FFN2WeightScale") + .AsDispensable() + .AsDuplicable(); AddOutput("CacheKVOut", "The updated cache KV. Inplace with CacheKV") .AsDispensable() .AsDuplicable(); AddOutput("Out", "Result after multi ."); - AddAttr("pre_layer_norm", "if true, the attention op uses pre_layer_norm architecure, " "else, uses post_layer_norm architecuture. " "[default true].") .SetDefault(true); + AddAttr("rotary_emb_dims", + "the Attr(dims) for RotaryPosEmb's Computation [default 0].") + .SetDefault(0) + .AddCustomChecker([](const int &rotary_emb_dims) { + PADDLE_ENFORCE_EQ( + rotary_emb_dims >= 0 && rotary_emb_dims <= 2, + true, + platform::errors::InvalidArgument( + "'rotary_emb_dims' in Op(Rotray) should be between" + "0 and 2, But received [%s].", + rotary_emb_dims)); + }); AddAttr("epsilon", "Constant for numerical stability [default 1e-5].") .SetDefault(1e-5) @@ -272,7 +269,17 @@ class FusedMultiTransformerOpOpMaker "dropout_implementation can only be downgrade_in_infer or " "upscale_in_train")); }); - AddAttr("act_method", "act_method").SetDefault("gelu"); + AddAttr("act_method", "act_method") + .SetDefault("gelu") + .AddCustomChecker([](const std::string &act_type) { + PADDLE_ENFORCE_EQ( + act_type == "gelu" || act_type == "geglu" || act_type == "relu" || act_type == "none", + true, + platform::errors::InvalidArgument( + "Only support `gelu`, `geglu`, `relu`, `none` activation in " + "FusedMultiTransformer. ")); + }); + AddAttr( "trans_qkvw", "Whether the weights of qkv should be transposed. If true," @@ -281,11 +288,14 @@ class FusedMultiTransformerOpOpMaker "[dim_embed, 3, num_head, dim_head]") .SetDefault(true); + AddAttr("quant_weight","Whether do weight quant") + .SetDefault(false); + AddAttr( "ring_id", "ring id for tensor model parallel. distributed training and inference") .SetDefault(-1); - + AddComment(R"DOC(fused multi transformer layers op)DOC"); } }; @@ -309,3 +319,4 @@ REGISTER_OP_VERSION(fused_multi_transformer) "trans_qkvw", "A flag to indicate whether to transpose for weights of qkv.", true)); + diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu index 5cf22885aabba..d0f2c7ba08fe9 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu @@ -21,34 +21,99 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { using U = LayerNormParamType; auto &dev_ctx = ctx.cuda_device_context(); - auto *time_step = ctx.Input("TimeStep"); + auto *time_step = ctx.Input("TimeStep"); // 0. input - auto *input_x = ctx.Input("X"); + auto *input_x = ctx.Input("X"); const auto input_x_dims = input_x->dims(); int bsz = input_x_dims[0]; int seq_len = input_x_dims[1]; int dim_embed = input_x_dims[2]; int bsz_seq = bsz * seq_len; + // LOG(INFO) << "intput X: bsz: " << bsz << ", seq_len: " << seq_len << ", dim_embed: " << dim_embed; + const std::string act_method = ctx.Attr("act_method"); + bool use_glu = (act_method == "geglu"); + bool remove_padding = false; + auto *sequence_lengths = ctx.Input("SeqLengths"); + if (sequence_lengths) { + remove_padding = true; + } + auto *beam_cache_offset = ctx.Input("BeamCacheOffset"); + int beam_size = 1; + if (beam_cache_offset) { + beam_size = beam_cache_offset->dims()[1]; + } + // LOG(INFO) << "beam_size: " << beam_size; + phi::DenseTensor d_token_tensor; + phi::DenseTensor padding_offset_tensor; + phi::DenseTensor x_remove_padding; + bool encoder_remove_padding = (remove_padding && !time_step); + int token_num = 0; + + auto *out = ctx.Output("Out"); + auto *from_data = dev_ctx.Alloc(out, out->numel() * sizeof(T)); + + // Init out + if (encoder_remove_padding) { + InitValue(dev_ctx, from_data, out->numel(), static_cast(0.)); + } + + // remove padding in encoder + if (encoder_remove_padding) { + // just for encoder + d_token_tensor.Resize({{1}}); + auto *d_token_num = dev_ctx.Alloc( + &d_token_tensor, d_token_tensor.numel() * sizeof(int)); + // alloc the max size of padding_offset_tensor + padding_offset_tensor.Resize({{bsz_seq}}); + dev_ctx.Alloc(&padding_offset_tensor, + padding_offset_tensor.numel() * sizeof(int)); + InvokeGetPaddingOffset(dev_ctx, + &token_num, + d_token_num, + padding_offset_tensor.data(), + sequence_lengths->data(), + bsz, + seq_len); + padding_offset_tensor.Resize({{token_num}}); + x_remove_padding.Resize({{token_num, dim_embed}}); + dev_ctx.Alloc(&x_remove_padding, x_remove_padding.numel() * sizeof(T)); + InvokeRemovePadding(dev_ctx, + x_remove_padding.data(), + input_x->data(), + padding_offset_tensor.data(), + token_num, + dim_embed); + } else { + token_num = bsz_seq; + } + + if (token_num == 0) { + return; + } + + auto *padding_offset_data = + encoder_remove_padding ? padding_offset_tensor.data() : nullptr; + // whether do weight only quant // 1. layer norm const auto pre_layer_norm = ctx.Attr("pre_layer_norm"); const float epsilon = ctx.Attr("epsilon"); - auto ln_scales = ctx.MultiInput("LnScale"); - auto ln_biases = ctx.MultiInput("LnBias"); + auto ln_scales = ctx.MultiInput("LnScale"); + auto ln_biases = ctx.MultiInput("LnBias"); - auto ln_compute = AttnLayerNorm(dev_ctx, epsilon, bsz_seq, dim_embed); - Tensor ln_mean, ln_var; - ln_mean.Resize({{bsz_seq}}); + auto ln_compute = AttnLayerNorm(dev_ctx, epsilon, token_num, dim_embed); + phi::DenseTensor ln_mean, ln_var; + ln_mean.Resize({{token_num}}); auto *ln_mean_data = dev_ctx.Alloc(&ln_mean, ln_mean.numel() * sizeof(U)); - ln_var.Resize({{bsz_seq}}); + ln_var.Resize({{token_num}}); auto *ln_var_data = dev_ctx.Alloc(&ln_var, ln_var.numel() * sizeof(U)); // 2. qkv // x: qkv's input [batch_size, seq_len, dim_embed] // y: qkv's weight: [3, num_head, dim_head, dim_embed] - auto qkv_weights = ctx.MultiInput("QKVW"); - auto qkv_biases = ctx.MultiInput("QKVBias"); + auto qkv_weights = ctx.MultiInput("QKVW"); + auto qkv_biases = ctx.MultiInput("QKVBias"); const bool trans_qkvw = ctx.Attr("trans_qkvw"); const auto qkv_w_dims = qkv_weights[0]->dims(); int num_head = trans_qkvw ? qkv_w_dims[1] : qkv_w_dims[2]; @@ -59,71 +124,95 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { bool compute_bias = qkv_biases.size() > 0 && time_step == nullptr; // (transA, transB, compute_bias) = (false, trans_qkvw, false) + // Since we fused QKVBias into QKVBiasAddTransposeSplit kernel, here we + // set compute_bias as false. auto qkv_compute = AttnMatMul(dev_ctx, false, trans_qkvw, - bsz_seq, + token_num, output_size, input_size, - compute_bias); - Tensor qkv_out; - qkv_out.Resize({{bsz, seq_len, 3, num_head, dim_head}}); + /*compute_bias=*/false); + phi::DenseTensor qkv_out; + qkv_out.Resize({{token_num, 3, num_head, dim_head}}); auto *qkv_out_data = dev_ctx.Alloc(&qkv_out, qkv_out.numel() * sizeof(T)); + // 2.1 rotary + auto *rotary_tensor = ctx.Input("RotaryPosEmb"); + const int rotary_emb_dims = ctx.Attr("rotary_emb_dims"); + // 3. fmha AttnDropoutParam attn_param( true, "upscale_in_train", 0.0, true, true, 0, nullptr); auto fmha_compute = FMHARef(dev_ctx, bsz, seq_len, num_head, dim_head, attn_param); - auto *src_mask = ctx.Input("SrcMask"); - auto cache_kvs = ctx.MultiInput("CacheKV"); - auto cache_kv_outs = ctx.MultiOutput("CacheKVOut"); - // auto *time_step = ctx.Input("TimeStep"); + auto *src_mask = ctx.Input("SrcMask"); + auto cache_kvs = ctx.MultiInput("CacheKV"); + auto cache_kv_outs = ctx.MultiOutput("CacheKVOut"); + + int cache_offset = 0; + + int time_step_cpu = 0; + if (time_step) { + // VLOG(0) << "time_step: " << *time_step; + time_step_cpu = src_mask->dims()[3] - 1; + // VLOG(0) << "time_step_cpu: " << time_step_cpu; + } auto out_seq_len = seq_len; if (time_step) { - PADDLE_ENFORCE_EQ(time_step->place(), - platform::CPUPlace(), - platform::errors::PreconditionNotMet( - "The place of input(TimeStep) must be CPUPlace.")); - // cache_seq_len - int time_step_value = time_step->data()[0]; - PADDLE_ENFORCE_GT(time_step_value, + PADDLE_ENFORCE_GT(time_step_cpu, 0, platform::errors::PreconditionNotMet( "The value of time_step must > 0, but now is %d", - time_step_value)); + time_step_cpu)); PADDLE_ENFORCE_EQ( seq_len, 1, platform::errors::PreconditionNotMet( "In decode stage, the seq_len of input must be 1, but now is %d", seq_len)); - out_seq_len += time_step_value; + out_seq_len += time_step_cpu; + } else { + out_seq_len += cache_offset; + } + + phi::DenseTensor q_transpose_out, kv_transpose_out, qk_out; + q_transpose_out.Resize({{bsz, num_head, seq_len, dim_head}}); + auto *q_transpose_out_data = + dev_ctx.Alloc(&q_transpose_out, q_transpose_out.numel() * sizeof(T)); + + kv_transpose_out.Resize({{2, bsz, num_head, seq_len, dim_head}}); + auto *kv_transpose_out_data = dev_ctx.Alloc( + &kv_transpose_out, kv_transpose_out.numel() * sizeof(T)); + + if (encoder_remove_padding) { + InitValue(dev_ctx, + q_transpose_out_data, + q_transpose_out.numel(), + static_cast(0.)); + InitValue(dev_ctx, + kv_transpose_out_data, + kv_transpose_out.numel(), + static_cast(0.)); } - Tensor transpose_out_2, qk_out; - transpose_out_2.Resize({{3, bsz, num_head, seq_len, dim_head}}); - auto *transpose_out_2_data = - dev_ctx.Alloc(&transpose_out_2, transpose_out_2.numel() * sizeof(T)); qk_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); auto *qk_out_data = dev_ctx.Alloc(&qk_out, qk_out.numel() * sizeof(T)); - Tensor softmax_out; - Tensor attn_dropout_mask_out, attn_dropout_out; - Tensor qktv_out, fmha_out; + phi::DenseTensor src_mask_out; + + // [2, bs, num_head, cache_seq_len + seq_len, head_dim] + phi::DenseTensor pre_cache_kv_out; + + phi::DenseTensor softmax_out; + phi::DenseTensor attn_dropout_mask_out, attn_dropout_out; + phi::DenseTensor qktv_out, fmha_out; softmax_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); auto *softmax_out_data = dev_ctx.Alloc(&softmax_out, softmax_out.numel() * sizeof(T)); - attn_dropout_mask_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); - auto *attn_dropout_mask_out_data = dev_ctx.Alloc( - &attn_dropout_mask_out, attn_dropout_mask_out.numel() * sizeof(T)); - attn_dropout_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); - auto *attn_dropout_data_data = dev_ctx.Alloc( - &attn_dropout_out, attn_dropout_out.numel() * sizeof(T)); - qktv_out.Resize({{bsz, num_head, seq_len, dim_head}}); auto *qktv_out_data = dev_ctx.Alloc(&qktv_out, qktv_out.numel() * sizeof(T)); @@ -132,97 +221,113 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { dev_ctx.Alloc(&fmha_out, fmha_out.numel() * sizeof(T)); // 4. out_linear - auto out_linear_weights = ctx.MultiInput("OutLinearW"); - auto out_linear_biases = ctx.MultiInput("OutLinearBias"); + auto out_linear_weights = ctx.MultiInput("OutLinearW"); + auto out_linear_biases = ctx.MultiInput("OutLinearBias"); int ring_id = ctx.Attr("ring_id"); // (transA, transB, compute_bias) = (false, false, false) auto out_linear_compute = AttnMatMul( - dev_ctx, false, false, bsz_seq, dim_embed, hidden_size, false); + dev_ctx, false, false, token_num, dim_embed, hidden_size, false); // 5. ln(residual + bias) DropoutParam dropout_param2(true, 0, true, true, 0.0, nullptr, 0); FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( - dev_ctx, bsz_seq, dim_embed, dropout_param2, epsilon); - auto ffn_ln_scales = ctx.MultiInput("FFNLnScale"); - auto ffn_ln_biases = ctx.MultiInput("FFNLnBias"); - Tensor bias_dropout_residual_out, dropout_mask_out; + dev_ctx, token_num, dim_embed, dropout_param2, epsilon); + auto ffn_ln_scales = ctx.MultiInput("FFNLnScale"); + auto ffn_ln_biases = ctx.MultiInput("FFNLnBias"); + phi::DenseTensor bias_dropout_residual_out, dropout_mask_out; T *bias_dropout_residual_out_data = nullptr; if (pre_layer_norm) { - bias_dropout_residual_out.Resize({{bsz, seq_len, dim_embed}}); + bias_dropout_residual_out.Resize({{token_num, dim_embed}}); bias_dropout_residual_out_data = dev_ctx.Alloc(&bias_dropout_residual_out, bias_dropout_residual_out.numel() * sizeof(T)); } - dropout_mask_out.Resize({{bsz, seq_len, dim_embed}}); - auto *dropout_mask_out_data = dev_ctx.Alloc( - &dropout_mask_out, dropout_mask_out.numel() * sizeof(uint8_t)); + uint8_t *dropout_mask_out_data = nullptr; // 6. ffn matmul1 - auto ffn1_weights = ctx.MultiInput("FFN1Weight"); - auto ffn1_biases = ctx.MultiInput("FFN1Bias"); + auto ffn1_weights = ctx.MultiInput("FFN1Weight"); + auto ffn1_weights_scales = + ctx.MultiInput("FFN1WeightScale"); + auto ffn1_biases = ctx.MultiInput("FFN1Bias"); auto ffn1_weight_dim = ffn1_weights[0]->dims(); int dim_ffn = ffn1_weight_dim[1]; + FFNGluHelper ffn1_glu_helper( + dev_ctx, act_method, token_num, dim_ffn / 2, dim_ffn, dim_embed); auto ffn1_linear_compute = AttnMatMul( - dev_ctx, false, false, bsz_seq, dim_ffn, dim_embed, false); - Tensor ffn1_out; - ffn1_out.Resize({{bsz_seq, dim_ffn}}); + dev_ctx, false, false, token_num, dim_ffn, dim_embed, false); + phi::DenseTensor ffn1_out; + ffn1_out.Resize({{token_num, dim_ffn}}); auto *ffn1_out_data = dev_ctx.Alloc(&ffn1_out, ffn1_out.numel() * sizeof(T)); // 7. ffn act + bias DropoutParam ffn1_dropout_param(true, 0, true, true, 0.0, nullptr, 0); - FusedDropoutHelper fused_act_dropout_helper( - dev_ctx, bsz_seq, dim_ffn, ffn1_dropout_param); - Tensor ffn1_dropout_out, ffn1_dropout_mask; - ffn1_dropout_out.Resize({{bsz_seq, dim_ffn}}); + FusedDropoutHelper fused_act_dropout_helper( + dev_ctx, token_num, dim_ffn, ffn1_dropout_param); + phi::DenseTensor ffn1_dropout_out, ffn1_dropout_mask; + int tmp_dim_ffn = dim_ffn; + if (use_glu) tmp_dim_ffn /= 2; + int8_t *ffn1_dropout_mask_data = nullptr; + ffn1_dropout_out.Resize({{token_num, tmp_dim_ffn}}); auto *ffn1_dropout_out_data = dev_ctx.Alloc( &ffn1_dropout_out, ffn1_dropout_out.numel() * sizeof(T)); - ffn1_dropout_mask.Resize({{bsz_seq, dim_ffn}}); - auto *ffn1_dropout_mask_data = dev_ctx.Alloc( - &ffn1_dropout_mask, ffn1_dropout_mask.numel() * sizeof(uint8_t)); // 8. ffn2 matmul - auto ffn2_weights = ctx.MultiInput("FFN2Weight"); - auto ffn2_biases = ctx.MultiInput("FFN2Bias"); + auto ffn2_weights = ctx.MultiInput("FFN2Weight"); + auto ffn2_biases = ctx.MultiInput("FFN2Bias"); auto ffn2_linear_compute = AttnMatMul( - dev_ctx, false, false, bsz_seq, dim_embed, dim_ffn, false); + dev_ctx, false, false, token_num, dim_embed, tmp_dim_ffn, false); // 9. ffn2 residual bias DropoutParam ffn2_dropout_param(true, 0, true, true, 0.0, nullptr, 0); FusedDropoutLayerNormHelper ffn2_fused_dropout_helper( - dev_ctx, bsz_seq, dim_embed, ffn2_dropout_param, epsilon); - - // calc - auto *out = ctx.Output("Out"); - auto *from_data = dev_ctx.Alloc(out, out->numel() * sizeof(T)); - Tensor *from_tensor = out; - Tensor tmp_out; - tmp_out.Resize({{bsz, seq_len, dim_embed}}); + dev_ctx, token_num, dim_embed, ffn2_dropout_param, epsilon); + + phi::DenseTensor tmp_out, tmp_out_rm_padding; + tmp_out.Resize({{token_num, dim_embed}}); + if (encoder_remove_padding) { + tmp_out_rm_padding.Resize({{token_num, dim_embed}}); + auto *tmp_out_rm_padding_data = dev_ctx.Alloc( + &tmp_out_rm_padding, tmp_out_rm_padding.numel() * sizeof(T)); + } auto *tmp_out_data = dev_ctx.Alloc(&tmp_out, tmp_out.numel() * sizeof(T)); - auto *x_data = input_x->data(); - Tensor *buf0 = nullptr; - Tensor *buf1 = nullptr; + const T *x_data; + if (encoder_remove_padding) { + x_data = x_remove_padding.data(); + } else { + x_data = input_x->data(); + } + phi::DenseTensor *buf0 = nullptr; + phi::DenseTensor *buf1 = nullptr; // step0: x --> buf1 // step1: buf1 --> buf0 // step2: buf0 --> buf1 int layers = qkv_weights.size(); - if (pre_layer_norm) { - if (layers & 1) { - // odd, set buf1 as out + if (encoder_remove_padding) { + // In the case of variable lengths, the padding needs to be rebuilt + // eventually. So buf0 and buf1 do not need to be changed according to the + // pre_layer_norm and the number of layers. + buf0 = &tmp_out; + buf1 = &tmp_out_rm_padding; + } else { + if (pre_layer_norm) { + if (layers & 1) { + // odd, set buf1 as out + buf0 = &tmp_out; + buf1 = out; + } else { + // even, set buf0 as out + buf0 = out; + buf1 = &tmp_out; + } + } else { buf0 = &tmp_out; buf1 = out; - } else { - // even, set buf0 as out - buf0 = out; - buf1 = &tmp_out; } - } else { - buf0 = &tmp_out; - buf1 = out; } for (int i = 0; i < layers; ++i) { @@ -238,28 +343,26 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { ln_mean_data, ln_var_data); } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step1"; -#endif // step2. qkv - const Tensor *qkv_bias = qkv_biases.size() > 0 ? qkv_biases[i] : nullptr; + const phi::DenseTensor *qkv_bias = + qkv_biases.size() > 0 ? qkv_biases[i] : nullptr; // NOTE: in decoder stage, bias is fused in fmha - const Tensor *bias = time_step ? nullptr : qkv_bias; + const phi::DenseTensor *bias = time_step ? nullptr : qkv_bias; if (!pre_layer_norm && i == 0) { + const phi::DenseTensor *tmp_input_x = + (encoder_remove_padding) ? &x_remove_padding : input_x; qkv_compute.ComputeForward( - qkv_weights[i], input_x, bias, &qkv_out, &qkv_out); + qkv_weights[i], tmp_input_x, bias, &qkv_out, &qkv_out); } else { qkv_compute.ComputeForward( qkv_weights[i], buf1, bias, &qkv_out, &qkv_out); } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step2"; -#endif // step3. fmha - const Tensor *cache_kv = cache_kvs.size() > 0 ? cache_kvs[i] : nullptr; - Tensor *cache_kv_out = cache_kv ? cache_kv_outs[i] : nullptr; + const phi::DenseTensor *cache_kv = + cache_kvs.size() > 0 ? cache_kvs[i] : nullptr; + phi::DenseTensor *cache_kv_out = cache_kv ? cache_kv_outs[i] : nullptr; if (time_step) { // generation decoder stage // [2, batch_size, num_head, max_seq_len, head_size] @@ -268,35 +371,89 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { qkv_out, *qkv_bias, *src_mask, + sequence_lengths, + rotary_tensor, + beam_cache_offset, cache_kv_out, &fmha_out, bsz, + beam_size, max_seq_len, num_head, dim_head, - time_step->data()[0], + time_step_cpu, + rotary_emb_dims, 1. / sqrt(dim_head)); } else if (cache_kv_out) { // generation context stage - // TODO(wangxi): can remove dropout in inference - fmha_compute.ComputeForward(qkv_out, - nullptr, - src_mask, - &transpose_out_2, - nullptr, - &qk_out, - nullptr, - &softmax_out, - &attn_dropout_mask_out, - &attn_dropout_out, - &qktv_out, - &fmha_out); - // [3, bsz, num_head, seq_len, head_dim] - T *qkv_data = transpose_out_2_data; - int64_t q_size = bsz * seq_len * num_head * dim_head; - int64_t k_size = q_size; - const T *q_ptr = qkv_data; - const T *k_ptr = q_ptr + q_size; - const T *v_ptr = k_ptr + k_size; + const phi::DenseTensor *pre_cache_kv_tensor = nullptr; + phi::DenseTensor *pre_cache_kv_out_tmp = nullptr; + phi::DenseTensor *src_mask_tmp = nullptr; + const int *sequence_lengths_data = + encoder_remove_padding ? sequence_lengths->data() : nullptr; + qkv_bias_add_transpose_split(dev_ctx, + q_transpose_out_data, + kv_transpose_out_data, + qkv_out_data, + qkv_bias->data(), + padding_offset_data, + token_num, + bsz, + num_head, + seq_len, + dim_head, + compute_bias); + + // q_transpose_out_data [bs, head_num, seq_len, dim_head] + // kv_transpose_out_data [2, bs, head_num, seq_len, dim_head] + if (rotary_emb_dims != 0) { + auto *rotary_emb_data = rotary_tensor->data(); + rotary_qk(dev_ctx, + q_transpose_out_data, + kv_transpose_out_data, + q_transpose_out_data, + kv_transpose_out_data, + rotary_emb_data, + sequence_lengths_data, + rotary_emb_dims, + bsz, + num_head, + seq_len, + dim_head); + } + + phi::DenseTensor *tmp_padding_offset_tensor = + encoder_remove_padding ? &padding_offset_tensor : nullptr; + fmha_compute.ComputeForwardWithoutTranspose(pre_cache_kv_tensor, + src_mask, + tmp_padding_offset_tensor, + &q_transpose_out, + &kv_transpose_out, + pre_cache_kv_out_tmp, + &qk_out, + src_mask_tmp, + &softmax_out, + &attn_dropout_mask_out, + &attn_dropout_out, + &qktv_out, + &fmha_out, + token_num); + + const T *k_ptr = nullptr; + const T *v_ptr = nullptr; + + if (cache_offset > 0) { + // [2, bsz, num_head, cache_offset + seq_len, head_dim] + const T *kv_data = pre_cache_kv_out.data(); + k_ptr = kv_data; + int64_t k_size = bsz * num_head * (seq_len + cache_offset) * dim_head; + v_ptr = k_ptr + k_size; + } else { + // [3, bsz, num_head, seq_len, head_dim] + int64_t k_size = bsz * seq_len * num_head * dim_head; + const T *q_ptr = q_transpose_out_data; + k_ptr = kv_transpose_out_data; + v_ptr = k_ptr + k_size; + } // [2, bsz, num_head, max_seq_len, head_dim] int max_seq_len = cache_kv_out->dims()[3]; @@ -306,35 +463,72 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { T *cache_k_ptr = cache_kv_data; T *cache_v_ptr = cache_kv_data + cache_k_size; + const int seq_len_tmp = seq_len + cache_offset; write_cache_kv(dev_ctx, cache_k_ptr, cache_v_ptr, k_ptr, v_ptr, + sequence_lengths_data, bsz, num_head, - seq_len, + seq_len_tmp, max_seq_len, dim_head); } else { // not generation // TODO(wangxi): can remove dropout in inference - fmha_compute.ComputeForward(qkv_out, - cache_kv, - src_mask, - &transpose_out_2, - cache_kv_out, - &qk_out, - nullptr, - &softmax_out, - &attn_dropout_mask_out, - &attn_dropout_out, - &qktv_out, - &fmha_out); + qkv_bias_add_transpose_split(dev_ctx, + q_transpose_out_data, + kv_transpose_out_data, + qkv_out_data, + qkv_bias->data(), + padding_offset_data, + token_num, + bsz, + num_head, + seq_len, + dim_head, + compute_bias); + + // q_transpose_out_data [bs, head_num, seq_len, dim_head] + // kv_transpose_out_data [2, bs, head_num, seq_len, dim_head] + if (rotary_emb_dims != 0) { + auto *rotary_emb_data = rotary_tensor->data(); + const int *sequence_lengths_data = + encoder_remove_padding ? sequence_lengths->data() : nullptr; + rotary_qk(dev_ctx, + q_transpose_out_data, + kv_transpose_out_data, + q_transpose_out_data, + kv_transpose_out_data, + rotary_emb_data, + sequence_lengths_data, + rotary_emb_dims, + bsz, + num_head, + seq_len, + dim_head); + } + phi::DenseTensor *tmp_padding_offset_tensor = + encoder_remove_padding ? &padding_offset_tensor : nullptr; + fmha_compute.ComputeForwardWithoutTranspose(cache_kv, + src_mask, + tmp_padding_offset_tensor, + &q_transpose_out, + &kv_transpose_out, + cache_kv_out, + &qk_out, + nullptr, + &softmax_out, + &attn_dropout_mask_out, + &attn_dropout_out, + &qktv_out, + &fmha_out, + token_num); } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step3"; #endif - if (pre_layer_norm) { out_linear_compute.ComputeForward( out_linear_weights[i], &fmha_out, nullptr, buf1, nullptr); @@ -390,25 +584,31 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { #endif // step6. ffn matmul1 - ffn1_linear_compute.ComputeForward( - ffn1_weights[i], buf1, nullptr, &ffn1_out, nullptr); + if (use_glu) { + ffn1_glu_helper.Compute(buf1, + ffn1_weights[i], + ffn1_biases[i], + &ffn1_out, + &ffn1_dropout_out); + } else { + ffn1_linear_compute.ComputeForward( + ffn1_weights[i], buf1, nullptr, &ffn1_out, nullptr); + } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step6"; #endif // step7. act bias // TODO(wangxi): remove dropout mask in inference - fused_act_dropout_helper.DropoutActBias(dev_ctx, - ffn1_out_data, - ffn1_biases[i]->data(), - "gelu", - ffn1_dropout_out_data, - ffn1_dropout_mask_data); -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step7"; -#endif - - // step8. ffn matmul2 + if (!use_glu) { + fused_act_dropout_helper.DropoutActBias(dev_ctx, + ffn1_out_data, + ffn1_biases[i]->data(), + act_method, + ffn1_dropout_out_data, + ffn1_dropout_mask_data); + } + // step8. ffn2 matmul if (pre_layer_norm) { ffn2_linear_compute.ComputeForward( ffn2_weights[i], &ffn1_dropout_out, nullptr, buf1, nullptr); @@ -480,6 +680,23 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { std::swap(buf0, buf1); } } + if (encoder_remove_padding) { + if (pre_layer_norm) { + InvokeRebuildPadding(dev_ctx, + from_data, + buf0->data(), + padding_offset_data, + token_num, + dim_embed); + } else { + InvokeRebuildPadding(dev_ctx, + from_data, + buf1->data(), + padding_offset_data, + token_num, + dim_embed); + } + } } }; diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.h b/paddle/fluid/operators/fused/fused_multi_transformer_op.h index 761a31ce094d1..94865f4415413 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.h +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.h @@ -1,12 +1,9 @@ /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,6 +13,8 @@ limitations under the License. */ // https://github.com/NVIDIA/FasterTransformer/blob/v4.0/fastertransformer/cuda/masked_multihead_attention.cu // We add License in the head. +#pragma once + #include #include @@ -32,22 +31,37 @@ limitations under the License. */ #include "paddle/phi/api/include/tensor.h" #include "paddle/phi/kernels/funcs/math_function.h" -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +// #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/distributed/collective/ProcessGroup.h" #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h" -#endif +// #endif + +#include +#include + +DECLARE_bool(gemm_use_half_precision_compute_type); namespace paddle { namespace operators { -using Tensor = framework::Tensor; +template +void print_tensor(const T *t, int size, const char *name){ + using namespace std; + ofstream out_txt_file; + out_txt_file.open(name, ios::out | ios::trunc); + out_txt_file << fixed; + for(int i=0; i < size; i++){ + out_txt_file << setprecision(8) << static_cast(t[i]) << endl; + } + out_txt_file.close(); +} // for debug // #define _DEBUG_FUSED_MULTI_TRANSFORMER template -static void AllReduce(framework::Tensor &tensor, // NOLINT +static void AllReduce(phi::DenseTensor &tensor, // NOLINT const int ring_id, const int count, const phi::GPUContext &ctx) { @@ -91,6 +105,9 @@ using float16 = plat::float16; #define MMHA_USE_FP32_ACUM_FOR_LOGITS #define MMHA_USE_FP32_ACUM_FOR_OUT +#define MMHA_USE_FP32_ACUM_FOR_FMA +// #define MMHA_USE_HMMA_FOR_REDUCTION + template struct Masked_multihead_attention_params { @@ -108,8 +125,18 @@ struct Masked_multihead_attention_params { // k [B, num_head, dim_head/x, max_seq_len, x], that is `seq_len` first // v [B, num_head, max_seq_len, dim_head] T *cache_kv; + // [B, max_seq_len] + const int* beam_cache_offset = nullptr; + + const int *sequence_lengths{nullptr}; + + // The RoPE embedding, [B, 1, 1, dim_head] + // rotary_emb_dims = 1 if pos_ids_extra is null else 2 + const T *rotary_emb; + int rotary_emb_dims; - int batch_size; + int batch_size; // batch * beam + int beam_width; int num_head; int timestep; // cache_seq_length int max_seq_length; @@ -153,6 +180,17 @@ template <> struct V_vec_ { using Type = uint32_t; }; template <> struct V_vec_ { using Type = uint2; }; template <> struct V_vec_ { using Type = uint4; }; +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA +template +struct K_vec_acum_fp32_ { +}; + +template<> +struct K_vec_acum_fp32_ { + using Type = float2; +}; +#endif + #ifdef MMHA_USE_FP32_ACUM_FOR_OUT template struct V_vec_acum_fp32_ {}; // template <> struct V_vec_acum_fp32_ { using Type = float; }; @@ -321,6 +359,15 @@ inline __device__ uint32_t mul(uint32_t a, float b) { return res; } +template <> +inline __device__ float2 mul(uint32_t a, float b) { + float2 tmp = half2_to_float2(a); + float2 res; + res.x = tmp.x * b; + res.y = tmp.y * b; + return res; +} + template <> inline __device__ uint2 mul(uint2 a, float b) { uint2 res; @@ -347,6 +394,15 @@ inline __device__ float2 mul(float2 a, float b) { return res; } +template <> +inline __device__ float2 mul(float2 a, uint32_t b) { + float2 tmp_b = half2_to_float2(b); + float2 res; + res.x = a.x * tmp_b.x; + res.y = a.y * tmp_b.y; + return res; +} + template <> inline __device__ float4 mul(float4 a, float b) { float4 res; @@ -357,6 +413,18 @@ inline __device__ float4 mul(float4 a, float b) { return res; } +template +inline __device__ Qk_vec apply_rotary_emb(Qk_vec input_left, + Qk_vec input_right, + Qk_vec cos_emb, + Qk_vec sin_emb, + float alpha) { + Qk_vec res1 = mul(input_left, cos_emb); + Qk_vec res2 = mul(input_right, sin_emb); + res2 = mul(res2, alpha); + return add(res1, res2); +} + inline __device__ float sum(float v) { return v; } inline __device__ float sum(float2 v) { return v.x + v.y; } inline __device__ float sum(float4 v) { return v.x + v.y + v.z + v.w; } @@ -406,6 +474,12 @@ inline __device__ float2 fma(float2 a, float2 b, float2 c) { return d; } +inline __device__ float2 fma(float2 a, uint32_t b, float2 c) { + float2 tmp_b = half2_to_float2(b); + float2 d = fma(a, tmp_b, c); + return d; +} + inline __device__ float4 fma(float4 a, float4 b, float4 c) { float4 d; d.x = fma(a.x, b.x, c.x); @@ -527,6 +601,50 @@ inline __device__ float qk_dot_(const K_vec (&q)[N], return qk; } +inline __device__ float4 hmma_fp32_tensorcore(const uint2 &a, uint32_t b) { + float4 c; + float zero = 0.f; + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5}, \n" + " {%6}, \n" + " {%7, %7, %7, %7}; \n" + + : "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w) + : "r"(a.x) "r"(a.y), "r"(b), "f"(zero)); + return c; +} + +template +inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], + const uint32_t (&k)[N], + float inv_sqrt_dh) { +#if defined(MMHA_USE_HMMA_FOR_REDUCTION) && defined(__CUDA_ARCH__) && \ + __CUDA_ARCH__ >= 750 +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + using K_vec_acum = typename K_vec_acum_fp32_::Type; +#else + using K_vec_acum = uint32_t; +#endif + K_vec_acum inv_q = mul(q[0], inv_sqrt_dh); + K_vec_acum qk_vec = mul(inv_q, k[0]); +#pragma unroll + for (int ii = 1; ii < N; ++ii) { + inv_q = mul(q[ii], inv_sqrt_dh); + qk_vec = fma(inv_q, k[ii], qk_vec); + } +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + uint32_t qk_vec_ = float2_to_half2(qk_vec); + return hmma_fp32_tensorcore(make_uint2(qk_vec_, 0u), 0x3c003c00u).x; +#else + return hmma_fp32_tensorcore(make_uint2(qk_vec, 0u), 0x3c003c00u).x; +#endif +#else + return 0.f; +#endif +} + template struct Qk_dot { template @@ -537,6 +655,21 @@ struct Qk_dot { } }; +template <> +struct Qk_dot { + template + static inline __device__ float dot(const uint32_t (&q)[N], + const uint32_t (&k)[N], + float inv_sqrt_dh) { +#if defined(MMHA_USE_HMMA_FOR_REDUCTION) && defined(__CUDA_ARCH__) && \ + __CUDA_ARCH__ >= 750 + return qk_hmma_dot_(q, k, inv_sqrt_dh); +#else + return qk_dot_<4>(q, k, inv_sqrt_dh); +#endif + } +}; + template inline __device__ float block_sum(float *red_smem, float sum) { int warp = threadIdx.x / WARP_SIZE; @@ -630,14 +763,24 @@ __global__ void masked_multihead_attention_kernel( using Qk_vec = typename Qk_vec_::Type; __shared__ __align__(sizeof(Qk_vec)) T q_smem[Dh_MAX]; + // batch * beam idx const int bi = blockIdx.y; + // beam id + const int beami = bi % params.beam_width; + // real batch id + const int bbi = bi / params.beam_width; const int hi = blockIdx.x; const int bhi = bi * params.num_head + hi; + const int bbhi = bbi * params.beam_width * params.num_head + hi; const int tid = threadIdx.x; - + const int bi_seq_len_offset = bi * params.max_seq_length; float qk_max = -FLT_MAX; float qk = 0; + int act_time_step = params.sequence_lengths == nullptr + ? params.timestep + : params.sequence_lengths[bi]; + // qkv [B, S=1, 3, num_head, head_dim] int qkv_base_offset = bi * 3 * params.num_head * Dh + hi * Dh; @@ -690,13 +833,74 @@ __global__ void masked_multihead_attention_kernel( // we may not require k_bias. k = add(k, k_bias); + // rotary pos emb + if (params.rotary_emb_dims != 0) { + int last_dim = Dh / params.rotary_emb_dims; + int half_lastdim = last_dim / 2; + int rotary_offset = bi * Dh + tid * QK_VEC_SIZE; + const T *cos_base = params.rotary_emb; + const T *sin_base = params.rotary_emb + params.batch_size * Dh; + int stride = half_lastdim / QK_VEC_SIZE; + int stride_all_lastdim = 2 * stride; + int right_id = tid / stride_all_lastdim * stride_all_lastdim + + (tid + stride) % (stride_all_lastdim); + int qk_right_offset = qkv_base_offset + right_id * QK_VEC_SIZE; + int qk_right_bias_offset = hi * Dh + right_id * QK_VEC_SIZE; + Qk_vec q_right; + zero(q_right); + q_right = + (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) + ? *reinterpret_cast(&q_base[qk_right_offset]) + : q_right; + Qk_vec k_right; + zero(k_right); + k_right = + (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) + ? *reinterpret_cast(&k_base[qk_right_offset]) + : k_right; + + Qk_vec q_right_bias; + zero(q_right_bias); + q_right_bias = (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) + ? *reinterpret_cast( + &q_bias_base[qk_right_bias_offset]) + : q_right_bias; + Qk_vec k_right_bias; + zero(k_right_bias); + k_right_bias = (Dh == Dh_MAX || right_id * QK_VEC_SIZE < Dh) + ? *reinterpret_cast( + &k_bias_base[qk_right_bias_offset]) + : k_right_bias; + + q_right = add(q_right, q_right_bias); + k_right = add(k_right, k_right_bias); + + Qk_vec cos_emb; + zero(cos_emb); + cos_emb = + (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) + ? *reinterpret_cast(&cos_base[rotary_offset]) + : cos_emb; + + Qk_vec sin_emb; + zero(sin_emb); + sin_emb = + (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) + ? *reinterpret_cast(&sin_base[rotary_offset]) + : sin_emb; + float alpha = (tid % stride_all_lastdim) < stride ? static_cast(-1) + : static_cast(1); + q = apply_rotary_emb(q, q_right, cos_emb, sin_emb, alpha); + k = apply_rotary_emb(k, k_right, cos_emb, sin_emb, alpha); + } + *reinterpret_cast(&q_smem[tid * QK_VEC_SIZE]) = q; int co = tid / QK_VECS_IN_16B; int ci = (tid % QK_VECS_IN_16B) * QK_VEC_SIZE; int offset = bhi * params.max_seq_length * Dh + co * params.max_seq_length * QK_ELTS_IN_16B + - params.timestep * QK_ELTS_IN_16B + ci; + act_time_step * QK_ELTS_IN_16B + ci; if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { *reinterpret_cast(¶ms.cache_kv[offset]) = k; } @@ -710,6 +914,7 @@ __global__ void masked_multihead_attention_kernel( } } } + if (QK_VECS_PER_WARP > WARP_SIZE) { constexpr int WARPS_PER_RED = (QK_VECS_PER_WARP + WARP_SIZE - 1) / WARP_SIZE; @@ -722,7 +927,7 @@ __global__ void masked_multihead_attention_kernel( // qk += static_cast(mask); qk *= params.inv_sqrt_dh; qk_max = qk; - qk_smem[params.timestep] = qk; + qk_smem[act_time_step] = qk; } __syncthreads(); @@ -735,14 +940,14 @@ __global__ void masked_multihead_attention_kernel( __syncthreads(); #endif - using K_vec = typename K_vec_::Type; - constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(T); + using K_vec = typename K_vec_::Type; // uint2 + constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(T); // 2 static_assert(Dh_MAX % K_VEC_SIZE == 0, ""); - constexpr int K_ELTS_PER_THREAD = Dh_MAX / THREADS_PER_KEY; - constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE; + constexpr int K_ELTS_PER_THREAD = Dh_MAX / THREADS_PER_KEY; // 32 + constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE; // 16 - int ko = tid / THREADS_PER_KEY; - int ki = (tid % THREADS_PER_KEY) * K_VEC_SIZE; + int ko = tid / THREADS_PER_KEY; // 0 ~ 63 + int ki = (tid % THREADS_PER_KEY) * K_VEC_SIZE; // 0 or 2 static_assert(Dh_MAX == THREADS_PER_KEY * K_VEC_SIZE * K_VECS_PER_THREAD, ""); @@ -753,11 +958,13 @@ __global__ void masked_multihead_attention_kernel( &q_smem[ki + i * THREADS_PER_KEY * K_VEC_SIZE]); } - constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY; - constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; + constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY; //128/2 = 64 + constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; // 32/2 = 16 T *k_cache = ¶ms.cache_kv[bhi * params.max_seq_length * Dh + ki]; - int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP; + T *k_cache_batch = ¶ms.cache_kv[bbhi * params.max_seq_length * Dh + ki]; + int ti_end = div_up(act_time_step, K_PER_WARP) * K_PER_WARP; // 160 + const int *beam_offsets = params.beam_cache_offset ? ¶ms.beam_cache_offset[bi_seq_len_offset] : nullptr; for (int ti = ko; ti < ti_end; ti += K_PER_ITER) { K_vec k[K_VECS_PER_THREAD]; @@ -765,12 +972,19 @@ __global__ void masked_multihead_attention_kernel( zero(k_vec_zero); #pragma unroll for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + // get beam_offset of this location + const int beam_offset = beam_offsets ? beam_offsets[ti] * params.num_head * params.max_seq_length * Dh : 0; int jj = ii * params.max_seq_length + ti; - if (ti < params.timestep) { + if (ti < act_time_step) { + // k[ii] = + // (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.max_seq_length) + // ? *reinterpret_cast( + // &k_cache[jj * QK_ELTS_IN_16B]) + // : k_vec_zero; k[ii] = (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.max_seq_length) ? *reinterpret_cast( - &k_cache[jj * QK_ELTS_IN_16B]) + &k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B]) : k_vec_zero; } } @@ -780,7 +994,7 @@ __global__ void masked_multihead_attention_kernel( float qk = Qk_dot::dot(q, k, params.inv_sqrt_dh); // bool is_mask = false; - if (ti < params.timestep && tid % THREADS_PER_KEY == 0) { + if (ti < act_time_step && tid % THREADS_PER_KEY == 0) { // qk_max = is_mask ? qk_max : fmaxf(qk_max, qk); T mask = params.attn_mask[bi * (params.timestep + 1) + ti]; qk += static_cast(mask); @@ -822,7 +1036,7 @@ __global__ void masked_multihead_attention_kernel( #endif float sum = 0.f; - for (int ti = tid; ti <= params.timestep; ti += THREADS_PER_BLOCK) { + for (int ti = tid; ti <= act_time_step; ti += THREADS_PER_BLOCK) { // bool is_mask = false; // float logit = is_mask ? 0.f : __expf(qk_smem[ti] - qk_max); float logit = __expf(qk_smem[ti] - qk_max); @@ -834,7 +1048,7 @@ __global__ void masked_multihead_attention_kernel( // FIXME(wangxi): need add 1.e-6f? float inv_sum = __fdividef(1.f, sum + 1.e-6f); - for (int ti = tid; ti <= params.timestep; ti += THREADS_PER_BLOCK) { + for (int ti = tid; ti <= act_time_step; ti += THREADS_PER_BLOCK) { convert_from_float(logits_smem[ti], qk_smem[ti] * inv_sum); } __syncthreads(); @@ -848,6 +1062,9 @@ __global__ void masked_multihead_attention_kernel( T *v_cache = ¶ms.cache_kv[params.batch_size * params.num_head * params.max_seq_length * Dh + bhi * params.max_seq_length * Dh + vi]; + T *v_cache_batch = ¶ms.cache_kv[params.batch_size * params.num_head * + params.max_seq_length * Dh + + bbhi * params.max_seq_length * Dh + vi]; #ifdef MMHA_USE_FP32_ACUM_FOR_OUT using V_vec_acum = typename V_vec_acum_fp32_::Type; @@ -860,8 +1077,9 @@ __global__ void masked_multihead_attention_kernel( constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; if (Dh == Dh_MAX || vi < Dh) { - for (int ti = vo; ti < params.timestep; ti += V_PER_ITER) { - V_vec v = *reinterpret_cast(&v_cache[ti * Dh]); + for (int ti = vo; ti < act_time_step; ti += V_PER_ITER) { + const int beam_offset = beam_offsets ? beam_offsets[ti] * params.num_head * params.max_seq_length * Dh : 0; + V_vec v = *reinterpret_cast(&v_cache_batch[beam_offset + ti * Dh]); #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) float logit = logits_smem[ti]; out = fma(logit, cast_to_float(v), out); @@ -884,18 +1102,18 @@ __global__ void masked_multihead_attention_kernel( V_vec v_bias; zero(v_bias); - if (vo == (params.timestep % V_PER_ITER) && (Dh == Dh_MAX || vi < Dh)) { + if (vo == (act_time_step % V_PER_ITER) && (Dh == Dh_MAX || vi < Dh)) { V_vec v = *reinterpret_cast( ¶ms.qkv[2 * params.num_head * Dh + qkv_base_offset + vi]); v_bias = *reinterpret_cast( ¶ms.qkv_bias[2 * params.num_head * Dh + hi * Dh + vi]); v = add(v, v_bias); - *reinterpret_cast(&v_cache[params.timestep * Dh]) = v; + *reinterpret_cast(&v_cache[act_time_step * Dh]) = v; #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) - out = fma(logits_smem[params.timestep], cast_to_float(v), out); + out = fma(logits_smem[act_time_step], cast_to_float(v), out); #else - out = fma(logits_smem[params.timestep], v, out); + out = fma(logits_smem[act_time_step], v, out); #endif } @@ -970,18 +1188,17 @@ inline size_t smem_size_in_bytes( return max(softmax_sz, red_sz); } -#define MMHA_LAUNCH_KERNEL( \ - T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, stream) \ - size_t smem_sz = \ - smem_size_in_bytes(params, Dh, THDS_PER_VALUE, THDS_PER_BLOCK); \ - dim3 grid(params.num_head, params.batch_size); \ - masked_multihead_attention_kernel \ - <<>>(params) +#define MMHA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, stream) \ + size_t smem_sz = \ + smem_size_in_bytes(params, Dh, THDS_PER_VALUE, THDS_PER_BLOCK); \ + constexpr auto kernel_fn = masked_multihead_attention_kernel; \ + if (smem_sz > 0xc000) { \ + cudaFuncSetAttribute( \ + kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \ + } \ + dim3 grid(params.num_head, params.batch_size); \ + kernel_fn<<>>(params) template void fmha_launch_kernel(const Masked_multihead_attention_params ¶ms, @@ -990,7 +1207,12 @@ void fmha_launch_kernel(const Masked_multihead_attention_params ¶ms, if (params.timestep < 32) { MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, stream); } else if (params.timestep < 2048) { +#if defined(MMHA_USE_HMMA_FOR_REDUCTION) && defined(__CUDA_ARCH__) && \ + __CUDA_ARCH__ >= 750 + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 256, stream); +#else MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, stream); +#endif } else { MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, stream); } @@ -998,16 +1220,21 @@ void fmha_launch_kernel(const Masked_multihead_attention_params ¶ms, template void fmha(const phi::GPUContext &dev_ctx, - const Tensor &qkv_tensor, - const Tensor &qkv_bias_tensor, - const Tensor &src_mask_tensor, - Tensor *cache_kv_tensor, - Tensor *out_tensor, + const phi::DenseTensor &qkv_tensor, + const phi::DenseTensor &qkv_bias_tensor, + const phi::DenseTensor &src_mask_tensor, + const phi::DenseTensor *sequence_lengths_tensor, + const phi::DenseTensor *rotary_tensor, + const phi::DenseTensor *beam_cache_offset_tensor, + phi::DenseTensor *cache_kv_tensor, + phi::DenseTensor *out_tensor, int batch_size, + int beam_width, int max_seq_length, int num_head, int dim_head, int timestep, + int rotary_emb_dims, float inv_sqrt_dh) { Masked_multihead_attention_params params; params.out = out_tensor->data(); @@ -1016,11 +1243,28 @@ void fmha(const phi::GPUContext &dev_ctx, params.attn_mask = src_mask_tensor.data(); params.cache_kv = cache_kv_tensor->data(); + if (sequence_lengths_tensor) { + params.sequence_lengths = sequence_lengths_tensor->data(); + } + + if (rotary_emb_dims > 0) { + params.rotary_emb = rotary_tensor->data(); + } else { + params.rotary_emb = nullptr; + } + + if (beam_cache_offset_tensor) { + // LOG(INFO) << "beam_cache_offset_tensor.dims: " << beam_cache_offset_tensor->dims().to_str(); + params.beam_cache_offset = beam_cache_offset_tensor->data(); + } + params.batch_size = batch_size; + params.beam_width = beam_width; params.num_head = num_head; params.timestep = timestep; params.max_seq_length = max_seq_length; params.inv_sqrt_dh = inv_sqrt_dh; + params.rotary_emb_dims = rotary_emb_dims; switch (dim_head) { case 10: @@ -1050,17 +1294,54 @@ void fmha(const phi::GPUContext &dev_ctx, } } +template +void fmha(const phi::GPUContext &dev_ctx, + const phi::DenseTensor &qkv_tensor, + const phi::DenseTensor &qkv_bias_tensor, + const phi::DenseTensor &src_mask_tensor, + phi::DenseTensor *cache_kv_tensor, + phi::DenseTensor *out_tensor, + int batch_size, + int max_seq_length, + int num_head, + int dim_head, + int timestep, + float inv_sqrt_dh) { + fmha(dev_ctx, + qkv_tensor, + qkv_bias_tensor, + src_mask_tensor, + nullptr, + nullptr, + nullptr, + cache_kv_tensor, + out_tensor, + batch_size, + 1, + max_seq_length, + num_head, + dim_head, + timestep, + 0, + inv_sqrt_dh); +} + // NOTE: simd with 16Bytes(128bit), float is 4, float16 is 8 constexpr int VEC_16B = 16; template __global__ void write_cache_k_kernel(T *cache_k, const T *k, + const int *seq_lens, const int num_head, const int dim_head, const int seq_len, const int max_seq_len) { const int bi = blockIdx.y; + if (seq_lens && seq_lens[bi] == 0) { + return; + } + const int hi = blockIdx.z; constexpr int X_ELEMS = VEC_16B / sizeof(T); @@ -1094,11 +1375,16 @@ __global__ void write_cache_k_kernel(T *cache_k, template __global__ void write_cache_v_kernel(T *cache_v, const T *v, + const int *seq_lens, const int num_head, const int dim_head, const int seq_len, const int max_seq_len) { const int bi = blockIdx.y; + if (seq_lens && seq_lens[bi] == 0) { + return; + } + const int hi = blockIdx.z; // [bsz, num_head, seq_len, dim_head/x, x] @@ -1124,6 +1410,7 @@ void write_cache_kv(const phi::GPUContext &dev_ctx, T *cache_v, const T *k, const T *v, + const int *seq_lens, const int bsz, const int num_head, const int seq_len, @@ -1147,14 +1434,496 @@ void write_cache_kv(const phi::GPUContext &dev_ctx, // transpose [bsz, num_head, seq_len, dim_head/x, x]-> // [bsz, num_head, dim_head/x, max_seq_len, x] write_cache_k_kernel<<>>( - cache_k, k, num_head, dim_head, seq_len, max_seq_len); + cache_k, k, seq_lens, num_head, dim_head, seq_len, max_seq_len); // copy [bsz, num_head, seq_len, dim_head/x, x]-> // [bsz, num_head, max_seq_len, dim_head/x, x] write_cache_v_kernel<<>>( - cache_v, v, num_head, dim_head, seq_len, max_seq_len); + cache_v, v, seq_lens, num_head, dim_head, seq_len, max_seq_len); +} + +template +void write_cache_kv(const phi::GPUContext &dev_ctx, + T *cache_k, + T *cache_v, + const T *k, + const T *v, + const int bsz, + const int num_head, + const int seq_len, + const int max_seq_len, + const int dim_head) { + write_cache_kv(dev_ctx, + cache_k, + cache_v, + k, v, nullptr, + bsz, num_head, seq_len, + max_seq_len, dim_head); +} + +template +__global__ void add_fusedQKV_bias_transpose_split_kernel( + T *q_buf, + T *kv_buf, + const T *qkv, + const T *qkv_bias, + const int *padding_offset, + const int32_t elem_cnt, + const int batch_size, + const int seq_len, + const int token_num, + const int head_num, + const int size_per_head) { + const int32_t offset = batch_size * seq_len * head_num * size_per_head; + const int32_t hidden_size = head_num * size_per_head; + const int32_t fused_hidden_size = 3 * hidden_size; + int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; + using LoadT = phi::AlignedVector; + LoadT src_vec; + LoadT bias_vec; + + for (int32_t linear_index = global_thread_idx * VecSize, + step = gridDim.x * blockDim.x * VecSize; + linear_index < elem_cnt; + linear_index += step) { + phi::Load(&qkv[linear_index], &src_vec); + int32_t bias_idx = linear_index % fused_hidden_size; + if (ComputeBias) { + phi::Load(&qkv_bias[bias_idx], &bias_vec); +#pragma unroll + for (int32_t unroll_idx = 0; unroll_idx < VecSize; unroll_idx++) { + src_vec[unroll_idx] += bias_vec[unroll_idx]; + } + } + const int32_t token_idx = linear_index / fused_hidden_size; + const int32_t ori_token_idx = + token_idx + (padding_offset == nullptr ? 0 : padding_offset[token_idx]); + const int32_t target_batch_id = ori_token_idx / seq_len; + const int32_t seq_id = ori_token_idx % seq_len; + + // equal to: + // const int qkv_id = (linear_index % fused_hidden_size) / hidden_size; + const int32_t qkv_id = bias_idx / hidden_size; + const int32_t head_id = (linear_index % hidden_size) / size_per_head; + const int32_t size_id = linear_index % size_per_head; + + if (qkv_id == 0) { + phi::Store( + src_vec, + &q_buf[target_batch_id * head_num * seq_len * size_per_head + + head_id * seq_len * size_per_head + seq_id * size_per_head + + size_id]); + } else { + const int32_t kv_store_offset = (qkv_id - 1) * offset; + phi::Store( + src_vec, + &kv_buf[kv_store_offset + + target_batch_id * head_num * seq_len * size_per_head + + head_id * seq_len * size_per_head + seq_id * size_per_head + + size_id]); + } + } +} + +inline cudaError_t GetNumBlocks(int64_t n, int *num_blocks) { + constexpr int kBlockSize = 128; + constexpr int kNumWaves = 16; + + const int device_id = phi::backends::gpu::GetCurrentDeviceId(); + const int sm_count = phi::backends::gpu::GetGPUMultiProcessors(device_id); + const int max_thread_per_multiprocessor = + phi::backends::gpu::GetGPUMultiProcessors(device_id); + + *num_blocks = + std::max(1, + std::min((n + kBlockSize - 1) / kBlockSize, + sm_count * max_thread_per_multiprocessor / + kBlockSize * kNumWaves)); + return cudaSuccess; +} + +template +void qkv_bias_add_transpose_split(const phi::GPUContext &dev_ctx, + T *q_buf, + T *kv_buf, + const T *qkv, + const T *qkv_bias, + const int *padding_offset, + const int token_num, + const int batch_size, + const int head_num, + const int seq_len, + const int size_per_head, + bool compute_bias) { + const int32_t elem_cnt = token_num * head_num * size_per_head * 3; + constexpr int PackSize = VEC_16B / sizeof(T); + PADDLE_ENFORCE_EQ(size_per_head % PackSize, + 0, + platform::errors::PreconditionNotMet( + "dim_head=%d must be divisible by vec_size=%d", + size_per_head, + PackSize)); + const int32_t pack_num = elem_cnt / PackSize; + const int32_t blocksize = 128; + int32_t grid_size = 1; + GetNumBlocks(pack_num, &grid_size); + if (compute_bias) { + add_fusedQKV_bias_transpose_split_kernel + <<>>(q_buf, + kv_buf, + qkv, + qkv_bias, + padding_offset, + elem_cnt, + batch_size, + seq_len, + token_num, + head_num, + size_per_head); + } else { + add_fusedQKV_bias_transpose_split_kernel + <<>>(q_buf, + kv_buf, + qkv, + qkv_bias, + padding_offset, + elem_cnt, + batch_size, + seq_len, + token_num, + head_num, + size_per_head); + } +} + +template +__global__ void RotrayKernel(const T *input, + const T *cos_emb, + const T *sin_emb, + const int *sequence_lengths, + T *output, + const int rotary_emb_dims, + const int batch_size, + const int head_num, + const int seq_len, + const int last_dim) { + int bi = blockIdx.x; + int hi = blockIdx.y; + int si = blockIdx.z; + if (sequence_lengths && si >= sequence_lengths[bi] * rotary_emb_dims) return; + int half_lastdim = last_dim / 2; + // Note(ZhenyuLi): Calculate the relevant data at one time, so that no + // additional space is required. + for (int ti = threadIdx.x; ti < half_lastdim; ti += blockDim.x) { + int base_idx = bi * head_num * seq_len * last_dim + + hi * seq_len * last_dim + si * last_dim; + int left_idx = base_idx + ti; + const int right_idx = base_idx + ti + half_lastdim; + int emb_idx = bi * seq_len * last_dim + si * last_dim + ti; + T input_left = input[left_idx]; + T input_right = input[right_idx]; + T cos_tmp = cos_emb[emb_idx]; + T sin_tmp = sin_emb[emb_idx]; + T res1 = input_left * cos_tmp - input_right * sin_tmp; + T res2 = input_right * cos_tmp + input_left * sin_tmp; + output[left_idx] = res1; + output[right_idx] = res2; + } +} + +template +void rotary_qk(const phi::GPUContext &dev_ctx, + T *q, + T *k, // kv + const T *q_input, // q + const T *k_input, // kv + const T *rotary_emb, + const int *sequence_lengths, + const int rotary_emb_dims, + const int batch_size, + const int head_num, + const int seq_len, + const int dim_head) { + // q_transpose_out_data [bs, head_num, seq_len, dim_head] -> [bs, head_num, + // seq_len * rotary_emb_dims, dim_head / rotary_emb_dims] + // kv_transpose_out_data [bs, head_num, seq_len, dim_head] -> [bs, head_num, + // seq_len * rotary_emb_dims, dim_head / rotary_emb_dims] rotary_emb [2, bs, + // 1, seq_len, dim_head] -> [2, bs, 1, seq_len * rotary_emb_dims, dim_head / + // rotary_emb_dims] + dim3 grid(batch_size, head_num, seq_len * rotary_emb_dims); + const int last_dim = dim_head / rotary_emb_dims; + auto getBlockSize = [](int dim) { + if (dim > 256) { + return 512; + } else if (dim > 128) { + return 256; + } else if (dim > 64) { + return 128; + } else if (dim > 32) { + return 64; + } else { + return 32; + } + }; + int BlockSize = getBlockSize(last_dim / 2); + const T *cos_emb = rotary_emb; + const T *sin_emb = rotary_emb + batch_size * seq_len * dim_head; + RotrayKernel<<>>( + q_input, + cos_emb, + sin_emb, + sequence_lengths, + q, + rotary_emb_dims, + batch_size, + head_num, + seq_len * rotary_emb_dims, + last_dim); + RotrayKernel<<>>( + k_input, + cos_emb, + sin_emb, + sequence_lengths, + k, + rotary_emb_dims, + batch_size, + head_num, + seq_len * rotary_emb_dims, + last_dim); +} + +__global__ void GetPaddingOffset(int *d_token_num, + int *padding_offset, + const int *sequence_lengths, + const int batch_size, + const int max_seq_len) { + // get padding offset of each batch + int total_seq_len = 0; + int cum_offset = 0; + int index = 0; + for (int i = 0; i < batch_size; i++) { + const int seq_len = sequence_lengths[i]; + for (int j = 0; j < seq_len; j++) { + padding_offset[index] = cum_offset; + index++; + } + cum_offset += max_seq_len - seq_len; + total_seq_len += seq_len; + } + d_token_num[0] = total_seq_len; +} + +void InvokeGetPaddingOffset(const phi::GPUContext &dev_ctx, + int *h_token_num, + int *d_token_num, + int *padding_offset, + const int *sequence_lengths, + const int batch_size, + const int max_seq_len) { + GetPaddingOffset<<<1, 1, 0, dev_ctx.stream()>>>( + d_token_num, padding_offset, sequence_lengths, batch_size, max_seq_len); + memory::Copy(platform::CPUPlace(), + h_token_num, + dev_ctx.GetPlace(), + d_token_num, + sizeof(int), + dev_ctx.stream()); } +template +__global__ void RemovePadding(T *output_data, + const T *input_data, + const int *padding_offset, + const int dim_embed) { + const int tid = threadIdx.x; + const int bid = blockIdx.x; + const int src_seq_id = bid + padding_offset[bid]; + const int tgt_seq_id = bid; + + for (int i = tid; i < dim_embed; i += blockDim.x) { + output_data[tgt_seq_id * dim_embed + i] = + input_data[src_seq_id * dim_embed + i]; + } +} + +template +void InvokeRemovePadding(const phi::GPUContext &dev_ctx, + T *output_data, + const T *input_data, + const int *padding_offset, + const int token_num, + const int dim_embed) { + RemovePadding<<>>( + output_data, input_data, padding_offset, dim_embed); +} + +template +__global__ void RebuildPadding(T *output_data, + const T *input_data, + const int *padding_offset, + const int dim_embed) { + const int tid = threadIdx.x; + const int bid = blockIdx.x; + const int dst_seq_id = bid + padding_offset[bid]; + const int src_seq_id = bid; + + for (int i = tid; i < dim_embed; i += blockDim.x) { + output_data[dst_seq_id * dim_embed + i] = + input_data[src_seq_id * dim_embed + i]; + } +} + +template +void InvokeRebuildPadding(const phi::GPUContext &dev_ctx, + T *output_data, + const T *input_data, + const int *padding_offset, + const int token_num, + const int dim_embed) { + // src: [token_num, dim_embed] + // dst: [batch_size * max_seq_len, dim_embed] + RebuildPadding<<>>( + output_data, input_data, padding_offset, dim_embed); +} + +template +__global__ void InitOutValueKernel(T *output_data, + const int64_t numel, + const T init_value) { + const int tid = threadIdx.x; + const int bid = blockIdx.x; + int64_t global_thread_idx = bid * blockDim.x + tid; + + for (int linear_index = global_thread_idx * VecSize, + step = gridDim.x * blockDim.x * VecSize; + linear_index < numel; + linear_index += step) { + for (int i = 0; i < VecSize; i ++) { + output_data[linear_index + i] = init_value; + } + } +} + +template +void InitValue(const phi::GPUContext &dev_ctx, + T *output_data, + const int64_t numel, + const T init_value) { + constexpr int PackSize = VEC_16B / sizeof(T); + PADDLE_ENFORCE_EQ(numel % PackSize, + 0, + platform::errors::PreconditionNotMet( + "numel=%d must be divisible by vec_size=%d", + numel, + PackSize)); + const int pack_num = numel / PackSize; + const int blocksize = 128; + int grid_size = 1; + GetNumBlocks(pack_num, &grid_size); + InitOutValueKernel<<>>( + output_data, numel, init_value); +} + +template +__global__ void ActFFNGlu(const T *input, + T *output, + Functor act_functor, + const int token_num, + const int hid_dim, + const int elem_num) { + using LoadT = phi::AlignedVector; + LoadT src_vec1; + LoadT src_vec2; + const int global_tid = blockIdx.x * blockDim.x + threadIdx.x; + for (int i = global_tid * VecSize; i < elem_num; + i += gridDim.x * blockDim.x * VecSize) { + int bi = i / hid_dim; + int idx = i % hid_dim; + const T *input_this_thread = input + bi * hid_dim * 2; + T *output_this_thread = output + bi * hid_dim; + phi::Load(&input_this_thread[idx], &src_vec1); + phi::Load(&input_this_thread[idx + hid_dim], &src_vec2); +#pragma unroll + for (int j = 0; j < VecSize; j++) { + src_vec1[j] = act_functor(src_vec1[j]); + src_vec1[j] *= src_vec2[j]; + } + phi::Store(src_vec1, &output_this_thread[idx]); + } +} + +template +class FFNGluHelper { + public: + FFNGluHelper(const phi::GPUContext &dev_ctx, + const std::string &act_method, + int token_num, + int hid_dim, + int dim_ffn, + int dim_embed) + : dev_ctx_(dev_ctx), + act_method_(act_method), + token_num_(token_num), + hid_dim_(hid_dim), + dim_ffn_(dim_ffn), + dim_embed_(dim_embed) {} + + // dst = act(fc(src[0]) + bias) * src[1] + void Compute(const phi::DenseTensor *input, + const phi::DenseTensor *weight, + const phi::DenseTensor *bias, + phi::DenseTensor *bias_out, + phi::DenseTensor *output) { + // input's shape [token_num, dim_ffn], bias' shape [dim_ffn] + // output's shape [token_num, hid_dim], bias_out's shape [token_num, + // dim_ffn] + auto ffn_linear_compute = AttnMatMul( + dev_ctx_, false, false, token_num_, dim_ffn_, dim_embed_, true); + ffn_linear_compute.ComputeForward(weight, input, bias, bias_out, bias_out); + + using Functor = GeluFunctor; + + Functor functor; + constexpr int VecSize = 16; + constexpr int PackSize = VecSize / sizeof(T); + const int elem_cnt = token_num_ * hid_dim_; + const int blocksize = 128; + int grid_size = 1; + switch (hid_dim_ % PackSize) { + case 0: + GetNumBlocks(elem_cnt / PackSize, &grid_size); + ActFFNGlu + <<>>( + bias_out->data(), + output->data(), + functor, + token_num_, + hid_dim_, + elem_cnt); + break; + default: + GetNumBlocks(elem_cnt, &grid_size); + ActFFNGlu + <<>>( + bias_out->data(), + output->data(), + functor, + token_num_, + hid_dim_, + elem_cnt); + break; + } + } + + private: + const phi::GPUContext &dev_ctx_; + std::string act_method_; + int token_num_; + int hid_dim_; + int dim_ffn_; + int dim_embed_; +}; + } // namespace } // namespace operators diff --git a/paddle/fluid/operators/fused/layernorm_quant_dequant.h b/paddle/fluid/operators/fused/layernorm_quant_dequant.h new file mode 100644 index 0000000000000..6d9fccbcb6e48 --- /dev/null +++ b/paddle/fluid/operators/fused/layernorm_quant_dequant.h @@ -0,0 +1,1076 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +namespace paddle { +namespace operators { + +constexpr int kWarpSize = 32; + +template +struct SumOp { + __device__ __forceinline__ T operator()(const T& a, const T& b) const { return a + b; } +}; + +template +struct MaxOp { + __device__ __forceinline__ T operator()(const T& a, const T& b) const { return max(a, b); } +}; + +template class ReductionOp, typename T, int thread_group_width = kWarpSize> +__inline__ __device__ T WarpAllReduce(T val) { + for (int mask = thread_group_width / 2; mask > 0; mask /= 2) { + val = ReductionOp()(val, __shfl_xor_sync(0xffffffff, val, mask, thread_group_width)); + } + return val; +} + +template class ReductionOp, typename T, int block_size> +__inline__ __device__ T BlockAllReduce(T val) { + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ T result_broadcast; + T result = BlockReduce(temp_storage).Reduce(val, ReductionOp()); + if (threadIdx.x == 0) { result_broadcast = result; } + __syncthreads(); + return result_broadcast; +} + +template +__inline__ __device__ T Div(T a, T b); + +template<> +__inline__ __device__ float Div(float a, float b) { +#ifdef OF_LAYER_NORM_USE_FAST_MATH + return __fdividef(a, b); +#else + return a / b; +#endif +} + +template<> +__inline__ __device__ double Div(double a, double b) { + return a / b; +} + +template +__inline__ __device__ T Rsqrt(T x); + +template<> +__inline__ __device__ float Rsqrt(float x) { +#ifdef OF_LAYER_NORM_USE_FAST_MATH + return __frsqrt_rn(x); +#else + return rsqrt(x); +#endif +} + +template<> +__inline__ __device__ double Rsqrt(double x) { + return rsqrt(x); +} + +template +inline cudaError_t GetNumBlocks(Func func, int64_t block_size, size_t dynamic_smem_size, + int64_t max_blocks, int64_t waves, int* num_blocks) { + int dev; + { + cudaError_t err = cudaGetDevice(&dev); + if (err != cudaSuccess) { return err; } + } + int sm_count; + { + cudaError_t err = cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev); + if (err != cudaSuccess) { return err; } + } + int max_active_blocks; + { + cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, func, + block_size, dynamic_smem_size); + } + *num_blocks = + std::max(1, std::min(max_blocks, sm_count * max_active_blocks * waves)); + return cudaSuccess; +} + +template +struct DefaultComputeType { + using type = T; +}; + +template<> +struct DefaultComputeType { + using type = float; +}; + +#if CUDA_VERSION >= 11000 +template<> +struct DefaultComputeType { + using type = float; +}; +#endif // CUDA_VERSION >= 11000 + +template +class HasCanPackAs { + typedef char one; + struct two { + char x[2]; + }; + + template + static one test(decltype(&C::CanPackAs)); + template + static two test(...); + + public: + enum { value = sizeof(test(0)) == sizeof(char) }; +}; + +template +typename std::enable_if::value == true, bool>::type CanPackAs(T t, + size_t pack_size) { + return t.CanPackAs(pack_size); +} + +template +typename std::enable_if::value == false, bool>::type CanPackAs(T t, + size_t pack_size) { + return true; +} + +template +struct GetPackType { + using type = typename std::aligned_storage::type; +}; + +template +using PackType = typename GetPackType::type; + +template +union Pack { + static_assert(sizeof(PackType) == sizeof(T) * N, ""); + __device__ Pack() { + // do nothing + } + PackType storage; + T elem[N]; +}; + +template +struct DirectLoad { + using LoadType = DST; + DirectLoad(const SRC* src, int64_t row_size) : src(src), row_size(row_size) {} + template + __device__ void load(DST* dst, int64_t row, int64_t col) const { + Pack pack; + const int64_t offset = (row * row_size + col) / N; + pack.storage = *(reinterpret_cast*>(src) + offset); +#pragma unroll + for (int i = 0; i < N; ++i) { dst[i] = static_cast(pack.elem[i]); } + } + const SRC* src; + int64_t row_size; +}; + +template +struct DirectStore { + DirectStore(DST* dst, int64_t row_size) : dst(dst), row_size(row_size) {} + template + __device__ void store(const SRC* src, int64_t row, int64_t col) { + Pack pack; + const int64_t offset = (row * row_size + col) / N; +#pragma unroll + for (int i = 0; i < N; ++i) { pack.elem[i] = static_cast(src[i]); } + *(reinterpret_cast*>(dst) + offset) = pack.storage; + } + DST* dst; + int64_t row_size; +}; + +template +inline __device__ void WelfordCombine(T val, T* mean, T* m2, T* count) { + // Use Welford Online algorithem to compute mean and variance + // For more details you can refer to: + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm + *count += 1; + T delta1 = val - *mean; + *mean += Div(delta1, *count); + T delta2 = val - *mean; + *m2 += delta1 * delta2; +} + +template +inline __device__ void WelfordCombine(T b_mean, T b_m2, T b_count, T* mean, T* m2, T* count) { + if (b_count == 0) { return; } + T new_count = *count + b_count; + T nb_over_n = Div(b_count, new_count); + T delta = b_mean - *mean; + *mean += delta * nb_over_n; + *m2 += b_m2 + delta * delta * (*count) * nb_over_n; + *count = new_count; +} + +template +__inline__ __device__ void WelfordWarpReduce(T thread_mean, T thread_m2, T thread_count, T* mean, + T* m2, T* count) { + *mean = thread_mean; + *m2 = thread_m2; + *count = thread_count; + for (int mask = thread_group_width / 2; mask > 0; mask /= 2) { + T b_mean = __shfl_down_sync(0xffffffff, *mean, mask, thread_group_width); + T b_m2 = __shfl_down_sync(0xffffffff, *m2, mask, thread_group_width); + T b_count = __shfl_down_sync(0xffffffff, *count, mask, thread_group_width); + WelfordCombine(b_mean, b_m2, b_count, mean, m2, count); + } +} + +template +__inline__ __device__ void WelfordWarpAllReduce(T thread_mean, T thread_m2, T thread_count, T* mean, + T* m2, T* count) { + WelfordWarpReduce(thread_mean, thread_m2, thread_count, mean, m2, count); + *mean = __shfl_sync(0xffffffff, *mean, 0, thread_group_width); + *m2 = __shfl_sync(0xffffffff, *m2, 0, thread_group_width); + *count = __shfl_sync(0xffffffff, *count, 0, thread_group_width); +} + +template +__inline__ __device__ T WarpReduceSum(T x) { + T result = 0.0f; + #pragma unroll + for (int mask = thread_group_width / 2; mask > 0; mask /= 2) { + result += __shfl_xor_sync(0xffffffff, x, mask, thread_group_width); + } + return result; +} + +template +__inline__ __device__ void WelfordBlockAllReduce(T thread_mean, T thread_m2, T thread_count, + T* result_mean, T* result_m2, T* result_count) { + __shared__ T mean_shared[kWarpSize]; + __shared__ T m2_shared[kWarpSize]; + __shared__ T count_shared[kWarpSize]; + __shared__ T mean_result_broadcast; + __shared__ T m2_result_broadcast; + __shared__ T count_result_broadcast; + const int lid = threadIdx.x % kWarpSize; + const int wid = threadIdx.x / kWarpSize; + T warp_mean = 0; + T warp_m2 = 0; + T warp_count = 0; + WelfordWarpReduce(thread_mean, thread_m2, thread_count, &warp_mean, &warp_m2, &warp_count); + __syncthreads(); + if (lid == 0) { + mean_shared[wid] = warp_mean; + m2_shared[wid] = warp_m2; + count_shared[wid] = warp_count; + } + __syncthreads(); + if (wid == 0) { + if (threadIdx.x < blockDim.x / kWarpSize) { + warp_mean = mean_shared[lid]; + warp_m2 = m2_shared[lid]; + warp_count = count_shared[lid]; + } else { + warp_mean = static_cast(0); + warp_m2 = static_cast(0); + warp_count = static_cast(0); + } + __syncwarp(); + T block_mean = 0; + T block_m2 = 0; + T block_count = 0; + WelfordWarpReduce(warp_mean, warp_m2, warp_count, &block_mean, &block_m2, &block_count); + if (lid == 0) { + mean_result_broadcast = block_mean; + m2_result_broadcast = block_m2; + count_result_broadcast = block_count; + } + } + __syncthreads(); + *result_mean = mean_result_broadcast; + *result_m2 = m2_result_broadcast; + *result_count = count_result_broadcast; +} + +template +__global__ void LayerNormWarpImpl(LOAD load, STORE store, const int64_t rows, const int64_t cols, + const double epsilon, ComputeType* mean, + ComputeType* inv_variance) { + using LoadType = typename LOAD::LoadType; + static_assert(max_cols_per_thread % pack_size == 0, ""); + static_assert(min_cols_per_thread % pack_size == 0, ""); + static_assert(thread_group_width <= kWarpSize, ""); + static_assert(kWarpSize % thread_group_width == 0, ""); + constexpr int max_num_packs = max_cols_per_thread / pack_size; + constexpr int min_num_packs = min_cols_per_thread / pack_size; + assert(cols <= max_cols_per_thread * thread_group_width); + ComputeType buf[rows_per_access][max_cols_per_thread]; + const int64_t global_thread_group_id = blockIdx.x * blockDim.y + threadIdx.y; + const int64_t num_global_thread_group = gridDim.x * blockDim.y; + const int64_t lane_id = threadIdx.x; + const int64_t step = num_global_thread_group * rows_per_access; + for (int64_t row = global_thread_group_id * rows_per_access; row < rows; row += step) { + ComputeType thread_mean[rows_per_access]; + ComputeType thread_m2[rows_per_access]; + ComputeType thread_count[rows_per_access]; +#pragma unroll + for (int row_id = 0; row_id < rows_per_access; ++row_id) { + thread_mean[row_id] = 0; + thread_m2[row_id] = 0; + thread_count[row_id] = 0; + ComputeType* row_buf = buf[row_id]; +#pragma unroll + for (int pack_id = 0; pack_id < min_num_packs; ++pack_id) { + const int col = (pack_id * thread_group_width + lane_id) * pack_size; + const int pack_offset = pack_id * pack_size; + LoadType pack[pack_size]; + load.template load(pack, row + row_id, col); +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + row_buf[pack_offset + i] = static_cast(pack[i]); + WelfordCombine(row_buf[pack_offset + i], thread_mean + row_id, thread_m2 + row_id, + thread_count + row_id); + } + } + for (int pack_id = min_num_packs; pack_id < max_num_packs; ++pack_id) { + const int col = (pack_id * thread_group_width + lane_id) * pack_size; + const int pack_offset = pack_id * pack_size; + if (!padding || col < cols) { + LoadType pack[pack_size]; + load.template load(pack, row + row_id, col); +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + row_buf[pack_offset + i] = static_cast(pack[i]); + WelfordCombine(row_buf[pack_offset + i], thread_mean + row_id, thread_m2 + row_id, + thread_count + row_id); + } + } else { +#pragma unroll + for (int i = 0; i < pack_size; ++i) { row_buf[pack_offset + i] = 0; } + } + } + } + ComputeType warp_mean[rows_per_access]; + ComputeType warp_m2[rows_per_access]; + ComputeType warp_count[rows_per_access]; +#pragma unroll + for (int row_id = 0; row_id < rows_per_access; ++row_id) { + int global_row_id = row + row_id; + ComputeType* row_buf = buf[row_id]; + WelfordWarpAllReduce( + thread_mean[row_id], thread_m2[row_id], thread_count[row_id], warp_mean + row_id, + warp_m2 + row_id, warp_count + row_id); + ComputeType row_mean = warp_mean[row_id]; + ComputeType row_variance = + max(Div(warp_m2[row_id], warp_count[row_id]), static_cast(0.0)); + ComputeType row_inv_var = Rsqrt(row_variance + static_cast(epsilon)); + if (lane_id == 0) { + mean[global_row_id] = row_mean; + inv_variance[global_row_id] = row_inv_var; + } +#pragma unroll + for (int i = 0; i < max_cols_per_thread; ++i) { + row_buf[i] = (row_buf[i] - row_mean) * row_inv_var; + } +#pragma unroll + for (int i = 0; i < min_num_packs; ++i) { + const int col = (i * thread_group_width + lane_id) * pack_size; + store.template store(row_buf + i * pack_size, global_row_id, col); + } +#pragma unroll + for (int i = min_num_packs; i < max_num_packs; ++i) { + const int col = (i * thread_group_width + lane_id) * pack_size; + if (!padding || col < cols) { + store.template store(row_buf + i * pack_size, global_row_id, col); + } + } + } + } +} + +template +inline cudaError_t LaunchLayerNormWarpImpl(cudaStream_t stream, LOAD load, STORE store, + const int64_t rows, const int64_t cols, + const double epsilon, ComputeType* mean, + ComputeType* inv_variance) { + constexpr int block_size = 128; + constexpr int waves = 32; + static_assert(block_size % thread_group_width == 0, ""); + constexpr int thread_groups_per_block = block_size / thread_group_width; + dim3 block_dim(thread_group_width, thread_groups_per_block); + const int64_t num_blocks = + (rows / rows_per_access + thread_groups_per_block - 1) / thread_groups_per_block; + int grid_dim_x; + { + cudaError_t err = GetNumBlocks( + LayerNormWarpImpl, + block_size, 0, num_blocks, waves, &grid_dim_x); + if (err != cudaSuccess) { return err; } + } + LayerNormWarpImpl + <<>>(load, store, rows, cols, epsilon, mean, inv_variance); + return cudaPeekAtLastError(); +} + +template +inline cudaError_t DispatchLayerNormWarpImplPadding(cudaStream_t stream, LOAD load, STORE store, + const int64_t rows, const int64_t cols, + const double epsilon, ComputeType* mean, + ComputeType* inv_variance) { + if (cols == max_cols_per_thread * thread_group_width) { + // when not padding, min_cols_per_thread must equals to max_cols_per_thread, pass + // max_cols_per_thread as min_cols_per_thread and max_cols_per_thread param. + return LaunchLayerNormWarpImpl( + stream, load, store, rows, cols, epsilon, mean, inv_variance); + } else { + return LaunchLayerNormWarpImpl( + stream, load, store, rows, cols, epsilon, mean, inv_variance); + } +} + +template +typename std::enable_if::type DispatchLayerNormWarpImplCols( + cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols, + const double epsilon, ComputeType* mean, ComputeType* inv_variance) { + if (cols <= 0) { return cudaErrorInvalidValue; } +#define DEFINE_ONE_ELIF(thread_group_width) \ + else if (cols <= (thread_group_width)*pack_size) { \ + if (rows % 2 == 0) { \ + return DispatchLayerNormWarpImplPadding( \ + stream, load, store, rows, cols, epsilon, mean, inv_variance); \ + } else { \ + return DispatchLayerNormWarpImplPadding( \ + stream, load, store, rows, cols, epsilon, mean, inv_variance); \ + } \ + } + DEFINE_ONE_ELIF(4) + DEFINE_ONE_ELIF(8) + DEFINE_ONE_ELIF(16) + DEFINE_ONE_ELIF(32) +#undef DEFINE_ONE_ELIF +#define DEFINE_ONE_ELIF(max_col, min_col) \ + else if (cols <= (max_col)*kWarpSize) { \ + return DispatchLayerNormWarpImplPadding(stream, load, store, rows, cols, \ + epsilon, mean, inv_variance); \ + } + DEFINE_ONE_ELIF(2, 1) + DEFINE_ONE_ELIF(4, 2) + DEFINE_ONE_ELIF(8, 4) + DEFINE_ONE_ELIF(12, 8) + DEFINE_ONE_ELIF(16, 12) + DEFINE_ONE_ELIF(20, 16) + DEFINE_ONE_ELIF(24, 20) + DEFINE_ONE_ELIF(28, 24) + +#undef DEFINE_ONE_ELIF + else { + return cudaErrorInvalidValue; + } +} + +template +typename std::enable_if::type DispatchLayerNormWarpImplCols( + cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols, + const double epsilon, ComputeType* mean, ComputeType* inv_variance) { + if (cols <= 0) { return cudaErrorInvalidValue; } +#define DEFINE_ONE_ELIF(thread_group_width) \ + else if (cols <= (thread_group_width)*pack_size) { \ + if (rows % 2 == 0) { \ + return DispatchLayerNormWarpImplPadding( \ + stream, load, store, rows, cols, epsilon, mean, inv_variance); \ + } else { \ + return DispatchLayerNormWarpImplPadding( \ + stream, load, store, rows, cols, epsilon, mean, inv_variance); \ + } \ + } + DEFINE_ONE_ELIF(4) + DEFINE_ONE_ELIF(8) + DEFINE_ONE_ELIF(16) + DEFINE_ONE_ELIF(32) +#undef DEFINE_ONE_ELIF +#define DEFINE_ONE_ELIF(max_col, min_col) \ + else if ((cols <= (max_col)*kWarpSize) && (cols > (min_col)*kWarpSize)) { \ + return DispatchLayerNormWarpImplPadding(stream, load, store, rows, cols, \ + epsilon, mean, inv_variance); \ + } + DEFINE_ONE_ELIF(4, 2) + DEFINE_ONE_ELIF(8, 4) + DEFINE_ONE_ELIF(12, 8) + DEFINE_ONE_ELIF(16, 12) + DEFINE_ONE_ELIF(20, 16) + DEFINE_ONE_ELIF(24, 20) + DEFINE_ONE_ELIF(28, 24) + DEFINE_ONE_ELIF(32, 28) + DEFINE_ONE_ELIF(48, 44) + DEFINE_ONE_ELIF(52, 48) +#undef DEFINE_ONE_ELIF + else { + return cudaErrorInvalidValue; + } +} + + +template +struct DispatchLayerNormWarpImplPackSize { + cudaError_t operator()(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, + const int64_t cols, const double epsilon, ComputeType* mean, + ComputeType* inv_variance) { + if (cols % 2 == 0 && CanPackAs(load, 2) && CanPackAs(store, 2)) { + return DispatchLayerNormWarpImplCols( + stream, load, store, rows, cols, epsilon, mean, inv_variance); + } + + else { + return DispatchLayerNormWarpImplCols( + stream, load, store, rows, cols, epsilon, mean, inv_variance); + } + } +}; + +template +inline cudaError_t DispatchLayerNormWarpImpl(cudaStream_t stream, LOAD load, STORE store, + const int64_t rows, const int64_t cols, + const double epsilon, ComputeType* mean, + ComputeType* inv_variance) { + return DispatchLayerNormWarpImplPackSize()( + stream, load, store, rows, cols, epsilon, mean, inv_variance); +} + +template +__global__ void LayerNormBlockSMemImpl(LOAD load, STORE store, const int64_t rows, + const int64_t cols, const double epsilon, + ComputeType* mean, + ComputeType* inv_variance, + ComputeType col_divisor) { + using LoadType = typename LOAD::LoadType; + extern __shared__ __align__(sizeof(double)) unsigned char shared_buf[]; + auto* buf = reinterpret_cast(shared_buf); + const int tid = threadIdx.x; + assert(cols % pack_size == 0); + const int num_packs = static_cast(cols) / pack_size; + for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) { + ComputeType thread_sum = 0; + ComputeType thread_sum_square = 0; + for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) { + LoadType pack[pack_size]; + load.template load(pack, row, pack_id * pack_size); +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + buf[i * num_packs + pack_id] = pack[i]; + ComputeType pack_val = static_cast(pack[i]); + thread_sum += pack_val; + thread_sum_square += pack_val * pack_val; + } + } + + const ComputeType row_sum = BlockAllReduce(thread_sum); + const ComputeType row_sum_square = BlockAllReduce(thread_sum_square); + + // use multiply instead of divide. + ComputeType row_mean = row_sum * col_divisor; + ComputeType row_sum_square_mean = row_sum_square * col_divisor; + ComputeType row_variance = max(row_sum_square_mean - row_mean * row_mean, static_cast(0.0)); + ComputeType row_inv_var = Rsqrt(row_variance + static_cast(epsilon)); + if (threadIdx.x == 0) { + mean[row] = row_mean; + inv_variance[row] = row_inv_var; + } + for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) { + ComputeType pack[pack_size]; +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + pack[i] = (static_cast(buf[i * num_packs + pack_id]) - row_mean) * row_inv_var; + } + store.template store(pack, row, pack_id * pack_size); + } + } +} + + +template +inline cudaError_t LaunchLayerNormBlockSMemImpl(cudaStream_t stream, LOAD load, STORE store, + int smem, const int64_t rows, const int64_t cols, + const double epsilon, ComputeType* mean, + ComputeType* inv_variance, + ComputeType col_divisor) { + constexpr int waves = 32; + int grid_dim_x; + { + cudaError_t err = + GetNumBlocks(LayerNormBlockSMemImpl, + block_size, smem, rows, waves, &grid_dim_x); + if (err != cudaSuccess) { return err; } + } + LayerNormBlockSMemImpl + <<>>(load, store, rows, cols, epsilon, mean, + inv_variance, col_divisor); + return cudaPeekAtLastError(); +} + +template +cudaError_t MaximizeDynamicSharedMemorySize(Func func, const int max_smem_size) { + cudaFuncAttributes attr{}; + cudaError_t err = cudaFuncGetAttributes(&attr, func); + if (err != cudaSuccess) { return err; } + constexpr int reserved_smem = 1024; // 1K + return cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, + max_smem_size - attr.sharedSizeBytes - reserved_smem); +} + +template +inline cudaError_t TryDispatchLayerNormBlockSMemImplBlockSize( + cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols, + const double epsilon, ComputeType* mean, ComputeType* inv_variance, ComputeType col_divisor, bool* success) { + constexpr int block_size_conf_1 = 128; + constexpr int block_size_conf_2 = 256; + constexpr int block_size_conf_3 = 512; + constexpr int block_size_conf_4 = 1024; + + int dev = 0; + { + cudaError_t err = cudaGetDevice(&dev); + if (err != cudaSuccess) { return err; } + } + + int sm_count = 0; + { + cudaError_t err = cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev); + if (err != cudaSuccess) { return err; } + } + + static const bool max_smem_configed = [=]() { + int max_smem_size = 0; + cudaError_t err = + cudaDeviceGetAttribute(&max_smem_size, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + if (err != cudaSuccess) { return false; } + + err = MaximizeDynamicSharedMemorySize( + LayerNormBlockSMemImpl, + max_smem_size); + if (err != cudaSuccess) { return false; } + err = MaximizeDynamicSharedMemorySize( + LayerNormBlockSMemImpl, + max_smem_size); + if (err != cudaSuccess) { return false; } + err = MaximizeDynamicSharedMemorySize( + LayerNormBlockSMemImpl, + max_smem_size); + if (err != cudaSuccess) { return false; } + err = MaximizeDynamicSharedMemorySize( + LayerNormBlockSMemImpl, + max_smem_size); + if (err != cudaSuccess) { return false; } + + return true; + }(); + + const size_t smem = cols * sizeof(typename LOAD::LoadType); + + int max_active_blocks_conf_1; + { + cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks_conf_1, + LayerNormBlockSMemImpl, + block_size_conf_1, smem); + if (err != cudaSuccess) { return err; } + } + if (max_active_blocks_conf_1 <= 0) { + *success = false; + return cudaSuccess; + } + + int max_active_blocks_conf_4; + { + cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks_conf_4, + LayerNormBlockSMemImpl, + block_size_conf_4, smem); + if (err != cudaSuccess) { return err; } + } + + if (max_active_blocks_conf_4 == max_active_blocks_conf_1 + || (max_active_blocks_conf_4 > 0 && rows <= sm_count)) { + *success = true; + return LaunchLayerNormBlockSMemImpl( + stream, load, store, smem, rows, cols, epsilon, mean, inv_variance, col_divisor); + } + + int max_active_blocks_conf_3; + { + cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks_conf_3, + LayerNormBlockSMemImpl, + block_size_conf_3, smem); + if (err != cudaSuccess) { return err; } + } + if (max_active_blocks_conf_3 == max_active_blocks_conf_1 + || (max_active_blocks_conf_3 > 0 && rows <= sm_count)) { + *success = true; + return LaunchLayerNormBlockSMemImpl( + stream, load, store, smem, rows, cols, epsilon, mean, inv_variance, col_divisor); + } + + int max_active_blocks_conf_2; + { + cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks_conf_2, + LayerNormBlockSMemImpl, + block_size_conf_2, smem); + if (err != cudaSuccess) { return err; } + } + if (max_active_blocks_conf_2 == max_active_blocks_conf_1 + || (max_active_blocks_conf_2 > 0 && rows <= sm_count)) { + *success = true; + return LaunchLayerNormBlockSMemImpl( + stream, load, store, smem, rows, cols, epsilon, mean, inv_variance, col_divisor); + } + + *success = true; + return LaunchLayerNormBlockSMemImpl( + stream, load, store, smem, rows, cols, epsilon, mean, inv_variance, col_divisor); +} + +template +struct TryDispatchLayerNormBlockSMemImplPackSize { + cudaError_t operator()(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, + const int64_t cols, const double epsilon, ComputeType* mean, + ComputeType* inv_variance, + ComputeType col_divisor, bool* success) { + if (cols % 4 == 0 && CanPackAs(load, 4) && CanPackAs(store, 4)) { + return TryDispatchLayerNormBlockSMemImplBlockSize( + stream, load, store, rows, cols, epsilon, mean, inv_variance, col_divisor, success); + } else if (cols % 2 == 0 && CanPackAs(load, 2) && CanPackAs(store, 2)) { + return TryDispatchLayerNormBlockSMemImplBlockSize( + stream, load, store, rows, cols, epsilon, mean, inv_variance, col_divisor, success); + } else { + return TryDispatchLayerNormBlockSMemImplBlockSize( + stream, load, store, rows, cols, epsilon, mean, inv_variance, col_divisor, success); + } + } +}; + +template +inline cudaError_t TryDispatchLayerNormBlockSMemImpl(cudaStream_t stream, LOAD load, STORE store, + const int64_t rows, const int64_t cols, + const double epsilon, ComputeType* mean, + ComputeType* inv_variance, + ComputeType col_divisor, bool* success) { + return TryDispatchLayerNormBlockSMemImplPackSize()( + stream, load, store, rows, cols, epsilon, mean, inv_variance, col_divisor, success); +} + +template +__global__ void __launch_bounds__(1024) + LayerNormBlockUncachedImpl(LOAD load, STORE store, const int64_t rows, const int64_t cols, + const double epsilon, ComputeType* mean, ComputeType* inv_variance) { + using LoadType = typename LOAD::LoadType; + const int tid = threadIdx.x; + assert(cols % pack_size == 0); + const int num_packs = static_cast(cols) / pack_size; + for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) { + ComputeType thread_mean = 0; + ComputeType thread_m2 = 0; + ComputeType thread_count = 0; + for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) { + LoadType pack[pack_size]; + load.template load(pack, row, pack_id * pack_size); +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + WelfordCombine(static_cast(pack[i]), &thread_mean, &thread_m2, &thread_count); + } + } + ComputeType row_mean = 0; + ComputeType row_m2 = 0; + ComputeType row_count = 0; + WelfordBlockAllReduce(thread_mean, thread_m2, thread_count, &row_mean, &row_m2, + &row_count); + ComputeType row_variance = max(Div(row_m2, row_count), static_cast(0.0)); + ComputeType row_inv_var = Rsqrt(row_variance + static_cast(epsilon)); + if (threadIdx.x == 0) { + mean[row] = row_mean; + inv_variance[row] = row_inv_var; + } + for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) { + LoadType pack[pack_size]; + ComputeType dst_pack[pack_size]; + const int pack_offset = pack_id * pack_size; + load.template load(pack, row, pack_offset); +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + dst_pack[i] = (static_cast(pack[i]) - row_mean) * row_inv_var; + } + store.template store(dst_pack, row, pack_offset); + } + } +} + +template +inline cudaError_t LaunchLayerNormBlockUncachedImpl(cudaStream_t stream, LOAD load, STORE store, + const int64_t rows, const int64_t cols, + const double epsilon, ComputeType* mean, + ComputeType* inv_variance) { + constexpr int block_size = 1024; + constexpr int waves = 32; + int grid_dim_x; + { + cudaError_t err = + GetNumBlocks(LayerNormBlockUncachedImpl, + block_size, 0, rows, waves, &grid_dim_x); + if (err != cudaSuccess) { return err; } + } + LayerNormBlockUncachedImpl + <<>>(load, store, rows, cols, epsilon, mean, inv_variance); + return cudaPeekAtLastError(); +} + +template +struct DispatchLayerNormBlockUncachedImplPackSize { + cudaError_t operator()(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, + const int64_t cols, const double epsilon, ComputeType* mean, + ComputeType* inv_variance) { + if (cols % 4 == 0 && CanPackAs(load, 4) && CanPackAs(store, 4)) { + return LaunchLayerNormBlockUncachedImpl( + stream, load, store, rows, cols, epsilon, mean, inv_variance); + } else if (cols % 2 == 0 && CanPackAs(load, 2) && CanPackAs(store, 2)) { + return LaunchLayerNormBlockUncachedImpl( + stream, load, store, rows, cols, epsilon, mean, inv_variance); + } else { + return LaunchLayerNormBlockUncachedImpl( + stream, load, store, rows, cols, epsilon, mean, inv_variance); + } + } +}; + +template +inline cudaError_t DispatchLayerNormBlockUncachedImpl(cudaStream_t stream, LOAD load, STORE store, + const int64_t rows, const int64_t cols, + const double epsilon, ComputeType* mean, + ComputeType* inv_variance) { + return DispatchLayerNormBlockUncachedImplPackSize()( + stream, load, store, rows, cols, epsilon, mean, inv_variance); +} + +template +inline typename std::enable_if::value, cudaError_t>::type +DispatchLayerNorm(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, + const int64_t cols, const double epsilon, ComputeType* mean, + ComputeType* inv_variance) { + const ComputeType col_divisor = 1.0f / cols; + if (cols <= 1024) { + return DispatchLayerNormWarpImpl(stream, load, store, rows, cols, + epsilon, mean, inv_variance); + } else { + bool dispatch_smem_impl_success; + { + cudaError_t err = TryDispatchLayerNormBlockSMemImpl( + stream, load, store, rows, cols, epsilon, mean, inv_variance, col_divisor, + &dispatch_smem_impl_success); + if (err != cudaSuccess) { return err; } + } + if (!dispatch_smem_impl_success) { + return DispatchLayerNormBlockUncachedImpl( + stream, load, store, rows, cols, epsilon, mean, inv_variance); + } + return cudaSuccess; + } +} + +template +inline typename std::enable_if::value, cudaError_t>::type +DispatchLayerNorm(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, + const int64_t cols, const double epsilon, ComputeType* mean, + ComputeType* inv_variance) { + return DispatchLayerNormBlockUncachedImpl( + stream, load, store, rows, cols, epsilon, mean, inv_variance); +} + +template +struct DequantSkipLoad { + using LoadType = DST; + DequantSkipLoad(const InputType* src, const SRC* bias, const SRC* skip, const float* dequant_scale, float alpha, int64_t row_size) + : src(src), bias(bias), skip(skip), dequant_scale(dequant_scale), alpha(alpha), row_size(row_size) {} + template + __device__ void load(DST* dst, int64_t row, int64_t col) const { + Pack src_pack; + Pack bias_pack; + Pack skip_pack; + Pack dequant_scale_pack; + const int64_t offset = (row * row_size + col) / N; + const int64_t bias_offset = col / N; + src_pack.storage = *(reinterpret_cast*>(src) + offset); + bias_pack.storage = *(reinterpret_cast*>(bias) + bias_offset); + skip_pack.storage = *(reinterpret_cast*>(skip) + offset); + dequant_scale_pack.storage = *(reinterpret_cast*>(dequant_scale) + bias_offset); // equal to col. +#pragma unroll + for (int i = 0; i < N; ++i) { + // First we need to cast src and dequant. + dst[i] = static_cast(static_cast(static_cast(src_pack.elem[i]) * dequant_scale_pack.elem[i]) + + bias_pack.elem[i] + + skip_pack.elem[i]); + } + } + const InputType* src; + const SRC* bias; + const SRC* skip; + const float* dequant_scale; + double alpha; + int64_t row_size; +}; + +template +__device__ __inline__ +T ClipFunc(const T v, const T min, const T max){ + if(v > max) return max; + if(v < min) return min; + return v; +} + +template +__forceinline__ __device__ OutType QuantHelperFunc(const InType input, + const float scale, + const int round_type, + const float max_bound, + const float min_bound) { + float quant_value = max_bound * scale * input; + + if (round_type == 0) { + // quant_value = static_cast(roundWithTiesToEven(quant_value)); + } else { + quant_value = static_cast(round(quant_value)); + } +// quant_value = quant_value > max_bound ? max_bound : quant_value; +// quant_value = quant_value < min_bound ? min_bound : quant_value; +// return static_cast(quant_value); + return static_cast(ClipFunc(quant_value, min_bound, max_bound)); +} + +template +struct AffineQuantStore { + AffineQuantStore(OutType* y, + const int64_t row_size, + const float* gamma, const float* beta, + const float quant_out_scale = 1.0, + const int quant_round_type = 1, + const float quant_max_bound = 127.0, + const float quant_min_bound = -127.0) : y(y), row_size(row_size), gamma(gamma), beta(beta), quant_round_type(quant_round_type), + quant_out_scale(quant_out_scale), quant_max_bound(quant_max_bound), quant_min_bound(quant_min_bound) {} + + template + __device__ void store(const SRC* src, int64_t row, int64_t col) { + Pack y_pack; + Pack gamma_pack; + Pack beta_pack; + // Pack out_pack; + const int64_t offset = (row * row_size + col) / N; + const int64_t gamma_offset = col / N; + gamma_pack.storage = *(reinterpret_cast*>(gamma) + gamma_offset); + beta_pack.storage = *(reinterpret_cast*>(beta) + gamma_offset); +#pragma unroll + for (int i = 0; i < N; ++i) { + float normalized_i = static_cast(src[i]); + float normalized_val = normalized_i * gamma_pack.elem[i] + beta_pack.elem[i]; + if (do_scale) { + y_pack.elem[i] = QuantHelperFunc(normalized_val, quant_out_scale, quant_round_type, quant_max_bound, quant_min_bound); + } else { + y_pack.elem[i] = static_cast(normalized_val); + } + } + *(reinterpret_cast*>(y) + offset) = y_pack.storage; + } + + OutType* y; + int64_t row_size; + const float* gamma; + const float* beta; + const int quant_round_type; + const float quant_out_scale; + const float quant_max_bound; + const float quant_min_bound; +}; + +template +struct DequantSkipLoadAndStoreResidual { + using LoadType = DST; + // need to aseert SRC equals to DST. + DequantSkipLoadAndStoreResidual(const InputType* src, + const SRC* bias, + const SRC* skip, + const float* dequant_scale, + SRC* residual_bias_out, + float alpha, int64_t row_size) + : src(src), bias(bias), skip(skip), dequant_scale(dequant_scale), residual_bias_out(residual_bias_out), alpha(alpha), row_size(row_size) {} + template + __device__ void load(DST* dst, int64_t row, int64_t col) const { + Pack src_pack; + Pack bias_pack; + Pack skip_pack; + Pack dequant_scale_pack; + Pack residual_out_pack; + + const int64_t offset = (row * row_size + col) / N; + const int64_t bias_offset = col / N; + src_pack.storage = *(reinterpret_cast*>(src) + offset); + bias_pack.storage = *(reinterpret_cast*>(bias) + bias_offset); + skip_pack.storage = *(reinterpret_cast*>(skip) + offset); + dequant_scale_pack.storage = *(reinterpret_cast*>(dequant_scale) + bias_offset); // equal to col. +#pragma unroll + for (int i = 0; i < N; ++i) { + // First we need to cast src and dequant. + if (do_dequant) { + residual_out_pack.elem[i] = static_cast(static_cast(static_cast(src_pack.elem[i]) * dequant_scale_pack.elem[i]) + + bias_pack.elem[i] + + skip_pack.elem[i]); + } else { + residual_out_pack.elem[i] = static_cast(static_cast(src_pack.elem[i]) + bias_pack.elem[i] + + skip_pack.elem[i]); + } + } +#pragma unroll + for (int i = 0; i < N; ++i) { + dst[i] = residual_out_pack.elem[i]; + } + *(reinterpret_cast*>(residual_bias_out) + offset) = residual_out_pack.storage; + } + const InputType* src; + const SRC* bias; + const SRC* skip; + const float* dequant_scale; + SRC* residual_bias_out; + double alpha; + int64_t row_size; +}; + +} // namespace operators +} // namespace paddle \ No newline at end of file diff --git a/paddle/fluid/operators/fused/quant_dequant_kernel.h b/paddle/fluid/operators/fused/quant_dequant_kernel.h index 21b7b0f345466..bd490555a8b86 100644 --- a/paddle/fluid/operators/fused/quant_dequant_kernel.h +++ b/paddle/fluid/operators/fused/quant_dequant_kernel.h @@ -18,17 +18,85 @@ limitations under the License. */ #include "paddle/fluid/operators/fake_quantize_op.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/float16.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/kernels/funcs/aligned_vector.h" namespace paddle { namespace operators { +using phi::backends::gpu::GpuLaunchConfig; + +constexpr int DequantKernelVecSize = 4; + + +template +struct QuantFunc{ + HOSTDEVICE int8_t operator()(const T x, const float scale, const float max_bound, + const float min_bound) { + float tmp = static_cast(x) * max_bound * scale; + tmp = round(tmp); + if (tmp > max_bound) + tmp = max_bound; + else if (tmp < min_bound) + tmp = min_bound; + return static_cast(tmp); + } +}; + +template +__global__ void QuantActKernel(const T* x, const int32_t rows, const int32_t cols, float scale, int8_t* quant_x, + const float max_bound, + const float min_bound) { + + using InVec = phi::AlignedVector; + using OutVec = phi::AlignedVector; + + const int stride = blockDim.x * gridDim.x * VecSize; + const int num_items = rows * cols; + + InVec in_vec; + OutVec out_vec; + for(int32_t linear_index = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize; linear_index < num_items; linear_index += stride){ + phi::Load(x + linear_index, &in_vec); + #pragma unroll + for (int i = 0; i < VecSize; ++i) { + out_vec[i] = QuantFunc()(in_vec[i], scale, max_bound, min_bound); + } + phi::Store(out_vec, quant_x + linear_index); + } +} + + +template +void LaunchQuantActKernel(const T* x, const int32_t rows, const int32_t cols, int8_t* quant_x, float scale, + const float max_bound, const float min_bound, gpuStream_t stream) { + constexpr int NumThreads=256; + constexpr int VecSize= 16 / sizeof(T); + + constexpr int kNumWaves = 8; + int dev; + cudaGetDevice(&dev); + int sm_count; + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev); + int tpm; + cudaDeviceGetAttribute(&tpm, cudaDevAttrMaxThreadsPerMultiProcessor, dev); + const int elem_cnt = rows*cols; + const int launch_elem_cnt = elem_cnt / VecSize; + const int grid_size = std::max(1, std::min((launch_elem_cnt + NumThreads - 1) / NumThreads, + sm_count * tpm / NumThreads * kNumWaves)); + + QuantActKernel<<>>(x, rows, cols, scale, quant_x, max_bound, min_bound); +} + + template __forceinline__ __device__ int8_t quant_helper(const T input, const float scale, const int round_type, const float max_bound, const float min_bound) { - float quant_value = max_bound * inverse(scale) * static_cast(input); + float quant_value = max_bound * scale * static_cast(input); + if (round_type == 0) { quant_value = static_cast(roundWithTiesToEven(quant_value)); } else { @@ -77,7 +145,7 @@ void quantize_kernel_launcher(const T* input, const float min_bound, gpuStream_t stream) { // TODO(minghaoBD): optimize the kennel launch times when m==1 or n==1 - dim3 grid((n + 31) / 32, (m + 31) / 32); + dim3 grid((n >> 2 + 31) / 32, (m + 31) / 32); dim3 block(32, 32); quantize_kernel<<>>(input, @@ -90,47 +158,49 @@ void quantize_kernel_launcher(const T* input, min_bound); } -// dequantize using weight scales and input scales -template +template __global__ void dequantize_kernel(T* output, const int32_t* input, - const int m, // hidden - const int n, // batch size + const int m, // batch size + const int n, // hidden const float quant_in_scale, - const float* dequant_out_scale_data, - const int quant_out_scale_offset) { - int m_id = blockIdx.x * blockDim.x + threadIdx.x; // hidden - int n_id = blockIdx.y * blockDim.y + threadIdx.y; // batch size - - bool check = ((m_id < m) && (n_id < n)); - if (check) { - float out_scale = dequant_out_scale_data[quant_out_scale_offset + m_id]; - output[n_id * m + m_id] = - static_cast(static_cast(input[n_id * m + m_id]) * - quant_in_scale / out_scale); + const float* dequant_out_scale_data) { + int numel = m * n; + int stride = blockDim.x * gridDim.x * VecSize; + int idx = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize; + int col_id = idx % n; + + phi::AlignedVector in_vec; + phi::AlignedVector out_scale_vec; + phi::AlignedVector out_vec; + + for (; idx < numel; idx += stride) { + phi::Load(input + idx, &in_vec); + phi::Load(dequant_out_scale_data + col_id, &out_scale_vec); + +#pragma unroll + for (int i = 0; i < VecSize; ++i) { + out_vec[i] = + static_cast(static_cast(in_vec[i]) * out_scale_vec[i]); + } + + phi::Store(out_vec, output + idx); } } template void dequantize_kernel_launcher(const int32_t* input, T* output, - const int batch_size, // m - const int hidden_units, // n + const int m, // m + const int n, // n gpuStream_t stream, + GpuLaunchConfig* gpu_config, const float quant_in_scale, - const float* dequant_out_scale_data, - const int quant_out_scale_offset) { - dim3 grid((hidden_units + 31) / 32, (batch_size + 31) / 32); - dim3 block(32, 32); - - dequantize_kernel<<>>(output, - input, - hidden_units, - batch_size, - quant_in_scale, - dequant_out_scale_data, - quant_out_scale_offset); + const float* dequant_out_scale_data) { + dequantize_kernel + <<block_per_grid, gpu_config->thread_per_block, 0, stream>>>( + output, input, m, n, quant_in_scale, dequant_out_scale_data); } } // namespace operators -} // namespace paddle +} // namespace paddle \ No newline at end of file diff --git a/paddle/fluid/platform/dynload/cublasLt.h b/paddle/fluid/platform/dynload/cublasLt.h index c3425ac604858..43a62ac2f9742 100644 --- a/paddle/fluid/platform/dynload/cublasLt.h +++ b/paddle/fluid/platform/dynload/cublasLt.h @@ -39,7 +39,7 @@ namespace dynload { extern DynLoad__##__name __name // APIs available after CUDA 10.1 -// #if CUDA_VERSION >= 10100 +#if CUDA_VERSION >= 11010 #define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \ __macro(cublasLtCreate); \ __macro(cublasLtDestroy); \ @@ -61,7 +61,33 @@ namespace dynload { __macro(cublasLtMatrixTransformDescDestroy); \ __macro(cublasLtMatrixTransformDescSetAttribute); \ __macro(cublasLtMatmulAlgoInit); \ - __macro(cublasLtMatmulAlgoConfigSetAttribute); + __macro(cublasLtMatmulAlgoConfigSetAttribute); \ + __macro(cublasLtMatmulAlgoGetIds); \ + __macro(cublasLtMatmulAlgoCapGetAttribute); \ + __macro(cublasLtMatmulAlgoCheck); \ + __macro(cublasLtGetCudartVersion); +#else +#define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \ + __macro(cublasLtCreate); \ + __macro(cublasLtDestroy); \ + __macro(cublasLtMatmul); \ + __macro(cublasLtMatmulDescCreate); \ + __macro(cublasLtMatmulDescDestroy); \ + __macro(cublasLtMatmulDescSetAttribute); \ + __macro(cublasLtMatmulDescGetAttribute); \ + __macro(cublasLtMatrixLayoutCreate); \ + __macro(cublasLtMatrixLayoutDestroy); \ + __macro(cublasLtMatrixLayoutSetAttribute); \ + __macro(cublasLtMatrixLayoutGetAttribute); \ + __macro(cublasLtMatmulPreferenceCreate); \ + __macro(cublasLtMatmulPreferenceDestroy); \ + __macro(cublasLtMatmulPreferenceSetAttribute); \ + __macro(cublasLtMatmulAlgoGetHeuristic); \ + __macro(cublasLtMatrixTransform); \ + __macro(cublasLtMatrixTransformDescCreate); \ + __macro(cublasLtMatrixTransformDescDestroy); \ + __macro(cublasLtMatrixTransformDescSetAttribute); +#endif CUBLASLT_BLAS_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_CUBLASLT_WRAP) // #endif @@ -69,4 +95,4 @@ CUBLASLT_BLAS_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_CUBLASLT_WRAP) #undef PLATFORM_DECLARE_DYNAMIC_LOAD_CUBLASLT_WRAP } // namespace dynload } // namespace platform -} // namespace paddle +} // namespace paddle \ No newline at end of file diff --git a/paddle/fluid/pybind/op_function_generator.h b/paddle/fluid/pybind/op_function_generator.h index af080bd0b3431..661e0edcd441d 100644 --- a/paddle/fluid/pybind/op_function_generator.h +++ b/paddle/fluid/pybind/op_function_generator.h @@ -61,7 +61,9 @@ std::map> op_ins_map = { "QKVW", "QKVBias", "CacheKV", + "BeamCacheOffset", "TimeStep", + "SeqLengths", "SrcMask", "OutLinearW", "OutLinearBias", @@ -71,12 +73,39 @@ std::map> op_ins_map = { "FFN1Bias", "FFN2Weight", "FFN2Bias"}}, + {"fused_multi_transformer_moe", + {"X", + "LnScale", + "LnBias", + "QKVW", + "QKVBias", + "CacheKV", + "BeamCacheOffset", + "TimeStep", + "SeqLengths", + "SrcMask", + "OutLinearW", + "OutLinearBias", + "GateWeight", + "GateBias", + "FFNLnScale", + "FFNLnBias", + "ExpertWeight1", + "ExpertBias1", + "ExpertWeight2", + "ExpertBias2"}}, {"fused_multi_transformer_int8", {"X", "LnScale", "LnBias", "QKVW", - "QKVBias", "CacheKV", "TimeStep", "SrcMask", + "QKVBias", "CacheKV", "BeamCacheOffset", "TimeStep", "SeqLengths", "SrcMask", "OutLinearW", "OutLinearBias", "FFNLnScale", "FFNLnBias", "FFN1Weight", "FFN1Bias", "FFN2Weight", "FFN2Bias", "QKVOutScale", "OutLinearOutScale", "FFN1OutScale", "FFN2OutScale"}}, + {"fused_multi_transformer_moe_int8", + {"X", "LnScale", "LnBias", "QKVW", + "QKVBias", "CacheKV", "BeamCacheOffset", "TimeStep", "SeqLengths", "SrcMask", + "OutLinearW", "OutLinearBias", "GateWeight", "GateBias", "FFNLnScale", "FFNLnBias", + "ExpertWeight1", "ExpertBias1", "ExpertWeight2", "ExpertBias2", + "QKVOutScale", "OutLinearOutScale", "ExpertWeight1OutScale", "ExpertWeight2OutScale"}}, {"fused_bias_dropout_residual_layer_norm", {"X", "Residual", "Bias", "LnScale", "LnBias"}}, {"instance_norm", {"X", "Scale", "Bias"}}, @@ -335,7 +364,9 @@ std::map> op_outs_map = { "Beta2PowOut", "MasterParamOut"}}, {"fused_multi_transformer", {"CacheKVOut", "Out"}}, + {"fused_multi_transformer_moe", {"CacheKVOut", "Out"}}, {"fused_multi_transformer_int8", {"CacheKVOut", "Out"}}, + {"fused_multi_transformer_moe_int8", {"CacheKVOut", "Out"}}, {"resnet_basic_block", {"Y", "Conv1", "SavedMean1", "SavedInvstd1", "Mean1Out", "Var1Out", "Conv2", "SavedMean2", "SavedInvstd2", "Mean2Out", @@ -440,7 +471,9 @@ std::map> op_passing_outs_map = { {"split", {"Out"}}, {"concat", {"Out"}}, {"fused_multi_transformer", {"CacheKVOut"}}, + {"fused_multi_transformer_moe", {"CacheKVOut"}}, {"fused_multi_transformer_int8", {"CacheKVOut"}}, + {"fused_multi_transformer_moe_int8", {"CacheKVOut"}}, {"group_norm", {"Mean", "Variance"}}, {"resnet_basic_block", {"Mean1Out", "Var1Out", "Mean2Out", "Var2Out", "Mean3Out", "Var3Out"}}, diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index d3cd00b3a541c..38bdacc3bcf6f 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -228,7 +228,10 @@ data_type : numbers - op : fused_moe_kernel - args : (Tensor x, Tensor gate_weight, Tensor gate_bias, Tensor ln_scale, Tensor ln_bias, Tensor[] experts_weight1, Tensor[] experts_bias1, Tensor[] experts_weight2, Tensor[] experts_bias2, bool pre_layer_norm, float ln_epsilon, int topk, int mp_size, int mp_rank, int num_expert, int world_size, int moe_ring_id, bool approximate) + args : (Tensor x, Tensor residual, Tensor gate_weight, Tensor gate_bias, Tensor ln_scale, Tensor ln_bias, + Tensor[] experts_weight1, Tensor[] experts_bias1, Tensor[] experts_weight2, Tensor[] experts_bias2, + bool pre_layer_norm, float ln_epsilon, int topk, int mp_size, int mp_rank, int num_expert, int world_size, + int moe_ring_id, bool approximate, int bsz, int seq_len, int d_model, int dim_feedforward) output : Tensor(out) infer_meta : func : FusedMoeInferMeta diff --git a/paddle/phi/backends/dynload/cublasLt.h b/paddle/phi/backends/dynload/cublasLt.h index 90492ff4ba69d..d078feef9c28e 100644 --- a/paddle/phi/backends/dynload/cublasLt.h +++ b/paddle/phi/backends/dynload/cublasLt.h @@ -54,6 +54,7 @@ extern void *cublasLt_dso_handle; // APIs available after CUDA 10.1 // #if CUDA_VERSION >= 10100 +#if CUDA_VERSION >= 11010 #define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \ __macro(cublasLtCreate); \ __macro(cublasLtDestroy); \ @@ -75,11 +76,37 @@ extern void *cublasLt_dso_handle; __macro(cublasLtMatrixTransformDescDestroy); \ __macro(cublasLtMatrixTransformDescSetAttribute); \ __macro(cublasLtMatmulAlgoInit); \ - __macro(cublasLtMatmulAlgoConfigSetAttribute); + __macro(cublasLtMatmulAlgoConfigSetAttribute); \ + __macro(cublasLtMatmulAlgoGetIds); \ + __macro(cublasLtMatmulAlgoCapGetAttribute); \ + __macro(cublasLtMatmulAlgoCheck); \ + __macro(cublasLtGetCudartVersion); +#else +#define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \ + __macro(cublasLtCreate); \ + __macro(cublasLtDestroy); \ + __macro(cublasLtMatmul); \ + __macro(cublasLtMatmulDescCreate); \ + __macro(cublasLtMatmulDescDestroy); \ + __macro(cublasLtMatmulDescSetAttribute); \ + __macro(cublasLtMatmulDescGetAttribute); \ + __macro(cublasLtMatrixLayoutCreate); \ + __macro(cublasLtMatrixLayoutDestroy); \ + __macro(cublasLtMatrixLayoutSetAttribute); \ + __macro(cublasLtMatrixLayoutGetAttribute); \ + __macro(cublasLtMatmulPreferenceCreate); \ + __macro(cublasLtMatmulPreferenceDestroy); \ + __macro(cublasLtMatmulPreferenceSetAttribute); \ + __macro(cublasLtMatmulAlgoGetHeuristic); \ + __macro(cublasLtMatrixTransform); \ + __macro(cublasLtMatrixTransformDescCreate); \ + __macro(cublasLtMatrixTransformDescDestroy); \ + __macro(cublasLtMatrixTransformDescSetAttribute); +#endif CUBLASLT_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLASLT_WRAP) // #endif #undef DECLARE_DYNAMIC_LOAD_CUBLASLT_WRAP } // namespace dynload -} // namespace phi +} // namespace phi \ No newline at end of file diff --git a/paddle/phi/backends/gpu/gpu_resources.cc b/paddle/phi/backends/gpu/gpu_resources.cc index 4a16480101a70..08975eca01948 100644 --- a/paddle/phi/backends/gpu/gpu_resources.cc +++ b/paddle/phi/backends/gpu/gpu_resources.cc @@ -174,13 +174,13 @@ void DestroyBlasHandle(blasHandle_t handle) { } void InitBlasLtHandle(blasLtHandle_t* blaslt_handle) { -#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 +#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10020 phi::dynload::cublasLtCreate(blaslt_handle); #endif } void DestroyBlasLtHandle(blasLtHandle_t handle) { -#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 +#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10020 if (handle != nullptr) { phi::dynload::cublasLtDestroy(handle); handle = nullptr; diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 0a45e1fb0530b..6340dbfd57401 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -2949,6 +2949,7 @@ void GraphSendUVInferMeta(const MetaTensor& x, } void FusedMoeInferMeta(const MetaTensor& x, + const MetaTensor& residual, const MetaTensor& gate_weight, const MetaTensor& gate_bias, const MetaTensor& ln_scale, @@ -2966,6 +2967,10 @@ void FusedMoeInferMeta(const MetaTensor& x, int world_size, int moe_ring_id, bool approximate, + int bsz, + int seq_len, + int d_model, + int dim_feedforward, MetaTensor* out) { out->set_dims(x.dims()); out->set_dtype(x.dtype()); diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 948c5ca75ac4c..2d1a45f31c0ce 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -545,6 +545,7 @@ void GraphSendUVInferMeta(const MetaTensor& x, MetaTensor* out); void FusedMoeInferMeta(const MetaTensor& x, + const MetaTensor& residual, const MetaTensor& gate_weight, const MetaTensor& gate_bias, const MetaTensor& ln_scale, @@ -562,5 +563,9 @@ void FusedMoeInferMeta(const MetaTensor& x, int world_size, int moe_ring_id, bool approximate, + int bsz, + int seq_len, + int d_model, + int dim_feedforward, MetaTensor* out); } // namespace phi diff --git a/paddle/phi/kernels/fused_moe_kernel.h b/paddle/phi/kernels/fused_moe_kernel.h index dbafdf3918025..a4abd127beb9a 100644 --- a/paddle/phi/kernels/fused_moe_kernel.h +++ b/paddle/phi/kernels/fused_moe_kernel.h @@ -49,9 +49,110 @@ #endif namespace phi { +using Tensor = DenseTensor; namespace framework = paddle::framework; namespace platform = paddle::platform; +template +static void AllToAll(Tensor& tensor, // NOLINT + Tensor& out, + const int ring_id, + const phi::GPUContext& ctx) { + if (ring_id == -1) return; +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); + + if (map->has(ring_id)) { + paddle::distributed::ProcessGroup* pg = map->get(ring_id); + auto pg_nccl = static_cast(pg); + + std::vector in_tensor; + std::vector out_tensor; + in_tensor.push_back(tensor); + out_tensor.push_back(out); + auto task = pg_nccl->AllToAll(in_tensor, out_tensor); + task->Wait(); + VLOG(0) << "wait, all to all success !"; + } else { + auto dtype = platform::ToNCCLDataType( + framework::TransToProtoVarType(tensor.dtype())); + int64_t send_numel = tensor.numel(); // send_numel + auto place = ctx.GetPlace(); + auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); + int nranks = comm->nranks(); + auto stream = ctx.stream(); + + framework::DDim x_dims = tensor.dims(); + framework::DDim out_dims(x_dims); + PADDLE_ENFORCE_EQ( + x_dims[0] % nranks, + 0, + platform::errors::InvalidArgument( + "The first dimension size (%d) of the input tensor must be " + "divisible by the number of ranks (%d).", + x_dims[0], + nranks)); + auto send_buf = tensor.data(); + auto recv_buf = out.mutable_data(out_dims, place); + size_t offset = 0; + send_numel /= nranks; + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); + for (auto i = 0; i < nranks; ++i) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend( + send_buf + offset, send_numel, dtype, i, comm->comm(), stream)); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv( + recv_buf + offset, send_numel, dtype, i, comm->comm(), stream)); + offset += send_numel; + } + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); + } +#else + PADDLE_THROW(platform::errors::Unimplemented( + "PaddlePaddle should compile with NCCL or RCCL when used tensor model " + "parallel op.")); +#endif +} + +template +static void AllGather(Tensor& tensor, // NOLINT + Tensor& out, + const int ring_id, + const phi::GPUContext& ctx) { + if (ring_id == -1) return; +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); + + if (map->has(ring_id)) { + paddle::distributed::ProcessGroup* pg = map->get(ring_id); + auto pg_nccl = static_cast(pg); + + std::vector in_tensor; + std::vector out_tensor; + in_tensor.push_back(tensor); + out_tensor.push_back(out); + auto task = pg_nccl->AllGather(in_tensor, out_tensor, true, true); + task->Wait(); + } else { + auto dtype = platform::ToNCCLDataType( + framework::TransToProtoVarType(tensor.dtype())); + int64_t numel = tensor.numel(); + auto place = ctx.GetPlace(); + auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); + auto stream = ctx.stream(); + auto out_dims = tensor.dims(); + int nranks = comm->nranks(); + out_dims[0] *= nranks; + out.mutable_data(out_dims, place); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather( + tensor.data(), out.data(), numel, dtype, comm->comm(), stream)); + } +#else + PADDLE_THROW(platform::errors::Unimplemented( + "PaddlePaddle should compile with NCCL or RCCL when used tensor model " + "parallel op.")); +#endif +} + template void GlobalScatterFunctor(const phi::GPUContext& ctx, const framework::Tensor* x, @@ -142,6 +243,12 @@ void GlobalScatterFunctor(const phi::GPUContext& ctx, } PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); } +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); + // VLOG(0) << "GlobalScatterFunctor cudaDeviceSynchronize success !"; +#else + PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize()); +#endif #else PADDLE_THROW( @@ -227,10 +334,12 @@ void GlobalScatterProcessGroupFunctor(const phi::GPUContext& ctx, } } PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); + // VLOG(0) << "GlobalScatterProcessGroupFunctor ncclGroupEnd " << i; } #ifdef PADDLE_WITH_CUDA PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); + // VLOG(0) << "GlobalScatterProcessGroupFunctor cudaDeviceSynchronize success!"; #else PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize()); #endif @@ -337,6 +446,14 @@ void GlobalGatherFunctor(const phi::GPUContext& ctx, } PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); } + +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); + // VLOG(0) << "GlobalGatherFunctor cudaDeviceSynchronize success !"; +#else + PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize()); +#endif + #else PADDLE_THROW( platform::errors::Unavailable("NCCL version >= 2.7.3 is needed.")); @@ -425,6 +542,7 @@ void GlobalGatherProcessGroupFunctor(const phi::GPUContext& ctx, #ifdef PADDLE_WITH_CUDA PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); + // VLOG(0) << "GlobalGatherProcessGroupFunctor cudaDeviceSynchronize success !"; #else PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize()); #endif @@ -480,6 +598,7 @@ void MatMulAndAdd(const phi::GPUContext& dev_ctx, template void FusedMoeKernel(const DeviceContext& context, const DenseTensor& x, + const DenseTensor& residual, const DenseTensor& gate_weight, const DenseTensor& gate_bias, const DenseTensor& ln_scale, @@ -497,6 +616,10 @@ void FusedMoeKernel(const DeviceContext& context, int world_size, int moe_ring_id, bool approximate, + int bsz, + int seq_len, + int d_model, + int dim_feedforward, DenseTensor* out); } // namespace phi diff --git a/paddle/phi/kernels/fusion/gpu/beam_search_softmax.cu b/paddle/phi/kernels/fusion/gpu/beam_search_softmax.cu index 5652adfd1c50b..0deaa33c6adab 100644 --- a/paddle/phi/kernels/fusion/gpu/beam_search_softmax.cu +++ b/paddle/phi/kernels/fusion/gpu/beam_search_softmax.cu @@ -826,6 +826,7 @@ void invokeTopkSoftMax(const Context &dev_ctx, CASE_K(14); CASE_K(15); CASE_K(16); + CASE_K(50); default: PADDLE_THROW(paddle::platform::errors::Unimplemented( "beam_size = %d is unsupport!", beam_size)); diff --git a/paddle/phi/kernels/gpu/fused_moe_kernel.cu b/paddle/phi/kernels/gpu/fused_moe_kernel.cu index 0d72d7e3b058f..d9d5abab64e82 100644 --- a/paddle/phi/kernels/gpu/fused_moe_kernel.cu +++ b/paddle/phi/kernels/gpu/fused_moe_kernel.cu @@ -19,108 +19,10 @@ using Tensor = DenseTensor; namespace framework = paddle::framework; namespace platform = paddle::platform; -template -static void AllToAll(Tensor& tensor, // NOLINT - Tensor& out, - const int ring_id, - const phi::GPUContext& ctx) { - if (ring_id == -1) return; -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) - auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); - - if (map->has(ring_id)) { - paddle::distributed::ProcessGroup* pg = map->get(ring_id); - auto pg_nccl = static_cast(pg); - - std::vector in_tensor; - std::vector out_tensor; - in_tensor.push_back(tensor); - out_tensor.push_back(out); - auto task = pg_nccl->AllToAll(in_tensor, out_tensor, true, true); - task->Wait(); - } else { - auto dtype = platform::ToNCCLDataType( - framework::TransToProtoVarType(tensor.dtype())); - int64_t send_numel = tensor.numel(); // send_numel - auto place = ctx.GetPlace(); - auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); - int nranks = comm->nranks(); - auto stream = ctx.stream(); - - framework::DDim x_dims = tensor.dims(); - framework::DDim out_dims(x_dims); - PADDLE_ENFORCE_EQ( - x_dims[0] % nranks, - 0, - platform::errors::InvalidArgument( - "The first dimension size (%d) of the input tensor must be " - "divisible by the number of ranks (%d).", - x_dims[0], - nranks)); - auto send_buf = tensor.data(); - auto recv_buf = out.mutable_data(out_dims, place); - size_t offset = 0; - send_numel /= nranks; - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); - for (auto i = 0; i < nranks; ++i) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend( - send_buf + offset, send_numel, dtype, i, comm->comm(), stream)); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv( - recv_buf + offset, send_numel, dtype, i, comm->comm(), stream)); - offset += send_numel; - } - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); - } -#else - PADDLE_THROW(platform::errors::Unimplemented( - "PaddlePaddle should compile with NCCL or RCCL when used tensor model " - "parallel op.")); -#endif -} - -template -static void AllGather(Tensor& tensor, // NOLINT - Tensor& out, - const int ring_id, - const phi::GPUContext& ctx) { - if (ring_id == -1) return; -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) - auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); - - if (map->has(ring_id)) { - paddle::distributed::ProcessGroup* pg = map->get(ring_id); - auto pg_nccl = static_cast(pg); - - std::vector in_tensor; - std::vector out_tensor; - in_tensor.push_back(tensor); - out_tensor.push_back(out); - auto task = pg_nccl->AllGather(in_tensor, out_tensor, true, true); - task->Wait(); - } else { - auto dtype = platform::ToNCCLDataType( - framework::TransToProtoVarType(tensor.dtype())); - int64_t numel = tensor.numel(); - auto place = ctx.GetPlace(); - auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); - auto stream = ctx.stream(); - auto out_dims = tensor.dims(); - int nranks = comm->nranks(); - out_dims[0] *= nranks; - out.mutable_data(out_dims, place); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather( - tensor.data(), out.data(), numel, dtype, comm->comm(), stream)); - } -#else - PADDLE_THROW(platform::errors::Unimplemented( - "PaddlePaddle should compile with NCCL or RCCL when used tensor model " - "parallel op.")); -#endif -} - template void FusedMoeKernel(const DeviceContext& dev_ctx, const DenseTensor& x, + const DenseTensor& residual, const DenseTensor& gate_weight, const DenseTensor& gate_bias, const DenseTensor& ln_scale, @@ -138,18 +40,21 @@ void FusedMoeKernel(const DeviceContext& dev_ctx, int world_size, int moe_ring_id, bool approximate, + int bsz, + int seq_len, + int d_model, + int dim_feedforward, DenseTensor* out) { using U = paddle::operators::LayerNormParamType; - // output - dev_ctx.template Alloc(out); // dim auto x_dim = x.dims(); - int bsz = x_dim[0]; - int seq_len = x_dim[1]; + // output + out->Resize(x_dim); + dev_ctx.template Alloc(out); + // auto out_dim = out->dims(); int bsz_seq = bsz * seq_len; - int d_model = x_dim[2]; int tot_expert = world_size * num_expert; - int dim_feedforward = experts_weight1[0]->dims()[1]; + // VLOG(0) << "moe, get dim: bsz_seq:" << bsz_seq << ", x.dim:" << x_dim << ", out.dim:" << out_dim; // pre_layer_norm const U* ln_scale_ptr = ln_scale.data(); @@ -165,6 +70,7 @@ void FusedMoeKernel(const DeviceContext& dev_ctx, Tensor ln_out; ln_out.Resize({{bsz, seq_len, d_model}}); auto *ln_out_data = dev_ctx.template Alloc(&ln_out); + // VLOG(0) << "moe, alloc pre layer norm"; // after slice, bsz_seq should be change int sliced_bsz_seq = bsz_seq; int start = 0; @@ -228,8 +134,11 @@ void FusedMoeKernel(const DeviceContext& dev_ctx, all_gather_out.Resize({{bsz_seq, d_model}}); dev_ctx.template Alloc(&all_gather_out); paddle::operators::DropoutParam dropout_param(false, 0, true, true, 0.0, nullptr, 0); + // for naccl comm + auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); // step1 layer norm + // VLOG(0) << "moe, layer norm"; if (pre_layer_norm) { pre_layernorm_helper.LayerNorm(dev_ctx, x.data(), @@ -241,6 +150,7 @@ void FusedMoeKernel(const DeviceContext& dev_ctx, } else { ln_out = x; } + // VLOG(0) << "moe, resize and slice ln_out"; // step2 resize and slice ln_out ln_out.Resize({{bsz_seq, d_model}}); if (mp_size > 1) { @@ -248,6 +158,7 @@ void FusedMoeKernel(const DeviceContext& dev_ctx, } else { sliced_inp = ln_out; } + // VLOG(0) << "moe, gate & topk"; // step3 gate & topk MatMulAndAdd(dev_ctx, &gate_weight, @@ -268,8 +179,10 @@ void FusedMoeKernel(const DeviceContext& dev_ctx, &topk_idx); // step4 prepare forward // step4.1 number count + // VLOG(0) << "moe, number count"; NumberCountKernel(dev_ctx, topk_idx, tot_expert, &local_expert_count); // step4.2 all_to_all + // VLOG(0) << "moe, all_to_all"; if (world_size > 1) { AllToAll(local_expert_count, global_expert_count, moe_ring_id, dev_ctx); } else { @@ -278,6 +191,7 @@ void FusedMoeKernel(const DeviceContext& dev_ctx, // global expert count resize global_expert_count.Resize({{world_size, num_expert}}); // fwd expert count + // VLOG(0) << "moe, fwd expert count"; SumKernel(dev_ctx, global_expert_count, IntArray({0}), @@ -285,6 +199,7 @@ void FusedMoeKernel(const DeviceContext& dev_ctx, false, &fwd_expert_count); // fwd batch size + // VLOG(0) << "moe, fwd batch size"; SumKernel(dev_ctx, fwd_expert_count, IntArray({}), // axis is None @@ -292,6 +207,7 @@ void FusedMoeKernel(const DeviceContext& dev_ctx, false, &fwd_batch_size); // step4.3 cumsum & assign pos + // VLOG(0) << "moe, cumsum & assign pos"; CumsumKernel(dev_ctx, local_expert_count, Scalar(0), @@ -332,10 +248,12 @@ void FusedMoeKernel(const DeviceContext& dev_ctx, // step 5, MOEScatter // step 5.1, index select // suppose tmp_pos->shape != [0] + // VLOG(0) << "moe, index select"; IndexSelectKernel(dev_ctx, sliced_inp, temp_pos, 0, &index_select_out); if (world_size > 1) { - auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); + // auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); // step 5.2, global_scatter + // VLOG(0) << "moe, global_scatter"; if (map->has(moe_ring_id)) { GlobalScatterProcessGroupFunctor(dev_ctx, &index_select_out, @@ -358,6 +276,7 @@ void FusedMoeKernel(const DeviceContext& dev_ctx, } // step 6, Expert Computation + // VLOG(0) << "moe, Expert Computation"; if (fwd_bsz != 0) { int last_index = 0; for (int idx = 0; idx < num_expert; idx++) { @@ -380,6 +299,7 @@ void FusedMoeKernel(const DeviceContext& dev_ctx, Tensor tmp_inp = global_scatter_out.Slice(last_index, end); // linear1 matmul + // VLOG(0) << "moe, Expert Computation, linear1 mul"; MatMulAndAdd(dev_ctx, experts_weight1[idx], &tmp_inp, @@ -390,6 +310,7 @@ void FusedMoeKernel(const DeviceContext& dev_ctx, &expert_out1, nullptr); // bias gelu + // VLOG(0) << "moe, Expert Computation, add bias & gelu"; fused_act_dropout_helper.DropoutActBias(dev_ctx, expert_out1.data(), experts_bias1[idx]->data(), @@ -405,6 +326,7 @@ void FusedMoeKernel(const DeviceContext& dev_ctx, -127.0, approximate); // linear2 matmul & add + // VLOG(0) << "moe, Expert Computation, linear2 matmul & add"; MatMulAndAdd(dev_ctx, experts_weight2[idx], &act_bias_out, @@ -423,8 +345,9 @@ void FusedMoeKernel(const DeviceContext& dev_ctx, all_expert_out = global_scatter_out; } // step7. MOEGather + // VLOG(0) << "moe, MOEGather"; if (world_size > 1) { - auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); + // auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); // step 7.1, global_gather if (map->has(moe_ring_id)) { GlobalGatherProcessGroupFunctor(dev_ctx, @@ -448,6 +371,7 @@ void FusedMoeKernel(const DeviceContext& dev_ctx, } // step 7.2, local_gather or scatter // suppose pos->shape != [0] + // VLOG(0) << "moe, local_gather or scatter"; ScatterKernel(dev_ctx, moe_gather_out, pos, @@ -456,11 +380,13 @@ void FusedMoeKernel(const DeviceContext& dev_ctx, &moe_gather_out); // step 8, reshape & bmm // moe gather out reshape + // VLOG(0) << "moe, reshape & bmm"; moe_gather_out.Resize({{sliced_bsz_seq, topk, d_model}}); topk_value.Resize({{sliced_bsz_seq, 1, topk}}); BmmKernel(dev_ctx, topk_value, moe_gather_out, &bmm_out); bmm_out.Resize({{sliced_bsz_seq, d_model}}); // step 9, AllGather + // VLOG(0) << "moe, AllGather"; if (mp_size > 1) { // all gather AllGather(bmm_out, all_gather_out, moe_ring_id, dev_ctx); @@ -468,9 +394,20 @@ void FusedMoeKernel(const DeviceContext& dev_ctx, all_gather_out = bmm_out; } // step 10, reshape + // VLOG(0) << "moe, reshape"; all_gather_out.Resize(x_dim); // step 11, add residual - AddKernel(dev_ctx, all_gather_out, x, out); + // VLOG(0) << "moe, add residual"; + AddKernel(dev_ctx, all_gather_out, residual, out); + if (!pre_layer_norm) { + pre_layernorm_helper.LayerNorm(dev_ctx, + out->data(), + ln_scale_ptr, + ln_bias_ptr, + out->data(), + ln_mean_data, + ln_variance_data); + } } } // namespace phi diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py b/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py index b2767b1dd1cbf..c081bffbef993 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py @@ -168,6 +168,7 @@ def _update_list(self): 'fused_feedforward', 'fused_attention', 'fused_multi_transformer', + 'fused_multi_transformer_moe', } # The set of ops that don't support fp16 calculation diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py index b23c94c7e4994..8884caca96ba6 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py @@ -110,7 +110,7 @@ def _keep_fp32_input(op, in_name): return in_name in { 'LnScale', 'LnBias', 'Ln2Scale', 'Ln2Bias', "Ln1Scale", "Ln1Bias" } - if op_type == 'fused_multi_transformer': + if op_type in ['fused_multi_transformer', 'fused_multi_transformer_moe']: return in_name in {'LnScale', 'LnBias', 'FFNLnScale', 'FFNLnBias'} return False diff --git a/python/paddle/fluid/dygraph/amp/auto_cast.py b/python/paddle/fluid/dygraph/amp/auto_cast.py index 98efe71421cf4..6ff2f9dc46c43 100644 --- a/python/paddle/fluid/dygraph/amp/auto_cast.py +++ b/python/paddle/fluid/dygraph/amp/auto_cast.py @@ -192,6 +192,7 @@ def pure_fp16_initialize(models): if isinstance(layer, (paddle.incubate.nn.FusedFeedForward, paddle.incubate.nn.FusedMultiHeadAttention, paddle.incubate.nn.FusedMultiTransformer, + paddle.incubate.nn.FusedMultiTransformerMoe, paddle.incubate.nn.FusedMoELayer)): layer._amp_decorate(dtype='float16') continue diff --git a/python/paddle/incubate/nn/__init__.py b/python/paddle/incubate/nn/__init__.py index 2a2def22bb3bf..c5ca8e38215d7 100644 --- a/python/paddle/incubate/nn/__init__.py +++ b/python/paddle/incubate/nn/__init__.py @@ -18,13 +18,15 @@ from .layer.fused_transformer import FusedMultiTransformer # noqa: F401 from .layer.fused_linear import FusedLinear # noqa: F401 from .layer.fused_transformer import FusedBiasDropoutResidualLayerNorm # noqa: F401 -from .layer.fused_transformer import FusedMoELayer # tianyan01 add +from .layer.fused_transformer import FusedMoELayer # noqa: F401 +from .layer.fused_transformer import FusedMultiTransformerMoe # noqa: F401 __all__ = [ #noqa 'FusedMultiHeadAttention', 'FusedFeedForward', 'FusedTransformerEncoderLayer', 'FusedMultiTransformer', + 'FusedMultiTransformerMoe', 'FusedLinear', 'FusedBiasDropoutResidualLayerNorm', 'FusedMoELayer', diff --git a/python/paddle/incubate/nn/functional/fused_transformer.py b/python/paddle/incubate/nn/functional/fused_transformer.py index a39f1cb94c0c5..cd925f2a3df2e 100644 --- a/python/paddle/incubate/nn/functional/fused_transformer.py +++ b/python/paddle/incubate/nn/functional/fused_transformer.py @@ -847,6 +847,8 @@ def fused_multi_transformer( pre_layer_norm=True, epsilon=1e-05, cache_kvs=None, + beam_offset=None, + seq_lens=None, time_step=None, attn_mask=None, dropout_rate=0.0, @@ -1006,7 +1008,9 @@ def fused_multi_transformer( list(qkv_weights), list(qkv_biases), cache_kvs, + beam_offset, time_step, + seq_lens, attn_mask, list(linear_weights), list(linear_biases), @@ -1054,6 +1058,8 @@ def fused_multi_transformer( inputs['LnScale'] = list(ln_scales) inputs['LnBias'] = list(ln_biases) inputs['QKVW'] = list(qkv_weights) + if seq_lens is not None: + inputs['SeqLengths'] = seq_lens if qkv_biases is not None: inputs['QKVBias'] = list(qkv_biases) if cache_kvs is not None: @@ -1061,6 +1067,8 @@ def fused_multi_transformer( inputs['CacheKV'] = cache_kvs if time_step is not None: inputs['TimeStep'] = time_step + if beam_offset is not None: + inputs['BeamCacheOffset'] = beam_offset inputs['SrcMask'] = attn_mask inputs['OutLinearW'] = list(linear_weights) if linear_biases is not None: diff --git a/python/paddle/incubate/nn/layer/fused_transformer.py b/python/paddle/incubate/nn/layer/fused_transformer.py index b4f173fc725d3..d66097cdf51a1 100644 --- a/python/paddle/incubate/nn/layer/fused_transformer.py +++ b/python/paddle/incubate/nn/layer/fused_transformer.py @@ -1379,7 +1379,7 @@ def get_attr(attrs, idx): self.activation = activation self.name = name - def forward(self, src, attn_mask=None, caches=None, time_step=None): + def forward(self, src, attn_mask=None, caches=None, seq_lens=None, beam_offset=None, time_step=None): """ Applies multi transformer layers on the input. @@ -1429,7 +1429,9 @@ def forward(self, src, attn_mask=None, caches=None, time_step=None): pre_layer_norm=self.normalize_before, epsilon=self._epsilon, cache_kvs=caches, + beam_offset=beam_offset, time_step=time_step, + seq_lens=seq_lens, attn_mask=attn_mask, dropout_rate=self.dropout_rate, activation=self.activation, @@ -1514,6 +1516,7 @@ def __init__(self, self.mp_rank = mp_group.rank self.mp_size = mp_group.nranks self.d_model = d_model + self.dim_feedforward = dim_feedforward self.top_k = top_k self.approximate = approximate self.ln_scale = self.create_parameter( @@ -1586,7 +1589,10 @@ def get_attr(attrs, idx): self.linear2_biases[i].name = "expert_" + self.linear2_biases[i].name def forward(self, inp): - inp = _C_ops.fused_moe_kernel( + bsz = inp.shape[0] + seq_len = inp.shape[1] + out = _C_ops.fused_moe_kernel( + inp, inp, self.gate_weight, self.gate_bias, @@ -1604,9 +1610,13 @@ def forward(self, inp): self.num_expert, self.world_size, -1 if self.group is None else self.group.id, - self.approximate + self.approximate, + bsz, + seq_len, + self.d_model, + self.dim_feedforward ) - return inp + return out def _amp_decorate(self, dtype): # tmp fix for amp.decorator(O2) @@ -1622,3 +1632,329 @@ def trans_to_fp16(l): _ = _to_dtype(self.gate_weight, dtype) _ = _to_dtype(self.gate_bias, dtype) self._dtype = dtype + + +class FusedMultiTransformerMoe(Layer): + """ + FusedMultiTransformerMoe + """ + def __init__( + self, + d_model, + embed_dim, + num_heads, + dim_feedforward, + dropout_rate=0.0, + activation="gelu", + normalize_before=True, + ln_scale_attrs=None, + ln_bias_attrs=None, + qkv_weight_attrs=None, + qkv_bias_attrs=None, + linear_weight_attrs=None, + linear_bias_attrs=None, + gate_weight_attrs=None, + gate_bias_attrs=None, + ffn_ln_scale_attrs=None, + ffn_ln_bias_attrs=None, + expert_weight1_attrs=None, + expert_bias1_attrs=None, + expert_weight2_attrs=None, + expert_bias2_attrs=None, + epsilon=1e-5, + num_layers=-1, + nranks=1, + trans_qkvw=True, + ring_id=-1, + num_expert=1, + top_k=2, + approximate=True, + moe_group=None, + mp_group=None, + name=None, + ): + super(FusedMultiTransformerMoe, self).__init__() + assert embed_dim > 0, ( + "Expected embed_dim to be greater than 0, " + "but received {}".format(embed_dim) + ) + assert ( + num_heads > 0 + ), "Expected nhead to be greater than 0, " "but received {}".format( + num_heads + ) + assert ( + dim_feedforward > 0 + ), "Expected dim_feedforward to be greater than 0, but received {}".format( + dim_feedforward + ) + # only support mp/dp + # for moe config + self.group = moe_group + self.world_size = 1 + if self.group is not None: + self.world_size = self.group.nranks + self.num_expert = num_expert + + self.mp_rank = 0 + self.mp_size = 1 + if mp_group is not None and mp_group.nranks > 1: + self.mp_rank = mp_group.rank + self.mp_size = mp_group.nranks + self.top_k = top_k + self.approximate = approximate + + # origin fmt config + self.normalize_before = normalize_before + self._dtype = self._helper.get_default_dtype() + self._epsilon = epsilon + self._trans_qkvw = trans_qkvw + self._ring_id = ring_id + + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == embed_dim + ), "embed_dim must be divisible by num_heads" + + # tensor model parallel + if nranks > 1: + assert ring_id != -1 + assert num_heads % nranks == 0 + num_heads = num_heads // nranks + + if isinstance(qkv_weight_attrs, (list, tuple, ParameterList)): + num_layers = len(qkv_weight_attrs) + assert num_layers > 0 + + self.ln_scales, self.ln_biases = ParameterList(), ParameterList() + self.qkv_weights, self.qkv_biases = ParameterList(), ParameterList() + self.linear_weights, self.linear_biases = ParameterList(), ParameterList() + self.gate_weights, self.gate_biases = ParameterList(), ParameterList() + self.ffn_ln_scales, self.ffn_ln_biases = ParameterList(), ParameterList() + self.expert_weights1, self.expert_biases1 = ParameterList(), ParameterList() + self.expert_weights2, self.expert_biases2 = ParameterList(), ParameterList() + def get_attr(attrs, idx): + if isinstance(attrs, (list, tuple, ParameterList)): + assert len(attrs) == num_layers + return attrs[idx] + return attrs + + for i in range(num_layers): + ln_scale_attr = get_attr(ln_scale_attrs, i) + ln_bias_attr = get_attr(ln_bias_attrs, i) + qkv_weight_attr = get_attr(qkv_weight_attrs, i) + qkv_bias_attr = get_attr(qkv_bias_attrs, i) + linear_weight_attr = get_attr(linear_weight_attrs, i) + linear_bias_attr = get_attr(linear_bias_attrs, i) + + ffn_ln_scale_attr = get_attr(ffn_ln_scale_attrs, i) + ffn_ln_bias_attr = get_attr(ffn_ln_bias_attrs, i) + gate_weight_attr = get_attr(gate_weight_attrs, i) + gate_bias_attr = get_attr(gate_bias_attrs, i) + + ln_scale = self.create_parameter( + attr=ln_scale_attr, + shape=[embed_dim], + default_initializer=Constant(value=1.0), + dtype="float32", + ) + ln_bias = self.create_parameter( + attr=ln_bias_attr, shape=[embed_dim], is_bias=True, dtype="float32" + ) + qkv_weight = self.create_parameter( + shape=[3, num_heads, self.head_dim, embed_dim] + if trans_qkvw + else [embed_dim, 3, num_heads, self.head_dim], + attr=qkv_weight_attr, + dtype=self._dtype, + is_bias=False, + ) + qkv_bias = self.create_parameter( + shape=[3, num_heads, self.head_dim], + attr=qkv_bias_attr, + dtype=self._dtype, + is_bias=True, + ) + linear_weight = self.create_parameter( + shape=[num_heads * self.head_dim, embed_dim], + attr=linear_weight_attr, + dtype=self._dtype, + is_bias=False, + ) + linear_bias = self.create_parameter( + shape=[embed_dim], + attr=linear_bias_attr, + dtype=self._dtype, + is_bias=True, + ) + + ffn_ln_scale = self.create_parameter( + shape=[embed_dim], + attr=ffn_ln_scale_attr, + is_bias=False, + default_initializer=Constant(1.0), + dtype="float32", + ) + ffn_ln_bias = self.create_parameter( + shape=[embed_dim], attr=ffn_ln_bias_attr, is_bias=True, dtype="float32" + ) + gate_weight = self.create_parameter( + shape=[d_model, num_expert * self.world_size], + attr=gate_weight_attr, + dtype=self._dtype, + is_bias=False + ) + gate_bias = self.create_parameter( + shape=[num_expert * self.world_size], + attr=gate_bias_attr, + dtype=self._dtype, + is_bias=True + ) + + # tensor model parallel + if nranks > 1: + # column parallel + _set_var_distributed(qkv_weight) + _set_var_distributed(qkv_bias) + # row parallel + _set_var_distributed(linear_weight) + + self.ln_scales.append(ln_scale) + self.ln_biases.append(ln_bias) + self.qkv_weights.append(qkv_weight) + self.qkv_biases.append(qkv_bias) + self.linear_weights.append(linear_weight) + self.linear_biases.append(linear_bias) + + self.ffn_ln_scales.append(ffn_ln_scale) + self.ffn_ln_biases.append(ffn_ln_bias) + self.gate_weights.append(gate_weight) + self.gate_biases.append(gate_bias) + + for j in range(num_expert): + expert_weight1_attr = get_attr(expert_weight1_attrs, i * num_expert + j) + expert_bias1_attr = get_attr(expert_bias1_attrs, i * num_expert + j) + expert_weight2_attr = get_attr(expert_weight2_attrs, i * num_expert + j) + expert_bias2_attr = get_attr(expert_bias2_attrs, i * num_expert + j) + + expert_weight1 = self.create_parameter( + shape=[d_model, dim_feedforward], + attr=expert_weight1_attr, + dtype=self._dtype, + is_bias=False, + default_initializer=nn.initializer.KaimingUniform() + ) + expert_bias1 = self.create_parameter( + shape=[dim_feedforward], + attr=expert_bias1_attr, + dtype=self._dtype, + is_bias=True, + default_initializer=nn.initializer.Constant(value=0.0) + ) + expert_weight2 = self.create_parameter( + shape=[dim_feedforward, d_model], + attr=expert_weight2_attr, + dtype=self._dtype, + is_bias=False, + default_initializer=nn.initializer.KaimingUniform() + ) + expert_bias2 = self.create_parameter( + shape=[d_model], + attr=expert_bias2_attr, + dtype=self._dtype, + is_bias=True, + default_initializer=nn.initializer.Constant(value=0.0) + ) + expert_weight1.name = "expert_" + expert_weight1.name + expert_bias1.name = "expert_" + expert_bias1.name + expert_weight2.name = "expert_" + expert_weight2.name + expert_bias2.name = "expert_" + expert_bias2.name + self.expert_weights1.append(expert_weight1) + self.expert_biases1.append(expert_bias1) + self.expert_weights2.append(expert_weight2) + self.expert_biases2.append(expert_bias2) + self.dropout_rate = dropout_rate + self.activation = activation + self.name = name + + def forward(self, src, attn_mask=None, caches=None, seq_lens=None, beam_offset=None, time_step=None): + """ + forward + """ + cache_kv_out, final_out = _legacy_C_ops.fused_multi_transformer_moe( + src, + list(self.ln_scales), + list(self.ln_biases), + list(self.qkv_weights), + list(self.qkv_biases), + caches, + beam_offset, + time_step, + seq_lens, + attn_mask, + list(self.linear_weights), + list(self.linear_biases), + list(self.gate_weights), + list(self.gate_biases), + list(self.ffn_ln_scales), + list(self.ffn_ln_biases), + list(self.expert_weights1), + list(self.expert_biases1), + list(self.expert_weights2), + list(self.expert_biases2), + caches, + 'pre_layer_norm', + self.normalize_before, + 'epsilon', + self._epsilon, + 'dropout_rate', + self.dropout_rate, + 'is_test', + not self.training, + 'dropout_implementation', + 'upscale_in_train', + 'act_method', + self.activation, + 'trans_qkvw', + self._trans_qkvw, + 'ring_id', + self._ring_id, + 'topk', + self.top_k, + 'mp_size', + self.mp_size, + 'mp_rank', + self.mp_rank, + 'num_expert', + self.num_expert, + 'world_size', + self.world_size, + 'moe_ring_id', + -1 if self.group is None else self.group.id, + 'approximate', + self.approximate + ) + if caches is not None: + return final_out, cache_kv_out + return final_out + + def _amp_decorate(self, dtype): + # tmp fix for amp.decorator(O2) + def trans_to_fp16(l): + for param in l: + if param is not None: + with no_grad(): + param_applied = _to_dtype(param, dtype) + trans_to_fp16(self.qkv_weights) + trans_to_fp16(self.qkv_biases) + trans_to_fp16(self.linear_weights) + trans_to_fp16(self.linear_biases) + trans_to_fp16(self.gate_weights) + trans_to_fp16(self.gate_biases) + trans_to_fp16(self.expert_weights1) + trans_to_fp16(self.expert_biases1) + trans_to_fp16(self.expert_weights2) + trans_to_fp16(self.expert_biases2) + self._dtype = dtype \ No newline at end of file From 19dc46926917e7d0cb12687df71afd28f2408fe9 Mon Sep 17 00:00:00 2001 From: tianyan01 Date: Fri, 17 Nov 2023 11:39:55 +0800 Subject: [PATCH 6/9] rm some logs --- .../fluid/operators/fused/fused_multi_transformer_moe_op.cu | 3 --- paddle/phi/kernels/fused_moe_kernel.h | 6 ------ paddle/phi/kernels/gpu/fused_moe_kernel.cu | 4 +--- 3 files changed, 1 insertion(+), 12 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_moe_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_moe_op.cu index 91fa104498bf4..6e6b41dd6ab74 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_moe_op.cu +++ b/paddle/fluid/operators/fused/fused_multi_transformer_moe_op.cu @@ -429,7 +429,6 @@ class FusedMultiTransformerMoeOpKernel : public framework::OpKernel { } else { sliced_inp = buf0; } - VLOG(0) << "sliced_inp shape[0]: " << sliced_inp.dims()[0]; #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "moe, gate & topk"; #endif @@ -825,12 +824,10 @@ class FusedMultiTransformerMoeOpKernel : public framework::OpKernel { // phi::AddKernel(dev_ctx, all_gather_out, bias_dropout_residual_out, &moe_out); x_data = moe_out.data(); - VLOG(0) << "layer " << i + 1 << " end"; } // layers loop end moe_out.Resize({{bsz, seq_len, dim_embed}}); *out = moe_out; - VLOG(0) << "kernel end"; } }; diff --git a/paddle/phi/kernels/fused_moe_kernel.h b/paddle/phi/kernels/fused_moe_kernel.h index a4abd127beb9a..d6a1f112c762c 100644 --- a/paddle/phi/kernels/fused_moe_kernel.h +++ b/paddle/phi/kernels/fused_moe_kernel.h @@ -72,7 +72,6 @@ static void AllToAll(Tensor& tensor, // NOLINT out_tensor.push_back(out); auto task = pg_nccl->AllToAll(in_tensor, out_tensor); task->Wait(); - VLOG(0) << "wait, all to all success !"; } else { auto dtype = platform::ToNCCLDataType( framework::TransToProtoVarType(tensor.dtype())); @@ -245,7 +244,6 @@ void GlobalScatterFunctor(const phi::GPUContext& ctx, } #ifdef PADDLE_WITH_CUDA PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); - // VLOG(0) << "GlobalScatterFunctor cudaDeviceSynchronize success !"; #else PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize()); #endif @@ -334,12 +332,10 @@ void GlobalScatterProcessGroupFunctor(const phi::GPUContext& ctx, } } PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); - // VLOG(0) << "GlobalScatterProcessGroupFunctor ncclGroupEnd " << i; } #ifdef PADDLE_WITH_CUDA PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); - // VLOG(0) << "GlobalScatterProcessGroupFunctor cudaDeviceSynchronize success!"; #else PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize()); #endif @@ -449,7 +445,6 @@ void GlobalGatherFunctor(const phi::GPUContext& ctx, #ifdef PADDLE_WITH_CUDA PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); - // VLOG(0) << "GlobalGatherFunctor cudaDeviceSynchronize success !"; #else PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize()); #endif @@ -542,7 +537,6 @@ void GlobalGatherProcessGroupFunctor(const phi::GPUContext& ctx, #ifdef PADDLE_WITH_CUDA PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); - // VLOG(0) << "GlobalGatherProcessGroupFunctor cudaDeviceSynchronize success !"; #else PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize()); #endif diff --git a/paddle/phi/kernels/gpu/fused_moe_kernel.cu b/paddle/phi/kernels/gpu/fused_moe_kernel.cu index d9d5abab64e82..657f53b9e29a8 100644 --- a/paddle/phi/kernels/gpu/fused_moe_kernel.cu +++ b/paddle/phi/kernels/gpu/fused_moe_kernel.cu @@ -54,7 +54,6 @@ void FusedMoeKernel(const DeviceContext& dev_ctx, // auto out_dim = out->dims(); int bsz_seq = bsz * seq_len; int tot_expert = world_size * num_expert; - // VLOG(0) << "moe, get dim: bsz_seq:" << bsz_seq << ", x.dim:" << x_dim << ", out.dim:" << out_dim; // pre_layer_norm const U* ln_scale_ptr = ln_scale.data(); @@ -70,7 +69,6 @@ void FusedMoeKernel(const DeviceContext& dev_ctx, Tensor ln_out; ln_out.Resize({{bsz, seq_len, d_model}}); auto *ln_out_data = dev_ctx.template Alloc(&ln_out); - // VLOG(0) << "moe, alloc pre layer norm"; // after slice, bsz_seq should be change int sliced_bsz_seq = bsz_seq; int start = 0; @@ -418,4 +416,4 @@ PD_REGISTER_KERNEL(fused_moe_kernel, phi::FusedMoeKernel, float, double, - paddle::platform::float16) {} \ No newline at end of file + paddle::platform::float16) {} From e54fcce2384d7d4018b9f2dd3a1259f3e6335f0e Mon Sep 17 00:00:00 2001 From: wanglipeng Date: Sun, 5 Nov 2023 22:16:23 +0800 Subject: [PATCH 7/9] trie search --- cmake/cuda.cmake | 5 +- paddle/fluid/framework/CMakeLists.txt | 7 + paddle/fluid/framework/trie.cc | 160 ++++++++++++++++++++++ paddle/fluid/framework/trie.h | 167 +++++++++++++++++++++++ paddle/fluid/framework/trie_manager.cc | 156 +++++++++++++++++++++ paddle/fluid/framework/trie_manager.cu | 38 ++++++ paddle/fluid/framework/trie_manager.h | 147 ++++++++++++++++++++ paddle/fluid/operators/CMakeLists.txt | 2 +- paddle/fluid/operators/trie_search_op.cc | 72 ++++++++++ paddle/fluid/operators/trie_search_op.cu | 133 ++++++++++++++++++ paddle/fluid/operators/trie_search_op.h | 42 ++++++ paddle/fluid/pybind/box_helper_py.cc | 33 +++++ paddle/fluid/pybind/box_helper_py.h | 1 + paddle/fluid/pybind/pybind.cc | 1 + python/paddle/fluid/layers/nn.py | 35 +++++ 15 files changed, 996 insertions(+), 3 deletions(-) create mode 100644 paddle/fluid/framework/trie.cc create mode 100644 paddle/fluid/framework/trie.h create mode 100644 paddle/fluid/framework/trie_manager.cc create mode 100644 paddle/fluid/framework/trie_manager.cu create mode 100644 paddle/fluid/framework/trie_manager.h create mode 100644 paddle/fluid/operators/trie_search_op.cc create mode 100644 paddle/fluid/operators/trie_search_op.cu create mode 100644 paddle/fluid/operators/trie_search_op.h diff --git a/cmake/cuda.cmake b/cmake/cuda.cmake index ab48a16529e63..578e4979636ef 100644 --- a/cmake/cuda.cmake +++ b/cmake/cuda.cmake @@ -28,7 +28,8 @@ elseif(NEW_RELEASE_JIT) else() set(paddle_known_gpu_archs "35 50 52 60 61 70 75 80") set(paddle_known_gpu_archs10 "35 50 52 60 61 70 75") - set(paddle_known_gpu_archs11 "52 60 61 70 75 80") + #set(paddle_known_gpu_archs11 "52 60 61 70 75 80") + set(paddle_known_gpu_archs11 "70 80") endif() ###################################################################################### @@ -161,7 +162,7 @@ function(select_nvcc_arch_flags out_variable) if(WITH_NV_JETSON) set(cuda_arch_bin "72") else() - set(cuda_arch_bin "70") + set(cuda_arch_bin "70 80") endif() elseif(${CUDA_ARCH_NAME} STREQUAL "Turing") set(cuda_arch_bin "75") diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 6e3897717596b..eb39600660dd7 100755 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -41,6 +41,13 @@ endif() proto_library(framework_proto SRCS framework.proto) proto_library(pass_desc_proto SRCS pass_desc.proto DEPS framework_proto) +if(WITH_GPU) + nv_library( + trie_manager + SRCS trie_manager.cc trie_manager.cu trie.cc + DEPS tensor device_context math_function) +endif() + proto_library(op_def_proto SRCS op_def.proto DEPS framework_proto) cc_library( op_def_api diff --git a/paddle/fluid/framework/trie.cc b/paddle/fluid/framework/trie.cc new file mode 100644 index 0000000000000..3ef91c6849f58 --- /dev/null +++ b/paddle/fluid/framework/trie.cc @@ -0,0 +1,160 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "trie.h" +#include "paddle/phi/core/enforce.h" + +namespace paddle { +namespace framework { +#define ENFORCE PADDLE_ENFORCE + +// Trie +int Trie::load(const std::string& dir, const uint32_t thr_num) { + std::string list_file = dir + "/file_list"; + std::ifstream ifs(list_file.c_str()); + if (!ifs.is_open()) { + printf("open file %s failed\n", list_file.c_str()); + return -1; + } + + std::vector files; + std::string line; + uint32_t node_num = 0; + while (getline(ifs, line)) { + std::stringstream ss(line); + File file; + ss >> file.filename; + ss >> file.node_num; + + file.filename = dir + "/" + file.filename; + file.node_off = node_num; + node_num += file.node_num; + files.emplace_back(std::move(file)); + } + printf("total file_num: %zu, node_num: %u\n", files.size(), node_num); + + { + ScopedNanoTimer t("Trie stat"); + parallel_run_range(files.size(), thr_num, + [this, &files](uint32_t thr_id, uint32_t start, uint32_t end) { + for (uint32_t i = start; i < end; ++i) { + stat_file(thr_id, files.at(i)); + } + }); + } + + Node root; + for (auto& file: files) { + root.child.insert(root.child.end(), file.root.begin(), file.root.end()); + } + + { + ScopedNanoTimer t("Trie resize"); + resize(node_num + 1); // +1 for root + + size_t off = root.child.size(); + for (size_t i = 0; i < files.size(); ++i) { + mem_off(files[i].node_off + 1) = off; //+1 for root + ENFORCE(files[i].node_num >= files[i].root.size()); + off += files[i].node_num - files[i].root.size(); + } + ENFORCE(off == node_num); + } + + { + ScopedNanoTimer t("Trie load"); + parallel_run_range(files.size(), thr_num, + [this, &files](uint32_t thr_id, uint32_t start, uint32_t end) { + for (size_t i = start; i < end; ++i) { + load_file(thr_id, files.at(i)); + } + }); + } + + link(root); + + return 0; +} + +void Trie::parse(std::string& line, Node& node, uint32_t off) { + node.child.clear(); + + char* str = const_cast(line.c_str()); + char* endstr = nullptr; + size_t len = 0; + + node.id = std::strtoul(str, &endstr, 10) + off; + str = endstr; + ENFORCE(*str == '\t'); + ++str; + + node.label = std::strtoul(str, &endstr, 10); + str = endstr; + ENFORCE(*str == '\t'); + ++str; + + len = std::strtoul(str, &endstr, 10); + str = endstr; + for (size_t k = 0; k < len; ++k) { + node.child.push_back(std::strtoul(str, &endstr, 10) + off); + ENFORCE(str != endstr); + str = endstr; + ++str; + } + + node.aleaf = std::strtoul(str, &endstr, 10); + str = endstr; + ENFORCE(*str == '\0'); +} + +void Trie::stat_file(uint32_t thr_id, File& file) { + printf("stat file %s\n", file.filename.c_str()); + Node node; + + std::ifstream ifs(file.filename.c_str()); + ENFORCE(ifs.is_open(), "open file %s failed\n", file.filename.c_str()); + + std::string line; + getline(ifs, line); + + parse(line, node, file.node_off); + file.root = std::move(node.child); +} + +void Trie::load_file(uint32_t thr_id, File& file) { + printf("load file %s\n", file.filename.c_str()); + + std::ifstream ifs(file.filename.c_str()); + ENFORCE(ifs.is_open(), "open file %s failed\n", file.filename.c_str()); + + Node node; + std::string line; + // don't link root + if (getline(ifs, line)) { + parse(line, node, file.node_off); + file.root = std::move(node.child); + } + + while(getline(ifs, line)) { + parse(line, node, file.node_off); + link(node); + } +} + +#undef ENFORCE +} // end namespace framework +} // end namespace paddle diff --git a/paddle/fluid/framework/trie.h b/paddle/fluid/framework/trie.h new file mode 100644 index 0000000000000..e0d72d55ae10a --- /dev/null +++ b/paddle/fluid/framework/trie.h @@ -0,0 +1,167 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace paddle { +namespace framework { + +template +void parallel_run_range(uint32_t n, uint32_t thr_num, T&& func) { + std::vector threads; + for (size_t i = 0; i < thr_num; ++i) { + threads.emplace_back(std::thread([i, n, thr_num, &func]() { + func(i, n * i / thr_num, n * (i + 1) / thr_num); + })); + } + for (auto& t : threads) { + t.join(); + } +} + +class ScopedNanoTimer { +public: + ScopedNanoTimer(const std::string& n) : t0(std::chrono::high_resolution_clock::now()), m(n) { + }; + ~ScopedNanoTimer() { + auto t1 = std::chrono::high_resolution_clock::now(); + auto nanos = std::chrono::duration_cast(t1-t0).count(); + printf("%s cost %fs\n", m.c_str(), nanos/1000000000.0); + } + +protected: + std::chrono::high_resolution_clock::time_point t0; + std::string m; +}; + +class Trie { +struct File { + std::string filename; + std::vector root; + uint32_t node_num = 0; + uint32_t node_off = 0; +}; + +struct Node { + uint32_t id = 0; + uint16_t label = 0; + std::vector child; + uint8_t aleaf = 0; +}; + +public: + Trie() {} + virtual ~Trie() {} + int load(const std::string& dir, const uint32_t thr_num=20u); + + uint16_t label(uint32_t id) { + return label_.at(id); + } + + uint8_t aleaf(uint32_t id) { + return aleaf_.at(id); + } + + void child(uint32_t id, std::vector& child) { + child.clear(); + size_t s = mem_off(id); + size_t e = mem_off(id + 1); + for (size_t i = s; i < e; ++i) { + child.push_back(child_mem_.at(i)); + } + } + + size_t child_size(uint32_t id) { + size_t s = mem_off(id); + size_t e = mem_off(id + 1); + + return e - s; + } + + size_t child_at(uint32_t id, size_t i) { + size_t s = mem_off(id); + + return child_mem_.at(s + i); + } + + void print() { + // level order traversal + std::queue q; + q.push(0); + std::vector child; + + while(!q.empty()) { + size_t len = q.size(); + for (size_t i = 0; i < len; ++i) { + uint32_t id = q.front(); + q.pop(); + + printf("[#%u,%u,%u,<", id, label(id), aleaf(id)); + this->child(id, child); + for (auto j : child) { + q.push(j); + printf("#%u,", j); + } + printf(">] "); + } + printf("\n"); + } + } + +protected: + void resize(uint32_t node_num) { + label_.resize(node_num); + aleaf_.resize(node_num); + child_mem_.resize(node_num); + mem_off_.resize(node_num + 1, 0); + } + + uint32_t& mem_off(uint32_t id) { + return mem_off_.at(id); + } + + void link(const Node& node) { + label_.at(node.id) = node.label; + aleaf_.at(node.id) = node.aleaf; + + uint32_t addr = mem_off(node.id); + for (size_t i = 0; i < node.child.size(); ++i) { + child_mem_.at(addr++) = node.child[i]; + } + if (mem_off(node.id + 1) == 0) { + mem_off(node.id + 1) = addr; + } + } + + void parse(std::string& line, Node& node, uint32_t off=0); + void load_file(uint32_t thr_id, File& file); + void stat_file(uint32_t thr_id, File& file); + + std::vector label_; + std::vector aleaf_; + std::vector child_mem_; + std::vector mem_off_; +}; + +} // end namespace framework +} // end namespace paddle \ No newline at end of file diff --git a/paddle/fluid/framework/trie_manager.cc b/paddle/fluid/framework/trie_manager.cc new file mode 100644 index 0000000000000..1842b4e7badc1 --- /dev/null +++ b/paddle/fluid/framework/trie_manager.cc @@ -0,0 +1,156 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/trie_manager.h" +#include + +namespace paddle { +namespace framework { +std::shared_ptr TrieManager::_s_instance = nullptr; + +void TrieManager::reset() { + VLOG(3) << "trie reset..."; + std::unique_lock lock(mtx_); + + // Todo: cache level1 + parent_idx_.mutable_data({1}, phi::GPUPinnedPlace()); + int64_t* parent_idx = parent_idx_.data(); + parent_idx[0] = 0; + + select_ids_.mutable_data({1}, phi::GPUPinnedPlace()); + int64_t* select_ids = select_ids_.data(); + select_ids[0] = 0; + + label2node_.resize(1); + label2node_[0][0] = 0; + + phase_ = Phase::run; + cv_.notify_one(); +} + +void TrieManager::search_start(const Tensor* d_parent, const Tensor* d_select) { + VLOG(3) << "trie search_start: " << d_parent->dims() << " # " << d_select->dims(); + if (d_parent->numel() == 0) { + return; + } + + std::unique_lock lock(mtx_); + + TensorCopySync(*d_parent, phi::GPUPinnedPlace(), &parent_idx_); + TensorCopySync(*d_select, phi::GPUPinnedPlace(), &select_ids_); + + phase_ = Phase::run; + cv_.notify_one(); +} + +void TrieManager::search_wait() { + VLOG(3) << "trie search_wait"; + + std::unique_lock lock(mtx_); + while (phase_ != Phase::done) { + cv_.wait(lock); + } +} + +void TrieManager::run() { + uint32_t thr_num = 10; + + for (;;) { + std::unique_lock lock(mtx_); + while (phase_ != Phase::run) { + cv_.wait(lock); + if (phase_ == Phase::stop) { + VLOG(3) << "phase_ == stop and exit"; + return; + } + } + VLOG(3) << "trie run, parent_idx: " << parent_idx_ << "\n, select_ids: " << select_ids_; + VLOG(3) << "label2node_ size: " << label2node_.size(); + + // 1. + int numel = parent_idx_.numel(); + PADDLE_ENFORCE(numel == select_ids_.numel()); + int64_t* parent_idx = parent_idx_.data(); + int64_t* select_ids = select_ids_.data(); + + std::vector> label2node(numel); + std::vector> outs(numel); + parallel_run_range(numel, thr_num, [this, parent_idx, select_ids, &outs, &label2node] ( + uint32_t thr_id, uint32_t start, uint32_t end) { + for (size_t i = start; i < end; ++i) { + auto& out = outs.at(i); + + int64_t idx = parent_idx[i]; + if (label2node_.size() == 1) { + idx = 0; + } + + auto& l2n_ = label2node_.at(idx); + auto& l2n = label2node.at(i); + auto it = l2n_.find(select_ids[i]); + if (it == l2n_.end()) { + out.push_back(endid_); + l2n.insert({endid_, 0}); + continue; + } + + size_t chs = trie_.child_size(it->second); + if (chs == 0 || trie_.aleaf(it->second)) { + out.push_back(endid_); + l2n.insert({endid_, 0}); + } + + for (size_t j = 0; j < chs; ++j) { + uint32_t cid = trie_.child_at(it->second, j); + uint32_t lab = trie_.label(cid); + + out.push_back(lab); + l2n.insert({lab, cid}); + } + } + }); + label2node_.swap(label2node); + + numel = 0; + for (size_t i = 0; i < outs.size(); ++i) { + numel += outs[i].size(); + } + + auto next_out = next_out_.mutable_data({numel, 1}, phi::GPUPinnedPlace()); + auto next_lod = next_lod_.mutable_data({int(outs.size())+1, 1}, phi::GPUPinnedPlace()); + + // 2. + next_lod[0] = 0; + int k = 0; + for (size_t i = 0; i < outs.size(); ++i) { + for (size_t j = 0; j < outs[i].size(); ++j) { + next_out[k] = outs[i][j]; + ++k; + } + next_lod[i+1] = next_lod[i] + int64_t(outs[i].size()); + } + + VLOG(3) << "out " << next_out_ << "\n lod " << next_lod_; + + // 3. + TensorCopySync(next_out_, place_, &next_out_d_); + TensorCopySync(next_lod_, place_, &next_lod_d_); + + phase_ = Phase::done; + cv_.notify_one(); + } +} + +} // end namespace framework +} // end namespace paddle diff --git a/paddle/fluid/framework/trie_manager.cu b/paddle/fluid/framework/trie_manager.cu new file mode 100644 index 0000000000000..5b2dfe7e095fa --- /dev/null +++ b/paddle/fluid/framework/trie_manager.cu @@ -0,0 +1,38 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/trie_manager.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" + +namespace paddle { +namespace framework { +const int CUDA_NUM_THREADS = platform::PADDLE_CUDA_NUM_THREADS; +#define GET_BLOCK(N) ((N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS) + +#define CUDA_KERNEL_LOOP(i, n) \ + for (auto i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) +#define CUDA_BLOCK(N) GET_BLOCK(N), CUDA_NUM_THREADS, 0 + +__global__ void fill_mask_kernel(const int len, const int N, const int64_t* pos, int64_t* arr) { + CUDA_KERNEL_LOOP(i, len) { + int x = pos[i*2]; + int y = pos[i*2+1]; + + arr[x * N + y] = 1; + } +} + +} // end namespace framework +} // end namespace paddle \ No newline at end of file diff --git a/paddle/fluid/framework/trie_manager.h b/paddle/fluid/framework/trie_manager.h new file mode 100644 index 0000000000000..456fa06a49f1d --- /dev/null +++ b/paddle/fluid/framework/trie_manager.h @@ -0,0 +1,147 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/platform/place.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "trie.h" + +namespace paddle { +namespace framework { + +template +std::string PrintTensor(const framework::Tensor& tensor, + int64_t print_len = 2000) { + std::stringstream sstream; + + if (print_len == -1) { + print_len = tensor.numel(); + } + print_len = std::min(tensor.numel(), print_len); + + const T* data = nullptr; + if (platform::is_cpu_place(tensor.place())) { + data = tensor.data(); + } else { + framework::Tensor cpu_tensor; + platform::CPUPlace cpu_place; + TensorCopySync(tensor, cpu_place, &cpu_tensor); + data = cpu_tensor.data(); + } + + sstream << "\n - dtype: " << tensor.dtype(); + sstream << "\n - place: " << tensor.place(); + sstream << "\n - dims: " << tensor.dims(); + sstream << "\n - data: ["; + if (print_len > 0) { + sstream << data[0]; + for (int64_t i = 1; i < print_len; ++i) { + sstream << " " << data[i]; + } + } + sstream << "]" << std::endl; + + return sstream.str(); +} + +class TrieManager { +enum class Phase { + init, + run, + done, + stop +}; + +public: + TrieManager(uint16_t endid) : endid_(endid), + place_(platform::GetCurrentDeviceId()) { + thread_ = std::thread(&TrieManager::run, this); + } + + ~TrieManager() { + { + std::unique_lock lock(mtx_); + phase_ = Phase::stop; + } + cv_.notify_all(); + if (thread_.joinable()) { + thread_.join(); + } + } + + static std::shared_ptr GetInstance() { + PADDLE_ENFORCE_EQ( + _s_instance == nullptr, + false, + platform::errors::PreconditionNotMet( + "GetInstance failed in TrieManager, you should use SetInstance firstly")); + return _s_instance; + } + + static std::shared_ptr SetInstance(uint16_t endid) { + static std::mutex mutex; + std::lock_guard lock(mutex); + if (nullptr == _s_instance) { + VLOG(3) << "TrieManager _s_instance is null"; + _s_instance.reset(new TrieManager(endid)); + } else { + LOG(WARNING) << "You have already used TrieManager SetInstance() before"; + } + + return _s_instance; + } + + int load(const std::string& dir, const uint32_t thr_num=20) { + return trie_.load(dir, thr_num); + } + void reset(); + void search_start(const Tensor* d_parent, const Tensor* d_select); + void search_wait(); + + // gpu + Tensor next_out_d_; + Tensor next_lod_d_; + +protected: + static std::shared_ptr _s_instance; + + // cpu + Tensor parent_idx_; + Tensor select_ids_; + std::vector> label2node_; + + // cpu + Tensor next_out_; + Tensor next_lod_; + + Trie trie_; + size_t endid_; + size_t vocab_size_; + + std::mutex mtx_; + std::condition_variable cv_; + std::thread thread_; + Phase phase_ = Phase::init; + + platform::CUDAPlace place_; + + void run(); +}; + +} // end namespace framework +} // end namespace paddle \ No newline at end of file diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index ac1d89ede5021..4fd94be9e509d 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -158,7 +158,7 @@ sequence_pooling executor device_memory_aligment generator) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc_functor matrix_inverse matrix_solve) -set(COMMON_OP_DEPS ${COMMON_OP_DEPS} box_wrapper ps_gpu_wrapper) +set(COMMON_OP_DEPS ${COMMON_OP_DEPS} box_wrapper trie_manager ps_gpu_wrapper) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} common_infer_shape_functions) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} eigen_function) if (WITH_GPU OR WITH_ROCM) diff --git a/paddle/fluid/operators/trie_search_op.cc b/paddle/fluid/operators/trie_search_op.cc new file mode 100644 index 0000000000000..6325389eb121f --- /dev/null +++ b/paddle/fluid/operators/trie_search_op.cc @@ -0,0 +1,72 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/trie_search_op.h" + +namespace paddle { +namespace operators { + +class TrieSearchStartOp : public framework::OperatorWithKernel { +public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + ctx->ShareDim("ids", "Out"); + ctx->ShareLoD("ids", "Out"); + } +}; + +class TrieSearchStartOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("parent_idx", "(Tensor), The parent_idx input tensor of trie_search_start op."); + AddInput("ids", "(Tensor), The ids input tensor of trie_search_start op."); + AddOutput("Out", "(Tensor), The output tensor of trie_search_start op."); + AddComment(R"DOC(trie_search_start)DOC"); + } +}; + +class TrieSearchWaitOp : public framework::OperatorWithKernel { +public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ( + ctx->HasInput("X"), true, + platform::errors::NotFound( + "Input(X) of TrieSearchWaitOp should not be null.")); + + ctx->ShareDim("X", "Out"); + ctx->ShareLoD("X", "Out"); + } +}; + +class TrieSearchWaitOpMaker : public framework::OpProtoAndCheckerMaker { +public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of trie_search_wait op."); + AddOutput("Out", "(Tensor), The output tensor of trie_search_wait op."); + AddComment(R"DOC(trie_search_wait)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(trie_search_start, ops::TrieSearchStartOp, + ops::TrieSearchStartOpMaker) +REGISTER_OP_WITHOUT_GRADIENT(trie_search_wait, ops::TrieSearchWaitOp, + ops::TrieSearchWaitOpMaker) + +REGISTER_OP_CPU_KERNEL(trie_search_start, ops::TrieSearchStartCPUKernel) +REGISTER_OP_CPU_KERNEL(trie_search_wait, ops::TrieSearchWaitCPUKernel) \ No newline at end of file diff --git a/paddle/fluid/operators/trie_search_op.cu b/paddle/fluid/operators/trie_search_op.cu new file mode 100644 index 0000000000000..482212ca9cbd1 --- /dev/null +++ b/paddle/fluid/operators/trie_search_op.cu @@ -0,0 +1,133 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/trie_search_op.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/fluid/framework/trie_manager.h" + +namespace paddle { +namespace operators { + +const int CUDA_NUM_THREADS = platform::PADDLE_CUDA_NUM_THREADS; +#define GET_BLOCK(N) ((N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS) + +#define CUDA_KERNEL_LOOP(i, n) \ + for (auto i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) +#define CUDA_BLOCK(N) GET_BLOCK(N), CUDA_NUM_THREADS, 0 + +__global__ void MaskKernel(const int64_t len, const int64_t* lod, const int64_t* idx, int64_t C, int64_t N, float* arr) { + CUDA_KERNEL_LOOP(i, len) { + int low = 0; + int high = C - 1; + while (low < high) { + int mid = (low + high) / 2; + if (i < lod[mid + 1]) { + high = mid; + } else { + low = mid + 1; + } + } + + int x = low * N + idx[i]; + arr[x] = 0; + } +} + +__global__ void MaskKernel1D(const int64_t total, const int64_t idx_len, const int64_t* idx, int64_t C,int64_t N, float* arr) { + CUDA_KERNEL_LOOP(i, total) { + int64_t c = i / idx_len; + int64_t x = c * N + idx[i % idx_len]; + arr[x] = 0; + } +} + +template +class TrieSearchStartCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + VLOG(3) << "TrieSearchStartCUDAKernel"; + auto s = framework::TrieManager::GetInstance(); + const framework::Tensor* idx = context.Input("parent_idx"); + const framework::Tensor* ids =context.Input("ids"); + s->search_start(idx, ids); + + framework::Tensor* out = context.Output("Out"); + + framework::TensorCopy(*ids, ids->place(), out); + } +}; + +template +class TrieSearchWaitCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + VLOG(3) << "TrieSearchWaitCUDAKernel"; + + auto s = framework::TrieManager::GetInstance(); + auto place = context.GetPlace(); + auto gpu_ctx = dynamic_cast( + platform::DeviceContextPool::Instance().Get(place)); + + const framework::Tensor* mask = context.Input("X"); // b,s,V + framework::Tensor* out = context.Output("Out"); // b,s,V + s->search_wait(); // b,s,x + + auto dims = mask->dims(); + PADDLE_ENFORCE_EQ( + dims.size(), 2, + platform::errors::InvalidArgument( + "Input X'dim size should be equal to 2. " + "But received X's shape = [%s].", + dims)); + + const int C = dims[0]; + const int N = dims[1]; + + int64_t* idx = s->next_out_d_.data(); + int64_t* lod = s->next_lod_d_.data(); // b,s,x + + VLOG(3) << "mask" << framework::PrintTensor(*mask, 100); + VLOG(3) << "idx" << framework::PrintTensor(s->next_out_d_); + VLOG(3) << "lod" << framework::PrintTensor(s->next_lod_d_); + + framework::TensorCopy(*mask, mask->place(), out); + int64_t len = s->next_out_d_.numel(); + float* arr = out->data(); + + if (s->next_lod_d_.numel() == 2) { + MaskKernel1D<<stream()>>>(len*C, len, idx, C, N, arr); + } else { + PADDLE_ENFORCE_EQ( + C+1, + s->next_lod_d_.numel(), + platform::errors::InvalidArgument("C+1 != lod %s %s", C+1, s->next_lod_d_.numel())); + MaskKernel<<stream()>>>(len, lod, idx, C, N, arr); + } + + gpu_ctx->Wait(); + VLOG(3) << "out" << framework::PrintTensor(*out, -1); + // std::cout << "out " << framework::PrintTensor(*out, -1); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL(trie_search_start, + ops::TrieSearchStartCUDAKernel) +REGISTER_OP_CUDA_KERNEL(trie_search_wait, + ops::TrieSearchWaitCUDAKernel) \ No newline at end of file diff --git a/paddle/fluid/operators/trie_search_op.h b/paddle/fluid/operators/trie_search_op.h new file mode 100644 index 0000000000000..514dac196fd69 --- /dev/null +++ b/paddle/fluid/operators/trie_search_op.h @@ -0,0 +1,42 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tensor.h" + +namespace paddle { +namespace operators { + +template +class TrieSearchStartCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "Not implemented TrieSearchStartCPUKernel.")); + } +}; + +template +class TrieSearchWaitCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "Not implemented TrieSearchWaitCPUKernel.")); + } +}; + +} // namespace operators +} // namespace paddle \ No newline at end of file diff --git a/paddle/fluid/pybind/box_helper_py.cc b/paddle/fluid/pybind/box_helper_py.cc index 227c9dedba1f8..a68ffe903ac4f 100644 --- a/paddle/fluid/pybind/box_helper_py.cc +++ b/paddle/fluid/pybind/box_helper_py.cc @@ -29,6 +29,7 @@ limitations under the License. */ #include "paddle/fluid/framework/data_feed.pb.h" #include "paddle/fluid/framework/fleet/box_wrapper.h" #include "paddle/fluid/pybind/box_helper_py.h" +#include "paddle/fluid/framework/trie_manager.h" #ifdef PADDLE_WITH_BOX_PS #include #endif @@ -113,5 +114,37 @@ void BindBoxWrapper(py::module* m) { } // end BoxWrapper #endif +void BindTrieManager(py::module* m) { + py::class_>( + *m, "Trie") + .def(py::init([]() { + return std::make_shared(); })) + .def("load", + &framework::Trie::load, + py::arg("dir"), + py::arg("thr_num")=20u, + py::call_guard()) + .def("label", + &framework::Trie::label, + py::call_guard()) + .def("aleaf", + &framework::Trie::aleaf, + py::call_guard()); + + py::class_>( + *m, "TrieManager") + .def(py::init([](int endid) { + // return std::make_shared(endid); })) + return framework::TrieManager::SetInstance(endid); })) + .def("load", + &framework::TrieManager::load, + py::arg("dir"), + py::arg("thr_num")=20u, + py::call_guard()) + .def("reset", + &framework::TrieManager::reset, + py::call_guard()); +} // end TrieManager + } // end namespace pybind } // end namespace paddle diff --git a/paddle/fluid/pybind/box_helper_py.h b/paddle/fluid/pybind/box_helper_py.h index 7bc36516c6580..f3ef1d1150547 100644 --- a/paddle/fluid/pybind/box_helper_py.h +++ b/paddle/fluid/pybind/box_helper_py.h @@ -26,6 +26,7 @@ void BindBoxHelper(py::module* m); #ifdef PADDLE_WITH_BOX_PS void BindBoxWrapper(py::module* m); #endif +void BindTrieManager(py::module* m); } // namespace pybind } // namespace paddle diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 67f0d9cc8eb85..34a9d0281eba5 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -2536,6 +2536,7 @@ All parameter, weight, gradient are variables in Paddle. #ifdef PADDLE_WITH_BOX_PS BindBoxWrapper(&m); #endif +BindTrieManager(&m); #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) BindNCCLWrapper(&m); #endif diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 49180f8c9670f..dc819c94d05a6 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -16657,3 +16657,38 @@ def unbind(input, axis=0): attrs={"axis": axis}, ) return outs + +def trie_search_start(parent_idx, ids): + """ + """ + helper = LayerHelper("trie_search_start", **locals()) + + out = helper.create_variable_for_type_inference(dtype=ids.dtype) + + helper.append_op( + type="trie_search_start", + inputs={ + "parent_idx": parent_idx, + "ids": ids + }, + outputs={"Out": out} + ) + + return out + +def trie_search_wait(mask): + """ + """ + helper = LayerHelper("trie_search_wait", **locals()) + + out = helper.create_variable_for_type_inference(dtype=mask.dtype) + + helper.append_op( + type="trie_search_wait", + inputs={ + "X": mask + }, + outputs={"Out": out}, + ) + + return out \ No newline at end of file From 456ab5bc7e67b7ac131c7a74d8b1c6790846a5e7 Mon Sep 17 00:00:00 2001 From: wanglipeng Date: Tue, 28 Nov 2023 14:34:13 +0800 Subject: [PATCH 8/9] beam search support length_penalty --- paddle/phi/api/yaml/ops.yaml | 2 +- paddle/phi/infermeta/multiary.cc | 1 + paddle/phi/infermeta/multiary.h | 1 + .../phi/kernels/fusion/beam_search_softmax.h | 1 + .../kernels/fusion/gpu/beam_search_softmax.cu | 28 +++++++++++++++---- python/paddle/tensor/search.py | 6 ++-- 6 files changed, 30 insertions(+), 9 deletions(-) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 38bdacc3bcf6f..904cd7de62e39 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -210,7 +210,7 @@ backward : flip_grad - op : beam_search_softmax - args : (Tensor logits, Tensor cum_scores, Tensor sequence_lengths, Tensor stop_flags, Tensor end_ids, Tensor step_ids, Tensor last_cache_ids, Tensor last_beam_offsets, int beam_size, int max_seq_len, int max_dec_len, bool fuse_softmax, bool early_stop) + args : (Tensor logits, Tensor cum_scores, Tensor sequence_lengths, Tensor stop_flags, Tensor end_ids, Tensor step_ids, Tensor last_cache_ids, Tensor last_beam_offsets, int beam_size, int max_seq_len, int max_dec_len, bool fuse_softmax, bool early_stop, float length_penalty=0.0) output : Tensor(ids_this_time), Tensor(out_cum_scores), Tensor(cache_ids), Tensor(beam_offsets), Tensor(parent_idx), Tensor(stop_flags_out), Tensor(seq_lens_out), Tensor(step_ids_out) infer_meta : func : BeamSearchSoftmaxInferMeta diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 6340dbfd57401..d134759bb62c5 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -687,6 +687,7 @@ void BeamSearchSoftmaxInferMeta(const MetaTensor& logits, int max_dec_len, bool fuse_softmax, bool early_stop, + float length_penalty, MetaTensor* ids_this_time, MetaTensor* out_cum_scores, MetaTensor* cache_ids, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 2d1a45f31c0ce..01018f3ef7cd9 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -203,6 +203,7 @@ void BeamSearchSoftmaxInferMeta(const MetaTensor& logits, int max_dec_len, bool fuse_softmax, bool early_stop, + float length_penalty, MetaTensor* ids_this_time, MetaTensor* out_cum_scores, MetaTensor* cache_ids, diff --git a/paddle/phi/kernels/fusion/beam_search_softmax.h b/paddle/phi/kernels/fusion/beam_search_softmax.h index 985515a0a39f3..6970286f0370f 100644 --- a/paddle/phi/kernels/fusion/beam_search_softmax.h +++ b/paddle/phi/kernels/fusion/beam_search_softmax.h @@ -34,6 +34,7 @@ void BeamSearchSoftmaxKernel(const Context &dev_ctx, int max_dec_len, bool fuse_softmax, bool early_stop, + float length_penalty, DenseTensor *ids_this_time, DenseTensor *out_cum_scores, DenseTensor *cache_ids, diff --git a/paddle/phi/kernels/fusion/gpu/beam_search_softmax.cu b/paddle/phi/kernels/fusion/gpu/beam_search_softmax.cu index 0deaa33c6adab..cea15b1e3a6ab 100644 --- a/paddle/phi/kernels/fusion/gpu/beam_search_softmax.cu +++ b/paddle/phi/kernels/fusion/gpu/beam_search_softmax.cu @@ -55,6 +55,7 @@ namespace fusion { max_dec_len, \ fuse_softmax, \ early_stop, \ + length_penalty, \ stream); \ break @@ -202,6 +203,7 @@ __global__ void batch_topk(const int *topk_tmp_id_buf, } } +// early stop template __global__ void batch_topk(const int *topk_tmp_id_buf, const T *topk_tmp_val_buf, @@ -391,7 +393,9 @@ __global__ void beam_search_softmax_topk_stage2(const float *tmp_buffer, T *tmp_vals, const int voc_parts, const int packed_top_kmd_size, - const bool fuse_softmax) { + const bool fuse_softmax, + const float length_penalty, + const int *step_ids) { const int vector_id = blockIdx.x; const int thread_id = threadIdx.x; const int PACKED_TOP_KMD_SIZE = packed_top_kmd_size; @@ -402,6 +406,10 @@ __global__ void beam_search_softmax_topk_stage2(const float *tmp_buffer, float *buf_s = reinterpret_cast(buf_s_); tmp_buffer += vector_id * PACKED_TOP_KMD_SIZE * voc_parts; + // Since cum_log_probs is the penalized values, need to be restored before accumulation. + T previous_penalty = static_cast(powf(step_ids[vector_id], length_penalty)); + T current_penalty = static_cast(powf(step_ids[vector_id] + 1, length_penalty)); + if (fuse_softmax) { typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; @@ -443,7 +451,7 @@ __global__ void beam_search_softmax_topk_stage2(const float *tmp_buffer, // float val = expf((float)total.topk.vals[i] - total.softmax_md.logit - d_total_log); float val = total.topk.vals[i] - total.softmax_md.logit - d_total_log; tmp_ids[i] = total.topk.ids[i]; - tmp_vals[i] = val + cum_log_probs[0]; + tmp_vals[i] = (val + cum_log_probs[0] * previous_penalty) / current_penalty; #ifdef DEBUG_BEAM_SEARCH_SOFTMAX printf("vector_id: %d, vals: %f, logit: %f, d_total_log: %f, id: %d, val: %f, cum_log_probs: %f, res: %f\n", vector_id, total.topk.vals[i], total.softmax_md.logit, d_total_log, tmp_ids[i], val, cum_log_probs[0], tmp_vals[i]); #endif @@ -485,7 +493,7 @@ __global__ void beam_search_softmax_topk_stage2(const float *tmp_buffer, for (int i = 0; i < K; ++i) { float val = total.vals[i]; tmp_ids[i] = total.ids[i]; - tmp_vals[i] = val + cum_log_probs[0]; + tmp_vals[i] = (val + cum_log_probs[0] * previous_penalty) / current_penalty; } } } @@ -501,25 +509,27 @@ void invokeBeamSearchSoftmaxTopKStage2(const float *tmp_buffer, const int voc_parts, const int packed_top_kmd_size, const bool fuse_softmax, + const float length_penalty, + const int *step_ids, cudaStream_t stream) { int smem_stage2_size = voc_parts * packed_top_kmd_size * sizeof(float); if (voc_parts <= 32) { beam_search_softmax_topk_stage2 <<>>( - tmp_buffer, cum_log_probs, ids, vals, voc_parts, packed_top_kmd_size, fuse_softmax); + tmp_buffer, cum_log_probs, ids, vals, voc_parts, packed_top_kmd_size, fuse_softmax, length_penalty, step_ids); return; } if (voc_parts <= 64) { beam_search_softmax_topk_stage2 <<>>( - tmp_buffer, cum_log_probs, ids, vals, voc_parts, packed_top_kmd_size, fuse_softmax); + tmp_buffer, cum_log_probs, ids, vals, voc_parts, packed_top_kmd_size, fuse_softmax, length_penalty, step_ids); return; } if (voc_parts <= 128) { beam_search_softmax_topk_stage2 <<>>( - tmp_buffer, cum_log_probs, ids, vals, voc_parts, packed_top_kmd_size, fuse_softmax); + tmp_buffer, cum_log_probs, ids, vals, voc_parts, packed_top_kmd_size, fuse_softmax, length_penalty, step_ids); return; } } @@ -681,6 +691,7 @@ void invokeTopKSoftMaxLauncher(const Context &dev_ctx, const int max_dec_len, const bool fuse_softmax, const bool early_stop, + const float length_penalty, cudaStream_t stream) { // K = 2 * beam_size const int block_size = 128; @@ -725,6 +736,8 @@ void invokeTopKSoftMaxLauncher(const Context &dev_ctx, voc_parts, packed_top_kmd_size, fuse_softmax, + length_penalty, + step_ids, stream); // (bs, bm) if (early_stop) { @@ -808,6 +821,7 @@ void invokeTopkSoftMax(const Context &dev_ctx, const int max_dec_len, const bool fuse_softmax, const bool early_stop, + const float length_penalty, cudaStream_t stream) { switch (beam_size) { CASE_K(1); @@ -848,6 +862,7 @@ void BeamSearchSoftmaxKernel(const Context &dev_ctx, int max_dec_len, bool fuse_softmax, bool early_stop, + float length_penalty, DenseTensor *ids_this_time, DenseTensor *out_cum_scores, DenseTensor *cache_ids, @@ -913,6 +928,7 @@ void BeamSearchSoftmaxKernel(const Context &dev_ctx, max_dec_len, fuse_softmax, early_stop, + length_penalty, dev_ctx.stream()); } diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 17499887e7a5b..eb1975fdcc246 100644 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -1122,7 +1122,7 @@ def beam_search_softmax( max_dec_len, fuse_softmax, early_stop, - name=None, + length_penalty=0.0, ): if in_dygraph_mode(): return _C_ops.beam_search_softmax( @@ -1138,7 +1138,8 @@ def beam_search_softmax( max_seq_len, max_dec_len, fuse_softmax, - early_stop + early_stop, + length_penalty ) inputs = { @@ -1157,6 +1158,7 @@ def beam_search_softmax( attrs['max_dec_len'] = max_dec_len attrs['fuse_softmax'] = fuse_softmax attrs['early_stop'] = early_stop + attrs['length_penalty'] = length_penalty helper = LayerHelper('beam_search_softmax', **locals()) ids_this_time = helper.create_variable_for_type_inference(dtype="int32") From a96bfb1f96e2cf588ad4ad4c8022ab03455a1882 Mon Sep 17 00:00:00 2001 From: tianyan01 Date: Tue, 28 Nov 2023 14:39:16 +0800 Subject: [PATCH 9/9] add fmt support int8 and fix some int8 bug --- .../fused/fused_multi_transformer_int8_op.cu | 1 + .../fused_multi_transformer_moe_int8_op.cu | 25 +- .../fused/fused_multi_transformer_moe_op.cu | 3 +- .../fused/fused_multi_transformer_moe_op.h | 10 +- .../fused/fused_residual_dropout_bias.h | 4 +- .../operators/fused/layernorm_quant_dequant.h | 8 +- python/paddle/fluid/dygraph/amp/auto_cast.py | 5 + python/paddle/fluid/framework.py | 23 + python/paddle/incubate/nn/__init__.py | 2 + .../incubate/nn/layer/fused_transformer.py | 387 ++++++- python/paddle/nn/quant/__init__.py | 37 +- python/paddle/nn/quant/format.py | 241 ++++ python/paddle/nn/quant/functional_layers.py | 40 +- python/paddle/nn/quant/lsq.py | 371 ++++++ python/paddle/nn/quant/qat/__init__.py | 15 + python/paddle/nn/quant/qat/conv.py | 88 ++ python/paddle/nn/quant/qat/linear.py | 66 ++ python/paddle/nn/quant/quant_layers.py | 1003 ++++++++++------- python/paddle/nn/quant/quantized_linear.py | 245 ++++ python/paddle/nn/quant/stub.py | 99 ++ python/setup.py.in | 1 + 21 files changed, 2222 insertions(+), 452 deletions(-) create mode 100644 python/paddle/nn/quant/format.py create mode 100644 python/paddle/nn/quant/lsq.py create mode 100644 python/paddle/nn/quant/qat/__init__.py create mode 100644 python/paddle/nn/quant/qat/conv.py create mode 100644 python/paddle/nn/quant/qat/linear.py create mode 100644 python/paddle/nn/quant/quantized_linear.py create mode 100644 python/paddle/nn/quant/stub.py diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu index d87b3db45cb19..245748d65a492 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu +++ b/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu @@ -837,6 +837,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { ffn1_dropout_mask_data, ffn1_in_scale[i], ffn1_out_scales[i]->data(), + 0, ffn2_in_scale[i], quant_round_type, quant_max_bound, diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_moe_int8_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_moe_int8_op.cu index 4869b14ca10de..dbfcd9948f758 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_moe_int8_op.cu +++ b/paddle/fluid/operators/fused/fused_multi_transformer_moe_int8_op.cu @@ -19,6 +19,7 @@ namespace paddle { namespace operators { using Tensor = phi::DenseTensor; +// #define _DEBUG_FUSED_MULTI_TRANSFORMER template static void PrintMatrix(const T* mat_d, int num, std::string name) { @@ -72,7 +73,7 @@ class FusedMultiTransformerMoeINT8OpKernel : public framework::OpKernel { auto quant_max_bound = ctx.Attr("quant_max_bound"); auto quant_min_bound = ctx.Attr("quant_min_bound"); - // dequant output scales, tensor, size = [num_layers, n], n is gemm output + // dequant output scales, vertor, size = [num_layers, n], n is gemm output // size auto qkv_out_scales = ctx.MultiInput("QKVOutScale"); auto out_linear_out_scales = @@ -164,7 +165,7 @@ class FusedMultiTransformerMoeINT8OpKernel : public framework::OpKernel { qktv_out.Resize({{bsz, num_head, seq_len, dim_head}}); auto *qktv_out_data = dev_ctx.Alloc(&qktv_out, qktv_out.numel() * sizeof(T)); - fmha_out.Resize({{bsz_seq, num_head, dim_head}}); + fmha_out.Resize({{bsz, seq_len, num_head, dim_head}}); auto *fmha_out_data = dev_ctx.Alloc(&fmha_out, fmha_out.numel() * sizeof(T)); @@ -231,7 +232,7 @@ class FusedMultiTransformerMoeINT8OpKernel : public framework::OpKernel { topk_value.Resize({{sliced_bsz_seq, topk}}); dev_ctx.Alloc(&topk_value, topk_value.numel() * sizeof(T)); topk_idx.Resize({{sliced_bsz_seq, topk}}); - dev_ctx.Alloc(&topk_idx, topk_idx.numel() * sizeof(T)); + dev_ctx.Alloc(&topk_idx, topk_idx.numel() * sizeof(int64_t)); // local expert count, global expert count Tensor local_expert_count, global_expert_count; local_expert_count.Resize({{tot_expert}}); @@ -424,7 +425,7 @@ class FusedMultiTransformerMoeINT8OpKernel : public framework::OpKernel { #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step3.2 out linear"; #endif - // T -> int8 + // T -> int32 out_linear_compute.ComputeForwardTToINT8(out_linear_weights[i], out_linear_in_scale[i], &fmha_out, @@ -444,7 +445,6 @@ class FusedMultiTransformerMoeINT8OpKernel : public framework::OpKernel { #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step4"; #endif - // step5. ln(residual + dropout(input + bias)) auto *ln_scale_data = ffn_ln_scales[i]->data(); auto *ln_bias_data = ffn_ln_biases[i]->data(); @@ -455,7 +455,6 @@ class FusedMultiTransformerMoeINT8OpKernel : public framework::OpKernel { // 改为输出先不做scale,输出是fp16,输出到buf0 AffineQuantStore store(buf0.data(), dim_embed, ln_scale_data, ln_bias_data); DispatchLayerNorm(dev_ctx.stream(), load, store, bsz_seq, dim_embed, epsilon, ln_mean_data, ln_var_data); - #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step5"; #endif @@ -564,9 +563,6 @@ class FusedMultiTransformerMoeINT8OpKernel : public framework::OpKernel { all_expert_out.Resize({{fwd_bsz, dim_embed}}); dev_ctx.Alloc(&all_expert_out, all_expert_out.numel() * sizeof(T)); - // global_scatter_out.Resize({{fwd_bsz, dim_embed}}); - // all_expert_out.Resize({{fwd_bsz, dim_embed}}); - // step 5, MOEScatter // step 5.1, index select // suppose tmp_pos->shape != [0] @@ -614,19 +610,16 @@ class FusedMultiTransformerMoeINT8OpKernel : public framework::OpKernel { int end = cur_expert_count + last_index; Tensor expert_in_tmp; // int8_t - expert_in_tmp.Resize({{(cur_expert_count * dim_feedforward + 31) / 32 * 32 }}); + expert_in_tmp.Resize({{cur_expert_count, dim_feedforward}}); dev_ctx.Alloc(&expert_in_tmp, expert_in_tmp.numel() * sizeof(int8_t)); Tensor expert_out1; // int32_t - expert_out1.Resize({{(cur_expert_count * dim_feedforward + 31) / 32 * 32}}); + expert_out1.Resize({{cur_expert_count, dim_feedforward}}); dev_ctx.Alloc(&expert_out1, expert_out1.numel() * sizeof(int32_t)); Tensor expert_out2; // T(fp16) expert_out2.Resize({{cur_expert_count, dim_embed}}); dev_ctx.Alloc(&expert_out2, expert_out2.numel() * sizeof(T)); - // act_bias_out.Resize({{cur_expert_count, dim_feedforward}}); maybe int8_t? - // maybe use input_workspace and output workspace? - // dev_ctx.Alloc(&act_bias_out, act_bias_out.numel() * sizeof(T)); // input is int32_t, output is int8_t FusedDropoutHelper fused_act_dropout_helper( @@ -654,7 +647,7 @@ class FusedMultiTransformerMoeINT8OpKernel : public framework::OpKernel { expert_out1.data(), expert_biases1[expert_idx]->data(), "gelu", - expert_in_tmp.data(), + expert_in_tmp.data(), // output nullptr, expert_weight1_in_scale[expert_idx], expert_weight1_out_scales[expert_idx]->data(), @@ -668,7 +661,7 @@ class FusedMultiTransformerMoeINT8OpKernel : public framework::OpKernel { MatMulINT8ToT(dev_ctx, expert_weights2[expert_idx], expert_weight2_in_scale[expert_idx], - &expert_in_tmp, + &expert_in_tmp, // input expert_biases2[expert_idx], &expert_out2, &expert_out1, // output_tmp diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_moe_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_moe_op.cu index 6e6b41dd6ab74..0307a737b81d6 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_moe_op.cu +++ b/paddle/fluid/operators/fused/fused_multi_transformer_moe_op.cu @@ -208,7 +208,7 @@ class FusedMultiTransformerMoeOpKernel : public framework::OpKernel { topk_value.Resize({{sliced_bsz_seq, topk}}); dev_ctx.Alloc(&topk_value, topk_value.numel() * sizeof(T)); topk_idx.Resize({{sliced_bsz_seq, topk}}); - dev_ctx.Alloc(&topk_idx, topk_idx.numel() * sizeof(T)); + dev_ctx.Alloc(&topk_idx, topk_idx.numel() * sizeof(int64_t)); // local expert count, global expert count Tensor local_expert_count, global_expert_count; local_expert_count.Resize({{tot_expert}}); @@ -642,7 +642,6 @@ class FusedMultiTransformerMoeOpKernel : public framework::OpKernel { Tensor tmp_inp = global_scatter_out.Slice(last_index, end); int expert_idx = i * num_expert + idx; - // linear1 matmul // VLOG(0) << "moe, Expert Computation, linear1 mul"; phi::MatMulAndAdd(dev_ctx, diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_moe_op.h b/paddle/fluid/operators/fused/fused_multi_transformer_moe_op.h index 01a5e344ecc54..108ba22c4b8db 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_moe_op.h +++ b/paddle/fluid/operators/fused/fused_multi_transformer_moe_op.h @@ -216,11 +216,11 @@ void MatMulTToINT8(const phi::GPUContext& dev_ctx, dev_ctx.stream()); helper->GEMM(input_tmp->data(), - weight->data(), - output->data(), - dev_ctx.stream(), - (void*)workspace->data(), - workspace->numel()); + weight->data(), + output->data(), + dev_ctx.stream(), + (void*)workspace->data(), + workspace->numel()); } template diff --git a/paddle/fluid/operators/fused/fused_residual_dropout_bias.h b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h index cabcbea2d2e60..55c04f429a74f 100644 --- a/paddle/fluid/operators/fused/fused_residual_dropout_bias.h +++ b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h @@ -107,8 +107,8 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread( for (int ii = 0; ii < VecSize; ii++) { T tmp; if (std::is_same::value) { - T tmp0 = static_cast(static_cast(src_vec[ii]) * - quant_last_in_scale / quant_out_scale_vec[ii]); + T tmp0 = static_cast(static_cast(src_vec[ii]) * + quant_out_scale_vec[ii]); tmp = tmp0 + bias_vec[ii]; } else { tmp = static_cast(src_vec[ii]) + bias_vec[ii]; diff --git a/paddle/fluid/operators/fused/layernorm_quant_dequant.h b/paddle/fluid/operators/fused/layernorm_quant_dequant.h index 6d9fccbcb6e48..3762a173890ba 100644 --- a/paddle/fluid/operators/fused/layernorm_quant_dequant.h +++ b/paddle/fluid/operators/fused/layernorm_quant_dequant.h @@ -1044,7 +1044,9 @@ struct DequantSkipLoadAndStoreResidual { src_pack.storage = *(reinterpret_cast*>(src) + offset); bias_pack.storage = *(reinterpret_cast*>(bias) + bias_offset); skip_pack.storage = *(reinterpret_cast*>(skip) + offset); - dequant_scale_pack.storage = *(reinterpret_cast*>(dequant_scale) + bias_offset); // equal to col. + if (do_dequant) { + dequant_scale_pack.storage = *(reinterpret_cast*>(dequant_scale) + bias_offset); // equal to col. + } #pragma unroll for (int i = 0; i < N; ++i) { // First we need to cast src and dequant. @@ -1053,8 +1055,8 @@ struct DequantSkipLoadAndStoreResidual { + bias_pack.elem[i] + skip_pack.elem[i]); } else { - residual_out_pack.elem[i] = static_cast(static_cast(src_pack.elem[i]) + bias_pack.elem[i] - + skip_pack.elem[i]); + // trick for smoe, dont add bias. + residual_out_pack.elem[i] = static_cast(static_cast(src_pack.elem[i]) + skip_pack.elem[i]); } } #pragma unroll diff --git a/python/paddle/fluid/dygraph/amp/auto_cast.py b/python/paddle/fluid/dygraph/amp/auto_cast.py index 6ff2f9dc46c43..9454d3ce9e274 100644 --- a/python/paddle/fluid/dygraph/amp/auto_cast.py +++ b/python/paddle/fluid/dygraph/amp/auto_cast.py @@ -183,6 +183,8 @@ def pure_fp16_initialize(models): for idx in range(len(models)): for layer in models[idx].sublayers(include_self=True): layer._casted_by_pure_fp16 = True + if isinstance(layer, paddle.incubate.nn.FusedMultiTransformerMoeINT8): + continue if (layer._dtype == 'float16') or isinstance( layer, (paddle.nn.BatchNorm, paddle.nn.BatchNorm1D, paddle.nn.BatchNorm2D, paddle.nn.BatchNorm3D, @@ -196,6 +198,9 @@ def pure_fp16_initialize(models): paddle.incubate.nn.FusedMoELayer)): layer._amp_decorate(dtype='float16') continue + # if isinstance(layer, paddle.incubate.nn.FusedMultiTransformerMoeINT8): + # layer._amp_decorate(dtype='int8') + # continue layer._to_impl(dtype='float16', include_sublayers=False, floating_only=True) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 4fbcdc78536ac..e72bdb3e51138 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1281,6 +1281,29 @@ def _debug_string_(proto, throw_on_error=True): return proto.__str__() +def _create_tensor( + type=core.VarDesc.VarType.LOD_TENSOR, + name=None, + shape=None, + dtype=None, + persistable=None, + **kwargs, +): + if dtype is not None: + if not isinstance(dtype, core.VarDesc.VarType): + dtype = convert_np_dtype_to_dtype_(dtype) + + eager_tensor = core.eager.Tensor( + dtype if dtype else core.VarDesc.VarType.FP32, + list(shape) if shape else [], + name, + type if type else core.VarDesc.VarType.LOD_TENSOR, + True if persistable else False, + ) + eager_tensor.retain_grads() + return eager_tensor + + def _varbase_creator( type=core.VarDesc.VarType.LOD_TENSOR, name=None, diff --git a/python/paddle/incubate/nn/__init__.py b/python/paddle/incubate/nn/__init__.py index c5ca8e38215d7..5bd4c5bbc4566 100644 --- a/python/paddle/incubate/nn/__init__.py +++ b/python/paddle/incubate/nn/__init__.py @@ -20,6 +20,7 @@ from .layer.fused_transformer import FusedBiasDropoutResidualLayerNorm # noqa: F401 from .layer.fused_transformer import FusedMoELayer # noqa: F401 from .layer.fused_transformer import FusedMultiTransformerMoe # noqa: F401 +from .layer.fused_transformer import FusedMultiTransformerMoeINT8 # noqa: F401 __all__ = [ #noqa 'FusedMultiHeadAttention', @@ -27,6 +28,7 @@ 'FusedTransformerEncoderLayer', 'FusedMultiTransformer', 'FusedMultiTransformerMoe', + 'FusedMultiTransformerMoeINT8', 'FusedLinear', 'FusedBiasDropoutResidualLayerNorm', 'FusedMoELayer', diff --git a/python/paddle/incubate/nn/layer/fused_transformer.py b/python/paddle/incubate/nn/layer/fused_transformer.py index d66097cdf51a1..4f2b77e6ab4e0 100644 --- a/python/paddle/incubate/nn/layer/fused_transformer.py +++ b/python/paddle/incubate/nn/layer/fused_transformer.py @@ -1957,4 +1957,389 @@ def trans_to_fp16(l): trans_to_fp16(self.expert_biases1) trans_to_fp16(self.expert_weights2) trans_to_fp16(self.expert_biases2) - self._dtype = dtype \ No newline at end of file + self._dtype = dtype + + +class FusedMultiTransformerMoeINT8(Layer): + """ + FusedMultiTransformerMoeINT8 + """ + def __init__( + self, + d_model, + embed_dim, + num_heads, + dim_feedforward, + dropout_rate=0.0, + activation="gelu", + normalize_before=True, + ln_scale_attrs=None, + ln_bias_attrs=None, + qkv_weight_attrs=None, + qkv_bias_attrs=None, + linear_weight_attrs=None, + linear_bias_attrs=None, + gate_weight_attrs=None, + gate_bias_attrs=None, + ffn_ln_scale_attrs=None, + ffn_ln_bias_attrs=None, + expert_weight1_attrs=None, + expert_bias1_attrs=None, + expert_weight2_attrs=None, + expert_bias2_attrs=None, + qkv_out_scales_attrs=None, # out scales + out_linear_out_scales_attrs=None, + expert_weight1_out_scales_attrs=None, + expert_weight2_out_scales_attrs=None, + qkv_in_scale=None, + out_linear_in_scale=None, + expert_weight1_in_scale=None, + expert_weight2_in_scale=None, + epsilon=1e-5, + num_layers=-1, + nranks=1, + trans_qkvw=True, + ring_id=-1, + num_expert=1, + top_k=2, + approximate=True, + moe_group=None, + mp_group=None, + name=None, + ): + super(FusedMultiTransformerMoeINT8, self).__init__() + assert embed_dim > 0, ( + "Expected embed_dim to be greater than 0, " + "but received {}".format(embed_dim) + ) + assert ( + num_heads > 0 + ), "Expected nhead to be greater than 0, " "but received {}".format( + num_heads + ) + assert ( + dim_feedforward > 0 + ), "Expected dim_feedforward to be greater than 0, but received {}".format( + dim_feedforward + ) + # only support mp/dp + # for moe config + self.group = moe_group + self.world_size = 1 + if self.group is not None: + self.world_size = self.group.nranks + self.num_expert = num_expert + + self.mp_rank = 0 + self.mp_size = 1 + if mp_group is not None and mp_group.nranks > 1: + self.mp_rank = mp_group.rank + self.mp_size = mp_group.nranks + self.top_k = top_k + self.approximate = approximate + + # origin fmt config + self.normalize_before = normalize_before + # self._dtype = self._helper.get_default_dtype() + self._dtype = "float16" # fix, default is fp16 + self._epsilon = epsilon + self._trans_qkvw = trans_qkvw + self._ring_id = ring_id + + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == embed_dim + ), "embed_dim must be divisible by num_heads" + + # tensor model parallel + if nranks > 1: + assert ring_id != -1 + assert num_heads % nranks == 0 + num_heads = num_heads // nranks + + if isinstance(qkv_weight_attrs, (list, tuple, ParameterList)): + num_layers = len(qkv_weight_attrs) + assert num_layers > 0 + + self.qkv_in_scale = qkv_in_scale + self.out_linear_in_scale = out_linear_in_scale + self.expert_weight1_in_scale = expert_weight1_in_scale + self.expert_weight2_in_scale = expert_weight2_in_scale + + self.ln_scales, self.ln_biases = ParameterList(), ParameterList() + self.qkv_weights, self.qkv_biases = ParameterList(), ParameterList() + self.linear_weights, self.linear_biases = ParameterList(), ParameterList() + self.gate_weights, self.gate_biases = ParameterList(), ParameterList() + self.ffn_ln_scales, self.ffn_ln_biases = ParameterList(), ParameterList() + self.expert_weights1, self.expert_biases1 = ParameterList(), ParameterList() + self.expert_weights2, self.expert_biases2 = ParameterList(), ParameterList() + self.qkv_out_scales, self.out_linear_out_scales = ParameterList(), ParameterList() + self.expert_weight1_out_scales, self.expert_weight2_out_scales = ParameterList(), ParameterList() + def get_attr(attrs, idx): + if isinstance(attrs, (list, tuple, ParameterList)): + assert len(attrs) == num_layers + return attrs[idx] + return attrs + + for i in range(num_layers): + ln_scale_attr = get_attr(ln_scale_attrs, i) + ln_bias_attr = get_attr(ln_bias_attrs, i) + qkv_weight_attr = get_attr(qkv_weight_attrs, i) + qkv_bias_attr = get_attr(qkv_bias_attrs, i) + linear_weight_attr = get_attr(linear_weight_attrs, i) + linear_bias_attr = get_attr(linear_bias_attrs, i) + + ffn_ln_scale_attr = get_attr(ffn_ln_scale_attrs, i) + ffn_ln_bias_attr = get_attr(ffn_ln_bias_attrs, i) + gate_weight_attr = get_attr(gate_weight_attrs, i) + gate_bias_attr = get_attr(gate_bias_attrs, i) + + qkv_out_scales_attr = get_attr(qkv_out_scales_attrs, i) + out_linear_out_scales_attr = get_attr(out_linear_out_scales_attrs, i) + expert_weight1_out_scales_attr = get_attr(expert_weight1_out_scales_attrs, i) + expert_weight2_out_scales_attr = get_attr(expert_weight2_out_scales_attrs, i) + + ln_scale = self.create_parameter( + attr=ln_scale_attr, + shape=[embed_dim], + default_initializer=Constant(value=1.0), + dtype="float32", + ) + ln_bias = self.create_parameter( + attr=ln_bias_attr, shape=[embed_dim], is_bias=True, dtype="float32" + ) + qkv_weight = self.create_parameter( + shape=[3, num_heads, self.head_dim, embed_dim] + if trans_qkvw + else [embed_dim, 3, num_heads, self.head_dim], + attr=qkv_weight_attr, + dtype=self._dtype, + is_bias=False, + ) + qkv_bias = self.create_parameter( + shape=[3, num_heads, self.head_dim], + attr=qkv_bias_attr, + dtype=self._dtype, + is_bias=True, + ) + linear_weight = self.create_parameter( + shape=[num_heads * self.head_dim, embed_dim], + attr=linear_weight_attr, + dtype=self._dtype, + is_bias=False, + ) + linear_bias = self.create_parameter( + shape=[embed_dim], + attr=linear_bias_attr, + dtype=self._dtype, + is_bias=True, + ) + qkv_out_scale = self.create_parameter( + shape=[3 * embed_dim], + attr=qkv_out_scales_attr, + dtype="float32", + is_bias=False + ) + out_linear_out_scale = self.create_parameter( + shape=[embed_dim], + attr=out_linear_out_scales_attr, + dtype="float32", + is_bias=False + ) + + ffn_ln_scale = self.create_parameter( + shape=[embed_dim], + attr=ffn_ln_scale_attr, + is_bias=False, + default_initializer=Constant(1.0), + dtype="float32", + ) + ffn_ln_bias = self.create_parameter( + shape=[embed_dim], attr=ffn_ln_bias_attr, is_bias=True, dtype="float32" + ) + gate_weight = self.create_parameter( + shape=[d_model, num_expert * self.world_size], + attr=gate_weight_attr, + dtype=self._dtype, + is_bias=False + ) + gate_bias = self.create_parameter( + shape=[num_expert * self.world_size], + attr=gate_bias_attr, + dtype=self._dtype, + is_bias=True + ) + + # tensor model parallel + if nranks > 1: + # column parallel + _set_var_distributed(qkv_weight) + _set_var_distributed(qkv_bias) + # row parallel + _set_var_distributed(linear_weight) + + self.ln_scales.append(ln_scale) + self.ln_biases.append(ln_bias) + self.qkv_weights.append(qkv_weight) + self.qkv_biases.append(qkv_bias) + self.linear_weights.append(linear_weight) + self.linear_biases.append(linear_bias) + self.qkv_out_scales.append(qkv_out_scale) + self.out_linear_out_scales.append(out_linear_out_scale) + + self.ffn_ln_scales.append(ffn_ln_scale) + self.ffn_ln_biases.append(ffn_ln_bias) + self.gate_weights.append(gate_weight) + self.gate_biases.append(gate_bias) + + for j in range(num_expert): + expert_weight1_attr = get_attr(expert_weight1_attrs, i * num_expert + j) + expert_bias1_attr = get_attr(expert_bias1_attrs, i * num_expert + j) + expert_weight2_attr = get_attr(expert_weight2_attrs, i * num_expert + j) + expert_bias2_attr = get_attr(expert_bias2_attrs, i * num_expert + j) + + expert_weight1 = self.create_parameter( + # shape=[d_model, dim_feedforward], + shape=[dim_feedforward, d_model], + attr=expert_weight1_attr, + dtype=self._dtype, + is_bias=False, + default_initializer=nn.initializer.KaimingUniform() + ) + expert_bias1 = self.create_parameter( + shape=[dim_feedforward], + attr=expert_bias1_attr, + dtype=self._dtype, + is_bias=True, + default_initializer=nn.initializer.Constant(value=0.0) + ) + expert_weight2 = self.create_parameter( + # shape=[dim_feedforward, d_model], + shape=[d_model, dim_feedforward], + attr=expert_weight2_attr, + dtype=self._dtype, + is_bias=False, + default_initializer=nn.initializer.KaimingUniform() + ) + expert_bias2 = self.create_parameter( + shape=[d_model], + attr=expert_bias2_attr, + dtype=self._dtype, + is_bias=True, + default_initializer=nn.initializer.Constant(value=0.0) + ) + expert_weight1_out_scale = self.create_parameter( + shape=[4 * embed_dim], + attr=expert_weight1_out_scales_attr, + dtype="float32", + is_bias=False + ) + expert_weight2_out_scale = self.create_parameter( + shape=[embed_dim], + attr=expert_weight2_out_scales_attr, + dtype="float32", + is_bias=False + ) + expert_weight1.name = "expert_" + expert_weight1.name + expert_bias1.name = "expert_" + expert_bias1.name + expert_weight2.name = "expert_" + expert_weight2.name + expert_bias2.name = "expert_" + expert_bias2.name + self.expert_weights1.append(expert_weight1) + self.expert_biases1.append(expert_bias1) + self.expert_weights2.append(expert_weight2) + self.expert_biases2.append(expert_bias2) + self.expert_weight1_out_scales.append(expert_weight1_out_scale) + self.expert_weight2_out_scales.append(expert_weight2_out_scale) + self.dropout_rate = dropout_rate + self.activation = activation + self.name = name + # int8 + self._int8_decorate() + + def forward(self, src, attn_mask=None, caches=None, seq_lens=None, beam_offset=None, time_step=None): + """ + forward + """ + cache_kv_out, final_out = _legacy_C_ops.fused_multi_transformer_moe_int8( + src, + list(self.ln_scales), + list(self.ln_biases), + list(self.qkv_weights), + list(self.qkv_biases), + caches, + beam_offset, + time_step, + seq_lens, + attn_mask, + list(self.linear_weights), + list(self.linear_biases), + list(self.gate_weights), + list(self.gate_biases), + list(self.ffn_ln_scales), + list(self.ffn_ln_biases), + list(self.expert_weights1), + list(self.expert_biases1), + list(self.expert_weights2), + list(self.expert_biases2), + list(self.qkv_out_scales), + list(self.out_linear_out_scales), + list(self.expert_weight1_out_scales), + list(self.expert_weight2_out_scales), + caches, + 'pre_layer_norm', + self.normalize_before, + 'epsilon', + self._epsilon, + 'dropout_rate', + self.dropout_rate, + 'is_test', + not self.training, + 'dropout_implementation', + 'upscale_in_train', + 'act_method', + self.activation, + 'trans_qkvw', + self._trans_qkvw, + 'ring_id', + self._ring_id, + 'topk', + self.top_k, + 'mp_size', + self.mp_size, + 'mp_rank', + self.mp_rank, + 'num_expert', + self.num_expert, + 'world_size', + self.world_size, + 'moe_ring_id', + -1 if self.group is None else self.group.id, + 'approximate', + self.approximate, + 'qkv_in_scale', + self.qkv_in_scale, + 'out_linear_in_scale', + self.out_linear_in_scale, + 'expert_weight1_in_scale', + self.expert_weight1_in_scale, + 'expert_weight2_in_scale', + self.expert_weight2_in_scale + ) + if caches is not None: + return final_out, cache_kv_out + return final_out + + def _int8_decorate(self, dtype="int8"): + # tmp fix for INT8 + def trans_to_int8(l): + for param in l: + if param is not None: + with no_grad(): + param_applied = _to_dtype(param, dtype) + trans_to_int8(self.qkv_weights) + trans_to_int8(self.linear_weights) + trans_to_int8(self.expert_weights1) + trans_to_int8(self.expert_weights2) + self._dtype = "int8" \ No newline at end of file diff --git a/python/paddle/nn/quant/__init__.py b/python/paddle/nn/quant/__init__.py index 8973761ab6944..0a9ef677b200e 100644 --- a/python/paddle/nn/quant/__init__.py +++ b/python/paddle/nn/quant/__init__.py @@ -12,15 +12,32 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .functional_layers import FloatFunctionalLayer # noqa: F401 -from .functional_layers import add # noqa: F401 -from .functional_layers import subtract # noqa: F401 -from .functional_layers import multiply # noqa: F401 -from .functional_layers import divide # noqa: F401 -from .functional_layers import reshape # noqa: F401 -from .functional_layers import transpose # noqa: F401 -from .functional_layers import concat # noqa: F401 -from .functional_layers import flatten # noqa: F401 +from . import qat # noqa: F401 +from .functional_layers import ( # noqa: F401 + FloatFunctionalLayer, + add, + concat, + divide, + flatten, + matmul, + multiply, + reshape, + subtract, + transpose, +) from .quant_layers import QuantStub # noqa: F401 +from .quantized_linear import ( + llm_int8_linear, + weight_dequantize, + weight_only_linear, + weight_quantize, +) +from .stub import Stub -__all__ = [] +__all__ = [ + "Stub", + "weight_only_linear", + "llm_int8_linear", + "weight_quantize", + "weight_dequantize", +] diff --git a/python/paddle/nn/quant/format.py b/python/paddle/nn/quant/format.py new file mode 100644 index 0000000000000..1d4bc5ce44f64 --- /dev/null +++ b/python/paddle/nn/quant/format.py @@ -0,0 +1,241 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Define some layers used to export quantization model with ONNX style.""" +import abc +from typing import List, Tuple + +import paddle +from paddle import _legacy_C_ops as _C_ops +from paddle.framework import in_dynamic_mode + +#from ..layer.layers import Layer +from paddle.fluid.dygraph.layers import Layer + + +class LinearQuanterDequanter(Layer): + def __init__(self, quanter, dequanter): + super().__init__() + self._quanter = quanter + self._dequanter = dequanter + + def forward(self, input): + out = input + if self._quanter is not None: + out = self._quanter(out) + if self._dequanter is not None: + out = self._dequanter(out) + return out + + @staticmethod + def from_quanter(quanter): + assert quanter is not None + return LinearQuanterDequanter( + LinearQuanter.from_quanter(quanter), + LinearDequanter.from_quanter(quanter), + ) + + +class LinearQuanter(Layer): + def __init__(self, scales, zero_point=None, quant_axis=None, bit_length=8): + super().__init__() + self._scales = paddle.to_tensor(scales, dtype="float32") + self._zero_point = ( + paddle.zeros([1], dtype="float32") + if zero_point is None + else paddle.to_tensor(zero_point) + ) + self._quant_axis = -1 if quant_axis is None else quant_axis + self._bit_length = bit_length + + def forward(self, input): + if in_dynamic_mode(): + return _C_ops.quantize_linear( + input.cast('float32'), + self._scales, + self._zero_point, + "quant_axis", + self._quant_axis, + "bit_length", + self._bit_length, + ).cast(input.dtype) + else: + out = self._helper.create_variable_for_type_inference(input.dtype) + self._helper.append_op( + type='quantize_linear', + inputs={ + 'X': input, + 'Scale': self._scales, + 'ZeroPoint': self._zero_point, + }, + outputs={'Y': out}, + attrs={ + 'quant_axis': self._quant_axis, + 'bit_length': self._bit_length, + }, + ) + return out + + @staticmethod + def from_quanter(quanter): + return LinearQuanter( + quanter.scales(), + zero_point=quanter.zero_points(), + quant_axis=quanter.quant_axis(), + bit_length=quanter.bit_length(), + ) + + +class LinearDequanter(Layer): + def __init__(self, scales, zero_point=None, quant_axis=None, bit_length=8): + super().__init__() + self._scales = paddle.to_tensor(scales, dtype="float32") + self._zero_point = ( + paddle.zeros([1], dtype="float32") + if zero_point is None + else paddle.to_tensor(zero_point) + ) + self._quant_axis = -1 if quant_axis is None else quant_axis + self._bit_length = bit_length + + def forward(self, input): + if in_dynamic_mode(): + return _C_ops.dequantize_linear( + input.cast('float32'), + self._scales, + self._zero_point, + "quant_axis", + self._quant_axis, + "bit_length", + self._bit_length, + ).cast(input.dtype) + else: + out = self._helper.create_variable_for_type_inference(input.dtype) + self._helper.append_op( + type='dequantize_linear', + inputs={ + 'X': input, + 'Scale': self._scales, + 'ZeroPoint': self._zero_point, + }, + outputs={'Y': out}, + attrs={ + 'quant_axis': self._quant_axis, + 'bit_length': self._bit_length, + }, + ) + return out + + @staticmethod + def from_quanter(quanter): + return LinearDequanter( + quanter.scales(), + zero_point=quanter.zero_points(), + quant_axis=quanter.quant_axis(), + bit_length=quanter.bit_length(), + ) + + +class ConvertibleQuantedLayer(Layer, metaclass=abc.ABCMeta): + r"""Abstract class to help convert quantized layer to inference model. + It defines some functions to convert quantizers and observers to quantize + or dequantize operators that maintain the quantization parameters used + during inference. + + Examples: + .. code-block:: python + + >>> # Given codes in ./customized_quanter.py + >>> class CustomizedQuantedLayer(ConvertibleQuantedLayer): + ... def __init__(self): + ... super().__init__() + ... self.weight_a = paddle.create_parameter(shape=[1], dtype='float32') + ... self.weight_b = paddle.create_parameter(shape=[1], dtype='float32') + ... self.quanter_for_weight_a = None + ... self.activation_weight = None + ... + ... def forward(self, input): + ... qweight_a = self.quanter_for_weight_a(self.weight_a) + ... weight_b = self.weight_b + ... qinput = self.activation_weight(input) + ... # compute with qweight_a, weight_b and qinput. + ... return qweight * qinput + weight_b + ... + ... def weights_to_quanters(self): + ... return [('weight_a', 'quanter_for_weight_a')] + ... + ... def activation_quanters(self): + ... return ['activation_weight'] + """ + + def __init__(self): + super().__init__() + self.converted = False + + @abc.abstractmethod + def weights_to_quanters(self) -> List[Tuple[str, str]]: + r"""Get the name pairs of weights to be quantized and their corresponding + quantizers. In the convert function of this abstract class, it will call + the ‘weights_to_quanters’ function and do something as follows: + For each pair, the quantizer will be converted to a quantize operator and + a dequantize operator. Then, the weight will be quantized by the quantize + operator. Finally, the quantize operator will be removed and the weights + will be stored in integer data type. + + Returns: A list of name pairs. Each pair contains two names. The first is name of weight + to be quantized and the second is name of corresponding quanter. + """ + pass + + @abc.abstractmethod + def activation_quanters(self) -> List[str]: + r"""Get the names of quanters used to quantize activations. + All the quanters or observers returned by this function will be converted to quantize + and dequantize operators for deployment. + Returns: A list of quanter names. + """ + pass + + def _convert_quanter_to_qdq(self, quanter_name) -> LinearQuanterDequanter: + r"""Convert quanter to an instance of LinearQuanterDequanter.""" + assert hasattr( + self, quanter_name + ), f"{quanter_name} is not attribute of current layer." + quanter = getattr(self, quanter_name) + if quanter is None: + return None + quanter = LinearQuanterDequanter.from_quanter(quanter) + setattr(self, quanter_name, quanter) + self._sub_layers[quanter_name] = quanter + return quanter + + def _quant_weights(self, weight_name, quanter): + r"""Quantize the weight by given quanter.""" + weight = getattr(self, weight_name) + qweight = quanter(weight) + weight.set_value(qweight) + + def _convert(self): + r"""Convert current layer to onnx style for inference.""" + assert not self.converted, "The model should be converted only once." + for weight_name, quanter_name in self.weights_to_quanters(): + qdq = self._convert_quanter_to_qdq(quanter_name) + if qdq is not None: + self._quant_weights(weight_name, qdq._quanter) + qdq._quanter = None + qdq._sub_layers['_quanter'] = None + + for quanter_name in self.activation_quanters(): + self._convert_quanter_to_qdq(quanter_name) + + self.converted = True diff --git a/python/paddle/nn/quant/functional_layers.py b/python/paddle/nn/quant/functional_layers.py index ca1eb5f4fb3c1..834bfebac4b37 100644 --- a/python/paddle/nn/quant/functional_layers.py +++ b/python/paddle/nn/quant/functional_layers.py @@ -12,85 +12,85 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ...tensor import math, manipulation -from .. import Layer +from ...tensor import linalg, manipulation, math +#from ..layer.layers import Layer +from paddle.fluid.dygraph.layers import Layer __all__ = [] class FloatFunctionalLayer(Layer): - def __init__(self): - super(FloatFunctionalLayer, self).__init__() + super().__init__() class add(FloatFunctionalLayer): - def __init__(self): - super(add, self).__init__() + super().__init__() def forward(self, x, y, name=None): return math.add(x, y, name) class subtract(FloatFunctionalLayer): - def __init__(self): - super(subtract, self).__init__() + super().__init__() def forward(self, x, y, name=None): return math.subtract(x, y, name) class multiply(FloatFunctionalLayer): - def __init__(self): - super(multiply, self).__init__() + super().__init__() def forward(self, x, y, name=None): return math.multiply(x, y, name) class divide(FloatFunctionalLayer): - def __init__(self): - super(divide, self).__init__() + super().__init__() def forward(self, x, y, name=None): return math.divide(x, y, name) class reshape(FloatFunctionalLayer): - def __init__(self): - super(reshape, self).__init__() + super().__init__() def forward(self, x, shape, name=None): return manipulation.reshape(x, shape, name) class transpose(FloatFunctionalLayer): - def __init__(self): - super(transpose, self).__init__() + super().__init__() def forward(self, x, perm, name=None): return manipulation.transpose(x, perm, name) class concat(FloatFunctionalLayer): - def __init__(self): - super(concat, self).__init__() + super().__init__() def forward(self, x, axis=0, name=None): return manipulation.concat(x, axis, name) class flatten(FloatFunctionalLayer): - def __init__(self): - super(flatten, self).__init__() + super().__init__() def forward(self, x, start_axis=0, stop_axis=-1, name=None): return manipulation.flatten(x, start_axis, stop_axis, name) + + +class matmul(FloatFunctionalLayer): + def __init__(self): + super().__init__() + + def forward(self, x, y, transpose_x=False, transpose_y=False, name=None): + return linalg.matmul(x, y, transpose_x, transpose_y, name) diff --git a/python/paddle/nn/quant/lsq.py b/python/paddle/nn/quant/lsq.py new file mode 100644 index 0000000000000..0fcae2fb6b1ec --- /dev/null +++ b/python/paddle/nn/quant/lsq.py @@ -0,0 +1,371 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import paddle +from paddle.autograd import PyLayer +from paddle.framework import ParamAttr +from paddle.nn.initializer import Constant +from paddle.utils import unique_name + +#from ..layer.layers import Layer +from paddle.fluid.dygraph.layers import Layer + + +def round(x): + sign = paddle.sign(x) + x = sign * paddle.floor(paddle.abs(x) + 0.5) + return x + + +class LsqFunc(PyLayer): + @staticmethod + def forward(ctx, weight, alpha, g, Qn, Qp, per_channel=False, quant_axis=0): + ctx.save_for_backward(weight, alpha) + ctx.other = g, Qn, Qp, per_channel, quant_axis + if per_channel: + sizes = weight.shape + weight = weight.reshape((weight.shape[quant_axis], -1)) + weight = weight.transpose((1, 0)) + alpha = paddle.broadcast_to(alpha, weight.shape) + quant_w = round(paddle.divide(weight, alpha)).clip(Qn, Qp) + quant_w = quant_w * alpha + quant_w = quant_w.transpose((1, 0)) + quant_w = quant_w.reshape(sizes) + else: + quant_w = round(paddle.divide(weight, alpha)).clip(Qn, Qp) + quant_w = quant_w * alpha + return quant_w + + @staticmethod + def backward(ctx, grad_weight): + weight, alpha = ctx.saved_tensor() + g, Qn, Qp, per_channel, quant_axis = ctx.other + if per_channel: + sizes = weight.shape + weight = weight.reshape((weight.shape[quant_axis], -1)) + weight = weight.transpose((1, 0)) + alpha = paddle.broadcast_to(alpha, weight.shape) + q_w = paddle.divide(weight, alpha) + q_w = q_w.transpose((1, 0)) + q_w = q_w.reshape(sizes) + else: + q_w = paddle.divide(weight, alpha) + lower_flag = paddle.cast((q_w < Qn), 'float32') + upper_flag = paddle.cast((q_w > Qp), 'float32') + middle_flag = 1.0 - lower_flag - upper_flag + if per_channel: + grad_alpha = ( + ( + lower_flag * Qn + + upper_flag * Qp + + middle_flag * round(q_w) + - middle_flag * q_w + ) + * grad_weight + * g + ) + grad_alpha = grad_alpha.reshape( + (grad_alpha.shape[quant_axis], -1) + ).sum(axis=1) + else: + grad_alpha = ( + ( + ( + lower_flag * Qn + + upper_flag * Qp + + middle_flag * round(q_w) + - middle_flag * q_w + ) + * grad_weight + * g + ) + .sum() + .unsqueeze(axis=0)[0] + ) + grad_weight = middle_flag * grad_weight + return grad_weight, grad_alpha + + +class LsqPlusActFunc(PyLayer): + @staticmethod + def forward(ctx, x, alpha, beta, g, Qn, Qp): + ctx.save_for_backward(x, alpha, beta) + ctx.other = g, Qn, Qp + quant_x = round(paddle.divide((x - beta), alpha)).clip(Qn, Qp) + return quant_x * alpha + beta + + @staticmethod + def backward(ctx, grad_x): + x, alpha, beta = ctx.saved_tensor() + g, Qn, Qp = ctx.other + q_x = (x - beta) / alpha + lower_flag = paddle.cast((q_x < Qn), 'float32') + upper_flag = paddle.cast((q_x > Qp), 'float32') + middle_flag = 1.0 - lower_flag - upper_flag + grad_alpha = ( + ( + ( + lower_flag * Qn + + upper_flag * Qp + + middle_flag * round(q_x) + - middle_flag * q_x + ) + * grad_x + * g + ) + .sum() + .unsqueeze(axis=0)[0] + ) + grad_beta = ( + ((lower_flag + upper_flag) * grad_x * g).sum().unsqueeze(axis=0)[0] + ) + grad_x = middle_flag * grad_x + return grad_x, grad_alpha, grad_beta + + +class FakeQuantActLSQPlus(Layer): + def __init__( + self, + quant_bits, + all_postive=False, + symmetric=False, + batch_init=20, + dtype='float32', + name=None, + reduce_type=None, + ): + super().__init__() + ''' + Args: + quant_bits(int): quantization bit number for weights. + all_postive(bool): whether unsigned or signed quantization, where True for unsigned quantization and False for signed quantization. + symmetric(bool): whether symmetric or asymmetric quantization. + batch_init(int): number of batches that collect Gaussian approximation for the weight distribution in each layer. + dtype(str): data type. + name(str): the name of the weight. + reduce_type(str): the reduce type which is needed when parallel training. + ''' + self.bits = quant_bits + self.all_positive = all_postive + self.symmetric = symmetric + self.batch_init = batch_init + self.name = name + self.reduce_type = reduce_type + + if self.all_positive: + # unsigned activation + self.Qn = 0 + self.Qp = 2**self.bits - 1 + else: + # signed activation + self.Qn = -(2 ** (self.bits - 1)) + self.Qp = 2 ** (self.bits - 1) - 1 + + scale_prefix = f"{name}.scale" if name else 'quant_dequant.scale' + self._scale_name = unique_name.generate(scale_prefix) + + s_attr = ParamAttr( + name=self._scale_name, initializer=Constant(1.0), trainable=True + ) + self.s = self.create_parameter(shape=[], attr=s_attr, dtype='float32') + self.s.stop_gradient = False + + if not self.symmetric: + beta_prefix = f"{name}.beta" if name else 'quant_dequant.beta' + self._beta_name = unique_name.generate(beta_prefix) + + beta_attr = ParamAttr( + name=self._beta_name, initializer=Constant(0.0), trainable=True + ) + self.beta = self.create_parameter( + shape=[], attr=beta_attr, dtype='float32' + ) + self.beta.stop_gradient = False + + self.init_state = 0 + + def forward(self, activation): + if self.reduce_type == "max": + paddle.distributed.all_reduce( + self.s, op=paddle.distributed.ReduceOp.MAX + ) + + if not self.symmetric and self.reduce_type == "max": + paddle.distributed.all_reduce( + self.beta, op=paddle.distributed.ReduceOp.MAX + ) + + if self.init_state == 0: + self.g = paddle.to_tensor( + 1.0 / math.sqrt(activation.numel() * self.Qp) + ) + min_a = paddle.min(activation.detach()) + max_a = paddle.max(activation.detach()) + self.s.set_value((max_a - min_a) / (self.Qp - self.Qn)) + if not self.symmetric: + self.beta.set_value(min_a - self.s * self.Qn) + self.init_state += 1 + elif self.init_state < self.batch_init: + min_a = paddle.min(activation.detach()) + max_a = paddle.max(activation.detach()) + self.s.set_value( + self.s * 0.9 + 0.1 * (max_a - min_a) / (self.Qp - self.Qn) + ) + if not self.symmetric: + self.beta.set_value( + self.s * 0.9 + 0.1 * (min_a - self.s * self.Qn) + ) + self.init_state += 1 + else: + self.init_state += 1 + activation.stop_gradient = False + if not self.symmetric: + q_a = LsqPlusActFunc.apply( + activation, self.s, self.beta, self.g, self.Qn, self.Qp + ) + else: + q_a = LsqFunc.apply( + activation, self.s, self.g, self.Qn, self.Qp, per_channel=False + ) + return q_a + + +class FakeQuantWeightLSQPlus(Layer): + def __init__( + self, + quant_bits, + all_postive=False, + per_channel=False, + batch_init=20, + channel_num=None, + quant_linear=False, + dtype='float32', + name=None, + reduce_type=None, + ): + super().__init__() + ''' + Args: + quant_bits(int): quantization bit number for weights. + all_postive(bool): whether unsigned or signed quantization, where True for unsigned quantization and False for signed quantization. + per_channel(bool): whether layer-wise or channel-wise quantization, where True for layer-wise quantization and False for channel-wise quantization. + batch_init(int): number of batches that collect Gaussian approximation for the weight distribution in each layer. + channel_num(int): the channel number of the weight which is needed when per_channel is True. + quant_linear(bool): whether the weight is from Linear. + dtype(str): data type. + name(str): the name of the weight. + reduce_type(str): the reduce type which is needed when parallel training. + ''' + + self.bits = quant_bits + self.all_positive = all_postive + self.per_channel = per_channel + self.quant_linear = quant_linear + self.batch_init = batch_init + self.name = name + self.quant_axis = 1 if quant_linear else 0 + self.collect_axis = 0 if quant_linear else 1 + self.reduce_type = reduce_type + + if self.all_positive: + # unsigned weight + self.Qn = 0 + self.Qp = 2**self.bits - 1 + else: + # signed weight + self.Qn = -(2 ** (self.bits - 1)) + self.Qp = 2 ** (self.bits - 1) - 1 + + self.init_state = 0 + scale_prefix = f"{name}.scale" if name else 'quant_dequant.scale' + self._scale_name = unique_name.generate(scale_prefix) + s_attr = ParamAttr( + name=self._scale_name, initializer=Constant(1.0), trainable=True + ) + self.s = self.create_parameter( + shape=[channel_num], attr=s_attr, dtype=dtype + ) + self.s.stop_gradient = False + + def forward(self, weight): + if self.reduce_type == "max": + paddle.distributed.all_reduce( + self.s, op=paddle.distributed.ReduceOp.MAX + ) + + if self.init_state == 0: + self.g = paddle.to_tensor(1.0 / math.sqrt(weight.numel() * self.Qp)) + self.div = 2**self.bits - 1 + if self.per_channel: + weight_tmp = weight.detach().reshape((weight.shape[0], -1)) + mean = paddle.mean(weight_tmp, axis=self.collect_axis) + std = paddle.std(weight_tmp, axis=self.collect_axis) + s = paddle.max( + paddle.stack( + [paddle.abs(mean - 3 * std), paddle.abs(mean + 3 * std)] + ), + axis=0, + ) + self.s.set_value(s / self.div) + else: + mean = paddle.mean(weight.detach()) + std = paddle.std(weight.detach()) + self.s.set_value( + max( + [paddle.abs(mean - 3 * std), paddle.abs(mean + 3 * std)] + ) + / self.div + ) + self.init_state += 1 + elif self.init_state < self.batch_init: + self.div = 2**self.bits - 1 + if self.per_channel: + weight_tmp = weight.detach().reshape((weight.shape[0], -1)) + mean = paddle.mean(weight_tmp, axis=self.collect_axis) + std = paddle.std(weight_tmp, axis=self.collect_axis) + s = paddle.max( + paddle.stack( + [paddle.abs(mean - 3 * std), paddle.abs(mean + 3 * std)] + ), + axis=0, + ) + self.s.set_value(s * 0.9 + 0.1 * s / self.div) + else: + mean = paddle.mean(weight.detach()) + std = paddle.std(weight.detach()) + self.s.set_value( + self.s * 0.9 + + 0.1 + * max( + [paddle.abs(mean - 3 * std), paddle.abs(mean + 3 * std)] + ) + / self.div + ) + self.init_state += 1 + elif self.init_state == self.batch_init: + self.init_state += 1 + + weight.stop_gradient = False + w_q = LsqFunc.apply( + weight, + self.s, + self.g, + self.Qn, + self.Qp, + self.per_channel, + self.quant_axis, + ) + return w_q diff --git a/python/paddle/nn/quant/qat/__init__.py b/python/paddle/nn/quant/qat/__init__.py new file mode 100644 index 0000000000000..8701b8af763b7 --- /dev/null +++ b/python/paddle/nn/quant/qat/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .conv import QuantedConv2D +from .linear import QuantedLinear diff --git a/python/paddle/nn/quant/qat/conv.py b/python/paddle/nn/quant/qat/conv.py new file mode 100644 index 0000000000000..d639f181c3ce1 --- /dev/null +++ b/python/paddle/nn/quant/qat/conv.py @@ -0,0 +1,88 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Layers used for QAT. +""" +from paddle.nn import functional as F + +#from ...layer.layers import Layer +from paddle.fluid.dygraph.layers import Layer +from ..format import ConvertibleQuantedLayer + + +class QuantedConv2D(ConvertibleQuantedLayer): + """ + The computational logic of QuantizedConv2D is the same as Conv2D. + The only difference is that its inputs are all fake quantized. + """ + + def __init__(self, layer: Layer, q_config): + super().__init__() + + # For Conv2D + self._groups = layer._groups + self._stride = layer._stride + self._padding = layer._padding + self._padding_mode = layer._padding_mode + if self._padding_mode != 'zeros': + self._reversed_padding_repeated_twice = ( + layer._reversed_padding_repeated_twice + ) + self._dilation = layer._dilation + self._data_format = layer._data_format + self.weight = layer.weight + self.bias = layer.bias + + self.weight_quanter = None + self.activation_quanter = None + if q_config.weight is not None: + self.weight_quanter = q_config.weight._instance(layer) + if q_config.activation is not None: + self.activation_quanter = q_config.activation._instance(layer) + + def forward(self, input): + quant_input = input + quant_weight = self.weight + if self.activation_quanter is not None: + quant_input = self.activation_quanter(input) + if self.weight_quanter is not None: + quant_weight = self.weight_quanter(self.weight) + return self._conv_forward(quant_input, quant_weight) + + def _conv_forward(self, inputs, weights): + if self._padding_mode != 'zeros': + inputs = F.pad( + inputs, + self._reversed_padding_repeated_twice, + mode=self._padding_mode, + data_format=self._data_format, + ) + self._padding = 0 + + return F.conv2d( + inputs, + weights, + bias=self.bias, + padding=self._padding, + stride=self._stride, + dilation=self._dilation, + groups=self._groups, + data_format=self._data_format, + ) + + def weights_to_quanters(self): + return [('weight', 'weight_quanter')] + + def activation_quanters(self): + return ['activation_quanter'] diff --git a/python/paddle/nn/quant/qat/linear.py b/python/paddle/nn/quant/qat/linear.py new file mode 100644 index 0000000000000..b24a89752e9af --- /dev/null +++ b/python/paddle/nn/quant/qat/linear.py @@ -0,0 +1,66 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from paddle.nn import functional as F + +#from ...layer.layers import Layer +from paddle.fluid.dygraph.layers import Layer +from ..format import ConvertibleQuantedLayer +from paddle.distributed import fleet + + +class QuantedLinear(ConvertibleQuantedLayer): + """ + The computational logic of QuantizedLinear is the same as Linear. + The only difference is that its inputs are all fake quantized. + """ + + def __init__(self, layer: Layer, q_config): + super().__init__() + # For Linear + self.weight = layer.weight + self.bias = layer.bias + if isinstance(layer, fleet.meta_parallel.ColumnParallelLinear) or \ + isinstance(layer, fleet.meta_parallel.RowParallelLinear): + self.name = layer._name + else: + self.name = layer.name + # For FakeQuant + + self.weight_quanter = None + self.activation_quanter = None + if q_config.weight is not None: + self.weight_quanter = q_config.weight._instance(layer) + if q_config.activation is not None: + self.activation_quanter = q_config.activation._instance(layer) + + def forward(self, input): + quant_input = input + quant_weight = self.weight + if self.activation_quanter is not None: + quant_input = self.activation_quanter(input) + if self.weight_quanter is not None: + quant_weight = self.weight_quanter(self.weight) + return self._linear_forward(quant_input, quant_weight) + + def _linear_forward(self, input, weight): + out = F.linear(x=input, weight=weight, bias=self.bias, name=self.name) + return out + + def weights_to_quanters(self): + return [('weight', 'weight_quanter')] + + def activation_quanters(self): + return ['activation_quanter'] diff --git a/python/paddle/nn/quant/quant_layers.py b/python/paddle/nn/quant/quant_layers.py index 72bad0d44a8c3..4ab052d0d90f5 100644 --- a/python/paddle/nn/quant/quant_layers.py +++ b/python/paddle/nn/quant/quant_layers.py @@ -12,20 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + import paddle -from paddle.framework import core -from paddle.fluid import dygraph_utils -from paddle.utils import unique_name -from paddle.framework import ParamAttr -from paddle.fluid.framework import _varbase_creator -from paddle.nn.initializer import Constant +from paddle import _legacy_C_ops, in_dynamic_mode from paddle.fluid.data_feeder import check_variable_and_dtype -from paddle.nn import functional as F -import logging +from paddle.fluid.framework import _create_tensor from paddle.fluid.log_helper import get_logger -from paddle import _C_ops, _legacy_C_ops -from paddle import in_dynamic_mode -from paddle.nn import Layer +from paddle.framework import ParamAttr, core +from paddle.nn import functional as F +from paddle.nn.initializer import Constant +from paddle.nn.quant.lsq import FakeQuantActLSQPlus, FakeQuantWeightLSQPlus +from paddle.utils import unique_name + +#from ..layer.layers import Layer +from paddle.fluid.dygraph.layers import Layer __all__ = [ 'FakeQuantAbsMax', @@ -40,11 +41,12 @@ 'QuantStub', 'QuantizedRowParallelLinear', 'QuantizedColumnParallelLinear', + 'QuantizedMatmul', ] -_logger = get_logger(__name__, - logging.INFO, - fmt='%(asctime)s-%(levelname)s: %(message)s') +_logger = get_logger( + __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s' +) class FakeQuantAbsMax(Layer): @@ -57,26 +59,29 @@ class FakeQuantAbsMax(Layer): :math:`Out = round(X / scale * range) * scale / range` """ - def __init__(self, - name=None, - quant_bits=8, - dtype='float32', - quant_on_weight=False, - reduce_type=None): - super(FakeQuantAbsMax, self).__init__() + def __init__( + self, + name=None, + quant_bits=8, + dtype='float32', + quant_on_weight=False, + reduce_type=None, + ): + super().__init__() self._quant_bits = quant_bits self._name = name self._reduce_type = reduce_type - scale_prefix = "{}.scale".format( - name) if name else 'quant_dequant.scale' + scale_prefix = f"{name}.scale" if name else 'quant_dequant.scale' self._scale_name = unique_name.generate(scale_prefix) if quant_on_weight: - scale_attr = ParamAttr(name=self._scale_name, - initializer=Constant(0.001), - trainable=False) - self._scale = self.create_parameter(shape=[1], - attr=scale_attr, - dtype=self._dtype) + scale_attr = ParamAttr( + name=self._scale_name, + initializer=Constant(0.001), + trainable=False, + ) + self._scale = self.create_parameter( + shape=[1], attr=scale_attr, dtype=self._dtype + ) self._scale.stop_gradient = True else: self._scale = None @@ -84,38 +89,46 @@ def __init__(self, def forward(self, input): if in_dynamic_mode(): attrs = ('bit_length', self._quant_bits) - quant_out = _varbase_creator(type=input.type, - name="{}.quantized.dequantized".format( - input.name), - shape=input.shape, - dtype=input.dtype, - persistable=False) + quant_out = _create_tensor( + type=input.type, + name=f"{input.name}.quantized.dequantized", + shape=input.shape, + dtype=input.dtype, + persistable=False, + ) out_scale = self._scale if self._reduce_type == "max": paddle.distributed.all_reduce( - out_scale, op=paddle.distributed.ReduceOp.MAX) + out_scale, op=paddle.distributed.ReduceOp.MAX + ) if not out_scale: - out_scale = _varbase_creator( + out_scale = _create_tensor( type=core.VarDesc.VarType.LOD_TENSOR, name=self._scale_name, shape=[1], dtype=self._dtype, - persistable=False) + persistable=False, + ) out_scale.stop_gradient = True - out, _, = _legacy_C_ops.fake_quantize_dequantize_abs_max( - input, quant_out, out_scale, *attrs) + ( + out, + _, + ) = _legacy_C_ops.fake_quantize_dequantize_abs_max( + input, quant_out, out_scale, *attrs + ) return out check_variable_and_dtype(input, 'input', ['float32'], "FakeQuantAbsMax") attrs = {'bit_length': self._quant_bits} inputs = {"X": [input]} quant_out = self._helper.create_variable( - name="{}.quantized.dequantized".format(input.name), + name=f"{input.name}.quantized.dequantized", dtype=input.dtype, type=core.VarDesc.VarType.LOD_TENSOR, persistable=False, - stop_gradient=False) + stop_gradient=False, + ) out_scale = self._scale if not out_scale: out_scale = self._helper.create_variable( @@ -123,13 +136,16 @@ def forward(self, input): dtype=self._dtype, type=core.VarDesc.VarType.LOD_TENSOR, persistable=False, - stop_gradient=True) + stop_gradient=True, + ) outputs = {"Out": [quant_out], "OutScale": [out_scale]} - self._helper.append_op(type="fake_quantize_dequantize_abs_max", - inputs=inputs, - outputs=outputs, - attrs=attrs) + self._helper.append_op( + type="fake_quantize_dequantize_abs_max", + inputs=inputs, + outputs=outputs, + attrs=attrs, + ) return quant_out @@ -144,83 +160,111 @@ class FakeQuantMovingAverageAbsMax(Layer): :math:`Out = round(X / scale * range) * scale / range` """ - def __init__(self, - name=None, - moving_rate=0.9, - quant_bits=8, - dtype='float32', - reduce_type=None): - super(FakeQuantMovingAverageAbsMax, self).__init__() + def __init__( + self, + name=None, + moving_rate=0.9, + quant_bits=8, + dtype='float32', + reduce_type=None, + ): + super().__init__() self._moving_rate = moving_rate self._quant_bits = quant_bits self._reduce_type = reduce_type - scale_prefix = "{}.scale".format( - name) if name else 'quant_dequant.scale' - scale_attr = ParamAttr(name=unique_name.generate(scale_prefix), - initializer=Constant(0.001), - trainable=False) - self._scale = self.create_parameter(shape=[1], - attr=scale_attr, - dtype=dtype) + scale_prefix = f"{name}.scale" if name else 'quant_dequant.scale' + scale_attr = ParamAttr( + name=unique_name.generate(scale_prefix), + initializer=Constant(0.001), + trainable=False, + ) + self._scale = self.create_parameter( + shape=[1], attr=scale_attr, dtype=dtype + ) self._scale.stop_gradient = True - state_prefix = "{}.state".format( - name) if name else 'quant_dequant.state' - state_attr = ParamAttr(name=unique_name.generate(state_prefix), - initializer=Constant(1), - trainable=False) - self._state = self.create_parameter(shape=[1], - attr=state_attr, - dtype=dtype) + state_prefix = f"{name}.state" if name else 'quant_dequant.state' + state_attr = ParamAttr( + name=unique_name.generate(state_prefix), + initializer=Constant(1), + trainable=False, + ) + self._state = self.create_parameter( + shape=[1], attr=state_attr, dtype=dtype + ) self._state.stop_gradient = True - accum_prefix = "{}.accum".format( - name) if name else 'quant_dequant.accum' - accum_attr = ParamAttr(name=unique_name.generate(accum_prefix), - initializer=Constant(1), - trainable=False) - self._accum = self.create_parameter(shape=[1], - attr=accum_attr, - dtype=dtype) + accum_prefix = f"{name}.accum" if name else 'quant_dequant.accum' + accum_attr = ParamAttr( + name=unique_name.generate(accum_prefix), + initializer=Constant(1), + trainable=False, + ) + self._accum = self.create_parameter( + shape=[1], attr=accum_attr, dtype=dtype + ) self._accum.stop_gradient = True def forward(self, input): if in_dynamic_mode(): - attrs = ('moving_rate', self._moving_rate, 'bit_length', - self._quant_bits, 'is_test', not self.training) - quant_out = _varbase_creator(type=input.type, - name="{}.quantized.dequantized".format( - input.name), - shape=input.shape, - dtype=input.dtype, - persistable=False) + attrs = ( + 'moving_rate', + self._moving_rate, + 'bit_length', + self._quant_bits, + 'is_test', + not self.training, + ) + quant_out = _create_tensor( + type=input.type, + name=f"{input.name}.quantized.dequantized", + shape=input.shape, + dtype=input.dtype, + persistable=False, + ) if self._reduce_type == "max": paddle.distributed.all_reduce( - self._scale, op=paddle.distributed.ReduceOp.MAX) + self._scale, op=paddle.distributed.ReduceOp.MAX + ) state = self._state if self.training else None accum = self._accum if self.training else None - out, _, _, _ = _legacy_C_ops.fake_quantize_dequantize_moving_average_abs_max( - input, self._scale, accum, state, quant_out, self._scale, state, - accum, *attrs) + ( + out, + _, + _, + _, + ) = _legacy_C_ops.fake_quantize_dequantize_moving_average_abs_max( + input, + self._scale, + accum, + state, + quant_out, + self._scale, + state, + accum, + *attrs, + ) return out - check_variable_and_dtype(input, 'input', ['float32'], - "FakeQuantMovingAverageAbsMax") + check_variable_and_dtype( + input, 'input', ['float32'], "FakeQuantMovingAverageAbsMax" + ) attrs = { 'moving_rate': self._moving_rate, 'bit_length': self._quant_bits, - 'is_test': not self.training + 'is_test': not self.training, } inputs = {"X": [input], "InScale": [self._scale]} quant_out = self._helper.create_variable( - name="{}.quantized.dequantized".format(input.name), + name=f"{input.name}.quantized.dequantized", dtype=input.dtype, type=core.VarDesc.VarType.LOD_TENSOR, persistable=False, - stop_gradient=False) + stop_gradient=False, + ) outputs = {"Out": [quant_out], "OutScale": [self._scale]} if self.training: @@ -233,81 +277,99 @@ def forward(self, input): type="fake_quantize_dequantize_moving_average_abs_max", inputs=inputs, outputs=outputs, - attrs=attrs) + attrs=attrs, + ) return quant_out class FakeQuantChannelWiseAbsMax(Layer): - - def __init__(self, - name=None, - channel_num=None, - quant_bits=8, - quant_axis=0, - dtype='float32', - quant_on_weight=False, - reduce_type=None): - assert quant_on_weight == True, "Channel_wise only can be used on weight quantization." - super(FakeQuantChannelWiseAbsMax, self).__init__() + def __init__( + self, + name=None, + channel_num=None, + quant_bits=8, + quant_axis=0, + dtype='float32', + quant_on_weight=False, + reduce_type=None, + ): + assert ( + quant_on_weight + ), "Channel_wise only can be used on weight quantization." + super().__init__() self._quant_bits = quant_bits self._quant_axis = quant_axis self._dtype = dtype self._name = name self._channel_num = channel_num self._reduce_type = reduce_type - scale_prefix = "{}.scale".format( - name) if name else 'quant_dequant.scale' + scale_prefix = f"{name}.scale" if name else 'quant_dequant.scale' self._scale_name = unique_name.generate(scale_prefix) if quant_on_weight: - scale_attr = ParamAttr(name=self._scale_name, - initializer=Constant(0.0), - trainable=False) - self._scale = self.create_parameter(shape=[self._channel_num], - attr=scale_attr, - dtype=self._dtype) + scale_attr = ParamAttr( + name=self._scale_name, + initializer=Constant(0.0), + trainable=False, + ) + self._scale = self.create_parameter( + shape=[self._channel_num], attr=scale_attr, dtype=self._dtype + ) self._scale.stop_gradient = True else: self._scale = None def forward(self, input): if in_dynamic_mode(): - attrs = ('bit_length', self._quant_bits, 'quant_axis', - self._quant_axis) - quant_out = _varbase_creator(type=input.type, - name="{}.quantized.dequantized".format( - input.name), - shape=input.shape, - dtype=input.dtype, - persistable=False) + attrs = ( + 'bit_length', + self._quant_bits, + 'quant_axis', + self._quant_axis, + ) + quant_out = _create_tensor( + type=input.type, + name=f"{input.name}.quantized.dequantized", + shape=input.shape, + dtype=input.dtype, + persistable=False, + ) out_scale = self._scale if self._reduce_type == "max": paddle.distributed.all_reduce( - out_scale, op=paddle.distributed.ReduceOp.MAX) + out_scale, op=paddle.distributed.ReduceOp.MAX + ) if out_scale is None: - out_scale = _varbase_creator( + out_scale = _create_tensor( type=core.VarDesc.VarType.LOD_TENSOR, name=self._scale_name, shape=[self._channel_num], dtype=self._dtype, - persistable=False) + persistable=False, + ) out_scale.stop_gradient = True - out, _, = _legacy_C_ops.fake_channel_wise_quantize_dequantize_abs_max( - input, quant_out, out_scale, *attrs) + ( + out, + _, + ) = _legacy_C_ops.fake_channel_wise_quantize_dequantize_abs_max( + input, quant_out, out_scale, *attrs + ) return out - check_variable_and_dtype(input, 'input', ['float32'], - "FakeQuantChannelWiseAbsMax") + check_variable_and_dtype( + input, 'input', ['float32'], "FakeQuantChannelWiseAbsMax" + ) attrs = {'bit_length': self._quant_bits, 'quant_axis': self._quant_axis} inputs = {"X": [input]} quant_out = self._helper.create_variable( - name="{}.quantized.dequantized".format(input.name), + name=f"{input.name}.quantized.dequantized", dtype=input.dtype, type=core.VarDesc.VarType.LOD_TENSOR, persistable=False, - stop_gradient=False) + stop_gradient=False, + ) out_scale = self._scale if not out_scale: out_scale = self._helper.create_variable( @@ -315,25 +377,24 @@ def forward(self, input): dtype=self._dtype, type=core.VarDesc.VarType.LOD_TENSOR, persistable=False, - stop_gradient=True) + stop_gradient=True, + ) outputs = {"Out": [quant_out], "OutScale": [out_scale]} self._helper.append_op( type="fake_channel_wise_quantize_dequantize_abs_max", inputs=inputs, outputs=outputs, - attrs=attrs) + attrs=attrs, + ) return quant_out class MovingAverageAbsMaxScale(Layer): - - def __init__(self, - name=None, - moving_rate=0.9, - dtype='float32', - reduce_type=None): + def __init__( + self, name=None, moving_rate=0.9, dtype='float32', reduce_type=None + ): r""" MovingAverageMaxScale layer is used to calculating the output quantization scale of Layer. Its computational formula is described as below: @@ -341,70 +402,90 @@ def __init__(self, :math:`scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)` :math:`Out = X` """ - super(MovingAverageAbsMaxScale, self).__init__() + super().__init__() self._moving_rate = moving_rate self._reduce_type = reduce_type - scale_prefix = '{}.scale'.format(name) if name else 'outscale.scale' + scale_prefix = f'{name}.scale' if name else 'outscale.scale' scale_name = unique_name.generate(scale_prefix) - scale_attr = ParamAttr(name=scale_name, - initializer=Constant(0), - trainable=False) - self._scale = self.create_parameter(shape=[1], - attr=scale_attr, - dtype=dtype) + scale_attr = ParamAttr( + name=scale_name, initializer=Constant(0), trainable=False + ) + self._scale = self.create_parameter( + shape=[1], attr=scale_attr, dtype=dtype + ) self._scale.stop_gradient = True - state_prefix = "{}.state".format(name) if name else 'outscale.state' - state_attr = ParamAttr(name=unique_name.generate(state_prefix), - initializer=Constant(0), - trainable=False) - self._state = self.create_parameter(shape=[1], - attr=state_attr, - dtype=dtype) + state_prefix = f"{name}.state" if name else 'outscale.state' + state_attr = ParamAttr( + name=unique_name.generate(state_prefix), + initializer=Constant(0), + trainable=False, + ) + self._state = self.create_parameter( + shape=[1], attr=state_attr, dtype=dtype + ) self._state.stop_gradient = True - accum_prefix = "{}.accum".format(name) if name else 'outscale.accum' - accum_attr = ParamAttr(name=unique_name.generate(accum_prefix), - initializer=Constant(0), - trainable=False) - self._accum = self.create_parameter(shape=[1], - attr=accum_attr, - dtype=dtype) + accum_prefix = f"{name}.accum" if name else 'outscale.accum' + accum_attr = ParamAttr( + name=unique_name.generate(accum_prefix), + initializer=Constant(0), + trainable=False, + ) + self._accum = self.create_parameter( + shape=[1], attr=accum_attr, dtype=dtype + ) self._accum.stop_gradient = True def forward(self, input): if in_dynamic_mode(): - attrs = ('moving_rate', self._moving_rate, 'is_test', - not self.training) - - quant_out = _varbase_creator(type=input.type, - name="{}.tmp".format(input.name), - shape=input.shape, - dtype=input.dtype, - persistable=False) + attrs = ( + 'moving_rate', + self._moving_rate, + 'is_test', + not self.training, + ) + + quant_out = _create_tensor( + type=input.type, + name=f"{input.name}.tmp", + shape=input.shape, + dtype=input.dtype, + persistable=False, + ) if self._reduce_type == "max": paddle.distributed.all_reduce( - self._scale, op=paddle.distributed.ReduceOp.MAX) + self._scale, op=paddle.distributed.ReduceOp.MAX + ) state = self._state if self.training else None accum = self._accum if self.training else None out, _, _, _ = _legacy_C_ops.moving_average_abs_max_scale( - input, accum, state, quant_out, self._scale, state, accum, - *attrs) + input, + accum, + state, + quant_out, + self._scale, + state, + accum, + *attrs, + ) return out - check_variable_and_dtype(input, 'input', ['float32', 'float64'], - 'MovingAverageAbsMaxScale') + check_variable_and_dtype( + input, 'input', ['float32', 'float64'], 'MovingAverageAbsMaxScale' + ) attrs = {'moving_rate': self._moving_rate, 'is_test': not self.training} inputs = {"X": [input]} quant_out = self._helper.create_variable( - name="{}.tmp".format(input.name), + name=f"{input.name}.tmp", dtype=input.dtype, type=core.VarDesc.VarType.LOD_TENSOR, persistable=False, - stop_gradient=False) + stop_gradient=False, + ) outputs = {"Out": [quant_out], "OutScale": [self._scale]} if self.training: @@ -413,10 +494,12 @@ def forward(self, input): outputs['OutState'] = [self._state] outputs['OutAccum'] = [self._accum] - self._helper.append_op(type="moving_average_abs_max_scale", - inputs=inputs, - outputs=outputs, - attrs=attrs) + self._helper.append_op( + type="moving_average_abs_max_scale", + inputs=inputs, + outputs=outputs, + attrs=attrs, + ) return quant_out @@ -430,30 +513,33 @@ class QuantizedConv2D(Layer): The only difference is that its inputs are all fake quantized. """ - def __init__(self, - layer, - weight_bits=8, - activation_bits=8, - moving_rate=0.9, - weight_quantize_type='abs_max', - activation_quantize_type='abs_max', - weight_pre_layer=None, - act_pre_layer=None, - weight_quant_layer=None, - act_quant_layer=None): - super(QuantizedConv2D, self).__init__() + def __init__( + self, + layer, + weight_bits=8, + activation_bits=8, + moving_rate=0.9, + weight_quantize_type='abs_max', + activation_quantize_type='abs_max', + weight_pre_layer=None, + act_pre_layer=None, + weight_quant_layer=None, + act_quant_layer=None, + ): + super().__init__() # For Conv2D - self._groups = getattr(layer, '_groups') - self._stride = getattr(layer, '_stride') - self._padding = getattr(layer, '_padding') - self._padding_mode = getattr(layer, '_padding_mode') + self._groups = layer._groups + self._stride = layer._stride + self._padding = layer._padding + self._padding_mode = layer._padding_mode if self._padding_mode != 'zeros': - self._reversed_padding_repeated_twice = getattr( - layer, '_reversed_padding_repeated_twice') - self._dilation = getattr(layer, '_dilation') - self._data_format = getattr(layer, '_data_format') - self.weight = getattr(layer, 'weight') - self.bias = getattr(layer, 'bias') + self._reversed_padding_repeated_twice = ( + layer._reversed_padding_repeated_twice + ) + self._dilation = layer._dilation + self._data_format = layer._data_format + self.weight = layer.weight + self.bias = layer.bias # For FakeQuant self._conv2d_quant_axis = 0 @@ -468,7 +554,8 @@ def __init__(self, dtype=self._dtype, quant_on_weight=True, channel_num=self.weight.shape[self._conv2d_quant_axis], - quant_axis=self._conv2d_quant_axis) + quant_axis=self._conv2d_quant_axis, + ) if act_quant_layer is not None: self._fake_quant_input = act_quant_layer() else: @@ -478,12 +565,15 @@ def __init__(self, moving_rate=moving_rate, quant_bits=activation_bits, dtype=self._dtype, - quant_on_weight=False) + quant_on_weight=False, + ) - self._act_preprocess = act_pre_layer( - ) if act_pre_layer is not None else None - self._weight_preprocess = weight_pre_layer( - ) if weight_pre_layer is not None else None + self._act_preprocess = ( + act_pre_layer() if act_pre_layer is not None else None + ) + self._weight_preprocess = ( + weight_pre_layer() if weight_pre_layer is not None else None + ) def forward(self, input): if self._act_preprocess is not None: @@ -496,20 +586,24 @@ def forward(self, input): quant_weight = self._fake_quant_weight(weight) if self._padding_mode != 'zeros': - quant_input = F.pad(quant_input, - self._reversed_padding_repeated_twice, - mode=self._padding_mode, - data_format=self._data_format) + quant_input = F.pad( + quant_input, + self._reversed_padding_repeated_twice, + mode=self._padding_mode, + data_format=self._data_format, + ) self._padding = 0 - return F.conv2d(quant_input, - quant_weight, - bias=self.bias, - padding=self._padding, - stride=self._stride, - dilation=self._dilation, - groups=self._groups, - data_format=self._data_format) + return F.conv2d( + quant_input, + quant_weight, + bias=self.bias, + padding=self._padding, + stride=self._stride, + dilation=self._dilation, + groups=self._groups, + data_format=self._data_format, + ) class QuantizedConv2DTranspose(Layer): @@ -517,52 +611,54 @@ class QuantizedConv2DTranspose(Layer): The computational logic of QuantizedConv2DTranspose is the same with Conv2DTranspose. The only difference is that its inputs are all fake quantized. - + Examples: - .. code-block:: python - - import paddle - import paddle.nn as nn - from paddle.nn.quant.quant_layers import QuantizedConv2DTranspose - - x_var = paddle.uniform((2, 4, 8, 8), dtype='float32', min=-1., max=1.) - conv = nn.Conv2DTranspose(4, 6, (3, 3)) - conv_quantized = QuantizedConv2DTranspose(conv) - y_quantized = conv_quantized(x_var) - y_var = conv(x_var) - y_quantized_np = y_quantized.numpy() - y_np = y_var.numpy() - print(y_np.shape, y_quantized_np.shape) - # (2, 6, 10, 10), (2, 6, 10, 10) + .. code-block:: python + + >>> import paddle + >>> import paddle.nn as nn + >>> from paddle.nn.quant.quant_layers import QuantizedConv2DTranspose + + >>> x_var = paddle.uniform((2, 4, 8, 8), dtype='float32', min=-1., max=1.) + >>> conv = nn.Conv2DTranspose(4, 6, (3, 3)) + >>> conv_quantized = QuantizedConv2DTranspose(conv) + >>> y_quantized = conv_quantized(x_var) + >>> y_var = conv(x_var) + >>> print(y_var.shape) + [2, 6, 10, 10] + >>> print(y_quantized.shape) + [2, 6, 10, 10] """ - def __init__(self, - layer, - weight_bits=8, - activation_bits=8, - moving_rate=0.9, - weight_quantize_type='abs_max', - activation_quantize_type='abs_max', - weight_pre_layer=None, - act_pre_layer=None, - weight_quant_layer=None, - act_quant_layer=None): + def __init__( + self, + layer, + weight_bits=8, + activation_bits=8, + moving_rate=0.9, + weight_quantize_type='abs_max', + activation_quantize_type='abs_max', + weight_pre_layer=None, + act_pre_layer=None, + weight_quant_layer=None, + act_quant_layer=None, + ): r""" Constructor. The arguments are the same as ImperativeQuantAware. """ - super(QuantizedConv2DTranspose, self).__init__() + super().__init__() # For Conv2DTranspose - self._groups = getattr(layer, '_groups') - self._stride = getattr(layer, '_stride') - self._padding = getattr(layer, '_padding') - self._output_padding = getattr(layer, 'output_padding') - self._dilation = getattr(layer, '_dilation') - self._data_format = getattr(layer, '_data_format') - self.weight = getattr(layer, 'weight') - self.bias = getattr(layer, 'bias') + self._groups = layer._groups + self._stride = layer._stride + self._padding = layer._padding + self._output_padding = layer.output_padding + self._dilation = layer._dilation + self._data_format = layer._data_format + self.weight = layer.weight + self.bias = layer.bias # For FakeQuant self._conv2d_transpose_quant_axis = 1 if weight_quant_layer is not None: @@ -576,8 +672,10 @@ def __init__(self, dtype=self._dtype, quant_on_weight=True, channel_num=self.weight.shape[ - self._conv2d_transpose_quant_axis], - quant_axis=self._conv2d_transpose_quant_axis) + self._conv2d_transpose_quant_axis + ], + quant_axis=self._conv2d_transpose_quant_axis, + ) if act_quant_layer is not None: self._fake_quant_input = act_quant_layer() else: @@ -587,12 +685,15 @@ def __init__(self, moving_rate=moving_rate, quant_bits=activation_bits, dtype=self._dtype, - quant_on_weight=False) + quant_on_weight=False, + ) - self._act_preprocess = act_pre_layer( - ) if act_pre_layer is not None else None - self._weight_preprocess = weight_pre_layer( - ) if weight_pre_layer is not None else None + self._act_preprocess = ( + act_pre_layer() if act_pre_layer is not None else None + ) + self._weight_preprocess = ( + weight_pre_layer() if weight_pre_layer is not None else None + ) def forward(self, input, output_size=None): if self._act_preprocess is not None: @@ -609,16 +710,18 @@ def forward(self, input, output_size=None): else: output_padding = 0 - return F.conv2d_transpose(quant_input, - quant_weight, - bias=self.bias, - padding=self._padding, - output_padding=output_padding, - stride=self._stride, - dilation=self._dilation, - groups=self._groups, - output_size=output_size, - data_format=self._data_format) + return F.conv2d_transpose( + quant_input, + quant_weight, + bias=self.bias, + padding=self._padding, + output_padding=output_padding, + stride=self._stride, + dilation=self._dilation, + groups=self._groups, + output_size=output_size, + data_format=self._data_format, + ) class QuantizedLinear(Layer): @@ -627,22 +730,24 @@ class QuantizedLinear(Layer): The only difference is that its inputs are all fake quantized. """ - def __init__(self, - layer, - weight_bits=8, - activation_bits=8, - moving_rate=0.9, - weight_quantize_type='abs_max', - activation_quantize_type='abs_max', - weight_pre_layer=None, - act_pre_layer=None, - weight_quant_layer=None, - act_quant_layer=None): - super(QuantizedLinear, self).__init__() + def __init__( + self, + layer, + weight_bits=8, + activation_bits=8, + moving_rate=0.9, + weight_quantize_type='abs_max', + activation_quantize_type='abs_max', + weight_pre_layer=None, + act_pre_layer=None, + weight_quant_layer=None, + act_quant_layer=None, + ): + super().__init__() # For Linear - self.weight = getattr(layer, 'weight') - self.bias = getattr(layer, 'bias') - self.name = getattr(layer, 'name') + self.weight = layer.weight + self.bias = layer.bias + self.name = layer.name # For FakeQuant self._linear_quant_axis = 1 @@ -657,7 +762,9 @@ def __init__(self, dtype=self._dtype, quant_on_weight=True, channel_num=self.weight.shape[self._linear_quant_axis], - quant_axis=self._linear_quant_axis) + quant_axis=self._linear_quant_axis, + quant_linear=True, + ) if act_quant_layer is not None: self._fake_quant_input = act_quant_layer() @@ -668,12 +775,15 @@ def __init__(self, moving_rate=moving_rate, quant_bits=activation_bits, dtype=self._dtype, - quant_on_weight=False) + quant_on_weight=False, + ) - self._act_preprocess = act_pre_layer( - ) if act_pre_layer is not None else None - self._weight_preprocess = weight_pre_layer( - ) if weight_pre_layer is not None else None + self._act_preprocess = ( + act_pre_layer() if act_pre_layer is not None else None + ) + self._weight_preprocess = ( + weight_pre_layer() if weight_pre_layer is not None else None + ) def forward(self, input): if self._act_preprocess is not None: @@ -685,42 +795,46 @@ def forward(self, input): weight = self._weight_preprocess(self.weight) quant_weight = self._fake_quant_weight(weight) - out = F.linear(x=quant_input, - weight=quant_weight, - bias=self.bias, - name=self.name) + out = F.linear( + x=quant_input, weight=quant_weight, bias=self.bias, name=self.name + ) return out class QuantizedColumnParallelLinear(Layer): - - def __init__(self, - layer, - weight_bits=8, - activation_bits=8, - moving_rate=0.9, - weight_quantize_type='abs_max', - activation_quantize_type='abs_max', - weight_pre_layer=None, - act_pre_layer=None, - weight_quant_layer=None, - act_quant_layer=None): - super(QuantizedColumnParallelLinear, self).__init__() - ''' - + def __init__( + self, + layer, + weight_bits=8, + activation_bits=8, + moving_rate=0.9, + weight_quantize_type='abs_max', + activation_quantize_type='abs_max', + weight_pre_layer=None, + act_pre_layer=None, + weight_quant_layer=None, + act_quant_layer=None, + ): + super().__init__() ''' - assert weight_quant_layer is None, "When quantizing ColumnParallelLinear, weight_quant_layer should be None." - assert act_quant_layer is None, "When quantizing ColumnParallelLinear, act_quant_layer should be None." - self.weight = getattr(layer, 'weight') - self.bias = getattr(layer, 'bias') - self.name = getattr(layer, '_name') + ''' + assert ( + weight_quant_layer is None + ), "When quantizing ColumnParallelLinear, weight_quant_layer should be None." + assert ( + act_quant_layer is None + ), "When quantizing ColumnParallelLinear, act_quant_layer should be None." + + self.weight = layer.weight + self.bias = layer.bias + self.name = layer._name # For FakeQuant self._linear_quant_axis = 1 - self.is_mp = getattr(layer, 'is_mp') - self.model_parallel_group = getattr(layer, 'model_parallel_group') - self.gather_output = getattr(layer, 'gather_output') + self.is_mp = layer.is_mp + self.model_parallel_group = layer.model_parallel_group + self.gather_output = layer.gather_output self._fake_quant_weight = _get_fake_quant_type( weight_quantize_type, @@ -732,7 +846,9 @@ def __init__(self, channel_num=self.weight.shape[self._linear_quant_axis], quant_axis=self._linear_quant_axis, reduce_type='max' - if paddle.distributed.get_world_size() > 1 else None) + if paddle.distributed.get_world_size() > 1 + else None, + ) self._fake_quant_input = _get_fake_quant_type( activation_quantize_type, @@ -741,17 +857,21 @@ def __init__(self, quant_bits=activation_bits, dtype=self._dtype, quant_on_weight=False, - reduce_type=None) + reduce_type=None, + ) - self._act_preprocess = act_pre_layer( - ) if act_pre_layer is not None else None - self._weight_preprocess = weight_pre_layer( - ) if weight_pre_layer is not None else None + self._act_preprocess = ( + act_pre_layer() if act_pre_layer is not None else None + ) + self._weight_preprocess = ( + weight_pre_layer() if weight_pre_layer is not None else None + ) def forward(self, input): if self.is_mp: input_parallel = paddle.distributed.collective._c_identity( - input, group=self.model_parallel_group) + input, group=self.model_parallel_group + ) else: input_parallel = input @@ -764,46 +884,51 @@ def forward(self, input): weight = self._weight_preprocess(self.weight) quant_weight = self._fake_quant_weight(weight) - output_parallel = F.linear(x=quant_input, - weight=quant_weight, - bias=self.bias, - name=self.name) + output_parallel = F.linear( + x=quant_input, weight=quant_weight, bias=self.bias, name=self.name + ) if self.gather_output and self.is_mp: output = paddle.distributed.collective._c_concat( - output_parallel, group=self.model_parallel_group) + output_parallel, group=self.model_parallel_group + ) else: output = output_parallel return output class QuantizedRowParallelLinear(Layer): - - def __init__(self, - layer, - weight_bits=8, - activation_bits=8, - moving_rate=0.9, - weight_quantize_type='abs_max', - activation_quantize_type='abs_max', - weight_pre_layer=None, - act_pre_layer=None, - weight_quant_layer=None, - act_quant_layer=None): - super(QuantizedRowParallelLinear, self).__init__() - assert weight_quant_layer is None, "When quantizing RowParallelLinear, weight_quant_layer cannot defined by yourself." - assert act_quant_layer is None, "When quantizing RowParallelLinear, act_quant_layer cannot defined by yourself." + def __init__( + self, + layer, + weight_bits=8, + activation_bits=8, + moving_rate=0.9, + weight_quantize_type='abs_max', + activation_quantize_type='abs_max', + weight_pre_layer=None, + act_pre_layer=None, + weight_quant_layer=None, + act_quant_layer=None, + ): + super().__init__() + assert ( + weight_quant_layer is None + ), "When quantizing RowParallelLinear, weight_quant_layer cannot defined by yourself." + assert ( + act_quant_layer is None + ), "When quantizing RowParallelLinear, act_quant_layer cannot defined by yourself." # For Linear - self.weight = getattr(layer, 'weight') - self.bias = getattr(layer, 'bias') - self.name = getattr(layer, '_name') + self.weight = layer.weight + self.bias = layer.bias + self.name = layer._name # For FakeQuant self._linear_quant_axis = 1 - self.input_is_parallel = getattr(layer, 'input_is_parallel') - self.is_mp = getattr(layer, 'is_mp') - self.model_parallel_group = getattr(layer, 'model_parallel_group') + self.input_is_parallel = layer.input_is_parallel + self.is_mp = layer.is_mp + self.model_parallel_group = layer.model_parallel_group self._fake_quant_weight = _get_fake_quant_type( weight_quantize_type, @@ -815,7 +940,9 @@ def __init__(self, channel_num=self.weight.shape[self._linear_quant_axis], quant_axis=self._linear_quant_axis, reduce_type='max' - if paddle.distributed.get_world_size() > 1 else None) + if paddle.distributed.get_world_size() > 1 + else None, + ) self._fake_quant_input = _get_fake_quant_type( activation_quantize_type, @@ -825,12 +952,16 @@ def __init__(self, dtype=self._dtype, quant_on_weight=False, reduce_type='max' - if paddle.distributed.get_world_size() > 1 else None) + if paddle.distributed.get_world_size() > 1 + else None, + ) - self._act_preprocess = act_pre_layer( - ) if act_pre_layer is not None else None - self._weight_preprocess = weight_pre_layer( - ) if weight_pre_layer is not None else None + self._act_preprocess = ( + act_pre_layer() if act_pre_layer is not None else None + ) + self._weight_preprocess = ( + weight_pre_layer() if weight_pre_layer is not None else None + ) def forward(self, input): if self.input_is_parallel or (not self.is_mp): @@ -838,7 +969,8 @@ def forward(self, input): else: # split last dim input_parallel = paddle.distributed.collective._c_split( - input, group=self.model_parallel_group) + input, group=self.model_parallel_group + ) if self._act_preprocess is not None: input_parallel = self._act_preprocess(input_parallel) @@ -849,48 +981,110 @@ def forward(self, input): weight = self._weight_preprocess(self.weight) quant_weight = self._fake_quant_weight(weight) - output_parallel = F.linear(x=quant_input, - weight=quant_weight, - name=self.name) + output_parallel = F.linear( + x=quant_input, weight=quant_weight, name=self.name + ) if self.is_mp: output_ = paddle.distributed.collective._mp_allreduce( output_parallel, group=self.model_parallel_group, use_calc_stream=True, - use_model_parallel=True) + use_model_parallel=True, + ) else: output_ = output_parallel output = output_ + self.bias if self.bias is not None else output_ return output +class QuantizedMatmul(Layer): + """ + The computational logic of QuantizedMatmul is the same with Matmul. + The only difference is that its inputs are all fake quantized. + """ + + def __init__( + self, + layer=None, + weight_bits=8, + activation_bits=8, + moving_rate=0.9, + weight_quantize_type='abs_max', + activation_quantize_type='abs_max', + weight_pre_layer=None, + act_pre_layer=None, + weight_quant_layer=None, + act_quant_layer=None, + ): + super().__init__() + + # For FakeQuant + if act_quant_layer is not None: + self._fake_quant_x = act_quant_layer() + self._fake_quant_y = act_quant_layer() + else: + self._fake_quant_x = _get_fake_quant_type( + activation_quantize_type, + moving_rate=moving_rate, + quant_bits=activation_bits, + quant_on_weight=False, + ) + self._fake_quant_y = _get_fake_quant_type( + activation_quantize_type, + moving_rate=moving_rate, + quant_bits=activation_bits, + quant_on_weight=False, + ) + + self._act_preprocess_x = ( + act_pre_layer() if act_pre_layer is not None else None + ) + self._act_preprocess_y = ( + act_pre_layer() if act_pre_layer is not None else None + ) + + def forward(self, x, y, transpose_x=False, transpose_y=False, name=None): + if self._act_preprocess_x is not None: + x = self._act_preprocess_x(x) + quant_x = self._fake_quant_x(x) + + if self._act_preprocess_y is not None: + y = self._act_preprocess_y(y) + quant_y = self._fake_quant_y(y) + + out = paddle.matmul(quant_x, quant_y, transpose_x, transpose_y, name) + return out + + class MAOutputScaleLayer(Layer): """ Add MovingAverageMaxScale layer to the behind of the input layer. Calculate the scale (moving average abs max) for the output of the input layer. """ - def __init__(self, - layer=None, - moving_rate=0.9, - name=None, - dtype='float32', - reduce_type=None): + def __init__( + self, + layer=None, + moving_rate=0.9, + name=None, + dtype='float32', + reduce_type=None, + ): r""" Construct """ - super(MAOutputScaleLayer, self).__init__() + super().__init__() self._layer = layer if name is None: name = layer.full_name() - self._ma_output_scale = \ - MovingAverageAbsMaxScale(name, moving_rate, dtype, reduce_type) + self._ma_output_scale = MovingAverageAbsMaxScale( + name, moving_rate, dtype, reduce_type + ) def forward(self, *inputs, **kwargs): out = self._layer(*inputs, **kwargs) # TODO (jc): support the ops of several outputs - if (isinstance(out, list) or isinstance(out, tuple) - or isinstance(out, dict)): + if isinstance(out, (list, tuple, dict)): return out else: return self._ma_output_scale(out) @@ -901,17 +1095,18 @@ class FakeQuantMAOutputScaleLayer(Layer): Add FakeQuantMovingAverageAbsMax layer to the behind of the input layer. """ - def __init__(self, - layer, - weight_bits=8, - activation_bits=8, - moving_rate=0.9, - name=None, - reduce_type=None, - *args, - **kwargs): - - super(FakeQuantMAOutputScaleLayer, self).__init__() + def __init__( + self, + layer, + weight_bits=8, + activation_bits=8, + moving_rate=0.9, + name=None, + reduce_type=None, + *args, + **kwargs, + ): + super().__init__() self._layer = layer self._fake_quant_output = _get_fake_quant_type( 'moving_average_abs_max', @@ -920,12 +1115,13 @@ def __init__(self, quant_bits=activation_bits, dtype=self._dtype, quant_on_weight=False, - reduce_type=reduce_type) + reduce_type=reduce_type, + ) def forward(self, *inputs, **kwargs): out = self._layer(*inputs, **kwargs) # TODO (jc): support the ops of several outputs - if (isinstance(out, list) or isinstance(out, tuple)) and len(out) > 1: + if (isinstance(out, (list, tuple))) and len(out) > 1: return out else: return self._fake_quant_output(out) @@ -936,7 +1132,7 @@ def _get_fake_quant_type(quant_type, **kwargs): "name": kwargs.get("name", None), "quant_bits": kwargs.get("quant_bits", 8), "dtype": kwargs.get("dtype", "float32"), - "reduce_type": kwargs.get("reduce_type", None) + "reduce_type": kwargs.get("reduce_type", None), } if quant_type == 'abs_max': @@ -949,11 +1145,32 @@ def _get_fake_quant_type(quant_type, **kwargs): call_args["quant_axis"] = kwargs.get("quant_axis", 0) assert call_args["channel_num"] is not None, ( "You need to input channel_num" - "when you use channel_wise_abs_max strategy.") + "when you use channel_wise_abs_max strategy." + ) + elif quant_type == 'lsq_weight': + call_args["all_postive"] = kwargs.get("all_postive", False) + call_args["per_channel"] = False + call_args["channel_num"] = 1 + call_args["quant_linear"] = kwargs.get("quant_linear", False) + elif quant_type == 'channel_wise_lsq_weight': + quant_type = 'lsq_weight' + call_args["all_postive"] = kwargs.get("all_postive", False) + call_args["per_channel"] = True + call_args["channel_num"] = kwargs.get("channel_num", None) + call_args["quant_linear"] = kwargs.get("quant_linear", False) + assert call_args["channel_num"] is not None, ( + "You need to input channel_num" + "when you use channel_wise_abs_max strategy." + ) + elif quant_type == 'lsq_act': + call_args["all_postive"] = kwargs.get("all_postive", False) + call_args["symmetric"] = kwargs.get("symmetric", True) fake_quant_map = { 'abs_max': FakeQuantAbsMax, 'moving_average_abs_max': FakeQuantMovingAverageAbsMax, - 'channel_wise_abs_max': FakeQuantChannelWiseAbsMax + 'channel_wise_abs_max': FakeQuantChannelWiseAbsMax, + 'lsq_weight': FakeQuantWeightLSQPlus, + 'lsq_act': FakeQuantActLSQPlus, } return fake_quant_map[quant_type](**call_args) diff --git a/python/paddle/nn/quant/quantized_linear.py b/python/paddle/nn/quant/quantized_linear.py new file mode 100644 index 0000000000000..0de9fb5fe5fd3 --- /dev/null +++ b/python/paddle/nn/quant/quantized_linear.py @@ -0,0 +1,245 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle import _C_ops +from paddle.fluid.data_feeder import check_dtype +from paddle.fluid.framework import convert_np_dtype_to_dtype_ +from paddle.framework import LayerHelper, in_dynamic_mode + + +def weight_quantize(x, algo="weight_only_int8"): + """ + Quantization function for weight_only and llm.int8's weight. + + Args: + x (Tensor): The input Tensor to be quantized, the data type is float16 or bfloat16. + algo (str): The algo that is x will be apply, must be one of 'weight_only_int8', + 'weight_only_int4' and 'llm.int8', default: 'weight_only_int8'. + + Returns: + out (Tensor): The Tensor which is the quantitative results, the data type is int8, the shape is transposition of x. + scale (Tensor): The scale Tensor which is the scale of pre-channel, the data type is float32. + Examples: + .. code-block:: python + + >>> # doctest: +SKIP('No testing required') + >>> import paddle + >>> from paddle.nn.quant import weight_quantize + + >>> paddle.seed(2023) + >>> x = paddle.rand(shape=[64, 32], dtype=paddle.float16) + >>> out, scale = weight_quantize(x, algo='weight_only_int8') + >>> print(out.shape) + [32, 64] + >>> print(scale.shape) + [32] + """ + + if in_dynamic_mode(): + return _C_ops.weight_quantize(x, algo) + else: + type = "weight_quantize" + helper = LayerHelper(type, **locals()) + out = helper.create_variable_for_type_inference('int8') + scale = helper.create_variable_for_type_inference('float') + + helper.append_op( + type=type, + inputs={"x": x}, + outputs={'out': out, "scale": scale}, + attrs={"algo": algo}, + ) + return (out, scale) + + +def weight_dequantize(x, scale, algo="weight_only_int8", out_dtype='float16'): + """ + Dequantization function for weight_only and llm.int8's weight. + + Args: + x (Tensor): The input Tensor to be dequantized, the data type is int8. + scale (Tensor): The scale Tensor which is the output of weight_quantize, the data type is float32. + algo (str): The algo that is x will be apply, must be one of 'weight_only_int8', + 'weight_only_int4' and 'llm.int8', default: 'weight_only_int8'. + out_dtype (str|np.dtype): The output Tensor's data type, must be one of 'float16' and 'bfloat16', default: 'float16'. + + Returns: + out (Tensor): The Tensor which is the dequantitative results, the data type is float16 or bfloat16, the shape is transposition of x. + + Examples: + .. code-block:: python + + >>> # doctest: +SKIP('No testing required') + >>> import paddle + >>> from paddle.nn.quant import weight_quantize, weight_dequantize + + >>> paddle.seed(2023) + >>> x = paddle.rand(shape=[64, 32], dtype=paddle.float16) + >>> out, scale = weight_quantize(x, algo='weight_only_int8') + >>> x_dequant = weight_dequantize(out, scale) + """ + check_dtype( + out_dtype, 'out_dtype', ['float16', 'bfloat16'], 'weight_dequantize' + ) + out_dtype = convert_np_dtype_to_dtype_(out_dtype) + if in_dynamic_mode(): + return _C_ops.weight_dequantize(x, scale, algo, out_dtype) + else: + type = "weight_dequantize" + helper = LayerHelper(type, **locals()) + out = helper.create_variable_for_type_inference(out_dtype) + + helper.append_op( + type=type, + inputs={"x": x, "scale": scale}, + outputs={'out': out}, + attrs={"algo": algo, "out_dtype": out_dtype}, + ) + return out + + +def weight_only_linear( + x, + weight, + bias=None, + weight_scale=None, + weight_dtype="int8", +): + """ + Applies matrix multiplication of two tensors and then bias addition if provided. + This method requires CUDA version >= 11.2. + + Args: + x (Tensor): The first input Tensor to be multiplied, the data type is float16 or bfloat16. + weight (Tensor): The second input Tensor to be multiplied. Its rank must be 2. + bias (Tensor|None): The input bias Tensor. If it is None, no bias addition would + be performed. Otherwise, The bias is added to the matrix multiplication result. + weight_scale (Tensor|None): The input scale Tensor Provided to weight for dequantization. Its rank must be 1. + weight_dtype(str): The dtype of weight Tensor, must be one of 'int8', 'int4', Defaulted to 'int8'. + Returns: + Tensor: the output Tensor, the data type is the same as that of x. + + Examples: + .. code-block:: python + + >>> # doctest: +SKIP('No testing required') + >>> import paddle + >>> from paddle.nn.quant import weight_only_linear + + >>> x = paddle.cast(paddle.randn([1, 2, 64]), dtype='float16') + >>> weight = paddle.cast(paddle.randint(0, 127, [32, 64]), dtype='int8') + >>> scale = paddle.randn([32], dtype='float32') + >>> bias = paddle.cast(paddle.randn([32]), dtype='float16') + >>> if paddle.device.cuda.get_device_capability()[0] >= 8: + ... out = weight_only_linear(x, weight, bias=bias, weight_scale=scale, weight_dtype='int8') + ... print(out.shape) + [1, 2, 32] + """ + if in_dynamic_mode(): + out = _C_ops.weight_only_linear( + x, weight, bias, weight_scale, weight_dtype + ) + return out + else: + check_dtype( + weight_dtype, 'weight_dtype', ['int8', 'int4'], 'weight_only_linear' + ) + type = "weight_only_linear" + helper = LayerHelper(type, **locals()) + dtype = x.dtype + + inputs = { + 'x': [x], + 'weight': [weight], + 'weight_scale': [weight_scale], + } + if bias is not None: + inputs["bias"] = [bias] + attrs = {'weight_dtype': weight_dtype} + + out = helper.create_variable_for_type_inference(dtype) + + helper.append_op( + type=type, + inputs=inputs, + outputs={'out': out}, + attrs=attrs, + ) + return out + + +def llm_int8_linear( + x, + weight, + bias=None, + weight_scale=None, + threshold=6.0, +): + """ + Applies matrix multiplication of two tensors and then bias addition if provided. + This method requires CUDA version >= 11.2. + + Args: + x (Tensor): the first input Tensor to be multiplied, the data type is float16 or bfloat16. + weight (Tensor): the second input Tensor to be multiplied. Its rank must be 2. + bias (Tensor|None): the input bias Tensor. If it is None, no bias addition would + be performed. Otherwise, the bias is added to the matrix multiplication result. + weight_scale (Tensor|None): the input scale Tensor Provided to weight for dequantization. Its rank must be 1. + threshold(float): The min value of outlier in activation, outlier's channel will be apply multiply with x.dtype. + + Returns: + Tensor: the output Tensor, the data type is the same as that of x. + + Examples: + .. code-block:: python + + >>> # doctest: +SKIP('No testing required') + >>> import paddle + >>> from paddle.nn.quant import llm_int8_linear + + >>> x = paddle.cast(paddle.randn([1, 2, 64]), dtype='float16') + >>> weight = paddle.cast(paddle.randint(0, 127, [32, 64]), dtype='int8') + >>> scale = paddle.randn([32], dtype='float32') + >>> bias = paddle.cast(paddle.randn([32]), dtype='float16') + >>> if paddle.device.cuda.get_device_capability()[0] >= 8: + ... out = llm_int8_linear(x, weight, bias=bias, weight_scale=scale, threshold=6.0) + ... print(out.shape) + [1, 2, 32] + """ + if in_dynamic_mode(): + out = _C_ops.llm_int8_linear(x, weight, bias, weight_scale, threshold) + return out + else: + type = "llm_int8_linear" + helper = LayerHelper(type, **locals()) + dtype = x.dtype + + inputs = { + 'x': [x], + 'weight': [weight], + 'weight_scale': [weight_scale], + } + if bias: + inputs["bias"] = [bias] + attrs = {'threshold': threshold} + + out = helper.create_variable_for_type_inference(dtype) + + helper.append_op( + type=type, + inputs=inputs, + outputs={'out': out}, + attrs=attrs, + ) + return out diff --git a/python/paddle/nn/quant/stub.py b/python/paddle/nn/quant/stub.py new file mode 100644 index 0000000000000..b7be50d8d8b2a --- /dev/null +++ b/python/paddle/nn/quant/stub.py @@ -0,0 +1,99 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Define stub used in quantization.""" + +#from ..layer.layers import Layer +from paddle.fluid.dygraph.layers import Layer + + +class Stub(Layer): + r""" + The stub is used as placeholders that will be replaced by observers before PTQ or QAT. + It is hard to assign a quantization configuration to a functional API called in + the forward of a layer. Instead, we can create a stub and add it to the sublayers of the layer. + And call the stub before the functional API in the forward. The observer held by the + stub will observe or quantize the inputs of the functional API. + + Args: + observer(QuanterFactory) - The configured information of the observer to be inserted. + It will use a global configuration to create the observers if the 'observer' is none. + + Examples: + .. code-block:: python + + >>> import paddle + >>> from paddle.nn.quant import Stub + >>> from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver + >>> from paddle.nn import Conv2D + >>> from paddle.quantization import QAT, QuantConfig + + >>> quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.9) + >>> class Model(paddle.nn.Layer): + ... def __init__(self, num_classes=10): + ... super().__init__() + ... self.conv = Conv2D(3, 6, 3, stride=1, padding=1) + ... self.quant = Stub(quanter) + ... + ... def forward(self, inputs): + ... out = self.conv(inputs) + ... out = self.quant(out) + ... return paddle.nn.functional.relu(out) + + >>> model = Model() + >>> q_config = QuantConfig(activation=quanter, weight=quanter) + >>> qat = QAT(q_config) + >>> quant_model = qat.quantize(model) + >>> print(quant_model) + Model( + (conv): QuantedConv2D( + (weight_quanter): FakeQuanterWithAbsMaxObserverLayer() + (activation_quanter): FakeQuanterWithAbsMaxObserverLayer() + ) + (quant): QuanterStub( + (_observer): FakeQuanterWithAbsMaxObserverLayer() + ) + ) + """ + + def __init__(self, observer=None): + super().__init__() + self._observer = observer + + def forward(self, input): + return input + + +class QuanterStub(Layer): + r""" + It is an identity layer with an observer observing the input. + Before QAT or PTQ, the stub in the model will be replaced with an instance of QuanterStub. + The user should not use this class directly. + + Args: + layer(paddle.nn.Layer) - The stub layer with an observer configure factory. If the observer + of the stub layer is none, it will use 'q_config' to create an observer instance. + q_config(QuantConfig) - The quantization configuration for the current stub layer. + """ + + def __init__(self, layer: Stub, q_config): + super().__init__() + self._observer = None + if layer._observer is not None: + self._observer = layer._observer._instance(layer) + elif q_config.activation is not None: + self._observer = q_config.activation._instance(layer) + + def forward(self, input): + return self._observer(input) if self._observer is not None else input diff --git a/python/setup.py.in b/python/setup.py.in index 92ff4a8e5407e..e50e305296fbb 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -397,6 +397,7 @@ packages=['paddle', 'paddle.nn.functional', 'paddle.nn.layer', 'paddle.nn.quant', + 'paddle.nn.quant.qat', 'paddle.nn.initializer', 'paddle.nn.utils', 'paddle.metric',