Skip to content

Commit

Permalink
【Hackathon 5th No.12】Add AdaptiveLogSoftmaxWithLoss API to Paddle (#770)
Browse files Browse the repository at this point in the history
* add AdaptiveLogSoftmaxWithLoss to Paddle

* update

* update

* resolve problems
  • Loading branch information
Patrick-Star125 authored Dec 11, 2023
1 parent 8dfef44 commit 8058019
Showing 1 changed file with 137 additions and 64 deletions.
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# paddle.nn.AdaptiveLogSoftmaxWithLoss 设计文档

|API名称 | paddle.nn.AdaptiveLogSoftmaxWithLoss |
|API名称 | paddle.nn.AdaptiveLogSoftmaxWithLoss |
|---|------------------------------------|
|提交作者<input type="checkbox" class="rowselector hidden"> | PeachML |
|提交时间<input type="checkbox" class="rowselector hidden"> | 2022-03-22 |
|版本号 | V1.0 |
|依赖飞桨版本<input type="checkbox" class="rowselector hidden"> | develop |
|文件名 | 20200322_api_design_for_AdaptiveLogSoftmaxWithLoss.md<br> |
|提交作者<input type="checkbox" class="rowselector hidden"> | netpunk |
|提交时间<input type="checkbox" class="rowselector hidden"> | 2023-12-02 |
|版本号 | V1.0 |
|依赖飞桨版本<input type="checkbox" class="rowselector hidden"> | develop |
|文件名 | 20200322_api_design_for_AdaptiveLogSoftmaxWithLoss.md<br> |

# 一、概述

Expand All @@ -17,10 +17,33 @@ Paddle需要扩充API,新增 AdaptiveLogSoftmaxWithLoss API,
调用路径为:`paddle.nn.AdaptiveLogSoftmaxWithLoss``paddle.nn.functional.adaptive_log_softmax_with_loss`
实现Softmax快速近似计算的功能。

## 2、功能目标

为飞桨补充 AdaptiveLogSoftmaxWithLoss API,该API实现 softmax 函数近似计算

adaptive_log_softmax_with_loss的计算分步骤如下

1. ![image](https://github.com/PaddlePaddle/community/assets/69072522/3f17c9fd-212a-444c-9a87-2a975c452940)

(将输入 `input` 通过线性变换映射到一个高维空间,其中 `head_weight` 是学习到的权重,`head_bias` 是偏置项。这个映射允许模型学习类别之间的复杂关系。)

2. ![image](https://github.com/PaddlePaddle/community/assets/69072522/893286a4-9c78-4e7f-b5f0-ec152ef69267)

( 将线性变换后的结果进行 softmax 操作,得到每个类别的概率分布,然后取对数。这有助于解决数值稳定性的问题,并且对数概率更容易处理。)

3. ![image](https://github.com/PaddlePaddle/community/assets/69072522/b6987bfb-e1a6-4193-9c12-818b9cc2a76c)

(从 `head_logprob` 中选择与给定类别索引 `gather_inds` 相对应的对数概率,然后将其累加到 `output` 中。这一步是为了计算 adaptive softmax 损失,其中仅关注一小部分类别的对数概率。)

4. ![image](https://github.com/PaddlePaddle/community/assets/69072522/e0b6e756-a6d3-46d5-b0bb-240e3847a05f)

(将累加的对数概率取负值并求平均,得到损失值。这是一个常见的负对数似然损失,用于衡量模型输出与真实标签之间的差异。)

这个函数不止输出`loss`,还输出`output`,表示经过 log softmax 转换后的对数概率的累加值,即每个类别的对数概率的总和。可能用于其他需要基于类别概率进行决策或分析的需求

## 3、意义
在自然语言处理中,当字典维度过大时,embedding 将占据模型大部分参数量。
例如机器翻译任务中,词表维度大约是2^17,embedding维度取1024,那么就会产生将近1亿参数量,
如果不共享embedding矩阵和softmax映射的矩阵,将会再多出1亿参数量。

在自然语言处理中,当字典维度过大时,embedding 可能会占据模型较大部分的参数量。

这样会引起常见的两个问题:

Expand Down Expand Up @@ -144,113 +167,115 @@ Efficient softmax approximation as described in

2. 训练
```python
def forward(self, input, target):
# input的shape为[batch_size * bptt, hidden_size]
# target的shape为[batch_size * bptt, 1]
if input.size(0) != target.size(0):
raise RuntimeError('Input and target should have the same size '
'in the batch dimension.')
# 用来统计多个cluster计算的batch,然后求和,保证最终等于batch_size
def forward(self, input_: Tensor, target_: Tensor) -> _ASMoutput:
targ_dim = target_.dim()

if targ_dim == 1:
if input_.size(0) != target_.size(0):
raise RuntimeError('Input and target should have the same size '
'in the batch dimension.')
if input_.dim() != 2:
raise RuntimeError('1D target tensor expects 2D input tensors, '
'but found inputs with size', input_.size())
elif targ_dim == 0:
if input_.dim() != 1:
raise RuntimeError('0D target tensor expects 1D input tensors, '
'but found inputs with size', input_.size())
else:
raise RuntimeError('0D or 1D target tensor expected, '
'multi-target not supported')

is_batched = targ_dim > 0
input = input_ if is_batched else input_.unsqueeze(0)
target = target_ if is_batched else target_.unsqueeze(0)

used_rows = 0
batch_size = target.size(0)
# 用来记录在target位置的 logprob

output = input.new_zeros(batch_size)
# 用来记录batch样本在第一层对应的类别
gather_inds = target.new_empty(batch_size)

cutoff_values = [0] + self.cutoffs
for i in range(len(cutoff_values) - 1):

low_idx = cutoff_values[i]
high_idx = cutoff_values[i + 1]
# 找到当前cluster的样本对应的index

target_mask = (target >= low_idx) & (target < high_idx)
row_indices = target_mask.nonzero().squeeze()
# 如果当前cluster没有样本,则没有loss

if row_indices.numel() == 0:
continue
# target对应高频词,这里只用来记录batch对应的target,高频词的预测在后面 self.head

if i == 0:
gather_inds.index_copy_(0, row_indices, target[target_mask])
# target对应低频词

else:
# 获取低频cluster对应的target的相对位置
relative_target = target[target_mask] - low_idx
# 获取对应cluster的input
input_subset = input.index_select(0, row_indices)
# 经过线性变换 得到 [batch_size_i, target_i]

cluster_output = self.tail[i - 1](input_subset)
# 当前cluster对应第一层权重元素的类别
cluster_index = self.shortlist_size + i - 1
# 记录对应第一层的类别

gather_inds.index_fill_(0, row_indices, cluster_index)
# 计算当前cluster的log_prob
cluster_logprob = log_softmax(cluster_output, dim=1)
# 获取对应target位置的log_prob
local_logprob = cluster_logprob.gather(1, relative_target.unsqueeze(1))
# 将结果记录到对应的batch中
output.index_copy_(0, row_indices, local_logprob.squeeze(1))

used_rows += row_indices.numel()

if used_rows != batch_size:
raise RuntimeError("Target values should be in [0, {}], "
"but values in range [{}, {}] "
"were found. ".format(self.n_classes - 1,
target.min().item(),
target.max().item()))
# 第一层的线性变换,因为无论高频和低频词都需要计算第一层,所以放到了这里统一计算
raise RuntimeError(f"Target values should be in [0, {self.n_classes - 1}], "
f"but values in range [{target.min().item()}, {target.max().item()}] "
"were found. ")

head_output = self.head(input)
# 取log_prob
head_logprob = log_softmax(head_output, dim=1)
# 这里是第一层的log_prob和第二层的log_prob加起来作为最后的输出
# tips: 对于属于第一层的样本,只需要计算第一层的log_prob就好
# 对于属于第二层的样本,需要将第一层计算得到的cluster对应类别的log_prob和
第二层cluster内计算得到的log_prob加起来,所以是output +=
output += head_logprob.gather(1, gather_inds.unsqueeze(1)).squeeze()
loss = (-output).mean()
# 返回一个nametuple

if not is_batched:
output = output.squeeze(0)

return _ASMoutput(output, loss)
```


3. 预测
```python
def predict(self, input):
"""
def predict(self, input: Tensor) -> Tensor:
r""" This is equivalent to `self.log_prob(input).argmax(dim=1)`,
but is more efficient in some cases.
Args:
input (Tensor): a minibatch of examples
Returns:
output (Tensor): a class with the highest probability for each example
Shape:
- Input: :math:`(N, in\_features)`
- Input: :math:`(N, \texttt{in\_features})`
- Output: :math:`(N)`
"""
# 第一层的线性转化

head_output = self.head(input)
# 记录预测target的位置
output = torch.argmax(head_output, dim=1)
# 判断预测的位置是否都是低频词
not_in_shortlist = (output >= self.shortlist_size)
# 获取预测高频词的样本index
all_in_shortlist = not (not_in_shortlist.any())
# 如果预测的结果都为高频词,则直接返回结果

if all_in_shortlist:
return output
# 如果预测的结果都为低频词

elif not_in_shortlist.all():
# 计算低频词对应cluster中target对应的log_prob
log_prob = self._get_full_log_prob(input, head_output)
return torch.argmax(log_prob, dim=1)
# 如果预测的结果既有高频词,也有低频词

else:
# 只对低频词进行对应cluser的预测
log_prob = self._get_full_log_prob(input[not_in_shortlist],
head_output[not_in_shortlist])
output[not_in_shortlist] = torch.argmax(log_prob, dim=1)
return output

# 计算低频词对应cluster中target对应的log_prob
def _get_full_log_prob(self, input, head_output):
""" Given input tensor, and output of `self.head`,
compute the log of the full distribution """
Expand All @@ -268,41 +293,89 @@ def _get_full_log_prob(self, input, head_output):
out[:, start_idx:stop_idx] = output_logprob

return out

def log_prob(self, input: Tensor) -> Tensor:
r""" Computes log probabilities for all :math:`\texttt{n\_classes}`
Args:
input (Tensor): a minibatch of examples
Returns:
log-probabilities of for each class :math:`c`
in range :math:`0 <= c <= \texttt{n\_classes}`, where :math:`\texttt{n\_classes}` is a
parameter passed to ``AdaptiveLogSoftmaxWithLoss`` constructor.
Shape:
- Input: :math:`(N, \texttt{in\_features})`
- Output: :math:`(N, \texttt{n\_classes})`
"""

head_output = self.head(input)
return self._get_full_log_prob(input, head_output)
```


# 四、对比分析
无其它框架实现

# 五、方案设计
# 五、设计思路与实现方案
## 命名与参数设计
API设计为`paddle.nn.AdaptiveLogSoftmaxWithLoss(in_features, n_classes, cutoffs, div_value=4.0, head_bias=False, name=None)`
`paddle.nn.functional.adaptive_log_softmax_with_loss(input, label,
in_features, n_classes, cutoffs, div_value=4.0, head_bias=False, name=None)`, 返回为`NamedTuple` 包含 `output``loss`字段

layer层类API:`paddle.nn.AdaptiveLogSoftmaxWithLoss(in_features, n_classes, cutoffs, div_value=4.0, head_bias=False, name=None)`,包含两个主要方法:
- forward(self, input, label),用于训练,返回为`output``loss`
- predict(self, input),用于预测

- in_features (int): 输入tensor的特征数量。
- n_classes (int): 数据集中类型的个数。
- cutoffs (Sequence): 用于将label分配到不同存储桶的截断值。
- div_value (float, 可选): 用于计算簇大小的指数值. 默认值:4.0。
- head_bias (bool, 可选): 如果为 ``True``,向自适应 softmax 的头部添加偏置项. 默认值:``False``.
- name (str, 可选): 具体用法请参见 :ref:`api_guide_Name`,一般无需设置,默认值为 None。

function API:`paddle.nn.functional.adaptive_log_softmax_with_loss(input, label, head_weight, tail_weights, cutoffs, head_bias=None)` 用于训练计算

- input (Tensor): 输入张量,数据类型为 float32 或 float64。
- label (Tensor): 标签张量,数据类型为 float32 或 float64。
- head_weight (Tensor): 用于线性计算的权重矩阵,数据类型为 float32 或 float64。
- tail_weights (Tensor): 用于线性计算的权重矩阵,数据类型为 float32 或 float64。
- cutoffs (Sequence): 用于将label分配到不同存储桶的截断值。
- head_bias (Tensor, 可选): 用于线性计算的偏置矩阵,数据类型为 float32 或 float64。
- name (str, 可选): 具体用法请参见 :ref:`api_guide_Name`,一般无需设置,默认值为 None。

## 底层OP设计
使用已有API组合实现,不再单独设计OP。

## API实现方案
主要参考pytorch实现,替换掉部分paddle没有的api


计算逻辑参考pytorch实现,并基于paddle API进行重组与封装:
- function API:`paddle.nn.functional.adaptive_log_softmax_with_loss(input, label, head_weight, tail_weights, cutoffs, head_bias=None)`,使用已有api进行组合实现,

- layer API:`paddle.nn.AdaptiveLogSoftmaxWithLoss(self, in_features, n_classes, cutoffs, div_value=4.0, head_bias=False, name=None)`,包含两个主要方法:
- `forward(self, input, label)`,用于训练,返回为`output``loss`
- `predict(self, input)`,用于预测,其计算与forward共享权重但是计算逻辑存在差异,故使用已有API组合实现的方式单独实现

# 六、测试和验收的考量
测试考虑的case如下:

- 数值正确性
- 数值正确性(CPU、GPU、动态图、静态图)
-`log_prob`(前置函数),log_prob的各类总和概率为1,即`paddle.exp(logprob_out).sum(1)=paddle.ones([4])`
-`forward`
- `output`为各类别概率即`output=log_prob.gather(y.unsqueeze(1), 1).slice([1], [0], [1]).squeeze()`
- `loss``loss=nll_loss(log_prob, y)`,其中`nll_loss`已经实现
-`predict`,有`predict=log_prob.argmax(axis=1)`
- 错误检查:`cutoff`的唯一性,数据类型,数值大于零小于`n_classes - 1`
- 错误检查:`input`尺寸与`in_features`一致


# 七、可行性分析及规划排期
方案主要依赖paddle现有api组合而成

paddle.gather与torch.gather存在差异,使用paddle.take_along_axis替换实现。实现无明显难点,可以按期完成。

# 八、影响面
为独立新增API,对其他模块没有影响

# 名词解释
# 附件及参考资料

0 comments on commit 8058019

Please sign in to comment.