码迷,mamicode.com
首页 > 其他好文 > 详细

pytorch保存模型等相关参数,利用torch.save(),以及读取保存之后的文件

时间:2018-07-15 00:50:52      阅读:486      评论:0      收藏:0      [点我收藏+]

标签:模型   绝对路径   load   ram   star   第一部分   home   epo   文件名   

本文分为两部分,第一部分讲如何保存模型参数,优化器参数等等,第二部分则讲如何读取。

假设网络为model = Net(), optimizer = optim.Adam(model.parameters(), lr=args.lr), 假设在某个epoch,我们要保存模型参数,优化器参数以及epoch

一、

1. 先建立一个字典,保存三个参数:

state = {‘net‘:model.state_dict(), ‘optimizer‘:optimizer.state_dict(), ‘epoch‘:epoch}

2.调用torch.save():

torch.save(state, dir)

其中dir表示保存文件的绝对路径+保存文件名,如‘/home/qinying/Desktop/modelpara.pth‘

二、

当你想恢复某一阶段的训练(或者进行测试)时,那么就可以读取之前保存的网络模型参数等。

checkpoint = torch.load(dir)

model.load_state_dict(checkpoint[‘net‘])

optimizer.load_state_dict(checkpoint[‘optimizer‘])

start_epoch = checkpoint[‘epoch‘] + 1

 

pytorch保存模型等相关参数,利用torch.save(),以及读取保存之后的文件

标签:模型   绝对路径   load   ram   star   第一部分   home   epo   文件名   

原文地址:https://www.cnblogs.com/qinduanyinghua/p/9311410.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!