Py学习  »  机器学习算法

深度学习笔记 | 第18讲:深度生成模型之生成式对抗网络(GAN)

狗熊会 • 5 年前 • 799 次点击  



大家好!又到了每周一狗熊会的深度学习时间了。在上一讲中,我们对深度生成模型中的自编码器模型进行了相对详细的介绍,对原始的自编码器及其简单实现、变分自编码器的原理、与概率图模型的一些关系以及基于Keras的简单实现都进行了较为细致的讲解。本节小编要和大家介绍的是另一种深度生成模型——生成式对抗网络。作为生成模型的另一座大山,生成式对抗网络又有哪些特性呢?

1

GAN

作为生成模型两座大山之一,生成式对抗网络(Generative Adversial Networks)自从问世以来就颇受瞩目。相对于变分自编码器,生成式对抗网络也可以学习图像的潜在空间表征,它可以生成与真实图像再统计上几乎无法区分的合成图像。本节就介绍一下 GAN 的基本原理。

追本溯源,开创 GAN 的必读论文是 Ian Goodfellow 的 Generative Adversarial Networks,Goodfellow 想必大家都很熟悉了,就是那本号称深度学习圣经的“花书”的作者。本节在结合论文的基础上对GAN进行解读。

在正式讲 GAN 之前,我们先以一个例子作为类比,先来直观的体会一下GAN的基本思想。假设一名书法家想伪造一幅王羲之的书法,一开始的时候,这名书法家对于模仿王羲之的书法并不精通,开始时候模仿的书法和王羲之的真迹放在一起交给另一位行家,这位行家对每一幅书法的真实性都进行了鉴定和评估,并向第一位画家进行反馈:告诉他王羲之的书法特点和精髓,以及如何模仿才像真正的王羲之书法。模仿的画家根据反馈回去继续研究,并不断给出新的模仿书法。随着时间的推移,模仿者越来越擅长模仿王羲之的书法,鉴定者也越来越擅长找出真正的赝品。

所以,GAN 的核心思想就在于两个部分:一个伪造者网络和一个鉴定网络。二者互相对抗,共同演进,在此过程大家的水平都越来越高,伪造者网络生成的图像就足以达到以假乱真的水平。基于这个思想,我们来看一下 GAN 的原理与细节。

GAN 的基本原理就在于两个网络:G(Generator)和D(Discriminator),分别是生成器和判别器。生成器网络以一个随机向量作为输入,并将其解码生成为一张图像,而判别器一张真实或者合成的图像作为输入,并预测该图像是来自于真实数据还是合成的图像。在训练过程中,生成网络G的目标就是尽量生成真实的图片去欺骗判别网络D。而D的目标就是尽量把G生成的图片和真实的图片分别开来。这样,G和D构成了一个动态的“博弈过程”。在理想状态下,博弈的结果就是G可以生成足以以假乱真的图片G(z),而此时的 D 难以判定生成的图像到底是真是假,最后得到D(G(z)) = 0.5的结果。这块的理解跟博弈论中零和博弈非常类似,可以说 GAN 借鉴了博弈论中相关的思想和方法。

将上述表述转化为数学语言描述时就是:为了学习生成器关于数据x上的分布pg, 我们定义输入噪声的先验变量pz(z)),然后使用G(z;θg)来代表数据空间的映射。这里G是一个由含有参数θg的多层感知机表示的可微函数。我们再定义了一个多层感知机D(x;θd)用来输出一个单独的标量。D(x)代表xx来自于真实数据分布而不是pg的概率,我们训练D来最大化分配正确标签给不管是来自于训练样例还是G生成的样例的概率.我们同时训练GG来最小化log(1−D(G(z)))。换句话说,D和G的训练是关于值函数V(G,D)的极小化极大的二人博弈问题:

下图是真实数据和生成数据所代表的两个分布在 GAN 训练中的演化过程:

GAN 在训练过程中,同时更新判别分布(D,蓝色虚线)使D能区分数据真实分布px(黑色虚线)中的样本和生成分布pg (G,绿色实线) 中的样本。下面的黑色箭头表示生成模型 x=G(z) 如何将分布pg作用在转换后的样本上。可以看到,在经过若干次训练之后,判别分布接近某个稳定点,此时真实分布等于生成分布,即pdata=pg。判别器将无法区分训练数据分布和生成数据分布,即D(x)=1/2。

生成器和判别器的优化算法都是随机梯度下降。但有个细节需要注意:第一步我们训练D,D是希望V(G, D)越大越好,所以这里是梯度上升(ascending)。第二步训练G时,V(G, D)越小越好,所以到这里则是梯度下降(descending)。整个训练是一个动态的交替过程。

上面我们已经知道了极小化极大的二人博弈问题的全局最优结果为pg=pdata,在给定任意生成器G的情况下,考虑最优判别器D。给定任意生成器G,判别器D 的训练标准为最大化目标函数V(G, D):

可以看到,对于任意不为零的(a, b),函数 y=alog(y)+blog(1-y) 在[0,1]中的 a/a+b 处达到最大值。

以上便是生成式对抗网络的基本原理,笔者将在下一讲继续分享有关 GAN 的扩展网络和一些基本实现等内容。

2

训练一个DCGAN

自从GoodFellow提出GAN以后,GAN就存在着训练困难、生成器和判别器的loss无法指示训练进程、生成样本缺乏多样性等问题。为了解决这些问题,后来的研究者不断推陈出新,以至于现在有着各种各样的GAN变体和升级网络。比如 LSGAN,WGAN,WGAN-GP,DRAGAN,CGAN,infoGAN, ACGAN,EBGAN,BEGAN,DCGAN以及最近号称史上最强图像生成网络的BigGAN等等。本节仅选取其中的DCGAN——深度卷积对抗网络进行简单讲解并利用keras进行实现。

DCGAN的原始论文为 UNSUPERVISED REPRESENTATION LEARNING WITH DEEP CONVOLUTIONAL GENERATIVE ADVERSARIAL NETWORKS,所谓DCGAN,顾名思义就是生成器和判别器都是深度卷积神经网络的GAN。

搭建一个稳健的DCGAN要点在于:

  • 所有的pooling层使用步幅卷积(判别网络)和微步幅度卷积(生成网络)进行替换。

  • 在生成网络和判别网络上使用批处理规范化。

  • 对于更深的架构移除全连接隐藏层。

  • 在生成网络的所有层上使用ReLU激活函数,除了输出层使用Tanh激活函数。

  • 在判别网络的所有层上使用LeakyReLU激活函数。

基于DCGAN生成的卧室图片:

下面就基于keras搭建一个DCGAN。
导入相关模块并设置相关参数:

from keras.layers import Dense, Conv2D, LeakyReLU, Dropout, Input
from keras.layers import Reshape, Conv2DTranspose, Flatten
from keras.models import Model
from keras import optimizers
import kerasimport numpy as npimport warnings warnings.filterwarnings('ignore')

设置相关参数:

# 潜变量维度
latent_dim = 32
# 输入像素维度
height = 32
width = 32
channels = 3

下面开始搭建生成器网络:




    
generator_input = Input(shape=(latent_dim,))
x = Dense(128 * 16 * 16)(generator_input)
x = LeakyReLU()(x)
x = Reshape((16, 16, 128))(x)
x = Conv2D(256, 5, padding='same')(x)
x = LeakyReLU()(x)
x = Conv2DTranspose(256, 4, strides=2, padding='same')(x)
x = LeakyReLU()(x)
x = Conv2D(256, 5, padding='same')(x)
x = LeakyReLU()(x)
x = Conv2D(256, 5, padding='same')(x)
x = LeakyReLU()(x)
x = Conv2D(channels, 7, activation='tanh', padding='same')(x)
generator = Model(generator_input, x)
generator.summary()

生成器网络概要如下:

然后搭建判别器网络:

discriminator_input = Input(shape=(height, width, channels))
x = Conv2D(128, 3)(discriminator_input)
x = LeakyReLU()(x)
x = Conv2D(128, 4, strides=2)(x)
x = LeakyReLU()(x)
x = Conv2D(128, 4, strides=2)(x)
x = LeakyReLU()(x)
x = Conv2D(128, 4, strides=2)(x)
x = LeakyReLU()(x)
x = Flatten()(x)
x = Dropout(0.4)(x)
x = Dense(1, activation='sigmoid')(x)
discriminator = Model(discriminator_input, x)
discriminator.summary()
discriminator_optimizer = optimizers.RMSprop(lr=0.0008, 
                                             clipvalue=1.0, 
                                             decay=1e-8)

discriminator.compile(optimizer=discriminator_optimizer,
                      loss='binary_crossentropy')

判别器网络概要如下:

将生成器网络和判别器网络进行组合成DCGAN:




    
# 将判别器参数设置为不可训练
discriminator.trainable = False
gan_input = Input(shape=(latent_dim,)) gan_output = discriminator(generator(gan_input))
# 搭建对抗网络
gan = Model(gan_input, gan_output) gan_optimizer = optimizers.RMSprop(lr=0.0004,                                   clipvalue=1.0,                                   decay=1e-8) gan.compile(optimizer=gan_optimizer, loss='binary_crossentropy')

DCGAN搭建完成之后,我们使用CIFAR-10数据来进行训练,构建训练代码如下:

import os
from keras.preprocessing import image
# 加载cifar-10数据
(x_train, y_train), (_, _) = keras.datasets.cifar10.load_data()
# 指定青蛙图像(编号为6)
x_train = x_train[y_train.flatten() == 6] x_train = x_train.reshape((x_train.shape[0],) +(height, width, channels)).astype('float32') / 255.
iterations = 10000

batch_size = 20

save_dir = './image'

start = 0
for step in range(iterations):    
   # 潜在空间随机采样    random_latent_vectors = np.random.normal(size=(batch_size, latent_dim))    # 解码生成虚假图像    generated_images = generator.predict(random_latent_vectors)    stop = start + batch_size    real_images = x_train[start: stop]    
   # 将虚假图像和真实图像混合    combined_images = np.concatenate([generated_images, real_images])    # 合并标签,区分真实和虚假图像    labels = np.concatenate([np.ones((batch_size, 1)), np.zeros((batch_size, 1))])    
   # 向标签中添加随机噪声    labels += 0.05 * np.random.random(labels.shape)    
   # 训练判别器    d_loss = discriminator.train_on_batch(combined_images, labels)    
   # 潜在空间随机采样    random_latent_vectors = np.random.normal(size=(batch_size, latent_dim))    
   # 合并标签,以假乱真    misleading_targets = np.zeros((batch_size, 1))    
   # 通过gan模型来训练生成器模型,冻结判别器模型权重    a_loss = gan.train_on_batch(random_latent_vectors, misleading_targets)    start += batch_size    
   if start > len(x_train) - batch_size:        start = 0    # 每100步绘图并保存    if step % 100 == 0:        gan.save_weights('gan.h5')        print('discriminator loss:', d_loss)        print('adversarial loss:', a_loss)        img = image.array_to_img(generated_images[0] * 255., scale=False )        img.save(os.path.join(save_dir, 'generated_frog' + str(step) + '.png'))        img = image.array_to_img(real_images[0] * 255., scale=False)        img.save(os.path.join(save_dir, 'real_frog' + str(step) + '.png'))

训练过程如下:

DCGAN生成的青蛙图片和真实图片混在一起如下图所示,能否辨别出哪张是真实样本,哪张是DCGAN生成的样本?

受限于CIFAR-10数据本身的低像素性,DCGAN生成出来的图像虽然也很模糊,但基本上足以达到以假乱真的水平。上图图片中,每一列有两张是生成样本,有一张是真实样本,按列第2、1、3和2张图片是真实样本,其余都是DCGAN伪造出来的青蛙图片。

以上便是本讲内容。

本节小编跟大家介绍了第二种深度生成模型,生成式对抗网络。在依据论文的基础上对GAN的基本原理和架构进行了简单的介绍,并在此基础上介绍了GAN的扩展网络——DCGAN,给出了基于keras在CIFAR-10上的训练示例作为参考。咱们下一期见!

 


【参考资料】

Generative Adversarial Nets

https://zhuanlan.zhihu.com/p/24767059

Deep Learning with Python

UNSUPERVISED REPRESENTATION LEARNING WITH DEEP CONVOLUTIONAL GENERATIVE ADVERSARIAL NETWORKS

https://blog.csdn.net/liuxiao214/article/details/74502975

thttp://www.twistedwg.com/2018/01/31/Various-GAN.html


作者简介


鲁伟,狗熊会人才计划一期学员。目前在杭州某软件公司从事数据分析和深度学习相关的研究工作,研究方向为贝叶斯统计、计算机视觉和迁移学习。

识别二维码,查看作者更多精彩文章





识别下方二维码成为狗熊会会员!

友情提示:

个人会员不提供数据、代码

视频only!

个人会员网址:http://teach.xiong99.com.cn

点击“阅读原文”,成为狗熊会会员!

今天看啥 - 高品质阅读平台
本文地址:http://www.jintiankansha.me/t/nCsqz77WN5
Python社区是高质量的Python/Django开发社区
本文地址:http://www.python88.com/topic/28082
 
799 次点击