700字范文,内容丰富有趣,生活中的好帮手!
700字范文 > 生成对抗网络(GAN)原理和实现

生成对抗网络(GAN)原理和实现

时间:2018-11-06 17:04:35

相关推荐

生成对抗网络(GAN)原理和实现

个人博客:/

原文链接:/show-54.html

生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。论文《Generative Adversarial Nets》首次提出GAN。

GAN的思想

GAN由生成器G和判别器D组成。生成器G根据输入先验分布的随机向量(一般使用随机分布,论文里用的是高斯分布)得到符合数据集数据分布。判别器D判别输入数据来源于G还是真实数据。框架如下:

GAN的训练过程:刚开始G和D里面的参数随机初始化,第一步使用真实图片训练D,则D能轻易判断G生成的图片和真实图片。接着训练G,使得G生成的图片更加逼真,直到D无法判断G生成的图片和真实图片。接着训练D,使D能轻易判断真实图片。。。以此类推,最终G生成的图片和真实图片很相似。就好像论文里的例子,G就像货币的伪造者,D就像警察,G造假币,D识别假币,两者互相对抗。G造的假币越来越逼真,D识别的手段越来越高超。最终G生成的东西跟真的差不多。过程示意:

目标函数

从本质上看,GAN的训练目标就是使G恢复训练数据的数据分布。数据分布可以理解为数据的概率函数。

用更数学的表示,G将先验分布(比如高斯分布)P_prior(z)中的z通过G映射到x,即x的分布为P_G(x,theta),theta即G的参数。将生成器分布P_G(x,theta)和真实数据分布P_data(x)对比即可得到损失loss,如下图:

判别器D评估P_G(x)和已知数据分布P_data(x)的差异,即判别输入的x是来自真实分布还是生成器。

根据GAN的思想,其优化过程可以表示为以下公式:

将上式拆分,得到D和G的目标函数。

优化D,D的目标是将G(z)判断为真的概率D(G(z))尽可能小,即1-D(G(z))尽可能大,且将真实数据x判断为真的概率D(x)尽可能大。因此得公式如下:

优化G,G的目标是令D判断为真的概率尽可能大,即D(G(z))尽可能大。因此得公式如下:

实际训练G的时候,早期要求V的初始斜率大,因此需要替换V:

实际的数据是离散的,因此计算分布的期望是通过采样计算得到的。论文里提出,迭代的优化k步D和1步G。这可以让D保持在最优解附近,可得迭代优化参数的算法:

理论证明

以下的理论分析将证明GAN的目标函数有一个最优解p_g=p_data。

对于目标函数:

上式是求两个期望的相加,等价于:

我们想要在找到最优的D*,使得V(G,D)最大:

求V(G,D)等价与求以下公式最大:

上式中, P_data(x)是一个常量,表示x对应的概率分布中的值,这里设为a,P_G(x)也是如此,设为b。因此可以对上式进行求导,即可得到D*,过程如下:

代入D目标函数,得:

当且仅当p_g=p_data,C(G)取得最大值,此时C(G)=-log4,如下:

将D*代入V(G,D)的积分表达式,得:

分子分母同时除以2。再将1/2提出来,且和等于1,则:

上式的KL是KL散度。KL散度:相对熵(relative entropy),又被称为Kullback-Leibler散度(Kullback-Leibler divergence)或信息散度(information divergence),是两个概率分布(probability distribution)间差异的非对称性度量 [1] 。在在信息理论中,相对熵等价于两个概率分布的信息熵(Shannon entropy)的差值。离散和连续随机变量的公式如下:

继续得到:

JSD是Jesen-Shannon散度。由于两个分布之间的Jensen-Shannon散度总是非负的,只有当它们相等时才为零,所以我们已经证明了C∗=−log(4)是C(G)的全局最小值,而唯一的解决方案是p_g = p_data,即,生成模型完美地复制了数据生成过程。

代码实现

论文里生成器使用全连接网络,使用relu和sigmoid激活函数;判别器也是全连接网络,使用maxout激活函数,同时应用了dropout。生成器的输入是符合高斯分布的随机向量,theta值根据交叉验证得到。在MNIST、TFD和CIFAR-10上面测试。

我这里的代码也是参照了网上开源的代码,判别器和生成器均使用relu和sigmoid激活函数。经过我的测试,发现每轮迭代的时候,生成器应该要比判别器训练更多,否则会发散,这跟论文里的描述相反。

网络的计算图如下:

代码如下:

定义参数

import matplotlib.pyplot as pltimport numpy as npimport tensorflow as tffrom tensorflow import name_scope as namespacefrom tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets("MNIST_data", one_hot=True)# 训练参数num_steps = 20000batch_size = 128learning_rate = 0.0002# 网络参数image_dim = 784 # 28*28gen_hidden_dim = 256disc_hidden_dim = 256noise_dim = 100 # Noise data pointsk=1# 保存隐藏层的权重和偏置,用于变量共享with namespace('var'):weights = {'gen_hidden1': tf.Variable(tf.truncated_normal([noise_dim, gen_hidden_dim],stddev=0.1)),'gen_out': tf.Variable(tf.truncated_normal([gen_hidden_dim, image_dim],stddev=0.1)),'disc_hidden1': tf.Variable(tf.truncated_normal([image_dim, disc_hidden_dim],stddev=0.1)),'disc_out': tf.Variable(tf.truncated_normal([disc_hidden_dim, 1],stddev=0.1)),}biases = {'gen_hidden1': tf.Variable(tf.zeros([gen_hidden_dim])),'gen_out': tf.Variable(tf.zeros([image_dim])),'disc_hidden1': tf.Variable(tf.zeros([disc_hidden_dim])),'disc_out': tf.Variable(tf.zeros([1])),}

定义网络和优化器

# 生成网络def generator(x):with namespace('gen_hidden1'):hidden_layer = tf.matmul(x, weights['gen_hidden1'])hidden_layer = tf.add(hidden_layer, biases['gen_hidden1'])hidden_layer = tf.nn.relu(hidden_layer)with namespace('gen_out'):out_layer = tf.matmul(hidden_layer, weights['gen_out'])out_layer = tf.add(out_layer, biases['gen_out'])out_layer = tf.nn.sigmoid(out_layer)return out_layer# 判别网络def discriminator(x):with namespace('disc_hidden1'):hidden_layer = tf.matmul(x, weights['disc_hidden1'])hidden_layer = tf.add(hidden_layer, biases['disc_hidden1'])hidden_layer = tf.nn.relu(hidden_layer)with namespace('disc_output'):out_layer = tf.matmul(hidden_layer, weights['disc_out'])out_layer = tf.add(out_layer, biases['disc_out'])out_layer = tf.nn.sigmoid(out_layer)return out_layer# 网络输入gen_input = tf.placeholder(tf.float32, shape=[None, noise_dim], name='input_noise')disc_input = tf.placeholder(tf.float32, shape=[None, image_dim], name='disc_input')# 创建生成网络with namespace('generator'):gen_sample = generator(gen_input)# 创建两个判别网络 (一个来自噪声输入, 一个来自生成的样本)with namespace('discriminator'):with namespace('discriminator_real'):disc_real = discriminator(disc_input)with namespace('discriminator_fake'):disc_fake = discriminator(gen_sample)with namespace('loss'):# 定义损失函数gen_loss = -tf.reduce_mean(tf.log(disc_fake))disc_loss = -tf.reduce_mean(tf.log(disc_real) + tf.log(1. - disc_fake))#将变量的损失值写入Losstf.summary.scalar('gen_loss', gen_loss)tf.summary.scalar('disc_loss', disc_loss)merged_summary = tf.summary.merge_all()with namespace('train'): # 定义优化器optimizer_gen = tf.train.AdamOptimizer(learning_rate=learning_rate)optimizer_disc = tf.train.AdamOptimizer(learning_rate=learning_rate)# 训练每个优化器的变量# 生成网络变量gen_vars = [weights['gen_hidden1'], weights['gen_out'],biases['gen_hidden1'], biases['gen_out']]# 判别网络变量disc_vars = [weights['disc_hidden1'], weights['disc_out'],biases['disc_hidden1'], biases['disc_out']]# 最小损失函数train_gen = optimizer_gen.minimize(gen_loss, var_list=gen_vars)train_disc = optimizer_disc.minimize(disc_loss, var_list=disc_vars)# 初始化变量init = tf.global_variables_initializer()

训练网络

def getData(batch_size=128):batch_x, _ = mnist.train.next_batch(batch_size)# 准备数据z = np.random.uniform(-1., 1., size=[batch_size, noise_dim])# 产生噪声给生成网络return batch_x,z# 开始训练with tf.Session() as sess:sess.run(init)writer=tf.summary.FileWriter('D:/Jupyter/GAN/mnist_gan_train_log/log',sess.graph)saver=tf.train.Saver()for i in range(1, num_steps+1):for j in range(k):x,z=getData()_,dl = sess.run([train_disc, disc_loss], feed_dict={disc_input: x, gen_input: z})x,z=getData()_,gl = sess.run([train_gen, gen_loss], feed_dict={disc_input: x, gen_input: z})x,z=getData()_,gl = sess.run([train_gen, gen_loss], feed_dict={disc_input: x, gen_input: z})x,z=getData()_,gl = sess.run([train_gen, gen_loss], feed_dict={disc_input: x, gen_input: z})x,z=getData()summary,g = sess.run([merged_summary,gen_sample], feed_dict={disc_input:x,gen_input:z})writer.add_summary(summary,i)#写summary和i到文件if i % 1000 == 0 or i == 1:print('Step %i: Generator Loss: %f, Discriminator Loss: %f' % (i, gl, dl))# 使用生成器网络从噪声生成图像f, a = plt.subplots(4, 10, figsize=(10, 4))for i in range(10):# 噪声输入.z = np.random.uniform(-1., 1., size=[4, noise_dim])g = sess.run([gen_sample], feed_dict={gen_input: z})g = np.reshape(g, newshape=(4, 28, 28, 1))# 将原来黑底白字转换成白底黑字,更好的显示g = -1 * (g - 1)for j in range(4):# 从噪音中生成图像。 扩展到3个通道,用于matplotlibimg = np.reshape(np.repeat(g[j][:, :, np.newaxis], 3, axis=2),newshape=(28, 28, 3))a[j][i].imshow(img)plt.savefig('test.png')#保存图片#f.show()#plt.draw()#plt.waitforbuttonpress()

生成的结果:

训练过程的损失值:

从损失曲线可以看到,GAN的训练过程是不稳定的。

参考文献

[1]Ian J. Goodfellow,etc.Generative Adversarial Nets..arXiv:1406.2661v1

[2]小白的成长. GAN之V(D,G)函数. /qq_42413820/article/details/80673857. -06-13

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。