使用Keras神经网络进行网格搜索

我正在尝试为用Keras构建的神经网络进行参数调整。这是我的代码,其中有一行注释说明了错误发生的位置:

from sklearn.cross_validation import StratifiedKFold, cross_val_scorefrom sklearn import grid_searchfrom sklearn.metrics import classification_reportimport multiprocessingfrom keras.models import Sequentialfrom keras.layers import Densefrom sklearn.preprocessing import LabelEncoderfrom keras.utils import np_utilsfrom keras.wrappers.scikit_learn import KerasClassifierimport numpy as npdef tuning(X_train,Y_train,X_test,Y_test):  in_size=X_train.shape[1]  num_cores=multiprocessing.cpu_count()  model = Sequential()  model.add(Dense(in_size, input_dim=in_size, init='uniform', activation='relu'))  model.add(Dense(8, init='uniform', activation='relu'))  model.add(Dense(1, init='uniform', activation='sigmoid'))  model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])  batch_size = [10, 20, 40, 60, 80, 100]  epochs = [10,20]  param_grid = dict(batch_size=batch_size, nb_epoch=epochs)  k_model = KerasClassifier(build_fn=model, verbose=0)  clf = grid_search.GridSearchCV(estimator=k_model, param_grid=param_grid, cv=StratifiedKFold(Y_train, n_folds=10, shuffle=True, random_state=1234),                   scoring="accuracy", verbose=100, n_jobs=num_cores)  clf.fit(X_train, Y_train) #ERROR HERE  print("Best parameters set found on development set:")  print()  print(clf.best_params_)  print()  print("Grid scores on development set:")  print()  for params, mean_score, scores in clf.grid_scores_:    print("%0.3f (+/-%0.03f) for %r"        % (mean_score, scores.std() * 2, params))  print()  print("Detailed classification report:")  print()  print("The model is trained on the full development set.")  print("The scores are computed on the full evaluation set.")  print()  y_true, y_pred = Y_test, clf.predict(X_test)  print(classification_report(y_true, y_pred))  print()

这是错误报告:

 clf.fit(X_train, Y_train)  File "/usr/local/lib/python2.7/dist-packages/sklearn/grid_search.py", line 804, in fit    return self._fit(X, y, ParameterGrid(self.param_grid))  File "/usr/local/lib/python2.7/dist-packages/sklearn/grid_search.py", line 553, in _fit    for parameters in parameter_iterable  File "/usr/local/lib/python2.7/dist-packages/sklearn/externals/joblib/parallel.py", line 800, in __call__    while self.dispatch_one_batch(iterator):  File "/usr/local/lib/python2.7/dist-packages/sklearn/externals/joblib/parallel.py", line 658, in dispatch_one_batch    self._dispatch(tasks)  File "/usr/local/lib/python2.7/dist-packages/sklearn/externals/joblib/parallel.py", line 566, in _dispatch    job = ImmediateComputeBatch(batch)  File "/usr/local/lib/python2.7/dist-packages/sklearn/externals/joblib/parallel.py", line 180, in __init__    self.results = batch()  File "/usr/local/lib/python2.7/dist-packages/sklearn/externals/joblib/parallel.py", line 72, in __call__    return [func(*args, **kwargs) for func, args, kwargs in self.items]  File "/usr/local/lib/python2.7/dist-packages/sklearn/cross_validation.py", line 1531, in _fit_and_score    estimator.fit(X_train, y_train, **fit_params)  File "/usr/local/lib/python2.7/dist-packages/keras/wrappers/scikit_learn.py", line 135, in fit    **self.filter_sk_params(self.build_fn.__call__))TypeError: __call__() takes at least 2 arguments (1 given)

我是不是遗漏了什么?网格搜索在随机森林、SVM和逻辑回归上运行得很好。我只在神经网络上遇到问题。


回答:

这里的错误表明build_fn需要有两个参数,正如param_grid中的参数数量所示。

所以你需要明确定义一个新函数,并将其用作build_fn=make_model

def make_model(batch_size, nb_epoch):    model = Sequential()    model.add(Dense(in_size, input_dim=in_size, init='uniform', activation='relu'))    model.add(Dense(8, init='uniform', activation='relu'))    model.add(Dense(1, init='uniform', activation='sigmoid'))    model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])    return model

还可以查看keras/examples/mnist_sklearn_wrapper.py,其中使用GridSearchCV进行超参数搜索。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注