Transformer编码器由N个EncoderBlock堆叠而成,每个EncoderBlock有两个子层:
- Multi-Head Attention (Self-Attention)
- Positionwise Feed Forward Network
\(EncoderBlock_i\)的输出给到下一层,作为\(EncoderBlock_{i+1}(i=1,2,..,n-1)\)的输入
将\(EncoderBlock_{n}\)的输出给到解码器的各层中
\(EncoderBlock_{n}\)的结构如下所示
image-20240805154949444
前面已实现MultiHeadAttention、AddNorm、PositionWiseFFN等模块,这里首先实现EncoderBlock。
其中包含两个子层: 多头注意力机制、基于位置前馈神经网络
每个子层之间均使用残差连接和Layer normalization
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
| class EncoderBlock(nn.Module): """Transformer编码器块""" def __init__(self, key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout, use_bias=False, **kwargs): super(EncoderBlock, self).__init__(**kwargs) self.attention = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout,use_bias) self.addnorm1 = AddNorm(norm_shape, dropout) self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens) self.addnorm2 = AddNorm(norm_shape, dropout)
def forward(self, X, valid_lens): Y = self.addnorm1(X, self.attention(X, X, X, valid_lens)) return self.addnorm2(Y, self.ffn(Y))
|
可以看出EncoderBlock并不会改变输入的形状
1 2 3 4 5
| X = torch.ones((2, 100, 24)) valid_lens = torch.tensor([3, 2]) encoder_blk = EncoderBlock(24, 24, 24, 24, [100, 24], 24, 48, 8, 0.5) encoder_blk.eval() encoder_blk(X, valid_lens).shape
|
利用n个EncoderBlock类堆叠起来实现Transformer的解码器
由于这里使用的是值范围在\(-1\)和\(1\)之间的固定位置编码,因此通过学习得到的输入的嵌入表示的值需要先乘以嵌入维度的平方根进行重新缩放,然后再与位置编码相加。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
| class TransformerEncoder(d2l.Encoder): """Transformer编码器""" def __init__(self, vocab_size, key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, use_bias=False, **kwargs): super(TransformerEncoder, self).__init__(**kwargs) self.num_hiddens = num_hiddens self.embedding = nn.Embedding(vocab_size, num_hiddens) self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout) self.blks = nn.Sequential() for i in range(num_layers): self.blks.add_module("block"+str(i), EncoderBlock(key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout, use_bias))
def forward(self, X, valid_lens, *args): X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens)) self.attention_weights = [None] * len(self.blks) for i, blk in enumerate(self.blks): X = blk(X, valid_lens) self.attention_weights[i] = blk.attention.attention.attention_weights return X
|
1 2 3
| encoder = TransformerEncoder(200, 24, 24, 24, 24, [100, 24], 24, 48, 8, 2, 0.5) encoder.eval() encoder(torch.ones((2, 100), dtype=torch.long), valid_lens).shape
|