锐单电子商城 , 一站式电子元器件采购平台!
  • 电话:400-990-0325

注意力机制

时间:2022-10-30 22:30:00 0228连接器

注意力机制深度学习网络的载体可以应用于注意力计算规则, 同时一些必要的全连接层和相关张量处理, 与应用网络融为一体. 使用自注意力计算规则的注意力机制称为自注意力机制.
NLP领域中, 目前的注意机制大多用于seq2seq架构, 即编码器和解码器模型.
实现注意机制步骤
第一步: 根据注意力计算规则, 对Q,K,V相应的计算.
第二步: 根据第一步采用的计算方法, 如果是拼接方法,则需要拼接Q和第二步的计算结果, 若为转移点积, 一般是自注, Q与V相同, 不需要与Q拼接.
第三步: 最后,让整个attention按指定尺寸输出机制, 在第二步结果中使用线性层进行线性变换, 得到Q的最终注意表示.
实现代码:

import torch import torch.nn as nn import torch.nn.functional as F  class Attn(nn.Module):     def __init__(self, query_size, key_size, value_size1, value_size2, output_size):         """初始化函数中有5个参数, query_size代表query最后一维大小            key_size代表key最后一维大小, value_size1代表value导数的二维大小,            value = (1, value_size1, value_size2)            value_size2代表value倒数第一维大小, output_size输出的最后一维尺寸"""         super(Attn, self).__init__()         # 将以下参数传入类别         self.query_size = query_size         self.key_size = key_size         self.value_size1 = value_size1         self.value_size2 = value_size2         self.output_size = output_size          # 初始化注意力机制在第一步实现所需的线性层.         self.attn = nn.Linear(self.query_size   self.key_size, value_size1)          # 实现第三步所需的线性层.         self.attn_combine = nn.Linear(self.query_size   value_size2, output_size)       def forward(self, Q, K, V):         """forward函数的输入参数有三个, 分别是Q, K, V, 根据模型训练常识, 输入给Attion机制的            张量一般为三维张量, 所以这里也假设Q, K, V都是三维张量"""          # 第一步, 按计算规则计算,         # 我们使用常见的第一个计算规则         # 将Q,K纵轴拼接, 线性变化, 最后使用softmax处理结果         attn_weights = F.softmax(             self.attn(torch.cat((Q[0], K[0]), 1)), dim=1)          # 然后进行第一步的后半部分, 将获得的权重矩阵和V作为矩阵乘法计算,         # 当两者都是三维张量,第一维代表batch条数时, 则做bmm运算         attn_applied = torch.bmm(attn_weights.unsqueeze(0), V)          # 第二步之后, 取[0]用于降维, 根据第一步采用的计算方法,         # Q和第一步的计算结果需要拼接         output = torch.cat((Q[0], attn_applied[0]), 1)          # 最后是第三步, 在第三步的结果上使用线性层进行线性变换和扩展维度,以获得输出         # 因为要保证输出也是三维张量, 因此使用unsqueeze(0)扩展维度         output = self.attn_combine(output).unsqueeze(0)         return output, attn_weights   query_size = 32 key_size = 32 value_size1 = 32 value_size2 = 64 output_size = 64 attn = Attn(query_size, key_size, value_size1, value_size2, output_size) Q = torch.randn(1,1,32) K = torch.randn(1,1,32) V = torch.randn(1,32,64) out = attn(Q, K ,V) print(out[0]) print(out[1]) 

结果如下:
tensor([[[-0.2961, -0.0948, 0.4384, -0.4684, 0.0987, 0.3926, 0.2671,
1.0258, 0.2068, -0.8418, 0.1220, -0.3244, -0.8128, 0.2292,
0.6818, -0.3369, -0.2666, 0.0036, 0.0643, -0.6318, -0.0867,
0.6521, -0.3778, -0.2478, -0.1729, 0.9106, 0.2469, 0.1512,
0.0736, 0.2501, 0.9162, -0.5796, 0.1865, 0.0234, -0.0553,
0.2651, -0.5230, -0.3136, 0.2308, 0.5429, -0.3149, -0.1805,
0.1518, -0.0573, -0.2517, -0.1196, 0.0647, 0.6827, -0.1228,
-0.2044, 0.0298, 0.2147, -0.3879, -0.0771, -0.1359, -0.1912,
-0.4390, 0.4078, 0.0616, 0.1442, 0.1604, -0.3253, -0.1718,
0.2007]]], grad_fn=)
tensor([[0.0275, 0.0272, 0.0259, 0.0258, 0.0529, 0.0228, 0.0382, 0.0111, 0.0544,
0.0352, 0.0188, 0.0241, 0.0375, 0.0172, 0.0194, 0.0528, 0.0124, 0.0263,
0.0811, 0.0194, 0.0238, 0.0553, 0.0232, 0.0468, 0.0183, 0.0193, 0.0075,
0.0193, 0.0382, 0.0188, 0.0362, 0.0636]])

锐单商城拥有海量元器件数据手册IC替代型号,打造电子元器件IC百科大全!

相关文章