锐单电子商城 , 一站式电子元器件采购平台!
  • 电话:400-990-0325

PyTorch学习笔记(七)------------------ Vision Transformer

时间:2022-09-21 14:00:00 q18j5a连接器

目录

一、Patch and Linear map

二、Adding classification token

三、Positional encoding

四、LN, MSA and Residual Connection

五、LN、MLP and Residual Connection

六、Classification MLP


前言:vision transformer(vit)自Dosovitskiy自介绍以来,它一直在计算机视觉领域占据主导地位,大多数情况超过了传统的卷积神经网络(cnn)Transformer刚才提出的其实是自然语言处理(NLP)领域,而vit整个思路和NLP大差不异,它把一张完整的图片分成几张token,再将这些token输入网络类似NLP这些单独的中文句子输入token相当于每一个小单词

这是在Vision Transformers for Remote Sensing Image Classification我借用中发表的图片

通过这张照片,我们可以看到a从x1-x9 9张图片,它们等长。这些子图像是线性嵌入的,现在只是一个一维向量,也可以从这些图像中看到x1-x9是按顺序从原图分开的,很重要。之后,在这些token也就是说,在向量中添加位置信息,网络可以通过这些子图恢复图片的原始外观

嵌入位置信息后tokens用于分类的token一起传入到transformer encoder这就是为什么在传输数据时会传输数据 1.这1是分类token。在这个transformer encoder它含有一层归一化(LN),多头自注(MSA)连接残差(resdiual connection),然后再来第二个LN,多层感知器(MLP),一个残差。一般来说,encoder里面的块可以重复很多次,类似于Resnet。最后,一个用于分类MLP块来分类原始输入的特殊分类标记,是一种分类的东西。

现在回顾上面的图片,你觉得思维有点流畅吗?

一、Patch and Linear map

首先,第一个问题是如何将图片变成类似英语句子的图片。作者的方法是将其分为多个子图,并按位置顺序映射到向量上

例如,这里有一张3*224*图片224(3)是通道数 RGB)我们可以把它分成144*14的patch,每一个patch大小为16*16

(N,C,H,W)→(N, 3, 224, 224)→ (N, pathes, patch_dim) → (N, 14*14, 16*16)

现在输入的3*224*图片变成了224 (196, 256),每个patch的维度是16*我们现在的patch每个子图片都可以通过线性映射反馈,线性映射可以映射到任何向量,称为隐藏维度。在这里,我们可以反馈256 映射为 8 256→8.注意可以去除映射的维度

二、Adding classification token

之前说 在tokens传入transformer encoder中时要加一个分类token,它的功能是捕捉其他标记的信息MSA发生在中间。当所有图像传输完成后,我们只能使用这个classification token对图像进行分类

还是刚刚3*224*上面提到的224例子

(N, 196, 256)→(N, 196 1, 256)

这里加的1是分类token

三、Positional encoding

当网络接收到每一个时patch输入,它是如何知道每一个patch在原始图像中的位置?

Vaswani等人的研究表明,只有添加正弦波和余弦波才能实现这一点

同时,标记大小为(N, 197, 前N为(197, 256)这个位置编码重复N次

四、LN, MSA and Residual Connection

LN:给定输入,减去其平均值,除以标准差

MSA:将每一个patch映射三个不同的向量:q,k and v,映射后,通过q和k点乘再除以dim的平方根,softmax这些结果(注意力点)最终将每个注意力线索乘以v,最后加上(感觉无聊)

同时,对每个自注意力头数创建不同的Q,K,V映射函数

或者用例子来解释

(N, 197, 256)→(N, 197, 16, 16)→ nn.Linear(16, 16) → (N, 197, 256)

输入的是(N, 197,256)通过多头注意力(这里用了16个头)将向量变成(N, 197, 16, 16)此时需要一个nn.Linear(16, 16)将其映射成(N, 197, 256)

Residual Connection:残差

之前说过在传入transformer encoder加一个classification token,那这些token如何获得别人token信息在经过LN,MSA和残差操作后,这个classification token还有别的token的信息。

五、LN、MLP and Residual Connection

之前提到在transformer enconder第一步加入块LN, MSA和残差,这是第二步,加入LN、 MLP 和 残差

六、Classification MLP

经过一系列的操作,我们的网络有很多权重指数和数据MLP我们只能从N序列中提取分类标记(token),并使用token来获得分类

例如,我们以前选择的每一个token是16dim我们可以五类,我们可以使用MLP创建一个16*并用5矩阵softmax函数激活

整个vit到目前为止,网络的建设已经全部结束

PY代码如下

class MyViT(nn.Module):     def __init__(self, input_shape, n_patches=14, hidden_d=8, n_heads=2, out_d=5, device=None):                   super(MyViT, self).__init__()         self.device = device                   self.input_shape = input_shape         self.n_patches = n_patches         self.n_heads = n_heads         assert input_shape[1] % n_patches == 0,          assert input_shape[2] % n_patches == 0,          self.patch_size = (input_shape[1] / n_patches, input_shape[2] / n_patches)         self.hidden_d = hidden_d          # 1) Linear mapper         self.input_d = int(input_shape[0] * self.patch_size[0] * self.patch_size[1])         self.linear_mapper = nn.Linear(self.input_d, self.hidden_d)          # 2) Classification token         elf.class_token = nn.Parameter(torch.rand(1, self.hidden_d))

        # 3) Positional embedding
        # (In forward method)

        # 4a) Layer normalization 1
        self.ln1 = nn.LayerNorm((self.n_patches ** 2 + 1, self.hidden_d))

        # 4b) Multi-head Self Attention (MSA) and classification token
        self.msa = MyMSA(self.hidden_d, n_heads)

        # 5a) Layer normalization 2
        self.ln2 = nn.LayerNorm((self.n_patches ** 2 + 1, self.hidden_d))

        # 5b) Encoder MLP
        self.enc_mlp = nn.Sequential(
            nn.Linear(self.hidden_d, self.hidden_d),
            nn.ReLU()
        )

        # 6) Classification MLP
        self.mlp = nn.Sequential(
            nn.Linear(self.hidden_d, out_d),
            nn.Softmax(dim=-1)
        )

    def forward(self, images):
       
        n, c, w, h = images.shape
        patches = images.reshape(n, self.n_patches ** 2, self.input_d)

        
        tokens = self.linear_mapper(patches)

       
        tokens = torch.stack([torch.vstack((self.class_token, tokens[i])) for i in range(len(tokens))])

       
        tokens += get_positional_embeddings(self.n_patches ** 2 + 1, self.hidden_d).repeat(n, 1, 1).to(self.device)

       
        out = tokens + self.msa(self.ln1(tokens))

        
        out = out + self.enc_mlp(self.ln2(out))
       
        
        out = out[:, 0]

        return self.mlp(out)

def get_positional_embeddings(sequence_length, d):
    result = torch.ones(sequence_length, d)
    for i in range(sequence_length):
        for j in range(d):
            result[i][j] = np.sin(i / (10000 ** (j / d))) if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / d)))
    return result

class MyMSA(nn.Module):
    def __init__(self, d, n_heads=2):
        super(MyMSA, self).__init__()
        self.d = d
        self.n_heads = n_heads

        assert d % n_heads == 0, f"Can't divide dimension {d} into {n_heads} heads"

        d_head = int(d / n_heads)
        self.q_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.k_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.v_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.d_head = d_head
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, sequences):
       
        result = []
        for sequence in sequences:
            seq_result = []
            for head in range(self.n_heads):
                q_mapping = self.q_mappings[head]
                k_mapping = self.k_mappings[head]
                v_mapping = self.v_mappings[head]

                seq = sequence[:, head * self.d_head: (head + 1) * self.d_head]
                q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)

                attention = self.softmax(q @ k.T / (self.d_head ** 0.5))
                seq_result.append(attention @ v)
            result.append(torch.hstack(seq_result))
        return torch.cat([torch.unsqueeze(r, dim=0) for r in result])

本文参考了https://medium.com/mlearning-ai/vision-transformers-from-scratch-pytorch-a-step-by-step-guide-96c3313c2e0chttps://medium.com/mlearning-ai/vision-transformers-from-scratch-pytorch-a-step-by-step-guide-96c3313c2e0c

不足之处欢迎指正,源码可以私信或评论,看到就会回复

锐单商城拥有海量元器件数据手册IC替代型号,打造电子元器件IC百科大全!

相关文章