码迷,mamicode.com
首页 > 其他好文 > 详细

keras中的keras.utils.to_categorical方法

时间:2020-02-14 22:38:26      阅读:108      评论:0      收藏:0      [点我收藏+]

标签:one   fun   keras   nbsp   function   世界   import   classes   源码   

参考链接:https://blog.csdn.net/nima1994/article/details/82468965 

参考链接:https://blog.csdn.net/gdl3463315/article/details/82659378

to_categorical(y, num_classes=None, dtype=‘float32‘)

将整型的类别标签转为onehot编码。y为int数组,num_classes为标签类别总数,大于max(y)(标签从0开始的)。

返回:如果num_classes=None,返回len(y) * [max(y)+1](维度,m*n表示m行n列矩阵,下同),否则为len(y) * num_classes。
 

  1.  
    import keras
  2.  
     
  3.  
    ohl=keras.utils.to_categorical([1,3])
  4.  
    # ohl=keras.utils.to_categorical([[1],[3]])
  5.  
    print(ohl)
  6.  
    """
  7.  
    [[0. 1. 0. 0.]
  8.  
    [0. 0. 0. 1.]]
  9.  
    """
  10.  
    ohl=keras.utils.to_categorical([1,3],num_classes=5)
  11.  
    print(ohl)
  12.  
    """
  13.  
    [[0. 1. 0. 0. 0.]
  14.  
    [0. 0. 0. 1. 0.]]
  15.  
    """

该部分keras源码如下:

  1.  
    def to_categorical(y, num_classes=None, dtype=‘float32‘):
  2.  
    """Converts a class vector (integers) to binary class matrix.
  3.  
     
  4.  
    E.g. for use with categorical_crossentropy.
  5.  
     
  6.  
    # Arguments
  7.  
    y: class vector to be converted into a matrix
  8.  
    (integers from 0 to num_classes).
  9.  
    num_classes: total number of classes.
  10.  
    dtype: The data type expected by the input, as a string
  11.  
    (`float32`, `float64`, `int32`...)
  12.  
     
  13.  
    # Returns
  14.  
    A binary matrix representation of the input. The classes axis
  15.  
    is placed last.
  16.  
    """
  17.  
    y = np.array(y, dtype=‘int‘)
  18.  
    input_shape = y.shape
  19.  
    if input_shape and input_shape[-1] == 1 and len(input_shape) > 1:
  20.  
    input_shape = tuple(input_shape[:-1])
  21.  
    y = y.ravel()
  22.  
    if not num_classes:
  23.  
    num_classes = np.max(y) + 1
  24.  
    n = y.shape[0]
  25.  
    categorical = np.zeros((n, num_classes), dtype=dtype)
  26.  
    categorical[np.arange(n), y] = 1
  27.  
    output_shape = input_shape + (num_classes,)
  28.  
    categorical = np.reshape(categorical, output_shape)
  29.  
    return categorical
  30.  
     

简单来说:**keras.utils.to_categorical函数:是把类别标签转换为onehot编码(categorical就是类别标签的意思,表示现实世界中你分类的各类别), 而onehot编码是一种方便计算机处理的二元编码。**

keras中的keras.utils.to_categorical方法

标签:one   fun   keras   nbsp   function   世界   import   classes   源码   

原文地址:https://www.cnblogs.com/klausage/p/12309823.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!