-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PTen]Reshape Kernel Refactor #37164
[PTen]Reshape Kernel Refactor #37164
Conversation
Thanks for your contribution! |
@@ -1883,6 +1883,10 @@ void OperatorWithKernel::BuildPtenKernelContext( | |||
pt_kernel_context_->EmplaceBackAttr(BOOST_GET_CONST(float, attr)); | |||
} else if (attr_defs[i].type_index == std::type_index(typeid(bool))) { | |||
pt_kernel_context_->EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); | |||
} else if (attr_defs[i].type_index == | |||
std::type_index(typeid(std::vector<int>))) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shape需要使用vector<int64_t>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reshape maker定义中使用的vector,使用vector<int64_t>会有兼容问题
paddle/fluid/operators/reshape_op.cc
Outdated
|
||
// only can include the headers in paddle/top/api dirs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
注释目录需要修改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
// will throw error(can't realloc shared memory) in current DenseTensor | ||
// design. So, codes | ||
// below create a tmp densetensor for output. | ||
const auto alloc = std::make_shared<paddle::experimental::DefaultAllocator>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#36916 合入之后,会支持realloc,这里需要修改写法,建议记个TODO
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
目前这样的写法可能会对模型性能造成影响,后面需要尽快修改
general::SetXShape(x, xshape); | ||
} | ||
|
||
void ReshapeFromDT(const CPUContext& dev_ctx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DT表示什么?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DenseTensor,代码中已加注释
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
函数命名后面要迭代一下,原则上是kernel注册name的驼峰式写法,如果不完备的话我们需要进一步考虑规则
paddle/fluid/operators/reshape_op.cc
Outdated
// and this | ||
// will throw error(can't realloc shared memory) in current DenseTensor | ||
// design. So, codes | ||
// below create a tmp densetensor for output. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
短的注释合并到一行更整齐一些
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const Scalar&); | ||
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<int64_t>&); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这两处的const & 是不是可以去掉了?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
去掉会有bug
"the element's shape must be [1]. But received the element's shape " | ||
"is [%s]", | ||
tensor.dims())); | ||
vector_shape.push_back(static_cast<int32_t>(*tensor.data<int32_t>())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里需要static_cast<int32_t>转换吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已去除
} | ||
for (size_t dtype = static_cast<size_t>(DataType::BOOL); | ||
dtype != static_cast<size_t>(DataType::NUM_DATA_TYPES); | ||
dtype++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
枚举类可以直接遍历,不需要转为size_t 再遍历
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
直接不行吧,需要重载++运算符才可以
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
那我记错了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Others
PR changes
OPs
Describe
reshape kernel refactor