-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Conv3DLayer, DeConv3DLayer #3348
Conversation
b9f7008
to
defb286
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -1039,6 +1039,42 @@ class Matrix : public BaseMatrix { | |||
LOG(FATAL) << "Not implemented"; | |||
} | |||
|
|||
virtual void vol2Col(real* data, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在test_matrixCompare.cpp
里面增加一下vol2Col,col2Vol的CPU和GPU实现check吧。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
conf.filter_size_y() * conf.filter_size()); | ||
weights_[index]->getW()->reshape( | ||
weights_[index]->getW()->getWidth(), | ||
weights_[index]->getW()->getHeight()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
34-36行的reshape是做什么用的?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在ConqBaseLayer中创建的weights_的shape在Conv3DLayer中不适用,所以在Conv3DLayer中需要把weights_的width和height颠倒过来,不过这个在最新的提交中做了做了调整,即把创建weights_的操作写在派生类中
MatrixPtr wMatSub = wMat->subMatrix(g * M, M); | ||
MatrixPtr in = colBuf_->subMatrix(g * K, K); | ||
MatrixPtr out = outMatSub->subMatrix(g * M, M); | ||
out->mul(*wMatSub, *in, 1.0, 0.0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
0.0需要修改成1.0,否则inputLayers_.size() > 1时就错了。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
isDeconv_ = (config_.type() == "exconv" || | ||
config_.type() == "cudnn_conv" || | ||
config_.type() == "conv3d" || | ||
config_.type() == "deconv3d" ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
27行不能有。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个得有,不然会出错,在DeConv3D中filter数据排布方式是: (output * depth * height * weigth) x channel,所以
height = filterPixels_[i] * filterChannels_[i]; width = (!isDeconv_) ? numFilters_ : channels_[i];
中width要等与numFilters_
wGradData += K * M; | ||
inData += M * N; | ||
} | ||
weights_[i]->getParameterPtr()->incUpdate(callback); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
171行放到batchSize循环外面。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still need:
- parser in
python/paddle/trainer/config_parser.py
- Python interface
paddle/cuda/src/hl_cuda_matrix.cu
Outdated
|
||
|
||
__global__ void keMatrixVol2Col( | ||
int num_kernels, real*data, real* data_col, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
realdata -> const real data
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/cuda/src/hl_cuda_matrix.cu
Outdated
const int threads = 512; | ||
const int blocks = DIVUP(num_kernels, threads); | ||
|
||
keMatrixVol2Col<<<blocks, threads>>>(num_kernels, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use the STREAM_DEFAULT:
keMatrixVol2Col <<< blocks, threads, 0, STREAM_DEFAULT>>>
The STREAM_DEFAULT is defined as follows. It's a thread local stream.
./cuda/include/hl_base.h:215:extern __thread cudaStream_t default_stream;
./cuda/include/hl_base.h:216:#define STREAM_DEFAULT default_stream
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/cuda/src/hl_cuda_matrix.cu
Outdated
const int threads = 512; | ||
const int blocks = DIVUP(num_kernels, threads); | ||
|
||
keMatrixCol2Vol<<<blocks, threads>>>(num_kernels, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
if (nullptr != weights_[index]->getWGrad()) | ||
weights_[index]->getWGrad()->reshape( | ||
weights_[index]->getWGrad()->getWidth(), | ||
weights_[index]->getWGrad()->getHeight()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's better to modify the weight shape in the base layer. Or move the creation in the base class and create weight in the inherited class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
weights_[index]->getWGrad()->getHeight()); | ||
++index; | ||
} | ||
CHECK(inputLayers_.size() == parameters_.size()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line is not correct. Remove this check. If inputLayers_.size() == 1
and this layer has bias, the parameters_.size() == 2
, so it is not correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the 58 line of ConqbaseLayer.cpp should be removed?
} else { | ||
outMat->addBias(*(biases_->getW()), 1.0f); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In fact, the same problem as the implementing 2D-Deconv before. Many similar code with 3D-Conv.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@hedaoyuan How do you think of this problem?
REGISTER_LAYER(deconv3d, DeConv3DLayer); | ||
|
||
#define DECONV_OUTPUT_SIZE(IN_SIZE, STRID, PAD, KSIZE) \ | ||
(((IN_SIZE)-1) * (STRID)-2 * (PAD) + (KSIZE)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not the use imageSize()
in ./paddle/math/MathUtils.h
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Sorry, I'm not familiar with ./paddle/math/MathUtils.h
} | ||
|
||
int deConvOutputSize(int inSize, int kSize, int pad, int stride) { | ||
return (inSize - 1) * stride - 2 * pad + kSize; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not the use imageSize()
in ./paddle/math/MathUtils.h
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It has been modified
@@ -1203,4 +1203,120 @@ TEST(Matrix, warpCTC) { | |||
} | |||
} | |||
|
|||
int outputSizeCol2Vol( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above comments. ./paddle/math/MathUtils.h
has contains the calculation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
for (auto depth : {9, 16, 64, 128}) { | ||
for (auto height : {9, 11, 73, 128, 256}) { | ||
for (auto width : { | ||
9, 32, 100, 512, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For speeding testing, please use the smaller number for depth, height, width. But please test larger number off-line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
f629486
to
bbbedf9
Compare
fd3da80
to
b608d39
Compare
c5246f1
to
fe63eb1
Compare
56d669f
to
4dcc533
Compare
fix #3359
Add Conv3DLayer, DeConv3DLayer.