Position-wise FFN

Transformer的每个子层的注意力机制之间都包含一个前馈神经网络

image-20240805142347660

对序列中的所有位置的表示进行变换时使用的是同一个多层感知机(MLP),因此该前馈网络是“Position-wise”

并且使用Relu作为激活函数 \[ FFN(x)=max(0,xW_1+b_1)W_2+b_2 \]

实现PositionWiseFFN类, 输入X通过三层的MLP计算得到输入(python默认对最后一个维度进行推理)

1
2
3
4
5
6
7
8
9
10
11
class PositionWiseFFN(nn.Module):
"""基于位置的前馈网络"""
def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs,
**kwargs):
super(PositionWiseFFN, self).__init__(**kwargs)
self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)
self.relu = nn.ReLU()
self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)

def forward(self, X):
return self.dense2(self.relu(self.dense1(X)))

输入输出的形状分别为:

X: (batch_size, num_steps, num_hidden)

Output: (batch_size, num_steps, ffn_num_outputs)

1
2
3
ffn = PositionWiseFFN(4, 4, 8)
ffn.eval()
ffn(torch.ones((2, 3, 4))).shape # torch.Size([2, 3, 8])

Transformer的一个优点就是在一次Attention计算中,注意力机制会同时考虑到所有位置

RNN在推理的过程中利用隐藏的状态存储信息,这样会导致最后一个token计算时想要参考第一个token时需要通过n次推理

因此Transformer相较于传统的RNN网络提取信息的能力更强!

image-20240805140934105