700字范文,内容丰富有趣,生活中的好帮手!
700字范文 > GAN —— 《Generative Adversarial Nets》

GAN —— 《Generative Adversarial Nets》

时间:2022-03-20 05:08:12

相关推荐

GAN —— 《Generative Adversarial Nets》

《Generative Adversarial Nets》

生成式对抗网络;作者:lan Goodfellow;单位:加拿大蒙特利尔大学;发表会议及时间:NeurlPS(NIPS) ;

核心要点

提出了一个基于对抗的 新生成式模型,由一个生成器和一个判别器组成;生成器的目标是学习到样本的数据分布,从而能生成样本欺骗判别器;判别器的目标是判断输入样本时生成/真实的概率;GAN模型等同于博弈论中的二人零和博弈;对于任意的生成器和判别器,都存在一个独特的全局最优解;在本文中,生成器和判别器都是由多层感知机实现,整个网络可以用反向传播算法来训练;通过实验的定性与定量分析显示,GAN具备很大的潜力;

研究背景

1、零和博弈

一方的收益必然意味着另一方的损失,博弈各方的收益和损失相加总和永远为“零”,双方不存在合作的可能;在零和博弈中,为了使己方达到最优解,所以把目标设为让对方的最大化收益最小化;

2、使用数据集

MNIST:手写数据集,源自NIST;28*28的灰度图,训练集60000张,测试集10000张;

TFD:The Toronro face dataset,人脸数据集;

CIFAR-10:32*32彩图,10个类别,每类6000张图,训练集50000张,测试集10000张;

3、GAN价值函数

价值函数

minGmaxDV(D,G)=Ex∼pdata(x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))]min_G max_D V(D,G)=E_{x\sim p_{data}(x)}[log D(x)]+E_{z\sim p_z(z)}[log(1-D(G(z)))]minG​maxD​V(D,G)=Ex∼pdata​(x)​[logD(x)]+Ez∼pz​(z)​[log(1−D(G(z)))]

datadatadata:真实数据;DDD:判别器,输出值为[0,1],代表输入来自真实数据的概率;zzz:随机噪声;GGG:生成器,输出为合成数据;

判别器DDD的目的是最大化价值函数VVV,对数函数log在底数大于1时为单调递增函数,最大化VVV就是最大化D(x)D(x)D(x)和1−D(G(z))1-D(G(z))1−D(G(z)),对于任意的x,都有D(x)=1D(x)=1D(x)=1,对于任意的zzz都有D(G(z))=0D(G(z))=0D(G(z))=0。

生成器GGG的目的是针对特定的DDD,去最小化价值函数VVV;最小化价值函数VVV,就是最小化D(x)D(x)D(x)和1−D(G(z))1-D(G(z))1−D(G(z));对于任意的zzz,都有D(G(z))=1D(G(z))=1D(G(z))=1。

训练小trick

在开始训练的时候,生成器GGG的性能较差,D(G(z))D(G(z))D(G(z))接近于0,此时价值函数中的log(1−D(G(z)))log(1-D(G(z)))log(1−D(G(z)))的梯度值较小,而log(D(G(z)))log(D(G(z)))log(D(G(z)))的梯度值较大,所以可以把生成器GGG的目标改为最大化logD(G(z))logD(G(z))logD(G(z)),这样可以在早期学习中提供更强的梯度。

4、训练流程

使用mini-batch梯度下降(带momentum);训练k次判别器(本论文实验中k=1);训练1次生成器;

根据伪代码可以知道,对应两个神经网络模型——生成器GGG和判别器DDD,首先会固定生成器GGG的参数,使用生成器GGG生成的数据和真实的数据训练判别器DDD,训练k次判别器DDD后,固定判别器DDD的参数,训练生成器GGG。

理想情况下,判别器的最优解为:DG∗(x)=Pdata(x)Pdata(x)+Pg(x)D^*_{G}(x)=\frac{P_{data}(x)}{P_{data}(x)+P_g(x)}DG∗​(x)=Pdata​(x)+Pg​(x)Pdata​(x)​判别器取得最优解时,生成器的最优解为:Pg=PdataP_g=P_{data}Pg​=Pdata​此时价值函数的值为C∗=−log(4)C^*=-log(4)C∗=−log(4)

模型优劣势

缺点:

没有显式表示的Pg(x)P_g(x)Pg​(x);必须同步训练G和D,可能会发生模式崩溃;

优点:

不使用马尔科夫链,在学习过程中不需要推理;可以将多种函数合并到模型中;可以表示非常尖锐、甚至退化的分布;不是直接使用数据来计算loss更新生成器,而是使用判别器的梯度,所以数据不会直接复制到生成器的参数中;

Pytorch代码

# 代码来源:/eriklindernoren/PyTorch-GANimport argparseimport osimport numpy as npimport mathimport torchvision.transforms as transformsfrom torchvision.utils import save_imagefrom torch.utils.data import DataLoaderfrom torchvision import datasetsfrom torch.autograd import Variableimport torch.nn as nnimport torch.nn.functional as Fimport torchos.makedirs("images", exist_ok=True)parser = argparse.ArgumentParser()parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training") # 迭代次数parser.add_argument("--batch_size", type=int, default=64, help="size of the batches") # 批量大小parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate") # adam的学习率parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient") # 动量法parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space") # 生成器输入维度parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension") # 照片的尺寸parser.add_argument("--channels", type=int, default=1, help="number of image channels") # 通道数,1表示灰度图parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples") # 采样照片频率opt = parser.parse_args()print(opt)img_shape = (opt.channels, opt.img_size, opt.img_size)cuda = True if torch.cuda.is_available() else Falseclass Generator(nn.Module): # 生成器def __init__(self):super(Generator, self).__init__()def block(in_feat, out_feat, normalize=True):layers = [nn.Linear(in_feat, out_feat)]if normalize:layers.append(nn.BatchNorm1d(out_feat, 0.8))layers.append(nn.LeakyReLU(0.2, inplace=True))return layersself.model = nn.Sequential(*block(opt.latent_dim, 128, normalize=False),*block(128, 256),*block(256, 512),*block(512, 1024),nn.Linear(1024, int(np.prod(img_shape))),nn.Tanh())def forward(self, z):img = self.model(z)img = img.view(img.size(0), *img_shape)return imgclass Discriminator(nn.Module): # 判别器def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(int(np.prod(img_shape)), 512),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 256),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256, 1),nn.Sigmoid(),)def forward(self, img):img_flat = img.view(img.size(0), -1)validity = self.model(img_flat)return validity# Loss functionadversarial_loss = torch.nn.BCELoss()# Initialize generator and discriminatorgenerator = Generator()discriminator = Discriminator()if cuda:generator.cuda()discriminator.cuda()adversarial_loss.cuda()# Configure data loaderos.makedirs("../../data/mnist", exist_ok=True)dataloader = torch.utils.data.DataLoader(datasets.MNIST("../../data/mnist",train=True, # 训练模式download=True, # 如果MNIST没有下载则直接下载transform=pose([transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]), # 照片处理方式), # 数据集batch_size=opt.batch_size, # 训练数据批量大小shuffle=True, # 是否打乱)# Optimizersoptimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) # 生成器的优化器optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) # 判别器的优化器Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor# ----------# Training# ----------for epoch in range(opt.n_epochs):for i, (imgs, _) in enumerate(dataloader):# Adversarial ground truthsvalid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False) # 真实数据的labelfake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False) # 生成数据的label# Configure inputreal_imgs = Variable(imgs.type(Tensor)) # 真实照片# -----------------# Train Generator# -----------------optimizer_G.zero_grad() ## Sample noise as generator inputz = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))) # 生成随机分布的数据# Generate a batch of imagesgen_imgs = generator(z) # 生成器生成伪照片# Loss measures generator's ability to fool the discriminatorg_loss = adversarial_loss(discriminator(gen_imgs), valid) # 生成器的目的是骗过判别器,所以希望生成器生成的照片被预测为1g_loss.backward()optimizer_G.step()# ---------------------# Train Discriminator# ---------------------optimizer_D.zero_grad()# Measure discriminator's ability to classify real from generated samplesreal_loss = adversarial_loss(discriminator(real_imgs), valid) # 判别器希望真实的照片预测为1fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake) # 判别器希望伪造的照片预测为0d_loss = (real_loss + fake_loss) / 2d_loss.backward()optimizer_D.step()print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item()))batches_done = epoch * len(dataloader) + iif batches_done % opt.sample_interval == 0:save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)os.makedirs("model", exist_ok=True)torch.save(generator, 'model/generator.pkl')torch.save(discriminator, 'model/discriminator.pkl')

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。