Attention的汇总与辨析_Additive、Multiplication、Scaled dot-product、Self Attention、Multi-head Self-Attention

(29) 2024-06-20 16:01:01

一、Seq2Seq

1.1 Seq2Seq(Encoder-Decoder)

  • 简介:使用Encoder将input编码为一个固定长度的context向量,使用Decoder将context解码为output。input、output长度不一定相同。
  • 奠基论文:Sequence to Sequence Learning with Neural Networks
    Attention的汇总与辨析_Additive、Multiplication、Scaled dot-product、Self Attention、Multi-head Self-Attention (https://mushiming.com/)  第1张

1.2 Seq2Seq(Decoder过程)

  • 解码过程:根据Encoder获得上下文向量context,在以后的每个解码的时间步,都要用到这个统一的context、上一时间步的输出、当前时间步的隐藏状态,这3个信息作为输入计算当前时间步的输出。
    与Attention机制辨析:普通Seq2Seq的context计算是前面所有输入的编码,也就是说在Decoder过程中,每一个时间步的context都是相同的,也就是所有输入部分的attention是相同的。引入Attention机制后,Decoder过程中每一个时间步的context都会不一样了,会关注到输入序列的不同部分。
  • 解码方式:贪心法、beam search(k=1时退化为贪心法)。

1.3 Seq2Seq存在问题

  • 忽略了输入序列的长度:当输入序列很长时,模型能力急剧下降。
  • 缺少输入序列的区分度:编码成固定上下文向量context时,句子中每个词都赋予相同的权重,没有区分度。

1.4 Seq2Seq改进

  • 长序列的问题:随着序列增长,句子越前面的信息丢失就越严重。比如一个源序列有100个词,解码时目标序列的第一个词很大概率上就和源序列的第一个词相对应,这就意味着第一步的解码要考虑到100步之前的信息。
  • 改进:逆序输入(双向序列)、重复序列输入、LSTM/GRU、Attention机制

二、Attention机制

2.1 Attention机制作用

  • 帮助模型对输入序列的每部分赋予不同的权重,也就是区分输入的不同部分对输出的影响。

2.2 Attention计算过程

  • 参考文章:遍地开花的 Attention ,你真的懂吗?
  • 可以把Attention计算过程抽象为以下3个步骤:
    1、计算相似性 score function:打分函数,度量输入序列隐藏状态与当前解码时间步的隐藏状态的相似性
    2、归一化 alignment function:对齐函数(通常使用softmax归一化),得到当前解码时间步的输出对应输入序列隐藏状态的权重,总和为1
    3、得到Attention context function:输入的隐藏状态与当前时刻的权重加权求和,得到当前解码时间步的context
    Attention的汇总与辨析_Additive、Multiplication、Scaled dot-product、Self Attention、Multi-head Self-Attention (https://mushiming.com/)  第2张

三、Attention种类

3.1 Bahdanau Attention/Additive Attention

由上图所示,源序列长度为 N N N t t t针对的是目标序列解码过程中的每个时间步。
1、score function: e t , i = a ( s t − 1 , h i ) = v a T t a n h ( W a ∗ s t − 1 + U a ∗ h i ) e_{t,i}=a(s_{t-1,h_i})=v_a^Ttanh(W_a*s_{t-1}+U_a*h_i) et,i=a(st1,hi)=vaTtanh(Wast1+Uahi),计算相似程度
2、alignment function: α t , i = e x p ( e t , i ) ∑ k = 1 N e x p ( e t , k ) \alpha_{t,i}=\frac{exp(e_{t,i})}{\sum_{k=1}^{N}{exp(e_{t,k})}} αt,i=k=1Nexp(et,k)exp(et,i)
3、context function: c t = ∑ i N α t , i h i c_t=\sum_i^N\alpha_{t,i}h_i ct=iNαt,ihi
这里可以和传统seq2seq架构进行对比,如果不引入Attention机制,则 y t y_t yt的条件概率公式: p ( y t ∣ y 1 , . . . , y t − 1 , c ) = g ( y t − 1 , s t , c ) p(y_t|y_1,...,y_{t-1},c)=g(y_{t-1},s_t,c) p(yty1,...,yt1,c)=g(yt1,st,c)引入Attention之后: p ( y t ∣ y 1 , . . . , y t − 1 , X ) = g ( y t − 1 , s t , c t ) p(y_t|y_1,...,y_{t-1},X)=g(y_{t-1},s_t,c_t) p(yty1,...,yt1,X)=g(yt1,st,ct)也就是解码过程中每个时间步的context都不一样。

3.2 Loung Attention/Multiplication Attention

由上图所示,源序列长度为 N N N t t t针对的是目标序列解码过程中的每个时间步。
1、score function: e t , i = a ( s t − 1 , h i ) = s t − 1 T W a h i e_{t,i}=a(s_{t-1,h_i})=s_{t-1}^TW_ah_i et,i=a(st1,hi)=st1TWahi
2、alignment function: α t , i = e x p ( e t , i ) ∑ k = 1 N e x p ( e t , k ) \alpha_{t,i}=\frac{exp(e_{t,i})}{\sum_{k=1}^{N}{exp(e_{t,k})}} αt,i=k=1Nexp(et,k)exp(et,i)
3、context function: c t = ∑ i N α t , i h i c_t=\sum_i^N\alpha_{t,i}h_i ct=iNαt,ihi
加性注意力和乘法注意力在复杂度上是相似的,但是乘法注意力在实践中往往要更快速、具有更高效的存储,因为它可以使用矩阵操作更高效地实现。

3.3 Scaled dot-product Attention

Transformer采用这种方式构建Self-Attention模块。KV代表Key-Value,不同的Key对应不同的Value。Q为Query向量,Query与所有的Key度量一个相似性,找到和它最相似的Key,仿照这个Key对应的Value,产生这个Query对应的Value。
Attention的汇总与辨析_Additive、Multiplication、Scaled dot-product、Self Attention、Multi-head Self-Attention (https://mushiming.com/)  第3张Scaled dot-product Attention计算公式: s o f t m a x ( Q K T i n _ d i m ) V softmax(\frac{QK^T}{\sqrt {in\_dim}})V softmax(in_dim
QKT
)V
也可以类似地拆成以下三步:
先要明确一下源序列为 x 1 , x 2 , . . . , x m x_1,x_2,...,x_m x1,x2,...,xm,而目标序列为 y 1 , y 2 , . . . , y t , . . . , y n y_1,y_2,...,y_t,...,y_n y1,y2,...,yt,...,yn
其中 Q = ( q 1 , q 2 , . . . , q m ) , 1 ≤ i ≤ m , 1 ≤ j ≤ n Q=(q_1,q_2,...,q_m),1≤i≤m,1≤j≤n Q=(q1,q2,...,qm),1im,1jn
1、score function: e i , j = a ( q i , k j ) e_{i,j}=a(q_i,k_j) ei,j=a(qi,kj)
2、alignment function: α i , j = e x p ( e i , j ) ∑ k = 1 n e x p ( e i , k ) \alpha_{i,j}=\frac{exp(e_{i,j})}{\sum_{k=1}^{n}{exp(e_{i,k})}} αi,j=k=1nexp(ei,k)exp(ei,j)
3、context function: c i = ∑ k = 1 n α i , k v k c_i=\sum_{k=1}^n\alpha_{i,k}v_k ci=k=1nαi,kvk
首先宏观上他也是计算相似性、归一化、Attention计算这三个步骤。
但和前面传统Attention计算有一些不同,但也只是计算的先后顺序不同,宏观上都是一回事。

  • 传统Attention:固定 y t y_t yt,将它与源序列所有 x 1 , x 2 , . . . , x m x_1,x_2,...,x_m x1,x2,...,xm做相似性计算,这里用的是对应的隐藏状态。换句话说,这里你要计算多少个 c t c_t ct?是不是你的目标序列有多少个解码的时间步?理想情况如果你都翻译对了,那你需要 n n n个时间步的上下文向量。
  • Scaled dot-product Attention:固定了 x i x_i xi,将它与 y 1 , y 2 , . . . , y n y_1,y_2,...,y_n y1,y2,...,yn做相似性计算,这里用到的是 Q K V QKV QKV矩阵。换句话说,这里你计算出来了多少个 c t c_t ct?是不是源序列有多长,你就计算出了几个 c t c_t ct?也就是 m m m个上下文向量。

3.4 Self-Attention

Self-Attention就是Scaled dot-product Attention。区别就是此时的源序列就是目标序列
Attention的汇总与辨析_Additive、Multiplication、Scaled dot-product、Self Attention、Multi-head Self-Attention (https://mushiming.com/)  第4张
再次明确一下源序列为 x 1 , x 2 , . . . , x m x_1,x_2,...,x_m x1,x2,...,xm,而目标序列为 y 1 , y 2 , . . . , y t , . . . , y n y_1,y_2,...,y_t,...,y_n y1,y2,...,yt,...,yn

  • Scaled dot-product Attention:源序列 x 1 , x 2 , . . . , x m x_1,x_2,...,x_m x1,x2,...,xm那么对应的Q矩阵就有 q 1 , q 2 , . . . , q m q_1,q_2,...,q_m q1,q2,...,qm。目标序列 y 1 , y 2 , . . . , y n y_1,y_2,...,y_n y1,y2,...,yn那么对应的K矩阵、V矩阵长度都是相同的,也就有 k 1 , k 2 , . . . , k n k_1,k_2,...,k_n k1,k2,...,kn以及 v 1 , v 2 , . . . , v n v_1,v_2,...,v_n v1,v2,...,vn
  • Self-Attention:序列 x 1 , x 2 , . . . , x m x_1,x_2,...,x_m x1,x2,...,xm既是源序列又是目标序列,因为是自己和自己做Attention的计算。那么对应的Q矩阵、K矩阵、V矩阵长度都是相同的,也就有 q 1 , q 2 , . . . , q m q_1,q_2,...,q_m q1,q2,...,qm k 1 , k 2 , . . . , k m k_1,k_2,...,k_m k1,k2,...,km v 1 , v 2 , . . . , v m v_1,v_2,...,v_m v1,v2,...,vm
  • 二者区别:Scaled dot-product Attention、Self-Attention辨析

3.5 Multi-head Self-Attention

就是将Self-Attention中配套的QKV映射到了不同的子空间当中。假设head=2,计算时每个head之间的数据独立,不产生交叉。最终 b i b_i bi的输出需要一个权重矩阵 W O W^O WO模型学习得到。
Attention的汇总与辨析_Additive、Multiplication、Scaled dot-product、Self Attention、Multi-head Self-Attention (https://mushiming.com/)  第5张当head=h时,最终Output由h个head产生的 O u t p u t i Output_i Outputi与权重矩阵相乘得到:
Attention的汇总与辨析_Additive、Multiplication、Scaled dot-product、Self Attention、Multi-head Self-Attention (https://mushiming.com/)  第6张

  • 二者辨析:Self-Attention原理、Multi-head Self-Attention原理及Pytorch实现
THE END

发表回复