【Transformer】Encoder

Transformer编码器由N个EncoderBlock堆叠而成,每个EncoderBlock有两个子层:

  • Multi-Head Attention (Self-Attention)
    • Add (残差连接)
    • Norm (层归一化)
  • Positionwise Feed Forward Network
    • Add (残差连接)
    • Norm (层归一化)

\(EncoderBlock_i\)的输出给到下一层,作为\(EncoderBlock_{i+1}(i=1,2,..,n-1)\)的输入

\(EncoderBlock_{n}\)的输出给到解码器的各层中

\(EncoderBlock_{n}\)的结构如下所示

image-20240805154949444

前面已实现MultiHeadAttention、AddNorm、PositionWiseFFN等模块,这里首先实现EncoderBlock。

其中包含两个子层: 多头注意力机制、基于位置前馈神经网络

每个子层之间均使用残差连接和Layer normalization

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class EncoderBlock(nn.Module):
"""Transformer编码器块"""
def __init__(self, key_size, query_size, value_size, num_hiddens,
norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
dropout, use_bias=False, **kwargs):
super(EncoderBlock, self).__init__(**kwargs)
self.attention = MultiHeadAttention(key_size, query_size, value_size,
num_hiddens, num_heads, dropout,use_bias)
self.addnorm1 = AddNorm(norm_shape, dropout)
self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)
self.addnorm2 = AddNorm(norm_shape, dropout)

def forward(self, X, valid_lens):
Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
return self.addnorm2(Y, self.ffn(Y))

可以看出EncoderBlock并不会改变输入的形状

1
2
3
4
5
X = torch.ones((2, 100, 24))
valid_lens = torch.tensor([3, 2])
encoder_blk = EncoderBlock(24, 24, 24, 24, [100, 24], 24, 48, 8, 0.5)
encoder_blk.eval()
encoder_blk(X, valid_lens).shape # torch.Size([2, 100, 24])

利用n个EncoderBlock类堆叠起来实现Transformer的解码器

由于这里使用的是值范围在\(-1\)\(1\)之间的固定位置编码,因此通过学习得到的输入的嵌入表示的值需要先乘以嵌入维度的平方根进行重新缩放,然后再与位置编码相加。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27

class TransformerEncoder(d2l.Encoder):
"""Transformer编码器"""
def __init__(self, vocab_size, key_size, query_size, value_size,
num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,
num_heads, num_layers, dropout, use_bias=False, **kwargs):
super(TransformerEncoder, self).__init__(**kwargs)
self.num_hiddens = num_hiddens
self.embedding = nn.Embedding(vocab_size, num_hiddens)
self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
self.blks = nn.Sequential()
for i in range(num_layers):
self.blks.add_module("block"+str(i),
EncoderBlock(key_size, query_size, value_size, num_hiddens,
norm_shape, ffn_num_input, ffn_num_hiddens,
num_heads, dropout, use_bias))

def forward(self, X, valid_lens, *args):
# 因为位置编码值在-1和1之间,
# 因此嵌入值乘以嵌入维度的平方根进行缩放,
# 然后再与位置编码相加。
X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
self.attention_weights = [None] * len(self.blks)
for i, blk in enumerate(self.blks):
X = blk(X, valid_lens)
self.attention_weights[i] = blk.attention.attention.attention_weights
return X
1
2
3
encoder = TransformerEncoder(200, 24, 24, 24, 24, [100, 24], 24, 48, 8, 2, 0.5)
encoder.eval()
encoder(torch.ones((2, 100), dtype=torch.long), valid_lens).shape