计算机系统应用教程网站

网站首页 > 技术文章 正文

CycleGAN

btikc 2024-09-05 12:32:11 技术文章 8 ℃ 0 评论

一种主题的风格转换成另外一种主题的风格,其他位置还能保持不变,比如

普通马和斑马之间的转换,草原景色不变。

  • 配对的数据(造型相同,只是颜色不一样)

  • 不配对的数据

  • 左右不匹配,左边是实际的场景,右边是油画。
  • 无论是配对图像还是非配对图像都可以进行训练,即不需要配对,也能识别出来。只需要两组图像数据集即可,无需指定对应关系,例如
    • 马和斑马
    • 航拍地图和地图据

    • 马和斑马训练和测试数据的目录结构

普通马训练数据trainA

斑马训练数据trainB

只需要指定马长什么样子,斑马长什么样子,不需要它们之间有一一对应的关系。

航拍地图数据trainA

普通地图数据trainB

只需要知道航拍数据是怎么样的,普通的地图数据是怎么样,就可以进行训练了,不需要一一对应,让网络学一学实际拍出来是什么样子,转换到地图中又是什么样子。

CycleGan怎么进行学习的

  • 传统的方式

  • 输入一张真实的图片A,经过G A-B生成网络,生成一张假的图片B,将假的B输入判决网络D中,G A-B网络希望判决器认为B是真的,判决器希望自己能够识别出来B是假的,将真实的马输入判决器,判决器希望能够识别出来是真实的数据,所以,无论是生成网络生成的假的马的数据,还是输入的真实马的数据到判决器,判决器都会认为这些都是真实的马,所以G A-B网络只需要生成的图片是马的数据,就可以骗过判决器。
  • 新的方式

为了使得G AB生成出来的结果跟原始输入是有关联的,增加了G BA网络,由B还原成A,使得真实的输入图片A和还原的图片尽可能的相似。G AB是在原始输入图片的基础之上进行了合成,G BA是将合成之后的图片还原,然后比较输入图片和还原图片的差异,即计算L2 Loss(目标变量和预测值的差值平方和)

上图中有3个损失函数

  • G网络需要计算损失
  • D网络需要计算损失
  • GAB的输入和GBA的输出之间也需要计算损失即Cycle网络

当前这个网络(把马转换成斑马)主要考虑GAB,让网络学习怎样做转换怎样做还原,但怎样把GBA突出出来呢?即把斑马转换成马

上图中有2个生成器和2个判决器共4个网络。怎么样让网络达到训练要求,是由损失函数决定的,当前的输入(Input_A)和最终还原出来的输出(Cyclic_A)做L2 Loss计算。

普通马图像经过GAB生成一个假的斑马图像B,把这个假的斑马图像B再输入到GAB中,它应该生成和原始输入Input_A一摸一样的图像,因为将斑马输入到GAB中,GAB就应该知道:这就是我想输出的结果。所以看到是斑马作为输入,直接将斑马输出就可以了。

上图中涉及4种损失函数:

  • G网络
  • D网络
  • Cycle网络
  • Identity网络 比如将生成的斑马B再输入到GAB中,GAB就会知道这个斑马就是我要输出的结果,直接将输入原封不动的输出就行了,这个过程也需要计算损失函数

判决器D网络有点特别,PatchGAN

之前判决器是传入一个Sigmoid函数中,最终得到的是一个数值;现在判决器经过卷积神经网络得到一个输出结果,但是最终的输出结果不会输入到Sigmoid函数中,也不会连全连接层,就是一个特征图,比如N x N x 1的特征图:最终一次卷积,filter个数是1,就得到了N x N的矩阵,需要基于感受野来计算损失。

从特征图中的每一点都能看到原始输入的一部分

第一次卷积得到的特征图(圈红的地方)是3x3的卷积核得到的结果,它能看到的位置就是原始输入3x3的部分。

  • 基于感受野在特征图上预测结果,标签也需要是NxN的矩阵计算损失值

原始图像经过一次卷积得到一个特征图,里面有4个点,点1对应的位置是原始图像中红色的部分,其他依次类推,每个点都能看到原始输入的一个区域。这一张图中有4个小patch(区域),不通过一张图判断是真还是假,而是基于每个小patch都做判断。

标签跟输出结果是一样的,也得是N x N的矩阵,代表每个小patch的标签值,判决器对每个小patch判决结果都是1才达到完美。

实现该场景的开源项目

https://gitee.com/pingfanrenbiji/pytorch-CycleGAN-and-pix2pix

下载训练数据

下载数据源在sh脚本中可以看到

其中maps是航拍转换成地图数据,hosrse2zebra是马和斑马数据,apple2orange是苹果转换成橙子的数据

将下载好的数据,放到datasets目录下

训练模型

然后进行模型训练得到模型或者下载已经训练好的模型,

这些都是已经训练好的模型,就可以直接拿测试数据进行预测结果了

有了模型之后,进行模型预测

或者通过idea传入参数

  • 第一个参数是测试文件夹的数据
  • 第二个参数是模型名称

  • 从这里获取指定模型,自己训练的模型保存在这个文件夹中。
  • 这个文件夹是下载的已经训练好的模型
  • 预测结果存在在这里

比如其中一对数据如下:

一个是假的数据,一个是真的数据

训练模型参数指定

或者

这个模型训练需要的显存(如同计算机的内存一样,显存是用来存储要处理的图形信息的部件)比较大,如果没有一个非常好的服务器或工作站或专门做深度学习的,batch_size就设置为1。默认的图片输入大小是256x256,如果显存实在太小,可以调整为128x128,最少显存8G,最好12G,不然会报错 memery error。

关键代码分析-构建数据集

读取数据,

指定要当前所做项目的名称,非对齐,也就是CycleGAN

指定好数据集,

有1096个数据,

指定输入输出的颜色通道RGB,一般是3

判断是否要做resize,将原始输入数据(256x256) resize成 286x286的

先resize,再crop操作,固定大小256x256

随机50%的可能性做这个翻转操作

归一化操作,第一步转换成Tensor格式,

然后指定平均值和标准差。原始输入数据取值范围是0-1之间的,实际网络训练,尤其是GAN网络,希望结果是-1到1之间,可能训练的会更好一些。mean和std在各个颜色通道上都指定0.5之后,所有输入数据的取值范围就都是-1到1之间了。

RandomHorizontalFlip是数据增强

创建模型

损失函数的名称定义

输入数据经过GAB的合成与GBA的还原跟原始的输入尽可能的相同。G网络把A转换成B,实际输入一个B,更应该输出一个B即把实际要生成的数据当作输入之后,输出是否跟输入是一样的,通过损失函数来计算。

netG_A和netG_B两个网络架构是一样的,区别是输入和输出不同、标签指定不同,损失函数不同。

norm='batch',沿batch的方向做归一化,主要用在卷积网络当中,做分类或回归任务都是用它,通常指定的batch都是比较大的,这里没有用batch,而是用InstanceNorm2d,原因1是因为一个一个训练的,第二点在做划分的时候影响更大的在channel(颜色通道)或着特征图的方向,在R、G、B上分别自己做归一化

这个网络是残差网络,

网络先加上了一个Padding

圈红的地方是原始的输入数据,指定padding=2上下左右都加了2圈并做了翻转(默认是翻转模式,自己可以设置),比如036上面是63下面是30。

第一步先做一个基本的卷积操作(所有的提取操作都是卷积,没有全连接的概念),将彩色的3个特征图转换成64个特征图。不断的做卷积得到的特征图一般会越来越小,特征图的个数会越来越多,这个就是正常卷积的过程,接下来执行反卷积(特征图越来越少,特征图大小越来越大),相当于特征提取完了,再还原回去,直到最后一个,可能就和原始的输入一样了即256x256x3,设置最终filter=3,就得到类似图像数据了。

默认添加9个残差的模块,来提取特征

每个模块都是一样的,所有的卷积操作都是输入256,输出256,特征图的个数是不变的,

这是上采样,用反卷积去做的。

输入256个特征,输出128个特征,反卷积的过程当中,特征图的个数变少,特征图大小要变大。

按照倍数把当前结果比上2,相等于特征图是原来的一半。

第一步做了一次反卷积得到128个特征图,

第二次for循环,128个特征图变成了64个特征图,又执行了一次反卷积,

最后一步,做了一个正常的卷积,输入64个特征图,输出跟任务是挂钩的,希望对抗生成网络最后生成的结果是实际的一张图像,channel个数一定是3,

最后一步生成结果是

为了使得跟原始数据有可比性,添加一个激活函数

一般都加Tanh函数,因为刚开始做预处理的时候,把数值的取值范围映射到-1到1之间了,输出也得跟它是一致的取值范围才行,所以这里加上一个Tanh激活函数。

这一步就是权重参数初始化,一般用高斯分布初始化所有的权重参数。

测试就不需要判决器了,直接拿生成器生成结果

颜色通道的归一化,

一开始做卷积,输入是3个颜色通道,得到64个特征图,

特征图个数变多了


接下来还是执行这样一个操作

又加了一个卷积层,得到的特征是512个,H x W x 512。在patch_gan中,最终得到的结果是N x N的矩阵,N x N中的每一个值代表原始图像中的一个小区域,基于当前这个N x N的结果去判断一下,每个小区域是不是都做对了。一点代表一个结果,而不是512个点代表一个batch。

怎么样把一个点代表一个batch,得将H x W x 512转换成H x W x 1,这就得到了N x N x 1的当前的判决器的结果了。

本文暂时没有评论,来添加一个吧(●'◡'●)

欢迎 发表评论:

最近发表
标签列表