-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/transform #7111
Feature/transform #7111
Conversation
paddle/framework/data_transform.cc
Outdated
PADDLE_ENFORCE(in.IsType<Tensor>(), "Only Support Tensor transform!."); | ||
auto src = in.Get<Tensor>(); | ||
auto* dst = out->GetMutable<Tensor>(); | ||
// TODO(dzhwinter): CPU <-> GPU need a copy here |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My view is that in TransDataType
, the in
and out
should be in the same device.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
auto dst_type = kernel_pair.second.data_type_; | ||
auto src_type = kernel_pair.first.data_type_; | ||
|
||
switch (src_type) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should check whether src_type
and dst_type
are the same.
If they are the same, we only need copy in
to out
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, this check should be done before the function.
paddle/framework/data_transform.cc
Outdated
|
||
namespace f = paddle::framework; | ||
REGISTER_DATA_TRANSFORM_FN(f::kernel_FP32, f::kernel_FP64, f::TransDataType); | ||
REGISTER_DATA_TRANSFORM_FN(f::kernel_Layout0, f::kernel_Layout1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kernel_Layout0
and kernel_Layout1
are not intuitive. I suggest that kernel_Layout0
be written as "kernel_Layout_NHWC", or other better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/framework/data_transform.h
Outdated
@@ -45,6 +48,42 @@ struct KernelTypePairHash { | |||
} | |||
}; | |||
|
|||
template <typename InType, typename OutType> | |||
struct CastDataTypeFunctor { | |||
HOSTDEVICE OutType operator()(InType in) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add inline
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
CastDataTypeFunctor<InType, OutType>()); | ||
} else { | ||
// TODO(dzhwinter): enhance CopyFrom CPU<->GPU with different data type? | ||
PADDLE_THROW("Unsupport CPU <-> GPU!"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OpKernelType
is made up of four members(type, layout, place, library), except for library, the other three members are related to Data. According to the principle of function: a single task, I suggest that CastDataType
only handles the data type conversion on the same device.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we have re-designed the registrars.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the very beginning, we didn't figure out a method to do the multiple transforms.
For example, if we need to transform 0000 -> 1111, should we done in 4 step or 1 step?
- For DataType, only test the FP32(float), FP64(double).
- e.g. 0000 -> FP32, CPUPlace, kNHWC, kPlain
1111 -> FP64, GPUPlace, kNCHW, kMKLDNN
So this PR looks ugly when deal with CPU <-> GPU, and transform datatype the same time.
Now we have made the decision, we will do this in 4 step. ONE functor will only do its own duty, then we will only have four kinds of transform functor. We will achieve that in the final state.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For LibraryType, Variable does not need a transform functor, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, only 3 kinds of transform functor if you take it seriously.
paddle/framework/data_transform.cc
Outdated
auto* dst = out->GetMutable<Tensor>(); | ||
dst->Resize(src.dims()); | ||
auto place = kernel_pair.second.place_; | ||
CopyFrom(src, place, *ctx, dst); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe TransDataLayout
is wrong.
e.g. arr
's layout is NCHW
format, and data is:
[[[[ 0, 1, 2],
[ 3, 4, 5]],
[[ 6, 7, 8],
[ 9, 10, 11]]]]
Obviously, arr
's shape is (1,2,2,3).
If we change NCHW
to NHWC
, it's shape will be (1,3,2,2), and the data is:
[[[[ 0, 6],
[ 3, 9]],
[[ 1, 7],
[ 4, 10]],
[[ 2, 8],
[ 5, 11]]]]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
auto src = in.Get<Tensor>(); | ||
auto* dst = out->GetMutable<Tensor>(); | ||
PADDLE_ENFORCE(arity(src.dims()) == 4, "Input Arity Only Suppport 4!"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
arity(src.dims())
can be replaced by src.dims().size()
. This is just a small problem.
} | ||
|
||
Tensor dst = out.Get<Tensor>(); | ||
EXPECT_TRUE(dst.layout() != src->layout()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
EXPECT_TRUE(dst.dims() == make_ddim({2, 2, 3, 1});
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point, will fix the test in next PR.
add Datatype, Layout transform.