码迷,mamicode.com
首页 > 其他好文 > 详细

使用 VGG16 对 CIFAR10 分类

时间:2020-08-01 14:36:14      阅读:128      评论:0      收藏:0      [点我收藏+]

标签:sel   pos   horizon   cpu   form   pre   ima   测试   flip   

1.定义 dataloader

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# 使用GPU训练,可以在菜单 "代码执行工具" -> "更改运行时类型" 里进行设置
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

trainset = torchvision.datasets.CIFAR10(root=‘./data‘, train=True,  download=True, transform=transform_train)
testset  = torchvision.datasets.CIFAR10(root=‘./data‘, train=False, download=True, transform=transform_test)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)

classes = (‘plane‘, ‘car‘, ‘bird‘, ‘cat‘,
           ‘deer‘, ‘dog‘, ‘frog‘, ‘horse‘, ‘ship‘, ‘truck‘)

2.VGG 网络定义并初始化

class VGG(nn.Module):
    def __init__(self):
        super(VGG, self).__init__()
        self.cfg = [64, ‘M‘, 128, ‘M‘, 256, 256, ‘M‘, 512, 512, ‘M‘, 512, 512, ‘M‘]
        self.features = self._make_layers(cfg)
        self.classifier = nn.Linear(2048, 10)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == ‘M‘:
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.BatchNorm2d(x),
                           nn.ReLU(inplace=True)]
                in_channels = x
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)
# 网络放到GPU上
net = VGG().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

3.网络训练

for epoch in range(10):  # 重复多轮训练
    for i, (inputs, labels) in enumerate(trainloader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        # 优化器梯度归零
        optimizer.zero_grad()
        # 正向传播 + 反向传播 + 优化 
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        # 输出统计信息
        if i % 100 == 0:   
            print(‘Epoch: %d Minibatch: %5d loss: %.3f‘ %(epoch + 1, i + 1, loss.item()))

print(‘Finished Training‘)

4.测试验证准确率

correct = 0
total = 0

for data in testloader:
    images, labels = data
    images, labels = images.to(device), labels.to(device)
    outputs = net(images)
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()

print(‘Accuracy of the network on the 10000 test images: %.2f %%‘ % (
    100 * correct / total))

使用 VGG16 对 CIFAR10 分类

标签:sel   pos   horizon   cpu   form   pre   ima   测试   flip   

原文地址:https://www.cnblogs.com/lixinhh/p/13414346.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!