网站首页 > 技术文章 正文
在这里插入图片描述
小麦穗目标检测【数据集】和出处
在这里插入图片描述
yolo小麦检测 包含两种标签: [1]yolo格式(txt标签) [2]coco格式(json标签) 一个类别:麦穗 4000张左右,分辨率为1024x1024
在这里插入图片描述
,包含了如何使用YOLOv8训练小麦穗目标检测的数据集。文档包括数据集准备、模型训练、评估、可视化训练结果、清理临时文件以及推理和显示结果的步骤。
在这里插入图片描述
使用 YOLOv8 训练小麦穗目标检测
在这里插入图片描述
在这里插入图片描述
数据集信息
- 类别: 麦穗
- 图片数量: 约4000张
- 分辨率: 1024x1024
- 标签格式:
- YOLO 格式 (txt)
- COCO 格式 (json)
在这里插入图片描述
步骤概述
- 数据集准备
- 创建数据集配置文件 (data.yaml)
- 分割数据集
- 训练模型
- 评估模型
- 可视化训练结果
- 清理临时文件
- 推理和显示结果
详细步骤
1. 数据集准备
确保你的数据集已经按照上述格式准备好,并且包含 images 和 labels 目录。
wheat_detection/
├── datasets/
│ └── wheat_dataset/
│ ├── images/
│ │ ├── image1.jpg
│ │ ├── image2.jpg
│ │ └── ...
│ └── labels/
│ ├── image1.txt
│ ├── image2.txt
│ └── ...
└── main.py
2. 创建数据集配置文件 (data.yaml)
创建一个 data.yaml 文件来配置数据集路径和类别信息。
train: ./datasets/wheat_dataset/train/images
val: ./datasets/wheat_dataset/val/images
nc: 1 # 类别数量
names: ['wheat_head'] # 类别名称
3. 分割数据集
将数据集分割成训练集和验证集。
import os
import random
from pathlib import Path
import shutil
def split_dataset(data_dir, train_ratio=0.8):
images = list(Path(data_dir).glob('*.jpg'))
random.shuffle(images)
num_train = int(len(images) * train_ratio)
train_images = images[:num_train]
val_images = images[num_train:]
train_dir = Path(data_dir).parent / 'train'
val_dir = Path(data_dir).parent / 'val'
train_img_dir = train_dir / 'images'
train_label_dir = train_dir / 'labels'
val_img_dir = val_dir / 'images'
val_label_dir = val_dir / 'labels'
train_img_dir.mkdir(parents=True, exist_ok=True)
train_label_dir.mkdir(parents=True, exist_ok=True)
val_img_dir.mkdir(parents=True, exist_ok=True)
val_label_dir.mkdir(parents=True, exist_ok=True)
for img in train_images:
label_path = img.with_suffix('.txt')
shutil.copy(img, train_img_dir / img.name)
shutil.copy(label_path, train_label_dir / label_path.name)
for img in val_images:
label_path = img.with_suffix('.txt')
shutil.copy(img, val_img_dir / img.name)
shutil.copy(label_path, val_label_dir / label_path.name)
# 使用示例
split_dataset('./datasets/wheat_dataset/images')
4. 训练模型
使用YOLOv8进行训练。
import torch
from ultralytics import YOLO
# 设置随机种子以保证可重复性
torch.manual_seed(42)
# 定义数据集路径
dataset_config = 'data.yaml'
# 加载预训练的YOLOv8n模型
model = YOLO('yolov8n.pt')
# 训练模型
results = model.train(
data=dataset_config,
epochs=100,
imgsz=1024,
batch=8,
name='wheat_detection',
project='runs/train'
)
# 评估模型
metrics = model.val()
# 保存最佳模型权重
best_model_weights = 'runs/train/wheat_detection/weights/best.pt'
print(f"Best model weights saved to {best_model_weights}")
5. 可视化训练结果
可视化训练结果。
from ultralytics import YOLO
# 加载训练好的模型
model = YOLO('runs/train/wheat_detection/weights/best.pt')
# 可视化训练结果
model.plot_results(save=True, save_dir='runs/train/wheat_detection')
6. 清理临时文件
清理不必要的临时文件。
import shutil
def clean_temp_files(project_dir):
temp_dirs = [
f'{project_dir}/wandb',
f'{project_dir}/cache'
]
for dir_path in temp_dirs:
if os.path.exists(dir_path):
shutil.rmtree(dir_path)
print(f"Removed directory: {dir_path}")
# 使用示例
clean_temp_files('runs/train/wheat_detection')
7. 推理和显示结果
推理和显示结果。
from ultralytics import YOLO
import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
# 7.1 检测单张图片
def detect_image(model, image_path, conf_threshold=0.5):
results = model.predict(image_path, conf=conf_threshold)[0]
annotated_frame = annotate_image(image_path, results, model)
return annotated_frame
def annotate_image(image_path, results, model):
frame = cv2.imread(image_path)
for result in results.boxes.cpu().numpy():
r = result.xyxy[0].astype(int)
cls = int(result.cls[0])
conf = result.conf[0]
label = f"{model.names[cls]} {conf:.2f}"
color = (0, 255, 0)
cv2.rectangle(frame, (r[0], r[1]), (r[2], r[3]), color, 2)
cv2.putText(frame, label, (r[0], r[1] - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
return frame
# 7.2 检测视频
def detect_video(model, video_path, conf_threshold):
cap = cv2.VideoCapture(video_path)
frames = []
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
results = model.predict(frame, conf=conf_threshold)[0]
annotated_frame = annotate_image(frame, results, model)
frames.append(annotated_frame)
cap.release()
return frames
# 7.3 检测摄像头
def detect_camera(model, conf_threshold):
cap = cv2.VideoCapture(0)
frames = []
while True:
ret, frame = cap.read()
if not ret:
break
results = model.predict(frame, conf=conf_threshold)[0]
annotated_frame = annotate_image(frame, results, model)
frames.append(annotated_frame)
cv2.imshow('Camera Detection', annotated_frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
return frames
# 主函数
def main():
# 加载模型
model = YOLO('runs/train/wheat_detection/weights/best.pt')
# 动态调整置信度阈值
conf_threshold = 0.5
# 输入方式选择
input_type = "Image" # 可选: "Image", "Video", "Camera"
if input_type == "Image":
test_image_path = './datasets/wheat_dataset/images/test_image.jpg'
annotated_image = detect_image(model, test_image_path, conf_threshold)
cv2.imwrite('annotated_test_image.jpg', annotated_image)
plt.imshow(cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB))
plt.axis('off')
plt.show()
# 统计检测到的物体数量
results = model.predict(test_image_path, conf=conf_threshold)[0]
class_counts = {}
for result in results.boxes.cpu().numpy():
cls = int(result.cls[0])
class_name = model.names[cls]
if class_name in class_counts:
class_counts[class_name] += 1
else:
class_counts[class_name] = 1
print("Detection Summary:")
for class_name, count in class_counts.items():
print(f"{class_name}: {count}")
elif input_type == "Video":
test_video_path = './datasets/wheat_dataset/videos/test_video.mp4'
frames = detect_video(model, test_video_path, conf_threshold)
for i, frame in enumerate(frames):
cv2.imwrite(f'frame_{i}.jpg', frame)
plt.imshow(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
plt.axis('off')
plt.show()
elif input_type == "Camera":
frames = detect_camera(model, conf_threshold)
for i, frame in enumerate(frames):
cv2.imwrite(f'camera_frame_{i}.jpg', frame)
plt.imshow(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
plt.axis('off')
plt.show()
if __name__ == "__main__":
main()
运行脚本
在终端中运行以下命令来执行整个流程:
python main.py
总结
以上文档包含了从数据集准备、模型训练、评估、可视化训练结果、清理临时文件到推理和显示结果的所有步骤。希望这些详细的信息和代码能够帮助你顺利实施和优化你的小麦穗目标检测系统
猜你喜欢
- 2024-12-29 国内首个非Attention大模型发布!训练效率是Transformer的7倍
- 2024-12-29 AI大模型探索之路 - 训练篇8:Transformer库预训练全流程实战指南
- 2024-12-29 基于yolov8,训练一个安全帽佩戴的目标检测模型
- 2024-12-29 从零手搓中文大模型计划|Day06|预训练代码汇总和梳理
- 2024-12-29 YOLOv8姿态估计模型训练简明教程 姿态估计heatmap
- 2024-12-29 首次!用合成人脸数据集训练的识别模型,性能高于真实数据集
- 2024-12-29 风控模型应聘,80%会被问到的面试题
- 2024-12-29 快乐8第24271期训练与验证 快乐八2021248期
- 2024-12-29 AI系列:怎么对模型进行测试 ai模拟量
- 2024-12-29 QAF2D:利用2D检测引导查询3D anchor来增强BEV远距离目标检测
你 发表评论:
欢迎- 最近发表
- 标签列表
-
- oraclesql优化 (66)
- 类的加载机制 (75)
- feignclient (62)
- 一致性hash算法 (71)
- dockfile (66)
- 锁机制 (57)
- javaresponse (60)
- 查看hive版本 (59)
- phpworkerman (57)
- spark算子 (58)
- vue双向绑定的原理 (68)
- springbootget请求 (58)
- docker网络三种模式 (67)
- spring控制反转 (71)
- data:image/jpeg (69)
- base64 (69)
- java分页 (64)
- kibanadocker (60)
- qabstracttablemodel (62)
- java生成pdf文件 (69)
- deletelater (62)
- com.aspose.words (58)
- android.mk (62)
- qopengl (73)
- epoch_millis (61)
本文暂时没有评论,来添加一个吧(●'◡'●)