码迷,mamicode.com
首页 > 其他好文 > 详细

PyTorch 有哪些坑/bug?

时间:2020-11-30 15:36:48      阅读:4      评论:0      收藏:0      [点我收藏+]

标签:概率   系统   add   data   感受   哪些   公众   batch   图片   

技术图片

算是动态图的一个坑吧。记录loss信息的时候直接使用了输出的Variable。
应该不止我经历过这个吧...
久久不用又会不小心掉到这个坑里去...

for data, label in trainloader:
    ......
    out = model(data)
    loss = criterion(out, label)
    loss_sum += loss     # <--- 这里
    ......

运行着就发现显存炸了

观察了一下发现随着每个batch显存消耗在不断增大..
参考了别人的代码发现那句loss一般是这样写 /(ㄒoㄒ)/~~

loss_sum += loss.data[0]

这是因为输出的loss的数据类型是Variable。

而PyTorch的动态图机制就是通过Variable来构建图。主要是使用Variable计算的时候,会记录下新产生的Variable的运算符号,在反向传播求导的时候进行使用。

如果这里直接将loss加起来,系统会认为这里也是计算图的一部分,也就是说网络会一直延伸变大~那么消耗的显存也就越来越大~~

总之使用Variable的数据时候要非常小心。不是必要的话尽量使用Tensor来进行计算...

包括数据的输入时候,如果“过早”把数据丢到Variable里面去,那么可能也会被系统视为网络的一部分。所以,要投入的时候再把数据丢到Variable里面去吧~
题外话

想更多感受动态图的话,可以通过Variable的grad_fun来观察到该Variable是通过什么运算得到的(前提是前面的Variable的required_grad置为True)。
大概是这样

>> >> z = x + y
>> z.grad_fn
out:
    <AddBackward1 at 0x107286240>

技术图片

推荐阅读:

【深度学习实战】pytorch中如何处理RNN输入变长序列padding
【机器学习基本理论】详解最大后验概率估计(MAP)的理解
【区块链】区块链最通俗入门教程

      欢迎关注公众号学习交流~         

技术图片

PyTorch 有哪些坑/bug?

标签:概率   系统   add   data   感受   哪些   公众   batch   图片   

原文地址:https://blog.51cto.com/15009309/2554205

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!