生成对抗网络——GAN(一)「建议收藏」

生成对抗网络——GAN(一)「建议收藏」Generativeadversarialnetwork据有关媒体统计:CVPR2018的论文里,有三分之一的论文与GAN有关!由此可见,GAN在视觉领域的未来多年内,将是一片沃土(CVer们是时候入门GAN了)。而发现这片矿源的就是GAN之父,Goodfellow大神。~~~生成对抗网络GAN,是当今的一大热门研究方向。在2014年,被Goodfellow大神提出来,当时的G…

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全家桶1年46,售后保障稳定

Generative adversarial network

据有关媒体统计:CVPR2018的论文里,有三分之一的论文与GAN有关
由此可见,GAN在视觉领域的未来多年内,将是一片沃土(CVer们是时候入门GAN了)。而发现这片矿源的就是GAN之父,Goodfellow大神。
文末有基于keras的GAN代码,有助于理解GAN的原理


生成对抗网络GAN,是当今的一大热门研究方向。在2014年,被Goodfellow大神提出来,当时的G神还只是蒙特利尔大学的博士生而已。
GAN之父的主页:
http://www.iangoodfellow.com/

GAN的论文首次出现在NIPS2014上,论文地址如下:
https://arxiv.org/pdf/1406.2661.pdf


入坑GAN,首先需要理由,GAN能做什么,为什么要学GAN。
GAN的初衷就是生成不存在于真实世界的数据,类似于使得 AI具有创造力或者想象力。应用场景如下:

  1. AI作家,AI画家等需要创造力的AI体;
  2. 将模糊图变清晰(去雨,去雾,去抖动,去马赛克等),这需要AI具有所谓的“想象力”,能脑补情节;
  3. 进行数据增强,根据已有数据生成更多新数据供以feed,可以减缓模型过拟合现象。
    以上的场景都可以找到相应的paper。而且GAN的用处也远不止此,期待我们继续挖掘,是发论文的好方向哦

GAN的原理介绍

这里介绍的是原生的GAN算法,虽然有一些不足,但提供了一种生成对抗性的新思路。放心,我这篇博文不会堆一大堆公式,只会提供一种理解思路。

理解GAN的两大护法GD

G是generator,生成器: 负责凭空捏造数据出来

D是discriminator,判别器: 负责判断数据是不是真数据

这样可以简单的看作是两个网络的博弈过程。在最原始的GAN论文里面,G和D都是两个多层感知机网络。首先,注意一点,GAN操作的数据不一定非得是图像数据,不过为了更方便解释,我在这里用图像数据为例解释以下GAN:
图片名称

稍微解释以下上图,z是随机噪声(就是随机生成的一些数,也就是GAN生成图像的源头)。D通过真图和假图的数据(相当于天然label),进行一个二分类神经网络训练(想各位必再熟悉不过了)。G根据一串随机数就可以捏造一个“假图像”出来,用这些假图去欺骗D,D负责辨别这是真图还是假图,会给出一个score。比如,G生成了一张图,在D这里得分很高,那证明G是很成功的;如果D能有效区分真假图,则G的效果还不太好,需要调整参数。GAN就是这么一个博弈的过程。


那么,GAN是怎么训练呢
根据GAN的训练算法,我画一张图:
图片名称

GAN的训练在同一轮梯度反传的过程中可以细分为2步,先训练D在训练G;注意不是等所有的D训练好以后,才开始训练G,因为D的训练也需要上一轮梯度反传中G的输出值作为输入。

当训练D的时候
,上一轮G产生的图片,和真实图片,直接拼接在一起,作为x。然后根据,按顺序摆放0和1,假图对应0,真图对应1。然后就可以通过,x输入生成一个score(从0到1之间的数),通过score和y组成的损失函数,就可以进行梯度反传了。(我在图片上举的例子是batch = 1,len(y)=2*batch,训练时通常可以取较大的batch)

当训练G的时候, 需要把G和D当作一个整体,我在这里取名叫做’D_on_G’。这个整体(下面简称DG系统)的输出仍然是score。输入一组随机向量,就可以在G生成一张图,通过D对生成的这张图进行打分,这就是DG系统的前向过程。score=1就是DG系统需要优化的目标,score和y=1之间的差异可以组成损失函数,然后可以反向传播梯度。注意,这里的D的参数是不可训练的。这样就能保证G的训练是符合D的打分标准的。这就好比:如果你参加考试,你别指望能改变老师的评分标准


需要注意的是,整个GAN的整个过程都是无监督的(后面会有监督性GAN比如cGAN),怎么理解这里的无监督呢?
这里,给的真图是没有经过人工标注的,你只知道这是真实的图片,比如全是人脸,而系统里的D并不知道来的图片是什么玩意儿,它只需要分辨真假。G也不知道自己生成的是什么玩意儿,反正就是学真图片的样子骗D。

正由于GAN的无监督,在生成过程中,G就会按照自己的意思天马行空生成一些“诡异”的图片,可怕的是D还能给一个很高的分数。比如,生成人脸极度扭曲的图片。这就是无监督目的性不强所导致的,所以在同年的NIPS大会上,有一篇论文conditional GAN就加入了监督性进去,将可控性增强,表现效果也好很多。


from __future__ import print_function, division

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam

import matplotlib.pyplot as plt

import sys

import numpy as np

class GAN():
    def __init__(self):
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = 100

        optimizer = Adam(0.0002, 0.5)

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy',
            optimizer=optimizer,
            metrics=['accuracy'])

        # Build the generator
        self.generator = self.build_generator()

        # The generator takes noise as input and generates imgs
        z = Input(shape=(self.latent_dim,))
        img = self.generator(z)

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # The discriminator takes generated images as input and determines validity
        validity = self.discriminator(img)

        # The combined model (stacked generator and discriminator)
        # Trains the generator to fool the discriminator
        self.combined = Model(z, validity)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)


    def build_generator(self):

        model = Sequential()

        model.add(Dense(256, input_dim=self.latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(np.prod(self.img_shape), activation='tanh'))
        model.add(Reshape(self.img_shape))

        model.summary()

        noise = Input(shape=(self.latent_dim,))
        img = model(noise)

        return Model(noise, img)

    def build_discriminator(self):

        model = Sequential()

        model.add(Flatten(input_shape=self.img_shape))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(1, activation='sigmoid'))
        model.summary()

        img = Input(shape=self.img_shape)
        validity = model(img)

        return Model(img, validity)

    def train(self, epochs, batch_size=128, sample_interval=50):

        # Load the dataset
        (X_train, _), (_, _) = mnist.load_data()

        # Rescale -1 to 1
        X_train = X_train / 127.5 - 1.
        X_train = np.expand_dims(X_train, axis=3)

        # Adversarial ground truths
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs):

            # ---------------------
            # Train Discriminator
            # ---------------------

            # Select a random batch of images
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

            # Generate a batch of new images
            gen_imgs = self.generator.predict(noise)

            # Train the discriminator
            d_loss_real = self.discriminator.train_on_batch(imgs, valid)
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ---------------------
            # Train Generator
            # ---------------------

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

            # Train the generator (to have the discriminator label samples as valid)
            g_loss = self.combined.train_on_batch(noise, valid)

            # Plot the progress
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

            # If at save interval => save generated image samples
            if epoch % sample_interval == 0:
                self.sample_images(epoch)

    def sample_images(self, epoch):
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/%d.png" % epoch)
        plt.close()


if __name__ == '__main__':
    gan = GAN()
    gan.train(epochs=30000, batch_size=32, sample_interval=200)

Jetbrains全家桶1年46,售后保障稳定

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/219096.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)


相关推荐

  • html倒计时代码

    <SPANid=span_dt_dt></SPAN><SCRIPTlanguage=javascript><!–//document.write(“”);functionshow_date_time(){window.setTimeout(“show_date_time()”,1000);BirthDay=newDate(“…

  • modprobe命令不能用_modprobe查看已加载模块

    modprobe命令不能用_modprobe查看已加载模块modprobe命令用于智能地向内核中加载模块或者从内核中移除模块。modprobe可载入指定的个别模块,或是载入一组相依的模块。modprobe会根据depmod所产生的相依关系,决定要载入哪些模块。若在载入过程中发生错误,在modprobe会卸载整组的模块。语法modprobe(选项)(参数)选项 -a或–all:载入全部的模块; -c或–show-conf:显示所有模块的设置信息; -d或–debug:使用排错模式; -l或–li.

    2022年10月24日
  • 卷积神经网络CNN的反向传播原理

    卷积神经网络CNN的反向传播原理  上一篇博客《详解神经网络的前向传播和反向传播》推导了普通神经网络(多层感知器)的反向传播过程,这篇博客则讨论一下卷积神经网络中反向传播的不同之处。先简单回顾一下普通神经网络中反向传播的四个核心公式:…

  • 如何运行SpringBoot项目

    如何运行SpringBoot项目最近在Ecplise上面写了一个简单的SpringBoot的测试项目,SpringBoot里面是有主函数的:我们知道的是在Ecplise上面找到这个主函数然后runas-&gt;javaApplication就可以了但是总不能一直不脱离Ecplise,总要出来自己单练的第一步:我就新建的一个文件夹boottest,然后右键导出整个工程:导出的是jar包,然后我们看…

    2022年10月13日
  • 建站指南和总结(期末总结)

    换了一个新的站点,Wordpress也没想象中的好用嘛

  • js用户注册表单验证_onclick调用js函数

    js用户注册表单验证_onclick调用js函数源代码<!DOCTYPEhtml><html><body><h1>js通过button的简单验证</h1><pid=&quo

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号