Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PTen】Add Scalar and ScalarArray in pten #37409

Merged
merged 5 commits into from
Nov 24, 2021

Conversation

zyfncg
Copy link
Contributor

@zyfncg zyfncg commented Nov 22, 2021

PR types

New features

PR changes

Others

Describe

Add Scalar and ScalarArray in pten

Scalar用来表示只有一个元素的0维Tensor,只支持CPU设备,可以通过Tensor或普通类型变量进行构造。

ScalarArray用来表示1维Tensor(现阶段只支持使用CPU,且只支持整数类型数据),可以通过vector<int>Tensor以及vector<Tensor>进行ScalarArray的构造。当前ScalarArray的主要用途是在PTen中代替原框架中表示shape或者索引类型的参数,即将vector<int>Tensorvector<Tensor>三种表示shape或者索引的数据类型统一为使用ScalarArray表示。

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@@ -28,6 +29,12 @@ PD_DLL_DECL Tensor full(const std::vector<int64_t>& shape,
Backend backend = Backend::CPU,
DataLayout layout = DataLayout::NCHW);

PD_DLL_DECL Tensor full_new(const ScalarArray& shape,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while not replace original full?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, remove original full.

@@ -214,17 +214,22 @@ const T *Tensor::data() const {
template PD_DLL_DECL const float *Tensor::data<float>() const;
template PD_DLL_DECL const double *Tensor::data<double>() const;
template PD_DLL_DECL const int64_t *Tensor::data<int64_t>() const;
template PD_DLL_DECL const uint64_t *Tensor::data<uint64_t>() const;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why need uint64 and uint32? If there is no current demand, over-design is not recommended

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, removed.


private:
DataType data_type_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use dtype?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}
}

template <typename TT>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TT is not a good name, maybe we can use InT, or OtherT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}
}

// template <typename TT>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if need in the future, add some comment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里暂时先去掉了,如果在后续使用中有提供此类接口的需求再进行添加

// const ScalarArrayBase<TT>& scalar_array);

private:
std::vector<int64_t> array_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe here need a TODO, we will not use vector finally

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -33,4 +34,10 @@ void FillConstant(const CPUContext& dev_ctx,
const Scalar& val,
DenseTensor* out);

template <typename T>
void FillConstantNew(const CPUContext& dev_ctx,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new is not a good suffix

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

暂时修改为FillConstantDynamicShape,后续PR将尝试去掉原来的FillConstant函数,把这里的函数名改为FillConstant

namespace paddle {
namespace experimental {

class Scalar {
template <typename T>
Copy link
Contributor

@Shixiaowei02 Shixiaowei02 Nov 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为了与之前的设计统一,推理希望优先使用继承而非模板,如果限于排期可暂时放松设计,后续再进行修改或给出文档说明

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

当前设计仅作为解决参数问题的临时方案,后续可在解决问题的前提下对方案进行调整完善或者重新设计

CopyScalar(other, this);
}

template <typename RT>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可能有类型安全问题

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[TODO] 在后续方案的优化迭代中进行完善

Comment on lines 68 to 79
for (size_t i = 0; i < n; i++) {
PD_CHECK(tensor_list[i].dtype() == data_type,
"The data_type of tensors in the list isn't consistent."
"the first tensor is`",
data_type,
"` but `",
i,
"`th tensor is`",
tensor_list[i].dtype(),
"`.");
array_.push_back(*tensor_list[i].template data<int32_t>());
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分或许可以抽出来,减少后续拓展类型产生重复代码

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前设计属于临时方案,暂不考虑类型扩展,如果后续有扩展的必要,再将这里的代码进行合并整理

@@ -57,6 +57,15 @@ void FillConstant(const CPUContext& dev_ctx,
eigen::fill<CPUContext, T>(dev_ctx, out, val.to<T>());
}

template <typename T>
void FillConstantNew(const CPUContext& dev_ctx,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

***New,奇怪的命名

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

暂时修改为FillConstantDynamicShape,后续PR将尝试去掉原来的FillConstant函数,把这里的函数名改为FillConstant

@@ -87,3 +96,19 @@ PT_REGISTER_KERNEL("fill_constant.scalar",
paddle::platform::bfloat16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}

PT_REGISTER_KERNEL("fill_constant.new",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前的注册kernel后缀最好能体现出一些输入特点,以便后边沉淀变成规则,规范化后续kernel迁移

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已将.new后缀去掉

void AssignData(const TYPE* value_data, int64_t n) {
if (value_data) {
array_.reserve(n);
for (auto i = 0; i < n; i++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

循环遍历一般是 ++i

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

// const ScalarArrayBase<T>& scalar_array) {
// return paddle::framework::make_ddim(scalar_array.array_);
// }

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


class DenseTensor;
using ScalarArray = paddle::experimental::ScalarArrayBase<DenseTensor>;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加上 // namespace pten

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


class DenseTensor;
using ScalarArray = paddle::experimental::ScalarArrayBase<DenseTensor>;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加上 //namespace pten

} else {
data_.d = std::stod(str_value);
data_.f64 = std::stod(str_value);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里str_value 会是非法字符串吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的str_value是为了适配原来fill_constant op的参数,目前不会出现非法字符串的情况,完成op的迁移后这个接口应该会去掉,已补充注释

Copy link
Contributor

@MingMingShangTian MingMingShangTian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

YuanRisheng
YuanRisheng previously approved these changes Nov 23, 2021
chenwhql
chenwhql previously approved these changes Nov 23, 2021
Copy link
Contributor

@chenwhql chenwhql left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zyfncg zyfncg merged commit 0f24de8 into PaddlePaddle:develop Nov 24, 2021
Zjq9409 pushed a commit to Zjq9409/Paddle that referenced this pull request Dec 10, 2021
* add scalar and scalar_array

* remove DenseTensor include from Scalar and ScalarArray

* remove inner header from scalar_array

* refactor the method of fill_constant and add some comment
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants