网站首页 > 技术文章 正文
介绍
SimCLR论文(http://cse.iitkgp.ac.in/~arastogi/papers/simclr.pdf)解释了这个框架如何从更大的模型和更大的批处理中获益,并且如果有足够的计算能力,可以产生与监督模型类似的结果。
但是这些需求使得框架的计算量相当大。如果我们可以拥有这个框架的简单性和强大功能,并且有更少的计算需求,这样每个人都可以访问它,这不是很好吗?Moco-v2前来救援。
注意:在之前的一篇博文中,我们在PyTorch中实现了SimCLR框架,它是在一个包含5个类别的简单数据集上实现的,总共只有1250个训练图像。
数据集
这次我们将在Pytorch中在更大的数据集上实现Moco-v2,并在Google Colab上训练我们的模型。这次我们将使用Imagenette和Imagewoof数据集
来自Imagenette数据集的一些图像
这些数据集的快速摘要(更多信息在这里:https://github.com/fastai/imagenette):
- Imagenette由Imagenet的10个容易分类的类组成,总共有9479个训练图像和3935个验证集图像。
- Imagewoof是一个由Imagenet提供的10个难分类组成的数据集,因为所有的类都是狗的品种。总共有9035个训练图像,3939个验证集图像。
对比学习
对比学习在自我监督学习中的作用是基于这样一个理念:我们希望同一类别中不同的图像观具有相似的表征。但是,由于我们不知道哪些图像属于同一类别,通常所做的是将同一图像的不同外观的表示拉近。我们把这些不同的外观称为正对(positive pairs)。
另外,我们希望不同类别的图像有不同的外观,使它们的表征彼此远离。不同图像的不同外观的呈现与类别无关,会被彼此推开。我们把这些不同的外观称为负对(negative pairs)。
在这种情况下,一个图像的前景是什么?前景可以被认为是以一种经过修改的方式看待图像的某些部分,它本质上是图像的一种变换。
根据手头的任务,有些转换可以比其他转换工作得更好。SimCLR表明,应用随机裁剪和颜色抖动可以很好地完成各种任务,包括图像分类。这本质上来自于网格搜索,从旋转、裁剪、剪切、噪声、模糊、Sobel滤波等选项中选择一对变换。
从外观到表示空间的映射是通过神经网络完成的,通常,resnet用于此目的。下面是从图像到表示的管道
负对是如何产生的?
在同一幅图像中,由于随机裁剪,我们可以得到多个表示。这样,我们就可以产生正对。
但是如何生成负对呢?负对是来自不同图像的表示。SimCLR论文在同一批中创建了这些。如果一个批包含N个图像,那么对于每个图像,我们将得到2个表示,这总共占2*N个表示。对于一个特定的表示x,有一个表示与x形成正对(与x来自同一个图像的表示),其余所有表示(正好是2*N–2)与x形成负对。
如果我们手头有大量的负样本,这些表示就会得到改善。但是,在SimCLR中,只有当批量较大时,才能实现大量的负样本,这导致了对计算能力的更高要求。MoCo-v2提供了生成负样本的另一种方法。让我们详细了解一下。
动态词典
我们可以用一种稍微不同的方式来看待对比学习方法,即将查询与键进行匹配。我们现在有两个编码器,一个用于查询,另一个用于键。此外,为了得到大量的负样本,我们需要一个大的键编码字典。
此上下文中的正对表示查询与键匹配。如果查询和键都来自同一个图像,则它们匹配。编码的查询应该与其匹配的键相似,而与其他查询不同。
对于负对,我们维护一个大字典,其中包含以前批处理的编码键。它们作为查询的负样本。我们以队列的形式维护字典。新的batch被入队,较早的batch被出列。通过更改此队列的大小,可以更改负采样数。
这种方法的挑战
- 随着键编码器的更改,在稍后时间点排队的键可能与较早排队的键不一致。为了使用对比学习方法,与查询进行比较的所有键必须来自相同或相似的编码器,这样比较才会有意义且一致。
- 另一个挑战是,使用反向传播学习编码器参数是不可行的,因为这将需要计算队列中所有样本的梯度(这将导致大的计算图)。
为了解决这两个问题,MoCo将键编码器实现为基于动量的查询编码器的移动平均值[1]。这意味着它以这种方式更新关键编码器参数:
其中m非常接近于1(例如,典型值为0.999),这确保我们在不同的时间从相似的编码器获得编码键。
损失函数-InfoNCE
我们希望查询接近其所有正样本,远离所有负样本。InfoNC函数E会捕获它。它代表信息噪声对比估计。对于查询q和键k,InfoNCE损失函数是:
我们可以重写为:
当q和k的相似性增大,q与负样本的相似性减小时,损失值减小
以下是损失函数的代码:
τ = 0.05
def loss_function(q, k, queue):
# N是批量大小
N = q.shape[0]
# C是表示的维数
C = q.shape[1]
# bmm代表批处理矩阵乘法
# 如果mat1是b×n×m张量,那么mat2是b×m×p张量,
# 然后输出一个b×n×p张量。
pos = torch.exp(torch.div(torch.bmm(q.view(N,1,C), k.view(N,C,1)).view(N, 1),τ))
# 在查询和队列张量之间执行矩阵乘法
neg = torch.sum(torch.exp(torch.div(torch.mm(q.view(N,C), torch.t(queue)),τ)), dim=1)
# 求和
denominator = neg + pos
return torch.mean(-torch.log(torch.div(pos,denominator)))
让我们再看看这个损失函数,并将它与分类交叉熵损失函数进行比较。
这里pred?是数据点在第i类中的概率值预测,true?是该点属于第i类的实际概率值(可以是模糊的,但大多数情况下是一个one-hot)。
如果你不熟悉这个话题,你可以看这个视频来更好地理解交叉熵。另外,请注意,我们经常通过softmax这样的函数将分数转换为概率值:https://www.youtube.com/watch?v=ErfnhcEV1O8
我们可以把信息损失函数看作交叉熵损失。数据样本“q”的正确类是第r类,底层分类器基于softmax,它试图在K+1类之间进行分类。
Info-NCE还与编码表示之间的相互信息有关;关于这一点的更多细节见[4]。
MoCo-v2框架
现在,让我们把所有的东西放在一起,看看整个Moco-v2算法是什么样子的。
步骤1:
我们必须得到查询和键编码器。最初,键编码器具有与查询编码器相同的参数。它们是彼此的复制品。随着训练的进行,键编码器将成为查询编码器的移动平均值(在这一点上进展缓慢)。
由于计算能力的限制,我们使用Resnet-18体系结构来实现。在通常的resnet架构之上,我们添加了一些密集的层,以使表示的维数降到25。这些层中的某些层稍后将充当投影。
# 定义我们的深度学习架构
resnetq = resnet18(pretrained=False)
classifier = nn.Sequential(OrderedDict([
('fc1', nn.Linear(resnetq.fc.in_features, 100)),
('added_relu1', nn.ReLU(inplace=True)),
('fc2', nn.Linear(100, 50)),
('added_relu2', nn.ReLU(inplace=True)),
('fc3', nn.Linear(50, 25))
]))
resnetq.fc = classifier
resnetk = copy.deepcopy(resnetq)
# 将resnet架构迁移到设备
resnetq.to(device)
resnetk.to(device)
步骤2:
现在,我们已经有了编码器,并且假设我们已经设置了其他重要的数据结构,现在是时候开始训练循环并理解管道了。
这一步是从训练批中获取编码查询和键。我们用L2范数对表示进行规范化。
只是一个约定警告,所有后续步骤中的代码都将位于批处理和epoch循环中。我们还将张量“k”从它的梯度中分离出来,因为我们不需要计算图中的键编码器部分,因为动量更新方程会更新键编码器。
# 梯度零化
optimizer.zero_grad()
# 检索xq和xk这两个图像batch
xq = sample_batched['image1']
xk = sample_batched['image2']
# 把它们移到设备上
xq = xq.to(device)
xk = xk.to(device)
# 获取他们的输出
q = resnetq(xq)
k = resnetk(xk)
k = k.detach()
# 将输出规范化,使它们成为单位向量
q = torch.div(q,torch.norm(q,dim=1).reshape(-1,1))
k = torch.div(k,torch.norm(k,dim=1).reshape(-1,1))
步骤3:
现在,我们将查询、键和队列传递给前面定义的loss函数,并将值存储在一个列表中。然后,像往常一样,对损失值调用backward函数并运行优化器。
# 获得损失值
loss = loss_function(q, k, queue)
# 把这个损失值放到epoch损失列表中
epoch_losses_train.append(loss.cpu().data.item())
# 反向传播
loss.backward()
# 运行优化器
optimizer.step()
步骤4:
我们将最新的batch加入我们的队列。如果我们的队列大小大于我们定义的最大队列大小(K),那么我们就从其中取出最老的batch。可以使用torch.cat进行队列操作。
# 更新队列
queue = torch.cat((queue, k), 0)
# 如果队列大于最大队列大小(k),则出列
# batch大小是256,可以用变量替换
if queue.shape[0] > K:
queue = queue[256:,:]
步骤5:
现在我们进入训练循环的最后一步,即更新键编码器。我们使用下面的for循环来实现这一点。
# 更新resnet
for θ_k, θ_q in zip(resnetk.parameters(), resnetq.parameters()):
θ_k.data.copy_(momentum*θ_k.data + θ_q.data*(1.0 - momentum))
一些训练细节
训练resnet-18模型的Imagenette和Imagewoof数据集的GPU时间接近18小时。为此,我们使用了googlecolab的GPU(16GB)。我们使用的batch大小为256,tau值为0.05,学习率为0.001,最终降低到1e-5,权重衰减为1e-6。我们的队列大小为8192,键编码器的动量值为0.999。
结果
前3层(将relu视为一层)定义了投影头,我们将其移除用于图像分类的下游任务。在剩下的网络上,我们训练了一个线性分类器。
我们得到了64.2%的正确率,而使用10%的标记训练数据,使用MoCo-v2。相比之下,使用最先进的监督学习方法,其准确率接近95%。
对于Imagewoof,我们对10%的标记数据得到了38.6%的准确率。在这个数据集上进行对比学习的效果低于我们的预期。我们怀疑这是因为首先,数据集非常困难,因为所有类都是狗类。
其次,我们认为颜色是这些类的一个重要的区别特征。应用颜色抖动可能会导致来自不同类的多个图像彼此混合表示。相比之下,监督方法的准确率接近90%。
能够弥合自监督模型和监督模型之间差距的设计变更:
- 使用更大更宽的模型。
- 通过使用更大的批量和字典大小。
- 使用更多的数据,如果可以的话。同时引入所有未标记的数据。
- 在大量数据上训练大型模型,然后提取它们。
- 谷歌Colab:https://colab.research.google.com/drive/1AepjEbcHPw2Z-xY8iJkvou-Njnn0VZmd?usp=sharing
- Imagewoof Github仓库结果:https://github.com/thunderInfy/mocov2-imagewoof-results
- Imagenette Github仓库结果:https://github.com/thunderInfy/simclr-with-momentum
- Imagewoof数据集链接:https://github.com/thunderInfy/imagewoof
- Imagenette数据集链接:https://github.com/thunderInfy/imagenette
猜你喜欢
- 2024-10-20 IJCAI 2019 | ProNE: 高精度快速网络表示学习算法
- 2024-10-20 (多图) 一种无刷直流电机电流高精度采样及保护电路的设计
- 2024-10-20 AI大模型企业应用实战(21)-RAG的核心-结果召回和重排序
- 2024-10-20 推荐系统召回负样本专题 - 负样本的构建艺术
- 2024-10-20 差分放大电路的应用 差分放大电路应用场景
- 2024-10-20 "全能选手"召回表征算法实践
- 2024-10-20 闲鱼搜索召回升级:向量召回&个性化召回
- 2024-10-20 机器学习评估指标 AUC 综述 机器学习auc小于0.5
- 2024-10-20 阿里飞猪个性化推荐:召回篇 阿里飞猪部门在哪个园区
- 2024-10-20 揭秘:反馈点接到运放同相端,输出震荡了,电路还是负反馈吗?
你 发表评论:
欢迎- 11-18软考系统分析师知识点十六:系统实现与测试
- 11-18第16篇 软件工程(四)过程管理与测试管理
- 11-18编程|实例(分书问题)了解数据结构、算法(穷举、递归、回溯)
- 11-18算法-减治法
- 11-18笑疯了!巴基斯坦首金!没有技巧全是蛮力!解说:真远啊!笑死!
- 11-18搜索算法之深度优先、广度优先、约束条件、限界函数及相应算法
- 11-18游戏中的优化指的的是什么?
- 11-18算法-分治法
- 最近发表
- 标签列表
-
- oraclesql优化 (66)
- 类的加载机制 (75)
- feignclient (62)
- 一致性hash算法 (71)
- dockfile (66)
- 锁机制 (57)
- javaresponse (60)
- 查看hive版本 (59)
- phpworkerman (57)
- spark算子 (58)
- vue双向绑定的原理 (68)
- springbootget请求 (58)
- docker网络三种模式 (67)
- spring控制反转 (71)
- data:image/jpeg (69)
- base64 (69)
- java分页 (64)
- kibanadocker (60)
- qabstracttablemodel (62)
- java生成pdf文件 (69)
- deletelater (62)
- com.aspose.words (58)
- android.mk (62)
- qopengl (73)
- epoch_millis (61)
本文暂时没有评论,来添加一个吧(●'◡'●)