centernet代码解读-mutipose
时间:2022-11-24 01:00:01
前言:
centernet目标现目标检测,估计人体姿势3D测试,本文解释了人体姿势估计代码。本文的核心思想是使用热图来表示物体的关键点和中心点,并使用中心点到关键点的矢量来保证测试结果。
1、数据集
数据集使用coco2017中的person_keypoints_{}2017.json,使用目标检测image_info_test-dev2017.json,同一个coco数据集有不同类型的标记文件,根据标记文件中的不同图片id图片可以在同一个图片文件夹中找到。
coco_hp.py
在init在函数中定义一些 a、在datasets/sample/multi_pose.py变量使用的变量和值;b、读取标记文件的路径和名称 (self._eig_valself._eig_vec不知道是什么作用)
multi_pose.py制作真值
a、根据图片和标注文件的路径读取 b、在getitem将标记文件中的真值处理成我们需要的数据存储模式,c、做了一些数据增强操作,如翻转及仿射变换 d、将图片除以255,使图片值变为0到1之间的值 e、归一化,减均值,除方差
ret = {'input': inp, 'hm': hm, 'reg_mask': reg_mask, 'ind': ind, 'wh': wh 'hps': kps, 'hps_mask': kps_mask} #图片,热力图,int_float回归mask,index,物体宽度高,关键点热图,关键点
inp:图片通过cv2.imread()读取图片,除255外,减均值除方差
hm :建立一个与输入图片相同宽度的数组,其中值是物体中心点的热图。热图是怎么来的?1.根据高斯半径和物体中心点,通过物体宽度和高度的高斯半径hm及函数draw_gaussian画出属于这个物体的热力图。
reg_mask:#第k对象的偏差mask=1,初始值为0 reg物体中心点int到float的偏差
ind:中心点所在的宽高拉成1维索引
wh:物体宽高
kps:中心点到关键点之间的距离
kps_mask:初始值为0,有关键点的地方值为1
2、模型
根据opt.py配置文件中的配置--arch,该任务使用哪种网络结构作为决定backbone,共有“res_18 | res_101 | resdcn_18 | resdcn_101 |''dlav0_34 | dla_34 | hourglass'”七种。
以resnet101为例,在main.py中,调用了createmodel,在createmodel中又调用了_model_factory及getmodel。在_model_factory中传入opt中写的模型名通过_model_factory的getmodel调用真正的建模方法,如get_pose_net。在get_pose_net中构建PoseResNet类实例,同时构建网络模型get_pose_net初始化模型参数。
初始化:如果模型层是卷积,则采用正态分布初始化,如果偏置,则采用常数0初始化;如果是BN层,weight使用常数1初始化,bias使用0初始化……
def init_weights(self, num_layers, pretrained=True): if pretrained: # print('=> init resnet deconv weights from normal distribution') for _, m in self.deconv_layers.named_modules(): if isinstance(m, nn.ConvTranspose2d): # print('=> init {}.weight as normal(0, 0.001)'.format(name)) # print('=> init {}.bias as 0'.format(name)) nn.init.normal_(m.weight, std=0.001) if self.deconv_with_bias: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): # print('=> init {}.weight as 1'.format(name)) # print('=> init {}.bias as 0'.format(name)) nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) ……
3、训练
在for epoch循环训练从中开始。使用trainer.train(),MutiposeTrainer继承BaseTrainer,BaseTrainer.train.调用runepoch,在runepoch中主要调用model_with_loss(batch)获得模型的检测结果和loss
获取模型结果和计算loss
获得模型结果和loss用这句话
output, loss, loss_stats = model_with_loss(batch)
loss调用 MultiPoseTrainer中的__get_losses,再次调用此函数MultiPoseLoss类完成整个loss的计算。ModelWithLoss中在BaseTrainer获得初中初始化模型及配置参数opt,调用ModelWithLoss类,在此类中输入模型,并将结果输入loss得到loss。
class BaseTrainer(object): def __init__( self, opt, model, optimizer=None): self.opt = opt self.optimizer = optimizer self.loss_stats, self.loss = self._get_losses(opt) self.model_with_loss = ModelWithLoss(model, self.loss) class ModelWithLoss(torch.nn.Module): def __init__(self, model, loss): super(ModelWithLoss, self).__init__() self.model = model self.loss = loss def forward(self, batch): outputs = self.model(batch['input']) loss, loss_stats = self.loss(outputs, batch) return outputs[-1], loss, loss_stats class MultiPoseLoss(torch.nn.Module): def __init__(self, opt): super(MultiPoseLoss, self).__init__() self.crit = FocalLoss() self.crit_hm_hp = torch.nn.MSELoss() if opt.mse_loss else FocalLoss() self.crit_kp = RegWeightedL1Loss() if not opt.dense_hp else \ torch.nn.L1Loss(reduction='sum') self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \ RegLoss() if opt.reg_loss == 'sl1' else None self.opt = opt def forward(self, outputs, batch): ………………………………………… ……………………………… loss = opt.hm_weight * hm_loss opt.wh_weight * wh_loss \ opt.off_weight * off_loss opt.hp_weight * hp_loss \ opt.hm_hp_weight *hm_hp_loss + opt.off_weight * hp_offset_loss
loss_stats = {'loss': loss, 'hm_loss': hm_loss, 'hp_loss': hp_loss,
'hm_hp_loss': hm_hp_loss, 'hp_offset_loss': hp_offset_loss,
'wh_loss': wh_loss, 'off_loss': off_loss}
return loss, loss_stats
class MultiPoseTrainer(BaseTrainer):
def __init__(self, opt, model, optimizer=None):
super(MultiPoseTrainer, self).__init__(opt, model, optimizer=optimizer)
def _get_losses(self, opt):
loss_states = ['loss', 'hm_loss', 'hp_loss', 'hm_hp_loss',
'hp_offset_loss', 'wh_loss', 'off_loss']
loss = MultiPoseLoss(opt)
return loss_states, loss
4、decode