介绍
本篇笔记参考李沐的《动手学深度学习》10.3章节的内容。
正文
考虑两个对应的“键”(Key),“值”(Value)矩阵
$$
\mathbf{K} \in \mathbb{R}^{n \times k}, ~ \mathbf{V} \in \mathbb{R}^{n \times v}
$$
$$
\mathbf{K}=\begin{pmatrix} \mathbf{k}_1 \\ \mathbf{k}_2 \\ \vdots \\ \mathbf{k}_n \end{pmatrix}, \quad \mathbf{V}=\begin{pmatrix} \mathbf{v}_1 \\ \mathbf{v}_2 \\ \vdots \\ \mathbf{v}_n \end{pmatrix}
$$
其中,$\mathbf{k}_i \in \mathbb{R}^{1 \times k}, ~ \mathbf{v}_i \in \mathbb{R}^{1 \times v}$,为行向量。
$\mathbf{k}_i$到$\mathbf{v}_i$为对应的映射关系。考虑行向量$\mathbf{q} \in \mathbb{R}^{1 \times k}$,通过计算$q$和$k_i$的相似度并使用Softmax进行归一化权重来拟合其对应的可能$\mathbf{v}$值。
$a(\mathbf{q}, \mathbf{k}_i)$,注意力评分函数(attention scoring function)
$\alpha(\mathbf{q}, \mathbf{k}_i)$,注意力权重
他们均为标量。
$$
\alpha(\mathbf{q}, \mathbf{k}_i) =
\mathrm{softmax}(a(\mathbf{q}, \mathbf{k}_i)) =
\frac{\mathrm{exp}(a(\mathbf{q}, \mathbf{k}_i))}
{\sum\limits_{j = 1}^{n}\mathrm{exp}(a(\mathbf{q}, \mathbf{k}_j))}$$
这样,行向量$\mathbf{q}$所预测的对应的$\mathbf{v}$值$f(\mathbf{q})$为:
$$
f(\mathbf{q}) =
\sum\limits_{i=1}^{n} \alpha(\mathbf{q}, \mathbf{k}_i) \cdot \mathbf{v}_i
~
\in \mathbb{R}^{1 \times v}
$$
现在我们引入缩放点积注意力评分函数,有:
$$
a(\mathbf{q}, \mathbf{k}_i) = \frac{\mathbf{q} \cdot \mathbf{k}_i^\mathrm{T}}{\sqrt{k}}
$$
假设查询和键的所有元素都是独立的随机变量, 并且都满足零均值和单位方差, 那么两个向量的点积的均值为0,方差为n。
现在考虑m个行向量查询q组成的矩阵$\mathbf{Q}$的运算$f(\mathbf{Q})$:
$$
\mathbf{Q} =
\begin{pmatrix}
\mathbf{q}_1 \\ \mathbf{q}_2 \\ \vdots \\ \mathbf{q}_m
\end{pmatrix}
\in
\mathbb{R}^{m \times k}
$$
$$
\frac{\mathbf{Q} \cdot \mathbf{K}^\mathrm{T}}{\sqrt{k}} =
\frac{1}{\sqrt{k}}
\cdot
\begin{pmatrix}
\mathbf{q}_1 \\ \mathbf{q}_2 \\ \vdots \\ \mathbf{q}_m
\end{pmatrix}
\cdot
\begin{pmatrix}
\mathbf{k}_1^\mathrm{T} & \mathbf{k}_2^\mathrm{T} & \cdots & \mathbf{k}_n^\mathrm{T} \end{pmatrix} =
\begin{pmatrix}
\frac{\mathbf{q}_1 \cdot \mathbf{k}_1^\mathrm{T}}{\sqrt{k}} & \frac{\mathbf{q}_1 \cdot \mathbf{k}_2^\mathrm{T}}{\sqrt{k}}& \cdots & \frac{\mathbf{q}_1 \cdot \mathbf{k}_n^\mathrm{T}}{\sqrt{k}}
\\
\frac{\mathbf{q}_2 \cdot \mathbf{k}_1^\mathrm{T}}{\sqrt{k}} & \frac{\mathbf{q}_2 \cdot \mathbf{k}_2^\mathrm{T}}{\sqrt{k}}& \cdots & \frac{\mathbf{q}_2 \cdot \mathbf{k}_n^\mathrm{T}}{\sqrt{k}}
\\
\vdots & \vdots & \ddots & \vdots
\\
\frac{\mathbf{q}_m \cdot \mathbf{k}_1^\mathrm{T}}{\sqrt{k}} & \frac{\mathbf{q}_m \cdot \mathbf{k}_2^\mathrm{T}}{\sqrt{k}}& \cdots & \frac{\mathbf{q}_m \cdot \mathbf{k}_n^\mathrm{T}}{\sqrt{k}}
\end{pmatrix}
$$
不难看出:
$$
\frac{\mathbf{Q} \cdot \mathbf{K}^\mathrm{T}}{\sqrt{k}} =
\begin{pmatrix}
a(\mathbf{q}_1, \mathbf{k}_1) & a(\mathbf{q}_1, \mathbf{k}_2)& \cdots & a(\mathbf{q}_1, \mathbf{k}_n)
\\
a(\mathbf{q}_2, \mathbf{k}_1) & a(\mathbf{q}_2, \mathbf{k}_2)& \cdots & a(\mathbf{q}_2, \mathbf{k}_n)
\\
\vdots & \vdots & \ddots & \vdots
\\
a(\mathbf{q}_m, \mathbf{k}_1) & a(\mathbf{q}_m, \mathbf{k}_2)& \cdots & a(\mathbf{q}_m, \mathbf{k}_n)
\end{pmatrix}$$
再对$\frac{\mathbf{Q} \cdot \mathbf{K}^\mathrm{T}}{\sqrt{k}}$执行行Softmax操作,有:
$$
\mathrm{softmax} \bigg (\frac{\mathbf{Q} \cdot \mathbf{K}^\mathrm{T}}{\sqrt{k}} \bigg ) =
\begin{pmatrix}
\alpha(\mathbf{q}_1, \mathbf{k}_1) & \alpha(\mathbf{q}_1, \mathbf{k}_2)& \cdots & \alpha(\mathbf{q}_1, \mathbf{k}_n)
\\
\alpha(\mathbf{q}_2, \mathbf{k}_1) & \alpha(\mathbf{q}_2, \mathbf{k}_2)& \cdots & \alpha(\mathbf{q}_2, \mathbf{k}_n)
\\
\vdots & \vdots & \ddots & \vdots
\\
\alpha(\mathbf{q}_m, \mathbf{k}_1) & \alpha(\mathbf{q}_m, \mathbf{k}_2)& \cdots & \alpha(\mathbf{q}_m, \mathbf{k}_n)
\end{pmatrix}
\in
\mathbb{R}^{m \times n}
$$
$$
\mathrm{softmax} \bigg (\frac{\mathbf{Q} \cdot \mathbf{K}^\mathrm{T}}{\sqrt{k}} \bigg )
\cdot
\mathbf{V} =
\begin{pmatrix}
\alpha(\mathbf{q}_1, \mathbf{k}_1) & \alpha(\mathbf{q}_1, \mathbf{k}_2)& \cdots & \alpha(\mathbf{q}_1, \mathbf{k}_n)
\\
\alpha(\mathbf{q}_2, \mathbf{k}_1) & \alpha(\mathbf{q}_2, \mathbf{k}_2)& \cdots & \alpha(\mathbf{q}_2, \mathbf{k}_n)
\\
\vdots & \vdots & \ddots & \vdots
\\
\alpha(\mathbf{q}_m, \mathbf{k}_1) & \alpha(\mathbf{q}_m, \mathbf{k}_2)& \cdots & \alpha(\mathbf{q}_m, \mathbf{k}_n)
\end{pmatrix}
\cdot
\begin{pmatrix}
\mathbf{v}_1 \\ \mathbf{v}_2 \\ \vdots \\ \mathbf{v}_n
\end{pmatrix}
=
\begin{pmatrix}
\sum\limits_{i=1}^{n} \alpha(\mathbf{q}_1, \mathbf{k}_i) \cdot \mathbf{v}_i
\\
\sum\limits_{i=1}^{n} \alpha(\mathbf{q}_2, \mathbf{k}_i) \cdot \mathbf{v}_i
\\
\vdots
\\
\sum\limits_{i=1}^{n} \alpha(\mathbf{q}_n, \mathbf{k}_i) \cdot \mathbf{v}_i
\end{pmatrix}
$$
$$
\mathrm{softmax} \bigg (\frac{\mathbf{Q} \cdot \mathbf{K}^\mathrm{T}}{\sqrt{k}} \bigg )
\cdot
\mathbf{V} =
\begin{pmatrix}
f(\mathbf{q}_1)
\\
f(\mathbf{q}_2)
\\
\vdots
\\
f(\mathbf{q}_m)
\end{pmatrix}
~
\in
\mathbb{R}^{m \times v}
$$