码迷,mamicode.com
首页 > 其他好文 > 详细

pytorch保存模型并记录最优模型

时间:2021-04-02 12:56:50      阅读:0      评论:0      收藏:0      [点我收藏+]

标签:poi   一个   com   ORC   copy   bsp   str   模型   state   

 

# https://github.com/tczhangzhi/pytorch-distributed/blob/master/distributed.py

# remember best acc@1 and save checkpoint
is_best = acc1 > best_acc1
best_acc1 = max(acc1, best_acc1)

if args.local_rank == 0:
    save_checkpoint(
               {
                    ‘epoch‘: epoch + 1,
                    ‘arch‘: args.arch,
                    ‘state_dict‘: model.module.state_dict(),
                    ‘best_acc1‘: best_acc1,
                }, is_best)



def save_checkpoint(state, is_best, filename=‘checkpoint.pth.tar‘):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, ‘model_best.pth.tar‘)

  

shutil.copyfile(filename, ‘model_best.pth.tar‘) # 如果是当前最优精度的模型,则保存时维护一个副本

 

pytorch保存模型并记录最优模型

标签:poi   一个   com   ORC   copy   bsp   str   模型   state   

原文地址:https://www.cnblogs.com/jiangkejie/p/14606279.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!