Sirry Chen bio photo

Email

Twitter

Github

Zhihu

Bilibili

Attention中的scale操作


1. 问题描述

《大规模语言模型:从理论到实践》第16页,“为防止过大的匹配分数在后续Softmax计算过程中导致的梯度爆炸以及收敛效率差的问题,这些得分会除缩放因子d 以稳定优化。”

原文等价于: d 的作用为缩减QKT 中元素值的大小,避免在梯度反向传播时导致的梯度爆炸问题。

alt text

截图自《大规模语言模型:从理论到实践》

2. d 的作用应该是防止梯度消失

使用与书中相同的记号

Attention(Q,K,V)=Softmax(QKTd)V,

并定义

W=QKTd,W=Softmax(W),

其中矩阵W中的元素 Wij=eWijdp=1eWip

2.1 去除d 项会使得变量Wij 的方差增大

由之前定义可知矩阵W 中元素Wij=QiKTid=dj=1QijKijd , 此时假设变量Qij,Kij 均服从标准正态分布且互相独立,即Qij,KijN(0,1) ,则变量WijN(0,1)

若去除d 项,即Wij=QiKTi=dj=1QijKij ,则WijN(0,d) ,此时方差增大,即W 矩阵中元素之间的差异增大。

2.2 变量Wij 的方差增大会使得梯度值偏小

在反向传播过程中,会涉及到对Softmax项进行求导,对于 W 中的某一个元素 Wij ,求偏导如下

Softmax(Wij)Wij=Softmax(Wij)(1Softmax(Wij))Softmax(Wij)Wip=Softmax(Wip)Softmax(Wij),pj

若变量 Wij 的方差增大,则考虑元素 Wij 远大于其他元素 Wip 的情况,则 Softmax(Wij) 趋近于1,而 Softmax(Wip) 趋近于0

  • 对于式(1), 1Softmax(Wij) 趋近于0,使得 Softmax(Wij)Wij 趋近于0
  • 对于式(2), Softmax(Wip) 趋近于0,使得 Softmax(Wij)Wip 趋近于0

2.3 归纳

丢弃 d 项会使得变量 Wij 的方差增大,而变量 Wij 的方差增大会使得梯度值偏小,进而引发梯度消失问题。

真正解决梯度爆炸问题的应该是 Softmax函数,它直接进行了归一化的操作,避免值过大导致的梯度爆炸问题。而正是因为使用了 Softmax函数,引入了变量方差过大导致梯度消失的问题,所以需要对变量进行除以d操作降低方差。

Refences

  1. 参考“Attention Is All You Need”原文,较大的d会导致点积的增大,进而在计算Softmax函数时被推向较小梯度的区域(说的比较含糊,脚注中解释从方差影响的角度进行说明)
alt text

截图自"Attention Is All You Need"

  1. 知乎:Transformer学习笔记二:Self-Attention(自注意力机制)
  2. 知乎提问:transformer中的attention为什么scaled?