如何正确使用sklearn.naive_bayes中的sample_weight?

我在使用sklearn实现Naive Bayes分类器时,处理的是不平衡数据。我的数据有超过16,000条记录和6个输出类别。

我尝试使用sklearn.utils.class_weight计算出的sample_weight来拟合模型。

sample_weight的值大致如下:

sample_weight = [11.77540107 1.82284768 0.64688602 2.47138047 0.38577435 1.21389195]

import numpy as np
data_set = np.loadtxt("./data/_vector21.csv", delimiter=",")
inp_vec = data_set[:, 1:22]
out_vec = data_set[:, 22:]
## # Split dataset into training set and test set
from sklearn.cross_validation import train_test_split
X_train, X_test, y_train, y_test = train_test_split(inp_vec, out_vec, test_size=0.2)    # 80% training and 20% test
## class weight
from keras.utils.np_utils import to_categorical
output_vec_categorical = to_categorical(y_train)
from sklearn.utils import class_weight
y_ints = [y.argmax() for y in output_vec_categorical]
c_w = class_weight.compute_class_weight('balanced', np.unique(y_ints), y_ints)
cw = {}
for i in set(y_ints):
    cw[i] = c_w[i]
# Create a Gaussian Classifier
from sklearn.naive_bayes import *
model = GaussianNB()
# Train the model using the training sets
print(c_w)
model.fit(X_train, y_train, c_w)
# Predict the response for test dataset
y_pred = model.predict(X_test)
# Import scikit-learn metrics module for accuracy calculation
from sklearn import metrics
# Model Accuracy, how often is the classifier correct?
print("\nClassification Report: \n", (metrics.classification_report(y_test, y_pred)))
print("\nAccuracy: %.3f%%" % (metrics.accuracy_score(y_test, y_pred)*100))

我收到了以下错误信息:ValueError: Found input variables with inconsistent numbers of samples: [13212, 6]

谁能告诉我我做错了什么,以及如何修正这个问题?

非常感谢。


回答:

sample_weightclass_weight是两种不同的东西。

顾名思义:

  • sample_weight应用于单个样本(数据中的行)。因此,sample_weight的长度必须与你的X中的样本数量相匹配。

  • class_weight是为了让分类器对某些类别给予更多的重要性和关注。因此,class_weight的长度必须与你的目标中的类别数量相匹配。

你使用sklearn.utils.class_weight计算的是class_weight而不是sample_weight,但随后尝试将其传递给sample_weight。因此导致了维度不匹配的错误。

请参阅以下问题以更好地理解这两个权重如何在内部相互作用:

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注