当前位置:
首页
文章
后端
详情

Pytorch训练模型得到输出后计算F1-Score 和AUC的操作

我们在pytorch训练模型完成后我们需要计算F1-Score和AUC来评估这个模型的训练效果。在pytorch中计算F1-Score和AUC是比较简单的。那么pytorch怎么求这两个值呢?接下来这篇文章告诉你。

1、计算F1-Score

对于二分类来说,假设batch size 大小为64的话,那么模型一个batch的输出应该是torch.size([64,2]),所以首先做的是得到这个二维矩阵的每一行的最大索引值,然后添加到一个列表中,同时把标签也添加到一个列表中,最后使用sklearn中计算F1的工具包进行计算,代码如下

import numpy as np
import sklearn.metrics import f1_score
prob_all = []
lable_all = []
for i, (data,label) in tqdm(train_data_loader):
    prob = model(data) #表示模型的预测输出
    prob = prob.cpu().numpy() #先把prob转到CPU上,然后再转成numpy,如果本身在CPU上训练的话就不用先转成CPU了
    prob_all.extend(np.argmax(prob,axis=1)) #求每一行的最大值索引
    label_all.extend(label)
print("F1-Score:{:.4f}".format(f1_score(label_all,prob_all)))

2、计算AUC

计算AUC的时候,本次使用的是sklearn中的roc_auc_score () 方法

输入参数:

y_true:真实的标签。形状 (n_samples,) 或 (n_samples, n_classes)。二分类的形状 (n_samples,1),而多标签情况的形状 (n_samples, n_classes)。

y_score:目标分数。形状 (n_samples,) 或 (n_samples, n_classes)。二分类情况形状 (n_samples,1),“分数必须是具有较大标签的类的分数”,通俗点理解:模型打分的第二列。举个例子:模型输入的得分是一个数组 [0.98361117 0.01638886],索引是其类别,这里 “较大标签类的分数”,指的是索引为 1 的分数:0.01638886,也就是正例的预测得分。

average='macro':二分类时,该参数可以忽略。用于多分类,' micro ':将标签指标矩阵的每个元素看作一个标签,计算全局的指标。' macro ':计算每个标签的指标,并找到它们的未加权平均值。这并没有考虑标签的不平衡。' weighted ':计算每个标签的指标,并找到它们的平均值,根据支持度 (每个标签的真实实例的数量) 进行加权。

sample_weight=None:样本权重。形状 (n_samples,),默认 = 无。

max_fpr=None

multi_class='raise':(多分类的问题在下一篇文章中解释)

labels=None

输出:

auc:是一个 float 的值。

import numpy as np
import sklearn.metrics import roc_auc_score
prob_all = []
lable_all = []
for i, (data,label) in tqdm(train_data_loader):
    prob = model(data) #表示模型的预测输出
    prob_all.extend(prob[:,1].cpu().numpy()) #prob[:,1]返回每一行第二列的数,根据该函数的参数可知,y_score表示的较大标签类的分数,因此就是最大索引对应的那个值,而不是最大索引值
    label_all.extend(label)
print("AUC:{:.4f}".format(roc_auc_score(label_all,prob_all)))

补充:pytorch训练模型的一些坑

1. 图像读取

opencv的python和c++读取的图像结果不一致,是因为python和c++采用的opencv版本不一样,从而使用的解码库不同,导致读取的结果不同。

2. 图像变换

PIL和pytorch的图像resize操作,与opencv的resize结果不一样,这样会导致训练采用PIL,预测时采用opencv,结果差别很大,尤其是在检测和分割任务中比较明显。

3. 数值计算

pytorch的torch.exp与c++的exp计算,10e-6的数值时候会有10e-3的误差,对于高精度计算需要特别注意,比如

两个输入5.601597, 5.601601, 经过exp计算后变成270.85862343143174, 270.85970686809225

以上就是Pytorch训练模型后计算F1-Score和AUC的方法介绍,希望能给大家一个参考,也希望大家多多支持W3Cschool。


免责申明:本站发布的内容(图片、视频和文字)以转载和分享为主,文章观点不代表本站立场,如涉及侵权请联系站长邮箱:xbc-online@qq.com进行反馈,一经查实,将立刻删除涉嫌侵权内容。

同类热门文章

深入了解C++中的new操作符:使用具体实例学习

C++中的new操作符是动态分配内存的主要手段之一。在程序运行时,我们可能需要动态地创建和销毁对象,而new就是为此提供了便利。但是,使用new也常常会引发一些问题,如内存泄漏、空指针等等。因此,本文将通过具体的示例,深入介绍C++中的new操作符,帮助读者更好地掌握其使用。


深入了解C++中的new操作符:使用具体实例学习

怎么用Java反射获取包下所有类? 详细代码实例操作

Java的反射机制就是在运行状态下,对于任何一个类,它能知道这个类的所有属性和方法;对于任何一个对象,都能调用这个对象的任意一个方法。本篇文章将通过具体的代码示例,展示如何通过Java反射来获取包下的所有类。


怎么用Java反射获取包下所有类? 详细代码实例操作

了解Java中的volati关键字的作用 以及具体使用方法

本篇文章将和大家分享一下Java当中的volatile关键字,下面将为各位小伙伴讲述volatile关键字的作用以及它的具体使用方法。


了解Java中的volati关键字的作用 以及具体使用方法

Java Map 所有的值转为String类型

可以使用 Java 8 中的 Map.replaceAll() 方法将所有的值转为 String 类型: 上面的代码会将 map 中所有的值都转为 String 类型。 HashMap 是 Java

Java Map 所有的值转为String类型

员工线上学习考试系统

有点播,直播,在线支付,三级分销等功能,可以对学员学习情况的监督监控,有源码,可二次开发。支持外网和局域网私有化部署,经过测试源码完整可用!1、视频点播:视频播放,图文资料,课件下载,章节试学,限时免

员工线上学习考试系统