DeepSeek V4 的注意力机制详解

在 DeepSeek-V4 中,使用CSA(压缩稀疏注意力)和 HCA(重度压缩注意力)的混合注意力。CSA(压缩稀疏注意力)和 HCA(重度压缩注意力)并不是在同一层内计算并相加的,而是以交替(Interleaved)的方式在不同的 Transformer 层中使用。例如,第 3 层可能使用 CSA,第 4 层可能使用 HCA。

因此,假设输入向量 A 维度为 [2, 1000, 8192](Batch Size = 2,序列长度 = 1000,隐藏维度 ddd = 8192),它会分别通过 CSA 层或 HCA 层,最终输出维度同样为 [2, 1000, 8192] 的向量 B

为了把每一步的矩阵乘法和维度变化讲透,我们假设以下符合 DeepSeek-V4 设计逻辑的超参数设定:

  • 输入维度 ddd: 8192
  • 注意力头数 nhn_hnh: 128
  • 单头维度 ccc: 512
  • Query 低秩压缩维度 dcd_cdc: 1536
  • CSA 压缩率 mmm: 4
  • 滑动窗口大小 nwinn_{win}nwin: 128
  • 分组投影组数 ggg: 16,每组中间维度 dgd_gdg: 1024

下面我们以 CSA 层 为例,一步步推演输入 A 到输出 B 的全过程。


第一步:KV 生成与序列压缩 (KV Generation & Compression)

传统 Attention 会为每个 Token 生成 KV,而 CSA 首先生成多组 KV 状态,并在序列维度上进行压缩。

  1. 线性映射生成初始 KV
    输入矩阵 HHH 大小为 [2, 1000, 8192]
    乘以权重矩阵 WaKV,WbKV∈R8192×512W^{aKV}, W^{bKV} \in \mathbb{R}^{8192 \times 512}WaKV,WbKVR8192×512

    • Ca=H⋅WaKVC^a = H \cdot W^{aKV}Ca=HWaKV →\rightarrow 激活状态维度:[2, 1000, 512]
    • Cb=H⋅WbKVC^b = H \cdot W^{bKV}Cb=HWbKV →\rightarrow 激活状态维度:[2, 1000, 512]
      (同时生成压缩权重 Za,ZbZ^a, Z^bZa,Zb,维度相同)
  2. 序列级 Token 压缩
    m=4m=4m=4 个 Token 被压缩为一个条目。
    序列长度从 1000 变为 1000 / 4 = 250

    • 得到压缩后的 KV 缓存 CCompC^{Comp}CComp →\rightarrow 激活状态维度:[2, 250, 512]

第二步:Query 的低秩生成 (Low-Rank Query Generation)

为了降低 128 个头的计算开销,Query 是通过低秩投影生成的。

  1. 降维投影 (Down-projection)
    乘以权重矩阵 WDQ∈R8192×1536W^{DQ} \in \mathbb{R}^{8192 \times 1536}WDQR8192×1536

    • cQ=H⋅WDQc^Q = H \cdot W^{DQ}cQ=HWDQ →\rightarrow 激活状态维度:[2, 1000, 1536]
  2. 升维投影得到多头 Query (Up-projection)
    乘以权重矩阵 WUQ∈R1536×(128×512)W^{UQ} \in \mathbb{R}^{1536 \times (128 \times 512)}WUQR1536×(128×512),即 1536×655361536 \times 655361536×65536

    • q=cQ⋅WUQq = c^Q \cdot W^{UQ}q=cQWUQ →\rightarrow 激活状态维度:[2, 1000, 65536]
    • 重塑 (Reshape) 为多头格式 →\rightarrow 维度:[2, 1000, 128, 512]

第三步:位置编码的注入 (Partial RoPE)

在这里,旋转位置编码(RoPE)并不会应用于整个 512 维的头向量,而是仅应用于最后 64 个维度

  • 对 Query [2, 1000, 128, 512] 的最后 64 维进行 RoPE 旋转。
  • 对压缩后的 KV CCompC^{Comp}CComp [2, 250, 512] 的最后 64 维进行 RoPE 旋转。
  • 维度变化:无变化,保持原维度,但向量内部的特征值已被注入位置信息。

第四步:核心注意力计算 (Multi-Query Attention)

现在 Query 准备和 Key/Value 进行点乘。由于这是 MQA 架构,所有的 Query 头共享同一组压缩后的 KV 缓存。

  1. 稀疏选择 (Lightning Indexer)
    在这个例子中,序列总长仅为 1000,压缩后只有 250 个 KV 块。由于通常的 Top-kkk 设置(如 k=1024k=1024k=1024)大于 250,这里索引器会直接选中所有 250 个压缩块参与计算。

  2. 引入滑动窗口 (Sliding Window)
    为了保证局部信息的精确性,模型还会拉取当前 Token 前面的 128 个未压缩的原始 KV 状态(滑动窗口 nwin=128n_{win} = 128nwin=128)加入计算。

  3. 计算 Attention Output
    Query [2, 1000, 128, 512] 与 结合了滑动窗口和压缩块的 KV 进行点乘、Softmax 加权求和。

    • 核心注意力输出 OOO →\rightarrow 激活状态维度:[2, 1000, 128, 512]
    • 重塑拼接后 →\rightarrow 维度:[2, 1000, 65536]
    • (注:此时输出的最后 64 维会做一次反向的 RoPE 抵消绝对位置的影响,转化为相对位置特征)。

第五步:分组输出投影 (Grouped Output Projection)

如果直接把 [2, 1000, 65536] 用一个大矩阵映射回 8192,参数量和计算量会极其恐怖。DeepSeek-V4 使用了分组投影策略。

  1. 分组映射
    将 128 个头分成 g=16g=16g=16 组,每组 8 个头。
    每组的输入维度为 8×512=40968 \times 512 = 40968×512=4096
    每组乘以一个子投影矩阵 Wgroup∈R4096×1024W_{group} \in \mathbb{R}^{4096 \times 1024}WgroupR4096×1024

    • 单组输出维度:[2, 1000, 1024]
    • 16 组拼接后的中间激活维度:[2, 1000, 16384]
  2. 最终映射到输出 B
    将拼接后的中间变量乘以最终的映射矩阵 Wfinal∈R16384×8192W_{final} \in \mathbb{R}^{16384 \times 8192}WfinalR16384×8192

    • 最终输出 BBB →\rightarrow 维度:[2, 1000, 8192]

总结:如果是 HCA 层会有什么不同?

如果这一层交替到了 HCA (重度压缩注意力),整个流程的矩阵维度变换逻辑几乎完全一致,核心区别在于第一步的序列压缩率

  • HCA 的压缩率 m′=128m' = 128m=128
  • 1000 个 Token 会被极端压缩成 1000/128≈81000 / 128 \approx 81000/1288 个 KV 块。
  • 由于序列极短,HCA 完全跳过第四步的稀疏选择,Query 直接与这 8 个全局浓缩块(外加 128 个局部滑动窗口块)进行全局注意力计算。输出同样经过分组投影,完美还原回 [2, 1000, 8192]

DeepSeek V4 的mHC机制详解

一、 终极目的:谱范数与非扩张 (Spectral Norm & Non-expansive)

我们要解决的核心问题是:在残差连接(或者超连接)中,信号在经过成百上千层的矩阵相乘后,为什么会发生梯度爆炸或消失?

1. 谱范数 (Spectral Norm)
想象一个向量 xxx(代表前向传播中的特征信号),当它乘以一个矩阵 BBB(代表网络层的权重映射)时,BxBxBx 的本质是对特征信号进行拉伸或压缩

矩阵 BBB谱范数(通常记为 ∣∣B∣∣2||B||_2∣∣B2)衡量的是这个矩阵对任何非零向量 xxx最大拉伸比例
用数学语言表达就是:
∣∣B∣∣2=max⁡x≠0∣∣Bx∣∣2∣∣x∣∣2||B||_2 = \max_{x \neq 0} \frac{||Bx||_2}{||x||_2}∣∣B2=x=0max∣∣x2∣∣Bx2
在线性代数中,它等于矩阵 BBB 的最大奇异值(σmax\sigma_{max}σmax)。

2. 非扩张 (Non-expansive)
顾名思义,“非扩张”就是 “绝不放大信号”
如果一个矩阵变换是非扩张的,就意味着对于任意输入的信号 xxx,变换后的长度绝不会超过原来的长度:
∣∣Bx∣∣2≤∣∣x∣∣2||Bx||_2 \le ||x||_2∣∣Bx2∣∣x2
这就等价于要求矩阵的谱范数 ∣∣B∣∣2≤1||B||_2 \le 1∣∣B21

在模型中的意义: 如果 ∣∣B∣∣2>1||B||_2 > 1∣∣B2>1,经过 100 层累乘,信号会被放大 1.1100≈137801.1^{100} \approx 137801.110013780 倍,导致数值爆炸(NaN)。如果保证 ∣∣B∣∣2≤1||B||_2 \le 1∣∣B21(即非扩张),无论网络堆叠多深,信号都处于安全边界内。


二、 约束的概念:流形 (Manifold)

既然我们需要 BBB 的谱范数 ≤1\le 11,最粗暴的方法是每次更新完权重就强行把矩阵缩放一下。但这破坏了梯度下降的自然连续性。更优雅的做法是让矩阵 BBB 只能在一个绝对安全的“空间”里更新。这个空间,就是流形。

什么是流形?
流形是一个几何概念,指局部看起来像平坦的欧几里得空间,但整体可能弯曲或有特定形状的空间

  • 通俗比喻: 地球表面。对站在操场上的人来说,地面是二维平坦的(局部欧几里得),你可以前后左右走动(梯度下降);但在宇宙视角看,它是一个被约束在三维空间中的二维球面(流形)。

在 mHC 的语境下,所有可能的 n×nn \times nn×n 矩阵构成了一个巨大的、无边无际的高维空间(充满着会导致信号爆炸的危险矩阵)。我们不想让残差矩阵 BBB 在这个无边无际的空间里乱跑,我们要把它“锁”在一个安全的特定几何表面上。 这个特定的表面,就是接下来要说的 Birkhoff 多胞形。


三、 具体的安全形状:双随机矩阵的 Birkhoff 多胞形

现在我们需要找到一个具体的数学流形,只要矩阵 BBB 在这个流形上,就绝对满足非扩张(∣∣B∣∣2≤1||B||_2 \le 1∣∣B21)的要求。

1. 双随机矩阵 (Doubly Stochastic Matrix)
这是一个非常特殊的方阵,必须同时满足三个条件:

  1. 非负性: 矩阵里的每一个元素都大于等于 0(Bij≥0B_{ij} \ge 0Bij0)。
  2. 行和为 1: 每一行的所有元素加起来等于 1(∑jBij=1\sum_j B_{ij} = 1jBij=1)。
  3. 列和为 1: 每一列的所有元素加起来等于 1(∑iBij=1\sum_i B_{ij} = 1iBij=1)。

2. Birkhoff 多胞形 (Birkhoff Polytope)

  • 多胞形 (Polytope): 你可以把它理解为高维空间中的“多边形”或“多面体”。
  • Birkhoff 多胞形: 它是由所有 n×nn \times nn×n 的双随机矩阵构成的一个几何凸集(Convex Polytope)。
    根据 Birkhoff-von Neumann 定理,这个多胞形的每一个“顶点”,恰好是所有的置换矩阵(Permutation Matrices,即只进行行/列交换,不改变数值大小的矩阵)。

在模型中的意义: 论文中说的“约束在流形上”,其实就是把残差矩阵 BBB 死死地限制在了这个 Birkhoff 多胞形的几何体内。


四、 终极串联:为什么这个设计能保证绝对稳定?

这是最惊艳的一步:只要 BBB 是一个双随机矩阵(身处 Birkhoff 多胞形中),它就必定是非扩张的!

在矩阵理论中,有一个著名的范数不等式(Schur界):
矩阵的谱范数(2-范数),一定小于等于它的 列和最大值(1-范数)行和最大值(∞\infty-范数) 乘积的平方根。
∣∣B∣∣2≤∣∣B∣∣1⋅∣∣B∣∣∞||B||_2 \le \sqrt{||B||_1 \cdot ||B||_\infty}∣∣B2∣∣B1∣∣B

因为我们把 BBB 约束成了双随机矩阵:

  • 所有列的和都是 1 ⇒∣∣B∣∣1=1\Rightarrow ||B||_1 = 1∣∣B1=1
  • 所有行的和都是 1 ⇒∣∣B∣∣∞=1\Rightarrow ||B||_\infty = 1∣∣B=1

代入公式:
∣∣B∣∣2≤1×1=1||B||_2 \le \sqrt{1 \times 1} = 1∣∣B21×1 =1

证明完毕! 没有任何外部的强制截断或粗暴缩放,仅仅通过让矩阵 BBB 的生成过程最终经过 Sinkhorn-Knopp 算法(该算法的作用就是把任意正数矩阵投影变成双随机矩阵),模型在数学机制上获得了两个极其强大的保障:

  1. 单层安全: 无论输入什么信号,经过残差矩阵 BlB_lBl 后,特征方差绝对不会放大。
  2. 无限深度的安全: 两个双随机矩阵相乘,结果仍然是双随机矩阵(Birkhoff 多胞形对乘法封闭)。这意味着即使你堆叠 1000 层 mHC,最终的复合残差映射也依然被完美锁定在这个多胞形内,谱范数永远 ≤1\le 11

在这里插入图片描述

下面用具体的数值走一遍算法,是理解矩阵变换最直观、最硬核的方式。

在开始推演之前,我需要稍微纠正一个小概念:Sinkhorn-Knopp 算法的“收敛”并不是指矩阵里的数值无限变小(趋于 0),而是指矩阵的“每一行的和”与“每一列的和”快速逼近于 1。 当然,由于原始的指数化矩阵数值通常很大,在归一化的过程中,数值确实会缩小并被限制在 (0,1)(0, 1)(0,1) 之间。

为了让你看得很清楚,我们构造一个简单的 2×22 \times 22×2 矩阵来进行一次完整的迭代(t=1t=1t=1)。


第一步:初始化 M(0)M^{(0)}M(0)

假设神经网络生成的原始残差映射参数矩阵 B~l\tilde{B}_lB~l 为:
B~l=[0.6931.0981.3860]\tilde{B}_l = \begin{bmatrix} 0.693 & 1.098 \\ 1.386 & 0 \end{bmatrix}B~l=[0.6931.3861.0980]

根据公式 M(0)=exp⁡(B~l)M^{(0)} = \exp(\tilde{B}_l)M(0)=exp(B~l),我们对矩阵里的每一个元素求自然指数(这里为了计算方便,我特意取了 ln⁡(2),ln⁡(3)\ln(2), \ln(3)ln(2),ln(3) 等近似值):
M(0)≈[2341]M^{(0)} \approx \begin{bmatrix} 2 & 3 \\ 4 & 1 \end{bmatrix}M(0)[2431]

此时的状态检查(距离双随机矩阵有多远?):

  • 列和 (Column Sums): 第 1 列是 2+4=62 + 4 = \textbf{6}2+4=6;第 2 列是 3+1=43 + 1 = \textbf{4}3+1=4
  • 行和 (Row Sums): 第 1 行是 2+3=52 + 3 = \textbf{5}2+3=5;第 2 行是 4+1=54 + 1 = \textbf{5}4+1=5
  • 结论: 行和与列和与目标值 1\textbf{1}1 差得很远。

第二步:列归一化 Tc\mathcal{T}_cTc (第一次迭代的前半步)

公式中的 Tc\mathcal{T}_cTc 表示列归一化,即把每一列的元素除以该列的总和,强行让列和变成 1。

  • 第 1 列处理: 元素除以列和 6 →\rightarrow 2/6=1/32/6 = 1/32/6=1/34/6=2/34/6 = 2/34/6=2/3
  • 第 2 列处理: 元素除以列和 4 →\rightarrow 3/43/43/41/41/41/4

我们得到了列归一化后的中间矩阵:
Mc=[1/33/42/31/4]≈[0.3330.7500.6670.250]M_c = \begin{bmatrix} 1/3 & 3/4 \\ 2/3 & 1/4 \end{bmatrix} \approx \begin{bmatrix} 0.333 & 0.750 \\ 0.667 & 0.250 \end{bmatrix}Mc=[1/32/33/41/4][0.3330.6670.7500.250]

此时的状态检查:

  • 列和: (1/3+2/3)=1(1/3 + 2/3) = \textbf{1}(1/3+2/3)=1(3/4+1/4)=1(3/4 + 1/4) = \textbf{1}(3/4+1/4)=1。(列和完美符合要求!)
  • 行和: 第 1 行是 1/3+3/4=1.0831/3 + 3/4 = \textbf{1.083}1/3+3/4=1.083;第 2 行是 2/3+1/4=0.9172/3 + 1/4 = \textbf{0.917}2/3+1/4=0.917
  • 结论: 列和搞定了,但行和被破坏了,不过比最开始的 555 已经非常接近 111 了。

第三步:行归一化 Tr\mathcal{T}_rTr (第一次迭代的后半步,完成 t=1t=1t=1)

公式中的 Tr\mathcal{T}_rTr 表示行归一化,即把 McM_cMc 每一行的元素除以该行的总和,强行让行和变成 1。

  • 第 1 行处理: 行和是 13/1213/1213/12
    元素除以行和 →\rightarrow (1/3)/(13/12)=4/13≈0.308(1/3) / (13/12) = \textbf{4/13} \approx 0.308(1/3)/(13/12)=4/130.308(3/4)/(13/12)=9/13≈0.692(3/4) / (13/12) = \textbf{9/13} \approx 0.692(3/4)/(13/12)=9/130.692
  • 第 2 行处理: 行和是 11/1211/1211/12
    元素除以行和 →\rightarrow (2/3)/(11/12)=8/11≈0.727(2/3) / (11/12) = \textbf{8/11} \approx 0.727(2/3)/(11/12)=8/110.727(1/4)/(11/12)=3/11≈0.273(1/4) / (11/12) = \textbf{3/11} \approx 0.273(1/4)/(11/12)=3/110.273

我们得到了完成一次完整迭代后的矩阵 M(1)M^{(1)}M(1)
M(1)=[4/139/138/113/11]≈[0.3080.6920.7270.273]M^{(1)} = \begin{bmatrix} 4/13 & 9/13 \\ 8/11 & 3/11 \end{bmatrix} \approx \begin{bmatrix} 0.308 & 0.692 \\ 0.727 & 0.273 \end{bmatrix}M(1)=[4/138/119/133/11][0.3080.7270.6920.273]

此时的状态检查:

  • 行和: (4/13+9/13)=1(4/13 + 9/13) = \textbf{1}(4/13+9/13)=1(8/11+3/11)=1(8/11 + 3/11) = \textbf{1}(8/11+3/11)=1。(行和完美符合要求!)
  • 列和: 第 1 列是 4/13+8/11≈1.0354/13 + 8/11 \approx \textbf{1.035}4/13+8/111.035;第 2 列是 9/13+3/11≈0.9659/13 + 3/11 \approx \textbf{0.965}9/13+3/110.965

总结:我们看到了什么?

通过对比 t=0t=0t=0t=1t=1t=1 的状态,我们可以清晰地看到 Sinkhorn-Knopp 算法的“收敛”魔力:

  1. 初始误差极大: M(0)M^{(0)}M(0) 的行和列和分别是 4, 5, 6,完全不受控。
  2. 迭代一次后误差极小: 仅仅经过一次 t=1t=1t=1 的交替归一化,矩阵的行和已经被强行拉到了 1,而列和变成了 1.0350.965
  3. 数值的变化: 矩阵里的数值从最初的 2,3,4,12, 3, 4, 12,3,4,1 迅速被压缩成了 (0,1)(0, 1)(0,1) 之间的小数。

如果你继续进行 t=2,t=3…t=2, t=3 \dotst=2,t=3 迭代,列和与行和会像钟摆一样越来越小幅地振荡,直到在 t=20t=20t=20 时,行和与列和将无限逼近于 1.00000,最终完美收敛为一个真正的双随机矩阵。此时,这个矩阵参与网络层之间的残差相乘,就绝对不会导致数值爆炸了。

DeepSeek-V4 中引入的流形约束超连接(Manifold-Constrained Hyper-Connections, mHC)是一种用来替代传统残差连接的高级架构设计。在极深的网络中,传统的残差叠加或简单的超连接(Hyper-Connections, HC)极易导致数值不稳定和梯度爆炸。

mHC 的核心思想是将残差映射矩阵约束在特定的数学流形(双随机矩阵的 Birkhoff 多胞形)上 。这保证了残差变换是非扩张的(即谱范数始终 ≤1\le 11),从而在不损失模型表达能力的前提下,极大地提高了深层网络前向传播和反向传播的稳定性。

以下我们将从数学公式、维度变化推演以及代码实现三个维度为你详细拆解。


一、 数学公式推导

标准的 Transformer 某层输入和输出均为 ddd 维,即 x∈Rdx \in \mathbb{R}^dxRd。但在 mHC 中,残差流(Residual Stream)被拓宽了 nhcn_{hc}nhc 倍,变为了矩阵形式 Xl∈Rnhc×dX_l \in \mathbb{R}^{n_{hc} \times d}XlRnhc×d

对于第 lll 层,其状态更新公式为:
Xl+1=BlXl+ClFl(AlXl)X_{l+1} = B_l X_l + C_l \mathcal{F}_l(A_l X_l)Xl+1=BlXl+ClFl(AlXl)

其中 Fl\mathcal{F}_lFl 代表当前层的实际计算模块(如 MoE 或 Attention)。为了实现上述公式,需要动态生成三个关键的线性映射:输入映射 AlA_lAl、残差映射 BlB_lBl 和输出映射 ClC_lCl

1. 动态参数生成

将当前层输入展平并归一化,作为动态生成参数的条件 :
X^l=RMSNorm(vec(Xl))∈R1×nhcd\hat{X}_l = \text{RMSNorm}(\text{vec}(X_l)) \in \mathbb{R}^{1 \times n_{hc}d}X^l=RMSNorm(vec(Xl))R1×nhcd

随后通过线性层和可学习门控因子(α\alphaα),生成无约束的原始参数:
A~l=αlpre⋅(X^lWlpre)+Slpre\tilde{A}_l = \alpha_l^{pre} \cdot (\hat{X}_l W_l^{pre}) + S_l^{pre}A~l=αlpre(X^lWlpre)+Slpre
B~l=αlres⋅Mat(X^lWlres)+Slres\tilde{B}_l = \alpha_l^{res} \cdot \text{Mat}(\hat{X}_l W_l^{res}) + S_l^{res}B~l=αlresMat(X^lWlres)+Slres
C~l=αlpost⋅(X^lWlpost)T+Slpost\tilde{C}_l = \alpha_l^{post} \cdot (\hat{X}_l W_l^{post})^T + S_l^{post}C~l=αlpost(X^lWlpost)T+Slpost

2. 施加流形约束 (Constraints)

为了保证稳定性,需要对上述原始参数进行严格约束:

  • 输入/输出映射 (Sigmoid 约束):保证非负性和有界性 。
    Al=σ(A~l)A_l = \sigma(\tilde{A}_l)Al=σ(A~l)
    Cl=2σ(C~l)C_l = 2\sigma(\tilde{C}_l)Cl=2σ(C~l)
  • 残差映射 (Sinkhorn-Knopp 算法):将 B~l\tilde{B}_lB~l 投影到双随机矩阵流形上,使其每行和每列的元素之和均为 1,且所有元素非负。具体通过迭代实现:
    M(0)=exp⁡(B~l)M^{(0)} = \exp(\tilde{B}_l)M(0)=exp(B~l)
    M(t)=Tr(Tc(M(t−1)))M^{(t)} = \mathcal{T}_r(\mathcal{T}_c(M^{(t-1)}))M(t)=Tr(Tc(M(t1)))
    通过交替进行列归一化 Tc\mathcal{T}_cTc 和行归一化 Tr\mathcal{T}_rTr,通常迭代 tmax=20t_{max} = \textbf{20}tmax=20 次后收敛,得到最终的 Bl=M(20)B_l = M^{(20)}Bl=M(20)

二、 维度变化推演 (以 Batch=2, Seq=1000, d=8192 为例)

我们设定 mHC 的超参数配置遵循 DeepSeek-V4 的设定:扩展因子 nhc=4n_{hc} = \textbf{4}nhc=4,核心隐藏层维度 d=8192d = \textbf{8192}d=8192

处理阶段 操作说明 张量维度 (形状)
0. 初始状态 上一层的输出 XlX_lXl,其在每个 Token 位置维护了 4 个不同的 ddd 维向量。 [2, 1000, 4, 8192]
1. 展平与归一化 nhcn_{hc}nhcddd 两个维度上展平,并进行 RMSNorm 。4×8192=327684 \times 8192 = 327684×8192=32768。得到条件变量 X^l\hat{X}_lX^l [2, 1000, 32768]
2. 生成 A~l\tilde{A}_lA~lAlA_lAl X^l\hat{X}_lX^l 经过线性层 WpreW^{pre}Wpre 投影到 4 维,加上静态偏置,经过 Sigmoid 约束。这就是各个通道的输入加权系数。 [2, 1000, 1, 4]
3. 生成 C~l\tilde{C}_lC~lClC_lCl X^l\hat{X}_lX^l 经过线性层 WpostW^{post}Wpost 投影到 4 维,经过 2×Sigmoid2 \times \text{Sigmoid}2×Sigmoid 约束。这就是各个通道的输出分配系数。 [2, 1000, 4, 1]
4. 生成 B~l\tilde{B}_lB~lBlB_lBl X^l\hat{X}_lX^l 经过线性层 WresW^{res}Wres 投影到 16 维,Reshape 为 4×4 矩阵 。利用 Sinkhorn-Knopp 迭代 20 次,得到双随机矩阵 BlB_lBl [2, 1000, 4, 4]
5. 准备模块输入 输入系数矩阵 AlA_lAlXlX_lXl 进行矩阵乘法 (Al⋅XlA_l \cdot X_lAlXl),将 4 个通道的特征缩减为 1ddd 维输入特征。 [2, 1000, 8192]
6. 模块前向传播 经过实际的 Transformer 计算层 Fl\mathcal{F}_lFl(如 Attention 或 MoE)。 [2, 1000, 8192]
7. 还原输出维度 将模块的输出 Fl\mathcal{F}_lFl 与输出系数矩阵 ClC_lCl 相乘,将 1 维输出重新广播/分配到 4 个通道中。 [2, 1000, 4, 8192]
8. 残差传递 原状态 XlX_lXl 经过双随机矩阵 BlB_lBl 进行通道间的特征融合和流转 (Bl⋅XlB_l \cdot X_lBlXl)。 [2, 1000, 4, 8192]
9. 最终叠加 Xl+1=BlXl+ClFl(AlXl)X_{l+1} = B_l X_l + C_l \mathcal{F}_l(A_l X_l)Xl+1=BlXl+ClFl(AlXl) [2, 1000, 4, 8192]

三、 PyTorch 代码实现

以下代码展示了如何使用 PyTorch 实现 mHC 模块,包含其动态参数生成机制与核心的 Sinkhorn-Knopp 算法。

import torch
import torch.nn as nn
import torch.nn.functional as F

class mHC_Layer(nn.Module):
    def __init__(self, d=8192, n_hc=4, t_max=20):
        super().__init__()
        self.d = d
        self.n_hc = n_hc
        self.t_max = t_max
        self.flat_dim = n_hc * d
        
        # 动态参数的生成权重 (W) 和静态偏置 (S)
        self.W_pre = nn.Linear(self.flat_dim, n_hc, bias=False)
        self.S_pre = nn.Parameter(torch.zeros(1, 1, 1, n_hc))
        
        self.W_res = nn.Linear(self.flat_dim, n_hc * n_hc, bias=False)
        self.S_res = nn.Parameter(torch.zeros(1, 1, n_hc, n_hc))
        
        self.W_post = nn.Linear(self.flat_dim, n_hc, bias=False)
        self.S_post = nn.Parameter(torch.zeros(1, 1, n_hc, 1))
        
        # 可学习门控因子 (初始化为较小的值)
        self.alpha_pre = nn.Parameter(torch.ones(1) * 0.01)
        self.alpha_res = nn.Parameter(torch.ones(1) * 0.01)
        self.alpha_post = nn.Parameter(torch.ones(1) * 0.01)

    def sinkhorn_knopp(self, B_tilde):
        # 确保正值矩阵 M^(0)
        M = torch.exp(B_tilde)
        
        # 迭代 t_max 次,交替进行列归一化和行归一化
        for _ in range(self.t_max):
            M = M / (M.sum(dim=-2, keepdim=True) + 1e-6) # 列归一化
            M = M / (M.sum(dim=-1, keepdim=True) + 1e-6) # 行归一化
        return M

    def forward(self, X_l, F_l_module):
        """
        X_l: 形状 [batch_size, seq_len, n_hc, d]
        F_l_module: 当前层的实际计算逻辑 (如 MoE 层)
        """
        b, seq, _, _ = X_l.shape
        
        # 1. 展平并归一化
        # [b, seq, n_hc * d]
        X_flat = X_l.view(b, seq, -1) 
        X_hat = F.rms_norm(X_flat, (self.flat_dim,))
        
        # 2. 动态生成原始参数
        # [b, seq, 1, n_hc]
        A_tilde = self.alpha_pre * self.W_pre(X_hat).unsqueeze(2) + self.S_pre
        # [b, seq, n_hc, n_hc]
        B_tilde = self.alpha_res * self.W_res(X_hat).view(b, seq, self.n_hc, self.n_hc) + self.S_res
        # [b, seq, n_hc, 1]
        C_tilde = self.alpha_post * self.W_post(X_hat).unsqueeze(3) + self.S_post
        
        # 3. 施加约束
        A_l = torch.sigmoid(A_tilde)
        C_l = 2 * torch.sigmoid(C_tilde)
        B_l = self.sinkhorn_knopp(B_tilde)  # 投影到双随机矩阵流形
        
        # 4. 前向传播计算
        # 输入投影:[b, seq, 1, n_hc] x [b, seq, n_hc, d] -> [b, seq, 1, d] -> [b, seq, d]
        module_input = torch.matmul(A_l, X_l).squeeze(2)
        
        # 经过具体的 Transformer 块
        module_output = F_l_module(module_input)
        
        # 输出重分配:[b, seq, d] -> [b, seq, 1, d]
        module_output = module_output.unsqueeze(2)
        # [b, seq, n_hc, 1] x [b, seq, 1, d] -> [b, seq, n_hc, d]
        output_contribution = torch.matmul(C_l, module_output)
        
        # 残差投影:[b, seq, n_hc, n_hc] x [b, seq, n_hc, d] -> [b, seq, n_hc, d]
        residual_transformation = torch.matmul(B_l, X_l)
        
        # 5. 最终叠加
        X_next = residual_transformation + output_contribution
        
        return X_next

在 DeepSeek-V4 中,将前 3 层的密集前馈网络(Dense FFN)替换为使用 Hash 路由(Hash Routing)的 MoE 层,是一个兼顾参数容量扩展计算效率的精妙工程设计 [cite: 167]。

以下是关于“为什么只在前 3 层使用”的深入剖析,以及 Hash 路由与原本动态路由的机制对比。


DeepSeek V4 的Hash 路由详解

一、 为什么只对前 3 层进行 Hash 路由?

这个决定是由 Transformer 模型不同深度的特征表达性质决定的:

  1. 浅层关注词法(浅层特征),深层关注语义(上下文特征)
    在 Transformer 的最初几层,Token 的隐藏状态还没有充分融合周围的上下文信息,其特征主要由 词汇本身(Token ID) 决定。在这些浅层使用基于上下文的动态路由(计算注意力/特征的相似度)有点“杀鸡用牛刀”,因为此时的特征还不够丰富,路由网络很难学到复杂的语义分配。
  2. 低成本实现参数规模的暴涨
    在 DeepSeek-V3 及更早的版本中,前几层通常是 Dense(密集)层,因为浅层做动态 MoE 收益不高。但 Dense 层的参数量有限。V4 通过引入基于 Token ID 的 Hash 路由 [cite: 168],实际上是在浅层实现了一个极其庞大的、非线性的“扩展词表嵌入(Extended Embedding)”。它让浅层拥有了 MoE 级别的参数容量,却完全不需要承担动态路由网络的计算和通信开销。
  3. 为什么深层不能用?
    到了第 4 层及以后,Token 的特征已经高度上下文化(例如,苹果公司的“苹果”和水果的“苹果”,虽然 Token ID 相同,但此时的隐藏状态 xxx 已经完全不同)。如果此时还在用固定的 Hash 路由,就会阻碍模型根据上下文动态选择专家的能力。因此,Hash 路由只适合前 3 层。

二、 Hash 路由是如何操作的?(具体例子)

Hash 路由的本质是**“静态的、基于身份的分配”**。它直接根据输入的 Token ID 来决定去哪个专家 ,而不看这个 Token 所在的上下文。

举例推演:
假设前 3 层的某一层有 N=256N = 256N=256 个路由专家。

  • 输入:句子 “The apple is red”,假设 “apple” 在词表中的 Token ID 是 4567
  • Hash 计算:系统使用一个预定义的哈希函数(例如最简单的取模运算):
    Target_Expert=Hash(4567)(mod256)\text{Target\_Expert} = \text{Hash}(4567) \pmod{256}Target_Expert=Hash(4567)(mod256)
    假设 4567(mod256)=2154567 \pmod{256} = 2154567(mod256)=215
  • 结果:无论是哪篇文档、哪个语境,只要遇到 “apple”(ID: 4567),在前 3 层它永远且必定会被分配给第 215 号专家进行处理。

三、 Hash 路由与原本路由(DeepSeekMoE)的速度对比

传统的 MoE 路由(如 DeepSeek-V4 第 4 层及以后使用的路由)是**“动态的、基于特征的分配”**。

1. 传统 DeepSeekMoE 路由的计算流程

对于每一个 Token,其隐藏状态向量 x∈Rdx \in \mathbb{R}^dxRd 需要经历以下步骤:

  • 矩阵乘法 (GEMM):与门控权重矩阵 Wgate∈Rd×NW_{gate} \in \mathbb{R}^{d \times N}WgateRd×N 相乘,计算出与所有 NNN 个专家的亲和力分数(Logits)。
  • 激活函数:在 DeepSeek-V4 中,使用 Sqrt(Softplus(⋅))\text{Sqrt}(\text{Softplus}(\cdot))Sqrt(Softplus()) 函数计算最终的分数 [cite: 164]。
  • Top-K 排序:在 NNN 个分数中进行排序,选出得分最高的 KKK 个专家(V4 中 K=6K=6K=6 [cite: 741])。
  • 负载均衡计算:为了防止某些专家被“饿死”或“撑死”,还需要引入偏置更新和序列级的平衡损失计算。
2. 为什么 Hash 路由更快?

对比之下,Hash 路由在速度和系统调度上具有压倒性优势:

  • 计算复杂度从 O(d×N)O(d \times N)O(d×N) 降为 O(1)O(1)O(1):Hash 路由只需要对一个整数(Token ID)进行一次数学哈希运算,完全消除了门控网络的矩阵乘法(GEMM)
  • 零负载均衡开销:由于现代哈希函数在数学上具有均匀分布的特性,词表中的 Token ID 会被自然、均匀地打散到各个专家中。因此,Hash 路由不需要任何辅助损失(Auxiliary Loss)或动态偏置调整来维持负载均衡。
  • 极佳的系统确定性:在传统的 MoE 中,由于路由是动态的,底层系统(如 Expert Parallelism)在运行时需要等待路由结果出来后,才能知道 GPU 之间需要通信多少数据。而 Hash 路由是静态的,只要拿到 Input IDs,系统在进入 Transformer 层之前就可以预先知道并调度所有专家的通信,极大地掩盖了通信延迟。
Logo

欢迎加入DeepSeek 技术社区。在这里,你可以找到志同道合的朋友,共同探索AI技术的奥秘。

更多推荐