700字范文,内容丰富有趣,生活中的好帮手!
700字范文 > Pytorch学习笔记7——自定义数据集

Pytorch学习笔记7——自定义数据集

时间:2023-04-04 11:31:32

相关推荐

Pytorch学习笔记7——自定义数据集

Pytorch学习笔记7——自定义数据集

1.读取数据

首先继承自torch.utils.data.Dataset

重写len与getitem

train就用train数据集,test就用test数据集。

自定义数据集的读取

import torchimport os,globimport random,csvimport torchvision.datasetsfrom torch.utils.data import Dataset,DataLoaderfrom torchvision import transformsfrom PIL import Imagefrom torchvision.transforms import InterpolationModeclass Pokemon(Dataset):def __init__(self,root,resize,mode):super(Pokemon, self).__init__()self.root=rootself.resize=resizeself.name2label={}#字典表达映射关系label‘‘sq..’’:0for name in sorted(os.listdir(os.path.join(root))):#遍历根目录下所有文件假if not os.path.isdir(os.path.join(root,name)):#判断是否是文件夹continueself.name2label[name]=len(self.name2label.keys())print(self.name2label)self.images,self.labels=self.load_csv('images.csv')#得到的是images的路径,和对应的数字标签if mode=='train':self.images=self.images[:int(0.6*len(self.images))]self.labels=self.labels[:int(0.6*len(self.labels))]elif mode=='val':#20self.images=self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]self.labels=self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]else:self.images=self.images[int(0.8*len(self.images)):]self.labels=self.labels[int(0.8*len(self.labels)):]#创建数据对:path+labeldef load_csv(self,filename):if not os.path.exists(os.path.join(self.root,filename)):#如果已经有了,不需要再创建images=[]for name in self.name2label.keys():#key:valueimages+=glob.glob(os.path.join(self.root,name,'*.png'))#glob方法获取目录下所有满足的文件images += glob.glob(os.path.join(self.root, name, '*.jpg'))images += glob.glob(os.path.join(self.root, name, '*.jpeg'))#1165,pokeman/bulbasaur/00001.png#对应关系保存到csvrandom.shuffle(images)with open(os.path.join(self.root,filename),mode='w',newline='') as f:writer=csv.writer(f)for img in images:#pokeman/bulbasaur/00001.pngname=img.split(os.sep)[-2]label=self.name2label[name]#字典根据key找value存入labelwriter.writerow([img,label])#pokeman/bulbasaur/00001.png,0print('writen into csvfile:',filename)#read from csvimages,labels=[],[]with open(os.path.join(self.root,filename)) as f:reader=csv.reader(f)for row in reader:img,label=rowlabel=int(label)images.append(img)labels.append(label)assert len(images)==len(labels)return images,labelsdef __len__(self):return len(self.images)def denormalize(self,x_hat):#逆归一化已回复图片视觉效果mean = [0.845, 0.456, 0.406]std = [0.229, 0.224, 0.225]#x_hat=(x-mean)/std#x=x_hat*std+mean#x:[c,h,w]#mean:[3]=>[3,1,1]mean=torch.tensor(mean).unsqueeze(1).unsqueeze(1)std = torch.tensor(std).unsqueeze(1).unsqueeze(1)print(mean.shape,std.shape)x=x_hat*std+meanreturn xdef __getitem__(self, idx):#self.images,self.labels#idx-[0-len(images)]img,label=self.images[idx],self.labels[idx]#从csv获得的图片路径与labeltf=pose([lambda x:Image.open(x).convert('RGB'),#string path=>image datatransforms.Resize((int(self.resize*1.25),int(self.resize*1.25))),transforms.RandomRotation(15),transforms.CenterCrop(self.resize),transforms.ToTensor(),transforms.Normalize(mean=[0.845,0.456,0.406],std=[0.229,0.224,0.225])])img=tf(img)label=torch.tensor(label)return img,labelif __name__=='__main__':import visdomimport timeviz=visdom.Visdom()# tf = pose([#transforms.Resize((64,64)),#transforms.ToTensor(),# ])# db=torchvision.datasets.ImageFolder(root='/home/lizheng/Study/yolov5-5.0/pytorch1/pokemon/pokeman',transform=tf)# loader=DataLoader(db,batch_size=32,shuffle=True)#一行代码完成数据集加载工作# print(db.class_to_idx)# for x,y in loader:#viz.images(x,nrow=8,win='batch',opts=dict(title='batch'))#viz.text(str(y.numpy()),win='label',opts=dict(title='batch-y'))##time.sleep(10)db=Pokemon('pokemon/pokeman',64,'train')x,y=next(iter(db))#利用迭代器输入路径获得具体图像,得到第一个样本,调用时自动使用getitem函数,此时x是图像print('sample:',x.shape,y.shape,y)viz.image(db.denormalize(x),win='sample_x',opts=dict(title='sample_x'))loader=DataLoader(db,batch_size=32,shuffle=True,num_workers=8)#不想一个一个取,想一个batch一个batch取for x,y in loader:viz.images(db.denormalize(x),nrow=8,win='batch',opts=dict(title='batch'))viz.text(str(y.numpy()),win='label',opts=dict(title='batch-y'))time.sleep(10)

自定义神经网络模型的搭建

import torchfrom torch import nnfrom torch.nn import functional as Fclass ResBlk(nn.Module):def __init__(self,ch_in,ch_out,stride=1):''':param ch_in::param ch_out:'''super(ResBlk,self).__init__()#super方法避免父类的init函数被替换self.conv1=nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=stride,padding=1)self.bn1=nn.BatchNorm2d(ch_out)self.conv2=nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1)self.bn2=nn.BatchNorm2d(ch_out)self.extra=nn.Sequential()if ch_out!=ch_in:self.extra=nn.Sequential(nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=stride),#Sequential里面加入的都是类,因此可以自己写,然后加入nn.BatchNorm2d(ch_out)#这些类在调用时会自动调用forward函数,记得要写return)def forward(self,x):''':param x:[b,ch,h,w]:return:'''out=F.relu(self.bn1(self.conv1(x)))out=self.bn2(self.conv2(out))#short cut#extra module:[b,ch_in,h,w] with [b,ch_out,h,w]#element-wise add:\out=self.extra(x)+outreturn outclass ResNet18(nn.Module):def __init__(self,num_class):super(ResNet18,self).__init__()self.conv1=nn.Sequential(nn.Conv2d(3,16,kernel_size=3,stride=3,padding=0),nn.BatchNorm2d(16))#followed 4 blocks#[b,16,h,w]=>[b,32,h,w]self.blk1=ResBlk(16,32,stride=3)#增多通道,减少长宽,避免数据量过大#[b,32,h,w]=>[b,64,h,w]self.blk2 = ResBlk(32, 64,stride=3)#[b,64,h,w]=>[b,128,h,w]self.blk3 = ResBlk(64,128,stride=2)# [b,128,h,w]=>[b,256,h,w]self.blk4 = ResBlk(128,256,stride=2)#[b,256,7,7]self.outlayer=nn.Linear(256*3*3,num_class)#输入512通道,输出10通道def forward(self,x):''':param x::return:'''x=F.relu(self.conv1(x))#[b,64,h,w]=>[b,1024,h,w]x=self.blk1(x)x=self.blk2(x)x=self.blk3(x)x=self.blk4(x)# print('after conv:',x.shape)#[b,512,2,2]# # [b,512,h,w]=>[b,512,2,2]# x=F.adaptive_avg_pool2d(x,[1,1])# # print('after pool:',x.shape)x=x.view(x.size(0),-1)x=self.outlayer(x)return xif __name__=='__main__':blk=ResBlk(64,128)tmp=torch.randn(2,64,224,224)out=blk(tmp)print('block',out.shape)model=ResNet18(5)#5分类tmp=torch.randn(2,3,224,224)out=model(tmp)print('resnet:',out.shape)p=sum(map(lambda p:p.numel(),model.parameters()))print('parameters size:',p)

自定义数据集的训练与测试:

import torchfrom torch import optim,nnimport visdomimport torchvisionfrom torch.utils.data import DataLoaderfrom pytorch1.pt3 import Pokemonfrom resnet import ResNet18batchsz=32lr=1e-3epochs=10device=torch.device('cuda')torch.manual_seed(1234)train_db=Pokemon('pokemon/pokeman',224,mode='train')#initial函数初始化训练集,val_db=Pokemon('pokemon/pokeman',224,mode='val')test_db=Pokemon('pokemon/pokeman',224,mode='test')train_loader=DataLoader(train_db,batch_size=batchsz,shuffle=True,num_workers=4)#loader里获得的都是真正是图片val_loader=DataLoader(val_db,batch_size=batchsz,num_workers=4)test_loader=DataLoader(test_db,batch_size=batchsz,num_workers=4)viz = visdom.Visdom()def evaluate(model,loader):correct=0total=len(loader.dataset)for x,y in loader:x,y=x.to(device),y.to(device)with torch.no_grad():logits=model(x)pred=logits.argmax(dim=1)correct+=torch.eq(pred,y).sum().float().item()return correct/totalif __name__=='__main__':model=ResNet18(5).to(device)optimizer=optim.Adam(model.parameters(),lr=lr)criteon=nn.CrossEntropyLoss()best_acc,best_epoch=0,0global_step=0viz.line([0],[-1],win='loss',opts=dict(title='loss'))viz.line([0],[-1],win='val_acc',opts=dict(title='val_acc'))#清空操作for epoch in range(epochs):for step,(x,y) in enumerate(train_loader):#x:[b,3,224,224],y:[b]x,y=x.to(device),y.to(device)logits=model(x)loss=criteon(logits,y)optimizer.zero_grad()loss.backward()optimizer.step()viz.line([loss.item()], [global_step], win='loss', update='append')global_step+=1if epoch%2==0:val_acc=evaluate(model,val_loader)if val_acc>best_acc:best_epoch=epochbest_acc=val_acctorch.save(model.state_dict(),'best.mdl')viz.line([val_acc], [global_step], win='val_acc',update='append')print('best acc:',best_acc,'best_epoch:',best_epoch)model.load_state_dict(torch.load('best.mdl'))print('loaded from ckpt!')test_acc=evaluate(model,test_loader)print('test acc:',test_acc)

实验效果:

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。