SINABS使用
时间:2023-09-10 08:07:01
SINABS是瑞士脑芯片公司aiCTX开源的SNN仿真库,https://sinabs.ai详细介绍了该库。
1.Spiking Neurons简介
1.1人工神经元(Artificial Neuron)模型
人工神经元模型可以表示y=f(Wx b),f是非线性激活函数。
人工神经元的输出直接依赖于输入,神经元不影响其输出的内部状态;脉冲神经网络SNN增加了神经元对当前状态的额外依赖。
1.2 LIF(Leaky Integrate and Fire)模型
LIF模型是一个简单的脉冲神经元(Spiking Neuron)神经元的状态通常是膜电位v,神经元模型取决于输入和以前的状态,表示如下:
其中tau作为膜时间常数,定义了神经元对以前状态和输入的依赖;该模型描述了泄漏动力学系统,随着时间的推移,膜电位V缓慢泄漏为0LIF(Leaky Integrate and Fire)模型的起源。公式中Isyn=Wx,是所有输入突触的加权和,Ibias输入偏置。R用于膜电位与电流电位的匹配,表示常数电阻。
当v达到阈值时,神经元的输出是二值和瞬时的vth当神经元输出输出1时,膜电位立即重置vreset(该值小于阈值vth)。所以v可以表示为:
1被称为时间t的尖峰,一系列的尖峰被称为s(t):
随着时间的推移,膜电位的轨迹如下:
1.3 IAF(Constant-leak Integrate and Fire)模型
LIF模型描述了动态系统,v是不断更新的,为了便于计算,假设膜电位只有一个恒定的泄漏,为IAF模型,表述如下:
vleak常数可与偏差结合,上式可简化为:
IAF随着时间的推移,膜电位的轨迹是:
1.4 激活函数
对于脉冲神经网络,激活函数的定义更为复杂,因为脉冲神经网络的输出与输入x和自身状态有关,神经元的输出是一系列尖峰和狄拉克三角函数系列。神经元对输入的反应可以通过观察神经元在一个时间窗口中产生的尖峰数来解释,因此神经元的峰值速率被称为速度编码。
基于脉冲神经元的速率编码解释,IAF神经元模型的传递函数如下:
因此IAF神经元的传递函数等于ReLU激活函数。
2.SINABS提供的例程
2.1 MNIST
本例程主要介绍如何建立脉冲神经网络:
- 使用Pytorch将一个卷积神经网络定义为三个卷积层,一个全连接层和一个输出层
import torch.nn as nn ann = nn.Sequential( nn.Conv2d(1, 20, 5, 1, bias=False), nn.ReLU(), nn.AvgPool2d(2,2), nn.Conv2d(20, 32, 5, 1, bias=False), nn.ReLU(), nn.AvgPool2d(2,2), nn.Conv2d(32, 128, 3, 1, bias=False), nn.ReLU(), nn.AvgPool2d(2,2), nn.Flatten(), nn.Linear(128, 500, bias=False), nn.ReLU(), nn.Linear(500, 10, bias=False), )
- 加载数据(例程目的是模拟一个SNN网络,网络返回一系列脉冲)
import numpy as np from PIL import Image from torchvision import datasets class MNIST_Dataset(datasets.MNIST): def __init__(self, root, train = True, spiking=False, tWindow=100): datasets.MNIST.__init__(self, root, train=train, download=True) self.spiking=spiking self.tWindow = tWindow def __getitem__(self, index): img, target = self.data[index], self.targets[index] if self.spiking: img = (np.random.rand(self.tWindow, 1, *img.size()) < img.numpy()/255.0).astype(float) img = torch.from_numpy(img).float() else: # Convert PIL image to tensor img = torch.from_numpy(img.numpy()).float() img.unsqueeze_(0) return img, target
-
加载训练网络train_loader时设置spiking=False,即使用常规的费脉冲图像训练模型,使用Adam优化器,学习率为0.0001,损失函数为交叉熵损失;这个例程只训练了3个epochs以作示意。
from torch.utils.data import DataLoader # Define test dataset loader train_loader = DataLoader( MNIST_Dataset('./data', train=True, spiking=False), batch_size=128, shuffle=True) import tqdm import torch import torch.nn.functional as F import torch.optim as optim try: # Load a pre-trained model to save time if you have already have one. ann.load_state_dict(torch.load("mnist_params.pt")) except: # Train the model ann.train() optim = torch.optim.Adam(ann.parameters(), lr=1e-4) n_epochs = 3 for n in tqdm.notebook.tqdm(range(n_epochs)): pbar = tqdm.notebook.tqdm(train_loader) # Iterate over data for data, target in pbar: data, target = data.to(device), target.to(device) output = ann(data) optim.zero_grad() # Add loss to the total loss loss = F.cross_entropy(output, target) # Propagate loss backwards loss.backward() # Update weights optim.step() # get the index of the max log-probability pred = output.argmax(dim=, keepdim=True)
# Compute the total correct predictions
correct = pred.eq(target.view_as(pred)).sum().item()
pbar.set_postfix({"loss": loss.item(), "accuracy": correct/(len(target))})
# Save model parameters
torch.save(ann.state_dict(), "mnist_params.pt")
- 下面介绍如何将CNN结构转换为SCNN结构,sinabs中的from_model函数可以将CNN转换为SCNN
from sinabs.from_torch import from_model
input_shape = (1, 28, 28)
sinabs_model = from_model(ann, input_shape=input_shape, add_spiking_output=True)
- 对sinabs_model进行可视化
sinabs_model.spiking_model
- 进行一个简单的测试,其中数据集加载函数中的spiking应当设置为True,其中tWindow是一个十分重要的参数,应当多次进行调试
# Time window per sample
tWindow = 200 # ms (or) time steps
# Define test dataset loader
test_spike_loader = torch.utils.data.DataLoader(
MNIST_Dataset('./data', train=False, spiking=True, tWindow=tWindow),
batch_size=1, shuffle=False)
test(sinabs_model, test_spike_loader, num_batches=200)
- sinabs集成了计算突触操作总数的方法:
sinabs_model.get_synops()
2.2 使用BPTT进行训练
BPTT一般在训练卷积神经网络中使用,与普通的神经网络不同,SNN的内部状态会持续一段时间,即使网络不是循环出现的,仍然可以通过膜电位的持续性记忆其先前的处理步骤。
- 定义数据集,其中像素值介于0-1之间,将这些值转化为峰值概率
from torchvision import datasets
import torch
torch.manual_seed(0)
class MNIST_Dataset(datasets.MNIST):
def __init__(self, root, train=True, single_channel=False):
datasets.MNIST.__init__(self, root, train=train, download=True)
self.single_channel = single_channel
def __getitem__(self, index):
img, target = self.data[index], self.targets[index]
img = img.float() / 255.
# default is by row, output is [time, channels] = [28, 28]
# OR if we want by single item, output is [784, 1]
if self.single_channel:
img = img.reshape(-1).unsqueeze(1)
spikes = torch.rand(size=img.shape) < img
spikes = spikes.float()
return spikes, target
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 64
dataset_test = MNIST_Dataset(root="./data/", train=False)
dataloader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=BATCH_SIZE, drop_last=True)
dataset = MNIST_Dataset(root="./data/", train=True)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=BATCH_SIZE, drop_last=True)
- 训练一个baseline
from torch import nn
ann = nn.Sequential(
nn.Linear(28, 128),
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU(),
nn.Linear(128, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, 10),
nn.ReLU()
)
from tqdm.notebook import tqdm
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(ann.parameters())
for epoch in range(5):
pbar = tqdm(dataloader)
for img, target in pbar:
optimizer.zero_grad()
target = target.unsqueeze(1).repeat([1, 28])
img = img.reshape([-1, 28])
target = target.reshape([-1])
out = ann(img)
# out = out.sum(1)
loss = criterion(out, target)
loss.backward()
optimizer.step()
pbar.set_postfix(loss=loss.item())
- 定义一个SNN并训练,网络状态必须在每次迭代时重置
from sinabs.from_torch import from_model
model = from_model(ann, batch_size=BATCH_SIZE).to(device)
model = model.train()
from tqdm.notebook import tqdm
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(10):
pbar = tqdm(dataloader)
for img, target in pbar:
optimizer.zero_grad()
model.reset_states()
out = model.spiking_model(img.to(device))
# the output of the network is summed over the 28 time steps (rows)
out = out.sum(1)
loss = criterion(out, target.to(device))
loss.backward()
optimizer.step()
pbar.set_postfix(loss=loss.item())
- 测试
accs = []
pbar = tqdm(dataloader_test)
for img, target in pbar:
model.reset_states()
out = model(img.to(device))
out = out.sum(1)
predicted = torch.max(out, axis=1)[1]
acc = (predicted == target.to(device)).sum().cpu().numpy() / BATCH_SIZE
accs.append(acc)
print(sum(accs)/len(accs))