700字范文,内容丰富有趣,生活中的好帮手!
700字范文 > Pytorch自定义数据集

Pytorch自定义数据集

时间:2023-06-07 02:45:51

相关推荐

Pytorch自定义数据集

简述

Pytorch自定义数据集方法,应该是用pytorch做算法的最基本的东西。

往往网络上给的demo都是基于torch自带的MNIST的相关类。所以,为了解决使用其他的数据集,在查阅了torch关于MNIST数据集的源码之后,很容易就可以推广到了我们自己需要的代码上。

具体操作如下:

准备工作

需要导入一些包。

from torch.utils.data import Dataset, DataLoader

再自定义一个用于当训练集合的类。

class TrainSet(Dataset):def __init__(self, X, Y):# 定义好 image 的路径self.X, self.Y = X, Ydef __getitem__(self, index):return self.X[index], self.Y[index]def __len__(self):return len(self.X)

数据预处理

之后,假设你的训练集合为[X,Y],其中X是训练数据,Y是对应的数据的标签。

首先,需要知道的是,torch能处理的数据只能是torch.Tensor,所以有必要将其他数据转换为torch.Tensor

常见的有几种数据:

np.ndarrayPIL.Image

如果是图片数据,其实也有多种情况,根据数据维度不同,有些是二维图,有些是三维图(通俗来讲,就是黑白图和彩图)。

所以,我先按照数据类型的模式将一遍,再补充关于图片的处理。

np.ndarray

np.ndarray是非常常见的格式,转成Tensor也非常简单。

torch.Tensor(array)

这样代码的返回格式就是一个Tensor

PIL.Image

import torchvision.transforms as transformstransforms.ToTensor()(image)

这样代码的返回格式就是一个Tensor

关于图片

彩色的三维图: 上面方法就已经完成了对应的数据处理的步骤灰白或者是二值的二维图:就需要将数据增加一个维度了(因为往往关于图片,所用到的算法都是包括了卷积的步骤,所以要求增加一个维度)

具体操作如下: 明显,torch.Tensor(X)这样的步骤,其实是重复了上面的将np.ndarray转成torch.Tensor的步骤。同理可以换成上面的关于PIL.Image的方法

X_tensor = torch.unsqueeze(torch.Tensor(X), 1)Y_tensor = torch.unsqueeze(torch.Tensor(Y), 1)

导入数据

建立自己的数据集。

mydataset = TrainSet(X_tensor, Y_tensor)

再把自己的数据集导入到数据加载器上:

batch_size表示用将原数据拆分之后,每batch_size个数据作为一组数据被调用。shuffle表示数据是否被洗牌(即刷新顺序,避免训练的时候多次调用结果都遇到同一batch,从而避免误差)

train_loader = DataLoader(mydataset, batch_size=10, shuffle=True)

使用的方式也非常简单:

for step, (x, y) in enumerate(train_loader):

这里的x,y就是每个batch所处理的数据。

另外,附上一个我常用的读取自定义图片的dataset类

main函数部分是对数据集做测试。

import torch.utils.data as dataimport globimport osimport torchvision.transforms as transformsfrom PIL import Imageimport matplotlib.pyplot as pltimport numpy as npimport torchimport piexifimport imghdrclass MyDataset(data.Dataset):def __init__(self, path, Train=True, Len=-1, resize=-1, img_type='png', remove_exif=False):if resize != -1:transform = pose([transforms.Resize(resize),transforms.CenterCrop(resize),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))# transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])else:transform = pose([transforms.ToTensor(),])img_format = '*.%s' % img_typeif remove_exif:for name in glob.glob(os.path.join(path, img_format)):try:piexif.remove(name) # 去除exifexcept Exception:continue# imghdr.what(img_path) 判断是否为损坏图片if Len == -1:self.dataset = [np.array(transform(Image.open(name).convert("RGB"))) for name inglob.glob(os.path.join(path, img_format)) if imghdr.what(name)]else:self.dataset = [np.array(transform(Image.open(name).convert("RGB"))) for name inglob.glob(os.path.join(path, img_format))[:Len] if imghdr.what(name)]self.dataset = np.array(self.dataset)self.dataset = torch.Tensor(self.dataset)self.Train = Traindef __len__(self):return len(self.dataset)def __getitem__(self, idx):return self.dataset[idx]if __name__ == '__main__':path = r'D:\Software\DataSet\faces'dataset = MyDataset(path=path, resize=96, Len=10, img_type='jpg')print(len(dataset))plt.imshow(dataset[0].numpy().transpose(1, 2, 0) * 0.5 + 0.5)plt.show()print(dataset[0].max(), dataset[0].min())

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。