Attention中的scale操作
1. 问题描述
《大规模语言模型:从理论到实践》第16页,“为防止过大的匹配分数在后续Softmax计算过程中导致的梯度爆炸以及收敛效率差的问题,这些得分会除缩放因子√d 以稳定优化。”
原文等价于: √d 的作用为缩减QKT 中元素值的大小,避免在梯度反向传播时导致的梯度爆炸问题。

截图自《大规模语言模型:从理论到实践》
2. √d 的作用应该是防止梯度消失
使用与书中相同的记号
Attention(Q,K,V)=Softmax(QKT√d)V,并定义
W∗=QKT√d,W=Softmax(W∗),其中矩阵W中的元素 Wij=eW∗ij∑dp=1eW∗ip 。
2.1 去除√d 项会使得变量W∗ij 的方差增大
由之前定义可知矩阵W∗ 中元素W∗ij=QiKTi√d=∑dj=1QijKij√d , 此时假设变量Qij,Kij 均服从标准正态分布且互相独立,即Qij,Kij∼N(0,1) ,则变量W∗ij∼N(0,1)。
若去除√d 项,即W∗ij=QiKTi=∑dj=1QijKij ,则W∗ij∼N(0,d) ,此时方差增大,即W∗ 矩阵中元素之间的差异增大。
2.2 变量W∗ij 的方差增大会使得梯度值偏小
在反向传播过程中,会涉及到对Softmax项进行求导,对于 W∗ 中的某一个元素 W∗ij ,求偏导如下
∂Softmax(W∗ij)∂W∗ij=Softmax(W∗ij)(1−Softmax(W∗ij))∂Softmax(W∗ij)∂W∗ip=−Softmax(W∗ip)Softmax(W∗ij),p≠j若变量 W∗ij 的方差增大,则考虑元素 W∗ij 远大于其他元素 W∗ip 的情况,则 Softmax(W∗ij) 趋近于1,而 Softmax(W∗ip) 趋近于0
- 对于式(1), 1−Softmax(W∗ij) 趋近于0,使得 ∂Softmax(W∗ij)∂W∗ij 趋近于0
- 对于式(2), Softmax(W∗ip) 趋近于0,使得 ∂Softmax(W∗ij)∂W∗ip 趋近于0
2.3 归纳
丢弃 √d 项会使得变量 W∗ij 的方差增大,而变量 W∗ij 的方差增大会使得梯度值偏小,进而引发梯度消失问题。
真正解决梯度爆炸问题的应该是 Softmax函数,它直接进行了归一化的操作,避免值过大导致的梯度爆炸问题。而正是因为使用了 Softmax函数,引入了变量方差过大导致梯度消失的问题,所以需要对变量进行除以√d操作降低方差。
Refences
- 参考“Attention Is All You Need”原文,较大的d会导致点积的增大,进而在计算Softmax函数时被推向较小梯度的区域(说的比较含糊,脚注中解释从方差影响的角度进行说明)

截图自"Attention Is All You Need"
- LATEX 公式支持调得感觉要寄了😵
- 这个是好久之前看书发现的bug,只了解并记录了很浅层的原因。最近看见了苏剑林老师对这个问题有更深层次的思考[浅谈Transformer的初始化、参数化与标准化]、[从熵不变性看Attention的Scale操作] orz,reading and learning to death…