-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix layout transform #7149
fix layout transform #7149
Conversation
paddle/framework/data_transform.cc
Outdated
@@ -64,7 +65,7 @@ void TransDataType(const platform::DeviceContext* ctx, | |||
framework::VisitDataType(dst_type, CastDataType<int>(src, dst, ctx)); | |||
break; | |||
case proto::DataType::INT64: | |||
framework::VisitDataType(dst_type, CastDataType<int64_t>(src, dst, ctx)); | |||
framework::VisitDataType(dst_type, CastDataType<int>(src, dst, ctx)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do you change int64_t
to int
?
paddle/framework/data_transform.cc
Outdated
@@ -76,12 +77,12 @@ void TransDataType(const platform::DeviceContext* ctx, | |||
|
|||
void TransDataLayout(const platform::DeviceContext* ctx, | |||
const KernelTypePair& kernel_pair, const Variable& in, | |||
Variable* out) { | |||
Variable* out, const std::vector<int>& axis) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Parameter order should be: inputs, then outputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
paddle/framework/data_transform.cc
Outdated
|
||
std::vector<int64_t> dst_dim; | ||
std::vector<int> dst_dim; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The same question: int64_t
->int
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a mistake change, done.
EXPECT_TRUE(dst.layout() != src->layout()); | ||
|
||
EXPECT_TRUE(dst.layout() == DataLayout::kNCHW); | ||
EXPECT_TRUE(dst.dims() == make_ddim({2, 2, 3, 1})); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The unit test of NCHW -> NHWC
is also need.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
paddle/framework/data_transform.cc
Outdated
@@ -111,5 +111,18 @@ void TransDataLayout(const platform::DeviceContext* ctx, | |||
} // namespace paddle | |||
|
|||
namespace f = paddle::framework; | |||
|
|||
std::vector<int> NHWC2NCHW = {0, 3, 1, 2}; | |||
std::vector<int> NCHW2NHWC = {0, 2, 3, 1}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it more appropriate to put this in an anonymous namespace?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.cc file always safe with translate unit local object. But you are right, anonymous namespace or static keywords is better.
Done.
paddle/framework/data_transform.cc
Outdated
const KernelTypePair& kernel_pair, const Variable& in, | ||
Variable* out) { | ||
PADDLE_ENFORCE(in.IsType<Tensor>(), "Only Support Tensor transform!."); | ||
PADDLE_ENFORCE( | ||
platform::places_are_same_class(kernel_pair.first.place_, | ||
kernel_pair.second.place_), | ||
"TransDataType Only Support DataType transform on same place!"); | ||
"TransDataLayout Only Support DataType transform on same place!"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DataType transform
-> DataLayout transform
And I think you should check whether the data types of in
and out
are the same. Otherwise, there might be an exception.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
paddle/framework/data_transform.cc
Outdated
@@ -91,7 +93,6 @@ void TransDataLayout(const platform::DeviceContext* ctx, | |||
dst->Resize(src_dim); | |||
auto place = kernel_pair.second.place_; | |||
CopyFrom(src, place, *ctx, dst); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CopyFrom(src, place, *ctx, dst);
is unnecessary.
@@ -111,5 +112,22 @@ void TransDataLayout(const platform::DeviceContext* ctx, | |||
} // namespace paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Parameter order should be inputs then outputs: line 106 CastDataLayout(src, dst, ctx, axis)
-> CastDataLayout( ctx, axis,src, dst)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
@@ -98,6 +99,7 @@ struct CastDataLayout { | |||
template <typename T> | |||
void operator()() { | |||
auto place = ctx_->GetPlace(); | |||
|
|||
if (platform::is_cpu_place(place)) { | |||
operators::math::Transpose<platform::CPUDeviceContext, T, 4> trans4; | |||
auto* context = static_cast<const platform::CPUDeviceContext*>(ctx_); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CastDataLayout
should also supportTranspose
from GPU to GPU.in_
,out_
andctx_
haveplace
, are they the same?
@@ -127,7 +127,19 @@ TEST(DataTransform, Layout) { | |||
} | |||
|
|||
Tensor dst = out.Get<Tensor>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When the Variable is on the GPU, it should also be tested in this unit test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing these flaws out! These fix will be done in next PR.
|
||
auto src = in.Get<Tensor>(); | ||
auto* dst = out->GetMutable<Tensor>(); | ||
PADDLE_ENFORCE(arity(src.dims()) == 4, "Input Arity Only Suppport 4!"); | ||
|
||
auto src_dim = src.dims(); | ||
dst->Resize(src_dim); | ||
auto place = kernel_pair.second.place_; | ||
CopyFrom(src, place, *ctx, dst); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CopyFrom
should be replaced with dst->mutable_data(place, src.type());
Please fix follow those comments in next PR. |
No description provided.