Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

deconv cudnn forward passed #5235

Merged
merged 22 commits into from
Nov 2, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
2d44a2e
deconv cudnn
zchen0211 Oct 31, 2017
e80489a
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zchen0211 Oct 31, 2017
a349bee
deconv2d cudnn
zchen0211 Oct 31, 2017
b77f9fb
deconv2d cudnn
zchen0211 Oct 31, 2017
8013328
Refine evaluator op types (#5208)
typhoonzero Oct 31, 2017
873ee9a
add test_Expand and simply the gserver/tests/CMakeLists
luotao1 Oct 25, 2017
c2f6aa9
add comments in test_Expand.cpp
luotao1 Oct 31, 2017
1e12796
correct the index of cluster_train_cn/en.md
luotao1 Oct 31, 2017
2113d6e
fix bug (#5233)
JiayiFeng Oct 31, 2017
ddde829
Fix/sequence pool (#5229)
dzhwinter Oct 31, 2017
e41f28c
Adding a framework for variable initializers (#5232)
abhinavarora Oct 31, 2017
9b65acd
memory log level change from 3 to 10 (#5231)
jacquesqiao Oct 31, 2017
f354bd9
AddBiasOp does not care num_flatten_dims (#5200)
reyoung Oct 31, 2017
8d3324d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zchen0211 Oct 31, 2017
a0acfc6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zchen0211 Oct 31, 2017
4e22802
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zchen0211 Nov 1, 2017
b720f28
deconv modify
zchen0211 Nov 1, 2017
634face
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zchen0211 Nov 1, 2017
31187e7
deconv fix
zchen0211 Nov 1, 2017
7e34b8e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zchen0211 Nov 2, 2017
2d956b8
deconv cudnn
zchen0211 Nov 2, 2017
0efac25
deconv small fix
zchen0211 Nov 2, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions paddle/operators/conv2d_transpose_cudnn_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/conv2d_transpose_op.h"

namespace paddle {
namespace operators {

class CudnnConv2DTransposeOpMaker : public Conv2DTransposeOpMaker {
public:
CudnnConv2DTransposeOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: Conv2DTransposeOpMaker(proto, op_checker) {
AddAttr<std::vector<int>>("dilations", "dilations of convolution operator.")
.SetDefault(std::vector<int>{1, 1});
AddAttr<int>("workspace_size_MB",
"workspace size for cudnn, in MB, "
"workspace is a section of GPU memory which will be "
"allocated/freed each time the operator runs, larger "
"workspace size can increase performance but also requires "
"better hardward. This size should be carefully setted.")
.SetDefault(4096);
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP(conv2d_transpose_cudnn, ops::Conv2DTransposeOp,
ops::CudnnConv2DTransposeOpMaker, conv2d_transpose_cudnn_grad,
ops::Conv2DTransposeOpGrad);

REGISTER_OP_CPU_KERNEL(
conv2d_transpose_cudnn,
ops::GemmConv2DTransposeKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
conv2d_transpose_cudnn_grad,
ops::GemmConv2DTransposeGradKernel<paddle::platform::CPUPlace, float>);
240 changes: 240 additions & 0 deletions paddle/operators/conv2d_transpose_cudnn_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/memory/memory.h"
#include "paddle/operators/conv2d_transpose_op.h"
#include "paddle/platform/assert.h"
#include "paddle/platform/cudnn_helper.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using ScopedFilterDescriptor = platform::ScopedFilterDescriptor;
using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor;
using DataLayout = platform::DataLayout;
using CUDADeviceContext = platform::CUDADeviceContext;

static constexpr size_t kConvCudnnWorkspaceLimitBytes = 1024 * 1024 * 1024;

template <typename T>
class CudnnConvTransposeOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use GPUPlace.");
auto* input = ctx.Input<Tensor>("Input");
auto* filter = ctx.Input<Tensor>("Filter");
auto* output = ctx.Output<Tensor>("Output");

std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
// cudnn v5 does not support dilations
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
int user_workspace_size = ctx.Attr<int>("workspace_size_MB");

const T* input_data = input->data<T>();
const T* filter_data = filter->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace());
// ------------------- cudnn descriptors ---------------------
ScopedTensorDescriptor input_desc;
ScopedTensorDescriptor output_desc;
ScopedFilterDescriptor filter_desc;
ScopedConvolutionDescriptor conv_desc;
DataLayout layout = DataLayout::kNCHW;

// N, M, H, W
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize2int(input->dims()));
// N, C, O_h, O_w
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
layout, framework::vectorize2int(output->dims()));
// M, C, K_h, K_w
cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
layout, framework::vectorize2int(filter->dims()));
cudnnConvolutionDescriptor_t cudnn_conv_desc =
conv_desc.descriptor<T>(paddings, strides, dilations);

// ------------------- cudnn conv workspace ---------------------
void* cudnn_workspace = nullptr;
size_t workspace_size_in_bytes; // final workspace to allocate.
size_t workspace_size_limit = kConvCudnnWorkspaceLimitBytes;
if (user_workspace_size > 0) {
workspace_size_limit = user_workspace_size * 1024 * 1024;
}
// ------------------- cudnn conv algorithm ---------------------
cudnnConvolutionBwdDataAlgo_t algo;
auto handle = ctx.cuda_device_context().cudnn_handle();
// Get the algorithm
PADDLE_ENFORCE(platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm(
handle, cudnn_filter_desc, cudnn_input_desc, cudnn_conv_desc,
// dxDesc: Handle to the previously initialized output tensor
// descriptor.
cudnn_output_desc, CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &algo));

// get workspace size able to allocate
PADDLE_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
handle, cudnn_filter_desc, cudnn_input_desc, cudnn_conv_desc,
cudnn_output_desc, algo, &workspace_size_in_bytes));

// Allocate on GPU memory
platform::GPUPlace gpu = boost::get<platform::GPUPlace>(ctx.GetPlace());
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);

// ------------------- cudnn conv transpose forward ---------------------
T alpha = 1.0f, beta = 0.0f;
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardData(
handle, &alpha, cudnn_filter_desc, filter_data, cudnn_input_desc,
input_data, cudnn_conv_desc, algo, cudnn_workspace,
workspace_size_in_bytes, &beta, cudnn_output_desc, output_data));

// Release the cudnn workspace
paddle::memory::Free(gpu, cudnn_workspace);
}
};

template <typename T>
class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use GPUPlace.");
auto input = ctx.Input<Tensor>("Input");
auto filter = ctx.Input<Tensor>("Filter");
auto output_grad = ctx.Input<Tensor>(framework::GradVarName("Output"));
auto input_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
auto filter_grad = ctx.Output<Tensor>(framework::GradVarName("Filter"));
const T* input_data = input->data<T>();
const T* output_grad_data = output_grad->data<T>();
const T* filter_data = filter->data<T>();

std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
// cudnn v5 does not support dilations
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
int user_workspace_size = ctx.Attr<int>("workspace_size_MB");

// ------------------- cudnn descriptors ---------------------
ScopedTensorDescriptor input_desc;
ScopedTensorDescriptor output_desc;
ScopedFilterDescriptor filter_desc;
ScopedConvolutionDescriptor conv_desc;
DataLayout layout = DataLayout::kNCHW;

// Input: (N, M, H, W)
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize2int(input->dims()));
// Output: (N, C, O_H, O_W)
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
layout, framework::vectorize2int(output_grad->dims()));
// Filter (M, C, K_H, K_W)
cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
layout, framework::vectorize2int(filter->dims()));

cudnnConvolutionDescriptor_t cudnn_conv_desc =
conv_desc.descriptor<T>(paddings, strides, dilations);

// ------------------- cudnn backward algorithm ---------------------
cudnnConvolutionFwdAlgo_t data_algo;
cudnnConvolutionBwdFilterAlgo_t filter_algo;
size_t bwd_filter_ws_size, fwd_ws_size;
size_t workspace_size_in_bytes = 0;
size_t workspace_size_limit = kConvCudnnWorkspaceLimitBytes;
if (user_workspace_size > 0) {
workspace_size_limit = user_workspace_size * 1024 * 1024;
}

auto handle = ctx.cuda_device_context().cudnn_handle();
if (input_grad) {
// choose backward algorithm for data
PADDLE_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm(
handle, cudnn_output_desc, cudnn_filter_desc, cudnn_conv_desc,
cudnn_input_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &data_algo));
PADDLE_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize(
handle, cudnn_output_desc, cudnn_filter_desc, cudnn_conv_desc,
cudnn_input_desc, data_algo, &fwd_ws_size));
workspace_size_in_bytes = std::max(workspace_size_in_bytes, fwd_ws_size);
}

if (filter_grad) {
// choose backward algorithm for filter
PADDLE_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
handle, cudnn_output_desc, cudnn_input_desc, cudnn_conv_desc,
cudnn_filter_desc,
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &filter_algo));

// get workspace for backwards filter algorithm
PADDLE_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize(
handle, cudnn_output_desc, cudnn_input_desc, cudnn_conv_desc,
cudnn_filter_desc, filter_algo, &bwd_filter_ws_size));
workspace_size_in_bytes =
std::max(workspace_size_in_bytes, bwd_filter_ws_size);
}

// ------------------- cudnn conv workspace ---------------------
// Already on GPU
void* cudnn_workspace = nullptr;
platform::GPUPlace gpu = boost::get<platform::GPUPlace>(ctx.GetPlace());
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
// ------------------- cudnn conv backward data ---------------------
// FIXME(typhoonzero): template type T may not be the same as cudnn call.
T alpha = 1.0f, beta = 0.0f;
if (input_grad) {
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*input_grad);
t.device(ctx.GetEigenDevice<platform::GPUPlace>()) =
t.constant(static_cast<T>(0));

PADDLE_ENFORCE(platform::dynload::cudnnConvolutionForward(
handle, &alpha, cudnn_output_desc, output_grad_data,
cudnn_filter_desc, filter_data, cudnn_conv_desc, data_algo,
cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_input_desc,
input_grad_data));
}

// ------------------- cudnn conv backward filter ---------------------
if (filter_grad) {
T* filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*filter_grad);
t.device(ctx.GetEigenDevice<platform::GPUPlace>()) =
t.constant(static_cast<T>(0));
// Gradient with respect to the filter
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
handle, &alpha, cudnn_output_desc, output_grad_data, cudnn_input_desc,
input_data, cudnn_conv_desc, filter_algo, cudnn_workspace,
workspace_size_in_bytes, &beta, cudnn_filter_desc, filter_grad_data));
}
// Release the cudnn workspace
paddle::memory::Free(gpu, cudnn_workspace);
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;

REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why name it conv2d_transpose?

ops::CudnnConvTransposeOpKernel<float>);
REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn_grad,
ops::CudnnConvTransposeGradOpKernel<float>);
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/conv2dtranspose_op.h"
#include "paddle/operators/conv2d_transpose_op.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -95,13 +95,13 @@ void Conv2DTransposeOpGrad::InferShape(
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP(conv2dtranspose, ops::Conv2DTransposeOp,
ops::Conv2DTransposeOpMaker, conv2dtranspose_grad,
REGISTER_OP(conv2d_transpose, ops::Conv2DTransposeOp,
ops::Conv2DTransposeOpMaker, conv2d_transpose_grad,
ops::Conv2DTransposeOpGrad);

REGISTER_OP_CPU_KERNEL(
conv2dtranspose,
conv2d_transpose,
ops::GemmConv2DTransposeKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
conv2dtranspose_grad,
conv2d_transpose_grad,
ops::GemmConv2DTransposeGradKernel<paddle::platform::CPUPlace, float>);
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/conv2dtranspose_op.h"
#include "paddle/operators/conv2d_transpose_op.h"

namespace ops = paddle::operators;

REGISTER_OP_GPU_KERNEL(
conv2dtranspose,
conv2d_transpose,
ops::GemmConv2DTransposeKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
conv2dtranspose_grad,
conv2d_transpose_grad,
ops::GemmConv2DTransposeGradKernel<paddle::platform::GPUPlace, float>);
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class GemmConv2DTransposeKernel : public framework::OpKernel<T> {
std::vector<int> strides = context.Attr<std::vector<int>>("strides");

// TODO(Zhuoyuan): Paddings can be added in future.
// groups will alway be disabled in conv2dtranspose.
// groups will alway be disabled in conv2d_transpose.

const int batch_size = input->dims()[0];
const int m = input->dims()[1];
Expand Down
Loading