【Transformer】模型总结

模型架构

采用encoder-decoder的架构

encoder将输入\((x_1,,x_n)\)映射到连续的中间表达\((z_1,,z_n)\)

decoder再采用自回归的方式输出序列\((y_1,,y_n)\)

(将前一个生成的符号添加到输入,接着生成下一个(类似RNN))

image-20240724172359971

解码器

Transformer编码器由N个DecoderBlock堆叠而成,每个DecoderBlock有三个子层:

  • Masked Multi-Head Attention (Self-Attention)

    • Add (残差连接)

    • Norm (层归一化)

  • Multi-Head Attention (Co-Attention)

    • Add (残差连接)
    • Norm (层归一化)
  • Positionwise Feed Forward Network

    • Add (残差连接)
    • Norm (层归一化)

\(DecoderBlock_i\)的输出给到下一层,作为\(DecoderBlock_{i+1}(i=1,2,..,n-1)\)的输入

自注意机制

模型的输入同时作为 Key Value Qurey

和自己计算相似度(向量内积)最大=1

image-20240724161741228

在计算Decoder时使用掩码(Mask)

image-20240724162135199

Decoder第二个子层中使用Encoder的输出作为Key-Value,上一层的输出作为Query。

在编码器和解码器中传递信息

image-20240724162544352

编码器

编码器使用N=6独立层,每层包含两个子层

  • multi-head self-attention mechanism(多头注意力机制)
  • position-wise fully connected feed-forward network.(简单说就是全连接层(MLP))

两个子层间采用残差连接

并使用层归一化处理 \(LayerNorm(x+Sublayer(x))\)

固定每个层的输出为\(d_{model}=512\)

image-20240724180323352
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
class EncoderBlock(nn.Module):
"""Transformer编码器块"""

def __init__(self, key_size, query_size, value_size, num_hiddens,
norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
dropout, use_bias=False, **kwargs):

super(EncoderBlock, self).__init__(**kwargs)
self.attention = d2l.MultiHeadAttention(
key_size, query_size, value_size, num_hiddens, num_heads, dropout,
use_bias)
self.addnorm1 = AddNorm(norm_shape, dropout)
self.ffn = PositionWiseFFN(
ffn_num_input, ffn_num_hiddens, num_hiddens)
self.addnorm2 = AddNorm(norm_shape, dropout)

def forward(self, X, valid_lens):
Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
return self.addnorm2(Y, self.ffn(Y))

Encoder堆叠了num_layers个EncoderBlock类的实例。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class TransformerEncoder(d2l.Encoder):
"""Transformer编码器"""

def __init__(self, vocab_size, key_size, query_size, value_size,
num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,
num_heads, num_layers, dropout, use_bias=False, **kwargs):

super(TransformerEncoder, self).__init__(**kwargs)
self.num_hiddens = num_hiddens

self.embedding = nn.Embedding(vocab_size, num_hiddens)
self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)

self.blks = nn.Sequential()
for i in range(num_layers):
self.blks.add_module("block"+str(i),
EncoderBlock(key_size, query_size, value_size, num_hiddens,
norm_shape, ffn_num_input, ffn_num_hiddens,
num_heads, dropout, use_bias))

def forward(self, X, valid_lens, *args):
# 因为位置编码值在-1和1之间,
# 因此嵌入值乘以嵌入维度的平方根进行缩放,
# 然后再与位置编码相加。
X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))

self.attention_weights = [None] * len(self.blks)
for i, blk in enumerate(self.blks):
X = blk(X, valid_lens)
self.attention_weights[
i] = blk.attention.attention.attention_weights

return X

解码器

同样采用N=6的独立层

相较于encoder多第三个子层掩码的多头注意力机制

(控制输入模型的序列部分可见。 例如预测第t时刻,不应该看到t时刻以后的输入)

Masked Multi-head Attention

在DecoderBlock类中实现的每个层包含了三个子层:

  1. 解码器自注意力
  2. “编码器-解码器”注意力
  3. 基于位置的前馈网络

子层之间通过残差连接和层规范化

image-20240725090842835

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
class DecoderBlock(nn.Module):
"""解码器中第i个块"""

def __init__(self, key_size, query_size, value_size, num_hiddens,
norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
dropout, i, **kwargs):

super(DecoderBlock, self).__init__(**kwargs)
self.i = i

self.attention1 = d2l.MultiHeadAttention(
key_size, query_size, value_size, num_hiddens, num_heads, dropout)
self.addnorm1 = AddNorm(norm_shape, dropout)

self.attention2 = d2l.MultiHeadAttention(
key_size, query_size, value_size, num_hiddens, num_heads, dropout)
self.addnorm2 = AddNorm(norm_shape, dropout)

self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)
self.addnorm3 = AddNorm(norm_shape, dropout)

def forward(self, X, state):
enc_outputs, enc_valid_lens = state[0], state[1]
# 训练阶段,输出序列的所有词元都在同一时间处理,
# 因此state[2][self.i]初始化为None。
# 预测阶段,输出序列是通过词元一个接着一个解码的,
# 因此state[2][self.i]包含着直到当前时间步第i个块解码的输出表示
# state = [Encoder输出,mask长度, 直到当前时间步第i个块解码的输出表示]
# X.shape = [batch_size, num_steps, d]

if state[2][self.i] is None:
key_values = X
else:
key_values = torch.cat((state[2][self.i], X), axis=1)
state[2][self.i] = key_values

if self.training:
batch_size, num_steps, _ = X.shape
# dec_valid_lens的开头:(batch_size,num_steps),
# 其中每一行是[1,2,...,num_steps]
dec_valid_lens = torch.arange(1, num_steps + 1, device=X.device).repeat(batch_size, 1)
else:
dec_valid_lens = None

# 自注意力
X2 = self.attention1(X, key_values, key_values, dec_valid_lens)
Y = self.addnorm1(X, X2)

# 编码器-解码器注意力。
# enc_outputs的开头:(batch_size,num_steps,num_hiddens)
Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)
Z = self.addnorm2(Y, Y2)

return self.addnorm3(Z, self.ffn(Z)), state

构建了由num_layers个DecoderBlock实例组成的完整的Transformer解码器。

最后,通过一个全连接层计算所有vocab_size个可能的输出词元的预测值。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
class TransformerDecoder(d2l.AttentionDecoder):
def __init__(self, vocab_size, key_size, query_size, value_size,
num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,
num_heads, num_layers, dropout, **kwargs):

super(TransformerDecoder, self).__init__(**kwargs)
self.num_hiddens = num_hiddens
self.num_layers = num_layers
self.embedding = nn.Embedding(vocab_size, num_hiddens)
self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
self.blks = nn.Sequential()
for i in range(num_layers):
self.blks.add_module("block"+str(i),
DecoderBlock(key_size, query_size, value_size, num_hiddens,
norm_shape, ffn_num_input, ffn_num_hiddens,
num_heads, dropout, i))
self.dense = nn.Linear(num_hiddens, vocab_size)

def init_state(self, enc_outputs, enc_valid_lens, *args):
return [enc_outputs, enc_valid_lens, [None] * self.num_layers]

def forward(self, X, state):
# Embedding + Positional Encoding
X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))

self._attention_weights = [[None] * len(self.blks) for _ in range (2)]
for i, blk in enumerate(self.blks):
X, state = blk(X, state)
# 解码器自注意力权重
self._attention_weights[0][i] = blk.attention1.attention.attention_weights
# “编码器-解码器”自注意力权重
self._attention_weights[1][i] = blk.attention2.attention.attention_weights

return self.dense(X), state

@property
def attention_weights(self):
return self._attention_weights
1
2
3
4
5
6
7
8
9
encoder = TransformerEncoder(
len(src_vocab), key_size, query_size, value_size, num_hiddens,
norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
num_layers, dropout)
decoder = TransformerDecoder(
len(tgt_vocab), key_size, query_size, value_size, num_hiddens,
norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
num_layers, dropout)
net = d2l.EncoderDecoder(encoder, decoder)

模型分析

n为序列的长度,d为模型的维度\(d_{model}\)

Layer Type Complexity per Layer Sequential Operations Maximum Path Length
Self-Attention \(O(n^2d)\) \(O(1)\) \(O(1)\)
Recurrent \(O(nd^2)\) \(O(n)\) \(O(n)\)
Convolutional \(O(knd^2)\) \(O(1)\) \(O(log_k(n))\)
Self-Attention(restricted) \(O(rnd)\) \(O(1)\) \(O(n/r)\)