系统学习CV-Transformer
时间:2023-11-13 04:37:02
系统学习CV-Transformer
- 参考
- 概念
参考
https://www.bilibili.com/video/BV15v411W78M?spm_id_from=333.999.0.0&vd_source=7155082256127a432d5ed516a6423e20
https://www.bilibili.com/video/BV1pu411o7BE?spm_id_from=333.337.search-card.all.click&vd_source=7155082256127a432d5ed516a6423e20
代码参考 – 未扒完
https://zh-v2.d2l.ai/chapter_attention-mechanisms/multihead-attention.html
概念
注意力:自主,非自主 https://zh-v2.d2l.ai/chapter_attention-mechanisms/attention-cues.html
非自主性:全连接层和汇聚层
自注机制
- embedding: 输入通过编码获得向量
- query: 我检查别人,我提供的向量
- key: 我被别人检查,我提供的向量
- value:表达当前词的特征
K\Q\V矩阵是通过训练获得的
x 1 ? W q = q 1 x1 * W^{q}=q_{1} x1?Wq=q1
x 1 ? W k = k 1 x1 * W^{k}=k_{1} x1?Wk=k1
x 1 ∗ W v = v 1 x1 * W^{v}=v_{1} x1∗Wv=v1
x 2 ∗ W q = q 2 x2 * W^{q}=q_{2} x2∗Wq=q2
x 2 ∗ W k = k 2 x2 * W^{k}=k_{2} x2∗Wk=k2
x 2 ∗ W v = v 2 x2 * W^{v}=v_{2} x2∗Wv=v2< q 1 ⃗ , q 1 ⃗ > <\vec{q_{1}},\vec{q_{1}}> <q1,q1>, < q 1 ⃗ , q 2 ⃗ > <\vec{q_{1}},\vec{q_{2}}> <q1,q2> 得到每个词与其他词的关系(建模为注意力权重)
w 11 = < q 1 ⃗ , q 1 ⃗ > w_{11}=<\vec{q_{1}},\vec{q_{1}}> w11=<q1,q1>, w 12 = < q 1 ⃗ , q 2 ⃗ > w_{12}=<\vec{q_{1}},\vec{q_{2}}> w12=<q1,q2>
o u t ( x 1 ) = w 11 ∗ v 1 + w 12 ∗ v 2 out(x_{1})=w_{11}*v_{1}+w_{12}*v_{2} out(x1)=w11∗v1+w12∗v2 再求加权求和进行词的重构(注意权重归一化)
w 21 = < q 2 ⃗ , q 1 ⃗ > w_{21}=<\vec{q_{2}},\vec{q_{1}}> w21=<q2,q1>, w 22 = < q 2 ⃗ , q 2 ⃗ > w_{22}=<\vec{q_{2}},\vec{q_{2}}> w22=<q2,q2>
o u t ( x 2 ) = w 21 ∗ v 1 + w 22 ∗ v 2 out(x_{2})=w_{21}*v_{1}+w_{22}*v_{2} out(x2)=w21∗v1+w22∗v2 再求加权求和进行词的重构(注意权重归一化)
非参注意力池化层–Nadaraya-Watson核回归
https://www.bilibili.com/video/BV1264y1i7R1?spm_id_from=333.999.0.0&vd_source=7155082256127a432d5ed516a6423e20
将查询与键之间的关系(注意力权重)建模为高斯核函数
最终其实就是softmax核函数
一个键(key)与给定的查询(query)越接近,分配给该键(key)对应的值(value)的注意力权重越大,即获得了更多的注意力
参数化注意力机制
https://www.bilibili.com/video/BV1264y1i7R1?spm_id_from=333.999.0.0&vd_source=7155082256127a432d5ed516a6423e20
注意力评分函数
上述的内积是一种评分函数
更通用的写法是 a ( q , k i ) a(q,k_{i}) a(q,ki)
两种常用的注意力评分函数:加性注意力、缩放点积注意力
加性注意力
query和key长度不同
# https://zh-v2.d2l.ai/chapter_attention-mechanisms/attention-scoring-functions.html
class AdditiveAttention(nn.Module):
def __init__(self,key_size,query_size,num_hiddens,dropout,**kwargs):
super(AdditiveAttention,self).__init__(**kwargs)
self.w_k=nn.Linear(key_size,num_hiddens,bias=False)
self.w_q=nn.Linear(query_size,num_hiddens,bias=False)
self.w_v=nn.Linear(num_hiddens,1,bias=False)
self.dropout=nn.Dropout(dropout)
def forward(self,queries,keys,values,valid_lens):
queries=self.w_q(queries) # (batch_size,查询的个数,num_hiddens)
keys=self.w_k(keys) # (batch_size,键值对的个数,num_hiddens)
# (batch_size,查询的个数,1,num_hiddens)
# (batch_size,1,键值对的个数,num_hiddens)
features=quires,unsqueeze(2)+keys.unsqueeze(1) #
features=torch.tanh(features)
scores=self.w_v(features).squeeze(-1)
# 做softmax
self.attention_weight=masked_softmax(scores,valid_lens)
return torch.bmm(self.dropout(self.attention_weights),values)
缩放点积注意力
query 和 key长度相同
注意:归一化
class DotProductAttention(nn.Module):
def __init___(self,dropout,**kwargs):
super(DotProductAttention,self).__init__(**kwargs)
self.dropout=nn.Dropout(dropout)
def forward(self,queries,keys,values,vaild_lens=None):
d=queries.shape[-1]
scores=torch.bmm(queries,keys.transpose(1,2)/math.sqrt(d))
self.attention_weight=masked_softmax(scores,valid_lens)
return torch.bmm元器件数据手册、IC替代型号,打造电子元器件IC百科大全!