700字范文,内容丰富有趣,生活中的好帮手!
700字范文 > PyTorch 中自定义数据集的读取方法

PyTorch 中自定义数据集的读取方法

时间:2019-02-17 08:56:49

相关推荐

PyTorch 中自定义数据集的读取方法

显然我们在学习深度学习时,不能只局限于通过使用官方提供的MNSIT、CIFAR-10、CIFAR-100这样的数据集,很多时候我们还是需要根据自己遇到的实际问题自己去搜集数据,然后制作数据集(收集数据集的方法有很多,这里就不过多的展开了)。这里只介绍数据集的读取。

自定义数据集的方法

首先创建一个Dataset类

在代码中:

definit() 一些初始化的过程写在这个函数下

deflen() 返回所有数据的数量,比如我们这里将数据划分好之后,这里仅仅返回的是被处理后的关系

defgetitem() 回数据和标签

补充代码

上述已经将框架打出来了,接下来就是将框架填充完整就行了,下面是完整的代码,代码的解释说明我也已经写在其中了

# -*- coding: utf-8 -*-# @Author : 胡子旋# @Email :1017190168@import torchimport os,globimport visdomimport timeimport torchvisionimport random,csvfrom torch.utils.data import Dataset,DataLoaderfrom torchvision import transformsfrom PIL import Imageclass pokemom(Dataset):def __init__(self,root,resize,mode,):super(pokemom,self).__init__()# 保存参数self.root=rootself.resize=resize# 给每一个类做映射self.name2label={} # "squirtle":0 ,"pikachu":1……for name in sorted(os.listdir(os.path.join(root))):# 过滤掉文件夹if not os.path.isdir(os.path.join(root,name)):continue# 保存在表中;将最长的映射作为最新的元素的label的值self.name2label[name]=len(self.name2label.keys())print(self.name2label)# 加载文件self.images,self.labels=self.load_csv('images.csv')# 裁剪数据if mode=='train':self.images=self.images[:int(0.6*len(self.images))] # 将数据集的60%设置为训练数据集合self.labels=self.labels[:int(0.6*len(self.labels))] # label的60%分配给训练数据集合elif mode=='val':self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))] # 从60%-80%的地方self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]else:self.images = self.images[int(0.8 * len(self.images)):] # 从80%的地方到最末尾self.labels = self.labels[int(0.8 * len(self.labels)):]# image+label 的路径def load_csv(self,filename):# 将所有的图片加载进来# 如果不存在的话才进行创建if not os.path.exists(os.path.join(self.root,filename)):images=[]for name in self.name2label.keys():images+=glob.glob(os.path.join(self.root,name,'*.png'))images+=glob.glob(os.path.join(self.root, name, '*.jpg'))images += glob.glob(os.path.join(self.root, name, '*.jpeg'))print(len(images),images)# 1167 'pokeman\\bulbasaur\\00000000.png'# 将文件以上述的格式保存在csv文件内random.shuffle(images)with open(os.path.join(self.root,filename),mode='w',newline='') as f:writer=csv.writer(f)for img in images: # 'pokeman\\bulbasaur\\00000000.png'name=img.split(os.sep)[-2]label=self.name2label[name]writer.writerow([img,label])print("write into csv into :",filename)# 如果存在的话就直接的跳到这个地方images,labels=[],[]with open(os.path.join(self.root, filename)) as f:reader=csv.reader(f)for row in reader:# 接下来就会得到 'pokeman\\bulbasaur\\00000000.png' 0 的对象img,label=row# 将label转码为int类型label=int(label)images.append(img)labels.append(label)# 保证images和labels的长度是一致的assert len(images)==len(labels)return images,labels# 返回数据的数量def __len__(self):return len(self.images) # 返回的是被裁剪之后的关系def denormalize(self, x_hat):mean = [0.485, 0.456, 0.406]std = [0.229, 0.224, 0.225]mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)std = torch.tensor(std).unsqueeze(1).unsqueeze(1)# print(mean.shape, std.shape)x = x_hat * std + meanreturn x# 返回idx的数据和当前图片的labeldef __getitem__(self,idx):# idex-[0-总长度]# retrun images,labels# 将图片,label的路径取出来# 得到的img是这样的一个类型:'pokeman\\bulbasaur\\00000000.png'# 然而label得到的则是 0,1,2 这样的整形的格式img,label=self.images[idx],self.labels[idx]tf=pose([lambda x:Image.open(x).convert('RGB'), # 将t图片的路径转换可以处理图片数据# 进行数据加强transforms.Resize((int(self.resize*1.25),int(self.resize*1.25))),# 随机旋转transforms.RandomRotation(15), # 设置旋转的度数小一些,否则的话会增加网络的学习难度# 中心裁剪transforms.CenterCrop(self.resize), # 此时:既旋转了又不至于导致图片变得比较的复杂transforms.ToTensor(),transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])])img=tf(img)label=torch.tensor(label)return img,labeldef main():# 验证工作viz=visdom.Visdom()db=pokemom('pokeman',64,'train') # 这里可以改变大小 224->64,可以通过visdom进行查看# 可视化样本x,y=next(iter(db))print('sample:',x.shape,y.shape,y)viz.image(db.denormalize(x),win='sample_x',opts=dict(title='sample_x'))# 加载batch_size的数据loader=DataLoader(db,batch_size=32,shuffle=True,num_workers=8)for x,y in loader:viz.images(db.denormalize(x),nrow=8,win='batch',opts=dict(title='batch'))viz.text(str(y.numpy()),win='label',opts=dict(title='batch-y'))# 每一次加载后,休息10stime.sleep(10)if __name__ == '__main__':main()

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。