锐单电子商城 , 一站式电子元器件采购平台!
  • 电话:400-990-0325

centernet代码解读-mutipose

时间:2022-11-24 01:00:01 hps拉绳位稳传感器

前言:

centernet目标现目标检测,估计人体姿势3D测试,本文解释了人体姿势估计代码。本文的核心思想是使用热图来表示物体的关键点和中心点,并使用中心点到关键点的矢量来保证测试结果。

1、数据集

数据集使用coco2017中的person_keypoints_{}2017.json,使用目标检测image_info_test-dev2017.json,同一个coco数据集有不同类型的标记文件,根据标记文件中的不同图片id图片可以在同一个图片文件夹中找到。

coco_hp.py

在init在函数中定义一些 a、在datasets/sample/multi_pose.py变量使用的变量和值;b、读取标记文件的路径和名称 (self._eig_valself._eig_vec不知道是什么作用)

multi_pose.py制作真值

a、根据图片和标注文件的路径读取 b、在getitem将标记文件中的真值处理成我们需要的数据存储模式,c、做了一些数据增强操作,如翻转及仿射变换 d、将图片除以255,使图片值变为0到1之间的值 e、归一化,减均值,除方差

    ret = {'input': inp, 'hm': hm, 'reg_mask': reg_mask, 'ind': ind, 'wh': wh            'hps': kps, 'hps_mask': kps_mask}  #图片,热力图,int_float回归mask,index,物体宽度高,关键点热图,关键点 

inp:图片通过cv2.imread()读取图片,除255外,减均值除方差

hm :建立一个与输入图片相同宽度的数组,其中值是物体中心点的热图。热图是怎么来的?1.根据高斯半径和物体中心点,通过物体宽度和高度的高斯半径hm及函数draw_gaussian画出属于这个物体的热力图。

reg_mask:#第k对象的偏差mask=1,初始值为0 reg物体中心点int到float的偏差

ind:中心点所在的宽高拉成1维索引

wh:物体宽高

kps:中心点到关键点之间的距离

kps_mask:初始值为0,有关键点的地方值为1

2、模型

根据opt.py配置文件中的配置--arch,该任务使用哪种网络结构作为决定backbone,共有“res_18 | res_101 | resdcn_18 | resdcn_101 |''dlav0_34 | dla_34 | hourglass'”七种。

以resnet101为例,在main.py中,调用了createmodel,在createmodel中又调用了_model_factory及getmodel。在_model_factory中传入opt中写的模型名通过_model_factory的getmodel调用真正的建模方法,如get_pose_net。在get_pose_net中构建PoseResNet类实例,同时构建网络模型get_pose_net初始化模型参数。

初始化:如果模型层是卷积,则采用正态分布初始化,如果偏置,则采用常数0初始化;如果是BN层,weight使用常数1初始化,bias使用0初始化……

def init_weights(self, num_layers, pretrained=True):         if pretrained:             # print('=> init resnet deconv weights from normal distribution')             for _, m in self.deconv_layers.named_modules():                 if isinstance(m, nn.ConvTranspose2d):                     # print('=> init {}.weight as normal(0, 0.001)'.format(name))                     # print('=> init {}.bias as 0'.format(name))                     nn.init.normal_(m.weight, std=0.001)                     if self.deconv_with_bias:                         nn.init.constant_(m.bias, 0)                 elif isinstance(m, nn.BatchNorm2d):                     # print('=> init {}.weight as 1'.format(name))                     # print('=> init {}.bias as 0'.format(name))                     nn.init.constant_(m.weight, 1)                     nn.init.constant_(m.bias, 0)            ……

3、训练

在for epoch循环训练从中开始。使用trainer.train(),MutiposeTrainer继承BaseTrainer,BaseTrainer.train.调用runepoch,在runepoch中主要调用model_with_loss(batch)获得模型的检测结果和loss

获取模型结果和计算loss

获得模型结果和loss用这句话

output, loss, loss_stats = model_with_loss(batch)

loss调用 MultiPoseTrainer中的__get_losses,再次调用此函数MultiPoseLoss类完成整个loss的计算。ModelWithLoss中在BaseTrainer获得初中初始化模型及配置参数opt,调用ModelWithLoss类,在此类中输入模型,并将结果输入loss得到loss。

class BaseTrainer(object):   def __init__(     self, opt, model, optimizer=None):     self.opt = opt     self.optimizer = optimizer     self.loss_stats, self.loss = self._get_losses(opt)     self.model_with_loss = ModelWithLoss(model, self.loss) class ModelWithLoss(torch.nn.Module):   def __init__(self, model, loss):     super(ModelWithLoss, self).__init__()     self.model = model     self.loss = loss      def forward(self, batch):     outputs = self.model(batch['input'])     loss, loss_stats = self.loss(outputs, batch)     return outputs[-1], loss, loss_stats class MultiPoseLoss(torch.nn.Module):   def __init__(self, opt):     super(MultiPoseLoss, self).__init__()     self.crit = FocalLoss()     self.crit_hm_hp = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()     self.crit_kp = RegWeightedL1Loss() if not opt.dense_hp else \                    torch.nn.L1Loss(reduction='sum')     self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \                     RegLoss() if opt.reg_loss == 'sl1' else None     self.opt = opt        def forward(self, outputs, batch):      …………………………………………     ………………………………     loss = opt.hm_weight * hm_loss   opt.wh_weight * wh_loss   \            opt.off_weight * off_loss   opt.hp_weight * hp_loss   \            opt.hm_hp_weight *hm_hp_loss + opt.off_weight * hp_offset_loss
    loss_stats = {'loss': loss, 'hm_loss': hm_loss, 'hp_loss': hp_loss, 
                  'hm_hp_loss': hm_hp_loss, 'hp_offset_loss': hp_offset_loss,
                  'wh_loss': wh_loss, 'off_loss': off_loss}
    return loss, loss_stats

class MultiPoseTrainer(BaseTrainer):
  def __init__(self, opt, model, optimizer=None):
    super(MultiPoseTrainer, self).__init__(opt, model, optimizer=optimizer)
  
  def _get_losses(self, opt):
    loss_states = ['loss', 'hm_loss', 'hp_loss', 'hm_hp_loss', 
                   'hp_offset_loss', 'wh_loss', 'off_loss']
    loss = MultiPoseLoss(opt)
    return loss_states, loss

4、decode

锐单商城拥有海量元器件数据手册IC替代型号,打造电子元器件IC百科大全!

相关文章