计算机系统应用教程网站

网站首页 > 技术文章 正文

融合EfficientNet和YoloV5,非常实用的物体检测二阶段pipeline

btikc 2024-09-06 18:15:25 技术文章 11 ℃ 0 评论

作者:Mostafa Ibrahim

编译:ronghuaiyang

导读

使用EfficientNet和YoloV5的融合可以提升20%的performance。

在本文中,我将解释上一篇文章中称之为“2 class filter”的概念。这是一种用于目标检测和分类模型的综合技术,在过去几周我一直在做的Kaggle比赛中被大量使用。几乎所有参加比赛的人都使用了这种技术,它似乎可以提高大约5-25%的性能,这是非常有用的。

目标检测:YoloV5

我们首先在我们的数据集上训练YoloV5模型,同时使用加权框融合(WBF)进行后处理/预处理,如果你想了解更多,我建议查看这两篇文章:

1、Kaggle竞赛中使用YoloV5将物体检测的性能翻倍的心路历程

2、WBF:优化目标检测,融合过滤预测框

我不想再深入讨论使用WBF训练YoloV5的细节。但是,你需要做的基本上就是使用WBF消除重复的框,然后对数据进行预处理,在其上运行YoloV5。YoloV5需要一个特定的层次结构来显示数据集,以便开始训练和评估。

分类:EfficientNet

接下来要做的是在数据集上训练一个分类网络。但是,有趣的一点是,虽然目标检测模型在14个不同的类(13个不同类型的疾病和1个无疾病类)上训练,但我们只在2个类(疾病和无疾病)上训练分类网络。你可以认为这是一种建模方法,简化了我们的分类问题,因为2分类网络比14分类容易得多,当我们融合这两个网络时,我们真的不需要每一个疾病的细节,我们将只需要一个2分类。当然,对于你的问题,这可能有点不同,因此你可能需要试验不同的设置,但是希望你能从本文中获得一些想法。

目前最先进的分类网络之一是EfficientNet。对于这个数据集,我们将使用使用Keras (TensorFlow)训练的B6 EfficientNet,以及这些扩展:

(
    rescale=1.0 / 255,
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode="nearest",
)

集成

这就是使用2分类过滤器来提高性能的原因,也是本文真正要讨论的内容。关于训练YoloV5和EfficientNet,我不想说太多,因为有很多资源可以提供给他们。

我想强调的主要思想是,尽管Yolo的分类预测非常好,但如果你可以将它们与另一个更强大的网络的分类混合在一起,你可以获得相当不错的性能提升。让我们看看这是如何实现的。这里使用的想法是设置一个高阈值和一个低阈值。然后我们要检查每个分类预测。如果概率小于低阈值,我们将预测设置为“无疾病”。回想一下,我们最初的问题是对14种疾病中的一种进行分类,或者对“无疾病”进行分类。这个低阈值可以是0到1之间的任何值,但可能是0到0.1之间的某个值。此外,如果分类预测在低阈值和高阈值之间,我们得到一个“No Disease”的预测,该预测具有EfficientNet的置信度(不是Yolo)。最后,如果分类预测高于高阈值,我们什么也不做,因为这意味着网络是高度自信的。

可以这样实现:

low_thr  = 0.08
high_thr = 0.95

def filter_2cls(row, low_thr=low_thr, high_thr=high_thr):
    prob = row['target']
    if prob<low_thr:
        ## Less chance of having any disease
        row['PredictionString'] = '14 1 0 0 1 1'
    elif low_thr<=prob<high_thr:
        ## More chance of having any disease
        row['PredictionString']+=f' 14 {prob} 0 0 1 1'
    elif high_thr<=prob:
        ## Good chance of having any disease so believe in object detection model
        row['PredictionString'] = row['PredictionString']
    else:
        raise ValueError('Prediction must be from [0-1]')
    return row

最后的思考

在比赛期间,我已经在各种不同的场景和模型上试验了这2分类过滤器,它似乎总能提高25%的性能,这是令人惊讶的。我认为如果你想把它应用到你的自定义场景中,你需要考虑在哪些情况下分类网络预测可以帮助你的目标检测模型。这并不完全是交换预测的执行度,而是用一种聪明的方式“融合”它们。


—END—

英文原文:https://towardsdatascience.com/fusing-efficientnet-yolov5-advanced-object-detection-2-stage-pipeline-tutorial-da3a77b118d1

Tags:

本文暂时没有评论,来添加一个吧(●'◡'●)

欢迎 发表评论:

最近发表
标签列表