diff --git a/paddle/fluid/operators/fused/fmha_ref.h b/paddle/fluid/operators/fused/fmha_ref.h index 7de59dd9ee2e3..863ab8cba964b 100644 --- a/paddle/fluid/operators/fused/fmha_ref.h +++ b/paddle/fluid/operators/fused/fmha_ref.h @@ -115,6 +115,10 @@ class FMHARef { // out [2, bs, num_head, cache_seq_len + seq_len, head_dim] concat(dev_ctx_, {*cache_kv_tensor, kv_tensor}, 3, cache_kv_out_tensor); out_seq_len = cache_kv_out_tensor->dims()[3]; + } else { + if (cache_kv_out_tensor) { + *cache_kv_out_tensor = transpose_2_out_tensor->Slice(1, 3); + } } int64_t q_size = batch_size_ * seq_len_ * num_head_ * head_dim_; diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 10e617bd91243..98c09f81574ec 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -208,3 +208,12 @@ kernel : func : flip backward : flip_grad + +- op : beam_search_softmax + args : (Tensor logits, Tensor cum_scores, Tensor sequence_lengths, Tensor stop_flags, Tensor end_ids, Tensor step_ids, Tensor last_cache_ids, Tensor last_beam_offsets, int beam_size, int max_seq_len, int max_dec_len, bool fuse_softmax, bool early_stop) + output : Tensor(ids_this_time), Tensor(out_cum_scores), Tensor(cache_ids), Tensor(beam_offsets), Tensor(parent_idx), Tensor(stop_flags_out), Tensor(seq_lens_out), Tensor(step_ids_out) + infer_meta : + func : BeamSearchSoftmaxInferMeta + kernel : + func : beam_search_softmax + data_type : logits diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 375b88493a92b..7d8ef2a39379d 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -674,6 +674,57 @@ void BatchNormInferInferMeta(const MetaTensor& x, config); } +void BeamSearchSoftmaxInferMeta(const MetaTensor& logits, + const MetaTensor& cum_scores, + const MetaTensor& sequence_lengths, + const MetaTensor& stop_flags, + const MetaTensor& end_ids, + const MetaTensor& step_ids, + const MetaTensor& last_cache_ids, + const MetaTensor& last_beam_offsets, + int beam_size, + int max_seq_len, + int max_dec_len, + bool fuse_softmax, + bool early_stop, + MetaTensor* ids_this_time, + MetaTensor* out_cum_scores, + MetaTensor* cache_ids, + MetaTensor* beam_offsets, + MetaTensor* parent_idx, + MetaTensor* stop_flags_out, + MetaTensor* seq_lens_out, + MetaTensor* step_ids_out) { + auto logits_dims = logits.dims(); + auto logits_dtype = logits.dtype(); + int bbm = logits_dims[0]; + + ids_this_time->set_dims({bbm, 1}); + ids_this_time->share_lod(cum_scores); + ids_this_time->set_dtype(DataType::INT32); + cache_ids->set_dims({bbm, max_dec_len}); + cache_ids->share_lod(cum_scores); + cache_ids->set_dtype(DataType::INT32); + beam_offsets->set_dims({1, bbm, max_seq_len + max_dec_len}); + beam_offsets->share_lod(cum_scores); + beam_offsets->set_dtype(DataType::INT32); + parent_idx->set_dims({bbm}); + parent_idx->share_lod(cum_scores); + parent_idx->set_dtype(DataType::INT32); + out_cum_scores->set_dims({bbm}); + out_cum_scores->share_lod(cum_scores); + out_cum_scores->set_dtype(logits_dtype); + stop_flags_out->set_dims({bbm, 1}); + stop_flags_out->share_lod(cum_scores); + stop_flags_out->set_dtype(DataType::BOOL); + seq_lens_out->set_dims({bbm, 1}); + seq_lens_out->share_lod(cum_scores); + seq_lens_out->set_dtype(DataType::INT32); + step_ids_out->set_dims({bbm, 1}); + step_ids_out->share_lod(cum_scores); + step_ids_out->set_dtype(DataType::INT32); +} + void BilinearTensorProductInferMeta(const MetaTensor& x, const MetaTensor& y, const MetaTensor& weight, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 8c601182e8fc8..6673d66b0125d 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -190,6 +190,28 @@ void BatchNormInferInferMeta(const MetaTensor& x, MetaTensor* variance_out, MetaConfig config = MetaConfig()); +void BeamSearchSoftmaxInferMeta(const MetaTensor& logits, + const MetaTensor& cum_scores, + const MetaTensor& sequence_lengths, + const MetaTensor& stop_flags, + const MetaTensor& end_ids, + const MetaTensor& step_ids, + const MetaTensor& last_cache_ids, + const MetaTensor& last_beam_offsets, + int beam_size, + int max_seq_len, + int max_dec_len, + bool fuse_softmax, + bool early_stop, + MetaTensor* ids_this_time, + MetaTensor* out_cum_scores, + MetaTensor* cache_ids, + MetaTensor* beam_offsets, + MetaTensor* parent_idx, + MetaTensor* stop_flags_out, + MetaTensor* seq_lens_out, + MetaTensor* step_ids_out); + void BilinearTensorProductInferMeta(const MetaTensor& x, const MetaTensor& y, const MetaTensor& weight, diff --git a/paddle/phi/kernels/fusion/beam_search_softmax.h b/paddle/phi/kernels/fusion/beam_search_softmax.h new file mode 100644 index 0000000000000..985515a0a39f3 --- /dev/null +++ b/paddle/phi/kernels/fusion/beam_search_softmax.h @@ -0,0 +1,47 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { +namespace fusion { + +template +void BeamSearchSoftmaxKernel(const Context &dev_ctx, + const DenseTensor &logits, + const DenseTensor &cum_scores, + const DenseTensor &sequence_lengths, + const DenseTensor &stop_flags, + const DenseTensor &end_ids, + const DenseTensor &step_ids, + const DenseTensor &last_cache_ids, + const DenseTensor &last_beam_offsets, + int beam_size, + int max_seq_len, + int max_dec_len, + bool fuse_softmax, + bool early_stop, + DenseTensor *ids_this_time, + DenseTensor *out_cum_scores, + DenseTensor *cache_ids, + DenseTensor *beam_offsets, + DenseTensor *parent_idx, + DenseTensor *stop_flags_out, + DenseTensor *seq_lens_out, + DenseTensor *step_ids_out); + +} // namespace fusion +} // namespace phi diff --git a/paddle/phi/kernels/fusion/gpu/beam_search_softmax.cu b/paddle/phi/kernels/fusion/gpu/beam_search_softmax.cu new file mode 100644 index 0000000000000..5652adfd1c50b --- /dev/null +++ b/paddle/phi/kernels/fusion/gpu/beam_search_softmax.cu @@ -0,0 +1,925 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/fusion/beam_search_softmax.h" + +namespace phi { +namespace fusion { + +#define FLT_MAX 1e38 +// #define DEBUG_BEAM_SEARCH_SOFTMAX + +#define CASE_K(K) \ + case K: \ + invokeTopKSoftMaxLauncher(dev_ctx, \ + log_probs, \ + stop_flags, \ + sequence_lengths, \ + cum_log_probs, \ + step_ids, \ + last_cache_ids, \ + last_beam_offsets, \ + end_ids, \ + out_cum_log_probs, \ + stop_flags_out, \ + seq_lens_out, \ + step_ids_out, \ + ids, \ + tmp_ids, \ + tmp_vals, \ + parent_idx, \ + cache_ids, \ + beam_offsets, \ + batch_size, \ + beam_size, \ + vocab_size, \ + max_seq_len, \ + max_dec_len, \ + fuse_softmax, \ + early_stop, \ + stream); \ + break + +struct __align__(8) DySoftMaxStruct { + float logit; + float score; +}; + +__device__ __forceinline__ DySoftMaxStruct +reduce_softmax_op(DySoftMaxStruct a, DySoftMaxStruct b) { + bool a_bigger = (a.logit > b.logit); + DySoftMaxStruct bigger_m = a_bigger ? a : b; + DySoftMaxStruct smaller_m = a_bigger ? b : a; + DySoftMaxStruct res; + res.score = bigger_m.score + + smaller_m.score * exp(smaller_m.logit - bigger_m.logit); + res.logit = bigger_m.logit; + return res; +} + +template +struct TopK { + int ids[K]; + T vals[K]; + int parent_ids[K]; + + __device__ __forceinline__ void insert(T elem, int elem_id) { + if (elem > vals[K - 1] || (ids[K - 1] == -1) || + ((elem == vals[K - 1]) && (elem_id < ids[K - 1]))) { + vals[K - 1] = elem; + ids[K - 1] = elem_id; + } + + for (int k = K - 2; k >= 0; --k) { + if ((vals[k + 1] > vals[k]) || (ids[k] == -1) || + ((vals[k + 1] == vals[k]) && (ids[k + 1] < ids[k]))) { + T tmp_val = vals[k]; + int tmp_id = ids[k]; + vals[k] = vals[k + 1]; + ids[k] = ids[k + 1]; + vals[k + 1] = tmp_val; + ids[k + 1] = tmp_id; + } + } + } + + __device__ __forceinline__ void insert(T elem, int elem_id, int parent_id) { + if (elem > vals[K - 1] || (ids[K - 1] == -1) || + ((elem == vals[K - 1]) && (elem_id < ids[K - 1]))) { + vals[K - 1] = elem; + ids[K - 1] = elem_id; + parent_ids[K - 1] = parent_id; + } + + for (int k = K - 2; k >= 0; --k) { + if ((vals[k + 1] > vals[k]) || (ids[k] == -1) || + ((vals[k + 1] == vals[k]) && (ids[k + 1] < ids[k]))) { + T tmp_val = vals[k]; + int tmp_id = ids[k]; + int parent_id2 = parent_ids[k]; + vals[k] = vals[k + 1]; + ids[k] = ids[k + 1]; + parent_ids[k] = parent_ids[k + 1]; + vals[k + 1] = tmp_val; + ids[k + 1] = tmp_id; + parent_ids[k + 1] = parent_id2; + } + } + } +}; + +template +__device__ __forceinline__ TopK reduce_topk_op(const TopK &a, + const TopK &b) { + TopK res = a; + for (int i = 0; i < K; ++i) res.insert(b.vals[i], b.ids[i]); + return res; +} + +template +struct TopKSoftMax { + DySoftMaxStruct softmax_md; + TopK topk; +}; + +template +__device__ __forceinline__ TopKSoftMax reduce_topk_softmax_op( + const TopKSoftMax &a, const TopKSoftMax &b) { + TopKSoftMax res; + res.softmax_md = reduce_softmax_op(a.softmax_md, b.softmax_md); + res.topk = reduce_topk_op(a.topk, b.topk); + return res; +} + +template +__global__ void batch_topk(const int *topk_tmp_id_buf, + const T *topk_tmp_val_buf, + const int *step_ids, + const bool *stop_flags, // bs * beam_size + const int *seq_lens, + const int *end_ids, + int *id_buf, + T *val_buf, + int *parent_idx, + bool *stop_flags_out, + int *seq_lens_out, + int *step_ids_out) { + int thread_id = threadIdx.x; + int block_id = blockIdx.x; // bs + const int beam_size = K / 2; + TopK partial; + if (thread_id == 0) { + for (int i = 0; i < beam_size; ++i) { + partial.ids[i] = -1; + partial.vals[i] = -FLT_MAX; + partial.parent_ids[i] = -1; + } + + int index = block_id * beam_size * K; + if (step_ids[0] == 0) { + for (int i = 0; i < K; i++) { + partial.insert( + (T)topk_tmp_val_buf[index + i], topk_tmp_id_buf[index + i], i / K); + } + } else { + for (int i = 0; i < beam_size * K; i++) { + partial.insert( + (T)topk_tmp_val_buf[index + i], topk_tmp_id_buf[index + i], i / K); + } + } + index = block_id * beam_size; + for (int i = 0; i < beam_size; i++) { + id_buf[index + i] = partial.ids[i]; + val_buf[index + i] = partial.vals[i]; + int parent_id = partial.parent_ids[i]; + parent_idx[index + i] = parent_id; + stop_flags_out[index + i] = stop_flags[index + parent_id]; + seq_lens_out[index + i] = seq_lens[index + parent_id]; + step_ids_out[index + i] = step_ids[index + parent_id]; +#ifdef DEBUG_BEAM_SEARCH_SOFTMAX + printf("bi: %d, id: %d, val: %f, parent_id: %d\n", block_id, + id_buf[index+i], val_buf[index+i], parent_id); +#endif + } + } +} + +template +__global__ void batch_topk(const int *topk_tmp_id_buf, + const T *topk_tmp_val_buf, + const float *cum_log_probs, + const int *step_ids, + const bool *stop_flags, // bs * beam_size + const int *seq_lens, + const int *end_ids, + int *id_buf, + T *val_buf, + int *parent_idx, + bool *stop_flags_out, + int *seq_lens_out, + int *step_ids_out) { + int thread_id = threadIdx.x; + int block_id = blockIdx.x; // bs + const int beam_size = K / 2; + TopK partial; + if (thread_id == 0) { + for (int i = 0; i < beam_size; ++i) { + partial.ids[i] = -1; + partial.vals[i] = -FLT_MAX; + partial.parent_ids[i] = -1; + } + + int index = block_id * beam_size * K; + if (step_ids[0] == 0) { + for (int i = 0; i < K; i++) { + partial.insert( + (T)topk_tmp_val_buf[index + i], topk_tmp_id_buf[index + i], i / K); + } + } else { + for (int i = 0; i < beam_size * K; i++) { + if (!stop_flags[block_id * beam_size + i / K]) { + // if stop, this branch end, no longer update. + partial.insert( + (T)topk_tmp_val_buf[index + i], topk_tmp_id_buf[index + i], i / K); + } + } + } + index = block_id * beam_size; + int stop_num = 0; + for (int i = 0; i < beam_size; i++) { + if (stop_flags[index + i]) { + parent_idx[index + i] = i; + id_buf[index + i] = end_ids[0]; + val_buf[index + i] = cum_log_probs[index + i]; + stop_flags_out[index + i] = stop_flags[index + i]; + seq_lens_out[index + i] = seq_lens[index + i]; + step_ids_out[index + i] = step_ids_out[index + i]; + stop_num++; +#ifdef DEBUG_BEAM_SEARCH_SOFTMAX + printf("%d has end, bi: %d, stop_num: %d\n", index + i, block_id, stop_num); +#endif + } else { + int parent_id = partial.parent_ids[i - stop_num]; + parent_idx[index + i] = parent_id; + id_buf[index + i] = partial.ids[i - stop_num]; + val_buf[index + i] = partial.vals[i - stop_num]; + stop_flags_out[index + i] = stop_flags[index + parent_id]; + seq_lens_out[index + i] = seq_lens[index + parent_id]; + step_ids_out[index + i] = step_ids[index + parent_id]; +#ifdef DEBUG_BEAM_SEARCH_SOFTMAX + printf("bi: %d, id: %d, val: %f, parent_id: %d\n", block_id, + id_buf[index+i], val_buf[index+i], parent_id); +#endif + } + } + } +} + +template +__global__ void beam_search_softmax_topk_stage1(const T *logits, + const bool *stop_flags, + const int *end_ids, + float *tmp_buffer, + const int vocab_size, + const bool fuse_softmax) { + int thread_id = threadIdx.x; + int vector_id = blockIdx.x; // batch beam index. + + __shared__ float buf_s[PACKED_TOP_KMD_SIZE]; + + const T MAX_T_VAL = FLT_MAX; + + const int v_local = (vocab_size + gridDim.y - 1) / gridDim.y; + const int section_start = v_local * blockIdx.y; + int section_end = section_start + v_local; + section_end = (section_end > vocab_size) ? vocab_size : section_end; + + logits += vector_id * vocab_size; + if (fuse_softmax) { + typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + TopKSoftMax partial; + bool finish = stop_flags[vector_id]; + for (int i = 0; i < K; ++i) { + partial.topk.ids[i] = -1; + partial.topk.vals[i] = -MAX_T_VAL; + } + partial.softmax_md.logit = -MAX_T_VAL; + partial.softmax_md.score = 0.0F; + + if (finish) { +#pragma unroll 1 + for (int elem_id = section_start + thread_id; elem_id < section_end; + elem_id += THREADBLOCK_SIZE) { + // if is_end, set to (MAX_T_VAL, 1) + T elem = (elem_id == end_ids[0]) ? MAX_T_VAL : -MAX_T_VAL; + DySoftMaxStruct new_elem{elem, 1.0F}; + partial.softmax_md = reduce_softmax_op(partial.softmax_md, new_elem); + partial.topk.insert(elem, elem_id); + } + } else { +#pragma unroll 1 + for (int elem_id = section_start + thread_id; elem_id < section_end; + elem_id += THREADBLOCK_SIZE) { + T elem = logits[elem_id]; + DySoftMaxStruct new_elem{elem, 1.0F}; + partial.softmax_md = reduce_softmax_op(partial.softmax_md, new_elem); + partial.topk.insert(elem, elem_id); + } + } + + TopKSoftMax total = + BlockReduce(temp_storage).Reduce(partial, reduce_topk_softmax_op); + + if (thread_id == 0) { + for (int i = 0; i < K; i++) { + reinterpret_cast(buf_s)[i] = total.topk.ids[i]; + buf_s[K + i] = total.topk.vals[i]; + } + buf_s[2 * K] = total.softmax_md.score; + buf_s[2 * K + 1] = total.softmax_md.logit; + } + } else { + typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + TopK partial; + bool finish = stop_flags[vector_id]; + for (int i = 0; i < K; ++i) { + partial.ids[i] = -1; + partial.vals[i] = -MAX_T_VAL; + } + + if (finish) { +#pragma unroll 1 + for (int elem_id = section_start + thread_id; elem_id < section_end; + elem_id += THREADBLOCK_SIZE) { + // if is_end, set to (end_id, 1) + T elem = (elem_id == end_ids[0]) ? 0 : -MAX_T_VAL; + partial.insert(elem, elem_id); + } + } else { +#pragma unroll 1 + for (int elem_id = section_start + thread_id; elem_id < section_end; + elem_id += THREADBLOCK_SIZE) { + T elem = logits[elem_id]; + partial.insert(elem, elem_id); + } + } + + TopK total = + BlockReduce(temp_storage).Reduce(partial, reduce_topk_op); + + if (thread_id == 0) { + for (int i = 0; i < K; i++) { + reinterpret_cast(buf_s)[i] = total.ids[i]; + buf_s[K + i] = total.vals[i]; + } + } + } + __syncthreads(); + for (int elem_id = thread_id; elem_id < PACKED_TOP_KMD_SIZE; + elem_id += THREADBLOCK_SIZE) { + tmp_buffer[blockIdx.x * PACKED_TOP_KMD_SIZE * gridDim.y + + blockIdx.y * PACKED_TOP_KMD_SIZE + elem_id] = buf_s[elem_id]; + } +} + +template +__global__ void beam_search_softmax_topk_stage2(const float *tmp_buffer, + const float *cum_log_probs, + int *tmp_ids, + T *tmp_vals, + const int voc_parts, + const int packed_top_kmd_size, + const bool fuse_softmax) { + const int vector_id = blockIdx.x; + const int thread_id = threadIdx.x; + const int PACKED_TOP_KMD_SIZE = packed_top_kmd_size; + + const T MAX_T_VAL = FLT_MAX; + + extern __shared__ char buf_s_[]; + float *buf_s = reinterpret_cast(buf_s_); + tmp_buffer += vector_id * PACKED_TOP_KMD_SIZE * voc_parts; + + if (fuse_softmax) { + typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + TopKSoftMax partial; + for (int i = 0; i < K; ++i) { + partial.topk.ids[i] = -1; + partial.topk.vals[i] = -MAX_T_VAL; + } + partial.softmax_md.logit = -MAX_T_VAL; + partial.softmax_md.score = 0.0F; + + for (int idx = thread_id; idx < PACKED_TOP_KMD_SIZE * voc_parts; + idx += THREADBLOCK_SIZE) { + buf_s[idx] = tmp_buffer[idx]; + } + __syncthreads(); + + if (threadIdx.x < voc_parts) { + float *b_s = buf_s + thread_id * PACKED_TOP_KMD_SIZE; + for (int i = 0; i < K; i++) { + partial.topk.ids[i] = reinterpret_cast(b_s)[i]; + partial.topk.vals[i] = b_s[K + i]; + } + partial.softmax_md.score = b_s[2 * K]; + partial.softmax_md.logit = b_s[2 * K + 1]; + } + __syncthreads(); + + TopKSoftMax total = + BlockReduce(temp_storage).Reduce(partial, reduce_topk_softmax_op); + + if (thread_id == 0) { + tmp_ids += vector_id * K; + tmp_vals += vector_id * K; + cum_log_probs += vector_id; + + float d_total_log = log(total.softmax_md.score); + for (int i = 0; i < K; ++i) { + // float val = expf((float)total.topk.vals[i] - total.softmax_md.logit - d_total_log); + float val = total.topk.vals[i] - total.softmax_md.logit - d_total_log; + tmp_ids[i] = total.topk.ids[i]; + tmp_vals[i] = val + cum_log_probs[0]; +#ifdef DEBUG_BEAM_SEARCH_SOFTMAX + printf("vector_id: %d, vals: %f, logit: %f, d_total_log: %f, id: %d, val: %f, cum_log_probs: %f, res: %f\n", vector_id, total.topk.vals[i], total.softmax_md.logit, d_total_log, tmp_ids[i], val, cum_log_probs[0], tmp_vals[i]); +#endif + } + } + } else { + typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + TopK partial; + for (int i = 0; i < K; ++i) { + partial.ids[i] = -1; + partial.vals[i] = -MAX_T_VAL; + } + + for (int idx = thread_id; idx < PACKED_TOP_KMD_SIZE * voc_parts; + idx += THREADBLOCK_SIZE) { + buf_s[idx] = tmp_buffer[idx]; + } + __syncthreads(); + + if (threadIdx.x < voc_parts) { + float *b_s = buf_s + thread_id * PACKED_TOP_KMD_SIZE; + for (int i = 0; i < K; i++) { + partial.ids[i] = reinterpret_cast(b_s)[i]; + partial.vals[i] = b_s[K + i]; + } + } + __syncthreads(); + + TopK total = + BlockReduce(temp_storage).Reduce(partial, reduce_topk_op); + + if (thread_id == 0) { + tmp_ids += vector_id * K; + tmp_vals += vector_id * K; + cum_log_probs += vector_id; + + for (int i = 0; i < K; ++i) { + float val = total.vals[i]; + tmp_ids[i] = total.ids[i]; + tmp_vals[i] = val + cum_log_probs[0]; + } + } + } +} + +template +void invokeBeamSearchSoftmaxTopKStage2(const float *tmp_buffer, + const float *cum_log_probs, + int *ids, + T *vals, + const int batch_size, + const int beam_size, + const int voc_parts, + const int packed_top_kmd_size, + const bool fuse_softmax, + cudaStream_t stream) { + int smem_stage2_size = voc_parts * packed_top_kmd_size * sizeof(float); + + if (voc_parts <= 32) { + beam_search_softmax_topk_stage2 + <<>>( + tmp_buffer, cum_log_probs, ids, vals, voc_parts, packed_top_kmd_size, fuse_softmax); + return; + } + if (voc_parts <= 64) { + beam_search_softmax_topk_stage2 + <<>>( + tmp_buffer, cum_log_probs, ids, vals, voc_parts, packed_top_kmd_size, fuse_softmax); + return; + } + if (voc_parts <= 128) { + beam_search_softmax_topk_stage2 + <<>>( + tmp_buffer, cum_log_probs, ids, vals, voc_parts, packed_top_kmd_size, fuse_softmax); + return; + } +} + +__global__ void update_beam_offsets_kernel( + const int *src_indir_cache, // bs * bm * max_len + const int *beam_ids, // bs * bm + const int *sequence_lengths, // bs * bm + const bool *stop_flags, + const int *step_ids, + int *tgt_indir_cache, + const int batch_size, + const int beam_size, + const int max_seq_len, + const int max_dec_len) { + int time_step = threadIdx.x + blockIdx.x * blockDim.x; + int bb_id = blockIdx.y; + const int batch_id = bb_id / beam_size; + const int beam_id = bb_id % beam_size; + const int src_beam = beam_ids[bb_id]; + const int src_bb_id = batch_id * beam_size + src_beam; + const int seq_len = sequence_lengths[src_bb_id]; + const int max_len = max_seq_len + max_dec_len; + + + if (seq_len == 0 || time_step >= min(seq_len + 1, max_len)) { + return; + } + // if (time_step >= max_len) { + // return; + // } + + if (bb_id >= beam_size * batch_size) { + return; + } + + const uint tgt_offset = + batch_id * beam_size * max_len + beam_id * max_len + time_step; + const uint src_offset = + batch_id * beam_size * max_len + src_beam * max_len + time_step; + + tgt_indir_cache[tgt_offset] = (time_step == sequence_lengths[src_bb_id]) + ? src_beam + : src_indir_cache[src_offset]; +} + +void invokeUpdateBeamOffset(const int *src_indir_cache, + const int *beam_ids, + const int *sequence_lengths, + const bool *stop_flags, + const int *step_ids, + int *tgt_indir_cache, + const int batch_size, + const int beam_size, + const int max_seq_len, + const int max_dec_len, + cudaStream_t stream) { + const dim3 block(32); + const dim3 grid((max_seq_len + max_dec_len + block.x - 1) / block.x, + batch_size * beam_size); + update_beam_offsets_kernel<<>>(src_indir_cache, + beam_ids, + sequence_lengths, + stop_flags, + step_ids, + tgt_indir_cache, + batch_size, + beam_size, + max_seq_len, + max_dec_len); +} + +__global__ void update_cache_ids_kernel( + const int *last_cache_ids, // bs * bm * max_dec_len + const int *beam_ids, // bs * bm + const int *ids_this_time, + const int *sequence_lengths, // bs * bm + const bool *stop_flags, + const int *step_ids, + int *cache_ids, + const int batch_size, + const int beam_size, + const int max_dec_len) { + int time_step = threadIdx.x + blockIdx.x * blockDim.x; + int bb_id = blockIdx.y; + const int batch_id = bb_id / beam_size; + const int beam_id = bb_id % beam_size; + const int src_beam = beam_ids[bb_id]; + const int src_bb_id = batch_id * beam_size + src_beam; + const int step = step_ids[src_bb_id]; + + if (sequence_lengths[src_bb_id] == 0 || time_step >= min(step + 1, max_dec_len)) { + return; + } + + if (bb_id >= beam_size * batch_size) { + return; + } + + const uint tgt_offset = + batch_id * beam_size * max_dec_len + beam_id * max_dec_len + time_step; + const uint src_offset = + batch_id * beam_size * max_dec_len + src_beam * max_dec_len + time_step; + + cache_ids[tgt_offset] = + (time_step == step) ? ids_this_time[bb_id] : last_cache_ids[src_offset]; +} + +void invokeUpdateCacheIds(const int *last_cache_ids, + const int *beam_ids, + const int *sequence_lengths, + const int *ids_this_time, + const bool *stop_flags, + const int *step_ids, + int *cache_ids, + const int batch_size, + const int beam_size, + const int max_dec_len, + cudaStream_t stream) { + const dim3 block(32); + const dim3 grid((max_dec_len + block.x - 1) / block.x, + batch_size * beam_size); + update_cache_ids_kernel<<>>(last_cache_ids, + beam_ids, + ids_this_time, + sequence_lengths, + stop_flags, + step_ids, + cache_ids, + batch_size, + beam_size, + max_dec_len); +} + +template +void invokeTopKSoftMaxLauncher(const Context &dev_ctx, + const T *log_probs, + const bool *stop_flags, + const int *sequence_lengths, + const float *cum_log_probs, + const int *step_ids, + const int *last_cache_ids, + const int *last_beam_offsets, + const int *end_ids, + float *out_cum_log_probs, + bool *stop_flags_out, + int *seq_lens_out, + int *step_ids_out, + int *ids, + int *tmp_ids, + T *tmp_vals, + int *parent_idx, + int *cache_ids, + int *beam_offsets, + const int batch_size, + const int beam_size, + const int vocab_size, + const int max_seq_len, + const int max_dec_len, + const bool fuse_softmax, + const bool early_stop, + cudaStream_t stream) { + // K = 2 * beam_size + const int block_size = 128; + int voc_parts = vocab_size / 1024; + voc_parts = std::min(128, voc_parts); + int packed_top_kmd_size = 2 * K; + if (fuse_softmax) { + packed_top_kmd_size += 2; + } + const int tmp_buffer_size = + batch_size * beam_size * voc_parts * packed_top_kmd_size; + DenseTensor tmp_buffer_tensor; + tmp_buffer_tensor.Resize(phi::make_ddim({tmp_buffer_size})); + dev_ctx.template Alloc(&tmp_buffer_tensor); + float *tmp_buffer = tmp_buffer_tensor.data(); + + dim3 grid(batch_size * beam_size, voc_parts); + if (fuse_softmax) { + cudaFuncSetAttribute(beam_search_softmax_topk_stage1, + cudaFuncAttributePreferredSharedMemoryCarveout, + cudaSharedmemCarveoutMaxL1); + // (bs, bm, voc_parts, 2 * K + 2) + beam_search_softmax_topk_stage1 + <<>>( + log_probs, stop_flags, end_ids, tmp_buffer, vocab_size, fuse_softmax); + } else { + cudaFuncSetAttribute(beam_search_softmax_topk_stage1, + cudaFuncAttributePreferredSharedMemoryCarveout, + cudaSharedmemCarveoutMaxL1); + // (bs, bm, voc_parts, 2 * K) + beam_search_softmax_topk_stage1 + <<>>( + log_probs, stop_flags, end_ids, tmp_buffer, vocab_size, fuse_softmax); + } + // (bs, bm, K) + invokeBeamSearchSoftmaxTopKStage2(tmp_buffer, + cum_log_probs, + tmp_ids, + tmp_vals, + batch_size, + beam_size, + voc_parts, + packed_top_kmd_size, + fuse_softmax, + stream); + // (bs, bm) + if (early_stop) { + batch_topk<<>>( + tmp_ids, + tmp_vals, + cum_log_probs, + step_ids, + stop_flags, + sequence_lengths, + end_ids, + ids, + out_cum_log_probs, + parent_idx, + stop_flags_out, + seq_lens_out, + step_ids_out); + } else { + batch_topk<<>>( + tmp_ids, + tmp_vals, + step_ids, + stop_flags, + sequence_lengths, + end_ids, + ids, + out_cum_log_probs, + parent_idx, + stop_flags_out, + seq_lens_out, + step_ids_out); + } + invokeUpdateBeamOffset(last_beam_offsets, + parent_idx, + sequence_lengths, + stop_flags, + step_ids, + beam_offsets, + batch_size, + beam_size, + max_seq_len, + max_dec_len, + stream); + invokeUpdateCacheIds(last_cache_ids, + parent_idx, + sequence_lengths, + ids, + stop_flags, + step_ids, + cache_ids, + batch_size, + beam_size, + max_dec_len, + stream); +} + +template +void invokeTopkSoftMax(const Context &dev_ctx, + const T *log_probs, + const bool *stop_flags, + const int *sequence_lengths, + const float *cum_log_probs, + const int *step_ids, + const int *last_cache_ids, + const int *last_beam_offsets, + const int *end_ids, + float *out_cum_log_probs, + bool *stop_flags_out, + int *seq_lens_out, + int *step_ids_out, + int *ids, + int *tmp_ids, + T *tmp_vals, + int *parent_idx, + int *cache_ids, + int *beam_offsets, + const int batch_size, + const int beam_size, + const int vocab_size, + const int max_seq_len, + const int max_dec_len, + const bool fuse_softmax, + const bool early_stop, + cudaStream_t stream) { + switch (beam_size) { + CASE_K(1); + CASE_K(2); + CASE_K(3); + CASE_K(4); + CASE_K(5); + CASE_K(6); + CASE_K(7); + CASE_K(8); + CASE_K(9); + CASE_K(10); + CASE_K(11); + CASE_K(12); + CASE_K(13); + CASE_K(14); + CASE_K(15); + CASE_K(16); + default: + PADDLE_THROW(paddle::platform::errors::Unimplemented( + "beam_size = %d is unsupport!", beam_size)); + } +} + +template +void BeamSearchSoftmaxKernel(const Context &dev_ctx, + const DenseTensor &logits, + const DenseTensor &cum_scores, + const DenseTensor &sequence_lengths, + const DenseTensor &stop_flags, + const DenseTensor &end_ids, + const DenseTensor &step_ids, + const DenseTensor &last_cache_ids, + const DenseTensor &last_beam_offsets, + int beam_size, + int max_seq_len, + int max_dec_len, + bool fuse_softmax, + bool early_stop, + DenseTensor *ids_this_time, + DenseTensor *out_cum_scores, + DenseTensor *cache_ids, + DenseTensor *beam_offsets, + DenseTensor *parent_idx, + DenseTensor *stop_flags_out, + DenseTensor *seq_lens_out, + DenseTensor *step_ids_out) { + const auto &logits_dims = logits.dims(); + int bs = logits_dims[0]; + int batch_size = bs / beam_size; + int vocab_size = logits_dims[1]; + + dev_ctx.template Alloc(ids_this_time); + dev_ctx.template Alloc(cache_ids); + dev_ctx.template Alloc(beam_offsets); + dev_ctx.template Alloc(parent_idx); + dev_ctx.template Alloc(out_cum_scores); + dev_ctx.template Alloc(stop_flags_out); + dev_ctx.template Alloc(seq_lens_out); + dev_ctx.template Alloc(step_ids_out); + + phi::Copy(dev_ctx, last_cache_ids, dev_ctx.GetPlace(), false, cache_ids); + phi::Copy( + dev_ctx, last_beam_offsets, dev_ctx.GetPlace(), false, beam_offsets); + phi::Copy( + dev_ctx, stop_flags, dev_ctx.GetPlace(), false, stop_flags_out); + phi::Copy( + dev_ctx, sequence_lengths, dev_ctx.GetPlace(), false, seq_lens_out); + phi::Copy( + dev_ctx, step_ids, dev_ctx.GetPlace(), false, step_ids_out); + + const int tmp_size = batch_size * beam_size * beam_size * 2; + DenseTensor tmp_topk_id, tmp_topk_val; + tmp_topk_id.Resize(phi::make_ddim({tmp_size})); + dev_ctx.template Alloc(&tmp_topk_id); + tmp_topk_val.Resize(phi::make_ddim({tmp_size})); + dev_ctx.template Alloc(&tmp_topk_val); + + invokeTopkSoftMax(dev_ctx, + logits.data(), + stop_flags.data(), + sequence_lengths.data(), + cum_scores.data(), + step_ids.data(), + last_cache_ids.data(), + last_beam_offsets.data(), + end_ids.data(), + out_cum_scores->data(), + stop_flags_out->data(), + seq_lens_out->data(), + step_ids_out->data(), + ids_this_time->data(), + tmp_topk_id.data(), + tmp_topk_val.data(), + parent_idx->data(), + cache_ids->data(), + beam_offsets->data(), + batch_size, + beam_size, + vocab_size, + max_seq_len, + max_dec_len, + fuse_softmax, + early_stop, + dev_ctx.stream()); +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(beam_search_softmax, + GPU, + ALL_LAYOUT, + phi::fusion::BeamSearchSoftmaxKernel, + float) {} // only supports float diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index c8286c09b10fa..018a59ebc5261 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -312,6 +312,7 @@ from .tensor.search import sort # noqa: F401 from .tensor.search import kthvalue # noqa: F401 from .tensor.search import mode # noqa: F401 +from .tensor.search import beam_search_softmax # noqa: F401 from .tensor.to_string import set_printoptions # noqa: F401 @@ -667,4 +668,5 @@ 'sgn', 'triu_indices', 'take', + 'beam_search_softmax', ] diff --git a/python/paddle/fluid/dygraph/amp/auto_cast.py b/python/paddle/fluid/dygraph/amp/auto_cast.py index 87df808213656..decaf45125750 100644 --- a/python/paddle/fluid/dygraph/amp/auto_cast.py +++ b/python/paddle/fluid/dygraph/amp/auto_cast.py @@ -186,10 +186,13 @@ def pure_fp16_initialize(models): if (layer._dtype == 'float16') or isinstance( layer, (paddle.nn.BatchNorm, paddle.nn.BatchNorm1D, paddle.nn.BatchNorm2D, paddle.nn.BatchNorm3D, - paddle.nn.LayerNorm, paddle.nn.SyncBatchNorm)): + paddle.nn.LayerNorm, paddle.nn.SyncBatchNorm, + paddle.nn.ParameterList)): + # tianyan01 add paddle.nn.ParameterList, hack continue if isinstance(layer, (paddle.incubate.nn.FusedFeedForward, - paddle.incubate.nn.FusedMultiHeadAttention)): + paddle.incubate.nn.FusedMultiHeadAttention, + paddle.incubate.nn.FusedMultiTransformer)): layer._amp_decorate(dtype='float16') continue layer._to_impl(dtype='float16', diff --git a/python/paddle/incubate/nn/functional/fused_transformer.py b/python/paddle/incubate/nn/functional/fused_transformer.py index 02b844751a889..a39f1cb94c0c5 100644 --- a/python/paddle/incubate/nn/functional/fused_transformer.py +++ b/python/paddle/incubate/nn/functional/fused_transformer.py @@ -40,6 +40,7 @@ def fused_feedforward( ln2_bias=None, dropout1_rate=0.5, dropout2_rate=0.5, + seed=None, activation="relu", ln1_epsilon=1e-5, ln2_epsilon=1e-5, @@ -123,7 +124,7 @@ def fused_feedforward( _verify_dropout_rate(dropout1_rate) _verify_dropout_rate(dropout2_rate) - seed = None + # seed = None if mode not in ('downscale_in_infer', 'upscale_in_train'): raise ValueError( "mode argument should be 'downscale_in_infer' or 'upscale_in_train'" @@ -133,8 +134,8 @@ def fused_feedforward( ) # semantic transfer if _non_static_mode(): - if default_main_program().random_seed != 0: - seed = default_main_program().random_seed + # if default_main_program().random_seed != 0: + # seed = default_main_program().random_seed out, _, _, _, _, _, _, _, _, _, _ = _legacy_C_ops.fused_feedforward( x, None, @@ -221,8 +222,8 @@ def fused_feedforward( x.dtype, stop_gradient=True ) - if (seed is None or seed == 0) and helper.main_program.random_seed != 0: - seed = helper.main_program.random_seed + # if (seed is None or seed == 0) and helper.main_program.random_seed != 0: + # seed = helper.main_program.random_seed helper.append_op( type='fused_feedforward', @@ -477,6 +478,8 @@ def fused_multi_head_attention( attn_mask=None, dropout_rate=0.5, attn_dropout_rate=0.5, + dropout_seed=None, + attn_dropout_seed=None, ln_epsilon=1e-05, training=True, mode='upscale_in_train', @@ -602,7 +605,7 @@ def fused_multi_head_attention( print(output.shape) """ - seed = None + # seed = None if mode not in ('downscale_in_infer', 'upscale_in_train'): raise ValueError( "mode argument should be 'downscale_in_infer' or 'upscale_in_train'" @@ -612,8 +615,8 @@ def fused_multi_head_attention( ) # semantic transfer if _non_static_mode(): - if default_main_program().random_seed != 0: - seed = default_main_program().random_seed + # if default_main_program().random_seed != 0: + # seed = default_main_program().random_seed # pre_ln_mean, pre_ln_variance, pre_ln_out, qkv_out, qkv_bias_out, transpose_out, qk_out, # qktv_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, attn_mask_out, fmha_out, # linear_out, dropout_mask_out, ln_mean_out, ln_var_out, bias_dropout_residual_out, final_out @@ -678,13 +681,13 @@ def fused_multi_head_attention( 'is_test', not training, 'attn_dropout_fix_seed', - seed is not None, + attn_dropout_seed is not None, 'dropout_fix_seed', - seed is not None, + dropout_seed is not None, 'attn_dropout_seed', - seed if seed is not None else 0, + attn_dropout_seed if attn_dropout_seed is not None else 0, 'dropout_seed', - seed if seed is not None else 0, + dropout_seed if dropout_seed is not None else 0, 'attn_dropout_implementation', mode, 'dropout_implementation', @@ -735,8 +738,8 @@ def fused_multi_head_attention( if cache_kv: inputs['CacheKV'] = [cache_kv] - if (seed is None or seed == 0) and helper.main_program.random_seed != 0: - seed = helper.main_program.random_seed + # if (seed is None or seed == 0) and helper.main_program.random_seed != 0: + # seed = helper.main_program.random_seed # set attrs attrs = { @@ -746,10 +749,10 @@ def fused_multi_head_attention( 'dropout_rate': dropout_rate, 'attn_dropout_rate': attn_dropout_rate, 'is_test': not training, - 'attn_dropout_fix_seed': seed is not None, - 'dropout_fix_seed': seed is not None, - 'attn_dropout_seed': seed if seed is not None else 0, - 'dropout_seed': seed if seed is not None else 0, + 'attn_dropout_fix_seed': attn_dropout_seed is not None, + 'dropout_fix_seed': dropout_seed is not None, + 'attn_dropout_seed': attn_dropout_seed if attn_dropout_seed is not None else 0, + 'dropout_seed': dropout_seed if dropout_seed is not None else 0, 'attn_dropout_implementation': mode, 'dropout_implementation': mode, 'add_residual': add_residual, @@ -998,21 +1001,21 @@ def fused_multi_transformer( if _non_static_mode(): cache_kv_out, final_out = _legacy_C_ops.fused_multi_transformer( x, - ln_scales, - ln_biases, - qkv_weights, - qkv_biases, + list(ln_scales), + list(ln_biases), + list(qkv_weights), + list(qkv_biases), cache_kvs, time_step, attn_mask, - linear_weights, - linear_biases, - ffn_ln_scales, - ffn_ln_biases, - ffn1_weights, - ffn1_biases, - ffn2_weights, - ffn2_biases, + list(linear_weights), + list(linear_biases), + list(ffn_ln_scales), + list(ffn_ln_biases), + list(ffn1_weights), + list(ffn1_biases), + list(ffn2_weights), + list(ffn2_biases), cache_kvs, 'pre_layer_norm', pre_layer_norm, @@ -1048,29 +1051,29 @@ def fused_multi_transformer( # set inputs inputs = dict() inputs['X'] = [x] - inputs['LnScale'] = ln_scales - inputs['LnBias'] = ln_biases - inputs['QKVW'] = qkv_weights + inputs['LnScale'] = list(ln_scales) + inputs['LnBias'] = list(ln_biases) + inputs['QKVW'] = list(qkv_weights) if qkv_biases is not None: - inputs['QKVBias'] = qkv_biases + inputs['QKVBias'] = list(qkv_biases) if cache_kvs is not None: assert len(cache_kvs) == len(qkv_weights) inputs['CacheKV'] = cache_kvs if time_step is not None: inputs['TimeStep'] = time_step inputs['SrcMask'] = attn_mask - inputs['OutLinearW'] = linear_weights + inputs['OutLinearW'] = list(linear_weights) if linear_biases is not None: - inputs['OutLinearBias'] = linear_biases + inputs['OutLinearBias'] = list(linear_biases) - inputs['FFNLnScale'] = ffn_ln_scales - inputs['FFNLnBias'] = ffn_ln_biases - inputs['FFN1Weight'] = ffn1_weights + inputs['FFNLnScale'] = list(ffn_ln_scales) + inputs['FFNLnBias'] = list(ffn_ln_biases) + inputs['FFN1Weight'] = list(ffn1_weights) if ffn1_biases is not None: - inputs['FFN1Bias'] = ffn1_biases - inputs['FFN2Weight'] = ffn2_weights + inputs['FFN1Bias'] = list(ffn1_biases) + inputs['FFN2Weight'] = list(ffn2_weights) if ffn2_biases is not None: - inputs['FFN2Bias'] = ffn2_biases + inputs['FFN2Bias'] = list(ffn2_biases) # set attrs attrs = { diff --git a/python/paddle/incubate/nn/layer/fused_transformer.py b/python/paddle/incubate/nn/layer/fused_transformer.py index c3655c9d93a27..d71086fe7b07d 100644 --- a/python/paddle/incubate/nn/layer/fused_transformer.py +++ b/python/paddle/incubate/nn/layer/fused_transformer.py @@ -16,6 +16,7 @@ from paddle.nn import Layer from paddle.framework import ParamAttr import paddle +from paddle.nn import ParameterList from paddle.nn.layer.transformer import ( _convert_attention_mask, _convert_param_attr_to_list, @@ -269,6 +270,8 @@ def __init__( num_heads, dropout_rate=0.5, attn_dropout_rate=0.5, + dropout_seed=None, + attn_dropout_seed=None, kdim=None, vdim=None, normalize_before=False, @@ -377,6 +380,8 @@ def __init__( self.dropout_rate = dropout_rate self.attn_dropout_rate = attn_dropout_rate + self.dropout_seed = dropout_seed + self.attn_dropout_seed = attn_dropout_seed self.name = name @@ -434,6 +439,8 @@ def forward(self, query, key=None, value=None, attn_mask=None, cache=None): attn_mask=attn_mask, dropout_rate=self.dropout_rate, attn_dropout_rate=self.attn_dropout_rate, + dropout_seed=self.dropout_seed, + attn_dropout_seed=self.attn_dropout_seed, ln_epsilon=self._epsilon, training=self.training, ring_id=self._ring_id, @@ -546,6 +553,7 @@ def __init__( epsilon=1e-05, activation="relu", act_dropout_rate=None, + seed=None, normalize_before=False, linear1_weight_attr=None, linear1_bias_attr=None, @@ -582,6 +590,7 @@ def __init__( self._act_dropout_rate = ( dropout_rate if act_dropout_rate is None else act_dropout_rate ) + self._seed = seed self._act_method = activation self._normalize_before = normalize_before self._epsilon = epsilon @@ -661,6 +670,7 @@ def forward(self, src, cache=None): self._ln2_bias, dropout1_rate=self._act_dropout_rate, dropout2_rate=self._dropout_rate, + seed=self._seed, activation=self._act_method, ln1_epsilon=self._epsilon, ln2_epsilon=self._epsilon, @@ -1187,6 +1197,7 @@ def __init__( trans_qkvw=True, ring_id=-1, name=None, + dy_to_st=False, ): super(FusedMultiTransformer, self).__init__() @@ -1227,19 +1238,26 @@ def __init__( dim_feedforward = dim_feedforward // nranks self._dim_feedforward = dim_feedforward - if isinstance(qkv_weight_attrs, (list, tuple)): + if isinstance(qkv_weight_attrs, (list, tuple, ParameterList)): num_layers = len(qkv_weight_attrs) assert num_layers > 0 - self.ln_scales, self.ln_biases = [], [] - self.qkv_weights, self.qkv_biases = [], [] - self.linear_weights, self.linear_biases = [], [] - self.ffn_ln_scales, self.ffn_ln_biases = [], [] - self.ffn1_weights, self.ffn1_biases = [], [] - self.ffn2_weights, self.ffn2_biases = [], [] - + # if not dy_to_st: + # self.ln_scales, self.ln_biases = [], [] + # self.qkv_weights, self.qkv_biases = [], [] + # self.linear_weights, self.linear_biases = [], [] + # self.ffn_ln_scales, self.ffn_ln_biases = [], [] + # self.ffn1_weights, self.ffn1_biases = [], [] + # self.ffn2_weights, self.ffn2_biases = [], [] + # else: + self.ln_scales, self.ln_biases = ParameterList(), ParameterList() + self.qkv_weights, self.qkv_biases = ParameterList(), ParameterList() + self.linear_weights, self.linear_biases = ParameterList(), ParameterList() + self.ffn_ln_scales, self.ffn_ln_biases = ParameterList(), ParameterList() + self.ffn1_weights, self.ffn1_biases = ParameterList(), ParameterList() + self.ffn2_weights, self.ffn2_biases = ParameterList(), ParameterList() def get_attr(attrs, idx): - if isinstance(attrs, (list, tuple)): + if isinstance(attrs, (list, tuple, ParameterList)): assert len(attrs) == num_layers return attrs[idx] return attrs @@ -1263,9 +1281,10 @@ def get_attr(attrs, idx): attr=ln_scale_attr, shape=[embed_dim], default_initializer=Constant(value=1.0), + dtype="float32", ) ln_bias = self.create_parameter( - attr=ln_bias_attr, shape=[embed_dim], is_bias=True + attr=ln_bias_attr, shape=[embed_dim], is_bias=True, dtype="float32" ) qkv_weight = self.create_parameter( shape=[3, num_heads, self.head_dim, embed_dim] @@ -1299,9 +1318,10 @@ def get_attr(attrs, idx): attr=ffn_ln_scale_attr, is_bias=False, default_initializer=Constant(1.0), + dtype="float32", ) ffn_ln_bias = self.create_parameter( - shape=[embed_dim], attr=ffn_ln_bias_attr, is_bias=True + shape=[embed_dim], attr=ffn_ln_bias_attr, is_bias=True, dtype="float32" ) ffn1_weight = self.create_parameter( shape=[embed_dim, dim_feedforward], @@ -1418,3 +1438,20 @@ def forward(self, src, attn_mask=None, caches=None, time_step=None): name=self.name, ) return out + + def _amp_decorate(self, dtype): + # tmp fix for amp.decorator(O2) + def trans_to_fp16(l): + for param in l: + if param is not None: + with no_grad(): + param_applied = _to_dtype(param, dtype) + trans_to_fp16(self.qkv_weights) + trans_to_fp16(self.qkv_biases) + trans_to_fp16(self.linear_weights) + trans_to_fp16(self.linear_biases) + trans_to_fp16(self.ffn1_weights) + trans_to_fp16(self.ffn1_biases) + trans_to_fp16(self.ffn2_weights) + trans_to_fp16(self.ffn2_biases) + self._dtype = dtype \ No newline at end of file diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 58dfa26cfe377..4c27dae6af5e5 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -265,6 +265,7 @@ from .search import masked_select # noqa: F401 from .search import kthvalue # noqa: F401 from .search import mode # noqa: F401 +from .search import beam_search_softmax # noqa: F401 from .stat import mean # noqa: F401 from .stat import std # noqa: F401 @@ -515,6 +516,7 @@ 'take', 'bucketize', 'sgn', + 'beam_search_softmax', ] # this list used in math_op_patch.py for magic_method bind diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index f987e8b89cf25..7733f2ddd99d6 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -2721,7 +2721,7 @@ def gather(x, index, axis=None, name=None): check_variable_and_dtype( x, 'x', - ['float16', 'float32', 'float64', 'int16', 'int32', 'int64', 'uint8'], + ['bool', 'float16', 'float32', 'float64', 'int16', 'int32', 'int64', 'uint8'], 'gather', ) check_variable_and_dtype(index, 'index', ['int32', 'int64'], 'gather') diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index e4458048edc6b..17499887e7a5b 100644 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -1106,3 +1106,93 @@ def kthvalue(x, k, axis=None, keepdim=False, name=None): attrs=attrs) indices.stop_gradient = True return values, indices + + +def beam_search_softmax( + logits, + cum_scores, + sequence_lengths, + stop_flags, + end_ids, + step_ids, + last_cache_ids, + last_beam_offsets, + beam_size, + max_seq_len, + max_dec_len, + fuse_softmax, + early_stop, + name=None, +): + if in_dygraph_mode(): + return _C_ops.beam_search_softmax( + logits, + cum_scores, + sequence_lengths, + stop_flags, + end_ids, + step_ids, + last_cache_ids, + last_beam_offsets, + beam_size, + max_seq_len, + max_dec_len, + fuse_softmax, + early_stop + ) + + inputs = { + "logits": logits, + "cum_scores": cum_scores, + "sequence_lengths": sequence_lengths, + "stop_flags": stop_flags, + "end_ids": end_ids, + "step_ids": step_ids, + "last_cache_ids": last_cache_ids, + "last_beam_offsets": last_beam_offsets, + } + attrs = {} + attrs['beam_size'] = beam_size + attrs['max_seq_len'] = max_seq_len + attrs['max_dec_len'] = max_dec_len + attrs['fuse_softmax'] = fuse_softmax + attrs['early_stop'] = early_stop + + helper = LayerHelper('beam_search_softmax', **locals()) + ids_this_time = helper.create_variable_for_type_inference(dtype="int32") + cache_ids = helper.create_variable_for_type_inference(dtype="int32") + beam_offsets = helper.create_variable_for_type_inference(dtype="int32") + parent_idx = helper.create_variable_for_type_inference(dtype="int32") + out_cum_scores = helper.create_variable_for_type_inference( + dtype=logits.dtype + ) + stop_flags_out = helper.create_variable_for_type_inference( + dtype=stop_flags.dtype + ) + seq_lens_out = helper.create_variable_for_type_inference(dtype="int32") + step_ids_out = helper.create_variable_for_type_inference(dtype="int32") + helper.append_op( + type='beam_search_softmax', + inputs=inputs, + outputs={ + "ids_this_time": ids_this_time, + "out_cum_scores": out_cum_scores, + "cache_ids": cache_ids, + "beam_offsets": beam_offsets, + "parent_idx": parent_idx, + "stop_flags_out": stop_flags_out, + "seq_lens_out": seq_lens_out, + "step_ids_out": step_ids_out, + }, + attrs=attrs, + ) + return ( + ids_this_time, + out_cum_scores, + cache_ids, + beam_offsets, + parent_idx, + stop_flags_out, + seq_lens_out, + step_ids_out + ) \ No newline at end of file