KL散度近似方法介绍:从John Schulman的博客到DeepSeek GRPO的应用

KL散度(Kullback-Leibler Divergence)是衡量两个概率分布之间差异的重要指标,广泛应用于机器学习、强化学习等领域。John Schulman在其2020年3月7日的博客中详细探讨了如何通过蒙特卡洛方法近似KL散度,并提出了一种低方差、无偏的估计器。这一方法不仅在理论上具有重要意义,还被DeepSeek的GRPO算法所采用。本文将基于Schulman的博客内容,介绍KL散度的近似方法及其在DeepSeek GRPO中的应用。

blog:http://joschu.net/blog/kl-approx.html

首先看一下John Schulman是谁:PPO的作者,领导创建了ChatGPT,OpenAI的cofounder。

在这里插入图片描述

source:http://joschu.net/index.html

在这里插入图片描述

source:https://scholar.google.com/citations?user=itSa94cAAAAJ

KL散度的基本定义

KL散度衡量两个概率分布 ( ppp ) 和 ( qqq ) 之间的差异,定义为:

KL[q,p]=∑xq(x)log⁡q(x)p(x)=Ex∼q[log⁡q(x)p(x)] KL[q, p] = \sum_x q(x) \log \frac{q(x)}{p(x)} = \mathbb{E}_{x \sim q} \left[ \log \frac{q(x)}{p(x)} \right] KL[q,p]=xq(x)logp(x)q(x)=Exq[logp(x)q(x)]

在实际应用中,精确计算KL散度通常不可行,因为:

  1. 计算所有 ( xxx ) 的概率和需要过多的计算或内存。
  2. 分布 ( ppp ) 和 ( qqq ) 可能没有闭合表达式。
  3. 在强化学习等场景中,KL散度常作为诊断工具,仅存储对数概率以简化代码。

因此,蒙特卡洛方法成为估计KL散度的常用策略。假设我们从分布 ( qqq ) 中采样 ( x1,x2,…x_1, x_2, \dotsx1,x2, ),如何构造一个既无偏又低方差的估计器是关键问题。

传统估计器及其局限性

最直接的蒙特卡洛估计器是基于KL散度的定义:

k1=log⁡q(x)p(x)=−log⁡r,其中r=p(x)q(x) k_1 = \log \frac{q(x)}{p(x)} = -\log r, \quad \text{其中} \quad r = \frac{p(x)}{q(x)} k1=logp(x)q(x)=logr,其中r=q(x)p(x)

这个估计器 ( k1k_1k1 ) 是无偏的,即其期望等于真实的KL散度。然而,由于 ( log⁡r\log rlogr ) 的值在正负之间变化(当 ( r>1r > 1r>1 ) 时为正,当 ( r<1r < 1r<1 ) 时为负),其方差较高,而KL散度本身始终为正。这种高方差使得 ( k1k_1k1 ) 在实际应用中表现不佳。

低方差的偏倚估计器

Schulman提出了一种替代估计器:

k2=12(log⁡p(x)q(x))2=12(log⁡r)2 k_2 = \frac{1}{2} \left( \log \frac{p(x)}{q(x)} \right)^2 = \frac{1}{2} (\log r)^2 k2=21(logq(x)p(x))2=21(logr)2

这个估计器 ( k2k_2k2 ) 虽然有偏,但方差显著低于 ( k1k_1k1 )。其优点在于:

  1. 始终为正:每个样本都反映了 ( ppp ) 和 ( qqq ) 之间的差异,且结果非负,与KL散度的性质一致。
  2. 低偏倚:( k2k_2k2 ) 的期望是一个f-散度,其形式为:

Eq[k2]=Eq[12(log⁡r)2] \mathbb{E}_q[k_2] = \mathbb{E}_q \left[ \frac{1}{2} (\log r)^2 \right] Eq[k2]=Eq[21(logr)2]

这是一个以 ( f(x)=12(log⁡x)2f(x) = \frac{1}{2} (\log x)^2f(x)=21(logx)2 ) 定义的f-散度。当 ( ppp ) 和 ( qqq ) 接近时,所有可微的f-散度在二阶近似下与KL散度等价,因此 ( k2k_2k2 ) 的偏倚非常小。

实验表明,当 ( q=N(0,1)q = \mathcal{N}(0, 1)q=N(0,1) )、( p=N(0.1,1)p = \mathcal{N}(0.1, 1)p=N(0.1,1) ) 时(真实KL散度为0.005),( k2k_2k2 ) 的偏倚仅为0.2%,标准差为真实值的1.42倍,远低于 ( k1k_1k1 ) 的20倍。

无偏且低方差的估计器

为了兼顾无偏和低方差,Schulman引入了控制变量(control variate)方法。利用 ( Eq[r−1]=0\mathbb{E}_q[r - 1] = 0Eq[r1]=0 ),可以构造一个新的估计器:

k3=−log⁡r+λ(r−1) k_3 = -\log r + \lambda (r - 1) k3=logr+λ(r1)

通过选择适当的 ( λ\lambdaλ ),可以降低方差。当 ( λ=1\lambda = 1λ=1 ) 时,估计器变为:

k3=(r−1)−log⁡r k_3 = (r - 1) - \log r k3=(r1)logr

由于对数的凹性,( log⁡r≤r−1\log r \leq r - 1logrr1 ),因此 ( k3k_3k3) 始终为正。这个估计器不仅无偏,而且方差低于 ( k1k_1k1 )。实验表明,当真实KL散度为0.5时,( k3k_3k3 ) 的标准差为真实值的1.7倍,低于 ( k2k_2k2 ) 的1.73倍,且无偏。

推广到其他f-散度

Schulman进一步指出,上述方法可以推广到任意f-散度。对于 ( KL[p,q]KL[p, q]KL[p,q] ),对应的f-散度估计器为:

rlog⁡r−(r−1) r \log r - (r - 1) rlogr(r1)

这个估计器同样基于凸函数与其切线的距离(Bregman散度),保证了非负性和低方差。

DeepSeek GRPO中的应用

DeepSeek的GRPO算法直接采用了Schulman提出的无偏估计器 ( k3k_3k3 )。具体来说,GRPO使用以下估计器来近似 ( DKL(πθ∣∣πref)D_{KL}(\pi_\theta || \pi_{ref})DKL(πθ∣∣πref) ):

DKL(πθ∣∣πref)=πref(oi,t∣q,oi,<t)πθ(oi,t∣q,oi,<t)−log⁡πref(oi,t∣q,oi,<t)πθ(oi,t∣q,oi,<t)−1 D_{KL}(\pi_\theta || \pi_{ref}) = \frac{\pi_{ref}(o_{i,t} | q, o_{i,<t})}{\pi_\theta(o_{i,t} | q, o_{i,<t})} - \log \frac{\pi_{ref}(o_{i,t} | q, o_{i,<t})}{\pi_\theta(o_{i,t} | q, o_{i,<t})} - 1 DKL(πθ∣∣πref)=πθ(oi,tq,oi,<t)πref(oi,tq,oi,<t)logπθ(oi,tq,oi,<t)πref(oi,tq,oi,<t)1

其中,( πθ\pi_\thetaπθ ) 表示策略网络,( πref\pi_{ref}πref ) 表示参考策略,( oi,to_{i,t}oi,t ) 表示在时间 ( ttt ) 的观测,( qqq ) 和 ( oi,<to_{i,<t}oi,<t ) 分别表示上下文和历史观测。令 ( r=πref(oi,t∣q,oi,<t)πθ(oi,t∣q,oi,<t)r = \frac{\pi_{ref}(o_{i,t} | q, o_{i,<t})}{\pi_\theta(o_{i,t} | q, o_{i,<t})}r=πθ(oi,tq,oi,<t)πref(oi,tq,oi,<t) ),该估计器可重写为:

k3=(r−1)−log⁡r k_3 = (r - 1) - \log r k3=(r1)logr

这个估计器的优点在于:

  1. 无偏性:保证估计的期望等于真实的KL散度。
  2. 非负性:由于 ( log⁡r≤r−1\log r \leq r - 1logrr1 ),估计值始终为正,与KL散度的性质一致。
  3. 低方差:通过控制变量降低了估计的方差,提升了算法的稳定性。

与传统的KL惩罚项(例如直接使用 (log⁡πθπref\log \frac{\pi_\theta}{\pi_{ref}}logπrefπθ))相比,GRPO的估计器在强化学习中表现更稳定,尤其是在策略优化需要精确控制分布差异时。

实验结果

Schulman的实验验证了 ( k3k_3k3 ) 的优越性。例如,当 ( q=N(0,1)q = \mathcal{N}(0, 1)q=N(0,1) )、( p=N(1,1)p = \mathcal{N}(1, 1)p=N(1,1) )(真实KL散度为0.5)时:

  • ( k1k_1k1 ): 无偏,标准差为真实值的2倍。
  • ( k2k_2k2 ): 偏倚为25%,标准差为真实值的1.73倍。
  • ( k3k_3k3 ): 无偏,标准差为真实值的1.7倍。

这些结果表明,( k3k_3k3 ) 在保持无偏的同时,显著降低了方差。

总结

John Schulman的博客提供了一种优雅的KL散度近似方法,通过引入控制变量构造了无偏且低方差的估计器 ( k3=(r−1)−log⁡rk_3 = (r - 1) - \log rk3=(r1)logr )。这一方法不仅在理论上具有普适性(可推广到任意f-散度),还在DeepSeek的GRPO算法中得到了实际应用。GRPO利用这一估计器实现了稳定的策略优化,展示了其在强化学习中的重要价值。对于需要在高维分布上估计KL散度的研究者和工程师来说,Schulman的方法提供了一个兼顾理论严谨性和实践效果的解决方案。

蒙特卡洛估计器

什么是蒙特卡洛估计器?

蒙特卡洛方法(Monte Carlo Method)是一种通过随机采样来估计复杂数学量的方法。它特别适合用来近似那些难以直接计算的积分、期望或求和,尤其是在高维空间或解析解不可得的情况下。蒙特卡洛估计器的核心思想是:通过从某个分布中抽取大量随机样本,计算这些样本上的函数值,然后用样本均值来近似目标量的期望。

在KL散度的背景下,KL散度 ( KL[q,p]KL[q, p]KL[q,p] ) 定义为:

KL[q,p]=∑xq(x)log⁡q(x)p(x)=Ex∼q[log⁡q(x)p(x)] KL[q, p] = \sum_x q(x) \log \frac{q(x)}{p(x)} = \mathbb{E}_{x \sim q} \left[ \log \frac{q(x)}{p(x)} \right] KL[q,p]=xq(x)logp(x)q(x)=Exq[logp(x)q(x)]

这里的 ( Ex∼q\mathbb{E}_{x \sim q}Exq ) 表示从分布 ( qqq ) 中采样的期望。如果我们无法直接计算整个求和(例如,分布 ( qqq ) 和 ( ppp ) 是高维的,或者没有闭合表达式),就可以用蒙特卡洛方法来估计这个期望。具体来说:

  1. 从分布 ( qqq ) 中抽取 ( NNN ) 个独立样本 ( x1,x2,…,xNx_1, x_2, \dots, x_Nx1,x2,,xN )。
  2. 对每个样本 ( xix_ixi ),计算函数 ( f(xi)=log⁡q(xi)p(xi)f(x_i) = \log \frac{q(x_i)}{p(x_i)}f(xi)=logp(xi)q(xi) )。
  3. 用样本均值来估计期望:

KL[q,p]≈1N∑i=1Nlog⁡q(xi)p(xi) KL[q, p] \approx \frac{1}{N} \sum_{i=1}^N \log \frac{q(x_i)}{p(x_i)} KL[q,p]N1i=1Nlogp(xi)q(xi)

这个均值就是蒙特卡洛估计器,称为 ( k1k_1k1 ) 在你的问题中:

k1=log⁡q(x)p(x)=−log⁡r,其中r=p(x)q(x) k_1 = \log \frac{q(x)}{p(x)} = -\log r, \quad \text{其中} \quad r = \frac{p(x)}{q(x)} k1=logp(x)q(x)=logr,其中r=q(x)p(x)

蒙特卡洛估计器的作用:它通过有限样本近似真实的KL散度值。KL散度本质上衡量两个分布 ( qqq ) 和 ( ppp ) 之间的“距离”(尽管严格来说它不是对称的距离),而蒙特卡洛估计器试图通过采样来捕捉这种差异。


蒙特卡洛估计器是否反映分布的距离?

是的,蒙特卡洛估计器 ( k1k_1k1 ) 的期望等于真实的KL散度,因此它确实反映了两个分布之间的“距离”。具体来说:

  • 无偏性:( Ex∼q[k1]=Ex∼q[log⁡q(x)p(x)]=KL[q,p]\mathbb{E}_{x \sim q}[k_1] = \mathbb{E}_{x \sim q} \left[ \log \frac{q(x)}{p(x)} \right] = KL[q, p]Exq[k1]=Exq[logp(x)q(x)]=KL[q,p] )。这意味着如果采样次数 ( N→∞N \to \inftyN ),估计器的平均值会收敛到真实的KL散度。
  • 分布差异的体现:( k1=log⁡q(x)p(x)k_1 = \log \frac{q(x)}{p(x)}k1=logp(x)q(x) ) 直接衡量了在样本点 ( xxx ) 上,分布 ( qqq ) 和 ( ppp ) 的概率密度(或概率)的对数差异。如果 ( q(x)≈p(x)q(x) \approx p(x)q(x)p(x) ),则 ( log⁡q(x)p(x)≈0\log \frac{q(x)}{p(x)} \approx 0logp(x)q(x)0 ),表明分布很接近;如果 ( q(x)q(x)q(x) ) 和 ( p(x)p(x)p(x) ) 差异很大,则 ( k1k_1k1 ) 的绝对值会较大,反映分布的差异。

然而,( k1k_1k1 ) 的局限性在于其高方差。由于 ( log⁡r\log rlogr ) (其中 ( r=p(x)q(x)r = \frac{p(x)}{q(x)}r=q(x)p(x) ))可能为正或负,样本值的波动会导致估计值的方差较大,尤其当 ( ppp ) 和 ( qqq ) 差异较大时。这种高方差使得 ( k1k_1k1 ) 在实际应用中不够稳定。


通过例子说明

让我们通过一个简单的例子来说明蒙特卡洛估计器如何工作,以及它如何反映分布的距离。

场景

假设有两个一维正态分布:

  • ( q=N(0,1)q = \mathcal{N}(0, 1)q=N(0,1) ) (均值0,标准差1)
  • ( p=N(0.5,1)p = \mathcal{N}(0.5, 1)p=N(0.5,1) ) (均值0.5,标准差1)

真实的KL散度 ( KL[q,p]KL[q, p]KL[q,p] ) 对于正态分布有解析解:

KL[q,p]=12((μq−μp)2σp2)=12(0−0.5)212=0.252=0.125 KL[q, p] = \frac{1}{2} \left( \frac{(\mu_q - \mu_p)^2}{\sigma_p^2} \right) = \frac{1}{2} \frac{(0 - 0.5)^2}{1^2} = \frac{0.25}{2} = 0.125 KL[q,p]=21(σp2(μqμp)2)=2112(00.5)2=20.25=0.125

我们用蒙特卡洛方法来估计这个值。

蒙特卡洛估计步骤
  1. 采样:从 ( q=N(0,1)q = \mathcal{N}(0, 1)q=N(0,1) ) 中抽取 ( N=1000N = 1000N=1000 ) 个样本 ( x1,x2,…,x1000x_1, x_2, \dots, x_{1000}x1,x2,,x1000 )。假设样本是随机生成的,例如 ( x1=−0.3,x2=1.2,x3=−1.5,…x_1 = -0.3, x_2 = 1.2, x_3 = -1.5, \dotsx1=0.3,x2=1.2,x3=1.5, )。
  2. 计算 ( k1k_1k1 ):对每个样本 ( xix_ixi ),计算:

k1(xi)=log⁡q(xi)p(xi) k_1(x_i) = \log \frac{q(x_i)}{p(x_i)} k1(xi)=logp(xi)q(xi)

正态分布的概率密度函数为:

q(x)=12πexp⁡(−x22),p(x)=12πexp⁡(−(x−0.5)22) q(x) = \frac{1}{\sqrt{2\pi}} \exp\left(-\frac{x^2}{2}\right), \quad p(x) = \frac{1}{\sqrt{2\pi}} \exp\left(-\frac{(x - 0.5)^2}{2}\right) q(x)=2π 1exp(2x2),p(x)=2π 1exp(2(x0.5)2)

因此:

log⁡q(xi)p(xi)=log⁡q(xi)−log⁡p(xi)=(−xi22)−(−(xi−0.5)22) \log \frac{q(x_i)}{p(x_i)} = \log q(x_i) - \log p(x_i) = \left( -\frac{x_i^2}{2} \right) - \left( -\frac{(x_i - 0.5)^2}{2} \right) logp(xi)q(xi)=logq(xi)logp(xi)=(2xi2)(2(xi0.5)2)

=−xi22+(xi−0.5)22=(xi−0.5)2−xi22=xi2−xi+0.25−xi22=−xi+0.252 = -\frac{x_i^2}{2} + \frac{(x_i - 0.5)^2}{2} = \frac{(x_i - 0.5)^2 - x_i^2}{2} = \frac{x_i^2 - x_i + 0.25 - x_i^2}{2} = \frac{-x_i + 0.25}{2} =2xi2+2(xi0.5)2=2(xi0.5)2xi2=2xi2xi+0.25xi2=2xi+0.25

k1(xi)=0.25−xi2 k_1(x_i) = \frac{0.25 - x_i}{2} k1(xi)=20.25xi

  1. 估计KL散度:计算所有样本的均值:

KL^[q,p]=1N∑i=1Nk1(xi)=1N∑i=1N0.25−xi2 \hat{KL}[q, p] = \frac{1}{N} \sum_{i=1}^N k_1(x_i) = \frac{1}{N} \sum_{i=1}^N \frac{0.25 - x_i}{2} KL^[q,p]=N1i=1Nk1(xi)=N1i=1N20.25xi

假设我们采样得到的 ( xix_ixi ) 的均值为 ( xˉ≈0\bar{x} \approx 0xˉ0 )(因为 ( xi∼N(0,1)x_i \sim \mathcal{N}(0, 1)xiN(0,1) )),则:

KL^[q,p]≈11000∑i=110000.25−xi2≈0.25−02=0.125 \hat{KL}[q, p] \approx \frac{1}{1000} \sum_{i=1}^{1000} \frac{0.25 - x_i}{2} \approx \frac{0.25 - 0}{2} = 0.125 KL^[q,p]10001i=1100020.25xi20.250=0.125

这个估计值接近真实的 ( KL[q,p]=0.125KL[q, p] = 0.125KL[q,p]=0.125 ),说明 ( k1k_1k1 ) 是无偏的。

高方差的表现

尽管 ( k1k_1k1 ) 无偏,但它的值在样本间波动较大。例如:

  • 如果 ( xi=0x_i = 0xi=0 ),则 ( k1=0.25−02=0.125k_1 = \frac{0.25 - 0}{2} = 0.125k1=20.250=0.125 )。
  • 如果 ( xi=2x_i = 2xi=2 ),则 ( k1=0.25−22=−0.875k_1 = \frac{0.25 - 2}{2} = -0.875k1=20.252=0.875 )。
  • 如果 ( xi=−2x_i = -2xi=2 ),则 ( k1=0.25−(−2)2=1.125k_1 = \frac{0.25 - (-2)}{2} = 1.125k1=20.25(2)=1.125 ).

这些值的正负波动导致样本均值的方差较高。如果采样次数 ( NNN ) 不够多,估计值可能偏离0.125较远。

反映分布距离

从 ( k1(xi)=0.25−xi2k_1(x_i) = \frac{0.25 - x_i}{2}k1(xi)=20.25xi ) 可以看出,( k1k_1k1 ) 的值直接与样本 ( xix_ixi) 有关,而样本是由 ( qqq ) 生成的。KL散度的值0.125反映了 ( qqq ) 和 ( ppp ) 之间的均值差异(0 vs 0.5)。如果我们改变 ( ppp ) 的均值(例如 ( p=N(1,1)p = \mathcal{N}(1, 1)p=N(1,1) )),KL散度会变大(( KL[q,p]=0.5KL[q, p] = 0.5KL[q,p]=0.5 )),相应的 ( k1k_1k1 ) 的均值也会变大,反映出分布间更大的差异。


总结

  • 蒙特卡洛估计器是通过从分布 ( qqq ) 中采样,用样本上的函数值均值来近似KL散度的期望。
  • 反映分布距离:( k1=log⁡q(x)p(x)k_1 = \log \frac{q(x)}{p(x)}k1=logp(x)q(x) ) 的期望等于KL散度,因此它确实捕捉了 ( qqq ) 和 ( ppp ) 之间的差异。
  • 局限性:高方差使得 ( k1k_1k1 ) 在有限样本下不够稳定,如例子中 ( k1k_1k1 ) 的值在正负间波动。
  • 改进方向:Schulman的博客中提出了 ( k2k_2k2 ) 和 ( k3k_3k3 ) 等估计器,通过降低方差(例如保证非负性或引入控制变量)来改进 ( k1k_1k1 ),在DeepSeek GRPO等应用中表现出色。

log相减就是k1估计器

在许多大语言模型(LLM)强化学习(RL)相关的论文中,尤其是那些涉及策略优化(如PPO、TRPO或RLHF)的论文,使用的KL散度估计通常是基于 ( k1k_1k1 ) 估计器,即直接计算对数概率的差值:

k1=log⁡q(x)p(x)=log⁡q(x)−log⁡p(x) k_1 = \log \frac{q(x)}{p(x)} = \log q(x) - \log p(x) k1=logp(x)q(x)=logq(x)logp(x)

或者在策略优化的上下文中,写作:

k1=log⁡πref(a∣s)πθ(a∣s)=log⁡πref(a∣s)−log⁡πθ(a∣s) k_1 = \log \frac{\pi_{\text{ref}}(a|s)}{\pi_{\theta}(a|s)} = \log \pi_{\text{ref}}(a|s) - \log \pi_{\theta}(a|s) k1=logπθ(as)πref(as)=logπref(as)logπθ(as)

其中,( πθ\pi_{\theta}πθ ) 是当前策略,( πref\pi_{\text{ref}}πref ) 是参考策略(如初始策略或SFT模型),( aaa ) 是动作(在LLM中通常是生成的token),( sss ) 是状态(上下文或提示)。这种形式的KL散度估计正是Schulman博客中描述的 ( k1k_1k1 ) 估计器,基于蒙特卡洛方法,通过采样计算对数概率差的均值来近似 ( KL[πθ∣∣πref]KL[\pi_{\theta} || \pi_{\text{ref}}]KL[πθ∣∣πref] ) 或 ( KL[πref∣∣πθ]KL[\pi_{\text{ref}} || \pi_{\theta}]KL[πref∣∣πθ] )。

为什么常用 ( k1k_1k1 ) 估计器?

  1. 简单直接:( k1k_1k1 ) 是KL散度的直接蒙特卡洛估计,形式上就是对数概率的差值,易于实现。LLM通常会输出每个token的对数概率(logits经过softmax后取log),因此计算 ( log⁡πθ(a∣s)\log \pi_{\theta}(a|s)logπθ(as) ) 和 ( log⁡πref(a∣s)\log \pi_{\text{ref}}(a|s)logπref(as) ) 是现成的。

  2. 无偏性:( k1k_1k1 ) 是无偏估计器,其期望等于真实的KL散度:

Ea∼πθ[log⁡πθ(a∣s)πref(a∣s)]=KL[πθ∣∣πref] \mathbb{E}_{a \sim \pi_{\theta}} \left[ \log \frac{\pi_{\theta}(a|s)}{\pi_{\text{ref}}(a|s)} \right] = KL[\pi_{\theta} || \pi_{\text{ref}}] Eaπθ[logπref(as)πθ(as)]=KL[πθ∣∣πref]

这保证了在采样足够多的情况下,估计值会收敛到真实KL散度。

  1. 常见应用场景:在强化学习中,KL散度常用于以下场景:
    • 正则化项:如PPO中,KL散度作为惩罚项,限制新策略 ( πθ\pi_{\theta}πθ ) 偏离参考策略 ( πref\pi_{\text{ref}}πref ) 过远。
    • 诊断工具:监控策略更新的稳定性。
    • 奖励函数:在RLHF中,KL散度常用于奖励函数(如 ( R=Rhuman−βKL[πθ∣∣πref]R = R_{\text{human}} - \beta KL[\pi_{\theta} || \pi_{\text{ref}}]R=RhumanβKL[πθ∣∣πref] )),以平衡偏好和策略稳定性。

这些场景中,( k1k_1k1 ) 的形式简单,且与策略梯度计算兼容,因此被广泛采用。

( k1k_1k1 ) 估计器的局限性在LLM RL中的体现

尽管 ( k1k_1k1 ) 常用,但正如Schulman的博客所指出的,它有高方差的缺点。这在LLM RL中可能表现为:

  • 方差导致的不稳定:由于 ( k1=log⁡πθ(a∣s)πref(a∣s)k_1 = \log \frac{\pi_{\theta}(a|s)}{\pi_{\text{ref}}(a|s)}k1=logπref(as)πθ(as) ) 可能为正或负(当 ( πθ(a∣s)>πref(a∣s)\pi_{\theta}(a|s) > \pi_{\text{ref}}(a|s)πθ(as)>πref(as) ) 时为正,反之为负),样本间的波动会导致KL散度估计不稳定,尤其在序列生成中,token序列较长时方差会累积。
  • 采样效率:LLM生成序列的采样成本高(需要运行整个模型),如果 ( k1k_1k1 ) 的高方差要求更多样本才能得到可靠估计,会增加计算开销。
  • 极端概率的影响:当 ( πθ(a∣s)\pi_{\theta}(a|s)πθ(as) ) 或 ( πref(a∣s)\pi_{\text{ref}}(a|s)πref(as) ) 接近0时,( log⁡πθ(a∣s)πref(a∣s)\log \frac{\pi_{\theta}(a|s)}{\pi_{\text{ref}}(a|s)}logπref(as)πθ(as) ) 可能变得很大或很小,导致估计值的异常波动。

总结

是的,许多LLM强化学习论文中使用的KL散度估计确实是 ( k1k_1k1 ) 估计器(( log⁡πθπref\log \frac{\pi_{\theta}}{\pi_{\text{ref}}}logπrefπθ )),因为它简单、无偏且易于实现。然而,其高方差可能导致不稳定,尤其在长序列生成或概率差异较大的情况下。一些前沿工作(如DeepSeek GRPO)采用了Schulman提出的 ( k3k_3k3 ) 估计器,以获得更低的方差和更稳定的优化。如果你阅读的论文明确提到KL散度惩罚,通常默认是 ( k1k_1k1 ),但需要检查具体实现或损失函数是否使用了改进的估计器(如 ( k3k_3k3 ) 或其他变体)。

f-散度和k2

什么是f-散度?

f-散度(f-divergence)是一类用于衡量两个概率分布 ( ppp ) 和 ( qqq ) 之间差异的广义距离度量。它通过一个凸函数 ( fff ) 定义,形式如下:

Df(p,q)=Ex∼q[f(p(x)q(x))] D_f(p, q) = \mathbb{E}_{x \sim q} \left[ f\left( \frac{p(x)}{q(x)} \right) \right] Df(p,q)=Exq[f(q(x)p(x))]

其中:

  • ( p(x)p(x)p(x) ) 和 ( q(x)q(x)q(x) ) 是分布 ( ppp ) 和 ( qqq ) 在点 ( xxx ) 上的概率密度(或概率质量,视离散/连续分布而定)。
  • ( r=p(x)q(x)r = \frac{p(x)}{q(x)}r=q(x)p(x) ) 是概率比率。
  • ( f:R+→Rf: \mathbb{R}^+ \to \mathbb{R}f:R+R ) 是一个凸函数,且通常满足 ( f(1)=0f(1) = 0f(1)=0 ),以确保当 ( p=qp = qp=q ) 时散度为零。

f-散度的特点是它包含了许多常见的分布距离度量,通过选择不同的 ( fff ) 函数,可以得到不同的散度。例如:

  • KL散度:当 ( f(u)=ulog⁡uf(u) = u \log uf(u)=ulogu ),对应 ( Df(p,q)=KL(p∣∣q)=Eq[p(x)q(x)log⁡p(x)q(x)]D_f(p, q) = KL(p || q) = \mathbb{E}_{q} \left[ \frac{p(x)}{q(x)} \log \frac{p(x)}{q(x)} \right]Df(p,q)=KL(p∣∣q)=Eq[q(x)p(x)logq(x)p(x)] )。
  • 逆KL散度:当 ( f(u)=−log⁡uf(u) = -\log uf(u)=logu ),对应 ( Df(p,q)=KL(q∣∣p)D_f(p, q) = KL(q || p)Df(p,q)=KL(q∣∣p) )。
  • Jensen-Shannon散度:通过选择适当的 ( fff ) 函数,可以推导JS散度。
  • 总变差距离:当 ( f(u)=∣u−1∣f(u) = |u - 1|f(u)=u1∣ ),对应总变差。

f-散度的优点是它提供了一个统一的框架来研究不同分布距离的性质,且许多f-散度在 ( p≈qp \approx qpq ) 时具有相似的二阶行为(如Schulman博客中提到的与KL散度的关系)。

Schulman博客中的f-散度背景

在John Schulman的博客《Approximating KL Divergence》中,他讨论了如何通过蒙特卡洛方法近似KL散度 ( KL[q,p]KL[q, p]KL[q,p] ),并提出了一种低方差但有偏的估计器 ( k2k_2k2 )。他指出,( k2k_2k2 ) 的期望是一个f-散度,并且在 ( p≈qp \approx qpq ) 时,所有可微的f-散度在二阶近似下与KL散度等价。这种等价性是 ( k2k_2k2 ) 低偏倚的理论基础。

具体来说,Schulman定义了 ( k2k_2k2 ) 估计器为:

k2=12(log⁡p(x)q(x))2=12(log⁡r)2,其中r=p(x)q(x) k_2 = \frac{1}{2} \left( \log \frac{p(x)}{q(x)} \right)^2 = \frac{1}{2} (\log r)^2, \quad \text{其中} \quad r = \frac{p(x)}{q(x)} k2=21(logq(x)p(x))2=21(logr)2,其中r=q(x)p(x)

他进一步说明,( k2k_2k2 ) 的期望是一个f-散度:

Ex∼q[k2]=Ex∼q[12(log⁡r)2] \mathbb{E}_{x \sim q}[k_2] = \mathbb{E}_{x \sim q} \left[ \frac{1}{2} (\log r)^2 \right] Exq[k2]=Exq[21(logr)2]

这对应于f-散度中的凸函数:

f(u)=12(log⁡u)2 f(u) = \frac{1}{2} (\log u)^2 f(u)=21(logu)2

为了理解 ( k2k_2k2 ) 的来源和为何它是一个合理的KL散度估计器,我们需要结合f-散度的性质和Schulman的推导思路进行分析。


推导 ( k2k_2k2 ) 的来源

以下是结合Schulman博客内容,推导 ( k2k_2k2 ) 如何作为KL散度估计器的过程:

1. KL散度的蒙特卡洛估计

KL散度 ( KL[q,p]KL[q, p]KL[q,p] ) 的定义为:

KL[q,p]=Ex∼q[log⁡q(x)p(x)]=Ex∼q[−log⁡r],r=p(x)q(x) KL[q, p] = \mathbb{E}_{x \sim q} \left[ \log \frac{q(x)}{p(x)} \right] = \mathbb{E}_{x \sim q} [-\log r], \quad r = \frac{p(x)}{q(x)} KL[q,p]=Exq[logp(x)q(x)]=Exq[logr],r=q(x)p(x)

最直接的蒙特卡洛估计器是 ( k1k_1k1 ):

k1=log⁡q(x)p(x)=−log⁡r k_1 = \log \frac{q(x)}{p(x)} = -\log r k1=logp(x)q(x)=logr

( k1k_1k1 ) 是无偏的,但由于 ( log⁡r\log rlogr ) 可正可负(( r>1r > 1r>1 ) 时 ( log⁡r>0\log r > 0logr>0 ),( r<1r < 1r<1 ) 时 ( log⁡r<0\log r < 0logr<0 )),其方差较高,而KL散度本身始终非负。这种高方差促使Schulman寻找更稳定的估计器。

2. 构造低方差估计器的动机

Schulman的目标是构造一个估计器,既能反映 ( ppp ) 和 ( qqq ) 之间的差异,又具有较低的方差。一个直观的思路是:既然KL散度衡量分布差异,可以尝试用某种“距离”函数来替代 ( log⁡q(x)p(x)\log \frac{q(x)}{p(x)}logp(x)q(x) ),使得估计值始终非负(与KL散度的性质一致),并且方差较小。

考虑 ( log⁡r=log⁡p(x)q(x)\log r = \log \frac{p(x)}{q(x)}logr=logq(x)p(x)),它直接衡量了概率比率的对数差异。如果我们取其平方:

(log⁡r)2=(log⁡p(x)q(x))2 (\log r)^2 = \left( \log \frac{p(x)}{q(x)} \right)^2 (logr)2=(logq(x)p(x))2

这个量有以下优点:

  • 非负性:( (log⁡r)2≥0(\log r)^2 \geq 0(logr)20 ),与KL散度的非负性质一致。
  • 对称性:它反映了 ( log⁡r\log rlogr ) 的大小,无论 ( log⁡r\log rlogr ) 是正还是负,都表示 ( ppp ) 和 ( qqq ) 的差异。
  • 平滑性:平方函数会放大较大的 ( log⁡r\log rlogr ) 值,强调分布差异较大的样本。

为了使量纲合理(KL散度的量纲类似于对数),Schulman引入一个系数 ( 12\frac{1}{2}21 ),定义:

k2=12(log⁡r)2=12(log⁡p(x)q(x))2 k_2 = \frac{1}{2} (\log r)^2 = \frac{1}{2} \left( \log \frac{p(x)}{q(x)} \right)^2 k2=21(logr)2=21(logq(x)p(x))2

3. ( k2k_2k2 ) 的期望是f-散度

计算 ( k2k_2k2 ) 的期望:

Ex∼q[k2]=Ex∼q[12(log⁡p(x)q(x))2]=Ex∼q[12(log⁡r)2] \mathbb{E}_{x \sim q}[k_2] = \mathbb{E}_{x \sim q} \left[ \frac{1}{2} \left( \log \frac{p(x)}{q(x)} \right)^2 \right] = \mathbb{E}_{x \sim q} \left[ \frac{1}{2} (\log r)^2 \right] Exq[k2]=Exq[21(logq(x)p(x))2]=Exq[21(logr)2]

这正是f-散度的形式,其中:

f(u)=12(log⁡u)2 f(u) = \frac{1}{2} (\log u)^2 f(u)=21(logu)2

我们需要验证 ( f(u)f(u)f(u) ) 是否满足f-散度的要求:

  • 凸性:计算 ( f(u)f(u)f(u) ) 的二阶导数:

f(u)=12(log⁡u)2 f(u) = \frac{1}{2} (\log u)^2 f(u)=21(logu)2

f′(u)=12⋅2log⁡u⋅1u=log⁡uu f'(u) = \frac{1}{2} \cdot 2 \log u \cdot \frac{1}{u} = \frac{\log u}{u} f(u)=212loguu1=ulogu

f′′(u)=1u⋅1u+log⁡uu2=1u2+log⁡uu2=1+log⁡uu2 f''(u) = \frac{1}{u} \cdot \frac{1}{u} + \frac{\log u}{u^2} = \frac{1}{u^2} + \frac{\log u}{u^2} = \frac{1 + \log u}{u^2} f′′(u)=u1u1+u2logu=u21+u2logu=u21+logu

由于 ( u>0u > 0u>0 ),( f′′(u)f''(u)f′′(u) ) 不总是正(例如当 ( u<e−1≈0.367u < e^{-1} \approx 0.367u<e10.367 ) 时,( 1+log⁡u<01 + \log u < 01+logu<0 )),但在实际应用中,( r=p(x)q(x)r = \frac{p(x)}{q(x)}r=q(x)p(x) ) 通常在 ( (0,∞)(0, \infty)(0,) ) 范围内变化,且在 ( p≈qp \approx qpq ) 时 ( r≈1r \approx 1r1 ),此时 ( f′′(1)=1>0f''(1) = 1 > 0f′′(1)=1>0 )。因此,在 ( r≈1r \approx 1r1 ) 的局部区域,( f(u)f(u)f(u) ) 是凸的,满足f-散度的要求。

  • ( f(1)=0f(1) = 0f(1)=0 )

f(1)=12(log⁡1)2=0 f(1) = \frac{1}{2} (\log 1)^2 = 0 f(1)=21(log1)2=0

这保证当 ( p=qp = qp=q )(即 ( r=1r = 1r=1 ))时,散度为零。

因此,( Ex∼q[k2]\mathbb{E}_{x \sim q}[k_2]Exq[k2] ) 是一个有效的f-散度。

4. 为何 ( k2k_2k2 ) 是KL散度的合理近似?

Schulman指出,所有可微的f-散度在 ( p≈qp \approx qpq ) 时,在二阶近似下与KL散度等价。让我们推导这一关键结论。

假设 ( pθp_{\theta}pθ ) 是一个参数化的分布,( pθ0=p0p_{\theta_0} = p_0pθ0=p0 ),我们考虑 ( Df(p0,pθ)D_f(p_0, p_{\theta})Df(p0,pθ) ) 在 ( θ≈θ0\theta \approx \theta_0θθ0 ) 时的泰勒展开。Schulman给出了近似:

Df(p0,pθ)=f′′(1)2θTFθ+O(θ3) D_f(p_0, p_{\theta}) = \frac{f''(1)}{2} \theta^T F \theta + O(\theta^3) Df(p0,pθ)=2f′′(1)θT+O(θ3)

其中 ( FFF ) 是 ( pθp_{\theta}pθ ) 在 ( θ=θ0\theta = \theta_0θ=θ0 ) 处的Fisher信息矩阵。关键在于 ( f′′(1)f''(1)f′′(1) )。我们计算 ( f(u)=12(log⁡u)2f(u) = \frac{1}{2} (\log u)^2f(u)=21(logu)2 ) 的二阶导数:

f′′(u)=1+log⁡uu2,f′′(1)=1+log⁡112=1 f''(u) = \frac{1 + \log u}{u^2}, \quad f''(1) = \frac{1 + \log 1}{1^2} = 1 f′′(u)=u21+logu,f′′(1)=121+log1=1

对于KL散度 ( KL[q,p]KL[q, p]KL[q,p] ),对应的f-散度函数是 ( f(u)=−log⁡uf(u) = -\log uf(u)=logu ):

f′(u)=−1u,f′′(u)=1u2,f′′(1)=1 f'(u) = -\frac{1}{u}, \quad f''(u) = \frac{1}{u^2}, \quad f''(1) = 1 f(u)=u1,f′′(u)=u21,f′′(1)=1

可以看到,( f′′(1)=1f''(1) = 1f′′(1)=1 ) 对于 ( k2k_2k2 ) 的 ( f(u)=12(log⁡u)2f(u) = \frac{1}{2} (\log u)^2f(u)=21(logu)2 ) 和KL散度的 ( f(u)=−log⁡uf(u) = -\log uf(u)=logu ) 都是相同的。因此,在 ( p≈qp \approx qpq ) 时,( Ex∼q[k2]\mathbb{E}_{x \sim q}[k_2]Exq[k2] ) 和 ( KL[q,p]KL[q, p]KL[q,p] ) 的二阶行为一致:

Ex∼q[k2]≈12θTFθ≈KL[q,p] \mathbb{E}_{x \sim q}[k_2] \approx \frac{1}{2} \theta^T F \theta \approx KL[q, p] Exq[k2]21θTKL[q,p]

这解释了为何 ( k2k_2k2 ) 虽然有偏(( E[k2]≠KL[q,p]\mathbb{E}[k_2] \neq KL[q, p]E[k2]=KL[q,p] )),但偏倚很小,尤其当 ( ppp ) 和 ( qqq ) 接近时。

5. ( k2k_2k2 ) 的直观解释

( k2=12(log⁡p(x)q(x))2k_2 = \frac{1}{2} \left( \log \frac{p(x)}{q(x)} \right)^2k2=21(logq(x)p(x))2 ) 的设计灵感可以理解为:

  • 惩罚对数差异:( log⁡p(x)q(x)\log \frac{p(x)}{q(x)}logq(x)p(x) ) 衡量了 ( ppp ) 和 ( qqq ) 在对数尺度上的相对差异,平方后强调了较大的差异。
  • 低方差:由于 ( (log⁡r)2≥0(\log r)^2 \geq 0(logr)20 ),每个样本的贡献都是非负的,避免了 ( k1=−log⁡rk_1 = -\log rk1=logr ) 中正负抵消导致的高方差。
  • f-散度的自然形式:平方形式 ( f(u)=12(log⁡u)2f(u) = \frac{1}{2} (\log u)^2f(u)=21(logu)2 ) 是一个自然的凸函数选择,且在 ( r≈1r \approx 1r1 ) 时与KL散度行为相似。

Schulman通过实验验证了 ( k2k_2k2 ) 的低方差和低偏倚。例如,当 ( q=N(0,1)q = \mathcal{N}(0, 1)q=N(0,1) )、( p=N(0.1,1)p = \mathcal{N}(0.1, 1)p=N(0.1,1) )(真实KL散度为0.005)时,( k2k_2k2 ) 的偏倚仅为0.2%,标准差远低于 ( k1k_1k1 )。


结合Schulman博客的总结

Schulman的博客通过引入 ( k2=12(log⁡p(x)q(x))2k_2 = \frac{1}{2} \left( \log \frac{p(x)}{q(x)} \right)^2k2=21(logq(x)p(x))2 ) 提供了一种低方差的KL散度估计器,其期望是一个f-散度,对应的凸函数为 ( f(u)=12(log⁡u)2f(u) = \frac{1}{2} (\log u)^2f(u)=21(logu)2 )。推导过程如下:

  1. 从KL散度的蒙特卡洛估计出发,识别 ( k1=log⁡q(x)p(x)k_1 = \log \frac{q(x)}{p(x)}k1=logp(x)q(x) ) 的高方差问题。
  2. 构造 ( k2=12(log⁡r)2k_2 = \frac{1}{2} (\log r)^2k2=21(logr)2 ),使其非负且方差较低。
  3. 证明 ( E[k2]\mathbb{E}[k_2]E[k2] ) 是一个f-散度,且在 ( p≈qp \approx qpq ) 时与KL散度在二阶近似下等价(因为 ( f′′(1)=1f''(1) = 1f′′(1)=1 ))。
  4. 通过实验验证 ( k2k_2k2 ) 的低偏倚和低方差。

f-散度的框架为 ( k2k_2k2 ) 的合理性提供了理论支持,因为它表明不同凸函数 ( fff ) 定义的散度在 ( p≈qp \approx qpq ) 时具有相似的行为。这也为后续 ( k3k_3k3 ) 等无偏估计器的设计提供了灵感(如通过控制变量进一步优化)。


k3的思考

介绍 ( k3k_3k3 ) 的来源及推导

在John Schulman的博客《Approximating KL Divergence》中,他提出了三种KL散度估计器:( k1k_1k1 )、( k2k_2k2 ) 和 ( k3k_3k3 )。我们已经讨论了 ( k1=log⁡q(x)p(x)=−log⁡rk_1 = \log \frac{q(x)}{p(x)} = -\log rk1=logp(x)q(x)=logr )(无偏但高方差)和 ( k2=12(log⁡p(x)q(x))2=12(log⁡r)2k_2 = \frac{1}{2} \left( \log \frac{p(x)}{q(x)} \right)^2 = \frac{1}{2} (\log r)^2k2=21(logq(x)p(x))2=21(logr)2 )(低方差但有偏)。( k3k_3k3 ) 是Schulman提出的第三个估计器,目标是兼顾 无偏性低方差,定义为:

k3=(r−1)−log⁡r,其中r=p(x)q(x) k_3 = \left( r - 1 \right) - \log r, \quad \text{其中} \quad r = \frac{p(x)}{q(x)} k3=(r1)logr,其中r=q(x)p(x)

本节将结合Schulman的博客内容,详细推导 ( k3k_3k3 ) 的来源,分析其与 ( k2k_2k2 ) 推导的联系,并解释其理论基础和优势。


推导 ( k3k_3k3 ) 的来源

1. 从 ( k1k_1k1 ) 的高方差出发

KL散度 ( KL[q,p]KL[q, p]KL[q,p] ) 的定义为:

KL[q,p]=Ex∼q[log⁡q(x)p(x)]=Ex∼q[−log⁡r],r=p(x)q(x) KL[q, p] = \mathbb{E}_{x \sim q} \left[ \log \frac{q(x)}{p(x)} \right] = \mathbb{E}_{x \sim q} [-\log r], \quad r = \frac{p(x)}{q(x)} KL[q,p]=Exq[logp(x)q(x)]=Exq[logr],r=q(x)p(x)

标准蒙特卡洛估计器 ( k1=−log⁡rk_1 = -\log rk1=logr ) 是无偏的,但由于 ( log⁡r\log rlogr ) 可正可负(( r>1r > 1r>1 ) 时 ( log⁡r>0\log r > 0logr>0 ),( r<1r < 1r<1 ) 时 ( log⁡r<0\log r < 0logr<0 )),其方差较高。Schulman的目标是构造一个估计器,既保留 ( k1k_1k1 ) 的无偏性,又降低方差。

2. 控制变量方法的引入

为了降低 ( k1k_1k1 ) 的方差,Schulman引入了 控制变量(control variate) 方法。控制变量方法是一种常用的方差缩减技术,核心思想是:在估计器中加入一个期望为零的项,这个项与原估计器负相关,从而抵消部分波动。

在KL散度的估计中,我们需要找到一个量 ( g(x)g(x)g(x) ),满足:

Ex∼q[g(x)]=0 \mathbb{E}_{x \sim q} [g(x)] = 0 Exq[g(x)]=0

然后构造新的估计器:

knew=k1+λg(x) k_{\text{new}} = k_1 + \lambda g(x) knew=k1+λg(x)

只要 ( Ex∼q[g(x)]=0\mathbb{E}_{x \sim q} [g(x)] = 0Exq[g(x)]=0 ),新估计器仍然无偏:

Ex∼q[knew]=Ex∼q[k1]+λEx∼q[g(x)]=KL[q,p]+0=KL[q,p] \mathbb{E}_{x \sim q} [k_{\text{new}}] = \mathbb{E}_{x \sim q} [k_1] + \lambda \mathbb{E}_{x \sim q} [g(x)] = KL[q, p] + 0 = KL[q, p] Exq[knew]=Exq[k1]+λExq[g(x)]=KL[q,p]+0=KL[q,p]

Schulman注意到,一个自然的控制变量是:

g(x)=r−1=p(x)q(x)−1 g(x) = r - 1 = \frac{p(x)}{q(x)} - 1 g(x)=r1=q(x)p(x)1

其期望为:

Ex∼q[r−1]=Ex∼q[p(x)q(x)]−1=∫q(x)⋅p(x)q(x) dx−1=∫p(x) dx−1=1−1=0 \mathbb{E}_{x \sim q} [r - 1] = \mathbb{E}_{x \sim q} \left[ \frac{p(x)}{q(x)} \right] - 1 = \int q(x) \cdot \frac{p(x)}{q(x)} \, dx - 1 = \int p(x) \, dx - 1 = 1 - 1 = 0 Exq[r1]=Exq[q(x)p(x)]1=q(x)q(x)p(x)dx1=p(x)dx1=11=0

因此,( r−1r - 1r1 ) 是一个期望为零的量,适合作为控制变量。于是,Schulman构造了新的估计器:

knew=k1+λ(r−1)=−log⁡r+λ(r−1) k_{\text{new}} = k_1 + \lambda (r - 1) = -\log r + \lambda (r - 1) knew=k1+λ(r1)=logr+λ(r1)

这里的 ( λ\lambdaλ ) 是一个可调参数,目标是通过选择合适的 ( λ\lambdaλ ) 最小化估计器的方差。

3. 选择 ( λ\lambdaλ ) 以保证非负性

理论上,可以通过最小化 ( knewk_{\text{new}}knew ) 的方差来求解最优的 ( λ\lambdaλ )。方差为:

Var(knew)=Var(−log⁡r+λ(r−1))=Var(−log⁡r)+λ2Var(r−1)+2λCov(−log⁡r,r−1) \text{Var}(k_{\text{new}}) = \text{Var}(-\log r + \lambda (r - 1)) = \text{Var}(-\log r) + \lambda^2 \text{Var}(r - 1) + 2 \lambda \text{Cov}(-\log r, r - 1) Var(knew)=Var(logr+λ(r1))=Var(logr)+λ2Var(r1)+2λCov(logr,r1)

最小化方差需要计算协方差和方差项,但这些项依赖于 ( ppp ) 和 ( qqq ) 的具体形式,解析求解往往复杂。Schulman提到,这种方法得到的 ( λ\lambdaλ ) 通常难以计算,因此他采用了更简单的策略:选择 ( λ\lambdaλ ) 使估计器 始终非负,从而直观上降低方差并与KL散度的非负性质一致。

Schulman利用对数函数的凹性,注意到:

log⁡u≤u−1,∀u>0 \log u \leq u - 1, \quad \forall u > 0 loguu1,u>0

等号在 ( u=1u = 1u=1 ) 时成立。这意味着:

−log⁡r≥−(r−1) -\log r \geq -(r - 1) logr(r1)

因此,如果选择 ( λ=1\lambda = 1λ=1 ),估计器变为:

k3=−log⁡r+(r−1)=(r−1)−log⁡r k_3 = -\log r + (r - 1) = (r - 1) - \log r k3=logr+(r1)=(r1)logr

我们可以验证 ( k3k_3k3 ) 的非负性:

k3=(r−1)−log⁡r k_3 = (r - 1) - \log r k3=(r1)logr

根据凹性不等式 ( log⁡r≤r−1\log r \leq r - 1logrr1 ),有:

(r−1)−log⁡r≥(r−1)−(r−1)=0 (r - 1) - \log r \geq (r - 1) - (r - 1) = 0 (r1)logr(r1)(r1)=0

因此,( k3≥0k_3 \geq 0k30 ),并且在 ( r=1r = 1r=1 )(即 ( p(x)=q(x)p(x) = q(x)p(x)=q(x) ))时,( k3=0k_3 = 0k3=0 )。这种非负性使得 ( k3k_3k3 ) 的样本值更稳定,避免了 ( k1k_1k1 ) 中正负抵消导致的高方差。

4. 验证 ( k3k_3k3 ) 的无偏性

我们需要确认 ( k3k_3k3 ) 是否无偏:

Ex∼q[k3]=Ex∼q[(r−1)−log⁡r]=Ex∼q[r−1]−Ex∼q[log⁡r] \mathbb{E}_{x \sim q} [k_3] = \mathbb{E}_{x \sim q} [(r - 1) - \log r] = \mathbb{E}_{x \sim q} [r - 1] - \mathbb{E}_{x \sim q} [\log r] Exq[k3]=Exq[(r1)logr]=Exq[r1]Exq[logr]

  • 第一项:( Ex∼q[r−1]=0\mathbb{E}_{x \sim q} [r - 1] = 0Exq[r1]=0 )(如前所述)。
  • 第二项:( Ex∼q[log⁡r]=Ex∼q[log⁡p(x)q(x)]=Ex∼q[log⁡q(x)p(x)]=KL[q,p]\mathbb{E}_{x \sim q} [\log r] = \mathbb{E}_{x \sim q} \left[ \log \frac{p(x)}{q(x)} \right] = \mathbb{E}_{x \sim q} \left[ \log \frac{q(x)}{p(x)} \right] = KL[q, p]Exq[logr]=Exq[logq(x)p(x)]=Exq[logp(x)q(x)]=KL[q,p] ).

因此:

Ex∼q[k3]=0−(−KL[q,p])=KL[q,p] \mathbb{E}_{x \sim q} [k_3] = 0 - (-KL[q, p]) = KL[q, p] Exq[k3]=0(KL[q,p])=KL[q,p]

( k3k_3k3 ) 是无偏的,满足我们对KL散度估计器的要求。

5. ( k3k_3k3 ) 与Bregman散度的联系

Schulman指出,( k3=(r−1)−log⁡rk_3 = (r - 1) - \log rk3=(r1)logr ) 的形式可以解释为一个 Bregman散度。Bregman散度衡量一个凸函数与其在某点的切线之间的垂直距离。对于凸函数 ( f(u)=−log⁡uf(u) = -\log uf(u)=logu ),其导数为:

f′(u)=−1u f'(u) = -\frac{1}{u} f(u)=u1

在点 ( u=1u = 1u=1 ) 处的切线为:

f(1)+f′(1)(u−1)=−log⁡1−11(u−1)=−(u−1) f(1) + f'(1)(u - 1) = -\log 1 - \frac{1}{1}(u - 1) = -(u - 1) f(1)+f(1)(u1)=log111(u1)=(u1)

Bregman散度定义为:

Df(u,1)=f(u)−f(1)−f′(1)(u−1) D_f(u, 1) = f(u) - f(1) - f'(1)(u - 1) Df(u,1)=f(u)f(1)f(1)(u1)

代入 ( f(u)=−log⁡uf(u) = -\log uf(u)=logu ):

Df(r,1)=(−log⁡r)−(−log⁡1)−(−11)(r−1)=−log⁡r−0+(r−1)=(r−1)−log⁡r D_f(r, 1) = (-\log r) - (-\log 1) - \left(-\frac{1}{1}\right)(r - 1) = -\log r - 0 + (r - 1) = (r - 1) - \log r Df(r,1)=(logr)(log1)(11)(r1)=logr0+(r1)=(r1)logr

这正是 ( k3k_3k3 ) 的形式!因此,( k3k_3k3 ) 可以看作 ( f(u)=−log⁡uf(u) = -\log uf(u)=logu ) 在 ( r=1r = 1r=1 ) 处的Bregman散度。这种解释提供了理论上的优雅性:( k3k_3k3 ) 不仅无偏,还通过Bregman散度的非负性自然降低了方差。

6. 实验验证

Schulman通过实验比较了 ( k1k_1k1 )、( k2k_2k2 ) 和 ( k3k_3k3 )。例如,当 ( q=N(0,1)q = \mathcal{N}(0, 1)q=N(0,1) )、( p=N(0.1,1)p = \mathcal{N}(0.1, 1)p=N(0.1,1) )(真实KL散度为0.005)时:

  • ( k1k_1k1 ): 无偏,标准差为真实值的20倍。
  • ( k2k_2k2 ): 偏倚为0.2%,标准差为真实值的1.42倍。
  • ( k3k_3k3 ): 无偏,标准差为真实值的1.42倍。

当KL散度较大(如 ( p=N(1,1)p = \mathcal{N}(1, 1)p=N(1,1) ),真实KL散度为0.5)时:

  • ( k3k_3k3 ): 无偏,标准差为真实值的1.7倍,优于 ( k2k_2k2 ) 的1.73倍和25%偏倚。

这些结果表明,(k3k_3k3 ) 成功结合了无偏性和低方差。


( k3k_3k3 ) 与 ( k2k_2k2 ) 推导的联系

虽然 ( k2k_2k2 ) 和 ( k3k_3k3 ) 的最终形式不同,但它们的推导和思考过程有以下联系:

  1. 共同目标:降低方差

    • ( k2k_2k2 ) 通过平方 ( log⁡r\log rlogr )(即 (12(log⁡r)2\frac{1}{2} (\log r)^221(logr)2 ))使估计值非负,强调分布差异并降低方差。
    • ( k3k_3k3 ) 通过控制变量 ( r−1r - 1r1 ) 和利用 ( log⁡r≤r−1\log r \leq r - 1logrr1 ) 的凹性不等式,同样追求非负性和低方差。
    • 两者都试图解决 ( k1=−log⁡rk_1 = -\log rk1=logr ) 的高方差问题,直观上希望估计器与KL散度的非负性质一致。
  2. f-散度与Bregman散度的理论支持

    • ( k2k_2k2 ) 的期望是一个f-散度(( f(u)=12(log⁡u)2f(u) = \frac{1}{2} (\log u)^2f(u)=21(logu)2 )),在 ( p≈qp \approx qpq ) 时通过二阶近似(( f′′(1)=1f''(1) = 1f′′(1)=1 ))与KL散度等价。
    • ( k3k_3k3 ) 直接是 ( KL[q,p]KL[q, p]KL[q,p] ) 的无偏估计器,且其形式对应于 ( f(u)=−log⁡uf(u) = -\log uf(u)=logu ) 的Bregman散度。Bregman散度可以看作f-散度的一种特殊形式(当考虑凸函数与其切线的距离时)。
    • 两者都利用了凸函数的性质(( k2k_2k2 ) 通过凸函数 ( f(u)f(u)f(u)),(k3k_3k3 ) 通过凹函数 ( −log⁡u-\log ulogu ) 的Bregman散度),体现了Schulman从广义散度框架出发的统一思路。
  3. 从有偏到无偏的改进

    • ( k2k_2k2 ) 是Schulman的初步尝试,牺牲了无偏性换取低方差,但通过f-散度的二阶等价性保证了低偏倚。
    • ( k3k_3k3 ) 是对 ( k2k_2k2 ) 的进一步改进,通过控制变量方法恢复了无偏性,同时保留了低方差的优点。( k3k_3k3 ) 的非负性(通过 ( log⁡r≤r−1\log r \leq r - 1logrr1 ))与 ( k2k_2k2 ) 的非负性(通过平方)有相似的直观动机。
  4. 设计思路的演进

    • ( k2k_2k2 ) 的设计更像是一种启发式尝试:用平方函数放大差异并确保非负,然后通过f-散度理论验证其合理性。
    • ( k3k_3k3 ) 的设计更有系统性,基于控制变量的统计方法,并通过Bregman散度的数学结构提供理论支撑。Schulman在推导 ( k3k_3k3 ) 时明确考虑了如何在 ( k1k_1k1 ) 的基础上系统性地降低方差,而 ( k2k_2k2 ) 更像是探索过程中的一个中间产物。
  5. 推广到其他散度

    • Schulman在博客中提到,( k2k_2k2 ) 和 ( k3k_3k3 ) 的思路可以推广到其他f-散度。例如,( k3k_3k3 ) 的形式 ( f(r)−f′(1)(r−1)f(r) - f'(1)(r - 1)f(r)f(1)(r1) ) 是一个通用的f-散度估计器,适用于任意凸函数 (fff )。
    • ( k2k_2k2 ) 的f-散度形式为其他估计器的设计提供了灵感,而 ( k3k_3k3 ) 的Bregman散度形式进一步揭示了这种估计器与凸优化理论的深层联系。

总结

( k3=(r−1)−log⁡rk_3 = (r - 1) - \log rk3=(r1)logr ) 的来源是通过控制变量方法改进 ( k1k_1k1 ),利用 ( r−1r - 1r1 ) 的零期望和 ( log⁡r≤r−1\log r \leq r - 1logrr1 ) 的凹性不等式,构造了一个无偏且低方差的KL散度估计器。其推导过程如下:

  1. 从 ( k1=−log⁡rk_1 = -\log rk1=logr ) 的高方差问题出发,引入控制变量 ( r−1r - 1r1 )。
  2. 构造新估计器 ( k3=−log⁡r+λ(r−1)k_3 = -\log r + \lambda (r - 1)k3=logr+λ(r1) ),选择 ( λ=1\lambda = 1λ=1 ) 保证非负性。
  3. 验证 ( k3k_3k3 ) 的无偏性和非负性,并通过Bregman散度提供理论解释。
  4. 实验确认 ( k3k_3k3 ) 兼顾无偏和低方差,优于 ( k2k_2k2 )。

与 ( k2k_2k2 ) 的联系在于:

  • 两者都旨在降低 ( k1k_1k1 ) 的方差,追求非负性以匹配KL散度的性质。
  • ( k2k_2k2 ) 通过f-散度框架启发了对分布差异的度量,( k3k_3k3 ) 则通过控制变量和Bregman散度实现了更系统的优化。
  • ( k3k_3k3 ) 可以看作 ( k2k_2k2 ) 思路的延伸,从有偏但低方差的探索(( k2k_2k2 ))进化到无偏且低方差的解决方案(( k3k_3k3 ))。

这种推导思路不仅在理论上优雅,还在DeepSeek GRPO等强化学习应用中得到了实践验证,展现了其在LLM优化中的重要价值。


后记

2025年4月20日于上海,在grok 3大模型辅助下完成。

Logo

欢迎加入DeepSeek 技术社区。在这里,你可以找到志同道合的朋友,共同探索AI技术的奥秘。

更多推荐