700字范文,内容丰富有趣,生活中的好帮手!
700字范文 > GAN 生成MNIST数据集

GAN 生成MNIST数据集

时间:2023-08-23 00:48:06

相关推荐

GAN 生成MNIST数据集

1、GAN是什么

GAN(生成式对抗网络,Generative Adversarial Networks)是一种深度学习模型,模型通过框架中两个模块(生成模型 Generative Model 和判别模型 Discriminative Model)的互相博弈学习,从而产生相当好的输出。原始GAN理论中,只要求G和D能拟合出相应生成和判别的函数即可,而并不要求他们必须都是神经网络,但是我们的实际应用中,一般都是采用深度神经网络作为G和D。

GAN论文:/abs/1406.2661

2、GAN的原理

@基本原理

GAN分为一个判别器(Discriminator,简称D)和一个生成器(Generator,简称G),简单的说,G和D就是两个多层感知机或卷积神经网络,它的基本思想,即为G和D的生成博弈过程。

G是一个生成图片的网络,它接收一个随机的噪声z,并且通过这个噪声生成图片,记做G(z)

D是一个进行判别的网络,它可以判别出一张图片是不是真实的。即给D输入真图片,它会将label赋值为1,输入假图片,就将label赋值为0

在训练过程中,生成网络G的目标就是尽量生成真实的图片去欺骗判别网络D,使D认为自己生成的是真图片;而D的目标就是尽量将G生成的图片和真实的图片区别开来,这样G和D就形成了一个动态的博弈过程。

那么,最后博弈的结果是什么呢?在理想的状态下,G足以生成以假乱真的图片G(z),对于D来说,它难以判定G生成的图片究竟是不是真实的,因此有D(G(z))=0.5

具体的流程如图中所示。

首先有一个一代的G,它生成的是一些很差的图片然后有一个一代的D,它能很准确的把G生成的假图片和真实的图片区分出来,打上标签0。其实这个D就是一个二分类器,对生成的图片输出0,而对真实的图片输出1。

接着经过训练,出现了二代的G,它能生成稍好一点的图片,能让一代的D认为他生成的是真图片。这时也出现了二代的D,它能识别哪些图片是G生成的,哪些是真实的图片。

以此类推,会有三代,四代。。。。n 代的 G(generator) 和D( discriminator),最后 D 无法分辨生成的图片和真实图片,这个网络就拟合了。

这两种网络具体是怎样的呢?

@Discriminator Network

首先要说的是对抗网络,因为这个网络相对较为简单。

对抗网络简单来说就是一个判断真假的判别器,解决的是一个二分类的问题。输入一张真的图片时我们希望它的输出结果是1,输入一张假的图片我们希望它能输出0。这其实和原图片的类别没有什么关系,无论原图片是什么类别的图片,我们都统称它为真图片,label为1;而生成的图片它是假的,label为0

我们对D训练的过程就是希望这个判别器能准确地判别出真的图片和假的图片,对于这个二分类问题可以有很多解决的方法,比如 logistic回归,深层网络,卷积神经网络,循环神经网络都可以。

# 判别网络class discriminator(nn.Module):def __init__(self):super(discriminator, self).__init__()self.dis = nn.Sequential(nn.Linear(784, 256),nn.LeakyReLU(0.2),nn.Linear(256, 256),nn.LeakyReLU(0.2),nn.Linear(256, 1),# sigmoid激活函数得到一个0到1之间的概率进行二分类nn.Sigmoid()) def forward(self, x):x = self.dis(x)return x

@Generative Network

怎样才能生成一张假的图片呢?

首先给出一个简单高维的正态分布的噪声向量,如上图所示的D-dimensional noise vector,通过对它进行仿射变换,也就是将 xw+b 映射到一个更高的维度,然后将它重新排列成一个矩形,这样看起来就更加像一张图片,然后经过一系列的卷积、池化、激活操作,最后得到了一个与我们输入图片大小一模一样的噪声矩阵,这就是我们所说的假的图片。这个时候是怎样训练我们的生成器呢?其实通过判别器来得到结果,我们不断增大判别器识别生成图片为真的概率,在这一步我们不会更新判别器的参数,只会更新生成器的参数。

# 生成网络class generator(nn.Module):def __init__(self, input_size):super(generator, self).__init__()self.gen = nn.Sequential(nn.Linear(input_size, 256),nn.ReLU(True),nn.Linear(256, 256),nn.ReLU(True),nn.Linear(256, 784),#Tanh激活函数是希望生成的假的图片数据分布能够在-1~1之间nn.Tanh())def forward(self, x):x = self.gen(x)return x

3、训练 Train

@判别器训练

判别器的训练由两部分构成,分别是真的图片判别为真,假的图片判别为假,而在这个过程中,生成器的参与不参与更新

首先是定义loss函数的度量方式和优化函数,loss度量使用二分类的交叉熵,优化函数要注意使用的学习率是0.0003

criterion = nn.BCELoss()d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)

然后进入训练

img = img.view(num_img, -1)# 将图片展开乘28x28=784real_img = Variable(img).cuda() # 将tensor变成Variable放入计算图中real_label = Variable(torch.ones(num_img)).cuda() # 定义真的label为1fake_label = Variable(torch.zeros(num_img)).cuda() # 定义假的label为0# 计算真图片的lossreal_out = D(real_img) # 将真实的图片放入判别器中d_loss_real = criterion(real_out, real_label) # 得到真实图片的loss real_scores = real_out # 真实图片放入判别器输出越接近1越好# 计算假图片的lossz = Variable(torch.randn(num_img, z_dimension)).cuda() # 随机生成一些噪声fake_img = G(z)# 将一张假的图片放入生成网络中fake_out = D(fake_img) # 判别器判断假的图片d_loss_fake = criterion(fake_out, fake_label) # 得到假的图片的lossfake_scores = fake_out # 假的图片放入判别器越接近0越好# bp and optimized_loss = d_loss_real + d_loss_fake # 将真假图片的loss加起来d_optimizer.zero_grad() # 归0梯度d_loss.backward() # 反向传播d_optimizer.step() # 更新参数

@生成器训练

在生成网络的训练过程中,我们生成假的图片,但是希望判别器能将它识别为真的图片。我们将判别器固定,将假的图片传入判别器的结果与真实label对应,反向传播更新的参数是生成网络里面的参数,这样我们就可以通过更新生成网络里面的参数来使得判别器判断生成的假的图片为真,这样就达到了生成对抗的作用。

# 计算假图片的lossz = Variable(torch.randn(num_img, z_dimension)).cuda() # 得到随机噪声fake_img = G(z) # 生成假的图片output = D(fake_img) # 经过判别器得到结果g_loss = criterion(output, real_label) # 得到假的图片与真实图片label的loss# bp and optimizeg_optimizer.zero_grad() # 归0梯度g_loss.backward() # 反向传播g_optimizer.step() # 更新生成网络的参数

4、全部代码 (pytorch实现)

贴上程序完整的代码:

import torchimport torchvisionimport torch.nn as nnimport torch.nn.functional as Ffrom torchvision import datasetsfrom torchvision import transformsfrom torchvision.utils import save_imagefrom torch.autograd import Variableimport osif not os.path.exists('img'):os.mkdir('img')def to_img(x):out = 0.5 * (x + 1)out = out.clamp(0, 1)out = out.view(-1, 1, 28, 28)return out# 初始化参数batch_size = 128num_epoch = 50z_dimension = 100# 对图片进行一些前期处理img_transform = pose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])# img_transform = pose([# transforms.ToTensor(),# transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))# ]# 下载数据集mnist = datasets.MNIST(root='mnist_data', train=True, transform=img_transform, download=False)# 加载数据集dataloader = torch.utils.data.DataLoader(dataset=mnist, batch_size=batch_size, shuffle=True)# 判别网络class discriminator(nn.Module):def __init__(self):super(discriminator, self).__init__()self.dis = nn.Sequential(nn.Linear(784, 256),nn.LeakyReLU(0.2),nn.Linear(256, 256),nn.LeakyReLU(0.2), nn.Linear(256, 1),nn.Sigmoid()) # sigmoid激活函数得到一个0到1之间的概率进行二分类def forward(self, x):x = self.dis(x)return x# 生成器class generator(nn.Module):def __init__(self):super(generator, self).__init__()self.gen = nn.Sequential(nn.Linear(100, 256),nn.ReLU(True),nn.Linear(256, 256),nn.ReLU(True),nn.Linear(256, 784),nn.Tanh()) # Tanh激活函数是希望生成的假的图片数据分布能够在-1~1之间。def forward(self, x):x = self.gen(x)return xD = discriminator()G = generator()if torch.cuda.is_available():D = D.cuda()G = G.cuda()# 判别器的训练由两部分组成,第一部分是真的图像判别为真,第二部分是假的图片判别为假,在这两个过程中,生成器的参数不参与更新。# 二进制交叉熵损失和优化器criterion = nn.BCELoss()d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)# 开始训练for epoch in range(num_epoch):for i, (img, _) in enumerate(dataloader):num_img = img.size(0)# ================================训练判别器===================================img = img.view(num_img, -1) # # 将图片展开乘28x28=784# real_img = Variable(img).cuda()# real_label = Variable(torch.ones(num_img)).cuda()# fake_label = Variable(torch.zeros(num_img)).cuda()real_img = Variable(img)real_label = Variable(torch.ones(num_img)) # 定义真实label为1fake_label = Variable(torch.zeros(num_img)) # 定义假label为1# 计算 real_img 的损失real_out = D(real_img) # 将真实的图片放入判别器中d_loss_real = criterion(real_out, real_label) # 得到真实图片的lossreal_scores = real_out # 越接近一越好# 计算 fake_img的损失# z = Variable(torch.randn(num_img, z_dimension)).cuda()z = Variable(torch.randn(num_img, z_dimension)) # 随机生成一些噪声fake_img = G(z) # 放入生成网络生成一张假的图片fake_out = D(fake_img) ## 判别器判断假的图片d_loss_fake = criterion(fake_out, fake_label) ## 得到假的图片的lossfake_scores = fake_out # 越接近0越好# 反向传播和优化d_loss = d_loss_real + d_loss_fake # 将真假图片的loss加起来d_optimizer.zero_grad() # 每次梯度归零d_loss.backward() # 反向传播d_optimizer.step() # 更新参数# =================================训练生成器================================# 计算fake_img损失# z = Variable(torch.randn(num_img, z_dimension)).cuda()z = Variable(torch.randn(num_img, z_dimension)) # 得到随机噪声fake_img = G(z) # 生成假的图片output = D(fake_img) # 经过判别器得到结果g_loss = criterion(output, real_label) ##得到假的图片与真实图片label的loss# 反向传播和优化g_optimizer.zero_grad()g_loss.backward()g_optimizer.step()if (i + 1) % 100 == 0:print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f},D real: {:.6f}, D fake: {:.6f}'.format(epoch, num_epoch, d_loss.item(), g_loss.item(),real_scores.data.mean(), fake_scores.data.mean()))if epoch == 0:real_images = to_img(real_img.cpu().data)save_image(real_images, 'real_images.png')fake_images = to_img(fake_img.cpu().data)save_image(fake_images, 'fake_images-{}.png'.format(epoch + 1))torch.save(G.state_dict(), 'generator.pth')torch.save(D.state_dict(), 'discriminator.pth')

5、结果 Result

结果展示:

随着epoch的增加,可以发现产生的噪声更少了,训练也更加稳定,图片中的数字也从模糊逐渐变为清晰,epoch-49中的图片简直就像真的图片一样。

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。