1.导入所需库
import torchimport torch.optim as optimimport torch.nn as nnimport torch.nn.functional as Fimport numpy as npimport matplotlib.pyplot as pltimport torchvisionimport torchvision.transforms as transforms
2. 训练集
# mini_batch sizemb_size=64#translate data to tensor format which is pytorch's expected formattransforms=pose([transforms.ToTensor()])#训练集trainset= torchvision.datasets.MNIST(root='./NewData',download=False,train=True,transform=transforms)trainloader=torch.utils.data.DataLoader(trainset,shuffle=True,batch_size=mb_size)
Notes
torchvision.transforms是pytorch中的图像预处理包。
一般用Compose()把多个步骤整合到一起,例如:
pose([transforms.CenterCrop(10),transforms.ToTensor(),])
此外,常用的transforms中的函数:
Resize:把给定的图片resize到given size
Normalize:Normalized an tensor image with mean and standard deviation
ToTensor:convert a PIL image to tensor (HWC) in range [0,255] to a
torch.Tensor(C * H * W) in the range [0.0,1.0]
ToPILImage: convert a tensor to PIL image
参考:/ftimes/article/details/105202795
3.可视化
参考:/xiongchengluo1129/article/details/79078478
#define an iteratordata_iter=iter(trainloader)#getting the next batch of the image and labelsimages,labels=data_iter.next()test=images.view(images.size(0),-1)print(test.size())#dims and learning ratez_dim=100x_dim=test.size(1)h_dim=128lr=0.003def imshow(img):#拼接图片im=torchvision.utils.make_grid(img)#转化成numpynpimg=im.numpy()plt.figure(figsize=(8,8))plt.imshow(np.transpose(npimg,(1,2,0)))plt.xticks([])plt.yticks([])plt.show()imshow(images)
输出:
Notes:
将多维度的tensor展平成一维,x.view(x.size(0), -1)就实现的这个功能。
所以我们一个batch里面的64张图,图片的大小是28 * 28,输出的size为64 * 784。
make_grid的作用是将若干幅图像拼成一幅图像。其中padding的作用就是子图像与子图像之间的pad有多宽。
plt.figure()语法
figure(num=None, figsize=None, dpi=None, facecolor=None)
np.transpose(img,(1,2,0))将图片的格式由(channels,imagesize,imagesize)转化为(imagesize,imagesize,channels),这样plt.show()就可以显示图片了。plt.xticks()用法参考:/Tenderness___/article/details/82972845edgecolor=None, frameon=True)
num:图像编号或名称,数字为编号 ,字符串为名称
figsize:指定figure的宽和高,单位为英寸;
dpi: 指定绘图对象的分辨率,即每英寸多少个像素,缺省值为80 1英寸等于2.5cm,A4纸是 21*30cm的纸张
facecolor:背景颜色
edgecolor:边框颜色
frameon:是否显示边框
4. 初始化weight和bias
def init_weights(m):if type(m)==nn.Linear:#初始化权重nn.init.xavier_uniform(m.weight)# bias都设为0m.bias.data.fill_(0)
参考:/dss_dssssd/article/details/83959474‘
pytorch官方教程中的例子:
5. Generator and Discriminator
class Generate(nn.Module):def __init__(self):super(Generate,self).__init__()self.predict=nn.Sequential(nn.Linear(z_dim,h_dim),nn.ReLU(),nn.Linear(h_dim,x_dim),nn.Sigmoid())self.predict.apply(init_weights)def forward(self,input):return self.predict(input)class Dis(nn.Module):def __init__(self):super(Dis,self).__init__()self.predict =nn.Sequential(nn.Linear(x_dim,h_dim),nn.ReLU(),nn.Linear(h_dim,1),nn.Sigmoid())self.predict.apply(init_weights)def forward(self,input):return self.predict(input)G=Generate()D=Dis()
6.Optimizer
G_solver=optim.Adam(G.parameters(),lr=lr)D_solver=optim.Adam(D.parameters(),lr=lr)
7.Training
for epoch in range(2):G_loss_run=0.0D_loss_run=0.0for i,data in enumerate(trainloader):# data里面包含图像数据(inputs)(tensor类型的)和标签(labels)(tensor类型)。X,label=datamb_size=X.size(0)X=X.view(X.size(0),-1)one_labels=torch.ones(mb_size,1)zero_labels=torch.zeros(mb_size,1)z=torch.randn(mb_size,z_dim)G_samples=G(z)D_fake=D(G_samples)D_real=D(X)D_fake_loss=F.binary_cross_entropy(D_fake,zero_labels)D_real_loss=F.binary_cross_entropy(D_real,one_labels)D_loss=D_fake_loss+D_real_lossD_solver.zero_grad()D_loss.backward(retain_graph=True)D_solver.step()z=torch.rand(mb_size,z_dim)G_sample=G(z)D_fake=D(G_samples)G_loss=F.binary_cross_entropy(D_fake,one_labels)G_solver.zero_grad()G_loss.backward()G_solver.step()print('Epoch: {}, G_loss: {}. D_loss:{}'.format(epoch,G_loss_run/(i+1),D_loss_run/(i+1)))samples=G(z).detach()samples=samples.view(mb_size,1,28,28)imshow(samples)
完整代码:
import torchimport torch.optim as optimimport torch.nn as nnimport torch.nn.functional as Fimport numpy as npimport matplotlib.pyplot as pltimport torchvisionimport torchvision.transforms as transformstorch.manual_seed(0)mb_size=64#translate data to tensor format which is pytorch's expected formattransforms=pose([transforms.ToTensor()])#训练集trainset= torchvision.datasets.MNIST(root='./NewData',download=False,train=True,transform=transforms)trainloader=torch.utils.data.DataLoader(trainset,shuffle=True,batch_size=mb_size)#可视化#define an iteratordataiter = iter(trainloader)#getting the next batch of the image and labelsimgs, labels = dataiter.next()test=imgs.view(imgs.size(0),-1)print(test.size())h_dim = 128 # number of hidden neurons in our hidden layerZ_dim = 100 # dimension of the input noise for generatorlr = 1e-3# learning rateX_dim = imgs.view(imgs.size(0), -1).size(1)print(X_dim)def imshow(img):im=torchvision.utils.make_grid(img)npimg=im.numpy()plt.figure(figsize=(8,8))plt.imshow(np.transpose(npimg,(1,2,0)))plt.xticks([])plt.yticks([])plt.show()imshow(imgs)def xavier_init(m):""" Xavier initialization """if type(m) == nn.Linear:nn.init.xavier_uniform_(m.weight)m.bias.data.fill_(0)class Gen(nn.Module):def __init__(self):super().__init__()self.model = nn.Sequential(nn.Linear(Z_dim, h_dim),nn.ReLU(),nn.Linear(h_dim, X_dim),nn.Sigmoid())self.model.apply(xavier_init)def forward(self, input):return self.model(input)class Dis(nn.Module):def __init__(self):super().__init__()self.model = nn.Sequential(nn.Linear(X_dim, h_dim),nn.ReLU(),nn.Linear(h_dim, 1),nn.Sigmoid())self.model.apply(xavier_init)def forward(self, input):return self.model(input)test = Dis()print(test)test = Dis()print(test)G=Gen()D=Dis()G_solver=optim.Adam(G.parameters(),lr=lr)D_solver=optim.Adam(D.parameters(),lr=lr)for epoch in range(20):G_loss_run = 0.0D_loss_run = 0.0for i, data in enumerate(trainloader):X, _ = dataX = X.view(X.size(0), -1)mb_size = X.size(0)# Definig labels for real (1s) and fake (0s) imagesone_labels = torch.ones(mb_size, 1)zero_labels = torch.zeros(mb_size, 1)# Random normal distribution for each imagez = torch.randn(mb_size, Z_dim)# Feed forward in discriminator both# fake and real imagesD_real = D(X)# fakes = G(z)D_fake = D(G(z))# Defining the loss for DiscriminatorD_real_loss = F.binary_cross_entropy(D_real, one_labels)D_fake_loss = F.binary_cross_entropy(D_fake, zero_labels)D_loss = D_fake_loss + D_real_loss# backward propagation for discriminatorD_solver.zero_grad()D_loss.backward()D_solver.step()# Feed forward for generatorz = torch.randn(mb_size, Z_dim)D_fake = D(G(z))# loss function of generatorG_loss = F.binary_cross_entropy(D_fake, one_labels)# backward propagation for generatorG_solver.zero_grad()G_loss.backward()G_solver.step()G_loss_run += G_loss.item()D_loss_run += D_loss.item()# printing loss after each epochprint('Epoch:{}, G_loss:{}, D_loss:{}'.format(epoch, G_loss_run / (i + 1), D_loss_run / (i + 1)))# Plotting fake images generated after each epoch by generatorsamples = G(z).detach()samples = samples.view(samples.size(0), 1, 28, 28)imshow(samples)