diff --git a/rfcs/APIs/20200322_api_design_for_AdaptiveLogSoftmaxWithLoss.md b/rfcs/APIs/20231202_api_design_for_AdaptiveLogSoftmaxWithLoss.md similarity index 59% rename from rfcs/APIs/20200322_api_design_for_AdaptiveLogSoftmaxWithLoss.md rename to rfcs/APIs/20231202_api_design_for_AdaptiveLogSoftmaxWithLoss.md index 611561471..c278ff3ac 100644 --- a/rfcs/APIs/20200322_api_design_for_AdaptiveLogSoftmaxWithLoss.md +++ b/rfcs/APIs/20231202_api_design_for_AdaptiveLogSoftmaxWithLoss.md @@ -1,12 +1,12 @@ # paddle.nn.AdaptiveLogSoftmaxWithLoss 设计文档 -|API名称 | paddle.nn.AdaptiveLogSoftmaxWithLoss | +|API名称 | paddle.nn.AdaptiveLogSoftmaxWithLoss | |---|------------------------------------| -|提交作者 | PeachML | -|提交时间 | 2022-03-22 | -|版本号 | V1.0 | -|依赖飞桨版本 | develop | -|文件名 | 20200322_api_design_for_AdaptiveLogSoftmaxWithLoss.md
| +|提交作者 | netpunk | +|提交时间 | 2023-12-02 | +|版本号 | V1.0 | +|依赖飞桨版本 | develop | +|文件名 | 20200322_api_design_for_AdaptiveLogSoftmaxWithLoss.md
| # 一、概述 @@ -17,10 +17,33 @@ Paddle需要扩充API,新增 AdaptiveLogSoftmaxWithLoss API, 调用路径为:`paddle.nn.AdaptiveLogSoftmaxWithLoss` 和 `paddle.nn.functional.adaptive_log_softmax_with_loss`。 实现Softmax快速近似计算的功能。 +## 2、功能目标 + +为飞桨补充 AdaptiveLogSoftmaxWithLoss API,该API实现 softmax 函数近似计算 + +adaptive_log_softmax_with_loss的计算分步骤如下 + +1. ![image](https://github.com/PaddlePaddle/community/assets/69072522/3f17c9fd-212a-444c-9a87-2a975c452940) + + (将输入 `input` 通过线性变换映射到一个高维空间,其中 `head_weight` 是学习到的权重,`head_bias` 是偏置项。这个映射允许模型学习类别之间的复杂关系。) + +2. ![image](https://github.com/PaddlePaddle/community/assets/69072522/893286a4-9c78-4e7f-b5f0-ec152ef69267) + + ( 将线性变换后的结果进行 softmax 操作,得到每个类别的概率分布,然后取对数。这有助于解决数值稳定性的问题,并且对数概率更容易处理。) + +3. ![image](https://github.com/PaddlePaddle/community/assets/69072522/b6987bfb-e1a6-4193-9c12-818b9cc2a76c) + + (从 `head_logprob` 中选择与给定类别索引 `gather_inds` 相对应的对数概率,然后将其累加到 `output` 中。这一步是为了计算 adaptive softmax 损失,其中仅关注一小部分类别的对数概率。) + +4. ![image](https://github.com/PaddlePaddle/community/assets/69072522/e0b6e756-a6d3-46d5-b0bb-240e3847a05f) + + (将累加的对数概率取负值并求平均,得到损失值。这是一个常见的负对数似然损失,用于衡量模型输出与真实标签之间的差异。) + +这个函数不止输出`loss`,还输出`output`,表示经过 log softmax 转换后的对数概率的累加值,即每个类别的对数概率的总和。可能用于其他需要基于类别概率进行决策或分析的需求 + ## 3、意义 -在自然语言处理中,当字典维度过大时,embedding 将占据模型大部分参数量。 -例如机器翻译任务中,词表维度大约是2^17,embedding维度取1024,那么就会产生将近1亿参数量, -如果不共享embedding矩阵和softmax映射的矩阵,将会再多出1亿参数量。 + +在自然语言处理中,当字典维度过大时,embedding 可能会占据模型较大部分的参数量。 这样会引起常见的两个问题: @@ -144,18 +167,32 @@ Efficient softmax approximation as described in 2. 训练 ```python -def forward(self, input, target): - # input的shape为[batch_size * bptt, hidden_size] - # target的shape为[batch_size * bptt, 1] - if input.size(0) != target.size(0): - raise RuntimeError('Input and target should have the same size ' - 'in the batch dimension.') - # 用来统计多个cluster计算的batch,然后求和,保证最终等于batch_size +def forward(self, input_: Tensor, target_: Tensor) -> _ASMoutput: + targ_dim = target_.dim() + + if targ_dim == 1: + if input_.size(0) != target_.size(0): + raise RuntimeError('Input and target should have the same size ' + 'in the batch dimension.') + if input_.dim() != 2: + raise RuntimeError('1D target tensor expects 2D input tensors, ' + 'but found inputs with size', input_.size()) + elif targ_dim == 0: + if input_.dim() != 1: + raise RuntimeError('0D target tensor expects 1D input tensors, ' + 'but found inputs with size', input_.size()) + else: + raise RuntimeError('0D or 1D target tensor expected, ' + 'multi-target not supported') + + is_batched = targ_dim > 0 + input = input_ if is_batched else input_.unsqueeze(0) + target = target_ if is_batched else target_.unsqueeze(0) + used_rows = 0 batch_size = target.size(0) - # 用来记录在target位置的 logprob + output = input.new_zeros(batch_size) - # 用来记录batch样本在第一层对应的类别 gather_inds = target.new_empty(batch_size) cutoff_values = [0] + self.cutoffs @@ -163,94 +200,82 @@ def forward(self, input, target): low_idx = cutoff_values[i] high_idx = cutoff_values[i + 1] - # 找到当前cluster的样本对应的index + target_mask = (target >= low_idx) & (target < high_idx) row_indices = target_mask.nonzero().squeeze() - # 如果当前cluster没有样本,则没有loss + if row_indices.numel() == 0: continue - # target对应高频词,这里只用来记录batch对应的target,高频词的预测在后面 self.head + if i == 0: gather_inds.index_copy_(0, row_indices, target[target_mask]) - # target对应低频词 + else: - # 获取低频cluster对应的target的相对位置 relative_target = target[target_mask] - low_idx - # 获取对应cluster的input input_subset = input.index_select(0, row_indices) - # 经过线性变换 得到 [batch_size_i, target_i] + cluster_output = self.tail[i - 1](input_subset) - # 当前cluster对应第一层权重元素的类别 cluster_index = self.shortlist_size + i - 1 - # 记录对应第一层的类别 + gather_inds.index_fill_(0, row_indices, cluster_index) - # 计算当前cluster的log_prob cluster_logprob = log_softmax(cluster_output, dim=1) - # 获取对应target位置的log_prob local_logprob = cluster_logprob.gather(1, relative_target.unsqueeze(1)) - # 将结果记录到对应的batch中 output.index_copy_(0, row_indices, local_logprob.squeeze(1)) used_rows += row_indices.numel() if used_rows != batch_size: - raise RuntimeError("Target values should be in [0, {}], " - "but values in range [{}, {}] " - "were found. ".format(self.n_classes - 1, - target.min().item(), - target.max().item())) - # 第一层的线性变换,因为无论高频和低频词都需要计算第一层,所以放到了这里统一计算 + raise RuntimeError(f"Target values should be in [0, {self.n_classes - 1}], " + f"but values in range [{target.min().item()}, {target.max().item()}] " + "were found. ") + head_output = self.head(input) - # 取log_prob head_logprob = log_softmax(head_output, dim=1) - # 这里是第一层的log_prob和第二层的log_prob加起来作为最后的输出 - # tips: 对于属于第一层的样本,只需要计算第一层的log_prob就好 - # 对于属于第二层的样本,需要将第一层计算得到的cluster对应类别的log_prob和 - 第二层cluster内计算得到的log_prob加起来,所以是output += output += head_logprob.gather(1, gather_inds.unsqueeze(1)).squeeze() loss = (-output).mean() - # 返回一个nametuple + + if not is_batched: + output = output.squeeze(0) + return _ASMoutput(output, loss) ``` 3. 预测 ```python -def predict(self, input): - """ +def predict(self, input: Tensor) -> Tensor: + r""" This is equivalent to `self.log_prob(input).argmax(dim=1)`, + but is more efficient in some cases. + Args: input (Tensor): a minibatch of examples + Returns: output (Tensor): a class with the highest probability for each example + Shape: - - Input: :math:`(N, in\_features)` + - Input: :math:`(N, \texttt{in\_features})` - Output: :math:`(N)` """ - # 第一层的线性转化 + head_output = self.head(input) - # 记录预测target的位置 output = torch.argmax(head_output, dim=1) - # 判断预测的位置是否都是低频词 not_in_shortlist = (output >= self.shortlist_size) - # 获取预测高频词的样本index all_in_shortlist = not (not_in_shortlist.any()) - # 如果预测的结果都为高频词,则直接返回结果 + if all_in_shortlist: return output - # 如果预测的结果都为低频词 + elif not_in_shortlist.all(): - # 计算低频词对应cluster中target对应的log_prob log_prob = self._get_full_log_prob(input, head_output) return torch.argmax(log_prob, dim=1) - # 如果预测的结果既有高频词,也有低频词 + else: - # 只对低频词进行对应cluser的预测 log_prob = self._get_full_log_prob(input[not_in_shortlist], head_output[not_in_shortlist]) output[not_in_shortlist] = torch.argmax(log_prob, dim=1) return output -# 计算低频词对应cluster中target对应的log_prob def _get_full_log_prob(self, input, head_output): """ Given input tensor, and output of `self.head`, compute the log of the full distribution """ @@ -268,36 +293,84 @@ def _get_full_log_prob(self, input, head_output): out[:, start_idx:stop_idx] = output_logprob return out + +def log_prob(self, input: Tensor) -> Tensor: + r""" Computes log probabilities for all :math:`\texttt{n\_classes}` + + Args: + input (Tensor): a minibatch of examples + + Returns: + log-probabilities of for each class :math:`c` + in range :math:`0 <= c <= \texttt{n\_classes}`, where :math:`\texttt{n\_classes}` is a + parameter passed to ``AdaptiveLogSoftmaxWithLoss`` constructor. + + Shape: + - Input: :math:`(N, \texttt{in\_features})` + - Output: :math:`(N, \texttt{n\_classes})` + + """ + + head_output = self.head(input) + return self._get_full_log_prob(input, head_output) ``` # 四、对比分析 无其它框架实现 -# 五、方案设计 +# 五、设计思路与实现方案 ## 命名与参数设计 -API设计为`paddle.nn.AdaptiveLogSoftmaxWithLoss(in_features, n_classes, cutoffs, div_value=4.0, head_bias=False, name=None)`及 -`paddle.nn.functional.adaptive_log_softmax_with_loss(input, label, -in_features, n_classes, cutoffs, div_value=4.0, head_bias=False, name=None)`, 返回为`NamedTuple` 包含 `output` 和 `loss`字段 +layer层类API:`paddle.nn.AdaptiveLogSoftmaxWithLoss(in_features, n_classes, cutoffs, div_value=4.0, head_bias=False, name=None)`,包含两个主要方法: + - forward(self, input, label),用于训练,返回为`output` 和 `loss` + - predict(self, input),用于预测 + +- in_features (int): 输入tensor的特征数量。 +- n_classes (int): 数据集中类型的个数。 +- cutoffs (Sequence): 用于将label分配到不同存储桶的截断值。 +- div_value (float, 可选): 用于计算簇大小的指数值. 默认值:4.0。 +- head_bias (bool, 可选): 如果为 ``True``,向自适应 softmax 的头部添加偏置项. 默认值:``False``. +- name (str, 可选): 具体用法请参见 :ref:`api_guide_Name`,一般无需设置,默认值为 None。 +function API:`paddle.nn.functional.adaptive_log_softmax_with_loss(input, label, head_weight, tail_weights, cutoffs, head_bias=None)` 用于训练计算 + +- input (Tensor): 输入张量,数据类型为 float32 或 float64。 +- label (Tensor): 标签张量,数据类型为 float32 或 float64。 +- head_weight (Tensor): 用于线性计算的权重矩阵,数据类型为 float32 或 float64。 +- tail_weights (Tensor): 用于线性计算的权重矩阵,数据类型为 float32 或 float64。 +- cutoffs (Sequence): 用于将label分配到不同存储桶的截断值。 +- head_bias (Tensor, 可选): 用于线性计算的偏置矩阵,数据类型为 float32 或 float64。 +- name (str, 可选): 具体用法请参见 :ref:`api_guide_Name`,一般无需设置,默认值为 None。 ## 底层OP设计 使用已有API组合实现,不再单独设计OP。 ## API实现方案 -主要参考pytorch实现,替换掉部分paddle没有的api - + +计算逻辑参考pytorch实现,并基于paddle API进行重组与封装: +- function API:`paddle.nn.functional.adaptive_log_softmax_with_loss(input, label, head_weight, tail_weights, cutoffs, head_bias=None)`,使用已有api进行组合实现, + +- layer API:`paddle.nn.AdaptiveLogSoftmaxWithLoss(self, in_features, n_classes, cutoffs, div_value=4.0, head_bias=False, name=None)`,包含两个主要方法: + - `forward(self, input, label)`,用于训练,返回为`output` 和 `loss` + - `predict(self, input)`,用于预测,其计算与forward共享权重但是计算逻辑存在差异,故使用已有API组合实现的方式单独实现 + # 六、测试和验收的考量 测试考虑的case如下: -- 数值正确性 +- 数值正确性(CPU、GPU、动态图、静态图) + - 对`log_prob`(前置函数),log_prob的各类总和概率为1,即`paddle.exp(logprob_out).sum(1)=paddle.ones([4])` + - 对`forward` + - `output`为各类别概率即`output=log_prob.gather(y.unsqueeze(1), 1).slice([1], [0], [1]).squeeze()` + - `loss`为`loss=nll_loss(log_prob, y)`,其中`nll_loss`已经实现 + - 对`predict`,有`predict=log_prob.argmax(axis=1)` - 错误检查:`cutoff`的唯一性,数据类型,数值大于零小于`n_classes - 1` - 错误检查:`input`尺寸与`in_features`一致 # 七、可行性分析及规划排期 -方案主要依赖paddle现有api组合而成 + +paddle.gather与torch.gather存在差异,使用paddle.take_along_axis替换实现。实现无明显难点,可以按期完成。 # 八、影响面 为独立新增API,对其他模块没有影响 @@ -305,4 +378,4 @@ in_features, n_classes, cutoffs, div_value=4.0, head_bias=False, name=None)`, # 名词解释 无 # 附件及参考资料 -无 +无 \ No newline at end of file