社区所有版块导航
Python
python开源   Django   Python   DjangoApp   pycharm  
DATA
docker   Elasticsearch  
aigc
aigc   chatgpt  
WEB开发
linux   MongoDB   Redis   DATABASE   NGINX   其他Web框架   web工具   zookeeper   tornado   NoSql   Bootstrap   js   peewee   Git   bottle   IE   MQ   Jquery  
机器学习
机器学习算法  
Python88.com
反馈   公告   社区推广  
产品
短视频  
印度
印度  
Py学习  »  Python

图像风格迁移也有框架了:使用Python编写,与PyTorch完美兼容,外行也能用

小白学视觉 • 2 年前 • 615 次点击  

点击上方小白学视觉”,选择加"星标"或“置顶

重磅干货,第一时间送达

易于使用的神经风格迁移框架 pystiche。


将内容图片与艺术风格图片进行融合,生成一张具有特定风格的新图,这种想法并不新鲜。早在 2015 年,Gatys、 Ecker 以及 Bethge 开创性地提出了神经风格迁移(Neural Style Transfer ,NST)。
不同于深度学习,目前 NST 还没有现成的库或框架。因此,新的 NST 技术要么从头开始实现所有内容,要么基于现有的方法实现。但这两种方法都有各自的缺点:前者由于可重用部分的冗长实现,限制了技术创新;后者继承了 DL 硬件和软件快速发展导致的技术债务。
最近,新项目 pystiche 很好地解决了这些问题,虽然它的核心受众是研究人员,但其易于使用的用户界面为非专业人员使用 NST 提供了可能。
pystiche 是一个用 Python 编写的 NST 框架,基于 PyTorch 构建,并与之完全兼容。相关研究由 pyOpenSci 进行同行评审,并发表在 JOSS 期刊 (Journal of Open Source Software) 上。

  • 论文地址:https://joss.theoj.org/papers/10.21105/joss.02761

  • 项目地址:https://github.com/pmeier/pystiche

在深入实现之前,我们先来回顾一下 NST 的原理。它有两种优化方式:基于图像的优化和基于模型的优化。虽然 pystiche 能够很好地处理后者,但更为复杂,因此本文只讨论基于图像的优化方法。
在基于图像的方法中,将图像的像素迭代调整训练,来拟合感知损失函数(perceptual loss)。感知损失是 NST 的核心部分,分为内容损失(content loss)和风格损失(style loss),这些损失评估输出图像与目标图像的匹配程度。与传统的风格迁移算法不同,感知损失包含一个称为编码器的多层模型,这就是 pystiche 基于 PyTorch 构建的原因。
如何使用 pystiche
让我们用一个例子介绍怎么使用 pystiche 生成神经风格迁移图片。首先导入所需模块,选择处理设备。虽然 pystiche 的设计与设备无关,但使用 GPU 可以将 NST 的速度提高几个数量级。
模块导入与设备选择:
import torchimport pystichefrom pystiche import demo, enc, loss, ops, optimprint(f"pystiche=={pystiche.__version__}")device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
输出:
pystiche==0.7.0
多层编码器
content_loss 和 style_loss 是对图像编码进行操作而不是图像本身,这些编码是由在不同层级的预训练编码器生成的。pystiche 定义了 enc.MultiLayerEncoder 类,该类在单个前向传递中可以有效地处理编码问题。该示例使用基于 VGG19 架构的 vgg19_multi_layer_encoder。默认情况下,它将加载 torchvision 提供的权重。
多层编码器:
multi_layer_encoder = enc.vgg19_multi_layer_encoder()print(multi_layer_encoder)
输出:
VGGMultiLayerEncoder(  arch=vgg19, framework=torch, allow_inplace=True  (preprocessing): TorchPreprocessing(   (0): Normalize(     mean=('0.485', '0.456', '0.406'),     std=('0.229', '0.224', '0.225')    )  ) (conv1_1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (relu1_1): ReLU(inplace=True) (conv1_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (relu1_2): ReLU(inplace=True) (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (conv2_1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (relu2_1): ReLU(inplace=True) (conv2_2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (relu2_2): ReLU(inplace=True) (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (conv3_1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (relu3_1): ReLU(inplace=True) (conv3_2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (relu3_2): ReLU(inplace=True) (conv3_3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (relu3_3): ReLU(inplace=True) (conv3_4): Conv2d(256, 256, kernel_size=(3, 3


    
), stride=(1, 1), padding=(1, 1)) (relu3_4): ReLU(inplace=True) (pool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (conv4_1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (relu4_1): ReLU(inplace=True) (conv4_2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (relu4_2): ReLU(inplace=True) (conv4_3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (relu4_3): ReLU(inplace=True) (conv4_4): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (relu4_4): ReLU(inplace=True) (pool4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (conv5_1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (relu5_1): ReLU(inplace=True) (conv5_2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (relu5_2): ReLU(inplace=True) (conv5_3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (relu5_3): ReLU(inplace=True) (conv5_4): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (relu5_4): ReLU(inplace=True) (pool5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))
感知损失
pystiche 将内容损失和风格损失定义为操作符。使用 ops.FeatureReconstructionOperator 作为 content_loss,直接与编码进行对比。如果编码器针对分类任务进行过训练,如该示例中这些编码表示内容。对于content_layer,选择 multi_layer_encoder 的较深层来获取抽象的内容表示,而不是许多不必要的细节。
content_layer = "relu4_2"encoder = multi_layer_encoder.extract_encoder(content_layer)content_loss = ops.FeatureReconstructionOperator(encoder)
pystiche 使用 ops.GramOperator 作为 style_loss 的基础,通过比较编码各个通道之间的相关性来丢弃空间信息。这样就可以在输出图像中的任意区域合成风格元素,而不仅仅是风格图像中它们所在的位置。对于 ops.GramOperator,如果它在浅层和深层 style_layers 都能很好地运行,则其性能达到最佳。
style_weight 可以控制模型对输出图像的重点——内容或风格。为了方便起见,pystiche 将所有内容包装在 ops.MultiLayerEncodingOperator 中,该操作处理在同一 multi_layer_encoder 的多个层上进行操作的相同类型操作符的情况。
style_layers = ("relu1_1", "relu2_1", "relu3_1", "relu4_1", "relu5_1")style_weight = 1e3def get_encoding_op(encoder, layer_weight):    return ops.GramOperator(encoder, score_weight=layer_weight)style_loss = ops.MultiLayerEncodingOperator(    multi_layer_encoder, style_layers, get_encoding_op, score_weight=style_weight,)loss.PerceptualLoss 结合了 content_loss 与 style_loss,将作为优化的标准。criterion = loss.PerceptualLoss(content_loss, style_loss).to(device)print(criterion)
输出:
PerceptualLoss( (content_loss): FeatureReconstructionOperator(   score_weight=1,   encoder=VGGMultiLayerEncoder(     layer=relu4_2,     arch=vgg19,     framework=torch,     allow_inplace=True   ) ) (style_loss): MultiLayerEncodingOperator(   encoder=VGGMultiLayerEncoder(     arch=vgg19,     framework=torch,     allow_inplace=True ), score_weight=1000 (relu1_1): GramOperator(score_weight=0.2) (relu2_1): GramOperator(score_weight=0.2) (relu3_1): GramOperator(score_weight=0.2) (relu4_1): GramOperator(score_weight=


    
0.2) (relu5_1): GramOperator(score_weight=0.2) ))

图像加载
首先加载并显在 NST 需要的目标图片。因为 NST 占用内存较多,故将图像大小调整为 500 像素。
size = 500images = demo.images()content_image = images["bird1"].read(size=size, device=device)criterion.set_content_image(content_image)

内容图片
style_image = images["paint"].read(size=size, device=device)criterion.set_style_image(style_image)

风格图片

神经风格迁移
创建 input_image。从 content_image 开始执行 NST,这样可以实现快速收敛。image_optimization 函数是为了方便,也可以由手动优化循环代替,且不受限制。如果没有指定,则使用 torch.optim.LBFGS 作为优化器。
input_image = content_image.clone()output_image = optim.image_optimization(input_image, criterion, num_steps=500)


好消息!

小白学视觉知识星球

开始面向外开放啦👇👇👇




下载1:OpenCV-Contrib扩展模块中文版教程
在「小白学视觉」公众号后台回复:扩展模块中文教程即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。

下载2:Python视觉实战项目52讲
小白学视觉公众号后台回复:Python视觉实战项目即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。

下载3:OpenCV实战项目20讲
小白学视觉公众号后台回复:OpenCV实战项目20讲即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。

交流群


欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~


Python社区是高质量的Python/Django开发社区
本文地址:http://www.python88.com/topic/153399
 
615 次点击