Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix layout transform #7149

Merged
merged 4 commits into from
Jan 4, 2018
Merged

Conversation

dzhwinter
Copy link
Contributor

No description provided.

@dzhwinter dzhwinter changed the title add shared_layout between ops fix layout transform Jan 2, 2018
@@ -64,7 +65,7 @@ void TransDataType(const platform::DeviceContext* ctx,
framework::VisitDataType(dst_type, CastDataType<int>(src, dst, ctx));
break;
case proto::DataType::INT64:
framework::VisitDataType(dst_type, CastDataType<int64_t>(src, dst, ctx));
framework::VisitDataType(dst_type, CastDataType<int>(src, dst, ctx));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you change int64_t to int?

@@ -76,12 +77,12 @@ void TransDataType(const platform::DeviceContext* ctx,

void TransDataLayout(const platform::DeviceContext* ctx,
const KernelTypePair& kernel_pair, const Variable& in,
Variable* out) {
Variable* out, const std::vector<int>& axis) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parameter order should be: inputs, then outputs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


std::vector<int64_t> dst_dim;
std::vector<int> dst_dim;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same question: int64_t->int.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a mistake change, done.

EXPECT_TRUE(dst.layout() != src->layout());

EXPECT_TRUE(dst.layout() == DataLayout::kNCHW);
EXPECT_TRUE(dst.dims() == make_ddim({2, 2, 3, 1}));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The unit test of NCHW -> NHWC is also need.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

@@ -111,5 +111,18 @@ void TransDataLayout(const platform::DeviceContext* ctx,
} // namespace paddle

namespace f = paddle::framework;

std::vector<int> NHWC2NCHW = {0, 3, 1, 2};
std::vector<int> NCHW2NHWC = {0, 2, 3, 1};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it more appropriate to put this in an anonymous namespace?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.cc file always safe with translate unit local object. But you are right, anonymous namespace or static keywords is better.
Done.

const KernelTypePair& kernel_pair, const Variable& in,
Variable* out) {
PADDLE_ENFORCE(in.IsType<Tensor>(), "Only Support Tensor transform!.");
PADDLE_ENFORCE(
platform::places_are_same_class(kernel_pair.first.place_,
kernel_pair.second.place_),
"TransDataType Only Support DataType transform on same place!");
"TransDataLayout Only Support DataType transform on same place!");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DataType transform -> DataLayout transform
And I think you should check whether the data types of in and out are the same. Otherwise, there might be an exception.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

@@ -91,7 +93,6 @@ void TransDataLayout(const platform::DeviceContext* ctx,
dst->Resize(src_dim);
auto place = kernel_pair.second.place_;
CopyFrom(src, place, *ctx, dst);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CopyFrom(src, place, *ctx, dst); is unnecessary.

@@ -111,5 +112,22 @@ void TransDataLayout(const platform::DeviceContext* ctx,
} // namespace paddle
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parameter order should be inputs then outputs: line 106 CastDataLayout(src, dst, ctx, axis) -> CastDataLayout( ctx, axis,src, dst)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

@@ -98,6 +99,7 @@ struct CastDataLayout {
template <typename T>
void operator()() {
auto place = ctx_->GetPlace();

if (platform::is_cpu_place(place)) {
operators::math::Transpose<platform::CPUDeviceContext, T, 4> trans4;
auto* context = static_cast<const platform::CPUDeviceContext*>(ctx_);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. CastDataLayout should also support Transpose from GPU to GPU.
  2. in_, out_ and ctx_ have place, are they the same?

@@ -127,7 +127,19 @@ TEST(DataTransform, Layout) {
}

Tensor dst = out.Get<Tensor>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the Variable is on the GPU, it should also be tested in this unit test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing these flaws out! These fix will be done in next PR.


auto src = in.Get<Tensor>();
auto* dst = out->GetMutable<Tensor>();
PADDLE_ENFORCE(arity(src.dims()) == 4, "Input Arity Only Suppport 4!");

auto src_dim = src.dims();
dst->Resize(src_dim);
auto place = kernel_pair.second.place_;
CopyFrom(src, place, *ctx, dst);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CopyFrom should be replaced with dst->mutable_data(place, src.type());

@chengduoZH
Copy link
Contributor

Please fix follow those comments in next PR.

@dzhwinter dzhwinter merged commit a893f15 into PaddlePaddle:develop Jan 4, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants