码迷,mamicode.com
首页 > Web开发 > 详细

mxnet-保存模型参数

时间:2018-11-14 11:00:58      阅读:1277      评论:0      收藏:0      [点我收藏+]

标签:time   shuffle   .data   ted   vat   idt   stride   rom   nis   

#!/usr/bin/env python2 # -*- coding: utf-8 -*- """ Created on Fri Aug 10 16:13:29 2018 @author: myhaspl """ import mxnet as mx import mxnet.ndarray as nd from mxnet import nd, autograd, gluon from mxnet.gluon.data.vision import datasets, transforms from time import time def build_lenet(net): with net.name_scope(): net.add(gluon.nn.Conv2D(channels=6,kernel_size=5,activation="relu"), gluon.nn.MaxPool2D(pool_size=2, strides=2), gluon.nn.Conv2D(channels=16, kernel_size=3, activation="relu"), gluon.nn.MaxPool2D(pool_size=2, strides=2), gluon.nn.Flatten(), gluon.nn.Dense(120, activation="relu"), gluon.nn.Dense(84, activation="relu"), gluon.nn.Dense(10)) return net mnist_train = datasets.FashionMNIST(train=True) X, y = mnist_train[0] print (‘X shape: ‘, X.shape, ‘X dtype‘, X.dtype, ‘y:‘, y,‘Y dtype‘, y.dtype) #x:(height, width, channel) #y:numpy.scalar,标签 text_labels = [ ‘t-shirt‘, ‘trouser‘, ‘pullover‘, ‘dress‘, ‘coat‘, ‘sandal‘, ‘shirt‘, ‘sneaker‘, ‘bag‘, ‘ankle boot‘ ] #转换图像为(channel, height, weight)格式,并且为floating数据类型,通过transforms.ToTensor。 #另外,normalize所有像素值 使用 transforms.Normalize平均值0.13和标准差0.31. transformer = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(0.13, 0.31)]) #只转换第一个元素,图像部分。第二个元素为标签。 mnist_train = mnist_train.transform_first(transformer) #加载批次数据 batch_size = 200 train_data = gluon.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True) #读取本批数据 i=1 for data, label in train_data: print i print data,label break#没有这一行,会以每批次200个数据来读取。 mnist_valid = gluon.data.vision.FashionMNIST(train=False) valid_data = gluon.data.DataLoader(mnist_valid.transform_first(transformer),batch_size=batch_size) #定义网络 net = build_lenet(gluon.nn.Sequential()) net.initialize(init=mx.init.Xavier()) print net #输出softmax与误差 softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss() #定义训练器 trainer = gluon.Trainer(net.collect_params(), ‘sgd‘, {‘learning_rate‘: 0.1}) def acc(output, label): # output: (batch, num_output) float32 ndarray # label: (batch, ) int32 ndarray return (output.argmax(axis=1) == label.astype(‘float32‘)).mean().asscalar() for epoch in range(10): train_loss, train_acc, valid_acc = 0., 0., 0. tic = time() for data, label in train_data: # 前向与反馈 with autograd.record(): output = net(data) loss = softmax_cross_entropy(output, label) loss.backward() # 换一批样本数据,更新参数 trainer.step(batch_size) # 计算训练误差和正确率 train_loss += loss.mean().asscalar() train_acc += acc(output, label) print "." #测试正确率 for data, label in valid_data: predict_data=net(data) valid_acc += acc(predict_data, label) print("Epoch %d: Loss: %.3f, Train acc %.3f, Test acc %.3f, \ Time %.1f sec" % ( epoch, train_loss/len(train_data), train_acc/len(train_data), valid_acc/len(valid_data), time()-tic)) #保存模型参数,非模型结构 file_name = "net.params" net.save_parameters(file_name)

mxnet-保存模型参数

标签:time   shuffle   .data   ted   vat   idt   stride   rom   nis   

原文地址:http://blog.51cto.com/13959448/2316666

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!