Py学习  »  机器学习算法

基于微软开源深度学习算法,用 Python 实现图像和视频修复

CSDN • 4 年前 • 649 次点击  

‍‍

作者 | 李秋键       责编 | 欧阳姝黎
出品 | AI科技大本营(ID:rgznai100)

图像修复是计算机视觉领域的一个重要任务,在数字艺术品修复、公安刑侦面部修复等种种实际场景中被广泛应用。图像修复的核心挑战在于为缺失区域合成视觉逼真和语义合理的像素,要求合成的像素与原像素具有一致性。
传统的图像修复技术有基于结构和纹理两种方法。基于结构的图像修复算法具有代表性的是 Bertalmio 等提出的 BSCB 模型和 Shen 等提出的基于曲率扩散的修复模型 CDD。基于纹理的修复算法中具有代表性的有 Criminisi 等提出的基于 patch 的纹理合成算法。这两种传统的修复算法可以修复小块区域的破损,但是在破损区域越来越大时, 修复效果则直线下降, 并且修复结果存在图像模糊、结构扭曲、纹理不清晰和视觉不连贯等问题。
近年来,随着硬件设备等计算能力的不断提升, 以及深度学习技术在图像翻译、图像超分辨率、图像修复等计算机视觉领域的迅速发展, 采用深度学习技术的修复方法能够捕获图像的高层语义信息, 与传统的修复方法相比, 具有良好的修复效果。故今天我们使用 Python 实现 Bringing Old Photo Back to Life 算法实现对图像和视频的修复。得到的模型评估效果如下:


基本介绍

传统的图像修复技术可以分为基于结构的图像修复技术和基于纹理的图像修复技术两大类。其中,变分偏微分方程模型是基于结构的图像修复技术的典型代表,由变分模型和偏微分方程模型组成。纹理合成是基于纹理的图像修复技术的典型代表。传统数字图像修复技术分类如下图所示。
传统的图像修复方法结果中存在语义信息不完整、图像模糊等问题,无法达到目前对图像修复的要求。基于深度学习的图像修复算法能够捕获更多图像的高级特征,修复结果较好,所以经常用于图像修复。目前基于生成式对抗网络的图像修复是深度学习图像修复领域的一大研究热点,为图像修复技术的发展奠定了坚实的基础。而我们使用的算法就是基于深度学习的微软开源的 Bringing Old Photo Back to Life 去修复图像。
1.1 环境要求
本次环境使用的是 Python3.6.5+windows 平台。主要用的库有:
  • PyTorch 模块。PyTorch 是一个基于 Torch 的 Python 开源机器学习库,用于自然语言处理等应用程序。它主要由 Facebookd 的人工智能小组开发,不仅能够 实现强大的 GPU 加速,同时还支持动态神经网络,这一点是现在很多主流框架如 TensorFlow 都不支持的。PyTorch提供了两个高级功能:1.具有强大的 GPU 加速的张量计算(如Numpy) 2.包含自动求导系统的深度神经网络 除了 Facebook 之外,Twitter、GMU 和 Salesforce 等机构都采用了 PyTorch。

  • pillow 模块。Pillow 是Python里的图像处理库(PIL:Python Image Library),提供了了广泛的文件格式支持,强大的图像处理能力,主要包括图像储存、图像显示、格式转换以及基本的图像处理操作等。

  • Numpy 模块。Numpy 是应用 Python 进行科学计算时的基础模块。它是一个提供多维数组对象的 Python 库,除此之外,还包含了多种衍生的对象(比如掩码式数组(masked arrays)或矩阵)以及一系列的为快速计算数组而生的例程,包括数学运算,逻辑运算,形状操作,排序,选择,I/O,离散傅里叶变换,基本线性代数,基本统计运算,随机模拟等等。

  • collections 这个模块实现了特定目标的容器,以提供 Python 标准内建容器 dict、list、set、tuple 的替代选择。Counter:字典的子类,提供了可哈希对象的计数功能;defaultdict:字典的子类,提供了一个工厂函数,为字典查询提供了默认值;OrderedDict:字典的子类,保留了他们被添加的顺序;namedtuple:创建命名元组子类的工厂函数;deque:类似列表容器,实现了在两端快速添加(append)和弹出(pop);ChainMap:类似字典的容器类,将多个映射集合到一个视图里面。


修复模型算法

本文所使用的 Bringing Old Photo Back to Life 算法流程分别为全局修复、脸部检测、脸部特征加强和特征融合。其中隐空间修复网络采用局部-全局视野融合,其中全局支路采用 nonlocal 模块大大增强处理视野。我们对局部破损图片建立了数据集,训练网络预测破损区域,该破损区域显式的送入 nonlocal 模块,并设置模块感受野为非破损区域
2.1 全局视野修复
本文的模型主要由三个部分组成两个变分自编码器(variational-autoencoder,VAE)和一个 latent space 映射网络,每个部分都可以看作是单独的一个模块。下面将介绍网络设计的思想和不同部分的作用。

模型使用了两个 VAE:
第一个 VAE 用于将合成的老照片(模糊、磨损)进行编码到隐空间。
第二个 VAE 用于将对应的干净的老照片进行编码。
然后,在隐空间学习从污损的老照片到干净照片的映射。
就这样,实现了一个老照片的修复算法。
这个有点像在学习控制图片清晰、磨损的一个特征表示,通过控制这个特征,可以达到修复破损照片的目的。
关键代码如下:
model = networks.UNet(in_channels=1, out_channels=1, depth=4, conv_num=2, wf=6, padding=True, batch_norm=True, up_mode="upsample",with_tanh=False, sync_bn=True, antialiasing=True,)for image_name in imagelist:idx += 1print("processing", image_name)results = []scratch_image = Image.open(os.path.join(config.test_path, image_name)).convert("RGB")w, h = scratch_image.size


    
transformed_image_PIL = data_transforms(scratch_image, config.input_size)scratch_image = transformed_image_PIL.convert("L")scratch_image = tv.transforms.ToTensor()(scratch_image)scratch_image = tv.transforms.Normalize([0.5], [0.5])(scratch_image)scratch_image = torch.unsqueeze(scratch_image, 0)scratch_image = scratch_image.to(config.GPU)P = torch.sigmoid(model(scratch_image))P = P.data.cpu()tv.utils.save_image((P >= 0.4).float(),os.path.join(output_dir, image_name[:-4] + ".png",),nrow=1,padding=0,normalize=True,)    transformed_image_PIL.save(os.path.join(input_dir, image_name[:-4] + ".png"))

2.2 局部脸部修复加强

脸部特征的加强使用pixpix2模型对脸部二次修复。其中, Pix2Pix模型由Isola等于2017年提出, 它由U-Net和PatchGAN组成, 分别充当Pix2Pix模型中的生成器和判别器。该模型使用户只需提供一个草图便能生成一个与之对应的高质量图像; 对应到图像着色工作中, 网络接收真实图像的亮度信息, 对亮度信息进行特征提取并预测图像颜色值。

关键代码:

def create_optimizers(self, opt):    G_params = list(self.netG.parameters())if opt.use_vae:        G_params += list(self.netE.parameters())if opt.isTrain:        D_params = list(self.netD.parameters())    beta1, beta2 = opt.beta1, opt.beta2if opt.no_TTUR:        G_lr, D_lr = opt.lr, opt.lrelse:        G_lr, D_lr = opt.lr / 2, opt.lr * 2    optimizer_G = torch.optim.Adam(G_params, lr=G_lr, betas=(beta1, beta2))    optimizer_D = torch.optim.Adam(D_params, lr=D_lr, betas=(beta1, beta2))return optimizer_G, optimizer_Ddef generate_fake(self, input_semantics, degraded_image, real_image, compute_kld_loss=False):    z = None    KLD_loss = Noneif self.opt.use_vae:        z, mu, logvar = self.encode_z(real_image)if compute_kld_loss:            KLD_loss = self.KLDLoss(mu, logvar) * self.opt.lambda_kld    fake_image = self.netG(input_semantics, degraded_image, z=z)    assert (not compute_kld_loss    ) or self.opt.use_vae, "You cannot compute KLD loss if opt.use_vae == False"return fake_image, KLD_lossdef discriminate(self, input_semantics, fake_image, real_image):if self.opt.no_parsing_map:        fake_concat = fake_image        real_concat = real_imageelse:        fake_concat = torch.cat([input_semantics, fake_image], dim=1)        real_concat = torch.cat([input_semantics, real_image], dim=1)    fake_and_real = torch.cat([fake_concat, real_concat], dim=0)


    
    discriminator_out = self.netD(fake_and_real)    pred_fake, pred_real = self.divide_pred(discriminator_out)return pred_fake, pred_real

源代码:https://pan.baidu.com/s/1lAzmWvAEyxi6RFsLpA5l_Q

提取码:osuh




    

字节跳动1/3员工不支持取消大小周!库克称iPhone将采用可回收材料生产;清华博士接亲被要求现场写代码|极客头条

“Replit 威胁我,要求我关闭我的开源项目!”

万维网源代码以 NFT 形式拍卖,价值或超 4.5 亿?

Python社区是高质量的Python/Django开发社区
本文地址:http://www.python88.com/topic/116079
 
649 次点击