[深度学习-实践]条件生成对抗网络cGAN的例子-Tensorflow2.x Keras

时间:2022-11-22 03:47:39


[深度学习-实践]条件生成对抗网络cGAN的例子-Tensorflow2.x Keras





系列文章目录1. 什么是cGAN2. 数据集准备3. 代码实现cGAN3.1. 定义判别器discriminator3.2. 定义生成器Generator3.2. 定义GAN3.4. 训练4. 完整代码5. 根据条件生成相应的衣服6. 个人总结

1. 什么是cGAN



Generator— 给定一个随机值向量作为输入,此网络生成的数据与训练数据的结构相同。Discriminator— 给定包含训练的真实数据和生成器生成的数据,此网络是用于分开“真实”或“生成”的图片。


generator- 给定标签随机数组作为输入,此网络将生成具有与对应于相同标签的训练数据中的数据。discriminator-


例如,对于MNIST,可以生成特定的手写数字,例如数字9;对于CIFAR-10,可以生成特定的对象照片,例如“青蛙”;对于Fashion MNIST数据集,可以生成特定的服装项目,例如“dress”


2. 数据集准备






# example of loading the fashion_mnist datasetfrom keras.datasets.fashion_mnist import load_data# load the images into memory(trainX, trainy), (testX, testy) = load_data()# summarize the shape of the datasetprint('Train', trainX.shape, trainy.shape)print('Test', testX.shape, testy.shape)



Train (60000, 28, 28) (60000,)Test (10000, 28, 28) (10000,)


# example of loading the fashion_mnist datasetfrom keras.datasets.fashion_mnist import load_datafrom matplotlib import pyplot# load the images into memory(trainX, trainy), (testX, testy) = load_data()# plot images from the training datasetfor i in range(100):# define subplotpyplot.subplot(10, 10, 1 + i)# turn off axispyplot.axis('off')# plot raw pixel datapyplot.imshow(trainX[i], cmap='gray_r')pyplot.show()


3. 代码实现cGAN

3.1. 定义判别器discriminator





def define_discriminator(in_shape=(28,28,1), n_classes=10):# label inputin_label = Input(shape=(1,))# embedding for categorical inputli = Embedding(n_classes, 50)(in_label)# scale up to image dimensions with linear activationn_nodes = in_shape[0] * in_shape[1]li = Dense(n_nodes)(li)# reshape to additional channelli = Reshape((in_shape[0], in_shape[1], 1))(li)# image inputin_image = Input(shape=in_shape)# concat label as a channelmerge = Concatenate()([in_image, li])# downsamplefe = Conv2D(128, (3,3), strides=(2,2), padding='same')(merge)fe = LeakyReLU(alpha=0.2)(fe)# downsamplefe = Conv2D(128, (3,3), strides=(2,2), padding='same')(fe)fe = LeakyReLU(alpha=0.2)(fe)# flatten feature mapsfe = Flatten()(fe)# dropoutfe = Dropout(0.4)(fe)# outputout_layer = Dense(1, activation='sigmoid')(fe)# define modelmodel = Model([in_image, in_label], out_layer)# compile modelopt = Adam(lr=0.0002, beta_1=0.5)pile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])#tf.keras.utils.plot_model(model, 'discriminator.png', show_shapes=True)return model



3.2. 定义生成器Generator



# define the standalone generator modeldef define_generator(latent_dim, n_classes=10):# label inputin_label = Input(shape=(1,))# embedding for categorical inputli = Embedding(n_classes, 50)(in_label)# linear multiplicationn_nodes = 7 * 7li = Dense(n_nodes)(li)# reshape to additional channelli = Reshape((7, 7, 1))(li)# image generator inputin_lat = Input(shape=(latent_dim,))# foundation for 7x7 imagen_nodes = 128 * 7 * 7gen = Dense(n_nodes)(in_lat)gen = LeakyReLU(alpha=0.2)(gen)gen = Reshape((7, 7, 128))(gen)# merge image gen and label inputmerge = Concatenate()([gen, li])# upsample to 14x14gen = Conv2DTranspose(128, (4,4), strides=(2,2), padding='same')(merge)gen = LeakyReLU(alpha=0.2)(gen)# upsample to 28x28gen = Conv2DTranspose(128, (4,4), strides=(2,2), padding='same')(gen)gen = LeakyReLU(alpha=0.2)(gen)# outputout_layer = Conv2D(1, (7,7), activation='tanh', padding='same')(gen)# define modelmodel = Model([in_lat, in_label], out_layer)return model



3.2. 定义GAN





# define the combined generator and discriminator model, for updating the generatordef define_gan(g_model, d_model):# make weights in the discriminator not trainabled_model.trainable = False# get noise and label inputs from generator modelgen_noise, gen_label = g_model.input# get image output from the generator modelgen_output = g_model.output# connect image output and label input from generator as inputs to discriminatorgan_output = d_model([gen_output, gen_label])# define gan model as taking noise and label and outputting a classificationmodel = Model([gen_noise, gen_label], gan_output)# compile modelopt = Adam(lr=0.0002, beta_1=0.5)pile(loss='binary_crossentropy', optimizer=opt)return model



3.4. 训练




# load fashion mnist imagesdef load_real_samples():# load dataset(trainX, trainy), (_, _) = load_data()# expand to 3d, e.g. add channelsX = expand_dims(trainX, axis=-1)# convert from ints to floatsX = X.astype('float32')# scale from [0,255] to [-1,1]X = (X - 127.5) / 127.5return [X, trainy]# select real samplesdef generate_real_samples(dataset, n_samples):# split into images and labelsimages, labels = dataset# choose random instancesix = randint(0, images.shape[0], n_samples)# select images and labelsX, labels = images[ix], labels[ix]# generate class labelsy = ones((n_samples, 1))return [X, labels], y



# generate points in latent space as input for the generatordef generate_latent_points(latent_dim, n_samples, n_classes=10):# generate points in the latent spacex_input = randn(latent_dim * n_samples)# reshape into a batch of inputs for the networkz_input = x_input.reshape(n_samples, latent_dim)# generate labelslabels = randint(0, n_classes, n_samples)return [z_input, labels]# use the generator to generate n fake examples, with class labelsdef generate_fake_samples(generator, latent_dim, n_samples):# generate points in latent spacez_input, labels_input = generate_latent_points(latent_dim, n_samples)# predict outputsimages = generator.predict([z_input, labels_input])# create class labelsy = zeros((n_samples, 1))return [images, labels_input], y


# train the generator and discriminatordef train(g_model, d_model, gan_model, dataset, latent_dim, n_epochs=100, n_batch=128):bat_per_epo = int(dataset[0].shape[0] / n_batch)half_batch = int(n_batch / 2)# manually enumerate epochsfor i in range(n_epochs):# enumerate batches over the training setfor j in range(bat_per_epo):# get randomly selected 'real' samples[X_real, labels_real], y_real = generate_real_samples(dataset, half_batch)# update discriminator model weightsd_loss1, _ = d_model.train_on_batch([X_real, labels_real], y_real)# generate 'fake' examples[X_fake, labels], y_fake = generate_fake_samples(g_model, latent_dim, half_batch)# update discriminator model weightsd_loss2, _ = d_model.train_on_batch([X_fake, labels], y_fake)# prepare points in latent space as input for the generator[z_input, labels_input] = generate_latent_points(latent_dim, n_batch)# create inverted labels for the fake samplesy_gan = ones((n_batch, 1))# update the generator via the discriminator's errorg_loss = gan_model.train_on_batch([z_input, labels_input], y_gan)# summarize loss on this batchprint('>%d, %d/%d, d1=%.3f, d2=%.3f g=%.3f' %(i+1, j+1, bat_per_epo, d_loss1, d_loss2, g_loss))# save the generator modelg_model.save('cgan_generator.h5')

4. 完整代码

import tensorflow as tffrom tensorflow.keras.optimizers import Adamfrom tensorflow.keras.models import Modelfrom tensorflow.keras.layers import Inputfrom tensorflow.keras.layers import Conv2Dfrom tensorflow.keras.layers import Conv2DTransposefrom tensorflow.keras.datasets.fashion_mnist import load_datafrom tensorflow.keras.layers import Concatenate, Dense, Reshape, Embedding, Flattenfrom tensorflow.keras.layers import Dropoutfrom tensorflow.keras.layers import LeakyReLUfrom numpy import expand_dimsfrom numpy import zerosfrom numpy import onesfrom numpy.random import randnfrom numpy.random import randint# define the standalone discriminator modeldef define_discriminator(in_shape=(28,28,1), n_classes=10):# label inputin_label = Input(shape=(1,))# embedding for categorical inputli = Embedding(n_classes, 50)(in_label)# scale up to image dimensions with linear activationn_nodes = in_shape[0] * in_shape[1]li = Dense(n_nodes)(li)# reshape to additional channelli = Reshape((in_shape[0], in_shape[1], 1))(li)# image inputin_image = Input(shape=in_shape)# concat label as a channelmerge = Concatenate()([in_image, li])# downsamplefe = Conv2D(128, (3,3), strides=(2,2), padding='same')(merge)fe = LeakyReLU(alpha=0.2)(fe)# downsamplefe = Conv2D(128, (3,3), strides=(2,2), padding='same')(fe)fe = LeakyReLU(alpha=0.2)(fe)# flatten feature mapsfe = Flatten()(fe)# dropoutfe = Dropout(0.4)(fe)# outputout_layer = Dense(1, activation='sigmoid')(fe)# define modelmodel = Model([in_image, in_label], out_layer)# compile modelopt = Adam(lr=0.0002, beta_1=0.5)pile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])#tf.keras.utils.plot_model(model, 'discriminator.png', show_shapes=True)return model# define the standalone generator modeldef define_generator(latent_dim, n_classes=10):# label inputin_label = Input(shape=(1,))# embedding for categorical inputli = Embedding(n_classes, 50)(in_label)# linear multiplicationn_nodes = 7 * 7li = Dense(n_nodes)(li)# reshape to additional channelli = Reshape((7, 7, 1))(li)# image generator inputin_lat = Input(shape=(latent_dim,))# foundation for 7x7 imagen_nodes = 128 * 7 * 7gen = Dense(n_nodes)(in_lat)gen = LeakyReLU(alpha=0.2)(gen)gen = Reshape((7, 7, 128))(gen)# merge image gen and label inputmerge = Concatenate()([gen, li])# upsample to 14x14gen = Conv2DTranspose(128, (4,4), strides=(2,2), padding='same')(merge)gen = LeakyReLU(alpha=0.2)(gen)# upsample to 28x28gen = Conv2DTranspose(128, (4,4), strides=(2,2), padding='same')(gen)gen = LeakyReLU(alpha=0.2)(gen)# outputout_layer = Conv2D(1, (7,7), activation='tanh', padding='same')(gen)# define modelmodel = Model([in_lat, in_label], out_layer)return model# define the combined generator and discriminator model, for updating the generatordef define_gan(g_model, d_model):# make weights in the discriminator not trainabled_model.trainable = False# get noise and label inputs from generator modelgen_noise, gen_label = g_model.input# get image output from the generator modelgen_output = g_model.output# connect image output and label input from generator as inputs to discriminatorgan_output = d_model([gen_output, gen_label])# define gan model as taking noise and label and outputting a classificationmodel = Model([gen_noise, gen_label], gan_output)# compile modelopt = Adam(lr=0.0002, beta_1=0.5)pile(loss='binary_crossentropy', optimizer=opt)return model# load fashion mnist imagesdef load_real_samples():# load dataset(trainX, trainy), (_, _) = load_data()# expand to 3d, e.g. add channelsX = expand_dims(trainX, axis=-1)# convert from ints to floatsX = X.astype('float32')# scale from [0,255] to [-1,1]X = (X - 127.5) / 127.5return [X, trainy]# select real samplesdef generate_real_samples(dataset, n_samples):# split into images and labelsimages, labels = dataset# choose random instancesix = randint(0, images.shape[0], n_samples)# select images and labelsX, labels = images[ix], labels[ix]# generate class labelsy = ones((n_samples, 1))return [X, labels], y# generate points in latent space as input for the generatordef generate_latent_points(latent_dim, n_samples, n_classes=10):# generate points in the latent spacex_input = randn(latent_dim * n_samples)# reshape into a batch of inputs for the networkz_input = x_input.reshape(n_samples, latent_dim)# generate labelslabels = randint(0, n_classes, n_samples)return [z_input, labels]# use the generator to generate n fake examples, with class labelsdef generate_fake_samples(generator, latent_dim, n_samples):# generate points in latent spacez_input, labels_input = generate_latent_points(latent_dim, n_samples)# predict outputsimages = generator.predict([z_input, labels_input])# create class labelsy = zeros((n_samples, 1))return [images, labels_input], y# train the generator and discriminatordef train(g_model, d_model, gan_model, dataset, latent_dim, n_epochs=100, n_batch=128):bat_per_epo = int(dataset[0].shape[0] / n_batch)half_batch = int(n_batch / 2)# manually enumerate epochsfor i in range(n_epochs):# enumerate batches over the training setfor j in range(bat_per_epo):# get randomly selected 'real' samples[X_real, labels_real], y_real = generate_real_samples(dataset, half_batch)# update discriminator model weightsd_loss1, _ = d_model.train_on_batch([X_real, labels_real], y_real)# generate 'fake' examples[X_fake, labels], y_fake = generate_fake_samples(g_model, latent_dim, half_batch)# update discriminator model weightsd_loss2, _ = d_model.train_on_batch([X_fake, labels], y_fake)# prepare points in latent space as input for the generator[z_input, labels_input] = generate_latent_points(latent_dim, n_batch)# create inverted labels for the fake samplesy_gan = ones((n_batch, 1))# update the generator via the discriminator's errorg_loss = gan_model.train_on_batch([z_input, labels_input], y_gan)# summarize loss on this batchprint('>%d, %d/%d, d1=%.3f, d2=%.3f g=%.3f' %(i+1, j+1, bat_per_epo, d_loss1, d_loss2, g_loss))# save the generator modelg_model.save('cgan_generator.h5')if __name__ == '__main__':# size of the latent spacelatent_dim = 100# create the discriminatord_model = define_discriminator()# create the generatorg_model = define_generator(latent_dim)# create the gangan_model = define_gan(g_model, d_model)# load image datadataset = load_real_samples()# train modeltrain(g_model, d_model, gan_model, dataset, latent_dim)

5. 根据条件生成相应的衣服




# example of loading the generator model and generating imagesfrom numpy import asarrayfrom numpy.random import randnfrom numpy.random import randintfrom tensorflow.keras.models import load_modelfrom matplotlib import pyplot# generate points in latent space as input for the generatordef generate_latent_points(latent_dim, n_samples, n_classes=10):# generate points in the latent spacex_input = randn(latent_dim * n_samples)# reshape into a batch of inputs for the networkz_input = x_input.reshape(n_samples, latent_dim)# generate labelslabels = randint(0, n_classes, n_samples)return [z_input, labels]# create and save a plot of generated imagesdef save_plot(examples, n):# plot imagesfor i in range(n * n):# define subplotpyplot.subplot(n, n, 1 + i)# turn off axispyplot.axis('off')# plot raw pixel datapyplot.imshow(examples[i, :, :, 0], cmap='gray_r')pyplot.show()# load modelmodel = load_model('cgan_generator.h5')# generate imageslatent_points, labels = generate_latent_points(100, 100)# specify labelslabels = asarray([x for _ in range(10) for x in range(10)])# generate imagesX = model.predict([latent_points, labels])# scale from [-1,1] to [0,1]X = (X + 1) / 2.0# plot the resultsave_plot(X, 10)




6. 个人总结

