码迷,mamicode.com
首页 > 其他好文 > 详细

pytorch自定义loss函数的几种方法

时间:2020-03-30 23:17:53      阅读:109      评论:0      收藏:0      [点我收藏+]

标签:tensor   strong   style   s函数   loss   包装   int   tom   __init__   

1. 让张量使用Variable类型,如下所示

1 from torch.autograd import Variable
2 
3 inp = torch.zeros(2, 3)
4 inp = Variable(inp).type(torch.LongTensor)
5 print(inp)

Variable类型包装了Tensor类型,并提供了backward()接口

使用Variable类型的好处是,可以按照论文公式来直接使用,并在做张量运算之后,使用继承的backward()直接进行反向传播

2. 自定义类继承nn.Module

1 class CustomMSELoss(nn.Module):
2     def __init__(self):
3         super().__init__()
4         
5     def forward(self, x, y):
6         return torch.mean(torch.pow((x - y), 2))

这种方法结构化程度高,在开发给用户使用时,由于不知道用户的Tensor是否是Variable类型,采用该方法可以减少问题。

pytorch自定义loss函数的几种方法

标签:tensor   strong   style   s函数   loss   包装   int   tom   __init__   

原文地址:https://www.cnblogs.com/webbery/p/12601936.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!