Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Conv3DLayer, DeConv3DLayer #3348

Merged
merged 22 commits into from
Aug 31, 2017
Merged

Conversation

chengduoZH
Copy link
Contributor

@chengduoZH chengduoZH commented Aug 9, 2017

fix #3359
Add Conv3DLayer, DeConv3DLayer.

  • CPU implementation.
  • GPU implementation.
  • Add test_Conv3D and Test_DeConv3D unit test.
  • Pass CPU unit test
  • Pass GPU unit test
  • Add config parser for Conv3D and DeConv3D

@chengduoZH chengduoZH changed the title add Conv3DLayer, DeConv3DLayer, Test_Conv3D Add Conv3DLayer, DeConv3DLayer Aug 10, 2017
@chengduoZH chengduoZH force-pushed the develop branch 2 times, most recently from b9f7008 to defb286 Compare August 13, 2017 01:25
Copy link
Contributor

@hedaoyuan hedaoyuan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -1039,6 +1039,42 @@ class Matrix : public BaseMatrix {
LOG(FATAL) << "Not implemented";
}

virtual void vol2Col(real* data,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_matrixCompare.cpp里面增加一下vol2Col,col2Vol的CPU和GPU实现check吧。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

conf.filter_size_y() * conf.filter_size());
weights_[index]->getW()->reshape(
weights_[index]->getW()->getWidth(),
weights_[index]->getW()->getHeight());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

34-36行的reshape是做什么用的?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在ConqBaseLayer中创建的weights_的shape在Conv3DLayer中不适用,所以在Conv3DLayer中需要把weights_的width和height颠倒过来,不过这个在最新的提交中做了做了调整,即把创建weights_的操作写在派生类中

MatrixPtr wMatSub = wMat->subMatrix(g * M, M);
MatrixPtr in = colBuf_->subMatrix(g * K, K);
MatrixPtr out = outMatSub->subMatrix(g * M, M);
out->mul(*wMatSub, *in, 1.0, 0.0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0.0需要修改成1.0,否则inputLayers_.size() > 1时就错了。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

isDeconv_ = (config_.type() == "exconv" ||
config_.type() == "cudnn_conv" ||
config_.type() == "conv3d" ||
config_.type() == "deconv3d" )
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

27行不能有。

Copy link
Contributor Author

@chengduoZH chengduoZH Aug 18, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个得有,不然会出错,在DeConv3D中filter数据排布方式是: (output * depth * height * weigth) x channel,所以
height = filterPixels_[i] * filterChannels_[i]; width = (!isDeconv_) ? numFilters_ : channels_[i];中width要等与numFilters_

wGradData += K * M;
inData += M * N;
}
weights_[i]->getParameterPtr()->incUpdate(callback);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

171行放到batchSize循环外面。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

@qingqing01 qingqing01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still need:

  1. parser in python/paddle/trainer/config_parser.py
  2. Python interface



__global__ void keMatrixVol2Col(
int num_kernels, real*data, real* data_col,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

realdata -> const real data

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

const int threads = 512;
const int blocks = DIVUP(num_kernels, threads);

keMatrixVol2Col<<<blocks, threads>>>(num_kernels,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the STREAM_DEFAULT:

keMatrixVol2Col <<< blocks, threads, 0, STREAM_DEFAULT>>>

The STREAM_DEFAULT is defined as follows. It's a thread local stream.

./cuda/include/hl_base.h:215:extern __thread cudaStream_t default_stream;
./cuda/include/hl_base.h:216:#define STREAM_DEFAULT default_stream

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

const int threads = 512;
const int blocks = DIVUP(num_kernels, threads);

keMatrixCol2Vol<<<blocks, threads>>>(num_kernels,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above comments.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

if (nullptr != weights_[index]->getWGrad())
weights_[index]->getWGrad()->reshape(
weights_[index]->getWGrad()->getWidth(),
weights_[index]->getWGrad()->getHeight());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to modify the weight shape in the base layer. Or move the creation in the base class and create weight in the inherited class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

weights_[index]->getWGrad()->getHeight());
++index;
}
CHECK(inputLayers_.size() == parameters_.size());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is not correct. Remove this check. If inputLayers_.size() == 1 and this layer has bias, the parameters_.size() == 2, so it is not correct.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the 58 line of ConqbaseLayer.cpp should be removed?

} else {
outMat->addBias(*(biases_->getW()), 1.0f);
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, the same problem as the implementing 2D-Deconv before. Many similar code with 3D-Conv.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hedaoyuan How do you think of this problem?

REGISTER_LAYER(deconv3d, DeConv3DLayer);

#define DECONV_OUTPUT_SIZE(IN_SIZE, STRID, PAD, KSIZE) \
(((IN_SIZE)-1) * (STRID)-2 * (PAD) + (KSIZE))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not the use imageSize() in ./paddle/math/MathUtils.h?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Sorry, I'm not familiar with ./paddle/math/MathUtils.h

}

int deConvOutputSize(int inSize, int kSize, int pad, int stride) {
return (inSize - 1) * stride - 2 * pad + kSize;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not the use imageSize() in ./paddle/math/MathUtils.h?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has been modified

@@ -1203,4 +1203,120 @@ TEST(Matrix, warpCTC) {
}
}

int outputSizeCol2Vol(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above comments. ./paddle/math/MathUtils.h has contains the calculation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

for (auto depth : {9, 16, 64, 128}) {
for (auto height : {9, 11, 73, 128, 256}) {
for (auto width : {
9, 32, 100, 512,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For speeding testing, please use the smaller number for depth, height, width. But please test larger number off-line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@chengduoZH chengduoZH force-pushed the develop branch 4 times, most recently from f629486 to bbbedf9 Compare August 21, 2017 06:56
@chengduoZH chengduoZH force-pushed the develop branch 3 times, most recently from fd3da80 to b608d39 Compare August 28, 2017 03:35
@chengduoZH chengduoZH force-pushed the develop branch 8 times, most recently from c5246f1 to fe63eb1 Compare August 29, 2017 14:45
@chengduoZH chengduoZH force-pushed the develop branch 2 times, most recently from 56d669f to 4dcc533 Compare August 30, 2017 01:13
hedaoyuan
hedaoyuan previously approved these changes Aug 30, 2017
qingqing01
qingqing01 previously approved these changes Aug 30, 2017
@chengduoZH chengduoZH dismissed stale reviews from qingqing01 and hedaoyuan via 2ae37a4 August 30, 2017 16:28
hedaoyuan
hedaoyuan previously approved these changes Aug 31, 2017
@chengduoZH chengduoZH merged commit 1e6c992 into PaddlePaddle:develop Aug 31, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support 3-Dimensional Conv3DLayer, DeConv3DLayer
3 participants