查找最相关的三个类别及其相应的概率

从下面的脚本中,我找到了多类文本分类问题中最高的概率及其对应的类别。我想知道如何以最有效的方式找到最高的前三预测概率及其对应的类别,并且不使用循环。

probabilities = classifier.predict_proba(X_test)max_probabilities = probabilities.max(axis=1)order=np.argsort(probabilities, axis=1)classification=(classifier.classes_[order[:, -1:]])print(accuracy_score(classification,y_test))

提前感谢您。(我有大约50个类别,我想从中提取每个叙述中最相关的三个类别,并将它们显示在数据框中)


回答:

您已经完成了大部分艰难的工作,只差一点numpy技巧就可以完成。您的这行代码

order = np.argsort(probabilities, axis=1)

包含了排序后的概率索引,因此每个样本的格式为[[最低概率类别1, ..., 最高概率类别1]...]。您已经用order[:, -1:]来给出分类,即最高概率类别的索引。因此,要获得前三个类别,我们可以进行简单的修改

top_3_classes = classifier.classes_[order[:, -3:]]

然后,要获取相应的概率,我们可以使用

top_3_probabilities = probabilities[np.repeat(np.arange(order.shape[0]), 3),                                    order[:, -3:].flatten()].reshape(order.shape[0], 3)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注