
大模型推理--MLA
本来想写一篇KV Cache压缩的综述性博客,结果写到MLA部分的时候发现越写越多,完全值得单独拿出来写篇博客,遂从KV Cache压缩博客中单独揪出MLA进行介绍。MLA(Multi-query Latent Attention)是国内创业公司deepseek在24年5月份发布的大模型中用到的KV Cache压缩技术,正是在该技术的加持下DeepSeek-V2可以大幅压缩KV Cache的大小,
本来想写一篇KV Cache压缩的综述性博客,结果写到MLA部分的时候发现越写越多,完全值得单独拿出来写篇博客,遂从KV Cache压缩博客中单独揪出MLA进行介绍。
MLA(Multi-query Latent Attention)是国内创业公司deepseek在24年5月份发布的DeepSeek-V2大模型中用到的KV Cache压缩技术,正是在该技术的加持下DeepSeek-V2可以大幅压缩KV Cache的大小,进而大幅提升吞吐量,也正是从该模型开始,大模型推理的价格一下降低到一个很低的水平。MLA是少有的由国内公司做出的硬核创新,感谢deepseek,感谢MLA!我觉得在出现新的KV Cache压缩技术之前后续的大模型可能都会采用MLA,它的压缩效果接近MQA,但是生成效果却还比MHA更好,值得大家跟进。
MLA相比MQA和MHA相比做到了既要又要,着实牛逼,代价就是太难懂了。花了一天的时间仔细研究了一下苏剑林苏神的博客《缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA》,对MLA有了一个大概的理解,下面我按照自己的思路尝试解释一下MLA,在解释的过中我抛弃了MLA论文中复杂的符号定义,重新按照自己的理解去定义相关矩阵。
常规的Attention计算会用到
Q
、
K
、
V
Q、K、V
Q、K、V,这也是我们需要保存KV的原因,而MLA则不保留KV Cache,另外引入了一个C Cache来代替KV Cache。在执行Attention计算的时候,通过一系列等价变换,将公式中出现的
K
K
K和
V
V
V均用
C
C
C来代替。不仅可以用
C
C
C来代替
K
V
KV
KV,而且
Q
Q
Q所有的head都共享同一个C Cache,从这个层面来说MLA和MQA很类似,只需要保留一个head的缓存即可获得非常好的结果。另外一点,现在的大模型一般会在计算Attention之前将
Q
Q
Q和
K
K
K进行RoPE(旋转位置编码),如图1所示。这就会导致单纯的
C
C
C丢失了位置信息,为了弥补这个缺陷,MLA中给Q和K额外增加了
d
r
d_r
dr个维度用来添加RoPE,其中
K
K
K新增的维度也是每个Head共享。大家看到这里可能比较懵逼,且听我一一道来。
1. KV Cache变C Cache
在MHA的单head计算过程中,输入
X
X
X(维度为
N
∗
d
N * d
N∗d)会先送入三个projection矩阵
W
Q
,
W
K
,
W
V
W_Q,W_K,W_V
WQ,WK,WV进行线性变换:
Q
=
X
∗
W
Q
\begin{equation} Q=X*W_Q \end{equation}
Q=X∗WQ
K
=
X
∗
W
K
\begin{equation} K=X*W_K \end{equation}
K=X∗WK
V
=
X
∗
W
V
\begin{equation} V=X*W_V \end{equation}
V=X∗WV这三个矩阵的维度一般相等,通常都是
d
∗
d
k
d * d_k
d∗dk,所以输出维度还是
N
∗
d
k
N * d_k
N∗dk。变换完成之后生成的
K
、
V
K、V
K、V即是传统意义上的KV Cache。Multi-head会将h个head的
Q
、
K
、
V
Q、K、V
Q、K、V矩阵分别拼接在一起,构成
d
∗
d
d * d
d∗d的矩阵,其中
d
=
h
∗
d
k
d = h * d_k
d=h∗dk。前面我们提到MLA没有保留KV Cache,而是引入了一个C Cache,也即增加了下面的公式:
C
=
X
∗
W
C
\begin{equation} C=X*W_C \end{equation}
C=X∗WC
W
C
W_C
WC的维度不再是
d
∗
d
k
d * d_k
d∗dk,而是变成了
d
∗
d
c
d * d_c
d∗dc,
d
c
d_c
dc通常远小于
d
d
d,但是会比
d
k
d_k
dk大(DeepSeek-V2中
d
k
d_k
dk是128,
d
c
d_c
dc是512)。有了
C
C
C之后,为了保证Attention计算的正确性,我们可以从概念上再引入K和V,也即得到如下两个关于KV的新公式:
K
=
C
∗
W
K
\begin{equation} K=C*W_K \end{equation}
K=C∗WK
V
=
C
∗
W
V
\begin{equation} V=C*W_V \end{equation}
V=C∗WV此时
W
K
W_K
WK和
W
V
W_V
WV的维度就变成了
d
c
∗
d
k
d_c * d_k
dc∗dk,最终输出的K和V维度依旧是
N
∗
d
k
N * d_k
N∗dk。但是请大家切记,MLA最终的公式中没有出现KV,我们现在处于将MHA转换到MLA的过程中,引入
K
和
V
K和V
K和V只是方便进行公式推导。MHA进行Attention计算时会将
Q
和
K
T
Q和K^T
Q和KT相乘,得到如下公式:
Q
K
T
=
X
∗
W
Q
∗
(
C
∗
W
K
)
T
=
X
∗
(
W
Q
∗
W
K
T
)
∗
C
T
=
Q
′
C
T
\begin{equation} QK^T=X*W_Q *(C*W_K)^T=X* (W_Q*W_K^T)*C^T=Q'C^T \end{equation}
QKT=X∗WQ∗(C∗WK)T=X∗(WQ∗WKT)∗CT=Q′CT从公式(7)可以看出,我们可以将
W
Q
∗
W
K
T
W_Q*W_K^T
WQ∗WKT当做Q新的projection矩阵
W
Q
′
W_Q'
WQ′(维度为
d
∗
d
c
d * d_c
d∗dc),此时Attention计算就和K没什么关系了。类似的,Attention计算中的V也可以利用公式(6)进行消除,从而使得整个Attention计算只和
Q
′
、
C
Q'、C
Q′、C相关。MLA更进一步,Attention计算完成得到
O
O
O之后,还会将其与projection矩阵
W
O
W_O
WO进行相乘,
W
V
W_V
WV也可以与
W
O
W_O
WO融合在一起得到新的projection矩阵
W
O
′
W_O'
WO′。最终我们可以得到MLA Attention部分的计算公式:
Q
′
=
X
∗
W
Q
′
\begin{equation} Q'=X*W_Q' \end{equation}
Q′=X∗WQ′
C
=
X
∗
W
C
\begin{equation} C=X*W_C \end{equation}
C=X∗WC
O
=
A
t
t
e
n
t
i
o
n
(
Q
′
,
C
)
=
s
o
f
t
m
a
x
(
Q
′
C
T
d
k
)
C
\begin{equation} O=Attention(Q',C)=softmax(\frac{Q'C^T}{\sqrt{d_k}}) C \end{equation}
O=Attention(Q′,C)=softmax(dkQ′CT)C
O
′
=
O
∗
W
O
′
\begin{equation} O'=O*W_O' \end{equation}
O′=O∗WO′其中,
X
X
X维度为
N
∗
d
N * d
N∗d,
W
Q
′
W_Q'
WQ′和
W
C
W_C
WC维度为
d
∗
d
c
d * d_c
d∗dc,
Q
′
、
C
和
O
Q'、C和O
Q′、C和O的维度为
N
∗
d
c
N * d_c
N∗dc,
W
O
′
W_O'
WO′的维度为
d
c
∗
d
k
d_c * d_k
dc∗dk,最终输出
O
′
O'
O′的维度为
N
∗
d
k
N * d_k
N∗dk。当单个head变成h个head之后,
W
Q
′
W_Q'
WQ′和
W
O
′
W_O'
WO′会多一个h维,但是
W
C
W_C
WC则始终只保留一个head的参数量。公式(8)~(11)中几个变量使用了上标’用来表示MHA到MLA的转换,如果我们放弃MHA直接考虑MLA,则上标都可以直接去掉。
MLA相比原始的MHA简化了计算公式,压缩了缓存大小,难道天下真有免费的午餐吗?我们仔细对比一下MHA、MQA、GQA和MLA在推理过程中的模型单层参数量和计算量,如下表。我们假定模型的hidden size
d
d
d为8k,head数为64,则得到
d
k
d_k
dk为128,GQA的group数为8,MLA中的
d
c
d_c
dc为512,模型支持的最长上下文N为128k。我们在统计计算量的时候只考虑了计算的大头,较小计算量的部分丢弃掉也不会产生大的影响。
可以看出,MLA在参数量和计算量上都比另外三种Attention计算方法要大,我理解MLA效果要好于MHA的原因也是因为计算量的增大。但是MLA的好处就是大大降低了memory-bound为主的decoding阶段的缓存大小,从而使得解码速度变快。不过上面表格中的MLA没有考虑下面要介绍的部分,但是增加的参数量和计算量不影响大的结论。
2. Q和K增加RoPE信息
如果MLA到这里就结束那就非常完美了,它在没有大幅增加计算量和参数量的情况下大大提升了解码的推理速度。但是事情总是不遂人意,MLA的上述机制在当下的大模型下存在不可避免的缺陷:当下的大模型一般会在Attention计算之前将
Q
和
K
Q和K
Q和K添加RoPE,这就导致上面的公式(7)不再成立。但是我们又不能丢弃RoPE,为了弥补这个缺陷,MLA又在
Q
和
K
Q和K
Q和K中额外增加了一些维度专门用来存放RoPE信息。在本博客中我们着重介绍MLA是如何做的,原理性的介绍可以参看上面提到的苏神博客。MLA为了在
Q
和
K
Q和K
Q和K中添加RoPE信息,分别将其增加了额外的
d
r
d_r
dr维度。公式(8)变为:
Q
=
[
X
∗
W
Q
,
X
∗
W
Q
R
∗
R
]
=
[
X
∗
W
Q
,
Q
R
]
\begin{equation} Q=[X*W_Q,X*W_{QR}*R]=[X*W_Q,QR] \end{equation}
Q=[X∗WQ,X∗WQR∗R]=[X∗WQ,QR]其中
W
Q
R
W_{QR}
WQR是新增的一个矩阵,维度为
d
∗
d
r
d * d_r
d∗dr,
R
R
R是RoPE矩阵,维度为
d
r
∗
d
r
d_r * d_r
dr∗dr,则生成的
Q
Q
Q维度变为
N
∗
(
d
c
+
d
r
)
N * (d_c + d_r)
N∗(dc+dr),和
W
Q
W_Q
WQ一样,h个head会使矩阵
W
Q
R
W_{QR}
WQR额外增加一个h维。我们用
Q
R
QR
QR来代表
Q
Q
Q新增的子矩阵,维度为
N
∗
d
r
N *d_r
N∗dr。类似的,公式(5)也可以扩展为:
K
=
[
C
∗
W
K
,
X
∗
W
K
R
∗
R
]
=
[
C
∗
W
K
,
K
R
]
\begin{equation} K=[C*W_K,X*W_{KR}*R]=[C*W_K,K R] \end{equation}
K=[C∗WK,X∗WKR∗R]=[C∗WK,KR]
W
K
R
W_{KR}
WKR也是新增的矩阵,维度也是
d
∗
d
r
d * d_r
d∗dr。用
K
R
KR
KR来表示
K
K
K新增的子矩阵,它保存了RoPE信息,维度也为
N
∗
d
r
N * d_r
N∗dr。有了(12)、(13)之后,我们也扩展一下公式(7):
Q
K
T
=
[
X
∗
W
Q
,
Q
R
]
∗
[
C
∗
W
K
,
K
R
]
T
=
[
X
∗
W
Q
,
Q
R
]
∗
(
(
C
∗
W
K
)
T
K
R
T
)
=
[
X
∗
W
Q
,
Q
R
]
∗
(
W
K
T
∗
C
T
K
R
T
)
=
X
∗
(
W
Q
∗
W
K
T
)
∗
C
T
+
Q
R
∗
K
R
T
=
X
∗
(
W
Q
′
)
∗
C
T
+
Q
R
∗
K
R
T
=
Q
′
C
T
+
Q
R
∗
K
R
T
\begin{equation} \begin{aligned} &QK^T =[X*W_Q,QR] *[C*W_K,K R] ^T \\ &= [X*W_Q,QR] * \begin{pmatrix} {(C*W_K)}^T\\ {KR}^T\end{pmatrix} \\ &=[X*W_Q,QR] * \begin{pmatrix} W_K^T*C^T\\ {KR}^T\end{pmatrix} \\ &=X*(W_Q*W_K^T)*C^T+QR*{KR}^T \\ &=X*(W_Q')*C^T+QR*{KR}^T \\ &=Q'C^T+QR*{KR}^T \end{aligned} \end{equation}
QKT=[X∗WQ,QR]∗[C∗WK,KR]T=[X∗WQ,QR]∗((C∗WK)TKRT)=[X∗WQ,QR]∗(WKT∗CTKRT)=X∗(WQ∗WKT)∗CT+QR∗KRT=X∗(WQ′)∗CT+QR∗KRT=Q′CT+QR∗KRT由公式(14)可以看出,在
Q
和
K
Q和K
Q和K新增
d
r
d_r
dr维度计算RoPE信息的情况下,Attention的计算公式还可以继续复用上面提到的C Cache压缩技巧,只需要把第一部分提到的结果与新增的两个子矩阵相乘的结果相加即可。不过我们还需要把KR缓存下来,加速后半部分在decoding阶段的计算速度。所以整体的MLA需要保留
N
∗
d
c
N * d_c
N∗dc的C Cache,还需要保留
N
∗
d
r
N * d_r
N∗dr的KR Cache,但是
d
r
d_r
dr一般不大,在论文中是
64
=
d
k
/
2
64=d_k/2
64=dk/2。
正是在上述两个创新点的加持下,MLA在大幅压缩KV Cache的基础上还保证了非常好的推理效果。当然我们也必须要承认,MLA实际上是增加了decoding计算量的,但是幸运的是decoding是一个访存主导的模块,目前看来减少访存的大小和次数还是加速decoding阶段的关键。此外,虽然在上面的介绍中我们是按照X为矩阵情况下的推导,也即对prefilling阶段进行了推导,但是在causal mask机制的影响下,上述的结果对X为向量时也成立,大家感兴趣可以参考我之前写的博客《大模型推理—KV Cache》将X换成向量推导看看。
更多推荐
所有评论(0)