【Transformer】Attention

注意力机制简单来说就是已知一个Query,以及一些Key-Value对

通过计算Query与每个Key之间的关联度,并通过Softmax计算得到每个Key对应Value的权重

利用权重计算Value的加权和作为Query的输出

如下图,Q1与K2之间的关联度比较高,从而对应V2在计算输出时会给出更大的权重。

与K5关联度比较小,从而V5的权重也会小一些

image-20240805113648793

Masked Softmax Operation

注意力机制通过计算关联度的Softmax值得到一个概率分布作为注意力权重。

但是在Transformer训练过程时,Decoder可以获得对应的所有输出,例如推理“I love you”中的love词时不应该将"love you"放入到注意力的计算中。因此在某些情况下,并非所有的值都应该被放入到注意力计算中。

首先定义sequence_mask函数,传入2Dtensor以及对应valid_len(1D)将超过指定范围的元素设置为指定的value

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def sequence_mask(X, valid_len, value=0):
# X: 2D tensor [num_steps, d_model]
# valid_len: 1D tensor [num_step]
maxlen = X.size(1)

# torch.arange(5)[None, :] 在tensor最前面维度上拓展一个维度 tensor([[0, 1, 2, 3, 4]])
# torch.tensor([2, 3])[:, None] 相当于.reshape([-1, 1])
# 广播机制 torch.arange(5)[None, :] < torch.tensor([2, 3]).reshape([-1, 1])
# tensor([[ True, True, False, False, False],
# [ True, True, True, False, False]])

mask = torch.arange((maxlen), dtype=torch.float32, device=X.device)[None, :] >= valid_len[:, None]
X[mask] = value
return X

sequence_mask(torch.rand(2, 5), torch.tensor([2, 3]))
"""
tensor([[0.4102, 0.1974, 0.0000, 0.0000, 0.0000],
[0.6009, 0.8539, 0.3014, 0.0000, 0.0000]])
"""

定义masked_softmax函数,完成掩码以及Softmax操作

通过指定一个有效序列长度, 在计算Softmax时去掉超出指定范围的值(权重置为零)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def masked_softmax(X, valid_lens):
# X: 3D [batch_size, num_queries, num_kv_pairs]
# valid_lens: 1D [batch_size] tensor([2, 3]) 表示第1、2个batch分别使用2、3的valid_lens
# 2D [batch_size, num_steps] torch.tensor([[1, 3], [2, 4]]) 第1个batch的第1、2句分别使用1、3的valid_len

if valid_lens is None:
return nn.functional.softmax(X, dim=-1) # 在最后一个维度上做softmax
else:
shape = X.shape
if valid_lens.dim() == 1:
# 给一个tensor将其展开n次
# torch.repeat_interleave(torch.tensor([10]), 5) -> tensor([10, 10, 10, 10, 10])
valid_lens = torch.repeat_interleave(valid_lens, shape[1])

else:
# 展开成一个维度
valid_lens = valid_lens.reshape(-1)

# 在最后一个维度将masked元素替换成负无穷,经过exp后变为0
# 合并batch_size和num_steps两个维度再传入sequence_mask处理
X = sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)

# 转回3D tensor并在最后一个维度做softmax
return nn.functional.softmax(X.reshape(shape), dim=-1)

例如定义两个2×4矩阵表示的样本,这两个样本的有效长度分别为2和3。

经过掩蔽softmax操作,超出有效长度的值都被掩蔽为0。

1
2
3
4
5
6
7
8
masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3]))
"""
tensor([[[0.3718, 0.6282, 0.0000, 0.0000],
[0.3768, 0.6232, 0.0000, 0.0000]],

[[0.3184, 0.2974, 0.3841, 0.0000],
[0.3008, 0.3684, 0.3308, 0.0000]]])
"""

同样也可以使用二维tensor,为矩阵样本中的每一行指定有效长度。

1
2
3
4
5
6
7
8
masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3]))
"""
tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
[0.5771, 0.4229, 0.0000, 0.0000]],

[[0.2957, 0.4654, 0.2389, 0.0000],
[0.2381, 0.2024, 0.2127, 0.3467]]])
"""

Scaled Dot-Product Attention

Transformer中所采用的是Scaled Dot-Product Attention作为注意力评分函数

要求query和key相同的维度\(d_k\), value的维度可以和key不同\(d_v\)

计算query与每个key之间的点乘,从几何的角度来看就是key在query上的投影,作为query与该key之间的相似程度

内积相当于计算余弦值,越大表示相似度越高,如果两个向量正交余弦值=0,表示不相关

image-20240805114456973

假设查询和键的所有元素都是独立的随机变量, 并且都满足零均值和单位方差, 那么两个向量的点积的均值为0,方差为𝑑。

为确保无论向量长度如何, 点积的方差在不考虑向量长度的情况下仍然是1, 再将点积除以𝑑(即“缩放”)

因此计算query与每个key之间的点乘后,除以\(\sqrt{d_k}\)(为了避免exp后方差过大,归一化操作)

再通过Softmax函数计算得到权重,最后乘上value得到注意力分数

一般情况是多个query和多个key-value对计算注意力分数,为了并行计算通常采用矩阵的乘法 \[ Attention(Q,K,V)=Softmax(\frac{QK^T}{\sqrt{d_k}})V \] 计算的流程图如下

image-20240805115604743

从矩阵的角度上看,Q矩阵首先和K的转置相乘,然后每个元素除以\(\sqrt{d_k}\),使得方差变为1

在Decoder中可能会使用到Mask的操作,将Mask范围以外的元素置为一个非常小的数

再通过Softmax计算权重(负无穷的exp为0)

乘上V矩阵得到注意力分数

image-20240805120913349

实现DotProductAttention类完成点乘注意力分数计算。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class DotProductAttention(nn.Module):
def __init__(self, dropout, **kwargs):
super(DotProductAttention, self).__init__(**kwargs)
self.dropout = nn.Dropout(dropout)

def forward(self, queries, keys, values, valid_lens=None):
"""
:param queries: [batch_size, num_queries, d_query]
:param keys: [batch_size, num_kv_pairs, d_key] # d_query = d_key
:param values: [batch_size, num_kv_pairs, d_value]
:param valid_lens: [batch_size] or [batch_size, num_query]
:return: [batch_size, num_queries, d_value]
"""
d = queries.shape[-1]

# bmm 批量矩阵乘法, 对分别对每个batch中的矩阵相乘
# a, b = torch.rand((2, 3, 4)), torch.rand((2, 4, 5))
# torch.bmm(a, b).shape -> torch.Size([2, 3, 5])
scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d)

# 掩码操作+Softmax
self.attention_weights = masked_softmax(scores, valid_lens)

# Dropout后乘上value矩阵
return torch.bmm(self.dropout(self.attention_weights), values)

从测试数据中可以看出输入输出的维度分别为:

Q: (batch_size, num_queries, d_key)

K: (batch_size, num_kv_qairs, d_key)

V: (batch_size, num_kv_qairs, d_value)

Ouput: (batch_size, num_queries, d_value)

1
2
3
4
5
6
7
8
batch_size, num_q, num_kv,d_k, d_v = 2, 3, 10, 5, 6 
queries = torch.rand((batch_size, num_q, d_k))
keys = torch.rand((batch_size, num_kv, d_k))
values = torch.rand((batch_size, num_kv, d_v))
valid_lens = torch.tensor([2, 6])
attention = DotProductAttention(dropout=0.5)
attention.eval()
attention(queries, keys, values, valid_lens).shape # torch.Size([2, 3, 6])

Multi-Head Attention

在给定一个的查询、键和值时,如果仅仅考虑使用一个注意力函数计算往往会表现不太好

参考CNN网络中的多通道,使用多个卷积核识别不同的图像模式以提取多个维度的特征

因此Transformer希望使用相同的注意力机制学习到不同的特征, 然后将不同的特征组合起来以获得更多的序列信息

在前面的点乘注意力机制中没有可以学习的参数,而线性层中有可以学习的参数

Transformer通过独立学习得到的ℎ组不同的Linear层,将Q,K,V投影到低维空间

将h组的Q,K,V分别计算注意力(并行实现)

最后将h个注意力的输出拼接到一起,再通过一个线性层(可学习)产生最终输出。

image-20240805124645884

多头注意力可以使模型从不同的表达空间中获得特征,并组合在一起,相较于普通的单头注意力机制有更好的特征提取能力

\[ MultiHead(Q,K,V)=Concat(head_1,...,head_h)W^o \]

\[ where\ head_i=Attention(QW^Q_i,KW^K_i,VW^V_i) \]

其中\(W^Q_i∈R^{d_{model}×d_k},\ W^K_i∈R^{d_{model}×d_k},\ W^V_i∈R^{d_{model}\ ×d_v},W^O∈R^{hd_v×d_{model}}\)

MultiHeadAttention类实现中考虑到涉及多个注意力头的计算,为了并行实现提高效率,因此将tensor进行转换,将head的维度放到batch的维度,这样可以一次实现对所有注意力头的计算。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
class MultiHeadAttention(nn.Module):
def __init__(self, key_size, query_size, value_size, num_hiddens,
num_heads, dropout, bias=False, **kwargs):
super(MultiHeadAttention, self).__init__(**kwargs)
self.num_heads = num_heads
self.attention = DotProductAttention(dropout)

# Linear 默认对最后一个维度做计算
# 这里的num_hiddens是指多个头合在一起的维度,单个头是num_hiddens/num_heads
self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

def forward(self, queries, keys, values, valid_lens):
"""
:param queries: [batch_size, num_queries, d_query]
:param keys: [batch_size, num_kv_pairs, d_key] # d_query = d_key
:param values: [batch_size, num_kv_pairs, d_value]
:param valid_lens: [batch_size] or [batch_size, num_query]
:return: [batch_size, num_queries, num_hiddens]
"""

# 变换后的形状 [batch_size*num_heads, num_queries, num_hiddens/num_heads]
queries = transpose_qkv(self.W_q(queries), self.num_heads)

# 变换后的形状 [batch_size*num_heads, num_kv_pairs, num_hiddens/num_heads]
keys = transpose_qkv(self.W_k(keys), self.num_heads)
values = transpose_qkv(self.W_v(values), self.num_heads)

if valid_lens is not None:
# 在第0个维度上,将第一个标量复制num_heads次,然后复制第二个。。
valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)

# 将多头注意力的计算合并到一次完成
# [batch_size*num_heads, num_queries, num_hiddens/num_heads]
output = self.attention(queries, keys, values, valid_lens)

# [batch_size, num_queries, num_hiddens]
output_concat = transpose_output(output, self.num_heads)

# 最后再通过一个线性层
return self.W_o(output_concat)

def transpose_qkv(X, num_heads):
"""通过变换矩阵形状实现多头注意力的并行计算"""
# X: [batch_size, num_q/num_kv, num_hidden]
# --> [batch_size, num_q/num_kv, num_heads, num_hidden/num_heads]
X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)

# 将第1、2维交换位置
# [batch_size, num_heads, num_q/num_kv, num_hidden/num_heads]
X = X.permute(0, 2, 1, 3)

# 合并0、1维
# [batch_size*num_heads, num_q/num_kv, num_hidden/num_heads]
return X.reshape(-1, X.shape[2], X.shape[3])


def transpose_output(X, num_heads):
"""逆转transpose_qkv操作"""
X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
X = X.permute(0, 2, 1, 3)
return X.reshape(X.shape[0], X.shape[1], -1)

可以看出多头注意力的输入输出维度:

Q: (batch_size, num_queries, d_key)

K: (batch_size, num_kv_qairs, d_key)

V: (batch_size, num_kv_qairs, d_value)

Ouput: (batch_size, num_queries, num_hiddens)

1
2
3
4
5
6
7
8
9
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
num_hiddens, num_heads, 0.5)
attention.eval()
batch_size, num_queries = 2, 4
num_kvpairs, valid_lens = 6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens).shape # torch.Size([2, 4, 100])