-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
optimize softmax with cross entropy soft label #32387
optimize softmax with cross entropy soft label #32387
Conversation
Thanks for your contribution! |
Performance optimized, computation time reduced: softmax_with_cross_entropy_0 (forward) : -40.1% |
Sorry to inform you that a6385f6's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
很早之前的一些review建议,一直没有提交,作为参考。
const int64_t* label, const int batch_size, | ||
const int stride, const int element_count, | ||
const int ignore_index) { | ||
__global__ void WarpSoftmaxForwardHardLabel(T* loss, T* softmax, const T* src, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个函数,按当初的设计,是即可用于计算softmax和log_softmax,又可用于计算softmax_with_cross_entropy(硬标签),所以可不必改名?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
__global__ void CrossEntropySoftLabel(T* loss, T* softmaxwrt, const T* softmax, | ||
const T* labels, const int n, | ||
const int dim, const int d, | ||
int Log2Elements) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Log2Elements
不作为模板,而是作为输入参数传入,则应该命名为log2_elements
。- 这个函数模板中不必传
InLogMode
? CrossEntropySoftLabel
和WarpSoftmaxForwardSoftLabel
的差别是什么?代码结构看起来差不多。- 输入参数
softmaxwrt
指的是什么? - 调用处第三个参数传的
NULL
,那可以删掉? - 调用处
T
和VecT
都传的T,即VecSize
为1,后续打算支持向量化吗?若不打算支持,则模板中可以删掉VecT
。
以上是我看到这个函数产生的一些问题。
通过对照文件中的全部实现和该函数调用,我大体明白了这个函数要支持2种情况:
- 输入是softmax,只需使用该softmax计算loss
- 输入是log_softmax,需要用log_softmax计算loss,并更新softmax。
建议函数命名、输入参数定义的更清晰一些。另外,给函数加一些解释说明。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已加注释
const int kDimLog2 = static_cast<int>(Log2Ceil(dim)); | ||
const int kDimCeil = 1 << kDimLog2; | ||
|
||
int kThreadPerBlock = 512; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
安全起见,rocm平台的block_size不要超过256。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
所有 512 均已修改
} | ||
|
||
template <typename T> | ||
static void SoftmaxWithCrossEntropySoftLabel( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个函数和SoftmaxWithCrossEntropyHardLabel
相似度太高,建议合并成1个。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,下个 PR 处理代码清理问题
|
||
if (idx_n < n && idx_dim < dim) { | ||
VecT softmaxdata; | ||
if (InLogMode == true) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
直接用if (InLogMode)
即可。下同。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
} | ||
VecT labelsdata = reinterpret_cast<const VecT*>(&labels[idx])[0]; | ||
T* softmaxptr = reinterpret_cast<T*>(&softmaxdata); | ||
T* labelsptr = reinterpret_cast<T*>(&labelsdata); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
template <typename T, int Size>
struct alignas(sizeof(T) * Size) AlignedVector {
T val[Size];
};
感觉上述定义的VecT
会更好用一些,避免指针使用reinterpret_cast
转来转去。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
个人觉得目前的方式可读性高,比较清楚的看到做了cast操作
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Performance optimization
PR changes
OPs
Describe
softmax_with_cross_entropy optimization with soft label. This PR includes optimization of
These optimization includes following technics:
Performance (computation time):