# 导入必要的库
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from pathlib import Path
from models.yolo import Model # 引入自定义的 YOLOv5 模型
from datasets import CustomDataset # 引入自定义的数据集类
# 定义一些超参数
batch_size = 16
num_epochs = 50
learning_rate = 0.001
# 定义数据转换
train_transforms = transforms.Compose([
transforms.Resize((640, 640)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ToTensor()
])
# 加载训练集和测试集
train_dataset = CustomDataset(root='data/train', transforms=train_transforms)
test_dataset = CustomDataset(root='data/test', transforms=train_transforms)
# 加载数据加载器
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# 初始化模型
model = Model()
# 定义优化器和损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = model.loss
# 定义设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 将模型移动到设备
model.to(device)
# 训练模型
for epoch in range(num_epochs):
model.train()
for i, (images, targets) in enumerate(train_loader):
# 将数据移动到设备
images = images.to(device)
targets = targets.to(device)
# 前向传播和计算损失
outputs = model(images)
loss = criterion(outputs, targets)
# 反向传播和更新参数
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 打印训练信息
if (i+1) % 10 == 0:
print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')
# 在测试集上测试模型
model.eval()
with torch.no_grad():
correct = 0
total = 0
for images, targets in test_loader:
# 将数据移动到设备
images = images.to(device)
targets = targets.to(device)
# 前向传播和计算损失
outputs = model(images)
loss = criterion(outputs, targets)
# 统计正确率和总数
preds = outputs.argmax(dim=1)
total += targets.size(0)
correct += (preds == targets).sum().item()
# 打印测试信息
accuracy = 100 * correct / total
print(f'Test Accuracy of the model on the test images: {accuracy:.2f} %')
# 保存模型
model_save_path = Path('models/yolov5.pth')
torch.save(model.state_dict(), model_save_path)
作者:A1程序设计开发 https://www.bilibili.com/read/cv21994644 出处:bilibili
本文暂时没有评论,来添加一个吧(●'◡'●)