为什么我在 scikit-learn 中使用 SVM 得到的所有结果都相同?

我在尝试使用 scikit-learn 计算多类别数据集的概率。然而,出于某种原因,每个样本的概率结果都相同。有人知道这是怎么回事吗?这与我的模型、我对库的使用,还是其他方面有关?感谢任何帮助!

svm_model = svm.SVC(probability=True, kernel='rbf',C=1, decision_function_shape='ovr', gamma=0.001,verbose=100)svm_model.fit(train_X,train_y)preds= svm_model.predict_proba(test_X)

train_X 看起来像这样

array([[2350, 5550, 2750.0, ..., 23478, 1, 3],       [2500, 5500, 3095.5, ..., 23674, 0, 3],       [3300, 6900, 3600.0, ..., 6529, 0, 3],       ...,        [2150, 6175, 2500.0, ..., 11209, 0, 3],       [2095, 5395, 2595.4, ..., 10070, 0, 3],       [1650, 2850, 2000.0, ..., 25463, 1, 3]], dtype=object)

train_y 看起来像这样

0        11        210       2100      21000     210000    210001    210002    210003    210004    210005    210006    210007    210008    110009    11001     210010    2

test_X 看起来像这样

array([[2190, 3937, 2200.5, ..., 24891, 1, 5],       [2695, 7000, 2850.0, ..., 5491, 1, 4],       [2950, 12000, 4039.5, ..., 22367, 0, 4],       ...,        [2850, 5200, 3000.0, ..., 15576, 1, 1],       [3200, 16000, 4100.0, ..., 1320, 0, 3],       [2100, 3750, 2400.0, ..., 6022, 0, 1]], dtype=object)

我的结果看起来像

array([[ 0.07819139,  0.22727628,  0.69453233],       [ 0.07819139,  0.22727628,  0.69453233],       [ 0.07819139,  0.22727628,  0.69453233],       ...,        [ 0.07819139,  0.22727628,  0.69453233],       [ 0.07819139,  0.22727628,  0.69453233],       [ 0.07819139,  0.22727628,  0.69453233]])

回答:

从预处理开始吧!

将数据标准化为零均值和单位方差是非常重要的。scikit-learn 的文档中提到这一点

支持向量机算法不是尺度不变的,因此强烈建议对数据进行缩放。例如,将输入向量 X 的每个属性缩放到 [0,1] 或 [-1,+1],或者标准化为均值为 0 且方差为 1。请注意,必须对测试向量应用相同的缩放以获得有意义的结果。有关缩放和归一化的更多详细信息,请参阅预处理数据部分

接下来的步骤是参数调整(C、gamma 等)。这通常通过网格搜索来完成。但我通常建议人们在尝试核 SVM 之前先尝试简单的线性 SVM(超参数更少,计算时间更短,对于非最优参数选择有更好的泛化能力)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注