网站首页 > 技术文章 正文
最近使用efficientdet-tf2目标检测算法完成了一个目标检测项目,涉及到数据集的创建、数据集格式转换、深度学习模型训练、使用模型进行目标预测等,中间踩了很多的坑,为避免后续继续采坑,故写下此文作为总结。
一、efficientdet-tf2目标检测算法概述
efficientdet-tf2是Google新推出的基于Tensorflow2平台的目标检测算法,可用于深度学习(Deep Learning)模型的训练,以及使用训练的模型进行目标检测应用(目标预测)。
二、软件环境搭建
本次实验软件环境如下所示:
1)电脑操作系统:WIN10 64位 家庭中文版,CPU为I7-8550U,8G内存;
2)Python:3.7.9 64位 Windows版本(裸装,不需要额外安装Anaconda或者Pycharm),可直接从Python官方网站直接下载;
3)Tensorflow模块:tensorflow-cpu版,2.6版本(当前的最新版本),可在配置了Python的环境变量后,直接在cmd窗口通过如下指令进行安装;
pip install tensorflow-cpu
# 若需指定其它版本,可通过在最后指定版本即可,如:pip install tensorflow-cpu==2.1
#若需安装GPU版本的tensorflow,可直接通过指令安装:pip install tensorflow,该指令默认安装同时支持CPU和GPU加速的版本
#若安装GPU版本的tensorflow,需要先安装NVIDIA的CUDA和cuDNN驱动
4)Python其它支持模块:cv2(pip install opencv-python)、PIL(pip install pillow)、numpy(pip install numpy)等。
三、数据集的创建和转换
efficientdet-tf2需要使用VOC格式的数据集进行模型训练。
为便于对图片进行标记,本项目采用了百度在Labelme基础上进行二次开发的Easyyibiao软件对图片标记检测目标以及添加标签。Easyyibiao软件简单易用,也是开源软件,可直接在GitHub上下载(在必应或百度搜索Easyyibiao,第一条显示的即为该软件)。
Easyyibiao软件使用非常便捷,下载完成后:
1)双击“main.exe”即可打开软件;
2)通过“文件”→“打开目录”,选定待标记的图片所在的文件夹,即可载入图片(注意将图片的后缀名更改为小写,若不更改,后续生成VOC数据集时可能会出现异常,血泪教训));
3)通过“创建矩形”(也可以是其它形状)即可在图片上拖拽矩形框(此时矩形框为绿色边框)标记待进行识别的区域,并会弹出对话框让你输入标签;
添加标签后,绿色矩形边框会变成红色。
4)标记完成后点击“保存”按钮,会弹出对话框以.json的格式保存图片标记信息,注意不要修改保存的图片标记信息的文件名(该文件名与图片的文件名保持一致)。
此处注意:在保存图片标记信息前将“保存图像数据”取消选择(若不更改,后续生成VOC数据集时可能会出现异常,血泪教训)。
所有图片标记完成后,均会生成一一对应的.json格式的labelme数据集,而efficientdet-tf2只能识别VOC数据集(.xml格式),故需要通过工具将labelme数据集转换为VOC数据集,以下为python环境下的数据集转换代码(之所以在此处贴转换代码,是因为自带的数据集转换脚本运行时报错)。
import os
from typing import List, Any
import numpy as np
import codecs
import json
from glob import glob
import cv2
import shutil
from sklearn.model_selection import train_test_split
# 1.标签路径
labelme_path = "original_data/" # 原始labelme标注数据路径
saved_path = "VOC2007/" # 保存路径
isUseTest=True#是否创建test集
# 2.创建要求文件夹
if not os.path.exists(saved_path + "Annotations"):
os.makedirs(saved_path + "Annotations")
if not os.path.exists(saved_path + "JPEGImages/"):
os.makedirs(saved_path + "JPEGImages/")
if not os.path.exists(saved_path + "ImageSets/Main/"):
os.makedirs(saved_path + "ImageSets/Main/")
# 3.获取待处理文件
files = glob(labelme_path + "*.json")
files = [i.replace("\\","/").split("/")[-1].split(".json")[0] for i in files]
print(files)
# 4.读取标注信息并写入 xml
for json_file_ in files:
json_filename = labelme_path + json_file_ + ".json"
json_file = json.load(open(json_filename, "r", encoding="utf-8"))
height, width, channels = cv2.imread(labelme_path + json_file_ + ".jpg").shape
with codecs.open(saved_path + "Annotations/" + json_file_ + ".xml", "w", "utf-8") as xml:
xml.write('<annotation>\n')
xml.write('\t<folder>' + 'WH_data' + '</folder>\n')
xml.write('\t<filename>' + json_file_ + ".jpg" + '</filename>\n')
xml.write('\t<source>\n')
xml.write('\t\t<database>WH Data</database>\n')
xml.write('\t\t<annotation>WH</annotation>\n')
xml.write('\t\t<image>flickr</image>\n')
xml.write('\t\t<flickrid>NULL</flickrid>\n')
xml.write('\t</source>\n')
xml.write('\t<owner>\n')
xml.write('\t\t<flickrid>NULL</flickrid>\n')
xml.write('\t\t<name>WH</name>\n')
xml.write('\t</owner>\n')
xml.write('\t<size>\n')
xml.write('\t\t<width>' + str(width) + '</width>\n')
xml.write('\t\t<height>' + str(height) + '</height>\n')
xml.write('\t\t<depth>' + str(channels) + '</depth>\n')
xml.write('\t</size>\n')
xml.write('\t\t<segmented>0</segmented>\n')
for multi in json_file["shapes"]:
points = np.array(multi["points"])
labelName=multi["label"]
xmin = min(points[:, 0])
xmax = max(points[:, 0])
ymin = min(points[:, 1])
ymax = max(points[:, 1])
label = multi["label"]
if xmax <= xmin:
pass
elif ymax <= ymin:
pass
else:
xml.write('\t<object>\n')
xml.write('\t\t<name>' + labelName+ '</name>\n')
xml.write('\t\t<pose>Unspecified</pose>\n')
xml.write('\t\t<truncated>1</truncated>\n')
xml.write('\t\t<difficult>0</difficult>\n')
xml.write('\t\t<bndbox>\n')
xml.write('\t\t\t<xmin>' + str(int(xmin)) + '</xmin>\n')
xml.write('\t\t\t<ymin>' + str(int(ymin)) + '</ymin>\n')
xml.write('\t\t\t<xmax>' + str(int(xmax)) + '</xmax>\n')
xml.write('\t\t\t<ymax>' + str(int(ymax)) + '</ymax>\n')
xml.write('\t\t</bndbox>\n')
xml.write('\t</object>\n')
print(json_filename, xmin, ymin, xmax, ymax, label)
xml.write('</annotation>')
# 5.复制图片到 VOC2007/JPEGImages/下
image_files = glob(labelme_path + "*.jpg")
print("copy image files to VOC007/JPEGImages/")
for image in image_files:
shutil.copy(image, saved_path + "JPEGImages/")
# 6.split files for txt
txtsavepath = saved_path + "ImageSets/Main/"
ftrainval = open(txtsavepath + '/trainval.txt', 'w')
ftest = open(txtsavepath + '/test.txt', 'w')
ftrain = open(txtsavepath + '/train.txt', 'w')
fval = open(txtsavepath + '/val.txt', 'w')
total_files = glob("./VOC2007/Annotations/*.xml")
total_files = [i.replace("\\","/").split("/")[-1].split(".xml")[0] for i in total_files]
trainval_files=[]
test_files=[]
if isUseTest:
trainval_files, test_files = train_test_split(total_files, test_size=0.15, random_state=55)
else:
trainval_files=total_files
for file in trainval_files:
ftrainval.write(file + "\n")
# split
train_files, val_files = train_test_split(trainval_files, test_size=0.15, random_state=55)
# train
for file in train_files:
ftrain.write(file + "\n")
# val
for file in val_files:
fval.write(file + "\n")
for file in test_files:
print(file)
ftest.write(file + "\n")
ftrainval.close()
ftrain.close()
fval.close()
ftest.close()
四、下载及配置efficientdet-tf2应用框架
直接在必应搜索efficientdet-tf2,选择GitHub中的“bubbliiiing/efficientdet-tf2”。
下载ZIP格式的efficientdet-tf2深度学习目标检测算法应用框架
五、模型训练
- efficientdet-tf2使用VOC数据集进行模型训练,首先将之前准备好的数据集按照如下方式放置:
1)训练前将标签文件(.xml格式)放在VOCdevkit文件夹下的VOC2007文件夹下的Annotation中。
2)训练前将图片文件放在VOCdevkit文件夹下的VOC2007文件夹下的JPEGImages中。
- 利用应用程序框架中的voc2efficientdet.py文件生成对应的txt。
- 再运行根目录下的voc_annotation.py,运行前需要将classes改成你自己的classes。注意不要使用中文标签,文件夹中不要有空格!
classes = ["bird", "fish"]
- 运行完voc_annotation.py后会在根目录下生成对应的2007_train.txt,每一行对应其图片位置及其真实框的位置
- 在训练前需要务必在model_data下新建一个txt文档,文档中输入需要分的类,在train.py中将classes_path指向该文件,示例如下:
classes_path = 'model_data/my_classes.txt'
model_data/new_classes.txt文件内容为标签分类,如:
bird
fish
- 修改train.py的classes_path,根据需要修改Epoch、Batch_Size等参数,运行train.py即可开始训练模型,模型训练时间的长短与电脑配置、tensorflow版本、数据集的大小以及train.py中的配置的参数有很大的关联(本项目80余图片的数据集,Batch_Size设置为4,Epoch为100,训练时长大概在1个多小时)。
#----------------------------------------------------#
# classes的路径,非常重要
# 训练前一定要修改classes_path,使其对应自己的数据集
#----------------------------------------------------#
classes_path = 'model_data/my_classes.txt'
#------------------------------------------------------#
- 训练完成后,将在logs文件夹下生成训练的模型以及其它日志信息,其中.h5格式的文件即为我们需要的训练好的模型文件。
其中在loss_XXX文件夹下可以查看训练的损失曲线。
六、目标预测应用
- 将训练好的模型文件拷贝至model_data文件夹下(一般拷贝最后生成的那个模型文件);
- 在efficientdet.py文件里面,在如下部分修改model_path、classes_path和phi使其对应训练好的文件;model_path对应logs文件夹下面的权值文件,classes_path是model_path对应分的类。phi为所使用的efficientdet的版本。
#--------------------------------------------#
# 使用自己训练好的模型预测需要修改3个参数
# model_path和classes_path和phi都需要修改!
# 训练时的model_path和classes_path参数的修改
#--------------------------------------------#
class EfficientDet(object):
_defaults = {
"model_path" : father_path + '/model_data/ep077-loss0.139-val_loss0.337.h5',
"classes_path" : father_path + '/model_data/my_classes.txt',
"phi" : 0,
"confidence" : 0.4,
"iou" : 0.3,
}
运行predict.py,输入待预测的图片路径,即可完成目标预测,示例:img/XXX.jpg。
七、总结
通过efficientdet-tf2目标检测应用框架,可以在不需掌握深奥的深度学习算法和专业的机器学习相关理论知识的情况下,通过设定的数据集训练自己的目标检测模型,并应用模型实现目标预测。
猜你喜欢
- 2024-11-07 谷歌AI公开新一代“目标检测”系统
- 2024-11-07 NVIDIA Jetson Nano 2GB 系列文章(53):TAO模型训练工具简介
- 2024-11-07 学会这招,再也不怕下载大文件失败了
- 2024-11-07 Gemini 目标检测能力实测 目标检测nms
- 2024-11-07 2024 年十大物体检测模型 物体检测算法的源代码
- 2024-11-07 基于改进EfficientDet的电力元件及缺陷智能检测方法研究
- 2024-11-07 EfficientDet目标检测谷歌官方终于开源了!
- 2024-11-07 「品览AI论技」精度和速度的极佳平衡——EfficientDet做到了
- 2024-11-07 手把手教物体检测——EfficientDet
- 2024-11-07 比当前SOTA小4倍、计算量少9倍,谷歌最新目标检测器EfficientDet
你 发表评论:
欢迎- 最近发表
- 标签列表
-
- oraclesql优化 (66)
- 类的加载机制 (75)
- feignclient (62)
- 一致性hash算法 (71)
- dockfile (66)
- 锁机制 (57)
- javaresponse (60)
- 查看hive版本 (59)
- phpworkerman (57)
- spark算子 (58)
- vue双向绑定的原理 (68)
- springbootget请求 (58)
- docker网络三种模式 (67)
- spring控制反转 (71)
- data:image/jpeg (69)
- base64 (69)
- java分页 (64)
- kibanadocker (60)
- qabstracttablemodel (62)
- java生成pdf文件 (69)
- deletelater (62)
- com.aspose.words (58)
- android.mk (62)
- qopengl (73)
- epoch_millis (61)
本文暂时没有评论,来添加一个吧(●'◡'●)