Skip to content

Commit

Permalink
[Backend] cuda normalize and permute, cuda concat, optimized ppcls, p…
Browse files Browse the repository at this point in the history
…pdet & ppseg (PaddlePaddle#546)

* cuda normalize and permute, cuda concat

* add use cuda option for preprocessor

* ppyoloe use cuda normalize

* ppseg use cuda normalize

* add proclib cuda in processor base

* ppcls add use cuda preprocess api

* ppcls preprocessor set gpu id

* fix pybind

* refine ppcls preprocessing use gpu logic

* fdtensor device id is -1 by default

* refine assert message

Co-authored-by: heliqi <1101791222@qq.com>
  • Loading branch information
2 people authored and felixhjh committed Nov 25, 2022
1 parent 412deb7 commit c84de29
Show file tree
Hide file tree
Showing 20 changed files with 204 additions and 26 deletions.
7 changes: 4 additions & 3 deletions fastdeploy/core/fd_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,8 @@ void FDTensor::FreeFn() {
}
}

void FDTensor::CopyBuffer(void* dst, const void* src, size_t nbytes) {
void FDTensor::CopyBuffer(void* dst, const void* src, size_t nbytes,
const Device& device, bool is_pinned_memory) {
if (device == Device::GPU) {
#ifdef WITH_GPU
FDASSERT(cudaMemcpy(dst, src, nbytes, cudaMemcpyDeviceToDevice) == 0,
Expand Down Expand Up @@ -295,7 +296,7 @@ FDTensor::FDTensor(const FDTensor& other)
size_t nbytes = Nbytes();
FDASSERT(ReallocFn(nbytes),
"The FastDeploy FDTensor allocate memory error");
CopyBuffer(buffer_, other.buffer_, nbytes);
CopyBuffer(buffer_, other.buffer_, nbytes, device, is_pinned_memory);
}
}

Expand Down Expand Up @@ -325,7 +326,7 @@ FDTensor& FDTensor::operator=(const FDTensor& other) {
} else {
Resize(other.shape, other.dtype, other.name, other.device);
size_t nbytes = Nbytes();
CopyBuffer(buffer_, other.buffer_, nbytes);
CopyBuffer(buffer_, other.buffer_, nbytes, device, is_pinned_memory);
}
external_data_ptr = other.external_data_ptr;
}
Expand Down
8 changes: 6 additions & 2 deletions fastdeploy/core/fd_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ struct FASTDEPLOY_DECL FDTensor {
// GPU to inference the model
// so we can skip data transfer, which may improve the efficience
Device device = Device::CPU;
// By default the device id of FDTensor is -1, which means this value is
// invalid, and FDTensor is using the same device id as Runtime.
int device_id = -1;

// Whether the data buffer is in pinned memory, which is allocated
// with cudaMallocHost()
Expand Down Expand Up @@ -130,8 +133,9 @@ struct FASTDEPLOY_DECL FDTensor {

~FDTensor() { FreeFn(); }

private:
void CopyBuffer(void* dst, const void* src, size_t nbytes);
static void CopyBuffer(void* dst, const void* src, size_t nbytes,
const Device& device = Device::CPU,
bool is_pinned_memory = false);
};

} // namespace fastdeploy
13 changes: 7 additions & 6 deletions fastdeploy/function/concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,9 @@ struct ConcatFunctor {
int64_t col_len = input_cols[j];
const T* input_data = reinterpret_cast<const T*>(input[j].Data());
for (int64_t k = 0; k < out_rows; ++k) {
std::memcpy(output_data + k * out_cols + col_idx,
input_data + k * col_len, sizeof(T) * col_len);
FDTensor::CopyBuffer(output_data + k * out_cols + col_idx,
input_data + k * col_len, sizeof(T) * col_len,
input[j].device, input[j].is_pinned_memory);
}
col_idx += col_len;
}
Expand All @@ -97,7 +98,8 @@ template <typename T>
void ConcatKernel(const std::vector<FDTensor>& input, FDTensor* output,
int axis) {
auto output_shape = ComputeAndCheckConcatOutputShape(input, axis);
output->Allocate(output_shape, TypeToDataType<T>::dtype);
output->Resize(output_shape, TypeToDataType<T>::dtype, output->name,
input[0].device);

ConcatFunctor<T> functor;
functor(input, axis, output);
Expand All @@ -115,10 +117,9 @@ void Concat(const std::vector<FDTensor>& x, FDTensor* out, int axis) {
if (axis < 0) {
axis += rank;
}
FDTensor out_temp;

FD_VISIT_ALL_TYPES(x[0].dtype, "Concat",
([&] { ConcatKernel<data_t>(x, &out_temp, axis); }));
*out = std::move(out_temp);
([&] { ConcatKernel<data_t>(x, out, axis); }));
}

} // namespace function
Expand Down
5 changes: 5 additions & 0 deletions fastdeploy/runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,11 @@ std::vector<TensorInfo> Runtime::GetOutputInfos() {

bool Runtime::Infer(std::vector<FDTensor>& input_tensors,
std::vector<FDTensor>* output_tensors) {
for (auto& tensor: input_tensors) {
FDASSERT(tensor.device_id < 0 || tensor.device_id == option.device_id,
"Device id of input tensor(%d) and runtime(%d) are not same.",
tensor.device_id, option.device_id);
}
return backend_->Infer(input_tensors, output_tensors);
}

Expand Down
Empty file modified fastdeploy/vision/classification/ppcls/model.cc
100755 → 100644
Empty file.
3 changes: 3 additions & 0 deletions fastdeploy/vision/classification/ppcls/ppcls_pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ void BindPaddleClas(pybind11::module& m) {
pybind11::eval("raise Exception('Failed to preprocess the input data in PaddleClasPreprocessor.')");
}
return outputs;
})
.def("use_gpu", [](vision::classification::PaddleClasPreprocessor& self, int gpu_id = -1) {
self.UseGpu(gpu_id);
});

pybind11::class_<vision::classification::PaddleClasPostprocessor>(
Expand Down
34 changes: 30 additions & 4 deletions fastdeploy/vision/classification/ppcls/preprocessor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,22 @@
#include "fastdeploy/vision/classification/ppcls/preprocessor.h"
#include "fastdeploy/function/concat.h"
#include "yaml-cpp/yaml.h"
#ifdef WITH_GPU
#include <cuda_runtime_api.h>
#endif

namespace fastdeploy {
namespace vision {
namespace classification {

PaddleClasPreprocessor::PaddleClasPreprocessor(const std::string& config_file) {
FDASSERT(BuildPreprocessPipelineFromConfig(config_file), "Failed to create PaddleClasPreprocessor.");
FDASSERT(BuildPreprocessPipelineFromConfig(config_file),
"Failed to create PaddleClasPreprocessor.");
initialized_ = true;
}

bool PaddleClasPreprocessor::BuildPreprocessPipelineFromConfig(const std::string& config_file) {
bool PaddleClasPreprocessor::BuildPreprocessPipelineFromConfig(
const std::string& config_file) {
processors_.clear();
YAML::Node cfg;
try {
Expand Down Expand Up @@ -73,6 +78,19 @@ bool PaddleClasPreprocessor::BuildPreprocessPipelineFromConfig(const std::string
return true;
}

void PaddleClasPreprocessor::UseGpu(int gpu_id) {
#ifdef WITH_GPU
use_cuda_ = true;
if (gpu_id < 0) return;
device_id_ = gpu_id;
cudaSetDevice(device_id_);
#else
FDWARNING << "FastDeploy didn't compile with WITH_GPU. "
<< "Will force to use CPU to run preprocessing." << std::endl;
use_cuda_ = false;
#endif
}

bool PaddleClasPreprocessor::Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs) {
if (!initialized_) {
FDERROR << "The preprocessor is not initialized." << std::endl;
Expand All @@ -85,8 +103,15 @@ bool PaddleClasPreprocessor::Run(std::vector<FDMat>* images, std::vector<FDTenso

for (size_t i = 0; i < images->size(); ++i) {
for (size_t j = 0; j < processors_.size(); ++j) {
if (!(*(processors_[j].get()))(&((*images)[i]))) {
FDERROR << "Failed to processs image:" << i << " in " << processors_[i]->Name() << "." << std::endl;
bool ret = false;
if (processors_[j]->Name() == "NormalizeAndPermute" && use_cuda_) {
ret = (*(processors_[j].get()))(&((*images)[i]), ProcLib::CUDA);
} else {
ret = (*(processors_[j].get()))(&((*images)[i]));
}
if (!ret) {
FDERROR << "Failed to processs image:" << i << " in "
<< processors_[i]->Name() << "." << std::endl;
return false;
}
}
Expand All @@ -104,6 +129,7 @@ bool PaddleClasPreprocessor::Run(std::vector<FDMat>* images, std::vector<FDTenso
} else {
function::Concat(tensors, &((*outputs)[0]), 0);
}
(*outputs)[0].device_id = device_id_;
return true;
}

Expand Down
8 changes: 8 additions & 0 deletions fastdeploy/vision/classification/ppcls/preprocessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,19 @@ class FASTDEPLOY_DECL PaddleClasPreprocessor {
*/
bool Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs);

/** \brief Use GPU to run preprocessing
*
* \param[in] gpu_id GPU device id
*/
void UseGpu(int gpu_id = -1);

private:
bool BuildPreprocessPipelineFromConfig(const std::string& config_file);
std::vector<std::shared_ptr<Processor>> processors_;
bool initialized_ = false;
bool use_cuda_ = false;
// GPU device id
int device_id_ = -1;
};

} // namespace classification
Expand Down
20 changes: 20 additions & 0 deletions fastdeploy/vision/common/processors/base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,32 @@ bool Processor::operator()(Mat* mat, ProcLib lib) {
return ImplByFlyCV(mat);
#else
FDASSERT(false, "FastDeploy didn't compile with FlyCV.");
#endif
} else if (target == ProcLib::CUDA) {
#ifdef WITH_GPU
return ImplByCuda(mat);
#else
FDASSERT(false, "FastDeploy didn't compile with WITH_GPU.");
#endif
}
// DEFAULT & OPENCV
return ImplByOpenCV(mat);
}

FDTensor* Processor::UpdateAndGetReusedBuffer(
const std::vector<int64_t>& new_shape, const int& opencv_dtype,
const std::string& buffer_name, const Device& new_device,
const bool& use_pinned_memory) {
if (reused_buffers_.count(buffer_name) == 0) {
reused_buffers_[buffer_name] = FDTensor();
}
reused_buffers_[buffer_name].is_pinned_memory = use_pinned_memory;
reused_buffers_[buffer_name].Resize(new_shape,
OpenCVDataTypeToFD(opencv_dtype),
buffer_name, new_device);
return &reused_buffers_[buffer_name];
}

void EnableFlyCV() {
#ifdef ENABLE_FLYCV
DefaultProcLib::default_lib = ProcLib::FLYCV;
Expand Down
14 changes: 14 additions & 0 deletions fastdeploy/vision/common/processors/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "fastdeploy/vision/common/processors/mat.h"
#include "opencv2/highgui/highgui.hpp"
#include "opencv2/imgproc/imgproc.hpp"
#include <unordered_map>

namespace fastdeploy {
namespace vision {
Expand Down Expand Up @@ -55,7 +56,20 @@ class FASTDEPLOY_DECL Processor {
return ImplByOpenCV(mat);
}

virtual bool ImplByCuda(Mat* mat) {
return ImplByOpenCV(mat);
}

virtual bool operator()(Mat* mat, ProcLib lib = ProcLib::DEFAULT);

protected:
FDTensor* UpdateAndGetReusedBuffer(
const std::vector<int64_t>& new_shape, const int& opencv_dtype,
const std::string& buffer_name, const Device& new_device = Device::CPU,
const bool& use_pinned_memory = false);

private:
std::unordered_map<std::string, FDTensor> reused_buffers_;
};

} // namespace vision
Expand Down
2 changes: 1 addition & 1 deletion fastdeploy/vision/common/processors/mat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ void* Mat::Data() {

void Mat::ShareWithTensor(FDTensor* tensor) {
tensor->SetExternalData({Channels(), Height(), Width()}, Type(), Data());
tensor->device = Device::CPU;
tensor->device = device;
if (layout == Layout::HWC) {
tensor->shape = {Height(), Width(), Channels()};
}
Expand Down
1 change: 1 addition & 0 deletions fastdeploy/vision/common/processors/mat.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ struct FASTDEPLOY_DECL Mat {

ProcLib mat_type = ProcLib::OPENCV;
Layout layout = Layout::HWC;
Device device = Device::CPU;

// Create FD Mat from FD Tensor. This method only create a
// new FD Mat with zero copy and it's data pointer is reference
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ bool NormalizeAndPermute::ImplByOpenCV(Mat* mat) {
res.ptr() + i * origin_h * origin_w * 4),
0);
}

mat->SetMat(res);
mat->layout = Layout::CHW;
return true;
Expand Down
82 changes: 82 additions & 0 deletions fastdeploy/vision/common/processors/normalize_and_permute.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "fastdeploy/vision/common/processors/normalize_and_permute.h"

namespace fastdeploy {
namespace vision {

__global__ void NormalizeAndPermuteKernel(
uint8_t* src, float* dst, const float* alpha, const float* beta,
int num_channel, bool swap_rb, int edge) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
if (idx >= edge) return;

if (swap_rb) {
uint8_t tmp = src[num_channel * idx];
src[num_channel * idx] = src[num_channel * idx + 2];
src[num_channel * idx + 2] = tmp;
}

for (int i = 0; i < num_channel; ++i) {
dst[idx + edge * i] = src[num_channel * idx + i] * alpha[i] + beta[i];
}
}

bool NormalizeAndPermute::ImplByCuda(Mat* mat) {
cv::Mat* im = mat->GetOpenCVMat();
std::string buf_name = Name() + "_src";
std::vector<int64_t> shape = {im->rows, im->cols, im->channels()};
FDTensor* src = UpdateAndGetReusedBuffer(shape, im->type(), buf_name,
Device::GPU);
FDASSERT(cudaMemcpy(src->Data(), im->ptr(), src->Nbytes(),
cudaMemcpyHostToDevice) == 0,
"Error occurs while copy memory from CPU to GPU.");

buf_name = Name() + "_dst";
FDTensor* dst = UpdateAndGetReusedBuffer(shape, CV_32FC(im->channels()),
buf_name, Device::GPU);
cv::Mat res(im->rows, im->cols, CV_32FC(im->channels()), dst->Data());

buf_name = Name() + "_alpha";
FDTensor* alpha = UpdateAndGetReusedBuffer({(int64_t)alpha_.size()}, CV_32FC1,
buf_name, Device::GPU);
FDASSERT(cudaMemcpy(alpha->Data(), alpha_.data(), alpha->Nbytes(),
cudaMemcpyHostToDevice) == 0,
"Error occurs while copy memory from CPU to GPU.");

buf_name = Name() + "_beta";
FDTensor* beta = UpdateAndGetReusedBuffer({(int64_t)beta_.size()}, CV_32FC1,
buf_name, Device::GPU);
FDASSERT(cudaMemcpy(beta->Data(), beta_.data(), beta->Nbytes(),
cudaMemcpyHostToDevice) == 0,
"Error occurs while copy memory from CPU to GPU.");

int jobs = im->cols * im->rows;
int threads = 256;
int blocks = ceil(jobs / (float)threads);
NormalizeAndPermuteKernel<<<blocks, threads, 0, NULL>>>(
reinterpret_cast<uint8_t*>(src->Data()),
reinterpret_cast<float*>(dst->Data()),
reinterpret_cast<float*>(alpha->Data()),
reinterpret_cast<float*>(beta->Data()), im->channels(), swap_rb_, jobs);

mat->SetMat(res);
mat->device = Device::GPU;
mat->layout = Layout::CHW;
return true;
}

} // namespace vision
} // namespace fastdeploy
3 changes: 3 additions & 0 deletions fastdeploy/vision/common/processors/normalize_and_permute.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ class FASTDEPLOY_DECL NormalizeAndPermute : public Processor {
bool ImplByOpenCV(Mat* mat);
#ifdef ENABLE_FLYCV
bool ImplByFlyCV(Mat* mat);
#endif
#ifdef WITH_GPU
bool ImplByCuda(Mat* mat);
#endif
std::string Name() { return "NormalizeAndPermute"; }

Expand Down
3 changes: 3 additions & 0 deletions fastdeploy/vision/common/processors/proc_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ std::ostream& operator<<(std::ostream& out, const ProcLib& p) {
case ProcLib::FLYCV:
out << "ProcLib::FLYCV";
break;
case ProcLib::CUDA:
out << "ProcLib::CUDA";
break;
default:
FDASSERT(false, "Unknow type of ProcLib.");
}
Expand Down
2 changes: 1 addition & 1 deletion fastdeploy/vision/common/processors/proc_lib.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
namespace fastdeploy {
namespace vision {

enum class FASTDEPLOY_DECL ProcLib { DEFAULT, OPENCV, FLYCV };
enum class FASTDEPLOY_DECL ProcLib { DEFAULT, OPENCV, FLYCV, CUDA };

FASTDEPLOY_DECL std::ostream& operator<<(std::ostream& out, const ProcLib& p);

Expand Down
Loading

0 comments on commit c84de29

Please sign in to comment.