码迷,mamicode.com
首页 > 其他好文 > 详细

理解metrics.classification_report

时间:2018-03-04 22:13:12      阅读:2052      评论:0      收藏:0      [点我收藏+]

标签:.class   pre   pytho   ==   report   str   code   call   dom   

混淆矩阵是一个矩阵,类别个数可以有多个,a[i][j]表示将类别i的样本误判为类别j的个数。

classification_report用来分析不同类别的准确率,召回率,F1值等,从而便于按照类别查看准确率、召回率。

总体的正确率跟classification_report中的正确率是不一样。


import numpy as np
import sklearn.metrics as metrics


def report(mine, real):
    if len(mine) != len(real):
        print("mine和real长度不一样")
        exit(0)
    all_classes = set(list(mine) + list(real))
    precision = dict()
    recall = dict()
    f1 = dict()
    support = dict()
    for c in all_classes:
        if np.count_nonzero(mine == c):
            precision[c] = np.count_nonzero(np.logical_and(mine == real, real == c)) / np.count_nonzero(mine == c)
        else:
            precision[c] = 0
        if np.count_nonzero(real == c):
            recall[c] = np.count_nonzero(np.logical_and(mine == real, real == c)) / np.count_nonzero(real == c)
        else:
            recall[c] = 0
        if precision[c] and recall[c]:
            f1[c] = 2 / (1 / precision[c] + 1 / recall[c])
        else:
            f1[c] = 0
        support[c] = np.count_nonzero(real_ans == c)
    s = ''
    s += "%10s%10s%10s%10s%10s\n" % ("class", "precision", "recall", "f1", "support")
    fmtstr2 = "%10s%10.2f%10.2f%10.2f%10d\n"
    for c in all_classes:
        s += (fmtstr2 % (c, precision[c], recall[c], f1[c], support[c]))
    s += fmtstr2 % ("avg",
                    np.sum([precision[c] * support[c] for c in all_classes]) / len(mine),
                    np.sum([recall[c] * support[c] for c in all_classes]) / len(mine),
                    np.sum([f1[c] * support[c] for c in all_classes]) / len(mine),
                    len(mine)
                    )
    return s


my_ans = np.random.randint(0, 2, 10)
real_ans = np.random.randint(0, 2, 10)
print(my_ans)
print(real_ans)
print("分类报告是按照类别分开的")
print('=' * 10)
print(metrics.classification_report(real_ans, my_ans))
print('=' * 10)
print(report(my_ans, real_ans))
print("准确率跟上面的正确率不一样")
print(metrics.accuracy_score(real_ans, my_ans))
print(np.count_nonzero(my_ans == real_ans) / len(my_ans))

理解metrics.classification_report

标签:.class   pre   pytho   ==   report   str   code   call   dom   

原文地址:https://www.cnblogs.com/weidiao/p/8506389.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!