KL散度近似方法介绍:从John Schulman的博客到DeepSeek GRPO的应用

KL散度(Kullback-Leibler Divergence)是衡量两个概率分布之间差异的重要指标,广泛应用于机器学习、强化学习等领域。John Schulman在其2020年3月7日的博客中详细探讨了如何通过蒙特卡洛方法近似KL散度,并提出了一种低方差、无偏的估计器。这一方法不仅在理论上具有重要意义,还被DeepSeek的GRPO算法所采用。本文将基于Schulman的博客内容,介绍KL散度的近似方法及其在DeepSeek GRPO中的应用。

blog:http://joschu.net/blog/kl-approx.html

首先看一下John Schulman是谁:PPO的作者,领导创建了ChatGPT,OpenAI的cofounder。

在这里插入图片描述

source:http://joschu.net/index.html

在这里插入图片描述

source:https://scholar.google.com/citations?user=itSa94cAAAAJ

KL散度的基本定义

KL散度衡量两个概率分布 ( p p p ) 和 ( q q q ) 之间的差异,定义为:

K L [ q , p ] = ∑ x q ( x ) log ⁡ q ( x ) p ( x ) = E x ∼ q [ log ⁡ q ( x ) p ( x ) ] KL[q, p] = \sum_x q(x) \log \frac{q(x)}{p(x)} = \mathbb{E}_{x \sim q} \left[ \log \frac{q(x)}{p(x)} \right] KL[q,p]=xq(x)logp(x)q(x)=Exq[logp(x)q(x)]

在实际应用中,精确计算KL散度通常不可行,因为:

  1. 计算所有 ( x x x ) 的概率和需要过多的计算或内存。
  2. 分布 ( p p p ) 和 ( q q q ) 可能没有闭合表达式。
  3. 在强化学习等场景中,KL散度常作为诊断工具,仅存储对数概率以简化代码。

因此,蒙特卡洛方法成为估计KL散度的常用策略。假设我们从分布 ( q q q ) 中采样 ( x 1 , x 2 , … x_1, x_2, \dots x1,x2, ),如何构造一个既无偏又低方差的估计器是关键问题。

传统估计器及其局限性

最直接的蒙特卡洛估计器是基于KL散度的定义:

k 1 = log ⁡ q ( x ) p ( x ) = − log ⁡ r , 其中 r = p ( x ) q ( x ) k_1 = \log \frac{q(x)}{p(x)} = -\log r, \quad \text{其中} \quad r = \frac{p(x)}{q(x)} k1=logp(x)q(x)=logr,其中r=q(x)p(x)

这个估计器 ( k 1 k_1 k1 ) 是无偏的,即其期望等于真实的KL散度。然而,由于 ( log ⁡ r \log r logr ) 的值在正负之间变化(当 ( r > 1 r > 1 r>1 ) 时为正,当 ( r < 1 r < 1 r<1 ) 时为负),其方差较高,而KL散度本身始终为正。这种高方差使得 ( k 1 k_1 k1 ) 在实际应用中表现不佳。

低方差的偏倚估计器

Schulman提出了一种替代估计器:

k 2 = 1 2 ( log ⁡ p ( x ) q ( x ) ) 2 = 1 2 ( log ⁡ r ) 2 k_2 = \frac{1}{2} \left( \log \frac{p(x)}{q(x)} \right)^2 = \frac{1}{2} (\log r)^2 k2=21(logq(x)p(x))2=21(logr)2

这个估计器 ( k 2 k_2 k2 ) 虽然有偏,但方差显著低于 ( k 1 k_1 k1 )。其优点在于:

  1. 始终为正:每个样本都反映了 ( p p p ) 和 ( q q q ) 之间的差异,且结果非负,与KL散度的性质一致。
  2. 低偏倚:( k 2 k_2 k2 ) 的期望是一个f-散度,其形式为:

E q [ k 2 ] = E q [ 1 2 ( log ⁡ r ) 2 ] \mathbb{E}_q[k_2] = \mathbb{E}_q \left[ \frac{1}{2} (\log r)^2 \right] Eq[k2]=Eq[21(logr)2]

这是一个以 ( f ( x ) = 1 2 ( log ⁡ x ) 2 f(x) = \frac{1}{2} (\log x)^2 f(x)=21(logx)2 ) 定义的f-散度。当 ( p p p ) 和 ( q q q ) 接近时,所有可微的f-散度在二阶近似下与KL散度等价,因此 ( k 2 k_2 k2 ) 的偏倚非常小。

实验表明,当 ( q = N ( 0 , 1 ) q = \mathcal{N}(0, 1) q=N(0,1) )、( p = N ( 0.1 , 1 ) p = \mathcal{N}(0.1, 1) p=N(0.1,1) ) 时(真实KL散度为0.005),( k 2 k_2 k2 ) 的偏倚仅为0.2%,标准差为真实值的1.42倍,远低于 ( k 1 k_1 k1 ) 的20倍。

无偏且低方差的估计器

为了兼顾无偏和低方差,Schulman引入了控制变量(control variate)方法。利用 ( E q [ r − 1 ] = 0 \mathbb{E}_q[r - 1] = 0 Eq[r1]=0 ),可以构造一个新的估计器:

k 3 = − log ⁡ r + λ ( r − 1 ) k_3 = -\log r + \lambda (r - 1) k3=logr+λ(r1)

通过选择适当的 ( λ \lambda λ ),可以降低方差。当 ( λ = 1 \lambda = 1 λ=1 ) 时,估计器变为:

k 3 = ( r − 1 ) − log ⁡ r k_3 = (r - 1) - \log r k3=(r1)logr

由于对数的凹性,( log ⁡ r ≤ r − 1 \log r \leq r - 1 logrr1 ),因此 ( k 3 k_3 k3) 始终为正。这个估计器不仅无偏,而且方差低于 ( k 1 k_1 k1 )。实验表明,当真实KL散度为0.5时,( k 3 k_3 k3 ) 的标准差为真实值的1.7倍,低于 ( k 2 k_2 k2 ) 的1.73倍,且无偏。

推广到其他f-散度

Schulman进一步指出,上述方法可以推广到任意f-散度。对于 ( K L [ p , q ] KL[p, q] KL[p,q] ),对应的f-散度估计器为:

r log ⁡ r − ( r − 1 ) r \log r - (r - 1) rlogr(r1)

这个估计器同样基于凸函数与其切线的距离(Bregman散度),保证了非负性和低方差。

DeepSeek GRPO中的应用

DeepSeek的GRPO算法直接采用了Schulman提出的无偏估计器 ( k 3 k_3 k3 )。具体来说,GRPO使用以下估计器来近似 ( D K L ( π θ ∣ ∣ π r e f ) D_{KL}(\pi_\theta || \pi_{ref}) DKL(πθ∣∣πref) ):

D K L ( π θ ∣ ∣ π r e f ) = π r e f ( o i , t ∣ q , o i , < t ) π θ ( o i , t ∣ q , o i , < t ) − log ⁡ π r e f ( o i , t ∣ q , o i , < t ) π θ ( o i , t ∣ q , o i , < t ) − 1 D_{KL}(\pi_\theta || \pi_{ref}) = \frac{\pi_{ref}(o_{i,t} | q, o_{i,<t})}{\pi_\theta(o_{i,t} | q, o_{i,<t})} - \log \frac{\pi_{ref}(o_{i,t} | q, o_{i,<t})}{\pi_\theta(o_{i,t} | q, o_{i,<t})} - 1 DKL(πθ∣∣πref)=πθ(oi,tq,oi,<t)πref(oi,tq,oi,<t)logπθ(oi,tq,oi,<t)πref(oi,tq,oi,<t)1

其中,( π θ \pi_\theta πθ ) 表示策略网络,( π r e f \pi_{ref} πref ) 表示参考策略,( o i , t o_{i,t} oi,t ) 表示在时间 ( t t t ) 的观测,( q q q ) 和 ( o i , < t o_{i,<t} oi,<t ) 分别表示上下文和历史观测。令 ( r = π r e f ( o i , t ∣ q , o i , < t ) π θ ( o i , t ∣ q , o i , < t ) r = \frac{\pi_{ref}(o_{i,t} | q, o_{i,<t})}{\pi_\theta(o_{i,t} | q, o_{i,<t})} r=πθ(oi,tq,oi,<t)πref(oi,tq,oi,<t) ),该估计器可重写为:

k 3 = ( r − 1 ) − log ⁡ r k_3 = (r - 1) - \log r k3=(r1)logr

这个估计器的优点在于:

  1. 无偏性:保证估计的期望等于真实的KL散度。
  2. 非负性:由于 ( log ⁡ r ≤ r − 1 \log r \leq r - 1 logrr1 ),估计值始终为正,与KL散度的性质一致。
  3. 低方差:通过控制变量降低了估计的方差,提升了算法的稳定性。

与传统的KL惩罚项(例如直接使用 ( log ⁡ π θ π r e f \log \frac{\pi_\theta}{\pi_{ref}} logπrefπθ))相比,GRPO的估计器在强化学习中表现更稳定,尤其是在策略优化需要精确控制分布差异时。

实验结果

Schulman的实验验证了 ( k 3 k_3 k3 ) 的优越性。例如,当 ( q = N ( 0 , 1 ) q = \mathcal{N}(0, 1) q=N(0,1) )、( p = N ( 1 , 1 ) p = \mathcal{N}(1, 1) p=N(1,1) )(真实KL散度为0.5)时:

  • ( k 1 k_1 k1 ): 无偏,标准差为真实值的2倍。
  • ( k 2 k_2 k2 ): 偏倚为25%,标准差为真实值的1.73倍。
  • ( k 3 k_3 k3 ): 无偏,标准差为真实值的1.7倍。

这些结果表明,( k 3 k_3 k3 ) 在保持无偏的同时,显著降低了方差。

总结

John Schulman的博客提供了一种优雅的KL散度近似方法,通过引入控制变量构造了无偏且低方差的估计器 ( k 3 = ( r − 1 ) − log ⁡ r k_3 = (r - 1) - \log r k3=(r1)logr )。这一方法不仅在理论上具有普适性(可推广到任意f-散度),还在DeepSeek的GRPO算法中得到了实际应用。GRPO利用这一估计器实现了稳定的策略优化,展示了其在强化学习中的重要价值。对于需要在高维分布上估计KL散度的研究者和工程师来说,Schulman的方法提供了一个兼顾理论严谨性和实践效果的解决方案。

蒙特卡洛估计器

什么是蒙特卡洛估计器?

蒙特卡洛方法(Monte Carlo Method)是一种通过随机采样来估计复杂数学量的方法。它特别适合用来近似那些难以直接计算的积分、期望或求和,尤其是在高维空间或解析解不可得的情况下。蒙特卡洛估计器的核心思想是:通过从某个分布中抽取大量随机样本,计算这些样本上的函数值,然后用样本均值来近似目标量的期望。

在KL散度的背景下,KL散度 ( K L [ q , p ] KL[q, p] KL[q,p] ) 定义为:

K L [ q , p ] = ∑ x q ( x ) log ⁡ q ( x ) p ( x ) = E x ∼ q [ log ⁡ q ( x ) p ( x ) ] KL[q, p] = \sum_x q(x) \log \frac{q(x)}{p(x)} = \mathbb{E}_{x \sim q} \left[ \log \frac{q(x)}{p(x)} \right] KL[q,p]=xq(x)logp(x)q(x)=Exq[logp(x)q(x)]

这里的 ( E x ∼ q \mathbb{E}_{x \sim q} Exq ) 表示从分布 ( q q q ) 中采样的期望。如果我们无法直接计算整个求和(例如,分布 ( q q q ) 和 ( p p p ) 是高维的,或者没有闭合表达式),就可以用蒙特卡洛方法来估计这个期望。具体来说:

  1. 从分布 ( q q q ) 中抽取 ( N N N ) 个独立样本 ( x 1 , x 2 , … , x N x_1, x_2, \dots, x_N x1,x2,,xN )。
  2. 对每个样本 ( x i x_i xi ),计算函数 ( f ( x i ) = log ⁡ q ( x i ) p ( x i ) f(x_i) = \log \frac{q(x_i)}{p(x_i)} f(xi)=logp(xi)q(xi) )。
  3. 用样本均值来估计期望:

K L [ q , p ] ≈ 1 N ∑ i = 1 N log ⁡ q ( x i ) p ( x i ) KL[q, p] \approx \frac{1}{N} \sum_{i=1}^N \log \frac{q(x_i)}{p(x_i)} KL[q,p]N1i=1Nlogp(xi)q(xi)

这个均值就是蒙特卡洛估计器,称为 ( k 1 k_1 k1 ) 在你的问题中:

k 1 = log ⁡ q ( x ) p ( x ) = − log ⁡ r , 其中 r = p ( x ) q ( x ) k_1 = \log \frac{q(x)}{p(x)} = -\log r, \quad \text{其中} \quad r = \frac{p(x)}{q(x)} k1=logp(x)q(x)=logr,其中r=q(x)p(x)

蒙特卡洛估计器的作用:它通过有限样本近似真实的KL散度值。KL散度本质上衡量两个分布 ( q q q ) 和 ( p p p ) 之间的“距离”(尽管严格来说它不是对称的距离),而蒙特卡洛估计器试图通过采样来捕捉这种差异。


蒙特卡洛估计器是否反映分布的距离?

是的,蒙特卡洛估计器 ( k 1 k_1 k1 ) 的期望等于真实的KL散度,因此它确实反映了两个分布之间的“距离”。具体来说:

  • 无偏性:( E x ∼ q [ k 1 ] = E x ∼ q [ log ⁡ q ( x ) p ( x ) ] = K L [ q , p ] \mathbb{E}_{x \sim q}[k_1] = \mathbb{E}_{x \sim q} \left[ \log \frac{q(x)}{p(x)} \right] = KL[q, p] Exq[k1]=Exq[logp(x)q(x)]=KL[q,p] )。这意味着如果采样次数 ( N → ∞ N \to \infty N ),估计器的平均值会收敛到真实的KL散度。
  • 分布差异的体现:( k 1 = log ⁡ q ( x ) p ( x ) k_1 = \log \frac{q(x)}{p(x)} k1=logp(x)q(x) ) 直接衡量了在样本点 ( x x x ) 上,分布 ( q q q ) 和 ( p p p ) 的概率密度(或概率)的对数差异。如果 ( q ( x ) ≈ p ( x ) q(x) \approx p(x) q(x)p(x) ),则 ( log ⁡ q ( x ) p ( x ) ≈ 0 \log \frac{q(x)}{p(x)} \approx 0 logp(x)q(x)0 ),表明分布很接近;如果 ( q ( x ) q(x) q(x) ) 和 ( p ( x ) p(x) p(x) ) 差异很大,则 ( k 1 k_1 k1 ) 的绝对值会较大,反映分布的差异。

然而,( k 1 k_1 k1 ) 的局限性在于其高方差。由于 ( log ⁡ r \log r logr ) (其中 ( r = p ( x ) q ( x ) r = \frac{p(x)}{q(x)} r=q(x)p(x) ))可能为正或负,样本值的波动会导致估计值的方差较大,尤其当 ( p p p ) 和 ( q q q ) 差异较大时。这种高方差使得 ( k 1 k_1 k1 ) 在实际应用中不够稳定。


通过例子说明

让我们通过一个简单的例子来说明蒙特卡洛估计器如何工作,以及它如何反映分布的距离。

场景

假设有两个一维正态分布:

  • ( q = N ( 0 , 1 ) q = \mathcal{N}(0, 1) q=N(0,1) ) (均值0,标准差1)
  • ( p = N ( 0.5 , 1 ) p = \mathcal{N}(0.5, 1) p=N(0.5,1) ) (均值0.5,标准差1)

真实的KL散度 ( K L [ q , p ] KL[q, p] KL[q,p] ) 对于正态分布有解析解:

K L [ q , p ] = 1 2 ( ( μ q − μ p ) 2 σ p 2 ) = 1 2 ( 0 − 0.5 ) 2 1 2 = 0.25 2 = 0.125 KL[q, p] = \frac{1}{2} \left( \frac{(\mu_q - \mu_p)^2}{\sigma_p^2} \right) = \frac{1}{2} \frac{(0 - 0.5)^2}{1^2} = \frac{0.25}{2} = 0.125 KL[q,p]=21(σp2(μqμp)2)=2112(00.5)2=20.25=0.125

我们用蒙特卡洛方法来估计这个值。

蒙特卡洛估计步骤
  1. 采样:从 ( q = N ( 0 , 1 ) q = \mathcal{N}(0, 1) q=N(0,1) ) 中抽取 ( N = 1000 N = 1000 N=1000 ) 个样本 ( x 1 , x 2 , … , x 1000 x_1, x_2, \dots, x_{1000} x1,x2,,x1000 )。假设样本是随机生成的,例如 ( x 1 = − 0.3 , x 2 = 1.2 , x 3 = − 1.5 , … x_1 = -0.3, x_2 = 1.2, x_3 = -1.5, \dots x1=0.3,x2=1.2,x3=1.5, )。
  2. 计算 ( k 1 k_1 k1 ):对每个样本 ( x i x_i xi ),计算:

k 1 ( x i ) = log ⁡ q ( x i ) p ( x i ) k_1(x_i) = \log \frac{q(x_i)}{p(x_i)} k1(xi)=logp(xi)q(xi)

正态分布的概率密度函数为:

q ( x ) = 1 2 π exp ⁡ ( − x 2 2 ) , p ( x ) = 1 2 π exp ⁡ ( − ( x − 0.5 ) 2 2 ) q(x) = \frac{1}{\sqrt{2\pi}} \exp\left(-\frac{x^2}{2}\right), \quad p(x) = \frac{1}{\sqrt{2\pi}} \exp\left(-\frac{(x - 0.5)^2}{2}\right) q(x)=2π 1exp(2x2),p(x)=2π 1exp(2(x0.5)2)

因此:

log ⁡ q ( x i ) p ( x i ) = log ⁡ q ( x i ) − log ⁡ p ( x i ) = ( − x i 2 2 ) − ( − ( x i − 0.5 ) 2 2 ) \log \frac{q(x_i)}{p(x_i)} = \log q(x_i) - \log p(x_i) = \left( -\frac{x_i^2}{2} \right) - \left( -\frac{(x_i - 0.5)^2}{2} \right) logp(xi)q(xi)=logq(xi)logp(xi)=(2xi2)(2(xi0.5)2)

= − x i 2 2 + ( x i − 0.5 ) 2 2 = ( x i − 0.5 ) 2 − x i 2 2 = x i 2 − x i + 0.25 − x i 2 2 = − x i + 0.25 2 = -\frac{x_i^2}{2} + \frac{(x_i - 0.5)^2}{2} = \frac{(x_i - 0.5)^2 - x_i^2}{2} = \frac{x_i^2 - x_i + 0.25 - x_i^2}{2} = \frac{-x_i + 0.25}{2} =2xi2+2(xi0.5)2=2(xi0.5)2xi2=2xi2xi+0.25xi2=2xi+0.25

k 1 ( x i ) = 0.25 − x i 2 k_1(x_i) = \frac{0.25 - x_i}{2} k1(xi)=20.25xi

  1. 估计KL散度:计算所有样本的均值:

K L ^ [ q , p ] = 1 N ∑ i = 1 N k 1 ( x i ) = 1 N ∑ i = 1 N 0.25 − x i 2 \hat{KL}[q, p] = \frac{1}{N} \sum_{i=1}^N k_1(x_i) = \frac{1}{N} \sum_{i=1}^N \frac{0.25 - x_i}{2} KL^[q,p]=N1i=1Nk1(xi)=N1i=1N20.25xi

假设我们采样得到的 ( x i x_i xi ) 的均值为 ( x ˉ ≈ 0 \bar{x} \approx 0 xˉ0 )(因为 ( x i ∼ N ( 0 , 1 ) x_i \sim \mathcal{N}(0, 1) xiN(0,1) )),则:

K L ^ [ q , p ] ≈ 1 1000 ∑ i = 1 1000 0.25 − x i 2 ≈ 0.25 − 0 2 = 0.125 \hat{KL}[q, p] \approx \frac{1}{1000} \sum_{i=1}^{1000} \frac{0.25 - x_i}{2} \approx \frac{0.25 - 0}{2} = 0.125 KL^[q,p]10001i=1100020.25xi20.250=0.125

这个估计值接近真实的 ( K L [ q , p ] = 0.125 KL[q, p] = 0.125 KL[q,p]=0.125 ),说明 ( k 1 k_1 k1 ) 是无偏的。

高方差的表现

尽管 ( k 1 k_1 k1 ) 无偏,但它的值在样本间波动较大。例如:

  • 如果 ( x i = 0 x_i = 0 xi=0 ),则 ( k 1 = 0.25 − 0 2 = 0.125 k_1 = \frac{0.25 - 0}{2} = 0.125 k1=20.250=0.125 )。
  • 如果 ( x i = 2 x_i = 2 xi=2 ),则 ( k 1 = 0.25 − 2 2 = − 0.875 k_1 = \frac{0.25 - 2}{2} = -0.875 k1=20.252=0.875 )。
  • 如果 ( x i = − 2 x_i = -2 xi=2 ),则 ( k 1 = 0.25 − ( − 2 ) 2 = 1.125 k_1 = \frac{0.25 - (-2)}{2} = 1.125 k1=20.25(2)=1.125 ).

这些值的正负波动导致样本均值的方差较高。如果采样次数 ( N N N ) 不够多,估计值可能偏离0.125较远。

反映分布距离

从 ( k 1 ( x i ) = 0.25 − x i 2 k_1(x_i) = \frac{0.25 - x_i}{2} k1(xi)=20.25xi ) 可以看出,( k 1 k_1 k1 ) 的值直接与样本 ( x i x_i xi) 有关,而样本是由 ( q q q ) 生成的。KL散度的值0.125反映了 ( q q q ) 和 ( p p p ) 之间的均值差异(0 vs 0.5)。如果我们改变 ( p p p ) 的均值(例如 ( p = N ( 1 , 1 ) p = \mathcal{N}(1, 1) p=N(1,1) )),KL散度会变大(( K L [ q , p ] = 0.5 KL[q, p] = 0.5 KL[q,p]=0.5 )),相应的 ( k 1 k_1 k1 ) 的均值也会变大,反映出分布间更大的差异。


总结

  • 蒙特卡洛估计器是通过从分布 ( q q q ) 中采样,用样本上的函数值均值来近似KL散度的期望。
  • 反映分布距离:( k 1 = log ⁡ q ( x ) p ( x ) k_1 = \log \frac{q(x)}{p(x)} k1=logp(x)q(x) ) 的期望等于KL散度,因此它确实捕捉了 ( q q q ) 和 ( p p p ) 之间的差异。
  • 局限性:高方差使得 ( k 1 k_1 k1 ) 在有限样本下不够稳定,如例子中 ( k 1 k_1 k1 ) 的值在正负间波动。
  • 改进方向:Schulman的博客中提出了 ( k 2 k_2 k2 ) 和 ( k 3 k_3 k3 ) 等估计器,通过降低方差(例如保证非负性或引入控制变量)来改进 ( k 1 k_1 k1 ),在DeepSeek GRPO等应用中表现出色。

log相减就是k1估计器

在许多大语言模型(LLM)强化学习(RL)相关的论文中,尤其是那些涉及策略优化(如PPO、TRPO或RLHF)的论文,使用的KL散度估计通常是基于 ( k 1 k_1 k1 ) 估计器,即直接计算对数概率的差值:

k 1 = log ⁡ q ( x ) p ( x ) = log ⁡ q ( x ) − log ⁡ p ( x ) k_1 = \log \frac{q(x)}{p(x)} = \log q(x) - \log p(x) k1=logp(x)q(x)=logq(x)logp(x)

或者在策略优化的上下文中,写作:

k 1 = log ⁡ π ref ( a ∣ s ) π θ ( a ∣ s ) = log ⁡ π ref ( a ∣ s ) − log ⁡ π θ ( a ∣ s ) k_1 = \log \frac{\pi_{\text{ref}}(a|s)}{\pi_{\theta}(a|s)} = \log \pi_{\text{ref}}(a|s) - \log \pi_{\theta}(a|s) k1=logπθ(as)πref(as)=logπref(as)logπθ(as)

其中,( π θ \pi_{\theta} πθ ) 是当前策略,( π ref \pi_{\text{ref}} πref ) 是参考策略(如初始策略或SFT模型),( a a a ) 是动作(在LLM中通常是生成的token),( s s s ) 是状态(上下文或提示)。这种形式的KL散度估计正是Schulman博客中描述的 ( k 1 k_1 k1 ) 估计器,基于蒙特卡洛方法,通过采样计算对数概率差的均值来近似 ( K L [ π θ ∣ ∣ π ref ] KL[\pi_{\theta} || \pi_{\text{ref}}] KL[πθ∣∣πref] ) 或 ( K L [ π ref ∣ ∣ π θ ] KL[\pi_{\text{ref}} || \pi_{\theta}] KL[πref∣∣πθ] )。

为什么常用 ( k 1 k_1 k1 ) 估计器?

  1. 简单直接:( k 1 k_1 k1 ) 是KL散度的直接蒙特卡洛估计,形式上就是对数概率的差值,易于实现。LLM通常会输出每个token的对数概率(logits经过softmax后取log),因此计算 ( log ⁡ π θ ( a ∣ s ) \log \pi_{\theta}(a|s) logπθ(as) ) 和 ( log ⁡ π ref ( a ∣ s ) \log \pi_{\text{ref}}(a|s) logπref(as) ) 是现成的。

  2. 无偏性:( k 1 k_1 k1 ) 是无偏估计器,其期望等于真实的KL散度:

E a ∼ π θ [ log ⁡ π θ ( a ∣ s ) π ref ( a ∣ s ) ] = K L [ π θ ∣ ∣ π ref ] \mathbb{E}_{a \sim \pi_{\theta}} \left[ \log \frac{\pi_{\theta}(a|s)}{\pi_{\text{ref}}(a|s)} \right] = KL[\pi_{\theta} || \pi_{\text{ref}}] Eaπθ[logπref(as)πθ(as)]=KL[πθ∣∣πref]

这保证了在采样足够多的情况下,估计值会收敛到真实KL散度。

  1. 常见应用场景:在强化学习中,KL散度常用于以下场景:
    • 正则化项:如PPO中,KL散度作为惩罚项,限制新策略 ( π θ \pi_{\theta} πθ ) 偏离参考策略 ( π ref \pi_{\text{ref}} πref ) 过远。
    • 诊断工具:监控策略更新的稳定性。
    • 奖励函数:在RLHF中,KL散度常用于奖励函数(如 ( R = R human − β K L [ π θ ∣ ∣ π ref ] R = R_{\text{human}} - \beta KL[\pi_{\theta} || \pi_{\text{ref}}] R=RhumanβKL[πθ∣∣πref] )),以平衡偏好和策略稳定性。

这些场景中,( k 1 k_1 k1 ) 的形式简单,且与策略梯度计算兼容,因此被广泛采用。

( k 1 k_1 k1 ) 估计器的局限性在LLM RL中的体现

尽管 ( k 1 k_1 k1 ) 常用,但正如Schulman的博客所指出的,它有高方差的缺点。这在LLM RL中可能表现为:

  • 方差导致的不稳定:由于 ( k 1 = log ⁡ π θ ( a ∣ s ) π ref ( a ∣ s ) k_1 = \log \frac{\pi_{\theta}(a|s)}{\pi_{\text{ref}}(a|s)} k1=logπref(as)πθ(as) ) 可能为正或负(当 ( π θ ( a ∣ s ) > π ref ( a ∣ s ) \pi_{\theta}(a|s) > \pi_{\text{ref}}(a|s) πθ(as)>πref(as) ) 时为正,反之为负),样本间的波动会导致KL散度估计不稳定,尤其在序列生成中,token序列较长时方差会累积。
  • 采样效率:LLM生成序列的采样成本高(需要运行整个模型),如果 ( k 1 k_1 k1 ) 的高方差要求更多样本才能得到可靠估计,会增加计算开销。
  • 极端概率的影响:当 ( π θ ( a ∣ s ) \pi_{\theta}(a|s) πθ(as) ) 或 ( π ref ( a ∣ s ) \pi_{\text{ref}}(a|s) πref(as) ) 接近0时,( log ⁡ π θ ( a ∣ s ) π ref ( a ∣ s ) \log \frac{\pi_{\theta}(a|s)}{\pi_{\text{ref}}(a|s)} logπref(as)πθ(as) ) 可能变得很大或很小,导致估计值的异常波动。

总结

是的,许多LLM强化学习论文中使用的KL散度估计确实是 ( k 1 k_1 k1 ) 估计器(( log ⁡ π θ π ref \log \frac{\pi_{\theta}}{\pi_{\text{ref}}} logπrefπθ )),因为它简单、无偏且易于实现。然而,其高方差可能导致不稳定,尤其在长序列生成或概率差异较大的情况下。一些前沿工作(如DeepSeek GRPO)采用了Schulman提出的 ( k 3 k_3 k3 ) 估计器,以获得更低的方差和更稳定的优化。如果你阅读的论文明确提到KL散度惩罚,通常默认是 ( k 1 k_1 k1 ),但需要检查具体实现或损失函数是否使用了改进的估计器(如 ( k 3 k_3 k3 ) 或其他变体)。

f-散度和k2

什么是f-散度?

f-散度(f-divergence)是一类用于衡量两个概率分布 ( p p p ) 和 ( q q q ) 之间差异的广义距离度量。它通过一个凸函数 ( f f f ) 定义,形式如下:

D f ( p , q ) = E x ∼ q [ f ( p ( x ) q ( x ) ) ] D_f(p, q) = \mathbb{E}_{x \sim q} \left[ f\left( \frac{p(x)}{q(x)} \right) \right] Df(p,q)=Exq[f(q(x)p(x))]

其中:

  • ( p ( x ) p(x) p(x) ) 和 ( q ( x ) q(x) q(x) ) 是分布 ( p p p ) 和 ( q q q ) 在点 ( x x x ) 上的概率密度(或概率质量,视离散/连续分布而定)。
  • ( r = p ( x ) q ( x ) r = \frac{p(x)}{q(x)} r=q(x)p(x) ) 是概率比率。
  • ( f : R + → R f: \mathbb{R}^+ \to \mathbb{R} f:R+R ) 是一个凸函数,且通常满足 ( f ( 1 ) = 0 f(1) = 0 f(1)=0 ),以确保当 ( p = q p = q p=q ) 时散度为零。

f-散度的特点是它包含了许多常见的分布距离度量,通过选择不同的 ( f f f ) 函数,可以得到不同的散度。例如:

  • KL散度:当 ( f ( u ) = u log ⁡ u f(u) = u \log u f(u)=ulogu ),对应 ( D f ( p , q ) = K L ( p ∣ ∣ q ) = E q [ p ( x ) q ( x ) log ⁡ p ( x ) q ( x ) ] D_f(p, q) = KL(p || q) = \mathbb{E}_{q} \left[ \frac{p(x)}{q(x)} \log \frac{p(x)}{q(x)} \right] Df(p,q)=KL(p∣∣q)=Eq[q(x)p(x)logq(x)p(x)] )。
  • 逆KL散度:当 ( f ( u ) = − log ⁡ u f(u) = -\log u f(u)=logu ),对应 ( D f ( p , q ) = K L ( q ∣ ∣ p ) D_f(p, q) = KL(q || p) Df(p,q)=KL(q∣∣p) )。
  • Jensen-Shannon散度:通过选择适当的 ( f f f ) 函数,可以推导JS散度。
  • 总变差距离:当 ( f ( u ) = ∣ u − 1 ∣ f(u) = |u - 1| f(u)=u1∣ ),对应总变差。

f-散度的优点是它提供了一个统一的框架来研究不同分布距离的性质,且许多f-散度在 ( p ≈ q p \approx q pq ) 时具有相似的二阶行为(如Schulman博客中提到的与KL散度的关系)。

Schulman博客中的f-散度背景

在John Schulman的博客《Approximating KL Divergence》中,他讨论了如何通过蒙特卡洛方法近似KL散度 ( K L [ q , p ] KL[q, p] KL[q,p] ),并提出了一种低方差但有偏的估计器 ( k 2 k_2 k2 )。他指出,( k 2 k_2 k2 ) 的期望是一个f-散度,并且在 ( p ≈ q p \approx q pq ) 时,所有可微的f-散度在二阶近似下与KL散度等价。这种等价性是 ( k 2 k_2 k2 ) 低偏倚的理论基础。

具体来说,Schulman定义了 ( k 2 k_2 k2 ) 估计器为:

k 2 = 1 2 ( log ⁡ p ( x ) q ( x ) ) 2 = 1 2 ( log ⁡ r ) 2 , 其中 r = p ( x ) q ( x ) k_2 = \frac{1}{2} \left( \log \frac{p(x)}{q(x)} \right)^2 = \frac{1}{2} (\log r)^2, \quad \text{其中} \quad r = \frac{p(x)}{q(x)} k2=21(logq(x)p(x))2=21(logr)2,其中r=q(x)p(x)

他进一步说明,( k 2 k_2 k2 ) 的期望是一个f-散度:

E x ∼ q [ k 2 ] = E x ∼ q [ 1 2 ( log ⁡ r ) 2 ] \mathbb{E}_{x \sim q}[k_2] = \mathbb{E}_{x \sim q} \left[ \frac{1}{2} (\log r)^2 \right] Exq[k2]=Exq[21(logr)2]

这对应于f-散度中的凸函数:

f ( u ) = 1 2 ( log ⁡ u ) 2 f(u) = \frac{1}{2} (\log u)^2 f(u)=21(logu)2

为了理解 ( k 2 k_2 k2 ) 的来源和为何它是一个合理的KL散度估计器,我们需要结合f-散度的性质和Schulman的推导思路进行分析。


推导 ( k 2 k_2 k2 ) 的来源

以下是结合Schulman博客内容,推导 ( k 2 k_2 k2 ) 如何作为KL散度估计器的过程:

1. KL散度的蒙特卡洛估计

KL散度 ( K L [ q , p ] KL[q, p] KL[q,p] ) 的定义为:

K L [ q , p ] = E x ∼ q [ log ⁡ q ( x ) p ( x ) ] = E x ∼ q [ − log ⁡ r ] , r = p ( x ) q ( x ) KL[q, p] = \mathbb{E}_{x \sim q} \left[ \log \frac{q(x)}{p(x)} \right] = \mathbb{E}_{x \sim q} [-\log r], \quad r = \frac{p(x)}{q(x)} KL[q,p]=Exq[logp(x)q(x)]=Exq[logr],r=q(x)p(x)

最直接的蒙特卡洛估计器是 ( k 1 k_1 k1 ):

k 1 = log ⁡ q ( x ) p ( x ) = − log ⁡ r k_1 = \log \frac{q(x)}{p(x)} = -\log r k1=logp(x)q(x)=logr

( k 1 k_1 k1 ) 是无偏的,但由于 ( log ⁡ r \log r logr ) 可正可负(( r > 1 r > 1 r>1 ) 时 ( log ⁡ r > 0 \log r > 0 logr>0 ),( r < 1 r < 1 r<1 ) 时 ( log ⁡ r < 0 \log r < 0 logr<0 )),其方差较高,而KL散度本身始终非负。这种高方差促使Schulman寻找更稳定的估计器。

2. 构造低方差估计器的动机

Schulman的目标是构造一个估计器,既能反映 ( p p p ) 和 ( q q q ) 之间的差异,又具有较低的方差。一个直观的思路是:既然KL散度衡量分布差异,可以尝试用某种“距离”函数来替代 ( log ⁡ q ( x ) p ( x ) \log \frac{q(x)}{p(x)} logp(x)q(x) ),使得估计值始终非负(与KL散度的性质一致),并且方差较小。

考虑 ( log ⁡ r = log ⁡ p ( x ) q ( x ) \log r = \log \frac{p(x)}{q(x)} logr=logq(x)p(x)),它直接衡量了概率比率的对数差异。如果我们取其平方:

( log ⁡ r ) 2 = ( log ⁡ p ( x ) q ( x ) ) 2 (\log r)^2 = \left( \log \frac{p(x)}{q(x)} \right)^2 (logr)2=(logq(x)p(x))2

这个量有以下优点:

  • 非负性:( ( log ⁡ r ) 2 ≥ 0 (\log r)^2 \geq 0 (logr)20 ),与KL散度的非负性质一致。
  • 对称性:它反映了 ( log ⁡ r \log r logr ) 的大小,无论 ( log ⁡ r \log r logr ) 是正还是负,都表示 ( p p p ) 和 ( q q q ) 的差异。
  • 平滑性:平方函数会放大较大的 ( log ⁡ r \log r logr ) 值,强调分布差异较大的样本。

为了使量纲合理(KL散度的量纲类似于对数),Schulman引入一个系数 ( 1 2 \frac{1}{2} 21 ),定义:

k 2 = 1 2 ( log ⁡ r ) 2 = 1 2 ( log ⁡ p ( x ) q ( x ) ) 2 k_2 = \frac{1}{2} (\log r)^2 = \frac{1}{2} \left( \log \frac{p(x)}{q(x)} \right)^2 k2=21(logr)2=21(logq(x)p(x))2

3. ( k 2 k_2 k2 ) 的期望是f-散度

计算 ( k 2 k_2 k2 ) 的期望:

E x ∼ q [ k 2 ] = E x ∼ q [ 1 2 ( log ⁡ p ( x ) q ( x ) ) 2 ] = E x ∼ q [ 1 2 ( log ⁡ r ) 2 ] \mathbb{E}_{x \sim q}[k_2] = \mathbb{E}_{x \sim q} \left[ \frac{1}{2} \left( \log \frac{p(x)}{q(x)} \right)^2 \right] = \mathbb{E}_{x \sim q} \left[ \frac{1}{2} (\log r)^2 \right] Exq[k2]=Exq[21(logq(x)p(x))2]=Exq[21(logr)2]

这正是f-散度的形式,其中:

f ( u ) = 1 2 ( log ⁡ u ) 2 f(u) = \frac{1}{2} (\log u)^2 f(u)=21(logu)2

我们需要验证 ( f ( u ) f(u) f(u) ) 是否满足f-散度的要求:

  • 凸性:计算 ( f ( u ) f(u) f(u) ) 的二阶导数:

f ( u ) = 1 2 ( log ⁡ u ) 2 f(u) = \frac{1}{2} (\log u)^2 f(u)=21(logu)2

f ′ ( u ) = 1 2 ⋅ 2 log ⁡ u ⋅ 1 u = log ⁡ u u f'(u) = \frac{1}{2} \cdot 2 \log u \cdot \frac{1}{u} = \frac{\log u}{u} f(u)=212loguu1=ulogu

f ′ ′ ( u ) = 1 u ⋅ 1 u + log ⁡ u u 2 = 1 u 2 + log ⁡ u u 2 = 1 + log ⁡ u u 2 f''(u) = \frac{1}{u} \cdot \frac{1}{u} + \frac{\log u}{u^2} = \frac{1}{u^2} + \frac{\log u}{u^2} = \frac{1 + \log u}{u^2} f′′(u)=u1u1+u2logu=u21+u2logu=u21+logu

由于 ( u > 0 u > 0 u>0 ),( f ′ ′ ( u ) f''(u) f′′(u) ) 不总是正(例如当 ( u < e − 1 ≈ 0.367 u < e^{-1} \approx 0.367 u<e10.367 ) 时,( 1 + log ⁡ u < 0 1 + \log u < 0 1+logu<0 )),但在实际应用中,( r = p ( x ) q ( x ) r = \frac{p(x)}{q(x)} r=q(x)p(x) ) 通常在 ( ( 0 , ∞ ) (0, \infty) (0,) ) 范围内变化,且在 ( p ≈ q p \approx q pq ) 时 ( r ≈ 1 r \approx 1 r1 ),此时 ( f ′ ′ ( 1 ) = 1 > 0 f''(1) = 1 > 0 f′′(1)=1>0 )。因此,在 ( r ≈ 1 r \approx 1 r1 ) 的局部区域,( f ( u ) f(u) f(u) ) 是凸的,满足f-散度的要求。

  • ( f ( 1 ) = 0 f(1) = 0 f(1)=0 )

f ( 1 ) = 1 2 ( log ⁡ 1 ) 2 = 0 f(1) = \frac{1}{2} (\log 1)^2 = 0 f(1)=21(log1)2=0

这保证当 ( p = q p = q p=q )(即 ( r = 1 r = 1 r=1 ))时,散度为零。

因此,( E x ∼ q [ k 2 ] \mathbb{E}_{x \sim q}[k_2] Exq[k2] ) 是一个有效的f-散度。

4. 为何 ( k 2 k_2 k2 ) 是KL散度的合理近似?

Schulman指出,所有可微的f-散度在 ( p ≈ q p \approx q pq ) 时,在二阶近似下与KL散度等价。让我们推导这一关键结论。

假设 ( p θ p_{\theta} pθ ) 是一个参数化的分布,( p θ 0 = p 0 p_{\theta_0} = p_0 pθ0=p0 ),我们考虑 ( D f ( p 0 , p θ ) D_f(p_0, p_{\theta}) Df(p0,pθ) ) 在 ( θ ≈ θ 0 \theta \approx \theta_0 θθ0 ) 时的泰勒展开。Schulman给出了近似:

D f ( p 0 , p θ ) = f ′ ′ ( 1 ) 2 θ T F θ + O ( θ 3 ) D_f(p_0, p_{\theta}) = \frac{f''(1)}{2} \theta^T F \theta + O(\theta^3) Df(p0,pθ)=2f′′(1)θT+O(θ3)

其中 ( F F F ) 是 ( p θ p_{\theta} pθ ) 在 ( θ = θ 0 \theta = \theta_0 θ=θ0 ) 处的Fisher信息矩阵。关键在于 ( f ′ ′ ( 1 ) f''(1) f′′(1) )。我们计算 ( f ( u ) = 1 2 ( log ⁡ u ) 2 f(u) = \frac{1}{2} (\log u)^2 f(u)=21(logu)2 ) 的二阶导数:

f ′ ′ ( u ) = 1 + log ⁡ u u 2 , f ′ ′ ( 1 ) = 1 + log ⁡ 1 1 2 = 1 f''(u) = \frac{1 + \log u}{u^2}, \quad f''(1) = \frac{1 + \log 1}{1^2} = 1 f′′(u)=u21+logu,f′′(1)=121+log1=1

对于KL散度 ( K L [ q , p ] KL[q, p] KL[q,p] ),对应的f-散度函数是 ( f ( u ) = − log ⁡ u f(u) = -\log u f(u)=logu ):

f ′ ( u ) = − 1 u , f ′ ′ ( u ) = 1 u 2 , f ′ ′ ( 1 ) = 1 f'(u) = -\frac{1}{u}, \quad f''(u) = \frac{1}{u^2}, \quad f''(1) = 1 f(u)=u1,f′′(u)=u21,f′′(1)=1

可以看到,( f ′ ′ ( 1 ) = 1 f''(1) = 1 f′′(1)=1 ) 对于 ( k 2 k_2 k2 ) 的 ( f ( u ) = 1 2 ( log ⁡ u ) 2 f(u) = \frac{1}{2} (\log u)^2 f(u)=21(logu)2 ) 和KL散度的 ( f ( u ) = − log ⁡ u f(u) = -\log u f(u)=logu ) 都是相同的。因此,在 ( p ≈ q p \approx q pq ) 时,( E x ∼ q [ k 2 ] \mathbb{E}_{x \sim q}[k_2] Exq[k2] ) 和 ( K L [ q , p ] KL[q, p] KL[q,p] ) 的二阶行为一致:

E x ∼ q [ k 2 ] ≈ 1 2 θ T F θ ≈ K L [ q , p ] \mathbb{E}_{x \sim q}[k_2] \approx \frac{1}{2} \theta^T F \theta \approx KL[q, p] Exq[k2]21θTKL[q,p]

这解释了为何 ( k 2 k_2 k2 ) 虽然有偏(( E [ k 2 ] ≠ K L [ q , p ] \mathbb{E}[k_2] \neq KL[q, p] E[k2]=KL[q,p] )),但偏倚很小,尤其当 ( p p p ) 和 ( q q q ) 接近时。

5. ( k 2 k_2 k2 ) 的直观解释

( k 2 = 1 2 ( log ⁡ p ( x ) q ( x ) ) 2 k_2 = \frac{1}{2} \left( \log \frac{p(x)}{q(x)} \right)^2 k2=21(logq(x)p(x))2 ) 的设计灵感可以理解为:

  • 惩罚对数差异:( log ⁡ p ( x ) q ( x ) \log \frac{p(x)}{q(x)} logq(x)p(x) ) 衡量了 ( p p p ) 和 ( q q q ) 在对数尺度上的相对差异,平方后强调了较大的差异。
  • 低方差:由于 ( ( log ⁡ r ) 2 ≥ 0 (\log r)^2 \geq 0 (logr)20 ),每个样本的贡献都是非负的,避免了 ( k 1 = − log ⁡ r k_1 = -\log r k1=logr ) 中正负抵消导致的高方差。
  • f-散度的自然形式:平方形式 ( f ( u ) = 1 2 ( log ⁡ u ) 2 f(u) = \frac{1}{2} (\log u)^2 f(u)=21(logu)2 ) 是一个自然的凸函数选择,且在 ( r ≈ 1 r \approx 1 r1 ) 时与KL散度行为相似。

Schulman通过实验验证了 ( k 2 k_2 k2 ) 的低方差和低偏倚。例如,当 ( q = N ( 0 , 1 ) q = \mathcal{N}(0, 1) q=N(0,1) )、( p = N ( 0.1 , 1 ) p = \mathcal{N}(0.1, 1) p=N(0.1,1) )(真实KL散度为0.005)时,( k 2 k_2 k2 ) 的偏倚仅为0.2%,标准差远低于 ( k 1 k_1 k1 )。


结合Schulman博客的总结

Schulman的博客通过引入 ( k 2 = 1 2 ( log ⁡ p ( x ) q ( x ) ) 2 k_2 = \frac{1}{2} \left( \log \frac{p(x)}{q(x)} \right)^2 k2=21(logq(x)p(x))2 ) 提供了一种低方差的KL散度估计器,其期望是一个f-散度,对应的凸函数为 ( f ( u ) = 1 2 ( log ⁡ u ) 2 f(u) = \frac{1}{2} (\log u)^2 f(u)=21(logu)2 )。推导过程如下:

  1. 从KL散度的蒙特卡洛估计出发,识别 ( k 1 = log ⁡ q ( x ) p ( x ) k_1 = \log \frac{q(x)}{p(x)} k1=logp(x)q(x) ) 的高方差问题。
  2. 构造 ( k 2 = 1 2 ( log ⁡ r ) 2 k_2 = \frac{1}{2} (\log r)^2 k2=21(logr)2 ),使其非负且方差较低。
  3. 证明 ( E [ k 2 ] \mathbb{E}[k_2] E[k2] ) 是一个f-散度,且在 ( p ≈ q p \approx q pq ) 时与KL散度在二阶近似下等价(因为 ( f ′ ′ ( 1 ) = 1 f''(1) = 1 f′′(1)=1 ))。
  4. 通过实验验证 ( k 2 k_2 k2 ) 的低偏倚和低方差。

f-散度的框架为 ( k 2 k_2 k2 ) 的合理性提供了理论支持,因为它表明不同凸函数 ( f f f ) 定义的散度在 ( p ≈ q p \approx q pq ) 时具有相似的行为。这也为后续 ( k 3 k_3 k3 ) 等无偏估计器的设计提供了灵感(如通过控制变量进一步优化)。


k3的思考

介绍 ( k 3 k_3 k3 ) 的来源及推导

在John Schulman的博客《Approximating KL Divergence》中,他提出了三种KL散度估计器:( k 1 k_1 k1 )、( k 2 k_2 k2 ) 和 ( k 3 k_3 k3 )。我们已经讨论了 ( k 1 = log ⁡ q ( x ) p ( x ) = − log ⁡ r k_1 = \log \frac{q(x)}{p(x)} = -\log r k1=logp(x)q(x)=logr )(无偏但高方差)和 ( k 2 = 1 2 ( log ⁡ p ( x ) q ( x ) ) 2 = 1 2 ( log ⁡ r ) 2 k_2 = \frac{1}{2} \left( \log \frac{p(x)}{q(x)} \right)^2 = \frac{1}{2} (\log r)^2 k2=21(logq(x)p(x))2=21(logr)2 )(低方差但有偏)。( k 3 k_3 k3 ) 是Schulman提出的第三个估计器,目标是兼顾 无偏性低方差,定义为:

k 3 = ( r − 1 ) − log ⁡ r , 其中 r = p ( x ) q ( x ) k_3 = \left( r - 1 \right) - \log r, \quad \text{其中} \quad r = \frac{p(x)}{q(x)} k3=(r1)logr,其中r=q(x)p(x)

本节将结合Schulman的博客内容,详细推导 ( k 3 k_3 k3 ) 的来源,分析其与 ( k 2 k_2 k2 ) 推导的联系,并解释其理论基础和优势。


推导 ( k 3 k_3 k3 ) 的来源

1. 从 ( k 1 k_1 k1 ) 的高方差出发

KL散度 ( K L [ q , p ] KL[q, p] KL[q,p] ) 的定义为:

K L [ q , p ] = E x ∼ q [ log ⁡ q ( x ) p ( x ) ] = E x ∼ q [ − log ⁡ r ] , r = p ( x ) q ( x ) KL[q, p] = \mathbb{E}_{x \sim q} \left[ \log \frac{q(x)}{p(x)} \right] = \mathbb{E}_{x \sim q} [-\log r], \quad r = \frac{p(x)}{q(x)} KL[q,p]=Exq[logp(x)q(x)]=Exq[logr],r=q(x)p(x)

标准蒙特卡洛估计器 ( k 1 = − log ⁡ r k_1 = -\log r k1=logr ) 是无偏的,但由于 ( log ⁡ r \log r logr ) 可正可负(( r > 1 r > 1 r>1 ) 时 ( log ⁡ r > 0 \log r > 0 logr>0 ),( r < 1 r < 1 r<1 ) 时 ( log ⁡ r < 0 \log r < 0 logr<0 )),其方差较高。Schulman的目标是构造一个估计器,既保留 ( k 1 k_1 k1 ) 的无偏性,又降低方差。

2. 控制变量方法的引入

为了降低 ( k 1 k_1 k1 ) 的方差,Schulman引入了 控制变量(control variate) 方法。控制变量方法是一种常用的方差缩减技术,核心思想是:在估计器中加入一个期望为零的项,这个项与原估计器负相关,从而抵消部分波动。

在KL散度的估计中,我们需要找到一个量 ( g ( x ) g(x) g(x) ),满足:

E x ∼ q [ g ( x ) ] = 0 \mathbb{E}_{x \sim q} [g(x)] = 0 Exq[g(x)]=0

然后构造新的估计器:

k new = k 1 + λ g ( x ) k_{\text{new}} = k_1 + \lambda g(x) knew=k1+λg(x)

只要 ( E x ∼ q [ g ( x ) ] = 0 \mathbb{E}_{x \sim q} [g(x)] = 0 Exq[g(x)]=0 ),新估计器仍然无偏:

E x ∼ q [ k new ] = E x ∼ q [ k 1 ] + λ E x ∼ q [ g ( x ) ] = K L [ q , p ] + 0 = K L [ q , p ] \mathbb{E}_{x \sim q} [k_{\text{new}}] = \mathbb{E}_{x \sim q} [k_1] + \lambda \mathbb{E}_{x \sim q} [g(x)] = KL[q, p] + 0 = KL[q, p] Exq[knew]=Exq[k1]+λExq[g(x)]=KL[q,p]+0=KL[q,p]

Schulman注意到,一个自然的控制变量是:

g ( x ) = r − 1 = p ( x ) q ( x ) − 1 g(x) = r - 1 = \frac{p(x)}{q(x)} - 1 g(x)=r1=q(x)p(x)1

其期望为:

E x ∼ q [ r − 1 ] = E x ∼ q [ p ( x ) q ( x ) ] − 1 = ∫ q ( x ) ⋅ p ( x ) q ( x )   d x − 1 = ∫ p ( x )   d x − 1 = 1 − 1 = 0 \mathbb{E}_{x \sim q} [r - 1] = \mathbb{E}_{x \sim q} \left[ \frac{p(x)}{q(x)} \right] - 1 = \int q(x) \cdot \frac{p(x)}{q(x)} \, dx - 1 = \int p(x) \, dx - 1 = 1 - 1 = 0 Exq[r1]=Exq[q(x)p(x)]1=q(x)q(x)p(x)dx1=p(x)dx1=11=0

因此,( r − 1 r - 1 r1 ) 是一个期望为零的量,适合作为控制变量。于是,Schulman构造了新的估计器:

k new = k 1 + λ ( r − 1 ) = − log ⁡ r + λ ( r − 1 ) k_{\text{new}} = k_1 + \lambda (r - 1) = -\log r + \lambda (r - 1) knew=k1+λ(r1)=logr+λ(r1)

这里的 ( λ \lambda λ ) 是一个可调参数,目标是通过选择合适的 ( λ \lambda λ ) 最小化估计器的方差。

3. 选择 ( λ \lambda λ ) 以保证非负性

理论上,可以通过最小化 ( k new k_{\text{new}} knew ) 的方差来求解最优的 ( λ \lambda λ )。方差为:

Var ( k new ) = Var ( − log ⁡ r + λ ( r − 1 ) ) = Var ( − log ⁡ r ) + λ 2 Var ( r − 1 ) + 2 λ Cov ( − log ⁡ r , r − 1 ) \text{Var}(k_{\text{new}}) = \text{Var}(-\log r + \lambda (r - 1)) = \text{Var}(-\log r) + \lambda^2 \text{Var}(r - 1) + 2 \lambda \text{Cov}(-\log r, r - 1) Var(knew)=Var(logr+λ(r1))=Var(logr)+λ2Var(r1)+2λCov(logr,r1)

最小化方差需要计算协方差和方差项,但这些项依赖于 ( p p p ) 和 ( q q q ) 的具体形式,解析求解往往复杂。Schulman提到,这种方法得到的 ( λ \lambda λ ) 通常难以计算,因此他采用了更简单的策略:选择 ( λ \lambda λ ) 使估计器 始终非负,从而直观上降低方差并与KL散度的非负性质一致。

Schulman利用对数函数的凹性,注意到:

log ⁡ u ≤ u − 1 , ∀ u > 0 \log u \leq u - 1, \quad \forall u > 0 loguu1,u>0

等号在 ( u = 1 u = 1 u=1 ) 时成立。这意味着:

− log ⁡ r ≥ − ( r − 1 ) -\log r \geq -(r - 1) logr(r1)

因此,如果选择 ( λ = 1 \lambda = 1 λ=1 ),估计器变为:

k 3 = − log ⁡ r + ( r − 1 ) = ( r − 1 ) − log ⁡ r k_3 = -\log r + (r - 1) = (r - 1) - \log r k3=logr+(r1)=(r1)logr

我们可以验证 ( k 3 k_3 k3 ) 的非负性:

k 3 = ( r − 1 ) − log ⁡ r k_3 = (r - 1) - \log r k3=(r1)logr

根据凹性不等式 ( log ⁡ r ≤ r − 1 \log r \leq r - 1 logrr1 ),有:

( r − 1 ) − log ⁡ r ≥ ( r − 1 ) − ( r − 1 ) = 0 (r - 1) - \log r \geq (r - 1) - (r - 1) = 0 (r1)logr(r1)(r1)=0

因此,( k 3 ≥ 0 k_3 \geq 0 k30 ),并且在 ( r = 1 r = 1 r=1 )(即 ( p ( x ) = q ( x ) p(x) = q(x) p(x)=q(x) ))时,( k 3 = 0 k_3 = 0 k3=0 )。这种非负性使得 ( k 3 k_3 k3 ) 的样本值更稳定,避免了 ( k 1 k_1 k1 ) 中正负抵消导致的高方差。

4. 验证 ( k 3 k_3 k3 ) 的无偏性

我们需要确认 ( k 3 k_3 k3 ) 是否无偏:

E x ∼ q [ k 3 ] = E x ∼ q [ ( r − 1 ) − log ⁡ r ] = E x ∼ q [ r − 1 ] − E x ∼ q [ log ⁡ r ] \mathbb{E}_{x \sim q} [k_3] = \mathbb{E}_{x \sim q} [(r - 1) - \log r] = \mathbb{E}_{x \sim q} [r - 1] - \mathbb{E}_{x \sim q} [\log r] Exq[k3]=Exq[(r1)logr]=Exq[r1]Exq[logr]

  • 第一项:( E x ∼ q [ r − 1 ] = 0 \mathbb{E}_{x \sim q} [r - 1] = 0 Exq[r1]=0 )(如前所述)。
  • 第二项:( E x ∼ q [ log ⁡ r ] = E x ∼ q [ log ⁡ p ( x ) q ( x ) ] = E x ∼ q [ log ⁡ q ( x ) p ( x ) ] = K L [ q , p ] \mathbb{E}_{x \sim q} [\log r] = \mathbb{E}_{x \sim q} \left[ \log \frac{p(x)}{q(x)} \right] = \mathbb{E}_{x \sim q} \left[ \log \frac{q(x)}{p(x)} \right] = KL[q, p] Exq[logr]=Exq[logq(x)p(x)]=Exq[logp(x)q(x)]=KL[q,p] ).

因此:

E x ∼ q [ k 3 ] = 0 − ( − K L [ q , p ] ) = K L [ q , p ] \mathbb{E}_{x \sim q} [k_3] = 0 - (-KL[q, p]) = KL[q, p] Exq[k3]=0(KL[q,p])=KL[q,p]

( k 3 k_3 k3 ) 是无偏的,满足我们对KL散度估计器的要求。

5. ( k 3 k_3 k3 ) 与Bregman散度的联系

Schulman指出,( k 3 = ( r − 1 ) − log ⁡ r k_3 = (r - 1) - \log r k3=(r1)logr ) 的形式可以解释为一个 Bregman散度。Bregman散度衡量一个凸函数与其在某点的切线之间的垂直距离。对于凸函数 ( f ( u ) = − log ⁡ u f(u) = -\log u f(u)=logu ),其导数为:

f ′ ( u ) = − 1 u f'(u) = -\frac{1}{u} f(u)=u1

在点 ( u = 1 u = 1 u=1 ) 处的切线为:

f ( 1 ) + f ′ ( 1 ) ( u − 1 ) = − log ⁡ 1 − 1 1 ( u − 1 ) = − ( u − 1 ) f(1) + f'(1)(u - 1) = -\log 1 - \frac{1}{1}(u - 1) = -(u - 1) f(1)+f(1)(u1)=log111(u1)=(u1)

Bregman散度定义为:

D f ( u , 1 ) = f ( u ) − f ( 1 ) − f ′ ( 1 ) ( u − 1 ) D_f(u, 1) = f(u) - f(1) - f'(1)(u - 1) Df(u,1)=f(u)f(1)f(1)(u1)

代入 ( f ( u ) = − log ⁡ u f(u) = -\log u f(u)=logu ):

D f ( r , 1 ) = ( − log ⁡ r ) − ( − log ⁡ 1 ) − ( − 1 1 ) ( r − 1 ) = − log ⁡ r − 0 + ( r − 1 ) = ( r − 1 ) − log ⁡ r D_f(r, 1) = (-\log r) - (-\log 1) - \left(-\frac{1}{1}\right)(r - 1) = -\log r - 0 + (r - 1) = (r - 1) - \log r Df(r,1)=(logr)(log1)(11)(r1)=logr0+(r1)=(r1)logr

这正是 ( k 3 k_3 k3 ) 的形式!因此,( k 3 k_3 k3 ) 可以看作 ( f ( u ) = − log ⁡ u f(u) = -\log u f(u)=logu ) 在 ( r = 1 r = 1 r=1 ) 处的Bregman散度。这种解释提供了理论上的优雅性:( k 3 k_3 k3 ) 不仅无偏,还通过Bregman散度的非负性自然降低了方差。

6. 实验验证

Schulman通过实验比较了 ( k 1 k_1 k1 )、( k 2 k_2 k2 ) 和 ( k 3 k_3 k3 )。例如,当 ( q = N ( 0 , 1 ) q = \mathcal{N}(0, 1) q=N(0,1) )、( p = N ( 0.1 , 1 ) p = \mathcal{N}(0.1, 1) p=N(0.1,1) )(真实KL散度为0.005)时:

  • ( k 1 k_1 k1 ): 无偏,标准差为真实值的20倍。
  • ( k 2 k_2 k2 ): 偏倚为0.2%,标准差为真实值的1.42倍。
  • ( k 3 k_3 k3 ): 无偏,标准差为真实值的1.42倍。

当KL散度较大(如 ( p = N ( 1 , 1 ) p = \mathcal{N}(1, 1) p=N(1,1) ),真实KL散度为0.5)时:

  • ( k 3 k_3 k3 ): 无偏,标准差为真实值的1.7倍,优于 ( k 2 k_2 k2 ) 的1.73倍和25%偏倚。

这些结果表明,( k 3 k_3 k3 ) 成功结合了无偏性和低方差。


( k 3 k_3 k3 ) 与 ( k 2 k_2 k2 ) 推导的联系

虽然 ( k 2 k_2 k2 ) 和 ( k 3 k_3 k3 ) 的最终形式不同,但它们的推导和思考过程有以下联系:

  1. 共同目标:降低方差

    • ( k 2 k_2 k2 ) 通过平方 ( log ⁡ r \log r logr )(即 ( 1 2 ( log ⁡ r ) 2 \frac{1}{2} (\log r)^2 21(logr)2 ))使估计值非负,强调分布差异并降低方差。
    • ( k 3 k_3 k3 ) 通过控制变量 ( r − 1 r - 1 r1 ) 和利用 ( log ⁡ r ≤ r − 1 \log r \leq r - 1 logrr1 ) 的凹性不等式,同样追求非负性和低方差。
    • 两者都试图解决 ( k 1 = − log ⁡ r k_1 = -\log r k1=logr ) 的高方差问题,直观上希望估计器与KL散度的非负性质一致。
  2. f-散度与Bregman散度的理论支持

    • ( k 2 k_2 k2 ) 的期望是一个f-散度(( f ( u ) = 1 2 ( log ⁡ u ) 2 f(u) = \frac{1}{2} (\log u)^2 f(u)=21(logu)2 )),在 ( p ≈ q p \approx q pq ) 时通过二阶近似(( f ′ ′ ( 1 ) = 1 f''(1) = 1 f′′(1)=1 ))与KL散度等价。
    • ( k 3 k_3 k3 ) 直接是 ( K L [ q , p ] KL[q, p] KL[q,p] ) 的无偏估计器,且其形式对应于 ( f ( u ) = − log ⁡ u f(u) = -\log u f(u)=logu ) 的Bregman散度。Bregman散度可以看作f-散度的一种特殊形式(当考虑凸函数与其切线的距离时)。
    • 两者都利用了凸函数的性质(( k 2 k_2 k2 ) 通过凸函数 ( f ( u ) f(u) f(u)),( k 3 k_3 k3 ) 通过凹函数 ( − log ⁡ u -\log u logu ) 的Bregman散度),体现了Schulman从广义散度框架出发的统一思路。
  3. 从有偏到无偏的改进

    • ( k 2 k_2 k2 ) 是Schulman的初步尝试,牺牲了无偏性换取低方差,但通过f-散度的二阶等价性保证了低偏倚。
    • ( k 3 k_3 k3 ) 是对 ( k 2 k_2 k2 ) 的进一步改进,通过控制变量方法恢复了无偏性,同时保留了低方差的优点。( k 3 k_3 k3 ) 的非负性(通过 ( log ⁡ r ≤ r − 1 \log r \leq r - 1 logrr1 ))与 ( k 2 k_2 k2 ) 的非负性(通过平方)有相似的直观动机。
  4. 设计思路的演进

    • ( k 2 k_2 k2 ) 的设计更像是一种启发式尝试:用平方函数放大差异并确保非负,然后通过f-散度理论验证其合理性。
    • ( k 3 k_3 k3 ) 的设计更有系统性,基于控制变量的统计方法,并通过Bregman散度的数学结构提供理论支撑。Schulman在推导 ( k 3 k_3 k3 ) 时明确考虑了如何在 ( k 1 k_1 k1 ) 的基础上系统性地降低方差,而 ( k 2 k_2 k2 ) 更像是探索过程中的一个中间产物。
  5. 推广到其他散度

    • Schulman在博客中提到,( k 2 k_2 k2 ) 和 ( k 3 k_3 k3 ) 的思路可以推广到其他f-散度。例如,( k 3 k_3 k3 ) 的形式 ( f ( r ) − f ′ ( 1 ) ( r − 1 ) f(r) - f'(1)(r - 1) f(r)f(1)(r1) ) 是一个通用的f-散度估计器,适用于任意凸函数 ( f f f )。
    • ( k 2 k_2 k2 ) 的f-散度形式为其他估计器的设计提供了灵感,而 ( k 3 k_3 k3 ) 的Bregman散度形式进一步揭示了这种估计器与凸优化理论的深层联系。

总结

( k 3 = ( r − 1 ) − log ⁡ r k_3 = (r - 1) - \log r k3=(r1)logr ) 的来源是通过控制变量方法改进 ( k 1 k_1 k1 ),利用 ( r − 1 r - 1 r1 ) 的零期望和 ( log ⁡ r ≤ r − 1 \log r \leq r - 1 logrr1 ) 的凹性不等式,构造了一个无偏且低方差的KL散度估计器。其推导过程如下:

  1. 从 ( k 1 = − log ⁡ r k_1 = -\log r k1=logr ) 的高方差问题出发,引入控制变量 ( r − 1 r - 1 r1 )。
  2. 构造新估计器 ( k 3 = − log ⁡ r + λ ( r − 1 ) k_3 = -\log r + \lambda (r - 1) k3=logr+λ(r1) ),选择 ( λ = 1 \lambda = 1 λ=1 ) 保证非负性。
  3. 验证 ( k 3 k_3 k3 ) 的无偏性和非负性,并通过Bregman散度提供理论解释。
  4. 实验确认 ( k 3 k_3 k3 ) 兼顾无偏和低方差,优于 ( k 2 k_2 k2 )。

与 ( k 2 k_2 k2 ) 的联系在于:

  • 两者都旨在降低 ( k 1 k_1 k1 ) 的方差,追求非负性以匹配KL散度的性质。
  • ( k 2 k_2 k2 ) 通过f-散度框架启发了对分布差异的度量,( k 3 k_3 k3 ) 则通过控制变量和Bregman散度实现了更系统的优化。
  • ( k 3 k_3 k3 ) 可以看作 ( k 2 k_2 k2 ) 思路的延伸,从有偏但低方差的探索(( k 2 k_2 k2 ))进化到无偏且低方差的解决方案(( k 3 k_3 k3 ))。

这种推导思路不仅在理论上优雅,还在DeepSeek GRPO等强化学习应用中得到了实践验证,展现了其在LLM优化中的重要价值。


后记

2025年4月20日于上海,在grok 3大模型辅助下完成。

Logo

欢迎加入DeepSeek 技术社区。在这里,你可以找到志同道合的朋友,共同探索AI技术的奥秘。

更多推荐