本来想写一篇KV Cache压缩的综述性博客,结果写到MLA部分的时候发现越写越多,完全值得单独拿出来写篇博客,遂从KV Cache压缩博客中单独揪出MLA进行介绍。

MLA(Multi-query Latent Attention)是国内创业公司deepseek在24年5月份发布的DeepSeek-V2大模型中用到的KV Cache压缩技术,正是在该技术的加持下DeepSeek-V2可以大幅压缩KV Cache的大小,进而大幅提升吞吐量,也正是从该模型开始,大模型推理的价格一下降低到一个很低的水平。MLA是少有的由国内公司做出的硬核创新,感谢deepseek,感谢MLA!我觉得在出现新的KV Cache压缩技术之前后续的大模型可能都会采用MLA,它的压缩效果接近MQA,但是生成效果却还比MHA更好,值得大家跟进。

MLA相比MQA和MHA相比做到了既要又要,着实牛逼,代价就是太难懂了。花了一天的时间仔细研究了一下苏剑林苏神的博客《缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA》,对MLA有了一个大概的理解,下面我按照自己的思路尝试解释一下MLA,在解释的过中我抛弃了MLA论文中复杂的符号定义,重新按照自己的理解去定义相关矩阵。

常规的Attention计算会用到 Q 、 K 、 V Q、K、V QKV,这也是我们需要保存KV的原因,而MLA则不保留KV Cache,另外引入了一个C Cache来代替KV Cache。在执行Attention计算的时候,通过一系列等价变换,将公式中出现的 K K K V V V均用 C C C来代替。不仅可以用 C C C来代替 K V KV KV,而且 Q Q Q所有的head都共享同一个C Cache,从这个层面来说MLA和MQA很类似,只需要保留一个head的缓存即可获得非常好的结果。另外一点,现在的大模型一般会在计算Attention之前将 Q Q Q K K K进行RoPE(旋转位置编码),如图1所示。这就会导致单纯的 C C C丢失了位置信息,为了弥补这个缺陷,MLA中给Q和K额外增加了 d r d_r dr个维度用来添加RoPE,其中 K K K新增的维度也是每个Head共享。大家看到这里可能比较懵逼,且听我一一道来。
在这里插入图片描述

图1 LLAMA2单层模型结构

1. KV Cache变C Cache

在MHA的单head计算过程中,输入 X X X(维度为 N ∗ d N * d Nd)会先送入三个projection矩阵 W Q , W K , W V W_Q,W_K,W_V WQWKWV进行线性变换:
Q = X ∗ W Q \begin{equation} Q=X*W_Q \end{equation} Q=XWQ K = X ∗ W K \begin{equation} K=X*W_K \end{equation} K=XWK V = X ∗ W V \begin{equation} V=X*W_V \end{equation} V=XWV这三个矩阵的维度一般相等,通常都是 d ∗ d k d * d_k ddk,所以输出维度还是 N ∗ d k N * d_k Ndk。变换完成之后生成的 K 、 V K、V KV即是传统意义上的KV Cache。Multi-head会将h个head的 Q 、 K 、 V Q、K、V QKV矩阵分别拼接在一起,构成 d ∗ d d * d dd的矩阵,其中 d = h ∗ d k d = h * d_k d=hdk。前面我们提到MLA没有保留KV Cache,而是引入了一个C Cache,也即增加了下面的公式:
C = X ∗ W C \begin{equation} C=X*W_C \end{equation} C=XWC W C W_C WC的维度不再是 d ∗ d k d * d_k ddk,而是变成了 d ∗ d c d * d_c ddc d c d_c dc通常远小于 d d d,但是会比 d k d_k dk大(DeepSeek-V2中 d k d_k dk是128, d c d_c dc是512)。有了 C C C之后,为了保证Attention计算的正确性,我们可以从概念上再引入K和V,也即得到如下两个关于KV的新公式:
K = C ∗ W K \begin{equation} K=C*W_K \end{equation} K=CWK V = C ∗ W V \begin{equation} V=C*W_V \end{equation} V=CWV此时 W K W_K WK W V W_V WV的维度就变成了 d c ∗ d k d_c * d_k dcdk,最终输出的K和V维度依旧是 N ∗ d k N * d_k Ndk。但是请大家切记,MLA最终的公式中没有出现KV,我们现在处于将MHA转换到MLA的过程中,引入 K 和 V K和V KV只是方便进行公式推导。MHA进行Attention计算时会将 Q 和 K T Q和K^T QKT相乘,得到如下公式:
Q K T = X ∗ W Q ∗ ( C ∗ W K ) T = X ∗ ( W Q ∗ W K T ) ∗ C T = Q ′ C T \begin{equation} QK^T=X*W_Q *(C*W_K)^T=X* (W_Q*W_K^T)*C^T=Q'C^T \end{equation} QKT=XWQ(CWK)T=X(WQWKT)CT=QCT从公式(7)可以看出,我们可以将 W Q ∗ W K T W_Q*W_K^T WQWKT当做Q新的projection矩阵 W Q ′ W_Q' WQ(维度为 d ∗ d c d * d_c ddc),此时Attention计算就和K没什么关系了。类似的,Attention计算中的V也可以利用公式(6)进行消除,从而使得整个Attention计算只和 Q ′ 、 C Q'、C QC相关。MLA更进一步,Attention计算完成得到 O O O之后,还会将其与projection矩阵 W O W_O WO进行相乘, W V W_V WV也可以与 W O W_O WO融合在一起得到新的projection矩阵 W O ′ W_O' WO。最终我们可以得到MLA Attention部分的计算公式:
Q ′ = X ∗ W Q ′ \begin{equation} Q'=X*W_Q' \end{equation} Q=XWQ C = X ∗ W C \begin{equation} C=X*W_C \end{equation} C=XWC O = A t t e n t i o n ( Q ′ , C ) = s o f t m a x ( Q ′ C T d k ) C \begin{equation} O=Attention(Q',C)=softmax(\frac{Q'C^T}{\sqrt{d_k}}) C \end{equation} O=Attention(Q,C)=softmax(dk QCT)C O ′ = O ∗ W O ′ \begin{equation} O'=O*W_O' \end{equation} O=OWO其中, X X X维度为 N ∗ d N * d Nd W Q ′ W_Q' WQ W C W_C WC维度为 d ∗ d c d * d_c ddc Q ′ 、 C 和 O Q'、C和O QCO的维度为 N ∗ d c N * d_c Ndc W O ′ W_O' WO的维度为 d c ∗ d k d_c * d_k dcdk,最终输出 O ′ O' O的维度为 N ∗ d k N * d_k Ndk。当单个head变成h个head之后, W Q ′ W_Q' WQ W O ′ W_O' WO会多一个h维,但是 W C W_C WC则始终只保留一个head的参数量。公式(8)~(11)中几个变量使用了上标’用来表示MHA到MLA的转换,如果我们放弃MHA直接考虑MLA,则上标都可以直接去掉。

MLA相比原始的MHA简化了计算公式,压缩了缓存大小,难道天下真有免费的午餐吗?我们仔细对比一下MHA、MQA、GQA和MLA在推理过程中的模型单层参数量和计算量,如下表。我们假定模型的hidden size d d d为8k,head数为64,则得到 d k d_k dk为128,GQA的group数为8,MLA中的 d c d_c dc为512,模型支持的最长上下文N为128k。我们在统计计算量的时候只考虑了计算的大头,较小计算量的部分丢弃掉也不会产生大的影响。在这里插入图片描述

表1 MLA和MHA、MQA、GQA的参数量、计算量、缓存量对比

可以看出,MLA在参数量和计算量上都比另外三种Attention计算方法要大,我理解MLA效果要好于MHA的原因也是因为计算量的增大。但是MLA的好处就是大大降低了memory-bound为主的decoding阶段的缓存大小,从而使得解码速度变快。不过上面表格中的MLA没有考虑下面要介绍的部分,但是增加的参数量和计算量不影响大的结论。

2. Q和K增加RoPE信息

如果MLA到这里就结束那就非常完美了,它在没有大幅增加计算量和参数量的情况下大大提升了解码的推理速度。但是事情总是不遂人意,MLA的上述机制在当下的大模型下存在不可避免的缺陷:当下的大模型一般会在Attention计算之前将 Q 和 K Q和K QK添加RoPE,这就导致上面的公式(7)不再成立。但是我们又不能丢弃RoPE,为了弥补这个缺陷,MLA又在 Q 和 K Q和K QK中额外增加了一些维度专门用来存放RoPE信息。在本博客中我们着重介绍MLA是如何做的,原理性的介绍可以参看上面提到的苏神博客。MLA为了在 Q 和 K Q和K QK中添加RoPE信息,分别将其增加了额外的 d r d_r dr维度。公式(8)变为:
Q = [ X ∗ W Q , X ∗ W Q R ∗ R ] = [ X ∗ W Q , Q R ] \begin{equation} Q=[X*W_Q,X*W_{QR}*R]=[X*W_Q,QR] \end{equation} Q=[XWQ,XWQRR]=[XWQ,QR]其中 W Q R W_{QR} WQR是新增的一个矩阵,维度为 d ∗ d r d * d_r ddr R R R是RoPE矩阵,维度为 d r ∗ d r d_r * d_r drdr,则生成的 Q Q Q维度变为 N ∗ ( d c + d r ) N * (d_c + d_r) N(dc+dr),和 W Q W_Q WQ一样,h个head会使矩阵 W Q R W_{QR} WQR额外增加一个h维。我们用 Q R QR QR来代表 Q Q Q新增的子矩阵,维度为 N ∗ d r N *d_r Ndr。类似的,公式(5)也可以扩展为:
K = [ C ∗ W K , X ∗ W K R ∗ R ] = [ C ∗ W K , K R ] \begin{equation} K=[C*W_K,X*W_{KR}*R]=[C*W_K,K R] \end{equation} K=[CWK,XWKRR]=[CWK,KR] W K R W_{KR} WKR也是新增的矩阵,维度也是 d ∗ d r d * d_r ddr。用 K R KR KR来表示 K K K新增的子矩阵,它保存了RoPE信息,维度也为 N ∗ d r N * d_r Ndr。有了(12)、(13)之后,我们也扩展一下公式(7):
Q K T = [ X ∗ W Q , Q R ] ∗ [ C ∗ W K , K R ] T = [ X ∗ W Q , Q R ] ∗ ( ( C ∗ W K ) T K R T ) = [ X ∗ W Q , Q R ] ∗ ( W K T ∗ C T K R T ) = X ∗ ( W Q ∗ W K T ) ∗ C T + Q R ∗ K R T = X ∗ ( W Q ′ ) ∗ C T + Q R ∗ K R T = Q ′ C T + Q R ∗ K R T \begin{equation} \begin{aligned} &QK^T =[X*W_Q,QR] *[C*W_K,K R] ^T \\ &= [X*W_Q,QR] * \begin{pmatrix} {(C*W_K)}^T\\ {KR}^T\end{pmatrix} \\ &=[X*W_Q,QR] * \begin{pmatrix} W_K^T*C^T\\ {KR}^T\end{pmatrix} \\ &=X*(W_Q*W_K^T)*C^T+QR*{KR}^T \\ &=X*(W_Q')*C^T+QR*{KR}^T \\ &=Q'C^T+QR*{KR}^T \end{aligned} \end{equation} QKT=[XWQ,QR][CWK,KR]T=[XWQ,QR]((CWK)TKRT)=[XWQ,QR](WKTCTKRT)=X(WQWKT)CT+QRKRT=X(WQ)CT+QRKRT=QCT+QRKRT由公式(14)可以看出,在 Q 和 K Q和K QK新增 d r d_r dr维度计算RoPE信息的情况下,Attention的计算公式还可以继续复用上面提到的C Cache压缩技巧,只需要把第一部分提到的结果与新增的两个子矩阵相乘的结果相加即可。不过我们还需要把KR缓存下来,加速后半部分在decoding阶段的计算速度。所以整体的MLA需要保留 N ∗ d c N * d_c Ndc的C Cache,还需要保留 N ∗ d r N * d_r Ndr的KR Cache,但是 d r d_r dr一般不大,在论文中是 64 = d k / 2 64=d_k/2 64=dk/2

正是在上述两个创新点的加持下,MLA在大幅压缩KV Cache的基础上还保证了非常好的推理效果。当然我们也必须要承认,MLA实际上是增加了decoding计算量的,但是幸运的是decoding是一个访存主导的模块,目前看来减少访存的大小和次数还是加速decoding阶段的关键。此外,虽然在上面的介绍中我们是按照X为矩阵情况下的推导,也即对prefilling阶段进行了推导,但是在causal mask机制的影响下,上述的结果对X为向量时也成立,大家感兴趣可以参考我之前写的博客《大模型推理—KV Cache》将X换成向量推导看看。

Logo

欢迎加入DeepSeek 技术社区。在这里,你可以找到志同道合的朋友,共同探索AI技术的奥秘。

更多推荐