700字范文,内容丰富有趣,生活中的好帮手!
700字范文 > GAN: Generative Adversarial Nets

GAN: Generative Adversarial Nets

时间:2021-09-17 14:01:27

相关推荐

GAN: Generative Adversarial Nets

谈到生成对抗网络,我们首先想到的是Goodfellow的开山之作:Generative Adversarial Nets。今天,我们就来谈谈这篇文章。针对一个估计数据分布的问题,当模型的类别已知,我们一般采用极大似然方法进行估计。然而,当模型的类别未知或数据分布过于莫杂时,我们如何近似得到数据的俄分布呢?我想,对抗网络的提出给了我们一些思路。

生成对抗网络,由两个网络组成,即生成器和判别器,在Goodfellow这篇文章里,两者都是由感知器组成。生成器用来建立满足一定分布的随机噪声和目标分布的映射关系,判别器用来区别实际数据分布和生成器产生的数据分布。在训练的过程中,交替迭代训练生成器和判别器,使得生成器产生的数据分布逼近真实数据的分布,欺骗判别器;判别器提升两个数据分布的判别能力。最终达到纳什均衡,使得判别器无法判断两个分布的真伪。

from __future__ import print_functionfrom six.moves import xrangeimport tensorflow.contrib.slim as slimimport osimport tensorflow as tfimport numpy as npimport tensorflow.contrib.layers as lyfrom load_svhn import load_svhnfrom tensorflow.examples.tutorials.mnist import input_datadef lrelu(x, leak=0.3, name="lrelu"):with tf.variable_scope(name):f1 = 0.5 * (1 + leak)f2 = 0.5 * (1 - leak)return f1 * x + f2 * abs(x)batch_size = 64z_dim = 128learning_rate_ger = 5e-5learning_rate_dis = 5e-5device = '/gpu:0'# img sizes = 32# update Citers times of critic in one iter(unless i < 25 or i % 500 == 0, i is iterstep)Citers = 5# the upper bound and lower bound of parameters in criticclamp_lower = -0.01clamp_upper = 0.01# whether to use mlp or dcgan stuctureis_mlp = False# whether to use adam for parameter update, if the flag is set False, use tf.train.RMSPropOptimizer# as recommended in paperis_adam = False# whether to use SVHN or MNIST, set false and MNIST is usedis_svhn = Falsechannel = 3 if is_svhn is True else 1s2, s4, s8, s16 =\int(s / 2), int(s / 4), int(s / 8), int(s / 16)# hidden layer size if mlp is chosen, ignore if otherwisengf = 64ndf = 64# directory to store log, including loss and grad_norm of generator and criticlog_dir = './log_wgan'ckpt_dir = './ckpt_wgan'if not os.path.exists(ckpt_dir):os.makedirs(ckpt_dir)# max iter step, note the one step indicates that a Citers updates of critic and one update of generatormax_iter_step = 20000def generator_mlp(z):train = ly.fully_connected(z, 4 * 4 * 512, activation_fn=lrelu, normalizer_fn=ly.batch_norm)train = ly.fully_connected(train, ngf, activation_fn=lrelu, normalizer_fn=ly.batch_norm)train = ly.fully_connected(train, ngf, activation_fn=lrelu, normalizer_fn=ly.batch_norm)train = ly.fully_connected(train, s*s*channel, activation_fn=tf.nn.tanh, normalizer_fn=ly.batch_norm)train = tf.reshape(train, tf.stack([batch_size, s, s, channel]))return traindef critic_mlp(img, reuse=False):with tf.variable_scope('critic') as scope:if reuse:scope.reuse_variables()size = 64img = ly.fully_connected(tf.reshape(img, [batch_size, -1]), ngf, activation_fn=tf.nn.relu)img = ly.fully_connected(img, ngf,activation_fn=tf.nn.relu)img = ly.fully_connected(img, ngf,activation_fn=tf.nn.relu)logit = ly.fully_connected(img, 1, activation_fn=None)return logitdef build_graph():z = tf.placeholder(tf.float32, shape=(batch_size, z_dim))generator = generator_mlp if is_mlp else generator_convcritic = critic_mlp if is_mlp else critic_convwith tf.variable_scope('generator'):train = generator(z)real_data = tf.placeholder(dtype=tf.float32, shape=(batch_size, 32, 32, channel))true_logit = critic(real_data)fake_logit = critic(train, reuse=True)c_loss = tf.reduce_mean(fake_logit - true_logit)g_loss = tf.reduce_mean(-fake_logit)g_loss_sum = tf.summary.scalar("g_loss", g_loss)c_loss_sum = tf.summary.scalar("c_loss", c_loss)img_sum = tf.summary.image("img", train, max_outputs=10)theta_g = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')theta_c = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='critic')counter_g = tf.Variable(trainable=False, initial_value=0, dtype=tf.int32)opt_g = ly.optimize_loss(loss=g_loss, learning_rate=learning_rate_ger,optimizer=tf.train.AdamOptimizer if is_adam is True else tf.train.RMSPropOptimizer, variables=theta_g, global_step=counter_g,summaries = 'gradient_norm')counter_c = tf.Variable(trainable=False, initial_value=0, dtype=tf.int32)opt_c = ly.optimize_loss(loss=c_loss, learning_rate=learning_rate_dis,optimizer=tf.train.AdamOptimizer if is_adam is True else tf.train.RMSPropOptimizer, variables=theta_c, global_step=counter_c,summaries = 'gradient_norm')clipped_var_c = [tf.assign(var, tf.clip_by_value(var, clamp_lower, clamp_upper)) for var in theta_c]# merge the clip operations on critic variableswith tf.control_dependencies([opt_c]):opt_c = tf.tuple(clipped_var_c)return opt_g, opt_c, z, real_datadef main():if is_svhn is True:dataset = load_svhn()else:dataset = input_data.read_data_sets('MNIST_data', one_hot=True)with tf.device(device):opt_g, opt_c, z, real_data = build_graph()merged_all = tf.summary.merge_all()saver = tf.train.Saver()config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)config.gpu_options.allow_growth = Trueconfig.gpu_options.per_process_gpu_memory_fraction = 0.8def next_feed_dict():train_img = dataset.train.next_batch(batch_size)[0]train_img = 2*train_img-1if is_svhn is not True:train_img = np.reshape(train_img, (-1, 28, 28))npad = ((0, 0), (2, 2), (2, 2))train_img = np.pad(train_img, pad_width=npad,mode='constant', constant_values=-1)train_img = np.expand_dims(train_img, -1)batch_z = np.random.normal(0, 1, [batch_size, z_dim]) \.astype(np.float32)feed_dict = {real_data: train_img, z: batch_z}return feed_dictwith tf.Session(config=config) as sess:sess.run(tf.global_variables_initializer())summary_writer = tf.summary.FileWriter(log_dir, sess.graph)for i in range(max_iter_step):if i < 25 or i % 500 == 0:citers = 100else:citers = Citersfor j in range(citers):feed_dict = next_feed_dict()if i % 100 == 99 and j == 0:run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)run_metadata = tf.RunMetadata()_, merged = sess.run([opt_c, merged_all], feed_dict=feed_dict,options=run_options, run_metadata=run_metadata)summary_writer.add_summary(merged, i)summary_writer.add_run_metadata(run_metadata, 'critic_metadata {}'.format(i), i)else:sess.run(opt_c, feed_dict=feed_dict)feed_dict = next_feed_dict()if i % 100 == 99:_, merged = sess.run([opt_g, merged_all], feed_dict=feed_dict,options=run_options, run_metadata=run_metadata)summary_writer.add_summary(merged, i)summary_writer.add_run_metadata(run_metadata, 'generator_metadata {}'.format(i), i)else:sess.run(opt_g, feed_dict=feed_dict)if i % 1000 == 999:saver.save(sess, os.path.join(ckpt_dir, "model.ckpt"), global_step=i)

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。