700字范文,内容丰富有趣,生活中的好帮手!
700字范文 > GAN网络生成手写体数字图片

GAN网络生成手写体数字图片

时间:2022-02-18 09:47:31

相关推荐

GAN网络生成手写体数字图片

Keras真香,以前都是用tensorflow来写神经网络,自从用了keras,发现这个Keras也蛮方便的。

目前感觉keras的优点就是方便搭建基于标准网络组件的神经网络,这里的网络组件包括全连接层,卷积层,池化层等等。对于需要对网络本身做创新的实验,keas可能不是很方便,还是得用tensorflow来搭建。

这篇博客,我想用Keras写一个简单的生成对抗网络。

生成对抗网络的目标是生成手写体数字。

先看看实验的效果:

epoch=1000的时候:

epoch=10000的时候:数字1已经有点像了

epoch=60000,数字1就很清晰了 ,而且其他数字也越来越清晰了

epoch=80000: 生成了5,7 啥的了。

随着训练的加深,生成的数字会越来越真实了。

代码已经开源,项目地址:

/jmhIcoding/GAN_MNIST.git

模型原理

模型原理就不说了,就是使用最基础GAN结构。

模型由一个生成器和一个鉴别器组成。

生成器用于输入噪声,然后生成一个手写体数字图片。

鉴别器用于判断某个输入给它的图片是不是生成器合成的。

生成器的目标是生成让鉴别器判断为非合成的图片。

鉴别器的目标则是以尽量高的正确率分类某种图片是否为合成的。

总的原理就是这些了。

模型的损失函数就是围绕着这两个目标来展开的。

模型编写

生成器

__author__ = 'dk'#生成器import sysimport numpy as npimport kerasfrom keras import layersfrom keras import modelsfrom keras import optimizersfrom keras import lossesclass Generator:def __init__(self,height=28,width=28,channel=1,latent_space_dimension=100):''':param height: 生成图片的高,minist为28:param width:生成图片的宽,minist为28:param channel: 生成器所生成的图片的通道数目,对于mnist灰度图来说,channel为1:param latent_space_dimension: 噪声的维度:return:'''self.latent_space_dimension = latent_space_dimensionself.height = heightself.width = widthself.channel = channelself.generator = self.build_model()self.generator.summary()def build_model(self,block_starting_size=128,num_blocks=4):model = models.Sequential(name='generator')for i in range(num_blocks):if i ==0 :model.add(layers.Dense(block_starting_size,input_shape=(self.latent_space_dimension,)))else:block_size = block_starting_size * (2**i)model.add(layers.Dense(block_size))model.add(layers.LeakyReLU())model.add(layers.BatchNormalization(momentum=0.75))model.add(layers.Dense(self.height*self.channel*self.width,activation='tanh'))model.add(layers.Reshape((self.width,self.height,self.channel)))return modeldef summary(self):self.model.summary()def save_model(self):self.generator.save("generator.h5")

注意,generator是和整个模型一起训练的,它可以不需要compile模型。

鉴别器

__author__ = 'dk'#判别器import sysimport osimport kerasfrom keras import layersfrom keras import optimizersfrom keras import modelsfrom keras import lossesclass Discriminator:def __init__(self,height=28,width=28,channel=1):''':param height: 输入图片的高:param width: 输入图片的宽:param channel: 输入图片的通道数:return:'''self.height = heightself.width = widthself.channel = channelself.discriminator = self.build_model()OPTIMIZER = optimizers.Adam()self.discriminator = self.build_model()pile(optimizer=OPTIMIZER,loss=losses.binary_crossentropy,metrics =['accuracy'])self.discriminator.summary()def build_model(self):model = models.Sequential(name='discriminator')model.add(layers.Flatten(input_shape=(self.width,self.height,self.channel)))model.add(layers.Dense(self.height*self.width*self.channel,input_shape=(self.width,self.height,self.channel)))model.add(layers.LeakyReLU(0.2))model.add(layers.Dense(self.height*self.width*self.channel//2))model.add(layers.LeakyReLU(0.2))model.add(layers.Dense(1,activation='sigmoid'))return modeldef summary(self):return self.discriminator.summary()def save_model(self):self.discriminator.save("discriminator.h5")

gan网络

把生成器和鉴别器合并起来

__author__ = 'dk'#生成对抗网络import kerasfrom keras import layersfrom keras import optimizersfrom keras import lossesfrom keras import modelsimport sysimport osfrom Discriminator import Discriminatorfrom Generator import Generatorclass GAN:def __init__(self,latent_space_dimension,height,width,channel):self.generator = Generator(height,width,channel,latent_space_dimension)self.discriminator = Discriminator(height,width,channel)self.discriminator.discriminator.trainable = False #gan部分,只训练生成器,鉴别器通过显式discriminator.train_on_batch调用来训练self.gan = self.build_model()OPTIMIZER = optimizers.Adamax()pile(optimizer = OPTIMIZER,loss = losses.binary_crossentropy)self.gan.summary()def build_model(self):model = models.Sequential(name='gan')model.add(self.generator.generator)model.add(self.discriminator.discriminator)return modeldef summary(self):self.gan.summary()def save_model(self):self.gan.save("gan.h5")

数据准备模块

__author__ = 'dk'#数据集采集器,主要是对mnist进行简单的封装from keras.datasets import mnistimport numpy as npdef sample_latent_space(instances_number,latent_space_dimension):return np.random.normal(0,1,(instances_number,latent_space_dimension))class Dator:def __init__(self,batch_size=None,model_type=1):''':param batch_size::param model_type: 当model_type为-1的时候,表示0-9个数字都选;当model_type=2,说明只选择数字2:return:'''self.batch_size = batch_sizeself.model_type = model_typewith np.load("mnist.npz", allow_pickle=True) as f:X_train, y_train = f['x_train'], f['y_train']#X_test, y_test = f['x_test'], f['y_test']if model_type != -1:X_train = X_train[np.where(y_train==model_type)[0]]if batch_size == None:self.batch_size = X_train.shape[0]else:self.batch_size = batch_sizeself.X_train = (np.float32(X_train)-128)/128.0self.X_train = np.expand_dims(self.X_train,3)self.watch_index = 0self.train_size = self.X_train.shape[0]def next_batch(self,batch_size = None):if batch_size == None:batch_size =self.batch_sizeX=np.concatenate([self.X_train[self.watch_index:(self.watch_index+batch_size)], self.X_train[:batch_size]])[:batch_size]self.watch_index = (self.watch_index + batch_size) % self.train_sizereturn Xif __name__ == '__main__':print(sample_latent_space(5,4))

训练main脚本:train.py

__author__ = 'dk'#模型训练代码from GAN import GANfrom data_utils import Dator,sample_latent_spaceimport numpy as npfrom matplotlib import pyplot as pltimport timeepochs = 50000height = 28width = 28channel =1latent_space_dimension = 100batch = 128dator = Dator(batch_size=batch,model_type=-1)gan = GAN(latent_space_dimension,height,width,channel)image_index = 0for i in range(epochs):real_img = dator.next_batch(batch_size=batch*2)real_label = np.ones(shape=(real_img.shape[0],1)) #真实的样本设置为1的标签noise = sample_latent_space(real_img.shape[0],latent_space_dimension)fake_img = gan.generator.generator.predict(noise)fake_label = np.zeros(shape=(fake_img.shape[0],1))#生成器生成的假图片标注为0###合成给gan的鉴别器的数据x_batch = np.concatenate([real_img,fake_img])y_batch = np.concatenate([real_label,fake_label])#训练一次discriminator_loss = gan.discriminator.discriminator.train_on_batch(x_batch,y_batch)[0]###注意,此时训练的是鉴别器,生成器部分不动。###合成训练生成器的数据noise = sample_latent_space(batch*2,latent_space_dimension)noise_labels = np.ones((batch*2,1)) #生成器的目标是把图片的label越来越像1generator_loss = gan.gan.train_on_batch(noise,noise_labels)print('Epoch : {0}, [Discriminator Loss:{1} ], [Generator Loss:{2}]'.format(i,discriminator_loss,generator_loss))if i!=0 and (i%50)==0:print('show time')#每50次输入16张图片看看效果noise = sample_latent_space(16,latent_space_dimension)images = gan.generator.generator.predict(noise)plt.figure(figsize=(10,10))plt.suptitle('epoch={0}'.format(i),fontsize=16)for index in range(images.shape[0]):plt.subplot(4,4,index+1)image =images[index,:,:,:]image = image.reshape(height,width)plt.imshow(image,cmap='gray')#plt.tight_layout()plt.savefig("./show_time/{0}.png".format(time.time()))image_index += 1plt.close()

运行脚本

python3 train.py

即可。

输出:

Model: "generator"_________________________________________________________________Layer (type) Output Shape Param # =================================================================dense_1 (Dense) (None, 128)12928_________________________________________________________________dense_2 (Dense) (None, 256)33024_________________________________________________________________leaky_re_lu_1 (LeakyReLU) (None, 256)0 _________________________________________________________________batch_normalization_1 (Batch (None, 256)1024_________________________________________________________________dense_3 (Dense) (None, 512)131584 _________________________________________________________________leaky_re_lu_2 (LeakyReLU) (None, 512)0 _________________________________________________________________batch_normalization_2 (Batch (None, 512)2048_________________________________________________________________dense_4 (Dense) (None, 1024) 525312 _________________________________________________________________leaky_re_lu_3 (LeakyReLU) (None, 1024) 0 _________________________________________________________________batch_normalization_3 (Batch (None, 1024) 4096_________________________________________________________________dense_5 (Dense) (None, 784)803600 _________________________________________________________________reshape_1 (Reshape)(None, 28, 28, 1) 0 =================================================================Total params: 1,513,616Trainable params: 1,510,032Non-trainable params: 3,584_________________________________________________________________Model: "discriminator"_________________________________________________________________Layer (type) Output Shape Param # =================================================================flatten_2 (Flatten)(None, 784)0 _________________________________________________________________dense_9 (Dense) (None, 784)615440 _________________________________________________________________leaky_re_lu_6 (LeakyReLU) (None, 784)0 _________________________________________________________________dense_10 (Dense) (None, 392)307720 _________________________________________________________________leaky_re_lu_7 (LeakyReLU) (None, 392)0 _________________________________________________________________dense_11 (Dense) (None, 1) 393 =================================================================Total params: 923,553Trainable params: 923,553Non-trainable params: 0_________________________________________________________________Model: "gan"_________________________________________________________________Layer (type) Output Shape Param # =================================================================generator (Sequential) (None, 28, 28, 1) 1513616 _________________________________________________________________discriminator (Sequential) (None, 1) 923553 =================================================================Total params: 2,437,169Trainable params: 1,510,032Non-trainable params: 927,137_________________________________________________________________·········Epoch : 117754, [Discriminator Loss:0.22975191473960876 ], [Generator Loss:2.57688570022583]Epoch : 117755, [Discriminator Loss:0.26782122254371643 ], [Generator Loss:3.1791584491729736]Epoch : 117756, [Discriminator Loss:0.2609345614910126 ], [Generator Loss:2.960988998413086]Epoch : 117757, [Discriminator Loss:0.2673880159854889 ], [Generator Loss:2.317220687866211]Epoch : 117758, [Discriminator Loss:0.24904575943946838 ], [Generator Loss:1.929720401763916]Epoch : 117759, [Discriminator Loss:0.25158950686454773 ], [Generator Loss:2.954155683517456]Epoch : 117760, [Discriminator Loss:0.20324105024337769 ], [Generator Loss:3.5244760513305664]Epoch : 117761, [Discriminator Loss:0.2849388122558594 ], [Generator Loss:3.195873498916626]Epoch : 117762, [Discriminator Loss:0.19631560146808624 ], [Generator Loss:2.328411340713501]Epoch : 117763, [Discriminator Loss:0.20523831248283386 ], [Generator Loss:2.402683973312378]Epoch : 117764, [Discriminator Loss:0.2625979781150818 ], [Generator Loss:3.2176101207733154]Epoch : 117765, [Discriminator Loss:0.29969191551208496 ], [Generator Loss:2.9656052589416504]Epoch : 117766, [Discriminator Loss:0.270328551530838 ], [Generator Loss:2.3880398273468018]Epoch : 117767, [Discriminator Loss:0.26741161942481995 ], [Generator Loss:2.7729406356811523]Epoch : 117768, [Discriminator Loss:0.28797847032546997 ], [Generator Loss:2.8959264755249023]Epoch : 117769, [Discriminator Loss:0.30181047320365906 ], [Generator Loss:2.791097402572632]Epoch : 117770, [Discriminator Loss:0.26939862966537476 ], [Generator Loss:2.3666043281555176]Epoch : 117771, [Discriminator Loss:0.26297527551651 ], [Generator Loss:2.895970582962036]Epoch : 117772, [Discriminator Loss:0.21928083896636963 ], [Generator Loss:3.4627976417541504]Epoch : 117773, [Discriminator Loss:0.3553962707519531 ], [Generator Loss:3.2194197177886963]Epoch : 117774, [Discriminator Loss:0.32673510909080505 ], [Generator Loss:2.473867893218994]Epoch : 117775, [Discriminator Loss:0.31245478987693787 ], [Generator Loss:2.999265193939209]Epoch : 117776, [Discriminator Loss:0.29536381363868713 ], [Generator Loss:3.733344554901123]Epoch : 117777, [Discriminator Loss:0.2955515682697296 ], [Generator Loss:3.2467658519744873]Epoch : 117778, [Discriminator Loss:0.3677394986152649 ], [Generator Loss:1.8517814874649048]Epoch : 117779, [Discriminator Loss:0.31648850440979004 ], [Generator Loss:2.6385254859924316]Epoch : 117780, [Discriminator Loss:0.31941041350364685 ], [Generator Loss:3.350475311279297]Epoch : 117781, [Discriminator Loss:0.47521263360977173 ], [Generator Loss:1.9556307792663574]Epoch : 117782, [Discriminator Loss:0.44070643186569214 ], [Generator Loss:1.9684114456176758]

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。