“标量变量的索引无效” – 使用Scikit Learn的”accuracy_score”时

我不确定具体哪里出了问题。然而,我的目标是建立一个交叉验证的Python代码。我知道有各种度量标准,但我认为我使用的是正确的度量标准。然而,我没有得到我想要的CV10结果,而是收到了一个错误:

“标量变量的索引无效”

我在StackOverflow上找到了这个:当你试图索引一个numpy标量,如numpy.int64或numpy.float64时,会发生IndexError: invalid index to scalar variable。这与当你试图索引一个int时发生的TypeError: ‘int’ object has no attribute ‘_getitem‘_非常相似。

任何帮助都将不胜感激…

我正在尝试按照以下链接进行操作: http://scikit-learn.org/stable/modules/model_evaluation.html

from sklearn.ensemble import RandomForestClassifierfrom sklearn import cross_validationfrom numpy import genfromtxtimport numpy as npfrom sklearn.metrics import accuracy_scoredef main():    #读取数据,并解析为训练和目标集    dataset = genfromtxt(open('D:\\CA_DataPrediction_TrainData\\CA_DataPrediction_TrainDataGenetic.csv','r'), delimiter=',', dtype='f8')[1:]       target = np.array( [x[0] for x in dataset] )    train = np.array( [x[1:] for x in dataset] )    #在这种情况下,我们将使用随机森林,但这可以是任何分类器    cfr = RandomForestClassifier(n_estimators=10)    #简单的K-Fold交叉验证。10折。    cv = cross_validation.KFold(len(train), k=10, indices=False)    #迭代训练和测试交叉验证段,并    #在每个段上运行分类器,将结果聚合到一个列表中    results = []    for traincv, testcv in cv:        pred = cfr.fit(train[traincv], target[traincv]).predict(train[testcv])        results.append(accuracy_score(target[testcv], [x[1] for x in pred]) )    #打印出交叉验证结果的平均值    print "Results: " + str( np.array(results).mean() )if __name__=="__main__":    main()

回答:

你的pred变量只是一个预测列表,因此你不能索引它的元素(这就是错误的原因)

results.append(accuracy_score(target[testcv], [x[1] for x in pred]) )

应该改为

results.append(accuracy_score(target[testcv], pred) )

或者如果你真的想要一个副本

results.append(accuracy_score(target[testcv], [x for x in pred]) )

Related Posts

L1-L2正则化的不同系数

我想对网络的权重同时应用L1和L2正则化。然而,我找不…

使用scikit-learn的无监督方法将列表分类成不同组别,有没有办法?

我有一系列实例,每个实例都有一份列表,代表它所遵循的不…

f1_score metric in lightgbm

我想使用自定义指标f1_score来训练一个lgb模型…

通过相关系数矩阵进行特征选择

我在测试不同的算法时,如逻辑回归、高斯朴素贝叶斯、随机…

可以将机器学习库用于流式输入和输出吗?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

在TensorFlow中,queue.dequeue_up_to()方法的用途是什么?

我对这个方法感到非常困惑,特别是当我发现这个令人费解的…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注