Skip to content

Commit

Permalink
[Cherry pick] fix fold for big bs (#49491)
Browse files Browse the repository at this point in the history
* fix fold for large bs

* fix fold for large bs

* fix pre-commit
  • Loading branch information
tink2123 authored Jan 3, 2023
1 parent d7855fe commit 2a438b0
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 60 deletions.
7 changes: 2 additions & 5 deletions paddle/phi/kernels/impl/fold_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,8 @@ void FoldGradKernel(const Context& ctx,

DDim out_shape =
make_ddim({n_output_plane, output_sizes[0], output_sizes[1]});
DDim input_matrix_shape = make_ddim({x_dims[0],
kernel_sizes[0],
kernel_sizes[1],
output_height,
output_width});
DDim input_matrix_shape = make_ddim(
{1, kernel_sizes[0], kernel_sizes[1], output_height, output_width});

paddle::operators::math::
Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, Context, T>
Expand Down
8 changes: 3 additions & 5 deletions paddle/phi/kernels/impl/fold_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,18 +56,16 @@ void FoldKernel(const Context& ctx,
DDim output_shape =
make_ddim({n_output_plane, output_sizes[0], output_sizes[1]});

DDim input_matrix_shape = make_ddim({x_dims[0],
kernel_sizes[0],
kernel_sizes[1],
output_height,
output_width});
DDim input_matrix_shape = make_ddim(
{1, kernel_sizes[0], kernel_sizes[1], output_height, output_width});

phi::funcs::SetConstant<Context, T> set_zero;
set_zero(ctx, out, static_cast<T>(0));

for (int i = 0; i < batch_size; i++) {
DenseTensor out_batch =
out->Slice(i, i + 1).Resize(output_shape); // im size=3

DenseTensor in_batch =
x.Slice(i, i + 1).Resize(input_matrix_shape); // col size=5
col2im(ctx, in_batch, dilations, strides, paddings, &out_batch);
Expand Down
153 changes: 103 additions & 50 deletions python/paddle/fluid/tests/unittests/test_fold_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import print_function


import math
import numpy as np
import unittest
Expand Down Expand Up @@ -45,34 +46,64 @@ def init_data(self):
def calc_fold(self):
output_shape = [0] * 4
output_shape[0] = self.batch_size
output_shape[1] = int(self.input_channels /
(self.kernel_sizes[0] * self.kernel_sizes[1]))
output_shape[1] = int(
self.input_channels / (self.kernel_sizes[0] * self.kernel_sizes[1])
)
output_shape[2] = self.output_sizes[0]
output_shape[3] = self.output_sizes[1]
dkernel_h = self.dilations[0] * (self.kernel_sizes[0] - 1) + 1
dkernel_w = self.dilations[1] * (self.kernel_sizes[1] - 1) + 1
col_height = int((self.output_sizes[0] + self.paddings[0] +
self.paddings[2] - dkernel_h) / self.strides[0]) + 1
col_width = int((self.output_sizes[1] + self.paddings[1] +
self.paddings[3] - dkernel_w) / self.strides[1]) + 1
col_height = (
int(
(
self.output_sizes[0]
+ self.paddings[0]
+ self.paddings[2]
- dkernel_h
)
/ self.strides[0]
)
+ 1
)
col_width = (
int(
(
self.output_sizes[1]
+ self.paddings[1]
+ self.paddings[3]
- dkernel_w
)
/ self.strides[1]
)
+ 1
)
output = np.zeros(output_shape).astype(np.float64)
############ calculate output ##############
for b in range(output_shape[0]):
for c in range(self.input_channels):
w_offset = int(c % self.kernel_sizes[1])
h_offset = int(
(c / self.kernel_sizes[1]) % self.kernel_sizes[0])
(c / self.kernel_sizes[1]) % self.kernel_sizes[0]
)
c_out = int(c / self.kernel_sizes[0] / self.kernel_sizes[1])
for h in range(col_height):
h_out = int(h * self.strides[0] - self.paddings[0] +
h_offset * self.dilations[0])
h_out = int(
h * self.strides[0]
- self.paddings[0]
+ h_offset * self.dilations[0]
)
for w in range(col_width):
w_out = int(w * self.strides[1] - self.paddings[1] +
w_offset * self.dilations[1])
w_out = int(
w * self.strides[1]
- self.paddings[1]
+ w_offset * self.dilations[1]
)
if (h_out >= 0 and h_out < self.output_sizes[0]) and (
w_out >= 0 and w_out < self.output_sizes[1]):
output[b, c_out, h_out,
w_out] += self.x[b, c, w + col_width * h]
w_out >= 0 and w_out < self.output_sizes[1]
):
output[b, c_out, h_out, w_out] += self.x[
b, c, w + col_width * h
]

self.outputs = output

Expand All @@ -85,7 +116,7 @@ def set_data(self):
'paddings': self.paddings,
'dilations': self.dilations,
'strides': self.strides,
'output_sizes': self.output_sizes
'output_sizes': self.output_sizes,
}
self.outputs = {'Y': self.outputs}

Expand All @@ -101,9 +132,23 @@ def test_check_grad(self):
self.check_grad(['X'], 'Y', check_eager=True)


class TestFoldshape(TestFoldOp):
def init_data(self):
self.batch_size = 8
self.input_channels = 3 * 3 * 3
self.length = 6
self.kernel_sizes = [3, 3]
self.strides = [1, 1]
self.paddings = [0, 0, 0, 0]
self.dilations = [1, 1]
self.output_sizes = [4, 5]
input_shape = [self.batch_size, self.input_channels, self.length]
self.x = np.random.rand(*input_shape).astype(np.float64)


class TestFoldAPI(TestFoldOp):

#This is for test on paddle.nn.Fold
# This is for test on paddle.nn.Fold

def setUp(self):
self.op_type = 'fold'
Expand All @@ -120,19 +165,19 @@ def test_api(self):
m = paddle.nn.Fold(**self.attrs)
m.eval()
result = m(input)
np.testing.assert_allclose(result.numpy(),
self.outputs['Y'],
rtol=1e-05)
np.testing.assert_allclose(
result.numpy(), self.outputs['Y'], rtol=1e-05
)

def test_info(self):
str(paddle.nn.Fold(**self.attrs))


class TestFoldOpError(unittest.TestCase):

def test_errors(self):
from paddle.nn.functional import fold
from paddle.fluid.framework import Program, program_guard

with program_guard(Program(), Program()):

def test_input_shape():
Expand All @@ -148,59 +193,67 @@ def test_kernel_shape():
def test_padding_shape():
# padding_size must be 2 or 4
x = paddle.randn(shape=[2, 6, 6], dtype="float32")
out = fold(x,
output_sizes=[2, 3],
kernel_sizes=[2, 2],
paddings=[2, 2, 3])
out = fold(
x,
output_sizes=[2, 3],
kernel_sizes=[2, 2],
paddings=[2, 2, 3],
)

def test_dilations_shape():
# dialtions_size must be 2
x = paddle.randn(shape=[2, 6, 6], dtype="float32")
out = fold(x,
output_sizes=[2, 3],
kernel_sizes=[2, 2],
dilations=[2, 2, 3])
out = fold(
x,
output_sizes=[2, 3],
kernel_sizes=[2, 2],
dilations=[2, 2, 3],
)

def test_strides_shape():
# strids_size must be 2
x = paddle.randn(shape=[2, 6, 6], dtype="float32")
out = fold(x,
output_sizes=[2, 3],
kernel_sizes=[2, 2],
strides=[2, 2, 3])
out = fold(
x,
output_sizes=[2, 3],
kernel_sizes=[2, 2],
strides=[2, 2, 3],
)

def test_output_size():
# im_h * im_w must be L
x = paddle.randn(shape=[2, 6, 6], dtype="float32")
out = fold(x,
output_sizes=[6, 6],
kernel_sizes=[2, 2],
strides=[1, 1])
out = fold(
x, output_sizes=[6, 6], kernel_sizes=[2, 2], strides=[1, 1]
)

def test_output_size_2():
# out_size must GT 1
x = paddle.randn(shape=[2, 6, 6], dtype="float32")
out = fold(x,
output_sizes=[0.1, 0.2],
kernel_sizes=[2, 2],
strides=[1, 1])
out = fold(
x,
output_sizes=[0.1, 0.2],
kernel_sizes=[2, 2],
strides=[1, 1],
)

def test_block_h_w():
# test_block_h_w GT 0
x = paddle.randn(shape=[2, 1, 1], dtype="float32")
out = fold(x,
output_sizes=[1, 1],
kernel_sizes=[2, 2],
strides=1)
out = fold(
x, output_sizes=[1, 1], kernel_sizes=[2, 2], strides=1
)

def test_GT_0():
x = paddle.randn(shape=[2, 1, 1], dtype="float32")
out = fold(x,
output_sizes=[0, 0],
kernel_sizes=[0, 0],
dilations=0,
paddings=[0, 0],
strides=0)
out = fold(
x,
output_sizes=[0, 0],
kernel_sizes=[0, 0],
dilations=0,
paddings=[0, 0],
strides=0,
)

self.assertRaises(AssertionError, test_input_shape)
self.assertRaises(AssertionError, test_kernel_shape)
Expand Down

0 comments on commit 2a438b0

Please sign in to comment.