【详解】Transformer 的框架结构

在这里插入图片描述


《Attention is All You Need》:https://arxiv.org/pdf/1706.03762v5.pdf
Transformer代码:https://github.com/ViatorSun/Backbone/Transformer

Attention 其实就是计算一种相关程度!


1、Encoder

Figure 1 是一个seq2seq的model,左侧为 Encoder block,右侧为 Decoder block。红色圈中的部分为Multi-Head Attention,是由多个Self-Attention组成的,可以看到 Encoder block 包含一个 Multi-Head Attention,而 Decoder block 包含两个 Multi-Head Attention (其中有一个用到 Masked)。Multi-Head Attention 上方还包括一个 Add & Norm 层,Add 表示残差连接 (Residual Connection) 用于防止网络退化,Norm 表示 Layer Normalization,用于对每一层的激活值进行归一化。比如说在Encoder Input处的输入是机器学习,在Decoder Input处的输入是,输出是machine。再下一个时刻在Decoder Input处的输入是machine,输出是learning。不断重复知道输出是句点(.)代表翻译结束。

接下来我们看看这个Encoder和Decoder里面分别都做了什么事情,先看左半部分的Encoder:首先输入 X ∈ R ( n x , N ) X \in R(n_x,N) XR(nx,N) 通过一个Input Embedding的转移矩阵 W x ∈ R ( d , n x ) W^x \in R(d,n_x) WxR(d,nx) 变为了一个张量,即上文所述的 I ∈ R ( d , N ) I\in R(d,N) IR(d,N) ,再加上一个表示位置的Positional Encoding ,得到一个张量,去往后面的操作。

它进入了这个绿色的block,这个绿色的block会重复 N N N 次。这个绿色的block里面有什么呢?它的第1层是一个上文讲的multi-head的attention。你现在一个sequence I ∈ R ( d , N ) I\in R(d,N) IR(d,N),经过一个multi-head的attention,你会得到另外一个sequence O ∈ R ( d , N ) O\in R(d,N) OR(d,N)

下一个Layer是Add & Norm,这个意思是说:把multi-head的attention的layer的输入 I ∈ R ( d , N ) I\in R(d,N) IR(d,N) O ∈ R ( d , N ) O\in R(d,N) OR(d,N) 输出 进行相加以后,再做Layer Normalization,至于Layer Normalization和我们熟悉的Batch Normalization的区别是什么,请参考图20和21。

在这里插入图片描述

其中,Batch Normalization和Layer Normalization的对比可以概括为图20,Batch Normalization强行让一个batch的数据的某个channel的 μ = 0 , σ = 1 \mu =0,\sigma=1 μ=0,σ=1,而Layer Normalization让一个数据的所有channel的 μ = 0 , σ = 1 \mu =0,\sigma=1 μ=0,σ=1

在这里插入图片描述

接着是一个Feed Forward的前馈网络和一个Add & Norm Layer。

所以,这一个绿色的block的前2个Layer操作的表达式为:

O 1 = L a y e r N o r m a l i z a t i o n ( I + M u l t i − h e a d S e l f − A t t e n t i o n ) (6) O_1=Layer Normalization(I+Multi-head Self-Attention) \tag6 O1=LayerNormalization(I+MultiheadSelfAttention)(6)

这一个绿色的block的后2个Layer操作的表达式为:

O 2 = L a y e r N o r m a l i z a t i o n ( O 1 + F e e d F o r w a r d N e t w o r k ( O 1 ) ) (7) O_2=Layer Normalization(O_1+ Feed Forward Network(O_1)) \tag7 O2=LayerNormalization(O1+FeedForwardNetwork(O1))(7)
B l o c k ( I ) = O 2 (8) Block(I)=O_2 \tag8 Block(I)=O2(8)

所以Transformer的Encoder的整体操作为:

O 1 = L a y e r N o r m a l i z a t i o n ( I + M u l t i − h e a d S e l f − A t t e n t i o n ) (9) O_1=Layer Normalization(I+Multi-head Self-Attention) \tag9 O1=LayerNormalization(I+MultiheadSelfAttention)(9)



2、Decoder:

现在来看Decoder的部分,输入包括2部分,下方是前一个time step的输出的embedding,即上文所述的 I ∈ R ( d , N ) I\in R(d,N) IR(d,N),再加上一个表示位置的Positional Encoding E ∈ R ( d , N ) E\in R(d,N) ER(d,N),得到一个张量,去往后面的操作。它进入了这个绿色的block,这个绿色的block会重复 N N N 次。这个绿色的block里面有什么呢?

首先是Masked Multi-Head Self-attention,masked的意思是使attention只会attend on已经产生的sequence,这个很合理,因为还没有产生出来的东西不存在,就无法做attention。

  • 输出是: 对应 i i i 位置的输出词的概率分布。
  • 输入是: E n c o d e r Encoder Encoder 的输出 和 对应 i − 1 i-1 i1 位置decoder的输出。所以中间的attention不是self-attention,它的Key和Value来自encoder,Query来自上一位置 D e c o d e r Decoder Decoder 的输出。
  • 解码: 这里要特别注意一下,编码可以并行计算,一次性全部Encoding出来,但解码不是一次把所有序列解出来的,而是像 R N N RNN RNN 一样一个一个解出来的,因为要用上一个位置的输入当作attention的query。

明确了解码过程之后最上面的图就很好懂了,这里主要的不同就是新加的另外要说一下新加的attention多加了一个mask,因为训练时的output都是Ground Truth,这样可以确保预测第 个位置时不会接触到未来的信息

  1. 包含两个 Multi-Head Attention 层。
  2. 第一个 Multi-Head Attention 层采用了 Masked 操作。
  3. 第二个 Multi-Head Attention 层的Key,Value矩阵使用 Encoder 的编码信息矩阵 C C C 进行计算,而Query使用上一个 Decoder block 的输出计算。
  4. 最后有一个 Softmax 层计算下一个翻译单词的概率。

下面详细介绍下Masked Multi-Head Self-attention的具体操作,Masked在Scale操作之后,softmax操作之前。

在这里插入图片描述

因为在翻译的过程中是顺序翻译的,即翻译完第 i i i 个单词,才可以翻译第 i + 1 i+1 i+1 个单词。通过 Masked 操作可以防止第 i i i 个单词知道第 i + 1 i+1 i+1 个单词之后的信息。下面以 “我有一只猫” 翻译成 “I have a cat” 为例,了解一下 Masked 操作。在 Decoder 的时候,是需要根据之前的翻译,求解当前最有可能的翻译,如下图所示。首先根据输入 “” 预测出第一个单词为 “I”,然后根据输入 " I" 预测下一个单词 “have”。

Decoder 可以在训练的过程中使用 Teacher Forcing 并且并行化训练,即将正确的单词序列 ( I have a cat) 和对应输出 (I have a cat ) 传递到 Decoder。那么在预测第 i i i 个输出时,就要将第 i + 1 i+1 i+1 之后的单词掩盖住,注意 Mask 操作是在 Self-Attention 的 Softmax 之前使用的,下面用 0 1 2 3 4 5 分别表示 " I have a cat "。

在这里插入图片描述

注意这里transformer模型训练和测试的方法不同:

测试时:

  1. 输入,解码器输出 I 。
  2. 输入前面已经解码的和 I,解码器输出have。
  3. 输入已经解码的,I, have, a, cat,解码器输出解码结束标志位,每次解码都会利用前面已经解码输出的所有单词嵌入信息。

Transformer测试时的解码过程:
在这里插入图片描述


训练时:

不采用上述类似RNN的方法 一个一个目标单词嵌入向量顺序输入训练,想采用类似编码器中的矩阵并行算法,一步就把所有目标单词预测出来。要实现这个功能就可以参考编码器的操作,把目标单词嵌入向量组成矩阵一次输入即可。即:并行化训练。

但是在解码have时候,不能利用到后面单词a和cat的目标单词嵌入向量信息,否则这就是作弊(测试时候不可能能未卜先知)。为此引入mask。具体是:在解码器中,self-attention层只被允许处理输出序列中更靠前的那些位置,在softmax步骤前,它会把后面的位置给隐去。


3、Masked Multi-Head Self-attention的具体操作

Step1: 输入矩阵包含 " I have a cat" (0, 1, 2, 3, 4) 五个单词的表示向量,Mask是一个 5×5 的矩阵。在Mask可以发现单词 0 只能使用单词 0 的信息,而单词 1 可以使用单词 0, 1 的信息,即只能使用之前的信息。输入矩阵 X ∈ R N , d x X\in R_{N,d_x} XRN,dx 经过transformation matrix变为3个矩阵:Query Q ∈ R N , d Q\in R_{N,d} QRN,d,Key K ∈ R N , d K\in R_{N,d} KRN,d 和Value V ∈ R N , d V\in R_{N,d} VRN,d

Step2: Q T ⋅ K Q^T \cdot K QTK得到 Attention矩阵 A ∈ R N , N A\in R_{N,N} ARN,N ,此时先不急于做softmax的操作,而是先于一个 M a s k ∈ R N , N Mask \in R_{N,N} MaskRN,N 矩阵相乘,使得attention矩阵的有些位置 归0,得到Masked Attention矩阵 M a s k A t t e n t i o n ∈ R N , N Mask Attention \in R_{N,N} MaskAttentionRN,N M a s k ∈ R N , N Mask \in R_{N,N} MaskRN,N 矩阵是个下三角矩阵,为什么这样设计?是因为想在计算 Z Z Z 矩阵的某一行时,只考虑它前面token的作用。即:在计算 Z Z Z 的第一行时,刻意地把 A t t e n t i o n Attention Attention 矩阵第一行的后面几个元素屏蔽掉,只考虑 A t t e n t i o n 0 , 0 Attention_{0,0} Attention0,0 。在产生have这个单词时,只考虑 I,不考虑之后的have a cat,即只会attend on已经产生的sequence,这个很合理,因为还没有产生出来的东西不存在,就无法做attention。

Step3: Masked Attention矩阵进行 Softmax,每一行的和都为 1。但是单词 0 在单词 1, 2, 3, 4 上的 attention score 都为 0。得到的结果再与 V V V 矩阵相乘得到最终的self-attention层的输出结果 Z 1 ∈ R N , d Z_1\in R_{N,d} Z1RN,d

Step4: Z 1 ∈ R N , d Z_1\in R_{N,d} Z1RN,d 只是某一个head的结果,将多个head的结果concat在一起之后再最后进行Linear Transformation得到最终的Masked Multi-Head Self-attention的输出结果 Z ∈ R N , d Z\in R_{N,d} ZRN,d

在这里插入图片描述

第1个Masked Multi-Head Self-attention的 Q u e r y , K e y , V a l u e Query,Key,Value QueryKeyValue 均来自Output Embedding。

第2个Multi-Head Self-attention的 Q u e r y Query Query 来自第1个Self-attention layer的输出, K e y , V a l u e Key,Value KeyValue 来自Encoder的输出。


为什么这么设计?

这里提供一种个人的理解:

  • K e y , V a l u e Key,Value KeyValue 来自Transformer Encoder的输出,所以可以看做句子(Sequence)/图片(image)的内容信息(content,比如句意是:“我有一只猫”,图片内容是:“有几辆车,几个人等等”)。
  • Q u e r y Query Query 表达了一种诉求:希望得到什么,可以看做引导信息(guide)。

通过Multi-Head Self-attention结合在一起的过程就相当于是把我们需要的内容信息指导表达出来。

Decoder的最后是Softmax 预测输出单词。因为 Mask 的存在,使得单词 0 的输出 Z ( 0 ) Z(0) Z(0), 只包含单词 0 的信息。Softmax 根据输出矩阵的每一行预测下一个单词,如下图25所示。

在这里插入图片描述

ViatorSun CSDN认证博客专家 深度学习 计算机视觉
研究生在读、Github开源世界贡献者,深度学习 & 计算机视觉分享者;
主要研究【深度学习 & 计算机视觉】相关方向,欢迎感兴趣的小伙伴一起交流、探讨~
相关推荐
<p> <span style="font-size:14px;color:#E53333;">限时福利1:</span><span style="font-size:14px;">购课进答疑群专享柳峰(刘运强)老师答疑服务</span> </p> <p> <br /> </p> <p> <br /> </p> <p> <span style="font-size:14px;"></span> </p> <p> <span style="font-size:14px;color:#337FE5;"><strong>为什么需要掌握高性能MySQL实战?</strong></span> </p> <p> <span><span style="font-size:14px;"><br /> </span></span> <span style="font-size:14px;">由于互联网产品用户量大、高并发请求场景多,因此对MySQL性能、可用性、扩展性都提出了很高要求。使用MySQL解决大量数据以及高并发请求已经是程序员必备技能,也是衡量一个程序员能力和薪资标准之一。</span> </p> <p> <br /> </p> <p> <span style="font-size:14px;">为了让大家快速系统了解高性能MySQL核心知识全貌,我为你总结了</span><span style="font-size:14px;">「高性能 MySQL 知识框架图」</span><span style="font-size:14px;">,帮你梳理学习重点,建议收藏!</span> </p> <p> <br /> </p> <p> <img alt="" src="https://img-bss.csdnimg.cn/202006031401338860.png" /> </p> <p> <br /> </p> <p> <span style="font-size:14px;color:#337FE5;"><strong>课程设计</strong></span> </p> <p> <span style="font-size:14px;"><br /> </span> </p> <p> <span style="font-size:14px;">课程分为四大篇章,将为你建立完整 MySQL 知识体系,同时将重点讲解 MySQL 底层运行原理、数据库性能调优、高并发、海量业务处理、面试解析等。</span> </p> <p> <span style="font-size:14px;"><br /> </span> </p> <p> <span style="font-size:14px;"></span> </p> <p style="text-align:justify;"> <span style="font-size:14px;"><strong>一、性能优化篇:</strong></span> </p> <p style="text-align:justify;"> <span style="font-size:14px;">主要包括经典 MySQL 问题剖析、索引底层原理和事务与锁机制。通过深入理解 MySQL 索引结构 B+Tree ,学员能够从根本上弄懂为什么有些 SQL 走索引、有些不走索引,从而彻底掌握索引使用和优化技巧,能够避开很多实战中遇到“坑”。</span> </p> <p style="text-align:justify;"> <br /> </p> <p style="text-align:justify;"> <span style="font-size:14px;"><strong>二、MySQL 8.0新特性篇:</strong></span> </p> <p style="text-align:justify;"> <span style="font-size:14px;">主要包括窗口函数和通用表表达式。企业中许多报表统计需求,如果不采用窗口函数,用普通 SQL 语句是很难实现。</span> </p> <p style="text-align:justify;"> <br /> </p> <p style="text-align:justify;"> <span style="font-size:14px;"><strong>三、高性能架构篇:</strong></span> </p> <p style="text-align:justify;"> <span style="font-size:14px;">主要包括主从复制和读写分离。在企业生产环境中,很少采用单台MySQL节点情况,因为一旦单个节点发生故障,整个系统都不可用,后果往往不堪设想,因此掌握高可用架构实现是非常有必要。</span> </p> <p style="text-align:justify;"> <br /> </p> <p style="text-align:justify;"> <span style="font-size:14px;"><strong>四、面试篇:</strong></span> </p> <p style="text-align:justify;"> <span style="font-size:14px;">程序员获得工作第一步,就是高效准备面试,面试篇主要从知识点回顾总结角度出发,结合程序员面试高频MySQL问题精讲精练,帮助程序员吊打面试官,获得心仪工作机会。</span> </p>
©️2020 CSDN 皮肤主题: 岁月 设计师:pinMode 返回首页