介绍MultiHead Attention的计算过程。
说明 (0)
MultiHead Attention是Self Attention的推广,或者说后者是前者的特例。
Self Attention是这样的:
- 把单词的512维的embedding向量在512维空间投影3次,分别生成Query投影、Key投影和Value投影。这些投影都还是512维的,可以理解为没有丢失信息。然后就和单词的embedding向量没有关系了。而是使用这3个投影进行运算。
- 在一个上下文中,使用Query投影和Key投影计算单词两两之间的注意力权重;注意力权重可以理解为单词之间的相关性大小。
- 使用注意力权重加权求和Value投影,得到各个单词的SelfAttentionOutput;SelfAttentionOutput和embedding向量是同构的(都是512维)。
MultiHead Attention:
- 把单词的512维的embedding向量在低维空间(例如64维,可以选择其它维度)投影3次,分别生成Query投影、Key投影和Value投影。这些投影都是64维的,可以理解为丢失了部分信息。然后暂时和单词的embedding向量没有关系了。而是使用这3个投影进行运算。
- 在一个上下文中,使用Query投影和Key投影计算单词两两之间的注意力权重;注意力权重可以理解为单词之间的相关性大小。
- 使用注意力权重加权求和Value投影,得到各个单词的Output;Output和embedding向量不是同构的,embedding是512维,而Output是64维。不过到这里还没结束。
到现在为止,除了投影到低维(64维)且生成的Output也是低维的(64维)之外,和Self Attention没有任何区别。但是,MultiHead Attention要把面的过程重复多次。重复几次呢?
答案是$\frac{512}{64}=8$次。如此以来,把8次的结果拼起来,又得到512维的Output,和embedding保持同构!这也是“multi”的由来。
最后,使用一个$512 \times 512$的线性变换矩阵$W^O$(见1.4节),对512维的Output进行一个线性变换,就得到最终的MultiHeadOutput!
可见,如果在选择低维空间进行投影的时候,还选择512维(把512维看作一个特殊的低维),那么重复的次数就是$\frac{512}{512}=1$次!再选择$512 \times 512$的单位矩阵作为线性变换矩阵$W^O$,MultiHead Attention就退化成Self Attention了。
我是这么理解的:**从高维(512)到低维(64)投影时丢失了部分信息,但是我们换了8个角度,投了8次(把Query投影-Key投影-Value投影算1次),这样就尽可能的保留了本体的特征。并且和Self Attention的1次同维投影相比,多角度的低维投影可能更能抓住本体的特征。这里,本体的特征就是单词间的关系!
输入 (1)
假设当前上下文是“humpty dumpty sat on”,我们要预测下一个单词。有以下输入:
Eembedding矩阵 (1.1)
假设每个单词的embedding向量是512维,并且假设:
- “humpty”的embedding向量是:
$$
E_1 = [e_{1,1}, e_{1,2}, e_{1,3}, \ldots, e_{1,512}]
$$
- “dumpty”的embedding向量是:
$$
E_2 = [e_{2,1}, e_{2,2}, e_{2,3}, \ldots, e_{2,512}]
$$
- “sat”的embedding向量是:
$$
E_3 = [e_{3,1}, e_{3,2}, e_{3,3}, \ldots, e_{3,512}]
$$
- “on”的embedding向量是:
$$
E_4 = [e_{4,1}, e_{4,2}, e_{4,3}, \ldots, e_{4,512}]
$$
其实就是一个$4 \times 512$的矩阵:
$$
X = \begin{bmatrix}
e_{1,1} & e_{1,2} & e_{1,3} & \cdots & e_{1,512} \\
e_{2,1} & e_{2,2} & e_{2,3} & \cdots & e_{2,512} \\
e_{3,1} & e_{3,2} & e_{3,3} & \cdots & e_{3,512} \\
e_{4,1} & e_{4,2} & e_{4,3} & \cdots & e_{4,512}
\end{bmatrix}
$$
头数 (1.2)
从前面的说明可知,选择头数和选择低维空间的维数是同一个事儿!假设头数为8:
$$ h = 8 $$
那么每个头的维度为:
$$ d_k = \frac{512}{8} = 64 $$
投影矩阵 (1.3)
对于每个头$i \in {1, 2, \ldots, 8}$,有3个投影矩阵,它们都是$512 \times 64$的矩阵(如第0节所述)。每个头的投影矩阵不同,当然也不能相同!否则的话,就是同样的运算重复8次!
$$
W_i^Q = \begin{bmatrix}
q_{1,1} & q_{1,2} & q_{1,3} & \cdots & q_{1,64} \\
q_{2,1} & q_{2,2} & q_{2,3} & \cdots & q_{2,64} \\
\vdots & \vdots & \vdots & \ddots & \vdots \\
q_{512,1} & q_{512,2} & q_{512,3} & \cdots & q_{512,64}
\end{bmatrix}
$$
$$
W_i^K = \begin{bmatrix}
k_{1,1} & k_{1,2} & k_{1,3} & \cdots & k_{1,64} \\
k_{2,1} & k_{2,2} & k_{2,3} & \cdots & k_{2,64} \\
\vdots & \vdots & \vdots & \ddots & \vdots \\
k_{512,1} & k_{512,2} & k_{512,3} & \cdots & k_{512,64}
\end{bmatrix}
$$
$$
W_i^V = \begin{bmatrix}
v_{1,1} & v_{1,2} & v_{1,3} & \cdots & v_{1,64} \\
v_{2,1} & v_{2,2} & v_{2,3} & \cdots & v_{2,64} \\
\vdots & \vdots & \vdots & \ddots & \vdots \\
v_{512,1} & v_{512,2} & v_{512,3} & \cdots & v_{512,64}
\end{bmatrix}
$$
这几个矩阵都可以通过学习得到,其中的意义也不好解释。但它们叫做投影矩阵,意思是把单词的512维embedding向量投影到64维空间,见第2节生成Query, Key, Value矩阵。
线性变换矩阵 (1.4)
这是个$512 \times 512$的矩阵,也通过学习得到的。在MultiHead Attention输出结果之前,对结果进行一个线性变换。显然,若它是单位矩阵,则相当于没有变换。我们把它记作$W^O$。具体形状不必展开。
生成 Query、Key、Value 矩阵 (2)
对于每个头$i \in {1, 2, \ldots, 8}$:
- Query矩阵 $Q_i = XW_i^Q =$
$$
\begin{bmatrix}
\sum\limits_{n=1}^{512}e_{1,n}q_{n,1} & \sum\limits_{n=1}^{512}e_{1,n}q_{n,2} & \sum\limits_{n=1}^{512}e_{1,n}q_{n,3} & \cdots & \sum\limits_{n=1}^{512}e_{1,n}q_{n,64} \\
\sum\limits_{n=1}^{512}e_{2,n}q_{n,1} & \sum\limits_{n=1}^{512}e_{2,n}q_{n,2} & \sum\limits_{n=1}^{512}e_{2,n}q_{n,3} & \cdots & \sum\limits_{n=1}^{512}e_{2,n}q_{n,64} \\
\sum\limits_{n=1}^{512}e_{3,n}q_{n,1} & \sum\limits_{n=1}^{512}e_{3,n}q_{n,2} & \sum\limits_{n=1}^{512}e_{3,n}q_{n,3} & \cdots & \sum\limits_{n=1}^{512}e_{3,n}q_{n,64} \\
\sum\limits_{n=1}^{512}e_{4,n}q_{n,1} & \sum\limits_{n=1}^{512}e_{4,n}q_{n,2} & \sum\limits_{n=1}^{512}e_{4,n}q_{n,3} & \cdots & \sum\limits_{n=1}^{512}e_{4,n}q_{n,64}
\end{bmatrix}
$$
审视一下这个矩阵:首先,第1行是一个64维的向量,每个元素都是第1个单词”humpty”的embedding向量($E_1$)的线性组合。注意到embedding向量是512维,而这一行是64维,所以我们说,它是”humpty”的512维的embedding向量在64维空间的投影(参考第1.3节的投影矩阵$W_i^Q$)。由于是高维(512)到低维(64)的投影,所以一定会丢失一些信息,就像3维到2维投影只保留了形状而丢失了深度信息。总之,第1行就是第1个单词”humpty”的一些特征,来自原始的512维的embedding向量。同理,第2行是第2个单词特征,来自512维的embedding向量;第3行是第3个单词的特征;第4行是第4个单词的特征。
下面的Key矩阵和Value矩阵也是一样,每一行都是一个单词的特征:只是投影矩阵不同($W_i^K$、$W_i^V$、$W_i^Q$互不相同),得到的特征也不同。
- Key矩阵 $K_i = XW_i^K =$
$$
\begin{bmatrix}
\sum\limits_{n=1}^{512}e_{1,n}k_{n,1} & \sum\limits_{n=1}^{512}e_{1,n}k_{n,2} & \sum\limits_{n=1}^{512}e_{1,n}k_{n,3} & \cdots & \sum\limits_{n=1}^{512}e_{1,n}k_{n,64} \\
\sum\limits_{n=1}^{512}e_{2,n}k_{n,1} & \sum\limits_{n=1}^{512}e_{2,n}k_{n,2} & \sum\limits_{n=1}^{512}e_{2,n}k_{n,3} & \cdots & \sum\limits_{n=1}^{512}e_{2,n}k_{n,64} \\
\sum\limits_{n=1}^{512}e_{3,n}k_{n,1} & \sum\limits_{n=1}^{512}e_{3,n}k_{n,2} & \sum\limits_{n=1}^{512}e_{3,n}k_{n,3} & \cdots & \sum\limits_{n=1}^{512}e_{3,n}k_{n,64} \\
\sum\limits_{n=1}^{512}e_{4,n}k_{n,1} & \sum\limits_{n=1}^{512}e_{4,n}k_{n,2} & \sum\limits_{n=1}^{512}e_{4,n}k_{n,3} & \cdots & \sum\limits_{n=1}^{512}e_{4,n}k_{n,64}
\end{bmatrix}
$$
- Value矩阵 $V_i = XW_i^V =$
$$
\begin{bmatrix}
\sum\limits_{n=1}^{512}e_{1,n}v_{n,1} & \sum\limits_{n=1}^{512}e_{1,n}v_{n,2} & \sum\limits_{n=1}^{512}e_{1,n}v_{n,3} & \cdots & \sum\limits_{n=1}^{512}e_{1,n}v_{n,64} \\
\sum\limits_{n=1}^{512}e_{2,n}v_{n,1} & \sum\limits_{n=1}^{512}e_{2,n}v_{n,2} & \sum\limits_{n=1}^{512}e_{2,n}v_{n,3} & \cdots & \sum\limits_{n=1}^{512}e_{2,n}v_{n,64} \\
\sum\limits_{n=1}^{512}e_{3,n}v_{n,1} & \sum\limits_{n=1}^{512}e_{3,n}v_{n,2} & \sum\limits_{n=1}^{512}e_{3,n}v_{n,3} & \cdots & \sum\limits_{n=1}^{512}e_{3,n}v_{n,64} \\
\sum\limits_{n=1}^{512}e_{4,n}v_{n,1} & \sum\limits_{n=1}^{512}e_{4,n}v_{n,2} & \sum\limits_{n=1}^{512}e_{4,n}v_{n,3} & \cdots & \sum\limits_{n=1}^{512}e_{4,n}v_{n,64}
\end{bmatrix}
$$
到目前,我们只是把每个单词的embedding向量在64维空间投影3次:Query投影,Key投影,Value投影。我是这么理解的:
- Value投影代表embedding本身(自己的影子);
- Query投影用于查询自己和别的单词的关系;
- Key投影用于别的单词查询和自己的关系;
- 单词之间的关系就是单词之间的相关性,也就是“注意力分数”;“注意力分数”经过缩放、归一化处理就是单词之间的权重;
- 用一个单词相对各个单词的权重对各个单词的Value向量进行加权求和,得到一个新的64维向量(因为Value向量是64维,所以新向量也是64维),就是这个头的Output。
我们使用$Q_x$,$K_x$,$V_x$分别表示单词$x$的Query投影,Key投影和Value投影(它们都是64维的向量),$x \in [humpty, dumpty, sat, on]$,则:
$$
Q_i = XW_i^Q = \begin{bmatrix}
Q_{humpty} \\
Q_{dumpty} \\
Q_{sat} \\
Q_{on}
\end{bmatrix}
$$
$$
K_i = XW_i^K = \begin{bmatrix}
K_{humpty} \\
K_{dumpty} \\
K_{sat} \\
K_{on}
\end{bmatrix}
$$
$$
V_i = XW_i^V = \begin{bmatrix}
V_{humpty} \\
V_{dumpty} \\
V_{sat} \\
V_{on}
\end{bmatrix}
$$
计算单头注意力 (3)
以第$i$个头为例:
计算注意力分数矩阵 (3.1)
$$
AttentionScore_i = \frac{Q_iK_i^T}{\sqrt{d_k}}
$$
这个矩阵展开就很吓人了。但不要恐慌,分析一下,还是比较简单的。
首先把分母忽略掉,它就是除以标量$\sqrt{d_k} = \sqrt{64} = 8$,不改变矩阵的结构。这是一个缩放因子,用于防止点积结果过大导致梯度消失或爆炸问题。
然后重点看分子部分:它是$Q_i$乘以$K_i$的转置($K_i^T$)。$Q_i$和$K_i$都是$4 \times 64$的矩阵,不转置肯定没法相乘。$K_i$的转置($K_i^T$)就是$64 \times 4$的矩阵,这样以来,结果就是$4 \times 4$的矩阵。
再看怎么相乘的:$K_i$的转置($K_i^T$),就是把第1行变成第1列,第2行变成第2列,第3行变成第3列……;然后$Q_i$和它相乘。用肉眼看的话,刚好也不用转置了:
直接拿$Q_i$的第1行和$K_i$的第1行对应相乘再相加(点积),得到$AttentionScore_i[1][1]$;即第1个单词的Query投影点乘第1个单词Key投影;
直接拿$Q_i$的第1行和$K_i$的第2行对应相乘再相加(点积),得到$AttentionScore_i[1][2]$;即第1个单词的Query投影点乘第2个单词Key投影;
直接拿$Q_i$的第1行和$K_i$的第3行对应相乘再相加(点积),得到$AttentionScore_i[1][3]$;即第1个单词的Query投影点乘第3个单词Key投影;
直接拿$Q_i$的第1行和$K_i$的第4行对应相乘再相加(点积),得到$AttentionScore_i[1][4]$;即第1个单词的Query投影点乘第4个单词Key投影;
直接拿$Q_i$的第2行和$K_i$的第1行对应相乘再相加(点积),得到$AttentionScore_i[2][1]$;即第2个单词的Query投影点乘第1个单词Key投影;
直接拿$Q_i$的第2行和$K_i$的第2行对应相乘再相加(点积),得到$AttentionScore_i[2][2]$;即第2个单词的Query投影点乘第2个单词Key投影;
直接拿$Q_i$的第2行和$K_i$的第3行对应相乘再相加(点积),得到$AttentionScore_i[2][3]$;即第2个单词的Query投影点乘第3个单词Key投影;
直接拿$Q_i$的第2行和$K_i$的第4行对应相乘再相加(点积),得到$AttentionScore_i[2][4]$;即第2个单词的Query投影点乘第4个单词Key投影;
…
前面我们使用$Q_x$,$K_x$分别表示Query投影和Key投影,$x \in [humpty, dumpty, sat, on]$,则:
$AttentionScore_i = $
$$
\begin{bmatrix}
Q_{humpty} \cdot K_{humpty}, & Q_{humpty} \cdot K_{dumpty}, & Q_{humpty} \cdot K_{sat}, & Q_{humpty} \cdot K_{on} \\
Q_{dumpty} \cdot K_{humpty}, & Q_{dumpty} \cdot K_{dumpty}, & Q_{dumpty} \cdot K_{sat}, & Q_{dumpty} \cdot K_{on} \\
Q_{sat} \cdot K_{humpty}, & Q_{sat} \cdot K_{dumpty}, & Q_{sat} \cdot K_{sat}, & Q_{sat} \cdot K_{on} \\
Q_{on} \cdot K_{humpty}, & Q_{on} \cdot K_{dumpty}, & Q_{on} \cdot K_{sat}, & Q_{on} \cdot K_{on}
\end{bmatrix}
$$
$AttentionScore_i[m][n]$到底是什么意义呢?它其实表示单词m对单词n的注意力分数;
应用Softmax归一化 (3.2)
$$
AttentionWeight_i = softmax(AttentionScore_i)
$$
注意$softmax$不是针对整个矩阵的,而是逐行归一化,即每行的所有元素独立进行softmax归一化,确保每行的和为1。
将所有元素视为一个向量,整体归一化(所有元素和为1),这种方式几乎不使用:因为这种操作会破坏矩阵的语义结构,导致行/列间依赖关系丢失。
所以,这一步比较简单,没有改变矩阵的结构,$AttentionWeight_i$还是一个$4 \times 4$的矩阵。就是把$AttentionScore_i$归一化,转化成概率分布。即$AttentionWeight_i[m][n]$是单词m对单词n的weight。我们表示为:
$$
AttentionWeight = \begin{bmatrix}
w_{humpty,humpty} & w_{humpty,dumpty} & w_{humpty,sat} & w_{humpty,on} \\
w_{dumpty,humpty} & w_{dumpty,dumpty} & w_{dumpty,sat} & w_{dumpty,on} \\
w_{sat,humpty} & w_{sat,dumpty} & w_{sat,sat} & w_{sat,on} \\
w_{on,humpty} & w_{on,dumpty} & w_{on,sat} & w_{on,on}
\end{bmatrix}
$$
加权聚合Value (3.3)
前面说过,矩阵$V_i$是一个$4 \times 64$的矩阵,也就是每个单词的Value投影,每行是一个单词。每列呢?可以看作是每个单词的一个投影分量,即:
$$
V_i = \begin{bmatrix}
V_{humpty} \\
V_{dumpty} \\
V_{sat} \\
V_{on}
\end{bmatrix} = \begin{bmatrix}
V_{humpty}分量1, & V_{humpty}分量2, & \cdots, & V_{humpty}分量64 \\
V_{dumpty}分量1, & V_{dumpty}分量2, & \cdots, & V_{dumpty}分量64 \\
V_{sat}分量1, & V_{sat}分量2, & \cdots, & V_{sat}分量64 \\
V_{on}分量1, & V_{on}分量2, & \cdots, & V_{on}分量64
\end{bmatrix}
$$
上一小节我们得到$AttentionWeight_i$是一个$4 \times 4$的矩阵,它乘以$V_i$ ($4 \times 64$),就得到$4 \times 64$的加权聚合Value:
$Head_i = AttentionWeight_i \cdot V_i = $
$$
\begin{bmatrix}
\sum\limits_{x}w_{humpty,x}V_x分量1, & \sum\limits_{x}w_{humpty,x}V_x分量2, & \cdots, & \sum\limits_{x}w_{humpty,x}V_x分量64 \\
\sum\limits_{x}w_{dumpty,x}V_x分量1, & \sum\limits_{x}w_{dumpty,x}V_x分量2, & \cdots, & \sum\limits_{x}w_{dumpty,x}V_x分量64 \\
\sum\limits_{x}w_{sat,x}V_x分量1, & \sum\limits_{x}w_{sat,x}V_x分量2, & \cdots, & \sum\limits_{x}w_{sat,x}V_x分量64 \\
\sum\limits_{x}w_{on,x}V_x分量1, & \sum\limits_{x}w_{on,x}V_x分量2, & \cdots, & \sum\limits_{x}w_{on,x}V_x分量64
\end{bmatrix}
$$
- $Head_i[humpty][1]$:所有单词的Value的分量1加权求和;这里的权是指humpty对其它单词的AttentionWeight。
- $Head_i[humpty][2]$:所有单词的Value的分量2加权求和;这里的权是指humpty对其它单词的AttentionWeight。
- …
- $Head_i[sat][1]$:所有单词的Value的分量1加权求和;这里的权是指sat对其它单词的AttentionWeight。
- $Head_i[sat][2]$:所有单词的Value的分量2加权求和;这里的权是指sat对其它单词的AttentionWeight。
- …
就是说,第1行是humpty的($Head_i$的)Output;第2行是dumpty的Output; ……
每个单词的Output又是所有单词的Value的加权求和。因为Value是一个向量,所以,加权求和是指各个分量对应加权求和!
这里的权是指当前单词对于所有单词的AttentionWeight。
总之,$Head_i$是个$4 \times 64$的矩阵,其中第1行中编码了第1个单词与其它单词之间的关联信息;第2行中编码了第2个单词与其它单词的关联信息,……
合并多头输出 (4)
到目前为止,我们都在看一个头;得到的$Head_i$也是这一个头$i$的Output。它是一个$4 \times 64$的矩阵。
总共有8个头,重复上面的过程,得到8个$4 \times 64$的Output!拼起来就是$4 \times 512$的矩阵,和输入矩阵$X$同构。
$$
MultiHead(X) = \begin{bmatrix}
Head_1; & Head_2; & Head_3; & \cdots; & Head_8
\end{bmatrix}
$$
其实这8个头,只是输入的投影矩阵不同,其它计算都一模一样。
输出线性变换 (5)
上面已经得到和输入$X$的同构矩阵$MultiHead(X)$ ($4 \times 512$),再乘以线性变换矩阵$W^O$ ($512 \times 512$,见第1.4节),得到的还是$4 \times 512$的矩阵。即:
$$
MultiHeadOutput = MultiHead(X)W^O \in R^{4 \times 512}
$$
总结 (6)
- 输入与输出对齐:输出维度与输入一致,便于残差连接和后续层处理。
- 多头协作:每个头学习不同模式,最终输出融合多视角信息。
- 上下文编码:每个输出位置聚合了全局依赖关系,为预测提供丰富特征。