(可选)keypoints (FloatTensor[N, K, 3]):对于 N 个对象中的每一个,它包含[x, y, visibility]格式的 K 个关键点,用于定义对象。 可见性为 0 (visibility=0)表示关键点不可见。 请注意,对于数据扩充,翻转关键点的概念取决于数据表示形式,您可能应该将references/detection/transforms.py修改为新的关键点表示形式
import torchvision from torchvision.models.detection.faster_rcnn import FastRCNNPredictor from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
def get_model_instance_segmentation(num_classes): #加载一个在 COCO 上预训练的实例分割模型 model = torchvision.models.detection.maskrcnn_resnet50_fpn(weight="DEFAULT")
images,targets = next(iter(data_loader)) images = list(image for image in images) targets = [{k: v for k,v in t.items()} for t in targets] output = model(images,targets)
model.eval() x = [torch.rand(3,300,400), torch.rand(3,500,400)] predictions = model(x)
from engine import train_one_epoch, evaluate import utils
def main(): # train on the GPU or on the CPU, if a GPU is not available device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# our dataset has two classes only - background and person num_classes = 2 # use our dataset and defined transformations dataset = PennFudanDataset('PennFudanPed', get_transform(train=True)) dataset_test = PennFudanDataset('PennFudanPed', get_transform(train=False))
# split the dataset in train and test set indices = torch.randperm(len(dataset)).tolist() dataset = torch.utils.data.Subset(dataset, indices[:-50]) dataset_test = torch.utils.data.Subset(dataset_test, indices[-50:])
# define training and validation data loaders data_loader = torch.utils.data.DataLoader( dataset, batch_size=2, shuffle=True, num_workers=4, collate_fn=utils.collate_fn)
# get the model using our helper function model = get_model_instance_segmentation(num_classes)
# move model to the right device model.to(device)
# construct an optimizer params = [p for p in model.parameters() if p.requires_grad] optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005) # and a learning rate scheduler lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
# let's train it for 10 epochs num_epochs = 10
for epoch in range(num_epochs): # train for one epoch, printing every 10 iterations train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10) # update the learning rate lr_scheduler.step() # evaluate on the test dataset evaluate(model, data_loader_test, device=device)
import torch import torch.nn as nn import torch.optim as optim from torch.optim import lr_scheduler import torch.backends.cudnn as cudnn import numpy as np import torchvision from torchvision import datasets,models,transforms import matplotlib.pyplot as plt import time import os import copy
print(plt.ion()) #交互式
输出结果:
1
<matplotlib.pyplot._IonContext object at 0x000001E0FBBAE640>
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train','val']} dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x],batch_size=4,shuffle=True,num_workers=4) for x in ['train','val']}
dataset_sizes = {x:len(image_datasets[x]) for x in ['train','val']} class_names = image_datasets['train'].classes
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
可视化一些图像
函数imshow()代码块一定要放在 ‘main‘函数内调用否则会出问题:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
def imshow(inp, title=None): inp = inp.numpy().transpose((1, 2, 0)) mean = np.array([0.485, 0.456, 0.406]) std = np.array([0.229, 0.224, 0.225]) inp = std * inp + mean inp = np.clip(inp, 0, 1) plt.imshow(inp) if title is not None: plt.title(title) plt.pause(0.001) # 暂停一会儿更新 if __name__ == '__main__': # 得到一批训练数据 input,classes = next(iter(dataloaders['train'])) out = torchvision.utils.make_grid(input) imshow(out,title=[class_names[x] for x in classes])
输出结果是一张图片:
训练模型
现在编写一个常规函数来训练模型,说明:
安排学习率
保存最佳模型
以下,参数scheduler是来自torch.optim.lr_scheduler的 LR 调度器对象。
#每7个纪元衰减 LR 的因子为0.1 exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft,step_size=7,gamma=0.1)
输出结果,只有第一次运行时会下载:
1 2 3 4 5 6
C:\ProgramData\Anaconda3\lib\site-packages\torchvision\models\_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and will be removed in 0.15, please use 'weights' instead. warnings.warn( C:\ProgramData\Anaconda3\lib\site-packages\torchvision\models\_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights. warnings.warn(msg) Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to C:\Users\sy/.cache\torch\hub\checkpoints\resnet18-f37072fd.pth 100%|██████████| 44.7M/44.7M [00:04<00:00, 11.5MB/s]