Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API Improvement for paddle.nn.initializer.TruncatedNormal 易用性提升 #64110

Merged
merged 15 commits into from
May 16, 2024

Conversation

NKNaN
Copy link
Contributor

@NKNaN NKNaN commented May 8, 2024

PR Category

User Experience

PR Types

Improvements

Description

为截断高斯分布初始化类添加截断参数a(默认值-2.0), b(默认值2.0),与pytorch功能对齐。

Copy link

paddle-bot bot commented May 8, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label May 8, 2024
@NKNaN NKNaN changed the title API Improvement for paddle.nn.initializer.TruncatedNormal API Improvement for paddle.nn.initializer.TruncatedNormal 易用性提升 May 8, 2024
@@ -79,7 +83,8 @@ struct TruncatedNormalOffset {
rng.discard(n + offset_);
T value = dist(rng);
auto p = a_normal_cdf + (b_normal_cdf - a_normal_cdf) * value;
return std::sqrt(2.0) * erfinvf(2 * p - 1) * std + mean;
T ret = std::sqrt(2.0) * erfinvf(2 * p - 1) * std + mean;
return std::clamp(ret, a, b);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch是这样实现的吗,这样直接截断 会不会导致 问题,比如 [a, b]范围内的概率密度之和不为1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch最后一步有clamp:

def trunc_normal_(
    tensor: Tensor,
    mean: float = 0.,
    std: float = 1.,
    a: float = -2.,
    b: float = 2.,
    generator: _Optional[torch.Generator] = None
) -> Tensor:
    r"""Fill the input Tensor with values drawn from a truncated normal distribution.

    The values are effectively drawn from the
    normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
    with values outside :math:`[a, b]` redrawn until they are within
    the bounds. The method used for generating the random values works
    best when :math:`a \leq \text{mean} \leq b`.

    Args:
        tensor: an n-dimensional `torch.Tensor`
        mean: the mean of the normal distribution
        std: the standard deviation of the normal distribution
        a: the minimum cutoff value
        b: the maximum cutoff value
        generator: the torch Generator to sample from (default: None)

    Examples:
        >>> w = torch.empty(3, 5)
        >>> nn.init.trunc_normal_(w)
    """
    return _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=generator)

def _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=None):
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    if (mean < a - 2 * std) or (mean > b + 2 * std):
        warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
                      "The distribution of values may be incorrect.",
                      stacklevel=2)

    with torch.no_grad():
        # Values are generated by using a truncated uniform distribution and
        # then using the inverse CDF for the normal distribution.
        # Get upper and lower cdf values
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)

        # Uniformly fill tensor with values from [l, u], then translate to
        # [2l-1, 2u-1].
        tensor.uniform_(2 * l - 1, 2 * u - 1, generator=generator)

        # Use inverse cdf transform for normal distribution to get truncated
        # standard normal
        tensor.erfinv_()

        # Transform to proper mean, std
        tensor.mul_(std * math.sqrt(2.))
        tensor.add_(mean)

        # Clamp to ensure it's in the proper range
        tensor.clamp_(min=a, max=b)
        return tensor

修改了一下,上面算 a_normal_cdf 和 b_normal_cdf 的时候应该把 a 和 b 分别标准化一下,对应的是 torch 里面的 l 和 u。理论上按这个方法采样出来的随机数绝大部分情况都是在 a 和 b 之间的,只有当 l = 0 或 u=1 的时候,uniform采样这一步可能采样出 -1 或 1 的端点值,然后端点值计算 erfinv() 后会变成 -inf 或 inf,此时clamp的截断避免了结果算出 inf 和 -inf。另外clamp也会解决浮点精度溢出的情况。(详见下面的例子)

@@ -165,21 +165,31 @@ def __init__(self, mean=0.0, std=1.0, name=None):
class TruncatedNormalInitializer(Initializer):
"""Implements the Random TruncatedNormal(Gaussian) distribution initializer

Note:
It is better to set `mean` in the range of [`a`, `b`].
If `mean` is more than 2 `std` from [`a`, `b`], the distribution of values may be incorrect.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个意思是?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是根据 torch 的 API 里面的提示:

if (mean < a - 2 * std) or (mean > b + 2 * std):
        warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
                      "The distribution of values may be incorrect.",
                      stacklevel=2)

我觉得意思可能是,如果 mean 在 [a-2std, b+2std] 范围外的话,也就是 a > mean + 2 * std 或者 b < mean - 2 * std 的时候,生成的随机数限制在 a, b 之间,此时 a, b 之间区域相当于是在正态分布的尾部,a,b 两点在正态分布中对应的 p 值差距会非常小,即下面 l, u 这两个值非常接近

l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)

这样采样出来的数据,离散程度会越来越小;
同时 a, b 在尾部的时候,由于浮点精度的问题 a, b 对应的 p 值很容易就会算成 0 或者 1。 比如下面这种情况:

def _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=None):
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    with torch.no_grad():
        # Values are generated by using a truncated uniform distribution and
        # then using the inverse CDF for the normal distribution.
        # Get upper and lower cdf values
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)
        
        print(l)
        print(u)
        
        # Uniformly fill tensor with values from [l, u], then translate to
        # [2l-1, 2u-1].
        tensor.uniform_(2 * l - 1, 2 * u - 1, generator=generator)
        
        print(np.array(tensor))
        
        # Use inverse cdf transform for normal distribution to get truncated
        # standard normal
        tensor.erfinv_()
        
        # Transform to proper mean, std
        tensor.mul_(std * math.sqrt(2.))
        tensor.add_(mean)
        
        print(np.array(tensor))
        
        # Clamp to ensure it's in the proper range
        tensor.clamp_(min=a, max=b)
        return tensor

_no_grad_trunc_normal_(w, mean=-7, std=1, a=-2, b=2)

此时 l = 0.9999997133484282, u = 1.0,
经过 uniform_ 之后 tensor 可能为
[[0.9999999 0.99999994 0.99999946 ... 0.9999997 0.9999997 0.9999999 ]
[0.9999994 0.9999997 0.9999995 ... 0.99999994 0.99999964 1. ]
[0.9999997 0.99999946 0.9999999 ... 0.9999994 0.99999964 0.9999998 ]
...
[0.9999998 0.99999964 0.99999976 ... 0.9999999 0.9999997 0.9999999 ]
[0.9999995 1. 0.99999964 ... 0.9999996 0.9999995 1. ]
[0.9999999 0.99999994 0.99999994 ... 0.9999997 0.9999997 1. ]]

然后线性变换后就会出现一些 inf
[[-1.705296 -1.5800171 -1.9872022 ... -1.8753071 -1.8753071 -1.705296 ]
[-2.0075064 -1.8753071 -1.9645939 ... -1.5800171 -1.909769 inf]
[-1.8753071 -1.9872022 -1.705296 ... -2.0075064 -1.909769 -1.7798872]
...
[-1.7798872 -1.909769 -1.8334217 ... -1.705296 -1.8753071 -1.705296 ]
[-1.9645939 inf -1.909769 ... -1.9390774 -1.9645939 inf]
[-1.705296 -1.5800171 -1.5800171 ... -1.8753071 -1.8753071 inf]]

最后得到
tensor([[-1.7053, -1.5800, -1.9872, ..., -1.8753, -1.8753, -1.7053],
[-2.0000, -1.8753, -1.9646, ..., -1.5800, -1.9098, 2.0000],
[-1.8753, -1.9872, -1.7053, ..., -2.0000, -1.9098, -1.7799],
...,
[-1.7799, -1.9098, -1.8334, ..., -1.7053, -1.8753, -1.7053],
[-1.9646, 2.0000, -1.9098, ..., -1.9391, -1.9646, 2.0000],
[-1.7053, -1.5800, -1.5800, ..., -1.8753, -1.8753, 2.0000]])

inf 出现的地方通过clamp截断成了 b(2.0) 的值,而这里如果浮点精度够的话 u 应该是要小于1.0 的,也就是 uniform_ 采样不会出现 1.0 这个端点,这样最后得到的值也应该是 -1.xxx 接近 -2.0,所以说此时采样出的数据分布是有一点问题的。

(clamp也解决了最后一步线性变换计算产生的浮点精度溢出,比如这里会有 -2.0075064 这个值,应该是溢出了)

@@ -165,21 +165,31 @@ def __init__(self, mean=0.0, std=1.0, name=None):
class TruncatedNormalInitializer(Initializer):
"""Implements the Random TruncatedNormal(Gaussian) distribution initializer

Note:
It is better to set `mean` in the range of [`a`, `b`].
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我建议这里对mean的建议,用具体的数学公式来表述,不然不是特别容易看明白

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改成这样可以吗:
It is better to set a <= mean <= b.
If mean is set to be > b + 2 · std or < a - 2 · std , the distribution of values may be incorrect.

Copy link
Contributor

@zhwesky2010 zhwesky2010 May 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改成这样可以吗: It is better to set a <= mean <= b. If mean is set to be > b + 2 · std or < a - 2 · std , the distribution of values may be incorrect.

It is better to set a <= mean <= b.
If mean < a - 2*std or mean > b + 2*std, the distribution of values may be incorrect.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

@@ -266,9 +280,15 @@ def forward(self, var, block=None):
class TruncatedNormal(TruncatedNormalInitializer):
"""The truncated normal distribution (Gaussian distribution) initializer.

Note:
It is better to set `mean` in the range of [`a`, `b`].
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

@zhwesky2010
Copy link
Contributor

@NKNaN 按上面的改一下吧

Copy link
Contributor

@zhwesky2010 zhwesky2010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@luotao1 luotao1 merged commit 657bb7e into PaddlePaddle:develop May 16, 2024
31 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants