diff --git a/rfcs/APIs/20200322_api_design_for_AdaptiveLogSoftmaxWithLoss.md b/rfcs/APIs/20231202_api_design_for_AdaptiveLogSoftmaxWithLoss.md
similarity index 59%
rename from rfcs/APIs/20200322_api_design_for_AdaptiveLogSoftmaxWithLoss.md
rename to rfcs/APIs/20231202_api_design_for_AdaptiveLogSoftmaxWithLoss.md
index 611561471..c278ff3ac 100644
--- a/rfcs/APIs/20200322_api_design_for_AdaptiveLogSoftmaxWithLoss.md
+++ b/rfcs/APIs/20231202_api_design_for_AdaptiveLogSoftmaxWithLoss.md
@@ -1,12 +1,12 @@
# paddle.nn.AdaptiveLogSoftmaxWithLoss 设计文档
-|API名称 | paddle.nn.AdaptiveLogSoftmaxWithLoss |
+|API名称 | paddle.nn.AdaptiveLogSoftmaxWithLoss |
|---|------------------------------------|
-|提交作者 | PeachML |
-|提交时间 | 2022-03-22 |
-|版本号 | V1.0 |
-|依赖飞桨版本 | develop |
-|文件名 | 20200322_api_design_for_AdaptiveLogSoftmaxWithLoss.md
|
+|提交作者 | netpunk |
+|提交时间 | 2023-12-02 |
+|版本号 | V1.0 |
+|依赖飞桨版本 | develop |
+|文件名 | 20200322_api_design_for_AdaptiveLogSoftmaxWithLoss.md
|
# 一、概述
@@ -17,10 +17,33 @@ Paddle需要扩充API,新增 AdaptiveLogSoftmaxWithLoss API,
调用路径为:`paddle.nn.AdaptiveLogSoftmaxWithLoss` 和 `paddle.nn.functional.adaptive_log_softmax_with_loss`。
实现Softmax快速近似计算的功能。
+## 2、功能目标
+
+为飞桨补充 AdaptiveLogSoftmaxWithLoss API,该API实现 softmax 函数近似计算
+
+adaptive_log_softmax_with_loss的计算分步骤如下
+
+1. ![image](https://github.com/PaddlePaddle/community/assets/69072522/3f17c9fd-212a-444c-9a87-2a975c452940)
+
+ (将输入 `input` 通过线性变换映射到一个高维空间,其中 `head_weight` 是学习到的权重,`head_bias` 是偏置项。这个映射允许模型学习类别之间的复杂关系。)
+
+2. ![image](https://github.com/PaddlePaddle/community/assets/69072522/893286a4-9c78-4e7f-b5f0-ec152ef69267)
+
+ ( 将线性变换后的结果进行 softmax 操作,得到每个类别的概率分布,然后取对数。这有助于解决数值稳定性的问题,并且对数概率更容易处理。)
+
+3. ![image](https://github.com/PaddlePaddle/community/assets/69072522/b6987bfb-e1a6-4193-9c12-818b9cc2a76c)
+
+ (从 `head_logprob` 中选择与给定类别索引 `gather_inds` 相对应的对数概率,然后将其累加到 `output` 中。这一步是为了计算 adaptive softmax 损失,其中仅关注一小部分类别的对数概率。)
+
+4. ![image](https://github.com/PaddlePaddle/community/assets/69072522/e0b6e756-a6d3-46d5-b0bb-240e3847a05f)
+
+ (将累加的对数概率取负值并求平均,得到损失值。这是一个常见的负对数似然损失,用于衡量模型输出与真实标签之间的差异。)
+
+这个函数不止输出`loss`,还输出`output`,表示经过 log softmax 转换后的对数概率的累加值,即每个类别的对数概率的总和。可能用于其他需要基于类别概率进行决策或分析的需求
+
## 3、意义
-在自然语言处理中,当字典维度过大时,embedding 将占据模型大部分参数量。
-例如机器翻译任务中,词表维度大约是2^17,embedding维度取1024,那么就会产生将近1亿参数量,
-如果不共享embedding矩阵和softmax映射的矩阵,将会再多出1亿参数量。
+
+在自然语言处理中,当字典维度过大时,embedding 可能会占据模型较大部分的参数量。
这样会引起常见的两个问题:
@@ -144,18 +167,32 @@ Efficient softmax approximation as described in
2. 训练
```python
-def forward(self, input, target):
- # input的shape为[batch_size * bptt, hidden_size]
- # target的shape为[batch_size * bptt, 1]
- if input.size(0) != target.size(0):
- raise RuntimeError('Input and target should have the same size '
- 'in the batch dimension.')
- # 用来统计多个cluster计算的batch,然后求和,保证最终等于batch_size
+def forward(self, input_: Tensor, target_: Tensor) -> _ASMoutput:
+ targ_dim = target_.dim()
+
+ if targ_dim == 1:
+ if input_.size(0) != target_.size(0):
+ raise RuntimeError('Input and target should have the same size '
+ 'in the batch dimension.')
+ if input_.dim() != 2:
+ raise RuntimeError('1D target tensor expects 2D input tensors, '
+ 'but found inputs with size', input_.size())
+ elif targ_dim == 0:
+ if input_.dim() != 1:
+ raise RuntimeError('0D target tensor expects 1D input tensors, '
+ 'but found inputs with size', input_.size())
+ else:
+ raise RuntimeError('0D or 1D target tensor expected, '
+ 'multi-target not supported')
+
+ is_batched = targ_dim > 0
+ input = input_ if is_batched else input_.unsqueeze(0)
+ target = target_ if is_batched else target_.unsqueeze(0)
+
used_rows = 0
batch_size = target.size(0)
- # 用来记录在target位置的 logprob
+
output = input.new_zeros(batch_size)
- # 用来记录batch样本在第一层对应的类别
gather_inds = target.new_empty(batch_size)
cutoff_values = [0] + self.cutoffs
@@ -163,94 +200,82 @@ def forward(self, input, target):
low_idx = cutoff_values[i]
high_idx = cutoff_values[i + 1]
- # 找到当前cluster的样本对应的index
+
target_mask = (target >= low_idx) & (target < high_idx)
row_indices = target_mask.nonzero().squeeze()
- # 如果当前cluster没有样本,则没有loss
+
if row_indices.numel() == 0:
continue
- # target对应高频词,这里只用来记录batch对应的target,高频词的预测在后面 self.head
+
if i == 0:
gather_inds.index_copy_(0, row_indices, target[target_mask])
- # target对应低频词
+
else:
- # 获取低频cluster对应的target的相对位置
relative_target = target[target_mask] - low_idx
- # 获取对应cluster的input
input_subset = input.index_select(0, row_indices)
- # 经过线性变换 得到 [batch_size_i, target_i]
+
cluster_output = self.tail[i - 1](input_subset)
- # 当前cluster对应第一层权重元素的类别
cluster_index = self.shortlist_size + i - 1
- # 记录对应第一层的类别
+
gather_inds.index_fill_(0, row_indices, cluster_index)
- # 计算当前cluster的log_prob
cluster_logprob = log_softmax(cluster_output, dim=1)
- # 获取对应target位置的log_prob
local_logprob = cluster_logprob.gather(1, relative_target.unsqueeze(1))
- # 将结果记录到对应的batch中
output.index_copy_(0, row_indices, local_logprob.squeeze(1))
used_rows += row_indices.numel()
if used_rows != batch_size:
- raise RuntimeError("Target values should be in [0, {}], "
- "but values in range [{}, {}] "
- "were found. ".format(self.n_classes - 1,
- target.min().item(),
- target.max().item()))
- # 第一层的线性变换,因为无论高频和低频词都需要计算第一层,所以放到了这里统一计算
+ raise RuntimeError(f"Target values should be in [0, {self.n_classes - 1}], "
+ f"but values in range [{target.min().item()}, {target.max().item()}] "
+ "were found. ")
+
head_output = self.head(input)
- # 取log_prob
head_logprob = log_softmax(head_output, dim=1)
- # 这里是第一层的log_prob和第二层的log_prob加起来作为最后的输出
- # tips: 对于属于第一层的样本,只需要计算第一层的log_prob就好
- # 对于属于第二层的样本,需要将第一层计算得到的cluster对应类别的log_prob和
- 第二层cluster内计算得到的log_prob加起来,所以是output +=
output += head_logprob.gather(1, gather_inds.unsqueeze(1)).squeeze()
loss = (-output).mean()
- # 返回一个nametuple
+
+ if not is_batched:
+ output = output.squeeze(0)
+
return _ASMoutput(output, loss)
```
3. 预测
```python
-def predict(self, input):
- """
+def predict(self, input: Tensor) -> Tensor:
+ r""" This is equivalent to `self.log_prob(input).argmax(dim=1)`,
+ but is more efficient in some cases.
+
Args:
input (Tensor): a minibatch of examples
+
Returns:
output (Tensor): a class with the highest probability for each example
+
Shape:
- - Input: :math:`(N, in\_features)`
+ - Input: :math:`(N, \texttt{in\_features})`
- Output: :math:`(N)`
"""
- # 第一层的线性转化
+
head_output = self.head(input)
- # 记录预测target的位置
output = torch.argmax(head_output, dim=1)
- # 判断预测的位置是否都是低频词
not_in_shortlist = (output >= self.shortlist_size)
- # 获取预测高频词的样本index
all_in_shortlist = not (not_in_shortlist.any())
- # 如果预测的结果都为高频词,则直接返回结果
+
if all_in_shortlist:
return output
- # 如果预测的结果都为低频词
+
elif not_in_shortlist.all():
- # 计算低频词对应cluster中target对应的log_prob
log_prob = self._get_full_log_prob(input, head_output)
return torch.argmax(log_prob, dim=1)
- # 如果预测的结果既有高频词,也有低频词
+
else:
- # 只对低频词进行对应cluser的预测
log_prob = self._get_full_log_prob(input[not_in_shortlist],
head_output[not_in_shortlist])
output[not_in_shortlist] = torch.argmax(log_prob, dim=1)
return output
-# 计算低频词对应cluster中target对应的log_prob
def _get_full_log_prob(self, input, head_output):
""" Given input tensor, and output of `self.head`,
compute the log of the full distribution """
@@ -268,36 +293,84 @@ def _get_full_log_prob(self, input, head_output):
out[:, start_idx:stop_idx] = output_logprob
return out
+
+def log_prob(self, input: Tensor) -> Tensor:
+ r""" Computes log probabilities for all :math:`\texttt{n\_classes}`
+
+ Args:
+ input (Tensor): a minibatch of examples
+
+ Returns:
+ log-probabilities of for each class :math:`c`
+ in range :math:`0 <= c <= \texttt{n\_classes}`, where :math:`\texttt{n\_classes}` is a
+ parameter passed to ``AdaptiveLogSoftmaxWithLoss`` constructor.
+
+ Shape:
+ - Input: :math:`(N, \texttt{in\_features})`
+ - Output: :math:`(N, \texttt{n\_classes})`
+
+ """
+
+ head_output = self.head(input)
+ return self._get_full_log_prob(input, head_output)
```
# 四、对比分析
无其它框架实现
-# 五、方案设计
+# 五、设计思路与实现方案
## 命名与参数设计
-API设计为`paddle.nn.AdaptiveLogSoftmaxWithLoss(in_features, n_classes, cutoffs, div_value=4.0, head_bias=False, name=None)`及
-`paddle.nn.functional.adaptive_log_softmax_with_loss(input, label,
-in_features, n_classes, cutoffs, div_value=4.0, head_bias=False, name=None)`, 返回为`NamedTuple` 包含 `output` 和 `loss`字段
+layer层类API:`paddle.nn.AdaptiveLogSoftmaxWithLoss(in_features, n_classes, cutoffs, div_value=4.0, head_bias=False, name=None)`,包含两个主要方法:
+ - forward(self, input, label),用于训练,返回为`output` 和 `loss`
+ - predict(self, input),用于预测
+
+- in_features (int): 输入tensor的特征数量。
+- n_classes (int): 数据集中类型的个数。
+- cutoffs (Sequence): 用于将label分配到不同存储桶的截断值。
+- div_value (float, 可选): 用于计算簇大小的指数值. 默认值:4.0。
+- head_bias (bool, 可选): 如果为 ``True``,向自适应 softmax 的头部添加偏置项. 默认值:``False``.
+- name (str, 可选): 具体用法请参见 :ref:`api_guide_Name`,一般无需设置,默认值为 None。
+function API:`paddle.nn.functional.adaptive_log_softmax_with_loss(input, label, head_weight, tail_weights, cutoffs, head_bias=None)` 用于训练计算
+
+- input (Tensor): 输入张量,数据类型为 float32 或 float64。
+- label (Tensor): 标签张量,数据类型为 float32 或 float64。
+- head_weight (Tensor): 用于线性计算的权重矩阵,数据类型为 float32 或 float64。
+- tail_weights (Tensor): 用于线性计算的权重矩阵,数据类型为 float32 或 float64。
+- cutoffs (Sequence): 用于将label分配到不同存储桶的截断值。
+- head_bias (Tensor, 可选): 用于线性计算的偏置矩阵,数据类型为 float32 或 float64。
+- name (str, 可选): 具体用法请参见 :ref:`api_guide_Name`,一般无需设置,默认值为 None。
## 底层OP设计
使用已有API组合实现,不再单独设计OP。
## API实现方案
-主要参考pytorch实现,替换掉部分paddle没有的api
-
+
+计算逻辑参考pytorch实现,并基于paddle API进行重组与封装:
+- function API:`paddle.nn.functional.adaptive_log_softmax_with_loss(input, label, head_weight, tail_weights, cutoffs, head_bias=None)`,使用已有api进行组合实现,
+
+- layer API:`paddle.nn.AdaptiveLogSoftmaxWithLoss(self, in_features, n_classes, cutoffs, div_value=4.0, head_bias=False, name=None)`,包含两个主要方法:
+ - `forward(self, input, label)`,用于训练,返回为`output` 和 `loss`
+ - `predict(self, input)`,用于预测,其计算与forward共享权重但是计算逻辑存在差异,故使用已有API组合实现的方式单独实现
+
# 六、测试和验收的考量
测试考虑的case如下:
-- 数值正确性
+- 数值正确性(CPU、GPU、动态图、静态图)
+ - 对`log_prob`(前置函数),log_prob的各类总和概率为1,即`paddle.exp(logprob_out).sum(1)=paddle.ones([4])`
+ - 对`forward`
+ - `output`为各类别概率即`output=log_prob.gather(y.unsqueeze(1), 1).slice([1], [0], [1]).squeeze()`
+ - `loss`为`loss=nll_loss(log_prob, y)`,其中`nll_loss`已经实现
+ - 对`predict`,有`predict=log_prob.argmax(axis=1)`
- 错误检查:`cutoff`的唯一性,数据类型,数值大于零小于`n_classes - 1`
- 错误检查:`input`尺寸与`in_features`一致
# 七、可行性分析及规划排期
-方案主要依赖paddle现有api组合而成
+
+paddle.gather与torch.gather存在差异,使用paddle.take_along_axis替换实现。实现无明显难点,可以按期完成。
# 八、影响面
为独立新增API,对其他模块没有影响
@@ -305,4 +378,4 @@ in_features, n_classes, cutoffs, div_value=4.0, head_bias=False, name=None)`,
# 名词解释
无
# 附件及参考资料
-无
+无
\ No newline at end of file