-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ignore im2col if not necessary in conv 1 * 1 #3163
Conversation
paddle/function/GemmConvOp.cpp
Outdated
@@ -66,16 +66,23 @@ class GemmConvFunction : public ConvFunctionBase { | |||
real* inputData = inputs[0].data<real>(); | |||
real* filterData = inputs[1].data<real>(); | |||
real* outputData = outputs[0].data<real>(); | |||
bool skipIm2col = isSkipIm2col(filter); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
skipIm2col
换成needIm2col
是否更好理解?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, sounds great.
paddle/function/GemmConvOp.cpp
Outdated
resizeBuffer<Device>(colShape.getElements()); | ||
real* colData = reinterpret_cast<real*>(memory_->getBuf()); | ||
TensorShape colShape; | ||
real *colBuffer, *colData = NULL; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
一行定义一个变量。实际上只需要一个colData变量就够了。
97行写成
if (needIm2col) {
im2col(colData);
} else {
colData = inputData + g * inputOffset;
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/function/GemmConvOp.cpp
Outdated
@@ -182,6 +200,12 @@ class GemmConvGradInputFunction : public ConvFunctionBase { | |||
int K = outputChannels / groups_; | |||
int N = outputHeight * outputWidth; | |||
int M = inputChannels / groups_ * filterHeight * filterWidth; | |||
colBuffer = colData; | |||
real scale = 0.0f; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
203-204写到if(skipIm2col)...else里面吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
real scale = 0.0 没必要放在if else 中吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
嗯,把colBuffer去掉就行了。
ignore im2col when this happens in conv:
I did a test of mobilenet forward process on Developing machine(Intel(R) Xeon(R) CPU E5-2620 v2 @ 2.10GHz), the average cost time on 100 times forward pass of one batch is about 15 ms faster than the normal one.