Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add fused vision transformer #3034

Merged
merged 11 commits into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions csrc/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# PaddleClas 自定义 OP

此文档介绍如何编译安装 PaddleClas 自定义 OP。

## 安装 pip 依赖

```shell
pip install -r requirements.txt
```

## 编译 Cuda 算子

```shell
python setup_cuda.py install
```
103 changes: 103 additions & 0 deletions csrc/generation/helper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/extension.h"
#include <cub/cub.cuh>
#include <curand_kernel.h>

constexpr int kBlockSize = 256;
constexpr int kNumWaves = 16;

inline cudaError_t GetNumBlocks(int64_t n, int* num_blocks) {
int dev;
{
cudaError_t err = cudaGetDevice(&dev);
if (err != cudaSuccess) { return err; }
}
int sm_count;
{
cudaError_t err = cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev);
if (err != cudaSuccess) { return err; }
}
int tpm;
{
cudaError_t err = cudaDeviceGetAttribute(&tpm, cudaDevAttrMaxThreadsPerMultiProcessor, dev);
if (err != cudaSuccess) { return err; }
}
*num_blocks = std::max<int>(1, std::min<int64_t>((n + kBlockSize - 1) / kBlockSize,
sm_count * tpm / kBlockSize * kNumWaves));
return cudaSuccess;
}

template<typename T>
__device__ T max_func(const T a, const T b) {
return a > b ? a : b;
}

template<typename T>
struct MaxOp {
__device__ __forceinline__ T operator()(const T& a, const T& b) const {
return max_func(a, b);
}
};

template <paddle::DataType D>
class PDTraits;

template <>
class PDTraits<paddle::DataType::FLOAT32> {
public:
typedef float DataType;
typedef float data_t;
};

template <>
class PDTraits<paddle::DataType::FLOAT16> {
public:
typedef half DataType;
typedef paddle::float16 data_t;
};

template <>
class PDTraits<paddle::DataType::BFLOAT16> {
public:
typedef __nv_bfloat16 DataType;
typedef paddle::bfloat16 data_t;
};

template <typename T, int Size>
struct alignas(sizeof(T) * Size) AlignedVector {
T val[Size];

HOSTDEVICE inline const T& operator[](int i) const { return val[i]; }
HOSTDEVICE inline T& operator[](int i) { return val[i]; }
};

template <typename T, int Size>
HOSTDEVICE inline void Load(const T* addr, AlignedVector<T, Size>* vec) {
const AlignedVector<T, Size>* addr_vec =
reinterpret_cast<const AlignedVector<T, Size>*>(addr);
*vec = *addr_vec;
}

template <typename T, int Size>
HOSTDEVICE inline void Store(const AlignedVector<T, Size>& vec, T* addr) {
AlignedVector<T, Size>* addr_vec =
reinterpret_cast<AlignedVector<T, Size>*>(addr);
*addr_vec = vec;
}

constexpr int VEC_16B = 16;
193 changes: 193 additions & 0 deletions csrc/generation/qkv_transpose_split.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "helper.h"

template <typename T, int VecSize>
__global__ void fusedQKV_transpose_split_kernel(
T *q_buf,
T *k_buf,
T *v_buf,
const T *qkv,
const int *padding_offset,
const int *seq_lens,
const int32_t elem_cnt,
const int batch_size,
const int max_len_this_time,
const int seq_len,
const int token_num,
const int head_num,
const int size_per_head) {
const int32_t offset = batch_size * max_len_this_time * head_num * size_per_head;
const int32_t hidden_size = head_num * size_per_head;
const int32_t fused_hidden_size = 3 * hidden_size;
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
using LoadT = AlignedVector<T, VecSize>;
LoadT src_vec;
LoadT bias_vec;

for (int32_t linear_index = global_thread_idx * VecSize,
step = gridDim.x * blockDim.x * VecSize;
linear_index < elem_cnt;
linear_index += step) {
Load<T, VecSize>(&qkv[linear_index], &src_vec);
int32_t bias_idx = linear_index % fused_hidden_size;
const int32_t token_idx = linear_index / fused_hidden_size;
const int32_t ori_token_idx =
token_idx + (padding_offset == nullptr ? 0 : padding_offset[token_idx]);
const int32_t target_batch_id = ori_token_idx / seq_len;
if (seq_lens[target_batch_id] == 0) continue;
const int32_t seq_id = ori_token_idx % seq_len;

// equal to:
// const int qkv_id = (linear_index % fused_hidden_size) / hidden_size;
const int32_t qkv_id = bias_idx / hidden_size;
const int32_t head_id = (linear_index % hidden_size) / size_per_head;
const int32_t size_id = linear_index % size_per_head;

if (qkv_id == 0) {
Store<T, VecSize>(
src_vec,
&q_buf[target_batch_id * head_num * max_len_this_time * size_per_head +
head_id * max_len_this_time * size_per_head + seq_id * size_per_head +
size_id]);
} else if (qkv_id == 1) {
Store<T, VecSize>(
src_vec,
&k_buf[target_batch_id * head_num * max_len_this_time * size_per_head +
head_id * max_len_this_time * size_per_head + seq_id * size_per_head +
size_id]);
} else {
Store<T, VecSize>(
src_vec,
&v_buf[target_batch_id * head_num * max_len_this_time * size_per_head +
head_id * max_len_this_time * size_per_head + seq_id * size_per_head +
size_id]);
}
}
}

template <paddle::DataType D>
std::vector<paddle::Tensor> qkv_transpose_split(const paddle::Tensor& qkv, // [token_num, dim_embed]
const paddle::Tensor& padding_offset, // [bsz, 1]
const paddle::Tensor& seq_lens,
const paddle::Tensor& input_ids,
int num_head,
int head_size) {
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;

auto cu_stream = qkv.stream();
std::vector<int64_t> qkv_shape = qkv.shape();
const int token_num = qkv_shape[0];
const int bsz = seq_lens.shape()[0];
const int max_seq_len = input_ids.shape()[1]; //max_seq_len_tensor.copy_to(paddle::CPUPlace(), false).data<int>()[0];
auto q_out = paddle::full({bsz, num_head, max_seq_len, head_size}, 0, qkv.dtype(), qkv.place());
auto k_out = paddle::full({bsz, num_head, max_seq_len, head_size}, 0, qkv.dtype(), qkv.place());
auto v_out = paddle::full({bsz, num_head, max_seq_len, head_size}, 0, qkv.dtype(), qkv.place());
constexpr int PackSize = VEC_16B / sizeof(DataType_);
const int elem_cnt = token_num * num_head * head_size * 3;
const int pack_num = elem_cnt / PackSize;
const int blocksize = 128;
const int grid_size = (pack_num + blocksize - 1) / blocksize;
fusedQKV_transpose_split_kernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, qkv.stream()>>>(
reinterpret_cast<DataType_*>(q_out.data<data_t>()),
reinterpret_cast<DataType_*>(k_out.data<data_t>()),
reinterpret_cast<DataType_*>(v_out.data<data_t>()),
reinterpret_cast<DataType_*>(const_cast<data_t*>(qkv.data<data_t>())),
padding_offset.data<int>(),
seq_lens.data<int>(),
elem_cnt,
bsz,
max_seq_len,
max_seq_len,
token_num,
num_head,
head_size);
return {q_out, k_out, v_out};
}

std::vector<paddle::Tensor> QKVTransposeSplit(const paddle::Tensor& qkv,
const paddle::Tensor& padding_offset,
const paddle::Tensor& seq_lens,
const paddle::Tensor& input_ids,
int num_head,
int head_size) {
switch (qkv.type()) {
case paddle::DataType::BFLOAT16: {
return qkv_transpose_split<paddle::DataType::BFLOAT16>(
qkv,
padding_offset,
seq_lens,
input_ids,
num_head,
head_size
);
}
case paddle::DataType::FLOAT16: {
return qkv_transpose_split<paddle::DataType::FLOAT16>(
qkv,
padding_offset,
seq_lens,
input_ids,
num_head,
head_size
);
}
case paddle::DataType::FLOAT32: {
return qkv_transpose_split<paddle::DataType::FLOAT32>(
qkv,
padding_offset,
seq_lens,
input_ids,
num_head,
head_size
);
}
default: {
PD_THROW(
"NOT supported data type. "
"Only float16, bfloat16 and float32 are supported. ");
break;
}
}
}

std::vector<std::vector<int64_t>> QKVTransposeSplitInferShape(const std::vector<int64_t>& qkv_shape,
const std::vector<int64_t>& padding_offset_shape,
const std::vector<int64_t>& seq_lens_shape,
const std::vector<int64_t>& input_ids_shape,
int num_head,
int head_size) {
int64_t bsz = seq_lens_shape[0];
return {{bsz, num_head, -1, head_size}, {bsz, num_head, -1, head_size}, {bsz, num_head, -1, head_size}};
}

std::vector<paddle::DataType> QKVTransposeSplitInferDtype(const paddle::DataType& qkv_dtype,
const paddle::DataType& padding_offset_dtype,
const paddle::DataType& seq_lens_dtype,
const paddle::DataType& input_ids_dtype) {
return {qkv_dtype, qkv_dtype, qkv_dtype};
}

PD_BUILD_OP(qkv_transpose_split)
.Inputs({"qkv", "padding_offset", "seq_lens", "input_ids"})
.Outputs({"q_out", "k_out", "v_out"})
.Attrs({"num_head: int",
"head_size: int"})
.SetKernelFn(PD_KERNEL(QKVTransposeSplit))
.SetInferShapeFn(PD_INFER_SHAPE(QKVTransposeSplitInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(QKVTransposeSplitInferDtype));
Loading