700字范文,内容丰富有趣,生活中的好帮手!
700字范文 > pytorch读取常用数据集dataset实现例子

pytorch读取常用数据集dataset实现例子

时间:2019-01-02 03:12:41

相关推荐

pytorch读取常用数据集dataset实现例子

MNIST示例

定义

class MNIST(VisionDataset):"""`MNIST </exdb/mnist/>`_ Dataset.Args:root (string): Root directory of dataset where ``MNIST/processed/training.pt``and ``MNIST/processed/test.pt`` exist.train (bool, optional): If True, creates dataset from ``training.pt``,otherwise from ``test.pt``.download (bool, optional): If true, downloads the dataset from the internet andputs it in root directory. If dataset is already downloaded, it is notdownloaded again.transform (callable, optional): A function/transform that takes in an PIL imageand returns a transformed version. E.g, ``transforms.RandomCrop``target_transform (callable, optional): A function/transform that takes in thetarget and transforms it."""mirrors = ['/exdb/mnist/','https://ossci-datasets./mnist/',]resources = [("train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),("train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),("t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"),("t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c")]training_file = 'training.pt'test_file = 'test.pt'classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four','5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']@propertydef train_labels(self):warnings.warn("train_labels has been renamed targets")return self.targets@propertydef test_labels(self):warnings.warn("test_labels has been renamed targets")return self.targets@propertydef train_data(self):warnings.warn("train_data has been renamed data")return self.data@propertydef test_data(self):warnings.warn("test_data has been renamed data")return self.datadef __init__(self,root: str,train: bool = True,transform: Optional[Callable] = None,target_transform: Optional[Callable] = None,download: bool = False,) -> None:super(MNIST, self).__init__(root, transform=transform,target_transform=target_transform)self.train = train # training set or test setif self._check_legacy_exist():self.data, self.targets = self._load_legacy_data()returnif download:self.download()if not self._check_exists():raise RuntimeError('Dataset not found.' +' You can use download=True to download it')self.data, self.targets = self._load_data()def _check_legacy_exist(self):processed_folder_exists = os.path.exists(self.processed_folder)if not processed_folder_exists:return Falsereturn all(check_integrity(os.path.join(self.processed_folder, file)) for file in (self.training_file, self.test_file))def _load_legacy_data(self):# This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data# directly.data_file = self.training_file if self.train else self.test_filereturn torch.load(os.path.join(self.processed_folder, data_file))def _load_data(self):image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte"data = read_image_file(os.path.join(self.raw_folder, image_file))label_file = f"{'train' if self.train else 't10k'}-labels-idx1-ubyte"targets = read_label_file(os.path.join(self.raw_folder, label_file))return data, targetsdef __getitem__(self, index: int) -> Tuple[Any, Any]:"""Args:index (int): IndexReturns:tuple: (image, target) where target is index of the target class."""img, target = self.data[index], int(self.targets[index])# doing this so that it is consistent with all other datasets# to return a PIL Imageimg = Image.fromarray(img.numpy(), mode='L')if self.transform is not None:img = self.transform(img)if self.target_transform is not None:target = self.target_transform(target)return img, targetdef __len__(self) -> int:return len(self.data)@propertydef raw_folder(self) -> str:return os.path.join(self.root, self.__class__.__name__, 'raw')@propertydef processed_folder(self) -> str:return os.path.join(self.root, self.__class__.__name__, 'processed')@propertydef class_to_idx(self) -> Dict[str, int]:return {_class: i for i, _class in enumerate(self.classes)}def _check_exists(self) -> bool:return all(check_integrity(os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0]))for url, _ in self.resources)def download(self) -> None:"""Download the MNIST data if it doesn't exist already."""if self._check_exists():returnos.makedirs(self.raw_folder, exist_ok=True)# download filesfor filename, md5 in self.resources:for mirror in self.mirrors:url = "{}{}".format(mirror, filename)try:print("Downloading {}".format(url))download_and_extract_archive(url, download_root=self.raw_folder,filename=filename,md5=md5)except URLError as error:print("Failed to download (trying next):\n{}".format(error))continuefinally:print()breakelse:raise RuntimeError("Error downloading {}".format(filename))def extra_repr(self) -> str:return "Split: {}".format("Train" if self.train is True else "Test")

FMNIST,KMNIST,QMNIST均可直接读取,在torchvision.datasets中

可通过下面的方式加载

train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('./data/', train=True, download=True,transform=pose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size_train, shuffle=True)test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('./data/', train=False, download=True,transform=pose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size_test, shuffle=True)

CIFAR

class CIFAR10(VisionDataset):"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.Args:root (string): Root directory of dataset where directory``cifar-10-batches-py`` exists or will be saved to if download is set to True.train (bool, optional): If True, creates dataset from training set, otherwisecreates from test set.transform (callable, optional): A function/transform that takes in an PIL imageand returns a transformed version. E.g, ``transforms.RandomCrop``target_transform (callable, optional): A function/transform that takes in thetarget and transforms it.download (bool, optional): If true, downloads the dataset from the internet andputs it in root directory. If dataset is already downloaded, it is notdownloaded again."""base_folder = 'cifar-10-batches-py'url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"filename = "cifar-10-python.tar.gz"tgz_md5 = 'c58f30108f718f92721af3b95e74349a'train_list = [['data_batch_1', 'c99cafc152244af753f735de768cd75f'],['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],['data_batch_4', '634d18415352ddfa80567beed471001a'],['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],]test_list = [['test_batch', '40351d587109b95175f43aff81a1287e'],]meta = {'filename': 'batches.meta','key': 'label_names','md5': '5ff9c542aee3614f3951f8cda6e48888',}def __init__(self,root: str,train: bool = True,transform: Optional[Callable] = None,target_transform: Optional[Callable] = None,download: bool = False,) -> None:super(CIFAR10, self).__init__(root, transform=transform,target_transform=target_transform)self.train = train # training set or test setif download:self.download()if not self._check_integrity():raise RuntimeError('Dataset not found or corrupted.' +' You can use download=True to download it')if self.train:downloaded_list = self.train_listelse:downloaded_list = self.test_listself.data: Any = []self.targets = []# now load the picked numpy arraysfor file_name, checksum in downloaded_list:file_path = os.path.join(self.root, self.base_folder, file_name)with open(file_path, 'rb') as f:entry = pickle.load(f, encoding='latin1')self.data.append(entry['data'])if 'labels' in entry:self.targets.extend(entry['labels'])else:self.targets.extend(entry['fine_labels'])self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWCself._load_meta()def _load_meta(self) -> None:path = os.path.join(self.root, self.base_folder, self.meta['filename'])if not check_integrity(path, self.meta['md5']):raise RuntimeError('Dataset metadata file not found or corrupted.' +' You can use download=True to download it')with open(path, 'rb') as infile:data = pickle.load(infile, encoding='latin1')self.classes = data[self.meta['key']]self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}def __getitem__(self, index: int) -> Tuple[Any, Any]:"""Args:index (int): IndexReturns:tuple: (image, target) where target is index of the target class."""img, target = self.data[index], self.targets[index]# doing this so that it is consistent with all other datasets# to return a PIL Imageimg = Image.fromarray(img)if self.transform is not None:img = self.transform(img)if self.target_transform is not None:target = self.target_transform(target)return img, targetdef __len__(self) -> int:return len(self.data)def _check_integrity(self) -> bool:root = self.rootfor fentry in (self.train_list + self.test_list):filename, md5 = fentry[0], fentry[1]fpath = os.path.join(root, self.base_folder, filename)if not check_integrity(fpath, md5):return Falsereturn Truedef download(self) -> None:if self._check_integrity():print('Files already downloaded and verified')returndownload_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)def extra_repr(self) -> str:return "Split: {}".format("Train" if self.train is True else "Test")

CIFAR100同理

预定义数据集

pytorch1.8以后其余已定义的数据集有

Caltech101Caltech256 (文件名caltech.py)STL10 (stl10.py)SVHN (svhn.py)CelebA (celeba.py)INaturalist (inaturalist.py)Omniglot (omniglot.py)Places365 (places365.py)

需要自己下载完整数据集

LSUNClassImageNet

补充

Food101

from pathlib import Pathimport jsonfrom typing import Any, Tuple, Callable, Optionalimport torchimport PIL.Imagefrom torchvision.datasets.utils import check_integrity,download_and_extract_archive, download_url, verify_str_argfrom torchvision.datasets.vision import VisionDatasetclass Food101(VisionDataset):"""`The Food-101 Data Set <https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/>`_.The Food-101 is a challenging data set of 101 food categories, with 101'000 images.For each class, 250 manually reviewed test images are provided as well as 750 training images.On purpose, the training images were not cleaned, and thus still contain some amount of noise.This comes mostly in the form of intense colors and sometimes wrong labels. All images wererescaled to have a maximum side length of 512 pixels.Args:root (string): Root directory of the dataset.split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``.transform (callable, optional): A function/transform that takes in an PIL image and returns a transformedversion. E.g, ``transforms.RandomCrop``.target_transform (callable, optional): A function/transform that takes in the target and transforms it.download (bool, optional): If True, downloads the dataset from the internet andputs it in root directory. If dataset is already downloaded, it is notdownloaded again. Default is False."""_URL = "http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz"_MD5 = "85eeb15f3717b99a5da872d97d918f87"def __init__(self,root: str,split: str = "train",transform: Optional[Callable] = None,target_transform: Optional[Callable] = None,download: bool = False,) -> None:super().__init__(root, transform=transform, target_transform=target_transform)self._split = verify_str_arg(split, "split", ("train", "test"))self._base_folder = Path(self.root) / "food-101"self._meta_folder = self._base_folder / "meta"self._images_folder = self._base_folder / "images"self.class_names_str = ['Apple pie', 'Baby back ribs', 'Baklava', 'Beef carpaccio', 'Beef tartare', 'Beet salad', 'Beignets', 'Bibimbap', 'Bread pudding', 'Breakfast burrito', 'Bruschetta', 'Caesar salad', 'Cannoli', 'Caprese salad', 'Carrot cake', 'Ceviche', 'Cheesecake', 'Cheese plate', 'Chicken curry', 'Chicken quesadilla', 'Chicken wings', 'Chocolate cake', 'Chocolate mousse', 'Churros', 'Clam chowder', 'Club sandwich', 'Crab cakes', 'Creme brulee', 'Croque madame', 'Cup cakes', 'Deviled eggs', 'Donuts', 'Dumplings', 'Edamame', 'Eggs benedict', 'Escargots', 'Falafel', 'Filet mignon', 'Fish and chips', 'Foie gras', 'French fries', 'French onion soup', 'French toast', 'Fried calamari', 'Fried rice', 'Frozen yogurt', 'Garlic bread', 'Gnocchi', 'Greek salad', 'Grilled cheese sandwich', 'Grilled salmon', 'Guacamole', 'Gyoza', 'Hamburger', 'Hot and sour soup', 'Hot dog', 'Huevos rancheros', 'Hummus', 'Ice cream', 'Lasagna', 'Lobster bisque', 'Lobster roll sandwich', 'Macaroni and cheese', 'Macarons', 'Miso soup', 'Mussels', 'Nachos', 'Omelette', 'Onion rings', 'Oysters', 'Pad thai', 'Paella', 'Pancakes', 'Panna cotta', 'Peking duck', 'Pho', 'Pizza', 'Pork chop', 'Poutine', 'Prime rib', 'Pulled pork sandwich', 'Ramen', 'Ravioli', 'Red velvet cake', 'Risotto', 'Samosa', 'Sashimi', 'Scallops', 'Seaweed salad', 'Shrimp and grits', 'Spaghetti bolognese', 'Spaghetti carbonara', 'Spring rolls', 'Steak', 'Strawberry shortcake', 'Sushi', 'Tacos', 'Takoyaki', 'Tiramisu', 'Tuna tartare', 'Waffles']if download:self._download()if not self._check_exists():raise RuntimeError("Dataset not found. You can use download=True to download it")self._labels = []self._image_files = []with open(self._meta_folder / f"{split}.json") as f:metadata = json.loads(f.read())self.classes = sorted(metadata.keys())self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))for class_label, im_rel_paths in metadata.items():self._labels += [self.class_to_idx[class_label]] * len(im_rel_paths)self._image_files += [self._images_folder.joinpath(*f"{im_rel_path}.jpg".split("/")) for im_rel_path in im_rel_paths]def __len__(self) -> int:return len(self._image_files)def __getitem__(self, idx) -> Tuple[Any, Any]:image_file, label = self._image_files[idx], self._labels[idx]image = PIL.Image.open(image_file).convert("RGB")if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, labeldef extra_repr(self) -> str:return f"split={self._split}"def _check_exists(self) -> bool:return all(folder.exists() and folder.is_dir() for folder in (self._meta_folder, self._images_folder))def _download(self) -> None:if self._check_exists():returndownload_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5)def examine_count(counter, name = "train"):print(f"in the {name} set")for label in counter:print(label, counter[label])if __name__ == "__main__":label_names = []with open('debug/food101_labels.txt') as f:for name in f:label_names.append(name.strip())print(label_names)train_set = Food101(root = "/nobackup/dataset_myf", split = "train", download = True)test_set = Food101(root = "/nobackup/dataset_myf", split = "test")print(f"train set len {len(train_set)}")print(f"test set len {len(test_set)}")from collections import Countertrain_label_count = Counter(train_set._labels)test_label_count = Counter(test_set._labels)# examine_count(train_label_count, name = "train")# examine_count(test_label_count, name = "test")kwargs = {'num_workers': 4, 'pin_memory': True}train_loader = torch.utils.data.DataLoader(train_set ,batch_size=16, shuffle=True, **kwargs)val_loader = torch.utils.data.DataLoader(test_set,batch_size=16, shuffle=False, **kwargs)

Flower102

from pathlib import Pathfrom typing import Any, Tuple, Callable, Optionalimport torchimport PIL.Imagefrom torchvision.datasets.utils import check_integrity,download_and_extract_archive, download_url, verify_str_argfrom torchvision.datasets.vision import VisionDatasetclass Flowers102(VisionDataset):"""`Oxford 102 Flower <https://www.robots.ox.ac.uk/~vgg/data/flowers/102/>`_ Dataset... warning::This class needs `scipy </doc/>`_ to load target files from `.mat` format.Oxford 102 Flower is an image classification dataset consisting of 102 flower categories. Theflowers were chosen to be flowers commonly occurring in the United Kingdom. Each class consists ofbetween 40 and 258 images.The images have large scale, pose and light variations. In addition, there are categories thathave large variations within the category, and several very similar categories.Args:root (string): Root directory of the dataset.split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.transform (callable, optional): A function/transform that takes in an PIL image and returns atransformed version. E.g, ``transforms.RandomCrop``.target_transform (callable, optional): A function/transform that takes in the target and transforms it.download (bool, optional): If true, downloads the dataset from the internet andputs it in root directory. If dataset is already downloaded, it is notdownloaded again."""_download_url_prefix = "https://www.robots.ox.ac.uk/~vgg/data/flowers/102/"_file_dict = {# filename, md5"image": ("102flowers.tgz", "52808999861908f626f3c1f4e79d11fa"),"label": ("imagelabels.mat", "e0620be6f572b9609742df49c70aed4d"),"setid": ("setid.mat", "a5357ecc9cb78c4bef273ce3793fc85c"),}_splits_map = {"train": "trnid", "val": "valid", "test": "tstid"}# /JosephKJ/94c7728ed1a8e0cd87fe6a029769cde1label_names = ['pink primrose', 'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea', 'english marigold', 'tiger lily', 'moon orchid', 'bird of paradise', 'monkshood', 'globe thistle', 'snapdragon', "colt's foot", 'king protea', 'spear thistle', 'yellow iris', 'globe-flower', 'purple coneflower', 'peruvian lily', 'balloon flower', 'giant white arum lily', 'fire lily', 'pincushion flower', 'fritillary', 'red ginger', 'grape hyacinth', 'corn poppy', 'prince of wales feathers', 'stemless gentian', 'artichoke', 'sweet william', 'carnation', 'garden phlox', 'love in the mist', 'mexican aster', 'alpine sea holly', 'ruby-lipped cattleya', 'cape flower', 'great masterwort', 'siam tulip', 'lenten rose', 'barbeton daisy', 'daffodil', 'sword lily', 'poinsettia', 'bolero deep blue', 'wallflower', 'marigold', 'buttercup', 'oxeye daisy', 'common dandelion', 'petunia', 'wild pansy', 'primula', 'sunflower', 'pelargonium', 'bishop of llandaff', 'gaura', 'geranium', 'orange dahlia', 'pink-yellow dahlia?', 'cautleya spicata', 'japanese anemone', 'black-eyed susan', 'silverbush', 'californian poppy', 'osteospermum', 'spring crocus', 'bearded iris', 'windflower', 'tree poppy', 'gazania', 'azalea', 'water lily', 'rose', 'thorn apple', 'morning glory', 'passion flower', 'lotus', 'toad lily', 'anthurium', 'frangipani', 'clematis', 'hibiscus', 'columbine', 'desert-rose', 'tree mallow', 'magnolia', 'cyclamen ', 'watercress', 'canna lily', 'hippeastrum ', 'bee balm', 'ball moss', 'foxglove', 'bougainvillea', 'camellia', 'mallow', 'mexican petunia', 'bromelia', 'blanket flower', 'trumpet creeper', 'blackberry lily']def __init__(self,root: str,split: str = "train",transform: Optional[Callable] = None,target_transform: Optional[Callable] = None,download: bool = False,) -> None:super().__init__(root, transform=transform, target_transform=target_transform)self._split = verify_str_arg(split, "split", ("train", "val", "test"))self._base_folder = Path(self.root) / "flowers-102"self._images_folder = self._base_folder / "jpg"if download:self.download()if not self._check_integrity():raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")from scipy.io import loadmatset_ids = loadmat(self._base_folder / self._file_dict["setid"][0], squeeze_me=True)image_ids = set_ids[self._splits_map[self._split]].tolist()labels = loadmat(self._base_folder / self._file_dict["label"][0], squeeze_me=True)image_id_to_label = dict(enumerate(labels["labels"].tolist(), 1))self._labels = []self._image_files = []for image_id in image_ids:self._labels.append(image_id_to_label[image_id])self._image_files.append(self._images_folder / f"image_{image_id:05d}.jpg")self.class_names_str = self.label_namesdef __len__(self) -> int:return len(self._image_files)def __getitem__(self, idx) -> Tuple[Any, Any]:image_file, label = self._image_files[idx], self._labels[idx]image = PIL.Image.open(image_file).convert("RGB")if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, labeldef extra_repr(self) -> str:return f"split={self._split}"def _check_integrity(self):if not (self._images_folder.exists() and self._images_folder.is_dir()):return Falsefor id in ["label", "setid"]:filename, md5 = self._file_dict[id]if not check_integrity(str(self._base_folder / filename), md5):return Falsereturn Truedef download(self):if self._check_integrity():returndownload_and_extract_archive(f"{self._download_url_prefix}{self._file_dict['image'][0]}",str(self._base_folder),md5=self._file_dict["image"][1],)for id in ["label", "setid"]:filename, md5 = self._file_dict[id]download_url(self._download_url_prefix + filename, str(self._base_folder), md5=md5)def examine_count(counter, name = "train"):print(f"in the {name} set")for label in counter:print(label, counter[label])if __name__ == "__main__":# label_names = []# with open('debug/flowers102_labels.txt') as f:#for name in f:# label_names.append(name.strip()[1:-1])# print(label_names)train_set = Flowers102(root = "/nobackup/dataset_myf", split = "train", download = True)val_set = Flowers102(root = "/nobackup/dataset_myf", split = "val")test_set = Flowers102(root = "/nobackup/dataset_myf", split = "test")from collections import Countertrain_label_count = Counter(train_set._labels)val_label_count = Counter(val_set._labels)test_label_count = Counter(test_set._labels)examine_count(train_label_count, name = "train")examine_count(val_label_count, name = "val")examine_count(test_label_count, name = "test")kwargs = {'num_workers': 4, 'pin_memory': True}train_loader = torch.utils.data.DataLoader(train_set ,batch_size=16, shuffle=True, **kwargs)val_loader = torch.utils.data.DataLoader(val_set,batch_size=16, shuffle=False, **kwargs)

Car196

import pathlibfrom typing import Callable, Optional, Any, Tuplefrom PIL import Imageimport torchfrom torchvision.datasets.utils import check_integrity,download_and_extract_archive, download_url, verify_str_argfrom torchvision.datasets.vision import VisionDatasetclass StanfordCars(VisionDataset):"""`Stanford Cars <https://ai.stanford.edu/~jkrause/cars/car_dataset.html>`_ DatasetThe Cars dataset contains 16,185 images of 196 classes of cars. The data issplit into 8,144 training images and 8,041 testing images, where each classhas been split roughly in a 50-50 split.. note::This class needs `scipy </doc/>`_ to load target files from `.mat` format.Args:root (string): Root directory of datasetsplit (string, optional): The dataset split, supports ``"train"`` (default) or ``"test"``.transform (callable, optional): A function/transform that takes in an PIL imageand returns a transformed version. E.g, ``transforms.RandomCrop``target_transform (callable, optional): A function/transform that takes in thetarget and transforms it.download (bool, optional): If True, downloads the dataset from the internet andputs it in root directory. If dataset is already downloaded, it is notdownloaded again."""def __init__(self,root: str,split: str = "train",transform: Optional[Callable] = None,target_transform: Optional[Callable] = None,download: bool = False,) -> None:try:import scipy.io as sioexcept ImportError:raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy")super().__init__(root, transform=transform, target_transform=target_transform)self._split = verify_str_arg(split, "split", ("train", "test"))self._base_folder = pathlib.Path(root) / "stanford_cars"devkit = self._base_folder / "devkit"if self._split == "train":self._annotations_mat_path = devkit / "cars_train_annos.mat"self._images_base_path = self._base_folder / "cars_train"else:self._annotations_mat_path = self._base_folder / "cars_test_annos_withlabels.mat"self._images_base_path = self._base_folder / "cars_test"if download:self.download()if not self._check_exists():raise RuntimeError("Dataset not found. You can use download=True to download it")self._samples = [(str(self._images_base_path / annotation["fname"]),annotation["class"] - 1, # Original target mapping starts from 1, hence -1)for annotation in sio.loadmat(self._annotations_mat_path, squeeze_me=True)["annotations"]]self.classes = sio.loadmat(str(devkit / "cars_meta.mat"), squeeze_me=True)["class_names"].tolist()self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}self.class_names_str = self.classesdef __len__(self) -> int:return len(self._samples)def __getitem__(self, idx: int) -> Tuple[Any, Any]:"""Returns pil_image and class_id for given index"""image_path, target = self._samples[idx]pil_image = Image.open(image_path).convert("RGB")if self.transform is not None:pil_image = self.transform(pil_image)if self.target_transform is not None:target = self.target_transform(target)return pil_image, targetdef download(self) -> None:if self._check_exists():returndownload_and_extract_archive(url="https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz",download_root=str(self._base_folder),md5="c3b158d763b6e2245038c8ad08e45376",)if self._split == "train":download_and_extract_archive(url="https://ai.stanford.edu/~jkrause/car196/cars_train.tgz",download_root=str(self._base_folder),md5="065e5b463ae28d29e77c1b4b166cfe61",)else:download_and_extract_archive(url="https://ai.stanford.edu/~jkrause/car196/cars_test.tgz",download_root=str(self._base_folder),md5="4ce7ebf6a94d07f1952d94dd34c4d501",)download_url(url="https://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat",root=str(self._base_folder),md5="b0a2b23655a3edd16d84508592a98d10",)def _check_exists(self) -> bool:if not (self._base_folder / "devkit").is_dir():return Falsereturn self._annotations_mat_path.exists() and self._images_base_path.is_dir()def examine_count(counter, name = "train"):print(f"in the {name} set")for label in counter:print(label, counter[label])if __name__ == "__main__":train_set = StanfordCars(root = "/nobackup/dataset_myf", split = "train", download = True)test_set = StanfordCars(root = "/nobackup/dataset_myf", split = "test", download = True)print(f"train set len {len(train_set)}")print(f"test set len {len(test_set)}")from collections import Countertrain_label_count = Counter([label for img, label in train_set._samples])test_label_count = Counter([label for img, label in test_set._samples])examine_count(train_label_count, name = "train")examine_count(test_label_count, name = "test")kwargs = {'num_workers': 4, 'pin_memory': True}train_loader = torch.utils.data.DataLoader(train_set ,batch_size=16, shuffle=True, **kwargs)val_loader = torch.utils.data.DataLoader(test_set,batch_size=16, shuffle=False, **kwargs)

CUB200

import numpy as np# 读取数据import matplotlib.imageimport osfrom PIL import Imagefrom torchvision import transformsimport torchclass CUB():def __init__(self, root, is_train=True, data_len=None,transform=None, target_transform=None):self.root = rootself.is_train = is_trainself.transform = transformself.target_transform = target_transformimg_txt_file = open(os.path.join(self.root, 'images.txt'))label_txt_file = open(os.path.join(self.root, 'image_class_labels.txt'))train_val_file = open(os.path.join(self.root, 'train_test_split.txt'))# 图片索引img_name_list = []for line in img_txt_file:# 最后一个字符为换行符img_name_list.append(line[:-1].split(' ')[-1])# 标签索引,每个对应的标签减1,标签值从0开始label_list = []for line in label_txt_file:label_list.append(int(line[:-1].split(' ')[-1]) - 1)# 设置训练集和测试集train_test_list = []for line in train_val_file:train_test_list.append(int(line[:-1].split(' ')[-1]))# zip压缩合并,将数据与标签(训练集还是测试集)对应压缩# zip() 函数用于将可迭代的对象作为参数,将对象中对应的元素打包成一个个元组,# 然后返回由这些元组组成的对象,这样做的好处是节约了不少的内存。# 我们可以使用 list() 转换来输出列表# 如果 i 为 1,那么设为训练集# 1为训练集,0为测试集# zip压缩合并,将数据与标签(训练集还是测试集)对应压缩# 如果 i 为 1,那么设为训练集train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i]test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i]train_label_list = [x for i, x in zip(train_test_list, label_list) if i][:data_len]test_label_list = [x for i, x in zip(train_test_list, label_list) if not i][:data_len]if self.is_train:# matplotlib.image.imread 图片读取出来为array类型,即numpy类型self.train_img = [matplotlib.image.imread(os.path.join(self.root, 'images', train_file)) for train_file intrain_file_list[:data_len]]# 读取训练集标签self.train_label = train_label_listif not self.is_train:self.test_img = [matplotlib.image.imread(os.path.join(self.root, 'images', test_file)) for test_file intest_file_list[:data_len]]self.test_label = test_label_list# 数据增强def __getitem__(self,index):# 训练集if self.is_train:img, target = self.train_img[index], self.train_label[index]# 测试集else:img, target = self.test_img[index], self.test_label[index]if len(img.shape) == 2:# 灰度图像转为三通道img = np.stack([img]*3,2)# 转为 RGB 类型img = Image.fromarray(img,mode='RGB')if self.transform is not None:img = self.transform(img)if self.target_transform is not None:target = self.target_transform(target)return img, targetdef __len__(self):if self.is_train:return len(self.train_label)else:return len(self.test_label)if __name__ == '__main__':'''dataset = CUB(root='./CUB_200_')for data in dataset:print(data[0].size(),data[1])'''# 以pytorch中DataLoader的方式读取数据集transform_train = pose([transforms.Resize((224, 224)),transforms.RandomCrop(224, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225]),])dataset = CUB(root='../dataset/CUB_200_', is_train=True, transform=transform_train,)print(len(dataset))trainloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True, num_workers=0,drop_last=True)print(len(trainloader))

Aircraft

import numpy as npimport osfrom torchvision.datasets import VisionDatasetfrom torchvision.datasets.folder import default_loaderfrom torchvision.datasets.utils import download_urlfrom torchvision.datasets.utils import extract_archiveclass Aircraft(VisionDataset):"""`FGVC-Aircraft <http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/>`_ Dataset.Args:root (string): Root directory of the dataset.train (bool, optional): If True, creates dataset from training set, otherwisecreates from test set.class_type (string, optional): choose from ('variant', 'family', 'manufacturer').transform (callable, optional): A function/transform that takes in an PIL imageand returns a transformed version. E.g, ``transforms.RandomCrop``target_transform (callable, optional): A function/transform that takes in thetarget and transforms it.download (bool, optional): If true, downloads the dataset from the internet andputs it in root directory. If dataset is already downloaded, it is notdownloaded again."""url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-b.tar.gz'class_types = ('variant', 'family', 'manufacturer')splits = ('train', 'val', 'trainval', 'test')img_folder = os.path.join('fgvc-aircraft-b', 'data', 'images')def __init__(self, root, train=True, class_type='variant', transform=None,target_transform=None, download=False):super(Aircraft, self).__init__(root, transform=transform, target_transform=target_transform)split = 'trainval' if train else 'test'if split not in self.splits:raise ValueError('Split "{}" not found. Valid splits are: {}'.format(split, ', '.join(self.splits),))if class_type not in self.class_types:raise ValueError('Class type "{}" not found. Valid class types are: {}'.format(class_type, ', '.join(self.class_types),))self.class_type = class_typeself.split = splitself.classes_file = os.path.join(self.root, 'fgvc-aircraft-b', 'data','images_%s_%s.txt' % (self.class_type, self.split))if download:self.download()(image_ids, targets, classes, class_to_idx) = self.find_classes()samples = self.make_dataset(image_ids, targets)self.loader = default_loaderself.samples = samplesself.classes = classesself.class_to_idx = class_to_idxdef __getitem__(self, index):path, target = self.samples[index]sample = self.loader(path)if self.transform is not None:sample = self.transform(sample)if self.target_transform is not None:target = self.target_transform(target)return sample, targetdef __len__(self):return len(self.samples)def _check_exists(self):return os.path.exists(os.path.join(self.root, self.img_folder)) and \os.path.exists(self.classes_file)def download(self):if self._check_exists():return# prepare to download data to PARENT_DIR/fgvc-aircraft-.tar.gzprint('Downloading %s...' % self.url)tar_name = self.url.rpartition('/')[-1]download_url(self.url, root=self.root, filename=tar_name)tar_path = os.path.join(self.root, tar_name)print('Extracting %s...' % tar_path)extract_archive(tar_path)print('Done!')def find_classes(self):# read classes file, separating out image IDs and class namesimage_ids = []targets = []with open(self.classes_file, 'r') as f:for line in f:split_line = line.split(' ')image_ids.append(split_line[0])targets.append(' '.join(split_line[1:]))# index class namesclasses = np.unique(targets)class_to_idx = {classes[i]: i for i in range(len(classes))}targets = [class_to_idx[c] for c in targets]return image_ids, targets, classes, class_to_idxdef make_dataset(self, image_ids, targets):assert (len(image_ids) == len(targets))images = []for i in range(len(image_ids)):item = (os.path.join(self.root, self.img_folder,'%s.jpg' % image_ids[i]), targets[i])images.append(item)return imagesif __name__ == '__main__':train_dataset = Aircraft('./aircraft', train=True, download=False)test_dataset = Aircraft('./aircraft', train=False, download=False)

PermutedMNIST

class PermutedMNISTDataLoader(torchvision.datasets.MNIST):def __init__(self, source='data/mnist_data', train = True, shuffle_seed = None):super(PermutedMNISTDataLoader, self).__init__(source, train, download=True)self.train = trainself.num_data = 0if self.train:self.permuted_train_data = torch.stack([img.type(dtype=torch.float32).view(-1)[shuffle_seed] / 255.0for img in self.train_data])self.num_data = self.permuted_train_data.shape[0]else:self.permuted_test_data = torch.stack([img.type(dtype=torch.float32).view(-1)[shuffle_seed] / 255.0for img in self.test_data])self.num_data = self.permuted_test_data.shape[0]def __getitem__(self, index):if self.train:input, label = self.permuted_train_data[index], self.train_labels[index]else:input, label = self.permuted_test_data[index], self.test_labels[index]return input, labeldef getNumData(self):return self.num_databatch_size = 64learning_rate = 1e-3num_task = 10criterion = torch.nn.CrossEntropyLoss()cuda_available = Falseif torch.cuda.is_available():cuda_available = Truedef permute_mnist():train_loader = {}test_loader = {}train_data_num = 0test_data_num = 0for i in range(num_task):shuffle_seed = np.arange(28*28)np.random.shuffle(shuffle_seed)train_PMNIST_DataLoader = PermutedMNISTDataLoader(train=True, shuffle_seed=shuffle_seed)test_PMNIST_DataLoader = PermutedMNISTDataLoader(train=False, shuffle_seed=shuffle_seed)train_data_num += train_PMNIST_DataLoader.getNumData()test_data_num += test_PMNIST_DataLoader.getNumData()train_loader[i] = torch.utils.data.DataLoader(train_PMNIST_DataLoader,batch_size=batch_size)test_loader[i] = torch.utils.data.DataLoader(test_PMNIST_DataLoader,batch_size=batch_size)return train_loader, test_loader, int(train_data_num/num_task), int(test_data_num/num_task)train_loader, test_loader, train_data_num, test_data_num = permute_mnist()

TinyImageNet

import osimport osimport pandas as pdimport warningsfrom torchvision.datasets import ImageFolderfrom torchvision.datasets import VisionDatasetfrom torchvision.datasets.folder import default_loaderfrom torchvision.datasets.folder import default_loaderfrom torchvision.datasets.utils import extract_archive, check_integrity, download_url, verify_str_argclass TinyImageNet(VisionDataset):"""`tiny-imageNet <http://cs231n.stanford.edu/tiny-imagenet-200.zip>`_ Dataset.Args:root (string): Root directory of the dataset.split (string, optional): The dataset split, supports ``train``, or ``val``.transform (callable, optional): A function/transform that takes in an PIL imageand returns a transformed version. E.g, ``transforms.RandomCrop``target_transform (callable, optional): A function/transform that takes in thetarget and transforms it.download (bool, optional): If true, downloads the dataset from the internet andputs it in root directory. If dataset is already downloaded, it is notdownloaded again."""base_folder = 'tiny-imagenet-200/'url = 'http://cs231n.stanford.edu/tiny-imagenet-200.zip'filename = 'tiny-imagenet-200.zip'md5 = '90528d7ca1a48142e341f4ef8d21d0de'def __init__(self, root, split='train', transform=None, target_transform=None, download=False):super(TinyImageNet, self).__init__(root, transform=transform, target_transform=target_transform)self.dataset_path = os.path.join(root, self.base_folder)self.loader = default_loaderself.split = verify_str_arg(split, "split", ("train", "val",))if self._check_integrity():print('Files already downloaded and verified.')elif download:self._download()else:raise RuntimeError('Dataset not found. You can use download=True to download it.')if not os.path.isdir(self.dataset_path):print('Extracting...')extract_archive(os.path.join(root, self.filename))_, class_to_idx = find_classes(os.path.join(self.dataset_path, 'wnids.txt'))self.data = make_dataset(self.root, self.base_folder, self.split, class_to_idx)def _download(self):print('Downloading...')download_url(self.url, root=self.root, filename=self.filename)print('Extracting...')extract_archive(os.path.join(self.root, self.filename))def _check_integrity(self):return check_integrity(os.path.join(self.root, self.filename), self.md5)def __getitem__(self, index):img_path, target = self.data[index]image = self.loader(img_path)if self.transform is not None:image = self.transform(image)if self.target_transform is not None:target = self.target_transform(target)return image, targetdef __len__(self):return len(self.data)def find_classes(class_file):with open(class_file) as r:classes = list(map(lambda s: s.strip(), r.readlines()))classes.sort()class_to_idx = {classes[i]: i for i in range(len(classes))}return classes, class_to_idxdef make_dataset(root, base_folder, dirname, class_to_idx):images = []dir_path = os.path.join(root, base_folder, dirname)if dirname == 'train':for fname in sorted(os.listdir(dir_path)):cls_fpath = os.path.join(dir_path, fname)if os.path.isdir(cls_fpath):cls_imgs_path = os.path.join(cls_fpath, 'images')for imgname in sorted(os.listdir(cls_imgs_path)):path = os.path.join(cls_imgs_path, imgname)item = (path, class_to_idx[fname])images.append(item)else:imgs_path = os.path.join(dir_path, 'images')imgs_annotations = os.path.join(dir_path, 'val_annotations.txt')with open(imgs_annotations) as r:data_info = map(lambda s: s.split('\t'), r.readlines())cls_map = {line_data[0]: line_data[1] for line_data in data_info}for imgname in sorted(os.listdir(imgs_path)):path = os.path.join(imgs_path, imgname)item = (path, class_to_idx[cls_map[imgname]])images.append(item)return imagesif __name__ == '__main__':train_dataset = TinyImageNet('./tiny-imagenet', split='train', download=False)test_dataset = TinyImageNet('./tiny-imagenet', split='val', download=False)

MiniImageNet

##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++## Created by: Yaoyao Liu## NUS School of Computing## Email: yaoyao.liu@nus.edu.sg## Copyright (c) #### This source code is licensed under the MIT-style license found in the## LICENSE file in the root directory of this source tree##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++import osimport randomimport numpy as npfrom tqdm import trangeimport imageioclass MiniImageNetDataLoader(object):def __init__(self, shot_num, way_num, episode_test_sample_num, shuffle_images = False):self.shot_num = shot_numself.way_num = way_numself.episode_test_sample_num = episode_test_sample_numself.num_samples_per_class = episode_test_sample_num + shot_numself.shuffle_images = shuffle_imagesmetatrain_folder = './processed_images/train'metaval_folder = './processed_images/val'metatest_folder = './processed_images/test'npy_dir = './episode_filename_list/'if not os.path.exists(npy_dir):os.mkdir(npy_dir)self.npy_base_dir = npy_dir + str(self.shot_num) + 'shot_' + str(self.way_num) + 'way_' + str(episode_test_sample_num) + 'shuffled_' + str(self.shuffle_images) + '/'if not os.path.exists(self.npy_base_dir):os.mkdir(self.npy_base_dir)self.metatrain_folders = [os.path.join(metatrain_folder, label) \for label in os.listdir(metatrain_folder) \if os.path.isdir(os.path.join(metatrain_folder, label)) \]self.metaval_folders = [os.path.join(metaval_folder, label) \for label in os.listdir(metaval_folder) \if os.path.isdir(os.path.join(metaval_folder, label)) \]self.metatest_folders = [os.path.join(metatest_folder, label) \for label in os.listdir(metatest_folder) \if os.path.isdir(os.path.join(metatest_folder, label)) \]def get_images(self, paths, labels, nb_samples=None, shuffle=True):if nb_samples is not None:sampler = lambda x: random.sample(x, nb_samples)else:sampler = lambda x: ximages = [(i, os.path.join(path, image)) \for i, path in zip(labels, paths) \for image in sampler(os.listdir(path))]if shuffle:random.shuffle(images)return imagesdef generate_data_list(self, phase='train', episode_num=None):if phase=='train':folders = self.metatrain_foldersif episode_num is None:episode_num = 20000if not os.path.exists(self.npy_base_dir+'/train_filenames.npy'):print('Generating train filenames')all_filenames = []for _ in trange(episode_num):sampled_character_folders = random.sample(folders, self.way_num)random.shuffle(sampled_character_folders)labels_and_images = self.get_images(sampled_character_folders, range(self.way_num), nb_samples=self.num_samples_per_class, shuffle=self.shuffle_images)labels = [li[0] for li in labels_and_images]filenames = [li[1] for li in labels_and_images]all_filenames.extend(filenames)np.save(self.npy_base_dir+'/train_labels.npy', labels)np.save(self.npy_base_dir+'/train_filenames.npy', all_filenames)print('Train filename and label lists are saved')elif phase=='val':folders = self.metaval_foldersif episode_num is None:episode_num = 600if not os.path.exists(self.npy_base_dir+'/val_filenames.npy'):print('Generating val filenames')all_filenames = []for _ in trange(episode_num):sampled_character_folders = random.sample(folders, self.way_num)random.shuffle(sampled_character_folders)labels_and_images = self.get_images(sampled_character_folders, range(self.way_num), nb_samples=self.num_samples_per_class, shuffle=self.shuffle_images)labels = [li[0] for li in labels_and_images]filenames = [li[1] for li in labels_and_images]all_filenames.extend(filenames)np.save(self.npy_base_dir+'/val_labels.npy', labels)np.save(self.npy_base_dir+'/val_filenames.npy', all_filenames)print('Val filename and label lists are saved')elif phase=='test':folders = self.metatest_foldersif episode_num is None:episode_num = 600if not os.path.exists(self.npy_base_dir+'/test_filenames.npy'):print('Generating test filenames')all_filenames = []for _ in trange(episode_num):sampled_character_folders = random.sample(folders, self.way_num)random.shuffle(sampled_character_folders)labels_and_images = self.get_images(sampled_character_folders, range(self.way_num), nb_samples=self.num_samples_per_class, shuffle=self.shuffle_images)labels = [li[0] for li in labels_and_images]filenames = [li[1] for li in labels_and_images]all_filenames.extend(filenames)np.save(self.npy_base_dir+'/test_labels.npy', labels)np.save(self.npy_base_dir+'/test_filenames.npy', all_filenames)print('Test filename and label lists are saved')else:print('Please select vaild phase')def load_list(self, phase='train'):if phase=='train':self.train_filenames = np.load(self.npy_base_dir + 'train_filenames.npy').tolist()self.train_labels = np.load(self.npy_base_dir + 'train_labels.npy').tolist()elif phase=='val':self.val_filenames = np.load(self.npy_base_dir + 'val_filenames.npy').tolist()self.val_labels = np.load(self.npy_base_dir + 'val_labels.npy').tolist()elif phase=='test':self.test_filenames = np.load(self.npy_base_dir + 'test_filenames.npy').tolist()self.test_labels = np.load(self.npy_base_dir + 'test_labels.npy').tolist()elif phase=='all':self.train_filenames = np.load(self.npy_base_dir + 'train_filenames.npy').tolist()self.train_labels = np.load(self.npy_base_dir + 'train_labels.npy').tolist()self.val_filenames = np.load(self.npy_base_dir + 'val_filenames.npy').tolist()self.val_labels = np.load(self.npy_base_dir + 'val_labels.npy').tolist()self.test_filenames = np.load(self.npy_base_dir + 'test_filenames.npy').tolist()self.test_labels = np.load(self.npy_base_dir + 'test_labels.npy').tolist()else:print('Please select vaild phase')def process_batch(self, input_filename_list, input_label_list, batch_sample_num, reshape_with_one=True):new_path_list = []new_label_list = []for k in range(batch_sample_num):class_idxs = list(range(0, self.way_num))random.shuffle(class_idxs)for class_idx in class_idxs:true_idx = class_idx*batch_sample_num + knew_path_list.append(input_filename_list[true_idx])new_label_list.append(input_label_list[true_idx])img_list = []for filepath in new_path_list:this_img = imageio.imread(filepath)this_img = this_img / 255.0img_list.append(this_img)if reshape_with_one:img_array = np.array(img_list)label_array = self.one_hot(np.array(new_label_list)).reshape([1, self.way_num*batch_sample_num, -1])else:img_array = np.array(img_list)label_array = self.one_hot(np.array(new_label_list)).reshape([self.way_num*batch_sample_num, -1])return img_array, label_arraydef one_hot(self, inp):n_class = inp.max() + 1n_sample = inp.shape[0]out = np.zeros((n_sample, n_class))for idx in range(n_sample):out[idx, inp[idx]] = 1return outdef get_batch(self, phase='train', idx=0):if phase=='train':all_filenames = self.train_filenameslabels = self.train_labels elif phase=='val':all_filenames = self.val_filenameslabels = self.val_labels elif phase=='test':all_filenames = self.test_filenameslabels = self.test_labelselse:print('Please select vaild phase')one_episode_sample_num = self.num_samples_per_class*self.way_numthis_task_filenames = all_filenames[idx*one_episode_sample_num:(idx+1)*one_episode_sample_num]epitr_sample_num = self.shot_numepite_sample_num = self.episode_test_sample_numthis_task_tr_filenames = []this_task_tr_labels = []this_task_te_filenames = []this_task_te_labels = []for class_k in range(self.way_num):this_class_filenames = this_task_filenames[class_k*self.num_samples_per_class:(class_k+1)*self.num_samples_per_class]this_class_label = labels[class_k*self.num_samples_per_class:(class_k+1)*self.num_samples_per_class]this_task_tr_filenames += this_class_filenames[0:epitr_sample_num]this_task_tr_labels += this_class_label[0:epitr_sample_num]this_task_te_filenames += this_class_filenames[epitr_sample_num:]this_task_te_labels += this_class_label[epitr_sample_num:]this_inputa, this_labela = self.process_batch(this_task_tr_filenames, this_task_tr_labels, epitr_sample_num, reshape_with_one=False)this_inputb, this_labelb = self.process_batch(this_task_te_filenames, this_task_te_labels, epite_sample_num, reshape_with_one=False)return this_inputa, this_labela, this_inputb, this_labelb

CINIC10

参考 CINIC10

import torchvisionimport torchvision.transforms as transformscinic_directory = '/path/to/cinic/directory'cinic_mean = [0.47889522, 0.47227842, 0.43047404]cinic_std = [0.24205776, 0.23828046, 0.25874835]cinic_train = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(cinic_directory + '/train',transform=pose([transforms.ToTensor(),transforms.Normalize(mean=cinic_mean,std=cinic_std)])),batch_size=128, shuffle=True)

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。