计算机系统应用教程网站

网站首页 > 技术文章 正文

使用 PyTorch/YOLO 进行小目标检测的模型训练代码示例

btikc 2024-09-02 16:50:32 技术文章 15 ℃ 0 评论


# 导入必要的库
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from pathlib import Path
from models.yolo import Model  # 引入自定义的 YOLOv5 模型
from datasets import CustomDataset  # 引入自定义的数据集类

# 定义一些超参数
batch_size = 16
num_epochs = 50
learning_rate = 0.001

# 定义数据转换
train_transforms = transforms.Compose([
    transforms.Resize((640, 640)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor()
])

# 加载训练集和测试集
train_dataset = CustomDataset(root='data/train', transforms=train_transforms)
test_dataset = CustomDataset(root='data/test', transforms=train_transforms)

# 加载数据加载器
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# 初始化模型
model = Model()

# 定义优化器和损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = model.loss

# 定义设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 将模型移动到设备
model.to(device)

# 训练模型
for epoch in range(num_epochs):
    model.train()
    for i, (images, targets) in enumerate(train_loader):
        # 将数据移动到设备
        images = images.to(device)
        targets = targets.to(device)

        # 前向传播和计算损失
        outputs = model(images)
        loss = criterion(outputs, targets)

        # 反向传播和更新参数
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # 打印训练信息
        if (i+1) % 10 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')

    # 在测试集上测试模型
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for images, targets in test_loader:
            # 将数据移动到设备
            images = images.to(device)
            targets = targets.to(device)

            # 前向传播和计算损失
            outputs = model(images)
            loss = criterion(outputs, targets)

            # 统计正确率和总数
            preds = outputs.argmax(dim=1)
            total += targets.size(0)
            correct += (preds == targets).sum().item()

        # 打印测试信息
        accuracy = 100 * correct / total
        print(f'Test Accuracy of the model on the test images: {accuracy:.2f} %')

# 保存模型
model_save_path = Path('models/yolov5.pth')
torch.save(model.state_dict(), model_save_path) 

作者:A1程序设计开发 https://www.bilibili.com/read/cv21994644 出处:bilibili

Tags:

本文暂时没有评论,来添加一个吧(●'◡'●)

欢迎 发表评论:

最近发表
标签列表