700字范文,内容丰富有趣,生活中的好帮手!
700字范文 > 【小白学习PyTorch教程】八 使用图像数据增强手段 提升CIFAR-10 数据集精确度...

【小白学习PyTorch教程】八 使用图像数据增强手段 提升CIFAR-10 数据集精确度...

时间:2024-04-26 07:31:40

相关推荐

【小白学习PyTorch教程】八 使用图像数据增强手段 提升CIFAR-10 数据集精确度...

「@Author:Runsen」

上次基于CIFAR-10 数据集,使用PyTorch构建图像分类模型的精确度是60%,对于如何提升精确度,方法就是常见的transforms图像数据增强手段。

importtorchimporttorch.nnasnnimporttorch.optimasoptimfromtorch.utils.dataimportDataLoaderimporttorchvisionimporttorchvision.datasetsasdatasetsimporttorchvision.transformsastransformsimporttorchvision.utilsasvutilsimportnumpyasnpimportosimportwarningsfrommatplotlibimportpyplotaspltwarnings.filterwarnings('ignore')`device=torch.device('cuda'iftorch.cuda.is_available()else'cpu')

加载数据集

#numberofimagesinoneforwardandbackwardpassbatch_size=128#numberofsubprocessesusedfordataloading#Normallydonotuseitifyourosiswindowsnum_workers=2train_dataset=datasets.CIFAR10('./data/CIFAR10/',train=True,download=True,transform=transform_train)train_loader=DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=num_workers)val_dataset=datasets.CIFAR10('./data/CIFAR10',train=True,transform=transform_test)val_loader=DataLoader(val_dataset,batch_size=batch_size,shuffle=False,num_workers=num_workers)test_dataset=datasets.CIFAR10('./data/CIFAR10',train=False,transform=transform_test)test_loader=DataLoader(test_dataset,batch_size=batch_size,shuffle=False,num_workers=num_workers)#declareclassesinCIFAR10classes=('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')

之前的transform ’只是进行了缩放和归一,在这里添加RandomCrop和RandomHorizontalFlip

#defineatransformtonormalizethedatatransform_train=pose([transforms.RandomCrop(32,padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),#convertingimagestotensortransforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))#iftheimagedatasetisblackandwhiteimage,therecanbejustonenumber.])transform_test=pose([transforms.ToTensor(),transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))])

可视化具体的图像

#functionthatwillbeusedforvisualizingthedatadefimshow(img):img=img/2+0.5#unnormalizeplt.imshow(np.transpose(img,(1,2,0)))#convertfromTensorimage#obtainonebatchofimgesfromtraindatasetdataiter=iter(train_loader)images,labels=dataiter.next()images=images.numpy()#convertimagestonumpyfordisplay#plottheimagesinonebatchwiththecorrespondinglabelsfig=plt.figure(figsize=(25,4))#displayimagesforidxinnp.arange(10):ax=fig.add_subplot(1,10,idx+1,xticks=[],yticks=[])imshow(images[idx])ax.set_title(classes[labels[idx]])

建立常见的CNN模型

#definetheCNNarchitectureclassCNN(nn.Module):def__init__(self):super(CNN,self).__init__()self.main=nn.Sequential(#3x32x32nn.Conv2d(in_channels=3,out_channels=32,kernel_size=3,padding=1),#3x32x32(O=(N+2P-F/S)+1)nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2,stride=2),#32x16x16nn.BatchNorm2d(32),nn.Conv2d(32,64,kernel_size=3,padding=1),#32x16x16nn.ReLU(inplace=True),nn.MaxPool2d(2,2),#64x8x8nn.BatchNorm2d(64),nn.Conv2d(64,128,3,padding=1),#64x8x8nn.ReLU(inplace=True),nn.MaxPool2d(2,2),#128x4x4nn.BatchNorm2d(128),)self.fc=nn.Sequential(nn.Linear(128*4*4,1024),nn.ReLU(inplace=True),nn.Dropout(0.5),nn.Linear(1024,256),nn.ReLU(inplace=True),nn.Dropout(0.5),nn.Linear(256,10))defforward(self,x):#ConvandPoolilnglayersx=self.main(x)#FlattenbeforeFullyConnectedlayersx=x.view(-1,128*4*4)#FullyConnectedLayerx=self.fc(x)returnxcnn=CNN().to(device)cnn

torch.nn.CrossEntropyLoss对输出概率介于0和1之间的分类模型进行分类。

训练模型

#超参数:Hyper Parameterslearning_rate=0.001train_losses=[]val_losses=[]#LossfunctionandOptimizercriterion=nn.CrossEntropyLoss()optimizer=optim.Adam(cnn.parameters(),lr=learning_rate)#definetrainfunctionthattrainsthemodelusingaCIFAR10datasetdeftrain(model,epoch,num_epochs):model.train()total_batch=len(train_dataset)//batch_sizefori,(images,labels)inenumerate(train_loader):X=images.to(device)Y=labels.to(device)###forwardpassandlosscalculation#forwardpasspred=model(X)#calculationoflossvaluecost=criterion(pred,Y)###backwardpassandoptimization#gradientinitializationoptimizer.zero_grad()#backwardpasscost.backward()#parameterupdateoptimizer.step()#trainingstatsif(i+1)%100==0:print('Train,Epoch[%d/%d],lter[%d/%d],Loss:%.4f'%(epoch+1,num_epochs,i+1,total_batch,np.average(train_losses)))train_losses.append(cost.item())n#defthevalidationfunctionthatvalidatesthemodelusingCIFAR10datasetdefvalidation(model,epoch,num_epochs):model.eval()total_batch=len(val_dataset)//batch_sizefori,(images,labels)inenumerate(val_loader):X=images.to(device)Y=labels.to(device)withtorch.no_grad():pred=model(X)cost=criterion(pred,Y)if(i+1)%100==0:print("Validation,Epoch[%d/%d],lter[%d/%d],Loss:%.4f"%(epoch+1,num_epochs,i+1,total_batch,np.average(val_losses)))val_losses.append(cost.item())defplot_losses(train_losses,val_losses):plt.figure(figsize=(5,5))plt.plot(train_losses,label='Train',alpha=0.5)plt.plot(val_losses,label='Validation',alpha=0.5)plt.xlabel('Epochs')plt.ylabel('Losses')plt.legend()plt.grid(b=True)plt.title('CIFAR10Train/ValLossesOverEpoch')plt.show()num_epochs=20forepochinrange(num_epochs):train(cnn,epoch,num_epochs)validation(cnn,epoch,num_epochs)torch.save(cnn.state_dict(),'./data/Tutorial_3_CNN_Epoch_{}.pkl'.format(epoch+1))plot_losses(train_losses,val_losses)

测试模型

deftest(model):#declarethatthemodelisabouttoevaluatemodel.eval()correct=0total=0withtorch.no_grad():forimages,labelsintest_dataset:images=images.unsqueeze(0).to(device)#forwardpassoutputs=model(images)_,predicted=torch.max(outputs.data,1)total+=1correct+=(predicted==labels).sum().item()print("AccuracyofTestImages:%f%%"%(100*float(correct)/total))

经过图像数据增强。模型从60提升到了84。

测试模型在哪些类上表现良好,

class_correct=list(0.foriinrange(10))class_total=list(0.foriinrange(10))withtorch.no_grad():fordataintest_loader:images,labels=dataimages=images.to(device)labels=labels.to(device)outputs=cnn(images)_,predicted=torch.max(outputs,1)c=(predicted==labels).squeeze()foriinrange(4):label=labels[i]class_correct[label]+=c[i].item()class_total[label]+=1foriinrange(10):print('Accuracyof%5s:%2d%%'%(classes[i],100*class_correct[i]/class_total[i]))

往期精彩回顾适合初学者入门人工智能的路线及资料下载机器学习及深度学习笔记等资料打印机器学习在线手册深度学习笔记专辑《统计学习方法》的代码复现专辑AI基础下载机器学习的数学基础专辑黄海广老师《机器学习课程》课件合集本站qq群851320808,加入微信群请扫码:

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。