码迷,mamicode.com
首页 > 其他好文 > 详细

torch.argmax和argmin返回值

时间:2020-07-12 20:51:13      阅读:65      评论:0      收藏:0      [点我收藏+]

标签:序号   port   意思   print   import   span   结果   example   com   

  在进行深度学习张量计算时,经常要获取张量在某个维度的最大值和最小值,以及这些值的位置。如果只需要知道位置,则torch.argmax和torch.argmin函数便可以实现。

Torch.argmax(input, dim=None, keepdim=False):返回指定维度最大值的序号。

  有时候返回的值比较难理解,所以这里直接放example以帮助理解:

 1 import torch
 2 
 3 t = torch.tensor([[1,2],[3,4],[2,8]])
 4 
 5 print(torch.argmax(t,0))
 6 
 7 
 8 g = torch.tensor([[[1,2,3],[2,3,4],[5,6,7]], [[3,4,5],[7,6,5],[5,4,3]], [[8,9,0],        
 9                             [2,8,4],[7,5,3]]])
10 print(g)
11 print(torch.argmax(g,0))

先从简单的2维张量来看,t 是一个2维张量,大小为(3,2)。t 为 技术图片,此时我们使dim=0,意思使求第0维的(即(3,2)中的3行)中的最大值的序号,所以固定行,直接看列,第一列中3最大,故得到值1,第2列中8最大,故得到值2。最终的结果为  tensor([1,2])

 


再来看一个3维张量g , tensor([[[1, 2, 3],

              [2, 3, 4],
              [5, 6, 7]],

              [[3, 4, 5],
              [7, 6, 5],
              [5, 4, 3]],

              [[8, 9, 0],
              [2, 8, 4],
              [7, 5, 3]]]),其大小为(3,3,3) 其中我们希望在dim=0的维度中求最大值的序号,则固定第一个维度,第一个维度为channel,则每个channel中对应位置进行比较。

比如每个channel中的(0,0)比较,1<3<8,所以得到的值为2;(0,1)比较,2<4<9,依然得到2,....以此类推。最终得到结果tensor([[2, 2, 1],[1, 2, 1],[2, 0, 0]])。

 

torch.argmax和argmin返回值

标签:序号   port   意思   print   import   span   结果   example   com   

原文地址:https://www.cnblogs.com/ASTHNONT/p/13289579.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!