700字范文,内容丰富有趣,生活中的好帮手!
700字范文 > [生成对抗网络GAN入门指南](10)InfoGAN: Interpretable Representation Learning by Information Maximizing GAN

[生成对抗网络GAN入门指南](10)InfoGAN: Interpretable Representation Learning by Information Maximizing GAN

时间:2019-03-06 10:53:56

相关推荐

[生成对抗网络GAN入门指南](10)InfoGAN: Interpretable Representation Learning by Information Maximizing GAN

本篇blog的内容基于原始论文InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets(NPIs)和《生成对抗网络入门指南》第六章。完整代码及简析见文章末尾

一、为什么要使用InfoGAN

InfoGAN采用无监督的方式学习,并尝试实现可解释特征。使用了信息论的原理,通过最大化输入噪声和观察值之间的互信息(Mutual Information,MI)来对网络模型进行优化。InfoGAN能适用于各种复杂的数据集,可以同时实现离散特征和连续特征。

二、输入端数据

InfoGAN在输入端把随机输入分为两个部分:

第一部分为z,代表噪声;

第二部分为c,代表隐含编码;

目标是希望在每个维度上都具备可解释型特征。

在同时输入噪声z和隐含编码c后,生成概率,为了应对这个问题,在InfoGAN中需要对隐含编码c和生成分布G(z,c)求互信息,并使其最大化

三、InfoGAN结构

InfoGAN和前面介绍过的GAN区别在于,真实训练数据不有标签数据,二输入数据为隐含编码和随机噪声的组合,最后通过判别器一端和最大化互信息的方式还原隐含编码的信息。也就是说,判别器D最终需要同时具备还原隐含编码和辨别真伪的能力。前者为了生成图像能够很好具备编码中的特性,也就是说隐含编码可以对生网络产生相对显著地成果;后者是要求生成模型在还原信息的同时保证生成的数据与真实数据非常逼近。

1. 互信息

互信息表示两个随机变量之间的依赖程度的度量。对于随机变量X和Y,互信息为I(X;Y),H(X)和H(Y)为边缘熵,H(X|Y)和H(Y|X)为条件熵。

2. 结构

3. 目标函数

当X和Y相互独立时候,互信息为0.给定任意的输入,希望生成器的有一个相对较小的熵,即希望隐含编码c的信息在生成过程中不会流失。对此我们修改目标函数:

由于概率能以得到,导致互信息难以最大化,实际计算可以定义一个近似概率的辅助分布来获取互信息的下界,推导如下:

由此可以得到互信息的下界值:

4. InfoGAN的推导

我们可以重新改写之前不等式,并重新使蒙特卡洛方法逼近

得到我们最终的目标函数

四、实验效果

1.MNIST数据

我们发现通过控制隐含编码中的可以调节生成数字是几,其他参数可以调节生成字符的倾斜程度、字体宽度等

2. 3D人脸数据

3. 椅子数据集

4. 门牌号数据集

五、实验代码

1. 导入相关包及超参数

from __future__ import print_function, divisionfrom keras.datasets import mnistfrom keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply, concatenatefrom keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D, Lambdafrom keras.layers.advanced_activations import LeakyReLUfrom keras.layers.convolutional import UpSampling2D, Conv2Dfrom keras.models import Sequential, Modelfrom keras.optimizers import Adamfrom keras.utils import to_categoricalimport keras.backend as Kimport matplotlib.pyplot as pltimport numpy as npclass INFOGAN():def __init__(self):self.img_rows = 28self.img_cols = 28self.channels = 1self.num_classes = 10self.img_shape = (self.img_rows, self.img_cols, self.channels)self.latent_dim = 72optimizer = Adam(0.0002, 0.5)losses = ['binary_crossentropy', self.mutual_info_loss]# Build and the discriminator and recognition networkself.discriminator, self.auxilliary = self.build_disk_and_q_net()pile(loss=['binary_crossentropy'],optimizer=optimizer,metrics=['accuracy'])# Build and compile the recognition network pile(loss=[self.mutual_info_loss],optimizer=optimizer,metrics=['accuracy'])# Build the generatorself.generator = self.build_generator()# The generator takes noise and the target label as input# and generates the corresponding digit of that labelgen_input = Input(shape=(self.latent_dim,))img = self.generator(gen_input)# For the combined model we will only train the generatorself.discriminator.trainable = False# The discriminator takes generated image as input and determines validityvalid = self.discriminator(img)# The recognition network produces the labeltarget_label = self.auxilliary(img)# The combined model (stacked generator and discriminator)bined = Model(gen_input, [valid, target_label])pile(loss=losses,optimizer=optimizer)

2. 构造生成器和判别器

def build_generator(self):model = Sequential()model.add(Dense(128 * 7 * 7, activation="relu", input_dim=self.latent_dim))model.add(Reshape((7, 7, 128)))model.add(BatchNormalization(momentum=0.8))model.add(UpSampling2D())model.add(Conv2D(128, kernel_size=3, padding="same"))model.add(Activation("relu"))model.add(BatchNormalization(momentum=0.8))model.add(UpSampling2D())model.add(Conv2D(64, kernel_size=3, padding="same"))model.add(Activation("relu"))model.add(BatchNormalization(momentum=0.8))model.add(Conv2D(self.channels, kernel_size=3, padding='same'))model.add(Activation("tanh"))gen_input = Input(shape=(self.latent_dim,))img = model(gen_input)model.summary()return Model(gen_input, img)def build_disk_and_q_net(self):img = Input(shape=self.img_shape)# Shared layers between discriminator and recognition networkmodel = Sequential()model.add(Conv2D(64, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same"))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))model.add(ZeroPadding2D(padding=((0,1),(0,1))))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(BatchNormalization(momentum=0.8))model.add(Conv2D(256, kernel_size=3, strides=2, padding="same"))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(BatchNormalization(momentum=0.8))model.add(Conv2D(512, kernel_size=3, strides=2, padding="same"))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(BatchNormalization(momentum=0.8))model.add(Flatten())img_embedding = model(img)# Discriminatorvalidity = Dense(1, activation='sigmoid')(img_embedding)# Recognitionq_net = Dense(128, activation='relu')(img_embedding)label = Dense(self.num_classes, activation='softmax')(q_net)# Return discriminator and recognition networkreturn Model(img, validity), Model(img, label)

3. 构造互信息

def mutual_info_loss(self, c, c_given_x):"""The mutual information metric we aim to minimize"""eps = 1e-8conditional_entropy = K.mean(- K.sum(K.log(c_given_x + eps) * c, axis=1))entropy = K.mean(- K.sum(K.log(c + eps) * c, axis=1))return conditional_entropy + entropydef sample_generator_input(self, batch_size):# Generator inputssampled_noise = np.random.normal(0, 1, (batch_size, 62))sampled_labels = np.random.randint(0, self.num_classes, batch_size).reshape(-1, 1)sampled_labels = to_categorical(sampled_labels, num_classes=self.num_classes)return sampled_noise, sampled_labels

4. 训练

def train(self, epochs, batch_size=128, sample_interval=50):# Load the dataset(X_train, y_train), (_, _) = mnist.load_data()# Rescale -1 to 1X_train = (X_train.astype(np.float32) - 127.5) / 127.5X_train = np.expand_dims(X_train, axis=3)y_train = y_train.reshape(-1, 1)# Adversarial ground truthsvalid = np.ones((batch_size, 1))fake = np.zeros((batch_size, 1))for epoch in range(epochs):# ---------------------# Train Discriminator# ---------------------# Select a random half batch of imagesidx = np.random.randint(0, X_train.shape[0], batch_size)imgs = X_train[idx]# Sample noise and categorical labelssampled_noise, sampled_labels = self.sample_generator_input(batch_size)gen_input = np.concatenate((sampled_noise, sampled_labels), axis=1)# Generate a half batch of new imagesgen_imgs = self.generator.predict(gen_input)# Train on real and generated datad_loss_real = self.discriminator.train_on_batch(imgs, valid)d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)# Avg. lossd_loss = 0.5 * np.add(d_loss_real, d_loss_fake)# ---------------------# Train Generator and Q-network# ---------------------g_loss = bined.train_on_batch(gen_input, [valid, sampled_labels])# Plot the progressprint ("%d [D loss: %.2f, acc.: %.2f%%] [Q loss: %.2f] [G loss: %.2f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss[1], g_loss[2]))# If at save interval => save generated image samplesif epoch % sample_interval == 0:self.sample_images(epoch)

5. 可视化

def sample_images(self, epoch):r, c = 10, 10fig, axs = plt.subplots(r, c)for i in range(c):sampled_noise, _ = self.sample_generator_input(c)label = to_categorical(np.full(fill_value=i, shape=(r,1)), num_classes=self.num_classes)gen_input = np.concatenate((sampled_noise, label), axis=1)gen_imgs = self.generator.predict(gen_input)gen_imgs = 0.5 * gen_imgs + 0.5for j in range(r):axs[j,i].imshow(gen_imgs[j,:,:,0], cmap='gray')axs[j,i].axis('off')fig.savefig("images/%d.png" % epoch)plt.close()def save_model(self):def save(model, model_name):model_path = "saved_model/%s.json" % model_nameweights_path = "saved_model/%s_weights.hdf5" % model_nameoptions = {"file_arch": model_path,"file_weight": weights_path}json_string = model.to_json()open(options['file_arch'], 'w').write(json_string)model.save_weights(options['file_weight'])save(self.generator, "generator")save(self.discriminator, "discriminator")if __name__ == '__main__':infogan = INFOGAN()infogan.train(epochs=50000, batch_size=128, sample_interval=50)

实验结果

完整代码

from __future__ import print_function, divisionfrom keras.datasets import mnistfrom keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply, concatenatefrom keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D, Lambdafrom keras.layers.advanced_activations import LeakyReLUfrom keras.layers.convolutional import UpSampling2D, Conv2Dfrom keras.models import Sequential, Modelfrom keras.optimizers import Adamfrom keras.utils import to_categoricalimport keras.backend as Kimport matplotlib.pyplot as pltimport numpy as npclass INFOGAN():def __init__(self):self.img_rows = 28self.img_cols = 28self.channels = 1self.num_classes = 10self.img_shape = (self.img_rows, self.img_cols, self.channels)self.latent_dim = 72optimizer = Adam(0.0002, 0.5)losses = ['binary_crossentropy', self.mutual_info_loss]# Build and the discriminator and recognition networkself.discriminator, self.auxilliary = self.build_disk_and_q_net()pile(loss=['binary_crossentropy'],optimizer=optimizer,metrics=['accuracy'])# Build and compile the recognition network pile(loss=[self.mutual_info_loss],optimizer=optimizer,metrics=['accuracy'])# Build the generatorself.generator = self.build_generator()# The generator takes noise and the target label as input# and generates the corresponding digit of that labelgen_input = Input(shape=(self.latent_dim,))img = self.generator(gen_input)# For the combined model we will only train the generatorself.discriminator.trainable = False# The discriminator takes generated image as input and determines validityvalid = self.discriminator(img)# The recognition network produces the labeltarget_label = self.auxilliary(img)# The combined model (stacked generator and discriminator)bined = Model(gen_input, [valid, target_label])pile(loss=losses,optimizer=optimizer)def build_generator(self):model = Sequential()model.add(Dense(128 * 7 * 7, activation="relu", input_dim=self.latent_dim))model.add(Reshape((7, 7, 128)))model.add(BatchNormalization(momentum=0.8))model.add(UpSampling2D())model.add(Conv2D(128, kernel_size=3, padding="same"))model.add(Activation("relu"))model.add(BatchNormalization(momentum=0.8))model.add(UpSampling2D())model.add(Conv2D(64, kernel_size=3, padding="same"))model.add(Activation("relu"))model.add(BatchNormalization(momentum=0.8))model.add(Conv2D(self.channels, kernel_size=3, padding='same'))model.add(Activation("tanh"))gen_input = Input(shape=(self.latent_dim,))img = model(gen_input)model.summary()return Model(gen_input, img)def build_disk_and_q_net(self):img = Input(shape=self.img_shape)# Shared layers between discriminator and recognition networkmodel = Sequential()model.add(Conv2D(64, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same"))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))model.add(ZeroPadding2D(padding=((0,1),(0,1))))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(BatchNormalization(momentum=0.8))model.add(Conv2D(256, kernel_size=3, strides=2, padding="same"))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(BatchNormalization(momentum=0.8))model.add(Conv2D(512, kernel_size=3, strides=2, padding="same"))model.add(LeakyReLU(alpha=0.2))model.add(Dropout(0.25))model.add(BatchNormalization(momentum=0.8))model.add(Flatten())img_embedding = model(img)# Discriminatorvalidity = Dense(1, activation='sigmoid')(img_embedding)# Recognitionq_net = Dense(128, activation='relu')(img_embedding)label = Dense(self.num_classes, activation='softmax')(q_net)# Return discriminator and recognition networkreturn Model(img, validity), Model(img, label)def mutual_info_loss(self, c, c_given_x):"""The mutual information metric we aim to minimize"""eps = 1e-8conditional_entropy = K.mean(- K.sum(K.log(c_given_x + eps) * c, axis=1))entropy = K.mean(- K.sum(K.log(c + eps) * c, axis=1))return conditional_entropy + entropydef sample_generator_input(self, batch_size):# Generator inputssampled_noise = np.random.normal(0, 1, (batch_size, 62))sampled_labels = np.random.randint(0, self.num_classes, batch_size).reshape(-1, 1)sampled_labels = to_categorical(sampled_labels, num_classes=self.num_classes)return sampled_noise, sampled_labelsdef train(self, epochs, batch_size=128, sample_interval=50):# Load the dataset(X_train, y_train), (_, _) = mnist.load_data()# Rescale -1 to 1X_train = (X_train.astype(np.float32) - 127.5) / 127.5X_train = np.expand_dims(X_train, axis=3)y_train = y_train.reshape(-1, 1)# Adversarial ground truthsvalid = np.ones((batch_size, 1))fake = np.zeros((batch_size, 1))for epoch in range(epochs):# ---------------------# Train Discriminator# ---------------------# Select a random half batch of imagesidx = np.random.randint(0, X_train.shape[0], batch_size)imgs = X_train[idx]# Sample noise and categorical labelssampled_noise, sampled_labels = self.sample_generator_input(batch_size)gen_input = np.concatenate((sampled_noise, sampled_labels), axis=1)# Generate a half batch of new imagesgen_imgs = self.generator.predict(gen_input)# Train on real and generated datad_loss_real = self.discriminator.train_on_batch(imgs, valid)d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)# Avg. lossd_loss = 0.5 * np.add(d_loss_real, d_loss_fake)# ---------------------# Train Generator and Q-network# ---------------------g_loss = bined.train_on_batch(gen_input, [valid, sampled_labels])# Plot the progressprint ("%d [D loss: %.2f, acc.: %.2f%%] [Q loss: %.2f] [G loss: %.2f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss[1], g_loss[2]))# If at save interval => save generated image samplesif epoch % sample_interval == 0:self.sample_images(epoch)def sample_images(self, epoch):r, c = 10, 10fig, axs = plt.subplots(r, c)for i in range(c):sampled_noise, _ = self.sample_generator_input(c)label = to_categorical(np.full(fill_value=i, shape=(r,1)), num_classes=self.num_classes)gen_input = np.concatenate((sampled_noise, label), axis=1)gen_imgs = self.generator.predict(gen_input)gen_imgs = 0.5 * gen_imgs + 0.5for j in range(r):axs[j,i].imshow(gen_imgs[j,:,:,0], cmap='gray')axs[j,i].axis('off')fig.savefig("images/%d.png" % epoch)plt.close()def save_model(self):def save(model, model_name):model_path = "saved_model/%s.json" % model_nameweights_path = "saved_model/%s_weights.hdf5" % model_nameoptions = {"file_arch": model_path,"file_weight": weights_path}json_string = model.to_json()open(options['file_arch'], 'w').write(json_string)model.save_weights(options['file_weight'])save(self.generator, "generator")save(self.discriminator, "discriminator")if __name__ == '__main__':infogan = INFOGAN()infogan.train(epochs=50000, batch_size=128, sample_interval=50)

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。