Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Api2.0] add pixel shuffle class #26071

Merged
merged 19 commits into from
Aug 24, 2020
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 66 additions & 30 deletions paddle/fluid/operators/pixel_shuffle_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,40 +28,59 @@ class PixelShuffleOp : public framework::OperatorWithKernel {
"Output(Out) of PixelShuffleOp should not be null."));

auto input_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(
input_dims.size(), 4,
platform::errors::InvalidArgument(
"Input should be a 4-D tensor of format [N, C, H, W], but got %u.",
input_dims.size()));
PADDLE_ENFORCE_EQ(input_dims.size(), 4,
platform::errors::InvalidArgument(
"Input should be a 4-D tensor of format [N, C, H, W] "
"or [N, H, W, C], but got %u.",
input_dims.size()));

auto upscale_factor = ctx->Attrs().Get<int>("upscale_factor");

PADDLE_ENFORCE_EQ(input_dims[1] % (upscale_factor * upscale_factor), 0,
platform::errors::InvalidArgument(
"The square of upscale_factor[%u] should divide the "
"number of channel[%u]",
input_dims[1], upscale_factor * upscale_factor));

const std::string data_format =
ctx->Attrs().Get<std::string>("data_format");
const bool channel_last = (data_format == "NHWC");

if (!channel_last) {
PADDLE_ENFORCE_EQ(
input_dims[1] % (upscale_factor * upscale_factor), 0,
platform::errors::InvalidArgument(
"The square of upscale_factor[%u] should divide the "
"number of channel[%u]",
input_dims[1], upscale_factor * upscale_factor));
} else {
PADDLE_ENFORCE_EQ(
input_dims[3] % (upscale_factor * upscale_factor), 0,
platform::errors::InvalidArgument(
"The square of upscale_factor[%u] should divide the "
"number of channel[%u]",
input_dims[3], upscale_factor * upscale_factor));
}
auto output_dims = input_dims;
output_dims[0] = input_dims[0];
output_dims[1] = input_dims[1] / (upscale_factor * upscale_factor);
output_dims[2] = input_dims[2] * upscale_factor;
output_dims[3] = input_dims[3] * upscale_factor;
if (!channel_last) {
output_dims[1] = input_dims[1] / (upscale_factor * upscale_factor);
output_dims[2] = input_dims[2] * upscale_factor;
output_dims[3] = input_dims[3] * upscale_factor;
} else {
output_dims[1] = input_dims[1] * upscale_factor;
output_dims[2] = input_dims[2] * upscale_factor;
output_dims[3] = input_dims[3] / (upscale_factor * upscale_factor);
}
ctx->SetOutputDim("Out", output_dims);
}
};

class PixelShuffleOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput(
"X",
"(Tensor, default Tensor<float>), "
"the input feature data of PixelShuffleOp, the layout is [N C H W].");
AddOutput(
"Out",
"(Tensor, default Tensor<float>), the output of "
"PixelShuffleOp. The layout is [N,C/factor^2,H*factor,W*factor].");
AddInput("X",
"(Tensor, default Tensor<float>), "
"the input feature data of PixelShuffleOp, the layout is [N, C, "
"H, W] or [N, H, W, C].");
AddOutput("Out",
"(Tensor, default Tensor<float>), the output of "
"PixelShuffleOp. The layout is [N, C/factor^2, H*factor, "
"W*factor] or [N, H*factor, W*factor, C/factor^2].");
AddAttr<int>("upscale_factor",
"the factor to increase spatial resolution by.")
.SetDefault(1)
Expand All @@ -70,6 +89,12 @@ class PixelShuffleOpMaker : public framework::OpProtoAndCheckerMaker {
platform::errors::InvalidArgument(
"upscale_factor should be larger than 0."));
});
AddAttr<std::string>(
"data_format",
"(string, default NCHW) Only used in "
shippingwang marked this conversation as resolved.
Show resolved Hide resolved
"An optional string from: \"NHWC\", \"NCHW\". "
"Defaults to \"NHWC\", Specify the data format of the input data.")
.SetDefault("NCHW");

AddComment(R"DOC(
Pixel Shuffle operator
Expand Down Expand Up @@ -114,19 +139,30 @@ class PixelShuffleGradOp : public framework::OperatorWithKernel {
platform::errors::NotFound("Output(X@Grad) should not be null"));

auto do_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE_EQ(
do_dims.size(), 4,
platform::errors::InvalidArgument(
"Input should be a 4-D tensor of format [N, C, H, W], but got %u.",
do_dims.size()));
PADDLE_ENFORCE_EQ(do_dims.size(), 4,
platform::errors::InvalidArgument(
"Input should be a 4-D tensor of format [N, C, H, W] "
"or [N, H, W, C], but got %u.",
do_dims.size()));

auto upscale_factor = ctx->Attrs().Get<int>("upscale_factor");

const std::string data_format =
ctx->Attrs().Get<std::string>("data_format");
const bool channel_last = (data_format == "NHWC");

auto dx_dims = do_dims;
dx_dims[0] = do_dims[0];
dx_dims[1] = do_dims[1] * (upscale_factor * upscale_factor);
dx_dims[2] = do_dims[2] / upscale_factor;
dx_dims[3] = do_dims[3] / upscale_factor;

if (!channel_last) {
dx_dims[1] = do_dims[1] * (upscale_factor * upscale_factor);
dx_dims[2] = do_dims[2] / upscale_factor;
dx_dims[3] = do_dims[3] / upscale_factor;
} else {
dx_dims[1] = do_dims[1] / upscale_factor;
dx_dims[2] = do_dims[2] / upscale_factor;
dx_dims[3] = do_dims[3] * (upscale_factor * upscale_factor);
}
ctx->SetOutputDim(framework::GradVarName("X"), dx_dims);
}
};
Expand Down
40 changes: 32 additions & 8 deletions paddle/fluid/operators/pixel_shuffle_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ limitations under the License. */

#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
Expand All @@ -24,23 +25,33 @@ class PixelShuffleOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto* out = ctx.Output<framework::Tensor>("Out");

out->mutable_data<T>(ctx.GetPlace());

int factor = ctx.Attr<int>("upscale_factor");

std::string data_format = ctx.Attr<std::string>("data_format");
bool channel_last = (data_format == "NHWC");

auto in_dims = in->dims();
auto o_dims = out->dims();

framework::Tensor t;
t.ShareDataWith(*in);
t.Resize({in_dims[0], o_dims[1], factor, factor, in_dims[2], in_dims[3]});

if (!channel_last) {
t.Resize({in_dims[0], o_dims[1], factor, factor, in_dims[2], in_dims[3]});
} else {
t.Resize({in_dims[0], in_dims[1], in_dims[2], o_dims[3], factor, factor});
}
std::vector<int> axis = {0, 1, 4, 2, 5, 3};

framework::Tensor o;
o.ShareDataWith(*out);
o.Resize({in_dims[0], o_dims[1], in_dims[2], factor, in_dims[3], factor});

if (!channel_last) {
o.Resize({in_dims[0], o_dims[1], in_dims[2], factor, in_dims[3], factor});
} else {
o.Resize({in_dims[0], in_dims[1], factor, in_dims[2], factor, o_dims[3]});
}
math::Transpose<DeviceContext, T, 6> trans;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
trans(dev_ctx, t, &o, axis);
Expand All @@ -58,19 +69,32 @@ class PixelShuffleGradOpKernel : public framework::OpKernel<T> {

int factor = ctx.Attr<int>("upscale_factor");

std::string data_format = ctx.Attr<std::string>("data_format");
bool channel_last = (data_format == "NHWC");

auto do_dims = dout->dims();
auto dx_dims = dx->dims();

framework::Tensor t;
t.ShareDataWith(*dout);
t.Resize({do_dims[0], do_dims[1], dx_dims[2], factor, dx_dims[3], factor});

if (!channel_last) {
t.Resize(
{do_dims[0], do_dims[1], dx_dims[2], factor, dx_dims[3], factor});
} else {
t.Resize(
{do_dims[0], dx_dims[1], factor, dx_dims[2], factor, do_dims[3]});
}
std::vector<int> axis = {0, 1, 3, 5, 2, 4};

framework::Tensor o;
o.ShareDataWith(*dx);
o.Resize({do_dims[0], do_dims[1], factor, factor, dx_dims[2], dx_dims[3]});

if (!channel_last) {
o.Resize(
{do_dims[0], do_dims[1], factor, factor, dx_dims[2], dx_dims[3]});
} else {
o.Resize(
{do_dims[0], dx_dims[1], dx_dims[2], do_dims[3], factor, factor});
}
math::Transpose<DeviceContext, T, 6> trans;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
trans(dev_ctx, t, &o, axis);
Expand Down
67 changes: 59 additions & 8 deletions python/paddle/fluid/tests/unittests/test_pixel_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,12 @@
import unittest
import numpy as np
from op_test import OpTest
import paddle


class TestPixelShuffle(OpTest):
def setUp(self):
self.op_type = "pixel_shuffle"
n, c, h, w = 2, 9, 4, 4
up_factor = 3
shape = [n, c, h, w]
x = np.random.random(shape).astype("float64")
def pixel_shuffle_np(x, up_factor, data_format="NCHW"):
if data_format == "NCHW":
n, c, h, w = x.shape
new_shape = (n, c // (up_factor * up_factor), up_factor, up_factor, h,
w)
# reshape to (num,output_channel,upscale_factor,upscale_factor,h,w)
Expand All @@ -34,10 +31,43 @@ def setUp(self):
npresult = npresult.transpose(0, 1, 4, 2, 5, 3)
oshape = [n, c // (up_factor * up_factor), h * up_factor, w * up_factor]
npresult = np.reshape(npresult, oshape)
return npresult
else:
n, h, w, c = x.shape

new_shape = (n, h, w, c // (up_factor * up_factor), up_factor,
up_factor)
# reshape to (num,h,w,output_channel,upscale_factor,upscale_factor)
npresult = np.reshape(x, new_shape)
# transpose to (num,h,upscale_factor,w,upscale_factor,output_channel)
npresult = npresult.transpose(0, 1, 4, 2, 5, 3)
oshape = [n, h * up_factor, w * up_factor, c // (up_factor * up_factor)]
npresult = np.reshape(npresult, oshape)
return npresult


class TestPixelShuffle(OpTest):
def setUp(self):
self.op_type = "pixel_shuffle"
self.init_data_format()
n, c, h, w = 2, 9, 4, 4

if self.format == "NCHW":
shape = [n, c, h, w]
if self.format == "NHWC":
shape = [n, h, w, c]

up_factor = 3

x = np.random.random(shape).astype("float64")
npresult = pixel_shuffle_np(x, up_factor, self.format)

self.inputs = {'X': x}
self.outputs = {'Out': npresult}
self.attrs = {'upscale_factor': up_factor}
self.attrs = {'upscale_factor': up_factor, "data_format": self.format}

def init_data_format(self):
self.format = "NCHW"

def test_check_output(self):
self.check_output()
Expand All @@ -46,5 +76,26 @@ def test_check_grad(self):
self.check_grad(['X'], 'Out')


class TestChannelLast(TestPixelShuffle):
def init_data_format(self):
self.format = "NHWC"


class TestPixelShuffleDygraph(unittest.TestCase):
def run_pixel_shuffle(self, up_factor):
x = np.random.rand(2, 9, 4, 4).astype(np.float32)

npresult = pixel_shuffle_np(x, up_factor)

paddle.disable_static()
pixel_shuffle = paddle.nn.PixelShuffle(up_factor)
result = pixel_shuffle(paddle.to_variable(x))

self.assertTrue(np.allclose(result.numpy(), npresult))

def test_pixel_shuffle(self):
self.run_pixel_shuffle(3)


if __name__ == '__main__':
unittest.main()
3 changes: 3 additions & 0 deletions python/paddle/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,10 @@
# from .layer.rnn import LSTMCell #DEFINE_ALIAS
from .layer.distance import PairwiseDistance #DEFINE_ALIAS

from .layer.vision import PixelShuffle

from .layer import loss #DEFINE_ALIAS
from .layer import conv #DEFINE_ALIAS
from .layer import vision #DEFINE_ALIAS
from ..fluid.dygraph.layers import Layer #DEFINE_ALIAS
from ..fluid.dygraph.container import LayerList, ParameterList, Sequential #DEFINE_ALIAS
68 changes: 34 additions & 34 deletions python/paddle/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,40 +158,40 @@
# from .rnn import gru_unit #DEFINE_ALIAS
# from .rnn import lstm #DEFINE_ALIAS
# from .rnn import lstm_unit #DEFINE_ALIAS
from .vision import affine_channel #DEFINE_ALIAS
from .vision import affine_grid #DEFINE_ALIAS
from .vision import anchor_generator #DEFINE_ALIAS
from .vision import bipartite_match #DEFINE_ALIAS
from .vision import box_clip #DEFINE_ALIAS
from .vision import box_coder #DEFINE_ALIAS
from .vision import box_decoder_and_assign #DEFINE_ALIAS
from .vision import collect_fpn_proposals #DEFINE_ALIAS
#from .vision import affine_channel #DEFINE_ALIAS
willthefrog marked this conversation as resolved.
Show resolved Hide resolved
#from .vision import affine_grid #DEFINE_ALIAS
#from .vision import anchor_generator #DEFINE_ALIAS
#from .vision import bipartite_match #DEFINE_ALIAS
#from .vision import box_clip #DEFINE_ALIAS
#from .vision import box_coder #DEFINE_ALIAS
#from .vision import box_decoder_and_assign #DEFINE_ALIAS
#from .vision import collect_fpn_proposals #DEFINE_ALIAS
# from .vision import deformable_conv #DEFINE_ALIAS
from .vision import deformable_roi_pooling #DEFINE_ALIAS
from .vision import density_prior_box #DEFINE_ALIAS
from .vision import detection_output #DEFINE_ALIAS
from .vision import distribute_fpn_proposals #DEFINE_ALIAS
from .vision import fsp_matrix #DEFINE_ALIAS
from .vision import generate_mask_labels #DEFINE_ALIAS
from .vision import generate_proposal_labels #DEFINE_ALIAS
from .vision import generate_proposals #DEFINE_ALIAS
from .vision import grid_sampler #DEFINE_ALIAS
from .vision import image_resize #DEFINE_ALIAS
from .vision import image_resize_short #DEFINE_ALIAS
#from .vision import deformable_roi_pooling #DEFINE_ALIAS
#from .vision import density_prior_box #DEFINE_ALIAS
#from .vision import detection_output #DEFINE_ALIAS
#from .vision import distribute_fpn_proposals #DEFINE_ALIAS
#from .vision import fsp_matrix #DEFINE_ALIAS
#from .vision import generate_mask_labels #DEFINE_ALIAS
#from .vision import generate_proposal_labels #DEFINE_ALIAS
#from .vision import generate_proposals #DEFINE_ALIAS
#from .vision import grid_sampler #DEFINE_ALIAS
#from .vision import image_resize #DEFINE_ALIAS
#from .vision import image_resize_short #DEFINE_ALIAS
# from .vision import multi_box_head #DEFINE_ALIAS
from .vision import pixel_shuffle #DEFINE_ALIAS
from .vision import prior_box #DEFINE_ALIAS
from .vision import prroi_pool #DEFINE_ALIAS
from .vision import psroi_pool #DEFINE_ALIAS
from .vision import resize_bilinear #DEFINE_ALIAS
from .vision import resize_nearest #DEFINE_ALIAS
from .vision import resize_trilinear #DEFINE_ALIAS
from .vision import retinanet_detection_output #DEFINE_ALIAS
from .vision import retinanet_target_assign #DEFINE_ALIAS
from .vision import roi_align #DEFINE_ALIAS
from .vision import roi_perspective_transform #DEFINE_ALIAS
from .vision import roi_pool #DEFINE_ALIAS
from .vision import shuffle_channel #DEFINE_ALIAS
from .vision import space_to_depth #DEFINE_ALIAS
from .vision import yolo_box #DEFINE_ALIAS
from .vision import yolov3_loss #DEFINE_ALIAS
#from .vision import prior_box #DEFINE_ALIAS
#from .vision import prroi_pool #DEFINE_ALIAS
#from .vision import psroi_pool #DEFINE_ALIAS
#from .vision import resize_bilinear #DEFINE_ALIAS
#from .vision import resize_nearest #DEFINE_ALIAS
#from .vision import resize_trilinear #DEFINE_ALIAS
#from .vision import retinanet_detection_output #DEFINE_ALIAS
#from .vision import retinanet_target_assign #DEFINE_ALIAS
#from .vision import roi_align #DEFINE_ALIAS
#from .vision import roi_perspective_transform #DEFINE_ALIAS
#from .vision import roi_pool #DEFINE_ALIAS
#from .vision import shuffle_channel #DEFINE_ALIAS
#from .vision import space_to_depth #DEFINE_ALIAS
#from .vision import yolo_box #DEFINE_ALIAS
#from .vision import yolov3_loss #DEFINE_ALIAS
Loading