我正在使用SVM构建一个分类器,并希望通过网格搜索来帮助自动寻找最佳模型。以下是代码:
from sklearn.svm import SVCfrom sklearn.model_selection import train_test_splitfrom sklearn.model_selection import GridSearchCVfrom sklearn.multiclass import OneVsRestClassifierX.shape # (22343, 3233)y.shape # (22343, 1)X_train, X_test, y_train, y_test = train_test_split( X, Y, test_size=0.4, random_state=0)tuned_parameters = [ { 'estimator__kernel': ['rbf'], 'estimator__gamma': [1e-3, 1e-4], 'estimator__C': [1, 10, 100, 1000] }, { 'estimator__kernel': ['linear'], 'estimator__C': [1, 10, 100, 1000] }]model_to_set = OneVsRestClassifier(SVC(), n_jobs=-1)clf = GridSearchCV(model_to_set, tuned_parameters)clf.fit(X_train, y_train)
我收到了以下错误消息(这不是完整的堆栈跟踪,只是最后三个调用):
----------------------------------------------------/anaconda/lib/python3.5/site-packages/sklearn/model_selection/_split.py in split(self, X, y, groups) 88 X, y, groups = indexable(X, y, groups) 89 indices = np.arange(_num_samples(X))---> 90 for test_index in self._iter_test_masks(X, y, groups): 91 train_index = indices[np.logical_not(test_index)] 92 test_index = indices[test_index]/anaconda/lib/python3.5/site-packages/sklearn/model_selection/_split.py in _iter_test_masks(self, X, y, groups) 606 607 def _iter_test_masks(self, X, y=None, groups=None):--> 608 test_folds = self._make_test_folds(X, y) 609 for i in range(self.n_splits): 610 yield test_folds == i/anaconda/lib/python3.5/site-packages/sklearn/model_selection/_split.py in _make_test_folds(self, X, y, groups) 593 for test_fold_indices, per_cls_splits in enumerate(zip(*per_cls_cvs)): 594 for cls, (_, test_split) in zip(unique_y, per_cls_splits):--> 595 cls_test_folds = test_folds[y == cls] 596 # the test split can be too big because we used 597 # KFold(...).split(X[:max(c, n_splits)]) when data is not 100%IndexError: too many indices for array
此外,当我尝试重塑数组,使y变为(22343,)时,即使我将tuned_parameters设置为仅默认值,网格搜索也永远不会完成。
以下是所有包的版本信息,如果有帮助的话:
Python: 3.5.2
scikit-learn: 0.18
pandas: 0.19.0
回答:
看起来你的实现没有错误。
然而,正如sklearn
文档中提到的,”fit时间复杂度随着样本数量的增加而超过二次方,这使得它很难扩展到超过几万个样本的数据集”。请查看此处的文档
在你的情况下,你有22343
个样本,这可能导致一些计算问题/内存问题。这就是为什么当你进行默认的交叉验证时,它需要很长时间。尝试减少你的训练集,使用10000
个或更少的样本。