1 说明
====
1.1 目标检测在python-opencv(cv2)里占很大的比重,而目标检测当红网络肯定少不了RCNN家族。
1.2 2014年,RBG(Ross B. Girshick)使用Region Proposal + CNN代替传统目标检测使用的滑动窗口+手工设计特征,设计了R-CNN框架,使得目标检测取得巨大突破,并开启了基于深度学习目标检测的热潮。
1.3 Fast R-CNN就是在R-CNN的基础上采纳了SPP Net方法,对R-CNN作了改进,使得性能进一步提高,让人们看到了Region Proposal + CNN这一框架实时检测的希望。
1.4 Faster R-CNN最新进展。
==================================
R-CNN(Selective Search + CNN + SVM)
SPP-net(ROI Pooling)
Fast R-CNN(Selective Search + CNN + ROI)
Faster R-CNN(RPN + CNN + ROI)
==================================
总的来说,从R-CNN, SPP-NET, Fast R-CNN, Faster R-CNN一路走来,基于深度学习目标检测的流程变得越来越精简,精度越来越高,速度也越来越快。可以说基于region proposal的R-CNN系列目标检测方法是当前目标检测技术领域最主要的一个分支。
2 准备
=====
2.1 环境:python3.8+deepin-linux操作系统+微软编辑器vscode。
2.2 切记:opencv4.2.0+torch1.5.1(注意版本,否则报错)。
2.3 图片来源:今日头条正版免费图库,效果图。
2.4 进行多目标检测,适当修改,提高可调试性,注意事项已经交代。
3 代码1.py
========
3.1 来源https://github.com/spmallick/learnopencv,对其代码进行删除,修改,注释,修复bug,调试和注意事项。
3.2 代码头注释部分
#Pytorch使用Faster R-CNN进行目标检测
#复杂性多种物体的目标检测
#注意版本
#查询opencv版本
'''
import cv2
cv2.__version__
'''
#opencv 4.2.0
#本机终端输入
#python3.8 1.py
3.3 完整代码1.py
#第1步:导入模块
import cv2
import numpy as np
import torch
import torchvision
import torchvision.transforms as T
#第2步:模型加载
#使用的是Faster R-CNN + ResNet50预训练模型。
#下载预训练模型,Resnet50 Faster R-CNN,带有训练好的权重参数。
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
#不启用 BatchNormalization 和 Dropout,保证BN和dropout不发生变化,
#pytorch框架会自动把BN和Dropout固定住,不会取平均,而是用训练好的值
model.eval()
#第3步:识别内容
#可识别内容,注意顺序不能调动和N/A不能删除
COCO_INSTANCE_CATEGORY_NAMES = [
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]
#第4步:定义预测函数
def get_prediction(img_path, threshold):
img=cv2.imread(img_path)
transform = T.Compose([T.ToTensor()])
img = transform(img)
pred = model([img])
pred_class = [COCO_INSTANCE_CATEGORY_NAMES[i] for i in list(pred[0]['labels'].numpy())]
#修改np.int32,原来打印出np.float
pred_boxes = [[(np.int32(i[0]), np.int32(i[1])), (np.int32(i[2]), np.int32(i[3]))] for i in list(pred[0]['boxes'].detach().numpy())]
pred_score = list(pred[0]['scores'].detach().numpy())
pred_t = [pred_score.index(x) for x in pred_score if x>threshold][-1]
pred_boxes = pred_boxes[:pred_t+1]
pred_class = pred_class[:pred_t+1]
return pred_boxes, pred_class
#第5步:定义目标检测函数
def object_detection_api(img_path, threshold=0.5, rect_th=2, text_size=1, text_th=2):
#调用预测函数,获取框和识别名
boxes, pred_cls = get_prediction(img_path, threshold)
#读取图片
img = cv2.imread(img_path)
for i in range(len(boxes)):
#绿色框
cv2.rectangle(img, boxes[i][0], boxes[i][1],color=(0, 255, 0))
#红色识别物体名称
cv2.putText(img,pred_cls[i], boxes[i][0], cv2.FONT_HERSHEY_SIMPLEX, text_size, (0,0,255))
#显示结果
cv2.imshow("out" , img)
cv2.waitKey(0)
cv2.destroyAllWindows()
#第6步:指定图片启动目标检测
#默认threshold=0.5,也可以修改
object_detection_api('./1.jpeg', threshold=0.8)
#object_detection_api('./2.jpeg', threshold=0.5)
===以上仅仅只能对图片进行多目标检测,档次有点低===
===高级一点,全套功能:图片、摄像头实时和视频文件mp4===
4 全套功能代码2.py
==============
4.1 头文件注释和代码来源:
# -*- coding: utf-8 -*-
"""
Created on Thu Jul 30 08:47:12 2020
@author: Johnson
"""
#https://blog.csdn.net/zhonglongshen/article/details/107682640?utm_medium=distribute.pc_relevant.none-task-blog-BlogCommendFromBaidu-7.control&depth_1-utm_source=distribute.pc_relevant.none-task-blog-BlogCommendFromBaidu-7.control
#对其代码进行删除,修改,注释,修复bug,调试和注意事项。
#本机终端输入
#python3.8 2.py
4.2 coco.names文件内容,注意去除引号和逗号
4.3 完整代码2.py
#第1步:模块导入
import numpy as np
import cv2
import torch
import torchvision
from torchvision import transforms
#第2步:打开文件coco.names
#将代码1的目标识别名放入coco.names中
with open("./coco.names") as f: #获取类别名称
coco_names = [line.strip() for line in f.readlines()]
#第3步:加载模型
#在torchvision框架可以直接加载预训练模型
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
model.eval()
##将图片变成 Tensor,并且把数值normalize到[0,1]
transform = transforms.Compose([transforms.ToTensor()])
#第4步:图片检测函数定义,功能1
def faster_rcnn_detection(path):
image = cv2.imread(path)
blob = transform(image)
c,h,w = blob.shape
input_x = blob.view(1,c,h,w)
output = model(input_x)[0] #这里如果是GPU.cuda()
boxes = output['boxes'].cpu().detach().numpy()
scores = output['scores'].cpu().detach().numpy()
labels = output['labels'].cpu().detach().numpy()
index = 0
for x1,y1,x2,y2 in boxes:
#大于0.9比较好,太小要识别过多不精准,误识别
#类似与代码1的threshold=0.6
if scores[index]>0.6:
#框的颜色等设置
cv2.rectangle(image, (np.int32(x1), np.int32(y1)),
(np.int32(x2), np.int32(y2)), (0, 255, 0), 1, 8, 0)
label_id = labels[index]
label_txt = coco_names[label_id]
#文字的颜色等设置
#将字体1.0改为2.0
#将1改为2,未报错
cv2.putText(image, label_txt, (np.int32(x1), np.int32(y1)),
cv2.FONT_HERSHEY_PLAIN, 2.0, (0, 0, 255), 2)
index+=1
cv2.imshow("Faster-RCNN Detection Demo", image)
cv2.waitKey(0)
cv2.destroyAllWindows()
#第5步:视频检测函数定义,功能2
def video_detection(path):
capture = cv2.VideoCapture(path)
#循环
#while True:
#while(capture.isOpened()):
while cv2.waitKey(1) < 0:
ret,frame = capture.read()
if ret == True:
#视频翻转
frame = cv2.flip(frame,1) #0颠倒,1翻转一下
blob = transform(frame)
c,h,w = blob.shape
input_x = blob.view(1,c,h,w)
output = model(input_x)[0] #这里如果是GPU.cuda()
boxes = output['boxes'].cpu().detach().numpy()
scores = output['scores'].cpu().detach().numpy()
labels = output['labels'].cpu().detach().numpy()
index = 0
for x1,y1,x2,y2 in boxes:
#大一些,否则漏识别
if scores[index]>0.5:
#框的颜色等设置
cv2.rectangle(frame, (np.int32(x1), np.int32(y1)),
(np.int32(x2), np.int32(y2)), (0, 255, 0), 1, 8, 0)
label_id = labels[index]
label_txt = coco_names[label_id]
#文字的颜色等设置
cv2.putText(frame, label_txt, (np.int32(x1), np.int32(y1)),
cv2.FONT_HERSHEY_PLAIN, 1.0, (0, 0, 255), 1)
index+=1
wk = cv2.waitKey(1)
if wk==27:
break
cv2.imshow("video detection",frame)
#摄像头检测函数
#video_detection(0)
#视频文件检测
video_detection("./video.mp4")
#图片检测函数
#faster_rcnn_detection("./1.jpeg")
5 小bug
视频保存未成功!大家可以加油。总之比较完美。
本文暂时没有评论,来添加一个吧(●'◡'●)