使用Keras模型预测新结果时检查模型输入出现错误

我在尝试使用一个已构建的Keras模型来处理新数据时,遇到了预测时的输入错误。

这是我的模型代码:

def build_model(max_features, maxlen):    """构建LSTM模型"""    model = Sequential()    model.add(Embedding(max_features, 128, input_length=maxlen))    model.add(LSTM(128))    model.add(Dropout(0.5))    model.add(Dense(1))    model.add(Activation('sigmoid'))    model.compile(loss='binary_crossentropy',                  optimizer='rmsprop')    return model

这是我用来预测新数据输出结果的代码:

LSTM_model = load_model('LSTMmodel.h5')data = pickle.load(open('traindata.pkl', 'rb'))#### LSTM ####"""在逻辑回归模型上运行训练/测试"""# 提取数据和标签X = [x[1] for x in data]labels = [x[0] for x in data]# 生成有效字符的字典valid_chars = {x:idx+1 for idx, x in enumerate(set(''.join(X)))}max_features = len(valid_chars) + 1maxlen = np.max([len(x) for x in X])# 将字符转换为整数并填充X = [[valid_chars[y] for y in x] for x in X]X = sequence.pad_sequences(X, maxlen=maxlen)# 将标签转换为0-1y = [0 if x == 'benign' else 1 for x in labels]y_pred = LSTM_model.predict(X)

运行这段代码时我得到的错误是:

ValueError: Error when checking input: expected embedding_1_input to have shape (57,) but got array with shape (36,)

我的错误来自于maxlen,因为在我的训练数据中,maxlen=57,而在我的新数据中,maxlen=36

所以我在预测代码中尝试设置maxlen=57,但接着我得到了这个错误:

tensorflow.python.framework.errors_impl.InvalidArgumentError: indices[31,53] = 38 is not in [0, 38)     [[Node: embedding_1/embedding_lookup = GatherV2[Taxis=DT_INT32, Tindices=DT_INT32, Tparams=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"](embedding_1/embeddings/read, embedding_1/Cast, embedding_1/embedding_lookup/axis)]]

为了解决这些问题,我应该做些什么?更改我的嵌入层吗?


回答:

要么将嵌入层的input_length设置为数据集中可能见到的最大长度,要么在pad_sequences中使用构建模型时使用的相同maxlen值。这样,任何短于maxlen的序列都会被填充,任何长于maxlen的序列都会被截断。

另外,请确保在训练和测试时使用的特征是相同的(即它们的数量不应改变)。

Related Posts

在使用k近邻算法时,有没有办法获取被使用的“邻居”?

我想找到一种方法来确定在我的knn算法中实际使用了哪些…

Theano在Google Colab上无法启用GPU支持

我在尝试使用Theano库训练一个模型。由于我的电脑内…

准确性评分似乎有误

这里是代码: from sklearn.metrics…

Keras Functional API: “错误检查输入时:期望input_1具有4个维度,但得到形状为(X, Y)的数组”

我在尝试使用Keras的fit_generator来训…

如何使用sklearn.datasets.make_classification在指定范围内生成合成数据?

我想为分类问题创建合成数据。我使用了sklearn.d…

如何处理预测时不在训练集中的标签

已关闭。 此问题与编程或软件开发无关。目前不接受回答。…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注