Keras模型在本地运行良好,但在Flask API上无法工作

我正在使用不同的分类器研究心脏病检测问题。我的做法是将模型保存为h5文件,并创建其对象,然后以JSON格式返回响应。

但同样的模型在我的终端上运行良好,却无法在Flask API上工作。

这是我的神经网络

def ANN():    global x_train,x_test,y_train,y_test    model = Sequential()    #隐式输入层与隐藏层结合    model.add(Dense(units = 13, kernel_initializer = 'uniform', activation = 'relu', input_dim = 13))    #隐藏层2    model.add(Dense(units = 13, kernel_initializer = 'uniform', activation = 'relu', input_dim = 13))    #输出层    model.add(Dense(units = 1, kernel_initializer = 'uniform', activation = 'sigmoid'))    model.compile(optimizer = 'adam', loss = 'binary_crossentropy', metrics = ['accuracy'])    #使用最佳超参数进行拟合    model.fit(x_train, y_train, batch_size = 25, nb_epoch = 287)    return {'model':model,            'accuracy':accuracy_score(model.predict(x_test) > 0.5,y_test)*100}

将模型保存为.h5文件后,在我的Flask API中,

ann = load_model('ann8524.h5')print(ann.predict(x_test)) #测试集,仅用于检查。

这是错误信息:

* Running on http://127.0.0.1:5000/ (Press CTRL+C to quit) [2018-12-20 23:37:43,548] ERROR in app: Exception on /heart/predict[GET] Traceback (most recent call last):   File "C:\python_installation\lib\site-packages\flask\app.py", line 1813, in full_dispatch_request    rv = self.dispatch_request()   File "C:\python_installation\lib\site-packages\flask\app.py", line 1799, in dispatch_request    return self.view_functions[rule.endpoint](**req.view_args)   File "C:\python_installation\lib\site-packages\flask_restful\__init__.py", line 458, in wrapper    resp = resource(*args, **kwargs)   File "C:\python_installation\lib\site-packages\flask\views.py", line 88, in view    return self.dispatch_request(*args, **kwargs)   File "C:\python_installation\lib\site-packages\flask_restful\__init__.py", line 573, in dispatch_request    resp = meth(*args, **kwargs)   File "app.py", line 41, in get    print(ann.predict(x_test))   File "C:\python_installation\lib\site-packages\keras\engine\training.py", line 1164, in predict    self._make_predict_function()   File "C:\python_installation\lib\site-packages\keras\engine\training.py", line 554, in _make_predict_function    **kwargs)   File "C:\python_installation\lib\site-packages\keras\backend\tensorflow_backend.py", line 2744, in function    return Function(inputs, outputs, updates=updates, **kwargs)   File "C:\python_installation\lib\site-packages\keras\backend\tensorflow_backend.py", line 2546, in __init__    with tf.control_dependencies(self.outputs):   File "C:\python_installation\lib\site-packages\tensorflow\python\framework\ops.py", line 5004, in control_dependencies    return get_default_graph().control_dependencies(control_inputs)   File "C:\python_installation\lib\site-packages\tensorflow\python\framework\ops.py", line 4543, in control_dependencies    c = self.as_graph_element(c)   File "C:\python_installation\lib\site-packages\tensorflow\python\framework\ops.py", line 3490, in as_graph_element    return self._as_graph_element_locked(obj, allow_tensor, allow_operation)   File "C:\python_installation\lib\site-packages\tensorflow\python\framework\ops.py", line 3569, in _as_graph_element_locked    raise ValueError("Tensor %s is not an element of this graph." % obj) ValueError: Tensor Tensor("dense_3/Sigmoid:0", shape=(?, 1), dtype=float32) is not an element of this graph.127.0.0.1 - - [20/Dec/2018 23:37:43] "[1m[35mGET /heart/predict HTTP/1.1[0m" 500 -

但在Spyder中它运行得非常好。(完全相同的代码)


回答:

您需要从TensorFlow中获取默认图,按照以下步骤应该可以解决这个问题:

import tensorflow as tfann = load_model('ann8524.h5')graph = tf.get_default_graph()def your_handler():    global graph    with graph.as_default():        print(ann.predict(x_test))

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注