码迷,mamicode.com
首页 > 其他好文 > 详细

PyTorch学习笔记之DataLoaders

时间:2017-07-22 00:46:18      阅读:495      评论:0      收藏:0      [点我收藏+]

标签:lamp   mini   loss   step   tor   variable   als   init   hand   

A DataLoader wraps a Dataset and provides minibatching, shuffling, multithreading, for you。

 1 import torch
 2 from torch.autograd import Variable
 3 import torch.nn as nn
 4 from torch.utils.data import TensorDataset, DataLoader
 5 
 6 # define our whole model as a single Module
 7 class TwoLayerNet(nn.Module):
 8     # Initializer sets up two children (Modules can contain modules)
 9     def _init_(self, D_in, H, D_out):
10         super(TwoLayerNet, self)._init_()
11         self.linear1 = torch.nn.Linear(D_in, H)
12         self.linear2 = torch.nn.Linear(H, D_out)
13 
14     # Define forward pass using child modules and autograd ops on Variables
15     # No need to define backward - autograd will handle it
16     def forward(self, x):
17         h_relu = self.linear1(x).clamp(min=0)
18         y_pred = self.linear2(h_relu)
19         return y_pred
20 
21 N, D_in, H, D_out = 64, 1000, 100, 10
22 x = Variable(torch.randn(N, D_in))
23 y = Variable(torch.randn(N, D_out))
24 
25 # When you need to load custom data, just write your own Dataset class
26 loader = DataLoader(TensorDataset(x, y), batch_size=8)
27 
28 model = TwoLayerNet(D_in, H, D_out)
29 
30 criterion = torch.nn.MSELoss(size_average=False)
31 optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
32 for epoch in range(10):
33     # Iterate(遍历) over loader to form minibatches
34     for x_batch, y_batch in loader:
35         # Loader gives Tensors so you need to wrap in Variables
36         x_var, y_var = Variable(x), Variable(y)
37         y_pred = model(x_var)
38         loss = criterion(y_pred, y_var)
39 
40         optimizer.zero_grad()
41         loss.backward()
42         optimizer.step()

 

PyTorch学习笔记之DataLoaders

标签:lamp   mini   loss   step   tor   variable   als   init   hand   

原文地址:http://www.cnblogs.com/Joyce-song94/p/7220102.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!