使用Keras scikit-learn包装器在交叉验证中对独热编码标签进行评分

我正在实现一个神经网络,并希望通过交叉验证来评估其性能。以下是我当前的代码:

def recall_m(y_true, y_pred):    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))    recall = true_positives / (possible_positives + K.epsilon())    return recalldef precision_m(y_true, y_pred):    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))    predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))    precision = true_positives / (predicted_positives + K.epsilon())    return precisiondef f1_m(y_true, y_pred):    precision = precision_m(y_true, y_pred)    recall = recall_m(y_true, y_pred)    return 2*((precision*recall)/(precision+recall+K.epsilon()))def build_model():    hiddenLayers = 1    neurons = 100    #hidden_neurons = int(train_x.shape[0]/(3*(neurons+1)))    hidden_neurons = 500    opt = optimizers.Adam(learning_rate=0.00005, amsgrad=False)    model = Sequential()    model.add(Dense(units=neurons, activation="relu", input_shape=(15,)))    model.add(Dense(units=2*hidden_neurons, activation="relu", input_shape=(18632,)))    model.add(Dense(units=4, activation="softmax"))    model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['acc',f1_m,precision_m, recall_m])    return modelx = df[['start-sin', 'start-cos', 'start-sin-lag', 'start-cos-lag', 'prev-close-sin', 'prev-close-cos', 'prev-length', 'state-lag', 'monday', 'tuesday', 'wednesday', 'thursday', 'friday', 'saturday', 'sunday']]y = df[['wait-categ-none', 'wait-categ-short', 'wait-categ-medium', 'wait-categ-long']]print(y)#enforce, this is gone wrong somewherey = y.replace(False, 0)y = y.replace(True, 1)ep = 1#fit = model.fit(train_x, train_y, epochs=ep, verbose=1)#pred = model.predict(test_x)#loss, accuracy, f1_score, precision, recall = model.evaluate(test_x, test_y, verbose=0)classifier = KerasClassifier(build_fn=build_model, batch_size=10, epochs=ep)accuracies = cross_val_score(estimator=classifier, X=x, y=y, cv=10, scoring="f1_macro", verbose=5)

我正在使用cross_val_score,并尝试在函数中使用不同于准确率的指标,但得到的错误是

ValueError: Classification metrics can’t handle a mix of multilabel-indicator and binary targets

我在这里读到,我需要在评分之前取消独热编码输出,但我找不到使用此函数的任何方法来做到这一点。

是否有更好的方法来实现多个评分,而不是自己编写整个过程?如您所见,我已经实现了这些评分,并且它们在训练期间按预期工作,但我似乎无法提取信息,因为cross_val_score的原因。

编辑:

我只运行了一次迭代,以下是代码:

train, test = train_test_split(df, test_size=0.1, shuffle=True)train_x = train[['start-sin', 'start-cos', 'start-sin-lag', 'start-cos-lag', 'prev-close-sin', 'prev-close-cos', 'prev-length', 'state-lag', 'monday', 'tuesday', 'wednesday', 'thursday', 'friday', 'saturday', 'sunday']]train_y = train[['wait-categ-none', 'wait-categ-short', 'wait-categ-medium', 'wait-categ-long']]test_x = test[['start-sin', 'start-cos', 'start-sin-lag', 'start-cos-lag', 'prev-close-sin', 'prev-close-cos', 'prev-length', 'state-lag', 'monday', 'tuesday', 'wednesday', 'thursday', 'friday', 'saturday', 'sunday']]test_y = test[['wait-categ-none', 'wait-categ-short', 'wait-categ-medium', 'wait-categ-long']]test_y = test_y.replace(False, 0).replace(True,1)train_y = train_y.replace(False, 0).replace(True,1)ep = 500model = build_model()print("Train y")print(train_y)print("Test y")print(test_y)model.fit(train_x, train_y, epochs=1, verbose=1)pred = model.predict(test_x)print(pred)loss, accuracy, f1_score, precision, recall = model.evaluate(test_x, test_y, verbose=0)

这产生了以下输出:

训练集 y

       wait-categ-none  wait-categ-short  wait-categ-medium  wait-categ-long4629                 1                 0                  0                07643                 0                 1                  0                04425                 0                 1                  0                010548                1                 0                  0                014180                1                 0                  0                0...                ...               ...                ...              ...13661                1                 0                  0                010546                1                 0                  0                01966                 1                 0                  0                05506                 0                 1                  0                010793                1                 0                  0                0[15632 rows x 4 columns]

测试集 y

       wait-categ-none  wait-categ-short  wait-categ-medium  wait-categ-long10394                0                 1                  0                03804                 0                 1                  0                015136                0                 1                  0                07050                 1                 0                  0                030                   0                 1                  0                0...                ...               ...                ...              ...12040                0                 1                  0                04184                 0                 1                  0                012345                1                 0                  0                012629                0                 1                  0                0664                  1                 0                  0                0[1737 rows x 4 columns]

预测结果

[[2.63620764e-01 5.09552181e-01 1.72765702e-01 5.40613122e-02] [5.40941073e-07 9.99827385e-01 1.72021420e-04 5.32279255e-11] [5.91083081e-05 9.97556090e-01 2.38463446e-03 1.01058276e-07] ... [2.69533932e-01 3.99731129e-01 2.22193986e-01 1.08540975e-01] [5.87045122e-03 9.67754781e-01 2.62637101e-02 1.11028130e-04] [2.32783407e-01 4.53738511e-01 2.31750652e-01 8.17274228e-02]]

我原样复制了输出内容。


回答:

我尝试了@的答案,但由于我有一个多类问题,即使在循环本身之外,我也遇到了问题,尤其是在np.argmax()这一行。经过谷歌搜索,我没有找到任何简单的方法来解决这个问题,所以在该用户的建议下,我最终手动实现了交叉验证。由于我使用的是pandas数据框,代码确实可以进一步清理,但这是工作代码:

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注