码迷,mamicode.com
首页 > Web开发 > 详细

pytorch训练AlexNet

时间:2020-12-25 11:37:31      阅读:0      评论:0      收藏:0      [点我收藏+]

标签:running   load   str   als   back   state   nbsp   poc   for   

一。AlexNet网络结构和参数

技术图片

 

 技术图片

 

 二。训练部分

model.py

 1 import torch.nn as nn
 2 import torch
 3 
 4 
 5 class AlexNet(nn.Module):
 6     def __init__(self, num_classes=1000, init_weights=False):
 7         super(AlexNet, self).__init__()
 8         self.features = nn.Sequential(
 9             nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),  # input[3, 224, 224]  output[48, 55, 55]
10             nn.ReLU(inplace=True),
11             nn.MaxPool2d(kernel_size=3, stride=2),                  # output[48, 27, 27]
12             nn.Conv2d(48, 128, kernel_size=5, padding=2),           # output[128, 27, 27]
13             nn.ReLU(inplace=True),
14             nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 13, 13]
15             nn.Conv2d(128, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
16             nn.ReLU(inplace=True),
17             nn.Conv2d(192, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
18             nn.ReLU(inplace=True),
19             nn.Conv2d(192, 128, kernel_size=3, padding=1),          # output[128, 13, 13]
20             nn.ReLU(inplace=True),
21             nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 6, 6]
22         )
23         self.classifier = nn.Sequential(
24             nn.Dropout(p=0.5),
25             nn.Linear(128 * 6 * 6, 2048),
26             nn.ReLU(inplace=True),
27             nn.Dropout(p=0.5),
28             nn.Linear(2048, 2048),
29             nn.ReLU(inplace=True),
30             nn.Linear(2048, num_classes),
31         )
32         if init_weights:
33             self._initialize_weights()
34 
35     def forward(self, x):
36         x = self.features(x)
37         x = torch.flatten(x, start_dim=1)
38         x = self.classifier(x)
39         return x
40 
41     def _initialize_weights(self):
42         for m in self.modules():
43             if isinstance(m, nn.Conv2d):
44                 nn.init.kaiming_normal_(m.weight, mode=fan_out, nonlinearity=relu)
45                 if m.bias is not None:
46                     nn.init.constant_(m.bias, 0)
47             elif isinstance(m, nn.Linear):
48                 nn.init.normal_(m.weight, 0, 0.01)
49                 nn.init.constant_(m.bias, 0)

 

train.py

  1 import torch
  2 import torch.nn as nn
  3 from torchvision import transforms, datasets, utils
  4 import matplotlib.pyplot as plt
  5 import numpy as np
  6 import torch.optim as optim
  7 from model import AlexNet
  8 import os
  9 import json
 10 import time
 11 
 12 
 13 def main():
 14     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 15     print("using {} device.".format(device))
 16 
 17     data_transform = {
 18         "train": transforms.Compose([transforms.RandomResizedCrop(224),
 19                                      transforms.RandomHorizontalFlip(),
 20                                      transforms.ToTensor(),
 21                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
 22         "val": transforms.Compose([transforms.Resize((224, 224)),  # cannot 224, must (224, 224)
 23                                    transforms.ToTensor(),
 24                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}
 25 
 26     data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
 27     image_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set path
 28     assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
 29     train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
 30                                          transform=data_transform["train"])
 31     train_num = len(train_dataset)
 32 
 33     # {‘daisy‘:0, ‘dandelion‘:1, ‘roses‘:2, ‘sunflower‘:3, ‘tulips‘:4}
 34     flower_list = train_dataset.class_to_idx
 35     cla_dict = dict((val, key) for key, val in flower_list.items())
 36     # write dict into json file
 37     json_str = json.dumps(cla_dict, indent=4)
 38     with open(class_indices.json, w) as json_file:
 39         json_file.write(json_str)
 40 
 41     batch_size = 32
 42     nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
 43     print(Using {} dataloader workers every process.format(nw))
 44 
 45     train_loader = torch.utils.data.DataLoader(train_dataset,
 46                                                batch_size=batch_size, shuffle=True,
 47                                                num_workers=0)
 48 
 49     validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
 50                                             transform=data_transform["val"])
 51     val_num = len(validate_dataset)
 52     validate_loader = torch.utils.data.DataLoader(validate_dataset,
 53                                                   batch_size=batch_size, shuffle=True,
 54                                                   num_workers=0)
 55 
 56     print("using {} images for training, {} images fot validation.".format(train_num,
 57                                                                            val_num))
 58     # test_data_iter = iter(validate_loader)
 59     # test_image, test_label = test_data_iter.next()
 60     # #
 61     # def imshow(img):
 62     #     img = img / 2 + 0.5  # unnormalize
 63     #     npimg = img.numpy()
 64     #     plt.imshow(np.transpose(npimg, (1, 2, 0)))
 65     #     plt.show()
 66     #
 67     # print(‘ ‘.join(‘%5s‘ % cla_dict[test_label[j].item()] for j in range(4)))
 68     # imshow(utils.make_grid(test_image))
 69 
 70     net = AlexNet(num_classes=5, init_weights=True)
 71 
 72     net.to(device)
 73     loss_function = nn.CrossEntropyLoss()
 74     # pata = list(net.parameters())
 75     optimizer = optim.Adam(net.parameters(), lr=0.0002)
 76 
 77     save_path = ./AlexNet.pth
 78     best_acc = 0.0
 79     for epoch in range(10):
 80         # train
 81         net.train()
 82         running_loss = 0.0
 83         t1 = time.perf_counter()
 84         for step, data in enumerate(train_loader, start=0):
 85             images, labels = data
 86             optimizer.zero_grad()
 87             outputs = net(images.to(device))
 88             loss = loss_function(outputs, labels.to(device))
 89             loss.backward()
 90             optimizer.step()
 91 
 92             # print statistics
 93             running_loss += loss.item()
 94             # print train process
 95             rate = (step + 1) / len(train_loader)
 96             a = "*" * int(rate * 50)
 97             b = "." * int((1 - rate) * 50)
 98             print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")
 99         print()
100         print(time.perf_counter()-t1)
101 
102         # validate
103         net.eval()
104         acc = 0.0  # accumulate accurate number / epoch
105         with torch.no_grad():
106             for val_data in validate_loader:
107                 val_images, val_labels = val_data
108                 outputs = net(val_images.to(device))
109                 predict_y = torch.max(outputs, dim=1)[1]
110                 acc += (predict_y == val_labels.to(device)).sum().item()
111             val_accurate = acc / val_num
112             if val_accurate > best_acc:
113                 best_acc = val_accurate
114                 torch.save(net.state_dict(), save_path)
115             print([epoch %d] train_loss: %.3f  test_accuracy: %.3f %
116                   (epoch + 1, running_loss / step, val_accurate))
117 
118     print(Finished Training)
119 
120 
121 if __name__ == __main__:
122     main()

 

pytorch训练AlexNet

标签:running   load   str   als   back   state   nbsp   poc   for   

原文地址:https://www.cnblogs.com/sclu/p/14162460.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!