计算机系统应用教程网站

网站首页 > 技术文章 正文

小麦头小麦穗目标检测数据集yolo格式(txt标签)4000张左右

btikc 2024-12-29 01:42:36 技术文章 38 ℃ 0 评论

在这里插入图片描述

小麦穗目标检测【数据集】和出处

在这里插入图片描述


yolo小麦检测 包含两种标签: [1]yolo格式(txt标签) [2]coco格式(json标签) 一个类别:麦穗 4000张左右,分辨率为1024x1024

在这里插入图片描述


,包含了如何使用YOLOv8训练小麦穗目标检测的数据集。文档包括数据集准备、模型训练、评估、可视化训练结果、清理临时文件以及推理和显示结果的步骤。

在这里插入图片描述


使用 YOLOv8 训练小麦穗目标检测

在这里插入图片描述



在这里插入图片描述


数据集信息

  • 类别: 麦穗
  • 图片数量: 约4000张
  • 分辨率: 1024x1024
  • 标签格式:
  • YOLO 格式 (txt)
  • COCO 格式 (json)


在这里插入图片描述


步骤概述

  1. 数据集准备
  2. 创建数据集配置文件 (data.yaml)
  3. 分割数据集
  4. 训练模型
  5. 评估模型
  6. 可视化训练结果
  7. 清理临时文件
  8. 推理和显示结果

详细步骤

1. 数据集准备

确保你的数据集已经按照上述格式准备好,并且包含 imageslabels 目录。

wheat_detection/
├── datasets/
│   └── wheat_dataset/
│       ├── images/
│       │   ├── image1.jpg
│       │   ├── image2.jpg
│       │   └── ...
│       └── labels/
│           ├── image1.txt
│           ├── image2.txt
│           └── ...
└── main.py

2. 创建数据集配置文件 (data.yaml)

创建一个 data.yaml 文件来配置数据集路径和类别信息。

train: ./datasets/wheat_dataset/train/images
val: ./datasets/wheat_dataset/val/images

nc: 1  # 类别数量
names: ['wheat_head']  # 类别名称

3. 分割数据集

将数据集分割成训练集和验证集。

import os
import random
from pathlib import Path
import shutil

def split_dataset(data_dir, train_ratio=0.8):
    images = list(Path(data_dir).glob('*.jpg'))
    random.shuffle(images)

    num_train = int(len(images) * train_ratio)
    train_images = images[:num_train]
    val_images = images[num_train:]

    train_dir = Path(data_dir).parent / 'train'
    val_dir = Path(data_dir).parent / 'val'

    train_img_dir = train_dir / 'images'
    train_label_dir = train_dir / 'labels'
    val_img_dir = val_dir / 'images'
    val_label_dir = val_dir / 'labels'

    train_img_dir.mkdir(parents=True, exist_ok=True)
    train_label_dir.mkdir(parents=True, exist_ok=True)
    val_img_dir.mkdir(parents=True, exist_ok=True)
    val_label_dir.mkdir(parents=True, exist_ok=True)

    for img in train_images:
        label_path = img.with_suffix('.txt')
        shutil.copy(img, train_img_dir / img.name)
        shutil.copy(label_path, train_label_dir / label_path.name)

    for img in val_images:
        label_path = img.with_suffix('.txt')
        shutil.copy(img, val_img_dir / img.name)
        shutil.copy(label_path, val_label_dir / label_path.name)

# 使用示例
split_dataset('./datasets/wheat_dataset/images')

4. 训练模型

使用YOLOv8进行训练。

import torch
from ultralytics import YOLO

# 设置随机种子以保证可重复性
torch.manual_seed(42)

# 定义数据集路径
dataset_config = 'data.yaml'

# 加载预训练的YOLOv8n模型
model = YOLO('yolov8n.pt')

# 训练模型
results = model.train(
    data=dataset_config,
    epochs=100,
    imgsz=1024,
    batch=8,
    name='wheat_detection',
    project='runs/train'
)

# 评估模型
metrics = model.val()

# 保存最佳模型权重
best_model_weights = 'runs/train/wheat_detection/weights/best.pt'
print(f"Best model weights saved to {best_model_weights}")

5. 可视化训练结果

可视化训练结果。

from ultralytics import YOLO

# 加载训练好的模型
model = YOLO('runs/train/wheat_detection/weights/best.pt')

# 可视化训练结果
model.plot_results(save=True, save_dir='runs/train/wheat_detection')

6. 清理临时文件

清理不必要的临时文件。

import shutil

def clean_temp_files(project_dir):
    temp_dirs = [
        f'{project_dir}/wandb',
        f'{project_dir}/cache'
    ]

    for dir_path in temp_dirs:
        if os.path.exists(dir_path):
            shutil.rmtree(dir_path)
            print(f"Removed directory: {dir_path}")

# 使用示例
clean_temp_files('runs/train/wheat_detection')

7. 推理和显示结果

推理和显示结果。

from ultralytics import YOLO
import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

# 7.1 检测单张图片
def detect_image(model, image_path, conf_threshold=0.5):
    results = model.predict(image_path, conf=conf_threshold)[0]
    annotated_frame = annotate_image(image_path, results, model)
    return annotated_frame

def annotate_image(image_path, results, model):
    frame = cv2.imread(image_path)
    for result in results.boxes.cpu().numpy():
        r = result.xyxy[0].astype(int)
        cls = int(result.cls[0])
        conf = result.conf[0]

        label = f"{model.names[cls]} {conf:.2f}"
        color = (0, 255, 0)
        cv2.rectangle(frame, (r[0], r[1]), (r[2], r[3]), color, 2)
        cv2.putText(frame, label, (r[0], r[1] - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
    return frame

# 7.2 检测视频
def detect_video(model, video_path, conf_threshold):
    cap = cv2.VideoCapture(video_path)
    frames = []
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        results = model.predict(frame, conf=conf_threshold)[0]
        annotated_frame = annotate_image(frame, results, model)
        frames.append(annotated_frame)
    cap.release()
    return frames

# 7.3 检测摄像头
def detect_camera(model, conf_threshold):
    cap = cv2.VideoCapture(0)
    frames = []
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        results = model.predict(frame, conf=conf_threshold)[0]
        annotated_frame = annotate_image(frame, results, model)
        frames.append(annotated_frame)
        cv2.imshow('Camera Detection', annotated_frame)
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break
    cap.release()
    cv2.destroyAllWindows()
    return frames

# 主函数
def main():
    # 加载模型
    model = YOLO('runs/train/wheat_detection/weights/best.pt')

    # 动态调整置信度阈值
    conf_threshold = 0.5

    # 输入方式选择
    input_type = "Image"  # 可选: "Image", "Video", "Camera"

    if input_type == "Image":
        test_image_path = './datasets/wheat_dataset/images/test_image.jpg'
        annotated_image = detect_image(model, test_image_path, conf_threshold)
        cv2.imwrite('annotated_test_image.jpg', annotated_image)
        plt.imshow(cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB))
        plt.axis('off')
        plt.show()

        # 统计检测到的物体数量
        results = model.predict(test_image_path, conf=conf_threshold)[0]
        class_counts = {}
        for result in results.boxes.cpu().numpy():
            cls = int(result.cls[0])
            class_name = model.names[cls]
            if class_name in class_counts:
                class_counts[class_name] += 1
            else:
                class_counts[class_name] = 1

        print("Detection Summary:")
        for class_name, count in class_counts.items():
            print(f"{class_name}: {count}")

    elif input_type == "Video":
        test_video_path = './datasets/wheat_dataset/videos/test_video.mp4'
        frames = detect_video(model, test_video_path, conf_threshold)
        for i, frame in enumerate(frames):
            cv2.imwrite(f'frame_{i}.jpg', frame)
            plt.imshow(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
            plt.axis('off')
            plt.show()

    elif input_type == "Camera":
        frames = detect_camera(model, conf_threshold)
        for i, frame in enumerate(frames):
            cv2.imwrite(f'camera_frame_{i}.jpg', frame)
            plt.imshow(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
            plt.axis('off')
            plt.show()

if __name__ == "__main__":
    main()

运行脚本

在终端中运行以下命令来执行整个流程:

python main.py

总结

以上文档包含了从数据集准备、模型训练、评估、可视化训练结果、清理临时文件到推理和显示结果的所有步骤。希望这些详细的信息和代码能够帮助你顺利实施和优化你的小麦穗目标检测系统

本文暂时没有评论,来添加一个吧(●'◡'●)

欢迎 发表评论:

最近发表
标签列表