700字范文,内容丰富有趣,生活中的好帮手!
700字范文 > 生成式对抗网络生成mnist数据集

生成式对抗网络生成mnist数据集

时间:2023-08-15 10:03:59

相关推荐

生成式对抗网络生成mnist数据集

生成式对抗网络(GAN)是基于可微生成器网络的另一种生成式建模方法。

GAN一般有两个内容,一是生成器(generator),二是辨别器(discriminator),辨别器的目的是:尽可能地分辨输入的数据是生成器生成的假数据还是真实的数据;生成器的目的是:尽可能地骗过辨别器,使得辨别器认为它生成的数据是真实的数据

生成式对抗网络基于博弈论场景,其中生成器网络必须与对手竞争,生成器网络直接产生样本。其对手判别器网络(diacriminator network)试图区分从训练数据中抽取的样本和从生成器中抽取的样本。通过下图可以简单看一下生成对抗的过程。

首先是对抗过程

(Discriminator Network)

对抗过程简单来说就是一个判断真假的判别器,相当于一个二分类问题,我们输入一张真的图片希望判别器输出的结果是1,输入一张假的图片希望判别器输出的结果是0。这其实已经和原图片的label没有关系了,不管原图片到底是一个多少类别的图片,他们都统一称为真的图片,label是1表示真实的;而生成的假的图片的label是0表示假的。我们训练的过程就是希望这个判别器能够正确的判出真的图片和假的图片。

代码如下

class discriminator(nn.Module):def __init__(self):super(discriminator, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(1, 32, 5, padding=2), # batch, 32, 28, 28nn.LeakyReLU(0.2, True),nn.AvgPool2d(2, stride=2), # batch, 32, 14, 14)self.conv2 = nn.Sequential(nn.Conv2d(32, 64, 5, padding=2), # batch, 64, 14, 14nn.LeakyReLU(0.2, True),nn.AvgPool2d(2, stride=2) # batch, 64, 7, 7)self.fc = nn.Sequential(nn.Linear(64 * 7 * 7, 1024),nn.LeakyReLU(0.2, True),nn.Linear(1024, 1),nn.Sigmoid())def forward(self, x):'''x: batch, width, height, channel=1'''x = self.conv1(x)x = self.conv2(x)x = x.view(x.size(0), -1)x = self.fc(x)return x

接着是生成图片的过程

(Generative Network)

首先给出一个简单的高维的正态分布的噪声向量,如上图所示的D-dimensional noise vector,这个时候我们可以通过仿射变换,也就是xw+b将其映射到一个更高的维度,然后将他重新排列成一个矩形,这样看着更像一张图片,接着进行一些卷积、池化、激活函数处理,最后得到了一个与我们输入图片大小一模一样的噪音矩阵,这就是我们所说的假的图片,这个时候我们如何去训练这个生成器呢?就是通过判别器来得到结果,然后希望增大判别器判别这个结果为真的概率,在这一步我们不会更新判别器的参数,只会更新生成器的参数。

代码如下

class generator(nn.Module):def __init__(self, input_size, num_feature):super(generator, self).__init__()self.fc = nn.Linear(input_size, num_feature) # batch, 3136=1x56x56self.br = nn.Sequential(nn.BatchNorm2d(1),nn.ReLU(True))self.downsample1 = nn.Sequential(nn.Conv2d(1, 50, 3, stride=1, padding=1), # batch, 50, 56, 56nn.BatchNorm2d(50),nn.ReLU(True))self.downsample2 = nn.Sequential(nn.Conv2d(50, 25, 3, stride=1, padding=1), # batch, 25, 56, 56nn.BatchNorm2d(25),nn.ReLU(True))self.downsample3 = nn.Sequential(nn.Conv2d(25, 1, 2, stride=2), # batch, 1, 28, 28nn.Tanh())def forward(self, x):x = self.fc(x)x = x.view(x.size(0), 1, 56, 56)x = self.br(x)x = self.downsample1(x)x = self.downsample2(x)x = self.downsample3(x)return x

Discriminator训练部分

首先我们需要定义loss的度量方式和优化器,loss度量使用二分类的交叉熵,优化器注意使用的学习率是0.0003

代码如下

D = discriminator().cuda() # discriminator modelG = generator(z_dimension, 3136).cuda() # generator modelcriterion = nn.BCELoss() # binary cross entropyd_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)

接下来是进入正式的train部分

代码如下

# trainfor epoch in range(num_epoch):for i, (img, _) in enumerate(dataloader):num_img = img.size(0)# =================train discriminatorreal_img = Variable(img).cuda()real_label = Variable(torch.ones(num_img)).cuda()fake_label = Variable(torch.zeros(num_img)).cuda()# compute loss of real_imgreal_out = D(real_img)d_loss_real = criterion(real_out, real_label)real_scores = real_out # closer to 1 means better# compute loss of fake_imgz = Variable(torch.randn(num_img, z_dimension)).cuda()fake_img = G(z)fake_out = D(fake_img)d_loss_fake = criterion(fake_out, fake_label)fake_scores = fake_out # closer to 0 means better# bp and optimized_loss = d_loss_real + d_loss_faked_optimizer.zero_grad()d_loss.backward()d_optimizer.step()

以上部分的代码时进行判别器的训练,我们希望最后的结果是判别器能够正确辨别出真假图片。

Generative训练部分

代码如下

z = Variable(torch.randn(num_img, z_dimension)).cuda()fake_img = G(z)output = D(fake_img)g_loss = criterion(output, real_label)# bp and optimizeg_optimizer.zero_grad()g_loss.backward()g_optimizer.step()

通过训练生成网络,我们希望能够生成一张假的图片,然后经过判别器之后希望他能够判断为真的图片,在这个过程中,我们将判别器固定,将假的图片传入判别器的结果与真实label对应,反向传播更新的参数是生成网络里面的参数,这样我们就可以通过跟新生成网络里面的参数来使得判别器判断生成的假的图片为真,这样就达到了生成对抗的作用。

这一部分代码可以看到训练过程中相关loss的变化,这里我只截图了第一个epoch的部分和最后一个epoch的部分

if (i + 1) % 100 == 0:print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} ''D real: {:.6f}, D fake: {:.6f}'.format(epoch, num_epoch, d_loss.item(), g_loss.item(),real_scores.data.mean(), fake_scores.data.mean()))

最后是生成图片的保存和生成模型的保存部分的代码

if epoch == 0:real_images = to_img(real_img.cpu().data)save_image(real_images, './dc_img/real_images.png')fake_images = to_img(fake_img.cpu().data)save_image(fake_images, './dc_img/fake_images-{}.png'.format(epoch + 1))torch.save(G.state_dict(), './generator.pth')torch.save(D.state_dict(), './discriminator.pth')

通过训练之后,我们可以在目录中找到生成mnist的图片

最后是完整的卷积网络代码,相比上面部分的代码,我在里面里面添加了一些比较详细注释,下面完整的代码可以直接运行

import torchimport torch.nn as nnfrom torch.autograd import Variablefrom torch.utils.data import DataLoaderfrom torchvision import transformsfrom torchvision import datasetsfrom torchvision.utils import save_imageimport osif not os.path.exists('./dc_img'):os.mkdir('./dc_img')def to_img(x):out = 0.5 * (x + 1)out = out.clamp(0, 1)out = out.view(-1, 1, 28, 28)return outbatch_size = 16num_epoch = 50z_dimension = 100 # noise dimensionimg_transform = pose([transforms.ToTensor(),transforms.Normalize((0.1307), (0.3081))])mnist = datasets.MNIST('./data', download=True, transform=img_transform)dataloader = DataLoader(mnist, batch_size=batch_size, shuffle=True,num_workers=0)class discriminator(nn.Module):def __init__(self):super(discriminator, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(1, 32, 5, padding=2), # batch, 32, 28, 28nn.LeakyReLU(0.2, True),nn.AvgPool2d(2, stride=2), # batch, 32, 14, 14)self.conv2 = nn.Sequential(nn.Conv2d(32, 64, 5, padding=2), # batch, 64, 14, 14nn.LeakyReLU(0.2, True),nn.AvgPool2d(2, stride=2) # batch, 64, 7, 7)self.fc = nn.Sequential(nn.Linear(64 * 7 * 7, 1024),nn.LeakyReLU(0.2, True),nn.Linear(1024, 1),nn.Sigmoid())def forward(self, x):'''x: batch, width, height, channel=1'''x = self.conv1(x)x = self.conv2(x)x = x.view(x.size(0), -1)x = self.fc(x)return x

class generator(nn.Module):def __init__(self, input_size, num_feature):super(generator, self).__init__()self.fc = nn.Linear(input_size, num_feature) # batch, 3136=1x56x56self.br = nn.Sequential(nn.BatchNorm2d(1),nn.ReLU(True))self.downsample1 = nn.Sequential(nn.Conv2d(1, 50, 3, stride=1, padding=1), # batch, 50, 56, 56nn.BatchNorm2d(50),nn.ReLU(True))self.downsample2 = nn.Sequential(nn.Conv2d(50, 25, 3, stride=1, padding=1), # batch, 25, 56, 56nn.BatchNorm2d(25),nn.ReLU(True))self.downsample3 = nn.Sequential(nn.Conv2d(25, 1, 2, stride=2), # batch, 1, 28, 28nn.Tanh())def forward(self, x):x = self.fc(x)x = x.view(x.size(0), 1, 56, 56)x = self.br(x)x = self.downsample1(x)x = self.downsample2(x)x = self.downsample3(x)return xD = discriminator().cuda() # discriminator modelG = generator(z_dimension, 3136).cuda() # generator model# 首先是定义loss的度量方式,使用的是单目标二分类交叉熵函数criterion = nn.BCELoss()# 其次定义 优化函数,优化函数的学习率为0.0003d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)

# trainfor epoch in range(num_epoch):for i, (img, _) in enumerate(dataloader):num_img = img.size(0)# =================train discriminatorreal_img = Variable(img).cuda() # 将tensor变成Variable放入计算图中real_label = Variable(torch.ones(num_img)).cuda() # 这一步是定义真实图片的lable为1fake_label = Variable(torch.zeros(num_img)).cuda() # 这一步是定义假的图片的lable为0# compute loss of real_imgreal_out = D(real_img) # 将真实图片放入到判别器中d_loss_real = criterion(real_out, real_label) # 得到真实图片的lossreal_scores = real_out # 真实图片放入判别器中输出越接近1越好# compute loss of fake_imgz = Variable(torch.randn(num_img, z_dimension)).cuda() # 随机生成噪声fake_img = G(z) # 在生成网络中放入一张假的图片fake_out = D(fake_img) # 判别器判断假的的图片d_loss_fake = criterion(fake_out, fake_label) # 得到假的图片的lossfake_scores = fake_out # 假的图片放入判别器中输出越接近0越好# bp and optimized_loss = d_loss_real + d_loss_fake # 将假的图片的;oss加起来d_optimizer.zero_grad() # 梯度置零d_loss.backward() # 反向传播d_optimizer.step() # 更新参数# ===============train generator# compute loss of fake_imgz = Variable(torch.randn(num_img, z_dimension)).cuda() # 得到随机噪声fake_img = G(z) # 生成假的图片output = D(fake_img) # 经过判别器得到结果g_loss = criterion(output, real_label) # 得到假的图片与真实图片label的loss# bp and optimizeg_optimizer.zero_grad() # 梯度置零g_loss.backward() # 反向传播g_optimizer.step() # 更新生成的网络参数# 这个部分会打印出相关的lossif (i + 1) % 100 == 0:print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} ''D real: {:.6f}, D fake: {:.6f}'.format(epoch, num_epoch, d_loss.item(), g_loss.item(),real_scores.data.mean(), fake_scores.data.mean()))# 生成图片的保存,训练好的模型的保存if epoch == 0:real_images = to_img(real_img.cpu().data)save_image(real_images, './dc_img/real_images.png')fake_images = to_img(fake_img.cpu().data)save_image(fake_images, './dc_img/fake_images-{}.png'.format(epoch + 1))torch.save(G.state_dict(), './generator.pth')torch.save(D.state_dict(), './discriminator.pth')

这是训练50个epoch之后所得到的手写数字

训练了50个epoch,每次训练之后都会得到一张通过噪声的方式生成的图片,在最后我们生成了一张看起来非常真的图片

最后我们来说一下为何Gans能够成为最近来机器学习以及深度学习界革命性的发现。这是因为不管是深度学习还是机器学习仍然很大一部分是监督学习,但是创建这么多有label的数据集所需要的人力物力是极大的,同时遇到的新的任务时我们很容易得到原始的没有label的数据集,这是我们需要花大量的时间去给其标定label,所以很多人都认为无监督学习才是机器学习的未来,这个时候Gans的出现为无监督学习提供了有力的支持,这当然引起了学界的大量关注,同时基于Gans的应用也越来越多,所以业界对其也非常狂热。

最后希望本文的代码能够帮助大家更好的理解使用生成式对抗网络GAN生成mnist数据集的相关内容。

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。