diff --git a/docs/api/paddle/nn/FractionalMaxPool2D_cn.rst b/docs/api/paddle/nn/FractionalMaxPool2D_cn.rst new file mode 100644 index 00000000000..4496976faf0 --- /dev/null +++ b/docs/api/paddle/nn/FractionalMaxPool2D_cn.rst @@ -0,0 +1,61 @@ +.. _cn_api_paddle_nn_FractionalMaxPool2D: + + +FractionalMaxPool2D +------------------------------- + +.. py:class:: paddle.nn.FractionalMaxPool2D(output_size, kernel_size=None, random_u=None, return_mask=False, name=None) + +对输入的 Tensor `x` 采取 `2` 维分数阶最大值池化操作,具体可以参考论文: + +[1] Ben Graham, Fractional Max-Pooling. 2015. http://arxiv.org/abs/1412.6071 + +其中输出的 `H` 和 `W` 由参数 `output_size` 决定。 + +对于各个输出维度,分数阶最大值池化的计算公式为: + +.. math:: + + \alpha &= size_{input} / size_{output} + + index_{start} &= ceil( \alpha * (i + u) - 1) + + index_{end} &= ceil( \alpha * (i + 1 + u) - 1) + + Output &= max(Input[index_{start}:index_{end}]) + + where, u \in (0, 1), i = 0,1,2...size_{output} + +公式中的 `u` 即为函数中的参数 `random_u`。另外,由于 `ceil` 对于正小数的操作最小值为 `1` ,因此这里需要再减去 `1` 使索引可以从 `0` 开始计数。 + +例如,有一个长度为 `7` 的序列 `[2, 4, 3, 1, 5, 2, 3]` , `output_size` 为 `5` , `random_u` 为 `0.3`。 +则由上述公式可得 `alpha = 7/5 = 1.4` , 索引的起始序列为 `[0, 1, 3, 4, 6]` ,索引的截止序列为 `[1, 3, 4, 6, 7]` 。 +进而得到论文中的随机序列为 `index_end - index_start = [1, 2, 1, 2, 1]` 。 +由于池化操作的步长与核尺寸相同,同为此随机序列,最终得到池化输出为 `[2, 4, 1, 5, 3]` 。 + + +参数 +::::::::: + + - **output_size** (int|list|tuple):算子输出图的尺寸,其数据类型为 int 或 list,tuple。如果输出为 tuple 或者 list,则必须包含两个元素, `(H, W)` 。 `H` 和 `W` 可以是 `int` ,也可以是 `None` ,表示与输入保持一致。 + - **kernel_size** (int|list|tuple, 可选) - 池化核大小。如果它是一个元组或列表,它必须包含两个整数值,(pool_size_Height, pool_size_Width)。若为一个整数,则表示 H 和 W 维度上均为该值,比如若 pool_size=2,则池化核大小为 [2,2]。默认为 `None`,表示使用 `disjoint` (`non-overlapping`) 模式。 + - **random_u** (float):分数阶池化操作的浮点随机数,取值范围为 `(0, 1)` 。默认为 `None` ,由框架随机生成,可以使用 `paddle.seed` 设置随机种子。 + - **return_mask** (bool,可选):如果设置为 `True` ,则会与输出一起返回最大值的索引,默认为 `False`。 + - **name** (str,可选) - 具体用法请参见 :ref:`api_guide_Name`,一般无需设置,默认值为 `None`。 + +形状 +::::::::: + + - **x** (Tensor):默认形状为(批大小,通道数,输出特征长度,宽度),即 NCHW 格式的 4-D Tensor。其数据类型为 float16, bfloat16, float32, float64。 + - **output** (Tensor):默认形状为(批大小,通道数,输出特征长度,宽度),即 NCHW 格式的 4-D Tensor。其数据类型与输入 x 相同。 + +返回 +::::::::: + +计算 FractionalMaxPool2D 的可调用对象 + + +代码示例 +::::::::: + +COPY-FROM: paddle.nn.FractionalMaxPool2D diff --git a/docs/api/paddle/nn/FractionalMaxPool3D_cn.rst b/docs/api/paddle/nn/FractionalMaxPool3D_cn.rst new file mode 100644 index 00000000000..6611496699e --- /dev/null +++ b/docs/api/paddle/nn/FractionalMaxPool3D_cn.rst @@ -0,0 +1,59 @@ +.. _cn_api_paddle_nn_FractionalMaxPool3D: + + +FractionalMaxPool3D +------------------------------- + +.. py:function:: paddle.nn.FractionalMaxPool3D(output_size, kernel_size=None, random_u=None, return_mask=False, name=None) + +对输入的 Tensor `x` 采取 `2` 维分数阶最大值池化操作,具体可以参考论文: + +[1] Ben Graham, Fractional Max-Pooling. 2015. http://arxiv.org/abs/1412.6071 + +其中输出的 `H` 和 `W` 由参数 `output_size` 决定。 + +对于各个输出维度,分数阶最大值池化的计算公式为: + +.. math:: + + \alpha &= size_{input} / size_{output} + + index_{start} &= ceil( \alpha * (i + u) - 1) + + index_{end} &= ceil( \alpha * (i + 1 + u) - 1) + + Output &= max(Input[index_{start}:index_{end}]) + + where, u \in (0, 1), i = 0,1,2...size_{output} + +公式中的 `u` 即为函数中的参数 `random_u`。另外,由于 `ceil` 对于正小数的操作最小值为 `1` ,因此这里需要再减去 `1` 使索引可以从 `0` 开始计数。 + +例如,有一个长度为 `7` 的序列 `[2, 4, 3, 1, 5, 2, 3]` , `output_size` 为 `5` , `random_u` 为 `0.3`。 +则由上述公式可得 `alpha = 7/5 = 1.4` , 索引的起始序列为 `[0, 1, 3, 4, 6]` ,索引的截止序列为 `[1, 3, 4, 6, 7]` 。 +进而得到论文中的随机序列为 `index_end - index_start = [1, 2, 1, 2, 1]` 。 +由于池化操作的步长与核尺寸相同,同为此随机序列,最终得到池化输出为 `[2, 4, 1, 5, 3]` 。 + + +参数 +::::::::: + + - **output_size** (int|list|tuple):算子输出图的尺寸,其数据类型为 int 或 list,tuple。如果输出为 tuple 或者 list,则必须包含两个元素, `(H, W)` 。 `H` 和 `W` 可以是 `int` ,也可以是 `None` ,表示与输入保持一致。 + - **kernel_size** (int|list|tuple,可选) - 池化核大小。如果它是一个元组或列表,它必须包含三个整数值,(pool_size_Depth,pool_size_Height, pool_size_Width)。若为一个整数,则表示 D,H 和 W 维度上均为该值,比如若 pool_size=2,则池化核大小为[2,2,2]。默认为 `None`,表示使用 `disjoint` (`non-overlapping`) 模式。 + - **random_u** (float):分数阶池化操作的浮点随机数,取值范围为 `(0, 1)` 。默认为 `None` ,由框架随机生成,可以使用 `paddle.seed` 设置随机种子。 + - **return_mask** (bool,可选):如果设置为 `True` ,则会与输出一起返回最大值的索引,默认为 `False`。 + - **name** (str,可选) - 具体用法请参见 :ref:`api_guide_Name`,一般无需设置,默认值为 `None`。 + +形状 +::::::::: + - **x** (Tensor):默认形状为(批大小,通道数,输出特征深度,高度,宽度),即 NCDHW 格式的 5-D Tensor。其数据类型为 float16, bfloat16, float32, float64。 + - **output** (Tensor):默认形状为(批大小,通道数,输出特征深度,高度,宽度),即 NCDHW 格式的 5-D Tensor。其数据类型与输入 x 相同。 + +返回 +::::::::: +计算 FractionalMaxPool3D 的可调用对象 + + +代码示例 +::::::::: + +COPY-FROM: paddle.nn.FractionalMaxPool3D diff --git a/docs/api/paddle/nn/Overview_cn.rst b/docs/api/paddle/nn/Overview_cn.rst index 084257ef7a0..addf2f21e7f 100644 --- a/docs/api/paddle/nn/Overview_cn.rst +++ b/docs/api/paddle/nn/Overview_cn.rst @@ -92,6 +92,8 @@ pooling 层 " :ref:`paddle.nn.MaxUnPool1D ` ", "一维最大反池化层" " :ref:`paddle.nn.MaxUnPool2D ` ", "二维最大反池化层" " :ref:`paddle.nn.MaxUnPool3D ` ", "三维最大反池化层" + " :ref:`paddle.nn.FractionalMaxPool2D ` ", "二维分数阶最大值池化层" + " :ref:`paddle.nn.FractionalMaxPool3D ` ", "三维分数阶最大值池化层" .. _padding_layers: @@ -359,6 +361,8 @@ Pooling 相关函数 " :ref:`paddle.nn.functional.max_unpool1d ` ", "一维最大反池化层" " :ref:`paddle.nn.functional.max_unpool1d ` ", "二维最大反池化层" " :ref:`paddle.nn.functional.max_unpool3d ` ", "三维最大反池化层" + " :ref:`paddle.nn.functional.fractional_max_pool2d ` ", "二维分数阶最大值池化" + " :ref:`paddle.nn.functional.fractional_max_pool3d ` ", "三维分数阶最大值池化" .. _padding_functional: diff --git a/docs/api/paddle/nn/functional/fractional_max_pool2d_cn.rst b/docs/api/paddle/nn/functional/fractional_max_pool2d_cn.rst new file mode 100644 index 00000000000..414663303b3 --- /dev/null +++ b/docs/api/paddle/nn/functional/fractional_max_pool2d_cn.rst @@ -0,0 +1,51 @@ +.. _cn_api_paddle_nn_functional_fractional_max_pool2d: + +fractional_max_pool2d +------------------------------- + +.. py:function:: paddle.nn.functional.fractional_max_pool2d(x, output_size, kernel_size=None, random_u=None, return_mask=False, name=None) + +对输入的 Tensor `x` 采取 `2` 维分数阶最大值池化操作,具体可以参考论文: + +[1] Ben Graham, Fractional Max-Pooling. 2015. http://arxiv.org/abs/1412.6071 + +其中输出的 `H` 和 `W` 由参数 `output_size` 决定。 + +对于各个输出维度,分数阶最大值池化的计算公式为: + +.. math:: + + \alpha &= size_{input} / size_{output} + + index_{start} &= ceil( \alpha * (i + u) - 1) + + index_{end} &= ceil( \alpha * (i + 1 + u) - 1) + + Output &= max(Input[index_{start}:index_{end}]) + + where, u \in (0, 1), i = 0,1,2...size_{output} + +公式中的 `u` 即为函数中的参数 `random_u`。另外,由于 `ceil` 对于正小数的操作最小值为 `1` ,因此这里需要再减去 `1` 使索引可以从 `0` 开始计数。 + +例如,有一个长度为 `7` 的序列 `[2, 4, 3, 1, 5, 2, 3]` , `output_size` 为 `5` , `random_u` 为 `0.3`。 +则由上述公式可得 `alpha = 7/5 = 1.4` , 索引的起始序列为 `[0, 1, 3, 4, 6]` ,索引的截止序列为 `[1, 3, 4, 6, 7]` 。 +进而得到论文中的随机序列为 `index_end - index_start = [1, 2, 1, 2, 1]` 。 +由于池化操作的步长与核尺寸相同,同为此随机序列,最终得到池化输出为 `[2, 4, 1, 5, 3]` 。 + +参数 +::::::::: + - **x** (Tensor):当前算子的输入,其是一个形状为 `[N, C, H, W]` 的 4-D Tensor。其中 `N` 是 batch size, `C` 是通道数, `H` 是输入特征的高度, `W` 是输入特征的宽度。其数据类型为 `float16`, `bfloat16`, `float32`, `float64` 。 + - **output_size** (int|list|tuple):算子输出图的尺寸,其数据类型为 int 或 list,tuple。如果输出为 tuple 或者 list,则必须包含两个元素, `(H, W)` 。 `H` 和 `W` 可以是 `int` ,也可以是 `None` ,表示与输入保持一致。 + - **kernel_size** (int|list|tuple, 可选) - 池化核大小。如果它是一个元组或列表,它必须包含两个整数值,(pool_size_Height, pool_size_Width)。若为一个整数,则表示 H 和 W 维度上均为该值,比如若 pool_size=2,则池化核大小为 [2,2]。默认为 `None`,表示使用 `disjoint` (`non-overlapping`) 模式。 + - **random_u** (float):分数阶池化操作的浮点随机数,取值范围为 `(0, 1)` 。默认为 `None` ,由框架随机生成,可以使用 `paddle.seed` 设置随机种子。 + - **return_mask** (bool,可选):如果设置为 `True` ,则会与输出一起返回最大值的索引,默认为 `False`。 + - **name** (str,可选) - 具体用法请参见 :ref:`api_guide_Name`,一般无需设置,默认值为 `None`。 + +返回 +::::::::: +`Tensor`,输入 `x` 经过分数阶最大值池化计算得到的目标 4-D Tensor,其数据类型与输入相同。 + +代码示例 +::::::::: + +COPY-FROM: paddle.nn.functional.fractional_max_pool2d diff --git a/docs/api/paddle/nn/functional/fractional_max_pool3d_cn.rst b/docs/api/paddle/nn/functional/fractional_max_pool3d_cn.rst new file mode 100644 index 00000000000..6044f06a40b --- /dev/null +++ b/docs/api/paddle/nn/functional/fractional_max_pool3d_cn.rst @@ -0,0 +1,51 @@ +.. _cn_api_paddle_nn_functional_fractional_max_pool3d: + +fractional_max_pool3d +------------------------------- + +.. py:function:: paddle.nn.functional.fractional_max_pool3d(x, output_size, kernel_size=None, random_u=None, return_mask=False, name=None) + +对输入的 Tensor `x` 采取 `3` 维分数阶最大值池化操作,具体可以参考论文: + +[1] Ben Graham, Fractional Max-Pooling. 2015. http://arxiv.org/abs/1412.6071 + +其中输出的 `D`, `H` 和 `W` 由参数 `output_size` 决定。 + +对于各个输出维度,分数阶最大值池化的计算公式为: + +.. math:: + + \alpha &= size_{input} / size_{output} + + index_{start} &= ceil( \alpha * (i + u) - 1) + + index_{end} &= ceil( \alpha * (i + 1 + u) - 1) + + Output &= max(Input[index_{start}:index_{end}]) + + where, u \in (0, 1), i = 0,1,2...size_{output} + +公式中的 `u` 即为函数中的参数 `random_u`。另外,由于 `ceil` 对于正小数的操作最小值为 `1` ,因此这里需要再减去 `1` 使索引可以从 `0` 开始计数。 + +例如,有一个长度为 `7` 的序列 `[2, 4, 3, 1, 5, 2, 3]` , `output_size` 为 `5` , `random_u` 为 `0.3`。 +则由上述公式可得 `alpha = 7/5 = 1.4` , 索引的起始序列为 `[0, 1, 3, 4, 6]` ,索引的截止序列为 `[1, 3, 4, 6, 7]` 。 +进而得到论文中的随机序列为 `index_end - index_start = [1, 2, 1, 2, 1]` 。 +由于池化操作的步长与核尺寸相同,同为此随机序列,最终得到池化输出为 `[2, 4, 1, 5, 3]` 。 + +参数 +::::::::: + - **x** (Tensor):当前算子的输入,其是一个形状为 `[N, C, D, H, W]` 的 5-D Tensor。其中 `N` 是 batch size, `C` 是通道数, `D` 是输入特征的深度, `H` 是输入特征的高度, `W` 是输入特征的宽度。其数据类型为 `float16`, `bfloat16`, `float32`, `float64` 。 + - **output_size** (int|list|tuple):算子输出图的尺寸,其数据类型为 int 或 list,tuple。如果输出为 tuple 或者 list,则必须包含三个元素, `(D, H, W)` 。 `D`, `H` 和 `W` 可以是 `int` ,也可以是 `None` ,表示与输入保持一致。 + - **kernel_size** (int|list|tuple,可选) - 池化核大小。如果它是一个元组或列表,它必须包含三个整数值,(pool_size_Depth,pool_size_Height, pool_size_Width)。若为一个整数,则表示 D,H 和 W 维度上均为该值,比如若 pool_size=2,则池化核大小为[2,2,2]。默认为 `None`,表示使用 `disjoint` (`non-overlapping`) 模式。 + - **random_u** (float):分数阶池化操作的浮点随机数,取值范围为 `(0, 1)` 。默认为 `None` ,由框架随机生成,可以使用 `paddle.seed` 设置随机种子。 + - **return_mask** (bool,可选):如果设置为 `True` ,则会与输出一起返回最大值的索引,默认为 `False`。 + - **name** (str,可选) - 具体用法请参见 :ref:`api_guide_Name`,一般无需设置,默认值为 `None`。 + +返回 +::::::::: +`Tensor`,输入 `x` 经过分数阶最大值池化计算得到的目标 5-D Tensor,其数据类型与输入相同。 + +代码示例 +::::::::: + +COPY-FROM: paddle.nn.functional.fractional_max_pool3d diff --git a/docs/guides/model_convert/convert_from_pytorch/api_difference/functional/torch.nn.functional.fractional_max_pool2d.md b/docs/guides/model_convert/convert_from_pytorch/api_difference/functional/torch.nn.functional.fractional_max_pool2d.md new file mode 100644 index 00000000000..65973bb3bc5 --- /dev/null +++ b/docs/guides/model_convert/convert_from_pytorch/api_difference/functional/torch.nn.functional.fractional_max_pool2d.md @@ -0,0 +1,50 @@ +## [ torch 参数更多 ]torch.nn.functional.fractional_max_pool2d + +### [torch.nn.functional.fractional_max_pool2d](https://pytorch.org/docs/stable/generated/torch.nn.functional.fractional_max_pool2d.html#torch-nn-functional-fractional-max-pool2d) + +```python +torch.nn.functional.fractional_max_pool2d(input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None) +``` + +### [paddle.nn.functional.fractional_max_pool2d](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/nn/functional/fractional_max_pool2d_cn.html) + +```python +paddle.nn.functional.fractional_max_pool2d(x, output_size, kernel_size=None, random_u=None, return_mask=False, name=None) +``` + +PyTorch 参数更多,具体如下: + +### 参数映射 + +| PyTorch | PaddlePaddle | 备注 | +| ------------- | ------------ | ------------------------------------------------------ | +| input | x | 表示输入的 Tensor 。仅参数名不一致。 | +| kernel_size | kernel_size | 表示核大小。参数完全一致。 | +| output_size | output_size | 表示目标输出尺寸,PyTorch 为可选参数,Paddle 为必选参数,仅参数默认值不一致。PyTorch 的 output_size 与 output_ratio 输入二选一,如不输入 output_size,则必须输入 output_ratio,此时需要转写。转写方式与下文 output_ratio 一致。 | +| output_ratio | - | 表示目标输出比例。Paddle 无此参数,需要转写。 | +| return_indices | return_mask | 表示是否返回最大值索引。仅参数名不一致。 | +| _random_samples | random_u | 表示随机数。PyTorch 以列表形式的 Tensor 方式传入,Paddle 以 float 的方式传入,如果 PyTorch 的多个随机数相同,需要转写,如果 PyTorch 的多个随机数不同,暂无转写方式。 | + +### 转写示例 + +#### output_ratio:目标输出比例 + +```python +# 假设 intput 的 with=7, height=7, +# output_ratio = 0.75, 则目标 output 的 width = int(7*0.75) = 5, height = int(7*0.75) = 5 +# Pytorch 写法 +torch.nn.functional.fractional_max_pool2d(input, 2, output_ratio=[0.75, 0.75], return_indices=True) + +# Paddle 写法 +paddle.nn.functional.fractional_max_pool2d(x, output_size=[5, 5], kernel_size=2, return_mask=True) +``` + +#### _random_samples:随机数 + +```python +# Pytorch 写法 +torch.nn.functional.fractional_max_pool2d(input, 2, output_size=[3, 3], return_indices=True, _random_samples=torch.tensor([[[0.3, 0.3]]])) + +# Paddle 写法 +paddle.nn.functional.fractional_max_pool2d(x, output_size=[3, 3], kernel_size=2, return_mask=True, random_u=0.3) +``` diff --git a/docs/guides/model_convert/convert_from_pytorch/api_difference/functional/torch.nn.functional.fractional_max_pool3d.md b/docs/guides/model_convert/convert_from_pytorch/api_difference/functional/torch.nn.functional.fractional_max_pool3d.md new file mode 100644 index 00000000000..53acc8b4af7 --- /dev/null +++ b/docs/guides/model_convert/convert_from_pytorch/api_difference/functional/torch.nn.functional.fractional_max_pool3d.md @@ -0,0 +1,50 @@ +## [ torch 参数更多 ]torch.nn.functional.fractional_max_pool3d + +### [torch.nn.functional.fractional_max_pool3d](https://pytorch.org/docs/stable/generated/torch.nn.functional.fractional_max_pool3d.html#torch-nn-functional-fractional-max-pool3d) + +```python +torch.nn.functional.fractional_max_pool3d(input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None) +``` + +### [paddle.nn.functional.fractional_max_pool3d](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/nn/functional/fractional_max_pool3d_cn.html) + +```python +paddle.nn.functional.fractional_max_pool3d(x, output_size, kernel_size=None, random_u=None, return_mask=False, name=None) +``` + +PyTorch 参数更多,具体如下: + +### 参数映射 + +| PyTorch | PaddlePaddle | 备注 | +| ------------- | ------------ | ------------------------------------------------------ | +| input | x | 表示输入的 Tensor 。仅参数名不一致。 | +| kernel_size | kernel_size | 表示核大小。参数完全一致。 | +| output_size | output_size | 表示目标输出尺寸,PyTorch 为可选参数,Paddle 为必选参数,仅参数默认值不一致。PyTorch 的 output_size 与 output_ratio 输入二选一,如不输入 output_size,则必须输入 output_ratio,此时需要转写。转写方式与下文 output_ratio 一致。 | +| output_ratio | - | 表示目标输出比例。Paddle 无此参数,需要转写。 | +| return_indices | return_mask | 表示是否返回最大值索引。仅参数名不一致。 | +| _random_samples | random_u | 表示随机数。PyTorch 以列表形式的 Tensor 方式传入,Paddle 以 float 的方式传入,如果 PyTorch 的多个随机数相同,需要转写,如果 PyTorch 的多个随机数不同,暂无转写方式。 | + +### 转写示例 + +#### output_ratio:目标输出比例 + +```python +# 假设 intput 的 depth=7, with=7, height=7, +# output_ratio = 0.75, 则目标 output 的 depth = int(7*0.75) = 5, width = int(7*0.75) = 5, height = int(7*0.75) = 5 +# Pytorch 写法 +torch.nn.functional.fractional_max_pool3d(input, 2, output_ratio=[0.75, 0.75, 0.75], return_indices=True) + +# Paddle 写法 +paddle.nn.functional.fractional_max_pool3d(x, output_size=[5, 5, 5], kernel_size=2, return_mask=True) +``` + +#### _random_samples:随机数 + +```python +# Pytorch 写法 +torch.nn.functional.fractional_max_pool3d(input, 2, output_size=[3, 3, 3], return_indices=True, _random_samples=torch.tensor([[[0.3, 0.3, 0.3]]])) + +# Paddle 写法 +paddle.nn.functional.fractional_max_pool3d(x, output_size=[3, 3, 3], kernel_size=2, return_mask=True, random_u=0.3) +``` diff --git a/docs/guides/model_convert/convert_from_pytorch/api_difference/nn/torch.nn.FractionalMaxPool2d.md b/docs/guides/model_convert/convert_from_pytorch/api_difference/nn/torch.nn.FractionalMaxPool2d.md new file mode 100644 index 00000000000..6b106778d3f --- /dev/null +++ b/docs/guides/model_convert/convert_from_pytorch/api_difference/nn/torch.nn.FractionalMaxPool2d.md @@ -0,0 +1,49 @@ +## [ torch 参数更多 ]torch.nn.FractionalMaxPool2d + +### [torch.nn.FractionalMaxPool2d](https://pytorch.org/docs/stable/generated/torch.nn.FractionalMaxPool2d.html#fractionalmaxpool2d) + +```python +torch.nn.FractionalMaxPool2d(kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None) +``` + +### [paddle.nn.FractionalMaxPool2D](https://www.paddlepaddle.org.cn/documentation/docs/en/develop/api/paddle/nn/FractionalMaxPool2D_cn.html) + +```python +paddle.nn.FractionalMaxPool2D(output_size, kernel_size=None, random_u=None, return_mask=False, name=None) +``` + +PyTorch 参数更多,具体如下: + +### 参数映射 + +| PyTorch | PaddlePaddle | 备注 | +| ------------- | ------------ | ------------------------------------------------------ | +| kernel_size | kernel_size | 表示核大小。参数完全一致。 | +| output_size | output_size | 表示目标输出尺寸,PyTorch 为可选参数,Paddle 为必选参数,仅参数默认值不一致。PyTorch 的 output_size 与 output_ratio 输入二选一,如不输入 output_size,则必须输入 output_ratio,此时需要转写。转写方式与下文 output_ratio 一致。 | +| output_ratio | - | 表示目标输出比例。Paddle 无此参数,需要转写。 | +| return_indices | return_mask | 表示是否返回最大值索引。仅参数名不一致。 | +| _random_samples | random_u | 表示随机数。PyTorch 以列表形式的 Tensor 方式传入,Paddle 以 float 的方式传入,如果 PyTorch 的多个随机数相同,需要转写,如果 PyTorch 的多个随机数不同,暂无转写方式。 | + +### 转写示例 + +#### output_ratio:目标输出比例 + +```python +# 假设 intput 的 with=7, height=7, +# output_ratio = 0.75, 则目标 output 的 width = int(7*0.75) = 5, height = int(7*0.75) = 5 +# Pytorch 写法 +torch.nn.FractionalMaxPool2d(2, output_ratio=[0.75, 0.75], return_indices=True) + +# Paddle 写法 +paddle.nn.FractionalMaxPool2D(output_size=[5, 5], kernel_size=2, return_mask=True) +``` + +#### _random_samples:随机数 + +```python +# Pytorch 写法 +torch.nn.FractionalMaxPool2d(2, output_size=[3, 3], return_indices=True, _random_samples=torch.tensor([[[0.3, 0.3]]])) + +# Paddle 写法 +paddle.nn.FractionalMaxPool2D(output_size=[3, 3], kernel_size=2, return_mask=True, random_u=0.3) +``` diff --git a/docs/guides/model_convert/convert_from_pytorch/api_difference/nn/torch.nn.FractionalMaxPool3d.md b/docs/guides/model_convert/convert_from_pytorch/api_difference/nn/torch.nn.FractionalMaxPool3d.md new file mode 100644 index 00000000000..c54931dbde4 --- /dev/null +++ b/docs/guides/model_convert/convert_from_pytorch/api_difference/nn/torch.nn.FractionalMaxPool3d.md @@ -0,0 +1,49 @@ +## [ torch 参数更多 ]torch.nn.FractionalMaxPool3d + +### [torch.nn.FractionalMaxPool3d](https://pytorch.org/docs/stable/generated/torch.nn.FractionalMaxPool3d.html#fractionalmaxpool3d) + +```python +torch.nn.FractionalMaxPool3d(kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None) +``` + +### [paddle.nn.FractionalMaxPool3D](https://www.paddlepaddle.org.cn/documentation/docs/en/develop/api/paddle/nn/FractionalMaxPool3D_cn.html) + +```python +paddle.nn.FractionalMaxPool3D(output_size, kernel_size=None, random_u=None, return_mask=False, name=None) +``` + +PyTorch 参数更多,具体如下: + +### 参数映射 + +| PyTorch | PaddlePaddle | 备注 | +| ------------- | ------------ | ------------------------------------------------------ | +| kernel_size | kernel_size | 表示核大小。参数完全一致。 | +| output_size | output_size | 表示目标输出尺寸,PyTorch 为可选参数,Paddle 为必选参数,仅参数默认值不一致。PyTorch 的 output_size 与 output_ratio 输入二选一,如不输入 output_size,则必须输入 output_ratio,此时需要转写。转写方式与下文 output_ratio 一致。 | +| output_ratio | - | 表示目标输出比例。Paddle 无此参数,需要转写。 | +| return_indices | return_mask | 表示是否返回最大值索引。仅参数名不一致。 | +| _random_samples | random_u | 表示随机数。PyTorch 以列表形式的 Tensor 方式传入,Paddle 以 float 的方式传入,如果 PyTorch 的多个随机数相同,需要转写,如果 PyTorch 的多个随机数不同,暂无转写方式。 | + +### 转写示例 + +#### output_ratio:目标输出比例 + +```python +# 假设 intput 的 depth=7, with=7, height=7, +# output_ratio = 0.75, 则目标 output 的 depth = int(7*0.75) = 5, width = int(7*0.75) = 5, height = int(7*0.75) = 5 +# Pytorch 写法 +torch.nn.FractionalMaxPool3d(2, output_ratio=[0.75, 0.75, 0.75], return_indices=True) + +# Paddle 写法 +paddle.nn.FractionalMaxPool2D(output_size=[5, 5, 5], kernel_size=2, return_mask=True) +``` + +#### _random_samples:随机数 + +```python +# Pytorch 写法 +torch.nn.FractionalMaxPool3d(2, output_size=[3, 3, 3], return_indices=True, _random_samples=torch.tensor([[[0.3, 0.3, 0.3]]])) + +# Paddle 写法 +paddle.nn.FractionalMaxPool3D(output_size=[3, 3, 3], kernel_size=2, return_mask=True, random_u=0.3) +``` diff --git a/docs/guides/model_convert/convert_from_pytorch/pytorch_api_mapping_cn.md b/docs/guides/model_convert/convert_from_pytorch/pytorch_api_mapping_cn.md index 98476f53aae..c986d798077 100644 --- a/docs/guides/model_convert/convert_from_pytorch/pytorch_api_mapping_cn.md +++ b/docs/guides/model_convert/convert_from_pytorch/pytorch_api_mapping_cn.md @@ -386,6 +386,8 @@ | REFERENCE-MAPPING-ITEM(`torch.nn.AdaptiveMaxPool1d`, https://github.com/PaddlePaddle/docs/tree/develop/docs/guides/model_convert/convert_from_pytorch/api_difference/nn/torch.nn.AdaptiveMaxPool1d.md) | | REFERENCE-MAPPING-ITEM(`torch.nn.AdaptiveMaxPool2d`, https://github.com/PaddlePaddle/docs/tree/develop/docs/guides/model_convert/convert_from_pytorch/api_difference/nn/torch.nn.AdaptiveMaxPool2d.md) | | REFERENCE-MAPPING-ITEM(`torch.nn.AdaptiveMaxPool3d`, https://github.com/PaddlePaddle/docs/tree/develop/docs/guides/model_convert/convert_from_pytorch/api_difference/nn/torch.nn.AdaptiveMaxPool3d.md) | +| REFERENCE-MAPPING-ITEM(`torch.nn.FractionalMaxPool2d`, https://github.com/PaddlePaddle/docs/tree/develop/docs/guides/model_convert/convert_from_pytorch/api_difference/nn/torch.nn.FractionalMaxPool2d.md) | +| REFERENCE-MAPPING-ITEM(`torch.nn.FractionalMaxPool3d`, https://github.com/PaddlePaddle/docs/tree/develop/docs/guides/model_convert/convert_from_pytorch/api_difference/nn/torch.nn.FractionalMaxPool3d.md) | | REFERENCE-MAPPING-ITEM(`torch.nn.AvgPool1d`, https://github.com/PaddlePaddle/docs/tree/develop/docs/guides/model_convert/convert_from_pytorch/api_difference/nn/torch.nn.AvgPool1d.md) | | REFERENCE-MAPPING-ITEM(`torch.nn.AvgPool2d`, https://github.com/PaddlePaddle/docs/tree/develop/docs/guides/model_convert/convert_from_pytorch/api_difference/nn/torch.nn.AvgPool2d.md) | | REFERENCE-MAPPING-ITEM(`torch.nn.AvgPool3d`, https://github.com/PaddlePaddle/docs/tree/develop/docs/guides/model_convert/convert_from_pytorch/api_difference/nn/torch.nn.AvgPool3d.md) | @@ -536,6 +538,8 @@ | REFERENCE-MAPPING-ITEM(`torch.nn.functional.adaptive_max_pool1d`, https://github.com/PaddlePaddle/docs/tree/develop/docs/guides/model_convert/convert_from_pytorch/api_difference/functional/torch.nn.functional.adaptive_max_pool1d.md) | | REFERENCE-MAPPING-ITEM(`torch.nn.functional.adaptive_max_pool2d`, https://github.com/PaddlePaddle/docs/tree/develop/docs/guides/model_convert/convert_from_pytorch/api_difference/functional/torch.nn.functional.adaptive_max_pool2d.md) | | REFERENCE-MAPPING-ITEM(`torch.nn.functional.adaptive_max_pool3d`, https://github.com/PaddlePaddle/docs/tree/develop/docs/guides/model_convert/convert_from_pytorch/api_difference/functional/torch.nn.functional.adaptive_max_pool3d.md) | +| REFERENCE-MAPPING-ITEM(`torch.nn.functional.fractional_max_pool2d`, https://github.com/PaddlePaddle/docs/tree/develop/docs/guides/model_convert/convert_from_pytorch/api_difference/functional/torch.nn.functional.fractional_max_pool2d.md) | +| REFERENCE-MAPPING-ITEM(`torch.nn.functional.fractional_max_pool3d`, https://github.com/PaddlePaddle/docs/tree/develop/docs/guides/model_convert/convert_from_pytorch/api_difference/functional/torch.nn.functional.fractional_max_pool3d.md) | | REFERENCE-MAPPING-ITEM(`torch.nn.functional.affine_grid`, https://github.com/PaddlePaddle/docs/tree/develop/docs/guides/model_convert/convert_from_pytorch/api_difference/functional/torch.nn.functional.affine_grid.md) | | REFERENCE-MAPPING-ITEM(`torch.nn.functional.bilinear`, https://github.com/PaddlePaddle/docs/tree/develop/docs/guides/model_convert/convert_from_pytorch/api_difference/functional/torch.nn.functional.bilinear.md) | | REFERENCE-MAPPING-ITEM(`torch.nn.functional.conv1d`, https://github.com/PaddlePaddle/docs/tree/develop/docs/guides/model_convert/convert_from_pytorch/api_difference/functional/torch.nn.functional.conv1d.md) |