锐单电子商城 , 一站式电子元器件采购平台!
  • 电话:400-990-0325

【PyTorch】09深度体验之图像分类

时间:2022-10-25 21:00:01 二极管db220b

9 PyTorch深度体验

图像分类(Image Classification)

【PyTorch】8.1 图像分类

9.1 模型如何完成图像分类?



图像分类的推理步骤:

9.2 ResNet18模型实例

图像分类ResNet网络结构
参考文献:Deep Residual Learning for Image Recognition


程序:

import os import time import torch.nn as nn import torch import torchvision.transforms as transforms from PIL import Image from matplotlib import pyplot as plt import torchvision.models as models  # BASE_DIR = os.path.dirname(os.path.abspath(__file__)) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # device = torch.device("cpu")  # config vis = True # vis = False vis_row = 4  norm_mean = [0.485, 0.456, 0.406] norm_std = [0.229, 0.224, 0.225]  # 数据预处理 inference_transform = transforms.Compose([     transforms.Resize(256),     transforms.CenterCrop(224),     transforms.ToTensor(),     transforms.Normalize(norm_mean, norm_std), ])  # 类别标签 classes = ["ants", "bees"]  def img_transform(img_rgb
       
        , transform
        =
        None
        )
        : 
        """ 将数据转换为模型读取的形式 :param img_rgb: PIL Image :param transform: torchvision.transform :return: tensor """ 
        if transform 
        is 
        None
        : 
        raise ValueError
        (
        "找不到transform!必须有transform对img进行处理"
        ) img_t 
        = transform
        (img_rgb
        ) 
        return img_t 
        def 
        get_img_name
        (img_dir
        , 
        format
        =
        "jpg"
        )
        : 
        """ 获取文件夹下format格式的文件名 :param img_dir: str :param format: str :return: list """ file_names 
        = os
        .listdir
        (img_dir
        ) img_names 
        = 
        list
        (
        filter
        (
        lambda x
        : x
        .endswith
        (
        format
        )
        , file_names
        )
        ) 
        if 
        len
        (img_names
        ) 
        < 
        1
        : 
        raise ValueError
        (
        "{}下找不到{}格式数据"
        .
        format
        (img_dir
        , 
        format
        )
        ) 
        return img_names 
        def 
        get_model
        (m_path
        , vis_model
        =
        False
        )
        : resnet18 
        = models
        .resnet18
        (
        ) num_ftrs 
        = resnet18
        .fc
        .in_features resnet18
        .fc 
        = nn
        .Linear
        (num_ftrs
        , 
        2
        ) checkpoint 
        = torch
        .load
        (m_path
        ) resnet18
        .load_state_dict
        (checkpoint
        [
        'model_state_dict'
        ]
        ) 
        if vis_model
        : 
        from torchsummary 
        import summary 
        # 查看模型结构及参数信息 summary
        (resnet18
        , input_size
        =
        (
        3
        , 
        224
        , 
        224
        )
        , device
        =
        "cpu"
        ) 
        return resnet18 
        if __name__ 
        == 
        "__main__"
        : 
        # 设置硬盘存放数据的路径 
        # img_dir = os.path.join("..", "..", "data/hymenoptera_data/val/bees") img_dir 
        = os
        .path
        .join
        (
        ".."
        , 
        "data_set"
        , 
        "hymenoptera_data/val/bees"
        ) model_path 
        = 
        "./model_checkpoint/checkpoint_14_epoch.pkl" time_total 
        = 
        0 img_list
        , img_pred 
        = 
        list
        (
        )
        , 
        list
        (
        ) 
        # 1. data img_names 
        = get_img_name
        (img_dir
        ) num_img 
        = 
        len
        (img_names
        ) 
        # 2. model resnet18 
        = get_model
        (model_path
        , 
        True
        ) resnet18
        .to
        (device
        ) 
        # 模型迁移加载至GPU resnet18
        .
        eval
        (
        ) 
        # 设置模型为验证状态 
        with torch
        .no_grad
        (
        )
        : 
        # 以下过程,不用计算梯度,以减少内存消耗,提高运算速度 
        for idx
        , img_name 
        in 
        enumerate
        (img_names
        )
        : path_img 
        = os
        .path
        .join
        (img_dir
        , img_name
        ) 
        # step 1/4 : path --> img img_rgb 
        = Image
        .
        open
        (path_img
        )
        .convert
        (
        'RGB'
        ) 
        # step 2/4 : img --> tensor(模型输入的格式) img_tensor 
        = img_transform
        (img_rgb
        , inference_transform
        ) img_tensor
        .unsqueeze_
        (
        0
        ) 
        # 增加一个新维度,变为4D(符合模型输入) img_tensor 
        = img_tensor
        .to
        (device
        ) 
        # 将(4D)张量数据img_tensor加载至GPU 
        # step 3/4 : tensor --> vector time_tic 
        = time
        .time
        (
        ) 
        # 记录时间 outputs 
        = resnet18
        (img_tensor
        ) time_toc 
        = time
        .time
        (
        ) 
        # step 4/4 : visualization _
        , pred_int 
        = torch
        .
        max
        (outputs
        .data
        , 
        1
        ) pred_str 
        = classes
        [
        int
        (pred_int
        )
        ] 
        if vis
        : img_list
        .append
        (img_rgb
        ) img_pred
        .append
        (pred_str
        ) 
        if 
        (idx
        +
        1
        ) 
        % 
        (vis_row
        *vis_row
        ) 
        == 
        0 
        or num_img 
        == idx
        +
        1
        : 
        for i 
        in 
        range
        (
        len
        (img_list
        )
        )
        : plt
        .subplot
        (vis_row
        , vis_row
        , i
        +
        1
        )
        .imshow
        (img_list
        [i
        ]
        ) plt
        .title
        (
        "predict:{}"
        .
        format
        (img_pred
        [i
        ]
        )
        ) plt
        .show
        (
        ) plt
        .close
        (
        ) img_list
        , img_pred 
        = 
        list
        (
        )
        , 
        list
        (
        ) time_s 
        = time_toc
        -time_tic time_total 
        += time_s 
        print
        (
        '{:d}/{:d}: {} {:.3f}s '
        .
        format
        (idx 
        + 
        1
        , num_img
        , img_name
        , time_s
        )
        ) 
        print
        (
        "\ndevice:{} total time:{:.1f}s mean:{:.3f}s"
        . 
        format
        (device
        , time_total
        , time_total
        /num_img
        )
        ) 
        if torch
        .cuda
        .is_available
        (
        )
        : 
        print
        (
        "GPU name:{}"
        .
        format
        (torch
        .cuda
        .get_device_name
        (
        )
        )
        ) 
       

输出:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]          36,864
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
       BasicBlock-11           [-1, 64, 56, 56]               0
           Conv2d-12           [-1, 64, 56, 56]          36,864
      BatchNorm2d-13           [-1, 64, 56, 56]             128
             ReLU-14           [-1, 64, 56, 56]               0
           Conv2d-15           [-1, 64, 56, 56]          36,864
      BatchNorm2d-16           [-1, 64, 56, 56]             128
             ReLU-17           [-1, 64, 56, 56]               0
       BasicBlock-18           [-1, 64, 56, 56]               0
           Conv2d-19          [-1, 128, 28, 28]          73,728
      BatchNorm2d-20          [-1, 128, 28, 28]             256
             ReLU-21          [-1, 128, 28, 28]               0
           Conv2d-22          [-1, 128, 28, 28]         147,456
      BatchNorm2d-23          [-1, 128, 28, 28]             256
           Conv2d-24          [-1, 128, 28, 28]           8,192
      BatchNorm2d-25          [-1, 128, 28, 28]             256
             ReLU-26          [-1, 128, 28, 28]               0
       BasicBlock-27          [-1, 128, 28, 28]               0
           Conv2d-28          [-1, 128, 28, 28]         147,456
      BatchNorm2d-29          [-1, 128, 28, 28]             256
             ReLU-30          [-1, 128, 28, 28]               0
           Conv2d-31          [-1, 128, 28, 28]         147,456
      BatchNorm2d-32          [-1, 128, 28, 28]             256
             ReLU-33          [-1, 128, 28, 28]               0
       BasicBlock-34          [-1, 128, 28, 28]               0
           Conv2d-35          [-1, 256, 14, 14]         294,912
      BatchNorm2d-36          [-1, 256, 14, 14]             512
             ReLU-37          [-1, 256, 14, 14]               0
           Conv2d-38          [-1, 256, 14, 14]         589,824
      BatchNorm2d-39          [-1, 256, 14, 14]             512
           Conv2d-40          [-1, 256, 14, 14]          32,768
      BatchNorm2d-41          [-1, 256, 14, 14]             512
             ReLU-42          [-1, 256, 14, 14]               0
       BasicBlock-43          [-1, 256, 14, 14]               0
           Conv2d-44          [-1, 256, 14, 14]         589,824
      BatchNorm2d-45          [-1, 256, 14, 14]             512
             ReLU-46          [-1, 256, 14, 14]               0
           Conv2d-47          [-1, 256, 14, 14]         589,824
      BatchNorm2d-48          [-1, 256, 14, 14]             512
             ReLU-49          [-1, 256, 14, 14]               0
       BasicBlock-50          [-1, 256, 14, 14]               0
           Conv2d-51            [-1, 512, 7, 7]       1,179,648
      BatchNorm2d-52            [-1, 512, 7, 7]           1,024
             ReLU-53            [-1, 512, 7, 7]               0
           Conv2d-54            [-1, 512, 7, 7]       2,359,296
      BatchNorm2d-55            [-1, 512, 7, 7]           1,024
           Conv2d-56            [-1, 512, 7, 7]         131,072
      BatchNorm2d-57            [-1, 512, 7, 7]           1,024
             ReLU-58            [-1, 512, 7, 7]               0
       BasicBlock-59            [-1, 512, 7, 7]               0
           Conv2d-60            [-1, 512, 7, 7]       2,359,296
      BatchNorm2d-61            [-1, 512, 7, 7]           1,024
             ReLU-62            [-1, 512, 7, 7]               0
           Conv2d-63            [-1, 512, 7, 7]       2,359,296
      BatchNorm2d-64            [-1, 512, 7, 7]           1,024
             ReLU-65            [-1, 512, 7,<
锐单商城拥有海量元器件数据手册IC替代型号,打造电子元器件IC百科大全!

相关文章