锐单电子商城 , 一站式电子元器件采购平台!
  • 电话:400-990-0325

pytorch P28 -卷积神经网络demo

时间:2022-11-18 06:00:00 p28j2mqjg密封连接器p28j4mj密封连接器p28k2aqjg连接器

卷积神经网络和 传统神经 网络训练模块基本一致,网络 模型差异很大。

一 读取数据

# 导包 import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torchvision import datasets,transforms import matplotlib.pyplot as plt import numpy as np  #读取 数据 # 定义超参数 input_size = 28 # 图像大小:28 * 28 num_classes = 10 # 标签的种类 num_epochs = 3 # 迭代的次数 batch_size = 64 # 每批的大小,即每64章图片一块进行一次训练  # 加载训练集 train_dataset = datasets.MNIST(                                 root='./data',                                 train=True,                                 transform=transforms.ToTensor(),                                 download=True                                 ) # 记载测试集 test_dataset = datasets.MNIST(root='./data',                              train=False,                              transform=transforms.ToTensor()) # 构建batch数据 train_loader = torch.utils.data.DataLoader(dataset=train_dataset,                                           batch_size=batch_size,                                           shuffle=True) test_loader = torch.utils.data.DataLoader(dataset=test_dataset,                                          batch_size=batch_size,                                          )

前面是导包

数据源还是mnist,分为训练集与测试集。使用DataLoader来构建batch数据。

二 构建卷积神经网络模型

#网络 模型 class CNN(nn.Module):     def __init__(self):         super(CNN, self).__init__()          self.conv1 = nn.Sequential(  # 输入大小 (1,28,28)             nn.Conv2d(                 in_channels=1,  # 说明是灰度图                 out_channels=16,  # 要获得多少个特征图?                 kernel_size=5,  # 卷积核的大小                 stride=1,  # 步长                 padding=2),  # 填充边缘的大小             nn.ReLU(),  #relu层             nn.MaxPool2d(kernel_size=2)  # 池化操作 (2 * 2) 输出结果如下: (16,14,14)          )         self.conv2 = nn.Sequential(             nn.Conv2d(16, 32, 5, 1, 2),#16 是输入,32 是特征图             nn.ReLU(), #relu层              nn.MaxPool2d(2))  # 输出 (32, 7, 7)         self.out = nn.Linear(32 * 7 * 7, 10)  # 全连接输入分类      def forward(self, x):         x = self.conv1(x)         x = self.conv2(x)         x = x.view(x.size(0), -1)  # flatten操作,结果是 (batch_size, 32*7*7)         output = self.out(x)         return output #准确率 def accuracy(predictions,labels):     pred = torch.max(predictions.data,1)[1]     rights = pred.eq(labels.data.view_as(pred)).sum()     return rights,len(labels)

主要内容 参数 ,conv1与conv2 里面的 需要结合老师的视频区理解。还有前传播调用out之前的view 变形操作。结合上节从矩阵降维到全连接的理解 。

三 训练网络模型

# 实例化 net = CNN() # 选择损失函数 criterion = nn.CrossEntropyLoss() # 选择优化器 optimizer = optim.Adam(net.parameters(), lr=0.001)  # 定义优化器,采用随机梯度下降算法  # 开始训练 for epoch in range(num_epochs):     train_right = []  # 保存当前epoch以前定义的保存结果loss是一个道理      for batch_idx, (data, target) in enumerate(train_loader):         net.train()         output = net(data)         loss = criterion(output, target)         optimizer.zero_grad()         loss.backward()         optimizer.step()  # 优化器调用step(),不是loss         right = accuracy(output, target)         train_right.append(right)          if batch_idx % 100 == 0:#每100次验证收集效果             net.eval()             val_right = []             for (data, target) in test_loader:                 output = net(data)                 right = accuracy(output, target)                 val_right.append(right)              # 计算精度             train_rate = (sum([tup[0] for tup in train_right]), sum(tup[1] for tup in train_right))             val_rate = (sum([tup[0] for tup in val_right]), sum(tup[1] for tup in val_right))              print('当前epoch:{}.0f}%)]\t 损失:{:.6f}\t 训练集精度:{:.2f}%\t 试验集精度:{:.2f}%'.format(                 epoch, batch_idx * batch_size, len(train_loader.dataset),                        100.0 * batch_idx / len(train_loader),                 loss.data,                        100.0 * train_rate[0].numpy() / train_rate[1],                        100.0 * val_rate[0].numpy() / val_rate[1]             ))

看验证集输出结果100次:

当前epoch:0[0/600000(0%)  损失:2.300367  训练集精度:12.50%  测试集精度:10.14% 当前epoch:0[6400/60000(11%)  损失:0.280655  训练集精度:74.13%  测试集精度:92.48% 当前epoch:0[12800/60000(21%)  损失:0.166317  训练集精度:83.76%  测试集精度:95.60% 当前epoch:0[19200/60000(32%)  损失:0.105674  训练集精度:87.71%  测试集精度:95.65% 当前epoch:0[25600/60000(43%)  损失:0.094606  训练集精度:89.83%  测试集精度:97.22% 当前epoch:0[32000/60000(53%)  损失:0.065384  训练集精度:91.26%  测试集精度:97.60% 当前epoch:0[38400/60000(64%)  损失:0.049964  训练集精度:92.25%  测试集精度:97.51% 当前epoch:0[44800/60000(75%)  损失:0.035163  训练集精度:93.01%  测试集精度:97.83% 当前epoch:0[51200/60000(85%)  损失:0.055695  训练集精度:93.56%  测试集精度:98.14% 当前epoch:0[57600/60000(96%)  损失:0.014890  训练集准确率:94.03%  测试集精度:97.77% 当前epoch:1[0/600000(0%) 损失:0.081240	 训练集准确率:93.75%	 测试集准确率:98.20%
当前epoch:1[6400/60000(11%)]	 损失:0.049458	 训练集准确率:98.04%	 测试集准确率:98.24%
当前epoch:1[12800/60000(21%)]	 损失:0.026402	 训练集准确率:98.12%	 测试集准确率:98.18%
当前epoch:1[19200/60000(32%)]	 损失:0.056982	 训练集准确率:98.11%	 测试集准确率:98.49%
当前epoch:1[25600/60000(43%)]	 损失:0.098775	 训练集准确率:98.13%	 测试集准确率:98.63%
当前epoch:1[32000/60000(53%)]	 损失:0.119748	 训练集准确率:98.15%	 测试集准确率:98.26%
当前epoch:1[38400/60000(64%)]	 损失:0.024341	 训练集准确率:98.18%	 测试集准确率:98.49%
当前epoch:1[44800/60000(75%)]	 损失:0.017717	 训练集准确率:98.20%	 测试集准确率:97.95%
当前epoch:1[51200/60000(85%)]	 损失:0.084650	 训练集准确率:98.20%	 测试集准确率:98.45%
当前epoch:1[57600/60000(96%)]	 损失:0.014650	 训练集准确率:98.18%	 测试集准确率:98.68%
当前epoch:2[0/60000(0%)]	 损失:0.089021	 训练集准确率:96.88%	 测试集准确率:98.54%
当前epoch:2[6400/60000(11%)]	 损失:0.048318	 训练集准确率:98.72%	 测试集准确率:98.68%
当前epoch:2[12800/60000(21%)]	 损失:0.051317	 训练集准确率:98.71%	 测试集准确率:98.62%
当前epoch:2[19200/60000(32%)]	 损失:0.033962	 训练集准确率:98.67%	 测试集准确率:98.53%
当前epoch:2[25600/60000(43%)]	 损失:0.025890	 训练集准确率:98.72%	 测试集准确率:98.79%
当前epoch:2[32000/60000(53%)]	 损失:0.007487	 训练集准确率:98.72%	 测试集准确率:98.57%
当前epoch:2[38400/60000(64%)]	 损失:0.015440	 训练集准确率:98.74%	 测试集准确率:98.81%
当前epoch:2[44800/60000(75%)]	 损失:0.006676	 训练集准确率:98.73%	 测试集准确率:98.84%
当前epoch:2[51200/60000(85%)]	 损失:0.034487	 训练集准确率:98.72%	 测试集准确率:98.85%
当前epoch:2[57600/60000(96%)]	 损失:0.042631	 训练集准确率:98.73%	 测试集准确率:98.73%

锐单商城拥有海量元器件数据手册IC替代型号,打造电子元器件IC百科大全!

相关文章