700字范文,内容丰富有趣,生活中的好帮手!
700字范文 > Pytorch实现GAN之生成手写数字图片

Pytorch实现GAN之生成手写数字图片

时间:2019-11-27 08:13:33

相关推荐

Pytorch实现GAN之生成手写数字图片

1.导入所需库

import torchimport torch.optim as optimimport torch.nn as nnimport torch.nn.functional as Fimport numpy as npimport matplotlib.pyplot as pltimport torchvisionimport torchvision.transforms as transforms

2. 训练集

# mini_batch sizemb_size=64#translate data to tensor format which is pytorch's expected formattransforms=pose([transforms.ToTensor()])#训练集trainset= torchvision.datasets.MNIST(root='./NewData',download=False,train=True,transform=transforms)trainloader=torch.utils.data.DataLoader(trainset,shuffle=True,batch_size=mb_size)

Notes

torchvision.transforms是pytorch中的图像预处理包。

一般用Compose()把多个步骤整合到一起,例如:

pose([transforms.CenterCrop(10),transforms.ToTensor(),])

此外,常用的transforms中的函数:

Resize:把给定的图片resize到given size

Normalize:Normalized an tensor image with mean and standard deviation

ToTensor:convert a PIL image to tensor (HWC) in range [0,255] to a

torch.Tensor(C * H * W) in the range [0.0,1.0]

ToPILImage: convert a tensor to PIL image

参考:/ftimes/article/details/105202795

3.可视化

参考:/xiongchengluo1129/article/details/79078478

#define an iteratordata_iter=iter(trainloader)#getting the next batch of the image and labelsimages,labels=data_iter.next()test=images.view(images.size(0),-1)print(test.size())#dims and learning ratez_dim=100x_dim=test.size(1)h_dim=128lr=0.003def imshow(img):#拼接图片im=torchvision.utils.make_grid(img)#转化成numpynpimg=im.numpy()plt.figure(figsize=(8,8))plt.imshow(np.transpose(npimg,(1,2,0)))plt.xticks([])plt.yticks([])plt.show()imshow(images)

输出:

Notes:

将多维度的tensor展平成一维,x.view(x.size(0), -1)就实现的这个功能。

所以我们一个batch里面的64张图,图片的大小是28 * 28,输出的size为64 * 784。

make_grid的作用是将若干幅图像拼成一幅图像。其中padding的作用就是子图像与子图像之间的pad有多宽。

plt.figure()语法

figure(num=None, figsize=None, dpi=None, facecolor=None)

edgecolor=None, frameon=True)

num:图像编号或名称,数字为编号 ,字符串为名称

figsize:指定figure的宽和高,单位为英寸;

dpi: 指定绘图对象的分辨率,即每英寸多少个像素,缺省值为80 1英寸等于2.5cm,A4纸是 21*30cm的纸张

facecolor:背景颜色

edgecolor:边框颜色

frameon:是否显示边框

np.transpose(img,(1,2,0))将图片的格式由(channels,imagesize,imagesize)转化为(imagesize,imagesize,channels),这样plt.show()就可以显示图片了。plt.xticks()用法参考:/Tenderness___/article/details/82972845

4. 初始化weight和bias

def init_weights(m):if type(m)==nn.Linear:#初始化权重nn.init.xavier_uniform(m.weight)# bias都设为0m.bias.data.fill_(0)

参考:/dss_dssssd/article/details/83959474‘

pytorch官方教程中的例子:

5. Generator and Discriminator

class Generate(nn.Module):def __init__(self):super(Generate,self).__init__()self.predict=nn.Sequential(nn.Linear(z_dim,h_dim),nn.ReLU(),nn.Linear(h_dim,x_dim),nn.Sigmoid())self.predict.apply(init_weights)def forward(self,input):return self.predict(input)class Dis(nn.Module):def __init__(self):super(Dis,self).__init__()self.predict =nn.Sequential(nn.Linear(x_dim,h_dim),nn.ReLU(),nn.Linear(h_dim,1),nn.Sigmoid())self.predict.apply(init_weights)def forward(self,input):return self.predict(input)G=Generate()D=Dis()

6.Optimizer

G_solver=optim.Adam(G.parameters(),lr=lr)D_solver=optim.Adam(D.parameters(),lr=lr)

7.Training

for epoch in range(2):G_loss_run=0.0D_loss_run=0.0for i,data in enumerate(trainloader):# data里面包含图像数据(inputs)(tensor类型的)和标签(labels)(tensor类型)。X,label=datamb_size=X.size(0)X=X.view(X.size(0),-1)one_labels=torch.ones(mb_size,1)zero_labels=torch.zeros(mb_size,1)z=torch.randn(mb_size,z_dim)G_samples=G(z)D_fake=D(G_samples)D_real=D(X)D_fake_loss=F.binary_cross_entropy(D_fake,zero_labels)D_real_loss=F.binary_cross_entropy(D_real,one_labels)D_loss=D_fake_loss+D_real_lossD_solver.zero_grad()D_loss.backward(retain_graph=True)D_solver.step()z=torch.rand(mb_size,z_dim)G_sample=G(z)D_fake=D(G_samples)G_loss=F.binary_cross_entropy(D_fake,one_labels)G_solver.zero_grad()G_loss.backward()G_solver.step()print('Epoch: {}, G_loss: {}. D_loss:{}'.format(epoch,G_loss_run/(i+1),D_loss_run/(i+1)))samples=G(z).detach()samples=samples.view(mb_size,1,28,28)imshow(samples)

完整代码:

import torchimport torch.optim as optimimport torch.nn as nnimport torch.nn.functional as Fimport numpy as npimport matplotlib.pyplot as pltimport torchvisionimport torchvision.transforms as transformstorch.manual_seed(0)mb_size=64#translate data to tensor format which is pytorch's expected formattransforms=pose([transforms.ToTensor()])#训练集trainset= torchvision.datasets.MNIST(root='./NewData',download=False,train=True,transform=transforms)trainloader=torch.utils.data.DataLoader(trainset,shuffle=True,batch_size=mb_size)#可视化#define an iteratordataiter = iter(trainloader)#getting the next batch of the image and labelsimgs, labels = dataiter.next()test=imgs.view(imgs.size(0),-1)print(test.size())h_dim = 128 # number of hidden neurons in our hidden layerZ_dim = 100 # dimension of the input noise for generatorlr = 1e-3# learning rateX_dim = imgs.view(imgs.size(0), -1).size(1)print(X_dim)def imshow(img):im=torchvision.utils.make_grid(img)npimg=im.numpy()plt.figure(figsize=(8,8))plt.imshow(np.transpose(npimg,(1,2,0)))plt.xticks([])plt.yticks([])plt.show()imshow(imgs)def xavier_init(m):""" Xavier initialization """if type(m) == nn.Linear:nn.init.xavier_uniform_(m.weight)m.bias.data.fill_(0)class Gen(nn.Module):def __init__(self):super().__init__()self.model = nn.Sequential(nn.Linear(Z_dim, h_dim),nn.ReLU(),nn.Linear(h_dim, X_dim),nn.Sigmoid())self.model.apply(xavier_init)def forward(self, input):return self.model(input)class Dis(nn.Module):def __init__(self):super().__init__()self.model = nn.Sequential(nn.Linear(X_dim, h_dim),nn.ReLU(),nn.Linear(h_dim, 1),nn.Sigmoid())self.model.apply(xavier_init)def forward(self, input):return self.model(input)test = Dis()print(test)test = Dis()print(test)G=Gen()D=Dis()G_solver=optim.Adam(G.parameters(),lr=lr)D_solver=optim.Adam(D.parameters(),lr=lr)for epoch in range(20):G_loss_run = 0.0D_loss_run = 0.0for i, data in enumerate(trainloader):X, _ = dataX = X.view(X.size(0), -1)mb_size = X.size(0)# Definig labels for real (1s) and fake (0s) imagesone_labels = torch.ones(mb_size, 1)zero_labels = torch.zeros(mb_size, 1)# Random normal distribution for each imagez = torch.randn(mb_size, Z_dim)# Feed forward in discriminator both# fake and real imagesD_real = D(X)# fakes = G(z)D_fake = D(G(z))# Defining the loss for DiscriminatorD_real_loss = F.binary_cross_entropy(D_real, one_labels)D_fake_loss = F.binary_cross_entropy(D_fake, zero_labels)D_loss = D_fake_loss + D_real_loss# backward propagation for discriminatorD_solver.zero_grad()D_loss.backward()D_solver.step()# Feed forward for generatorz = torch.randn(mb_size, Z_dim)D_fake = D(G(z))# loss function of generatorG_loss = F.binary_cross_entropy(D_fake, one_labels)# backward propagation for generatorG_solver.zero_grad()G_loss.backward()G_solver.step()G_loss_run += G_loss.item()D_loss_run += D_loss.item()# printing loss after each epochprint('Epoch:{}, G_loss:{}, D_loss:{}'.format(epoch, G_loss_run / (i + 1), D_loss_run / (i + 1)))# Plotting fake images generated after each epoch by generatorsamples = G(z).detach()samples = samples.view(samples.size(0), 1, 28, 28)imshow(samples)

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。