From 96b4035dd132d419f463bd0341baa2c4a773b8b6 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Tue, 10 Oct 2017 16:08:23 +0800 Subject: [PATCH 01/10] Add conv3d_gemm_op --- paddle/operators/CMakeLists.txt | 5 +- paddle/operators/conv3d_op.cc | 117 +++++++++++++++ paddle/operators/conv3d_op.cu | 22 +++ paddle/operators/conv3d_op.h | 259 ++++++++++++++++++++++++++++++++ 4 files changed, 402 insertions(+), 1 deletion(-) create mode 100644 paddle/operators/conv3d_op.cc create mode 100644 paddle/operators/conv3d_op.cu create mode 100644 paddle/operators/conv3d_op.h diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 7dae8fe2f99f9..576cd2530d1e5 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -112,7 +112,8 @@ set(DEPS_OPS cond_op cross_entropy_op softmax_with_cross_entropy_op - sum_op) + sum_op + conv3d_op) op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc @@ -121,6 +122,8 @@ op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op) op_library(cross_entropy_op DEPS cross_entropy) op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax) op_library(sum_op DEPS net_op) +op_library(conv3d_op DEPS vol2col) + list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) foreach(src ${GENERAL_OPS}) diff --git a/paddle/operators/conv3d_op.cc b/paddle/operators/conv3d_op.cc new file mode 100644 index 0000000000000..2b34a2671d7af --- /dev/null +++ b/paddle/operators/conv3d_op.cc @@ -0,0 +1,117 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/conv3d_op.h" + +namespace paddle { +namespace operators { + +int OutputSizeConv3d(int input_size, int filter_size, int padding, int stride) { + int output_size = (input_size - filter_size + 2 * padding) / stride + 1; + return output_size; +} + +void Conv3DOp::InferShape(framework::InferShapeContext* ctx) const { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input(Input) of Conv3DOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Filter"), + "Input(Filter) of Conv3DOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Output"), + "Output(Output) of Conv3DOp should not be null."); + + auto in_dims = ctx->GetInputDim("Input"); + auto filter_dims = ctx->GetInputDim("Filter"); + std::vector strides = ctx->Attrs().Get>("strides"); + std::vector paddings = ctx->Attrs().Get>("paddings"); + int groups = ctx->Attrs().Get("groups"); + int input_channels = in_dims[1]; + int output_channels = filter_dims[0]; + + PADDLE_ENFORCE_EQ(in_dims.size(), 5, "Conv3DOp input should be 5-D."); + PADDLE_ENFORCE_EQ(filter_dims.size(), 5, "Conv3DOp filter should be 5-D."); + PADDLE_ENFORCE_EQ(input_channels, filter_dims[1] * groups, + "The number of input channels should be equal to filter " + "channels * groups."); + PADDLE_ENFORCE_EQ( + output_channels % groups, 0, + "The number of output channels should be divided by groups."); + + std::vector output_shape({in_dims[0], filter_dims[0]}); + for (size_t i = 0; i < paddings.size(); ++i) { + output_shape.push_back(OutputSizeConv3d(in_dims[i + 2], filter_dims[i], + paddings[i], strides[i])); + } + ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); +} + +void Conv3DOpGrad::InferShape(framework::InferShapeContext* ctx) const { + auto in_dims = ctx->GetInputDim("Input"); + auto filter_dims = ctx->GetInputDim("Filter"); + if (ctx->HasOutput(framework::GradVarName("Input"))) { + ctx->SetOutputDim(framework::GradVarName("Input"), in_dims); + } + if (ctx->HasOutput(framework::GradVarName("Filter"))) { + ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims); + } +} + +Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "Input", + "The input tensor of convolution operator. " + "The format of input tensor is NCDHW. Where N is batch size, C is the " + "number of channels, D, H and W is the depth, height and width of " + "image."); + AddInput("Filter", + "The filter tensor of convolution operator." + "The format of the filter tensor is MCDHW, where M is the number of " + "output image channels, C is the number of input image channels, " + "D, H and W is depth, height and width of filter. " + "If the groups attribute is greater than 1, C equal the number of " + "input image channels divided by the groups."); + AddOutput("Output", + "The output tensor of convolution operator." + "The format of output tensor is also NCDHW."); + AddAttr>("strides", "strides of convolution operator.") + .SetDefault({1, 1, 1}); + AddAttr>("paddings", "paddings of convolution operator.") + .SetDefault({0, 0, 0}); + AddAttr( + "groups", + "group size of convolution operator. " + "Refer to grouped convolution in Alex Krizhevsky's paper: " + "when group=2, the first half of the filters are only connected to the " + "first half of the input channels, and the second half only connected " + "to the second half.") + .SetDefault(1); + AddComment(R"DOC( +The convolution operation calculates the output based on the input, filter +and strides, paddings, groups parameters. The size of each dimension of the +parameters is checked in the infer-shape. +)DOC"); +} + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(conv3d, ops::Conv3DOp, ops::Conv3DOpMaker, conv3d_grad, + ops::Conv3DOpGrad); + +REGISTER_OP_CPU_KERNEL( + conv3d, ops::GemmConv3DKernel); +REGISTER_OP_CPU_KERNEL( + conv3d_grad, ops::GemmConvGrad3DKernel); diff --git a/paddle/operators/conv3d_op.cu b/paddle/operators/conv3d_op.cu new file mode 100644 index 0000000000000..ec6121d5d5136 --- /dev/null +++ b/paddle/operators/conv3d_op.cu @@ -0,0 +1,22 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/conv3d_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_GPU_KERNEL( + conv3d, ops::GemmConv3DKernel); +REGISTER_OP_GPU_KERNEL( + conv3d_grad, ops::GemmConvGrad3DKernel); diff --git a/paddle/operators/conv3d_op.h b/paddle/operators/conv3d_op.h new file mode 100644 index 0000000000000..a22cb34f674ec --- /dev/null +++ b/paddle/operators/conv3d_op.h @@ -0,0 +1,259 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" +#include "paddle/operators/math/vol2col.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class Conv3DOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override; +}; + +class Conv3DOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override; +}; + +class Conv3DOpMaker : public framework::OpProtoAndCheckerMaker { + public: + Conv3DOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker); +}; + +template +class GemmConv3DKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("Input"); + // The filter will be reshaped in the calculations, + // so here use an assignment operation, + // that avoids modifying the variable in the Scope. + Tensor filter = *context.Input("Filter"); + Tensor* output = context.Output("Output"); + output->mutable_data(context.GetPlace()); + + std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + int groups = context.Attr("groups"); + + int batch_size = input->dims()[0]; + int input_channels = input->dims()[1]; + int filter_depth = filter.dims()[filter.dims().size() - 3]; + int filter_height = filter.dims()[filter.dims().size() - 2]; + int filter_width = filter.dims()[filter.dims().size() - 1]; + int output_channels = output->dims()[1]; + int output_depth = output->dims()[2]; + int output_height = output->dims()[3]; + int output_width = output->dims()[4]; + + paddle::operators::math::Vol2ColFunctor vol2col; + // use col_shape in the vol2col calculation + framework::DDim col_shape = {input_channels / groups, + filter_depth, + filter_height, + filter_width, + output_depth, + output_height, + output_width}; + // use col_matrix_shape in the gemm calculation + framework::DDim col_matrix_shape = { + input_channels / groups * filter_depth * filter_height * filter_width, + output_depth * output_height * output_width}; + Tensor col; + col.mutable_data(col_shape, context.GetPlace()); + // col_matrix shares the same piece of data with col, + // but will be reshaped into a two-dimensional matrix shape + // to call the matrix multiplication interface. + Tensor col_matrix = col; + col_matrix.Resize(col_matrix_shape); + + framework::DDim input_shape = {input->dims()[1], input->dims()[2], + input->dims()[3], input->dims()[4]}; + framework::DDim filter_matrix_shape = {filter.dims()[0], + filter.numel() / filter.dims()[0]}; + filter.Resize(filter_matrix_shape); + + framework::DDim output_matrix_shape = { + output_channels, output_depth * output_height * output_width}; + + // convolution operator: vol2col + gemm + int in_step = input_channels / groups; + int out_step = output_channels / groups; + for (int i = 0; i < batch_size; i++) { + Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); + for (int g = 0; g < groups; g++) { + // vol2col + Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); + vol2col(context.device_context(), in_slice, col, strides[0], strides[1], + strides[2], paddings[0], paddings[1], paddings[2]); + + // gemm + Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); + math::matmul(context.device_context(), filter_slice, false, + col_matrix, false, T(1.0), &out_slice, T(0.0)); + } + } + } +}; + +template +class GemmConvGrad3DKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("Input"); + const Tensor* output_grad = + context.Input(framework::GradVarName("Output")); + Tensor* input_grad = + context.Output(framework::GradVarName("Input")); + Tensor* filter_grad = + context.Output(framework::GradVarName("Filter")); + + // The filter and filter_grad will be reshaped in the calculations, + // so here use an assignment operation, + // that avoids modifying the variable in the Scope. + Tensor filter = *context.Input("Filter"); + + std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + int groups = context.Attr("groups"); + + int batch_size = input->dims()[0]; + int input_channels = input->dims()[1]; + int filter_depth = filter.dims()[filter.dims().size() - 3]; + int filter_height = filter.dims()[filter.dims().size() - 2]; + int filter_width = filter.dims()[filter.dims().size() - 1]; + int output_channels = output_grad->dims()[1]; + int output_depth = output_grad->dims()[2]; + int output_height = output_grad->dims()[3]; + int output_width = output_grad->dims()[4]; + + paddle::operators::math::Col2VolFunctor col2vol; + paddle::operators::math::Vol2ColFunctor vol2col; + // use col_shape in the vol2col and col2vol calculation + framework::DDim col_shape = {input_channels / groups, + filter_depth, + filter_height, + filter_width, + output_depth, + output_height, + output_width}; + // use col_matrix_shape in the gemm calculation + framework::DDim col_matrix_shape = { + input_channels / groups * filter_depth * filter_height * filter_width, + output_depth * output_height * output_width}; + Tensor col; + col.mutable_data(col_shape, context.GetPlace()); + // col_matrix shares the same piece of data with col, + // but will be reshaped into a two-dimensional matrix shape + // to call the matrix multiplication interface. + Tensor col_matrix = col; + col_matrix.Resize(col_matrix_shape); + + framework::DDim input_shape = {input->dims()[1], input->dims()[2], + input->dims()[3], input->dims()[4]}; + framework::DDim output_matrix_shape = {output_grad->dims()[1], + output_grad->dims()[2] * + output_grad->dims()[3] * + output_grad->dims()[4]}; + + framework::DDim filter_matrix_shape = {filter.dims()[0], + filter.numel() / filter.dims()[0]}; + filter.Resize(filter_matrix_shape); + + // convolution backward input operator: gemm + col2vol + // convolution backward weight operator: vol2col + gemm + int in_step = input_channels / groups; + int out_step = output_channels / groups; + + if (input_grad) { + input_grad->mutable_data(context.GetPlace()); + auto t = framework::EigenVector::Flatten(*input_grad); + t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); + + for (int i = 0; i < batch_size; i++) { + Tensor out_grad_batch = + output_grad->Slice(i, i + 1).Resize(output_matrix_shape); + Tensor in_grad_batch = + input_grad->Slice(i, i + 1).Resize(input_shape); + for (int g = 0; g < groups; g++) { + // gemm + Tensor out_grad_slice = + out_grad_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor filter_slice = + filter.Slice(g * out_step, (g + 1) * out_step); + math::matmul(context.device_context(), filter_slice, true, + out_grad_slice, false, T(1.0), &col_matrix, + T(0.0)); + + // col2vol + Tensor in_grad_slice = + in_grad_batch.Slice(g * in_step, (g + 1) * in_step); + col2vol(context.device_context(), in_grad_slice, col, strides[0], + strides[1], strides[2], paddings[0], paddings[1], + paddings[2]); + } + } + } + + if (filter_grad) { + filter_grad->mutable_data(context.GetPlace()); + Tensor filter_grad_ = *filter_grad; + filter_grad_.Resize(filter_matrix_shape); + auto t = framework::EigenVector::Flatten(filter_grad_); + t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); + + for (int i = 0; i < batch_size; i++) { + Tensor out_grad_batch = + output_grad->Slice(i, i + 1).Resize(output_matrix_shape); + Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + for (int g = 0; g < groups; g++) { + // vol2col + Tensor out_grad_slice = + out_grad_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); + vol2col(context.device_context(), in_slice, col, strides[0], + strides[1], strides[2], paddings[0], paddings[1], + paddings[2]); + + // gemm + Tensor filter_grad_slice = + filter_grad_.Slice(g * out_step, (g + 1) * out_step); + math::matmul(context.device_context(), out_grad_slice, + false, col_matrix, true, T(1.0), + &filter_grad_slice, T(1.0)); + } + } + } + } +}; + +} // namespace operators +} // namespace paddle From c2fbf8c5a7e3ea299d2ab011b116df7f114c7e4c Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Thu, 12 Oct 2017 09:37:37 +0800 Subject: [PATCH 02/10] Add unit test --- .../v2/framework/tests/test_conv3d_op.py | 118 ++++++++++++++++++ 1 file changed, 118 insertions(+) create mode 100644 python/paddle/v2/framework/tests/test_conv3d_op.py diff --git a/python/paddle/v2/framework/tests/test_conv3d_op.py b/python/paddle/v2/framework/tests/test_conv3d_op.py new file mode 100644 index 0000000000000..cbc60111890c7 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_conv3d_op.py @@ -0,0 +1,118 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestConv3dOp(OpTest): + def setUp(self): + self.init_groups() + self.op_type = "conv3d" + batch_size = 2 + input_channels = 3 + input_depth = 5 + input_height = 5 + input_width = 5 + output_channels = 6 + filter_depth = 3 + filter_height = 3 + filter_width = 3 + stride = 1 + padding = 0 + output_depth = (input_depth - filter_depth + 2 * padding) / stride + 1 + output_height = (input_height - filter_height + 2 * padding + ) / stride + 1 + output_width = (input_width - filter_width + 2 * padding) / stride + 1 + input = np.random.random((batch_size, input_channels, input_depth, + input_height, input_width)).astype("float32") + + filter = np.random.random( + (output_channels, input_channels / self.groups, filter_depth, + filter_height, filter_width)).astype("float32") + output = np.ndarray((batch_size, output_channels, output_depth, + output_height, output_width)) + + self.inputs = {'Input': input, 'Filter': filter} + self.attrs = { + 'strides': [1, 1, 1], + 'paddings': [0, 0, 0], + 'groups': self.groups + } + + output_group_channels = output_channels / self.groups + input_group_channels = input_channels / self.groups + for batchid in xrange(batch_size): + for group in xrange(self.groups): + for outchannelid in range(group * output_group_channels, + (group + 1) * output_group_channels): + for deepid in xrange(output_depth): + for rowid in xrange(output_height): + for colid in xrange(output_width): + start_d = (deepid * stride) - padding + start_h = (rowid * stride) - padding + start_w = (colid * stride) - padding + output_value = 0.0 + for inchannelid in range( + group * input_group_channels, + (group + 1) * input_group_channels): + for fdeepid in xrange(filter_depth): + for frowid in xrange(filter_height): + for fcolid in xrange(filter_width): + input_value = 0.0 + indeepid = start_d + fdeepid + inrowid = start_h + frowid + incolid = start_w + fcolid + if ((indeepid >= 0 and + indeepid < input_depth) and + (inrowid >= 0 and + inrowid < input_height) and + (incolid >= 0 and + incolid < input_width)): + + input_value = input[ + batchid][inchannelid][ + indeepid][inrowid][ + incolid] + filter_value = filter[ + outchannelid][ + inchannelid % + input_group_channels][ + fdeepid][frowid][ + fcolid] + output_value += input_value * filter_value + output[batchid][outchannelid][deepid][rowid][ + colid] = output_value + + self.outputs = {'Output': output} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad( + set(['Input', 'Filter']), 'Output', max_relative_error=0.05) + + def test_check_grad_no_filter(self): + self.check_grad( + ['Input'], + 'Output', + max_relative_error=0.05, + no_grad_set=set(['Filter'])) + + def test_check_grad_no_input(self): + self.check_grad( + ['Filter'], + 'Output', + max_relative_error=0.05, + no_grad_set=set(['Input'])) + + def init_groups(self): + self.groups = 1 + + +class TestWithGroup(TestConv3dOp): + def init_groups(self): + self.groups = 3 + + +if __name__ == '__main__': + unittest.main() From 4aae1fff78d805ef9c2c08e6fc8702cc3e3ccc25 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Thu, 12 Oct 2017 15:13:10 +0800 Subject: [PATCH 03/10] fix conv3d_gemm, unit test and follow comments --- paddle/operators/conv3d_op.cc | 20 +-- paddle/operators/conv3d_op.cu | 18 +-- paddle/operators/conv3d_op.h | 18 +-- .../v2/framework/tests/test_conv3d_op.py | 138 ++++++++---------- 4 files changed, 92 insertions(+), 102 deletions(-) diff --git a/paddle/operators/conv3d_op.cc b/paddle/operators/conv3d_op.cc index 2b34a2671d7af..8477bc5719dd8 100644 --- a/paddle/operators/conv3d_op.cc +++ b/paddle/operators/conv3d_op.cc @@ -1,16 +1,16 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ #include "paddle/operators/conv3d_op.h" @@ -52,7 +52,7 @@ void Conv3DOp::InferShape(framework::InferShapeContext* ctx) const { output_shape.push_back(OutputSizeConv3d(in_dims[i + 2], filter_dims[i], paddings[i], strides[i])); } - ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); + ctx->SetOutputDim("Output", framework::make_ddim(output_shape)); } void Conv3DOpGrad::InferShape(framework::InferShapeContext* ctx) const { diff --git a/paddle/operators/conv3d_op.cu b/paddle/operators/conv3d_op.cu index ec6121d5d5136..ec6279f9bbbf9 100644 --- a/paddle/operators/conv3d_op.cu +++ b/paddle/operators/conv3d_op.cu @@ -1,16 +1,16 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ #include "paddle/operators/conv3d_op.h" diff --git a/paddle/operators/conv3d_op.h b/paddle/operators/conv3d_op.h index a22cb34f674ec..960d104877d38 100644 --- a/paddle/operators/conv3d_op.h +++ b/paddle/operators/conv3d_op.h @@ -1,16 +1,16 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ #pragma once diff --git a/python/paddle/v2/framework/tests/test_conv3d_op.py b/python/paddle/v2/framework/tests/test_conv3d_op.py index cbc60111890c7..1ec59afcfc28c 100644 --- a/python/paddle/v2/framework/tests/test_conv3d_op.py +++ b/python/paddle/v2/framework/tests/test_conv3d_op.py @@ -3,85 +3,59 @@ from op_test import OpTest +def conv3d_forward_naive(input, filter, group, conv_param): + in_n, in_c, in_d, in_h, in_w = input.shape + out_c, f_c, f_d, f_h, f_w = filter.shape + assert f_c * group == in_c + assert np.mod(out_c, group) == 0 + sub_out_c = out_c / group + + stride, pad = conv_param['stride'], conv_param['pad'] + out_d = 1 + (in_d + 2 * pad[0] - f_h) / stride[0] + out_h = 1 + (in_h + 2 * pad[1] - f_h) / stride[1] + out_w = 1 + (in_w + 2 * pad[2] - f_w) / stride[2] + out = np.zeros((in_n, out_c, out_d, out_h, out_w)) + + input_pad = np.pad(input, ((0, ), (0, ), (pad[0], ), (pad[1], ), + (pad[2], )), + mode='constant', + constant_values=0) + for d in range(out_d): + for i in range(out_h): + for j in range(out_w): + for g in range(group): + input_pad_masked = \ + input_pad[:, g * f_c:(g + 1) * f_c, + d * stride[0]:d * stride[0] + f_d, + i * stride[1]:i * stride[1] + f_h, + j * stride[2]:j * stride[2] + f_w] + f_sub = filter[g * sub_out_c:(g + 1) * + sub_out_c, :, :, :, :] + for k in range(sub_out_c): + out[:, g * sub_out_c + k, d, i, j] = \ + np.sum(input_pad_masked * f_sub[k, :, :, :, :], + axis=(1, 2, 3,4)) + + return out + + class TestConv3dOp(OpTest): def setUp(self): - self.init_groups() - self.op_type = "conv3d" - batch_size = 2 - input_channels = 3 - input_depth = 5 - input_height = 5 - input_width = 5 - output_channels = 6 - filter_depth = 3 - filter_height = 3 - filter_width = 3 - stride = 1 - padding = 0 - output_depth = (input_depth - filter_depth + 2 * padding) / stride + 1 - output_height = (input_height - filter_height + 2 * padding - ) / stride + 1 - output_width = (input_width - filter_width + 2 * padding) / stride + 1 - input = np.random.random((batch_size, input_channels, input_depth, - input_height, input_width)).astype("float32") - - filter = np.random.random( - (output_channels, input_channels / self.groups, filter_depth, - filter_height, filter_width)).astype("float32") - output = np.ndarray((batch_size, output_channels, output_depth, - output_height, output_width)) + self.init_group() + self.init_op_type() + self.init_test_case() + + conv3d_param = {'stride': self.stride, 'pad': self.pad} + input = np.random.random(self.input_size).astype("float32") + filter = np.random.random(self.filter_size).astype("float32") + output = conv3d_forward_naive(input, filter, self.groups, conv3d_param) self.inputs = {'Input': input, 'Filter': filter} self.attrs = { - 'strides': [1, 1, 1], - 'paddings': [0, 0, 0], + 'strides': self.stride, + 'paddings': self.pad, 'groups': self.groups } - - output_group_channels = output_channels / self.groups - input_group_channels = input_channels / self.groups - for batchid in xrange(batch_size): - for group in xrange(self.groups): - for outchannelid in range(group * output_group_channels, - (group + 1) * output_group_channels): - for deepid in xrange(output_depth): - for rowid in xrange(output_height): - for colid in xrange(output_width): - start_d = (deepid * stride) - padding - start_h = (rowid * stride) - padding - start_w = (colid * stride) - padding - output_value = 0.0 - for inchannelid in range( - group * input_group_channels, - (group + 1) * input_group_channels): - for fdeepid in xrange(filter_depth): - for frowid in xrange(filter_height): - for fcolid in xrange(filter_width): - input_value = 0.0 - indeepid = start_d + fdeepid - inrowid = start_h + frowid - incolid = start_w + fcolid - if ((indeepid >= 0 and - indeepid < input_depth) and - (inrowid >= 0 and - inrowid < input_height) and - (incolid >= 0 and - incolid < input_width)): - - input_value = input[ - batchid][inchannelid][ - indeepid][inrowid][ - incolid] - filter_value = filter[ - outchannelid][ - inchannelid % - input_group_channels][ - fdeepid][frowid][ - fcolid] - output_value += input_value * filter_value - output[batchid][outchannelid][deepid][rowid][ - colid] = output_value - self.outputs = {'Output': output} def test_check_output(self): @@ -105,14 +79,30 @@ def test_check_grad_no_input(self): max_relative_error=0.05, no_grad_set=set(['Input'])) - def init_groups(self): + def init_test_case(self): + # self.groups = 1 + # self.op_type = "conv3d" + self.pad = [0, 0, 0] + self.stride = [1, 1, 1] + self.input_size = [2, 3, 5, 5, 5] # NCDHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] / self.groups + self.filter_size = [6, f_c, 3, 3, 3] + + def init_group(self): self.groups = 1 + def init_op_type(self): + self.op_type = "conv3d" + class TestWithGroup(TestConv3dOp): - def init_groups(self): + def init_group(self): self.groups = 3 + def init_op_type(self): + self.op_type = "conv3d" + if __name__ == '__main__': unittest.main() From 91db457fc0f8409f5c05995482289d7386f3e986 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Wed, 18 Oct 2017 18:40:29 +0800 Subject: [PATCH 04/10] follow comments --- paddle/operators/conv3d_op.cc | 4 ++-- paddle/operators/conv3d_op.h | 22 ++++++++++++------- .../v2/framework/tests/test_conv3d_op.py | 10 ++------- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/paddle/operators/conv3d_op.cc b/paddle/operators/conv3d_op.cc index 714cf8abbf5c5..f86ed86a5022c 100644 --- a/paddle/operators/conv3d_op.cc +++ b/paddle/operators/conv3d_op.cc @@ -87,11 +87,11 @@ Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto, "The format of output tensor is also NCDHW."); AddAttr>("strides", "strides of convolution operator.") .SetDefault({1, 1, 1}); - AddAttr>("paddings", "paddings of convolution operator.") + AddAttr>("paddings", "The paddings of convolution operator.") .SetDefault({0, 0, 0}); AddAttr( "groups", - "group size of convolution operator. " + "The group size of convolution operator. " "Refer to grouped convolution in Alex Krizhevsky's paper: " "when group=2, the first half of the filters are only connected to the " "first half of the input channels, and the second half only connected " diff --git a/paddle/operators/conv3d_op.h b/paddle/operators/conv3d_op.h index 960d104877d38..0bc0673967947 100644 --- a/paddle/operators/conv3d_op.h +++ b/paddle/operators/conv3d_op.h @@ -93,10 +93,13 @@ class GemmConv3DKernel : public framework::OpKernel { Tensor col_matrix = col; col_matrix.Resize(col_matrix_shape); - framework::DDim input_shape = {input->dims()[1], input->dims()[2], - input->dims()[3], input->dims()[4]}; - framework::DDim filter_matrix_shape = {filter.dims()[0], - filter.numel() / filter.dims()[0]}; + framework::DDim input_shape = { + input->dims()[1], input->dims()[2], input->dims()[3], + input->dims()[4]}; // channel, depth, height, width + framework::DDim filter_matrix_shape = { + filter.dims()[0], + filter.numel() / filter.dims()[0]}; // filter_out_channel, + // filter_in_channel*filter_depth*filter_height*filter_width filter.Resize(filter_matrix_shape); framework::DDim output_matrix_shape = { @@ -177,15 +180,18 @@ class GemmConvGrad3DKernel : public framework::OpKernel { Tensor col_matrix = col; col_matrix.Resize(col_matrix_shape); - framework::DDim input_shape = {input->dims()[1], input->dims()[2], - input->dims()[3], input->dims()[4]}; + framework::DDim input_shape = { + input->dims()[1], input->dims()[2], input->dims()[3], + input->dims()[4]}; // channel, depth, height, width framework::DDim output_matrix_shape = {output_grad->dims()[1], output_grad->dims()[2] * output_grad->dims()[3] * output_grad->dims()[4]}; - framework::DDim filter_matrix_shape = {filter.dims()[0], - filter.numel() / filter.dims()[0]}; + framework::DDim filter_matrix_shape = { + filter.dims()[0], + filter.numel() / filter.dims()[0]}; // filter_out_channel, + // filter_in_channel*filter_depth*filter_height*filter_width filter.Resize(filter_matrix_shape); // convolution backward input operator: gemm + col2vol diff --git a/python/paddle/v2/framework/tests/test_conv3d_op.py b/python/paddle/v2/framework/tests/test_conv3d_op.py index e81f2a166caa4..4e12b1a0c89d6 100644 --- a/python/paddle/v2/framework/tests/test_conv3d_op.py +++ b/python/paddle/v2/framework/tests/test_conv3d_op.py @@ -34,7 +34,7 @@ def conv3d_forward_naive(input, filter, group, conv_param): for k in range(sub_out_c): out[:, g * sub_out_c + k, d, i, j] = \ np.sum(input_pad_masked * f_sub[k, :, :, :, :], - axis=(1, 2, 3,4)) + axis=(1, 2, 3, 4)) return out @@ -65,7 +65,6 @@ def test_check_grad(self): self.check_grad( set(['Input', 'Filter']), 'Output', max_relative_error=0.05) - def test_check_grad_no_filter(self): self.check_grad( ['Input'], 'Output', @@ -80,8 +79,6 @@ def test_check_grad_no_input(self): no_grad_set=set(['Input'])) def init_test_case(self): - # self.groups = 1 - # self.op_type = "conv3d" self.pad = [0, 0, 0] self.stride = [1, 1, 1] self.input_size = [2, 3, 5, 5, 5] # NCDHW @@ -98,8 +95,6 @@ def init_op_type(self): class TestCase1(TestConv3dOp): def init_test_case(self): - # self.groups = 1 - # self.op_type = "conv3d" self.pad = [1, 1, 1] self.stride = [1, 1, 1] self.input_size = [2, 3, 5, 5, 5] # NCDHW @@ -114,7 +109,6 @@ def init_op_type(self): self.op_type = "conv3d" -''' class TestWithGroup1(TestConv3dOp): def init_group(self): self.groups = 3 @@ -129,7 +123,7 @@ def init_group(self): def init_op_type(self): self.op_type = "conv3d" -''' + if __name__ == '__main__': unittest.main() From 08a7b1ded7cd7c1b021c06f3dcf427dd9c3a71d9 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Wed, 18 Oct 2017 19:28:15 +0800 Subject: [PATCH 05/10] fix unit test --- python/paddle/v2/framework/tests/test_conv3d_op.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/paddle/v2/framework/tests/test_conv3d_op.py b/python/paddle/v2/framework/tests/test_conv3d_op.py index 4e12b1a0c89d6..010217cbf87d3 100644 --- a/python/paddle/v2/framework/tests/test_conv3d_op.py +++ b/python/paddle/v2/framework/tests/test_conv3d_op.py @@ -65,6 +65,7 @@ def test_check_grad(self): self.check_grad( set(['Input', 'Filter']), 'Output', max_relative_error=0.05) + def test_check_grad_no_filter(self): self.check_grad( ['Input'], 'Output', @@ -81,7 +82,7 @@ def test_check_grad_no_input(self): def init_test_case(self): self.pad = [0, 0, 0] self.stride = [1, 1, 1] - self.input_size = [2, 3, 5, 5, 5] # NCDHW + self.input_size = [2, 3, 4, 4, 4] # NCDHW assert np.mod(self.input_size[1], self.groups) == 0 f_c = self.input_size[1] / self.groups self.filter_size = [6, f_c, 3, 3, 3] @@ -97,7 +98,7 @@ class TestCase1(TestConv3dOp): def init_test_case(self): self.pad = [1, 1, 1] self.stride = [1, 1, 1] - self.input_size = [2, 3, 5, 5, 5] # NCDHW + self.input_size = [2, 3, 4, 4, 4] # NCDHW assert np.mod(self.input_size[1], self.groups) == 0 f_c = self.input_size[1] / self.groups self.filter_size = [6, f_c, 3, 3, 3] From 9f7c9875a9cabc5b4298ecff93c106e005987099 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Wed, 25 Oct 2017 11:34:35 +0800 Subject: [PATCH 06/10] fix doc --- paddle/operators/conv3d_op.cc | 39 +++++++++++++++++++++++++++-------- paddle/operators/pool_op.cc | 2 -- 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/paddle/operators/conv3d_op.cc b/paddle/operators/conv3d_op.cc index f86ed86a5022c..fb3f1265f3af0 100644 --- a/paddle/operators/conv3d_op.cc +++ b/paddle/operators/conv3d_op.cc @@ -38,11 +38,12 @@ void Conv3DOp::InferShape(framework::InferShapeContext* ctx) const { int input_channels = in_dims[1]; int output_channels = filter_dims[0]; - PADDLE_ENFORCE_EQ(in_dims.size(), 5, "Conv3DOp input should be 5-D."); - PADDLE_ENFORCE_EQ(filter_dims.size(), 5, "Conv3DOp filter should be 5-D."); + PADDLE_ENFORCE_EQ(in_dims.size(), 5, "Conv3DOp input should be 5-D tensor."); + PADDLE_ENFORCE_EQ(filter_dims.size(), 5, + "Conv3DOp filter should be 5-D tensor."); PADDLE_ENFORCE_EQ(input_channels, filter_dims[1] * groups, "The number of input channels should be equal to filter " - "channels * groups."); + "(channels * groups)."); PADDLE_ENFORCE_EQ( output_channels % groups, 0, "The number of output channels should be divided by groups."); @@ -71,27 +72,31 @@ Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto, : OpProtoAndCheckerMaker(proto, op_checker) { AddInput( "Input", - "The input tensor of convolution operator. " + "(Tensor), the input tensor of convolution operator. " "The format of input tensor is NCDHW. Where N is batch size, C is the " "number of channels, D, H and W is the depth, height and width of " "image."); AddInput("Filter", - "The filter tensor of convolution operator." + "(Tensor), the filter tensor of convolution operator." "The format of the filter tensor is MCDHW, where M is the number of " "output image channels, C is the number of input image channels, " "D, H and W is depth, height and width of filter. " "If the groups attribute is greater than 1, C equal the number of " "input image channels divided by the groups."); AddOutput("Output", - "The output tensor of convolution operator." + "(Tensor), the output tensor of convolution operator." "The format of output tensor is also NCDHW."); - AddAttr>("strides", "strides of convolution operator.") + AddAttr>( + "strides", + "(vector, default {0,0,0}), the strides of convolution operator.") .SetDefault({1, 1, 1}); - AddAttr>("paddings", "The paddings of convolution operator.") + AddAttr>( + "paddings", + "(vector, default {0,0,0}), the paddings of convolution operator.") .SetDefault({0, 0, 0}); AddAttr( "groups", - "The group size of convolution operator. " + "(int, default 1) the group size of convolution operator. " "Refer to grouped convolution in Alex Krizhevsky's paper: " "when group=2, the first half of the filters are only connected to the " "first half of the input channels, and the second half only connected " @@ -101,6 +106,22 @@ Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto, The convolution operation calculates the output based on the input, filter and strides, paddings, groups parameters. The size of each dimension of the parameters is checked in the infer-shape. +Input(Input, Filter) and output(Output) are in NCDHW format. Where N is batch +size, C is the number of channels, D, H and W is the depth, height and +width of feature. Parameters(ksize, strides, paddings) are three elements. +These three elements represent depth, height and width, respectively. +The input(X) size and output(Out) size may be different. + +Example: + Input: + Input shape: (N, C_in, D_in, H_in, W_in) + Filter shape: (C_out, C_in, D_f, H_f, W_f) + Output: + Output shape: (N, C_out, D_out, H_out, W_out) + where + D_out = (D_in - filter_size[0] + 2 * paddings[0]) / strides[0] + 1; + H_out = (H_in - filter_size[1] + 2 * paddings[1]) / strides[1] + 1; + W_out = (W_in - filter_size[2] + 2 * paddings[2]) / strides[2] + 1; )DOC"); } diff --git a/paddle/operators/pool_op.cc b/paddle/operators/pool_op.cc index a326839c0f9ad..898ae2fb62799 100644 --- a/paddle/operators/pool_op.cc +++ b/paddle/operators/pool_op.cc @@ -123,7 +123,6 @@ The input(X) size and output(Out) size may be different. X shape: (N, C, H_in, W_in) Output: Out shape: (N, C, H_out, W_out) - Mask shape: (N, C, H_out, W_out) where H_out = (H_in - ksize[0] + 2 * paddings[0]) / strides[0] + 1; W_out = (W_in - ksize[1] + 2 * paddings[1]) / strides[1] + 1; @@ -190,7 +189,6 @@ The input(X) size and output(Out) size may be different. X shape: (N, C, D_in, H_in, W_in) Output: Out shape: (N, C, D_out, H_out, W_out) - Mask shape: (N, C, D_out, H_out, W_out) where D_out = (D_in - ksize[0] + 2 * paddings[0]) / strides[0] + 1; H_out = (H_in - ksize[1] + 2 * paddings[1]) / strides[1] + 1; From eafbbc11a0bb1f347f7917552d46c2944b5f3bb2 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Thu, 26 Oct 2017 10:21:05 +0800 Subject: [PATCH 07/10] write conv2d and conv3d together --- paddle/operators/CMakeLists.txt | 11 +- paddle/operators/conv2d_op.cc | 111 -------- paddle/operators/conv3d_op.cu | 22 -- paddle/operators/conv3d_op.h | 263 ------------------ paddle/operators/conv_cudnn_op.cc | 7 +- paddle/operators/conv_cudnn_op.cu | 2 +- paddle/operators/{conv3d_op.cc => conv_op.cc} | 100 +++++-- paddle/operators/{conv2d_op.cu => conv_op.cu} | 7 +- paddle/operators/{conv2d_op.h => conv_op.h} | 224 ++++++++++++++- 9 files changed, 315 insertions(+), 432 deletions(-) delete mode 100644 paddle/operators/conv2d_op.cc delete mode 100644 paddle/operators/conv3d_op.cu delete mode 100644 paddle/operators/conv3d_op.h rename paddle/operators/{conv3d_op.cc => conv_op.cc} (61%) rename paddle/operators/{conv2d_op.cu => conv_op.cu} (78%) rename paddle/operators/{conv2d_op.h => conv_op.h} (51%) diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 4d1fb3b96e306..39250480db37f 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -69,6 +69,13 @@ function(op_library TARGET) file(APPEND ${pybind_file} "USE_OP(max_pool2d_with_index);\n") endif() + # conv_op contains several operators + if ("${TARGET}" STREQUAL "conv_op") + set(pybind_flag 1) + # It's enough to just adding one operator to pybind + file(APPEND ${pybind_file} "USE_OP(conv2d);\n") + endif() + # save_restore_op contains several operators if ("${TARGET}" STREQUAL "save_restore_op") set(pybind_flag 1) @@ -123,7 +130,7 @@ set(DEPS_OPS sum_op pool_op pool_with_index_op - conv3d_op + conv_op lstm_op) @@ -133,7 +140,7 @@ op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op) op_library(cross_entropy_op DEPS cross_entropy) op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax) op_library(sum_op DEPS net_op) -op_library(conv3d_op DEPS vol2col) +op_library(conv_op DEPS vol2col) op_library(pool_op DEPS pooling) op_library(pool_with_index_op DEPS pooling) op_library(lstm_op DEPS sequence2batch lstm_compute) diff --git a/paddle/operators/conv2d_op.cc b/paddle/operators/conv2d_op.cc deleted file mode 100644 index 1acb8415d0691..0000000000000 --- a/paddle/operators/conv2d_op.cc +++ /dev/null @@ -1,111 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#include "paddle/operators/conv2d_op.h" - -namespace paddle { -namespace operators { - -void Conv2DOp::InferShape(framework::InferShapeContext* ctx) const { - PADDLE_ENFORCE(ctx->HasInput("Input"), - "Input(Input) of Conv2DOp should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Filter"), - "Input(Filter) of Conv2DOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Output"), - "Output(Output) of Conv2DOp should not be null."); - - auto in_dims = ctx->GetInputDim("Input"); - auto filter_dims = ctx->GetInputDim("Filter"); - std::vector strides = ctx->Attrs().Get>("strides"); - std::vector paddings = ctx->Attrs().Get>("paddings"); - int groups = ctx->Attrs().Get("groups"); - int input_channels = in_dims[1]; - int output_channels = filter_dims[0]; - - PADDLE_ENFORCE_EQ(in_dims.size(), 4, "Conv2DOp input should be 4-D."); - PADDLE_ENFORCE_EQ(filter_dims.size(), 4, "Conv2DOp filter should be 4-D."); - PADDLE_ENFORCE_EQ(input_channels, filter_dims[1] * groups, - "The number of input channels should be equal to filter " - "channels * groups."); - PADDLE_ENFORCE_EQ( - output_channels % groups, 0, - "The number of output channels should be divided by groups."); - - auto output_height = - OutputSize(in_dims[2], filter_dims[2], paddings[0], strides[0]); - auto output_width = - OutputSize(in_dims[3], filter_dims[3], paddings[1], strides[1]); - ctx->SetOutputDim("Output", - {in_dims[0], filter_dims[0], output_height, output_width}); -} - -Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto, - framework::OpAttrChecker* op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput( - "Input", - "The input tensor of convolution operator. " - "The format of input tensor is NCHW. Where N is batch size, C is the " - "number of channels, H and W is the height and width of image."); - AddInput("Filter", - "The filter tensor of convolution operator." - "The format of the filter tensor is MCHW, where M is the number of " - "output image channels, C is the number of input image channels, " - "H and W is height and width of filter. " - "If the groups attribute is greater than 1, C equal the number of " - "input image channels divided by the groups."); - AddOutput("Output", - "The output tensor of convolution operator." - "The format of output tensor is also NCHW."); - AddAttr>("strides", "strides of convolution operator.") - .SetDefault({1, 1}); - AddAttr>("paddings", "paddings of convolution operator.") - .SetDefault({0, 0}); - AddAttr( - "groups", - "group size of convolution operator. " - "Refer to grouped convolution in Alex Krizhevsky's paper: " - "when group=2, the first half of the filters are only connected to the " - "first half of the input channels, and the second half only connected " - "to the second half.") - .SetDefault(1); - AddComment(R"DOC( -The convolution operation calculates the output based on the input, filter -and strides, paddings, groups parameters. The size of each dimension of the -parameters is checked in the infer-shape. -)DOC"); -} - -void Conv2DOpGrad::InferShape(framework::InferShapeContext* ctx) const { - auto in_dims = ctx->GetInputDim("Input"); - auto filter_dims = ctx->GetInputDim("Filter"); - if (ctx->HasOutput(framework::GradVarName("Input"))) { - ctx->SetOutputDim(framework::GradVarName("Input"), in_dims); - } - if (ctx->HasOutput(framework::GradVarName("Filter"))) { - ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims); - } -} - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP(conv2d, ops::Conv2DOp, ops::Conv2DOpMaker, conv2d_grad, - ops::Conv2DOpGrad); - -REGISTER_OP_CPU_KERNEL( - conv2d, ops::GemmConv2DKernel); -REGISTER_OP_CPU_KERNEL( - conv2d_grad, ops::GemmConvGrad2DKernel); diff --git a/paddle/operators/conv3d_op.cu b/paddle/operators/conv3d_op.cu deleted file mode 100644 index ec6279f9bbbf9..0000000000000 --- a/paddle/operators/conv3d_op.cu +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#include "paddle/operators/conv3d_op.h" - -namespace ops = paddle::operators; - -REGISTER_OP_GPU_KERNEL( - conv3d, ops::GemmConv3DKernel); -REGISTER_OP_GPU_KERNEL( - conv3d_grad, ops::GemmConvGrad3DKernel); diff --git a/paddle/operators/conv3d_op.h b/paddle/operators/conv3d_op.h deleted file mode 100644 index c5aaf019f3b4c..0000000000000 --- a/paddle/operators/conv3d_op.h +++ /dev/null @@ -1,263 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#pragma once - -#include "paddle/framework/eigen.h" -#include "paddle/framework/op_registry.h" -#include "paddle/operators/math/math_function.h" -#include "paddle/operators/math/vol2col.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -class Conv3DOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - void InferShape(framework::InferShapeContext* ctx) const override; -}; - -class Conv3DOpGrad : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - void InferShape(framework::InferShapeContext* ctx) const override; -}; - -class Conv3DOpMaker : public framework::OpProtoAndCheckerMaker { - public: - Conv3DOpMaker(framework::OpProto* proto, - framework::OpAttrChecker* op_checker); -}; - -template -class GemmConv3DKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor* input = context.Input("Input"); - // The filter will be reshaped in the calculations, - // so here use an assignment operation, - // that avoids modifying the variable in the Scope. - Tensor filter = *context.Input("Filter"); - Tensor* output = context.Output("Output"); - output->mutable_data(context.GetPlace()); - - std::vector strides = context.Attr>("strides"); - std::vector paddings = context.Attr>("paddings"); - int groups = context.Attr("groups"); - - int batch_size = input->dims()[0]; - int input_channels = input->dims()[1]; - int filter_depth = filter.dims()[filter.dims().size() - 3]; - int filter_height = filter.dims()[filter.dims().size() - 2]; - int filter_width = filter.dims()[filter.dims().size() - 1]; - int output_channels = output->dims()[1]; - int output_depth = output->dims()[2]; - int output_height = output->dims()[3]; - int output_width = output->dims()[4]; - - paddle::operators::math::Vol2ColFunctor vol2col; - // use col_shape in the vol2col calculation - framework::DDim col_shape = {input_channels / groups, - filter_depth, - filter_height, - filter_width, - output_depth, - output_height, - output_width}; - // use col_matrix_shape in the gemm calculation - framework::DDim col_matrix_shape = { - input_channels / groups * filter_depth * filter_height * filter_width, - output_depth * output_height * output_width}; - Tensor col; - col.mutable_data(col_shape, context.GetPlace()); - // col_matrix shares the same piece of data with col, - // but will be reshaped into a two-dimensional matrix shape - // to call the matrix multiplication interface. - Tensor col_matrix = col; - col_matrix.Resize(col_matrix_shape); - - framework::DDim input_shape = { - input->dims()[1], input->dims()[2], input->dims()[3], - input->dims()[4]}; // channel, depth, height, width - framework::DDim filter_matrix_shape = { - filter.dims()[0], - filter.numel() / filter.dims()[0]}; // filter_out_channel, - // filter_in_channel*filter_depth*filter_height*filter_width - filter.Resize(filter_matrix_shape); - - framework::DDim output_matrix_shape = { - output_channels, output_depth * output_height * output_width}; - - // convolution operator: vol2col + gemm - int in_step = input_channels / groups; - int out_step = output_channels / groups; - for (int i = 0; i < batch_size; i++) { - Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); - Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); - for (int g = 0; g < groups; g++) { - // vol2col - Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); - vol2col(context.device_context(), in_slice, col, strides[0], strides[1], - strides[2], paddings[0], paddings[1], paddings[2]); - - // gemm - Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(context.device_context(), filter_slice, false, - col_matrix, false, T(1.0), &out_slice, T(0.0)); - } - } - } -}; - -template -class GemmConvGrad3DKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor* input = context.Input("Input"); - const Tensor* output_grad = - context.Input(framework::GradVarName("Output")); - Tensor* input_grad = - context.Output(framework::GradVarName("Input")); - Tensor* filter_grad = - context.Output(framework::GradVarName("Filter")); - - // The filter and filter_grad will be reshaped in the calculations, - // so here use an assignment operation, - // that avoids modifying the variable in the Scope. - Tensor filter = *context.Input("Filter"); - - std::vector strides = context.Attr>("strides"); - std::vector paddings = context.Attr>("paddings"); - int groups = context.Attr("groups"); - - int batch_size = input->dims()[0]; - int input_channels = input->dims()[1]; - int filter_depth = filter.dims()[filter.dims().size() - 3]; - int filter_height = filter.dims()[filter.dims().size() - 2]; - int filter_width = filter.dims()[filter.dims().size() - 1]; - int output_channels = output_grad->dims()[1]; - int output_depth = output_grad->dims()[2]; - int output_height = output_grad->dims()[3]; - int output_width = output_grad->dims()[4]; - - paddle::operators::math::Col2VolFunctor col2vol; - paddle::operators::math::Vol2ColFunctor vol2col; - // use col_shape in the vol2col and col2vol calculation - framework::DDim col_shape = {input_channels / groups, - filter_depth, - filter_height, - filter_width, - output_depth, - output_height, - output_width}; - // use col_matrix_shape in the gemm calculation - framework::DDim col_matrix_shape = { - input_channels / groups * filter_depth * filter_height * filter_width, - output_depth * output_height * output_width}; - Tensor col; - col.mutable_data(col_shape, context.GetPlace()); - // col_matrix shares the same piece of data with col, - // but will be reshaped into a two-dimensional matrix shape - // to call the matrix multiplication interface. - Tensor col_matrix = col; - col_matrix.Resize(col_matrix_shape); - - framework::DDim input_shape = { - input->dims()[1], input->dims()[2], input->dims()[3], - input->dims()[4]}; // channel, depth, height, width - framework::DDim output_matrix_shape = {output_grad->dims()[1], - output_grad->dims()[2] * - output_grad->dims()[3] * - output_grad->dims()[4]}; - - framework::DDim filter_matrix_shape = { - filter.dims()[0], - filter.numel() / filter.dims()[0]}; // filter_out_channel, - // filter_in_channel*filter_depth*filter_height*filter_width - filter.Resize(filter_matrix_shape); - - // convolution backward input operator: gemm + col2vol - // convolution backward weight operator: vol2col + gemm - int in_step = input_channels / groups; - int out_step = output_channels / groups; - - if (input_grad) { - input_grad->mutable_data(context.GetPlace()); - auto t = framework::EigenVector::Flatten(*input_grad); - t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); - - for (int i = 0; i < batch_size; i++) { - Tensor out_grad_batch = - output_grad->Slice(i, i + 1).Resize(output_matrix_shape); - Tensor in_grad_batch = input_grad->Slice(i, i + 1).Resize(input_shape); - for (int g = 0; g < groups; g++) { - // gemm - Tensor out_grad_slice = - out_grad_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(context.device_context(), filter_slice, true, - out_grad_slice, false, T(1.0), &col_matrix, - T(0.0)); - - // col2vol - Tensor in_grad_slice = - in_grad_batch.Slice(g * in_step, (g + 1) * in_step); - col2vol(context.device_context(), in_grad_slice, col, strides[0], - strides[1], strides[2], paddings[0], paddings[1], - paddings[2]); - } - } - } - - if (filter_grad) { - filter_grad->mutable_data(context.GetPlace()); - Tensor filter_grad_ = *filter_grad; - filter_grad_.Resize(filter_matrix_shape); - auto t = framework::EigenVector::Flatten(filter_grad_); - t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); - - for (int i = 0; i < batch_size; i++) { - Tensor out_grad_batch = - output_grad->Slice(i, i + 1).Resize(output_matrix_shape); - Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); - for (int g = 0; g < groups; g++) { - // vol2col - Tensor out_grad_slice = - out_grad_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); - vol2col(context.device_context(), in_slice, col, strides[0], - strides[1], strides[2], paddings[0], paddings[1], - paddings[2]); - - // gemm - Tensor filter_grad_slice = - filter_grad_.Slice(g * out_step, (g + 1) * out_step); - math::matmul(context.device_context(), out_grad_slice, - false, col_matrix, true, T(1.0), - &filter_grad_slice, T(1.0)); - } - } - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/operators/conv_cudnn_op.cc b/paddle/operators/conv_cudnn_op.cc index 4288f300dd5b0..37bba3a1a1ef0 100644 --- a/paddle/operators/conv_cudnn_op.cc +++ b/paddle/operators/conv_cudnn_op.cc @@ -12,7 +12,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/conv2d_op.h" +#include "paddle/operators/conv_op.h" namespace paddle { namespace operators { @@ -38,8 +38,9 @@ class CudnnConvOpMaker : public Conv2DOpMaker { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(conv_cudnn, ops::Conv2DOp, ops::CudnnConvOpMaker, conv_cudnn_grad, - ops::Conv2DOpGrad); +REGISTER_OP(conv_cudnn, ops::ConvOp, ops::CudnnConvOpMaker, conv_cudnn_grad, + ops::ConvOpGrad); + REGISTER_OP_CPU_KERNEL( conv_cudnn, ops::GemmConv2DKernel); REGISTER_OP_CPU_KERNEL( diff --git a/paddle/operators/conv_cudnn_op.cu b/paddle/operators/conv_cudnn_op.cu index 366d0323b840c..e34d593740754 100644 --- a/paddle/operators/conv_cudnn_op.cu +++ b/paddle/operators/conv_cudnn_op.cu @@ -15,7 +15,7 @@ #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" #include "paddle/memory/memory.h" -#include "paddle/operators/conv2d_op.h" +#include "paddle/operators/conv_op.h" #include "paddle/platform/assert.h" #include "paddle/platform/cudnn_helper.h" diff --git a/paddle/operators/conv3d_op.cc b/paddle/operators/conv_op.cc similarity index 61% rename from paddle/operators/conv3d_op.cc rename to paddle/operators/conv_op.cc index fb3f1265f3af0..5e264d730c457 100644 --- a/paddle/operators/conv3d_op.cc +++ b/paddle/operators/conv_op.cc @@ -12,23 +12,18 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/conv3d_op.h" +#include "paddle/operators/conv_op.h" namespace paddle { namespace operators { -int OutputSizeConv3d(int input_size, int filter_size, int padding, int stride) { - int output_size = (input_size - filter_size + 2 * padding) / stride + 1; - return output_size; -} - -void Conv3DOp::InferShape(framework::InferShapeContext* ctx) const { +void ConvOp::InferShape(framework::InferShapeContext* ctx) const { PADDLE_ENFORCE(ctx->HasInput("Input"), - "Input(Input) of Conv3DOp should not be null."); + "Input(Input) of ConvOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Filter"), - "Input(Filter) of Conv3DOp should not be null."); + "Input(Filter) of ConvOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Output"), - "Output(Output) of Conv3DOp should not be null."); + "Output(Output) of ConvOp should not be null."); auto in_dims = ctx->GetInputDim("Input"); auto filter_dims = ctx->GetInputDim("Filter"); @@ -38,33 +33,65 @@ void Conv3DOp::InferShape(framework::InferShapeContext* ctx) const { int input_channels = in_dims[1]; int output_channels = filter_dims[0]; - PADDLE_ENFORCE_EQ(in_dims.size(), 5, "Conv3DOp input should be 5-D tensor."); - PADDLE_ENFORCE_EQ(filter_dims.size(), 5, - "Conv3DOp filter should be 5-D tensor."); + PADDLE_ENFORCE_EQ( + in_dims.size(), filter_dims.size(), + "Conv input dimension and filter dimension should be the same."); + PADDLE_ENFORCE( + in_dims.size() - strides.size() == 2U, + "Conv input dimension and strides dimension should be consistent."); + PADDLE_ENFORCE_EQ( + paddings.size(), strides.size(), + "Conv paddings dimension and Conv strides dimension should be the same."); PADDLE_ENFORCE_EQ(input_channels, filter_dims[1] * groups, "The number of input channels should be equal to filter " - "(channels * groups)."); + "channels * groups."); PADDLE_ENFORCE_EQ( output_channels % groups, 0, "The number of output channels should be divided by groups."); std::vector output_shape({in_dims[0], filter_dims[0]}); for (size_t i = 0; i < paddings.size(); ++i) { - output_shape.push_back(OutputSizeConv3d(in_dims[i + 2], filter_dims[i + 2], - paddings[i], strides[i])); + output_shape.push_back(OutputSize(in_dims[i + 2], filter_dims[i + 2], + paddings[i], strides[i])); } ctx->SetOutputDim("Output", framework::make_ddim(output_shape)); } -void Conv3DOpGrad::InferShape(framework::InferShapeContext* ctx) const { - auto in_dims = ctx->GetInputDim("Input"); - auto filter_dims = ctx->GetInputDim("Filter"); - if (ctx->HasOutput(framework::GradVarName("Input"))) { - ctx->SetOutputDim(framework::GradVarName("Input"), in_dims); - } - if (ctx->HasOutput(framework::GradVarName("Filter"))) { - ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims); - } +Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "Input", + "The input tensor of convolution operator. " + "The format of input tensor is NCHW. Where N is batch size, C is the " + "number of channels, H and W is the height and width of image."); + AddInput("Filter", + "The filter tensor of convolution operator." + "The format of the filter tensor is MCHW, where M is the number of " + "output image channels, C is the number of input image channels, " + "H and W is height and width of filter. " + "If the groups attribute is greater than 1, C equal the number of " + "input image channels divided by the groups."); + AddOutput("Output", + "The output tensor of convolution operator." + "The format of output tensor is also NCHW."); + AddAttr>("strides", "strides of convolution operator.") + .SetDefault({1, 1}); + AddAttr>("paddings", "paddings of convolution operator.") + .SetDefault({0, 0}); + AddAttr( + "groups", + "group size of convolution operator. " + "Refer to grouped convolution in Alex Krizhevsky's paper: " + "when group=2, the first half of the filters are only connected to the " + "first half of the input channels, and the second half only connected " + "to the second half.") + .SetDefault(1); + AddComment(R"DOC( +The convolution operation calculates the output based on the input, filter +and strides, paddings, groups parameters. The size of each dimension of the +parameters is checked in the infer-shape. +)DOC"); } Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto, @@ -125,12 +152,31 @@ The input(X) size and output(Out) size may be different. )DOC"); } +void ConvOpGrad::InferShape(framework::InferShapeContext* ctx) const { + auto in_dims = ctx->GetInputDim("Input"); + auto filter_dims = ctx->GetInputDim("Filter"); + if (ctx->HasOutput(framework::GradVarName("Input"))) { + ctx->SetOutputDim(framework::GradVarName("Input"), in_dims); + } + if (ctx->HasOutput(framework::GradVarName("Filter"))) { + ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims); + } +} + } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(conv3d, ops::Conv3DOp, ops::Conv3DOpMaker, conv3d_grad, - ops::Conv3DOpGrad); +REGISTER_OP(conv2d, ops::ConvOp, ops::Conv2DOpMaker, conv2d_grad, + ops::ConvOpGrad); +namespace ops = paddle::operators; +REGISTER_OP(conv3d, ops::ConvOp, ops::Conv3DOpMaker, conv3d_grad, + ops::ConvOpGrad); + +REGISTER_OP_CPU_KERNEL( + conv2d, ops::GemmConv2DKernel); +REGISTER_OP_CPU_KERNEL( + conv2d_grad, ops::GemmConvGrad2DKernel); REGISTER_OP_CPU_KERNEL( conv3d, ops::GemmConv3DKernel); diff --git a/paddle/operators/conv2d_op.cu b/paddle/operators/conv_op.cu similarity index 78% rename from paddle/operators/conv2d_op.cu rename to paddle/operators/conv_op.cu index c697c9466d34c..d8c0bd9326bb9 100644 --- a/paddle/operators/conv2d_op.cu +++ b/paddle/operators/conv_op.cu @@ -12,7 +12,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/conv2d_op.h" +#include "paddle/operators/conv_op.h" namespace ops = paddle::operators; @@ -20,3 +20,8 @@ REGISTER_OP_GPU_KERNEL( conv2d, ops::GemmConv2DKernel); REGISTER_OP_GPU_KERNEL( conv2d_grad, ops::GemmConvGrad2DKernel); + +REGISTER_OP_GPU_KERNEL( + conv3d, ops::GemmConv3DKernel); +REGISTER_OP_GPU_KERNEL( + conv3d_grad, ops::GemmConvGrad3DKernel); diff --git a/paddle/operators/conv2d_op.h b/paddle/operators/conv_op.h similarity index 51% rename from paddle/operators/conv2d_op.h rename to paddle/operators/conv_op.h index 0621389a79eee..e39b1ffeb6d88 100644 --- a/paddle/operators/conv2d_op.h +++ b/paddle/operators/conv_op.h @@ -18,6 +18,7 @@ limitations under the License. */ #include "paddle/framework/op_registry.h" #include "paddle/operators/math/im2col.h" #include "paddle/operators/math/math_function.h" +#include "paddle/operators/math/vol2col.h" namespace paddle { namespace operators { @@ -40,14 +41,20 @@ class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker { framework::OpAttrChecker* op_checker); }; -class Conv2DOp : public framework::OperatorWithKernel { +class Conv3DOpMaker : public framework::OpProtoAndCheckerMaker { + public: + Conv3DOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker); +}; + +class ConvOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override; }; -class Conv2DOpGrad : public framework::OperatorWithKernel { +class ConvOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -251,5 +258,218 @@ class GemmConvGrad2DKernel : public framework::OpKernel { } }; +template +class GemmConv3DKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("Input"); + // The filter will be reshaped in the calculations, + // so here use an assignment operation, + // that avoids modifying the variable in the Scope. + Tensor filter = *context.Input("Filter"); + Tensor* output = context.Output("Output"); + output->mutable_data(context.GetPlace()); + + std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + int groups = context.Attr("groups"); + + int batch_size = input->dims()[0]; + int input_channels = input->dims()[1]; + int filter_depth = filter.dims()[filter.dims().size() - 3]; + int filter_height = filter.dims()[filter.dims().size() - 2]; + int filter_width = filter.dims()[filter.dims().size() - 1]; + int output_channels = output->dims()[1]; + int output_depth = output->dims()[2]; + int output_height = output->dims()[3]; + int output_width = output->dims()[4]; + + paddle::operators::math::Vol2ColFunctor vol2col; + // use col_shape in the vol2col calculation + framework::DDim col_shape = {input_channels / groups, + filter_depth, + filter_height, + filter_width, + output_depth, + output_height, + output_width}; + // use col_matrix_shape in the gemm calculation + framework::DDim col_matrix_shape = { + input_channels / groups * filter_depth * filter_height * filter_width, + output_depth * output_height * output_width}; + Tensor col; + col.mutable_data(col_shape, context.GetPlace()); + // col_matrix shares the same piece of data with col, + // but will be reshaped into a two-dimensional matrix shape + // to call the matrix multiplication interface. + Tensor col_matrix = col; + col_matrix.Resize(col_matrix_shape); + + framework::DDim input_shape = { + input->dims()[1], input->dims()[2], input->dims()[3], + input->dims()[4]}; // channel, depth, height, width + framework::DDim filter_matrix_shape = { + filter.dims()[0], + filter.numel() / filter.dims()[0]}; // filter_out_channel, + // filter_in_channel*filter_depth*filter_height*filter_width + filter.Resize(filter_matrix_shape); + + framework::DDim output_matrix_shape = { + output_channels, output_depth * output_height * output_width}; + + // convolution operator: vol2col + gemm + int in_step = input_channels / groups; + int out_step = output_channels / groups; + for (int i = 0; i < batch_size; i++) { + Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); + for (int g = 0; g < groups; g++) { + // vol2col + Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); + vol2col(context.device_context(), in_slice, col, strides[0], strides[1], + strides[2], paddings[0], paddings[1], paddings[2]); + + // gemm + Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); + math::matmul(context.device_context(), filter_slice, false, + col_matrix, false, T(1.0), &out_slice, T(0.0)); + } + } + } +}; + +template +class GemmConvGrad3DKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("Input"); + const Tensor* output_grad = + context.Input(framework::GradVarName("Output")); + Tensor* input_grad = + context.Output(framework::GradVarName("Input")); + Tensor* filter_grad = + context.Output(framework::GradVarName("Filter")); + + // The filter and filter_grad will be reshaped in the calculations, + // so here use an assignment operation, + // that avoids modifying the variable in the Scope. + Tensor filter = *context.Input("Filter"); + + std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + int groups = context.Attr("groups"); + + int batch_size = input->dims()[0]; + int input_channels = input->dims()[1]; + int filter_depth = filter.dims()[filter.dims().size() - 3]; + int filter_height = filter.dims()[filter.dims().size() - 2]; + int filter_width = filter.dims()[filter.dims().size() - 1]; + int output_channels = output_grad->dims()[1]; + int output_depth = output_grad->dims()[2]; + int output_height = output_grad->dims()[3]; + int output_width = output_grad->dims()[4]; + + paddle::operators::math::Col2VolFunctor col2vol; + paddle::operators::math::Vol2ColFunctor vol2col; + // use col_shape in the vol2col and col2vol calculation + framework::DDim col_shape = {input_channels / groups, + filter_depth, + filter_height, + filter_width, + output_depth, + output_height, + output_width}; + // use col_matrix_shape in the gemm calculation + framework::DDim col_matrix_shape = { + input_channels / groups * filter_depth * filter_height * filter_width, + output_depth * output_height * output_width}; + Tensor col; + col.mutable_data(col_shape, context.GetPlace()); + // col_matrix shares the same piece of data with col, + // but will be reshaped into a two-dimensional matrix shape + // to call the matrix multiplication interface. + Tensor col_matrix = col; + col_matrix.Resize(col_matrix_shape); + + framework::DDim input_shape = { + input->dims()[1], input->dims()[2], input->dims()[3], + input->dims()[4]}; // channel, depth, height, width + framework::DDim output_matrix_shape = {output_grad->dims()[1], + output_grad->dims()[2] * + output_grad->dims()[3] * + output_grad->dims()[4]}; + + framework::DDim filter_matrix_shape = { + filter.dims()[0], + filter.numel() / filter.dims()[0]}; // filter_out_channel, + // filter_in_channel*filter_depth*filter_height*filter_width + filter.Resize(filter_matrix_shape); + + // convolution backward input operator: gemm + col2vol + // convolution backward weight operator: vol2col + gemm + int in_step = input_channels / groups; + int out_step = output_channels / groups; + + if (input_grad) { + input_grad->mutable_data(context.GetPlace()); + auto t = framework::EigenVector::Flatten(*input_grad); + t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); + + for (int i = 0; i < batch_size; i++) { + Tensor out_grad_batch = + output_grad->Slice(i, i + 1).Resize(output_matrix_shape); + Tensor in_grad_batch = input_grad->Slice(i, i + 1).Resize(input_shape); + for (int g = 0; g < groups; g++) { + // gemm + Tensor out_grad_slice = + out_grad_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); + math::matmul(context.device_context(), filter_slice, true, + out_grad_slice, false, T(1.0), &col_matrix, + T(0.0)); + + // col2vol + Tensor in_grad_slice = + in_grad_batch.Slice(g * in_step, (g + 1) * in_step); + col2vol(context.device_context(), in_grad_slice, col, strides[0], + strides[1], strides[2], paddings[0], paddings[1], + paddings[2]); + } + } + } + + if (filter_grad) { + filter_grad->mutable_data(context.GetPlace()); + Tensor filter_grad_ = *filter_grad; + filter_grad_.Resize(filter_matrix_shape); + auto t = framework::EigenVector::Flatten(filter_grad_); + t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); + + for (int i = 0; i < batch_size; i++) { + Tensor out_grad_batch = + output_grad->Slice(i, i + 1).Resize(output_matrix_shape); + Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + for (int g = 0; g < groups; g++) { + // vol2col + Tensor out_grad_slice = + out_grad_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); + vol2col(context.device_context(), in_slice, col, strides[0], + strides[1], strides[2], paddings[0], paddings[1], + paddings[2]); + + // gemm + Tensor filter_grad_slice = + filter_grad_.Slice(g * out_step, (g + 1) * out_step); + math::matmul(context.device_context(), out_grad_slice, + false, col_matrix, true, T(1.0), + &filter_grad_slice, T(1.0)); + } + } + } + } +}; + } // namespace operators } // namespace paddle From 172481534ddde5de01e2b6b2603f17c36c26e294 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Mon, 30 Oct 2017 17:27:14 +0800 Subject: [PATCH 08/10] fix code format and doc --- paddle/operators/conv_op.cc | 41 ++++++++++++++----- paddle/operators/conv_op.h | 18 +++----- .../v2/framework/tests/test_conv2d_op.py | 3 ++ 3 files changed, 40 insertions(+), 22 deletions(-) diff --git a/paddle/operators/conv_op.cc b/paddle/operators/conv_op.cc index 5e264d730c457..1250900d154c4 100644 --- a/paddle/operators/conv_op.cc +++ b/paddle/operators/conv_op.cc @@ -33,6 +33,8 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { int input_channels = in_dims[1]; int output_channels = filter_dims[0]; + PADDLE_ENFORCE(in_dims.size() == 4 || in_dims.size() == 5, + "Conv intput should be 4-D or 5-D tensor."); PADDLE_ENFORCE_EQ( in_dims.size(), filter_dims.size(), "Conv input dimension and filter dimension should be the same."); @@ -62,26 +64,30 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto, : OpProtoAndCheckerMaker(proto, op_checker) { AddInput( "Input", - "The input tensor of convolution operator. " + "(Tensor), the input tensor of convolution operator. " "The format of input tensor is NCHW. Where N is batch size, C is the " "number of channels, H and W is the height and width of image."); AddInput("Filter", - "The filter tensor of convolution operator." + "(Tensor), the filter tensor of convolution operator." "The format of the filter tensor is MCHW, where M is the number of " "output image channels, C is the number of input image channels, " "H and W is height and width of filter. " "If the groups attribute is greater than 1, C equal the number of " "input image channels divided by the groups."); AddOutput("Output", - "The output tensor of convolution operator." - "The format of output tensor is also NCHW."); - AddAttr>("strides", "strides of convolution operator.") + "(Tensor), the output tensor of convolution operator." + "The format of output tensor is also NCHW. Where N is batch size, " + "C is the " + "number of channels, H and W is the height and width of image."); + AddAttr>( + "strides", "(vector default:{1, 1}), strides of convolution operator.") .SetDefault({1, 1}); - AddAttr>("paddings", "paddings of convolution operator.") + AddAttr>( + "paddings", "(vector default:{0, 0}), paddings of convolution operator.") .SetDefault({0, 0}); AddAttr( "groups", - "group size of convolution operator. " + "(int, default:1), group size of convolution operator. " "Refer to grouped convolution in Alex Krizhevsky's paper: " "when group=2, the first half of the filters are only connected to the " "first half of the input channels, and the second half only connected " @@ -91,6 +97,21 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto, The convolution operation calculates the output based on the input, filter and strides, paddings, groups parameters. The size of each dimension of the parameters is checked in the infer-shape. +Input(Input, Filter) and output(Output) are in NCHW format. Where N is batch +size, C is the number of channels, H and W is the height and +width of feature. Parameters(ksize, strides, paddings) are two elements. +These two elements represent height and width, respectively. +The input(X) size and output(Out) size may be different. + +Example: + Input: + Input shape: (N, C_in, H_in, W_in) + Filter shape: (C_out, C_in, H_f, W_f) + Output: + Output shape: (N, C_out, H_out, W_out) + where + H_out = (H_in - filter_size[0] + 2 * paddings[0]) / strides[0] + 1; + W_out = (W_in - filter_size[1] + 2 * paddings[1]) / strides[1] + 1; )DOC"); } @@ -115,15 +136,15 @@ Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto, "The format of output tensor is also NCDHW."); AddAttr>( "strides", - "(vector, default {0,0,0}), the strides of convolution operator.") + "(vector, default:{0, 0, 0}), the strides of convolution operator.") .SetDefault({1, 1, 1}); AddAttr>( "paddings", - "(vector, default {0,0,0}), the paddings of convolution operator.") + "(vector, default:{0, 0, 0}), the paddings of convolution operator.") .SetDefault({0, 0, 0}); AddAttr( "groups", - "(int, default 1) the group size of convolution operator. " + "(int, default:1) the group size of convolution operator. " "Refer to grouped convolution in Alex Krizhevsky's paper: " "when group=2, the first half of the filters are only connected to the " "first half of the input channels, and the second half only connected " diff --git a/paddle/operators/conv_op.h b/paddle/operators/conv_op.h index 7e8f5d75bb6be..198e51e4ad4c4 100644 --- a/paddle/operators/conv_op.h +++ b/paddle/operators/conv_op.h @@ -85,9 +85,7 @@ class GemmConv2DKernel : public framework::OpKernel { int output_height = output->dims()[2]; int output_width = output->dims()[3]; - paddle::operators::math::Im2ColFunctor< - paddle::operators::math::ColFormat::kCFO, Place, T> - im2col; + math::Im2ColFunctor im2col; // use col_shape in the im2col calculation framework::DDim col_shape = {input_channels / groups, filter_height, filter_width, output_height, output_width}; @@ -162,12 +160,8 @@ class GemmConvGrad2DKernel : public framework::OpKernel { int output_height = output_grad->dims()[2]; int output_width = output_grad->dims()[3]; - paddle::operators::math::Col2ImFunctor< - paddle::operators::math::ColFormat::kCFO, Place, T> - col2im; - paddle::operators::math::Im2ColFunctor< - paddle::operators::math::ColFormat::kCFO, Place, T> - im2col; + math::Col2ImFunctor col2im; + math::Im2ColFunctor im2col; // use col_shape in the im2col and col2im calculation framework::DDim col_shape = {input_channels / groups, filter_height, filter_width, output_height, output_width}; @@ -283,7 +277,7 @@ class GemmConv3DKernel : public framework::OpKernel { int output_height = output->dims()[3]; int output_width = output->dims()[4]; - paddle::operators::math::Vol2ColFunctor vol2col; + math::Vol2ColFunctor vol2col; // use col_shape in the vol2col calculation framework::DDim col_shape = {input_channels / groups, filter_depth, @@ -369,8 +363,8 @@ class GemmConvGrad3DKernel : public framework::OpKernel { int output_height = output_grad->dims()[3]; int output_width = output_grad->dims()[4]; - paddle::operators::math::Col2VolFunctor col2vol; - paddle::operators::math::Vol2ColFunctor vol2col; + math::Col2VolFunctor col2vol; + math::Vol2ColFunctor vol2col; // use col_shape in the vol2col and col2vol calculation framework::DDim col_shape = {input_channels / groups, filter_depth, diff --git a/python/paddle/v2/framework/tests/test_conv2d_op.py b/python/paddle/v2/framework/tests/test_conv2d_op.py index f58b96463cf78..6bd4bad8e2db5 100644 --- a/python/paddle/v2/framework/tests/test_conv2d_op.py +++ b/python/paddle/v2/framework/tests/test_conv2d_op.py @@ -103,6 +103,9 @@ def init_op_type(self): self.op_type = "conv2d" +#----------------Conv2dCudnn---------------- + + class TestCudnn(TestConv2dOp): def init_group(self): self.groups = 1 From 8ac1178707fed50d3061445ee410d6987e3b70de Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Mon, 6 Nov 2017 10:26:01 +0800 Subject: [PATCH 09/10] fix doc --- paddle/operators/conv_op.cc | 76 +++++++++++++++++++------------------ 1 file changed, 40 insertions(+), 36 deletions(-) diff --git a/paddle/operators/conv_op.cc b/paddle/operators/conv_op.cc index 1250900d154c4..54ac4f4111445 100644 --- a/paddle/operators/conv_op.cc +++ b/paddle/operators/conv_op.cc @@ -64,42 +64,41 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto, : OpProtoAndCheckerMaker(proto, op_checker) { AddInput( "Input", - "(Tensor), the input tensor of convolution operator. " - "The format of input tensor is NCHW. Where N is batch size, C is the " - "number of channels, H and W is the height and width of image."); + "(Tensor) The input tensor of convolution operator. " + "The format of input tensor is NCHW, where N is batch size, C is the " + "number of channels, H is the height of the feature, " + "and W is the width of the feature."); AddInput("Filter", - "(Tensor), the filter tensor of convolution operator." + "(Tensor) The filter tensor of convolution operator. " "The format of the filter tensor is MCHW, where M is the number of " "output image channels, C is the number of input image channels, " - "H and W is height and width of filter. " - "If the groups attribute is greater than 1, C equal the number of " + "H is the height of the filter, and W is the width of the filter. " + "If the groups attribute is greater than 1, C equals the number of " "input image channels divided by the groups."); AddOutput("Output", - "(Tensor), the output tensor of convolution operator." - "The format of output tensor is also NCHW. Where N is batch size, " - "C is the " - "number of channels, H and W is the height and width of image."); - AddAttr>( - "strides", "(vector default:{1, 1}), strides of convolution operator.") + "(Tensor) The output tensor of convolution operator. " + "The format of output tensor is also NCHW."); + AddAttr>("strides", "strides of convolution operator.") .SetDefault({1, 1}); - AddAttr>( - "paddings", "(vector default:{0, 0}), paddings of convolution operator.") + AddAttr>("paddings", "paddings of convolution operator.") .SetDefault({0, 0}); AddAttr( "groups", - "(int, default:1), group size of convolution operator. " - "Refer to grouped convolution in Alex Krizhevsky's paper: " - "when group=2, the first half of the filters are only connected to the " - "first half of the input channels, and the second half only connected " - "to the second half.") + "(int default:1), the group size of convolution operator. " + "According to grouped convolution in Alex Krizhevsky's Deep CNN paper: " + "when group=2, the first half of the filters is only connected to the " + "first half of the input channels, while the second half of the filters " + "is only connected to the second half of the input channels.") .SetDefault(1); AddComment(R"DOC( +Convolution Operator. + The convolution operation calculates the output based on the input, filter and strides, paddings, groups parameters. The size of each dimension of the parameters is checked in the infer-shape. Input(Input, Filter) and output(Output) are in NCHW format. Where N is batch -size, C is the number of channels, H and W is the height and -width of feature. Parameters(ksize, strides, paddings) are two elements. +size, C is the number of channels, H is the height of the feature, and W is +the width of the feature. Parameters(ksize, strides, paddings) are two elements. These two elements represent height and width, respectively. The input(X) size and output(Out) size may be different. @@ -120,19 +119,21 @@ Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto, : OpProtoAndCheckerMaker(proto, op_checker) { AddInput( "Input", - "(Tensor), the input tensor of convolution operator. " + "(Tensor) The input tensor of convolution operator. " "The format of input tensor is NCDHW. Where N is batch size, C is the " - "number of channels, D, H and W is the depth, height and width of " - "image."); + "number of channels, D is the depth of the feature, H is the height of " + "the feature, " + "and W is the width of the feature."); AddInput("Filter", - "(Tensor), the filter tensor of convolution operator." + "(Tensor) The filter tensor of convolution operator. " "The format of the filter tensor is MCDHW, where M is the number of " "output image channels, C is the number of input image channels, " - "D, H and W is depth, height and width of filter. " - "If the groups attribute is greater than 1, C equal the number of " + "D is the depth of the filter, H is the height of the filter, and W " + "is the width of the filter." + "If the groups attribute is greater than 1, C equals the number of " "input image channels divided by the groups."); AddOutput("Output", - "(Tensor), the output tensor of convolution operator." + "(Tensor) The output tensor of convolution operator." "The format of output tensor is also NCDHW."); AddAttr>( "strides", @@ -144,20 +145,23 @@ Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto, .SetDefault({0, 0, 0}); AddAttr( "groups", - "(int, default:1) the group size of convolution operator. " - "Refer to grouped convolution in Alex Krizhevsky's paper: " - "when group=2, the first half of the filters are only connected to the " - "first half of the input channels, and the second half only connected " - "to the second half.") + "(int default:1), the group size of convolution operator. " + "According to grouped convolution in Alex Krizhevsky's Deep CNN paper: " + "when group=2, the first half of the filters is only connected to the " + "first half of the input channels, while the second half of the filters " + "is only connected to the second half of the input channels.") .SetDefault(1); + AddComment(R"DOC( +Convolution3D Operator. + The convolution operation calculates the output based on the input, filter and strides, paddings, groups parameters. The size of each dimension of the parameters is checked in the infer-shape. Input(Input, Filter) and output(Output) are in NCDHW format. Where N is batch -size, C is the number of channels, D, H and W is the depth, height and -width of feature. Parameters(ksize, strides, paddings) are three elements. -These three elements represent depth, height and width, respectively. +size, C is the number of channels,D is the depth of the feature, H is the height of +the feature, and W is the width of the feature. Parameters(ksize, strides, paddings) +are three elements. These three elements represent depth, height and width, respectively. The input(X) size and output(Out) size may be different. Example: From f302c6a3b4582cc3305940406a77bd437025512c Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Mon, 6 Nov 2017 15:01:01 +0800 Subject: [PATCH 10/10] write conv2d and conv3d together --- paddle/operators/conv_cudnn_op.cc | 6 +- paddle/operators/conv_op.cc | 12 +- paddle/operators/conv_op.cu | 12 +- paddle/operators/conv_op.h | 395 ++++++------------ .../v2/framework/tests/test_conv2d_op.py | 8 +- .../v2/framework/tests/test_conv3d_op.py | 6 +- 6 files changed, 145 insertions(+), 294 deletions(-) diff --git a/paddle/operators/conv_cudnn_op.cc b/paddle/operators/conv_cudnn_op.cc index a068daf9a812f..97f31bf22d707 100644 --- a/paddle/operators/conv_cudnn_op.cc +++ b/paddle/operators/conv_cudnn_op.cc @@ -41,8 +41,8 @@ namespace ops = paddle::operators; REGISTER_OP(conv_cudnn, ops::ConvOp, ops::CudnnConvOpMaker, conv_cudnn_grad, ops::ConvOpGrad); -REGISTER_OP_CPU_KERNEL( - conv_cudnn, ops::GemmConv2DKernel); +REGISTER_OP_CPU_KERNEL(conv_cudnn, + ops::GemmConvKernel); REGISTER_OP_CPU_KERNEL( conv_cudnn_grad, - ops::GemmConvGrad2DKernel); + ops::GemmConvGradKernel); diff --git a/paddle/operators/conv_op.cc b/paddle/operators/conv_op.cc index 54ac4f4111445..a6f65f1016592 100644 --- a/paddle/operators/conv_op.cc +++ b/paddle/operators/conv_op.cc @@ -198,12 +198,12 @@ namespace ops = paddle::operators; REGISTER_OP(conv3d, ops::ConvOp, ops::Conv3DOpMaker, conv3d_grad, ops::ConvOpGrad); +REGISTER_OP_CPU_KERNEL(conv2d, + ops::GemmConvKernel); REGISTER_OP_CPU_KERNEL( - conv2d, ops::GemmConv2DKernel); -REGISTER_OP_CPU_KERNEL( - conv2d_grad, ops::GemmConvGrad2DKernel); + conv2d_grad, ops::GemmConvGradKernel); +REGISTER_OP_CPU_KERNEL(conv3d, + ops::GemmConvKernel); REGISTER_OP_CPU_KERNEL( - conv3d, ops::GemmConv3DKernel); -REGISTER_OP_CPU_KERNEL( - conv3d_grad, ops::GemmConvGrad3DKernel); + conv3d_grad, ops::GemmConvGradKernel); diff --git a/paddle/operators/conv_op.cu b/paddle/operators/conv_op.cu index d8c0bd9326bb9..8e6f9da455b72 100644 --- a/paddle/operators/conv_op.cu +++ b/paddle/operators/conv_op.cu @@ -16,12 +16,12 @@ namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(conv2d, + ops::GemmConvKernel); REGISTER_OP_GPU_KERNEL( - conv2d, ops::GemmConv2DKernel); -REGISTER_OP_GPU_KERNEL( - conv2d_grad, ops::GemmConvGrad2DKernel); + conv2d_grad, ops::GemmConvGradKernel); +REGISTER_OP_GPU_KERNEL(conv3d, + ops::GemmConvKernel); REGISTER_OP_GPU_KERNEL( - conv3d, ops::GemmConv3DKernel); -REGISTER_OP_GPU_KERNEL( - conv3d_grad, ops::GemmConvGrad3DKernel); + conv3d_grad, ops::GemmConvGradKernel); diff --git a/paddle/operators/conv_op.h b/paddle/operators/conv_op.h index 198e51e4ad4c4..7c1729213bf3f 100644 --- a/paddle/operators/conv_op.h +++ b/paddle/operators/conv_op.h @@ -62,7 +62,7 @@ class ConvOpGrad : public framework::OperatorWithKernel { }; template -class GemmConv2DKernel : public framework::OpKernel { +class GemmConvKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { const Tensor* input = context.Input("Input"); @@ -77,49 +77,78 @@ class GemmConv2DKernel : public framework::OpKernel { std::vector paddings = context.Attr>("paddings"); int groups = context.Attr("groups"); - int batch_size = input->dims()[0]; - int input_channels = input->dims()[1]; - int filter_height = filter.dims()[filter.dims().size() - 2]; - int filter_width = filter.dims()[filter.dims().size() - 1]; - int output_channels = output->dims()[1]; - int output_height = output->dims()[2]; - int output_width = output->dims()[3]; + const int batch_size = static_cast(input->dims()[0]); + + // filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w} + std::vector filter_shape_vec(framework::vectorize(filter.dims())); + filter_shape_vec.erase(filter_shape_vec.begin(), + filter_shape_vec.begin() + 2); + + // output_shape_vec: {o_h, o_w} or {o_d, o_h, o_w} + std::vector output_shape_vec(framework::vectorize(output->dims())); + output_shape_vec.erase(output_shape_vec.begin(), + output_shape_vec.begin() + 2); - math::Im2ColFunctor im2col; // use col_shape in the im2col calculation - framework::DDim col_shape = {input_channels / groups, filter_height, - filter_width, output_height, output_width}; + // col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d, + // o_h, o_w} + std::vector col_shape_vec; + col_shape_vec.push_back(input->dims()[1] / groups); + col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin(), + filter_shape_vec.end()); + col_shape_vec.insert(col_shape_vec.end(), output_shape_vec.begin(), + output_shape_vec.end()); + framework::DDim col_shape(framework::make_ddim(col_shape_vec)); + // use col_matrix_shape in the gemm calculation - framework::DDim col_matrix_shape = { - input_channels / groups * filter_height * filter_width, - output_height * output_width}; + // size: (i_c/g * k_h * k_w, o_h * o_w) or (i_c/g * k_d * k_h * k_w, o_d * + // o_h * o_w) + framework::DDim col_matrix_shape = + framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1); + Tensor col; col.mutable_data(col_shape, context.GetPlace()); // col_matrix shares the same piece of data with col, // but will be reshaped into a two-dimensional matrix shape // to call the matrix multiplication interface. - Tensor col_matrix = col; + Tensor col_matrix; + col_matrix.ShareDataWith(col); col_matrix.Resize(col_matrix_shape); - framework::DDim input_shape = {input->dims()[1], input->dims()[2], - input->dims()[3]}; + framework::DDim input_shape = framework::slice_ddim( + input->dims(), 1, static_cast(input->dims().size())); + framework::DDim filter_matrix_shape = {filter.dims()[0], filter.numel() / filter.dims()[0]}; filter.Resize(filter_matrix_shape); - framework::DDim output_matrix_shape = {output_channels, - output_height * output_width}; - // convolution operator: im2col + gemm - int in_step = input_channels / groups; - int out_step = output_channels / groups; + framework::DDim output_matrix_shape = { + output->dims()[1], + output->numel() / (output->dims()[0] * output->dims()[1])}; + + // convolution operator: im2col(or vol2col) + gemm + int in_step = static_cast(input->dims()[1]) / groups; + int out_step = static_cast(output->dims()[1]) / groups; + for (int i = 0; i < batch_size; i++) { Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); for (int g = 0; g < groups; g++) { - // im2col Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); - im2col(context.device_context(), in_slice, col, strides[0], strides[1], - paddings[0], paddings[0], paddings[1], paddings[1]); + + if (filter_shape_vec.size() == 2) { + // im2col + math::Im2ColFunctor im2col; + im2col(context.device_context(), in_slice, col, strides[0], + strides[1], paddings[0], paddings[0], paddings[1], + paddings[1]); + } else if (filter_shape_vec.size() == 3) { + // vol2col + math::Vol2ColFunctor vol2col; + vol2col(context.device_context(), in_slice, col, strides[0], + strides[1], strides[2], paddings[0], paddings[1], + paddings[2]); + } // gemm Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); @@ -132,7 +161,7 @@ class GemmConv2DKernel : public framework::OpKernel { }; template -class GemmConvGrad2DKernel : public framework::OpKernel { +class GemmConvGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { const Tensor* input = context.Input("Input"); @@ -142,267 +171,74 @@ class GemmConvGrad2DKernel : public framework::OpKernel { context.Output(framework::GradVarName("Input")); Tensor* filter_grad = context.Output(framework::GradVarName("Filter")); - // The filter and filter_grad will be reshaped in the calculations, // so here use an assignment operation, // that avoids modifying the variable in the Scope. Tensor filter = *context.Input("Filter"); + if (!input_grad && !filter_grad) return; + std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); int groups = context.Attr("groups"); - int batch_size = input->dims()[0]; - int input_channels = input->dims()[1]; - int filter_height = filter.dims()[filter.dims().size() - 2]; - int filter_width = filter.dims()[filter.dims().size() - 1]; - int output_channels = output_grad->dims()[1]; - int output_height = output_grad->dims()[2]; - int output_width = output_grad->dims()[3]; - - math::Col2ImFunctor col2im; - math::Im2ColFunctor im2col; - // use col_shape in the im2col and col2im calculation - framework::DDim col_shape = {input_channels / groups, filter_height, - filter_width, output_height, output_width}; - // use col_matrix_shape in the gemm calculation - framework::DDim col_matrix_shape = { - input_channels / groups * filter_height * filter_width, - output_height * output_width}; - Tensor col; - col.mutable_data(col_shape, context.GetPlace()); - // col_matrix shares the same piece of data with col, - // but will be reshaped into a two-dimensional matrix shape - // to call the matrix multiplication interface. - Tensor col_matrix = col; - col_matrix.Resize(col_matrix_shape); - - framework::DDim input_shape = {input->dims()[1], input->dims()[2], - input->dims()[3]}; - framework::DDim output_matrix_shape = { - output_grad->dims()[1], - output_grad->dims()[2] * output_grad->dims()[3]}; - - framework::DDim filter_matrix_shape = {filter.dims()[0], - filter.numel() / filter.dims()[0]}; - filter.Resize(filter_matrix_shape); - - // convolution backward input operator: gemm + col2im - // convolution backward weight operator: im2col + gemm - int in_step = input_channels / groups; - int out_step = output_channels / groups; - math::SetConstant set_zero; - - if (input_grad) { - input_grad->mutable_data(context.GetPlace()); - set_zero(context.device_context(), input_grad, static_cast(0)); - - for (int i = 0; i < batch_size; i++) { - Tensor out_grad_batch = - output_grad->Slice(i, i + 1).Resize(output_matrix_shape); - Tensor in_grad_batch = input_grad->Slice(i, i + 1).Resize(input_shape); - for (int g = 0; g < groups; g++) { - // gemm - Tensor out_grad_slice = - out_grad_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(context.device_context(), filter_slice, true, - out_grad_slice, false, T(1.0), &col_matrix, - T(0.0)); - - // col2im - Tensor in_grad_slice = - in_grad_batch.Slice(g * in_step, (g + 1) * in_step); - col2im(context.device_context(), in_grad_slice, col, strides[0], - strides[1], paddings[0], paddings[0], paddings[1], - paddings[1]); - } - } - } - - if (filter_grad) { - filter_grad->mutable_data(context.GetPlace()); - Tensor filter_grad_ = *filter_grad; - filter_grad_.Resize(filter_matrix_shape); - set_zero(context.device_context(), filter_grad, static_cast(0)); + const int batch_size = static_cast(input->dims()[0]); - for (int i = 0; i < batch_size; i++) { - Tensor out_grad_batch = - output_grad->Slice(i, i + 1).Resize(output_matrix_shape); - Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); - for (int g = 0; g < groups; g++) { - // im2col - Tensor out_grad_slice = - out_grad_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); - im2col(context.device_context(), in_slice, col, strides[0], - strides[1], paddings[0], paddings[0], paddings[1], - paddings[1]); - - // gemm - Tensor filter_grad_slice = - filter_grad_.Slice(g * out_step, (g + 1) * out_step); - math::matmul(context.device_context(), out_grad_slice, - false, col_matrix, true, T(1.0), - &filter_grad_slice, T(1.0)); - } - } - } - } -}; + // filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w} + std::vector filter_shape_vec(framework::vectorize(filter.dims())); + filter_shape_vec.erase(filter_shape_vec.begin(), + filter_shape_vec.begin() + 2); -template -class GemmConv3DKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor* input = context.Input("Input"); - // The filter will be reshaped in the calculations, - // so here use an assignment operation, - // that avoids modifying the variable in the Scope. - Tensor filter = *context.Input("Filter"); - Tensor* output = context.Output("Output"); - output->mutable_data(context.GetPlace()); + // output_shape_vec: {o_h, o_w} or {o_d, o_h, o_w} + std::vector output_shape_vec( + framework::vectorize(output_grad->dims())); + output_shape_vec.erase(output_shape_vec.begin(), + output_shape_vec.begin() + 2); - std::vector strides = context.Attr>("strides"); - std::vector paddings = context.Attr>("paddings"); - int groups = context.Attr("groups"); + // use col_shape in the im2col calculation + // col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d, + // o_h, o_w} + std::vector col_shape_vec; + col_shape_vec.push_back(input->dims()[1] / groups); + col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin(), + filter_shape_vec.end()); + col_shape_vec.insert(col_shape_vec.end(), output_shape_vec.begin(), + output_shape_vec.end()); + framework::DDim col_shape(framework::make_ddim(col_shape_vec)); - int batch_size = input->dims()[0]; - int input_channels = input->dims()[1]; - int filter_depth = filter.dims()[filter.dims().size() - 3]; - int filter_height = filter.dims()[filter.dims().size() - 2]; - int filter_width = filter.dims()[filter.dims().size() - 1]; - int output_channels = output->dims()[1]; - int output_depth = output->dims()[2]; - int output_height = output->dims()[3]; - int output_width = output->dims()[4]; - - math::Vol2ColFunctor vol2col; - // use col_shape in the vol2col calculation - framework::DDim col_shape = {input_channels / groups, - filter_depth, - filter_height, - filter_width, - output_depth, - output_height, - output_width}; // use col_matrix_shape in the gemm calculation - framework::DDim col_matrix_shape = { - input_channels / groups * filter_depth * filter_height * filter_width, - output_depth * output_height * output_width}; - Tensor col; - col.mutable_data(col_shape, context.GetPlace()); - // col_matrix shares the same piece of data with col, - // but will be reshaped into a two-dimensional matrix shape - // to call the matrix multiplication interface. - Tensor col_matrix = col; - col_matrix.Resize(col_matrix_shape); + // size: (i_c/g * k_h * k_w, o_h * o_w) + // or + // (i_c/g * k_d * k_h * k_w, o_d * o_h * o_w) + framework::DDim col_matrix_shape = + framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1); + + framework::DDim input_shape = framework::slice_ddim( + input->dims(), 1, static_cast(input->dims().size())); - framework::DDim input_shape = { - input->dims()[1], input->dims()[2], input->dims()[3], - input->dims()[4]}; // channel, depth, height, width - framework::DDim filter_matrix_shape = { - filter.dims()[0], - filter.numel() / filter.dims()[0]}; // filter_out_channel, - // filter_in_channel*filter_depth*filter_height*filter_width + framework::DDim filter_matrix_shape = {filter.dims()[0], + filter.numel() / filter.dims()[0]}; filter.Resize(filter_matrix_shape); framework::DDim output_matrix_shape = { - output_channels, output_depth * output_height * output_width}; - - // convolution operator: vol2col + gemm - int in_step = input_channels / groups; - int out_step = output_channels / groups; - for (int i = 0; i < batch_size; i++) { - Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); - Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); - for (int g = 0; g < groups; g++) { - // vol2col - Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); - vol2col(context.device_context(), in_slice, col, strides[0], strides[1], - strides[2], paddings[0], paddings[1], paddings[2]); - - // gemm - Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(context.device_context(), filter_slice, false, - col_matrix, false, T(1.0), &out_slice, T(0.0)); - } - } - } -}; - -template -class GemmConvGrad3DKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor* input = context.Input("Input"); - const Tensor* output_grad = - context.Input(framework::GradVarName("Output")); - Tensor* input_grad = - context.Output(framework::GradVarName("Input")); - Tensor* filter_grad = - context.Output(framework::GradVarName("Filter")); - - // The filter and filter_grad will be reshaped in the calculations, - // so here use an assignment operation, - // that avoids modifying the variable in the Scope. - Tensor filter = *context.Input("Filter"); + output_grad->dims()[1], + output_grad->numel() / + (output_grad->dims()[0] * output_grad->dims()[1])}; - std::vector strides = context.Attr>("strides"); - std::vector paddings = context.Attr>("paddings"); - int groups = context.Attr("groups"); + // convolution backward input operator: gemm + col2im(or col2vol) + // convolution backward weight operator: im2col(or vol2col) + gemm + int in_step = static_cast(input->dims()[1]) / groups; + int out_step = static_cast(output_grad->dims()[1]) / groups; - int batch_size = input->dims()[0]; - int input_channels = input->dims()[1]; - int filter_depth = filter.dims()[filter.dims().size() - 3]; - int filter_height = filter.dims()[filter.dims().size() - 2]; - int filter_width = filter.dims()[filter.dims().size() - 1]; - int output_channels = output_grad->dims()[1]; - int output_depth = output_grad->dims()[2]; - int output_height = output_grad->dims()[3]; - int output_width = output_grad->dims()[4]; - - math::Col2VolFunctor col2vol; - math::Vol2ColFunctor vol2col; - // use col_shape in the vol2col and col2vol calculation - framework::DDim col_shape = {input_channels / groups, - filter_depth, - filter_height, - filter_width, - output_depth, - output_height, - output_width}; - // use col_matrix_shape in the gemm calculation - framework::DDim col_matrix_shape = { - input_channels / groups * filter_depth * filter_height * filter_width, - output_depth * output_height * output_width}; Tensor col; - col.mutable_data(col_shape, context.GetPlace()); // col_matrix shares the same piece of data with col, // but will be reshaped into a two-dimensional matrix shape // to call the matrix multiplication interface. - Tensor col_matrix = col; + Tensor col_matrix; + col.mutable_data(col_shape, context.GetPlace()); + col_matrix.ShareDataWith(col); col_matrix.Resize(col_matrix_shape); - framework::DDim input_shape = { - input->dims()[1], input->dims()[2], input->dims()[3], - input->dims()[4]}; // channel, depth, height, width - framework::DDim output_matrix_shape = {output_grad->dims()[1], - output_grad->dims()[2] * - output_grad->dims()[3] * - output_grad->dims()[4]}; - - framework::DDim filter_matrix_shape = { - filter.dims()[0], - filter.numel() / filter.dims()[0]}; // filter_out_channel, - // filter_in_channel*filter_depth*filter_height*filter_width - filter.Resize(filter_matrix_shape); - - // convolution backward input operator: gemm + col2vol - // convolution backward weight operator: vol2col + gemm - int in_step = input_channels / groups; - int out_step = output_channels / groups; math::SetConstant set_zero; if (input_grad) { @@ -421,13 +257,22 @@ class GemmConvGrad3DKernel : public framework::OpKernel { math::matmul(context.device_context(), filter_slice, true, out_grad_slice, false, T(1.0), &col_matrix, T(0.0)); - - // col2vol + // col2im Tensor in_grad_slice = in_grad_batch.Slice(g * in_step, (g + 1) * in_step); - col2vol(context.device_context(), in_grad_slice, col, strides[0], - strides[1], strides[2], paddings[0], paddings[1], - paddings[2]); + + if (filter_shape_vec.size() == 2) { + math::Col2ImFunctor col2im; + col2im(context.device_context(), in_grad_slice, col, strides[0], + strides[1], paddings[0], paddings[0], paddings[1], + paddings[1]); + + } else if (filter_shape_vec.size() == 3) { + math::Col2VolFunctor col2vol; + col2vol(context.device_context(), in_grad_slice, col, strides[0], + strides[1], strides[2], paddings[0], paddings[1], + paddings[2]); + } } } } @@ -443,13 +288,22 @@ class GemmConvGrad3DKernel : public framework::OpKernel { output_grad->Slice(i, i + 1).Resize(output_matrix_shape); Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); for (int g = 0; g < groups; g++) { - // vol2col + // im2col Tensor out_grad_slice = out_grad_batch.Slice(g * out_step, (g + 1) * out_step); Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); - vol2col(context.device_context(), in_slice, col, strides[0], - strides[1], strides[2], paddings[0], paddings[1], - paddings[2]); + + if (filter_shape_vec.size() == 2) { + math::Im2ColFunctor im2col; + im2col(context.device_context(), in_slice, col, strides[0], + strides[1], paddings[0], paddings[0], paddings[1], + paddings[1]); + } else if (filter_shape_vec.size() == 3) { + math::Vol2ColFunctor vol2col; + vol2col(context.device_context(), in_slice, col, strides[0], + strides[1], strides[2], paddings[0], paddings[1], + paddings[2]); + } // gemm Tensor filter_grad_slice = @@ -462,6 +316,5 @@ class GemmConvGrad3DKernel : public framework::OpKernel { } } }; - } // namespace operators } // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_conv2d_op.py b/python/paddle/v2/framework/tests/test_conv2d_op.py index 6bd4bad8e2db5..04ae7f294c27f 100644 --- a/python/paddle/v2/framework/tests/test_conv2d_op.py +++ b/python/paddle/v2/framework/tests/test_conv2d_op.py @@ -61,25 +61,23 @@ def test_check_output(self): def test_check_grad(self): self.check_grad( - set(['Input', 'Filter']), 'Output', max_relative_error=0.05) + set(['Input', 'Filter']), 'Output', max_relative_error=0.02) def test_check_grad_no_filter(self): self.check_grad( ['Input'], 'Output', - max_relative_error=0.05, + max_relative_error=0.02, no_grad_set=set(['Filter'])) def test_check_grad_no_input(self): self.check_grad( ['Filter'], 'Output', - max_relative_error=0.05, + max_relative_error=0.02, no_grad_set=set(['Input'])) def init_test_case(self): - # self.groups = 1 - # self.op_type = "conv2d" self.pad = [0, 0] self.stride = [1, 1] self.dilations = [1, 1] diff --git a/python/paddle/v2/framework/tests/test_conv3d_op.py b/python/paddle/v2/framework/tests/test_conv3d_op.py index f8e07fc562602..44c192f58d25f 100644 --- a/python/paddle/v2/framework/tests/test_conv3d_op.py +++ b/python/paddle/v2/framework/tests/test_conv3d_op.py @@ -64,20 +64,20 @@ def test_check_output(self): def test_check_grad(self): self.check_grad( - set(['Input', 'Filter']), 'Output', max_relative_error=0.05) + set(['Input', 'Filter']), 'Output', max_relative_error=0.03) def test_check_grad_no_filter(self): self.check_grad( ['Input'], 'Output', - max_relative_error=0.05, + max_relative_error=0.03, no_grad_set=set(['Filter'])) def test_check_grad_no_input(self): self.check_grad( ['Filter'], 'Output', - max_relative_error=0.05, + max_relative_error=0.03, no_grad_set=set(['Input'])) def init_test_case(self):