我刚开始使用Tensorflow,正在尝试实现一个讽刺检测模型。我的数据集由标记为1或0的推文组成,用以表示它们是否具有讽刺性。
经过预处理、标记化和填充阶段后,我得到了固定长度的序列和一个相关的标签向量,用于划分为训练集和测试集,并作为模型的输入。序列的形式如下:
>>> dataarray([[ 1, 677, 348, ..., 0, 0, 0],
[ 1, 677, 348, ..., 0, 0, 0],
[ 1, 825, 1, ..., 0, 0, 0],
...,
[ 908, 1376, 686, ..., 0, 0, 0],
[ 8, 158, 14579, ..., 0, 0, 0],
[ 1, 1, 35, ..., 0, 0, 0]], dtype=int32)>>> data.shape(3977, 50)>>> data[0].shape(50,)
模型如下所示:
num_words = len(tok.word_index) + 1 # tok是一个我用于数据的Tokenizerimport tensorflow as tkfrom tensorflow import kerasearly_stopping = keras.callbacks.EarlyStopping(monitor='val_loss', patience=2)# 模型model = keras.Sequential()model.add(keras.layers.Embedding(num_words, 64, input_length=Config.SEQUENCE_LENGTH, mask_zero=True))model.add(keras.layers.GRU(64, return_sequences=True))model.add(keras.layers.GRU(64))model.add(keras.layers.Dense(1, activation='sigmoid'))
在使用model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
编译模型并使用sklearn
的工具函数划分数据集后,我调用了模型的fit方法:
model.fit(x_train, y_train, batch_size=10, epochs=10, validation_split=0.1, callbacks=[early_stopping])
训练模型后,evaluate
方法按预期工作,输入为x_test
和y_test
,但如果我调用model.predict_classes(x_test[0])
(或(model.predict(x_test[0]) > 0.5).astype("int32")
),而不是得到单个预测,我得到的是形状为(50,1)的预测数组。我尝试这样重塑x_test[0]
:model.predict_classes(x_test[0].reshape(1,50))
,然后我得到一个包含单个预测的数组:array([[1]], dtype=int32)
因此,现在我留下了以下问题(也因为在调用evaluate(x_test, y_test)
时得到0.6的准确率):
- 为什么如果我将数据集作为数组的数组(x_train)传递给模型,我不能直接将测试集的一个元素传递给预测函数(例如
x_test[0]
),而是必须重塑它呢? - 这是正常的还是有错误?我是否错误地设置了模型的输入维度?在将序列输入模型之前,我是否也应该重塑它们?
回答:
为什么如果我将数据集作为数组的数组(x_train)传递给模型,我不能直接将测试集的一个元素传递给预测函数(例如x_test[0]),而是必须重塑它呢?
这是因为model.predict
只能接受一组示例,而不是单个示例,如果你想提供一个单个示例,你必须将其重塑为(1,50)
,这样它就是一组示例,其中集合的大小为1。
然而,与其将示例一个接一个地输入model.predict
,更有效的方法是将一组示例输入model.predict
,即pred_test = model.predict(x_test)
,之后如果你想知道第i个示例的预测,可以做prediction_of_the_ith_example = pred_test[i]
这是正常的还是有错误?我是否错误地设置了模型的输入维度?在将序列输入模型之前,我是否也应该重塑它们?
模型定义和训练部分是正确的。
关于标签的形状(y_train
和y_test
),正确的形状是(数据大小,标签长度),即分别为(3181,1)
和(796,1)
。你在这里进行的是单标签分类,因此标签长度为1。但即使你使用了(3181,)
和(796,)
,它也正常工作,只是因为它为你自动进行了广播处理。