From 29f25fbe033e97f74123f2380d6e384ba840d0da Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Mon, 10 Jul 2017 12:26:35 +0800 Subject: [PATCH 1/9] Add pixel softmax layer for FCN model 1. Add switch function for switching image dimensions order 2. Add CpuMatrix::backwardSoftmax function 3. Add pixel softmax layer, python wrapper and grad_test --- paddle/function/CMakeLists.txt | 1 + paddle/function/SwitchOp.cpp | 132 ++++++++++++++++++ paddle/function/SwitchOp.h | 62 ++++++++ paddle/function/SwitchOpGpu.cu | 80 +++++++++++ paddle/function/SwitchOpTest.cpp | 44 ++++++ paddle/gserver/layers/PixelSoftmaxLayer.cpp | 89 ++++++++++++ paddle/gserver/layers/PixelSoftmaxLayer.h | 44 ++++++ paddle/gserver/tests/test_LayerGrad.cpp | 19 +++ paddle/math/Matrix.cpp | 21 +++ paddle/math/Matrix.h | 5 + python/paddle/trainer/config_parser.py | 16 +++ .../paddle/trainer_config_helpers/layers.py | 38 +++++ 12 files changed, 551 insertions(+) create mode 100644 paddle/function/SwitchOp.cpp create mode 100644 paddle/function/SwitchOp.h create mode 100644 paddle/function/SwitchOpGpu.cu create mode 100644 paddle/function/SwitchOpTest.cpp create mode 100644 paddle/gserver/layers/PixelSoftmaxLayer.cpp create mode 100644 paddle/gserver/layers/PixelSoftmaxLayer.h diff --git a/paddle/function/CMakeLists.txt b/paddle/function/CMakeLists.txt index 1518a8a654cfb..138f7dcf1680d 100644 --- a/paddle/function/CMakeLists.txt +++ b/paddle/function/CMakeLists.txt @@ -37,6 +37,7 @@ if(WITH_GPU) add_simple_unittest(MulOpTest) add_simple_unittest(CosSimOpTest) add_simple_unittest(RowConvOpTest) + add_simple_unittest(SwitchOpTest) endif() add_simple_unittest(ConvOpTest) diff --git a/paddle/function/SwitchOp.cpp b/paddle/function/SwitchOp.cpp new file mode 100644 index 0000000000000..4667c4e01d52a --- /dev/null +++ b/paddle/function/SwitchOp.cpp @@ -0,0 +1,132 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "SwitchOp.h" +#include "paddle/math/Vector.h" + +namespace paddle { + +template <> +void NCHW2NHWC(real* outputs, + const real* inputs, + const int num, + const int inC, + const int inH, + const int inW) { + for (int n = 0; n < num; ++n) { + for (int c = 0; c < inC; ++c) { + for (int h = 0; h < inH; ++h) { + for (int w = 0; w < inW; ++w) { + outputs[((n * inH + h) * inW + w) * inC + c] = *(inputs++); + } + } + } + } +} + +template <> +void NHWC2NCHW(real* outputs, + const real* inputs, + const int num, + const int inH, + const int inW, + const int inC) { + for (int n = 0; n < num; ++n) { + for (int h = 0; h < inH; ++h) { + for (int w = 0; w < inW; ++w) { + for (int c = 0; c < inC; ++c) { + outputs[((n * inC + c) * inH + h) * inW + w] = *(inputs++); + } + } + } + } +} + +/** + * \brief Padding zeros to input according to the specify dimension. + * The struct pad_ contains the padding size in each dimension. + * The input and output is a 4D tensor. In PadFunc, we only + * pad zeros to the 2nd to 4th dimension. + * + * Argument in this Function: + * \param pad_ A struct object contains the padding size in each dimension. + * It has six integers. The channelStart and channelEnd indicate + * how many zeros to add before and after the input in channel + * dimension. And the heightStart and heightEnd indicate padding + * in height dimension. The widthStart and widthEnd indicate the + * padding in width dimension. + * \param inputs A 4D tensor, only one input. + * \param outputs A 4D tensor, the output value after padding. + * + */ + +template +class NCHW2NHWCFunc : public FunctionBase { +public: + void init(const FuncConfig& config) override {} + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(1UL, inputs.size()); + CHECK_EQ(1UL, outputs.size()); + + size_t num = inputs[0].shape()[0]; + size_t inC = inputs[0].shape()[1]; + size_t inH = inputs[0].shape()[2]; + size_t inW = inputs[0].shape()[3]; + typename Tensor::Vector vec(outputs[0].shape().getElements(), + outputs[0].data()); + vec.zero(); + + NCHW2NHWC( + outputs[0].data(), inputs[0].data(), num, inC, inH, inW); + } +}; + +/** + * \brief The backward propagation of padding Function. Remove the elements + * in the padding positions of forward. + * + * Argument in this Function: + * \param pad_ The same meaning as it in PadFunc. + * \param inputs The gradient with respect to the output value of PadFunc. + * \param outputs The gradient with respect to the input value of PadFunc. + */ + +template +class NHWC2NCHWFunc : public FunctionBase { +public: + void init(const FuncConfig& config) override {} + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(1UL, inputs.size()); + CHECK_EQ(1UL, outputs.size()); + + size_t num = inputs[0].shape()[0]; + size_t inH = inputs[0].shape()[1]; + size_t inW = inputs[0].shape()[2]; + size_t inC = inputs[0].shape()[3]; + + NHWC2NCHW( + outputs[0].data(), inputs[0].data(), num, inH, inW, inC); + } +}; + +REGISTER_TYPED_FUNC(NCHW2NHWC, CPU, NCHW2NHWCFunc); +REGISTER_TYPED_FUNC(NHWC2NCHW, CPU, NHWC2NCHWFunc); +#ifndef PADDLE_ONLY_CPU +REGISTER_TYPED_FUNC(NCHW2NHWC, GPU, NCHW2NHWCFunc); +REGISTER_TYPED_FUNC(NHWC2NCHW, GPU, NHWC2NCHWFunc); +#endif + +} // namespace paddle diff --git a/paddle/function/SwitchOp.h b/paddle/function/SwitchOp.h new file mode 100644 index 0000000000000..5a2418a703e51 --- /dev/null +++ b/paddle/function/SwitchOp.h @@ -0,0 +1,62 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "Function.h" + +namespace paddle { + +/** + * \brief This funtion switch dimension order of image input. + * The input and output is a 4D tensor. Switch order 'batch_size, + *channels, height, width' to + * order 'batch_size, height, width, channels'. + * + * \param[out] outputs save results. + * \param[in] inputs input data. + * \param[in] num batch size of input data. + * \param[in] inC channel number of input data. + * \param[in] inH height of input data. + * \param[in] inH with of input data. + */ +template +void NCHW2NHWC(real* outputs, + const real* inputs, + const int num, + const int inC, + const int inH, + const int inW); + +/** + * \brief This funtion switch dimension order of image input. + * The input and output is a 4D tensor. Switch order 'batch_size, + *height, width, channels' to + * order 'batch_size, channels, height, width'. + * + * \param[out] inGrad gradients of previous layer. + * \param[in] outGrad output gradients. + * \param[in] num batch size of input data. + * \param[in] inH height of input data. + * \param[in] inW with of input data. + * \param[in] inC channel number of input data. + */ +template +void NHWC2NCHW(real* inGrad, + const real* outGrad, + const int num, + const int inH, + const int inW, + const int inC); +} // namespace paddle diff --git a/paddle/function/SwitchOpGpu.cu b/paddle/function/SwitchOpGpu.cu new file mode 100644 index 0000000000000..c2020cb2ab1cd --- /dev/null +++ b/paddle/function/SwitchOpGpu.cu @@ -0,0 +1,80 @@ +/* Copyright (c) 2016 Paddle + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "hl_base.h" +#include "SwitchOp.h" + +namespace paddle { + +__global__ void KeNCHW2NHWC(real* outputs, const real* inputs, + int inC, int inH, int inW, + int nthreads) { + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < nthreads) { + const int w = idx % inW; + const int h = (idx / inW) % inH; + const int c = (idx / inW / inH) % inC; + const int n = idx / inW / inH / inC; + + const int off = ((n * inH + h) * inW + w) * inC +c; + outputs[off] = inputs[idx]; + } +} + +template <> +void NCHW2NHWC(real* outputs, + const real* inputs, + const int num, + const int inC, + const int inH, + const int inW) { + size_t nth = num * inC * inH * inW; + int blockSize = 1024; + int gridSize = (nth + 1024 - 1) / 1024; + KeNCHW2NHWC<<>> + (outputs, inputs, inC, inH, inW, nth); + CHECK_SYNC("NCHW2NHWC"); +} + +__global__ void KeNHWC2NCHW(real* outputs, const real* inputs, + int inH, int inW, int inC, + int nthreads) { + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < nthreads) { + const int c = idx % inC; + const int w = (idx / inC) % inW; + const int h = (idx / inC / inW) % inH; + const int n = idx / inW / inH / inC; + + const int off = ((n * inC + c) * inH + h) * inW + w; + outputs[off] = inputs[idx]; + } +} + +template <> +void NHWC2NCHW(real* outputs, + const real* inputs, + const int num, + const int inH, + const int inW, + const int inC) { + int nth = num * inC * inH * inW; + int blockSize = 1024; + int gridSize = (nth + 1024 - 1) / 1024; + KeNHWC2NCHW<<>> + (outputs, inputs, inH, inW, inC, nth); + CHECK_SYNC("NHWC2NCHW"); +} + +} // namespace paddle diff --git a/paddle/function/SwitchOpTest.cpp b/paddle/function/SwitchOpTest.cpp new file mode 100644 index 0000000000000..03b0dd66ddcba --- /dev/null +++ b/paddle/function/SwitchOpTest.cpp @@ -0,0 +1,44 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "FunctionTest.h" + +namespace paddle { + +TEST(Pad, real) { + for (size_t numSamples : {1, 4, 8, 16}) { + for (size_t channels : {1, 4, 8, 16}) { + for (size_t imgSizeH : {1, 4, 8, 16}) { + for (size_t imgSizeW : {1, 4, 8, 16}) { + VLOG(3) << " numSamples=" << numSamples << " channels=" << channels + << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW; + for (bool test_grad : {true, false}) { + CpuGpuFuncCompare compare(test_grad ? "NHWC2NCHW" : "NCHW2NHWC", + FuncConfig()); + TensorShape inDims{numSamples, channels, imgSizeH, imgSizeW}; + TensorShape outDims{numSamples, imgSizeH, imgSizeW, channels}; + compare.addInputs( + BufferArg(VALUE_TYPE_FLOAT, test_grad ? outDims : inDims)); + compare.addOutputs(BufferArg( + VALUE_TYPE_FLOAT, test_grad ? inDims : outDims, ASSIGN_TO)); + compare.run(); + } + } + } + } + } +} + +} // namespace paddle diff --git a/paddle/gserver/layers/PixelSoftmaxLayer.cpp b/paddle/gserver/layers/PixelSoftmaxLayer.cpp new file mode 100644 index 0000000000000..6da84a6303102 --- /dev/null +++ b/paddle/gserver/layers/PixelSoftmaxLayer.cpp @@ -0,0 +1,89 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "PixelSoftmaxLayer.h" +#include "paddle/utils/Stat.h" + +namespace paddle { + +REGISTER_LAYER(pixel_softmax, PixelSoftmaxLayer); + +bool PixelSoftmaxLayer::init(const LayerMap& layerMap, + const ParameterMap& parameterMap) { + /* Initialize the basic parent class */ + Layer::init(layerMap, parameterMap); + auto& img_conf = config_.inputs(0).image_conf(); + inH_ = + img_conf.has_img_size_y() ? img_conf.img_size_y() : img_conf.img_size(); + inW_ = img_conf.img_size(); + inC_ = img_conf.channels(); + createFunction(forward_, "NCHW2NHWC", FuncConfig()); + createFunction(backward_, "NHWC2NCHW", FuncConfig()); + inDims_ = TensorShape({0, inH_, inW_, inC_}); + outDims_ = TensorShape({0, inC_, inH_, inW_}); + return true; +} + +void PixelSoftmaxLayer::forward(PassType passType) { + Layer::forward(passType); + MatrixPtr input = inputLayers_[0]->getOutputValue(); + size_t batchSize = input->getHeight(); + // cout<<"useGpu:"<zeroMem(); + resetOutput(batchSize, inH_ * inW_ * inC_); + inDims_.setDim(0, batchSize); + outDims_.setDim(0, batchSize); + + // switch NCHW to NHWC + BufferArgs inputs; + BufferArgs outputs; + inputs.addArg(*getInputValue(0), inDims_); + outputs.addArg(*tmpInput_, outDims_); + forward_[0]->calc(inputs, outputs); + // softmax forward and save softmax result into tmpMatrix_ + tmpInput_->softmax(*tmpOutput_); + + // switch NHWC to NCHW + BufferArgs inputs_1; + BufferArgs outputs_1; + inputs_1.addArg(*tmpOutput_, outDims_); + outputs_1.addArg(*getOutputValue(), inDims_); + backward_[0]->calc(inputs_1, outputs_1); +} + +void PixelSoftmaxLayer::backward(const UpdateCallback& callback) { + (void)callback; + REGISTER_TIMER_INFO("PixelSoftmaxBackward", getName().c_str()); + + // switch NCHW to NHWC + BufferArgs inputs; + BufferArgs outputs; + inputs.addArg(*getOutputGrad(), inDims_); + outputs.addArg(*tmpInput_, outDims_); + forward_[0]->calc(inputs, outputs); + // softmax backward and save grad result into tmpOutput_ + tmpInput_->softmaxBackward(*tmpOutput_); + + // switch NHWC to NCHW + BufferArgs inputs_1; + BufferArgs outputs_1; + inputs_1.addArg(*tmpInput_, outDims_); + outputs_1.addArg(*getInputGrad(0), inDims_); + backward_[0]->calc(inputs_1, outputs_1); +} +} // namespace paddle diff --git a/paddle/gserver/layers/PixelSoftmaxLayer.h b/paddle/gserver/layers/PixelSoftmaxLayer.h new file mode 100644 index 0000000000000..80a4ddad5a692 --- /dev/null +++ b/paddle/gserver/layers/PixelSoftmaxLayer.h @@ -0,0 +1,44 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "Layer.h" + +namespace paddle { + +/** + * \brief This layer calculate softmax in image channel dimension. + */ +class PixelSoftmaxLayer : public Layer { +public: + explicit PixelSoftmaxLayer(const LayerConfig& config) : Layer(config) {} + + ~PixelSoftmaxLayer() {} + + bool init(const LayerMap& layerMap, + const ParameterMap& parameterMap) override; + void forward(PassType passType) override; + void backward(const UpdateCallback& callback = nullptr) override; + +protected: + uint32_t inC_; + uint32_t inH_; + uint32_t inW_; + TensorShape inDims_; + TensorShape outDims_; + MatrixPtr tmpInput_; + MatrixPtr tmpOutput_; +}; +} // namespace paddle diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index 59d1e9273d42d..8a9904087e192 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -1792,6 +1792,25 @@ TEST(Layer, RowConvLayer) { } } +TEST(Layer, PixelSoftmaxLayer) { + TestConfig config; + // config input_0 + config.inputDefs.push_back({INPUT_DATA, "layer_0", 1024, 0}); + LayerInputConfig* input = config.layerConfig.add_inputs(); + ImageConfig* img = input->mutable_image_conf(); + img->set_channels(4); + img->set_img_size(16); + img->set_img_size_y(16); + + // config softmax layer + config.layerConfig.set_type("pixel_softmax"); + config.layerConfig.set_name("pixelSofrmaxLayer"); + + for (auto useGpu : {false, true}) { + testLayerGrad(config, "pixel_softmax", 100, false, useGpu, true, 2); + } +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); initMain(argc, argv); diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index 4431d613f655c..2c18df3732f3a 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -3385,6 +3385,27 @@ void CpuMatrix::oneHotCrossEntropyWithSelfNormBp(Matrix& output, real* out = output.getData(); \ for (size_t i = 0; i < numSamples; ++i, grad += dim, out += dim) +void CpuMatrix::softmaxBackward(Matrix& outputV) { + CHECK(!outputV.useGpu()) << "Matrix type are not equal"; + size_t height = getHeight(); + size_t width = getWidth(); + CHECK(height == outputV.getHeight() && width == outputV.getWidth()) + << "Matrix dimensions are not equal"; + Matrix::resizeOrCreate(sftmaxDot_, + height_, + width_, + /* trans */ false, + useGpu_); + Matrix::resizeOrCreate(sftmaxSum_, + height_, + 1, + /* trans */ false, + useGpu_); + sftmaxDot_->dotMul(*this, outputV); + sftmaxSum_->colMerge(*sftmaxDot_); + softmaxDerivative(outputV, *sftmaxSum_); +} + void CpuMatrix::softmax(Matrix& output) { CHECK(!output.useGpu()); diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h index 7dfd593225065..dcb63a2d3fcd4 100644 --- a/paddle/math/Matrix.h +++ b/paddle/math/Matrix.h @@ -1456,6 +1456,10 @@ class GpuMatrix : public Matrix { }; class CpuMatrix : public Matrix { +private: + MatrixPtr sftmaxSum_; + MatrixPtr sftmaxDot_; + public: CpuMatrix(size_t height, size_t width, bool trans = false); CpuMatrix(real* data, size_t height, size_t width, bool trans = false) @@ -1728,6 +1732,7 @@ class CpuMatrix : public Matrix { Matrix& prevGrad2); void softmax(Matrix& output); + void softmaxBackward(Matrix& outputV); void sequenceSoftmax(Matrix& output, const IVector& index); void softmaxDerivative(Matrix& output, Matrix& sftmaxSum); diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 370529ed97b1f..dc9c503e0b51b 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -3171,6 +3171,22 @@ def __init__(self, name, device=None): name, 'recurrent_layer_group', 0, inputs=[], device=device) +@config_layer('pixel_softmax') +class PixelSoftmaxLayer(LayerBase): + def __init__(self, input, name, **xargs): + super(PixelSoftmaxLayer, self).__init__( + name, 'pixel_softmax', 0, inputs=inputs, **xargs) + + input_layer = self.get_input_layer(0) + image_conf = self.config.inputs[0].image_conf + image_conf.img_size = input_layer.width + image_conf.img_size_y = input_layer.height + image_conf.channels = input_layer.size / (input_layer.width * + input_layer.height) + self.set_cnn_layer(name, image_conf.img_size_y, image_conf.img_size, + image_conf.channels) + + # Deprecated, use a new layer specific class instead @config_func def Layer(name, type, **xargs): diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 206de1f8e1c7d..fdac5984b08f9 100755 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -217,6 +217,7 @@ class LayerType(object): SMOOTH_L1 = 'smooth_l1' PRELU = 'prelu' + PIXEL_SOFTMAX_LAYER = 'pixel_softmax' @staticmethod def is_layer_type(type_name): @@ -5853,3 +5854,40 @@ def prelu_layer(input, layer_type=LayerType.PRELU, parents=input, size=l.config.size) + + +@layer_support() +@wrap_name_default('pixel_softmax') +def pixel_softmax_layer(input, name=None, layer_attr=None): + """ + This layer calculate softmax in image channel dimension + + The example usage is: + + .. code-block:: python + + prelu = pixel_softmax(input=layer, name='softmax') + + :param name: Name of this layer. + :type name: basestring + :param input: The input layer. + :type input: LayerOutput + :return: LayerOutput object. + :rtype: LayerOutput + """ + if isinstance(input, LayerOutput): + input = [input] + elif isinstance(input, Projection): + input = [input] + else: + assert isinstance(input, collections.Sequence) + l = Layer( + inputs=[x.name for x in input], + name=name, + type=LayerType.PIXEL_SOFTMAX_LAYER, + **ExtraLayerAttribute.to_kwargs(layer_attr)) + return LayerOutput( + name=name, + layer_type=LayerType.PIXEL_SOFTMAX_LAYER, + parents=input, + size=l.config.size) From 0152d97e6344fbf866d75bf24f6f6034a81f5e81 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Tue, 11 Jul 2017 10:23:29 +0800 Subject: [PATCH 2/9] fix pixel softmax python wrapper bug --- python/paddle/trainer/config_parser.py | 2 +- python/paddle/trainer_config_helpers/layers.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index c24af47c4ba31..261e834e11846 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -3176,7 +3176,7 @@ def __init__(self, name, device=None): @config_layer('pixel_softmax') class PixelSoftmaxLayer(LayerBase): - def __init__(self, input, name, **xargs): + def __init__(self, name, inputs, **xargs): super(PixelSoftmaxLayer, self).__init__( name, 'pixel_softmax', 0, inputs=inputs, **xargs) diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index d8cc52d4098ac..2f8b0d1002453 100755 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -126,6 +126,7 @@ 'row_conv_layer', 'dropout_layer', 'prelu_layer', + 'pixel_softmax_layer', ] @@ -5905,8 +5906,8 @@ def pixel_softmax_layer(input, name=None, layer_attr=None): else: assert isinstance(input, collections.Sequence) l = Layer( - inputs=[x.name for x in input], name=name, + inputs=[x.name for x in input], type=LayerType.PIXEL_SOFTMAX_LAYER, **ExtraLayerAttribute.to_kwargs(layer_attr)) return LayerOutput( From 1cdf149b6fccf4fba030f0bb847965500960fa9b Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Wed, 19 Jul 2017 12:50:45 +0800 Subject: [PATCH 3/9] 1. delete PixelSoftmaxLayer and add SwitchOrderLayer 2. Make SwitchOrderLayer support for softmax activation 3. Fix bugs --- CMakeLists.txt | 2 +- paddle/function/SwitchOp.cpp | 72 ++++++----- paddle/function/SwitchOp.h | 8 +- paddle/function/SwitchOpGpu.cu | 26 ++-- paddle/gserver/layers/PixelSoftmaxLayer.cpp | 89 -------------- paddle/gserver/layers/SwitchOrderLayer.cpp | 112 ++++++++++++++++++ ...PixelSoftmaxLayer.h => SwitchOrderLayer.h} | 19 +-- paddle/gserver/tests/test_LayerGrad.cpp | 14 ++- paddle/math/Matrix.cpp | 21 ---- paddle/math/Matrix.h | 1 - proto/ModelConfig.proto | 8 ++ python/paddle/trainer/config_parser.py | 21 ++-- .../paddle/trainer_config_helpers/layers.py | 36 +++--- 13 files changed, 231 insertions(+), 198 deletions(-) delete mode 100644 paddle/gserver/layers/PixelSoftmaxLayer.cpp create mode 100644 paddle/gserver/layers/SwitchOrderLayer.cpp rename paddle/gserver/layers/{PixelSoftmaxLayer.h => SwitchOrderLayer.h} (71%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 15a7c6b07417a..fdc62b31511c4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,7 +13,7 @@ # limitations under the License cmake_minimum_required(VERSION 3.0) - +SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -ldl -lpthread") set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake") set(PROJ_ROOT ${CMAKE_CURRENT_SOURCE_DIR}) set(PROJ_BINARY_ROOT ${CMAKE_CURRENT_BINARY_DIR}) diff --git a/paddle/function/SwitchOp.cpp b/paddle/function/SwitchOp.cpp index 4667c4e01d52a..01e252a8dc0cd 100644 --- a/paddle/function/SwitchOp.cpp +++ b/paddle/function/SwitchOp.cpp @@ -23,12 +23,17 @@ void NCHW2NHWC(real* outputs, const int num, const int inC, const int inH, - const int inW) { + const int inW, + const int argType) { for (int n = 0; n < num; ++n) { for (int c = 0; c < inC; ++c) { for (int h = 0; h < inH; ++h) { for (int w = 0; w < inW; ++w) { - outputs[((n * inH + h) * inW + w) * inC + c] = *(inputs++); + if (argType == ADD_TO) { + outputs[((n * inH + h) * inW + w) * inC + c] += *(inputs++); + } else { + outputs[((n * inH + h) * inW + w) * inC + c] = *(inputs++); + } } } } @@ -41,12 +46,17 @@ void NHWC2NCHW(real* outputs, const int num, const int inH, const int inW, - const int inC) { + const int inC, + const int argType) { for (int n = 0; n < num; ++n) { for (int h = 0; h < inH; ++h) { for (int w = 0; w < inW; ++w) { for (int c = 0; c < inC; ++c) { - outputs[((n * inC + c) * inH + h) * inW + w] = *(inputs++); + if (argType == ADD_TO) { + outputs[((n * inC + c) * inH + h) * inW + w] += *(inputs++); + } else { + outputs[((n * inC + c) * inH + h) * inW + w] = *(inputs++); + } } } } @@ -54,23 +64,15 @@ void NHWC2NCHW(real* outputs, } /** - * \brief Padding zeros to input according to the specify dimension. - * The struct pad_ contains the padding size in each dimension. - * The input and output is a 4D tensor. In PadFunc, we only - * pad zeros to the 2nd to 4th dimension. + * \brief Switch dimension order of image input. + * The input and output is a 4D tensor. Switch order + * 'batch_size,channels, height, width' to + * order 'batch_size, height, width, channels'. * * Argument in this Function: - * \param pad_ A struct object contains the padding size in each dimension. - * It has six integers. The channelStart and channelEnd indicate - * how many zeros to add before and after the input in channel - * dimension. And the heightStart and heightEnd indicate padding - * in height dimension. The widthStart and widthEnd indicate the - * padding in width dimension. - * \param inputs A 4D tensor, only one input. - * \param outputs A 4D tensor, the output value after padding. - * + * \param inputs input data with order 'batch_size,channels, height, width'. + * \param outputs output data with order 'batch_size, height, width, channels'. */ - template class NCHW2NHWCFunc : public FunctionBase { public: @@ -84,25 +86,26 @@ class NCHW2NHWCFunc : public FunctionBase { size_t inC = inputs[0].shape()[1]; size_t inH = inputs[0].shape()[2]; size_t inW = inputs[0].shape()[3]; - typename Tensor::Vector vec(outputs[0].shape().getElements(), - outputs[0].data()); - vec.zero(); - - NCHW2NHWC( - outputs[0].data(), inputs[0].data(), num, inC, inH, inW); + NCHW2NHWC(outputs[0].data(), + inputs[0].data(), + num, + inC, + inH, + inW, + outputs[0].getArgType()); } }; /** - * \brief The backward propagation of padding Function. Remove the elements - * in the padding positions of forward. + * \brief Switch dimension order of image input. + * The input and output is a 4D tensor. Switch order + * 'batch_size, height, width, channels' to + * order 'batch_size, channels, height, width'. * * Argument in this Function: - * \param pad_ The same meaning as it in PadFunc. - * \param inputs The gradient with respect to the output value of PadFunc. - * \param outputs The gradient with respect to the input value of PadFunc. + * \param inputs input data with order 'batch_size, height, width, channels'. + * \param outputs output data with order 'batch_size, channels, height, width'. */ - template class NHWC2NCHWFunc : public FunctionBase { public: @@ -117,8 +120,13 @@ class NHWC2NCHWFunc : public FunctionBase { size_t inW = inputs[0].shape()[2]; size_t inC = inputs[0].shape()[3]; - NHWC2NCHW( - outputs[0].data(), inputs[0].data(), num, inH, inW, inC); + NHWC2NCHW(outputs[0].data(), + inputs[0].data(), + num, + inH, + inW, + inC, + outputs[0].getArgType()); } }; diff --git a/paddle/function/SwitchOp.h b/paddle/function/SwitchOp.h index 5a2418a703e51..e4c1c3ac922f8 100644 --- a/paddle/function/SwitchOp.h +++ b/paddle/function/SwitchOp.h @@ -30,6 +30,7 @@ namespace paddle { * \param[in] inC channel number of input data. * \param[in] inH height of input data. * \param[in] inH with of input data. + * \param[in] argType type of output argument. */ template void NCHW2NHWC(real* outputs, @@ -37,7 +38,8 @@ void NCHW2NHWC(real* outputs, const int num, const int inC, const int inH, - const int inW); + const int inW, + const int argtype); /** * \brief This funtion switch dimension order of image input. @@ -51,6 +53,7 @@ void NCHW2NHWC(real* outputs, * \param[in] inH height of input data. * \param[in] inW with of input data. * \param[in] inC channel number of input data. + * \param[in] argType type of output argument. */ template void NHWC2NCHW(real* inGrad, @@ -58,5 +61,6 @@ void NHWC2NCHW(real* inGrad, const int num, const int inH, const int inW, - const int inC); + const int inC, + const int argType); } // namespace paddle diff --git a/paddle/function/SwitchOpGpu.cu b/paddle/function/SwitchOpGpu.cu index c2020cb2ab1cd..0b9401dea1fea 100644 --- a/paddle/function/SwitchOpGpu.cu +++ b/paddle/function/SwitchOpGpu.cu @@ -19,7 +19,7 @@ namespace paddle { __global__ void KeNCHW2NHWC(real* outputs, const real* inputs, int inC, int inH, int inW, - int nthreads) { + int nthreads, int argType) { const int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < nthreads) { const int w = idx % inW; @@ -28,7 +28,11 @@ __global__ void KeNCHW2NHWC(real* outputs, const real* inputs, const int n = idx / inW / inH / inC; const int off = ((n * inH + h) * inW + w) * inC +c; - outputs[off] = inputs[idx]; + if (argType == ADD_TO) { + outputs[off] += inputs[idx]; + } else { + outputs[off] = inputs[idx]; + } } } @@ -38,18 +42,19 @@ void NCHW2NHWC(real* outputs, const int num, const int inC, const int inH, - const int inW) { + const int inW, + const int argType) { size_t nth = num * inC * inH * inW; int blockSize = 1024; int gridSize = (nth + 1024 - 1) / 1024; KeNCHW2NHWC<<>> - (outputs, inputs, inC, inH, inW, nth); + (outputs, inputs, inC, inH, inW, nth, argType); CHECK_SYNC("NCHW2NHWC"); } __global__ void KeNHWC2NCHW(real* outputs, const real* inputs, int inH, int inW, int inC, - int nthreads) { + int nthreads, int argType) { const int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < nthreads) { const int c = idx % inC; @@ -58,7 +63,11 @@ __global__ void KeNHWC2NCHW(real* outputs, const real* inputs, const int n = idx / inW / inH / inC; const int off = ((n * inC + c) * inH + h) * inW + w; - outputs[off] = inputs[idx]; + if (argType == ADD_TO) { + outputs[off] += inputs[idx]; + } else { + outputs[off] = inputs[idx]; + } } } @@ -68,12 +77,13 @@ void NHWC2NCHW(real* outputs, const int num, const int inH, const int inW, - const int inC) { + const int inC, + const int argType) { int nth = num * inC * inH * inW; int blockSize = 1024; int gridSize = (nth + 1024 - 1) / 1024; KeNHWC2NCHW<<>> - (outputs, inputs, inH, inW, inC, nth); + (outputs, inputs, inH, inW, inC, nth, argType); CHECK_SYNC("NHWC2NCHW"); } diff --git a/paddle/gserver/layers/PixelSoftmaxLayer.cpp b/paddle/gserver/layers/PixelSoftmaxLayer.cpp deleted file mode 100644 index 6da84a6303102..0000000000000 --- a/paddle/gserver/layers/PixelSoftmaxLayer.cpp +++ /dev/null @@ -1,89 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "PixelSoftmaxLayer.h" -#include "paddle/utils/Stat.h" - -namespace paddle { - -REGISTER_LAYER(pixel_softmax, PixelSoftmaxLayer); - -bool PixelSoftmaxLayer::init(const LayerMap& layerMap, - const ParameterMap& parameterMap) { - /* Initialize the basic parent class */ - Layer::init(layerMap, parameterMap); - auto& img_conf = config_.inputs(0).image_conf(); - inH_ = - img_conf.has_img_size_y() ? img_conf.img_size_y() : img_conf.img_size(); - inW_ = img_conf.img_size(); - inC_ = img_conf.channels(); - createFunction(forward_, "NCHW2NHWC", FuncConfig()); - createFunction(backward_, "NHWC2NCHW", FuncConfig()); - inDims_ = TensorShape({0, inH_, inW_, inC_}); - outDims_ = TensorShape({0, inC_, inH_, inW_}); - return true; -} - -void PixelSoftmaxLayer::forward(PassType passType) { - Layer::forward(passType); - MatrixPtr input = inputLayers_[0]->getOutputValue(); - size_t batchSize = input->getHeight(); - // cout<<"useGpu:"<zeroMem(); - resetOutput(batchSize, inH_ * inW_ * inC_); - inDims_.setDim(0, batchSize); - outDims_.setDim(0, batchSize); - - // switch NCHW to NHWC - BufferArgs inputs; - BufferArgs outputs; - inputs.addArg(*getInputValue(0), inDims_); - outputs.addArg(*tmpInput_, outDims_); - forward_[0]->calc(inputs, outputs); - // softmax forward and save softmax result into tmpMatrix_ - tmpInput_->softmax(*tmpOutput_); - - // switch NHWC to NCHW - BufferArgs inputs_1; - BufferArgs outputs_1; - inputs_1.addArg(*tmpOutput_, outDims_); - outputs_1.addArg(*getOutputValue(), inDims_); - backward_[0]->calc(inputs_1, outputs_1); -} - -void PixelSoftmaxLayer::backward(const UpdateCallback& callback) { - (void)callback; - REGISTER_TIMER_INFO("PixelSoftmaxBackward", getName().c_str()); - - // switch NCHW to NHWC - BufferArgs inputs; - BufferArgs outputs; - inputs.addArg(*getOutputGrad(), inDims_); - outputs.addArg(*tmpInput_, outDims_); - forward_[0]->calc(inputs, outputs); - // softmax backward and save grad result into tmpOutput_ - tmpInput_->softmaxBackward(*tmpOutput_); - - // switch NHWC to NCHW - BufferArgs inputs_1; - BufferArgs outputs_1; - inputs_1.addArg(*tmpInput_, outDims_); - outputs_1.addArg(*getInputGrad(0), inDims_); - backward_[0]->calc(inputs_1, outputs_1); -} -} // namespace paddle diff --git a/paddle/gserver/layers/SwitchOrderLayer.cpp b/paddle/gserver/layers/SwitchOrderLayer.cpp new file mode 100644 index 0000000000000..2a8a9500faef3 --- /dev/null +++ b/paddle/gserver/layers/SwitchOrderLayer.cpp @@ -0,0 +1,112 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "SwitchOrderLayer.h" +#include "paddle/utils/Stat.h" + +namespace paddle { + +REGISTER_LAYER(switch_order, SwitchOrderLayer); + +bool SwitchOrderLayer::init(const LayerMap& layerMap, + const ParameterMap& parameterMap) { + /* Initialize the basic parent class */ + Layer::init(layerMap, parameterMap); + auto& img_conf = config_.inputs(0).image_conf(); + size_t inH = + img_conf.has_img_size_y() ? img_conf.img_size_y() : img_conf.img_size(); + size_t inW = img_conf.img_size(); + size_t inC = img_conf.channels(); + inDims_ = TensorShape({0, inC, inH, inW}); + outDims_ = TensorShape(4); + + auto& reshape_conf = config_.reshape_conf(); + for (size_t i = 0; i < reshape_conf.heightaxis_size(); i++) { + LOG(INFO) << "reshape height axis: " << reshape_conf.heightaxis(i); + heightAxis_.push_back(reshape_conf.heightaxis(i)); + } + for (size_t i = 0; i < reshape_conf.widthaxis_size(); i++) { + LOG(INFO) << "reshape width axis: " << reshape_conf.widthaxis(i); + widthAxis_.push_back(reshape_conf.widthaxis(i)); + } + createFunction(nchw2nhwc_, "NCHW2NHWC", FuncConfig()); + createFunction(nhwc2nchw_, "NHWC2NCHW", FuncConfig()); + return true; +} + +void SwitchOrderLayer::setOutDims() { + outDims_.setDim(0, inDims_[0]); + outDims_.setDim(1, inDims_[2]); + outDims_.setDim(2, inDims_[3]); + outDims_.setDim(3, inDims_[1]); + reshapeHeight_ = 1; + for (size_t i = 0; i < heightAxis_.size(); i++) { + reshapeHeight_ *= outDims_[heightAxis_[i]]; + } + output_.setFrameHeight(reshapeHeight_); + reshapeWidth_ = 1; + for (size_t i = 0; i < widthAxis_.size(); i++) { + reshapeWidth_ *= outDims_[widthAxis_[i]]; + } + output_.setFrameWidth(reshapeWidth_); + LOG(INFO) << "outDims: " << outDims_[0] << "; " << outDims_[1] << ";" + << outDims_[2] << ";" << outDims_[3]; +} + +void SwitchOrderLayer::setInDims() { + MatrixPtr input = inputLayers_[0]->getOutputValue(); + size_t batchSize = input->getHeight(); + inDims_.setDim(0, batchSize); + + int h = inputLayers_[0]->getOutput().getFrameHeight(); + if (h != 0) inDims_.setDim(2, h); + int w = inputLayers_[0]->getOutput().getFrameWidth(); + if (w != 0) inDims_.setDim(3, w); + int totalCount = input->getElementCnt(); + int channels = totalCount / (inDims_[0] * inDims_[2] * inDims_[3]); + if (channels != 0) inDims_.setDim(1, channels); + LOG(INFO) << "inDims: " << inDims_[0] << "; " << inDims_[1] << ";" + << inDims_[2] << ";" << inDims_[3]; +} + +void SwitchOrderLayer::forward(PassType passType) { + Layer::forward(passType); + setInDims(); + setOutDims(); + resetOutput(outDims_[0], outDims_[1] * outDims_[2] * outDims_[3]); + if (heightAxis_.size() > 0) { + getOutputValue()->reshape(reshapeHeight_, reshapeWidth_); + } + + // switch NCHW to NHWC + BufferArgs inputs; + BufferArgs outputs; + inputs.addArg(*getInputValue(0), inDims_); + outputs.addArg(*getOutputValue(), outDims_); + nchw2nhwc_[0]->calc(inputs, outputs); + // forwardActivation(); +} + +void SwitchOrderLayer::backward(const UpdateCallback& callback) { + (void)callback; + // backwardActivation(); + + // switch NHWC to NCHW + BufferArgs inputs; + BufferArgs outputs; + inputs.addArg(*getOutputGrad(), outDims_); + outputs.addArg(*getInputGrad(0), inDims_, ADD_TO); + nhwc2nchw_[0]->calc(inputs, outputs); +} +} // namespace paddle diff --git a/paddle/gserver/layers/PixelSoftmaxLayer.h b/paddle/gserver/layers/SwitchOrderLayer.h similarity index 71% rename from paddle/gserver/layers/PixelSoftmaxLayer.h rename to paddle/gserver/layers/SwitchOrderLayer.h index 80a4ddad5a692..47b1f7f73ee78 100644 --- a/paddle/gserver/layers/PixelSoftmaxLayer.h +++ b/paddle/gserver/layers/SwitchOrderLayer.h @@ -21,24 +21,27 @@ namespace paddle { /** * \brief This layer calculate softmax in image channel dimension. */ -class PixelSoftmaxLayer : public Layer { +class SwitchOrderLayer : public Layer { public: - explicit PixelSoftmaxLayer(const LayerConfig& config) : Layer(config) {} + explicit SwitchOrderLayer(const LayerConfig& config) : Layer(config) {} - ~PixelSoftmaxLayer() {} + ~SwitchOrderLayer() {} bool init(const LayerMap& layerMap, const ParameterMap& parameterMap) override; void forward(PassType passType) override; void backward(const UpdateCallback& callback = nullptr) override; + void setInDims(); + void setOutDims(); protected: - uint32_t inC_; - uint32_t inH_; - uint32_t inW_; + std::vector> nchw2nhwc_; + std::vector> nhwc2nchw_; TensorShape inDims_; TensorShape outDims_; - MatrixPtr tmpInput_; - MatrixPtr tmpOutput_; + std::vector heightAxis_; + std::vector widthAxis_; + size_t reshapeHeight_; + size_t reshapeWidth_; }; } // namespace paddle diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index 98c9cbe9f5d0a..42c23f02264cf 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -1802,7 +1802,7 @@ TEST(Layer, RowConvLayer) { } } -TEST(Layer, PixelSoftmaxLayer) { +TEST(Layer, SwitchOrderLayer) { TestConfig config; // config input_0 config.inputDefs.push_back({INPUT_DATA, "layer_0", 1024, 0}); @@ -1812,12 +1812,18 @@ TEST(Layer, PixelSoftmaxLayer) { img->set_img_size(16); img->set_img_size_y(16); + ReshapeConfig* reshape = config.layerConfig.mutable_reshape_conf(); + reshape->add_heightaxis(0); + reshape->add_heightaxis(1); + reshape->add_heightaxis(2); + reshape->add_widthaxis(3); + // config softmax layer - config.layerConfig.set_type("pixel_softmax"); - config.layerConfig.set_name("pixelSofrmaxLayer"); + config.layerConfig.set_type("switch_order"); + config.layerConfig.set_name("switchOrderLayer"); for (auto useGpu : {false, true}) { - testLayerGrad(config, "pixel_softmax", 100, false, useGpu, true, 2); + testLayerGrad(config, "switch_order", 100, false, useGpu, true); } } diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index 2c18df3732f3a..4431d613f655c 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -3385,27 +3385,6 @@ void CpuMatrix::oneHotCrossEntropyWithSelfNormBp(Matrix& output, real* out = output.getData(); \ for (size_t i = 0; i < numSamples; ++i, grad += dim, out += dim) -void CpuMatrix::softmaxBackward(Matrix& outputV) { - CHECK(!outputV.useGpu()) << "Matrix type are not equal"; - size_t height = getHeight(); - size_t width = getWidth(); - CHECK(height == outputV.getHeight() && width == outputV.getWidth()) - << "Matrix dimensions are not equal"; - Matrix::resizeOrCreate(sftmaxDot_, - height_, - width_, - /* trans */ false, - useGpu_); - Matrix::resizeOrCreate(sftmaxSum_, - height_, - 1, - /* trans */ false, - useGpu_); - sftmaxDot_->dotMul(*this, outputV); - sftmaxSum_->colMerge(*sftmaxDot_); - softmaxDerivative(outputV, *sftmaxSum_); -} - void CpuMatrix::softmax(Matrix& output) { CHECK(!output.useGpu()); diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h index dcb63a2d3fcd4..20f97a5060bbf 100644 --- a/paddle/math/Matrix.h +++ b/paddle/math/Matrix.h @@ -1732,7 +1732,6 @@ class CpuMatrix : public Matrix { Matrix& prevGrad2); void softmax(Matrix& output); - void softmaxBackward(Matrix& outputV); void sequenceSoftmax(Matrix& output, const IVector& index); void softmaxDerivative(Matrix& output, Matrix& sftmaxSum); diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto index 37cd16c798907..9fd017b23e4ae 100644 --- a/proto/ModelConfig.proto +++ b/proto/ModelConfig.proto @@ -266,6 +266,11 @@ message PadConfig { repeated uint32 pad_w = 4; } +message ReshapeConfig { + repeated uint32 heightAxis = 1; + repeated uint32 widthAxis = 2; +} + message MultiBoxLossConfig { required uint32 num_classes = 1; required float overlap_threshold = 2; @@ -476,6 +481,9 @@ message LayerConfig { // controls the scope of pooling operation. can be set > 0. // leave empty or set to -1 to disable this stride pooling. optional int32 seq_pool_stride = 53 [default = -1]; + + // for switch order layer + optional ReshapeConfig reshape_conf = 54; } message EvaluatorConfig { diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 261e834e11846..fe06dd812edec 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -3174,20 +3174,13 @@ def __init__(self, name, device=None): name, 'recurrent_layer_group', 0, inputs=[], device=device) -@config_layer('pixel_softmax') -class PixelSoftmaxLayer(LayerBase): - def __init__(self, name, inputs, **xargs): - super(PixelSoftmaxLayer, self).__init__( - name, 'pixel_softmax', 0, inputs=inputs, **xargs) - - input_layer = self.get_input_layer(0) - image_conf = self.config.inputs[0].image_conf - image_conf.img_size = input_layer.width - image_conf.img_size_y = input_layer.height - image_conf.channels = input_layer.size / (input_layer.width * - input_layer.height) - self.set_cnn_layer(name, image_conf.img_size_y, image_conf.img_size, - image_conf.channels) +@config_layer('switch_order') +class SwitchOrderLayer(LayerBase): + def __init__(self, name, inputs, reshape, **xargs): + super(SwitchOrderLayer, self).__init__( + name, 'switch_order', 0, inputs=inputs, **xargs) + self.conf.reshape_conf.heightAxis_ = reshape['height'] + self.conf.reshape_conf.widthAxis_ = reshape['width'] # Deprecated, use a new layer specific class instead diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 2f8b0d1002453..6980a31679b5a 100755 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -126,7 +126,7 @@ 'row_conv_layer', 'dropout_layer', 'prelu_layer', - 'pixel_softmax_layer', + 'switch_order_layer', ] @@ -218,7 +218,7 @@ class LayerType(object): SMOOTH_L1 = 'smooth_l1' PRELU = 'prelu' - PIXEL_SOFTMAX_LAYER = 'pixel_softmax' + SWITCH_ORDER_LAYER = 'switch_order' @staticmethod def is_layer_type(type_name): @@ -5881,37 +5881,37 @@ def prelu_layer(input, @layer_support() -@wrap_name_default('pixel_softmax') -def pixel_softmax_layer(input, name=None, layer_attr=None): +@wrap_name_default('switch_order') +def switch_order_layer(input, name=None, reshape=None, layer_attr=None): """ - This layer calculate softmax in image channel dimension + This layer switch dimension order of image input. + From order "batchSize, channels, height, width" + to order "batchSize, height, width, channels". The example usage is: .. code-block:: python + reshape = {'height':[ 0, 1, 2], 'width':[3]} + switch = switch_order(input=layer, name='switch', reshape=reshape) - prelu = pixel_softmax(input=layer, name='softmax') - - :param name: Name of this layer. - :type name: basestring :param input: The input layer. :type input: LayerOutput + :param name: Name of this layer. + :type name: basestring + :param reshape: reshape matrix by axises. + :type reshape: Dict :return: LayerOutput object. :rtype: LayerOutput """ - if isinstance(input, LayerOutput): - input = [input] - elif isinstance(input, Projection): - input = [input] - else: - assert isinstance(input, collections.Sequence) + assert isinstance(input, LayerOutput) l = Layer( name=name, - inputs=[x.name for x in input], - type=LayerType.PIXEL_SOFTMAX_LAYER, + inputs=input, + reshape=reshape, + type=LayerType.SWITCH_ORDER_LAYER, **ExtraLayerAttribute.to_kwargs(layer_attr)) return LayerOutput( name=name, - layer_type=LayerType.PIXEL_SOFTMAX_LAYER, + layer_type=LayerType.SWITCH_ORDER_LAYER, parents=input, size=l.config.size) From fa02963659239fbbd61594b61073802cc9ab4513 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Wed, 19 Jul 2017 13:15:03 +0800 Subject: [PATCH 4/9] Delete debug log --- paddle/gserver/layers/SwitchOrderLayer.cpp | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/paddle/gserver/layers/SwitchOrderLayer.cpp b/paddle/gserver/layers/SwitchOrderLayer.cpp index 2a8a9500faef3..8d337611b98d0 100644 --- a/paddle/gserver/layers/SwitchOrderLayer.cpp +++ b/paddle/gserver/layers/SwitchOrderLayer.cpp @@ -33,11 +33,9 @@ bool SwitchOrderLayer::init(const LayerMap& layerMap, auto& reshape_conf = config_.reshape_conf(); for (size_t i = 0; i < reshape_conf.heightaxis_size(); i++) { - LOG(INFO) << "reshape height axis: " << reshape_conf.heightaxis(i); heightAxis_.push_back(reshape_conf.heightaxis(i)); } for (size_t i = 0; i < reshape_conf.widthaxis_size(); i++) { - LOG(INFO) << "reshape width axis: " << reshape_conf.widthaxis(i); widthAxis_.push_back(reshape_conf.widthaxis(i)); } createFunction(nchw2nhwc_, "NCHW2NHWC", FuncConfig()); @@ -60,8 +58,6 @@ void SwitchOrderLayer::setOutDims() { reshapeWidth_ *= outDims_[widthAxis_[i]]; } output_.setFrameWidth(reshapeWidth_); - LOG(INFO) << "outDims: " << outDims_[0] << "; " << outDims_[1] << ";" - << outDims_[2] << ";" << outDims_[3]; } void SwitchOrderLayer::setInDims() { @@ -76,8 +72,6 @@ void SwitchOrderLayer::setInDims() { int totalCount = input->getElementCnt(); int channels = totalCount / (inDims_[0] * inDims_[2] * inDims_[3]); if (channels != 0) inDims_.setDim(1, channels); - LOG(INFO) << "inDims: " << inDims_[0] << "; " << inDims_[1] << ";" - << inDims_[2] << ";" << inDims_[3]; } void SwitchOrderLayer::forward(PassType passType) { @@ -95,12 +89,12 @@ void SwitchOrderLayer::forward(PassType passType) { inputs.addArg(*getInputValue(0), inDims_); outputs.addArg(*getOutputValue(), outDims_); nchw2nhwc_[0]->calc(inputs, outputs); - // forwardActivation(); + forwardActivation(); } void SwitchOrderLayer::backward(const UpdateCallback& callback) { (void)callback; - // backwardActivation(); + backwardActivation(); // switch NHWC to NCHW BufferArgs inputs; From e23acb4e6f7b12f1b61faf3cf8d74872b7df5b39 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Wed, 19 Jul 2017 14:09:32 +0800 Subject: [PATCH 5/9] fix cmake --- CMakeLists.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 9a852248432e5..2a6b0a20e4416 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,7 +13,6 @@ # limitations under the License cmake_minimum_required(VERSION 3.0) -SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -ldl -lpthread") set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake") set(PROJ_ROOT ${CMAKE_CURRENT_SOURCE_DIR}) set(PROJ_BINARY_ROOT ${CMAKE_CURRENT_BINARY_DIR}) From a6c53fc2fcef380784829cfb29764e1a6458827d Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Wed, 19 Jul 2017 17:32:05 +0800 Subject: [PATCH 6/9] fix python wrapper bugs --- python/paddle/trainer/config_parser.py | 4 ++-- python/paddle/trainer_config_helpers/layers.py | 9 +++++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 6e2f21823405c..0a466380aeb6d 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -3187,8 +3187,8 @@ class SwitchOrderLayer(LayerBase): def __init__(self, name, inputs, reshape, **xargs): super(SwitchOrderLayer, self).__init__( name, 'switch_order', 0, inputs=inputs, **xargs) - self.conf.reshape_conf.heightAxis_ = reshape['height'] - self.conf.reshape_conf.widthAxis_ = reshape['width'] + self.config.reshape_conf.heightAxis.extend(reshape['height']) + self.config.reshape_conf.widthAxis.extend(reshape['width']) # Deprecated, use a new layer specific class instead diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 1f5b9e999c740..0bcfbe1e0c7b6 100755 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -5976,7 +5976,11 @@ def gated_unit_layer(input, @layer_support() @wrap_name_default('switch_order') -def switch_order_layer(input, name=None, reshape=None, layer_attr=None): +def switch_order_layer(input, + name=None, + reshape=None, + act=None, + layer_attr=None): """ This layer switch dimension order of image input. From order "batchSize, channels, height, width" @@ -6000,9 +6004,10 @@ def switch_order_layer(input, name=None, reshape=None, layer_attr=None): assert isinstance(input, LayerOutput) l = Layer( name=name, - inputs=input, + inputs=input.name, reshape=reshape, type=LayerType.SWITCH_ORDER_LAYER, + active_type=act.name, **ExtraLayerAttribute.to_kwargs(layer_attr)) return LayerOutput( name=name, From baae8447ac936b29fb2b14981851bb502f5193cd Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Wed, 19 Jul 2017 18:53:32 +0800 Subject: [PATCH 7/9] Fix SwitchOrderLayer grad bugs by reshape output.grad --- paddle/gserver/layers/SwitchOrderLayer.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/gserver/layers/SwitchOrderLayer.cpp b/paddle/gserver/layers/SwitchOrderLayer.cpp index 8d337611b98d0..6a91042f62892 100644 --- a/paddle/gserver/layers/SwitchOrderLayer.cpp +++ b/paddle/gserver/layers/SwitchOrderLayer.cpp @@ -81,6 +81,7 @@ void SwitchOrderLayer::forward(PassType passType) { resetOutput(outDims_[0], outDims_[1] * outDims_[2] * outDims_[3]); if (heightAxis_.size() > 0) { getOutputValue()->reshape(reshapeHeight_, reshapeWidth_); + getOutputGrad()->reshape(reshapeHeight_, reshapeWidth_); } // switch NCHW to NHWC From eb3c774b8308e030407a113e8206f200899c7492 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Thu, 7 Sep 2017 07:12:18 +0800 Subject: [PATCH 8/9] Fix format error --- paddle/function/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/function/CMakeLists.txt b/paddle/function/CMakeLists.txt index 4076e20de2ab9..4fd72d64a90ae 100644 --- a/paddle/function/CMakeLists.txt +++ b/paddle/function/CMakeLists.txt @@ -42,9 +42,9 @@ if(WITH_GPU) add_simple_unittest(MulOpTest) add_simple_unittest(CosSimOpTest) add_simple_unittest(RowConvOpTest) - add_simple_unittest(SwitchOpTest) add_simple_unittest(BlockExpandOpTest) add_simple_unittest(CropOpTest) + add_simple_unittest(SwitchOpTest) endif() add_simple_unittest(Im2ColTest) From e033569dd649c08a986a4d97608692f354003c78 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Thu, 7 Sep 2017 11:22:21 +0800 Subject: [PATCH 9/9] Fix format --- paddle/function/SwitchOpGpu.cu | 56 +++++++++++++++++++--------------- proto/ModelConfig.proto | 11 +++---- 2 files changed, 37 insertions(+), 30 deletions(-) diff --git a/paddle/function/SwitchOpGpu.cu b/paddle/function/SwitchOpGpu.cu index 0b9401dea1fea..45390a56c3f77 100644 --- a/paddle/function/SwitchOpGpu.cu +++ b/paddle/function/SwitchOpGpu.cu @@ -12,14 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "hl_base.h" #include "SwitchOp.h" +#include "hl_base.h" namespace paddle { -__global__ void KeNCHW2NHWC(real* outputs, const real* inputs, - int inC, int inH, int inW, - int nthreads, int argType) { +__global__ void KeNCHW2NHWC(real* outputs, + const real* inputs, + int inC, + int inH, + int inW, + int nthreads, + int argType) { const int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < nthreads) { const int w = idx % inW; @@ -27,7 +31,7 @@ __global__ void KeNCHW2NHWC(real* outputs, const real* inputs, const int c = (idx / inW / inH) % inC; const int n = idx / inW / inH / inC; - const int off = ((n * inH + h) * inW + w) * inC +c; + const int off = ((n * inH + h) * inW + w) * inC + c; if (argType == ADD_TO) { outputs[off] += inputs[idx]; } else { @@ -38,23 +42,27 @@ __global__ void KeNCHW2NHWC(real* outputs, const real* inputs, template <> void NCHW2NHWC(real* outputs, - const real* inputs, - const int num, - const int inC, - const int inH, - const int inW, - const int argType) { + const real* inputs, + const int num, + const int inC, + const int inH, + const int inW, + const int argType) { size_t nth = num * inC * inH * inW; int blockSize = 1024; int gridSize = (nth + 1024 - 1) / 1024; - KeNCHW2NHWC<<>> - (outputs, inputs, inC, inH, inW, nth, argType); + KeNCHW2NHWC<<>>( + outputs, inputs, inC, inH, inW, nth, argType); CHECK_SYNC("NCHW2NHWC"); } -__global__ void KeNHWC2NCHW(real* outputs, const real* inputs, - int inH, int inW, int inC, - int nthreads, int argType) { +__global__ void KeNHWC2NCHW(real* outputs, + const real* inputs, + int inH, + int inW, + int inC, + int nthreads, + int argType) { const int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < nthreads) { const int c = idx % inC; @@ -73,17 +81,17 @@ __global__ void KeNHWC2NCHW(real* outputs, const real* inputs, template <> void NHWC2NCHW(real* outputs, - const real* inputs, - const int num, - const int inH, - const int inW, - const int inC, - const int argType) { + const real* inputs, + const int num, + const int inH, + const int inW, + const int inC, + const int argType) { int nth = num * inC * inH * inW; int blockSize = 1024; int gridSize = (nth + 1024 - 1) / 1024; - KeNHWC2NCHW<<>> - (outputs, inputs, inH, inW, inC, nth, argType); + KeNHWC2NCHW<<>>( + outputs, inputs, inH, inW, inC, nth, argType); CHECK_SYNC("NHWC2NCHW"); } diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto index f5b15c3adb96a..0f44d8cb8d78e 100644 --- a/proto/ModelConfig.proto +++ b/proto/ModelConfig.proto @@ -288,8 +288,8 @@ message PadConfig { } message ReshapeConfig { - repeated uint32 heightAxis = 1; - repeated uint32 widthAxis = 2; + repeated uint32 heightAxis = 1; + repeated uint32 widthAxis = 2; } message MultiBoxLossConfig { @@ -344,7 +344,6 @@ message LayerInputConfig { } message LayerConfig { - required string name = 1; required string type = 2; optional uint64 size = 3; @@ -516,13 +515,13 @@ message LayerConfig { optional int32 axis = 54 [ default = 2 ]; repeated uint32 offset = 55; repeated uint32 shape = 56; - + // for HuberRegressionLoss optional double delta = 57 [ default = 1.0 ]; optional uint64 depth = 58 [ default = 1 ]; - - // for switch order layer + + // for switch order layer optional ReshapeConfig reshape_conf = 59; }