一文掌握图像超分辨率重建(算法原理、Pytorch实现)——含完整代码和数据

一文掌握图像超分辨率重建(算法原理、Pytorch实现)——含完整代码和数据Photo-RealisticSingleImageSuper-ResolutionUsingaGenerativeAdversarialNetwork

大家好,又见面了,我是你们的朋友全栈君。

目录

一.  图像超分辨率重建概述

1. 概念

2. 应用领域

3. 研究进展

3.1 传统超分辨率重建算法

3.2 基于深度学习的超分辨率重建算法

二.  SRResNet算法原理和Pytorch实现

1. 超分重建基本处理流程

2. 构建深度网络模型提高超分重建性能

3.  基于子像素卷积放大图像尺寸

4.  SRResNet结构剖析

5. Pytorch实现

5.1 运行环境

5.2 训练

5.3 评估

5.4 测试

​  

三.  SRGAN算法原理和Pytorch实现

1. 生成对抗网络(GAN)

2. 感知损失

3. SRGAN结构剖析

4. Pytorch实现

4.1 训练

4.2 评估

4.3 测试

四. 总结

 参考文献


一.  图像超分辨率重建概述

1. 概念

图像分辨率是一组用于评估图像中蕴含细节信息丰富程度的性能参数,包括时间分辨率、空间分辨率及色阶分辨率等,体现了成像系统实际所能反映物体细节信息的能力。相较于低分辨率图像,高分辨率图像通常包含更大的像素密度、更丰富的纹理细节及更高的可信赖度。但在实际上情况中,受采集设备与环境、网络传输介质与带宽、图像退化模型本身等诸多因素的约束,我们通常并不能直接得到具有边缘锐化、无成块模糊的理想高分辨率图像。提升图像分辨率的最直接的做法是对采集系统中的光学硬件进行改进,但是由于制造工艺难以大幅改进并且制造成本十分高昂,因此物理上解决图像低分辨率问题往往代价太大。由此,从软件和算法的角度着手,实现图像超分辨率重建的技术成为了图像处理和计算机视觉等多个领域的热点研究课题。

图像的超分辨率重建技术指的是将给定的低分辨率图像通过特定的算法恢复成相应的高分辨率图像。具体来说,图像超分辨率重建技术指的是利用数字图像处理、计算机视觉等领域的相关知识,借由特定的算法和处理流程,从给定的低分辨率图像中重建出高分辨率图像的过程。其旨在克服或补偿由于图像采集系统或采集环境本身的限制,导致的成像图像模糊、质量低下、感兴趣区域不显著等问题。

简单来理解超分辨率重建就是将小尺寸图像变为大尺寸图像,使图像更加“清晰”。具体效果如下图所示:

一文掌握图像超分辨率重建(算法原理、Pytorch实现)——含完整代码和数据

                                                                      图1 图像超分辨率重建示例

可以看到,通过特定的超分辨率重建算法,使得原本模糊的图像变得清晰了。读者可能会疑惑,直接对低分辨率图像进行“拉伸”不就可以了吗?答案是可以的,但是效果并不好。传统的“拉伸”型算法主要采用近邻搜索等方式,即对低分辨率图像中的每个像素采用近邻查找或近邻插值的方式进行重建,这种手工设定的方式只考虑了局部并不能满足每个像素的特殊情况,难以恢复出低分辨率图像原本的细节信息。因此,一系列有效的超分辨率重建算法开始陆续被研究学者提出,重建能力不断加强,直至今日,依托深度学习技术,图像的超分辨率重建已经取得了非凡的成绩,在效果上愈发真实和清晰。

2. 应用领域

1955年,Toraldo di Francia在光学成像领域首次明确定义了超分辨率这一概念,主要是指利用光学相关的知识,恢复出衍射极限以外的数据信息的过程。1964年左右,Harris和Goodman则首次提出了图像超分辨率这一概念,主要是指利用外推频谱的方法合成出细节信息更丰富的单帧图像的过程。1984 年,在前人的基础上,Tsai和 Huang 等首次提出使用多帧低分辨率图像重建出高分辨率图像的方法后, 超分辨率重建技术开始受到了学术界和工业界广泛的关注和研究。

图像超分辨率重建技术在多个领域都有着广泛的应用范围和研究意义。主要包括:

(1) 图像压缩领域

在视频会议等实时性要求较高的场合,可以在传输前预先对图片进行压缩,等待传输完毕,再由接收端解码后通过超分辨率重建技术复原出原始图像序列,极大减少存储所需的空间及传输所需的带宽。

(2) 医学成像领域

对医学图像进行超分辨率重建,可以在不增加高分辨率成像技术成本的基础上,降低对成像环境的要求,通过复原出的清晰医学影像,实现对病变细胞的精准探测,有助于医生对患者病情做出更好的诊断。

(3) 遥感成像领域

高分辨率遥感卫星的研制具有耗时长、价格高、流程复杂等特点,由此研究者将图像超分辨率重建技术引入了该领域,试图解决高分辨率的遥感成像难以获取这一挑战,使得能够在不改变探测系统本身的前提下提高观测图像的分辨率。

(4) 公共安防领域

公共场合的监控设备采集到的视频往往受到天气、距离等因素的影响,存在图像模糊、分辨率低等问题。通过对采集到的视频进行超分辨率重建,可以为办案人员恢复出车牌号码、清晰人脸等重要信息,为案件侦破提供必要线索。

(5) 视频感知领域

通过图像超分辨率重建技术,可以起到增强视频画质、改善视频的质量,提升用户的视觉体验的作用。

3. 研究进展

按照时间和效果进行分类,可以将超分辨率重建算法分为传统算法和深度学习算法两类。

3.1 传统超分辨率重建算法

传统的超分辨率重建算法主要依靠基本的数字图像处理技术进行重建,常见的有如下几类:

(1) 基于插值的超分辨率重建

基于插值的方法将图像上每个像素都看做是图像平面上的一个点,那么对超分辨率图像的估计可以看做是利用已知的像素信息为平面上未知的像素信息进行拟合的过程,这通常由一个预定义的变换函数或者插值核来完成。基于插值的方法计算简单、易于理解,但是也存在着一些明显的缺陷。

首先,它假设像素灰度值的变化是一个连续的、平滑的过程,但实际上这种假设并不完全成立。其次,在重建过程中,仅根据一个事先定义的转换函数来计算超分辨率图像,不考虑图像的降质退化模型,往往会导致复原出的图像出现模糊、锯齿等现象。常见的基于插值的方法包括最近邻插值法、双线性插值法和双立方插值法等。

(2) 基于退化模型的超分辨率重建

此类方法从图像的降质退化模型出发,假定高分辨率图像是经过了适当的运动变换、模糊及噪声才得到低分辨率图像。这种方法通过提取低分辨率图像中的关键信息,并结合对未知的超分辨率图像的先验知识来约束超分辨率图像的生成。常见的方法包括迭代反投影法、凸集投影法和最大后验概率法等。

(3) 基于学习的超分辨率重建

基于学习的方法则是利用大量的训练数据,从中学习低分辨率图像和高分辨率图像之间某种对应关系,然后根据学习到的映射关系来预测低分辨率图像所对应的高分辨率图像,从而实现图像的超分辨率重建过程。常见的基于学习的方法包括流形学习、稀疏编码方法。

3.2 基于深度学习的超分辨率重建算法

机器学习是人工智能的一个重要分支,而深度学习则是机器学习中最主要的一个算法,其旨在通过多层非线性变换,提取数据的高层抽象特征,学习数据潜在的分布规律,从而获取对新数据做出合理的判断或者预测的能力。随着人工智能和计算机硬件的不断发展,Hinton等人在2006年提出了深度学习这一概念,其旨在利用多层非线性变换提取数据的高层抽象特征。凭借着强大的拟合能力,深度学习开始在各个领域崭露头角,特别是在图像与视觉领域,卷积神经网络大放异,这也使得越来越多的研究者开始尝试将深度学习引入到超分辨率重建领域。

2014年,Dong等人首次将深度学习应用到图像超分辨率重建领域,他们使用一个三层的卷积神经网络学习低分辨率图像与高分辨率图像之间映射关系,自此,在超分辨率重建率领域掀起了深度学习的浪潮,他们的设计的网络模型命名为SRCNN(Super-Resolution Convolutional Neural Network)。

SRCNN采用了插值的方式先将低分辨率图像进行放大,再通过模型进行复原。Shi等人则认为这种预先采用近邻插值的方式本身已经影响了性能,如果从源头出发,应该从样本中去学习如何进行放大,他们基于这个原理提出了ESPCN (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network)算法。该算法在将低分辨率图像送入神经网络之前,无需对给定的低分辨率图像进行一个上采样过程,而是引入一个亚像素卷积层(Sub-pixel convolution layer),来间接实现图像的放大过程。这种做法极大降低了SRCNN的计算量,提高了重建效率。

这里需要注意到,不管是SRCNN还是ESPCN,它们均使用了MSE作为目标函数来训练模型。2017年,Christian Ledig等人从照片感知角度出发,通过对抗网络来进行超分重建(论文题目:Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network)。他们认为,大部分深度学习超分算法采用的MSE损失函数会导致重建的图像过于平滑,缺乏感官上的照片真实感。他们改用生成对抗网络(Generative Adversarial Networks, GAN)来进行重建,并且定义了新的感知目标函数,算法被命名为SRGAN,由一个生成器和一个判别器组成。生成器负责合成高分辨率图像,判别器用于判断给定的图像是来自生成器还是真实样本。通过一个博弈的对抗过程,使得生成器能够将给定的低分辨率图像重建为高分辨率图像。在SRGAN这篇论文中,作者同时提出了一个对比算法,名为SRResNet。SRResNet依然采用了MSE作为最终的损失函数,与以往不同的是,SRResNet采用了足够深的残差卷积网络模型,相比于其它的残差学习重建算法,SRResNet本身也能够取得较好的效果。

由于SRGAN这篇论文同时提出了两种当前主流模式的深度学习超分重建算法,因此,接下来将以SRGAN这篇论文为主线,依次讲解SRResNet和SRGAN算法实现原理,并采用Pytorch深度学习框架完成上述两个算法的复现。

二.  SRResNet算法原理和Pytorch实现

1. 超分重建基本处理流程

最早的采用深度学习进行超分重建的算法是SRCNN算法,其原理很简单,对于输入的一张低分辨率图像,SRCNN首先使用双立方插值将其放大至目标尺寸,然后利用一个三层的卷积神经网络去拟合低分辨率图像与高分辨率图像之间的非线性映射,最后将网络输出的结果作为重建后的高分辨率图像。尽管原理简单,但是依托深度学习模型以及大样本数据的学习,在性能上超过了当时一众传统的图像处理算法,开启了深度学习在超分辨率领域的研究征程。SRCNN的网络结构如图2所示。

一文掌握图像超分辨率重建(算法原理、Pytorch实现)——含完整代码和数据

                                                                                                             图2 SRCNN的网络结构

其中f_1f_2f_3分别表示1、2、3层卷积对应的核大小。

SRCNN作为早期开创性的研究论文,也为后面的工作奠定了处理超分问题的基本流程:

(1) 寻找大量真实场景下图像样本;

(2) 对每张图像进行下采样处理降低图像分辨率,一般有2倍下采样、3倍下采样、4倍下采样等。如果是2倍下采样,则图像长宽均变成原来的1/2.。下采样前的图像作为高分辨率图像H,下采样后的图像作为低分辨率图像L,L和H构成一个有效的图像对用于后期模型训练;

(3) 训练模型时,对低分辨率图像L进行放大还原为高分辨率图像SR,然后与原始的高分辨率图像H进行比较,其差异用来调整模型的参数,通过迭代训练,使得差异最小。实际情况下,研究学者提出了多种损失函数用来定义这种差异,不同的定义方式也会直接影响最终的重建效果;

(4) 训练完的模型可以用来对新的低分辨率图像进行重建,得到高分辨率图像。

从实际操作上来看,整个超分重建分为两步:图像放大和修复。所谓放大就是采用某种方式(SRCNN采用了插值上采样)将图像放大到指定倍数,然后再根据图像修复原理,将放大后的图像映射为目标图像。超分辨率重建不仅能够放大图像尺寸,在某种意义上具备了图像修复的作用,可以在一定程度上削弱图像中的噪声、模糊等。因此,超分辨率重建的很多算法也被学者迁移到图像修复领域中,完成一些诸如jpep压缩去燥、去模糊等任务。

一文掌握图像超分辨率重建(算法原理、Pytorch实现)——含完整代码和数据

                                                                                           图3 简化版超分重建处理流程

简化版的超分重建处理流程如图3所示,当然,图像放大和修复两个步骤的顺序可以任意互换。

2. 构建深度网络模型提高超分重建性能

SRCNN只采用了3个卷积层来实现超分重建,有文献指出如果采用更深的网络结构模型,那么可以重建出更高质量的图像,因为更深的网络模型可以抽取出更高级的图像特征,这种深层模型对图像可以更好的进行表达。在SRCNN之后,有不少研究人员尝试加深网络结构以期取得更佳的重建性能,但是越深的模型越不能很好的收敛,无法得到期望的结果。部分研究学者通过迁移学习来逐步的增加模型深度,但这种方式加深程度有限。因此,亟需一种有效的模型,使得构建深层网络模型变得容易并且有效。这个问题直到2015年由何凯明团队提出ResNet网络才得以有效解决。

ResNet中文名字叫作深度残差网络,主要作用是图像分类。现在在图像分割、目标检测等领域都有很广泛的运用。ResNet在传统卷积神经网络中加入了残差学习(residual learning),解决了深层网络中梯度弥散和精度下降(训练集)的问题,使网络能够越来越深,既保证了精度,又控制了速度。

ResNet可以直观的来理解其背后的意义。以往的神经网络模型每一层学习的是一个 y = f(x) 的映射,可以想象的到,随着层数不断加深,每个函数映射出来的y误差逐渐累计,误差越来越大,梯度在反向传播的过程中越来越发散。这时候,如果改变一下每层的映射关系,改为 y = f(x) + x,也就是在每层的结束加上原始输入,此时输入是x,输出是f(x)+x,那么自然的f(x)趋向于0,或者说f(x)是一个相对较小的值,这样,即便层数不断加大,这个误差f(x)依然控制在一个较小值,整个模型训练时不容易发散。

一文掌握图像超分辨率重建(算法原理、Pytorch实现)——含完整代码和数据

                                                                                         图4 残差网络原理图

上图为残差网络的原理图,可以看到一根线直接跨越两层网络(跳链),将原始数据x带入到了输出中,此时F(x)预测的是一个差值。有了残差学习这种强大的网络结构,就可以按照SRCNN的思路构建用于超分重建的深度神经网络。SRResNet算法主干部分就采用了这种网络结构,如下图所示:

一文掌握图像超分辨率重建(算法原理、Pytorch实现)——含完整代码和数据

                                                                             图5 超分重建深度残差模块

上述模型采用了多个深度残差模块进行图像的特征抽取,多次运用跳链技术将输入连接到网络输出,这种结构能够保证整个网络的稳定性。由于采用了深度模型,相比浅层模型能够更有效的挖掘图像特征,在性能上可以超越浅层模型算法(SRResNet使用了16个残差模块)。注意到,上述模型每层仅仅改变了图像的通道数,并没有改变图像的尺寸大小,从这个意义上来说这个网络可以认为是前面提到的修复模型。下面会介绍如何在这个模型基础上再增加一个子模块用来放大图像,从而构建一个完整的超分重建模型。

3.  基于子像素卷积放大图像尺寸

子像素卷积(Sub-pixel convolution)是一种巧妙的图像及特征图放大方法,又叫做pixel shuffle(像素清洗)。在深度学习超分辨率重建中,常见的扩尺度方法有直接上采样,双线性插值,反卷积等等。ESPCN算法中提出了一种超分辨率扩尺度方法,即为子像素卷积方法,该方法后续也被应用在了SRResNet和SRGAN算法中。因此,这里需要先介绍子像素卷积的原理及实现方式。

采用CNN对特征图进行放大一般会采用deconvolution等方法,这种方法通常会带入过多人工因素,而子像素卷积会大大降低这个风险。因为子像素卷积放大使用的参数是需要学习的,相比那些手工设定的方式,这种通过样本学习的方式其放大性能更加准确。

具体实现原理如下图所示:

一文掌握图像超分辨率重建(算法原理、Pytorch实现)——含完整代码和数据

                                                                                             图6 子像素卷积示意图

上图很直观得表达了子像素卷积的流程。假设,如果想对原图放大3倍,那么需要生成出3^2=9个同等大小的特征图,也就是通道数扩充了9倍(这个通过普通的卷积操作即可实现)。然后将九个同等大小的特征图拼成一个放大3倍的大图,这就是子像素卷积操作了。

实现时先将原始特征图通过卷积扩展其通道数,如果是想放大4倍,那么就需要将通道数扩展为原来的16倍。特征图做完卷积后再按照特定的格式进行排列,即可得到一张大图,这就是所谓的像素清洗。通过像素清洗,特征的通道数重新恢复为原来输入时的大小,但是每个特征图的尺寸变大了。这里注意到每个像素的扩展方式由对应的卷积来决定,此时卷积的参数是需要学习的,因此,相比于手工设计的放大方式,这种基于学习的放大方式能够更好的去拟合像素之间的关系。

SRResNet模型也利用子像素卷积来放大图像,具体的,在图5所示模型后面添加两个子像素卷积模块,每个子像素卷积模块使得输入图像放大2倍,因此这个模型最终可以将图像放大4倍,如下图所示:

一文掌握图像超分辨率重建(算法原理、Pytorch实现)——含完整代码和数据

                                                                                     图7 SRResNet子像素卷积模块

4.  SRResNet结构剖析

SRResNet使用深度残差网络来构建超分重建模型,主要包含两部分:深度残差模型、子像素卷积模型。深度残差模型用来进行高效的特征提取,可以在一定程度上削弱图像噪点。子像素卷积模型主要用来放大图像尺寸。完整的SRResNet网络结果如下图所示:

一文掌握图像超分辨率重建(算法原理、Pytorch实现)——含完整代码和数据

                                                                                          图 8 SRResNet网络结构

上图中,k表示卷积核大小,n表示输出通道数,s表示步长。除了深度残差模块和子像素卷积模块以外,在整个模型输入和输出部分均添加了一个卷积模块用于数据调整和增强。

需要注意的是,SRResNet模型使用MSE作为目标函数,也就是通过模型还原出来的高分辨率图像与原始高分辨率图像的均方误差,公式如下:

                                                                           L=\frac{1}{n}\sum_{i=1}^{n}\|F(X_i;\Theta))-Y_i\|^2

MSE也是目前大部分超分重建算法采用的目标函数。后面我们会看到,使用该目标函数重建的超分图像并不能很好的符合人眼主观感受,SRGAN算法正是基于此进行的改进。

5. Pytorch实现

本节将从源码出发,完成SRResNet算法的建模、训练和推理。本文基于深度学习框架Pytorch来完成所有的编码工作,读者在阅读本文代码前需要熟悉Pytorch基本操作命令。

所有代码和数据可以从百度云上进行下载https://pan.baidu.com/s/1yUCK8JMmMRjDgtwwUt7z2w

提取码:jkpy

该工程比较大,主要是包含了用于训练的COCO2014数据集。提供这样一个完整的工程包是为了方便读者只需要下载和解压就可以直接运行,而不需要再去额外的寻找数据集和测试集。代码里也提供了已经训练好的.pth模型文件。

5.1 运行环境

(1) 基本配置

本文使用Python语言进行代码编写,Python版本为3.6

采用Pytorch深度学习框架进行算法建模,对应的版本为torch1.4.0。Pytorch的完整安装教程请参考另一篇博客。Pytorch的兼容性较好,1.0版本以后代码都可以正常运行。

操作系统为Windows 10,使用2块 GTX 1080TI显卡进行加速运算。

IDE为VS Code。

(2) 安装scikit-image

按照下述命令进行安装,该包主要用来提供PSNR和SSIM的计算。PSNR和SSIM是超分重建中经常会使用的两个评价指标。

pip install scikit-image==0.16.2

(3) 可视化结果

为了能够在模型的训练过程中方便的查看运行结果,推荐使用tensorboard来实现。由于tensorboard是tensorflow推出来的,因此首先需要安装tensorflow,但是此处没必要再安装GPU版的tensorflow,只需要直接安装CPU版本的即可,安装命令如下:

pip install tensorflow

本文安装的tensorflow版本为2.1.0。在安装tensorflow的过程中会自动安装tensorboard。

(4) 数据集

本文使用COCO2014数据集进行训练,训练时联合使用train2014(82783张图片)和val2014(40504张图片)两部分数据,共有123285张图像。测试时使用Set5、Set14和BSD100数据集分别进行测试。

5.2 训练

(1)代码结构组织

为了方便读者阅读、运行和修改代码,本文采用比较简单的代码组织方式。完整结构如下图所示:

一文掌握图像超分辨率重建(算法原理、Pytorch实现)——含完整代码和数据

                                                                                                图9 代码组织结构

项目根目录下有8个.py文件和2个文件夹,下面对各个文件和文件夹进行简单说明。

  • create_data_lists.py:生成数据列表,检查数据集中的图像文件尺寸,并将符合的图像文件名写入JSON文件列表供后续Pytorch调用;
  • datasets.py:用于构建数据集加载器,主要沿用Pytorch标准数据加载器格式进行封装;
  • models.py:模型结构文件,存储各个模型的结构定义;
  • utils.py:工具函数文件,所有项目中涉及到的一些自定义函数均放置在该文件中;
  • train_srresnet.py:用于训练SRResNet算法;
  • train_srgan.py:用于训练SRGAN算法;
  • eval.py:用于模型评估,主要以计算测试集的PSNR和SSIM为主;
  • test.py:用于单张样本测试,运用训练好的模型为单张图像进行超分重建;
  • data:用于存放训练和测试数据集以及文件列表;
  • results:用于存放运行结果,包括训练好的模型以及单张样本测试结果;

读者可以下载本文代码和数据集进行查看和运行,整个代码运行顺序如下:

  • 运行create_data_lists.py文件用于为数据集生成文件列表;
  • 运行train_srresnet.py进行SRResNet算法训练,训练结束后在results文件夹中会生成checkpoint_srresnet.pth模型文件;
  • 运行eval.py文件对测试集进行评估,计算每个测试集的平均PSNR、SSIM值;
  • 运行test.py文件对results文件夹下名为test.jpg的图像进行超分还原,还原结果存储在results文件夹下面;
  • 运行train_srgan.py文件进行SRGAN算法训练,训练结束后在results文件夹中会生成checkpoint_srgan.pth模型文件;
  • 修改并运行eval.py文件对测试集进行评估,计算每个测试集的平均PSNR、SSIM值;
  • 修改并运行test.py文件对results文件夹下名为test.jpg的图像进行超分还原,还原结果存储在results文件夹下面;

(2)生成数据列表

在训练前需要先准备好数据集,按照特定的格式生成文件列表供Pytorch的数据加载器torch.utils.data.DataLoader对图像进行高效率并行加载。只有准确生成了数据集列表文件才能进行下面的训练。

create_data_lists.py文件内容如下:

from utils import create_data_lists

if __name__ == '__main__':
    create_data_lists(train_folders=['./data/COCO2014/train2014',
                                     './data/COCO2014/val2014'],
                      test_folders=['./data/BSD100',
                                    './data/Set5',
                                    './data/Set14'],
                      min_size=100,
                      output_folder='./data/')

首先从utils中导入create_data_lists函数,该函数用于执行具体的JSON文件创建。在主函数部分设置好训练集train_folders和测试集test_folders文件夹路径,参数min_size=100用于检查训练集和测试集中的每张图像分辨率,无论是图像宽度还是高度,如果小于min_size则该图像路径不写入JSON文件列表中。output_folder用于指明最后的JSON文件列表存放路径。

create_data_lists实现方式如下:

from PIL import Image
import os
import json

def create_data_lists(train_folders, test_folders, min_size, output_folder):
    """
    创建训练集和测试集列表文件.
        参数 train_folders: 训练文件夹集合; 各文件夹中的图像将被合并到一个图片列表文件里面
        参数 test_folders: 测试文件夹集合; 每个文件夹将形成一个图片列表文件
        参数 min_size: 图像宽、高的最小容忍值
        参数 output_folder: 最终生成的文件列表,json格式
    """
    print("\n正在创建文件列表... 请耐心等待.\n")
    train_images = list()
    for d in train_folders:
        for i in os.listdir(d):
            img_path = os.path.join(d, i)
            img = Image.open(img_path, mode='r')
            if img.width >= min_size and img.height >= min_size:
                train_images.append(img_path)
    print("训练集中共有 %d 张图像\n" % len(train_images))
    with open(os.path.join(output_folder, 'train_images.json'), 'w') as j:
        json.dump(train_images, j)

    for d in test_folders:
        test_images = list()
        test_name = d.split("/")[-1]
        for i in os.listdir(d):
            img_path = os.path.join(d, i)
            img = Image.open(img_path, mode='r')
            if img.width >= min_size and img.height >= min_size:
                test_images.append(img_path)
        print("在测试集 %s 中共有 %d 张图像\n" %
              (test_name, len(test_images)))
        with open(os.path.join(output_folder, test_name + '_test_images.json'),'w') as j:
            json.dump(test_images, j)

    print("生成完毕。训练集和测试集文件列表已保存在 %s 下\n" % output_folder)

运行后,在data文件夹下会产生下列文件列表(每个文件列表均存储了用于训练和测试的图像路径):

一文掌握图像超分辨率重建(算法原理、Pytorch实现)——含完整代码和数据

(3)训练SRResNet

train_srresnet.py文件内容如下:

import torch.backends.cudnn as cudnn
import torch
from torch import nn
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter
from models import SRResNet
from datasets import SRDataset
from utils import *


# 数据集参数
data_folder = './data/'          # 数据存放路径
crop_size = 96      # 高分辨率图像裁剪尺寸
scaling_factor = 4  # 放大比例

# 模型参数
large_kernel_size = 9   # 第一层卷积和最后一层卷积的核大小
small_kernel_size = 3   # 中间层卷积的核大小
n_channels = 64         # 中间层通道数
n_blocks = 16           # 残差模块数量

# 学习参数
checkpoint = None   # 预训练模型路径,如果不存在则为None
batch_size = 400    # 批大小
start_epoch = 1     # 轮数起始位置
epochs = 130        # 迭代轮数
workers = 4         # 工作线程数
lr = 1e-4           # 学习率

# 设备参数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ngpu = 2           # 用来运行的gpu数量

cudnn.benchmark = True # 对卷积进行加速

writer = SummaryWriter() # 实时监控     使用命令 tensorboard --logdir runs  进行查看

def main():
    """
    训练.
    """
    global checkpoint,start_epoch,writer

    # 初始化
    model = SRResNet(large_kernel_size=large_kernel_size,
                        small_kernel_size=small_kernel_size,
                        n_channels=n_channels,
                        n_blocks=n_blocks,
                        scaling_factor=scaling_factor)
    # 初始化优化器
    optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, model.parameters()),lr=lr)

    # 迁移至默认设备进行训练
    model = model.to(device)
    criterion = nn.MSELoss().to(device)

    # 加载预训练模型
    if checkpoint is not None:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
    
    if torch.cuda.is_available() and ngpu > 1:
        model = nn.DataParallel(model, device_ids=list(range(ngpu)))

    # 定制化的dataloaders
    train_dataset = SRDataset(data_folder,split='train',
                              crop_size=crop_size,
                              scaling_factor=scaling_factor,
                              lr_img_type='imagenet-norm',
                              hr_img_type='[-1, 1]')
    train_loader = torch.utils.data.DataLoader(train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=workers,
        pin_memory=True) 

    # 开始逐轮训练
    for epoch in range(start_epoch, epochs+1):

        model.train()  # 训练模式:允许使用批样本归一化

        loss_epoch = AverageMeter()  # 统计损失函数

        n_iter = len(train_loader)

        # 按批处理
        for i, (lr_imgs, hr_imgs) in enumerate(train_loader):

            # 数据移至默认设备进行训练
            lr_imgs = lr_imgs.to(device)  # (batch_size (N), 3, 24, 24), imagenet-normed 格式
            hr_imgs = hr_imgs.to(device)  # (batch_size (N), 3, 96, 96),  [-1, 1]格式

            # 前向传播
            sr_imgs = model(lr_imgs)

            # 计算损失
            loss = criterion(sr_imgs, hr_imgs)  

            # 后向传播
            optimizer.zero_grad()
            loss.backward()

            # 更新模型
            optimizer.step()

            # 记录损失值
            loss_epoch.update(loss.item(), lr_imgs.size(0))

            # 监控图像变化
            if i==(n_iter-2):
                writer.add_image('SRResNet/epoch_'+str(epoch)+'_1', make_grid(lr_imgs[:4,:3,:,:].cpu(), nrow=4, normalize=True),epoch)
                writer.add_image('SRResNet/epoch_'+str(epoch)+'_2', make_grid(sr_imgs[:4,:3,:,:].cpu(), nrow=4, normalize=True),epoch)
                writer.add_image('SRResNet/epoch_'+str(epoch)+'_3', make_grid(hr_imgs[:4,:3,:,:].cpu(), nrow=4, normalize=True),epoch)

            # 打印结果
            print("第 "+str(i)+ " 个batch训练结束")
 
        # 手动释放内存              
        del lr_imgs, hr_imgs, sr_imgs

        # 监控损失值变化
        writer.add_scalar('SRResNet/MSE_Loss', loss_epoch.val, epoch)    

        # 保存预训练模型
        torch.save({
            'epoch': epoch,
            'model': model.module.state_dict(),
            'optimizer': optimizer.state_dict()
        }, 'results/checkpoint_srresnet.pth')
    
    # 训练结束关闭监控
    writer.close()


if __name__ == '__main__':
    main()

上述代码中已经配上详细的注释,读者可以先Debug逐行运行即可。其中参数batch_size 设置为400,GPU数量ngpu设置为2,读者可以根据自己的机器性能调节设置,对于一般的GPU,batch_size设置为128可以满足。
        需要说明的是本文使用了tensorboard来查看训练结果,在程序运行的过程中可以再开一个终端运行下述命令:

tensorboard --logdir runs

从而打开tensorboard监视器,打开后可以在浏览器中访问:http://localhost:6006/ 来监视训练结果。

上述代码实现时首先构建了一个SRResNet模型,该模型的定义在models.py文件中给出:

class SRResNet(nn.Module):
    """
    SRResNet模型
    """
    def __init__(self, large_kernel_size=9, small_kernel_size=3, n_channels=64, n_blocks=16, scaling_factor=4):
        """
        :参数 large_kernel_size: 第一层卷积和最后一层卷积核大小
        :参数 small_kernel_size: 中间层卷积核大小
        :参数 n_channels: 中间层通道数
        :参数 n_blocks: 残差模块数
        :参数 scaling_factor: 放大比例
        """
        super(SRResNet, self).__init__()

        # 放大比例必须为 2、 4 或 8
        scaling_factor = int(scaling_factor)
        assert scaling_factor in {2, 4, 8}, "放大比例必须为 2、 4 或 8!"

        # 第一个卷积块
        self.conv_block1 = ConvolutionalBlock(in_channels=3, out_channels=n_channels, kernel_size=large_kernel_size,
                                              batch_norm=False, activation='PReLu')

        # 一系列残差模块, 每个残差模块包含一个跳连接
        self.residual_blocks = nn.Sequential(
            *[ResidualBlock(kernel_size=small_kernel_size, n_channels=n_channels) for i in range(n_blocks)])

        # 第二个卷积块
        self.conv_block2 = ConvolutionalBlock(in_channels=n_channels, out_channels=n_channels,
                                              kernel_size=small_kernel_size,
                                              batch_norm=True, activation=None)

        # 放大通过子像素卷积模块实现, 每个模块放大两倍
        n_subpixel_convolution_blocks = int(math.log2(scaling_factor))
        self.subpixel_convolutional_blocks = nn.Sequential(
            *[SubPixelConvolutionalBlock(kernel_size=small_kernel_size, n_channels=n_channels, scaling_factor=2) for i
              in range(n_subpixel_convolution_blocks)])

        # 最后一个卷积模块
        self.conv_block3 = ConvolutionalBlock(in_channels=n_channels, out_channels=3, kernel_size=large_kernel_size,
                                              batch_norm=False, activation='Tanh')

    def forward(self, lr_imgs):
        """
        前向传播.

        :参数 lr_imgs: 低分辨率输入图像集, 张量表示,大小为 (N, 3, w, h)
        :返回: 高分辨率输出图像集, 张量表示, 大小为 (N, 3, w * scaling factor, h * scaling factor)
        """
        output = self.conv_block1(lr_imgs)  # (16, 3, 24, 24)
        residual = output  # (16, 64, 24, 24)
        output = self.residual_blocks(output)  # (16, 64, 24, 24)
        output = self.conv_block2(output)  # (16, 64, 24, 24)
        output = output + residual  # (16, 64, 24, 24)
        output = self.subpixel_convolutional_blocks(output)  # (16, 64, 24 * 4, 24 * 4)
        sr_imgs = self.conv_block3(output)  # (16, 3, 24 * 4, 24 * 4)

        return sr_imgs

整个模型完全参照SRResNet的实现方式,组成方式为:1个卷积模块+16个残差模块+1个卷积模块+2个子像素卷积模块+1个卷积模块。

数据的加载通过自定义的SRDataset来实现,其定义在datasets.py文件中给出:

import torch
from torch.utils.data import Dataset
import json
import os
from PIL import Image
from utils import ImageTransforms


class SRDataset(Dataset):
    """
    数据集加载器
    """

    def __init__(self, data_folder, split, crop_size, scaling_factor, lr_img_type, hr_img_type, test_data_name=None):
        """
        :参数 data_folder: # Json数据文件所在文件夹路径
        :参数 split: 'train' 或者 'test'
        :参数 crop_size: 高分辨率图像裁剪尺寸  (实际训练时不会用原图进行放大,而是截取原图的一个子块进行放大)
        :参数 scaling_factor: 放大比例
        :参数 lr_img_type: 低分辨率图像预处理方式
        :参数 hr_img_type: 高分辨率图像预处理方式
        :参数 test_data_name: 如果是评估阶段,则需要给出具体的待评估数据集名称,例如 "Set14"
        """

        self.data_folder = data_folder
        self.split = split.lower()
        self.crop_size = int(crop_size)
        self.scaling_factor = int(scaling_factor)
        self.lr_img_type = lr_img_type
        self.hr_img_type = hr_img_type
        self.test_data_name = test_data_name

        assert self.split in {'train', 'test'}
        if self.split == 'test' and self.test_data_name is None:
            raise ValueError("请提供测试数据集名称!")
        assert lr_img_type in {'[0, 255]', '[0, 1]', '[-1, 1]', 'imagenet-norm'}
        assert hr_img_type in {'[0, 255]', '[0, 1]', '[-1, 1]', 'imagenet-norm'}

        # 如果是训练,则所有图像必须保持固定的分辨率以此保证能够整除放大比例
        # 如果是测试,则不需要对图像的长宽作限定
        if self.split == 'train':
            assert self.crop_size % self.scaling_factor == 0, "裁剪尺寸不能被放大比例整除!"

        # 读取图像路径
        if self.split == 'train':
            with open(os.path.join(data_folder, 'train_images.json'), 'r') as j:
                self.images = json.load(j)
        else:
            with open(os.path.join(data_folder, self.test_data_name + '_test_images.json'), 'r') as j:
                self.images = json.load(j)

        # 数据处理方式
        self.transform = ImageTransforms(split=self.split,
                                         crop_size=self.crop_size,
                                         scaling_factor=self.scaling_factor,
                                         lr_img_type=self.lr_img_type,
                                         hr_img_type=self.hr_img_type)

    def __getitem__(self, i):
        """
        为了使用PyTorch的DataLoader,必须提供该方法.

        :参数 i: 图像检索号
        :返回: 返回第i个低分辨率和高分辨率的图像对
        """
        # 读取图像
        img = Image.open(self.images[i], mode='r')
        img = img.convert('RGB')
        if img.width <= 96 or img.height <= 96:
            print(self.images[i], img.width, img.height)
        lr_img, hr_img = self.transform(img)

        return lr_img, hr_img

    def __len__(self):
        """
        为了使用PyTorch的DataLoader,必须提供该方法.

        :返回: 加载的图像总数
        """
        return len(self.images)

其中需要注意的是,我们提供了4中图像数据的变换方式:'[0, 255]’, ‘[0, 1]’, ‘[-1, 1]’, ‘imagenet-norm’。对于SRResNet数据集,我们希望输入的图像数据经过标准的ImageNet处理,即减去ImageNet均值并除以其方差,对于输出将其变换至[-1,1]之间。具体的变换由utils.py文件的convert_image函数实现。

图像加载的处理流程如下:

  • 加载一张图像,从任意位置处裁剪96×96的子块,将该子块作为原始高分辨率图像hr_img;
  • 对hr_img进行双线性下采样(4倍),得到24×24的子块,将该子块作为初始的低分辨率图像lr_img;
  • 对lr_img按照ImageNet数据集方式进行预处理,将hr_img转换至[-1,1];
  • 将lr_img和hr_img作为一对训练对返回;

图像变换的实现方式在utils.py文件中的ImageTransforms类给出:

class ImageTransforms(object):
    """
    图像变换.
    """

    def __init__(self, split, crop_size, scaling_factor, lr_img_type,
                 hr_img_type):
        """
        :参数 split: 'train' 或 'test'
        :参数 crop_size: 高分辨率图像裁剪尺寸
        :参数 scaling_factor: 放大比例
        :参数 lr_img_type: 低分辨率图像预处理方式
        :参数 hr_img_type: 高分辨率图像预处理方式
        """
        self.split = split.lower()
        self.crop_size = crop_size
        self.scaling_factor = scaling_factor
        self.lr_img_type = lr_img_type
        self.hr_img_type = hr_img_type

        assert self.split in {'train', 'test'}

    def __call__(self, img):
        """
        对图像进行裁剪和下采样形成低分辨率图像
        :参数 img: 由PIL库读取的图像
        :返回: 特定形式的低分辨率和高分辨率图像
        """

        # 裁剪
        if self.split == 'train':
            # 从原图中随机裁剪一个子块作为高分辨率图像
            left = random.randint(1, img.width - self.crop_size)
            top = random.randint(1, img.height - self.crop_size)
            right = left + self.crop_size
            bottom = top + self.crop_size
            hr_img = img.crop((left, top, right, bottom))
        else:
            # 从图像中尽可能大的裁剪出能被放大比例整除的图像
            x_remainder = img.width % self.scaling_factor
            y_remainder = img.height % self.scaling_factor
            left = x_remainder // 2
            top = y_remainder // 2
            right = left + (img.width - x_remainder)
            bottom = top + (img.height - y_remainder)
            hr_img = img.crop((left, top, right, bottom))

        # 下采样(双三次差值)
        lr_img = hr_img.resize((int(hr_img.width / self.scaling_factor),
                                int(hr_img.height / self.scaling_factor)),
                               Image.BICUBIC)

        # 安全性检查
        assert hr_img.width == lr_img.width * self.scaling_factor and hr_img.height == lr_img.height * self.scaling_factor

        # 转换图像
        lr_img = convert_image(lr_img, source='pil', target=self.lr_img_type)
        hr_img = convert_image(hr_img, source='pil', target=self.hr_img_type)

        return lr_img, hr_img

由于采用的是Pytorch框架,读者可以方便的通过debug方式逐行运行和调试代码。整体实现难度并不大,本文不再对此进行赘述。         

(4)训练结果

训练共用时5小时19分6秒(2块GTX 1080Ti显卡),训练完成后保存的模型共17.8M。下图展示了训练过程中的损失函数变化。可以看到,随着训练的进行,损失函数逐渐开始收敛,在结束的时候基本处在收敛平稳点。

一文掌握图像超分辨率重建(算法原理、Pytorch实现)——含完整代码和数据

                                                                              图10 训练时损失函数变化

下图展示了训练过程中训练数据超分重建的效果图,依次展示epoch=1、60和130时的效果,每张图像共三行,第一行为低分辨率图像,第二行为当前模型重建出的超分图像,第三行为实际的真实原始清晰图像。可以看到,随着迭代次数的增加,超分还原的效果越来越好,到了第99个epoch的时候还原出来的图像已经大幅削弱了块状噪点的影响,图像更加的平滑和清晰。

一文掌握图像超分辨率重建(算法原理、Pytorch实现)——含完整代码和数据

                                                                                      图11 epoch=1时超分重建效

一文掌握图像超分辨率重建(算法原理、Pytorch实现)——含完整代码和数据

                                                                                     图12 epoch=60时超分重建效果

一文掌握图像超分辨率重建(算法原理、Pytorch实现)——含完整代码和数据

                                                                                     图13 epoch=130时超分重建效果

5.3 评估

eval.py文件完整代码如下:

from utils import *
from torch import nn
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
from datasets import SRDataset
from models import SRResNet
import time

# 模型参数
large_kernel_size = 9   # 第一层卷积和最后一层卷积的核大小
small_kernel_size = 3   # 中间层卷积的核大小
n_channels = 64         # 中间层通道数
n_blocks = 16           # 残差模块数量
scaling_factor = 4      # 放大比例
ngpu = 2                # GP数量
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


if __name__ == '__main__':
    
    # 测试集目录
    data_folder = "./data/"
    test_data_names = ["Set5","Set14", "BSD100"]

    # 预训练模型
    srresnet_checkpoint = "./results/checkpoint_srresnet.pth"

    # 加载模型SRResNet
    checkpoint = torch.load(srresnet_checkpoint)
    srresnet = SRResNet(large_kernel_size=large_kernel_size,
                        small_kernel_size=small_kernel_size,
                        n_channels=n_channels,
                        n_blocks=n_blocks,
                        scaling_factor=scaling_factor)
    srresnet = srresnet.to(device)
    srresnet.load_state_dict(checkpoint['model'])

    # 多GPU测试
    if torch.cuda.is_available() and ngpu > 1:
        srresnet = nn.DataParallel(srresnet, device_ids=list(range(ngpu)))
   
    srresnet.eval()
    model = srresnet

    for test_data_name in test_data_names:
        print("\n数据集 %s:\n" % test_data_name)

        # 定制化数据加载器
        test_dataset = SRDataset(data_folder,
                                split='test',
                                crop_size=0,
                                scaling_factor=4,
                                lr_img_type='imagenet-norm',
                                hr_img_type='[-1, 1]',
                                test_data_name=test_data_name)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=1,
                                                pin_memory=True)

        # 记录每个样本 PSNR 和 SSIM值
        PSNRs = AverageMeter()
        SSIMs = AverageMeter()

        # 记录测试时间
        start = time.time()

        with torch.no_grad():
            # 逐批样本进行推理计算
            for i, (lr_imgs, hr_imgs) in enumerate(test_loader):
                
                # 数据移至默认设备
                lr_imgs = lr_imgs.to(device)  # (batch_size (1), 3, w / 4, h / 4), imagenet-normed
                hr_imgs = hr_imgs.to(device)  # (batch_size (1), 3, w, h), in [-1, 1]

                # 前向传播.
                sr_imgs = model(lr_imgs)  # (1, 3, w, h), in [-1, 1]                

                # 计算 PSNR 和 SSIM
                sr_imgs_y = convert_image(sr_imgs, source='[-1, 1]', target='y-channel').squeeze(
                    0)  # (w, h), in y-channel
                hr_imgs_y = convert_image(hr_imgs, source='[-1, 1]', target='y-channel').squeeze(0)  # (w, h), in y-channel
                psnr = peak_signal_noise_ratio(hr_imgs_y.cpu().numpy(), sr_imgs_y.cpu().numpy(),
                                            data_range=255.)
                ssim = structural_similarity(hr_imgs_y.cpu().numpy(), sr_imgs_y.cpu().numpy(),
                                            data_range=255.)
                PSNRs.update(psnr, lr_imgs.size(0))
                SSIMs.update(ssim, lr_imgs.size(0))


        # 输出平均PSNR和SSIM
        print('PSNR  {psnrs.avg:.3f}'.format(psnrs=PSNRs))
        print('SSIM  {ssims.avg:.3f}'.format(ssims=SSIMs))
        print('平均单张样本用时  {:.3f} 秒'.format((time.time()-start)/len(test_dataset)))

    print("\n")

最终在三个数据集上的测试结果如下表所示:

  Set5 Set14 BSD100
PSNR 31.866 28.504 27.498
SSIM 0.900 0.797 0.754
单张图片平均用时(毫秒) 886 304 81

上表中结果与论文中的值基本一致。由于初始化等随机因素的影响,读者在复现的时候并不一定与上述值完全一致,较为接近即可。上述测试值在性能上已经超越了很多算法,例如DRCNN、ESPCN等。这个主要归功于深度残差网络的作用,我们采用了16个残差模块进行学习,其对于图像的特征表示能力更加显著。读者可以自行尝试进一步再加深网络模块数查看效果,本文不再赘述。

5.4 测试

test.py文件完整代码如下:

from utils import *
from torch import nn
from models import SRResNet
import time
from PIL import Image

# 测试图像
imgPath = './results/test.jpg'

# 模型参数
large_kernel_size = 9   # 第一层卷积和最后一层卷积的核大小
small_kernel_size = 3   # 中间层卷积的核大小
n_channels = 64         # 中间层通道数
n_blocks = 16           # 残差模块数量
scaling_factor = 4      # 放大比例
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


if __name__ == '__main__':

    # 预训练模型
    #srgan_checkpoint = "./results/checkpoint_srgan.pth"
    srresnet_checkpoint = "./results/checkpoint_srresnet.pth"

    # 加载模型SRResNet 或 SRGAN
    checkpoint = torch.load(srresnet_checkpoint)
    srresnet = SRResNet(large_kernel_size=large_kernel_size,
                        small_kernel_size=small_kernel_size,
                        n_channels=n_channels,
                        n_blocks=n_blocks,
                        scaling_factor=scaling_factor)
    srresnet = srresnet.to(device)
    srresnet.load_state_dict(checkpoint['model'])
   
    srresnet.eval()
    model = srresnet
    # srgan_generator = torch.load(srgan_checkpoint)['generator'].to(device)
    # srgan_generator.eval()
    # model = srgan_generator

    # 加载图像
    img = Image.open(imgPath, mode='r')
    img = img.convert('RGB')

    # 双线性上采样
    Bicubic_img = img.resize((int(img.width * scaling_factor),int(img.height * scaling_factor)),Image.BICUBIC)
    Bicubic_img.save('./results/test_bicubic.jpg')

    # 图像预处理
    lr_img = convert_image(img, source='pil', target='imagenet-norm')
    lr_img.unsqueeze_(0)

    # 记录时间
    start = time.time()

    # 转移数据至设备
    lr_img = lr_img.to(device)  # (1, 3, w, h ), imagenet-normed

    # 模型推理
    with torch.no_grad():
        sr_img = model(lr_img).squeeze(0).cpu().detach()  # (1, 3, w*scale, h*scale), in [-1, 1]   
        sr_img = convert_image(sr_img, source='[-1, 1]', target='pil')
        sr_img.save('./results/test_srres.jpg')

    print('用时  {:.3f} 秒'.format(time.time()-start))

本文从网上选取一张低分辨率的证件照片进行实验测试。原图像大小为130×93,然后分别对其进行双线性上采样以及超分重建,图像放大4倍,变为520×372。对比效果如下所示:

一文掌握图像超分辨率重建(算法原理、Pytorch实现)——含完整代码和数据  一文掌握图像超分辨率重建(算法原理、Pytorch实现)——含完整代码和数据 一文掌握图像超分辨率重建(算法原理、Pytorch实现)——含完整代码和数据

                                 图14 低分辨率证件照测试效果图,从左到右依次为:原图、bicubic上采样、超分重建

三.  SRGAN算法原理和Pytorch实现

SRResNet算法是一个单模型算法,从图像输入到图像输出中间通过各个卷积模块的操作完成,整个结构比较清晰。但是SRResNet也有不可避免的缺陷,就是它采用了MSE作为最终的目标函数,而这个MSE是直接通过衡量模型输出和真值的像素差异来计算的,SRGAN算法指出,这种目标函数会使得超分重建出的图像过于平滑,尽管PSNR和SSIM值会比较高,但是重建出来的图像并不能够很好的符合人眼主观感受,丢失了细节纹理信息。下面给出一张图来说明SRResNet算法和SRGAN算法超分重建效果的不同之处:

一文掌握图像超分辨率重建(算法原理、Pytorch实现)——含完整代码和数据

                        图15 超分重建效果对比,从左至右分别为:双线性插值、SRResNet、SRGAN、真值

从图上可以看到,原图因为分辨率较低,产生了模糊并且丢失了大量的细节信息,双线性插值无法有效的去模糊,而SRResNet算法尽管能够一定程度上去除模糊,但是其纹理细节不清晰。最后会发现,SRGAN算法不仅去除了模糊,而且还逼真的重建出了水面上的纹理细节,使得重建的图片视觉上与真值图非常吻合。

那怎么让模型在纹理细节丢失的情况下“无中生有”的重建出这些信息呢?答案就是生成对抗网络(Generative Adversarial
Network, GAN)。

1. 生成对抗网络(GAN)

 GAN的主要灵感来源于博弈论中博弈的思想,应用到深度学习上来说,就是构造两个深度学习模型:生成网络G(Generator)和判别网络D(Discriminator),然后两个模型不断博弈,进而使G生成逼真的图像,而D具有非常强的判断图像真伪的能力。生成网络和判别网络的主要功能是:

  • G是一个生成式的网络,它通过某种特定的网络结构以及目标函数来生成图像;
  • D是一个判别网络,判别一张图片是不是“真实的”,即判断输入的照片是不是由G生成;

G的作用就是尽可能的生成逼真的图像来迷惑D,使得D判断失败;而D的作用就是尽可能的挖掘G的破绽,来判断图像到底是不是由G生成的“假冒伪劣”。整个过程就好比两个新手下棋博弈,随着对弈盘数的增加,一个迷惑手段越来越高明,而另一个甄别本领也越来越强大,最后,两个新手都变成了高手。这个时候再让G去和其它的人下棋,可以想到G迷惑的本领已经超越了一众普通棋手。

以上就是GAN算法的原理。运用在图像领域,例如风格迁移,超分重建,图像补全,去噪等,运用GAN可以避免损失函数设计的困难,不管三七二十一,只要有一个基准,直接加上判别器,剩下的就交给对抗训练。相比其他所有模型, GAN可以产生更加清晰,真实的样本。

2. 感知损失

为了防止重建图像过度平滑,SRGAN重新定义了损失函数,并将其命名为感知损失(Perceptual loss)。感知损失有两部分构成:

感知损失=内容损失+对抗损失

对抗损失就是重建出来的图片被判别器正确判断的损失,这部分内容跟一般的GAN定义相同。SRGAN的一大创新点就是提出了内容损失,SRGAN希望让整个网络在学习的过程中更加关注重建图片和原始图片的语义特征差异,而不是逐个像素之间的颜色亮度差异。以往我们在计算超分重建图像和原始高清图像差异的时候是直接在像素图像上进行比较的,用的MSE准则。SRGAN算法提出者认为这种方式只会过度的让模型去学习这些像素差异,而忽略了重建图像的固有特征。实际的差异计算应该在图像的固有特征上计算。但是这种固有特征怎么表示呢?其实很简单,已经有很多模型专门提出来提取图像固有特征然后进行分类等任务。我们只需要把这些模型中的特征提取模块截取出来,然后去计算重建图像和原始图像的特征,这些特征就是语义特征了,然后再在特征层上进行两幅图像的MSE计算。在众多模型中,SRGAN选用了VGG19模型,其截取的模型命名为truncated_vgg19。所谓模型截断,也就是只提取原始模型的一部分,然后作为一个新的单独的模型进行使用。

至此重新整理下内容损失计算方式:

  • 通过SRResNet模型重建出高清图像SR;
  • 通过truncated_vgg19模型对原始高清图像H和重建出的高清图像SR分别进行计算,得到两幅图像对应的特征图H_fea和SR_fea;
  • 计算H_fea和SR_fea的MSE值;

从上述计算方式上看出,原来的计算方式是直接计算H和SR的MSE值,而改用新的内容损失后只需要利用truncated_vgg19模型对图像多作一次推理得到特征图,再在特征图上进行计算。

3. SRGAN结构剖析

SRGAN分为两部分:生成器模型(Generator)和判别器模型(Discriminator)。

生成器模型采用了SRResNet完全一样的结构,只是在计算损失函数时需要利用截断的VGG19模型进行计算。这里注意,截断的VGG19模型只是用来计算图像特征,其本身并不作为一个子模块加在生成器后面。可以将此处的VGG19模型理解为静止的(梯度不更新的),只是用它来计算一下特征而已,其使用与一般的图像滤波器sobel、canny算子等类似。

判别器模型结构如下所示:

一文掌握图像超分辨率重建(算法原理、Pytorch实现)——含完整代码和数据

                                                                                        图16 SRGAN判别器模型

判别器模型对原始高清图像或者重建的高清图像进行判断,判断图像到底是不是生成器创造出来。本质上是一个分类模型,因此判别器的最终输出是一个1维的张量。判别器模型中间部分使用了多个卷积模块进行特征提取,这部分内容并没有特别之处,因此本文不再对该模型进行阐述,读者自行对照代码理解即可。

4. Pytorch实现

Pytorch实现沿用前面SRResNet的设计框架。由于SRGAN算法的生成器部分采用的是与SRResNet模型完全一样的结构,因此我们在训练时就可以直接使用前面训练好的SRResNet模型对生成器进行初始化以加快整个算法的收敛。

4.1 训练

(1)训练脚本train_srgan.py

参照SRResNet,完整的训练代码如下:

import torch.backends.cudnn as cudnn
import torch
from torch import nn
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter
from models import Generator, Discriminator, TruncatedVGG19
from datasets import SRDataset
from utils import *


# 数据集参数
data_folder = './data/'    # 数据存放路径
crop_size = 96             # 高分辨率图像裁剪尺寸
scaling_factor = 4         # 放大比例

# 生成器模型参数(与SRResNet相同)
large_kernel_size_g = 9   # 第一层卷积和最后一层卷积的核大小
small_kernel_size_g = 3   # 中间层卷积的核大小
n_channels_g = 64         # 中间层通道数
n_blocks_g = 16           # 残差模块数量
srresnet_checkpoint = "./results/checkpoint_srresnet.pth"  # 预训练的SRResNet模型,用来初始化

# 判别器模型参数
kernel_size_d = 3  # 所有卷积模块的核大小
n_channels_d = 64  # 第1层卷积模块的通道数, 后续每隔1个模块通道数翻倍
n_blocks_d = 8     # 卷积模块数量
fc_size_d = 1024  # 全连接层连接数

# 学习参数
batch_size = 128    # 批大小
start_epoch = 1     # 迭代起始位置
epochs = 50         # 迭代轮数
checkpoint = None   # SRGAN预训练模型, 如果没有则填None
workers = 4         # 加载数据线程数量
vgg19_i = 5         # VGG19网络第i个池化层
vgg19_j = 4         # VGG19网络第j个卷积层
beta = 1e-3         # 判别损失乘子
lr = 1e-4           # 学习率

# 设备参数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ngpu = 2                 # 用来运行的gpu数量
cudnn.benchmark = True   # 对卷积进行加速
writer = SummaryWriter() # 实时监控     使用命令 tensorboard --logdir runs  进行查看


def main():
    """
    训练.
    """
    global checkpoint,start_epoch,writer

    # 模型初始化
    generator = Generator(large_kernel_size=large_kernel_size_g,
                              small_kernel_size=small_kernel_size_g,
                              n_channels=n_channels_g,
                              n_blocks=n_blocks_g,
                              scaling_factor=scaling_factor)

    discriminator = Discriminator(kernel_size=kernel_size_d,
                                    n_channels=n_channels_d,
                                    n_blocks=n_blocks_d,
                                    fc_size=fc_size_d)

    # 初始化优化器
    optimizer_g = torch.optim.Adam(params=filter(lambda p: p.requires_grad,generator.parameters()),lr=lr)
    optimizer_d = torch.optim.Adam(params=filter(lambda p: p.requires_grad,discriminator.parameters()),lr=lr)

    # 截断的VGG19网络用于计算损失函数
    truncated_vgg19 = TruncatedVGG19(i=vgg19_i, j=vgg19_j)
    truncated_vgg19.eval()

    # 损失函数
    content_loss_criterion = nn.MSELoss()
    adversarial_loss_criterion = nn.BCEWithLogitsLoss()

    # 将数据移至默认设备
    generator = generator.to(device)
    discriminator = discriminator.to(device)
    truncated_vgg19 = truncated_vgg19.to(device)
    content_loss_criterion = content_loss_criterion.to(device)
    adversarial_loss_criterion = adversarial_loss_criterion.to(device)
    
    # 加载预训练模型
    srresnetcheckpoint = torch.load(srresnet_checkpoint)
    generator.net.load_state_dict(srresnetcheckpoint['model'])

    if checkpoint is not None:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        generator.load_state_dict(checkpoint['generator'])
        discriminator.load_state_dict(checkpoint['discriminator'])
        optimizer_g.load_state_dict(checkpoint['optimizer_g'])
        optimizer_d.load_state_dict(checkpoint['optimizer_d'])
    
    # 单机多GPU训练
    if torch.cuda.is_available() and ngpu > 1:
        generator = nn.DataParallel(generator, device_ids=list(range(ngpu)))
        discriminator = nn.DataParallel(discriminator, device_ids=list(range(ngpu)))

    # 定制化的dataloaders
    train_dataset = SRDataset(data_folder,split='train',
                              crop_size=crop_size,
                              scaling_factor=scaling_factor,
                              lr_img_type='imagenet-norm',
                              hr_img_type='imagenet-norm')
    train_loader = torch.utils.data.DataLoader(train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=workers,
        pin_memory=True) 

    # 开始逐轮训练
    for epoch in range(start_epoch, epochs+1):
        
        if epoch == int(epochs / 2):  # 执行到一半时降低学习率
            adjust_learning_rate(optimizer_g, 0.1)
            adjust_learning_rate(optimizer_d, 0.1)

        generator.train()   # 开启训练模式:允许使用批样本归一化
        discriminator.train()

        losses_c = AverageMeter()  # 内容损失
        losses_a = AverageMeter()  # 生成损失
        losses_d = AverageMeter()  # 判别损失

        n_iter = len(train_loader)

        # 按批处理
        for i, (lr_imgs, hr_imgs) in enumerate(train_loader):

            # 数据移至默认设备进行训练
            lr_imgs = lr_imgs.to(device)  # (batch_size (N), 3, 24, 24),  imagenet-normed 格式
            hr_imgs = hr_imgs.to(device)  # (batch_size (N), 3, 96, 96),  imagenet-normed 格式

            #-----------------------1. 生成器更新----------------------------
            # 生成
            sr_imgs = generator(lr_imgs)  # (N, 3, 96, 96), 范围在 [-1, 1]
            sr_imgs = convert_image(
                sr_imgs, source='[-1, 1]',
                target='imagenet-norm')  # (N, 3, 96, 96), imagenet-normed

            # 计算 VGG 特征图
            sr_imgs_in_vgg_space = truncated_vgg19(sr_imgs)              # batchsize X 512 X 6 X 6
            hr_imgs_in_vgg_space = truncated_vgg19(hr_imgs).detach()     # batchsize X 512 X 6 X 6

            # 计算内容损失
            content_loss = content_loss_criterion(sr_imgs_in_vgg_space,hr_imgs_in_vgg_space)

            # 计算生成损失
            sr_discriminated = discriminator(sr_imgs)  # (batch X 1)   
            adversarial_loss = adversarial_loss_criterion(
                sr_discriminated, torch.ones_like(sr_discriminated)) # 生成器希望生成的图像能够完全迷惑判别器,因此它的预期所有图片真值为1

            # 计算总的感知损失
            perceptual_loss = content_loss + beta * adversarial_loss

            # 后向传播.
            optimizer_g.zero_grad()
            perceptual_loss.backward()

            # 更新生成器参数
            optimizer_g.step()

            #记录损失值
            losses_c.update(content_loss.item(), lr_imgs.size(0))
            losses_a.update(adversarial_loss.item(), lr_imgs.size(0))

            #-----------------------2. 判别器更新----------------------------
            # 判别器判断
            hr_discriminated = discriminator(hr_imgs)
            sr_discriminated = discriminator(sr_imgs.detach())

            # 二值交叉熵损失
            adversarial_loss = adversarial_loss_criterion(sr_discriminated, torch.zeros_like(sr_discriminated)) + \
                            adversarial_loss_criterion(hr_discriminated, torch.ones_like(hr_discriminated))  # 判别器希望能够准确的判断真假,因此凡是生成器生成的都设置为0,原始图像均设置为1

            # 后向传播
            optimizer_d.zero_grad()
            adversarial_loss.backward()

            # 更新判别器
            optimizer_d.step()

            # 记录损失
            losses_d.update(adversarial_loss.item(), hr_imgs.size(0))

            # 监控图像变化
            if i==(n_iter-2):
                writer.add_image('SRGAN/epoch_'+str(epoch)+'_1', make_grid(lr_imgs[:4,:3,:,:].cpu(), nrow=4, normalize=True),epoch)
                writer.add_image('SRGAN/epoch_'+str(epoch)+'_2', make_grid(sr_imgs[:4,:3,:,:].cpu(), nrow=4, normalize=True),epoch)
                writer.add_image('SRGAN/epoch_'+str(epoch)+'_3', make_grid(hr_imgs[:4,:3,:,:].cpu(), nrow=4, normalize=True),epoch)

            # 打印结果
            print("第 "+str(i)+ " 个batch结束")
 
        # 手动释放内存              
        del lr_imgs, hr_imgs, sr_imgs, hr_imgs_in_vgg_space, sr_imgs_in_vgg_space, hr_discriminated, sr_discriminated  # 手工清除掉缓存

        # 监控损失值变化
        writer.add_scalar('SRGAN/Loss_c', losses_c.val, epoch) 
        writer.add_scalar('SRGAN/Loss_a', losses_a.val, epoch)    
        writer.add_scalar('SRGAN/Loss_d', losses_d.val, epoch)    

        # 保存预训练模型
        torch.save({
            'epoch': epoch,
            'generator': generator.module.state_dict(),
            'discriminator': discriminator.module.state_dict(),
            'optimizer_g': optimizer_g.state_dict(),
            'optimizer_g': optimizer_g.state_dict(),
        }, 'results/checkpoint_srgan.pth')
    
    # 训练结束关闭监控
    writer.close()


if __name__ == '__main__':
    main()

相关定义模型在models.py中给出:

class Generator(nn.Module):
    """
    生成器模型,其结构与SRResNet完全一致.
    """

    def __init__(self, large_kernel_size=9, small_kernel_size=3, n_channels=64, n_blocks=16, scaling_factor=4):
        """
        参数 large_kernel_size:第一层和最后一层卷积核大小
        参数 small_kernel_size:中间层卷积核大小
        参数 n_channels:中间层卷积通道数
        参数 n_blocks: 残差模块数量
        参数 scaling_factor: 放大比例
        """
        super(Generator, self).__init__()
        self.net = SRResNet(large_kernel_size=large_kernel_size, small_kernel_size=small_kernel_size,
                            n_channels=n_channels, n_blocks=n_blocks, scaling_factor=scaling_factor)

    def forward(self, lr_imgs):
        """
        前向传播.

        参数 lr_imgs: 低精度图像 (N, 3, w, h)
        返回: 超分重建图像 (N, 3, w * scaling factor, h * scaling factor)
        """
        sr_imgs = self.net(lr_imgs)  # (N, n_channels, w * scaling factor, h * scaling factor)

        return sr_imgs


class Discriminator(nn.Module):
    """
    SRGAN判别器
    """

    def __init__(self, kernel_size=3, n_channels=64, n_blocks=8, fc_size=1024):
        """
        参数 kernel_size: 所有卷积层的核大小
        参数 n_channels: 初始卷积层输出通道数, 后面每隔一个卷积层通道数翻倍
        参数 n_blocks: 卷积块数量
        参数 fc_size: 全连接层连接数
        """
        super(Discriminator, self).__init__()

        in_channels = 3

        # 卷积系列,参照论文SRGAN进行设计
        conv_blocks = list()
        for i in range(n_blocks):
            out_channels = (n_channels if i is 0 else in_channels * 2) if i % 2 is 0 else in_channels
            conv_blocks.append(
                ConvolutionalBlock(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                                   stride=1 if i % 2 is 0 else 2, batch_norm=i is not 0, activation='LeakyReLu'))
            in_channels = out_channels
        self.conv_blocks = nn.Sequential(*conv_blocks)

        # 固定输出大小
        self.adaptive_pool = nn.AdaptiveAvgPool2d((6, 6))

        self.fc1 = nn.Linear(out_channels * 6 * 6, fc_size)

        self.leaky_relu = nn.LeakyReLU(0.2)

        self.fc2 = nn.Linear(1024, 1)

        # 最后不需要添加sigmoid层,因为PyTorch的nn.BCEWithLogitsLoss()已经包含了这个步骤

    def forward(self, imgs):
        """
        前向传播.

        参数 imgs: 用于作判别的原始高清图或超分重建图,张量表示,大小为(N, 3, w * scaling factor, h * scaling factor)
        返回: 一个评分值, 用于判断一副图像是否是高清图, 张量表示,大小为 (N)
        """
        batch_size = imgs.size(0)
        output = self.conv_blocks(imgs)
        output = self.adaptive_pool(output)
        output = self.fc1(output.view(batch_size, -1))
        output = self.leaky_relu(output)
        logit = self.fc2(output)

        return logit


class TruncatedVGG19(nn.Module):
    """
    truncated VGG19网络,用于计算VGG特征空间的MSE损失
    """

    def __init__(self, i, j):
        """
        :参数 i: 第 i 个池化层
        :参数 j: 第 j 个卷积层
        """
        super(TruncatedVGG19, self).__init__()

        # 加载预训练的VGG模型
        vgg19 = torchvision.models.vgg19(pretrained=True)  # C:\Users\Administrator/.cache\torch\checkpoints\vgg19-dcbb9e9d.pth

        maxpool_counter = 0
        conv_counter = 0
        truncate_at = 0
        # 迭代搜索
        for layer in vgg19.features.children():
            truncate_at += 1

            # 统计
            if isinstance(layer, nn.Conv2d):
                conv_counter += 1
            if isinstance(layer, nn.MaxPool2d):
                maxpool_counter += 1
                conv_counter = 0

            # 截断位置在第(i-1)个池化层之后(第 i 个池化层之前)的第 j 个卷积层
            if maxpool_counter == i - 1 and conv_counter == j:
                break

        # 检查是否满足条件
        assert maxpool_counter == i - 1 and conv_counter == j, "当前 i=%d 、 j=%d 不满足 VGG19 模型结构" % (
            i, j)

        # 截取网络
        self.truncated_vgg19 = nn.Sequential(*list(vgg19.features.children())[:truncate_at + 1])

    def forward(self, input):
        """
        前向传播
        参数 input: 高清原始图或超分重建图,张量表示,大小为 (N, 3, w * scaling factor, h * scaling factor)
        返回: VGG19特征图,张量表示,大小为 (N, feature_map_channels, feature_map_w, feature_map_h)
        """
        output = self.truncated_vgg19(input)  # (N, feature_map_channels, feature_map_w, feature_map_h)

        return output

在代码中已经给出了详细的注释,读者在运行调试时结合注释相信可以快速的理解整个处理流程。

(2)训练结果

下图分别展示了整个训练过程中内容损失、生成损失和判别损失的变化曲线。

一文掌握图像超分辨率重建(算法原理、Pytorch实现)——含完整代码和数据

一文掌握图像超分辨率重建(算法原理、Pytorch实现)——含完整代码和数据

 

一文掌握图像超分辨率重建(算法原理、Pytorch实现)——含完整代码和数据

                                                                                                      图18 损失函数变化曲线

从上图中可以看到,相对SRResNet的收敛曲线,SRGAN非常不平稳,判别损失和生成损失此消彼长,这说明判别器和生成器正在做着激烈的对抗。一般来说,生成对抗网络相比普通的网络其训练难度更大,我们无法通过查看loss来说明gan训练得怎么样。目前也有不少文献开始尝试解决整个问题,使得GAN算法的训练进程可以更加明显。

尽管不能从loss损失函数变化曲线上看出训练进程,我们还可以从每次epoch的训练样本重建效果上进行查看。下图分别显示了epoch=1、25和50 部分训练样本重建效果图,第一行为低分辨率图,第二行为超分重建图,第三行为原始高清图。

一文掌握图像超分辨率重建(算法原理、Pytorch实现)——含完整代码和数据

                                                                                          图19 epoch=1时训练结果图

一文掌握图像超分辨率重建(算法原理、Pytorch实现)——含完整代码和数据

                                                                                           图20 epoch=25时训练结果图

一文掌握图像超分辨率重建(算法原理、Pytorch实现)——含完整代码和数据

                                                                                   图21 epoch=50时训练结果图

从训练图上可以看到,在epoch=50即训练结束的时候,其生成到的超分图已经非常接近原始高清图,重建出的图像视觉感受更加突出,细节较丰富,相比SRResNet的过度平滑,其生成的图像更符合真实场景效果。

4.2 评估

完整的评估代码如下eval.py:

from utils import *
from torch import nn
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
from datasets import SRDataset
from models import SRResNet,Generator
import time

# 模型参数
large_kernel_size = 9   # 第一层卷积和最后一层卷积的核大小
small_kernel_size = 3   # 中间层卷积的核大小
n_channels = 64         # 中间层通道数
n_blocks = 16           # 残差模块数量
scaling_factor = 4      # 放大比例
ngpu = 2                # GP数量
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


if __name__ == '__main__':
    
    # 测试集目录
    data_folder = "./data/"
    test_data_names = ["Set5","Set14", "BSD100"]

    # 预训练模型
    srgan_checkpoint = "./results/checkpoint_srgan.pth"
    #srresnet_checkpoint = "./results/checkpoint_srresnet.pth"

    # 加载模型SRResNet 或 SRGAN
    checkpoint = torch.load(srgan_checkpoint)
    generator = Generator(large_kernel_size=large_kernel_size,
                        small_kernel_size=small_kernel_size,
                        n_channels=n_channels,
                        n_blocks=n_blocks,
                        scaling_factor=scaling_factor)
    generator = generator.to(device)
    generator.load_state_dict(checkpoint['generator'])

    # 多GPU测试
    if torch.cuda.is_available() and ngpu > 1:
        generator = nn.DataParallel(generator, device_ids=list(range(ngpu)))
   
    generator.eval()
    model = generator
    # srgan_generator = torch.load(srgan_checkpoint)['generator'].to(device)
    # srgan_generator.eval()
    # model = srgan_generator

    for test_data_name in test_data_names:
        print("\n数据集 %s:\n" % test_data_name)

        # 定制化数据加载器
        test_dataset = SRDataset(data_folder,
                                split='test',
                                crop_size=0,
                                scaling_factor=4,
                                lr_img_type='imagenet-norm',
                                hr_img_type='[-1, 1]',
                                test_data_name=test_data_name)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=1,
                                                pin_memory=True)

        # 记录每个样本 PSNR 和 SSIM值
        PSNRs = AverageMeter()
        SSIMs = AverageMeter()

        # 记录测试时间
        start = time.time()

        with torch.no_grad():
            # 逐批样本进行推理计算
            for i, (lr_imgs, hr_imgs) in enumerate(test_loader):
                
                # 数据移至默认设备
                lr_imgs = lr_imgs.to(device)  # (batch_size (1), 3, w / 4, h / 4), imagenet-normed
                hr_imgs = hr_imgs.to(device)  # (batch_size (1), 3, w, h), in [-1, 1]

                # 前向传播.
                sr_imgs = model(lr_imgs)  # (1, 3, w, h), in [-1, 1]                

                # 计算 PSNR 和 SSIM
                sr_imgs_y = convert_image(sr_imgs, source='[-1, 1]', target='y-channel').squeeze(
                    0)  # (w, h), in y-channel
                hr_imgs_y = convert_image(hr_imgs, source='[-1, 1]', target='y-channel').squeeze(0)  # (w, h), in y-channel
                psnr = peak_signal_noise_ratio(hr_imgs_y.cpu().numpy(), sr_imgs_y.cpu().numpy(),
                                            data_range=255.)
                ssim = structural_similarity(hr_imgs_y.cpu().numpy(), sr_imgs_y.cpu().numpy(),
                                            data_range=255.)
                PSNRs.update(psnr, lr_imgs.size(0))
                SSIMs.update(ssim, lr_imgs.size(0))


        # 输出平均PSNR和SSIM
        print('PSNR  {psnrs.avg:.3f}'.format(psnrs=PSNRs))
        print('SSIM  {ssims.avg:.3f}'.format(ssims=SSIMs))
        print('平均单张样本用时  {:.3f} 秒'.format((time.time()-start)/len(test_dataset)))

    print("\n")

最终在三个数据集上的测试结果如下表所示:

  Set5 Set14 BSD100
PSNR 29.021 25.652 24.833
SSIM 0.839 0.693 0.650
单张图片平均用时(毫秒) 850 650 80

上表中结果与论文中的值较为接近。可以看到,其PSNR和SSIM效果并不好,这是因为SRGAN本质上就不是为了PSNR和SSIM指标而设计优化的。SRGAN论文中也指出使用PSNR和SSIM会让算法过度平滑,在视觉效果上并不理想。为了定量评价SRGAN算法效果,论文中新设计了MOS(Mean Option Score)指标,简单来说,就是让多位观察者采用主观评价的方式对重建效果图进行打分,最后将分数作平均并以此作为评价指标。

4.3 测试

完整的测试代码如下test.py:

from utils import *
from torch import nn
from models import SRResNet,Generator
import time
from PIL import Image

# 测试图像
imgPath = './results/test.jpg'

# 模型参数
large_kernel_size = 9   # 第一层卷积和最后一层卷积的核大小
small_kernel_size = 3   # 中间层卷积的核大小
n_channels = 64         # 中间层通道数
n_blocks = 16           # 残差模块数量
scaling_factor = 4      # 放大比例
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


if __name__ == '__main__':

    # 预训练模型
    srgan_checkpoint = "./results/checkpoint_srgan.pth"
    #srresnet_checkpoint = "./results/checkpoint_srresnet.pth"

    # 加载模型SRResNet 或 SRGAN
    checkpoint = torch.load(srgan_checkpoint)
    generator = Generator(large_kernel_size=large_kernel_size,
                        small_kernel_size=small_kernel_size,
                        n_channels=n_channels,
                        n_blocks=n_blocks,
                        scaling_factor=scaling_factor)
    generator = generator.to(device)
    generator.load_state_dict(checkpoint['generator'])
   
    generator.eval()
    model = generator

    # 加载图像
    img = Image.open(imgPath, mode='r')
    img = img.convert('RGB')

    # 双线性上采样
    Bicubic_img = img.resize((int(img.width * scaling_factor),int(img.height * scaling_factor)),Image.BICUBIC)
    Bicubic_img.save('./results/test_bicubic.jpg')

    # 图像预处理
    lr_img = convert_image(img, source='pil', target='imagenet-norm')
    lr_img.unsqueeze_(0)

    # 记录时间
    start = time.time()

    # 转移数据至设备
    lr_img = lr_img.to(device)  # (1, 3, w, h ), imagenet-normed

    # 模型推理
    with torch.no_grad():
        sr_img = model(lr_img).squeeze(0).cpu().detach()  # (1, 3, w*scale, h*scale), in [-1, 1]   
        sr_img = convert_image(sr_img, source='[-1, 1]', target='pil')
        sr_img.save('./results/test_srgan.jpg')

    print('用时  {:.3f} 秒'.format(time.time()-start))

测试效果如下:

一文掌握图像超分辨率重建(算法原理、Pytorch实现)——含完整代码和数据 一文掌握图像超分辨率重建(算法原理、Pytorch实现)——含完整代码和数据

一文掌握图像超分辨率重建(算法原理、Pytorch实现)——含完整代码和数据 一文掌握图像超分辨率重建(算法原理、Pytorch实现)——含完整代码和数据

图22 低分辨率证件照测试效果图,从左到右从上到下依次为:原图、bicubic上采样、SRResNet超分重建、SRGAN超分重建

四. 总结

本文详细讲述了超分重建的概念和研究进展。针对有代表性的SRResNet和SRGAN算法,分别进行了原理剖析,并给出了Pytorch实现代码。从最终的效果上来看,达到了原论文里的效果。如果读者希望能够较好的入门超分重建领域,那么可以从本文出发,在掌握原理的基础上按照本文示例自己动手完成算法建模,可以为自己在超分重建领域打下良好的基础。当然,如果读者并不是研究超分重建领域,那么本文也可以作为一个实战案例,学习生成对抗网络的操作技巧。本文在组织代码时力求简单明了,并不希望成为一个负担较重的“工程”,在简洁的基础上尽量“傻瓜式”,这样也方便读者可以任意的对其进行扩展和操作。

由于水平有限,本文肯定有不少理解上的错误或者是代码实现上的问题,还请读者能够多多指正,共同进步!

下一篇博文打算研究图像语义分割、抠图领域,同样会以简洁实用为目标,有兴趣的读者后面可以继续关注!

 参考文献

【1】Dong C, Loy C C, He K, et al. Image Super-Resolution Using Deep Convolutional Networks[C]. Computer Vision and Pattern Recognition, 2014.

【2】Shi W, Caballero J, Huszar F, et al. Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network[C]. Computer Vision and Pattern Recognition, 2016: 1874-1883.

【3】Ledig C, Theis L, Huszar F, et al. Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network[C]. Computer Vision and Pattern Recognition, 2017.

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/153416.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(2)


相关推荐

  • LinkedList基本用法「建议收藏」

    LinkedList基本用法「建议收藏」LinkedList类是双向列表,列表中的每个节点都包含了对前一个和后一个元素的引用.LinkedList的构造函数如下1.publicLinkedList(): ——生成空的链表2.publicLinkedList(Collectioncol): 复制构造函数1、获取链表的第一个和最后一个元素[java] viewplaincopy

  • java如何获取随机数(两种方式)

    java如何获取随机数(两种方式)在小的知识,都有深挖之价值。很久没有生产随机数,竟然忘了!我明明记得我做过关于随机数产生的总结,but,我翻遍了整个笔记本,就是没找到。即便我知道笔记就在某一个角落;我还是放弃了查找笔记,跑去Google了,所以我决定建立电子笔记,记录那些小知识点。//获取100以内的随机数packagecom.isea.java;importjava.util.Random;public……

  • unity3d c# 产生真正的随机数

    unity3d c# 产生真正的随机数

  • JavaWeb-简单学生信息管理系统的实现-Jsp+Servlet+MySql

    JavaWeb-简单学生信息管理系统的实现-Jsp+Servlet+MySql关注公众号:吾爱代码,回复Java学生管理系统,获取下载链接~关注公众号:吾爱代码,回复Java学生管理系统,获取下载链接~关注公众号:吾爱代码,回复Java学生管理系统,获取下载链接~

  • MATLAB函数句柄

    MATLAB函数句柄之前一直在用,也知道这么个东西,但是没怎么总结。感觉matlab函数句柄就是c语言里面的函数指针,在matlab里面叫它handle,句柄嘛,有了它就可以操纵这个对象(这里也可以叫做函数),这个概念其实可以推广到很多东西,图形fig,自定义函数句柄(也就是下面将会展示的),matlab自带函数句柄,以及某些函数返回的函数句柄,某些类对象或者表达式,也可以叫做句柄。1、何为函数句柄?函数句柄…

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号