Py学习  »  机器学习算法

【深度学习】用Pytorch给你的母校做一个样式迁移吧!

机器学习初学者 • 2 年前 • 238 次点击  


前言


先看下效果,我实在没有拍过学校的照片,随便谷歌了一张,学校是哈尔滨理工大学荣成校区。


Github代码我已经开源在文末,环境我使用的是Colab pro,下载直接运行。(别忘了Star~)大家可以用好看的照片哦!


输出图像:


输入图像:


样式迁移


如果你是一位摄影爱好者,你也许接触过滤镜。它能改变照片的颜色样式,从而使风景照更加锐利或者令人像更加美白。但一个滤镜通常只能改变照片的某个方面。如果要照片达到理想中的样式,你可能需要尝试大量不同的组合。这个过程的复杂程度不亚于模型调参。


这里我们需要两张输入图像:一张是内容图像,另一张是样式图像


我们将使用神经网络修改内容图像,使其在样式上接近样式图像。


例如,图像为本书作者在西雅图郊区的雷尼尔山国家公园拍摄的风景照,而样式图像则是一幅主题为秋天橡树的油画。


最终输出的合成图像应用了样式图像的油画笔触让整体颜色更加鲜艳,同时保留了内容图像中物体主体的形状。


2.1 方法

简单的例子阐述了基于卷积神经网络的样式迁移方法。


首先,我们初始化合成图像,例如将其初始化为内容图像。该合成图像是样式迁移过程中唯一需要更新的变量,即样式迁移所需迭代的模型参数。然后,我们选择一个预训练的卷积神经网络来抽取图像的特征,其中的模型参数在训练中无须更新。


这个深度卷积神经网络凭借多个层逐级抽取图像的特征,我们可以选择其中某些层的输出作为内容特征或样式特征。


接下来,我们通过正向传播(实线箭头方向)计算样式迁移的损失函数,并通过反向传播(虚线箭头方向)迭代模型参数,即不断更新合成图像


样式迁移常用的损失函数由3部分组成:

  1. 内容损失使合成图像与内容图像在内容特征上接近;

  2. 样式损失使合成图像与样式图像在样式特征上接近;

  3. 总变差损失则有助于减少合成图像中的噪点。


最后,当模型训练结束时,我们输出样式迁移的模型参数,即得到最终的合成图像。


这里选取的预训练的神经网络含有3个卷积层,其中第二层输出内容特征,第一层和第三层输出样式特征


和无监督学习不是一个道理哈。

这句话很重要哦。


2.2 数据处理和网络实现


第一步,统一图像尺寸。

下面,定义图像的预处理函数和后处理函数。预处理函数preprocess对输入图像在RGB三个通道分别做标准化,并将结果变换成卷积神经网络接受的输入格式。后处理函数postprocess则将输出图像中的像素值还原回标准化之前的值。由于图像打印函数要求每个像素的浮点数值在0到1之间,我们对小于0和大于1的值分别取0和1。

rgb_mean = torch.tensor([0.485, 0.456, 0.406])rgb_std = torch.tensor([0.229, 0.224, 0.225])
def preprocess(img, image_shape):    transforms = torchvision.transforms.Compose([        torchvision.transforms.Resize(image_shape),        torchvision.transforms.ToTensor(),        torchvision.transforms.Normalize(mean=rgb_mean, std=rgb_std)])    return transforms(img).unsqueeze(0)
def postprocess(img):    img = img[0].to(rgb_std.dnet = nn.Sequential(*[pretrained_net.features[i] for i in                      range(max(content_layers + style_layers) + 1)])evice)    img = torch.clamp(img.permute(1, 2, 0) * rgb_std + rgb_mean, 0, 1)    return torchvision.transforms.ToPILImage()(img.permute(2, 0, 1))


基于ImageNet数据集预训练的VGG-19模型来抽取图像特征

pretrained_net = torchvision.models.vgg19(pretrained=True)


使用VGG层抽取特征时,我们只需要用到从输入层到最靠近输出层的内容层或样式层之间的所有层。下面构建一个新的网络net,它只保留需要用到的VGG的所有层。

net = nn.Sequential(*[pretrained_net.features[i] for i in                      range(max(content_layers + style_layers) + 1)])

下面定义两个函数:get_contents函数对内容图像抽取内容特征;get_styles函数对样式图像抽取样式特征。因为在训练时无须改变预训练的VGG的模型参数,所以我们可以在训练开始之前就提取出内容特征和样式特征。由于合成图像是样式迁移所需迭代的模型参数,我们只能在训练过程中通过调用extract_features函数来抽取合成图像的内容特征和样式特征。

def get_contents(image_shape, device):    content_X = preprocess(content_img, image_shape).to(device)    contents_Y, _ = extract_features(content_X, content_layers, style_layers)    return content_X, contents_Y
def get_styles(image_shape, device):    style_X = preprocess(style_img, image_shape).to(device)    _, styles_Y = extract_features(style_X, content_layers, style_layers)    return style_X, styles_Y


2.3 训练


在训练模型进行样式迁移时,我们不断抽取合成图像的内容特征和样式特征,然后计算损失函数。下面定义了训练循环。

def train(X, contents_Y, styles_Y, device, lr, num_epochs, lr_decay_epoch):    X, styles_Y_gram, trainer = get_inits(X, device, lr, styles_Y)    scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_decay_epoch, 0.8)    animator = d2l.Animator(xlabel='epoch', ylabel='loss',                            xlim=[10


    
, num_epochs],                            legend=['content', 'style', 'TV'],                            ncols=2, figsize=(7, 2.5))    for epoch in range(num_epochs):        trainer.zero_grad()        contents_Y_hat, styles_Y_hat = extract_features(            X, content_layers, style_layers)        contents_l, styles_l, tv_l, l = compute_loss(            X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram)        l.backward()        trainer.step()        scheduler.step()        if (epoch + 1) % 10 == 0:            animator.axes[1].imshow(postprocess(X))            animator.add(epoch + 1, [float(sum(contents_l)),                                     float(sum(styles_l)), float(tv_l)])    return X


现在我们训练模型:首先将内容图像和样式图像的高和宽分别调整为300和450像素,用内容图像来初始化合成图像。

device, image_shape = d2l.try_gpu(), (300, 450)net = net.to(device)content_X, contents_Y = get_contents(image_shape, device)_, styles_Y = get_styles(image_shape, device)output = train(content_X, contents_Y, styles_Y, device, 0.3, 500, 50)

device, image_shape = d2l.try_gpu(), (300, 450)net = net.to(device)content_X, contents_Y = get_contents(image_shape, device)_, styles_Y = get_styles(image_shape, device)output = train(content_X, contents_Y, styles_Y, device, 0.3, 500, 50)


  • 样式迁移常用的损失函数由3部分组成:(i) 内容损失使合成图像与内容图像在内容特征上接近;(ii) 样式损失令合成图像与样式图像在样式特征上接近;(iii) 总变差损失则有助于减少合成图像中的噪点。

  • 我们可以通过预训练的卷积神经网络来抽取图像的特征,并通过最小化损失函数来不断更新合成图像来作为模型参数。

  • 我们使用格拉姆矩阵表达样式层输出的样式。


开源代码


代码地址:https://github.com/lixiang007666/Style-Transfer-pytorch


运行style-pytorch.ipynb:


训练300个epoch结果:

epoch  10, content loss 25.22, style loss 3014.07, TV loss 1.16, 0.01 secepoch  20, content loss 29.34, style loss 740.11, TV loss 1.31, 0.00 secepoch  30, content loss 30.87, style loss 383.17, TV loss 1.36, 0.00 secepoch  40, content loss 31.51, style loss 250.63, TV loss 1.40, 0.01 secepoch  50, content loss 31.39, style loss 190.49, TV loss 1.45, 0.01 secepoch  60, content loss 30.82, style loss 152.23, TV loss 1.46, 0.01 secepoch  70, content loss 29.83, style loss 124.40, TV loss 1.49, 0.01 secepoch  80, content loss 29.00, style loss 108.24, TV loss 1.50, 0.01 secepoch  90, content loss 28.27, style loss 92.64, TV loss 1.52, 0.01 secepoch 100, content loss 27.65, style loss 82.47, TV loss 1.53, 0.00 secepoch 110, content loss 27.15, style loss 73.10, TV loss 1.54, 0.01 secepoch 120, content loss 26.44, style loss 65.02, TV loss 1.56, 0.01 secepoch 130, content loss 25.90, style loss 58.60, TV loss 1.57, 0.01 secepoch 140, content loss 25.44, style loss 53.61, TV loss 1.58, 0.01 secepoch 150, content loss 24.98, style loss 49.11, TV loss 1.59, 0.00 secepoch 160, content loss 24.60, style loss 45.28, TV loss 1.60, 0.01 secepoch 170, content loss 24.11, style loss 42.02, TV loss 1.61, 0.01 sec


    
epoch 180, content loss 23.78, style loss 39.58, TV loss 1.61, 0.01 secepoch 190, content loss 23.41, style loss 37.26, TV loss 1.62, 0.01 secepoch 200, content loss 23.05, style loss 35.32, TV loss 1.62, 0.00 secepoch 210, content loss 22.81, style loss 33.80, TV loss 1.62, 0.00 secepoch 220, content loss 22.49, style loss 32.43, TV loss 1.62, 0.00 secepoch 230, content loss 22.19, style loss 31.25, TV loss 1.62, 0.01 secepoch 240, content loss 21.94, style loss 29.98, TV loss 1.62, 0.00 secepoch 250, content loss 21.65, style loss 28.75, TV loss 1.62, 0.00 secepoch 260, content loss 21.44, style loss 27.63, TV loss 1.62, 0.01 secepoch 270, content loss 21.19, style loss 26.77, TV loss 1.62, 0.01 secepoch 280, content loss 20.97, style loss 25.81, TV loss 1.62, 0.01 secepoch 290, content loss 20.81, style loss 24.97, TV loss 1.62, 0.01 secepoch 300, content loss 20.57, style loss 24.25, TV loss 1.62, 0.01 sec


参考


[1].https://zh-v2.d2l.ai/index.html


注:本文仅代表作者个人观点。如有不同看法,欢迎留言反馈/讨论。

作者:李响Superb,CSDN百万访问量博主,普普通通男大学生,深度学习算法、医学图像处理专攻,偶尔也搞全栈开发,没事就写文章。

博客地址:lixiang.blog.csdn.net



—End—


往期精彩回顾




本站qq群851320808,加入微信群请扫码:
Python社区是高质量的Python/Django开发社区
本文地址:http://www.python88.com/topic/120080
 
238 次点击