700字范文,内容丰富有趣,生活中的好帮手!
700字范文 > Pytorch实现mnist手写数字识别

Pytorch实现mnist手写数字识别

时间:2024-05-27 21:42:27

相关推荐

Pytorch实现mnist手写数字识别

/6/29

Hey,突然想起来之前做的一个入门实验,用pytorch实现mnist手写数字识别。可以在这个基础上增加网络层数,或是尝试用不同的数据集,去实现不一样的功能。

Mnist数据集如图:

代码如下:

importtorchimporttorch.nnasnnimporttorch.utils.dataasDataimporttorchvision#数据库模块importmatplotlib.pyplotasplttorch.manual_seed(1)#reproducible#HyperParametersEPOCH=1#训练整批数据多少次,为了节约时间,我们只训练一次BATCH_SIZE=50LR=0.001#学习率DOWNLOAD_MNIST=True#如果你已经下载好了mnist数据就写上False#Mnist手写数字train_data=torchvision.datasets.MNIST(root='./mnist/',#保存或者提取位置train=True,#thisistrainingdatatransform=torchvision.transforms.ToTensor(),#转换PIL.Imageornumpy.ndarray成#torch.FloatTensor(CxHxW),训练的时候normalize成[0.0,1.0]区间download=DOWNLOAD_MNIST,#没下载就下载,下载了就不用再下了)test_data=torchvision.datasets.MNIST(root='./mnist/',train=False)#批训练50samples,1channel,28x28(50,1,28,28)train_loader=Data.DataLoader(dataset=train_data,batch_size=BATCH_SIZE,shuffle=True)#为了节约时间,我们测试时只测试前2000个test_x=torch.unsqueeze(test_data.test_data,dim=1).type(torch.FloatTensor)[:2000]/255.#shapefrom(2000,28,28)to(2000,1,28,28),valueinrange(0,1)test_y=test_data.test_labels[:2000]classCNN(nn.Module):def__init__(self):super(CNN,self).__init__()self.conv1=nn.Sequential(#inputshape(1,28,28)nn.Conv2d(in_channels=1,#inputheightout_channels=16,#n_filterskernel_size=5,#filtersizestride=1,#filtermovement/steppadding=2,#如果想要con2d出来的图片长宽没有变化,padding=(kernel_size-1)/2当stride=1),#outputshape(16,28,28)nn.ReLU(),#activationnn.MaxPool2d(kernel_size=2),#在2x2空间里向下采样,outputshape(16,14,14))self.conv2=nn.Sequential(#inputshape(16,14,14)nn.Conv2d(16,32,5,1,2),#outputshape(32,14,14)nn.ReLU(),#activationnn.MaxPool2d(2),#outputshape(32,7,7))self.out=nn.Linear(32*7*7,10)#fullyconnectedlayer,output10classesdefforward(self,x):x=self.conv1(x)x=self.conv2(x)x=x.view(x.size(0),-1)#展平多维的卷积图成(batch_size,32*7*7)output=self.out(x)returnoutputcnn=CNN()print(cnn)#netarchitecture"""CNN((conv1):Sequential((0):Conv2d(1,16,kernel_size=(5,5),stride=(1,1),padding=(2,2))(1):ReLU()(2):MaxPool2d(size=(2,2),stride=(2,2),dilation=(1,1)))(conv2):Sequential((0):Conv2d(16,32,kernel_size=(5,5),stride=(1,1),padding=(2,2))(1):ReLU()(2):MaxPool2d(size=(2,2),stride=(2,2),dilation=(1,1)))(out):Linear(1568->10))"""optimizer=torch.optim.Adam(cnn.parameters(),lr=LR)#optimizeallcnnparametersloss_func=nn.CrossEntropyLoss()#thetargetlabelisnotone-hotted#trainingandtestingforepochinrange(EPOCH):forstep,(b_x,b_y)inenumerate(train_loader):#分配batchdata,normalizexwheniteratetrain_loaderoutput=cnn(b_x)#cnnoutputloss=loss_func(output,b_y)#crossentropylossoptimizer.zero_grad()#cleargradientsforthistrainingsteploss.backward()#backpropagation,computegradientsoptimizer.step()#applygradients"""...Epoch:0|trainloss:0.0306|testaccuracy:0.97Epoch:0|trainloss:0.0147|testaccuracy:0.98Epoch:0|trainloss:0.0427|testaccuracy:0.98Epoch:0|trainloss:0.0078|testaccuracy:0.98"""test_output=cnn(test_x[:10])pred_y=torch.max(test_output,1)[1].data.numpy().squeeze()print(pred_y,'predictionnumber')print(test_y[:10].numpy(),'realnumber')"""[7210414959]predictionnumber[7210414959]realnumber"""

这个项目还是很有意思,对于初学者可以先试着对32-60行进行修改,增加网络层数。看看最后效果如何。

九层之台,起于累土。那天看到一句话,一个人把自己的事情做好,已经很不容易了。现在回想起之前安安静静在实验室的日子感觉很遥远,这半年来总是有各种各样的烦心事儿,也少了很多可以静下心来安静学习的时间。也许这就是生活吧C'est La Vie。我们总是要迎接挑战的,虽然没法回学习但是在智星云组用的GPU也是一样的好用,环境都是配置好了的,用来做实验非常节省时间和精力。有同样需求的朋友可以参考:智星云官网:http://www.ai-/,淘宝店:/公众号: 智星AI,

最后再唠叨两句,明天就是6月的最后一天了,眼看着就要过去一半了,岁月不居,时节如流。通过这次疫情也让我深刻的认识到管理好自己的时间是多么的重要。往者不可谏,来者犹可追。

PEACE

参考资料:

/docs/stable/index.html

https://morvanzhou.github.io/tutorials/machine-learning/torch/

http://www.planetb.ca/syntax-highlight-word

http://www.ai-/

/

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。