一、FlashMLA概述

FlashMLA是DeepSeek专为H架构的GPU设计的的高效MLA解码内核,优化了可变长度序列的多头潜在注意力机制。

(官方开源代码链接:https://github.com/deepseek-ai/FlashMLA)

1. FlashMLA三大核心特点

  1. MLA解码内核:MLA机制相比传统的多头注意力(MHA)机制,在处理长序列时具有更高的效率和更低的计算复杂度。它通过优化KV缓存的使用,减少了内存访问开销,同时提升了模型对长距离依赖关系的捕捉能力。
  2. Hopper GPU优化:专为NVIDIA Hopper架构GPU进行了深度优化,实现高达3000 GB/s的内存带宽和580 TFLOPS的计算性能,能够显著降低延迟,提升推理效率。
  3. 针对可变长序列优化:采用了分页KV缓存机制,块大小为64,有效解决了传统KV缓存的内存碎片化问题,提升了显存利用率。能够动态调整资源分配,高效处理不同长度的序列数据。特别适用于长对话、文档分析等场景。

2. 性能表现与对比

指标 FlashMLA 传统方法(如 MHA)
内存带宽利用率 3000 GB/s 约 2000 GB/s
计算峰值 580 TFLOPS 约 350 TFLOPS
KV 缓存内存占用 减少93.3% 100%
长序列处理效率 线性复杂度( O ( n k ) O(nk) O(nk)) 平方复杂度( O ( n 2 ) O(n^2) O(n2))

二、MLA机制详解

1. 传统的MHA机制简述

在这里插入图片描述

这张图展示了传统多头注意力机制(MHA,Multi-Head Attention)的工作流程。MHA是Transformer模型中的关键组件,用于捕捉序列中不同位置之间的依赖关系。下面是对图中流程的简单分析:

  1. 输入分解
    输入序列首先通过线性变换生成查询(Q)、键(K)和值(V)矩阵。这些矩阵的维度通常是 d × s d \times s d×s,其中 d d d 是每个头的维度, s s s 是序列长度。
  2. 多头机制
    查询、键和值矩阵被分解为多个头(图中显示为头1到头n)。每个头处理一部分信息,这样可以并行处理不同的特征子空间。每个头的维度被缩小到 d / n d/n d/n,以减少计算量和参数数量。
  3. 缩放点积注意力
    对于每个头,计算查询和键的点积,然后除以 d \sqrt{d} d 来缩放,最后应用softmax函数将点积转换为权重,这些权重表示每个值向量对当前查询的重要性。
  4. 加权求和
    使用softmax得到的权重对值向量进行加权求和,得到每个头的输出。
  5. 拼接与变换
    将所有头的输出拼接在一起,形成一个维度为 d × s d \times s d×s 的矩阵,再通过一个线性层对拼接后的矩阵进行变换,恢复到原始维度输出。

总之,MHA通过并行处理多个头,能够捕捉序列中不同位置之间的多种依赖关系。但是每个头都需要存储完整的K和V矩阵,这在处理长序列时会导致大量的内存消耗

2. DeepSeek-V3的MLA机制

MLA 是多头部潜在注意力机制(Multi-head Latent Attention),其本质是对KV的有损压缩,提高存储信息密度的同时尽可能保留关键细节。 它的核心原理是对注意力键(Key)和值(Value)进行低秩压缩,使用两个线性层和来代替一个大的 Key/Value 投影矩阵,将输入投影到一个低维空间,然后将其投影回原始维度,从而减少存储和计算量。此外,MLA 还可对查询(Query)进行低秩压缩,以进一步减少激活内存。

在这里插入图片描述

MLA对Key和Value的进行的处理如公式(1-5)所示:
c t K V = W D K V h t , [ k t , 1 C ; k t , 2 C ; … ; k t , n h C ] = k t C = W U K c t K V , k t R = RoPE ( W K R h t ) , k t , i = [ k t , i C ; k t R ] , [ v t , 1 C ; v t , 2 C ; … ; v t , n h C ] = v t C = W U V c t K V , \begin{align} \textcolor{blue}{\mathbf{c}_t^{KV}} &= \mathbf{W}^{DKV} \mathbf{h}_t, \\ \left[ \mathbf{k}_{t,1}^C; \mathbf{k}_{t,2}^C; \ldots; \mathbf{k}_{t,n_h}^C \right] = \mathbf{k}_t^C &= \mathbf{W}^{UK} \mathbf{c}_t^{KV}, \\ \textcolor{blue}{\mathbf{k}_t^R} &= \text{RoPE}(\mathbf{W}^{KR} \mathbf{h}_t), \\ \mathbf{k}_{t,i} &= \left[ \mathbf{k}_{t,i}^C; \mathbf{k}_t^R \right], \\ \left[ \mathbf{v}_{t,1}^C; \mathbf{v}_{t,2}^C; \ldots; \mathbf{v}_{t,n_h}^C \right] = \mathbf{v}_t^C &= \mathbf{W}^{UV} \mathbf{c}_t^{KV}, \end{align} ctKV[kt,1C;kt,2C;;kt,nhC]=ktCktRkt,i[vt,1C;vt,2C;;vt,nhC]=vtC=WDKVht,=WUKctKV,=RoPE(WKRht),=[kt,iC;ktR],=WUVctKV,

  • 公式(1-2)先通过 W D K V \mathbf{W}^{DKV} WDKV矩阵对 h t \mathbf{h}_t ht实现降维;又通过 W U K \mathbf{W}^{UK} WUK矩阵对 h t \mathbf{h}_t ht实现升维,这样对 h t \mathbf{h}_t ht维度的一降一升,可以大幅降低了 h t \mathbf{h}_t ht本身的权重矩阵参数。
  • 公式(3-4)通过 W K R \mathbf{W}^{KR} WKR矩阵对 h t \mathbf{h}_t ht进行映射计算,然后对其做RoPE位置编码;并将 k t C \mathbf{k}_t^C ktC的每个头的计算结果分别与RoPE位置编码后的 k t R \mathbf{k}_t^R ktR进行拼接得到 k \mathbf{k} k,这样得到了MHA中的 K \mathbf{K} K
  • 公式(5)用于计算 V \mathbf{V} V矩阵。

类似的,MLA对于Query也做了低秩分解,用来减少训练时的激活内存,参见公式(6-9)。

c t Q = W D Q h t , [ q t , 1 C ; q t , 2 C ; … ; q t , n h C ] = q t C = W U Q c t Q , [ q t , 1 R ; q t , 2 R ; … ; q t , n h R ] = q t R = RoPE ( W Q R c t Q ) , q t , i = [ q t , i C ; q t , i R ] , \begin{align} \mathbf{c}_t^Q &= \mathbf{W}^{DQ} \mathbf{h}_t, \\ \left[ \mathbf{q}_{t,1}^C; \mathbf{q}_{t,2}^C; \ldots; \mathbf{q}_{t,n_h}^C \right] = \mathbf{q}_t^C &= \mathbf{W}^{UQ} \mathbf{c}_t^Q, \\ \left[ \mathbf{q}_{t,1}^R; \mathbf{q}_{t,2}^R; \ldots; \mathbf{q}_{t,n_h}^R \right] = \mathbf{q}_t^R &= \text{RoPE}(\mathbf{W}^{QR} \mathbf{c}_t^Q), \\ \mathbf{q}_{t,i} &= \left[ \mathbf{q}_{t,i}^C; \mathbf{q}_{t,i}^R \right], \end{align} ctQ[qt,1C;qt,2C;;qt,nhC]=qtC[qt,1R;qt,2R;;qt,nhR]=qtRqt,i=WDQht,=WUQctQ,=RoPE(WQRctQ),=[qt,iC;qt,iR],

在完成上述转换后的Q、K、V输入MHA,分别计算每个头的注意力,然后拼接到一块,接着利用 W O \mathbf{W}^O WO做个映射,完成Attention计算,公式见下:

o t , i = ∑ j = 1 t Softmax j ( q t , i T k j , i d h + d h R ) v j , i C , u t = W O [ o t , 1 ; o t , 2 ; … ; o t , n h ] , \begin{align*} {\mathbf{o}_{t,i}} &= \sum_{j=1}^{t} \text{Softmax}_j \left( \frac{\mathbf{q}_{t,i}^T \mathbf{k}_{j,i}}{\sqrt{d_h + d_h^R}} \right) \mathbf{v}_{j,i}^C, \\ \mathbf{u}_t &= \mathbf{W}^O \left[ {\mathbf{o}_{t,1}}; \mathbf{o}_{t,2}; \ldots; {\mathbf{o}_{t,n_h}} \right], \end{align*} ot,iut=j=1tSoftmaxj dh+dhR qt,iTkj,i vj,iC,=WO[ot,1;ot,2;;ot,nh],

综合上述分析在整个MLA机制处理过程中,只有标注蓝色的变量( c t K V \textcolor{blue}{\mathbf{c}_t^{KV}} ctKV k t R \textcolor{blue}{\mathbf{k}_t^R} ktR)需要被缓存,其它的都可以利用“矩阵吸收”,重新恢复过来。

3. 针对MLA与MHA的通俗解释

MLA(Multi-head Latent Attention)和 MHA(Multi-Head Attention)都可以用来实现注意力机制,但它们并不是完全并行的“思考方式”选择,而是针对不同目标优化的设计。

我们可以用图书馆查阅的比喻来理解两者的核心差异:

  • MHA 的思考方式:想象你在一个传统的巨型图书馆查阅资料,每本书(每个注意力头)都完整保留所有原始内容(完整 K/V 缓存),书架按主题(头数 h)严格分区。查阅时每次需要回答问题(计算注意力),必须跑到每个主题区(每个头)翻遍对应书架的所有书籍(加载全部 K/V)
  • MLA 的思考方式:更像是现代化数字图书馆,书籍压缩归档。将相似主题的书籍(h 个头)打包成精华合辑(若干组潜在状态),只保留核心摘要(潜在 K/V),同时为每个合辑建立关键词标签(潜在变量),实时更新内容概要。查阅时先通过关键词标签(潜在状态)定位相关合辑(组),快速浏览摘要(全局注意力),必要时调取合辑内的原始书籍(局部注意力),但频次大幅降低。

相比之下,MHA方式中每个主题区存在重复内容,且随着藏书量增加(序列变长),找书耗时将会快速增长(O(n²)复杂度),这样造成了空间浪费和效率的低下。而MLA找书时间从翻遍全馆变为先看目录再精准查阅(复杂度从降为 O(n) 主导),显著提升了查找效率。


三、Hopper GPU优化详解

1. 内存子系统

  • 分页KV缓存与合并访问
    FlashMLA采用分页KV缓存机制(块大小64),将长序列切分为固定大小的块,对齐Hopper GPU的HBM3内存总线(256位宽),实现连续内存访问。这种设计不仅减少了显存碎片化,还通过合并内存事务将内存带宽提升至3000 GB/s。
  • BF16混合精度支持
    采用Brain Float 16(BF16)格式进行KV缓存和计算,相比FP32减少50%内存占用,同时保持模型精度。这一优化在Hopper的第三代Tensor Core上进一步加速矩阵运算,使计算性能达到580 TFLOPS。

2. 计算性能

  • Tensor Core与低秩分解
    通过将全局注意力计算分解为低秩矩阵乘法(如 Q i K ^ j T Q_i \hat{K}_j^T QiK^jT),利用Tensor Core的单周期8x8x16矩阵乘加(MMA)指令,最大化并行计算效率。在H800 GPU上,这一优化使计算性能接近理论峰值。
  • 算子融合与流水线并行
    FlashMLA将QKV投影、分组Pooling和注意力计算融合为单一内核,减少中间结果回写显存的次数。同时采用双缓冲流水线技术,异步预加载下一Token的Q向量至寄存器,实现计算与数据搬运的无缝重叠,降低端到端延迟。

3. 硬件架构

  • SM资源配置优化
    每个线程块处理2个注意力头(共128线程),匹配Hopper SM的128 FP32核心设计,确保计算单元满载。共享内存分配64KB(48KB用于潜在状态缓存,16KB用于局部KV),减少全局内存访问。
  • DPX指令加速动态逻辑
    利用Hopper新增的Dynamic Programming Extensions(DPX)指令,加速分页缓存的地址偏移计算和动态序列调度,例如通过dp4a指令快速聚合零散内存访问。

四、可变长序列优化详解

1. 对KV Cache的理解

(此部分学习参考知乎文章: https://zhuanlan.zhihu.com/p/26911261250)

由于decoder是有因果性的(即一个token的注意力attention只依赖于它前面的token),当每生成一个新的 token 就会把这个新的 token 添加进之前的序列中,在将这个序列当作新的输入进行新的 token 生成,直到 e o s _ t o k e n eos\_token eos_token 结束。这使得每次新序列输入时都需要取重复计算前面的 ( n − 1 ) (n-1) (n1) 个 token 的 ( q , k , v ) (q, k, v) (q,k,v),浪费了很多资源。

KV Cache 就是在这里使用的,我们在每次处理新的序列时,不需要对之前已经计算过的 Token 的 K 和 V 重新进行计算。因为对于之前的 Token 可以复用上一轮计算的结果,避免了重复计算,只需要计算当前 Token 的 Q、K、V。

在这里插入图片描述

我们从Key Cache中提取先前计算的Key向量,并计算注意力分数矩阵的最后一行作为新Query向量与每个Key向量的点积:

在这里插入图片描述

与Key向量一样,每次迭代时只需要计算最后一个Value向量。所有其他Value向量都可以从Value Cache中提取并重复使用:

在这里插入图片描述

2. FlashMLA 对 KV Cache的优化

(1) Paged KV Cache 实现

  • 显存分块:以64为单位(block_size = 64),通过block_table维护逻辑块到物理显存的映射。
  • 流水线:分离数据加载与计算阶段,通过cp.async实现异步数据预取。
    • flash_fwd_splitkv_mla_kernel:用于并行计算Flash Attention的前向传播。
    • flash_fwd_splitkv_mla_combine_kernel:用于合并多个分割的计算结果。

(2) 优化特性

  • 分页KV缓存管理
    针对长序列推理中显存碎片严重问题,Flash-MLA实现基于64-block Paged KV Cache,极大提高了显存利用率,缓解内存访问瓶颈。
  • 异步内存拷贝
    利用NVIDIA Hopper SM90架构特性,借助Tensor Memory Accelerator(TMA)异步内存拷贝指令,实现显存(HBM/GDDR)到SRAM零拷贝传输,接近理论峰值带宽。
  • 双模式执行引擎
    为适应不同输入序列长度场景,FlashMLA采用动态负载均衡算法,设计了双缓冲模式,短序列下采用计算优先模式,长序列下采用内存优先模式,使得整体延迟大幅降低。

五、小结

综上所述,FlashMLA通过一系列创新的优化措施,显著提升了NVIDIA Hopper GPU在处理可变长序列时的性能和效率。其分页KV缓存管理、异步内存拷贝和双模式执行引擎等特性,不仅提高了显存利用率和内存带宽,还降低了整体延迟。与传统的MHA机制相比,FlashMLA通过低秩压缩和旋转位置编码等技术,减少了内存访问开销,同时提升了模型对长距离依赖关系的捕捉能力,这些优化使得FlashMLA在处理大规模语言模型推理任务时表现出色,特别是在长对话和文档分析等场景中。

Logo

欢迎加入DeepSeek 技术社区。在这里,你可以找到志同道合的朋友,共同探索AI技术的奥秘。

更多推荐