700字范文,内容丰富有趣,生活中的好帮手!
700字范文 > 【Pytorch】CIFAR1010数据集的训练和测试

【Pytorch】CIFAR1010数据集的训练和测试

时间:2022-12-12 02:43:25

相关推荐

【Pytorch】CIFAR1010数据集的训练和测试

代码:

import torch.nn as nnimport torchimport torch.nn.functional as Ffrom torchvision import transforms, datasets, modelsfrom torch.utils.data import DataLoaderimport argparseimport os# 训练def train(args, model, device, train_loader, optimizer):for epoch in range(1, args.epochs + 1):model.train()for batch_index, data in enumerate(train_loader):images, labels = dataimages = images.to(device)labels = labels.to(device)# forwardoutput = model(images)loss = F.cross_entropy(output, labels)# backwardoptimizer.zero_grad() # 梯度清空loss.backward() # 梯度回传,更新参数optimizer.step()# 打印lossprint(f'Epoch:{epoch},Batch ID:{batch_index}/{len(train_loader)}, loss:{loss}')# 保存模型if epoch % args.checkpoint_interval == 0:torch.save(model.state_dict(), f'checkpoints/cifar10_%d.pth' % epoch)def test(args, model, device, test_loader):model.eval()total_loss = 0num_correect = 0with torch.no_grad():for images, labels in test_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)# 总的losstotal_loss += F.cross_entropy(outputs, labels).item()# 预测值_, predected = torch.max(outputs, dim=1)# 预测对的总个数num_correect += (predected==labels).sum().item()# 计算平均lossaverage_loss = total_loss / len(test_loader.dataset)# 计算准确率accuracy = num_correect / len(test_loader.dataset)# 打印平均loss和准确率print(f'Average loss:{average_loss}\nTest Accuracy:{accuracy*100}%')if __name__ == '__main__':parser = argparse.ArgumentParser(description = 'Pytorch-cifar10_classification')parser.add_argument('--epochs', type=int, default=10, help='number of epochs')parser.add_argument('--batch_size', type=int, default=32, help='size of each image batch')parser.add_argument('--num_classes', type=int, default=10, help='number of classes')parser.add_argument('--lr', type=float, default=0.001, help='learning rate')parser.add_argument('--momentum', type=float, default=0.9, help='SGD momentum')parser.add_argument('--pretrained_weights', type=str, default='checkpoints/cifar10_17.pth',help='if specified starts from checkpoint model')parser.add_argument("--img_size", type=int, default=224, help="size of each image dimension")parser.add_argument("--checkpoint_interval", type=int, default=1, help="interval between saving model weights")parser.add_argument("--train", default=True, help="train or test")args = parser.parse_args()print(args)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# os.makedirs() 方法用于递归创建目录os.makedirs("output", exist_ok=True)os.makedirs("checkpoints", exist_ok=True)# transformdata_transform = pose([transforms.ToTensor(),transforms.RandomResizedCrop(args.img_size)])# 下载训练数据集trian_data = datasets.CIFAR10(root = 'data',train = True,download = False,transform = data_transform,target_transform = None,)# 下载测试数据集test_data = datasets.CIFAR10(root = "data",train = False,download = False,transform = data_transform,target_transform = None)# 加载数据train_loader = DataLoader(dataset = trian_data,batch_size = args.batch_size,shuffle = True)test_loader = DataLoader(dataset = test_data,batch_size = args.batch_size)# 创建模型,使用预训练好的权重model = models.vgg16(pretrained = True)# # 冻结模型,参数不更新# for para in model.parameters():#para.requires_grad = False# # 只训练全连接层# model.classifier[3].requires_grad = True# model.classifier[6].requires_grad = True# 修改vgg16的输出维度model.classifier[6] = nn.Linear(in_features=4096, out_features=args.num_classes, bias=True)model = model.to(device)# 打印网络结构print(model)# 定义优化器(也可以选择其他优化器)optimizer = torch.optim.SGD(model.parameters(), lr = args.lr, momentum = args.momentum)# optimizer = torch.optim.Adam(model.parameters())if train == True:if args.pretrained_weights.endswith(".pth"):model.load_state_dict(torch.load(args.pretrained_weights))for epoch in range(1, epochs+1):train(args, model, device, train_loader, optimizer)else:if args.pretrained_weights.endswith(".pth"):model.load_state_dict(torch.load(args.pretrained_weights))test(args, model, device, test_loader)

说明:

cifar10数据集可以通过trochvision中的datasets.CIFAR10下载,也可以自己下载(注意存放路径);我模型使用的是torchvision中的models中预训练好的vgg16网络,也可以自己搭建网络。

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。