在运行Keras和scikit-learn时遇到了不同的异常

我尝试将Keras模型(作为函数)传递给scikit_learn中的KerasClassifier包装器,然后使用GridSearchCV创建一些设置,最后拟合训练和测试数据集(两者都是numpy数组)

接着,在使用相同的Python脚本时,我得到了不同的异常,其中一些是:

_1.

Traceback (most recent call last): File “mnist_flat_imac.py”, line 63, in grid_result = validator.fit(train_images, train_labels) File “/home/longnv/PYTHON_ENV/DataScience/lib/python3.5/site-packages/sklearn/model_selection/_search.py”, line 626, in fit base_estimator = clone(self.estimator) File “/home/longnv/PYTHON_ENV/DataScience/lib/python3.5/site-packages/sklearn/base.py”, line 62, in clone new_object_params[name] = clone(param, safe=False) File “/home/longnv/PYTHON_ENV/DataScience/lib/python3.5/site-packages/sklearn/base.py”, line 53, in clone

snipped here

in _deepcopy_dict y[deepcopy(key, memo)] = deepcopy(value, memo) File “/home/longnv/PYTHON_ENV/DataScience/lib/python3.5/copy.py”, line 174, in deepcopy rv = reductor(4) TypeError: can’t pickle SwigPyObject objects Exception ignored in: > Traceback (most recent call last): File “/home/longnv/PYTHON_ENV/DataScience/lib/python3.5/site-packages/tensorflow/python/framework/c_api_util.py”, line 52, in __del__ c_api.TF_DeleteGraph(self.graph) AttributeError: ‘ScopedTFGraph’ object has no attribute ‘graph’

_2.

Traceback (most recent call last): File “mnist_flat_imac.py”, line 63, in grid_result = validator.fit(train_images, train_labels) File “/home/longnv/PYTHON_ENV/DataScience/lib/python3.5/site-packages/sklearn/model_selection/_search.py”, line 626, in fit base_estimator = clone(self.estimator) File “/home/longnv/PYTHON_ENV/DataScience/lib/python3.5/site-packages/sklearn/base.py”, line 62, in clone new_object_params[name] = clone(param, safe=False) File “/home/longnv/PYTHON_ENV/DataScience/lib/python3.5/site-packages/sklearn/base.py”, line 53, in clone return copy.deepcopy(estimator) File “/home/longnv/PYTHON_ENV/DataScience/lib/python3.5/copy.py”, line 182, in deepcopy y = _reconstruct(x, rv, 1, memo) File “/home/longnv/PYTHON_ENV/DataScience/lib/python3.5/copy.py”, line 297, in _reconstruct

snipped here

in deepcopy y = _reconstruct(x, rv, 1, memo) File “/home/longnv/PYTHON_ENV/DataScience/lib/python3.5/copy.py”, line 297, in _reconstruct state = deepcopy(state, memo) File “/home/longnv/PYTHON_ENV/DataScience/lib/python3.5/copy.py”, line 155, in deepcopy y = copier(x, memo) File “/home/longnv/PYTHON_ENV/DataScience/lib/python3.5/copy.py”, line 243, in _deepcopy_dict y[deepcopy(key, memo)] = deepcopy(value, memo) File “/home/longnv/PYTHON_ENV/DataScience/lib/python3.5/copy.py”, line 174, in deepcopy rv = reductor(4) TypeError: can’t pickle SwigPyObject objects

_3.

Traceback (most recent call last): File “mnist_flat_imac.py”, line 63, in grid_result = validator.fit(train_images, train_labels) File “/home/longnv/PYTHON_ENV/DataScience/lib/python3.5/site-packages/sklearn/model_selection/_search.py”, line 626, in fit base_estimator = clone(self.estimator) File “/home/longnv/PYTHON_ENV/DataScience/lib/python3.5/site-packages/sklearn/base.py”, line 62, in clone new_object_params[name] = clone(param, safe=False) File “/home/longnv/PYTHON_ENV/DataScience/lib/python3.5/site-packages/sklearn/base.py”, line 53, in clone

snipped here

in _deepcopy_dict y[deepcopy(key, memo)] = deepcopy(value, memo) File “/home/longnv/PYTHON_ENV/DataScience/lib/python3.5/copy.py”, line 182, in deepcopy y = _reconstruct(x, rv, 1, memo) File “/home/longnv/PYTHON_ENV/DataScience/lib/python3.5/copy.py”, line 306, in _reconstruct y.dict.update(state) AttributeError: ‘NoneType’ object has no attribute ‘update’

为什么使用相同的Python脚本会输出不同的错误?如何解决这个问题呢?

非常感谢!

附注:

  • python: 3.5
  • tensorflow: 1.10.1
  • pandas: 0.23.4
  • Ubuntu: 4.4.0-124-generic

回答:

找到了解决方法。

应该使用:clf = KerasClassifier(build_fn=get_model)

而不是:clf = KerasClassifier(build_fn=get_model())

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注