网站首页 > 技术文章 正文
0 准备工作
新建文件夹“动物分类”,在“动物分类”新建文件夹“数据”。
1 爬取动物图片
在“动物分类”,右键运行终端:
gedit get_data.py
python get_data.py
get_data.py
import requests
import urllib.parse as up
import json
import time
import os
major_url = 'https://image.baidu.com/search/index?'
headers = {'User-Agent' : 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/84.0.4147.135 Safari/537.36'}
def pic_spider(kw, path, page = 10):
path = os.path.join(path, kw)
if not os.path.exists(path):
os.mkdir(path)
if kw != '':
for num in range(page):
data = {
"tn": "resultjson_com",
"logid": "11587207680030063767",
"ipn": "rj",
"ct": "201326592",
"is": "",
"fp": "result",
"queryWord": kw,
"cl": "2",
"lm": "-1",
"ie": "utf-8",
"oe": "utf-8",
"adpicid": "",
"st": "-1",
"z": "",
"ic": "0",
"hd": "",
"latest": "",
"copyright": "",
"word": kw,
"s": "",
"se": "",
"tab": "",
"width": "",
"height": "",
"face": "0",
"istype": "2",
"qc": "",
"nc": "1",
"fr": "",
"expermode": "",
"force": "",
"pn": num*30,
"rn": "30",
"gsm": oct(num*30),
"1602481599433": ""
}
url = major_url + up.urlencode(data)
i = 0
pic_list = []
while i < 5:
try:
pic_list = requests.get(url=url, headers=headers).json().get('data')
break
except:
print('网络不好,正在重试...')
i += 1
time.sleep(1.3)
for pic in pic_list:
url = pic.get('thumbURL', '') # 有的没有图片链接,就设置成空
if url == '':
continue
name = pic.get('fromPageTitleEnc')
for char in ['?', '\\', '/', '*', '"', '|', ':', '<', '>']:
name = name.replace(char, '') # 将所有不能出现在文件名中的字符去除掉
type = pic.get('type', 'jpg') # 找到图片的类型,若没有找到,默认为 jpg
pic_path = (os.path.join(path, '%s.%s') % (name, type))
print(name, '已完成下载')
if not os.path.exists(pic_path):
with open(pic_path, 'wb') as f:
f.write(requests.get(url = url, headers = headers).content)
cwd = os.getcwd() # 当前路径
file1 = 'flower_data/flower_photos'
file2 = '数据/下载数据'
save_path = os.path.join(cwd,file2)
#flower_class = [cla for cla in os.listdir(file1) if ".txt" not in cla]
lists = ['猫','哈士奇','燕子','恐龙','鹦鹉','老鹰','柴犬','田园犬','咖啡猫','老虎','狮子','哥斯拉','奥特曼']
print("lists_len: ",len(lists))
for list in lists:
if not os.path.exists(save_path):
os.mkdir(save_path)
pic_spider(list,save_path, page = 10)
2 数据划分
将下载数据划分为训练集(80%)、验证集(10%)和测试集(10%)
gedit spile_data.py
python spile_data.py
spile_data.py
import os
from shutil import copy
import random
def mkfile(file):
if not os.path.exists(file):
os.makedirs(file)
#file = 'flower_data/flower_photos'
file = '数据/下载数据'
flower_class = [cla for cla in os.listdir(file) if ".txt" not in cla]
#mkfile('flower_data/train')
mkfile('数据/train')
for cla in flower_class:
#mkfile('flower_data/train/'+cla)
mkfile('数据/train/'+cla)
#mkfile('flower_data/val')
mkfile('数据/val')
for cla in flower_class:
#mkfile('flower_data/val/'+cla)
mkfile('数据/val/'+cla)
mkfile('数据/predict')
for cla in flower_class:
#mkfile('flower_data/predict/'+cla)
mkfile('数据/predict/'+cla)
split_rate = 0.1
for cla in flower_class:
cla_path = file + '/' + cla + '/'
images1 = [cla1 for cla1 in os.listdir(cla_path) if ".jpg" in cla1]
images = [cla1 for cla1 in os.listdir(cla_path) if ".png" in cla1]+images1
#images = os.listdir(cla_path)
num = len(images)
#eval_index = random.sample(images, k=int(num*split_rate))
for index, image in enumerate(images):
if index<0.1*num:
image_path = cla_path + image
new_path = '数据/val/' + cla
copy(image_path, new_path)
elif 0.1*num<index<0.9*num:
image_path = cla_path + image
new_path = '数据/train/' + cla
copy(image_path, new_path)
else:
image_path = cla_path + image
new_path = '数据/predict/' + cla
copy(image_path, new_path)
print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="") # processing bar
print()
print("处理完成 !")
3 模型
gedit model.py
python model.py
model.py
import torch.nn as nn
import torch
class AlexNet(nn.Module):
def __init__(self, num_classes=1000, init_weights=False):
super(AlexNet, self).__init__()
self.features = nn.Sequential( #打包
nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2), # input[3, 224, 224] output[48, 55, 55] 自动舍去小数点后
nn.ReLU(inplace=True), #inplace 可以载入更大模型
nn.MaxPool2d(kernel_size=3, stride=2), # output[48, 27, 27] kernel_num为原论文一半
nn.Conv2d(48, 128, kernel_size=5, padding=2), # output[128, 27, 27]
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2), # output[128, 13, 13]
nn.Conv2d(128, 192, kernel_size=3, padding=1), # output[192, 13, 13]
nn.ReLU(inplace=True),
nn.Conv2d(192, 192, kernel_size=3, padding=1), # output[192, 13, 13]
nn.ReLU(inplace=True),
nn.Conv2d(192, 128, kernel_size=3, padding=1), # output[128, 13, 13]
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2), # output[128, 6, 6]
)
self.classifier = nn.Sequential(
nn.Dropout(p=0.5),
#全链接
nn.Linear(128 * 6 * 6, 2048),
nn.ReLU(inplace=True),
nn.Dropout(p=0.5),
nn.Linear(2048, 2048),
nn.ReLU(inplace=True),
nn.Linear(2048, num_classes),
)
if init_weights:
self._initialize_weights()
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, start_dim=1) #展平 或者view()
x = self.classifier(x)
return x
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') #何教授方法
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01) #正态分布赋值
nn.init.constant_(m.bias, 0)
4 训练和验证
gedit train.py
python train.py
train.py
import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from model import AlexNet
import os
import json
import time
import torchvision
#device : GPU 或 CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
#数据预处理
data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(224), # 随机裁剪为224x224
transforms.RandomHorizontalFlip(), # 水平翻转
transforms.ToTensor(), # 转为张量
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),# 均值和方差为0.5
"val": transforms.Compose([transforms.Resize((224, 224)), # 重置大小
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}
batch_size = 32 # 批次大小
data_root = os.getcwd() # 获取当前路径
image_path = data_root + "/数据" # 数据路径
train_dataset = datasets.ImageFolder(root=image_path + "/train",
transform=data_transform["train"]) # 加载训练数据集并预处理
train_num = len(train_dataset) # 训练数据集大小
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size, shuffle=True,
num_workers=0) # 训练加载器
validate_dataset = datasets.ImageFolder(root=image_path + "/val",
transform=data_transform["val"]) # 验证数据集
val_num = len(validate_dataset) # 验证数据集大小
validate_loader = torch.utils.data.DataLoader(validate_dataset,
batch_size=batch_size, shuffle=True,
num_workers=0) # 验证加载器
print("训练数据集大小: ",train_num,"\n") # 28218
print("验证数据集大小: ",val_num,"\n") # 308
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
net = AlexNet(num_classes=13, init_weights=True) # 调用模型
net.to(device)
loss_function = nn.CrossEntropyLoss() # 损失函数:交叉熵
optimizer = optim.Adam(net.parameters(), lr=0.0002) #优化器 Adam
save_path = './AlexNet.pth' # 训练参数保存路径
best_acc = 0.0 # 训练过程中最高准确率
#开始进行训练和测试,训练一轮,测试一轮
for epoch in range(10):
# 训练部分
print(">>开始训练: ",epoch+1)
net.train() #训练dropout
running_loss = 0.0
t1 = time.perf_counter()
for step, data in enumerate(train_loader, start=0):
images, labels = data
#print("\nlabels: ",labels)
#imshow(torchvision.utils.make_grid(images))
optimizer.zero_grad() # 梯度置0
outputs = net(images.to(device))
loss = loss_function(outputs, labels.to(device))
loss.backward() # 反向传播
optimizer.step()
running_loss += loss.item() # 累加损失
rate = (step + 1) / len(train_loader) # 训练进度
a = "*" * int(rate * 50) # *数
b = "." * int((1 - rate) * 50) # .数
print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")
print()
print(time.perf_counter()-t1) # 一个epoch花费的时间
# 验证部分
print(">>开始验证: ",epoch+1)
net.eval() #验证不需要dropout
acc = 0.0 # 一个批次中分类正确个数
with torch.no_grad():
for val_data in validate_loader:
val_images, val_labels = val_data
outputs = net(val_images.to(device))
#print("outputs: \n",outputs,"\n")
predict_y = torch.max(outputs, dim=1)[1]
#print("predict_y: \n",predict_y,"\n")
acc += (predict_y == val_labels.to(device)).sum().item() # 预测和标签一致,累加
val_accurate = acc / val_num # 一个批次的准确率
if val_accurate > best_acc:
best_acc = val_accurate
torch.save(net.state_dict(), save_path) # 更新准确率最高的网络参数
print('[epoch %d] train_loss: %.3f test_accuracy: %.3f' %
(epoch + 1, running_loss / step, val_accurate))
print('Finished Training')
# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = train_dataset.class_to_idx
# {0: 'daisy', 1: 'dandelion', 2: 'roses', 3: 'sunflowers', 4: 'tulips'}
cla_dict = dict((val, key) for key, val in flower_list.items())
# 将字典写入 json 文件
json_str = json.dumps(cla_dict, indent=4) # 字典转json
with open('class_indices.json', 'w') as json_file: # 对class_indices.json写入操作
json_file.write(json_str) # 写入class_indices.json
5 测试
gedit predict.py
python predict.py
predict.py
import torch
from model import AlexNet
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import json
import os
data_transform = transforms.Compose(
[transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
cwd = os.getcwd() # 当前路径
predict = '数据/predict'
predict_path = os.path.join(cwd,predict)
#flowers = ['雏菊','蒲公英','玫瑰花','太阳花','郁金香']
#flowers = [flower for flower in os.listdir(predict_path)]
try:
json_file = open('./class_indices.json', 'r')
class_indict = json.load(json_file)
except Exception as e:
print(e)
exit(-1)
for j,flower in class_indict.items():
print(">>测试: ",flower)
#print("花\t","概率")
path = os.path.join(predict_path,flower)
images = [f1 for f1 in os.listdir(path) if ".gif" not in f1] # 过滤gif动图
acc_ = [0,0,0,0,0,0,0,0,0,0,0,0,0]
for image in images:
# 加载图片
img = Image.open(path+'/'+image).convert('RGB')
# RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 0
# .convert('RGB')
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)
# read class_indict
try:
json_file = open('./class_indices.json', 'r')
class_indict = json.load(json_file)
except Exception as e:
print(e)
exit(-1)
# create model
model = AlexNet(num_classes=13)
# load model weights
model_weight_path = "./AlexNet.pth"
model.load_state_dict(torch.load(model_weight_path))
model.eval()
with torch.no_grad():
# predict class
output = torch.squeeze(model(img))
predict = torch.softmax(output, dim=0)
predict_flower = torch.argmax(predict).numpy()
#print(class_indict[str(predict_flower)],'\t', predict[predict_flower].item())
#print(str(predict_flower))
acc_[predict_flower]+=1
#print("acc_: ",acc_)
print("{}总共有{}张图片 \n".format(flower,len(images)))
#print(class_indict.values(),'\n',str(acc_))
print("{}准确率为:{}%".format(flower,100*acc_[int(j)]/len(images)))
print("\n")
print(">>测试完毕!")
测试结果:
>>测试: 咖啡猫
咖啡猫总共有21张图片
咖啡猫准确率为:14.285714285714286%
>>测试: 哈士奇
哈士奇总共有24张图片
哈士奇准确率为:58.333333333333336%
>>测试: 哥斯拉
哥斯拉总共有21张图片
哥斯拉准确率为:57.142857142857146%
>>测试: 奥特曼
奥特曼总共有23张图片
奥特曼准确率为:30.434782608695652%
>>测试: 恐龙
恐龙总共有23张图片
恐龙准确率为:34.78260869565217%
>>测试: 柴犬
柴犬总共有17张图片
柴犬准确率为:70.58823529411765%
>>测试: 燕子
燕子总共有18张图片
燕子准确率为:61.111111111111114%
>>测试: 狮子
狮子总共有24张图片
狮子准确率为:58.333333333333336%
>>测试: 猫
猫总共有22张图片
猫准确率为:18.181818181818183%
>>测试: 田园犬
田园犬总共有22张图片
田园犬准确率为:4.545454545454546%
>>测试: 老虎
老虎总共有22张图片
老虎准确率为:22.727272727272727%
>>测试: 老鹰
老鹰总共有22张图片
老鹰准确率为:0.0%
>>测试: 鹦鹉
鹦鹉总共有22张图片
鹦鹉准确率为:72.72727272727273%
提示:爬虫有风险,使用需谨慎!
- 上一篇: 图表显示日志离线信息 离线日志正在运行
- 下一篇: SM框架整合篇 ssm框架crud
猜你喜欢
- 2024-09-27 后端思维篇:如何抽一个观察者模板
- 2024-09-27 使用Hourglass网络来理解人体姿态
- 2024-09-27 SM框架整合篇 ssm框架crud
- 2024-09-27 图表显示日志离线信息 离线日志正在运行
- 2024-09-27 比用Pytorch框架快200倍!0.76秒后,笔记本上的CNN就搞定了MNIST
- 2024-09-27 卷积神经网络背后的数学 卷积神经网络教学视频
- 2024-09-27 多层级遇到多兴趣:快手、武汉大学用于序列推荐的多粒度神经模型
- 2024-09-27 使用分割来寻找疑似结节(13) 分割检测
- 2024-09-27 买药秒送 JADE动态线程池实践及原理浅析
- 2024-09-27 可逆神经网络详细解析:让神经网络更加轻量化
你 发表评论:
欢迎- 最近发表
-
- 在 Spring Boot 项目中使用 activiti
- 开箱即用-activiti流程引擎(active 流程引擎)
- 在springBoot项目中整合使用activiti
- activiti中的网关是干什么的?(activiti包含网关)
- SpringBoot集成工作流Activiti(完整源码和配套文档)
- Activiti工作流介绍及使用(activiti工作流会签)
- SpringBoot集成工作流Activiti(实际项目演示)
- activiti工作流引擎(activiti工作流引擎怎么用)
- 工作流Activiti初体验及在数据库中生成的表
- Activiti工作流浅析(activiti6.0工作流引擎深度解析)
- 标签列表
-
- oraclesql优化 (66)
- 类的加载机制 (75)
- feignclient (62)
- 一致性hash算法 (71)
- dockfile (66)
- 锁机制 (57)
- javaresponse (60)
- 查看hive版本 (59)
- phpworkerman (57)
- spark算子 (58)
- vue双向绑定的原理 (68)
- springbootget请求 (58)
- docker网络三种模式 (67)
- spring控制反转 (71)
- data:image/jpeg (69)
- base64 (69)
- java分页 (64)
- kibanadocker (60)
- qabstracttablemodel (62)
- java生成pdf文件 (69)
- deletelater (62)
- com.aspose.words (58)
- android.mk (62)
- qopengl (73)
- epoch_millis (61)
本文暂时没有评论,来添加一个吧(●'◡'●)