700字范文,内容丰富有趣,生活中的好帮手!
700字范文 > 生成式对抗网络GAN生成手写数字

生成式对抗网络GAN生成手写数字

时间:2024-06-12 16:12:49

相关推荐

生成式对抗网络GAN生成手写数字

GAN(Generative Adversarial Networks)是较为火热的一种神经网络,具有较多的优势和特点。

一、GAN

1. 原理

源自于零和博弈(zero-sum game),包括生成模型(generative model, G)和判别模型(discriminative model, D)。

G,D的主要功能是:

(1)G是一个生成式的网络,它接收一个随机的噪声z(随机数),通过这个噪声生成图像;

(2)D是一个判别网络,判别一张图片是不是“真实的”。它的输入参数是x,x代表一张图片,输出D(x)代表x为真实图片的概率,如果为1,就代表100%是真实的图片,而输出为0,就代表不可能是真实的图片。

训练过程中,生成网络G的目标就是尽量生成真实的图片去欺骗判别网络D。而D的目标就是尽量辨别出G生成的假图像和真实的图像。这样,G和D构成了一个动态的“博弈过程”,最终的平衡点即纳什均衡点。

2. 特点

(1)相比较传统的模型,他存在两个不同的网络,而不是单一的网络,并且训练方式采用的是对抗训练方式;

(2)GAN中G的梯度更新信息来自判别器D,而不是来自数据样本。

3. 优点

(1)GAN采用的是一种无监督的学习方式训练,可以被广泛用在无监督学习和半监督学习领域

(2)相比VAE,GAN没有变分下界,如果判别器训练良好,那么生成器可以完美的学习到训练样本的分布。换句话说,GANs是渐进一致的,但是VAE是有偏差的;

(3)GAN应用到一些场景上,比如图片风格迁移,超分辨率,图像补全,去噪,避免了损失函数设计的困难,只要有一个基准,直接上判别器,剩下的就交给对抗训练了。

二、MNIST数据集

下载地址:/download/yql_617540298/10618317

Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解压后 47 MB, 包含 60,000 个样本)Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解压后 60 KB, 包含 60,000 个标签)Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解压后 7.8 MB, 包含 10,000 个样本)Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解压后 10 KB, 包含 10,000 个标签)

三、GAN生成MNIST手写数字

采用Tensorflow框架,生成手写数字。

tensorflow安装:

pip install tensorflow-gpu==版本号

检查tensorflow版本:

代码如下:

import tensorflow as tfimport numpy as npimport matplotlib.pyplot as pltimport tensorflow.examples.tutorials.mnist.input_data as input_datamnist = input_data.read_data_sets("MNIST_data/", one_hot=True)class G_Net:def forward(self, x, reuse=False):with tf.variable_scope("gnet", reuse=reuse):full_connect= tf.layers.dense(x, 4*4*512)flatten = tf.reshape(full_connect, [-1, 4, 4, 512])layer_1 = tf.nn.leaky_relu(tf.layers.batch_normalization(flatten, training=True))layer_1_dout = tf.nn.dropout(layer_1, keep_prob=0.8)# 4 * 4 * 512 to 7 x 7 x 256layer_2 = tf.nn.leaky_relu(tf.layers.batch_normalization(tf.layers.conv2d_transpose(layer_1_dout, 256, 4, strides=1, padding='valid'), training=True))layer_2_dout = tf.nn.dropout(layer_2, keep_prob=0.8)# 7 x 7 256 to 14 x 14 x 128layer_3 = tf.nn.leaky_relu(tf.layers.batch_normalization(tf.layers.conv2d_transpose(layer_2_dout, 128, 3, strides=2, padding='same'), training=True))layer_3_dout = tf.nn.dropout(layer_3, keep_prob=0.8)# 14 x 14 x 128 to 28 x 28 x 1f_img = tf.layers.conv2d_transpose(layer_3_dout, 1, 3, strides=2, padding='same')out = tf.tanh(f_img)return outdef getParam(self):return tf.get_collection(tf.GraphKeys.VARIABLES, scope="gnet")class D_Net:def forward(self, x, reuse=False):with tf.variable_scope("dnet", reuse=reuse):# 28 x 28 x 1 to 14 x 14 x 128layer_1 = tf.nn.leaky_relu(tf.layers.conv2d(x, 128, kernel_size=3, strides=2, padding="same"))layer_1_dout = tf.nn.dropout(layer_1, keep_prob=0.8)# # 14 x 14 x 128 to 7 x 7 x 256layer_2 = tf.nn.leaky_relu(tf.layers.batch_normalization(tf.layers.conv2d(layer_1_dout, 256, 3, strides=2, padding="same"), training=True))layer_2_dout = tf.nn.dropout(layer_2, keep_prob=0.8)# 7 x 7 x 256 to 4 x 4 x 512layer_3 = tf.nn.leaky_relu(tf.layers.batch_normalization(tf.layers.conv2d(layer_2_dout, 512, 3, strides=2, padding="same"), training=True))layer_3_dout = tf.nn.dropout(layer_3, keep_prob=0.8)# 4 x 4 x 512 to 4 * 4* 512 x 1flatten = tf.reshape(layer_3_dout, (-1, 4*4*512))# logits = tf.sigmoid(tf.layers.dense(flatten, 1))logits = tf.layers.dense(flatten, 1)return logitsdef getParam(self):return tf.get_collection(tf.GraphKeys.VARIABLES, scope="dnet")class GAN_NET:def __init__(self):self.f_xs = tf.placeholder(dtype=tf.float32, shape=[None, 100])self.f_ys = tf.placeholder(dtype=tf.float32, shape=[None, 1])self.t_xs = tf.placeholder(dtype=tf.float32, shape=[None, 28, 28, 1])self.t_ys = tf.placeholder(dtype=tf.float32, shape=[None ,1])self.gnet = G_Net()self.dnet = D_Net()self.forward()self.backward()def forward(self):self.g_out = self.gnet.forward(self.f_xs)self.g_d_out = self.dnet.forward(self.g_out)self.t_d_out = self.dnet.forward(self.t_xs, True)def backward(self):# self.d_loss = tf.reduce_mean((self.t_d_out - self.t_ys) ** 2)self.d_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.g_d_out, labels=self.f_ys)+ tf.nn.sigmoid_cross_entropy_with_logits(logits=self.t_d_out, labels=self.t_ys))self.d_opt = tf.train.AdamOptimizer().minimize(self.d_loss, var_list=self.dnet.getParam())# self.g_loss = tf.reduce_mean((self.g_d_out - self.f_ys) ** 2)self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.g_d_out, labels=self.f_ys))self.g_opt = tf.train.AdamOptimizer().minimize(self.g_loss, var_list=self.gnet.getParam())if __name__ == '__main__':gan_net = GAN_NET()save = tf.train.Saver(max_to_keep=1)d_batch = 10g_batch = 80plt.ion()with tf.Session() as sess:sess.run(tf.global_variables_initializer())for i in range(100000000):xs, _ = mnist.train.next_batch(d_batch)t_xs = np.reshape(xs, newshape=(d_batch,28,28,-1))t_ys = np.ones(shape=[d_batch,1])f_xs = np.random.normal(-1,1,size=(d_batch,100))f_ys = np.zeros(shape=[d_batch, 1])d_loss, _ = sess.run([gan_net.d_loss, gan_net.d_opt], feed_dict={gan_net.f_xs:f_xs, gan_net.f_ys:f_ys,gan_net.t_xs:t_xs, gan_net.t_ys:t_ys})_f_xs = np.random.normal(-1,1,size=(g_batch,100))_t_ys = np.ones(shape=[g_batch, 1])g_loss,_ = sess.run([gan_net.g_loss,gan_net.g_opt], feed_dict={gan_net.f_xs:_f_xs,gan_net.f_ys:_t_ys})print("i---",i," d_loss=",d_loss,"-- g_loss",g_loss)if (i+1) % 100 == 0:save_path = save.save(sess, "./save/gan_mnist")print(save_path)if (i+1) % 5 == 0:# save.restore(sess, "./save/gan_mnist")t_data = np.random.normal(-1,1,size=(1,100))img = sess.run(gan_net.g_out, feed_dict={gan_net.f_xs:t_data})img = (img + 1) * 127.5img = np.array(img, np.uint8)img = np.reshape(img[0],newshape=[28,28])plt.imshow(img)plt.pause(1)

训练:

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。