700字范文,内容丰富有趣,生活中的好帮手!
700字范文 > Pytorch自定义数据集(Custom Dataset)的读取方式

Pytorch自定义数据集(Custom Dataset)的读取方式

时间:2020-02-22 18:24:24

相关推荐

Pytorch自定义数据集(Custom Dataset)的读取方式

外部数据集的接入

相关模块:torchvision具体操作自定义数据集的基础方法:使用 Torchvision Transforms 结合 Pandas 使用 __getitem__()使用 Dataloader 读取自定义数据集Stanford Dogs 数据集自定义实例FaceLandmarks实例

相关模块:torchvision

torchvision 是独立于pytorch 之外的图像操作库

具体介绍详见:DrHW的文章

torchvision主要包括一下几个包:1

torchvision.datasets:几个常用视觉数据集,可以下载和加载这里主要的高级用法就是可以看源码如何自己写自己的Dataset的子类

这部分就是本文要介绍的重点torchvision.models: 流行的模型,例如 AlexNet, VGG, ResNet 和 Densenet 以及 与训练好的参数。torchvision.transforms: 常用的图像操作,例如:随机切割,旋转,数据类型转换,图像到tensor ,numpy 数组到tensor , tensor 到 图像等。torchvision.utils: 用于把形似(3 x H x W)的张量保存到硬盘中,给一个mini-batch的图像可以产生一个图像格网。

shape = (channel, height, weight)

具体操作

自定义数据集的基础方法:

引文2

"""inout pipline for custom dataset"""from torch.utils.data.dataset import Datasetclass CustomDataset(Dataset):def __init__(self):"""一些初始化过程写在这里"""# TODO# 1. Initialize file paths or a list of file names. passdef __getitem__(self, index):"""返回数据和标签,可以这样显示调用:img, label = MyCustomDataset.__getitem__(99)"""# TODO# 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).# 2. Preprocess the data (e.g. torchvision.Transform).# 3. Return a data pair (e.g. image and label).passdef __len__(self):"""返回所有数据的数量"""# You should change 9 to the total size of your dataset.return 9 # e.g. 9 is size of dataset

使用 Torchvision Transforms
方法一:

from torch.utils.data.dataset import Datasetfrom torchvision import transformsclass MyCustomDataset(Dataset):def __init__(self, ..., transforms=None):# stuff...self.transforms = transformsdef __getitem__(self, index):# stuff...data = # 一些读取的数据if self.transforms is not None:data = self.transforms(data)# 如果 transform 不为 None,则进行 transform 操作return (img, label)def __len__(self):return count if __name__ == \'__main__\':# 定义我们的 transforms (1)transformations = pose([transforms.CenterCrop(100), transforms.ToTensor()])# 创建 datasetcustom_dataset = MyCustomDataset(..., transformations)

方法二:

有些人不喜欢将transform写在Dataset外, 即在Dataset内定义transform

from torch.utils.data.dataset import Datasetfrom torchvision import transformsclass MyCustomDataset(Dataset):def __init__(self, ...):# stuff...# (2) 一种方法是单独定义 transformself.center_crop = transforms.CenterCrop(100)self.to_tensor = transforms.ToTensor()# (3) 或者写成下面这样 self.transformations = \pose([transforms.CenterCrop(100),transforms.ToTensor()])def __getitem__(self, index):# stuff...data = #一些读取的数据# 当第二次调用 transform 时,调用的是 __call__()data = self.center_crop(data) # (2)data = self.to_tensor(data) # (2)# 或者写成下面这样data = self.trasnformations(data) # (3)# 注意 (2) 和 (3) 中只需要实现一种return (img, label)def __len__(self):return countif __name__ == \'__main__\':custom_dataset = MyCustomDataset(...)

结合 Pandas 使用getitem()

另一种情况是 csv 文件中保存了我们需要的图像文件的像素值(比如有些 MNIST 教程就是这样的)。我们需要改动一下getitem() 函数。

class CustomDatasetFromCSV(Dataset):def __init__(self, csv_path, height, width, transforms=None):"""Args:csv_path (string): csv 文件路径height (int): 图像高度width (int): 图像宽度transform: transform 操作"""self.data = pd.read_csv(csv_path)self.labels = np.asarray(self.data.iloc[:, 0])self.height = heightself.width = widthself.transforms = transformdef __getitem__(self, index):single_image_label = self.labels[index]# 读取所有像素值,并将 1D array ([784]) reshape 成为 2D array ([28,28]) img_as_np = np.asarray(self.data.iloc[index][1:]).reshape(28,28).astype(\'uint8\')# 把 numpy array 格式的图像转换成灰度 PIL imageimg_as_img = Image.fromarray(img_as_np)img_as_img = img_as_img.convert(\'L\')# 将图像转换成 tensorif self.transforms is not None:img_as_tensor = self.transforms(img_as_img)# 返回图像及其 labelreturn (img_as_tensor, single_image_label)def __len__(self):return len(self.data.index)if __name__ == "__main__":transformations = pose([transforms.ToTensor()])custom_mnist_from_csv = \CustomDatasetFromCSV(\'../data/mnist_in_csv.csv\', 28, 28, transformations)

使用 Dataloader 读取自定义数据集

PyTorch 中的 Dataloader 只是调用getitem() 方法并组合成 batch,我们可以这样调用:

...if __name__ == "__main__":# 定义 transformstransformations = pose([transforms.ToTensor()])# 自定义数据集custom_mnist_from_csv = \CustomDatasetFromCSV(\'../data/mnist_in_csv.csv\',28, 28,transformations)# 定义 data loadermn_dataset_loader = torch.utils.data.DataLoader(dataset=custom_mnist_from_csv,batch_size=10,shuffle=False)for images, labels in mn_dataset_loader:# 将数据传给网络模型

需要注意的是使用多卡训练时,PyTorch dataloader 会将每个 batch 平均分配到各个 GPU。所以如果 batch size 过小,可能发挥不了多卡的效果。

Stanford Dogs 数据集自定义实例

from torch.utils.data.dataset import Datasetfrom torchvision import transformsclass MyDateset(Dataset):def __init__(self, file_folder, is_test=False, transform=None):self.img_folder_path = '../input/images/Images/'self.annotation_folder_path = '../input/annotations/Annotation/'self.file_folder = file_folderself.transform = transform#self.transform = poseself.is_test = is_testdef __getitem__(self, idx):file = self.file_folder[idx]img_path = self.img_folder_path + fileimg = Image.open(img_path).convert('RGB')if not self.is_test:annotation_path = self.annotation_folder_path + file.split('.')[0]with open(annotation_path) as f:annotation = f.read()xy = self.get_xy(annotation)box = torch.FloatTensor(list(xy))new_box = self.box_resize(box, img)if self.transform is not None:img = self.transform(img)return img, new_boxelse:if self.transform is not None:img = self.transform(img)return imgdef __len__(self):return len(self.file_folder)def get_xy(self, annotation):xmin = int(re.findall('(?<=<xmin>)[0-9]+?(?=</xmin>)', annotation)[0])xmax = int(re.findall('(?<=<xmax>)[0-9]+?(?=</xmax>)', annotation)[0])ymin = int(re.findall('(?<=<ymin>)[0-9]+?(?=</ymin>)', annotation)[0])ymax = int(re.findall('(?<=<ymax>)[0-9]+?(?=</ymax>)', annotation)[0])return xmin, ymin, xmax, ymaxdef show_box(self):file = random.choice(self.file_folder)annotation_path = self.annotation_folder_path + file.split('.')[0]img_box = Image.open(self.img_folder_path + file)with open(annotation_path) as f:annotation = f.read()draw = ImageDraw.Draw(img_box)xy = self.get_xy(annotation)print('bbox:', xy)draw.rectangle(xy=[xy[:2], xy[2:]])return img_boxdef box_resize(self, box, img, dims=(332, 332)):old_dims = torch.FloatTensor([img.width, img.height, img.width, img.height]).unsqueeze(0)new_box = box / old_dimsnew_dims = torch.FloatTensor([dims[1], dims[0], dims[1], dims[0]]).unsqueeze(0)new_box = new_box * new_dimsreturn new_box

FaceLandmarks实例

class FaceLandmarksDataset(Dataset):"""Face Landmarks dataset."""def __init__(self, csv_file, root_dir, transform=None):"""Args:csv_file (string): Path to the csv file with annotations.root_dir (string): Directory with all the images.transform (callable, optional): Optional transform to be appliedon a sample."""self.landmarks_frame = pd.read_csv(csv_file)self.root_dir = root_dirself.transform = transformdef __len__(self):return len(self.landmarks_frame)def __getitem__(self, idx):img_name = os.path.join(self.root_dir,self.landmarks_frame.iloc[idx, 0])image = io.imread(img_name)landmarks = self.landmarks_frame.iloc[idx, 1:].as_matrix()landmarks = landmarks.astype('float').reshape(-1, 2)sample = {'image': image, 'landmarks': landmarks}if self.transform:sample = self.transform(sample)return sample

参考文献:

yunjey的github代码pytorch官方教程数据集:Stanford Dogs Datasetpytorch中文网:PyTorch 中自定义数据集的读取方法小结

/yjphhw/p/9773333.html ↩︎

/yunjey/pytorch-tutorial/ ↩︎

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。