Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ignore im2col if not necessary in conv 1 * 1 #3163

Merged
merged 2 commits into from
Aug 3, 2017

Conversation

NHZlX
Copy link
Contributor

@NHZlX NHZlX commented Aug 2, 2017

ignore im2col when this happens in conv:

  • filteHeight == 1 && filterWidth == 1
  • paddingH == 0 && paddingW == 0
  • strideH == 1 && strideW == 1

I did a test of mobilenet forward process on Developing machine(Intel(R) Xeon(R) CPU E5-2620 v2 @ 2.10GHz), the average cost time on 100 times forward pass of one batch is about 15 ms faster than the normal one.

withIm2col average time
True 0.2359s
Fase 0.22018s

@NHZlX NHZlX requested a review from hedaoyuan August 2, 2017 07:52
@@ -66,16 +66,23 @@ class GemmConvFunction : public ConvFunctionBase {
real* inputData = inputs[0].data<real>();
real* filterData = inputs[1].data<real>();
real* outputData = outputs[0].data<real>();
bool skipIm2col = isSkipIm2col(filter);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

skipIm2col换成needIm2col是否更好理解?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, sounds great.

resizeBuffer<Device>(colShape.getElements());
real* colData = reinterpret_cast<real*>(memory_->getBuf());
TensorShape colShape;
real *colBuffer, *colData = NULL;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

一行定义一个变量。实际上只需要一个colData变量就够了。
97行写成

if (needIm2col) {
  im2col(colData);
} else {
  colData = inputData + g * inputOffset;
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -182,6 +200,12 @@ class GemmConvGradInputFunction : public ConvFunctionBase {
int K = outputChannels / groups_;
int N = outputHeight * outputWidth;
int M = inputChannels / groups_ * filterHeight * filterWidth;
colBuffer = colData;
real scale = 0.0f;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

203-204写到if(skipIm2col)...else里面吧

Copy link
Contributor Author

@NHZlX NHZlX Aug 2, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

real scale = 0.0 没必要放在if else 中吧

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯,把colBuffer去掉就行了。

@NHZlX NHZlX merged commit 94cee3d into PaddlePaddle:develop Aug 3, 2017
@NHZlX NHZlX deleted the fix_conv_1x1 branch August 3, 2017 03:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants