用具体参数矩阵例子说明多头注意力 MHA

下面用一个极小的多头注意力例子,完整说明 Multi-Head Attention,多头注意力 的计算过程。

我们设:

B = 1      # 只看一条序列
T = 3      # 序列里有 3 个 token
d = 4      # 每个 token 是 4 维向量
H = 2      # 分成 2 个 head
d_h = 2    # 每个 head 是 2 维

因为:

d=H⋅dh=2×2=4 d = H \cdot d_h = 2 \times 2 = 4 d=Hdh=2×2=4


1. 输入矩阵 X

忽略 batch 维度,输入是:

X∈R3×4 X \in \mathbb{R}^{3 \times 4} XR3×4

设:

X=[101001011100] X= \begin{bmatrix} 1 & 0 & 1 & 0\\ 0 & 1 & 0 & 1\\ 1 & 1 & 0 & 0 \end{bmatrix} X= 101011100010

每一行是一个 token:

x1=[1,0,1,0] x_1 = [1,0,1,0] x1=[1,0,1,0]

x2=[0,1,0,1] x_2 = [0,1,0,1] x2=[0,1,0,1]

x3=[1,1,0,0] x_3 = [1,1,0,0] x3=[1,1,0,0]


2. 简化设定:令 W_Q, W_K, W_V 都是单位矩阵

为了专注看 attention 的计算,先令:

WQ=WK=WV=I4 W_Q=W_K=W_V=I_4 WQ=WK=WV=I4

所以:

Q=XWQ=X Q=XW_Q=X Q=XWQ=X

K=XWK=X K=XW_K=X K=XWK=X

V=XWV=X V=XW_V=X V=XWV=X

因此:

Q=K=V=[101001011100] Q=K=V= \begin{bmatrix} 1 & 0 & 1 & 0\\ 0 & 1 & 0 & 1\\ 1 & 1 & 0 & 0 \end{bmatrix} Q=K=V= 101011100010

维度都是:

Q,K,V∈R3×4 Q,K,V \in \mathbb{R}^{3 \times 4} Q,K,VR3×4


3. 拆成 2 个 head

因为:

d=4,H=2,dh=2 d=4,\quad H=2,\quad d_h=2 d=4,H=2,dh=2

所以每个 token 的 4 维向量拆成两段:

前 2 维 -> head 1
后 2 维 -> head 2

3.1 Head 1

取前两维:

Q(1)=K(1)=V(1)=[100111] Q^{(1)}=K^{(1)}=V^{(1)}= \begin{bmatrix} 1 & 0\\ 0 & 1\\ 1 & 1 \end{bmatrix} Q(1)=K(1)=V(1)= 101011

维度:

Q(1),K(1),V(1)∈R3×2 Q^{(1)},K^{(1)},V^{(1)} \in \mathbb{R}^{3 \times 2} Q(1),K(1),V(1)R3×2

注意这里:

3 行 = 3 个 token
2 列 = 每个 head 的维度 d_h

也就是:

v1(1)=[1,0] v_1^{(1)}=[1,0] v1(1)=[1,0]

v2(1)=[0,1] v_2^{(1)}=[0,1] v2(1)=[0,1]

v3(1)=[1,1] v_3^{(1)}=[1,1] v3(1)=[1,1]


3.2 Head 2

取后两维:

Q(2)=K(2)=V(2)=[100100] Q^{(2)}=K^{(2)}=V^{(2)}= \begin{bmatrix} 1 & 0\\ 0 & 1\\ 0 & 0 \end{bmatrix} Q(2)=K(2)=V(2)= 100010

维度:

Q(2),K(2),V(2)∈R3×2 Q^{(2)},K^{(2)},V^{(2)} \in \mathbb{R}^{3 \times 2} Q(2),K(2),V(2)R3×2


4. 计算 Head 1

多头注意力中每个 head 的公式是:

O(h)=softmax⁡(Q(h)(K(h))Tdh)V(h) O^{(h)}= \operatorname{softmax} \left( \frac{Q^{(h)}(K^{(h)})^T}{\sqrt{d_h}} \right) V^{(h)} O(h)=softmax(dh Q(h)(K(h))T)V(h)

对 head 1:

O(1)=softmax⁡(Q(1)(K(1))T2)V(1) O^{(1)}= \operatorname{softmax} \left( \frac{Q^{(1)}(K^{(1)})^T}{\sqrt{2}} \right) V^{(1)} O(1)=softmax(2 Q(1)(K(1))T)V(1)


4.1 计算 Q(1)(K(1))^T

Q(1)=[100111] Q^{(1)}= \begin{bmatrix} 1 & 0\\ 0 & 1\\ 1 & 1 \end{bmatrix} Q(1)= 101011

(K(1))T=[101011] (K^{(1)})^T= \begin{bmatrix} 1 & 0 & 1\\ 0 & 1 & 1 \end{bmatrix} (K(1))T=[100111]

所以:

Q(1)(K(1))T=[101011112] Q^{(1)}(K^{(1)})^T= \begin{bmatrix} 1 & 0 & 1\\ 0 & 1 & 1\\ 1 & 1 & 2 \end{bmatrix} Q(1)(K(1))T= 101011112

维度:

[3,2]×[2,3]=[3,3] [3,2]\times[2,3]=[3,3] [3,2]×[2,3]=[3,3]

这个 3×33\times33×3 矩阵表示:

每个 token 对每个 token 的相似度

例如第一行:

[1,0,1] [1,0,1] [1,0,1]

表示:

token 1 和 token 1 相似度 = 1
token 1 和 token 2 相似度 = 0
token 1 和 token 3 相似度 = 1

4.2 缩放

因为:

dh=2 d_h=2 dh=2

所以除以:

2 \sqrt{2} 2

得到:

S(1)=Q(1)(K(1))T2=[0.707100.707100.70710.70710.70710.70711.4142] S^{(1)}= \frac{Q^{(1)}(K^{(1)})^T}{\sqrt{2}}= \begin{bmatrix} 0.7071 & 0 & 0.7071\\ 0 & 0.7071 & 0.7071\\ 0.7071 & 0.7071 & 1.4142 \end{bmatrix} S(1)=2 Q(1)(K(1))T= 0.707100.707100.70710.70710.70710.70711.4142


4.3 对每一行做 softmax

A(1)=softmax⁡(S(1)) A^{(1)}= \operatorname{softmax}(S^{(1)}) A(1)=softmax(S(1))

得到:

A(1)≈[0.40110.19780.40110.19780.40110.40110.24830.24830.5035] A^{(1)} \approx \begin{bmatrix} 0.4011 & 0.1978 & 0.4011\\ 0.1978 & 0.4011 & 0.4011\\ 0.2483 & 0.2483 & 0.5035 \end{bmatrix} A(1) 0.40110.19780.24830.19780.40110.24830.40110.40110.5035

维度:

A(1)∈R3×3 A^{(1)} \in \mathbb{R}^{3 \times 3} A(1)R3×3

每一行加起来等于 1。

第一行表示:

token 1 在 head 1 中:
关注 token 1 的权重 = 0.4011
关注 token 2 的权重 = 0.1978
关注 token 3 的权重 = 0.4011

第二行表示:

token 2 在 head 1 中:
关注 token 1 的权重 = 0.1978
关注 token 2 的权重 = 0.4011
关注 token 3 的权重 = 0.4011

第三行表示:

token 3 在 head 1 中:
关注 token 1 的权重 = 0.2483
关注 token 2 的权重 = 0.2483
关注 token 3 的权重 = 0.5035

4.4 乘以 V^(1)

现在计算:

O(1)=A(1)V(1) O^{(1)}=A^{(1)}V^{(1)} O(1)=A(1)V(1)

其中:

A(1)=[0.40110.19780.40110.19780.40110.40110.24830.24830.5035] A^{(1)}= \begin{bmatrix} 0.4011 & 0.1978 & 0.4011\\ 0.1978 & 0.4011 & 0.4011\\ 0.2483 & 0.2483 & 0.5035 \end{bmatrix} A(1)= 0.40110.19780.24830.19780.40110.24830.40110.40110.5035

V(1)=[100111] V^{(1)}= \begin{bmatrix} 1 & 0\\ 0 & 1\\ 1 & 1 \end{bmatrix} V(1)= 101011

维度:

[3,3]×[3,2]=[3,2] [3,3]\times[3,2]=[3,2] [3,3]×[3,2]=[3,2]

所以可以乘。


第一行输出

$$
o_1^{(1)}

0.4011[1,0]
+
0.1978[0,1]
+
0.4011[1,1]
$$

第一维:

0.4011×1+0.1978×0+0.4011×1=0.8022 0.4011\times1+0.1978\times0+0.4011\times1=0.8022 0.4011×1+0.1978×0+0.4011×1=0.8022

第二维:

0.4011×0+0.1978×1+0.4011×1=0.5989 0.4011\times0+0.1978\times1+0.4011\times1=0.5989 0.4011×0+0.1978×1+0.4011×1=0.5989

所以:

o1(1)=[0.8022,0.5989] o_1^{(1)}=[0.8022,0.5989] o1(1)=[0.8022,0.5989]


第二行输出

o2(1)=0.1978[1,0]+0.4011[0,1]+0.4011[1,1] o_2^{(1)}= 0.1978[1,0] + 0.4011[0,1] + 0.4011[1,1] o2(1)=0.1978[1,0]+0.4011[0,1]+0.4011[1,1]

第一维:

0.1978+0+0.4011=0.5989 0.1978+0+0.4011=0.5989 0.1978+0+0.4011=0.5989

第二维:

0+0.4011+0.4011=0.8022 0+0.4011+0.4011=0.8022 0+0.4011+0.4011=0.8022

所以:

o2(1)=[0.5989,0.8022] o_2^{(1)}=[0.5989,0.8022] o2(1)=[0.5989,0.8022]


第三行输出

o3(1)=0.2483[1,0]+0.2483[0,1]+0.5035[1,1] o_3^{(1)}= 0.2483[1,0] + 0.2483[0,1] + 0.5035[1,1] o3(1)=0.2483[1,0]+0.2483[0,1]+0.5035[1,1]

第一维:

0.2483+0+0.5035=0.7518 0.2483+0+0.5035=0.7518 0.2483+0+0.5035=0.7518

第二维:

0+0.2483+0.5035=0.7518 0+0.2483+0.5035=0.7518 0+0.2483+0.5035=0.7518

所以:

o3(1)=[0.7518,0.7518] o_3^{(1)}=[0.7518,0.7518] o3(1)=[0.7518,0.7518]


因此:

O(1)≈[0.80220.59890.59890.80220.75180.7518] O^{(1)} \approx \begin{bmatrix} 0.8022 & 0.5989\\ 0.5989 & 0.8022\\ 0.7518 & 0.7518 \end{bmatrix} O(1) 0.80220.59890.75180.59890.80220.7518

维度:

O(1)∈R3×2 O^{(1)}\in\mathbb{R}^{3\times2} O(1)R3×2

含义:

3 个 token
每个 token 在 head 1 中得到 2 维新表示

5. 计算 Head 2

Head 2:

Q(2)=K(2)=V(2)=[100100] Q^{(2)}=K^{(2)}=V^{(2)}= \begin{bmatrix} 1 & 0\\ 0 & 1\\ 0 & 0 \end{bmatrix} Q(2)=K(2)=V(2)= 100010


5.1 计算相似度

Q(2)(K(2))T=[100010000] Q^{(2)}(K^{(2)})^T= \begin{bmatrix} 1 & 0 & 0\\ 0 & 1 & 0\\ 0 & 0 & 0 \end{bmatrix} Q(2)(K(2))T= 100010000

除以 2\sqrt{2}2

S(2)=[0.70710000.70710000] S^{(2)}= \begin{bmatrix} 0.7071 & 0 & 0\\ 0 & 0.7071 & 0\\ 0 & 0 & 0 \end{bmatrix} S(2)= 0.70710000.70710000


5.2 softmax

$$
A^{(2)}

\operatorname{softmax}(S^{(2)})
$$

得到:

A(2)≈[0.50350.24830.24830.24830.50350.24830.33330.33330.3333] A^{(2)} \approx \begin{bmatrix} 0.5035 & 0.2483 & 0.2483\\ 0.2483 & 0.5035 & 0.2483\\ 0.3333 & 0.3333 & 0.3333 \end{bmatrix} A(2) 0.50350.24830.33330.24830.50350.33330.24830.24830.3333

第三行是:

[0,0,0] [0,0,0] [0,0,0]

softmax 后就是平均:

[1/3,1/3,1/3] [1/3,1/3,1/3] [1/3,1/3,1/3]


5.3 乘以 V^(2)

O(2)=A(2)V(2) O^{(2)}=A^{(2)}V^{(2)} O(2)=A(2)V(2)

其中:

V(2)=[100100] V^{(2)}= \begin{bmatrix} 1 & 0\\ 0 & 1\\ 0 & 0 \end{bmatrix} V(2)= 100010

计算得到:

O(2)≈[0.50350.24830.24830.50350.33330.3333] O^{(2)} \approx \begin{bmatrix} 0.5035 & 0.2483\\ 0.2483 & 0.5035\\ 0.3333 & 0.3333 \end{bmatrix} O(2) 0.50350.24830.33330.24830.50350.3333

维度:

O(2)∈R3×2 O^{(2)}\in\mathbb{R}^{3\times2} O(2)R3×2


6. 拼接两个 head

现在两个 head 的输出分别是:

O(1)≈[0.80220.59890.59890.80220.75180.7518] O^{(1)} \approx \begin{bmatrix} 0.8022 & 0.5989\\ 0.5989 & 0.8022\\ 0.7518 & 0.7518 \end{bmatrix} O(1) 0.80220.59890.75180.59890.80220.7518

O(2)≈[0.50350.24830.24830.50350.33330.3333] O^{(2)} \approx \begin{bmatrix} 0.5035 & 0.2483\\ 0.2483 & 0.5035\\ 0.3333 & 0.3333 \end{bmatrix} O(2) 0.50350.24830.33330.24830.50350.3333

把它们在最后一维拼接:

$$
O_{\text{concat}}

\operatorname{Concat}(O{(1)},O{(2)})
$$

得到:

Oconcat≈[0.80220.59890.50350.24830.59890.80220.24830.50350.75180.75180.33330.3333] O_{\text{concat}} \approx \begin{bmatrix} 0.8022 & 0.5989 & 0.5035 & 0.2483\\ 0.5989 & 0.8022 & 0.2483 & 0.5035\\ 0.7518 & 0.7518 & 0.3333 & 0.3333 \end{bmatrix} Oconcat 0.80220.59890.75180.59890.80220.75180.50350.24830.33330.24830.50350.3333

维度:

Oconcat∈R3×4 O_{\text{concat}}\in\mathbb{R}^{3\times4} OconcatR3×4

也就是:

3 个 token
每个 token 重新变回 4 维

7. 输出投影 W_O

真实 MHA 最后还有一个输出矩阵:

WO∈R4×4 W_O \in \mathbb{R}^{4\times4} WOR4×4

最终输出:

Y=OconcatWO Y=O_{\text{concat}}W_O Y=OconcatWO

这里为了简化,设:

WO=I4 W_O=I_4 WO=I4

所以:

Y=Oconcat Y=O_{\text{concat}} Y=Oconcat

即:

Y≈[0.80220.59890.50350.24830.59890.80220.24830.50350.75180.75180.33330.3333] Y \approx \begin{bmatrix} 0.8022 & 0.5989 & 0.5035 & 0.2483\\ 0.5989 & 0.8022 & 0.2483 & 0.5035\\ 0.7518 & 0.7518 & 0.3333 & 0.3333 \end{bmatrix} Y 0.80220.59890.75180.59890.80220.75180.50350.24830.33330.24830.50350.3333


8. 维度总表

步骤 矩阵 维度
输入 XXX [3,4][3,4][3,4]
Q/K/V Q,K,VQ,K,VQ,K,V [3,4][3,4][3,4]
head 1 Q(1),K(1),V(1)Q^{(1)},K^{(1)},V^{(1)}Q(1),K(1),V(1) [3,2][3,2][3,2]
head 2 Q(2),K(2),V(2)Q^{(2)},K^{(2)},V^{(2)}Q(2),K(2),V(2) [3,2][3,2][3,2]
head 1 score Q(1)(K(1))TQ^{(1)}(K^{(1)})^TQ(1)(K(1))T [3,3][3,3][3,3]
head 1 attention A(1)A^{(1)}A(1) [3,3][3,3][3,3]
head 1 output O(1)=A(1)V(1)O^{(1)}=A^{(1)}V^{(1)}O(1)=A(1)V(1) [3,2][3,2][3,2]
head 2 output O(2)=A(2)V(2)O^{(2)}=A^{(2)}V^{(2)}O(2)=A(2)V(2) [3,2][3,2][3,2]
拼接 OconcatO_{\text{concat}}Oconcat [3,4][3,4][3,4]
输出 YYY [3,4][3,4][3,4]

9. 最关键的理解

每个 head 里面:

A(h)∈RT×T A^{(h)} \in \mathbb{R}^{T\times T} A(h)RT×T

表示:

每个 token 对所有 token 的注意力权重

而:

V(h)∈RT×dh V^{(h)} \in \mathbb{R}^{T\times d_h} V(h)RT×dh

表示:

每个 token 对外提供的 value 信息

所以:

O(h)=A(h)V(h) O^{(h)}=A^{(h)}V^{(h)} O(h)=A(h)V(h)

就是:

[T,T]×[T,dh]=[T,dh] [T,T]\times[T,d_h]=[T,d_h] [T,T]×[T,dh]=[T,dh]

含义是:

每个 token 根据注意力权重,从所有 token 的 value 中加权汇聚信息。

也就是:

oi(h)=∑j=1TAij(h)vj(h) o_i^{(h)}= \sum_{j=1}^{T} A_{ij}^{(h)}v_j^{(h)} oi(h)=j=1TAij(h)vj(h)

一句话:

Attention 的本质是:每个 token 用自己的 Query 找到所有 Key 的相关性,再用这个相关性加权所有 Value。多头就是在多个子空间里重复这个过程,最后拼接。

Logo

欢迎加入DeepSeek 技术社区。在这里,你可以找到志同道合的朋友,共同探索AI技术的奥秘。

更多推荐