Keras functional API: 拟合和测试接受多个输入的模型

我构建了一个Keras模型,它有两个分支,每个分支接受相同数据的不同特征表示。任务是将句子分类到6个类别之一。

我已经测试了我的代码,直到model.fit,它接受一个包含两个输入特征矩阵的列表作为X。一切正常。但在预测时,当我传递测试数据的两个输入特征矩阵时,会生成一个错误。

代码如下:

X_train_feature1 = ... # 形状:(2200, 100) 每行是一个句子,每列是一个特征
X_train_feature2 = ... # 形状:(2200, 13) 每行是一个句子,每列是一个特征
y_train= ... # 形状:(2200,6)
X_test_feature1 = ... # 形状:(587, 100) 每行是一个句子,每列是一个特征
X_test_feature2 = ... # 形状:(587, 13) 每行是一个句子,每列是一个特征
y_test= ... # 形状:(587,6)
model= ... # 创建一个有两个分支的模型,参见下图
model.fit([X_train_feature1, X_train_feature2],y_train,epochs=100, batch_size=10, verbose=2) # 模型训练正常
model.predict([X_test_feature1, X_test_feature2],y_test,epochs=100, batch_size=10, verbose=2) # 此处报错

模型看起来像这样: enter image description here

错误信息如下:

predictions = model.predict([X_test_feature1,X_test_feature2], y_test, verbose=2)
  File "/home/zz/Programs/anaconda3/lib/python3.6/site-packages/keras/engine/training.py", line 1748, in predict
    verbose=verbose, steps=steps)
  File "/home/zz/Programs/anaconda3/lib/python3.6/site-packages/keras/engine/training.py", line 1290, in _predict_loop
    batches = _make_batches(num_samples, batch_size)
  File "/home/zz/Programs/anaconda3/lib/python3.6/site-packages/keras/engine/training.py", line 384, in _make_batches
    num_batches = int(np.ceil(size / float(batch_size)))
TypeError: only length-1 arrays can be converted to Python scalars

我非常希望能得到一些帮助来理解这个错误以及如何修复它。


回答:

predict方法只接受数据(即x)和batch_size作为输入(设置这个参数不是必须的)。它不接受标签或epochs作为输入。

如果你想预测类别,那么你应该使用predict_classes方法,它会给你预测的类别标签(而不是predict方法给出的概率):

preds_prob = model.predict([X_test_feature1, X_test_feature2])
preds = model.predict_classes([X_test_feature1, X_test_feature2])

如果你想在测试数据上评估你的模型以找到损失和指标值,那么你应该使用evaluate方法:

loss_metrics = model.evaluate([X_test_feature1, X_test_feature2], y_test)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注