700字范文,内容丰富有趣,生活中的好帮手!
700字范文 > 【小白学习PyTorch教程】十七 PyTorch 中 数据集torchvision和torchtext

【小白学习PyTorch教程】十七 PyTorch 中 数据集torchvision和torchtext

时间:2021-01-08 05:38:23

相关推荐

【小白学习PyTorch教程】十七  PyTorch 中 数据集torchvision和torchtext

@Author:Runsen

对于PyTorch加载和处理不同类型数据,官方提供了torchvision和torchtext。

之前使用 torchDataLoader类直接加载图像并将其转换为张量。现在结合torchvision和torchtext介绍torch中的内置数据集

Torchvision 中的数据集

MNIST

MNIST 是一个由标准化和中心裁剪的手写图像组成的数据集。它有超过 60,000 张训练图像和 10,000 张测试图像。这是用于学习和实验目的最常用的数据集之一。要加载和使用数据集,使用以下语法导入:torchvision.datasets.MNIST()

Fashion MNIST

Fashion MNIST数据集类似于MNIST,但该数据集包含T恤、裤子、包包等服装项目,而不是手写数字,训练和测试样本数分别为60,000和10,000。要加载和使用数据集,使用以下语法导入:torchvision.datasets.FashionMNIST()

CIFAR

CIFAR数据集有两个版本,CIFAR10和CIFAR100。CIFAR10 由 10 个不同标签的图像组成,而 CIFAR100 有 100 个不同的类。这些包括常见的图像,如卡车、青蛙、船、汽车、鹿等。

torchvision.datasets.CIFAR10()torchvision.datasets.CIFAR100()

COCO

COCO数据集包含超过 100,000 个日常对象,如人、瓶子、文具、书籍等。这个图像数据集广泛用于对象检测和图像字幕应用。下面是可以加载 COCO 的位置​​:torchvision.datasets.CocoCaptions()

EMNIST

EMNIST数据集是 MNIST 数据集的高级版本。它由包括数字和字母的图像组成。如果您正在处理基于从图像中识别文本的问题,EMNIST是一个不错的选择。下面是可以加载 EMNIST的位置​​::torchvision.datasets.EMNIST()

IMAGE-NET

ImageNet 是用于训练高端神经网络的旗舰数据集之一。它由分布在 10,000 个类别中的超过 120 万张图像组成。通常,这个数据集加载在高端硬件系统上,因为单独的 CPU 无法处理这么大的数据集。下面是加载 ImageNet 数据集的类:torchvision.datasets.ImageNet()

Torchtext 中的数据集

IMDB

IMDB是一个用于情感分类的数据集,其中包含一组 25,000 条高度极端的电影评论用于训练,另外 25,000 条用于测试。使用以下类加载这些数据torchtext:torchtext.datasets.IMDB()

WikiText2

WikiText2语言建模数据集是一个超过 1 亿个标记的集合。它是从维基百科中提取的,并保留了标点符号和实际的字母大小写。它广泛用于涉及长期依赖的应用程序。可以从torchtext以下位置加载此数据:torchtext.datasets.WikiText2()

除了上述两个流行的数据集,torchtext库中还有更多可用的数据集,例如 SST、TREC、SNLI、MultiNLI、WikiText-2、WikiText103、PennTreebank、Multi30k 等。

深入查看 MNIST 数据集

MNIST 是最受欢迎的数据集之一。现在我们将看到 PyTorch 如何从 pytorch/vision 存储库加载 MNIST 数据集。让我们首先下载数据集并将其加载到名为 的变量中data_train

from torchvision.datasets import MNIST# Download MNIST data_train = MNIST('~/mnist_data', train=True, download=True)import matplotlib.pyplot as pltrandom_image = data_train[0][0]random_image_label = data_train[0][1]# Print the Image using Matplotlibplt.imshow(random_image)print("The label of the image is:", random_image_label)

DataLoader加载MNIST

下面我们使用DataLoader该类加载数据集,如下所示。

import torchfrom torchvision import transformsdata_train = torch.utils.data.DataLoader(MNIST('~/mnist_data', train=True, download=True, transform = pose([transforms.ToTensor()])),batch_size=64,shuffle=True)for batch_idx, samples in enumerate(data_train):print(batch_idx, samples)

CUDA加载

我们可以启用 GPU 来更快地训练我们的模型。现在让我们使用CUDA加载数据时可以使用的(GPU 支持 PyTorch)的配置。

device = "cuda" if torch.cuda.is_available() else "cpu"kwargs = {'num_workers': 1, 'pin_memory': True} if device=='cuda' else {}train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('/files/', train=True, download=True),batch_size=batch_size_train, **kwargs)test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('files/', train=False, download=True),batch_size=batch_size, **kwargs)

ImageFolder

ImageFolder是一个通用数据加载器类torchvision,可帮助加载自己的图像数据集。处理一个分类问题并构建一个神经网络来识别给定的图像是apple还是orange。要在 PyTorch 中执行此操作,第一步是在默认文件夹结构中排列图像,如下所示:

root├── orange│ ├── orange_image1.png│ └── orange_image1.png├── apple│ └── apple_image1.png│ └── apple_image2.png│ └── apple_image3.png

可以使用ImageLoader该类加载所有这些图像。

torchvision.datasets.ImageFolder(root, transform)

transforms

PyTorch 转换定义了简单的图像转换技术,可将整个数据集转换为独特的格式。

如果是一个包含不同分辨率的不同汽车图片的数据集,在训练时,我们训练数据集中的所有图像都应该具有相同的分辨率大小。如果我们手动将所有图像转换为所需的输入大小,则很耗时,因此我们可以使用transforms;使用几行 PyTorch 代码,我们数据集中的所有图像都可以转换为所需的输入大小和分辨率。

现在让我们加载 CIFAR10torchvision.datasets并应用以下转换:

将所有图像调整为 32×32对图像应用中心裁剪变换将裁剪后的图像转换为张量标准化图像

import torchimport torchvisionimport torchvision.transforms as transformsimport matplotlib.pyplot as pltimport numpy as nptransform = pose([# resize 32×32transforms.Resize(32),# center-crop裁剪变换transforms.CenterCrop(32),# to-tensortransforms.ToTensor(),# normalize 标准化transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,shuffle=False)

在 PyTorch 中创建自定义数据集

下面将创建一个由数字和文本组成的简单自定义数据集。需要封装Dataset 类中的__getitem__()__len__()方法。

__getitem__()方法通过索引返回数据集中的选定样本。__len__()方法返回数据集的总大小。

下面是曾经封装FruitImagesDataset数据集的代码,基本是比较好的 PyTorch 中创建自定义数据集的模板。

import osimport numpy as npimport cv2import torchimport matplotlib.patches as patchesimport albumentations as Afrom albumentations.pytorch.transforms import ToTensorV2from matplotlib import pyplot as pltfrom torch.utils.data import Datasetfrom xml.etree import ElementTree as etfrom torchvision import transforms as torchtransclass FruitImagesDataset(torch.utils.data.Dataset):def __init__(self, files_dir, width, height, transforms=None):self.transforms = transformsself.files_dir = files_dirself.height = heightself.width = widthself.imgs = [image for image in sorted(os.listdir(files_dir))if image[-4:] == '.jpg']self.classes = ['_','apple', 'banana', 'orange']def __getitem__(self, idx):img_name = self.imgs[idx]image_path = os.path.join(self.files_dir, img_name)# reading the images and converting them to correct size and colorimg = cv2.imread(image_path)img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)img_res = cv2.resize(img_rgb, (self.width, self.height), cv2.INTER_AREA)# diving by 255img_res /= 255.0# annotation fileannot_filename = img_name[:-4] + '.xml'annot_file_path = os.path.join(self.files_dir, annot_filename)boxes = []labels = []tree = et.parse(annot_file_path)root = tree.getroot()# cv2 image gives size as height x widthwt = img.shape[1]ht = img.shape[0]# box coordinates for xml files are extracted and corrected for image size givenfor member in root.findall('object'):labels.append(self.classes.index(member.find('name').text))# bounding boxxmin = int(member.find('bndbox').find('xmin').text)xmax = int(member.find('bndbox').find('xmax').text)ymin = int(member.find('bndbox').find('ymin').text)ymax = int(member.find('bndbox').find('ymax').text)xmin_corr = (xmin / wt) * self.widthxmax_corr = (xmax / wt) * self.widthymin_corr = (ymin / ht) * self.heightymax_corr = (ymax / ht) * self.heightboxes.append([xmin_corr, ymin_corr, xmax_corr, ymax_corr])# convert boxes into a torch.Tensorboxes = torch.as_tensor(boxes, dtype=torch.float32)# getting the areas of the boxesarea = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])# suppose all instances are not crowdiscrowd = torch.zeros((boxes.shape[0],), dtype=torch.int64)labels = torch.as_tensor(labels, dtype=torch.int64)target = {}target["boxes"] = boxestarget["labels"] = labelstarget["area"] = areatarget["iscrowd"] = iscrowd# image_idimage_id = torch.tensor([idx])target["image_id"] = image_idif self.transforms:sample = self.transforms(image=img_res,bboxes=target['boxes'],labels=labels)img_res = sample['image']target['boxes'] = torch.Tensor(sample['bboxes'])return img_res, targetdef __len__(self):return len(self.imgs)def get_transform(train):if train:return pose([A.HorizontalFlip(0.5),ToTensorV2(p=1.0)], bbox_params={'format': 'pascal_voc', 'label_fields': ['labels']})else:return pose([ToTensorV2(p=1.0)], bbox_params={'format': 'pascal_voc', 'label_fields': ['labels']})files_dir = '../input/fruit-images-for-object-detection/train_zip/train'test_dir = '../input/fruit-images-for-object-detection/test_zip/test'dataset = FruitImagesDataset(train_dir, 480, 480)

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。