You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Same question. This lead to a strange situation. The final kl loss is computed like: kl_penalty = -self.kl_penalty_weight * (logprobs - ref_logprob)
However, the part ref_logprob does not require grad. So maybe it can be removed from computation graph. In current situation, the regularization is more similar to "limit the label logit and prevent it becoming too large" rather than a normal kl-divergence.
您好!
我看了下代码,发现里面的token级的reward里加的kl 惩罚好像不是按标准的kl散度计算的,标准的应该是按两个分布来计算。但是我看代码里好像用的是只用了label这个一个token的概率相除(标准的kl散度能保证是非零的,但是现在代码里的实现不是可能是一个负数么),这是为什么呢?还有我看approx_kl也是这样。
The text was updated successfully, but these errors were encountered: