700字范文,内容丰富有趣,生活中的好帮手!
700字范文 > Pytorch教程(十六):FashionMNIST数据集DataSet DataLoader

Pytorch教程(十六):FashionMNIST数据集DataSet DataLoader

时间:2023-04-09 22:50:10

相关推荐

Pytorch教程(十六):FashionMNIST数据集DataSet DataLoader

torchvision.datasets

由于MNIST数据集太简单,简单的网络就可以达到99%以上的top one准确率,也就是说在这个数据集上表现较好的网络,在别的任务上表现不一定好。因此zalando research的工作人员建立了fashion mnist数据集,该数据集由衣服、鞋子等服饰组成,包含70000张图像,其中60000张训练图像加10000张测试图像,图像大小为28x28,单通道,共分10个类,如下图,每3行表示一个类。

所以我们通过torchvison来处理FashionMNIST数据集:

import torchimport torchvisionimport torchvision.transforms as transformstrain_set = torchvision.datasets.FashionMNIST(root = './data/FasionMNIST', # 将数据保存在本地什么位置train=True, # 我们希望数据用于训练集,其中6万张图片用作训练数据,1万张图片用于测试数据download=True, # 如果目录下没有文件,则自动下载transform=pose([transforms.ToTensor()]) # 我们将数据转为Tensor类型)

这样我们就完成了FashionMNIST数据的提取和转换。

如果这个过程中报错:ImportError: IProgress not found. Please update jupyter and ipywidgets.。一般是jupyter的版本有些低了,可能是你默认的环境,所以重装以下就好了:

# 可以先用你的环境 conda activate xx# 卸载jupyter:pip install --upgrade jupyter

访问单独某个训练数据:

torchvision.dataloader

dataloader使我们能够访问数据并提供查询功能。

train_loader = torch.utils.data.DataLoader(train_set, batch_size=10)

通过train_loader方式得到的batch包含图像的张量是4维的张量,形状是[10, 1, 28, 28],这告诉我们有10个图像,他们都有1个单独的颜色通道,高度宽度都是28;对于包含标签的张量,他的长度是10,每10个图像为一批数据。

现在让我们看看如何使用torchvison.utils.make_grid函数一次性的画出整批图像:

我们可以看到,我们已经使用torchvision.utils.make_grid函数创建了一个网络,我们把图像张量作为第一个参数,nrow=10这样我们所有的图像就会沿着一行显示,nrow参数指定每一行的图像数量,因为我们的batch_size=10,这就给我们了一排图像,我们使用np.transpose(grid, (1,2,0)),这样轴就满足了图像的功能需要的规格。

现在我们知道了datasetdataloader之间如何交互的了。现在试试如何批量处理数据:

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。