KL散度近似方法介绍:从John Schulman的博客到DeepSeek GRPO的应用
John Schulman在其2020年3月7日的博客中详细探讨了如何通过蒙特卡洛方法近似KL散度,并提出了一种低方差、无偏的估计器。这一方法不仅在理论上具有重要意义,还被DeepSeek的GRPO算法所采用。
KL散度近似方法介绍:从John Schulman的博客到DeepSeek GRPO的应用
KL散度(Kullback-Leibler Divergence)是衡量两个概率分布之间差异的重要指标,广泛应用于机器学习、强化学习等领域。John Schulman在其2020年3月7日的博客中详细探讨了如何通过蒙特卡洛方法近似KL散度,并提出了一种低方差、无偏的估计器。这一方法不仅在理论上具有重要意义,还被DeepSeek的GRPO算法所采用。本文将基于Schulman的博客内容,介绍KL散度的近似方法及其在DeepSeek GRPO中的应用。
blog:http://joschu.net/blog/kl-approx.html
首先看一下John Schulman是谁:PPO的作者,领导创建了ChatGPT,OpenAI的cofounder。

source:http://joschu.net/index.html

source:https://scholar.google.com/citations?user=itSa94cAAAAJ
KL散度的基本定义
KL散度衡量两个概率分布 ( ppp ) 和 ( qqq ) 之间的差异,定义为:
KL[q,p]=∑xq(x)logq(x)p(x)=Ex∼q[logq(x)p(x)] KL[q, p] = \sum_x q(x) \log \frac{q(x)}{p(x)} = \mathbb{E}_{x \sim q} \left[ \log \frac{q(x)}{p(x)} \right] KL[q,p]=x∑q(x)logp(x)q(x)=Ex∼q[logp(x)q(x)]
在实际应用中,精确计算KL散度通常不可行,因为:
- 计算所有 ( xxx ) 的概率和需要过多的计算或内存。
- 分布 ( ppp ) 和 ( qqq ) 可能没有闭合表达式。
- 在强化学习等场景中,KL散度常作为诊断工具,仅存储对数概率以简化代码。
因此,蒙特卡洛方法成为估计KL散度的常用策略。假设我们从分布 ( qqq ) 中采样 ( x1,x2,…x_1, x_2, \dotsx1,x2,… ),如何构造一个既无偏又低方差的估计器是关键问题。
传统估计器及其局限性
最直接的蒙特卡洛估计器是基于KL散度的定义:
k1=logq(x)p(x)=−logr,其中r=p(x)q(x) k_1 = \log \frac{q(x)}{p(x)} = -\log r, \quad \text{其中} \quad r = \frac{p(x)}{q(x)} k1=logp(x)q(x)=−logr,其中r=q(x)p(x)
这个估计器 ( k1k_1k1 ) 是无偏的,即其期望等于真实的KL散度。然而,由于 ( logr\log rlogr ) 的值在正负之间变化(当 ( r>1r > 1r>1 ) 时为正,当 ( r<1r < 1r<1 ) 时为负),其方差较高,而KL散度本身始终为正。这种高方差使得 ( k1k_1k1 ) 在实际应用中表现不佳。
低方差的偏倚估计器
Schulman提出了一种替代估计器:
k2=12(logp(x)q(x))2=12(logr)2 k_2 = \frac{1}{2} \left( \log \frac{p(x)}{q(x)} \right)^2 = \frac{1}{2} (\log r)^2 k2=21(logq(x)p(x))2=21(logr)2
这个估计器 ( k2k_2k2 ) 虽然有偏,但方差显著低于 ( k1k_1k1 )。其优点在于:
- 始终为正:每个样本都反映了 ( ppp ) 和 ( qqq ) 之间的差异,且结果非负,与KL散度的性质一致。
- 低偏倚:( k2k_2k2 ) 的期望是一个f-散度,其形式为:
Eq[k2]=Eq[12(logr)2] \mathbb{E}_q[k_2] = \mathbb{E}_q \left[ \frac{1}{2} (\log r)^2 \right] Eq[k2]=Eq[21(logr)2]
这是一个以 ( f(x)=12(logx)2f(x) = \frac{1}{2} (\log x)^2f(x)=21(logx)2 ) 定义的f-散度。当 ( ppp ) 和 ( qqq ) 接近时,所有可微的f-散度在二阶近似下与KL散度等价,因此 ( k2k_2k2 ) 的偏倚非常小。
实验表明,当 ( q=N(0,1)q = \mathcal{N}(0, 1)q=N(0,1) )、( p=N(0.1,1)p = \mathcal{N}(0.1, 1)p=N(0.1,1) ) 时(真实KL散度为0.005),( k2k_2k2 ) 的偏倚仅为0.2%,标准差为真实值的1.42倍,远低于 ( k1k_1k1 ) 的20倍。
无偏且低方差的估计器
为了兼顾无偏和低方差,Schulman引入了控制变量(control variate)方法。利用 ( Eq[r−1]=0\mathbb{E}_q[r - 1] = 0Eq[r−1]=0 ),可以构造一个新的估计器:
k3=−logr+λ(r−1) k_3 = -\log r + \lambda (r - 1) k3=−logr+λ(r−1)
通过选择适当的 ( λ\lambdaλ ),可以降低方差。当 ( λ=1\lambda = 1λ=1 ) 时,估计器变为:
k3=(r−1)−logr k_3 = (r - 1) - \log r k3=(r−1)−logr
由于对数的凹性,( logr≤r−1\log r \leq r - 1logr≤r−1 ),因此 ( k3k_3k3) 始终为正。这个估计器不仅无偏,而且方差低于 ( k1k_1k1 )。实验表明,当真实KL散度为0.5时,( k3k_3k3 ) 的标准差为真实值的1.7倍,低于 ( k2k_2k2 ) 的1.73倍,且无偏。
推广到其他f-散度
Schulman进一步指出,上述方法可以推广到任意f-散度。对于 ( KL[p,q]KL[p, q]KL[p,q] ),对应的f-散度估计器为:
rlogr−(r−1) r \log r - (r - 1) rlogr−(r−1)
这个估计器同样基于凸函数与其切线的距离(Bregman散度),保证了非负性和低方差。
DeepSeek GRPO中的应用
DeepSeek的GRPO算法直接采用了Schulman提出的无偏估计器 ( k3k_3k3 )。具体来说,GRPO使用以下估计器来近似 ( DKL(πθ∣∣πref)D_{KL}(\pi_\theta || \pi_{ref})DKL(πθ∣∣πref) ):
DKL(πθ∣∣πref)=πref(oi,t∣q,oi,<t)πθ(oi,t∣q,oi,<t)−logπref(oi,t∣q,oi,<t)πθ(oi,t∣q,oi,<t)−1 D_{KL}(\pi_\theta || \pi_{ref}) = \frac{\pi_{ref}(o_{i,t} | q, o_{i,<t})}{\pi_\theta(o_{i,t} | q, o_{i,<t})} - \log \frac{\pi_{ref}(o_{i,t} | q, o_{i,<t})}{\pi_\theta(o_{i,t} | q, o_{i,<t})} - 1 DKL(πθ∣∣πref)=πθ(oi,t∣q,oi,<t)πref(oi,t∣q,oi,<t)−logπθ(oi,t∣q,oi,<t)πref(oi,t∣q,oi,<t)−1
其中,( πθ\pi_\thetaπθ ) 表示策略网络,( πref\pi_{ref}πref ) 表示参考策略,( oi,to_{i,t}oi,t ) 表示在时间 ( ttt ) 的观测,( qqq ) 和 ( oi,<to_{i,<t}oi,<t ) 分别表示上下文和历史观测。令 ( r=πref(oi,t∣q,oi,<t)πθ(oi,t∣q,oi,<t)r = \frac{\pi_{ref}(o_{i,t} | q, o_{i,<t})}{\pi_\theta(o_{i,t} | q, o_{i,<t})}r=πθ(oi,t∣q,oi,<t)πref(oi,t∣q,oi,<t) ),该估计器可重写为:
k3=(r−1)−logr k_3 = (r - 1) - \log r k3=(r−1)−logr
这个估计器的优点在于:
- 无偏性:保证估计的期望等于真实的KL散度。
- 非负性:由于 ( logr≤r−1\log r \leq r - 1logr≤r−1 ),估计值始终为正,与KL散度的性质一致。
- 低方差:通过控制变量降低了估计的方差,提升了算法的稳定性。
与传统的KL惩罚项(例如直接使用 (logπθπref\log \frac{\pi_\theta}{\pi_{ref}}logπrefπθ))相比,GRPO的估计器在强化学习中表现更稳定,尤其是在策略优化需要精确控制分布差异时。
实验结果
Schulman的实验验证了 ( k3k_3k3 ) 的优越性。例如,当 ( q=N(0,1)q = \mathcal{N}(0, 1)q=N(0,1) )、( p=N(1,1)p = \mathcal{N}(1, 1)p=N(1,1) )(真实KL散度为0.5)时:
- ( k1k_1k1 ): 无偏,标准差为真实值的2倍。
- ( k2k_2k2 ): 偏倚为25%,标准差为真实值的1.73倍。
- ( k3k_3k3 ): 无偏,标准差为真实值的1.7倍。
这些结果表明,( k3k_3k3 ) 在保持无偏的同时,显著降低了方差。
总结
John Schulman的博客提供了一种优雅的KL散度近似方法,通过引入控制变量构造了无偏且低方差的估计器 ( k3=(r−1)−logrk_3 = (r - 1) - \log rk3=(r−1)−logr )。这一方法不仅在理论上具有普适性(可推广到任意f-散度),还在DeepSeek的GRPO算法中得到了实际应用。GRPO利用这一估计器实现了稳定的策略优化,展示了其在强化学习中的重要价值。对于需要在高维分布上估计KL散度的研究者和工程师来说,Schulman的方法提供了一个兼顾理论严谨性和实践效果的解决方案。
蒙特卡洛估计器
什么是蒙特卡洛估计器?
蒙特卡洛方法(Monte Carlo Method)是一种通过随机采样来估计复杂数学量的方法。它特别适合用来近似那些难以直接计算的积分、期望或求和,尤其是在高维空间或解析解不可得的情况下。蒙特卡洛估计器的核心思想是:通过从某个分布中抽取大量随机样本,计算这些样本上的函数值,然后用样本均值来近似目标量的期望。
在KL散度的背景下,KL散度 ( KL[q,p]KL[q, p]KL[q,p] ) 定义为:
KL[q,p]=∑xq(x)logq(x)p(x)=Ex∼q[logq(x)p(x)] KL[q, p] = \sum_x q(x) \log \frac{q(x)}{p(x)} = \mathbb{E}_{x \sim q} \left[ \log \frac{q(x)}{p(x)} \right] KL[q,p]=x∑q(x)logp(x)q(x)=Ex∼q[logp(x)q(x)]
这里的 ( Ex∼q\mathbb{E}_{x \sim q}Ex∼q ) 表示从分布 ( qqq ) 中采样的期望。如果我们无法直接计算整个求和(例如,分布 ( qqq ) 和 ( ppp ) 是高维的,或者没有闭合表达式),就可以用蒙特卡洛方法来估计这个期望。具体来说:
- 从分布 ( qqq ) 中抽取 ( NNN ) 个独立样本 ( x1,x2,…,xNx_1, x_2, \dots, x_Nx1,x2,…,xN )。
- 对每个样本 ( xix_ixi ),计算函数 ( f(xi)=logq(xi)p(xi)f(x_i) = \log \frac{q(x_i)}{p(x_i)}f(xi)=logp(xi)q(xi) )。
- 用样本均值来估计期望:
KL[q,p]≈1N∑i=1Nlogq(xi)p(xi) KL[q, p] \approx \frac{1}{N} \sum_{i=1}^N \log \frac{q(x_i)}{p(x_i)} KL[q,p]≈N1i=1∑Nlogp(xi)q(xi)
这个均值就是蒙特卡洛估计器,称为 ( k1k_1k1 ) 在你的问题中:
k1=logq(x)p(x)=−logr,其中r=p(x)q(x) k_1 = \log \frac{q(x)}{p(x)} = -\log r, \quad \text{其中} \quad r = \frac{p(x)}{q(x)} k1=logp(x)q(x)=−logr,其中r=q(x)p(x)
蒙特卡洛估计器的作用:它通过有限样本近似真实的KL散度值。KL散度本质上衡量两个分布 ( qqq ) 和 ( ppp ) 之间的“距离”(尽管严格来说它不是对称的距离),而蒙特卡洛估计器试图通过采样来捕捉这种差异。
蒙特卡洛估计器是否反映分布的距离?
是的,蒙特卡洛估计器 ( k1k_1k1 ) 的期望等于真实的KL散度,因此它确实反映了两个分布之间的“距离”。具体来说:
- 无偏性:( Ex∼q[k1]=Ex∼q[logq(x)p(x)]=KL[q,p]\mathbb{E}_{x \sim q}[k_1] = \mathbb{E}_{x \sim q} \left[ \log \frac{q(x)}{p(x)} \right] = KL[q, p]Ex∼q[k1]=Ex∼q[logp(x)q(x)]=KL[q,p] )。这意味着如果采样次数 ( N→∞N \to \inftyN→∞ ),估计器的平均值会收敛到真实的KL散度。
- 分布差异的体现:( k1=logq(x)p(x)k_1 = \log \frac{q(x)}{p(x)}k1=logp(x)q(x) ) 直接衡量了在样本点 ( xxx ) 上,分布 ( qqq ) 和 ( ppp ) 的概率密度(或概率)的对数差异。如果 ( q(x)≈p(x)q(x) \approx p(x)q(x)≈p(x) ),则 ( logq(x)p(x)≈0\log \frac{q(x)}{p(x)} \approx 0logp(x)q(x)≈0 ),表明分布很接近;如果 ( q(x)q(x)q(x) ) 和 ( p(x)p(x)p(x) ) 差异很大,则 ( k1k_1k1 ) 的绝对值会较大,反映分布的差异。
然而,( k1k_1k1 ) 的局限性在于其高方差。由于 ( logr\log rlogr ) (其中 ( r=p(x)q(x)r = \frac{p(x)}{q(x)}r=q(x)p(x) ))可能为正或负,样本值的波动会导致估计值的方差较大,尤其当 ( ppp ) 和 ( qqq ) 差异较大时。这种高方差使得 ( k1k_1k1 ) 在实际应用中不够稳定。
通过例子说明
让我们通过一个简单的例子来说明蒙特卡洛估计器如何工作,以及它如何反映分布的距离。
场景
假设有两个一维正态分布:
- ( q=N(0,1)q = \mathcal{N}(0, 1)q=N(0,1) ) (均值0,标准差1)
- ( p=N(0.5,1)p = \mathcal{N}(0.5, 1)p=N(0.5,1) ) (均值0.5,标准差1)
真实的KL散度 ( KL[q,p]KL[q, p]KL[q,p] ) 对于正态分布有解析解:
KL[q,p]=12((μq−μp)2σp2)=12(0−0.5)212=0.252=0.125 KL[q, p] = \frac{1}{2} \left( \frac{(\mu_q - \mu_p)^2}{\sigma_p^2} \right) = \frac{1}{2} \frac{(0 - 0.5)^2}{1^2} = \frac{0.25}{2} = 0.125 KL[q,p]=21(σp2(μq−μp)2)=2112(0−0.5)2=20.25=0.125
我们用蒙特卡洛方法来估计这个值。
蒙特卡洛估计步骤
- 采样:从 ( q=N(0,1)q = \mathcal{N}(0, 1)q=N(0,1) ) 中抽取 ( N=1000N = 1000N=1000 ) 个样本 ( x1,x2,…,x1000x_1, x_2, \dots, x_{1000}x1,x2,…,x1000 )。假设样本是随机生成的,例如 ( x1=−0.3,x2=1.2,x3=−1.5,…x_1 = -0.3, x_2 = 1.2, x_3 = -1.5, \dotsx1=−0.3,x2=1.2,x3=−1.5,… )。
- 计算 ( k1k_1k1 ):对每个样本 ( xix_ixi ),计算:
k1(xi)=logq(xi)p(xi) k_1(x_i) = \log \frac{q(x_i)}{p(x_i)} k1(xi)=logp(xi)q(xi)
正态分布的概率密度函数为:
q(x)=12πexp(−x22),p(x)=12πexp(−(x−0.5)22) q(x) = \frac{1}{\sqrt{2\pi}} \exp\left(-\frac{x^2}{2}\right), \quad p(x) = \frac{1}{\sqrt{2\pi}} \exp\left(-\frac{(x - 0.5)^2}{2}\right) q(x)=2π1exp(−2x2),p(x)=2π1exp(−2(x−0.5)2)
因此:
logq(xi)p(xi)=logq(xi)−logp(xi)=(−xi22)−(−(xi−0.5)22) \log \frac{q(x_i)}{p(x_i)} = \log q(x_i) - \log p(x_i) = \left( -\frac{x_i^2}{2} \right) - \left( -\frac{(x_i - 0.5)^2}{2} \right) logp(xi)q(xi)=logq(xi)−logp(xi)=(−2xi2)−(−2(xi−0.5)2)
=−xi22+(xi−0.5)22=(xi−0.5)2−xi22=xi2−xi+0.25−xi22=−xi+0.252 = -\frac{x_i^2}{2} + \frac{(x_i - 0.5)^2}{2} = \frac{(x_i - 0.5)^2 - x_i^2}{2} = \frac{x_i^2 - x_i + 0.25 - x_i^2}{2} = \frac{-x_i + 0.25}{2} =−2xi2+2(xi−0.5)2=2(xi−0.5)2−xi2=2xi2−xi+0.25−xi2=2−xi+0.25
k1(xi)=0.25−xi2 k_1(x_i) = \frac{0.25 - x_i}{2} k1(xi)=20.25−xi
- 估计KL散度:计算所有样本的均值:
KL^[q,p]=1N∑i=1Nk1(xi)=1N∑i=1N0.25−xi2 \hat{KL}[q, p] = \frac{1}{N} \sum_{i=1}^N k_1(x_i) = \frac{1}{N} \sum_{i=1}^N \frac{0.25 - x_i}{2} KL^[q,p]=N1i=1∑Nk1(xi)=N1i=1∑N20.25−xi
假设我们采样得到的 ( xix_ixi ) 的均值为 ( xˉ≈0\bar{x} \approx 0xˉ≈0 )(因为 ( xi∼N(0,1)x_i \sim \mathcal{N}(0, 1)xi∼N(0,1) )),则:
KL^[q,p]≈11000∑i=110000.25−xi2≈0.25−02=0.125 \hat{KL}[q, p] \approx \frac{1}{1000} \sum_{i=1}^{1000} \frac{0.25 - x_i}{2} \approx \frac{0.25 - 0}{2} = 0.125 KL^[q,p]≈10001i=1∑100020.25−xi≈20.25−0=0.125
这个估计值接近真实的 ( KL[q,p]=0.125KL[q, p] = 0.125KL[q,p]=0.125 ),说明 ( k1k_1k1 ) 是无偏的。
高方差的表现
尽管 ( k1k_1k1 ) 无偏,但它的值在样本间波动较大。例如:
- 如果 ( xi=0x_i = 0xi=0 ),则 ( k1=0.25−02=0.125k_1 = \frac{0.25 - 0}{2} = 0.125k1=20.25−0=0.125 )。
- 如果 ( xi=2x_i = 2xi=2 ),则 ( k1=0.25−22=−0.875k_1 = \frac{0.25 - 2}{2} = -0.875k1=20.25−2=−0.875 )。
- 如果 ( xi=−2x_i = -2xi=−2 ),则 ( k1=0.25−(−2)2=1.125k_1 = \frac{0.25 - (-2)}{2} = 1.125k1=20.25−(−2)=1.125 ).
这些值的正负波动导致样本均值的方差较高。如果采样次数 ( NNN ) 不够多,估计值可能偏离0.125较远。
反映分布距离
从 ( k1(xi)=0.25−xi2k_1(x_i) = \frac{0.25 - x_i}{2}k1(xi)=20.25−xi ) 可以看出,( k1k_1k1 ) 的值直接与样本 ( xix_ixi) 有关,而样本是由 ( qqq ) 生成的。KL散度的值0.125反映了 ( qqq ) 和 ( ppp ) 之间的均值差异(0 vs 0.5)。如果我们改变 ( ppp ) 的均值(例如 ( p=N(1,1)p = \mathcal{N}(1, 1)p=N(1,1) )),KL散度会变大(( KL[q,p]=0.5KL[q, p] = 0.5KL[q,p]=0.5 )),相应的 ( k1k_1k1 ) 的均值也会变大,反映出分布间更大的差异。
总结
- 蒙特卡洛估计器是通过从分布 ( qqq ) 中采样,用样本上的函数值均值来近似KL散度的期望。
- 反映分布距离:( k1=logq(x)p(x)k_1 = \log \frac{q(x)}{p(x)}k1=logp(x)q(x) ) 的期望等于KL散度,因此它确实捕捉了 ( qqq ) 和 ( ppp ) 之间的差异。
- 局限性:高方差使得 ( k1k_1k1 ) 在有限样本下不够稳定,如例子中 ( k1k_1k1 ) 的值在正负间波动。
- 改进方向:Schulman的博客中提出了 ( k2k_2k2 ) 和 ( k3k_3k3 ) 等估计器,通过降低方差(例如保证非负性或引入控制变量)来改进 ( k1k_1k1 ),在DeepSeek GRPO等应用中表现出色。
log相减就是k1估计器
在许多大语言模型(LLM)强化学习(RL)相关的论文中,尤其是那些涉及策略优化(如PPO、TRPO或RLHF)的论文,使用的KL散度估计通常是基于 ( k1k_1k1 ) 估计器,即直接计算对数概率的差值:
k1=logq(x)p(x)=logq(x)−logp(x) k_1 = \log \frac{q(x)}{p(x)} = \log q(x) - \log p(x) k1=logp(x)q(x)=logq(x)−logp(x)
或者在策略优化的上下文中,写作:
k1=logπref(a∣s)πθ(a∣s)=logπref(a∣s)−logπθ(a∣s) k_1 = \log \frac{\pi_{\text{ref}}(a|s)}{\pi_{\theta}(a|s)} = \log \pi_{\text{ref}}(a|s) - \log \pi_{\theta}(a|s) k1=logπθ(a∣s)πref(a∣s)=logπref(a∣s)−logπθ(a∣s)
其中,( πθ\pi_{\theta}πθ ) 是当前策略,( πref\pi_{\text{ref}}πref ) 是参考策略(如初始策略或SFT模型),( aaa ) 是动作(在LLM中通常是生成的token),( sss ) 是状态(上下文或提示)。这种形式的KL散度估计正是Schulman博客中描述的 ( k1k_1k1 ) 估计器,基于蒙特卡洛方法,通过采样计算对数概率差的均值来近似 ( KL[πθ∣∣πref]KL[\pi_{\theta} || \pi_{\text{ref}}]KL[πθ∣∣πref] ) 或 ( KL[πref∣∣πθ]KL[\pi_{\text{ref}} || \pi_{\theta}]KL[πref∣∣πθ] )。
为什么常用 ( k1k_1k1 ) 估计器?
-
简单直接:( k1k_1k1 ) 是KL散度的直接蒙特卡洛估计,形式上就是对数概率的差值,易于实现。LLM通常会输出每个token的对数概率(logits经过softmax后取log),因此计算 ( logπθ(a∣s)\log \pi_{\theta}(a|s)logπθ(a∣s) ) 和 ( logπref(a∣s)\log \pi_{\text{ref}}(a|s)logπref(a∣s) ) 是现成的。
-
无偏性:( k1k_1k1 ) 是无偏估计器,其期望等于真实的KL散度:
Ea∼πθ[logπθ(a∣s)πref(a∣s)]=KL[πθ∣∣πref] \mathbb{E}_{a \sim \pi_{\theta}} \left[ \log \frac{\pi_{\theta}(a|s)}{\pi_{\text{ref}}(a|s)} \right] = KL[\pi_{\theta} || \pi_{\text{ref}}] Ea∼πθ[logπref(a∣s)πθ(a∣s)]=KL[πθ∣∣πref]
这保证了在采样足够多的情况下,估计值会收敛到真实KL散度。
- 常见应用场景:在强化学习中,KL散度常用于以下场景:
- 正则化项:如PPO中,KL散度作为惩罚项,限制新策略 ( πθ\pi_{\theta}πθ ) 偏离参考策略 ( πref\pi_{\text{ref}}πref ) 过远。
- 诊断工具:监控策略更新的稳定性。
- 奖励函数:在RLHF中,KL散度常用于奖励函数(如 ( R=Rhuman−βKL[πθ∣∣πref]R = R_{\text{human}} - \beta KL[\pi_{\theta} || \pi_{\text{ref}}]R=Rhuman−βKL[πθ∣∣πref] )),以平衡偏好和策略稳定性。
这些场景中,( k1k_1k1 ) 的形式简单,且与策略梯度计算兼容,因此被广泛采用。
( k1k_1k1 ) 估计器的局限性在LLM RL中的体现
尽管 ( k1k_1k1 ) 常用,但正如Schulman的博客所指出的,它有高方差的缺点。这在LLM RL中可能表现为:
- 方差导致的不稳定:由于 ( k1=logπθ(a∣s)πref(a∣s)k_1 = \log \frac{\pi_{\theta}(a|s)}{\pi_{\text{ref}}(a|s)}k1=logπref(a∣s)πθ(a∣s) ) 可能为正或负(当 ( πθ(a∣s)>πref(a∣s)\pi_{\theta}(a|s) > \pi_{\text{ref}}(a|s)πθ(a∣s)>πref(a∣s) ) 时为正,反之为负),样本间的波动会导致KL散度估计不稳定,尤其在序列生成中,token序列较长时方差会累积。
- 采样效率:LLM生成序列的采样成本高(需要运行整个模型),如果 ( k1k_1k1 ) 的高方差要求更多样本才能得到可靠估计,会增加计算开销。
- 极端概率的影响:当 ( πθ(a∣s)\pi_{\theta}(a|s)πθ(a∣s) ) 或 ( πref(a∣s)\pi_{\text{ref}}(a|s)πref(a∣s) ) 接近0时,( logπθ(a∣s)πref(a∣s)\log \frac{\pi_{\theta}(a|s)}{\pi_{\text{ref}}(a|s)}logπref(a∣s)πθ(a∣s) ) 可能变得很大或很小,导致估计值的异常波动。
总结
是的,许多LLM强化学习论文中使用的KL散度估计确实是 ( k1k_1k1 ) 估计器(( logπθπref\log \frac{\pi_{\theta}}{\pi_{\text{ref}}}logπrefπθ )),因为它简单、无偏且易于实现。然而,其高方差可能导致不稳定,尤其在长序列生成或概率差异较大的情况下。一些前沿工作(如DeepSeek GRPO)采用了Schulman提出的 ( k3k_3k3 ) 估计器,以获得更低的方差和更稳定的优化。如果你阅读的论文明确提到KL散度惩罚,通常默认是 ( k1k_1k1 ),但需要检查具体实现或损失函数是否使用了改进的估计器(如 ( k3k_3k3 ) 或其他变体)。
f-散度和k2
什么是f-散度?
f-散度(f-divergence)是一类用于衡量两个概率分布 ( ppp ) 和 ( qqq ) 之间差异的广义距离度量。它通过一个凸函数 ( fff ) 定义,形式如下:
Df(p,q)=Ex∼q[f(p(x)q(x))] D_f(p, q) = \mathbb{E}_{x \sim q} \left[ f\left( \frac{p(x)}{q(x)} \right) \right] Df(p,q)=Ex∼q[f(q(x)p(x))]
其中:
- ( p(x)p(x)p(x) ) 和 ( q(x)q(x)q(x) ) 是分布 ( ppp ) 和 ( qqq ) 在点 ( xxx ) 上的概率密度(或概率质量,视离散/连续分布而定)。
- ( r=p(x)q(x)r = \frac{p(x)}{q(x)}r=q(x)p(x) ) 是概率比率。
- ( f:R+→Rf: \mathbb{R}^+ \to \mathbb{R}f:R+→R ) 是一个凸函数,且通常满足 ( f(1)=0f(1) = 0f(1)=0 ),以确保当 ( p=qp = qp=q ) 时散度为零。
f-散度的特点是它包含了许多常见的分布距离度量,通过选择不同的 ( fff ) 函数,可以得到不同的散度。例如:
- KL散度:当 ( f(u)=uloguf(u) = u \log uf(u)=ulogu ),对应 ( Df(p,q)=KL(p∣∣q)=Eq[p(x)q(x)logp(x)q(x)]D_f(p, q) = KL(p || q) = \mathbb{E}_{q} \left[ \frac{p(x)}{q(x)} \log \frac{p(x)}{q(x)} \right]Df(p,q)=KL(p∣∣q)=Eq[q(x)p(x)logq(x)p(x)] )。
- 逆KL散度:当 ( f(u)=−loguf(u) = -\log uf(u)=−logu ),对应 ( Df(p,q)=KL(q∣∣p)D_f(p, q) = KL(q || p)Df(p,q)=KL(q∣∣p) )。
- Jensen-Shannon散度:通过选择适当的 ( fff ) 函数,可以推导JS散度。
- 总变差距离:当 ( f(u)=∣u−1∣f(u) = |u - 1|f(u)=∣u−1∣ ),对应总变差。
f-散度的优点是它提供了一个统一的框架来研究不同分布距离的性质,且许多f-散度在 ( p≈qp \approx qp≈q ) 时具有相似的二阶行为(如Schulman博客中提到的与KL散度的关系)。
Schulman博客中的f-散度背景
在John Schulman的博客《Approximating KL Divergence》中,他讨论了如何通过蒙特卡洛方法近似KL散度 ( KL[q,p]KL[q, p]KL[q,p] ),并提出了一种低方差但有偏的估计器 ( k2k_2k2 )。他指出,( k2k_2k2 ) 的期望是一个f-散度,并且在 ( p≈qp \approx qp≈q ) 时,所有可微的f-散度在二阶近似下与KL散度等价。这种等价性是 ( k2k_2k2 ) 低偏倚的理论基础。
具体来说,Schulman定义了 ( k2k_2k2 ) 估计器为:
k2=12(logp(x)q(x))2=12(logr)2,其中r=p(x)q(x) k_2 = \frac{1}{2} \left( \log \frac{p(x)}{q(x)} \right)^2 = \frac{1}{2} (\log r)^2, \quad \text{其中} \quad r = \frac{p(x)}{q(x)} k2=21(logq(x)p(x))2=21(logr)2,其中r=q(x)p(x)
他进一步说明,( k2k_2k2 ) 的期望是一个f-散度:
Ex∼q[k2]=Ex∼q[12(logr)2] \mathbb{E}_{x \sim q}[k_2] = \mathbb{E}_{x \sim q} \left[ \frac{1}{2} (\log r)^2 \right] Ex∼q[k2]=Ex∼q[21(logr)2]
这对应于f-散度中的凸函数:
f(u)=12(logu)2 f(u) = \frac{1}{2} (\log u)^2 f(u)=21(logu)2
为了理解 ( k2k_2k2 ) 的来源和为何它是一个合理的KL散度估计器,我们需要结合f-散度的性质和Schulman的推导思路进行分析。
推导 ( k2k_2k2 ) 的来源
以下是结合Schulman博客内容,推导 ( k2k_2k2 ) 如何作为KL散度估计器的过程:
1. KL散度的蒙特卡洛估计
KL散度 ( KL[q,p]KL[q, p]KL[q,p] ) 的定义为:
KL[q,p]=Ex∼q[logq(x)p(x)]=Ex∼q[−logr],r=p(x)q(x) KL[q, p] = \mathbb{E}_{x \sim q} \left[ \log \frac{q(x)}{p(x)} \right] = \mathbb{E}_{x \sim q} [-\log r], \quad r = \frac{p(x)}{q(x)} KL[q,p]=Ex∼q[logp(x)q(x)]=Ex∼q[−logr],r=q(x)p(x)
最直接的蒙特卡洛估计器是 ( k1k_1k1 ):
k1=logq(x)p(x)=−logr k_1 = \log \frac{q(x)}{p(x)} = -\log r k1=logp(x)q(x)=−logr
( k1k_1k1 ) 是无偏的,但由于 ( logr\log rlogr ) 可正可负(( r>1r > 1r>1 ) 时 ( logr>0\log r > 0logr>0 ),( r<1r < 1r<1 ) 时 ( logr<0\log r < 0logr<0 )),其方差较高,而KL散度本身始终非负。这种高方差促使Schulman寻找更稳定的估计器。
2. 构造低方差估计器的动机
Schulman的目标是构造一个估计器,既能反映 ( ppp ) 和 ( qqq ) 之间的差异,又具有较低的方差。一个直观的思路是:既然KL散度衡量分布差异,可以尝试用某种“距离”函数来替代 ( logq(x)p(x)\log \frac{q(x)}{p(x)}logp(x)q(x) ),使得估计值始终非负(与KL散度的性质一致),并且方差较小。
考虑 ( logr=logp(x)q(x)\log r = \log \frac{p(x)}{q(x)}logr=logq(x)p(x)),它直接衡量了概率比率的对数差异。如果我们取其平方:
(logr)2=(logp(x)q(x))2 (\log r)^2 = \left( \log \frac{p(x)}{q(x)} \right)^2 (logr)2=(logq(x)p(x))2
这个量有以下优点:
- 非负性:( (logr)2≥0(\log r)^2 \geq 0(logr)2≥0 ),与KL散度的非负性质一致。
- 对称性:它反映了 ( logr\log rlogr ) 的大小,无论 ( logr\log rlogr ) 是正还是负,都表示 ( ppp ) 和 ( qqq ) 的差异。
- 平滑性:平方函数会放大较大的 ( logr\log rlogr ) 值,强调分布差异较大的样本。
为了使量纲合理(KL散度的量纲类似于对数),Schulman引入一个系数 ( 12\frac{1}{2}21 ),定义:
k2=12(logr)2=12(logp(x)q(x))2 k_2 = \frac{1}{2} (\log r)^2 = \frac{1}{2} \left( \log \frac{p(x)}{q(x)} \right)^2 k2=21(logr)2=21(logq(x)p(x))2
3. ( k2k_2k2 ) 的期望是f-散度
计算 ( k2k_2k2 ) 的期望:
Ex∼q[k2]=Ex∼q[12(logp(x)q(x))2]=Ex∼q[12(logr)2] \mathbb{E}_{x \sim q}[k_2] = \mathbb{E}_{x \sim q} \left[ \frac{1}{2} \left( \log \frac{p(x)}{q(x)} \right)^2 \right] = \mathbb{E}_{x \sim q} \left[ \frac{1}{2} (\log r)^2 \right] Ex∼q[k2]=Ex∼q[21(logq(x)p(x))2]=Ex∼q[21(logr)2]
这正是f-散度的形式,其中:
f(u)=12(logu)2 f(u) = \frac{1}{2} (\log u)^2 f(u)=21(logu)2
我们需要验证 ( f(u)f(u)f(u) ) 是否满足f-散度的要求:
- 凸性:计算 ( f(u)f(u)f(u) ) 的二阶导数:
f(u)=12(logu)2 f(u) = \frac{1}{2} (\log u)^2 f(u)=21(logu)2
f′(u)=12⋅2logu⋅1u=loguu f'(u) = \frac{1}{2} \cdot 2 \log u \cdot \frac{1}{u} = \frac{\log u}{u} f′(u)=21⋅2logu⋅u1=ulogu
f′′(u)=1u⋅1u+loguu2=1u2+loguu2=1+loguu2 f''(u) = \frac{1}{u} \cdot \frac{1}{u} + \frac{\log u}{u^2} = \frac{1}{u^2} + \frac{\log u}{u^2} = \frac{1 + \log u}{u^2} f′′(u)=u1⋅u1+u2logu=u21+u2logu=u21+logu
由于 ( u>0u > 0u>0 ),( f′′(u)f''(u)f′′(u) ) 不总是正(例如当 ( u<e−1≈0.367u < e^{-1} \approx 0.367u<e−1≈0.367 ) 时,( 1+logu<01 + \log u < 01+logu<0 )),但在实际应用中,( r=p(x)q(x)r = \frac{p(x)}{q(x)}r=q(x)p(x) ) 通常在 ( (0,∞)(0, \infty)(0,∞) ) 范围内变化,且在 ( p≈qp \approx qp≈q ) 时 ( r≈1r \approx 1r≈1 ),此时 ( f′′(1)=1>0f''(1) = 1 > 0f′′(1)=1>0 )。因此,在 ( r≈1r \approx 1r≈1 ) 的局部区域,( f(u)f(u)f(u) ) 是凸的,满足f-散度的要求。
- ( f(1)=0f(1) = 0f(1)=0 ):
f(1)=12(log1)2=0 f(1) = \frac{1}{2} (\log 1)^2 = 0 f(1)=21(log1)2=0
这保证当 ( p=qp = qp=q )(即 ( r=1r = 1r=1 ))时,散度为零。
因此,( Ex∼q[k2]\mathbb{E}_{x \sim q}[k_2]Ex∼q[k2] ) 是一个有效的f-散度。
4. 为何 ( k2k_2k2 ) 是KL散度的合理近似?
Schulman指出,所有可微的f-散度在 ( p≈qp \approx qp≈q ) 时,在二阶近似下与KL散度等价。让我们推导这一关键结论。
假设 ( pθp_{\theta}pθ ) 是一个参数化的分布,( pθ0=p0p_{\theta_0} = p_0pθ0=p0 ),我们考虑 ( Df(p0,pθ)D_f(p_0, p_{\theta})Df(p0,pθ) ) 在 ( θ≈θ0\theta \approx \theta_0θ≈θ0 ) 时的泰勒展开。Schulman给出了近似:
Df(p0,pθ)=f′′(1)2θTFθ+O(θ3) D_f(p_0, p_{\theta}) = \frac{f''(1)}{2} \theta^T F \theta + O(\theta^3) Df(p0,pθ)=2f′′(1)θTFθ+O(θ3)
其中 ( FFF ) 是 ( pθp_{\theta}pθ ) 在 ( θ=θ0\theta = \theta_0θ=θ0 ) 处的Fisher信息矩阵。关键在于 ( f′′(1)f''(1)f′′(1) )。我们计算 ( f(u)=12(logu)2f(u) = \frac{1}{2} (\log u)^2f(u)=21(logu)2 ) 的二阶导数:
f′′(u)=1+loguu2,f′′(1)=1+log112=1 f''(u) = \frac{1 + \log u}{u^2}, \quad f''(1) = \frac{1 + \log 1}{1^2} = 1 f′′(u)=u21+logu,f′′(1)=121+log1=1
对于KL散度 ( KL[q,p]KL[q, p]KL[q,p] ),对应的f-散度函数是 ( f(u)=−loguf(u) = -\log uf(u)=−logu ):
f′(u)=−1u,f′′(u)=1u2,f′′(1)=1 f'(u) = -\frac{1}{u}, \quad f''(u) = \frac{1}{u^2}, \quad f''(1) = 1 f′(u)=−u1,f′′(u)=u21,f′′(1)=1
可以看到,( f′′(1)=1f''(1) = 1f′′(1)=1 ) 对于 ( k2k_2k2 ) 的 ( f(u)=12(logu)2f(u) = \frac{1}{2} (\log u)^2f(u)=21(logu)2 ) 和KL散度的 ( f(u)=−loguf(u) = -\log uf(u)=−logu ) 都是相同的。因此,在 ( p≈qp \approx qp≈q ) 时,( Ex∼q[k2]\mathbb{E}_{x \sim q}[k_2]Ex∼q[k2] ) 和 ( KL[q,p]KL[q, p]KL[q,p] ) 的二阶行为一致:
Ex∼q[k2]≈12θTFθ≈KL[q,p] \mathbb{E}_{x \sim q}[k_2] \approx \frac{1}{2} \theta^T F \theta \approx KL[q, p] Ex∼q[k2]≈21θTFθ≈KL[q,p]
这解释了为何 ( k2k_2k2 ) 虽然有偏(( E[k2]≠KL[q,p]\mathbb{E}[k_2] \neq KL[q, p]E[k2]=KL[q,p] )),但偏倚很小,尤其当 ( ppp ) 和 ( qqq ) 接近时。
5. ( k2k_2k2 ) 的直观解释
( k2=12(logp(x)q(x))2k_2 = \frac{1}{2} \left( \log \frac{p(x)}{q(x)} \right)^2k2=21(logq(x)p(x))2 ) 的设计灵感可以理解为:
- 惩罚对数差异:( logp(x)q(x)\log \frac{p(x)}{q(x)}logq(x)p(x) ) 衡量了 ( ppp ) 和 ( qqq ) 在对数尺度上的相对差异,平方后强调了较大的差异。
- 低方差:由于 ( (logr)2≥0(\log r)^2 \geq 0(logr)2≥0 ),每个样本的贡献都是非负的,避免了 ( k1=−logrk_1 = -\log rk1=−logr ) 中正负抵消导致的高方差。
- f-散度的自然形式:平方形式 ( f(u)=12(logu)2f(u) = \frac{1}{2} (\log u)^2f(u)=21(logu)2 ) 是一个自然的凸函数选择,且在 ( r≈1r \approx 1r≈1 ) 时与KL散度行为相似。
Schulman通过实验验证了 ( k2k_2k2 ) 的低方差和低偏倚。例如,当 ( q=N(0,1)q = \mathcal{N}(0, 1)q=N(0,1) )、( p=N(0.1,1)p = \mathcal{N}(0.1, 1)p=N(0.1,1) )(真实KL散度为0.005)时,( k2k_2k2 ) 的偏倚仅为0.2%,标准差远低于 ( k1k_1k1 )。
结合Schulman博客的总结
Schulman的博客通过引入 ( k2=12(logp(x)q(x))2k_2 = \frac{1}{2} \left( \log \frac{p(x)}{q(x)} \right)^2k2=21(logq(x)p(x))2 ) 提供了一种低方差的KL散度估计器,其期望是一个f-散度,对应的凸函数为 ( f(u)=12(logu)2f(u) = \frac{1}{2} (\log u)^2f(u)=21(logu)2 )。推导过程如下:
- 从KL散度的蒙特卡洛估计出发,识别 ( k1=logq(x)p(x)k_1 = \log \frac{q(x)}{p(x)}k1=logp(x)q(x) ) 的高方差问题。
- 构造 ( k2=12(logr)2k_2 = \frac{1}{2} (\log r)^2k2=21(logr)2 ),使其非负且方差较低。
- 证明 ( E[k2]\mathbb{E}[k_2]E[k2] ) 是一个f-散度,且在 ( p≈qp \approx qp≈q ) 时与KL散度在二阶近似下等价(因为 ( f′′(1)=1f''(1) = 1f′′(1)=1 ))。
- 通过实验验证 ( k2k_2k2 ) 的低偏倚和低方差。
f-散度的框架为 ( k2k_2k2 ) 的合理性提供了理论支持,因为它表明不同凸函数 ( fff ) 定义的散度在 ( p≈qp \approx qp≈q ) 时具有相似的行为。这也为后续 ( k3k_3k3 ) 等无偏估计器的设计提供了灵感(如通过控制变量进一步优化)。
k3的思考
介绍 ( k3k_3k3 ) 的来源及推导
在John Schulman的博客《Approximating KL Divergence》中,他提出了三种KL散度估计器:( k1k_1k1 )、( k2k_2k2 ) 和 ( k3k_3k3 )。我们已经讨论了 ( k1=logq(x)p(x)=−logrk_1 = \log \frac{q(x)}{p(x)} = -\log rk1=logp(x)q(x)=−logr )(无偏但高方差)和 ( k2=12(logp(x)q(x))2=12(logr)2k_2 = \frac{1}{2} \left( \log \frac{p(x)}{q(x)} \right)^2 = \frac{1}{2} (\log r)^2k2=21(logq(x)p(x))2=21(logr)2 )(低方差但有偏)。( k3k_3k3 ) 是Schulman提出的第三个估计器,目标是兼顾 无偏性 和 低方差,定义为:
k3=(r−1)−logr,其中r=p(x)q(x) k_3 = \left( r - 1 \right) - \log r, \quad \text{其中} \quad r = \frac{p(x)}{q(x)} k3=(r−1)−logr,其中r=q(x)p(x)
本节将结合Schulman的博客内容,详细推导 ( k3k_3k3 ) 的来源,分析其与 ( k2k_2k2 ) 推导的联系,并解释其理论基础和优势。
推导 ( k3k_3k3 ) 的来源
1. 从 ( k1k_1k1 ) 的高方差出发
KL散度 ( KL[q,p]KL[q, p]KL[q,p] ) 的定义为:
KL[q,p]=Ex∼q[logq(x)p(x)]=Ex∼q[−logr],r=p(x)q(x) KL[q, p] = \mathbb{E}_{x \sim q} \left[ \log \frac{q(x)}{p(x)} \right] = \mathbb{E}_{x \sim q} [-\log r], \quad r = \frac{p(x)}{q(x)} KL[q,p]=Ex∼q[logp(x)q(x)]=Ex∼q[−logr],r=q(x)p(x)
标准蒙特卡洛估计器 ( k1=−logrk_1 = -\log rk1=−logr ) 是无偏的,但由于 ( logr\log rlogr ) 可正可负(( r>1r > 1r>1 ) 时 ( logr>0\log r > 0logr>0 ),( r<1r < 1r<1 ) 时 ( logr<0\log r < 0logr<0 )),其方差较高。Schulman的目标是构造一个估计器,既保留 ( k1k_1k1 ) 的无偏性,又降低方差。
2. 控制变量方法的引入
为了降低 ( k1k_1k1 ) 的方差,Schulman引入了 控制变量(control variate) 方法。控制变量方法是一种常用的方差缩减技术,核心思想是:在估计器中加入一个期望为零的项,这个项与原估计器负相关,从而抵消部分波动。
在KL散度的估计中,我们需要找到一个量 ( g(x)g(x)g(x) ),满足:
Ex∼q[g(x)]=0 \mathbb{E}_{x \sim q} [g(x)] = 0 Ex∼q[g(x)]=0
然后构造新的估计器:
knew=k1+λg(x) k_{\text{new}} = k_1 + \lambda g(x) knew=k1+λg(x)
只要 ( Ex∼q[g(x)]=0\mathbb{E}_{x \sim q} [g(x)] = 0Ex∼q[g(x)]=0 ),新估计器仍然无偏:
Ex∼q[knew]=Ex∼q[k1]+λEx∼q[g(x)]=KL[q,p]+0=KL[q,p] \mathbb{E}_{x \sim q} [k_{\text{new}}] = \mathbb{E}_{x \sim q} [k_1] + \lambda \mathbb{E}_{x \sim q} [g(x)] = KL[q, p] + 0 = KL[q, p] Ex∼q[knew]=Ex∼q[k1]+λEx∼q[g(x)]=KL[q,p]+0=KL[q,p]
Schulman注意到,一个自然的控制变量是:
g(x)=r−1=p(x)q(x)−1 g(x) = r - 1 = \frac{p(x)}{q(x)} - 1 g(x)=r−1=q(x)p(x)−1
其期望为:
Ex∼q[r−1]=Ex∼q[p(x)q(x)]−1=∫q(x)⋅p(x)q(x) dx−1=∫p(x) dx−1=1−1=0 \mathbb{E}_{x \sim q} [r - 1] = \mathbb{E}_{x \sim q} \left[ \frac{p(x)}{q(x)} \right] - 1 = \int q(x) \cdot \frac{p(x)}{q(x)} \, dx - 1 = \int p(x) \, dx - 1 = 1 - 1 = 0 Ex∼q[r−1]=Ex∼q[q(x)p(x)]−1=∫q(x)⋅q(x)p(x)dx−1=∫p(x)dx−1=1−1=0
因此,( r−1r - 1r−1 ) 是一个期望为零的量,适合作为控制变量。于是,Schulman构造了新的估计器:
knew=k1+λ(r−1)=−logr+λ(r−1) k_{\text{new}} = k_1 + \lambda (r - 1) = -\log r + \lambda (r - 1) knew=k1+λ(r−1)=−logr+λ(r−1)
这里的 ( λ\lambdaλ ) 是一个可调参数,目标是通过选择合适的 ( λ\lambdaλ ) 最小化估计器的方差。
3. 选择 ( λ\lambdaλ ) 以保证非负性
理论上,可以通过最小化 ( knewk_{\text{new}}knew ) 的方差来求解最优的 ( λ\lambdaλ )。方差为:
Var(knew)=Var(−logr+λ(r−1))=Var(−logr)+λ2Var(r−1)+2λCov(−logr,r−1) \text{Var}(k_{\text{new}}) = \text{Var}(-\log r + \lambda (r - 1)) = \text{Var}(-\log r) + \lambda^2 \text{Var}(r - 1) + 2 \lambda \text{Cov}(-\log r, r - 1) Var(knew)=Var(−logr+λ(r−1))=Var(−logr)+λ2Var(r−1)+2λCov(−logr,r−1)
最小化方差需要计算协方差和方差项,但这些项依赖于 ( ppp ) 和 ( qqq ) 的具体形式,解析求解往往复杂。Schulman提到,这种方法得到的 ( λ\lambdaλ ) 通常难以计算,因此他采用了更简单的策略:选择 ( λ\lambdaλ ) 使估计器 始终非负,从而直观上降低方差并与KL散度的非负性质一致。
Schulman利用对数函数的凹性,注意到:
logu≤u−1,∀u>0 \log u \leq u - 1, \quad \forall u > 0 logu≤u−1,∀u>0
等号在 ( u=1u = 1u=1 ) 时成立。这意味着:
−logr≥−(r−1) -\log r \geq -(r - 1) −logr≥−(r−1)
因此,如果选择 ( λ=1\lambda = 1λ=1 ),估计器变为:
k3=−logr+(r−1)=(r−1)−logr k_3 = -\log r + (r - 1) = (r - 1) - \log r k3=−logr+(r−1)=(r−1)−logr
我们可以验证 ( k3k_3k3 ) 的非负性:
k3=(r−1)−logr k_3 = (r - 1) - \log r k3=(r−1)−logr
根据凹性不等式 ( logr≤r−1\log r \leq r - 1logr≤r−1 ),有:
(r−1)−logr≥(r−1)−(r−1)=0 (r - 1) - \log r \geq (r - 1) - (r - 1) = 0 (r−1)−logr≥(r−1)−(r−1)=0
因此,( k3≥0k_3 \geq 0k3≥0 ),并且在 ( r=1r = 1r=1 )(即 ( p(x)=q(x)p(x) = q(x)p(x)=q(x) ))时,( k3=0k_3 = 0k3=0 )。这种非负性使得 ( k3k_3k3 ) 的样本值更稳定,避免了 ( k1k_1k1 ) 中正负抵消导致的高方差。
4. 验证 ( k3k_3k3 ) 的无偏性
我们需要确认 ( k3k_3k3 ) 是否无偏:
Ex∼q[k3]=Ex∼q[(r−1)−logr]=Ex∼q[r−1]−Ex∼q[logr] \mathbb{E}_{x \sim q} [k_3] = \mathbb{E}_{x \sim q} [(r - 1) - \log r] = \mathbb{E}_{x \sim q} [r - 1] - \mathbb{E}_{x \sim q} [\log r] Ex∼q[k3]=Ex∼q[(r−1)−logr]=Ex∼q[r−1]−Ex∼q[logr]
- 第一项:( Ex∼q[r−1]=0\mathbb{E}_{x \sim q} [r - 1] = 0Ex∼q[r−1]=0 )(如前所述)。
- 第二项:( Ex∼q[logr]=Ex∼q[logp(x)q(x)]=Ex∼q[logq(x)p(x)]=KL[q,p]\mathbb{E}_{x \sim q} [\log r] = \mathbb{E}_{x \sim q} \left[ \log \frac{p(x)}{q(x)} \right] = \mathbb{E}_{x \sim q} \left[ \log \frac{q(x)}{p(x)} \right] = KL[q, p]Ex∼q[logr]=Ex∼q[logq(x)p(x)]=Ex∼q[logp(x)q(x)]=KL[q,p] ).
因此:
Ex∼q[k3]=0−(−KL[q,p])=KL[q,p] \mathbb{E}_{x \sim q} [k_3] = 0 - (-KL[q, p]) = KL[q, p] Ex∼q[k3]=0−(−KL[q,p])=KL[q,p]
( k3k_3k3 ) 是无偏的,满足我们对KL散度估计器的要求。
5. ( k3k_3k3 ) 与Bregman散度的联系
Schulman指出,( k3=(r−1)−logrk_3 = (r - 1) - \log rk3=(r−1)−logr ) 的形式可以解释为一个 Bregman散度。Bregman散度衡量一个凸函数与其在某点的切线之间的垂直距离。对于凸函数 ( f(u)=−loguf(u) = -\log uf(u)=−logu ),其导数为:
f′(u)=−1u f'(u) = -\frac{1}{u} f′(u)=−u1
在点 ( u=1u = 1u=1 ) 处的切线为:
f(1)+f′(1)(u−1)=−log1−11(u−1)=−(u−1) f(1) + f'(1)(u - 1) = -\log 1 - \frac{1}{1}(u - 1) = -(u - 1) f(1)+f′(1)(u−1)=−log1−11(u−1)=−(u−1)
Bregman散度定义为:
Df(u,1)=f(u)−f(1)−f′(1)(u−1) D_f(u, 1) = f(u) - f(1) - f'(1)(u - 1) Df(u,1)=f(u)−f(1)−f′(1)(u−1)
代入 ( f(u)=−loguf(u) = -\log uf(u)=−logu ):
Df(r,1)=(−logr)−(−log1)−(−11)(r−1)=−logr−0+(r−1)=(r−1)−logr D_f(r, 1) = (-\log r) - (-\log 1) - \left(-\frac{1}{1}\right)(r - 1) = -\log r - 0 + (r - 1) = (r - 1) - \log r Df(r,1)=(−logr)−(−log1)−(−11)(r−1)=−logr−0+(r−1)=(r−1)−logr
这正是 ( k3k_3k3 ) 的形式!因此,( k3k_3k3 ) 可以看作 ( f(u)=−loguf(u) = -\log uf(u)=−logu ) 在 ( r=1r = 1r=1 ) 处的Bregman散度。这种解释提供了理论上的优雅性:( k3k_3k3 ) 不仅无偏,还通过Bregman散度的非负性自然降低了方差。
6. 实验验证
Schulman通过实验比较了 ( k1k_1k1 )、( k2k_2k2 ) 和 ( k3k_3k3 )。例如,当 ( q=N(0,1)q = \mathcal{N}(0, 1)q=N(0,1) )、( p=N(0.1,1)p = \mathcal{N}(0.1, 1)p=N(0.1,1) )(真实KL散度为0.005)时:
- ( k1k_1k1 ): 无偏,标准差为真实值的20倍。
- ( k2k_2k2 ): 偏倚为0.2%,标准差为真实值的1.42倍。
- ( k3k_3k3 ): 无偏,标准差为真实值的1.42倍。
当KL散度较大(如 ( p=N(1,1)p = \mathcal{N}(1, 1)p=N(1,1) ),真实KL散度为0.5)时:
- ( k3k_3k3 ): 无偏,标准差为真实值的1.7倍,优于 ( k2k_2k2 ) 的1.73倍和25%偏倚。
这些结果表明,(k3k_3k3 ) 成功结合了无偏性和低方差。
( k3k_3k3 ) 与 ( k2k_2k2 ) 推导的联系
虽然 ( k2k_2k2 ) 和 ( k3k_3k3 ) 的最终形式不同,但它们的推导和思考过程有以下联系:
-
共同目标:降低方差
- ( k2k_2k2 ) 通过平方 ( logr\log rlogr )(即 (12(logr)2\frac{1}{2} (\log r)^221(logr)2 ))使估计值非负,强调分布差异并降低方差。
- ( k3k_3k3 ) 通过控制变量 ( r−1r - 1r−1 ) 和利用 ( logr≤r−1\log r \leq r - 1logr≤r−1 ) 的凹性不等式,同样追求非负性和低方差。
- 两者都试图解决 ( k1=−logrk_1 = -\log rk1=−logr ) 的高方差问题,直观上希望估计器与KL散度的非负性质一致。
-
f-散度与Bregman散度的理论支持
- ( k2k_2k2 ) 的期望是一个f-散度(( f(u)=12(logu)2f(u) = \frac{1}{2} (\log u)^2f(u)=21(logu)2 )),在 ( p≈qp \approx qp≈q ) 时通过二阶近似(( f′′(1)=1f''(1) = 1f′′(1)=1 ))与KL散度等价。
- ( k3k_3k3 ) 直接是 ( KL[q,p]KL[q, p]KL[q,p] ) 的无偏估计器,且其形式对应于 ( f(u)=−loguf(u) = -\log uf(u)=−logu ) 的Bregman散度。Bregman散度可以看作f-散度的一种特殊形式(当考虑凸函数与其切线的距离时)。
- 两者都利用了凸函数的性质(( k2k_2k2 ) 通过凸函数 ( f(u)f(u)f(u)),(k3k_3k3 ) 通过凹函数 ( −logu-\log u−logu ) 的Bregman散度),体现了Schulman从广义散度框架出发的统一思路。
-
从有偏到无偏的改进
- ( k2k_2k2 ) 是Schulman的初步尝试,牺牲了无偏性换取低方差,但通过f-散度的二阶等价性保证了低偏倚。
- ( k3k_3k3 ) 是对 ( k2k_2k2 ) 的进一步改进,通过控制变量方法恢复了无偏性,同时保留了低方差的优点。( k3k_3k3 ) 的非负性(通过 ( logr≤r−1\log r \leq r - 1logr≤r−1 ))与 ( k2k_2k2 ) 的非负性(通过平方)有相似的直观动机。
-
设计思路的演进
- ( k2k_2k2 ) 的设计更像是一种启发式尝试:用平方函数放大差异并确保非负,然后通过f-散度理论验证其合理性。
- ( k3k_3k3 ) 的设计更有系统性,基于控制变量的统计方法,并通过Bregman散度的数学结构提供理论支撑。Schulman在推导 ( k3k_3k3 ) 时明确考虑了如何在 ( k1k_1k1 ) 的基础上系统性地降低方差,而 ( k2k_2k2 ) 更像是探索过程中的一个中间产物。
-
推广到其他散度
- Schulman在博客中提到,( k2k_2k2 ) 和 ( k3k_3k3 ) 的思路可以推广到其他f-散度。例如,( k3k_3k3 ) 的形式 ( f(r)−f′(1)(r−1)f(r) - f'(1)(r - 1)f(r)−f′(1)(r−1) ) 是一个通用的f-散度估计器,适用于任意凸函数 (fff )。
- ( k2k_2k2 ) 的f-散度形式为其他估计器的设计提供了灵感,而 ( k3k_3k3 ) 的Bregman散度形式进一步揭示了这种估计器与凸优化理论的深层联系。
总结
( k3=(r−1)−logrk_3 = (r - 1) - \log rk3=(r−1)−logr ) 的来源是通过控制变量方法改进 ( k1k_1k1 ),利用 ( r−1r - 1r−1 ) 的零期望和 ( logr≤r−1\log r \leq r - 1logr≤r−1 ) 的凹性不等式,构造了一个无偏且低方差的KL散度估计器。其推导过程如下:
- 从 ( k1=−logrk_1 = -\log rk1=−logr ) 的高方差问题出发,引入控制变量 ( r−1r - 1r−1 )。
- 构造新估计器 ( k3=−logr+λ(r−1)k_3 = -\log r + \lambda (r - 1)k3=−logr+λ(r−1) ),选择 ( λ=1\lambda = 1λ=1 ) 保证非负性。
- 验证 ( k3k_3k3 ) 的无偏性和非负性,并通过Bregman散度提供理论解释。
- 实验确认 ( k3k_3k3 ) 兼顾无偏和低方差,优于 ( k2k_2k2 )。
与 ( k2k_2k2 ) 的联系在于:
- 两者都旨在降低 ( k1k_1k1 ) 的方差,追求非负性以匹配KL散度的性质。
- ( k2k_2k2 ) 通过f-散度框架启发了对分布差异的度量,( k3k_3k3 ) 则通过控制变量和Bregman散度实现了更系统的优化。
- ( k3k_3k3 ) 可以看作 ( k2k_2k2 ) 思路的延伸,从有偏但低方差的探索(( k2k_2k2 ))进化到无偏且低方差的解决方案(( k3k_3k3 ))。
这种推导思路不仅在理论上优雅,还在DeepSeek GRPO等强化学习应用中得到了实践验证,展现了其在LLM优化中的重要价值。
后记
2025年4月20日于上海,在grok 3大模型辅助下完成。
更多推荐



所有评论(0)