自定义交叉验证分割 sklearn

我在尝试为交叉验证和GridSearch分割数据集。我想定义自己的分割方法,但GridSearch只能使用内置的交叉验证方法。

然而,我不能使用内置的交叉验证方法,因为我需要某些示例组在同一个折叠中。因此,如果我有示例:[A1, A2, A3, A4, A5, B1, B2, B3, C1, C2, C3, C4, …. , Z1, Z2, Z3]

我想进行交叉验证,使得每个组[A,B,C…]的示例只存在于一个折叠中。

例如,K1包含[D,E,G,J,K…],K2包含[A,C,L,M,…],K3包含[B,F,I,…]等


回答:

这种情况通常可以使用sklearn.cross_validation.LeaveOneLabelOut来完成。你只需要构建一个编码你的组的标签向量。即,K1中的所有样本将使用标签1K2中的所有样本将使用标签2,依此类推。

这里是一个使用假数据的完全可运行的示例。重要的是创建cv对象的行,以及调用cross_val_score的行

import numpy as np
n_features = 10
# 生成一些数据
A = np.random.randn(3, n_features)
B = np.random.randn(5, n_features)
C = np.random.randn(4, n_features)
D = np.random.randn(7, n_features)
E = np.random.randn(9, n_features)
# 分组
K1 = np.concatenate([A, B])
K2 = np.concatenate([C, D])
K3 = E
data = np.concatenate([K1, K2, K3])
# 生成一些虚拟预测目标
target = np.random.randn(len(data)) > 0
# 生成对应的标签
labels = np.concatenate([[i] * len(K) for i, K in enumerate([K1, K2, K3])])
from sklearn.cross_validation import LeaveOneLabelOut, cross_val_score
cv = LeaveOneLabelOut(labels)
# 在数据上使用某种分类器进行交叉验证
from sklearn.linear_model import LogisticRegression
lr = LogisticRegression()
scores = cross_val_score(lr, data, target, cv=cv)

当然,也可能遇到你希望完全手动定义折叠的情况。在这种情况下,你需要创建一个iterable(例如list),其中包含(train, test)对,通过索引指示哪些样本应纳入每个折叠的训练和测试集。我们来检查一下:

# 从我们的标签创建训练和测试折叠:
cv_by_hand = [(np.where(labels != label)[0], np.where(labels == label)[0]) 
              for label in np.unique(labels)]
# 我们通过将后者转换为列表来与现有的cv进行比较
cv_to_list = list(cv)
print cv_by_hand
print cv_to_list
# 检查相等性
for (train1, test1), (train2, test2) in zip(cv_by_hand, cv_to_list):
    assert (train1 == train2).all() and (test1 == test2).all()
# 在交叉验证中使用创建的cv_by_hand
scores2 = cross_val_score(lr, data, target, cv=cv_by_hand)
# 再次断言相等性
assert (scores == scores2).all()

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注