Keras模型是否会影响输入数据的大小?

我的意思是,如果一个卷积神经网络模型输入n个模型,它会输出n个结果,对吗?然而,当我尝试使用一个瓶颈模型(基于VGG16卷积神经网络构建)时,VGG16卷积神经网络的输出比输入少了16个。

这是控制台输出:

import numpy as np

train_data = np.load(open('bottleneck_features_train.npy'))
train_data.shape
(8384, 7, 7, 512)

validation_data = np.load(open('bottleneck_features_validation.npy')) validation_data.shape
(3584, 7, 7, 512)

生成此输出的脚本可以在这里找到。

上述脚本的堆栈跟踪信息如下:

使用Theano后端。
https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5下载数据
发现8400张图片,属于120个类别。正在保存训练特征…
发现3600张图片,属于120个类别。正在保存测试特征…
训练顶层…
编译瓶颈模型…
训练瓶颈模型…
跟踪(最近一次调用):

文件”pretrained_network.py”, 第87行,
train_top_model()

文件 “pretrained_network.py”, 第82行,train_top_model
validation_data=(validation_data, validation_labels))

文件 “/home/ashish/ml-projects/venv/local/lib/python2.7/site-packages/keras/models.py”,第845行,fit initial_epoch=initial_epoch)

文件 “/home/ashish/ml-projects/venv/local/lib/python2.7/site-packages/keras/engine/training.py”, 第1405行,fit batch_size=batch_size)

文件 “/home/ashish/ml-projects/venv/local/lib/python2.7/site-packages/keras/engine/training.py”, 第1307行,_standardize_user_data _check_array_lengths(x, y, sample_weights)

文件 “/home/ashish/ml-projects/venv/local/lib/python2.7/site-packages/keras/engine/training.py”, 第229行,_check_array_lengths ‘and ‘ + str(list(set_y)[0]) + ‘ target samples.’)

ValueError: 输入数组的样本数应与目标数组的样本数相同。
发现8384个输入样本和8400个目标样本


回答:

问题出在你的脚本中的这一部分,例如:

bottleneck_features_train = model.predict_generator(        generator, nb_train_samples // batch_size)

应该改为:

bottleneck_features_train = model.predict_generator(        generator, (nb_train_samples // batch_size) + 1)

如果不这样做,generator调用的次数就不够。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注