700字范文,内容丰富有趣,生活中的好帮手!
700字范文 > 深度卷积生成对抗网络DCGAN——生成手写数字图片

深度卷积生成对抗网络DCGAN——生成手写数字图片

时间:2019-08-17 04:53:28

相关推荐

深度卷积生成对抗网络DCGAN——生成手写数字图片

前言

本文使用深度卷积生成对抗网络(DCGAN)生成手写数字图片,代码使用Keras API与tf.GradientTape 编写的,其中tf.GradientTrape是训练模型时用到的。

本文用到imageio 库来生成gif图片,如果没有安装的,需要安装下:

# 用于生成 GIF 图片pip install -q imageio

目录

前言

一、什么是生成对抗网络?

二、加载数据集

三、创建模型

3.1 生成器

3.1 判别器

四、定义损失函数和优化器

4.1 生成器的损失和优化器

4.2 判别器的损失和优化器

五、训练模型

5.1 保存检查点

5.2 定义训练过程

5.3 训练模型

六、评估模型

一、什么是生成对抗网络?

生成对抗网络(GAN),包含生成器和判别器,两个模型通过对抗过程同时训练。

生成器,可以理解为“艺术家、创造者”,它学习创造看起来真实的图像。

判别器,可以理解为“艺术评论家、审核者”,它学习区分真假图像。

训练过程中,生成器在生成逼真图像方便逐渐变强,而判别器在辨别这些图像的能力上逐渐变强。

当判别器不能再区分真实图片和伪造图片时,训练过程达到平衡。

本文,在MNIST数据集上演示了该过程。随着训练的进行,生成器所生成的一系列图片,越来越像真实的手写数字。

二、加载数据集

使用MNIST数据,来训练生成器和判别器。生成器将生成类似于MNIST数据集的手写数字。

(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内BUFFER_SIZE = 60000BATCH_SIZE = 256# 批量化和打乱数据train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

三、创建模型

主要创建两个模型,一个是生成器,另一个是判别器

3.1 生成器

生成器使用 tf.keras.layers.Conv2DTranspose 层,来从随机噪声中产生图片。

然后把从随机噪声中产生图片,作为输入数据,输入到Dense层,开始。

后面,经过多次上采样,达到所预期 28x28x1 的图片尺寸。

def make_generator_model():model = tf.keras.Sequential()model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())model.add(layers.Reshape((7, 7, 256)))assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))assert model.output_shape == (None, 7, 7, 128)model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))assert model.output_shape == (None, 14, 14, 64)model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))assert model.output_shape == (None, 28, 28, 1)return model

用tf.keras.utils.plot_model( ),看一下模型结构

用summary(),看一下模型结构和参数

使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。

generator = make_generator_model()noise = tf.random.normal([1, 100])generated_image = generator(noise, training=False)plt.imshow(generated_image[0, :, :, 0], cmap='gray')

3.1 判别器

判别器是基于 CNN卷积神经网络 的图片分类器。

def make_discriminator_model():model = tf.keras.Sequential()model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',input_shape=[28, 28, 1]))model.add(layers.LeakyReLU())model.add(layers.Dropout(0.3))model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))model.add(layers.LeakyReLU())model.add(layers.Dropout(0.3))model.add(layers.Flatten())model.add(layers.Dense(1))return model

用tf.keras.utils.plot_model( ),看一下模型结构

用summary(),看一下模型结构和参数

四、定义损失函数和优化器

由于有两个模型,一个是生成器,另一个是判别器;所以要分别为两个模型定义损失函数和优化器。

首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。

# 该方法返回计算交叉熵损失的辅助函数cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

4.1 生成器的损失和优化器

1)生成器损失

生成器损失是量化其欺骗判别器的能力;如果生成器表现良好,判别器将会把伪造图片判断为真实图片(或1)。

这里我们将把判别器在生成图片上的判断结果,与一个值全为1的数组进行对比。

def generator_loss(fake_output):return cross_entropy(tf.ones_like(fake_output), fake_output)

2)生成器优化器

generator_optimizer = tf.keras.optimizers.Adam(1e-4)

4.2 判别器的损失和优化器

1)判别器损失

判别器损失,是量化判断真伪图片的能力。它将判别器对真实图片的预测值,与全值为1的数组进行对比;将判别器对伪造(生成的)图片的预测值,与全值为0的数组进行对比。

def discriminator_loss(real_output, fake_output):real_loss = cross_entropy(tf.ones_like(real_output), real_output)fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)total_loss = real_loss + fake_lossreturn total_loss

2)判别器优化器

discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

五、训练模型

5.1 保存检查点

保存检查点,能帮助保存和恢复模型,在长时间训练任务被中断的情况下比较有帮助。

checkpoint_dir = './training_checkpoints'checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,discriminator_optimizer=discriminator_optimizer,generator=generator,discriminator=discriminator)

5.2 定义训练过程

EPOCHS = 50noise_dim = 100num_examples_to_generate = 16# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)seed = tf.random.normal([num_examples_to_generate, noise_dim])

训练过程中,在生成器接收到一个“随机噪声中产生的图片”作为输入开始。

判别器随后被用于区分真实图片(训练集的)和伪造图片(生成器生成的)。

两个模型都计算损失函数,并且分别计算梯度用于更新生成器与判别器。

# 注意 `tf.function` 的使用# 该注解使函数被“编译”@tf.functiondef train_step(images):noise = tf.random.normal([BATCH_SIZE, noise_dim])with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:generated_images = generator(noise, training=True)real_output = discriminator(images, training=True)fake_output = discriminator(generated_images, training=True)gen_loss = generator_loss(fake_output)disc_loss = discriminator_loss(real_output, fake_output)gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))def train(dataset, epochs):for epoch in range(epochs):start = time.time()for image_batch in dataset:train_step(image_batch)# 继续进行时为 GIF 生成图像display.clear_output(wait=True)generate_and_save_images(generator,epoch + 1,seed)# 每 15 个 epoch 保存一次模型if (epoch + 1) % 15 == 0:checkpoint.save(file_prefix = checkpoint_prefix)print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))# 最后一个 epoch 结束后生成图片display.clear_output(wait=True)generate_and_save_images(generator,epochs,seed)# 生成与保存图片def generate_and_save_images(model, epoch, test_input):# 注意 training` 设定为 False# 因此,所有层都在推理模式下运行(batchnorm)。predictions = model(test_input, training=False)fig = plt.figure(figsize=(4,4))for i in range(predictions.shape[0]):plt.subplot(4, 4, i+1)plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')plt.axis('off')plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))plt.show()

5.3 训练模型

调用上面定义的train()函数,来同时训练生成器和判别器。

注意,训练GAN可能比较难的;生成器和判别器不能互相压制对方,需要两种达到平衡,它们用相似的学习率训练。

%%timetrain(train_dataset, EPOCHS)

在刚开始训练时,生成的图片看起来很像随机噪声,随着训练过程的进行,生成的数字越来越真实。训练大约50轮后,生成器生成的图片看起来很像MNIST数字了。

训练了15轮的效果:

训练了30轮的效果:

训练过程:

恢复最新的检查点

checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

六、评估模型

这里通过直接查看生成的图片,来看模型的效果。使用训练过程中生成的图片,通过imageio生成动态gif。

# 使用 epoch 数生成单张图片def display_image(epoch_no):return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))display_image(EPOCHS)

anim_file = 'dcgan.gif'with imageio.get_writer(anim_file, mode='I') as writer:filenames = glob.glob('image*.png')filenames = sorted(filenames)last = -1for i,filename in enumerate(filenames):frame = 2*(i**0.5)if round(frame) > round(last):last = frameelse:continueimage = imageio.imread(filename)writer.append_data(image)image = imageio.imread(filename)writer.append_data(image)import IPythonif IPython.version_info > (6,2,0,''):display.Image(filename=anim_file)

完整代码:

import tensorflow as tfimport globimport imageioimport matplotlib.pyplot as pltimport numpy as npimport osimport PILfrom tensorflow.keras import layersimport timefrom IPython import display(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内BUFFER_SIZE = 60000BATCH_SIZE = 256# 批量化和打乱数据train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)# 创建模型--生成器def make_generator_model():model = tf.keras.Sequential()model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())model.add(layers.Reshape((7, 7, 256)))assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))assert model.output_shape == (None, 7, 7, 128)model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))assert model.output_shape == (None, 14, 14, 64)model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))assert model.output_shape == (None, 28, 28, 1)return model# 使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。generator = make_generator_model()noise = tf.random.normal([1, 100])generated_image = generator(noise, training=False)plt.imshow(generated_image[0, :, :, 0], cmap='gray')tf.keras.utils.plot_model(generator)# 判别器def make_discriminator_model():model = tf.keras.Sequential()model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',input_shape=[28, 28, 1]))model.add(layers.LeakyReLU())model.add(layers.Dropout(0.3))model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))model.add(layers.LeakyReLU())model.add(layers.Dropout(0.3))model.add(layers.Flatten())model.add(layers.Dense(1))return model# 使用(尚未训练的)判别器来对图片的真伪进行判断。模型将被训练为为真实图片输出正值,为伪造图片输出负值。discriminator = make_discriminator_model()decision = discriminator(generated_image)print (decision)# 首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)# 生成器的损失和优化器def generator_loss(fake_output):return cross_entropy(tf.ones_like(fake_output), fake_output)generator_optimizer = tf.keras.optimizers.Adam(1e-4)# 判别器的损失和优化器def discriminator_loss(real_output, fake_output):real_loss = cross_entropy(tf.ones_like(real_output), real_output)fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)total_loss = real_loss + fake_lossreturn total_lossdiscriminator_optimizer = tf.keras.optimizers.Adam(1e-4)# 保存检查点checkpoint_dir = './training_checkpoints'checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,discriminator_optimizer=discriminator_optimizer,generator=generator,discriminator=discriminator)# 定义训练过程EPOCHS = 50noise_dim = 100num_examples_to_generate = 16# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)seed = tf.random.normal([num_examples_to_generate, noise_dim])# 注意 `tf.function` 的使用# 该注解使函数被“编译”@tf.functiondef train_step(images):noise = tf.random.normal([BATCH_SIZE, noise_dim])with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:generated_images = generator(noise, training=True)real_output = discriminator(images, training=True)fake_output = discriminator(generated_images, training=True)gen_loss = generator_loss(fake_output)disc_loss = discriminator_loss(real_output, fake_output)gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))def train(dataset, epochs):for epoch in range(epochs):start = time.time()for image_batch in dataset:train_step(image_batch)# 继续进行时为 GIF 生成图像display.clear_output(wait=True)generate_and_save_images(generator,epoch + 1,seed)# 每 15 个 epoch 保存一次模型if (epoch + 1) % 15 == 0:checkpoint.save(file_prefix = checkpoint_prefix)print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))# 最后一个 epoch 结束后生成图片display.clear_output(wait=True)generate_and_save_images(generator,epochs,seed)# 生成与保存图片def generate_and_save_images(model, epoch, test_input):# 注意 training` 设定为 False# 因此,所有层都在推理模式下运行(batchnorm)。predictions = model(test_input, training=False)fig = plt.figure(figsize=(4,4))for i in range(predictions.shape[0]):plt.subplot(4, 4, i+1)plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')plt.axis('off')plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))plt.show()# 训练模型train(train_dataset, EPOCHS)# 恢复最新的检查点checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))# 评估模型# 使用 epoch 数生成单张图片def display_image(epoch_no):return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))display_image(EPOCHS)anim_file = 'dcgan.gif'with imageio.get_writer(anim_file, mode='I') as writer:filenames = glob.glob('image*.png')filenames = sorted(filenames)last = -1for i,filename in enumerate(filenames):frame = 2*(i**0.5)if round(frame) > round(last):last = frameelse:continueimage = imageio.imread(filename)writer.append_data(image)image = imageio.imread(filename)writer.append_data(image)import IPythonif IPython.version_info > (6,2,0,''):display.Image(filename=anim_file)

参考:/tutorials/generative/dcgan

一篇文章“简单”认识《生成对抗网络》(GAN)

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。