注意力机制简单来说就是已知一个Query,以及一些Key-Value对
通过计算Query与每个Key之间的关联度,并通过Softmax计算得到每个Key对应Value的权重
利用权重计算Value的加权和作为Query的输出
如下图,Q1与K2之间的关联度比较高,从而对应V2在计算输出时会给出更大的权重。
与K5关联度比较小,从而V5的权重也会小一些
image-20240805113648793
Masked Softmax Operation
注意力机制通过计算关联度的Softmax值得到一个概率分布作为注意力权重。
但是在Transformer训练过程时,Decoder可以获得对应的所有输出,例如推理“I
love you”中的love词时不应该将"love
you"放入到注意力的计算中。因此在某些情况下,并非所有的值都应该被放入到注意力计算中。
首先定义sequence_mask函数,传入2Dtensor以及对应valid_len(1D)将超过指定范围的元素设置为指定的value
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
| def sequence_mask(X, valid_len, value=0): maxlen = X.size(1) mask = torch.arange((maxlen), dtype=torch.float32, device=X.device)[None, :] >= valid_len[:, None] X[mask] = value return X
sequence_mask(torch.rand(2, 5), torch.tensor([2, 3])) """ tensor([[0.4102, 0.1974, 0.0000, 0.0000, 0.0000], [0.6009, 0.8539, 0.3014, 0.0000, 0.0000]]) """
|
定义masked_softmax函数,完成掩码以及Softmax操作
通过指定一个有效序列长度,
在计算Softmax时去掉超出指定范围的值(权重置为零)。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
| def masked_softmax(X, valid_lens): if valid_lens is None: return nn.functional.softmax(X, dim=-1) else: shape = X.shape if valid_lens.dim() == 1: valid_lens = torch.repeat_interleave(valid_lens, shape[1]) else: valid_lens = valid_lens.reshape(-1) X = sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6) return nn.functional.softmax(X.reshape(shape), dim=-1)
|
例如定义两个2×4矩阵表示的样本,这两个样本的有效长度分别为2和3。
经过掩蔽softmax操作,超出有效长度的值都被掩蔽为0。
1 2 3 4 5 6 7 8
| masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3])) """ tensor([[[0.3718, 0.6282, 0.0000, 0.0000], [0.3768, 0.6232, 0.0000, 0.0000]],
[[0.3184, 0.2974, 0.3841, 0.0000], [0.3008, 0.3684, 0.3308, 0.0000]]]) """
|
同样也可以使用二维tensor,为矩阵样本中的每一行指定有效长度。
1 2 3 4 5 6 7 8
| masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3])) """ tensor([[[1.0000, 0.0000, 0.0000, 0.0000], [0.5771, 0.4229, 0.0000, 0.0000]],
[[0.2957, 0.4654, 0.2389, 0.0000], [0.2381, 0.2024, 0.2127, 0.3467]]]) """
|
Scaled Dot-Product Attention
Transformer中所采用的是Scaled Dot-Product
Attention作为注意力评分函数
要求query和key相同的维度\(d_k\),
value的维度可以和key不同\(d_v\)
计算query与每个key之间的点乘,从几何的角度来看就是key在query上的投影,作为query与该key之间的相似程度
内积相当于计算余弦值,越大表示相似度越高,如果两个向量正交余弦值=0,表示不相关
image-20240805114456973
假设查询和键的所有元素都是独立的随机变量,
并且都满足零均值和单位方差, 那么两个向量的点积的均值为0,方差为𝑑。
为确保无论向量长度如何, 点积的方差在不考虑向量长度的情况下仍然是1,
再将点积除以𝑑(即“缩放”)
因此计算query与每个key之间的点乘后,除以\(\sqrt{d_k}\)(为了避免exp后方差过大,归一化操作)
再通过Softmax函数计算得到权重,最后乘上value得到注意力分数
一般情况是多个query和多个key-value对计算注意力分数,为了并行计算通常采用矩阵的乘法
\[
Attention(Q,K,V)=Softmax(\frac{QK^T}{\sqrt{d_k}})V
\] 计算的流程图如下
image-20240805115604743
从矩阵的角度上看,Q矩阵首先和K的转置相乘,然后每个元素除以\(\sqrt{d_k}\),使得方差变为1
在Decoder中可能会使用到Mask的操作,将Mask范围以外的元素置为一个非常小的数
再通过Softmax计算权重(负无穷的exp为0)
乘上V矩阵得到注意力分数
image-20240805120913349
实现DotProductAttention类完成点乘注意力分数计算。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
| class DotProductAttention(nn.Module): def __init__(self, dropout, **kwargs): super(DotProductAttention, self).__init__(**kwargs) self.dropout = nn.Dropout(dropout)
def forward(self, queries, keys, values, valid_lens=None): """ :param queries: [batch_size, num_queries, d_query] :param keys: [batch_size, num_kv_pairs, d_key] # d_query = d_key :param values: [batch_size, num_kv_pairs, d_value] :param valid_lens: [batch_size] or [batch_size, num_query] :return: [batch_size, num_queries, d_value] """ d = queries.shape[-1] scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d) self.attention_weights = masked_softmax(scores, valid_lens) return torch.bmm(self.dropout(self.attention_weights), values)
|
从测试数据中可以看出输入输出的维度分别为:
Q: (batch_size, num_queries, d_key)
K: (batch_size, num_kv_qairs, d_key)
V: (batch_size, num_kv_qairs, d_value)
Ouput: (batch_size, num_queries, d_value)
1 2 3 4 5 6 7 8
| batch_size, num_q, num_kv,d_k, d_v = 2, 3, 10, 5, 6 queries = torch.rand((batch_size, num_q, d_k)) keys = torch.rand((batch_size, num_kv, d_k)) values = torch.rand((batch_size, num_kv, d_v)) valid_lens = torch.tensor([2, 6]) attention = DotProductAttention(dropout=0.5) attention.eval() attention(queries, keys, values, valid_lens).shape
|
Multi-Head Attention
在给定一个的查询、键和值时,如果仅仅考虑使用一个注意力函数计算往往会表现不太好
参考CNN网络中的多通道,使用多个卷积核识别不同的图像模式以提取多个维度的特征
因此Transformer希望使用相同的注意力机制学习到不同的特征,
然后将不同的特征组合起来以获得更多的序列信息
在前面的点乘注意力机制中没有可以学习的参数,而线性层中有可以学习的参数
Transformer通过独立学习得到的ℎ组不同的Linear层,将Q,K,V投影到低维空间
将h组的Q,K,V分别计算注意力(并行实现)
最后将h个注意力的输出拼接到一起,再通过一个线性层(可学习)产生最终输出。
image-20240805124645884
多头注意力可以使模型从不同的表达空间中获得特征,并组合在一起,相较于普通的单头注意力机制有更好的特征提取能力
\[
MultiHead(Q,K,V)=Concat(head_1,...,head_h)W^o
\]
\[
where\ head_i=Attention(QW^Q_i,KW^K_i,VW^V_i)
\]
其中\(W^Q_i∈R^{d_{model}×d_k},\
W^K_i∈R^{d_{model}×d_k},\ W^V_i∈R^{d_{model}\
×d_v},W^O∈R^{hd_v×d_{model}}\)
MultiHeadAttention类实现中考虑到涉及多个注意力头的计算,为了并行实现提高效率,因此将tensor进行转换,将head的维度放到batch的维度,这样可以一次实现对所有注意力头的计算。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
| class MultiHeadAttention(nn.Module): def __init__(self, key_size, query_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs): super(MultiHeadAttention, self).__init__(**kwargs) self.num_heads = num_heads self.attention = DotProductAttention(dropout) self.W_q = nn.Linear(query_size, num_hiddens, bias=bias) self.W_k = nn.Linear(key_size, num_hiddens, bias=bias) self.W_v = nn.Linear(value_size, num_hiddens, bias=bias) self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
def forward(self, queries, keys, values, valid_lens): """ :param queries: [batch_size, num_queries, d_query] :param keys: [batch_size, num_kv_pairs, d_key] # d_query = d_key :param values: [batch_size, num_kv_pairs, d_value] :param valid_lens: [batch_size] or [batch_size, num_query] :return: [batch_size, num_queries, num_hiddens] """
queries = transpose_qkv(self.W_q(queries), self.num_heads) keys = transpose_qkv(self.W_k(keys), self.num_heads) values = transpose_qkv(self.W_v(values), self.num_heads)
if valid_lens is not None: valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)
output = self.attention(queries, keys, values, valid_lens)
output_concat = transpose_output(output, self.num_heads) return self.W_o(output_concat)
def transpose_qkv(X, num_heads): """通过变换矩阵形状实现多头注意力的并行计算""" X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
X = X.permute(0, 2, 1, 3)
return X.reshape(-1, X.shape[2], X.shape[3])
def transpose_output(X, num_heads): """逆转transpose_qkv操作""" X = X.reshape(-1, num_heads, X.shape[1], X.shape[2]) X = X.permute(0, 2, 1, 3) return X.reshape(X.shape[0], X.shape[1], -1)
|
可以看出多头注意力的输入输出维度:
Q: (batch_size, num_queries, d_key)
K: (batch_size, num_kv_qairs, d_key)
V: (batch_size, num_kv_qairs, d_value)
Ouput: (batch_size, num_queries, num_hiddens)
1 2 3 4 5 6 7 8 9
| num_hiddens, num_heads = 100, 5 attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, 0.5) attention.eval() batch_size, num_queries = 2, 4 num_kvpairs, valid_lens = 6, torch.tensor([3, 2]) X = torch.ones((batch_size, num_queries, num_hiddens)) Y = torch.ones((batch_size, num_kvpairs, num_hiddens)) attention(X, Y, Y, valid_lens).shape
|