注意力机制
时间:2022-10-30 22:30:00
注意力机制深度学习网络的载体可以应用于注意力计算规则, 同时一些必要的全连接层和相关张量处理, 与应用网络融为一体. 使用自注意力计算规则的注意力机制称为自注意力机制.
NLP领域中, 目前的注意机制大多用于seq2seq架构, 即编码器和解码器模型.
实现注意机制步骤
第一步: 根据注意力计算规则, 对Q,K,V相应的计算.
第二步: 根据第一步采用的计算方法, 如果是拼接方法,则需要拼接Q和第二步的计算结果, 若为转移点积, 一般是自注, Q与V相同, 不需要与Q拼接.
第三步: 最后,让整个attention按指定尺寸输出机制, 在第二步结果中使用线性层进行线性变换, 得到Q的最终注意表示.
实现代码:
import torch import torch.nn as nn import torch.nn.functional as F class Attn(nn.Module): def __init__(self, query_size, key_size, value_size1, value_size2, output_size): """初始化函数中有5个参数, query_size代表query最后一维大小 key_size代表key最后一维大小, value_size1代表value导数的二维大小, value = (1, value_size1, value_size2) value_size2代表value倒数第一维大小, output_size输出的最后一维尺寸""" super(Attn, self).__init__() # 将以下参数传入类别 self.query_size = query_size self.key_size = key_size self.value_size1 = value_size1 self.value_size2 = value_size2 self.output_size = output_size # 初始化注意力机制在第一步实现所需的线性层. self.attn = nn.Linear(self.query_size self.key_size, value_size1) # 实现第三步所需的线性层. self.attn_combine = nn.Linear(self.query_size value_size2, output_size) def forward(self, Q, K, V): """forward函数的输入参数有三个, 分别是Q, K, V, 根据模型训练常识, 输入给Attion机制的 张量一般为三维张量, 所以这里也假设Q, K, V都是三维张量""" # 第一步, 按计算规则计算, # 我们使用常见的第一个计算规则 # 将Q,K纵轴拼接, 线性变化, 最后使用softmax处理结果 attn_weights = F.softmax( self.attn(torch.cat((Q[0], K[0]), 1)), dim=1) # 然后进行第一步的后半部分, 将获得的权重矩阵和V作为矩阵乘法计算, # 当两者都是三维张量,第一维代表batch条数时, 则做bmm运算 attn_applied = torch.bmm(attn_weights.unsqueeze(0), V) # 第二步之后, 取[0]用于降维, 根据第一步采用的计算方法, # Q和第一步的计算结果需要拼接 output = torch.cat((Q[0], attn_applied[0]), 1) # 最后是第三步, 在第三步的结果上使用线性层进行线性变换和扩展维度,以获得输出 # 因为要保证输出也是三维张量, 因此使用unsqueeze(0)扩展维度 output = self.attn_combine(output).unsqueeze(0) return output, attn_weights query_size = 32 key_size = 32 value_size1 = 32 value_size2 = 64 output_size = 64 attn = Attn(query_size, key_size, value_size1, value_size2, output_size) Q = torch.randn(1,1,32) K = torch.randn(1,1,32) V = torch.randn(1,32,64) out = attn(Q, K ,V) print(out[0]) print(out[1])
结果如下:
tensor([[[-0.2961, -0.0948, 0.4384, -0.4684, 0.0987, 0.3926, 0.2671,
1.0258, 0.2068, -0.8418, 0.1220, -0.3244, -0.8128, 0.2292,
0.6818, -0.3369, -0.2666, 0.0036, 0.0643, -0.6318, -0.0867,
0.6521, -0.3778, -0.2478, -0.1729, 0.9106, 0.2469, 0.1512,
0.0736, 0.2501, 0.9162, -0.5796, 0.1865, 0.0234, -0.0553,
0.2651, -0.5230, -0.3136, 0.2308, 0.5429, -0.3149, -0.1805,
0.1518, -0.0573, -0.2517, -0.1196, 0.0647, 0.6827, -0.1228,
-0.2044, 0.0298, 0.2147, -0.3879, -0.0771, -0.1359, -0.1912,
-0.4390, 0.4078, 0.0616, 0.1442, 0.1604, -0.3253, -0.1718,
0.2007]]], grad_fn=)
tensor([[0.0275, 0.0272, 0.0259, 0.0258, 0.0529, 0.0228, 0.0382, 0.0111, 0.0544,
0.0352, 0.0188, 0.0241, 0.0375, 0.0172, 0.0194, 0.0528, 0.0124, 0.0263,
0.0811, 0.0194, 0.0238, 0.0553, 0.0232, 0.0468, 0.0183, 0.0193, 0.0075,
0.0193, 0.0382, 0.0188, 0.0362, 0.0636]])