ANN: keras/sklearn的扩展性不佳

我不知道为什么我得到的结果扩展性不好。如您在下面的图片中所见,存在缩放问题。

enter image description here

enter image description here

存在两个问题:

  • 没有负值
  • 预测最大值时存在问题

我不知道为什么会出现这些问题。您有任何解决这些问题的想法吗?

如果您能提供帮助,我将非常感激

代码:

# 读取输入X = dataset.iloc[0:20000, [1, 4, 10]].values# 读取输出y = dataset.iloc[0:20000, 5].values# 将数据集拆分为训练集和测试集from sklearn.cross_validation import train_test_splitX_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, random_state = 0)# 输出矩阵转换y_train = y_train.reshape(-1, 1)y_test = y_test.reshape(-1, 1)# 特征缩放from sklearn.preprocessing import StandardScalersc = StandardScaler()X_train = sc.fit_transform(X_train)X_test = sc.transform(X_test)y_train = sc.fit_transform(y_train) y_test = sc.transform(y_test)# 导入Keras库和包from keras.models import Sequentialfrom keras.layers import Dense# 构建模型classifier = Sequential()classifier.add(Dense(activation="sigmoid", input_dim=3, units=64, kernel_initializer="uniform"))classifier.add(Dense(activation="sigmoid", units=32, kernel_initializer="uniform"))classifier.add(Dense(activation="sigmoid", units=16, kernel_initializer="uniform"))classifier.add(Dense(activation="sigmoid", units=1, kernel_initializer="uniform"))classifier.compile(optimizer='adam', loss='mean_squared_error', metrics=['accuracy'])# 将ANN拟合到训练集results = classifier.fit(X_train, y_train, batch_size=16, epochs=25)# 预测测试集结果y_pred = classifier.predict(X_test)

回答:

如果您的蓝色曲线显示的是初始输出y,而橙色曲线显示的是模型的输出(您并未详细说明这一点…),那么这里没有什么奇怪的…

预测最大值时存在问题

仔细查看您的代码,您会发现您实际上并没有将初始y输入到网络中,而是输入了其缩放版本,即sc.transform()的结果;因此,您的输出也是缩放的,您应该使用inverse_transform方法将其转换回初始比例:

y_final = sc.inverse_transform(y_pred)

顺便说一下,现在这样做是可以的,但通常情况下,使用同一个缩放器(这里是sc)来处理两个不同的数据集(即您的X和y)并不是一个好主意 – 您应该定义两个不同的缩放器,例如sc_Xsc_y

没有负值

这是因为您在输出层中使用的sigmoid函数只能在[0, 1]之间取正值,所以您可能需要将其更改为其他能够提供所需值范围的函数(linear是一个候选),并且可能还需要将其他sigmoid函数更改为tanh

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注