计算机系统应用教程网站

网站首页 > 技术文章 正文

Pytorch-Faster-R-CNN进行多目标检测

btikc 2024-09-14 01:04:46 技术文章 22 ℃ 0 评论

1 说明

====

1.1 目标检测在python-opencv(cv2)里占很大的比重,而目标检测当红网络肯定少不了RCNN家族。

1.2 2014年,RBG(Ross B. Girshick)使用Region Proposal + CNN代替传统目标检测使用的滑动窗口+手工设计特征,设计了R-CNN框架,使得目标检测取得巨大突破,并开启了基于深度学习目标检测的热潮。

1.3 Fast R-CNN就是在R-CNN的基础上采纳了SPP Net方法,对R-CNN作了改进,使得性能进一步提高,让人们看到了Region Proposal + CNN这一框架实时检测的希望。

1.4 Faster R-CNN最新进展。

==================================

R-CNN(Selective Search + CNN + SVM)

SPP-net(ROI Pooling)

Fast R-CNN(Selective Search + CNN + ROI)

Faster R-CNN(RPN + CNN + ROI)

==================================

总的来说,从R-CNN, SPP-NET, Fast R-CNN, Faster R-CNN一路走来,基于深度学习目标检测的流程变得越来越精简,精度越来越高,速度也越来越快。可以说基于region proposal的R-CNN系列目标检测方法是当前目标检测技术领域最主要的一个分支。

2 准备

=====

2.1 环境:python3.8+deepin-linux操作系统+微软编辑器vscode。

2.2 切记:opencv4.2.0+torch1.5.1(注意版本,否则报错)。

2.3 图片来源:今日头条正版免费图库,效果图。

2.4 进行多目标检测,适当修改,提高可调试性,注意事项已经交代。

3 代码1.py

========

3.1 来源https://github.com/spmallick/learnopencv,对其代码进行删除,修改,注释,修复bug,调试和注意事项。

3.2 代码头注释部分

#Pytorch使用Faster R-CNN进行目标检测
#复杂性多种物体的目标检测

#注意版本
#查询opencv版本
'''
import cv2
cv2.__version__
'''
#opencv 4.2.0

#本机终端输入
#python3.8 1.py

3.3 完整代码1.py

#第1步:导入模块
import cv2
import numpy as np

import torch
import torchvision
import torchvision.transforms as T

#第2步:模型加载
#使用的是Faster R-CNN + ResNet50预训练模型。
#下载预训练模型,Resnet50 Faster R-CNN,带有训练好的权重参数。
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
#不启用 BatchNormalization 和 Dropout,保证BN和dropout不发生变化,
#pytorch框架会自动把BN和Dropout固定住,不会取平均,而是用训练好的值
model.eval()

#第3步:识别内容
#可识别内容,注意顺序不能调动和N/A不能删除
COCO_INSTANCE_CATEGORY_NAMES = [
    '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
    'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
    'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
    'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
    'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
    'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
    'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
    'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
    'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
    'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
    'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]


#第4步:定义预测函数
def get_prediction(img_path, threshold):
    img=cv2.imread(img_path) 
    transform = T.Compose([T.ToTensor()])
    img = transform(img)
    pred = model([img])
    pred_class = [COCO_INSTANCE_CATEGORY_NAMES[i] for i in list(pred[0]['labels'].numpy())]
    #修改np.int32,原来打印出np.float
    pred_boxes = [[(np.int32(i[0]), np.int32(i[1])), (np.int32(i[2]), np.int32(i[3]))] for i in list(pred[0]['boxes'].detach().numpy())]
    pred_score = list(pred[0]['scores'].detach().numpy())
    pred_t = [pred_score.index(x) for x in pred_score if x>threshold][-1]
    pred_boxes = pred_boxes[:pred_t+1]
    pred_class = pred_class[:pred_t+1]
    return pred_boxes, pred_class
  

#第5步:定义目标检测函数
def object_detection_api(img_path, threshold=0.5, rect_th=2, text_size=1, text_th=2):
    #调用预测函数,获取框和识别名
    boxes, pred_cls = get_prediction(img_path, threshold)
    #读取图片
    img = cv2.imread(img_path)
    
    for i in range(len(boxes)):
        #绿色框
        cv2.rectangle(img, boxes[i][0], boxes[i][1],color=(0, 255, 0))
        #红色识别物体名称
        cv2.putText(img,pred_cls[i], boxes[i][0], cv2.FONT_HERSHEY_SIMPLEX, text_size, (0,0,255))

    #显示结果
    cv2.imshow("out" , img)
    cv2.waitKey(0)
    cv2.destroyAllWindows()

#第6步:指定图片启动目标检测
#默认threshold=0.5,也可以修改
object_detection_api('./1.jpeg', threshold=0.8)
#object_detection_api('./2.jpeg', threshold=0.5)

===以上仅仅只能对图片进行多目标检测,档次有点低===

===高级一点,全套功能:图片、摄像头实时和视频文件mp4===

4 全套功能代码2.py

==============

4.1 头文件注释和代码来源:

# -*- coding: utf-8 -*-
"""
Created on Thu Jul 30 08:47:12 2020
@author: Johnson
"""
#https://blog.csdn.net/zhonglongshen/article/details/107682640?utm_medium=distribute.pc_relevant.none-task-blog-BlogCommendFromBaidu-7.control&depth_1-utm_source=distribute.pc_relevant.none-task-blog-BlogCommendFromBaidu-7.control
#对其代码进行删除,修改,注释,修复bug,调试和注意事项。

#本机终端输入
#python3.8 2.py

4.2 coco.names文件内容,注意去除引号和逗号

4.3 完整代码2.py

#第1步:模块导入
import numpy as np
import cv2

import torch
import  torchvision
from torchvision import transforms


#第2步:打开文件coco.names
#将代码1的目标识别名放入coco.names中
with open("./coco.names") as f: #获取类别名称
    coco_names = [line.strip() for line in f.readlines()]

#第3步:加载模型
#在torchvision框架可以直接加载预训练模型
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
model.eval()
##将图片变成 Tensor,并且把数值normalize到[0,1]
transform = transforms.Compose([transforms.ToTensor()])

#第4步:图片检测函数定义,功能1
def faster_rcnn_detection(path):
    image = cv2.imread(path)
    blob = transform(image)
    c,h,w = blob.shape
    input_x = blob.view(1,c,h,w)
    output = model(input_x)[0]  #这里如果是GPU.cuda()
    boxes = output['boxes'].cpu().detach().numpy()
    scores = output['scores'].cpu().detach().numpy()
    labels = output['labels'].cpu().detach().numpy()
    index = 0
    for x1,y1,x2,y2 in boxes:
        #大于0.9比较好,太小要识别过多不精准,误识别
        #类似与代码1的threshold=0.6
        if scores[index]>0.6:
            #框的颜色等设置
            cv2.rectangle(image, (np.int32(x1), np.int32(y1)),
                         (np.int32(x2), np.int32(y2)), (0, 255, 0), 1, 8, 0)
            label_id = labels[index]
            label_txt = coco_names[label_id]
            #文字的颜色等设置
            #将字体1.0改为2.0
            #将1改为2,未报错
            cv2.putText(image, label_txt, (np.int32(x1), np.int32(y1)), 
            cv2.FONT_HERSHEY_PLAIN, 2.0, (0, 0, 255), 2)
        index+=1
    
    cv2.imshow("Faster-RCNN Detection Demo", image)
    cv2.waitKey(0)
    cv2.destroyAllWindows()
    

#第5步:视频检测函数定义,功能2
def video_detection(path):
    capture = cv2.VideoCapture(path)

    #循环
    #while True:
    #while(capture.isOpened()):
    while cv2.waitKey(1) < 0:
        ret,frame = capture.read()
        
        if ret == True:
            #视频翻转
            frame = cv2.flip(frame,1)  #0颠倒,1翻转一下
            blob = transform(frame)
            c,h,w = blob.shape
            input_x = blob.view(1,c,h,w)
            output = model(input_x)[0]  #这里如果是GPU.cuda()
            boxes = output['boxes'].cpu().detach().numpy()
            scores = output['scores'].cpu().detach().numpy()
            labels = output['labels'].cpu().detach().numpy()
            index = 0
        
            for x1,y1,x2,y2 in boxes:
                #大一些,否则漏识别
                if scores[index]>0.5:
                    #框的颜色等设置
                    cv2.rectangle(frame, (np.int32(x1), np.int32(y1)),
                                (np.int32(x2), np.int32(y2)), (0, 255, 0), 1, 8, 0)
                    label_id = labels[index]
                    label_txt = coco_names[label_id]
                    #文字的颜色等设置
                    cv2.putText(frame, label_txt, (np.int32(x1), np.int32(y1)), 
                    cv2.FONT_HERSHEY_PLAIN, 1.0, (0, 0, 255), 1)
                index+=1
    
            wk = cv2.waitKey(1)
            if wk==27:
                break
    
        cv2.imshow("video detection",frame)
   
#摄像头检测函数
#video_detection(0)
#视频文件检测
video_detection("./video.mp4")

#图片检测函数
#faster_rcnn_detection("./1.jpeg")

5 小bug

视频保存未成功!大家可以加油。总之比较完美。

本文暂时没有评论,来添加一个吧(●'◡'●)

欢迎 发表评论:

最近发表
标签列表