使用Keras训练的回归模型对所有测试特征集预测相同输出

我正在尝试构建一个回归模型,使用数据集https://www.kaggle.com/shubhammehta21/movie-lens-small-latest-dataset来预测电影的“评级”。然而,在训练模型后,预测结果对所有测试特征都输出了相同的值。我已经阅读了之前类似的建议,这些建议包括调整学习率、特征数量以及检查预测模型是否与训练模型相同,但这些方法对我都没有效果。

我加载并处理数据如下:

links= pd.read_csv('../input/movie-lens-small-latest-dataset/links.csv')movies=pd.read_csv('../input/movie-lens-small-latest-dataset/movies.csv')...dataset=movies.merge(ratings,on='movieId').merge(tags,on='movieId').merge(links,on='movieId')to_drop='title','genres','timestamp_x','timestamp_y','userId_y','imdbId','tmdbId']dataset.drop(columns=to_drop,inplace=True)dataset=pd.get_dummies(dataset)

代码展示了我是如何构建回归模型的。我尝试过调整神经元和层的数量,但这并没有影响输出结果。

from keras.models import Sequentialfrom keras.layers.core import Dense, Activationfrom keras.optimizers import Adammodel = Sequential()model.add(Dense(13, input_dim=1586, kernel_initializer='zero', activation='relu'))model.add(Dense(6, kernel_initializer='normal', activation='relu'))model.add(Dense(1, kernel_initializer='normal',activation='linear'))# Compile modeladam = Adam(lr=0.001)model.compile(loss='mean_squared_error', optimizer=adam,metrics=['mse','mae'])model.summary()history = model.fit(train_dataset,train_labels,batch_size=30, epochs=10,verbose=1, validation_split=0.3)score = model.evaluate(validation_dataset,validation_labels)print("Test score:", score)

每当我尝试预测测试数据集时:

model.predict(test_dataset)

它预测的值总是

3.97

对于所有值。我期望得到的值范围在0到5之间。


回答:

你绝对不应该(我是说,绝对不应该)使用kernel_initializer='zero' – 说实话,我很惊讶Keras居然有这个选项!

此外,kernel_initializer='normal'也是不推荐的。

作为第一步,移除所有kernel_initializer参数,以便恢复到默认和推荐的kernel_initializer='glorot-uniform';请记住,默认值的存在是有原因的(通常它们效果很好),只有在你确实有理由这样做(我相信你这里没有这样的理由)并且你知道自己在做什么时,你才应该改变它们。

如果仍然没有得到你期望的结果,可以尝试其他参数(层的数量/神经元数量、更多的训练轮次等);起初,你应该保持Adam优化器的学习率(lr)不变(它也是这些似乎在不同情况下都能很好工作的默认值之一)。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注