Py学习  »  Git

Github大热论文 | U-GAT-IT:基于GAN的新型无监督图像转换

PaperWeekly • 4 年前 • 397 次点击  


作者丨武广

学校丨合肥工业大学硕士生

研究方向丨图像生成


生成对抗网络(GAN)在这几年的发展下已经渐渐沉淀下来,在网络的架构、训练的稳定性控制、模型参数设计上都有了指导性的研究成果。我们可以看出 17、18 年大部分关于 GAN 的有影响力的文章多集中在模型自身的理论改进上,如 PGGAN、SNGAN、SAGAN、BigGAN、StyleGAN 等,这些模型都还在强调如何通过随机采样生成高质量图像。19 年关于 GAN 的有影响力的文章则更加关注 GAN 的应用上,如 FUNIT、SPADE 等已经将注意力放在了应用层,也就是如何利用 GAN 做好图像翻译等实际应用任务。


学术上的一致性也暗示了 GAN 研究的成熟,本文主要介绍一篇利用 GAN 的新型无监督图像转换论文。




论文引入


图像到图像转换可以应用在很多计算机视觉任务,图像分割、图像修复、图像着色、图像超分辨率、图像风格(场景)变换等都是图像到图像转换的范畴。生成对抗网络 [1] 不仅仅在模型训练的收敛速度上,同时在图像转换质量上展示了优越的结果。

这些优越性能相比 Pixel CNN、VAE、Glow 都是具有很大竞争力的。所以近年来的围绕 GAN 实现图像翻译的研究是很多的,例如 CycleGAN、UNIT、MUNIT、DRIT、FUNIT、SPADE。图像翻译是 GAN 铺开应用的第一步,跨模态间的转换,文本到图像、文本到视频、语音到视频等,凡是这种端到端,希望实现一个分布到另一个分布转换的过程,GAN 都是可以发挥一定的作用的。

回归到现实,图像到图像的转换到目前为止还是具有一定挑战性的,大多数的工作都围绕着局部纹理间的转换展开的,例如人脸属性变换、画作的风格变换、图像分割等,但是在图像差异性较大的情况下,在猫到狗或者是仅仅是语义联系的图像转换上的表现则不佳的。

这就是图像转换模型的适用域问题了,实现一个具有多任务下鲁棒的图像转换模型是十分有必要的。本文将要介绍的 U-GAT-IT 正是为了实现这种鲁棒性能设计的,我们先宏观的看一下文章采用何种方式去实现这种鲁棒性能。


首先是引入注意力机制,这里的注意力机制并不传统的 Attention 或者 Self-Attention 的计算全图的权重作为关注,而是采用全局和平均池化下的类激活图(Class Activation Map-CAM)[2]  来实现的,CAM 对于做分类和检测的应该很熟悉,通过 CNN 确定分类依据的位置,这个思想和注意力是一致的,同时这对于无监督下语义信息的一致性判断也是有作用的,这块我们后续再进行展开。

有了这个注意力图,文章再加上自适应图层实例归一化(AdaLIN),其作用是帮助注意力引导模型灵活控制形状和纹理的变化量。有了上述的两项作用,使得 U-GAT-IT 实现了鲁棒下的图像转换。总结一下 U-GAT-IT 的优势:


  • 提出了一种新的无监督图像到图像转换方法,它具有新的注意模块和新的归一化函数 AdaLIN。

  • 注意模块通过基于辅助分类器获得的注意力图区分源域和目标域,帮助模型知道在何处进行密集转换。

  • AdaLIN 功能帮助注意力引导模型灵活地控制形状和纹理的变化量,增强模型鲁棒性。

模型结构


端到端模型最直观的展示就是模型结构图,我们看一下 U-GAT-IT 实现结构:



我们先把我们能直观看懂的部分做一个介绍,模型分为生成器和判别器,可以看到生成器和判别器的结构几乎相同,生成器好像多了一点操作(这多的这点就是 AdaLIN 和 Decoder部分),我们分析生成器,首先是对端的输入端进行图像的下采样,配合残差块增强图像特征提取,接下来就是注意力模块(这部分乍一看,看不出具体细节,后续分析),接着就是对注意力模块通过 AdaLIN 引导下残差块,最后通过上采样得到转换后的图像。对于判别器相对于生成器而言,就是将解码过程换成判别输出。


重点来了,就是如何实现图像编码后注意力模块以及 AdaLIN 怎样引导解码得到目标域图像的呢?


CAM & Auxillary classifier


对于这部分计算 CAM,结合模型结构图做进一步理解:



由上图,我们可以看到对于图像经过下采样和残差块得到的  Encoder Feature map 经过 Global average pooling 和 Global max pooling 后得到依托通道数的特征向量。创建可学习参数 weight,经过全连接层压缩到 B×1 维,这里的 B 是 BatchSize,对于图像转换,通常取为 1。

对于学习参数 weight 和 Encoder Feature map 做 multiply(对应位想乘)也就是对于 Encoder Feature map 的每一个通道,我们赋予一个权重,这个权重决定了这一通道对应特征的重要性,这就实现了 Feature map 下的注意力机制。

对于经过全连接得到的 B×1 维,在 average 和 max pooling 下做 concat 后送入分类,做源域和目标域的分类判断,这是个无监督过程,仅仅知道的是源域和目标域,这种二分类问题在 CAM 全局和平均池化下可以实现很好的分类。

当生成器可以很好的区分出源域和目标域输入时在注意力模块下可以帮助模型知道在何处进行密集转换。将 average 和 max 得到的注意力图做 concat,经过一层卷积层还原为输入通道数,便送入 AdaLIN 下进行自适应归一化。


AdaLIN



由上图,完整的 AdaLIN 操作就是上图展示,对于经过 CAM 得到的输出,首先经过 MLP 多层感知机得到 γ,β,在 Adaptive Instance Layer resblock 中,中间就是 AdaLIN 归一化。

AdaLIN 正如图中展示的那样,就是 Instance Normalization 和 Layer Normalization 的结合,学习参数为 ρ,论文作者也是参考自 BIN [3] 设计。AdaIN 的前提是保证通道之间不相关,因为它仅对图像 map 本身做归一化,文中说明 AdaIN 会保留稍多的内容结构,而 LN 则并没有假设通道相关性,它做了全局的归一化,却不能很好的保留内容结构,AdaLIN 的设计正是为了结合 AdaIN 和 LN 的优点。


判别器


文章的源码中,判别器的设计采用一个全局判别器(Global Discriminator)以及一个局部判别器(Local Discriminator)结合实现,所谓的全局判别器和局部判别器的区别就在于全局判别器对输入的图像进行了更深层次的特征压缩,最后输出的前一层,feature map 的尺寸达到了

根据感受野的传递,这个尺度卷积下的感受野是作用在全局的(感受野超过了图像尺寸),读者可以自行按照论文给出的网络设计参数进行计算(kernel 全为 4),(我算的结果是 286×286 比输入图像 256×256 要大)对于局部判别器,最后输出的前一层,feature map 的尺寸达到了,此时感受野是达不到图像尺寸(我算的结果是 70×70),这部分称为局部判别器。

最后通过 extend 将全局和局部判别结果进行连接,此处要提一下,在判别器中也加入了 CAM 模块,虽然在判别器下 CAM 并没有做域的分类,但是加入注意力模块对于判别图像真伪是有益的,文中给出的解释是注意力图通过关注目标域中的真实图像和伪图像之间的差异来帮助进行微调。

损失函数


对于利用 GAN 实现图像到图像转换的损失函数其实也就那几个,首先是 GAN 的对抗损失,循环一致性损失,以及身份损失(相同域之间不希望进行转换),最后说一下 CAM 的损失。CAM 的损失主要是生成器中对图像域进行分类,希望源域和目标域尽可能分开,这部分利用交叉熵损失: 



在判别器中,也对真假图像的 CAM 进行了对抗损失优化,主要是为了在注意图上进一步区分真假图像,最后得到完整的目标函数:



Trick 和 AdaLIN 实现代码


作者公布的源码如果大家尝试运行可能会参数爆内存,这里的罪魁祸首发生在计算自适应图层实例归一化下的 MLP 直接从卷积展平送到全连接层中参数过大,大致计算一下,从 1×64×64×256 直接展平的尺寸是 1×1048576 再进行全连接操作,这个计算量是很大的,为了减小计算量,可以在这一步先对 map 进行 global average pooling 再进行全连接操作。


对于生成器 Encoder 部分的归一化采用 IN,Decoder 采用 LIN,判别器网络加上谱归一化限制(生成器并没有加),我们接下来放上论文的核心 AdaLIN 的 tensorflow 下代码实现:


def adaptive_instance_layer_norm(x, gamma, beta, smoothing=True, scope='instance_layer_norm'):
    with tf.variable_scope(scope):
        ch = x.shape[-1]
        eps = 1e-5
        # 计算Instance mean,sigma and ins
        ins_mean, ins_sigma = tf.nn.moments(x, axes=[12], keep_dims=True)
        x_ins = (x - ins_mean) / (tf.sqrt(ins_sigma + eps))

        # 计算Layer mean,sigma and ln
        ln_mean, ln_sigma = tf.nn.moments(x, axes=[123], keep_dims=True)
        x_ln = (x - ln_mean) / (tf.sqrt(ln_sigma + eps))

        # 给定rho的范围,smoothing控制rho的弹性范围
        if smoothing:
            rho = tf.get_variable("rho", [ch], initializer=tf.constant_initializer(0.9),
                                  constraint=lambda x: tf.clip_by_value(x, 
                                  clip_value_min=0.0, clip_value_max=0.9))
        else:
            rho = tf.get_variable("rho", [ch], initializer=tf.constant_initializer(1.0),
                                  constraint=lambda x: tf.clip_by_value(x,
                                  clip_value_min=0.0, clip_value_max=1.0))

        # rho = tf.clip_by_value(rho - tf.constant(0.1), 0.0, 1.0)

        x_hat = rho * x_ins + (1 - rho) * x_ln

        x_hat = x_hat * gamma + beta

        return x_hat


实验


作者在五个不成对的图像数据集评估了方法的性能,有比较熟悉的马和斑马,猫到狗,人脸到油画,风格场景还有就是最让我感兴趣的作者团队创建的女性到动漫的数据集,不过可惜的是这个数据集作者并没有公布。


在定性和定量上,U-GAT-IT 都展示了优越的结果:



总结


论文提出了无监督的图像到图像转换(U-GAT-IT),其中注意模块和 AdaLIN 可以在具有固定网络架构和超参数的各种数据集中产生更加视觉上令人愉悦的结果。辅助分类器获得的关注图可以指导生成器更多地关注源域和目标域之间的不同区域,从而进行有效的密集转换。此外,自适应图层实例规范化(AdaLIN)可以进一步增强模型在不同数据集下的鲁棒性。


参考文献


[1] Goodfellow I, Pouget-Abadie J, Mirza M, et al. Generative adversarial nets[C]//Advances in neural information processing systems. 2014: 2672-2680.

[2] B. Zhou, A. Khosla, A. Lapedriza, A. Oliva, and A. Torralba. Learning deep features for discriminative localization. In Computer Vision and Pattern Recognition (CVPR), 2016 IEEE Conference on, pages 2921–2929. IEEE, 2016. 2, 3

[3] H. Nam and H.-E. Kim. Batch-instance normalization for adaptively style-invariant neural networks. arXiv preprint arXiv:1805.07925, 2018. 2, 3




点击以下标题查看更多往期内容: 




#投 稿 通 道#

 让你的论文被更多人看到 



如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。


总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 


PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学习心得技术干货。我们的目的只有一个,让知识真正流动起来。


📝 来稿标准:

• 稿件确系个人原创作品,来稿需注明作者个人信息(姓名+学校/工作单位+学历/职位+研究方向) 

• 如果文章并非首发,请在投稿时提醒并附上所有已发布链接 

• PaperWeekly 默认每篇文章都是首发,均会添加“原创”标志


📬 投稿邮箱:

• 投稿邮箱:hr@paperweekly.site 

• 所有文章配图,请单独在附件中发送 

• 请留下即时联系方式(微信或手机),以便我们在编辑发布时和作者沟通




🔍


现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧



关于PaperWeekly


PaperWeekly 是一个推荐、解读、讨论、报道人工智能前沿论文成果的学术平台。如果你研究或从事 AI 领域,欢迎在公众号后台点击「交流群」,小助手将把你带入 PaperWeekly 的交流群里。


▽ 点击 | 阅读原文 | 下载论文 & 源码

Python社区是高质量的Python/Django开发社区
本文地址:http://www.python88.com/topic/37661
 
397 次点击