
KL散度近似方法介绍:从John Schulman的博客到DeepSeek GRPO的应用
John Schulman在其2020年3月7日的博客中详细探讨了如何通过蒙特卡洛方法近似KL散度,并提出了一种低方差、无偏的估计器。这一方法不仅在理论上具有重要意义,还被DeepSeek的GRPO算法所采用。
KL散度近似方法介绍:从John Schulman的博客到DeepSeek GRPO的应用
KL散度(Kullback-Leibler Divergence)是衡量两个概率分布之间差异的重要指标,广泛应用于机器学习、强化学习等领域。John Schulman在其2020年3月7日的博客中详细探讨了如何通过蒙特卡洛方法近似KL散度,并提出了一种低方差、无偏的估计器。这一方法不仅在理论上具有重要意义,还被DeepSeek的GRPO算法所采用。本文将基于Schulman的博客内容,介绍KL散度的近似方法及其在DeepSeek GRPO中的应用。
blog:http://joschu.net/blog/kl-approx.html
首先看一下John Schulman是谁:PPO的作者,领导创建了ChatGPT,OpenAI的cofounder。
source:http://joschu.net/index.html
source:https://scholar.google.com/citations?user=itSa94cAAAAJ
KL散度的基本定义
KL散度衡量两个概率分布 ( p p p ) 和 ( q q q ) 之间的差异,定义为:
K L [ q , p ] = ∑ x q ( x ) log q ( x ) p ( x ) = E x ∼ q [ log q ( x ) p ( x ) ] KL[q, p] = \sum_x q(x) \log \frac{q(x)}{p(x)} = \mathbb{E}_{x \sim q} \left[ \log \frac{q(x)}{p(x)} \right] KL[q,p]=x∑q(x)logp(x)q(x)=Ex∼q[logp(x)q(x)]
在实际应用中,精确计算KL散度通常不可行,因为:
- 计算所有 ( x x x ) 的概率和需要过多的计算或内存。
- 分布 ( p p p ) 和 ( q q q ) 可能没有闭合表达式。
- 在强化学习等场景中,KL散度常作为诊断工具,仅存储对数概率以简化代码。
因此,蒙特卡洛方法成为估计KL散度的常用策略。假设我们从分布 ( q q q ) 中采样 ( x 1 , x 2 , … x_1, x_2, \dots x1,x2,… ),如何构造一个既无偏又低方差的估计器是关键问题。
传统估计器及其局限性
最直接的蒙特卡洛估计器是基于KL散度的定义:
k 1 = log q ( x ) p ( x ) = − log r , 其中 r = p ( x ) q ( x ) k_1 = \log \frac{q(x)}{p(x)} = -\log r, \quad \text{其中} \quad r = \frac{p(x)}{q(x)} k1=logp(x)q(x)=−logr,其中r=q(x)p(x)
这个估计器 ( k 1 k_1 k1 ) 是无偏的,即其期望等于真实的KL散度。然而,由于 ( log r \log r logr ) 的值在正负之间变化(当 ( r > 1 r > 1 r>1 ) 时为正,当 ( r < 1 r < 1 r<1 ) 时为负),其方差较高,而KL散度本身始终为正。这种高方差使得 ( k 1 k_1 k1 ) 在实际应用中表现不佳。
低方差的偏倚估计器
Schulman提出了一种替代估计器:
k 2 = 1 2 ( log p ( x ) q ( x ) ) 2 = 1 2 ( log r ) 2 k_2 = \frac{1}{2} \left( \log \frac{p(x)}{q(x)} \right)^2 = \frac{1}{2} (\log r)^2 k2=21(logq(x)p(x))2=21(logr)2
这个估计器 ( k 2 k_2 k2 ) 虽然有偏,但方差显著低于 ( k 1 k_1 k1 )。其优点在于:
- 始终为正:每个样本都反映了 ( p p p ) 和 ( q q q ) 之间的差异,且结果非负,与KL散度的性质一致。
- 低偏倚:( k 2 k_2 k2 ) 的期望是一个f-散度,其形式为:
E q [ k 2 ] = E q [ 1 2 ( log r ) 2 ] \mathbb{E}_q[k_2] = \mathbb{E}_q \left[ \frac{1}{2} (\log r)^2 \right] Eq[k2]=Eq[21(logr)2]
这是一个以 ( f ( x ) = 1 2 ( log x ) 2 f(x) = \frac{1}{2} (\log x)^2 f(x)=21(logx)2 ) 定义的f-散度。当 ( p p p ) 和 ( q q q ) 接近时,所有可微的f-散度在二阶近似下与KL散度等价,因此 ( k 2 k_2 k2 ) 的偏倚非常小。
实验表明,当 ( q = N ( 0 , 1 ) q = \mathcal{N}(0, 1) q=N(0,1) )、( p = N ( 0.1 , 1 ) p = \mathcal{N}(0.1, 1) p=N(0.1,1) ) 时(真实KL散度为0.005),( k 2 k_2 k2 ) 的偏倚仅为0.2%,标准差为真实值的1.42倍,远低于 ( k 1 k_1 k1 ) 的20倍。
无偏且低方差的估计器
为了兼顾无偏和低方差,Schulman引入了控制变量(control variate)方法。利用 ( E q [ r − 1 ] = 0 \mathbb{E}_q[r - 1] = 0 Eq[r−1]=0 ),可以构造一个新的估计器:
k 3 = − log r + λ ( r − 1 ) k_3 = -\log r + \lambda (r - 1) k3=−logr+λ(r−1)
通过选择适当的 ( λ \lambda λ ),可以降低方差。当 ( λ = 1 \lambda = 1 λ=1 ) 时,估计器变为:
k 3 = ( r − 1 ) − log r k_3 = (r - 1) - \log r k3=(r−1)−logr
由于对数的凹性,( log r ≤ r − 1 \log r \leq r - 1 logr≤r−1 ),因此 ( k 3 k_3 k3) 始终为正。这个估计器不仅无偏,而且方差低于 ( k 1 k_1 k1 )。实验表明,当真实KL散度为0.5时,( k 3 k_3 k3 ) 的标准差为真实值的1.7倍,低于 ( k 2 k_2 k2 ) 的1.73倍,且无偏。
推广到其他f-散度
Schulman进一步指出,上述方法可以推广到任意f-散度。对于 ( K L [ p , q ] KL[p, q] KL[p,q] ),对应的f-散度估计器为:
r log r − ( r − 1 ) r \log r - (r - 1) rlogr−(r−1)
这个估计器同样基于凸函数与其切线的距离(Bregman散度),保证了非负性和低方差。
DeepSeek GRPO中的应用
DeepSeek的GRPO算法直接采用了Schulman提出的无偏估计器 ( k 3 k_3 k3 )。具体来说,GRPO使用以下估计器来近似 ( D K L ( π θ ∣ ∣ π r e f ) D_{KL}(\pi_\theta || \pi_{ref}) DKL(πθ∣∣πref) ):
D K L ( π θ ∣ ∣ π r e f ) = π r e f ( o i , t ∣ q , o i , < t ) π θ ( o i , t ∣ q , o i , < t ) − log π r e f ( o i , t ∣ q , o i , < t ) π θ ( o i , t ∣ q , o i , < t ) − 1 D_{KL}(\pi_\theta || \pi_{ref}) = \frac{\pi_{ref}(o_{i,t} | q, o_{i,<t})}{\pi_\theta(o_{i,t} | q, o_{i,<t})} - \log \frac{\pi_{ref}(o_{i,t} | q, o_{i,<t})}{\pi_\theta(o_{i,t} | q, o_{i,<t})} - 1 DKL(πθ∣∣πref)=πθ(oi,t∣q,oi,<t)πref(oi,t∣q,oi,<t)−logπθ(oi,t∣q,oi,<t)πref(oi,t∣q,oi,<t)−1
其中,( π θ \pi_\theta πθ ) 表示策略网络,( π r e f \pi_{ref} πref ) 表示参考策略,( o i , t o_{i,t} oi,t ) 表示在时间 ( t t t ) 的观测,( q q q ) 和 ( o i , < t o_{i,<t} oi,<t ) 分别表示上下文和历史观测。令 ( r = π r e f ( o i , t ∣ q , o i , < t ) π θ ( o i , t ∣ q , o i , < t ) r = \frac{\pi_{ref}(o_{i,t} | q, o_{i,<t})}{\pi_\theta(o_{i,t} | q, o_{i,<t})} r=πθ(oi,t∣q,oi,<t)πref(oi,t∣q,oi,<t) ),该估计器可重写为:
k 3 = ( r − 1 ) − log r k_3 = (r - 1) - \log r k3=(r−1)−logr
这个估计器的优点在于:
- 无偏性:保证估计的期望等于真实的KL散度。
- 非负性:由于 ( log r ≤ r − 1 \log r \leq r - 1 logr≤r−1 ),估计值始终为正,与KL散度的性质一致。
- 低方差:通过控制变量降低了估计的方差,提升了算法的稳定性。
与传统的KL惩罚项(例如直接使用 ( log π θ π r e f \log \frac{\pi_\theta}{\pi_{ref}} logπrefπθ))相比,GRPO的估计器在强化学习中表现更稳定,尤其是在策略优化需要精确控制分布差异时。
实验结果
Schulman的实验验证了 ( k 3 k_3 k3 ) 的优越性。例如,当 ( q = N ( 0 , 1 ) q = \mathcal{N}(0, 1) q=N(0,1) )、( p = N ( 1 , 1 ) p = \mathcal{N}(1, 1) p=N(1,1) )(真实KL散度为0.5)时:
- ( k 1 k_1 k1 ): 无偏,标准差为真实值的2倍。
- ( k 2 k_2 k2 ): 偏倚为25%,标准差为真实值的1.73倍。
- ( k 3 k_3 k3 ): 无偏,标准差为真实值的1.7倍。
这些结果表明,( k 3 k_3 k3 ) 在保持无偏的同时,显著降低了方差。
总结
John Schulman的博客提供了一种优雅的KL散度近似方法,通过引入控制变量构造了无偏且低方差的估计器 ( k 3 = ( r − 1 ) − log r k_3 = (r - 1) - \log r k3=(r−1)−logr )。这一方法不仅在理论上具有普适性(可推广到任意f-散度),还在DeepSeek的GRPO算法中得到了实际应用。GRPO利用这一估计器实现了稳定的策略优化,展示了其在强化学习中的重要价值。对于需要在高维分布上估计KL散度的研究者和工程师来说,Schulman的方法提供了一个兼顾理论严谨性和实践效果的解决方案。
蒙特卡洛估计器
什么是蒙特卡洛估计器?
蒙特卡洛方法(Monte Carlo Method)是一种通过随机采样来估计复杂数学量的方法。它特别适合用来近似那些难以直接计算的积分、期望或求和,尤其是在高维空间或解析解不可得的情况下。蒙特卡洛估计器的核心思想是:通过从某个分布中抽取大量随机样本,计算这些样本上的函数值,然后用样本均值来近似目标量的期望。
在KL散度的背景下,KL散度 ( K L [ q , p ] KL[q, p] KL[q,p] ) 定义为:
K L [ q , p ] = ∑ x q ( x ) log q ( x ) p ( x ) = E x ∼ q [ log q ( x ) p ( x ) ] KL[q, p] = \sum_x q(x) \log \frac{q(x)}{p(x)} = \mathbb{E}_{x \sim q} \left[ \log \frac{q(x)}{p(x)} \right] KL[q,p]=x∑q(x)logp(x)q(x)=Ex∼q[logp(x)q(x)]
这里的 ( E x ∼ q \mathbb{E}_{x \sim q} Ex∼q ) 表示从分布 ( q q q ) 中采样的期望。如果我们无法直接计算整个求和(例如,分布 ( q q q ) 和 ( p p p ) 是高维的,或者没有闭合表达式),就可以用蒙特卡洛方法来估计这个期望。具体来说:
- 从分布 ( q q q ) 中抽取 ( N N N ) 个独立样本 ( x 1 , x 2 , … , x N x_1, x_2, \dots, x_N x1,x2,…,xN )。
- 对每个样本 ( x i x_i xi ),计算函数 ( f ( x i ) = log q ( x i ) p ( x i ) f(x_i) = \log \frac{q(x_i)}{p(x_i)} f(xi)=logp(xi)q(xi) )。
- 用样本均值来估计期望:
K L [ q , p ] ≈ 1 N ∑ i = 1 N log q ( x i ) p ( x i ) KL[q, p] \approx \frac{1}{N} \sum_{i=1}^N \log \frac{q(x_i)}{p(x_i)} KL[q,p]≈N1i=1∑Nlogp(xi)q(xi)
这个均值就是蒙特卡洛估计器,称为 ( k 1 k_1 k1 ) 在你的问题中:
k 1 = log q ( x ) p ( x ) = − log r , 其中 r = p ( x ) q ( x ) k_1 = \log \frac{q(x)}{p(x)} = -\log r, \quad \text{其中} \quad r = \frac{p(x)}{q(x)} k1=logp(x)q(x)=−logr,其中r=q(x)p(x)
蒙特卡洛估计器的作用:它通过有限样本近似真实的KL散度值。KL散度本质上衡量两个分布 ( q q q ) 和 ( p p p ) 之间的“距离”(尽管严格来说它不是对称的距离),而蒙特卡洛估计器试图通过采样来捕捉这种差异。
蒙特卡洛估计器是否反映分布的距离?
是的,蒙特卡洛估计器 ( k 1 k_1 k1 ) 的期望等于真实的KL散度,因此它确实反映了两个分布之间的“距离”。具体来说:
- 无偏性:( E x ∼ q [ k 1 ] = E x ∼ q [ log q ( x ) p ( x ) ] = K L [ q , p ] \mathbb{E}_{x \sim q}[k_1] = \mathbb{E}_{x \sim q} \left[ \log \frac{q(x)}{p(x)} \right] = KL[q, p] Ex∼q[k1]=Ex∼q[logp(x)q(x)]=KL[q,p] )。这意味着如果采样次数 ( N → ∞ N \to \infty N→∞ ),估计器的平均值会收敛到真实的KL散度。
- 分布差异的体现:( k 1 = log q ( x ) p ( x ) k_1 = \log \frac{q(x)}{p(x)} k1=logp(x)q(x) ) 直接衡量了在样本点 ( x x x ) 上,分布 ( q q q ) 和 ( p p p ) 的概率密度(或概率)的对数差异。如果 ( q ( x ) ≈ p ( x ) q(x) \approx p(x) q(x)≈p(x) ),则 ( log q ( x ) p ( x ) ≈ 0 \log \frac{q(x)}{p(x)} \approx 0 logp(x)q(x)≈0 ),表明分布很接近;如果 ( q ( x ) q(x) q(x) ) 和 ( p ( x ) p(x) p(x) ) 差异很大,则 ( k 1 k_1 k1 ) 的绝对值会较大,反映分布的差异。
然而,( k 1 k_1 k1 ) 的局限性在于其高方差。由于 ( log r \log r logr ) (其中 ( r = p ( x ) q ( x ) r = \frac{p(x)}{q(x)} r=q(x)p(x) ))可能为正或负,样本值的波动会导致估计值的方差较大,尤其当 ( p p p ) 和 ( q q q ) 差异较大时。这种高方差使得 ( k 1 k_1 k1 ) 在实际应用中不够稳定。
通过例子说明
让我们通过一个简单的例子来说明蒙特卡洛估计器如何工作,以及它如何反映分布的距离。
场景
假设有两个一维正态分布:
- ( q = N ( 0 , 1 ) q = \mathcal{N}(0, 1) q=N(0,1) ) (均值0,标准差1)
- ( p = N ( 0.5 , 1 ) p = \mathcal{N}(0.5, 1) p=N(0.5,1) ) (均值0.5,标准差1)
真实的KL散度 ( K L [ q , p ] KL[q, p] KL[q,p] ) 对于正态分布有解析解:
K L [ q , p ] = 1 2 ( ( μ q − μ p ) 2 σ p 2 ) = 1 2 ( 0 − 0.5 ) 2 1 2 = 0.25 2 = 0.125 KL[q, p] = \frac{1}{2} \left( \frac{(\mu_q - \mu_p)^2}{\sigma_p^2} \right) = \frac{1}{2} \frac{(0 - 0.5)^2}{1^2} = \frac{0.25}{2} = 0.125 KL[q,p]=21(σp2(μq−μp)2)=2112(0−0.5)2=20.25=0.125
我们用蒙特卡洛方法来估计这个值。
蒙特卡洛估计步骤
- 采样:从 ( q = N ( 0 , 1 ) q = \mathcal{N}(0, 1) q=N(0,1) ) 中抽取 ( N = 1000 N = 1000 N=1000 ) 个样本 ( x 1 , x 2 , … , x 1000 x_1, x_2, \dots, x_{1000} x1,x2,…,x1000 )。假设样本是随机生成的,例如 ( x 1 = − 0.3 , x 2 = 1.2 , x 3 = − 1.5 , … x_1 = -0.3, x_2 = 1.2, x_3 = -1.5, \dots x1=−0.3,x2=1.2,x3=−1.5,… )。
- 计算 ( k 1 k_1 k1 ):对每个样本 ( x i x_i xi ),计算:
k 1 ( x i ) = log q ( x i ) p ( x i ) k_1(x_i) = \log \frac{q(x_i)}{p(x_i)} k1(xi)=logp(xi)q(xi)
正态分布的概率密度函数为:
q ( x ) = 1 2 π exp ( − x 2 2 ) , p ( x ) = 1 2 π exp ( − ( x − 0.5 ) 2 2 ) q(x) = \frac{1}{\sqrt{2\pi}} \exp\left(-\frac{x^2}{2}\right), \quad p(x) = \frac{1}{\sqrt{2\pi}} \exp\left(-\frac{(x - 0.5)^2}{2}\right) q(x)=2π1exp(−2x2),p(x)=2π1exp(−2(x−0.5)2)
因此:
log q ( x i ) p ( x i ) = log q ( x i ) − log p ( x i ) = ( − x i 2 2 ) − ( − ( x i − 0.5 ) 2 2 ) \log \frac{q(x_i)}{p(x_i)} = \log q(x_i) - \log p(x_i) = \left( -\frac{x_i^2}{2} \right) - \left( -\frac{(x_i - 0.5)^2}{2} \right) logp(xi)q(xi)=logq(xi)−logp(xi)=(−2xi2)−(−2(xi−0.5)2)
= − x i 2 2 + ( x i − 0.5 ) 2 2 = ( x i − 0.5 ) 2 − x i 2 2 = x i 2 − x i + 0.25 − x i 2 2 = − x i + 0.25 2 = -\frac{x_i^2}{2} + \frac{(x_i - 0.5)^2}{2} = \frac{(x_i - 0.5)^2 - x_i^2}{2} = \frac{x_i^2 - x_i + 0.25 - x_i^2}{2} = \frac{-x_i + 0.25}{2} =−2xi2+2(xi−0.5)2=2(xi−0.5)2−xi2=2xi2−xi+0.25−xi2=2−xi+0.25
k 1 ( x i ) = 0.25 − x i 2 k_1(x_i) = \frac{0.25 - x_i}{2} k1(xi)=20.25−xi
- 估计KL散度:计算所有样本的均值:
K L ^ [ q , p ] = 1 N ∑ i = 1 N k 1 ( x i ) = 1 N ∑ i = 1 N 0.25 − x i 2 \hat{KL}[q, p] = \frac{1}{N} \sum_{i=1}^N k_1(x_i) = \frac{1}{N} \sum_{i=1}^N \frac{0.25 - x_i}{2} KL^[q,p]=N1i=1∑Nk1(xi)=N1i=1∑N20.25−xi
假设我们采样得到的 ( x i x_i xi ) 的均值为 ( x ˉ ≈ 0 \bar{x} \approx 0 xˉ≈0 )(因为 ( x i ∼ N ( 0 , 1 ) x_i \sim \mathcal{N}(0, 1) xi∼N(0,1) )),则:
K L ^ [ q , p ] ≈ 1 1000 ∑ i = 1 1000 0.25 − x i 2 ≈ 0.25 − 0 2 = 0.125 \hat{KL}[q, p] \approx \frac{1}{1000} \sum_{i=1}^{1000} \frac{0.25 - x_i}{2} \approx \frac{0.25 - 0}{2} = 0.125 KL^[q,p]≈10001i=1∑100020.25−xi≈20.25−0=0.125
这个估计值接近真实的 ( K L [ q , p ] = 0.125 KL[q, p] = 0.125 KL[q,p]=0.125 ),说明 ( k 1 k_1 k1 ) 是无偏的。
高方差的表现
尽管 ( k 1 k_1 k1 ) 无偏,但它的值在样本间波动较大。例如:
- 如果 ( x i = 0 x_i = 0 xi=0 ),则 ( k 1 = 0.25 − 0 2 = 0.125 k_1 = \frac{0.25 - 0}{2} = 0.125 k1=20.25−0=0.125 )。
- 如果 ( x i = 2 x_i = 2 xi=2 ),则 ( k 1 = 0.25 − 2 2 = − 0.875 k_1 = \frac{0.25 - 2}{2} = -0.875 k1=20.25−2=−0.875 )。
- 如果 ( x i = − 2 x_i = -2 xi=−2 ),则 ( k 1 = 0.25 − ( − 2 ) 2 = 1.125 k_1 = \frac{0.25 - (-2)}{2} = 1.125 k1=20.25−(−2)=1.125 ).
这些值的正负波动导致样本均值的方差较高。如果采样次数 ( N N N ) 不够多,估计值可能偏离0.125较远。
反映分布距离
从 ( k 1 ( x i ) = 0.25 − x i 2 k_1(x_i) = \frac{0.25 - x_i}{2} k1(xi)=20.25−xi ) 可以看出,( k 1 k_1 k1 ) 的值直接与样本 ( x i x_i xi) 有关,而样本是由 ( q q q ) 生成的。KL散度的值0.125反映了 ( q q q ) 和 ( p p p ) 之间的均值差异(0 vs 0.5)。如果我们改变 ( p p p ) 的均值(例如 ( p = N ( 1 , 1 ) p = \mathcal{N}(1, 1) p=N(1,1) )),KL散度会变大(( K L [ q , p ] = 0.5 KL[q, p] = 0.5 KL[q,p]=0.5 )),相应的 ( k 1 k_1 k1 ) 的均值也会变大,反映出分布间更大的差异。
总结
- 蒙特卡洛估计器是通过从分布 ( q q q ) 中采样,用样本上的函数值均值来近似KL散度的期望。
- 反映分布距离:( k 1 = log q ( x ) p ( x ) k_1 = \log \frac{q(x)}{p(x)} k1=logp(x)q(x) ) 的期望等于KL散度,因此它确实捕捉了 ( q q q ) 和 ( p p p ) 之间的差异。
- 局限性:高方差使得 ( k 1 k_1 k1 ) 在有限样本下不够稳定,如例子中 ( k 1 k_1 k1 ) 的值在正负间波动。
- 改进方向:Schulman的博客中提出了 ( k 2 k_2 k2 ) 和 ( k 3 k_3 k3 ) 等估计器,通过降低方差(例如保证非负性或引入控制变量)来改进 ( k 1 k_1 k1 ),在DeepSeek GRPO等应用中表现出色。
log相减就是k1估计器
在许多大语言模型(LLM)强化学习(RL)相关的论文中,尤其是那些涉及策略优化(如PPO、TRPO或RLHF)的论文,使用的KL散度估计通常是基于 ( k 1 k_1 k1 ) 估计器,即直接计算对数概率的差值:
k 1 = log q ( x ) p ( x ) = log q ( x ) − log p ( x ) k_1 = \log \frac{q(x)}{p(x)} = \log q(x) - \log p(x) k1=logp(x)q(x)=logq(x)−logp(x)
或者在策略优化的上下文中,写作:
k 1 = log π ref ( a ∣ s ) π θ ( a ∣ s ) = log π ref ( a ∣ s ) − log π θ ( a ∣ s ) k_1 = \log \frac{\pi_{\text{ref}}(a|s)}{\pi_{\theta}(a|s)} = \log \pi_{\text{ref}}(a|s) - \log \pi_{\theta}(a|s) k1=logπθ(a∣s)πref(a∣s)=logπref(a∣s)−logπθ(a∣s)
其中,( π θ \pi_{\theta} πθ ) 是当前策略,( π ref \pi_{\text{ref}} πref ) 是参考策略(如初始策略或SFT模型),( a a a ) 是动作(在LLM中通常是生成的token),( s s s ) 是状态(上下文或提示)。这种形式的KL散度估计正是Schulman博客中描述的 ( k 1 k_1 k1 ) 估计器,基于蒙特卡洛方法,通过采样计算对数概率差的均值来近似 ( K L [ π θ ∣ ∣ π ref ] KL[\pi_{\theta} || \pi_{\text{ref}}] KL[πθ∣∣πref] ) 或 ( K L [ π ref ∣ ∣ π θ ] KL[\pi_{\text{ref}} || \pi_{\theta}] KL[πref∣∣πθ] )。
为什么常用 ( k 1 k_1 k1 ) 估计器?
-
简单直接:( k 1 k_1 k1 ) 是KL散度的直接蒙特卡洛估计,形式上就是对数概率的差值,易于实现。LLM通常会输出每个token的对数概率(logits经过softmax后取log),因此计算 ( log π θ ( a ∣ s ) \log \pi_{\theta}(a|s) logπθ(a∣s) ) 和 ( log π ref ( a ∣ s ) \log \pi_{\text{ref}}(a|s) logπref(a∣s) ) 是现成的。
-
无偏性:( k 1 k_1 k1 ) 是无偏估计器,其期望等于真实的KL散度:
E a ∼ π θ [ log π θ ( a ∣ s ) π ref ( a ∣ s ) ] = K L [ π θ ∣ ∣ π ref ] \mathbb{E}_{a \sim \pi_{\theta}} \left[ \log \frac{\pi_{\theta}(a|s)}{\pi_{\text{ref}}(a|s)} \right] = KL[\pi_{\theta} || \pi_{\text{ref}}] Ea∼πθ[logπref(a∣s)πθ(a∣s)]=KL[πθ∣∣πref]
这保证了在采样足够多的情况下,估计值会收敛到真实KL散度。
- 常见应用场景:在强化学习中,KL散度常用于以下场景:
- 正则化项:如PPO中,KL散度作为惩罚项,限制新策略 ( π θ \pi_{\theta} πθ ) 偏离参考策略 ( π ref \pi_{\text{ref}} πref ) 过远。
- 诊断工具:监控策略更新的稳定性。
- 奖励函数:在RLHF中,KL散度常用于奖励函数(如 ( R = R human − β K L [ π θ ∣ ∣ π ref ] R = R_{\text{human}} - \beta KL[\pi_{\theta} || \pi_{\text{ref}}] R=Rhuman−βKL[πθ∣∣πref] )),以平衡偏好和策略稳定性。
这些场景中,( k 1 k_1 k1 ) 的形式简单,且与策略梯度计算兼容,因此被广泛采用。
( k 1 k_1 k1 ) 估计器的局限性在LLM RL中的体现
尽管 ( k 1 k_1 k1 ) 常用,但正如Schulman的博客所指出的,它有高方差的缺点。这在LLM RL中可能表现为:
- 方差导致的不稳定:由于 ( k 1 = log π θ ( a ∣ s ) π ref ( a ∣ s ) k_1 = \log \frac{\pi_{\theta}(a|s)}{\pi_{\text{ref}}(a|s)} k1=logπref(a∣s)πθ(a∣s) ) 可能为正或负(当 ( π θ ( a ∣ s ) > π ref ( a ∣ s ) \pi_{\theta}(a|s) > \pi_{\text{ref}}(a|s) πθ(a∣s)>πref(a∣s) ) 时为正,反之为负),样本间的波动会导致KL散度估计不稳定,尤其在序列生成中,token序列较长时方差会累积。
- 采样效率:LLM生成序列的采样成本高(需要运行整个模型),如果 ( k 1 k_1 k1 ) 的高方差要求更多样本才能得到可靠估计,会增加计算开销。
- 极端概率的影响:当 ( π θ ( a ∣ s ) \pi_{\theta}(a|s) πθ(a∣s) ) 或 ( π ref ( a ∣ s ) \pi_{\text{ref}}(a|s) πref(a∣s) ) 接近0时,( log π θ ( a ∣ s ) π ref ( a ∣ s ) \log \frac{\pi_{\theta}(a|s)}{\pi_{\text{ref}}(a|s)} logπref(a∣s)πθ(a∣s) ) 可能变得很大或很小,导致估计值的异常波动。
总结
是的,许多LLM强化学习论文中使用的KL散度估计确实是 ( k 1 k_1 k1 ) 估计器(( log π θ π ref \log \frac{\pi_{\theta}}{\pi_{\text{ref}}} logπrefπθ )),因为它简单、无偏且易于实现。然而,其高方差可能导致不稳定,尤其在长序列生成或概率差异较大的情况下。一些前沿工作(如DeepSeek GRPO)采用了Schulman提出的 ( k 3 k_3 k3 ) 估计器,以获得更低的方差和更稳定的优化。如果你阅读的论文明确提到KL散度惩罚,通常默认是 ( k 1 k_1 k1 ),但需要检查具体实现或损失函数是否使用了改进的估计器(如 ( k 3 k_3 k3 ) 或其他变体)。
f-散度和k2
什么是f-散度?
f-散度(f-divergence)是一类用于衡量两个概率分布 ( p p p ) 和 ( q q q ) 之间差异的广义距离度量。它通过一个凸函数 ( f f f ) 定义,形式如下:
D f ( p , q ) = E x ∼ q [ f ( p ( x ) q ( x ) ) ] D_f(p, q) = \mathbb{E}_{x \sim q} \left[ f\left( \frac{p(x)}{q(x)} \right) \right] Df(p,q)=Ex∼q[f(q(x)p(x))]
其中:
- ( p ( x ) p(x) p(x) ) 和 ( q ( x ) q(x) q(x) ) 是分布 ( p p p ) 和 ( q q q ) 在点 ( x x x ) 上的概率密度(或概率质量,视离散/连续分布而定)。
- ( r = p ( x ) q ( x ) r = \frac{p(x)}{q(x)} r=q(x)p(x) ) 是概率比率。
- ( f : R + → R f: \mathbb{R}^+ \to \mathbb{R} f:R+→R ) 是一个凸函数,且通常满足 ( f ( 1 ) = 0 f(1) = 0 f(1)=0 ),以确保当 ( p = q p = q p=q ) 时散度为零。
f-散度的特点是它包含了许多常见的分布距离度量,通过选择不同的 ( f f f ) 函数,可以得到不同的散度。例如:
- KL散度:当 ( f ( u ) = u log u f(u) = u \log u f(u)=ulogu ),对应 ( D f ( p , q ) = K L ( p ∣ ∣ q ) = E q [ p ( x ) q ( x ) log p ( x ) q ( x ) ] D_f(p, q) = KL(p || q) = \mathbb{E}_{q} \left[ \frac{p(x)}{q(x)} \log \frac{p(x)}{q(x)} \right] Df(p,q)=KL(p∣∣q)=Eq[q(x)p(x)logq(x)p(x)] )。
- 逆KL散度:当 ( f ( u ) = − log u f(u) = -\log u f(u)=−logu ),对应 ( D f ( p , q ) = K L ( q ∣ ∣ p ) D_f(p, q) = KL(q || p) Df(p,q)=KL(q∣∣p) )。
- Jensen-Shannon散度:通过选择适当的 ( f f f ) 函数,可以推导JS散度。
- 总变差距离:当 ( f ( u ) = ∣ u − 1 ∣ f(u) = |u - 1| f(u)=∣u−1∣ ),对应总变差。
f-散度的优点是它提供了一个统一的框架来研究不同分布距离的性质,且许多f-散度在 ( p ≈ q p \approx q p≈q ) 时具有相似的二阶行为(如Schulman博客中提到的与KL散度的关系)。
Schulman博客中的f-散度背景
在John Schulman的博客《Approximating KL Divergence》中,他讨论了如何通过蒙特卡洛方法近似KL散度 ( K L [ q , p ] KL[q, p] KL[q,p] ),并提出了一种低方差但有偏的估计器 ( k 2 k_2 k2 )。他指出,( k 2 k_2 k2 ) 的期望是一个f-散度,并且在 ( p ≈ q p \approx q p≈q ) 时,所有可微的f-散度在二阶近似下与KL散度等价。这种等价性是 ( k 2 k_2 k2 ) 低偏倚的理论基础。
具体来说,Schulman定义了 ( k 2 k_2 k2 ) 估计器为:
k 2 = 1 2 ( log p ( x ) q ( x ) ) 2 = 1 2 ( log r ) 2 , 其中 r = p ( x ) q ( x ) k_2 = \frac{1}{2} \left( \log \frac{p(x)}{q(x)} \right)^2 = \frac{1}{2} (\log r)^2, \quad \text{其中} \quad r = \frac{p(x)}{q(x)} k2=21(logq(x)p(x))2=21(logr)2,其中r=q(x)p(x)
他进一步说明,( k 2 k_2 k2 ) 的期望是一个f-散度:
E x ∼ q [ k 2 ] = E x ∼ q [ 1 2 ( log r ) 2 ] \mathbb{E}_{x \sim q}[k_2] = \mathbb{E}_{x \sim q} \left[ \frac{1}{2} (\log r)^2 \right] Ex∼q[k2]=Ex∼q[21(logr)2]
这对应于f-散度中的凸函数:
f ( u ) = 1 2 ( log u ) 2 f(u) = \frac{1}{2} (\log u)^2 f(u)=21(logu)2
为了理解 ( k 2 k_2 k2 ) 的来源和为何它是一个合理的KL散度估计器,我们需要结合f-散度的性质和Schulman的推导思路进行分析。
推导 ( k 2 k_2 k2 ) 的来源
以下是结合Schulman博客内容,推导 ( k 2 k_2 k2 ) 如何作为KL散度估计器的过程:
1. KL散度的蒙特卡洛估计
KL散度 ( K L [ q , p ] KL[q, p] KL[q,p] ) 的定义为:
K L [ q , p ] = E x ∼ q [ log q ( x ) p ( x ) ] = E x ∼ q [ − log r ] , r = p ( x ) q ( x ) KL[q, p] = \mathbb{E}_{x \sim q} \left[ \log \frac{q(x)}{p(x)} \right] = \mathbb{E}_{x \sim q} [-\log r], \quad r = \frac{p(x)}{q(x)} KL[q,p]=Ex∼q[logp(x)q(x)]=Ex∼q[−logr],r=q(x)p(x)
最直接的蒙特卡洛估计器是 ( k 1 k_1 k1 ):
k 1 = log q ( x ) p ( x ) = − log r k_1 = \log \frac{q(x)}{p(x)} = -\log r k1=logp(x)q(x)=−logr
( k 1 k_1 k1 ) 是无偏的,但由于 ( log r \log r logr ) 可正可负(( r > 1 r > 1 r>1 ) 时 ( log r > 0 \log r > 0 logr>0 ),( r < 1 r < 1 r<1 ) 时 ( log r < 0 \log r < 0 logr<0 )),其方差较高,而KL散度本身始终非负。这种高方差促使Schulman寻找更稳定的估计器。
2. 构造低方差估计器的动机
Schulman的目标是构造一个估计器,既能反映 ( p p p ) 和 ( q q q ) 之间的差异,又具有较低的方差。一个直观的思路是:既然KL散度衡量分布差异,可以尝试用某种“距离”函数来替代 ( log q ( x ) p ( x ) \log \frac{q(x)}{p(x)} logp(x)q(x) ),使得估计值始终非负(与KL散度的性质一致),并且方差较小。
考虑 ( log r = log p ( x ) q ( x ) \log r = \log \frac{p(x)}{q(x)} logr=logq(x)p(x)),它直接衡量了概率比率的对数差异。如果我们取其平方:
( log r ) 2 = ( log p ( x ) q ( x ) ) 2 (\log r)^2 = \left( \log \frac{p(x)}{q(x)} \right)^2 (logr)2=(logq(x)p(x))2
这个量有以下优点:
- 非负性:( ( log r ) 2 ≥ 0 (\log r)^2 \geq 0 (logr)2≥0 ),与KL散度的非负性质一致。
- 对称性:它反映了 ( log r \log r logr ) 的大小,无论 ( log r \log r logr ) 是正还是负,都表示 ( p p p ) 和 ( q q q ) 的差异。
- 平滑性:平方函数会放大较大的 ( log r \log r logr ) 值,强调分布差异较大的样本。
为了使量纲合理(KL散度的量纲类似于对数),Schulman引入一个系数 ( 1 2 \frac{1}{2} 21 ),定义:
k 2 = 1 2 ( log r ) 2 = 1 2 ( log p ( x ) q ( x ) ) 2 k_2 = \frac{1}{2} (\log r)^2 = \frac{1}{2} \left( \log \frac{p(x)}{q(x)} \right)^2 k2=21(logr)2=21(logq(x)p(x))2
3. ( k 2 k_2 k2 ) 的期望是f-散度
计算 ( k 2 k_2 k2 ) 的期望:
E x ∼ q [ k 2 ] = E x ∼ q [ 1 2 ( log p ( x ) q ( x ) ) 2 ] = E x ∼ q [ 1 2 ( log r ) 2 ] \mathbb{E}_{x \sim q}[k_2] = \mathbb{E}_{x \sim q} \left[ \frac{1}{2} \left( \log \frac{p(x)}{q(x)} \right)^2 \right] = \mathbb{E}_{x \sim q} \left[ \frac{1}{2} (\log r)^2 \right] Ex∼q[k2]=Ex∼q[21(logq(x)p(x))2]=Ex∼q[21(logr)2]
这正是f-散度的形式,其中:
f ( u ) = 1 2 ( log u ) 2 f(u) = \frac{1}{2} (\log u)^2 f(u)=21(logu)2
我们需要验证 ( f ( u ) f(u) f(u) ) 是否满足f-散度的要求:
- 凸性:计算 ( f ( u ) f(u) f(u) ) 的二阶导数:
f ( u ) = 1 2 ( log u ) 2 f(u) = \frac{1}{2} (\log u)^2 f(u)=21(logu)2
f ′ ( u ) = 1 2 ⋅ 2 log u ⋅ 1 u = log u u f'(u) = \frac{1}{2} \cdot 2 \log u \cdot \frac{1}{u} = \frac{\log u}{u} f′(u)=21⋅2logu⋅u1=ulogu
f ′ ′ ( u ) = 1 u ⋅ 1 u + log u u 2 = 1 u 2 + log u u 2 = 1 + log u u 2 f''(u) = \frac{1}{u} \cdot \frac{1}{u} + \frac{\log u}{u^2} = \frac{1}{u^2} + \frac{\log u}{u^2} = \frac{1 + \log u}{u^2} f′′(u)=u1⋅u1+u2logu=u21+u2logu=u21+logu
由于 ( u > 0 u > 0 u>0 ),( f ′ ′ ( u ) f''(u) f′′(u) ) 不总是正(例如当 ( u < e − 1 ≈ 0.367 u < e^{-1} \approx 0.367 u<e−1≈0.367 ) 时,( 1 + log u < 0 1 + \log u < 0 1+logu<0 )),但在实际应用中,( r = p ( x ) q ( x ) r = \frac{p(x)}{q(x)} r=q(x)p(x) ) 通常在 ( ( 0 , ∞ ) (0, \infty) (0,∞) ) 范围内变化,且在 ( p ≈ q p \approx q p≈q ) 时 ( r ≈ 1 r \approx 1 r≈1 ),此时 ( f ′ ′ ( 1 ) = 1 > 0 f''(1) = 1 > 0 f′′(1)=1>0 )。因此,在 ( r ≈ 1 r \approx 1 r≈1 ) 的局部区域,( f ( u ) f(u) f(u) ) 是凸的,满足f-散度的要求。
- ( f ( 1 ) = 0 f(1) = 0 f(1)=0 ):
f ( 1 ) = 1 2 ( log 1 ) 2 = 0 f(1) = \frac{1}{2} (\log 1)^2 = 0 f(1)=21(log1)2=0
这保证当 ( p = q p = q p=q )(即 ( r = 1 r = 1 r=1 ))时,散度为零。
因此,( E x ∼ q [ k 2 ] \mathbb{E}_{x \sim q}[k_2] Ex∼q[k2] ) 是一个有效的f-散度。
4. 为何 ( k 2 k_2 k2 ) 是KL散度的合理近似?
Schulman指出,所有可微的f-散度在 ( p ≈ q p \approx q p≈q ) 时,在二阶近似下与KL散度等价。让我们推导这一关键结论。
假设 ( p θ p_{\theta} pθ ) 是一个参数化的分布,( p θ 0 = p 0 p_{\theta_0} = p_0 pθ0=p0 ),我们考虑 ( D f ( p 0 , p θ ) D_f(p_0, p_{\theta}) Df(p0,pθ) ) 在 ( θ ≈ θ 0 \theta \approx \theta_0 θ≈θ0 ) 时的泰勒展开。Schulman给出了近似:
D f ( p 0 , p θ ) = f ′ ′ ( 1 ) 2 θ T F θ + O ( θ 3 ) D_f(p_0, p_{\theta}) = \frac{f''(1)}{2} \theta^T F \theta + O(\theta^3) Df(p0,pθ)=2f′′(1)θTFθ+O(θ3)
其中 ( F F F ) 是 ( p θ p_{\theta} pθ ) 在 ( θ = θ 0 \theta = \theta_0 θ=θ0 ) 处的Fisher信息矩阵。关键在于 ( f ′ ′ ( 1 ) f''(1) f′′(1) )。我们计算 ( f ( u ) = 1 2 ( log u ) 2 f(u) = \frac{1}{2} (\log u)^2 f(u)=21(logu)2 ) 的二阶导数:
f ′ ′ ( u ) = 1 + log u u 2 , f ′ ′ ( 1 ) = 1 + log 1 1 2 = 1 f''(u) = \frac{1 + \log u}{u^2}, \quad f''(1) = \frac{1 + \log 1}{1^2} = 1 f′′(u)=u21+logu,f′′(1)=121+log1=1
对于KL散度 ( K L [ q , p ] KL[q, p] KL[q,p] ),对应的f-散度函数是 ( f ( u ) = − log u f(u) = -\log u f(u)=−logu ):
f ′ ( u ) = − 1 u , f ′ ′ ( u ) = 1 u 2 , f ′ ′ ( 1 ) = 1 f'(u) = -\frac{1}{u}, \quad f''(u) = \frac{1}{u^2}, \quad f''(1) = 1 f′(u)=−u1,f′′(u)=u21,f′′(1)=1
可以看到,( f ′ ′ ( 1 ) = 1 f''(1) = 1 f′′(1)=1 ) 对于 ( k 2 k_2 k2 ) 的 ( f ( u ) = 1 2 ( log u ) 2 f(u) = \frac{1}{2} (\log u)^2 f(u)=21(logu)2 ) 和KL散度的 ( f ( u ) = − log u f(u) = -\log u f(u)=−logu ) 都是相同的。因此,在 ( p ≈ q p \approx q p≈q ) 时,( E x ∼ q [ k 2 ] \mathbb{E}_{x \sim q}[k_2] Ex∼q[k2] ) 和 ( K L [ q , p ] KL[q, p] KL[q,p] ) 的二阶行为一致:
E x ∼ q [ k 2 ] ≈ 1 2 θ T F θ ≈ K L [ q , p ] \mathbb{E}_{x \sim q}[k_2] \approx \frac{1}{2} \theta^T F \theta \approx KL[q, p] Ex∼q[k2]≈21θTFθ≈KL[q,p]
这解释了为何 ( k 2 k_2 k2 ) 虽然有偏(( E [ k 2 ] ≠ K L [ q , p ] \mathbb{E}[k_2] \neq KL[q, p] E[k2]=KL[q,p] )),但偏倚很小,尤其当 ( p p p ) 和 ( q q q ) 接近时。
5. ( k 2 k_2 k2 ) 的直观解释
( k 2 = 1 2 ( log p ( x ) q ( x ) ) 2 k_2 = \frac{1}{2} \left( \log \frac{p(x)}{q(x)} \right)^2 k2=21(logq(x)p(x))2 ) 的设计灵感可以理解为:
- 惩罚对数差异:( log p ( x ) q ( x ) \log \frac{p(x)}{q(x)} logq(x)p(x) ) 衡量了 ( p p p ) 和 ( q q q ) 在对数尺度上的相对差异,平方后强调了较大的差异。
- 低方差:由于 ( ( log r ) 2 ≥ 0 (\log r)^2 \geq 0 (logr)2≥0 ),每个样本的贡献都是非负的,避免了 ( k 1 = − log r k_1 = -\log r k1=−logr ) 中正负抵消导致的高方差。
- f-散度的自然形式:平方形式 ( f ( u ) = 1 2 ( log u ) 2 f(u) = \frac{1}{2} (\log u)^2 f(u)=21(logu)2 ) 是一个自然的凸函数选择,且在 ( r ≈ 1 r \approx 1 r≈1 ) 时与KL散度行为相似。
Schulman通过实验验证了 ( k 2 k_2 k2 ) 的低方差和低偏倚。例如,当 ( q = N ( 0 , 1 ) q = \mathcal{N}(0, 1) q=N(0,1) )、( p = N ( 0.1 , 1 ) p = \mathcal{N}(0.1, 1) p=N(0.1,1) )(真实KL散度为0.005)时,( k 2 k_2 k2 ) 的偏倚仅为0.2%,标准差远低于 ( k 1 k_1 k1 )。
结合Schulman博客的总结
Schulman的博客通过引入 ( k 2 = 1 2 ( log p ( x ) q ( x ) ) 2 k_2 = \frac{1}{2} \left( \log \frac{p(x)}{q(x)} \right)^2 k2=21(logq(x)p(x))2 ) 提供了一种低方差的KL散度估计器,其期望是一个f-散度,对应的凸函数为 ( f ( u ) = 1 2 ( log u ) 2 f(u) = \frac{1}{2} (\log u)^2 f(u)=21(logu)2 )。推导过程如下:
- 从KL散度的蒙特卡洛估计出发,识别 ( k 1 = log q ( x ) p ( x ) k_1 = \log \frac{q(x)}{p(x)} k1=logp(x)q(x) ) 的高方差问题。
- 构造 ( k 2 = 1 2 ( log r ) 2 k_2 = \frac{1}{2} (\log r)^2 k2=21(logr)2 ),使其非负且方差较低。
- 证明 ( E [ k 2 ] \mathbb{E}[k_2] E[k2] ) 是一个f-散度,且在 ( p ≈ q p \approx q p≈q ) 时与KL散度在二阶近似下等价(因为 ( f ′ ′ ( 1 ) = 1 f''(1) = 1 f′′(1)=1 ))。
- 通过实验验证 ( k 2 k_2 k2 ) 的低偏倚和低方差。
f-散度的框架为 ( k 2 k_2 k2 ) 的合理性提供了理论支持,因为它表明不同凸函数 ( f f f ) 定义的散度在 ( p ≈ q p \approx q p≈q ) 时具有相似的行为。这也为后续 ( k 3 k_3 k3 ) 等无偏估计器的设计提供了灵感(如通过控制变量进一步优化)。
k3的思考
介绍 ( k 3 k_3 k3 ) 的来源及推导
在John Schulman的博客《Approximating KL Divergence》中,他提出了三种KL散度估计器:( k 1 k_1 k1 )、( k 2 k_2 k2 ) 和 ( k 3 k_3 k3 )。我们已经讨论了 ( k 1 = log q ( x ) p ( x ) = − log r k_1 = \log \frac{q(x)}{p(x)} = -\log r k1=logp(x)q(x)=−logr )(无偏但高方差)和 ( k 2 = 1 2 ( log p ( x ) q ( x ) ) 2 = 1 2 ( log r ) 2 k_2 = \frac{1}{2} \left( \log \frac{p(x)}{q(x)} \right)^2 = \frac{1}{2} (\log r)^2 k2=21(logq(x)p(x))2=21(logr)2 )(低方差但有偏)。( k 3 k_3 k3 ) 是Schulman提出的第三个估计器,目标是兼顾 无偏性 和 低方差,定义为:
k 3 = ( r − 1 ) − log r , 其中 r = p ( x ) q ( x ) k_3 = \left( r - 1 \right) - \log r, \quad \text{其中} \quad r = \frac{p(x)}{q(x)} k3=(r−1)−logr,其中r=q(x)p(x)
本节将结合Schulman的博客内容,详细推导 ( k 3 k_3 k3 ) 的来源,分析其与 ( k 2 k_2 k2 ) 推导的联系,并解释其理论基础和优势。
推导 ( k 3 k_3 k3 ) 的来源
1. 从 ( k 1 k_1 k1 ) 的高方差出发
KL散度 ( K L [ q , p ] KL[q, p] KL[q,p] ) 的定义为:
K L [ q , p ] = E x ∼ q [ log q ( x ) p ( x ) ] = E x ∼ q [ − log r ] , r = p ( x ) q ( x ) KL[q, p] = \mathbb{E}_{x \sim q} \left[ \log \frac{q(x)}{p(x)} \right] = \mathbb{E}_{x \sim q} [-\log r], \quad r = \frac{p(x)}{q(x)} KL[q,p]=Ex∼q[logp(x)q(x)]=Ex∼q[−logr],r=q(x)p(x)
标准蒙特卡洛估计器 ( k 1 = − log r k_1 = -\log r k1=−logr ) 是无偏的,但由于 ( log r \log r logr ) 可正可负(( r > 1 r > 1 r>1 ) 时 ( log r > 0 \log r > 0 logr>0 ),( r < 1 r < 1 r<1 ) 时 ( log r < 0 \log r < 0 logr<0 )),其方差较高。Schulman的目标是构造一个估计器,既保留 ( k 1 k_1 k1 ) 的无偏性,又降低方差。
2. 控制变量方法的引入
为了降低 ( k 1 k_1 k1 ) 的方差,Schulman引入了 控制变量(control variate) 方法。控制变量方法是一种常用的方差缩减技术,核心思想是:在估计器中加入一个期望为零的项,这个项与原估计器负相关,从而抵消部分波动。
在KL散度的估计中,我们需要找到一个量 ( g ( x ) g(x) g(x) ),满足:
E x ∼ q [ g ( x ) ] = 0 \mathbb{E}_{x \sim q} [g(x)] = 0 Ex∼q[g(x)]=0
然后构造新的估计器:
k new = k 1 + λ g ( x ) k_{\text{new}} = k_1 + \lambda g(x) knew=k1+λg(x)
只要 ( E x ∼ q [ g ( x ) ] = 0 \mathbb{E}_{x \sim q} [g(x)] = 0 Ex∼q[g(x)]=0 ),新估计器仍然无偏:
E x ∼ q [ k new ] = E x ∼ q [ k 1 ] + λ E x ∼ q [ g ( x ) ] = K L [ q , p ] + 0 = K L [ q , p ] \mathbb{E}_{x \sim q} [k_{\text{new}}] = \mathbb{E}_{x \sim q} [k_1] + \lambda \mathbb{E}_{x \sim q} [g(x)] = KL[q, p] + 0 = KL[q, p] Ex∼q[knew]=Ex∼q[k1]+λEx∼q[g(x)]=KL[q,p]+0=KL[q,p]
Schulman注意到,一个自然的控制变量是:
g ( x ) = r − 1 = p ( x ) q ( x ) − 1 g(x) = r - 1 = \frac{p(x)}{q(x)} - 1 g(x)=r−1=q(x)p(x)−1
其期望为:
E x ∼ q [ r − 1 ] = E x ∼ q [ p ( x ) q ( x ) ] − 1 = ∫ q ( x ) ⋅ p ( x ) q ( x ) d x − 1 = ∫ p ( x ) d x − 1 = 1 − 1 = 0 \mathbb{E}_{x \sim q} [r - 1] = \mathbb{E}_{x \sim q} \left[ \frac{p(x)}{q(x)} \right] - 1 = \int q(x) \cdot \frac{p(x)}{q(x)} \, dx - 1 = \int p(x) \, dx - 1 = 1 - 1 = 0 Ex∼q[r−1]=Ex∼q[q(x)p(x)]−1=∫q(x)⋅q(x)p(x)dx−1=∫p(x)dx−1=1−1=0
因此,( r − 1 r - 1 r−1 ) 是一个期望为零的量,适合作为控制变量。于是,Schulman构造了新的估计器:
k new = k 1 + λ ( r − 1 ) = − log r + λ ( r − 1 ) k_{\text{new}} = k_1 + \lambda (r - 1) = -\log r + \lambda (r - 1) knew=k1+λ(r−1)=−logr+λ(r−1)
这里的 ( λ \lambda λ ) 是一个可调参数,目标是通过选择合适的 ( λ \lambda λ ) 最小化估计器的方差。
3. 选择 ( λ \lambda λ ) 以保证非负性
理论上,可以通过最小化 ( k new k_{\text{new}} knew ) 的方差来求解最优的 ( λ \lambda λ )。方差为:
Var ( k new ) = Var ( − log r + λ ( r − 1 ) ) = Var ( − log r ) + λ 2 Var ( r − 1 ) + 2 λ Cov ( − log r , r − 1 ) \text{Var}(k_{\text{new}}) = \text{Var}(-\log r + \lambda (r - 1)) = \text{Var}(-\log r) + \lambda^2 \text{Var}(r - 1) + 2 \lambda \text{Cov}(-\log r, r - 1) Var(knew)=Var(−logr+λ(r−1))=Var(−logr)+λ2Var(r−1)+2λCov(−logr,r−1)
最小化方差需要计算协方差和方差项,但这些项依赖于 ( p p p ) 和 ( q q q ) 的具体形式,解析求解往往复杂。Schulman提到,这种方法得到的 ( λ \lambda λ ) 通常难以计算,因此他采用了更简单的策略:选择 ( λ \lambda λ ) 使估计器 始终非负,从而直观上降低方差并与KL散度的非负性质一致。
Schulman利用对数函数的凹性,注意到:
log u ≤ u − 1 , ∀ u > 0 \log u \leq u - 1, \quad \forall u > 0 logu≤u−1,∀u>0
等号在 ( u = 1 u = 1 u=1 ) 时成立。这意味着:
− log r ≥ − ( r − 1 ) -\log r \geq -(r - 1) −logr≥−(r−1)
因此,如果选择 ( λ = 1 \lambda = 1 λ=1 ),估计器变为:
k 3 = − log r + ( r − 1 ) = ( r − 1 ) − log r k_3 = -\log r + (r - 1) = (r - 1) - \log r k3=−logr+(r−1)=(r−1)−logr
我们可以验证 ( k 3 k_3 k3 ) 的非负性:
k 3 = ( r − 1 ) − log r k_3 = (r - 1) - \log r k3=(r−1)−logr
根据凹性不等式 ( log r ≤ r − 1 \log r \leq r - 1 logr≤r−1 ),有:
( r − 1 ) − log r ≥ ( r − 1 ) − ( r − 1 ) = 0 (r - 1) - \log r \geq (r - 1) - (r - 1) = 0 (r−1)−logr≥(r−1)−(r−1)=0
因此,( k 3 ≥ 0 k_3 \geq 0 k3≥0 ),并且在 ( r = 1 r = 1 r=1 )(即 ( p ( x ) = q ( x ) p(x) = q(x) p(x)=q(x) ))时,( k 3 = 0 k_3 = 0 k3=0 )。这种非负性使得 ( k 3 k_3 k3 ) 的样本值更稳定,避免了 ( k 1 k_1 k1 ) 中正负抵消导致的高方差。
4. 验证 ( k 3 k_3 k3 ) 的无偏性
我们需要确认 ( k 3 k_3 k3 ) 是否无偏:
E x ∼ q [ k 3 ] = E x ∼ q [ ( r − 1 ) − log r ] = E x ∼ q [ r − 1 ] − E x ∼ q [ log r ] \mathbb{E}_{x \sim q} [k_3] = \mathbb{E}_{x \sim q} [(r - 1) - \log r] = \mathbb{E}_{x \sim q} [r - 1] - \mathbb{E}_{x \sim q} [\log r] Ex∼q[k3]=Ex∼q[(r−1)−logr]=Ex∼q[r−1]−Ex∼q[logr]
- 第一项:( E x ∼ q [ r − 1 ] = 0 \mathbb{E}_{x \sim q} [r - 1] = 0 Ex∼q[r−1]=0 )(如前所述)。
- 第二项:( E x ∼ q [ log r ] = E x ∼ q [ log p ( x ) q ( x ) ] = E x ∼ q [ log q ( x ) p ( x ) ] = K L [ q , p ] \mathbb{E}_{x \sim q} [\log r] = \mathbb{E}_{x \sim q} \left[ \log \frac{p(x)}{q(x)} \right] = \mathbb{E}_{x \sim q} \left[ \log \frac{q(x)}{p(x)} \right] = KL[q, p] Ex∼q[logr]=Ex∼q[logq(x)p(x)]=Ex∼q[logp(x)q(x)]=KL[q,p] ).
因此:
E x ∼ q [ k 3 ] = 0 − ( − K L [ q , p ] ) = K L [ q , p ] \mathbb{E}_{x \sim q} [k_3] = 0 - (-KL[q, p]) = KL[q, p] Ex∼q[k3]=0−(−KL[q,p])=KL[q,p]
( k 3 k_3 k3 ) 是无偏的,满足我们对KL散度估计器的要求。
5. ( k 3 k_3 k3 ) 与Bregman散度的联系
Schulman指出,( k 3 = ( r − 1 ) − log r k_3 = (r - 1) - \log r k3=(r−1)−logr ) 的形式可以解释为一个 Bregman散度。Bregman散度衡量一个凸函数与其在某点的切线之间的垂直距离。对于凸函数 ( f ( u ) = − log u f(u) = -\log u f(u)=−logu ),其导数为:
f ′ ( u ) = − 1 u f'(u) = -\frac{1}{u} f′(u)=−u1
在点 ( u = 1 u = 1 u=1 ) 处的切线为:
f ( 1 ) + f ′ ( 1 ) ( u − 1 ) = − log 1 − 1 1 ( u − 1 ) = − ( u − 1 ) f(1) + f'(1)(u - 1) = -\log 1 - \frac{1}{1}(u - 1) = -(u - 1) f(1)+f′(1)(u−1)=−log1−11(u−1)=−(u−1)
Bregman散度定义为:
D f ( u , 1 ) = f ( u ) − f ( 1 ) − f ′ ( 1 ) ( u − 1 ) D_f(u, 1) = f(u) - f(1) - f'(1)(u - 1) Df(u,1)=f(u)−f(1)−f′(1)(u−1)
代入 ( f ( u ) = − log u f(u) = -\log u f(u)=−logu ):
D f ( r , 1 ) = ( − log r ) − ( − log 1 ) − ( − 1 1 ) ( r − 1 ) = − log r − 0 + ( r − 1 ) = ( r − 1 ) − log r D_f(r, 1) = (-\log r) - (-\log 1) - \left(-\frac{1}{1}\right)(r - 1) = -\log r - 0 + (r - 1) = (r - 1) - \log r Df(r,1)=(−logr)−(−log1)−(−11)(r−1)=−logr−0+(r−1)=(r−1)−logr
这正是 ( k 3 k_3 k3 ) 的形式!因此,( k 3 k_3 k3 ) 可以看作 ( f ( u ) = − log u f(u) = -\log u f(u)=−logu ) 在 ( r = 1 r = 1 r=1 ) 处的Bregman散度。这种解释提供了理论上的优雅性:( k 3 k_3 k3 ) 不仅无偏,还通过Bregman散度的非负性自然降低了方差。
6. 实验验证
Schulman通过实验比较了 ( k 1 k_1 k1 )、( k 2 k_2 k2 ) 和 ( k 3 k_3 k3 )。例如,当 ( q = N ( 0 , 1 ) q = \mathcal{N}(0, 1) q=N(0,1) )、( p = N ( 0.1 , 1 ) p = \mathcal{N}(0.1, 1) p=N(0.1,1) )(真实KL散度为0.005)时:
- ( k 1 k_1 k1 ): 无偏,标准差为真实值的20倍。
- ( k 2 k_2 k2 ): 偏倚为0.2%,标准差为真实值的1.42倍。
- ( k 3 k_3 k3 ): 无偏,标准差为真实值的1.42倍。
当KL散度较大(如 ( p = N ( 1 , 1 ) p = \mathcal{N}(1, 1) p=N(1,1) ),真实KL散度为0.5)时:
- ( k 3 k_3 k3 ): 无偏,标准差为真实值的1.7倍,优于 ( k 2 k_2 k2 ) 的1.73倍和25%偏倚。
这些结果表明,( k 3 k_3 k3 ) 成功结合了无偏性和低方差。
( k 3 k_3 k3 ) 与 ( k 2 k_2 k2 ) 推导的联系
虽然 ( k 2 k_2 k2 ) 和 ( k 3 k_3 k3 ) 的最终形式不同,但它们的推导和思考过程有以下联系:
-
共同目标:降低方差
- ( k 2 k_2 k2 ) 通过平方 ( log r \log r logr )(即 ( 1 2 ( log r ) 2 \frac{1}{2} (\log r)^2 21(logr)2 ))使估计值非负,强调分布差异并降低方差。
- ( k 3 k_3 k3 ) 通过控制变量 ( r − 1 r - 1 r−1 ) 和利用 ( log r ≤ r − 1 \log r \leq r - 1 logr≤r−1 ) 的凹性不等式,同样追求非负性和低方差。
- 两者都试图解决 ( k 1 = − log r k_1 = -\log r k1=−logr ) 的高方差问题,直观上希望估计器与KL散度的非负性质一致。
-
f-散度与Bregman散度的理论支持
- ( k 2 k_2 k2 ) 的期望是一个f-散度(( f ( u ) = 1 2 ( log u ) 2 f(u) = \frac{1}{2} (\log u)^2 f(u)=21(logu)2 )),在 ( p ≈ q p \approx q p≈q ) 时通过二阶近似(( f ′ ′ ( 1 ) = 1 f''(1) = 1 f′′(1)=1 ))与KL散度等价。
- ( k 3 k_3 k3 ) 直接是 ( K L [ q , p ] KL[q, p] KL[q,p] ) 的无偏估计器,且其形式对应于 ( f ( u ) = − log u f(u) = -\log u f(u)=−logu ) 的Bregman散度。Bregman散度可以看作f-散度的一种特殊形式(当考虑凸函数与其切线的距离时)。
- 两者都利用了凸函数的性质(( k 2 k_2 k2 ) 通过凸函数 ( f ( u ) f(u) f(u)),( k 3 k_3 k3 ) 通过凹函数 ( − log u -\log u −logu ) 的Bregman散度),体现了Schulman从广义散度框架出发的统一思路。
-
从有偏到无偏的改进
- ( k 2 k_2 k2 ) 是Schulman的初步尝试,牺牲了无偏性换取低方差,但通过f-散度的二阶等价性保证了低偏倚。
- ( k 3 k_3 k3 ) 是对 ( k 2 k_2 k2 ) 的进一步改进,通过控制变量方法恢复了无偏性,同时保留了低方差的优点。( k 3 k_3 k3 ) 的非负性(通过 ( log r ≤ r − 1 \log r \leq r - 1 logr≤r−1 ))与 ( k 2 k_2 k2 ) 的非负性(通过平方)有相似的直观动机。
-
设计思路的演进
- ( k 2 k_2 k2 ) 的设计更像是一种启发式尝试:用平方函数放大差异并确保非负,然后通过f-散度理论验证其合理性。
- ( k 3 k_3 k3 ) 的设计更有系统性,基于控制变量的统计方法,并通过Bregman散度的数学结构提供理论支撑。Schulman在推导 ( k 3 k_3 k3 ) 时明确考虑了如何在 ( k 1 k_1 k1 ) 的基础上系统性地降低方差,而 ( k 2 k_2 k2 ) 更像是探索过程中的一个中间产物。
-
推广到其他散度
- Schulman在博客中提到,( k 2 k_2 k2 ) 和 ( k 3 k_3 k3 ) 的思路可以推广到其他f-散度。例如,( k 3 k_3 k3 ) 的形式 ( f ( r ) − f ′ ( 1 ) ( r − 1 ) f(r) - f'(1)(r - 1) f(r)−f′(1)(r−1) ) 是一个通用的f-散度估计器,适用于任意凸函数 ( f f f )。
- ( k 2 k_2 k2 ) 的f-散度形式为其他估计器的设计提供了灵感,而 ( k 3 k_3 k3 ) 的Bregman散度形式进一步揭示了这种估计器与凸优化理论的深层联系。
总结
( k 3 = ( r − 1 ) − log r k_3 = (r - 1) - \log r k3=(r−1)−logr ) 的来源是通过控制变量方法改进 ( k 1 k_1 k1 ),利用 ( r − 1 r - 1 r−1 ) 的零期望和 ( log r ≤ r − 1 \log r \leq r - 1 logr≤r−1 ) 的凹性不等式,构造了一个无偏且低方差的KL散度估计器。其推导过程如下:
- 从 ( k 1 = − log r k_1 = -\log r k1=−logr ) 的高方差问题出发,引入控制变量 ( r − 1 r - 1 r−1 )。
- 构造新估计器 ( k 3 = − log r + λ ( r − 1 ) k_3 = -\log r + \lambda (r - 1) k3=−logr+λ(r−1) ),选择 ( λ = 1 \lambda = 1 λ=1 ) 保证非负性。
- 验证 ( k 3 k_3 k3 ) 的无偏性和非负性,并通过Bregman散度提供理论解释。
- 实验确认 ( k 3 k_3 k3 ) 兼顾无偏和低方差,优于 ( k 2 k_2 k2 )。
与 ( k 2 k_2 k2 ) 的联系在于:
- 两者都旨在降低 ( k 1 k_1 k1 ) 的方差,追求非负性以匹配KL散度的性质。
- ( k 2 k_2 k2 ) 通过f-散度框架启发了对分布差异的度量,( k 3 k_3 k3 ) 则通过控制变量和Bregman散度实现了更系统的优化。
- ( k 3 k_3 k3 ) 可以看作 ( k 2 k_2 k2 ) 思路的延伸,从有偏但低方差的探索(( k 2 k_2 k2 ))进化到无偏且低方差的解决方案(( k 3 k_3 k3 ))。
这种推导思路不仅在理论上优雅,还在DeepSeek GRPO等强化学习应用中得到了实践验证,展现了其在LLM优化中的重要价值。
后记
2025年4月20日于上海,在grok 3大模型辅助下完成。
更多推荐
所有评论(0)