计算机系统应用教程网站

网站首页 > 技术文章 正文

YOLOv3网络构建

btikc 2024-09-01 15:29:31 技术文章 8 ℃ 0 评论

参数解析

首先我们先来看一下,在训练模型的时候,我们需要准备的参数有哪些

创建一个config.py文件用于存放参数

config.py

# cuda 是否使用GPU 没有则设置为False
   cuda = True

   '''
      distributed : 用于指定是否使用单机多卡卡分布式运行
  '''
   distributed = False

   # 是否使用sync_bn,DDP模式多卡可用
   sync_bn = False

   # 是否使用混合精度训练 可减少一般的显存,需要pyrotch1.7.1以上
   fp16 = False

   # 指向model_data下的txt,与自己训练的数据相关
   classes_path = 'model_data/yolo_anchors.txt'

   # anchors_path 代表先验框对应的txt文件,一般不修改
   anchors_path = 'model_data/yolo_anchors.txt'

   # anchors_mask 用于帮助代码找到对应的先验框 一般不修改
   anchors_mask = [[6,7,8],[3,4,5],[0,1,2]]

   model_path = 'model_data/yolo_weights.pth'

   # input_shape   输入的shape大小
   input_shape = [416,416]

   # pretrained   是否使用主干网络的预训练权重,此处使用的是主干的权重,因此是在模型构建的时候进行加载的
   pretrained = False

   Init_Epoch = 0
   Freeze_Epoch = 50
   Freeze_batch_size = 16

   UnFreeze_Epoch = 300
   UnFreeze_batch_size = 8

   # 是否进行冻结训练
   Freeze_Train = True

   # 模型最大学习率
   Init_lr = 1e-2
   # 模型的最小学习率,默认为最大学习率的0.01
   Min_lr = Init_lr * 0.01

   optimzer_type = 'sgd'
   momentum = 0.937
   # 权重衰减,防止过拟合
   weight_decay = 5e-4

   # 使用学习率下降方式,可选的有step(等间隔的学习率),cos(余弦退火)
   lr_decay_type = 'cos'

   # 多少个epoch并保存一次权值
   save_period = 10

   # 权值与日志文件保存的文件夹
   save_dir = 'logs'

   # 用于设置是否使用多线程读取数据
   num_workers = 4

   # 获得图片路径和标签
   train_annotation_path = '2007_train.txt'
   val_annotatio_path = '2007_val.txt'

设置好参数之后,我们在根目录下创建一个train.py文件,首先导入相关依赖,并设置用到的显卡

import os
import numpy as np

import torch
import torch.backends.cuda as cudnn
import torch.distributed as dist
import torch.nn as nn
from torch import nn,optim
from torch.utils.data import DataLoader
from config import *

if __name__ == '__main__':
  # 设置用到的显卡
   ngpus_per_node = torch.cuda.device_count()
   if distributed:  # 判断是否使用sync_bn,DDP模式多卡可用
       # 初始化函数
       # backend 通信后端,可选的包括:nccl(NVIDIA推出)、gloo(Facebook推出)、mpi(OpenMPI)
       dist.init_process_group(backend='nccl')
       #local_rank 指在一个node上进程的相对序号,local_rank在node之间相互独立
       local_rank = int(os.environ['LOCAL_RANK'])

       # rank 进程号,在多进程上下文中,我们通常假定rank 0是第一个进程或者主进程,其它进程分别具有1,2,3不同rank号,
       # 这样总共具有4个进程
       rank = int(os.environ['RANK'])
       device = torch.device('cuda',local_rank)

       if local_rank == 0:  
           print(f'[{os.getpid()}] (rank = {rank}, local_rank = {local_rank}) training...')
           print('Gpu Device Count:',ngpus_per_node)
   else:
       device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
       local_rank = 0

构建网络

接下来我们需要完成网络结构的构建,构建的代码在前面我们已经写好了,只需要在这里进行调用即可

  # 获classes和anchor
   class_names,num_classes = get_classes(classes_path)
   anchors,num_anchors = get_anchors(anchors_path)

   # 构建网络结构
   model = YoloBody(anchors_mask,num_classes,pretrained=pretrained)
   print(model)

然后我们就需要判断是否使用主干网络的预训练权重,如果使用则加载预训练权重,否则需要进行初始化

权重文件下载地址

链接: https://pan.baidu.com/s/1hCV4kg8NyStkywLiAeEr3g 提取码: 6da3

代码如下:

if not pretrained:  # 如果没有使用主干网络预训练权重,则初始化
       weights_init(model)
   if model_path != '':   # 判断权重文件是否存在
       if local_rank == 0:
           print('Load weight {}.'.format(model_path))

       # 根据预训练权重的key和模型的key进行加载 返回的结果的key值是对网络参数的说明,value是对应的参数
       model_dict = model.state_dict()
       pretrained_dict = torch.load(model_path,map_location=device)  # 加载模型
       load_key,no_load_key,temp_dict = [],[],{}
?
       for k,v in pretrained_dict.items(): 
           if k in model_dict.keys() and np.shape(model_dict[k]) == np.shape(v):
               temp_dict[k] = v
               load_key.append(k)
           else:
               no_load_key.append(k)
               
         # 显示没有匹配上的key
       if local_rank == 0:
           print("\nSuccessful Load Key:", str(load_key)[:500], "……\nSuccessful Load Key Num:", len(load_key))
           print("\nFail To Load Key:", str(no_load_key)[:500], "……\nFail To Load Key num:", len(no_load_key))
           print("\n\033[1;33;44m温馨提示,head部分没有载入是正常现象,Backbone部分没有载入是错误的。\033[0m")

这里我们需要创建一个yolotraing,py文件,然后实现weights_init()函数

腾科IT教育,提升您的IT价值!了解IT培训、有趣的IT知识干货,免费领取PPT课件,加入学习交流群,可后台私信“福利”获取。

Tags:

本文暂时没有评论,来添加一个吧(●'◡'●)

欢迎 发表评论:

最近发表
标签列表