Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize softmax with cross entropy soft label #32387

Merged
merged 9 commits into from
Dec 20, 2021
Merged

optimize softmax with cross entropy soft label #32387

merged 9 commits into from
Dec 20, 2021

Conversation

xingfeng01
Copy link
Contributor

@xingfeng01 xingfeng01 commented Apr 20, 2021

PR types

Performance optimization

PR changes

OPs

Describe

softmax_with_cross_entropy optimization with soft label. This PR includes optimization of

  • "SoftmaxWithCrossEntropySoftLabel" : compute log_softmax and then compute loss.
  • "CrossEntropySoftLabel" : compute loss with softmax as input.

These optimization includes following technics:

  • read data to buffer with vectorization
  • compute max and sum in warp
  • fixed loop size with macro

Performance (computation time):

  • softmax_with_cross_entropy_0 (forward) : -40.1%
  • softmax_with_cross_entropy_0 (backward): -41%

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@xingfeng01
Copy link
Contributor Author

Performance optimized, computation time reduced:

softmax_with_cross_entropy_0 (forward) : -40.1%
softmax_with_cross_entropy_0 (backward): -41%

@paddle-bot-old
Copy link

Sorry to inform you that a6385f6's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

很早之前的一些review建议,一直没有提交,作为参考。

const int64_t* label, const int batch_size,
const int stride, const int element_count,
const int ignore_index) {
__global__ void WarpSoftmaxForwardHardLabel(T* loss, T* softmax, const T* src,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数,按当初的设计,是即可用于计算softmax和log_softmax,又可用于计算softmax_with_cross_entropy(硬标签),所以可不必改名?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

__global__ void CrossEntropySoftLabel(T* loss, T* softmaxwrt, const T* softmax,
const T* labels, const int n,
const int dim, const int d,
int Log2Elements) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Log2Elements不作为模板,而是作为输入参数传入,则应该命名为log2_elements
  • 这个函数模板中不必传InLogMode
  • CrossEntropySoftLabelWarpSoftmaxForwardSoftLabel的差别是什么?代码结构看起来差不多。
  • 输入参数softmaxwrt指的是什么?
  • 调用处第三个参数传的NULL,那可以删掉?
  • 调用处TVecT都传的T,即VecSize为1,后续打算支持向量化吗?若不打算支持,则模板中可以删掉VecT

以上是我看到这个函数产生的一些问题。

通过对照文件中的全部实现和该函数调用,我大体明白了这个函数要支持2种情况:

  • 输入是softmax,只需使用该softmax计算loss
  • 输入是log_softmax,需要用log_softmax计算loss,并更新softmax。

建议函数命名、输入参数定义的更清晰一些。另外,给函数加一些解释说明。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已加注释

const int kDimLog2 = static_cast<int>(Log2Ceil(dim));
const int kDimCeil = 1 << kDimLog2;

int kThreadPerBlock = 512;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

安全起见,rocm平台的block_size不要超过256。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

所有 512 均已修改

}

template <typename T>
static void SoftmaxWithCrossEntropySoftLabel(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数和SoftmaxWithCrossEntropyHardLabel相似度太高,建议合并成1个。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,下个 PR 处理代码清理问题


if (idx_n < n && idx_dim < dim) {
VecT softmaxdata;
if (InLogMode == true) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

直接用if (InLogMode)即可。下同。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

}
VecT labelsdata = reinterpret_cast<const VecT*>(&labels[idx])[0];
T* softmaxptr = reinterpret_cast<T*>(&softmaxdata);
T* labelsptr = reinterpret_cast<T*>(&labelsdata);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

template <typename T, int Size>
struct alignas(sizeof(T) * Size) AlignedVector {
  T val[Size];
};

感觉上述定义的VecT会更好用一些,避免指针使用reinterpret_cast转来转去。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

个人觉得目前的方式可读性高,比较清楚的看到做了cast操作

JiaXiao243
JiaXiao243 previously approved these changes Sep 15, 2021
Copy link

@JiaXiao243 JiaXiao243 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. mv3_large_x1_0_distill_mv3_small_x1_0 top1 acc is 67.9%.
0fb6547764967cb0ee494e03fea9e6dc

Copy link
Contributor

@yghstill yghstill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@lanxianghit lanxianghit merged commit f895560 into PaddlePaddle:develop Dec 20, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants