Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#2 from wwbitejotunn/swin_multihead_pa…
Browse files Browse the repository at this point in the history
…ss_temp

Swin multihead pass temp
  • Loading branch information
wwbitejotunn committed Aug 10, 2022
2 parents 63230e3 + a3ab439 commit ec5a74f
Show file tree
Hide file tree
Showing 8 changed files with 132 additions and 7 deletions.
8 changes: 8 additions & 0 deletions paddle/fluid/framework/ir/swin_attention1_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,14 @@ void SwinAttention1FusePass::ApplyImpl(ir::Graph* graph) const {
auto bias_qkv_dims = phi::make_ddim({3, bias_qkv_tensor->dims()[0]/3});
bias_qkv_tensor->Resize(bias_qkv_dims);

auto * bias_qk_1_var = scope->FindVar(elementwise_70_in_y->Name());
auto* bias_qk_1_tensor = bias_qk_1_var->GetMutable<LoDTensor>();
auto bias_qk_1_dims = bias_qk_1_tensor->dims();
auto* bias_qk_1_data = bias_qk_1_tensor->mutable_data<float>(platform::CPUPlace());
printf("@@@ in pass biasqk 0: %f ",bias_qk_1_data[0]);
VLOG(0)<<"@@@ bias_qk_1_tensor:";
VLOG(0)<<bias_qk_1_dims;

std::vector<int64_t> softmax_shape=softmax_80_out->Var()->GetShape();
float alpha=PADDLE_GET_CONST(float,scale_50_op->Op()->GetAttr("scale"));

Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/framework/var_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ void VarDesc::SetShapes(
}

std::vector<int64_t> VarDesc::GetShape() const {
// VLOG(0)<<"@@@ VarDesc::GetShape()"<<tensor_desc().dims();
// VLOG(1)<<"@@@ VarDesc::GetShape()"<<tensor_desc().dims();
return RepeatedToVector(tensor_desc().dims());
}

Expand Down Expand Up @@ -210,7 +210,7 @@ std::vector<int32_t> VarDesc::GetLoDLevels() const {
}

const proto::VarType::TensorDesc &VarDesc::tensor_desc() const {
// VLOG(0)<<"@@@ tensor name: "<<this->Name();
// VLOG(1)<<"@@@ tensor name: "<<this->Name();
PADDLE_ENFORCE_EQ(
desc_.has_type(),
true,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,10 @@ class MultiheadMatMulOpConverter : public OpConverter {
engine_->GetFp32TrtWeight(biasqk_name, *biasqk_t);
biasQK_constLayer = TRT_ENGINE_ADD_LAYER(
engine_, Constant, biasqk_dims, biasqk_const_weight.get());
float* biasqk_data = const_cast<float*>(static_cast<const float*>(
engine_->GetFp32TrtWeight(biasqk_name, *biasqk_t).get().values));
printf("@@ in convert biasqk_data 0 1 2 3: %f %f %f %f \r\n",biasqk_data[0],biasqk_data[1],biasqk_data[2],biasqk_data[3]);

engine_->SetITensor(biasqk_name,biasQK_constLayer->getOutput(0));
op_desc.SetInput("BiasQK",{biasqk_name});
}
Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,10 @@ class SkipLayerNormOpConverter : public OpConverter {
auto scale_weight = GetFp32Weight("Scale").get();

float eps = PADDLE_GET_CONST(float, op_desc.GetAttr("epsilon"));
bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
// bool with_fp16 =
// engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
bool with_fp16 = false;

plugin::SkipLayerNormPluginDynamic* plugin =
new plugin::SkipLayerNormPluginDynamic(
static_cast<const float*>(bias_weight.values),
Expand Down
69 changes: 69 additions & 0 deletions paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,24 @@ __global__ void broadcast(const T *src,
}
}

template <typename T>
__global__ void broadcast_batch(const T *src,
T *dst,
const int seq_len,
const int head_num,
const int window_num) {
int WindownumHeadSeqlen_id = blockIdx.x % (window_num * head_num * seq_len);
int dst_offset = blockIdx.x * seq_len;
if (threadIdx.x < seq_len) {
dst[threadIdx.x + dst_offset] = src[threadIdx.x+WindownumHeadSeqlen_id*seq_len];
}
}

// TODO wangbojun for debug
__global__ void print_float(const float *src, int index){
printf("%f:",src[index]);
}

int QkvToContextPluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc *input_desc,
const nvinfer1::PluginTensorDesc *output_desc,
Expand Down Expand Up @@ -329,6 +347,7 @@ int QkvToContextPluginDynamic::enqueue(
// fit to [batch, head_num, length, length] + [batch, 1, 1, length]
framework::Tensor temp_qk_bias_tensor;
float *qk_bias = const_cast<float *>(static_cast<const float *>(inputs[1]));

if (ProductDim(input_desc[1].dims) == (batch * seq_len)) {
temp_qk_bias_tensor.Resize({batch, head_number_, seq_len, seq_len});
auto *temp_qk_bias = temp_qk_bias_tensor.mutable_data<float>(
Expand All @@ -342,7 +361,36 @@ int QkvToContextPluginDynamic::enqueue(
head_number_);
qk_bias = temp_qk_bias;
}
// if bias_qk is [window_num,head_number,seq_len,seq_len]
// in swin SW-MSA block dim[0] of input is batch_number*windows_number
// therefore, we broadcast bias_qk to [Batch_num*window_num, head_number, seq_len, seq_len]
int window_num=input_desc[1].dims.d[0];
if(ProductDim(input_desc[1].dims)==window_num*head_number_*seq_len*seq_len){
temp_qk_bias_tensor.Resize({batch, head_number_, seq_len, seq_len});
auto *temp_qk_bias = temp_qk_bias_tensor.mutable_data<float>(
platform::CUDAPlace(device_id));
int grid = batch * head_number_ * seq_len;
int block = round_up(seq_len);
broadcast_batch<<<grid, block, 0, stream>>>(
static_cast<const float *>(inputs[1]),
temp_qk_bias,
seq_len,
head_number_,
window_num);
qk_bias = temp_qk_bias;
}

printf("@@@ input_desc[0] shape: %d, %d, %d \r\n",input_desc[0].dims.d[0],input_desc[0].dims.d[1],input_desc[0].dims.d[2]);
printf("@@@ input_desc[1] shape: %d, %d, %d, %d \r\n",input_desc[1].dims.d[0],input_desc[1].dims.d[1],input_desc[1].dims.d[2],input_desc[1].dims.d[3]);
printf("\r\n");

const float *input1_data = static_cast<const float *>(qk_bias);
printf("@@@ in plugin biasqk 0 1 2 3: ");
print_float<<<1,1,0,stream>>>(input1_data,0);
print_float<<<1,1,0,stream>>>(input1_data,1);
print_float<<<1,1,0,stream>>>(input1_data,2);
print_float<<<1,1,0,stream>>>(input1_data,3);
printf("\r\n");
// BxSx3xNxH => tptr: 3xBxNxSxH.
TransposeQKV(
batch, seq_len, head_size_, head_number_, input0_data, tptr, stream);
Expand Down Expand Up @@ -398,6 +446,27 @@ int QkvToContextPluginDynamic::enqueue(
head_number_);
qk_bias = temp_qk_bias;
}
// if bias_qk is [window_num,head_number,seq_len,seq_len]
// in swin SW-MSA block dim[0] of input is batch_number*windows_number
// therefore, we broadcast bias_qk to [Batch_num*window_num, head_number, seq_len, seq_len]
int window_num=input_desc[1].dims.d[0];
if(ProductDim(input_desc[1].dims)==window_num*head_number_*seq_len*seq_len){
temp_qk_bias_tensor.Resize({batch, head_number_, seq_len, seq_len});
auto *temp_qk_bias =
reinterpret_cast<half *>(temp_qk_bias_tensor.mutable_data<int16_t>(
platform::CUDAPlace(device_id)));
int grid = batch * head_number_ * seq_len;
int block = round_up(seq_len);
broadcast_batch<<<grid, block, 0, stream>>>(
static_cast<const half *>(inputs[1]),
temp_qk_bias,
seq_len,
head_number_,
window_num);
qk_bias = temp_qk_bias;
}


const half *input1_data = static_cast<const half *>(qk_bias);
// BxSx3xNxH => tptr: 3xBxNxSxH.
TransposeQKV(
Expand Down
42 changes: 41 additions & 1 deletion paddle/fluid/operators/fused/multihead_matmul_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,23 @@ __global__ void broadcast(const T *src,
}
}

template <typename T>
__global__ void broadcast_batch(const T *src,
T *dst,
const int seq_len,
const int head_num,
const int window_num) {
int WindownumHeadSeqlen_id = blockIdx.x % (window_num * head_num * seq_len);
int dst_offset = blockIdx.x * seq_len;
if (threadIdx.x < seq_len) {
dst[threadIdx.x + dst_offset] = src[threadIdx.x+WindownumHeadSeqlen_id*seq_len];
}
}
template<typename T>
__global__ void print_float(const T *src, int index){
printf("@@@ %f ",src[index]);
}

template <typename DeviceContext, typename T>
class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
public:
Expand All @@ -274,7 +291,15 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
auto *bias_d = bias->data<T>();
auto *bias_qk_d = bias_qk.template data<T>();
T scale = static_cast<T>(context.Attr<float>("alpha"));


auto bias_qk_dims=bias_qk.dims();
int window_num=bias_qk_dims[0];
printf("@@@@ multihead op \r\n");
printf("@@@ bias qk dims: %d %d %d %d \r\n",bias_qk_dims[0],bias_qk_dims[1],bias_qk_dims[2],bias_qk_dims[3]);
// print_float<T><<<1,1>>>(w_d,0);
// print_float<T><<<1,1>>>(bias_d,0);
print_float<T><<<1,1>>>(bias_qk_d,0);
printf("\r\n @@@ scale %f: \r\n", scale);
int head_number = context.Attr<int>("head_number");
// compute q*k with eltadd
auto &device_ctx = context.template device_context<DeviceContext>();
Expand All @@ -286,6 +311,7 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
int batch = input_dims[0];
int seq_len = input_dims[1];
int hidden = input_dims[2];

Tensor temp_bias_tensor;
// if bias_qk is[batch, 1, 1, seq_len], the bias_qk_d need to be broadcasted
if (bias_qk.numel() == (batch * seq_len)) {
Expand All @@ -297,6 +323,20 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
bias_qk_d, temp_qk_bias, seq_len, head_number);
bias_qk_d = static_cast<const T *>(temp_qk_bias);
}
// if bias_qk is [window_num,head_number,seq_len,seq_len]
// in swin SW-MSA block dim[0] of input is batch_number*windows_number
// therefore, we broadcast bias_qk to [window_num*originalBatch, head_number, seq_len, seq_len]
if(bias_qk.numel()==(window_num*head_number*seq_len*seq_len)){
temp_bias_tensor.Resize({batch * head_number * seq_len * seq_len});
printf("@@@@ type of qk_bias: %s \r\n",__PRETTY_FUNCTION__);
auto *temp_qk_bias = temp_bias_tensor.mutable_data<T>(context.GetPlace());
int grid = batch * head_number * seq_len;
int block = round_up(seq_len);
broadcast_batch<<<grid, block, 0, stream>>>(
bias_qk_d, temp_qk_bias, seq_len, head_number, window_num);
bias_qk_d = static_cast<const T *>(temp_qk_bias);
}

int all_head_size = w_dims[2];
int head_size = all_head_size / head_number;

Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/operators/math/bert_encoder_functor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,8 @@ inline void MatMulWithHeadQK(const phi::GPUContext &context,
T beta) {
CBLAS_TRANSPOSE transA = !q_trans ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = !k_trans ? CblasNoTrans : CblasTrans;

printf("@@ MatMulWithHeadQK: batch_size:%d, head_num:%d, seq_len:%d\r\n",
batch_size,head_num,seq_len);
typedef typename CUDATypeTraits<T>::TYPE run_type;
auto blas = phi::funcs::GetBlas<phi::GPUContext, run_type>(context);
auto stream = context.stream();
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/operators/mkldnn/interpolate_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> {
}
}
}
if (scale[0] > 0.0f && scale[1] > 0.0f && scale[2] > 0.0f) {
if (scale.size() == 3 && scale[0] > 0.0f && scale[1] > 0.0f &&
scale[2] > 0.0f) {
int j = 0;
std::vector<int64_t> in_dhw_vec = phi::vectorize(in_dhw_dims);
std::transform(
Expand Down

0 comments on commit ec5a74f

Please sign in to comment.