计算分类模型预测的概率

我有一个分类任务。训练数据有50个不同的标签。客户希望区分出低概率的预测,这意味着,我需要根据模型的概率(确定性?)将一些测试数据分类为未分类/其他

当我测试我的代码时,预测结果是一个numpy数组(我使用不同的模型,这个是预训练的BertTransformer)。预测数组不包含像Keras中的predict_proba()方法那样的概率。这些是预训练的BertTransformer模型的预测方法生成的数字。

[[-1.7862008  -0.7037363   0.09885322  1.5318055   2.1137428  -0.2216074   0.18905772 -0.32575375  1.0748093  -0.06001111  0.01083148  0.47495762   0.27160102  0.13852511 -0.68440574  0.6773654  -2.2712054  -0.2864312  -0.8428862  -2.1132915  -1.0157436  -1.0340284  -0.35126117 -1.0333195   9.149789   -0.21288703  0.11455813 -0.32903734  0.10503325 -0.3004114  -1.3854568  -0.01692022 -0.4388664  -0.42163098 -0.09182278 -0.28269592  -0.33082992 -1.147654   -0.6703184   0.33038092 -0.50087476  1.1643585   0.96983343  1.3400391   1.0692116  -0.7623776  -0.6083422  -0.91371405   0.10002492]]

我使用numpy.argmax()来识别正确的标签。预测工作得很好。然而,由于这些不是概率,我无法将最佳结果与阈值进行比较。

我的问题是,我如何定义一个阈值(比如,0.6),然后将BertTransformer预测数组中argmax()元素的概率与该阈值进行比较,以便如果概率低于阈值,我可以将预测分类为“其他”?

编辑1:

我们使用两种不同的模型。一个是Keras,另一个是BertTransformer。我们在Keras中没有问题,因为它提供了概率,所以我跳过了Keras模型。

Bert模型是预训练的。这是它的生成方式:

def model(self, data):        number_of_categories = len(data['encoded_categories'].unique())        model = BertForSequenceClassification.from_pretrained(            "dbmdz/bert-base-turkish-128k-uncased",            num_labels=number_of_categories,            output_attentions=False,            output_hidden_states=False,        )        # model.cuda()        return model

上面给出的输出是model.predict()方法的结果。我们比较了两个模型,Bert略胜一筹,因此我们知道预测工作得很好。然而,我们不确定这些数字代表什么或有什么意义。

这是Bert文档的链接。


回答:

BertForSequenceClassification返回logits,即归一化前的分类得分。您可以通过调用F.softmax(output, dim=-1)来归一化这些得分,其中torch.nn.functional已作为F导入。

对于数千个标签,归一化可能成本高昂,而当您只对argmax感兴趣时,您不需要它。这可能就是模型只返回原始得分的原因。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注