Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#1 from tianyan01/v2.4.2
Browse files Browse the repository at this point in the history
modify for dynamic zeus
  • Loading branch information
laipaang committed Aug 14, 2023
2 parents 0e92adc + 0631322 commit 052c62e
Show file tree
Hide file tree
Showing 13 changed files with 1,251 additions and 56 deletions.
4 changes: 4 additions & 0 deletions paddle/fluid/operators/fused/fmha_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ class FMHARef {
// out [2, bs, num_head, cache_seq_len + seq_len, head_dim]
concat(dev_ctx_, {*cache_kv_tensor, kv_tensor}, 3, cache_kv_out_tensor);
out_seq_len = cache_kv_out_tensor->dims()[3];
} else {
if (cache_kv_out_tensor) {
*cache_kv_out_tensor = transpose_2_out_tensor->Slice(1, 3);
}
}

int64_t q_size = batch_size_ * seq_len_ * num_head_ * head_dim_;
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,12 @@
kernel :
func : flip
backward : flip_grad

- op : beam_search_softmax
args : (Tensor logits, Tensor cum_scores, Tensor sequence_lengths, Tensor stop_flags, Tensor end_ids, Tensor step_ids, Tensor last_cache_ids, Tensor last_beam_offsets, int beam_size, int max_seq_len, int max_dec_len, bool fuse_softmax, bool early_stop)
output : Tensor(ids_this_time), Tensor(out_cum_scores), Tensor(cache_ids), Tensor(beam_offsets), Tensor(parent_idx), Tensor(stop_flags_out), Tensor(seq_lens_out), Tensor(step_ids_out)
infer_meta :
func : BeamSearchSoftmaxInferMeta
kernel :
func : beam_search_softmax
data_type : logits
51 changes: 51 additions & 0 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,57 @@ void BatchNormInferInferMeta(const MetaTensor& x,
config);
}

void BeamSearchSoftmaxInferMeta(const MetaTensor& logits,
const MetaTensor& cum_scores,
const MetaTensor& sequence_lengths,
const MetaTensor& stop_flags,
const MetaTensor& end_ids,
const MetaTensor& step_ids,
const MetaTensor& last_cache_ids,
const MetaTensor& last_beam_offsets,
int beam_size,
int max_seq_len,
int max_dec_len,
bool fuse_softmax,
bool early_stop,
MetaTensor* ids_this_time,
MetaTensor* out_cum_scores,
MetaTensor* cache_ids,
MetaTensor* beam_offsets,
MetaTensor* parent_idx,
MetaTensor* stop_flags_out,
MetaTensor* seq_lens_out,
MetaTensor* step_ids_out) {
auto logits_dims = logits.dims();
auto logits_dtype = logits.dtype();
int bbm = logits_dims[0];

ids_this_time->set_dims({bbm, 1});
ids_this_time->share_lod(cum_scores);
ids_this_time->set_dtype(DataType::INT32);
cache_ids->set_dims({bbm, max_dec_len});
cache_ids->share_lod(cum_scores);
cache_ids->set_dtype(DataType::INT32);
beam_offsets->set_dims({1, bbm, max_seq_len + max_dec_len});
beam_offsets->share_lod(cum_scores);
beam_offsets->set_dtype(DataType::INT32);
parent_idx->set_dims({bbm});
parent_idx->share_lod(cum_scores);
parent_idx->set_dtype(DataType::INT32);
out_cum_scores->set_dims({bbm});
out_cum_scores->share_lod(cum_scores);
out_cum_scores->set_dtype(logits_dtype);
stop_flags_out->set_dims({bbm, 1});
stop_flags_out->share_lod(cum_scores);
stop_flags_out->set_dtype(DataType::BOOL);
seq_lens_out->set_dims({bbm, 1});
seq_lens_out->share_lod(cum_scores);
seq_lens_out->set_dtype(DataType::INT32);
step_ids_out->set_dims({bbm, 1});
step_ids_out->share_lod(cum_scores);
step_ids_out->set_dtype(DataType::INT32);
}

void BilinearTensorProductInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& weight,
Expand Down
22 changes: 22 additions & 0 deletions paddle/phi/infermeta/multiary.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,28 @@ void BatchNormInferInferMeta(const MetaTensor& x,
MetaTensor* variance_out,
MetaConfig config = MetaConfig());

void BeamSearchSoftmaxInferMeta(const MetaTensor& logits,
const MetaTensor& cum_scores,
const MetaTensor& sequence_lengths,
const MetaTensor& stop_flags,
const MetaTensor& end_ids,
const MetaTensor& step_ids,
const MetaTensor& last_cache_ids,
const MetaTensor& last_beam_offsets,
int beam_size,
int max_seq_len,
int max_dec_len,
bool fuse_softmax,
bool early_stop,
MetaTensor* ids_this_time,
MetaTensor* out_cum_scores,
MetaTensor* cache_ids,
MetaTensor* beam_offsets,
MetaTensor* parent_idx,
MetaTensor* stop_flags_out,
MetaTensor* seq_lens_out,
MetaTensor* step_ids_out);

void BilinearTensorProductInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& weight,
Expand Down
47 changes: 47 additions & 0 deletions paddle/phi/kernels/fusion/beam_search_softmax.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/core/dense_tensor.h"

namespace phi {
namespace fusion {

template <typename T, typename Context>
void BeamSearchSoftmaxKernel(const Context &dev_ctx,
const DenseTensor &logits,
const DenseTensor &cum_scores,
const DenseTensor &sequence_lengths,
const DenseTensor &stop_flags,
const DenseTensor &end_ids,
const DenseTensor &step_ids,
const DenseTensor &last_cache_ids,
const DenseTensor &last_beam_offsets,
int beam_size,
int max_seq_len,
int max_dec_len,
bool fuse_softmax,
bool early_stop,
DenseTensor *ids_this_time,
DenseTensor *out_cum_scores,
DenseTensor *cache_ids,
DenseTensor *beam_offsets,
DenseTensor *parent_idx,
DenseTensor *stop_flags_out,
DenseTensor *seq_lens_out,
DenseTensor *step_ids_out);

} // namespace fusion
} // namespace phi
Loading

0 comments on commit 052c62e

Please sign in to comment.