码迷,mamicode.com
首页 > 其他好文 > 详细

F.cross_entropy()函数

时间:2020-06-10 18:54:23      阅读:706      评论:0      收藏:0      [点我收藏+]

标签:ast   from   poc   向量   多个   epo   ORC   pre   oat   

技术图片

 1 x = np.array([[1, 2,3,4,5],#共三3样本,有5个类别
 2              [1, 2,3,4,5],
 3              [1, 2,3,4,5]]).astype(np.float32)
 4 y = np.array([1, 1, 0])#这3个样本的标签分别是1,1,0即两个是第2类,一个是第1类
 5 x = torch.from_numpy(x)
 6 y = torch.from_numpy(y).long()
 7 
 8 soft_out = F.softmax(x,dim=1)#给每个样本的pred向量做指数归一化---softmax

 9 log_soft_out = torch.log(soft_out)#将上面得到的归一化的向量再point-wise取对数

10 loss = F.nll_loss(log_soft_out, y)#将归一化且取对数后的张量根据标签求和,实际就是计算loss的过程
"""
这里的loss计算式根据batch_size归一化后的,即是一个batch的平均单样本的损失,迭代一次模型对一个样本平均损失。
在多个epoch训练时,还会求每个epoch内的总损失,用于衡量epoch之间模型性能的提升。
"""
11 print(soft_out)
12 print(log_soft_out)
13 print(loss)
14   
15 loss = F.cross_entropy(x, y)
16 print(loss)

#输出:
softmax:
tensor([[0.0117, 0.0317, 0.0861, 0.2341, 0.6364],
[0.0117, 0.0317, 0.0861, 0.2341, 0.6364],
[0.0117, 0.0317, 0.0861, 0.2341, 0.6364]])


tensor([[-4.4519, -3.4519, -2.4519, -1.4519, -0.4519],
[-4.4519, -3.4519, -2.4519, -1.4519, -0.4519],
[-4.4519, -3.4519, -2.4519, -1.4519, -0.4519]])


tensor(3.7852)
tensor(3.7852)

F.cross_entropy()函数

标签:ast   from   poc   向量   多个   epo   ORC   pre   oat   

原文地址:https://www.cnblogs.com/Henry-ZHAO/p/13087275.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!