700字范文,内容丰富有趣,生活中的好帮手!
700字范文 > pytorch学习之GAN生成MNIST手写数字

pytorch学习之GAN生成MNIST手写数字

时间:2019-09-28 08:34:29

相关推荐

pytorch学习之GAN生成MNIST手写数字

0.简单介绍:

学深度学习的人必然知道,最基本的GAN模型由一个生成器 G 和判别器 D 组成。生成器用于生成假样本,判别器用于判断样本是真实的还是假的。

在整个训练过程中,生成器努力地让生成的图像更加真实,而判别器则努力地去识别出图像的真假,这个过程相当于一个二人博弈,随着时间的推移,生成器和判别器在不断地进行对抗,最终期望两个网络达到一个动态均衡:生成器生成的图像接近于真实图像分布,而判别器识别不出真假图像,对于给定图像的预测为真的概率基本接近0.5(相当于随机猜测类别)。

以下是作为初学者的我 了解GAN的结构和运作机制的代码:

1.必要的库函数:

import argparseimport osimport numpy as npimport mathimport torchvision.transforms as transformsfrom torchvision.utils import save_imagefrom torch.utils.data import DataLoaderfrom torchvision import datasetsfrom torch.autograd import Variableimport torch.nn as nnimport torch.nn.functional as Fimport torchos.makedirs("images", exist_ok=True) #生成的数字最后会放到image文件夹,没有则创建

2.参数设置

使用argparse模块主要用来为脚本传递命令参数功能**,代码更加灵活:**

parser = argparse.ArgumentParser() #创建一个参数对象#调用 add_argument() 方法给 ArgumentParser对象添加程序所需的参数信息parser.add_argument("--n_epochs", type=int, default=100, help="number of epochs of training")parser.add_argument("--batch_size", type=int, default=128, help="size of the batches")parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")parser.add_argument("--channels", type=int, default=1, help="number of image channels")parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")opt = parser.parse_args() # parse_args()返回我们定义的参数字典print(opt)img_shape = (opt.channels, opt.img_size, opt.img_size) #(1,28,28)cuda = True if torch.cuda.is_available() else False #是否使用“cuda”

3.定义生成器:

class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()def block(in_feat, out_feat, normalize=True):layers = [nn.Linear(in_feat, out_feat)] #该例只是用全连接层,未卷积if normalize:layers.append(nn.BatchNorm1d(out_feat, 0.8)) #BatchNorm:在深度神经网络训练过程中使得每一层神经网络的输入保持相同分布,momentum=0.8layers.append(nn.LeakyReLU(0.2, inplace=True)) #inplace = True ,直接覆盖原输入数据的值return layersself.model = nn.Sequential(*block(opt.latent_dim, 128, normalize=False),#opt.latent_dim,100维的随机噪声*block(128, 256),*block(256, 512),*block(512, 1024),nn.Linear(1024, int(np.prod(img_shape))), #np.prod(img_shape),返回1*28*28nn.Tanh() #使用Tanh()激活函数)def forward(self, z): #前向传播img = self.model(z)#高斯噪声信号z调用model(),model()调用block(),完成生成图像操作img = img.view(img.size(0), *img_shape)#img.size(0)为784的像素,转换为(1,28,28)的图像return img

3.1通过输入噪声图片,generator输出一个与真实图片一样大小的图像

3.2.生成器生成图像只用了全连接层哦,没有进行复杂的卷积操作

3.3隐层激活函数采用的是Leaky ReLU,了解各种激活函数,可参考:/p/88429934?from_voters_page=true

3.4在输出层我们使用tanh函数,这是因为tanh在这里相比sigmoid的结果会更好一点

4.定义判别器:

class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(int(np.prod(img_shape)), 512),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 256),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256, 1),nn.Sigmoid(),)def forward(self, img):img_flat = img.view(img.size(0), -1) validity = self.model(img_flat)return validity

4.1判别器接收一张图片,输出层为1个结点,输出为1的概率

4.2同样隐层使用了Leaky ReLU

5.定义损失函数,初始化一个生成器和一个判别器对象,加载数据,定义优化器

# Loss functionadversarial_loss = torch.nn.BCELoss()# Initialize generator and discriminatorgenerator = Generator()discriminator = Discriminator()if cuda:generator.cuda()discriminator.cuda()adversarial_loss.cuda()# Configure data loaderos.makedirs("./data/mnist", exist_ok=True) #不存在则下载dataloader = torch.utils.data.DataLoader(datasets.MNIST("./data/mnist",train=True,download=True,transform=pose([transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])] #transforms.Resize重置图像分辨率),),batch_size=opt.batch_size, #一个batch:128shuffle=True,)# Optimizersoptimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) #Betas是动量梯度的下降Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

6.训练网络

对于生成器来说,传给辨别器的生成图片,生成器希望辨别器打上标签1。因为它要不断训练减小损失,以期望骗过判别器。

对于判别器来说,给定的真实图片,辨别器要为其打上标签1;给定的生成图片,辨别器要为其打上标签0;它要能够识别真假。

for epoch in range(opt.n_epochs):for i, (imgs, _) in enumerate(dataloader):# Adversarial ground truthsvalid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False) #fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False) ## Configure inputreal_imgs = Variable(imgs.type(Tensor))# -----------------# Train Generator# -----------------optimizer_G.zero_grad()#梯度置0# Sample noise as generator inputz = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))) #输入从0到1之间,形状为imgs.shape[0], opt.latent_dim的随机高斯数据。# Generate a batch of imagesgen_imgs = generator(z)#生成图像# Loss measures generator's ability to fool the discriminatorg_loss = adversarial_loss(discriminator(gen_imgs), valid) #计算生成图像的损失,一开始很大g_loss.backward() #算梯度optimizer_G.step() #更新权重# ---------------------# Train Discriminator# ---------------------optimizer_D.zero_grad()# Measure discriminator's ability to classify real from generated samplesreal_loss = adversarial_loss(discriminator(real_imgs), valid) #计算真实图像的损失fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake) #计算生成图像的损失#noise 从 generator 输入,到discriminator 输出,计算 generator 损失,回传,这一步更新了 generator 的参数,并释放了计算图。# 下一步更新 discriminator 的参数时,generator 的输出经过 detach 后,又通过了一遍 discriminator,相当于,generator 的输出前后两次通过了 discriminator ,得到相同的输出d_loss = (real_loss + fake_loss) / 2 #计算判别器的损失d_loss.backward()optimizer_D.step()print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item()))batches_done = epoch * len(dataloader) + iif batches_done % opt.sample_interval == 0:#已完成的batch是400的倍数save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True) #将生成的图片的25张保存下来

6.1生成器端,g_loss表示它希望让判别器对自己生成的图片尽可能输出为1,相当于它在于判别器进行对抗。

6.2判别器端,real_loss对应着真实图片的loss,它尽可能让判别器的输出接近于1,real_loss与 fake_loss加起来就是整个判别器的损失。

6.3 我的例子是先训练生成器,再训练判别器。有的方案是反过来的,可以自己找来参考。

看看我们生成的数据:

马马虎虎。代码清楚了。

深部学习小白,欢迎交流。

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。