我正在使用不同的分类器研究心脏病检测问题。我的做法是将模型保存为h5文件,并创建其对象,然后以JSON格式返回响应。
但同样的模型在我的终端上运行良好,却无法在Flask API上工作。
这是我的神经网络:
def ANN(): global x_train,x_test,y_train,y_test model = Sequential() #隐式输入层与隐藏层结合 model.add(Dense(units = 13, kernel_initializer = 'uniform', activation = 'relu', input_dim = 13)) #隐藏层2 model.add(Dense(units = 13, kernel_initializer = 'uniform', activation = 'relu', input_dim = 13)) #输出层 model.add(Dense(units = 1, kernel_initializer = 'uniform', activation = 'sigmoid')) model.compile(optimizer = 'adam', loss = 'binary_crossentropy', metrics = ['accuracy']) #使用最佳超参数进行拟合 model.fit(x_train, y_train, batch_size = 25, nb_epoch = 287) return {'model':model, 'accuracy':accuracy_score(model.predict(x_test) > 0.5,y_test)*100}
将模型保存为.h5文件后,在我的Flask API中,
ann = load_model('ann8524.h5')print(ann.predict(x_test)) #测试集,仅用于检查。
这是错误信息:
* Running on http://127.0.0.1:5000/ (Press CTRL+C to quit) [2018-12-20 23:37:43,548] ERROR in app: Exception on /heart/predict[GET] Traceback (most recent call last): File "C:\python_installation\lib\site-packages\flask\app.py", line 1813, in full_dispatch_request rv = self.dispatch_request() File "C:\python_installation\lib\site-packages\flask\app.py", line 1799, in dispatch_request return self.view_functions[rule.endpoint](**req.view_args) File "C:\python_installation\lib\site-packages\flask_restful\__init__.py", line 458, in wrapper resp = resource(*args, **kwargs) File "C:\python_installation\lib\site-packages\flask\views.py", line 88, in view return self.dispatch_request(*args, **kwargs) File "C:\python_installation\lib\site-packages\flask_restful\__init__.py", line 573, in dispatch_request resp = meth(*args, **kwargs) File "app.py", line 41, in get print(ann.predict(x_test)) File "C:\python_installation\lib\site-packages\keras\engine\training.py", line 1164, in predict self._make_predict_function() File "C:\python_installation\lib\site-packages\keras\engine\training.py", line 554, in _make_predict_function **kwargs) File "C:\python_installation\lib\site-packages\keras\backend\tensorflow_backend.py", line 2744, in function return Function(inputs, outputs, updates=updates, **kwargs) File "C:\python_installation\lib\site-packages\keras\backend\tensorflow_backend.py", line 2546, in __init__ with tf.control_dependencies(self.outputs): File "C:\python_installation\lib\site-packages\tensorflow\python\framework\ops.py", line 5004, in control_dependencies return get_default_graph().control_dependencies(control_inputs) File "C:\python_installation\lib\site-packages\tensorflow\python\framework\ops.py", line 4543, in control_dependencies c = self.as_graph_element(c) File "C:\python_installation\lib\site-packages\tensorflow\python\framework\ops.py", line 3490, in as_graph_element return self._as_graph_element_locked(obj, allow_tensor, allow_operation) File "C:\python_installation\lib\site-packages\tensorflow\python\framework\ops.py", line 3569, in _as_graph_element_locked raise ValueError("Tensor %s is not an element of this graph." % obj) ValueError: Tensor Tensor("dense_3/Sigmoid:0", shape=(?, 1), dtype=float32) is not an element of this graph.127.0.0.1 - - [20/Dec/2018 23:37:43] "[1m[35mGET /heart/predict HTTP/1.1[0m" 500 -
但在Spyder中它运行得非常好。(完全相同的代码)
回答:
您需要从TensorFlow中获取默认图,按照以下步骤应该可以解决这个问题:
import tensorflow as tfann = load_model('ann8524.h5')graph = tf.get_default_graph()def your_handler(): global graph with graph.as_default(): print(ann.predict(x_test))