【PyTorch】09深度体验之图像分类
时间:2022-10-25 21:00:01
9 PyTorch深度体验
图像分类(Image Classification)
【PyTorch】8.1 图像分类
9.1 模型如何完成图像分类?
图像分类的推理步骤:
9.2 ResNet18模型实例
图像分类ResNet网络结构
参考文献:Deep Residual Learning for Image Recognition
程序:
import os import time import torch.nn as nn import torch import torchvision.transforms as transforms from PIL import Image from matplotlib import pyplot as plt import torchvision.models as models # BASE_DIR = os.path.dirname(os.path.abspath(__file__)) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # device = torch.device("cpu") # config vis = True # vis = False vis_row = 4 norm_mean = [0.485, 0.456, 0.406] norm_std = [0.229, 0.224, 0.225] # 数据预处理 inference_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std), ]) # 类别
标签 classes = ["ants", "bees"] def img_transform(img_rgb
, transform = None ) : """ 将数据转换为模型读取的形式 :param img_rgb: PIL Image :param transform: torchvision.transform :return: tensor """ if transform is None : raise ValueError ( "找不到transform!必须有transform对img进行处理" ) img_t = transform (img_rgb ) return img_t def get_img_name (img_dir , format = "jpg" ) : """ 获取文件夹下format格式的文件名 :param img_dir: str :param format: str :return: list """ file_names = os .listdir (img_dir ) img_names = list ( filter ( lambda x : x .endswith ( format ) , file_names ) ) if len (img_names ) < 1 : raise ValueError ( "{}下找不到{}格式数据" . format (img_dir , format ) ) return img_names def get_model (m_path , vis_model = False ) : resnet18 = models .resnet18 ( ) num_ftrs = resnet18 .fc .in_features resnet18 .fc = nn .Linear (num_ftrs , 2 ) checkpoint = torch .load (m_path ) resnet18 .load_state_dict (checkpoint [ 'model_state_dict' ] ) if vis_model : from torchsummary import summary # 查看模型结构及参数信息 summary (resnet18 , input_size = ( 3 , 224 , 224 ) , device = "cpu" ) return resnet18 if __name__ == "__main__" : # 设置硬盘存放数据的路径 # img_dir = os.path.join("..", "..", "data/hymenoptera_data/val/bees") img_dir = os .path .join ( ".." , "data_set" , "hymenoptera_data/val/bees" ) model_path = "./model_checkpoint/checkpoint_14_epoch.pkl" time_total = 0 img_list , img_pred = list ( ) , list ( ) # 1. data img_names = get_img_name (img_dir ) num_img = len (img_names ) # 2. model resnet18 = get_model (model_path , True ) resnet18 .to (device ) # 模型迁移加载至GPU resnet18 . eval ( ) # 设置模型为验证状态 with torch .no_grad ( ) : # 以下过程,不用计算梯度,以减少内存消耗,提高运算速度 for idx , img_name in enumerate (img_names ) : path_img = os .path .join (img_dir , img_name ) # step 1/4 : path --> img img_rgb = Image . open (path_img ) .convert ( 'RGB' ) # step 2/4 : img --> tensor(模型输入的格式) img_tensor = img_transform (img_rgb , inference_transform ) img_tensor .unsqueeze_ ( 0 ) # 增加一个新维度,变为4D(符合模型输入) img_tensor = img_tensor .to (device ) # 将(4D)张量数据img_tensor加载至GPU # step 3/4 : tensor --> vector time_tic = time .time ( ) # 记录时间 outputs = resnet18 (img_tensor ) time_toc = time .time ( ) # step 4/4 : visualization _ , pred_int = torch . max (outputs .data , 1 ) pred_str = classes [ int (pred_int ) ] if vis : img_list .append (img_rgb ) img_pred .append (pred_str ) if (idx + 1 ) % (vis_row *vis_row ) == 0 or num_img == idx + 1 : for i in range ( len (img_list ) ) : plt .subplot (vis_row , vis_row , i + 1 ) .imshow (img_list [i ] ) plt .title ( "predict:{}" . format (img_pred [i ] ) ) plt .show ( ) plt .close ( ) img_list , img_pred = list ( ) , list ( ) time_s = time_toc -time_tic time_total += time_s print ( '{:d}/{:d}: {} {:.3f}s ' . format (idx + 1 , num_img , img_name , time_s ) ) print ( "\ndevice:{} total time:{:.1f}s mean:{:.3f}s" . format (device , time_total , time_total /num_img ) ) if torch .cuda .is_available ( ) : print ( "GPU name:{}" . format (torch .cuda .get_device_name ( ) ) )
输出:
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 112, 112] 9,408
BatchNorm2d-2 [-1, 64, 112, 112] 128
ReLU-3 [-1, 64, 112, 112] 0
MaxPool2d-4 [-1, 64, 56, 56] 0
Conv2d-5 [-1, 64, 56, 56] 36,864
BatchNorm2d-6 [-1, 64, 56, 56] 128
ReLU-7 [-1, 64, 56, 56] 0
Conv2d-8 [-1, 64, 56, 56] 36,864
BatchNorm2d-9 [-1, 64, 56, 56] 128
ReLU-10 [-1, 64, 56, 56] 0
BasicBlock-11 [-1, 64, 56, 56] 0
Conv2d-12 [-1, 64, 56, 56] 36,864
BatchNorm2d-13 [-1, 64, 56, 56] 128
ReLU-14 [-1, 64, 56, 56] 0
Conv2d-15 [-1, 64, 56, 56] 36,864
BatchNorm2d-16 [-1, 64, 56, 56] 128
ReLU-17 [-1, 64, 56, 56] 0
BasicBlock-18 [-1, 64, 56, 56] 0
Conv2d-19 [-1, 128, 28, 28] 73,728
BatchNorm2d-20 [-1, 128, 28, 28] 256
ReLU-21 [-1, 128, 28, 28] 0
Conv2d-22 [-1, 128, 28, 28] 147,456
BatchNorm2d-23 [-1, 128, 28, 28] 256
Conv2d-24 [-1, 128, 28, 28] 8,192
BatchNorm2d-25 [-1, 128, 28, 28] 256
ReLU-26 [-1, 128, 28, 28] 0
BasicBlock-27 [-1, 128, 28, 28] 0
Conv2d-28 [-1, 128, 28, 28] 147,456
BatchNorm2d-29 [-1, 128, 28, 28] 256
ReLU-30 [-1, 128, 28, 28] 0
Conv2d-31 [-1, 128, 28, 28] 147,456
BatchNorm2d-32 [-1, 128, 28, 28] 256
ReLU-33 [-1, 128, 28, 28] 0
BasicBlock-34 [-1, 128, 28, 28] 0
Conv2d-35 [-1, 256, 14, 14] 294,912
BatchNorm2d-36 [-1, 256, 14, 14] 512
ReLU-37 [-1, 256, 14, 14] 0
Conv2d-38 [-1, 256, 14, 14] 589,824
BatchNorm2d-39 [-1, 256, 14, 14] 512
Conv2d-40 [-1, 256, 14, 14] 32,768
BatchNorm2d-41 [-1, 256, 14, 14] 512
ReLU-42 [-1, 256, 14, 14] 0
BasicBlock-43 [-1, 256, 14, 14] 0
Conv2d-44 [-1, 256, 14, 14] 589,824
BatchNorm2d-45 [-1, 256, 14, 14] 512
ReLU-46 [-1, 256, 14, 14] 0
Conv2d-47 [-1, 256, 14, 14] 589,824
BatchNorm2d-48 [-1, 256, 14, 14] 512
ReLU-49 [-1, 256, 14, 14] 0
BasicBlock-50 [-1, 256, 14, 14] 0
Conv2d-51 [-1, 512, 7, 7] 1,179,648
BatchNorm2d-52 [-1, 512, 7, 7] 1,024
ReLU-53 [-1, 512, 7, 7] 0
Conv2d-54 [-1, 512, 7, 7] 2,359,296
BatchNorm2d-55 [-1, 512, 7, 7] 1,024
Conv2d-56 [-1, 512, 7, 7] 131,072
BatchNorm2d-57 [-1, 512, 7, 7] 1,024
ReLU-58 [-1, 512, 7, 7] 0
BasicBlock-59 [-1, 512, 7, 7] 0
Conv2d-60 [-1, 512, 7, 7] 2,359,296
BatchNorm2d-61 [-1, 512, 7, 7] 1,024
ReLU-62 [-1, 512, 7, 7] 0
Conv2d-63 [-1, 512, 7, 7] 2,359,296
BatchNorm2d-64 [-1, 512, 7, 7] 1,024
ReLU-65 [-1, 512, 7,<