-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【PTen】Add Scalar and ScalarArray in pten #37409
Conversation
Thanks for your contribution! |
paddle/pten/api/include/creation.h
Outdated
@@ -28,6 +29,12 @@ PD_DLL_DECL Tensor full(const std::vector<int64_t>& shape, | |||
Backend backend = Backend::CPU, | |||
DataLayout layout = DataLayout::NCHW); | |||
|
|||
PD_DLL_DECL Tensor full_new(const ScalarArray& shape, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
while not replace original full?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, remove original full.
paddle/pten/api/lib/tensor.cc
Outdated
@@ -214,17 +214,22 @@ const T *Tensor::data() const { | |||
template PD_DLL_DECL const float *Tensor::data<float>() const; | |||
template PD_DLL_DECL const double *Tensor::data<double>() const; | |||
template PD_DLL_DECL const int64_t *Tensor::data<int64_t>() const; | |||
template PD_DLL_DECL const uint64_t *Tensor::data<uint64_t>() const; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why need uint64 and uint32? If there is no current demand, over-design is not recommended
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, removed.
paddle/pten/common/scalar.h
Outdated
|
||
private: | ||
DataType data_type_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use dtype?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/pten/common/scalar_array.h
Outdated
} | ||
} | ||
|
||
template <typename TT> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TT is not a good name, maybe we can use InT, or OtherT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/pten/common/scalar_array.h
Outdated
} | ||
} | ||
|
||
// template <typename TT> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if need in the future, add some comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里暂时先去掉了,如果在后续使用中有提供此类接口的需求再进行添加
// const ScalarArrayBase<TT>& scalar_array); | ||
|
||
private: | ||
std::vector<int64_t> array_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe here need a TODO, we will not use vector finally
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/pten/kernels/cpu/creation.h
Outdated
@@ -33,4 +34,10 @@ void FillConstant(const CPUContext& dev_ctx, | |||
const Scalar& val, | |||
DenseTensor* out); | |||
|
|||
template <typename T> | |||
void FillConstantNew(const CPUContext& dev_ctx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
new is not a good suffix
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
暂时修改为FillConstantDynamicShape
,后续PR将尝试去掉原来的FillConstant函数,把这里的函数名改为FillConstant
namespace paddle { | ||
namespace experimental { | ||
|
||
class Scalar { | ||
template <typename T> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为了与之前的设计统一,推理希望优先使用继承而非模板,如果限于排期可暂时放松设计,后续再进行修改或给出文档说明
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
当前设计仅作为解决参数问题的临时方案,后续可在解决问题的前提下对方案进行调整完善或者重新设计
CopyScalar(other, this); | ||
} | ||
|
||
template <typename RT> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可能有类型安全问题
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[TODO] 在后续方案的优化迭代中进行完善
paddle/pten/common/scalar_array.h
Outdated
for (size_t i = 0; i < n; i++) { | ||
PD_CHECK(tensor_list[i].dtype() == data_type, | ||
"The data_type of tensors in the list isn't consistent." | ||
"the first tensor is`", | ||
data_type, | ||
"` but `", | ||
i, | ||
"`th tensor is`", | ||
tensor_list[i].dtype(), | ||
"`."); | ||
array_.push_back(*tensor_list[i].template data<int32_t>()); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这部分或许可以抽出来,减少后续拓展类型产生重复代码
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
目前设计属于临时方案,暂不考虑类型扩展,如果后续有扩展的必要,再将这里的代码进行合并整理
paddle/pten/kernels/cpu/creation.cc
Outdated
@@ -57,6 +57,15 @@ void FillConstant(const CPUContext& dev_ctx, | |||
eigen::fill<CPUContext, T>(dev_ctx, out, val.to<T>()); | |||
} | |||
|
|||
template <typename T> | |||
void FillConstantNew(const CPUContext& dev_ctx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
***New,奇怪的命名
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
暂时修改为FillConstantDynamicShape
,后续PR将尝试去掉原来的FillConstant函数,把这里的函数名改为FillConstant
paddle/pten/kernels/cpu/creation.cc
Outdated
@@ -87,3 +96,19 @@ PT_REGISTER_KERNEL("fill_constant.scalar", | |||
paddle::platform::bfloat16, | |||
paddle::platform::complex<float>, | |||
paddle::platform::complex<double>) {} | |||
|
|||
PT_REGISTER_KERNEL("fill_constant.new", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
目前的注册kernel后缀最好能体现出一些输入特点,以便后边沉淀变成规则,规范化后续kernel迁移
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已将.new
后缀去掉
paddle/pten/common/scalar_array.h
Outdated
void AssignData(const TYPE* value_data, int64_t n) { | ||
if (value_data) { | ||
array_.reserve(n); | ||
for (auto i = 0; i < n; i++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
循环遍历一般是 ++i
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/pten/common/scalar_array.h
Outdated
// const ScalarArrayBase<T>& scalar_array) { | ||
// return paddle::framework::make_ddim(scalar_array.array_); | ||
// } | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/pten/common/scalar_array.h
Outdated
|
||
class DenseTensor; | ||
using ScalarArray = paddle::experimental::ScalarArrayBase<DenseTensor>; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加上 // namespace pten
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/pten/common/scalar_array.h
Outdated
|
||
class DenseTensor; | ||
using ScalarArray = paddle::experimental::ScalarArrayBase<DenseTensor>; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加上 //namespace pten
} else { | ||
data_.d = std::stod(str_value); | ||
data_.f64 = std::stod(str_value); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里str_value 会是非法字符串吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的str_value是为了适配原来fill_constant op的参数,目前不会出现非法字符串的情况,完成op的迁移后这个接口应该会去掉,已补充注释
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
72aa162
* add scalar and scalar_array * remove DenseTensor include from Scalar and ScalarArray * remove inner header from scalar_array * refactor the method of fill_constant and add some comment
PR types
New features
PR changes
Others
Describe
Add Scalar and ScalarArray in pten
Scalar用来表示只有一个元素的0维Tensor,只支持CPU设备,可以通过
Tensor
或普通类型变量进行构造。ScalarArray用来表示1维Tensor(现阶段只支持使用CPU,且只支持整数类型数据),可以通过
vector<int>
,Tensor
以及vector<Tensor>
进行ScalarArray的构造。当前ScalarArray的主要用途是在PTen中代替原框架中表示shape或者索引类型的参数,即将vector<int>
,Tensor
和vector<Tensor>
三种表示shape或者索引的数据类型统一为使用ScalarArray表示。