700字范文,内容丰富有趣,生活中的好帮手!
700字范文 > pytorch创建自己的Dataset加载数据集

pytorch创建自己的Dataset加载数据集

时间:2021-05-28 05:11:50

相关推荐

pytorch创建自己的Dataset加载数据集

文章目录

创建一个类并继承torch.utils.data.dataset.Datase类创建__getitem__方法加载数据集

创建一个类并继承torch.utils.data.dataset.Datase类

class MyDataset(Dataset):'''data_path: 数据集路径img_size: 图片大小train_lines: 图片名数组'''def __init__(self,data_path,img_size,train_lines):super(MyDataset, self).__init__()self.data_path = data_pathself.img_size = img_sizeself.train_lines = train_linesself.length = len(train_lines)

创建__getitem__方法

class MyDataset(Dataset):'''data_path: 数据集路径img_size: 图片大小train_lines: 图片名数组'''def __init__(self,data_path,img_size,train_lines):super(MyDataset, self).__init__()self.data_path = data_pathself.img_size = img_sizeself.train_lines = train_linesdef __getitem__(self, index):annotation_line = self.train_lines[index]name = annotation_line.split()[0] # 获取图片名image = Image.open(os.path.join(os.path.join(self.data_path,"dem"),name+".tif"))label = Image.open(os.path.join(os.path.join(self.data_path, "label"), name + ".png"))image = np.array(image)label = np.array(label)image = cv2.resize(image,(self.img_size,self.img_size))label = cv2.resize(label,(self.img_size,self.img_size))# image = image[np.newaxis,:]print("images size: {}, label size: {}".format(image.shape,label.shape))return image,label

加载数据集

如果不知道如何将文件夹中所有图片名称写入TXT中可以参考:python读取文件夹中的所有图片并将图片名逐行写入txt中:/weixin_43598687/article/details/125666776?spm=1001..3001.5501

dataset_path = r"E:/workspace/PyCharmProject/dem_feature/dem/512"# 打开数据集的txt, 逐行读取图片名with open(os.path.join(dataset_path, "dem/train.txt"), "r") as f:train_lines = f.readlines()with open(os.path.join(dataset_path, "dem/val.txt"), "r") as f:val_lines = f.readlines()train_dataset = MyDataset(dataset_path, img_size=512,train_lines=train_lines)train_dataloader = DataLoader(train_dataset,batch_size=8,shuffle=False)for iteration,data in enumerate(train_dataloader):imgs,labels = dataprint(imgs,labels)

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。