网站首页 > 技术文章 正文
YOLO-NAS 是 Deci 开发的一种新的最先进的目标检测模型。 在本指南中,我们将讨论什么是 YOLO-NAS 以及如何在自定义数据集上训练 YOLO-NAS 模型。
在线工具推荐: Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - 3D场景编辑器
为了训练我们的自定义模型,我们将:
- 加载预训练的YOLO-NAS模型;
- 从 Roboflow 加载自定义数据集,或者使用UnrealSynth制作合成数据集
- 设置超参数值;
- 使用超级梯度 Python 包根据我们的数据训练模型;
- 评估模型以了解结果。
话不多说,让我们开始吧!
1、什么是 YOLO-NAS?
You Only Look Once??神经架构搜索(YOLO-NAS)是最新最先进的(SOTA)实时目标检测模型。 在 COCO 数据集上进行评估并与其前身 YOLOv6 和 YOLOv8? 相比,YOLO-NAS 以更低的延迟实现了更高的 mAP 值。
YOLO-NAS 作为 Deci 维护的 super-gradient包的一部分提供。
下图展示了Deci在YOLO-NAS上的基准测试结果:
YOLO-NAS 与其他顶级实时检测器在 COCO 数据集上的性能对比图
YOLO-NAS 在 Roboflow 100 数据集基准测试中也是最好的,这表明它可以轻松地在自定义数据集上进行微调。
YOLO-NAS 和其他顶级实时检测器在 RF100 数据集上的性能对比图
2、Python环境设置
在开始训练之前,我们需要准备好Python环境。 让我们从安装三个 pip 包开始。 YOLO-NAS 模型本身是使用 super-gradient 包进行分发的。 请记住,该模型仍在积极开发中。 为了保持环境的稳定性,最好固定特定版本的包。 此外,我们将安装 roboflow 和监督,这将使我们能够从 Roboflow Universe 下载数据集并分别可视化我们的训练结果。
pip install super-gradients==3.1.1
pip install roboflow
pip install supervision
如果你在 Jupyter Notebook 中运行 YOLO-NAS,请不要忘记在安装完成后重新启动环境。
3、使用预训练模型进行推理
在开始培训之前,最好确保安装按计划进行。 最简单的方法是使用预先训练的模型之一进行测试推理。 同时,这也能让我们熟悉YOLO-NAS API。
3.1 加载YOLO-NAS模型
为了使用预训练的 COCO 模型进行推理,我们首先需要选择模型的大小。 YOLO-NAS提供三种不同的模型大小:yolo_nas_s、yolo_nas_m和yolo_nas_l。
yolo_nas_s 模型是最小且最快的,但它可能不会像较大的模型那么准确。 相反,yolo_nas_l 模型最大、最准确、最慢。 yolo_nas_m 模型提供了两者之间的中间立场。
import torch
from super_gradients.training import models
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
MODEL_ARCH = 'yolo_nas_l'
# 'yolo_nas_m'
# 'yolo_nas_s'
model = models.get(MODEL_ARCH, pretrained_weights="coco").to(DEVICE)
3.2 YOLO-NAS模型推理
推理过程包括设置置信度阈值和调用预测方法。 预测方法将返回预测列表,其中每个预测对应于图像中检测到的对象。
CONFIDENCE_TRESHOLD = 0.35
result = list(model.predict(image, conf=CONFIDENCE_TRESHOLD))[0]
3.3 YOLO-NAS 推理输出格式
YOLO-NAS 推理的输出是一个 ImageDetectionPrediction 对象,它封装了图像中检测到的对象的详细信息。 该对象包含三个字段:
- image - 表示用于推理的图像的 NumPy 数组。
- class_names - 模型训练期间使用的类别名称的 Python 列表。
- Prediction -DetectionPrediction 类的实例,其中包含有关模型检测的详细信息。
DetectionPrediction对象具有三个字段:
- bboxes_xyxy - 形状 (N, 4) 的 NumPy 数组,以 xyxy 格式表示检测到的对象的边界框。
- confidence - 形状 (N,) 的 NumPy 数组,表示检测的置信度值。 每个值都在 0 和 1 之间。
- labels - 形状 (N,) 的 NumPy 数组,表示检测到的对象的类 ID。 每个类 ID 对应于 class_names 列表中的一个索引。
4、使用开源数据集微调 YOLO-NAS
为了微调模型,我们需要数据。 我们将使用足球运动员检测图像数据集。
如果你已经有 YOLO 格式的数据集,请随意使用它。 如果没有,请看看 Roboflow Universe,那里拥有超过 200,000 个开源项目,并且所有项目都可以以任何格式导出。
另外一种获取数据集的方法是使用UnrealSynth,一个基于虚幻引擎开发的YOLO合成数据生成器,可以自动生成包括标注的训练数据集,非常方便:
import roboflow
from roboflow import Roboflow
roboflow.login()
rf = Roboflow()
project = rf.workspace(WORKSPACE_ID).project(PROJECT_ID)
dataset = project.version(PROJECT_VERSION).download("yolov5")
要训练 YOLO-NAS 模型,你需要设置几个关键参数。
首先,你需要选择模型尺寸。 有三个选项可供选择:小型、中型和大型。 请记住,较大的模型可能需要更长的时间来训练并需要更多的内存,因此如果使用的资源有限,你可能需要考虑使用较小的模型。
接下来,你需要设置批量大小。 该参数指示在训练过程的每次迭代期间将有多少图像通过神经网络。 较大的批量大小将加快训练过程,但也需要更多的内存。
MODEL_ARCH = 'yolo_nas_l'
BATCH_SIZE = 8
MAX_EPOCHS = 25
CHECKPOINT_DIR = f'{HOME}/checkpoints'
EXPERIMENT_NAME = project.name.lower().replace(" ", "_")
LOCATION = dataset.location
CLASSES = sorted(project.classes.keys())
dataset_params = {
'data_dir': LOCATION,
'train_images_dir':'train/images',
'train_labels_dir':'train/labels',
'val_images_dir':'valid/images',
'val_labels_dir':'valid/labels',
'test_images_dir':'test/images',
'test_labels_dir':'test/labels',
'classes': CLASSES
}
from super_gradients.training.dataloaders.dataloaders import (
coco_detection_yolo_format_train, coco_detection_yolo_format_val)
train_data = coco_detection_yolo_format_train(
dataset_params={
'data_dir': dataset_params['data_dir'],
'images_dir': dataset_params['train_images_dir'],
'labels_dir': dataset_params['train_labels_dir'],
'classes': dataset_params['classes']
},
dataloader_params={
'batch_size': BATCH_SIZE,
'num_workers': 2
}
)
val_data = coco_detection_yolo_format_val(
dataset_params={
'data_dir': dataset_params['data_dir'],
'images_dir': dataset_params['val_images_dir'],
'labels_dir': dataset_params['val_labels_dir'],
'classes': dataset_params['classes']
},
dataloader_params={
'batch_size': BATCH_SIZE,
'num_workers': 2
}
)
最后,你需要设置训练过程的纪元数。 这本质上是整个数据集通过神经网络的次数。
5、训练自定义 YOLO-NAS 模型
你可能已经注意到,训练模型的过程比 YOLOv8 更加冗长。 Ultralytics 模型中的许多功能需要在 CLI 中传递参数,而对于 YOLO-NAS,则需要编写自定义逻辑。
最后,我们准备开始训练。 在调用 train 方法之前,值得运行 TensorBoard。 这将使我们能够实时跟踪培训的关键指标。 值得一提的是,YOLO-NAS还支持W&B等最流行的实验记录仪。
YOLO-NAS 训练期间获得的指标图
trainer.train(
model=model,
training_params=train_params,
train_loader=train_data,
valid_loader=val_data
)
6、评估自定义 YOLO-NAS 模型
训练结束后,你可以使用Trainer提供的测试方法评估模型的性能。 你需要传入测试集数据加载器,训练器将返回一个指标列表,包括通常用于评估对象检测模型的平均精度(mAP)。
trainer.test(
model=best_model,
test_loader=test_data,
test_metrics_list=DetectionMetrics_050(
score_thres=0.1,
top_k_predictions=300,
num_cls=len(dataset_params['classes']),
normalize_targets=True,
post_prediction_callback=PPYoloEPostPredictionCallback(
score_threshold=0.01,
nms_top_k=1000,
max_predictions=300,
nms_threshold=0.7
)
)
)
此外,你可以对测试集图像进行推理并可视化结果,以更好地了解模型在各个示例上的表现。 你还可以计算混淆矩阵,以更详细地了解每个类别的模型性能:
模型评估过程中创建的混淆矩阵
7、结束语
一夜之间,YOLO-NAS 成为实时物体检测器的新选择。 在为你的项目微调模型时,请记住要考虑所有方面——从模型准确性到推理速度,再到易于训练和许可限制。
原文链接:http://www.bimant.com/blog/train-yolo-nas-on-custom-dataset/
猜你喜欢
- 2024-10-19 物体分割检测YOLO4算法环境配置 yolo实例分割
- 2024-10-19 Yolo框架优化:黑夜中也可以实时目标检测,
- 2024-10-19 目标检测开源框架YOLOv6全面升级,更快更准的2.0版本来啦
- 2024-10-19 YOLOv5全面解析教程⑤:计算mAP用到的Numpy函数详解
- 2024-10-19 Q-YOLO:用于实时目标检测的高效推理
- 2024-10-19 Drone-YOLO:一种有效的无人机图像目标检测
- 2024-10-19 PE-YOLO:解决黑夜中的目标检测难点
- 2024-10-19 大改Yolo框架 | 能源消耗极低的目标检测新框架(附论文下载)
- 2024-10-19 Yolo框架大改 | 消耗极低的目标检测新框架(附论文下载)
你 发表评论:
欢迎- 11-18软考系统分析师知识点十六:系统实现与测试
- 11-18第16篇 软件工程(四)过程管理与测试管理
- 11-18编程|实例(分书问题)了解数据结构、算法(穷举、递归、回溯)
- 11-18算法-减治法
- 11-18笑疯了!巴基斯坦首金!没有技巧全是蛮力!解说:真远啊!笑死!
- 11-18搜索算法之深度优先、广度优先、约束条件、限界函数及相应算法
- 11-18游戏中的优化指的的是什么?
- 11-18算法-分治法
- 最近发表
- 标签列表
-
- oraclesql优化 (66)
- 类的加载机制 (75)
- feignclient (62)
- 一致性hash算法 (71)
- dockfile (66)
- 锁机制 (57)
- javaresponse (60)
- 查看hive版本 (59)
- phpworkerman (57)
- spark算子 (58)
- vue双向绑定的原理 (68)
- springbootget请求 (58)
- docker网络三种模式 (67)
- spring控制反转 (71)
- data:image/jpeg (69)
- base64 (69)
- java分页 (64)
- kibanadocker (60)
- qabstracttablemodel (62)
- java生成pdf文件 (69)
- deletelater (62)
- com.aspose.words (58)
- android.mk (62)
- qopengl (73)
- epoch_millis (61)
本文暂时没有评论,来添加一个吧(●'◡'●)