Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Non-square Filters and Separated Stride and Padding #505

Merged
merged 5 commits into from
Jul 29, 2014

Conversation

shelhamer
Copy link
Member

  • add non-square kernel size, padding, and stride fields
  • check these parameters
  • add rectangular im2col test
  • fix checks -- see Rectangular pooling #614
  • fix GPU rectangular im2col (thanks @ejaz-izy!)
  • add separable filter test case for convolution layer.

Accept pairs of height/Y and width/X values for kernel size, stride, and pad in lieu of a single, shared value.

Of course the square/equal case still works and the old defaults (stride = 1 and padding = 0) are kept.

Open to comments on implementation and style.

stride_ = this->layer_param_.convolution_param().stride();
pad_ = this->layer_param_.convolution_param().pad();
ConvolutionParameter conv_param = this->layer_param_.convolution_param();
kernel_height_ = conv_param.kernel_size(0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should probably CHECK_GT(conv_param.kernel_size_size(), 0) first. Also CHECK_GT(kernel_height_ * kernel_width_, 0) after setting those -- we should have had a similar check all along really (since kernel_size defaults to 0).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Please see my comments regarding protobuf fields)

Maybe we can save old format by having:

if (conv_param.has_kernel_size()) {
kernel_height_ = conv_param.kernel_size();
kernel_height_ = conv_param.kernel_size();
} else {
CHECK_EQ(conv_param.rectangular_kernel_size_size(), 2)
<< "Must specify either kernel_size or rectangular_kernel_size (2 numbers).";
kernel_height_ = conv_param.rectangular_kernel_size(0);
kernel_height_ = conv_param.rectangular_kernel_size(1);
}

This may provide maximum backward compatibility.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. I've added these checks.

Le mardi 17 juin 2014, Jeff Donahue notifications@github.com a écrit :

In src/caffe/layers/im2col_layer.cpp:

@@ -13,15 +13,28 @@ template
void Im2colLayer::SetUp(const vector<Blob>& bottom,
vector<Blob
>* top) {
Layer::SetUp(bottom, top);

  • kernel_size_ = this->layer_param_.convolution_param().kernel_size();
  • stride_ = this->layer_param_.convolution_param().stride();
  • pad_ = this->layer_param_.convolution_param().pad();
  • ConvolutionParameter conv_param = this->layer_param_.convolution_param();
  • kernel_height_ = conv_param.kernel_size(0);

should probably CHECK_GT(conv_param.kernel_size_size(), 0) first. Also CHECK_GT(kernel_height_

  • kernel_width_, 0) after setting those -- we should have had a similar
    check all along really (since kernel_size defaults to 0).


Reply to this email directly or view it on GitHub
https://github.com/BVLC/caffe/pull/505/files#r13839327.

@Yangqing
Copy link
Member

Have fun in Beijing Jeff and Evan :)

@shelhamer
Copy link
Member Author

Thanks Yangqing, and thanks for the advice about the fields.

I'll introduce kernel_size_y, kernel_size_x, etc. I've added tests too, but
I still have work to do because my GPU im2col seems to be wrong...

Le jeudi 19 juin 2014, Yangqing Jia notifications@github.com a écrit :

Have fun in Beijing Jeff and Evan :)


Reply to this email directly or view it on GitHub
#505 (comment).

@shelhamer
Copy link
Member Author

Rebased to improve interface per #505 (comment).

However, this isn't done. See todo list in the PR message. I'll give this another shot after I finish some other work.

@buaaliyi you might try to debug the rectangular im2col test failure if you are still interested in this feature.

@rmanor
Copy link
Contributor

rmanor commented Jul 2, 2014

Hi, I'd be happy to help if needed.
Thanks.

@ronghanghu ronghanghu mentioned this pull request Jul 4, 2014
ronghanghu referenced this pull request in ronghanghu/caffe Jul 5, 2014
@rmanor
Copy link
Contributor

rmanor commented Jul 10, 2014

Hey guys, I don't want to push, but are there any plans for this? Just to know if I should wait or not.
Thanks!

@shelhamer
Copy link
Member Author

This has stalled for now. If you debug the GPU mode im2col please
contribute the fix!

Le jeudi 10 juillet 2014, rmanor notifications@github.com a écrit :

Hey guys, I don't want to push, but are there any plans for this? Just to
know if I should wait or not.
Thanks!


Reply to this email directly or view it on GitHub
#505 (comment).

@rmanor
Copy link
Contributor

rmanor commented Jul 12, 2014

@shelhamer I spent a few hours on this and still haven't figured it out. Still trying...
Quick question though, was there a good reason to change the order of iteration between the cpu and gpu implementations of im2col?
Also, how is it possible that the stride value doesn't change the values in the test?
I don't see it taken into consideration when comparing the output to the input and the test succeeds if I change the stride values.
Shouldn't the values be different since the input index progresses in jumps of [stride]?

@ejaz-izy
Copy link

Hello everyone,

I was looking to implement this functionality myself and came across this thread. Thanks @shelhamer for starting this. I found your implementation very helpful. As you mentioned it was failing few tests. I found some minor issues with your implementation which was causing the problem. They are listed below -

  1. In function col2im_gpu_kernel (file - im2col.cu) -
    // int offset = (c * patch_h * patch_w + h * patch_h + w) * height_col * width_col;
    should be changed to
    int offset = (c * patch_h * patch_w + h * patch_w + w) * height_col * width_col;
    Explanation - let us assume for simplicity that c = 0 and stride_w = stride_h = 1; then to find where h,w is found in data_col for location hcol, wcol we need to do the following -
    data_col + hcol_width_out + wcol + ((h-hcol)_patch_w+(w-wcol))*height_out * width_out
    If we expand it and rearrange we will find that it should be patch_w instead of patch_h
  2. In function im2col_cpu (file im2col.cpp)
    // int w_offset = c % patch_h;
    // int h_offset = (c / patch_h) % patch_h;
    should be changed to
    int w_offset = c % patch_w;
    int h_offset = (c / patch_w) % patch_h;
  3. In function col2im_cpu (file im2col.cpp)
    // int w_offset = c % patch_h;
    // int h_offset = (c / patch_h) % patch_h;
    should be changed to
    int w_offset = c % patch_w;
    int h_offset = (c / patch_w) % patch_h;
  4. TestRectCPU (file - test_im2col_layer.cpp)
    // EXPECT_EQ(this->blob_top_->data_at(0, c, 0, 0),
    // this->blob_bottom_->data_at(0, (c / 15), (c / 5) % 5, c % 5));
    should be changed to
    EXPECT_EQ(this->blob_top_->data_at(0, c, 0, 0),
    this->blob_bottom_->data_at(0, (c / 15), (c / 3) % 5, c % 3));
  5. TestRectGPU (file test_im2col_layer.cpp)
    // EXPECT_EQ(this->blob_top_->data_at(0, c, 0, 0),
    // this->blob_bottom_->data_at(0, (c / 15), (c / 5) % 5, c % 5));
    should be changed to
    EXPECT_EQ(this->blob_top_->data_at(0, c, 0, 0),
    this->blob_bottom_->data_at(0, (c / 15), (c / 3) % 5, c % 3));
    Explanation for 2,3,4 and 5 is same, for w_offset mod with patch_w should be taken and the rest follows accordingly.

With the above changes I ran the tests and it passed all the tests!

Sorry for writing a long response and thanks again for opening this issue.

@shelhamer
Copy link
Member Author

@ejaz-izy thank you for taking a look and figuring out the issues! I will incorporate your fixes and finish up this PR with a separable convolution test that convolves with rectangular filters. I'll be sure to credit you in the commit message once this is done and rebased for merge.

I can confirm your fixes are correct -- I'm happy many-eyes came to the rescue here since I couldn't spot my mistake.

while keeping everything working as-is.
Compute the G_x kernel of the Sobel operator as a full filter and as
separable filters to check the rectangular filter output.
shelhamer added a commit that referenced this pull request Jul 29, 2014
Non-square Filters and Separated Stride and Padding
@shelhamer shelhamer merged commit 5542cf7 into BVLC:dev Jul 29, 2014
@shelhamer shelhamer deleted the non-square-filters branch July 29, 2014 08:37
@rmanor
Copy link
Contributor

rmanor commented Jul 29, 2014

Thanks @ejaz-izy!

  • Ran

On Sun, Jul 27, 2014 at 3:13 PM, ejaz-izy notifications@github.com wrote:

Hello everyone,

I was looking to implement this functionality myself and came across this
thread. Thanks @shelhamer https://github.com/shelhamer for starting
this. I found your implementation very helpful. As you mentioned it was
failing few tests. I found some minor issues with your implementation which
was causing the problem. They are listed below -

  1. In function col2im_gpu_kernel (file - im2col.cu) - // int offset =
    (c * patch_h * patch_w + h * patch_h + w) * height_col * width_col; should
    be changed to int offset = (c * patch_h * patch_w + h * patch_w + w) *
    height_col * width_col;

Explanation - let us assume for simplicity that c = 0 and stride_w =
stride_h = 1; then to find where h,w is found in data_col for location
hcol, wcol we need to do the following -
data_col + hcol_width_out + wcol + ((h-hcol)_patch_w+(w-wcol))*height_out

  • width_out
    If we expand it and rearrange we will find that it should be patch_w
    instead of patch_h

    In function im2col_cpu (file im2col.cpp)
    // int w_offset = c % patch_h;
    // int h_offset = (c / patch_h) % patch_h;
    should be changed to
    int w_offset = c % patch_w;
    int h_offset = (c / patch_w) % patch_h;
    2.

    In function col2im_cpu (file im2col.cpp)
    // int w_offset = c % patch_h;
    // int h_offset = (c / patch_h) % patch_h;
    should be changed to
    int w_offset = c % patch_w;
    int h_offset = (c / patch_w) % patch_h;
    3.

    TestRectCPU (file - test_im2col_layer.cpp)
    // EXPECT_EQ(this->blob_top_->data_at(0, c, 0, 0),
    // this->blob_bottom_->data_at(0, (c / 15), (c / 5) % 5, c % 5));
    should be changed to
    EXPECT_EQ(this->blob_top_->data_at(0, c, 0, 0),
    this->blob_bottom_->data_at(0, (c / 15), (c / 3) % 5, c % 3));
    4.

    TestRectGPU (file test_im2col_layer.cpp)
    // EXPECT_EQ(this->blob_top_->data_at(0, c, 0, 0),
    // this->blob_bottom_->data_at(0, (c / 15), (c / 5) % 5, c % 5));
    should be changed to
    EXPECT_EQ(this->blob_top_->data_at(0, c, 0, 0),
    this->blob_bottom_->data_at(0, (c / 15), (c / 3) % 5, c % 3));

Explanation for 2,3,4 and 5 is same, for w_offset mod with patch_w should
be taken and the rest follows accordingly.

With the above changes I ran the tests and it passed all the tests!

Sorry for writing a long response and thanks again for opening this issue.


Reply to this email directly or view it on GitHub
#505 (comment).

@shelhamer shelhamer mentioned this pull request Aug 7, 2014
mitmul pushed a commit to mitmul/caffe that referenced this pull request Sep 30, 2014
mitmul pushed a commit to mitmul/caffe that referenced this pull request Sep 30, 2014
Non-square Filters and Separated Stride and Padding
RazvanRanca pushed a commit to RazvanRanca/caffe that referenced this pull request Nov 4, 2014
RazvanRanca pushed a commit to RazvanRanca/caffe that referenced this pull request Nov 4, 2014
Non-square Filters and Separated Stride and Padding
aiworld pushed a commit to aiworld/aiworld.github.com that referenced this pull request Feb 28, 2015
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants