multi_head_attention_matrix_example 多头注意力 算例子
用具体参数矩阵例子说明多头注意力 MHA
下面用一个极小的多头注意力例子,完整说明 Multi-Head Attention,多头注意力 的计算过程。
我们设:
B = 1 # 只看一条序列
T = 3 # 序列里有 3 个 token
d = 4 # 每个 token 是 4 维向量
H = 2 # 分成 2 个 head
d_h = 2 # 每个 head 是 2 维
因为:
d=H⋅dh=2×2=4 d = H \cdot d_h = 2 \times 2 = 4 d=H⋅dh=2×2=4
1. 输入矩阵 X
忽略 batch 维度,输入是:
X∈R3×4 X \in \mathbb{R}^{3 \times 4} X∈R3×4
设:
X=[101001011100] X= \begin{bmatrix} 1 & 0 & 1 & 0\\ 0 & 1 & 0 & 1\\ 1 & 1 & 0 & 0 \end{bmatrix} X= 101011100010
每一行是一个 token:
x1=[1,0,1,0] x_1 = [1,0,1,0] x1=[1,0,1,0]
x2=[0,1,0,1] x_2 = [0,1,0,1] x2=[0,1,0,1]
x3=[1,1,0,0] x_3 = [1,1,0,0] x3=[1,1,0,0]
2. 简化设定:令 W_Q, W_K, W_V 都是单位矩阵
为了专注看 attention 的计算,先令:
WQ=WK=WV=I4 W_Q=W_K=W_V=I_4 WQ=WK=WV=I4
所以:
Q=XWQ=X Q=XW_Q=X Q=XWQ=X
K=XWK=X K=XW_K=X K=XWK=X
V=XWV=X V=XW_V=X V=XWV=X
因此:
Q=K=V=[101001011100] Q=K=V= \begin{bmatrix} 1 & 0 & 1 & 0\\ 0 & 1 & 0 & 1\\ 1 & 1 & 0 & 0 \end{bmatrix} Q=K=V= 101011100010
维度都是:
Q,K,V∈R3×4 Q,K,V \in \mathbb{R}^{3 \times 4} Q,K,V∈R3×4
3. 拆成 2 个 head
因为:
d=4,H=2,dh=2 d=4,\quad H=2,\quad d_h=2 d=4,H=2,dh=2
所以每个 token 的 4 维向量拆成两段:
前 2 维 -> head 1
后 2 维 -> head 2
3.1 Head 1
取前两维:
Q(1)=K(1)=V(1)=[100111] Q^{(1)}=K^{(1)}=V^{(1)}= \begin{bmatrix} 1 & 0\\ 0 & 1\\ 1 & 1 \end{bmatrix} Q(1)=K(1)=V(1)= 101011
维度:
Q(1),K(1),V(1)∈R3×2 Q^{(1)},K^{(1)},V^{(1)} \in \mathbb{R}^{3 \times 2} Q(1),K(1),V(1)∈R3×2
注意这里:
3 行 = 3 个 token
2 列 = 每个 head 的维度 d_h
也就是:
v1(1)=[1,0] v_1^{(1)}=[1,0] v1(1)=[1,0]
v2(1)=[0,1] v_2^{(1)}=[0,1] v2(1)=[0,1]
v3(1)=[1,1] v_3^{(1)}=[1,1] v3(1)=[1,1]
3.2 Head 2
取后两维:
Q(2)=K(2)=V(2)=[100100] Q^{(2)}=K^{(2)}=V^{(2)}= \begin{bmatrix} 1 & 0\\ 0 & 1\\ 0 & 0 \end{bmatrix} Q(2)=K(2)=V(2)= 100010
维度:
Q(2),K(2),V(2)∈R3×2 Q^{(2)},K^{(2)},V^{(2)} \in \mathbb{R}^{3 \times 2} Q(2),K(2),V(2)∈R3×2
4. 计算 Head 1
多头注意力中每个 head 的公式是:
O(h)=softmax(Q(h)(K(h))Tdh)V(h) O^{(h)}= \operatorname{softmax} \left( \frac{Q^{(h)}(K^{(h)})^T}{\sqrt{d_h}} \right) V^{(h)} O(h)=softmax(dhQ(h)(K(h))T)V(h)
对 head 1:
O(1)=softmax(Q(1)(K(1))T2)V(1) O^{(1)}= \operatorname{softmax} \left( \frac{Q^{(1)}(K^{(1)})^T}{\sqrt{2}} \right) V^{(1)} O(1)=softmax(2Q(1)(K(1))T)V(1)
4.1 计算 Q(1)(K(1))^T
Q(1)=[100111] Q^{(1)}= \begin{bmatrix} 1 & 0\\ 0 & 1\\ 1 & 1 \end{bmatrix} Q(1)= 101011
(K(1))T=[101011] (K^{(1)})^T= \begin{bmatrix} 1 & 0 & 1\\ 0 & 1 & 1 \end{bmatrix} (K(1))T=[100111]
所以:
Q(1)(K(1))T=[101011112] Q^{(1)}(K^{(1)})^T= \begin{bmatrix} 1 & 0 & 1\\ 0 & 1 & 1\\ 1 & 1 & 2 \end{bmatrix} Q(1)(K(1))T= 101011112
维度:
[3,2]×[2,3]=[3,3] [3,2]\times[2,3]=[3,3] [3,2]×[2,3]=[3,3]
这个 3×33\times33×3 矩阵表示:
每个 token 对每个 token 的相似度
例如第一行:
[1,0,1] [1,0,1] [1,0,1]
表示:
token 1 和 token 1 相似度 = 1
token 1 和 token 2 相似度 = 0
token 1 和 token 3 相似度 = 1
4.2 缩放
因为:
dh=2 d_h=2 dh=2
所以除以:
2 \sqrt{2} 2
得到:
S(1)=Q(1)(K(1))T2=[0.707100.707100.70710.70710.70710.70711.4142] S^{(1)}= \frac{Q^{(1)}(K^{(1)})^T}{\sqrt{2}}= \begin{bmatrix} 0.7071 & 0 & 0.7071\\ 0 & 0.7071 & 0.7071\\ 0.7071 & 0.7071 & 1.4142 \end{bmatrix} S(1)=2Q(1)(K(1))T= 0.707100.707100.70710.70710.70710.70711.4142
4.3 对每一行做 softmax
A(1)=softmax(S(1)) A^{(1)}= \operatorname{softmax}(S^{(1)}) A(1)=softmax(S(1))
得到:
A(1)≈[0.40110.19780.40110.19780.40110.40110.24830.24830.5035] A^{(1)} \approx \begin{bmatrix} 0.4011 & 0.1978 & 0.4011\\ 0.1978 & 0.4011 & 0.4011\\ 0.2483 & 0.2483 & 0.5035 \end{bmatrix} A(1)≈ 0.40110.19780.24830.19780.40110.24830.40110.40110.5035
维度:
A(1)∈R3×3 A^{(1)} \in \mathbb{R}^{3 \times 3} A(1)∈R3×3
每一行加起来等于 1。
第一行表示:
token 1 在 head 1 中:
关注 token 1 的权重 = 0.4011
关注 token 2 的权重 = 0.1978
关注 token 3 的权重 = 0.4011
第二行表示:
token 2 在 head 1 中:
关注 token 1 的权重 = 0.1978
关注 token 2 的权重 = 0.4011
关注 token 3 的权重 = 0.4011
第三行表示:
token 3 在 head 1 中:
关注 token 1 的权重 = 0.2483
关注 token 2 的权重 = 0.2483
关注 token 3 的权重 = 0.5035
4.4 乘以 V^(1)
现在计算:
O(1)=A(1)V(1) O^{(1)}=A^{(1)}V^{(1)} O(1)=A(1)V(1)
其中:
A(1)=[0.40110.19780.40110.19780.40110.40110.24830.24830.5035] A^{(1)}= \begin{bmatrix} 0.4011 & 0.1978 & 0.4011\\ 0.1978 & 0.4011 & 0.4011\\ 0.2483 & 0.2483 & 0.5035 \end{bmatrix} A(1)= 0.40110.19780.24830.19780.40110.24830.40110.40110.5035
V(1)=[100111] V^{(1)}= \begin{bmatrix} 1 & 0\\ 0 & 1\\ 1 & 1 \end{bmatrix} V(1)= 101011
维度:
[3,3]×[3,2]=[3,2] [3,3]\times[3,2]=[3,2] [3,3]×[3,2]=[3,2]
所以可以乘。
第一行输出
$$
o_1^{(1)}
0.4011[1,0]
+
0.1978[0,1]
+
0.4011[1,1]
$$
第一维:
0.4011×1+0.1978×0+0.4011×1=0.8022 0.4011\times1+0.1978\times0+0.4011\times1=0.8022 0.4011×1+0.1978×0+0.4011×1=0.8022
第二维:
0.4011×0+0.1978×1+0.4011×1=0.5989 0.4011\times0+0.1978\times1+0.4011\times1=0.5989 0.4011×0+0.1978×1+0.4011×1=0.5989
所以:
o1(1)=[0.8022,0.5989] o_1^{(1)}=[0.8022,0.5989] o1(1)=[0.8022,0.5989]
第二行输出
o2(1)=0.1978[1,0]+0.4011[0,1]+0.4011[1,1] o_2^{(1)}= 0.1978[1,0] + 0.4011[0,1] + 0.4011[1,1] o2(1)=0.1978[1,0]+0.4011[0,1]+0.4011[1,1]
第一维:
0.1978+0+0.4011=0.5989 0.1978+0+0.4011=0.5989 0.1978+0+0.4011=0.5989
第二维:
0+0.4011+0.4011=0.8022 0+0.4011+0.4011=0.8022 0+0.4011+0.4011=0.8022
所以:
o2(1)=[0.5989,0.8022] o_2^{(1)}=[0.5989,0.8022] o2(1)=[0.5989,0.8022]
第三行输出
o3(1)=0.2483[1,0]+0.2483[0,1]+0.5035[1,1] o_3^{(1)}= 0.2483[1,0] + 0.2483[0,1] + 0.5035[1,1] o3(1)=0.2483[1,0]+0.2483[0,1]+0.5035[1,1]
第一维:
0.2483+0+0.5035=0.7518 0.2483+0+0.5035=0.7518 0.2483+0+0.5035=0.7518
第二维:
0+0.2483+0.5035=0.7518 0+0.2483+0.5035=0.7518 0+0.2483+0.5035=0.7518
所以:
o3(1)=[0.7518,0.7518] o_3^{(1)}=[0.7518,0.7518] o3(1)=[0.7518,0.7518]
因此:
O(1)≈[0.80220.59890.59890.80220.75180.7518] O^{(1)} \approx \begin{bmatrix} 0.8022 & 0.5989\\ 0.5989 & 0.8022\\ 0.7518 & 0.7518 \end{bmatrix} O(1)≈ 0.80220.59890.75180.59890.80220.7518
维度:
O(1)∈R3×2 O^{(1)}\in\mathbb{R}^{3\times2} O(1)∈R3×2
含义:
3 个 token
每个 token 在 head 1 中得到 2 维新表示
5. 计算 Head 2
Head 2:
Q(2)=K(2)=V(2)=[100100] Q^{(2)}=K^{(2)}=V^{(2)}= \begin{bmatrix} 1 & 0\\ 0 & 1\\ 0 & 0 \end{bmatrix} Q(2)=K(2)=V(2)= 100010
5.1 计算相似度
Q(2)(K(2))T=[100010000] Q^{(2)}(K^{(2)})^T= \begin{bmatrix} 1 & 0 & 0\\ 0 & 1 & 0\\ 0 & 0 & 0 \end{bmatrix} Q(2)(K(2))T= 100010000
除以 2\sqrt{2}2:
S(2)=[0.70710000.70710000] S^{(2)}= \begin{bmatrix} 0.7071 & 0 & 0\\ 0 & 0.7071 & 0\\ 0 & 0 & 0 \end{bmatrix} S(2)= 0.70710000.70710000
5.2 softmax
$$
A^{(2)}
\operatorname{softmax}(S^{(2)})
$$
得到:
A(2)≈[0.50350.24830.24830.24830.50350.24830.33330.33330.3333] A^{(2)} \approx \begin{bmatrix} 0.5035 & 0.2483 & 0.2483\\ 0.2483 & 0.5035 & 0.2483\\ 0.3333 & 0.3333 & 0.3333 \end{bmatrix} A(2)≈ 0.50350.24830.33330.24830.50350.33330.24830.24830.3333
第三行是:
[0,0,0] [0,0,0] [0,0,0]
softmax 后就是平均:
[1/3,1/3,1/3] [1/3,1/3,1/3] [1/3,1/3,1/3]
5.3 乘以 V^(2)
O(2)=A(2)V(2) O^{(2)}=A^{(2)}V^{(2)} O(2)=A(2)V(2)
其中:
V(2)=[100100] V^{(2)}= \begin{bmatrix} 1 & 0\\ 0 & 1\\ 0 & 0 \end{bmatrix} V(2)= 100010
计算得到:
O(2)≈[0.50350.24830.24830.50350.33330.3333] O^{(2)} \approx \begin{bmatrix} 0.5035 & 0.2483\\ 0.2483 & 0.5035\\ 0.3333 & 0.3333 \end{bmatrix} O(2)≈ 0.50350.24830.33330.24830.50350.3333
维度:
O(2)∈R3×2 O^{(2)}\in\mathbb{R}^{3\times2} O(2)∈R3×2
6. 拼接两个 head
现在两个 head 的输出分别是:
O(1)≈[0.80220.59890.59890.80220.75180.7518] O^{(1)} \approx \begin{bmatrix} 0.8022 & 0.5989\\ 0.5989 & 0.8022\\ 0.7518 & 0.7518 \end{bmatrix} O(1)≈ 0.80220.59890.75180.59890.80220.7518
O(2)≈[0.50350.24830.24830.50350.33330.3333] O^{(2)} \approx \begin{bmatrix} 0.5035 & 0.2483\\ 0.2483 & 0.5035\\ 0.3333 & 0.3333 \end{bmatrix} O(2)≈ 0.50350.24830.33330.24830.50350.3333
把它们在最后一维拼接:
$$
O_{\text{concat}}
\operatorname{Concat}(O{(1)},O{(2)})
$$
得到:
Oconcat≈[0.80220.59890.50350.24830.59890.80220.24830.50350.75180.75180.33330.3333] O_{\text{concat}} \approx \begin{bmatrix} 0.8022 & 0.5989 & 0.5035 & 0.2483\\ 0.5989 & 0.8022 & 0.2483 & 0.5035\\ 0.7518 & 0.7518 & 0.3333 & 0.3333 \end{bmatrix} Oconcat≈ 0.80220.59890.75180.59890.80220.75180.50350.24830.33330.24830.50350.3333
维度:
Oconcat∈R3×4 O_{\text{concat}}\in\mathbb{R}^{3\times4} Oconcat∈R3×4
也就是:
3 个 token
每个 token 重新变回 4 维
7. 输出投影 W_O
真实 MHA 最后还有一个输出矩阵:
WO∈R4×4 W_O \in \mathbb{R}^{4\times4} WO∈R4×4
最终输出:
Y=OconcatWO Y=O_{\text{concat}}W_O Y=OconcatWO
这里为了简化,设:
WO=I4 W_O=I_4 WO=I4
所以:
Y=Oconcat Y=O_{\text{concat}} Y=Oconcat
即:
Y≈[0.80220.59890.50350.24830.59890.80220.24830.50350.75180.75180.33330.3333] Y \approx \begin{bmatrix} 0.8022 & 0.5989 & 0.5035 & 0.2483\\ 0.5989 & 0.8022 & 0.2483 & 0.5035\\ 0.7518 & 0.7518 & 0.3333 & 0.3333 \end{bmatrix} Y≈ 0.80220.59890.75180.59890.80220.75180.50350.24830.33330.24830.50350.3333
8. 维度总表
| 步骤 | 矩阵 | 维度 |
|---|---|---|
| 输入 | XXX | [3,4][3,4][3,4] |
| Q/K/V | Q,K,VQ,K,VQ,K,V | [3,4][3,4][3,4] |
| head 1 | Q(1),K(1),V(1)Q^{(1)},K^{(1)},V^{(1)}Q(1),K(1),V(1) | [3,2][3,2][3,2] |
| head 2 | Q(2),K(2),V(2)Q^{(2)},K^{(2)},V^{(2)}Q(2),K(2),V(2) | [3,2][3,2][3,2] |
| head 1 score | Q(1)(K(1))TQ^{(1)}(K^{(1)})^TQ(1)(K(1))T | [3,3][3,3][3,3] |
| head 1 attention | A(1)A^{(1)}A(1) | [3,3][3,3][3,3] |
| head 1 output | O(1)=A(1)V(1)O^{(1)}=A^{(1)}V^{(1)}O(1)=A(1)V(1) | [3,2][3,2][3,2] |
| head 2 output | O(2)=A(2)V(2)O^{(2)}=A^{(2)}V^{(2)}O(2)=A(2)V(2) | [3,2][3,2][3,2] |
| 拼接 | OconcatO_{\text{concat}}Oconcat | [3,4][3,4][3,4] |
| 输出 | YYY | [3,4][3,4][3,4] |
9. 最关键的理解
每个 head 里面:
A(h)∈RT×T A^{(h)} \in \mathbb{R}^{T\times T} A(h)∈RT×T
表示:
每个 token 对所有 token 的注意力权重
而:
V(h)∈RT×dh V^{(h)} \in \mathbb{R}^{T\times d_h} V(h)∈RT×dh
表示:
每个 token 对外提供的 value 信息
所以:
O(h)=A(h)V(h) O^{(h)}=A^{(h)}V^{(h)} O(h)=A(h)V(h)
就是:
[T,T]×[T,dh]=[T,dh] [T,T]\times[T,d_h]=[T,d_h] [T,T]×[T,dh]=[T,dh]
含义是:
每个 token 根据注意力权重,从所有 token 的 value 中加权汇聚信息。
也就是:
oi(h)=∑j=1TAij(h)vj(h) o_i^{(h)}= \sum_{j=1}^{T} A_{ij}^{(h)}v_j^{(h)} oi(h)=j=1∑TAij(h)vj(h)
一句话:
Attention 的本质是:每个 token 用自己的 Query 找到所有 Key 的相关性,再用这个相关性加权所有 Value。多头就是在多个子空间里重复这个过程,最后拼接。
更多推荐



所有评论(0)