检查输入时发生错误:期望 input_3 有 3 个维度,但得到的数组形状为 (860, 11)

我遇到了这个错误。

ValueError: 检查输入时发生错误:期望 input_1 有 3 个维度,但得到的数组形状为 (860, 11)

以下是我使用的代码。 df 具有 860x15 维度,因此 datX 具有 860×11 维度

# first neural network with keras tutorialfrom numpy import loadtxtfrom keras.models import Sequentialfrom keras.layers import Denseimport pandas as pdfrom pandas import ExcelWriterfrom pandas import ExcelFiledf =pd.read_excel('C:/Users/ASUS/Documents/Script/Simulation.Machine.V1/Final.xlsx', sheetname= "C0")datX = df.drop(['C0', 'C1', 'C2', 'C3'], axis=1)import numpy as npdatY = df['C1'] / df['C0'] datW = df['C0']**(1/2)datZ = df['C1']q=20p= len(datX.columns)from keras import backend as K# define the keras modelmodel = Sequential()model.add(Dense(q, input_dim=p, activation='tanh'))model.add(Dense(1, activation= K.exp))# define the keras modeloffset = Sequential()offset.add(Dense(1, input_dim=1, activation='linear'))from keras.layers import Inputtweet_a = Input(shape=(860, 11))tweet_b = Input(shape=(860, 1))tweetx = model(tweet_a)tweety = offset(tweet_b)from keras.layers import Multiply, addoutput = Multiply()([tweetx, tweety])from keras.models import Modelmodelnew = Model(inputs=[tweet_a, tweet_b], outputs=output)modelnew.compile(optimizer='rmsprop',loss='mse',metrics=['accuracy'])modelnew.fit([datX, datW], datY, epochs=100, batch_size=10000)

我期望输出是一个维度,输入是一个 11 维度


回答:

这里发生的事情非常明显。为了理解这一点,您需要知道 Input 层中 batch_shapeshape 参数的区别。

当您指定 shape 参数时,它实际上会在形状的开头添加一个新的维度(即批次维度)。因此,当您将 860x11 作为 shape 传递时,实际模型期望一个 bx860x11 大小的输出(其中 b 是批次大小)。您在这里指定的是 batch_shape 参数的值。因此有两种解决方案。

对您来说最好的解决方案是将上述内容更改为以下内容。因为这样您就不依赖于固定的批次维度了。

tweet_a = Input(shape=(11,))tweet_b = Input(shape=(1,))tweetx = model(tweet_a)tweety = offset(tweet_b)

但如果您 100% 确定批次大小始终为 860,您可以尝试下面的选项。

tweet_a = Input(batch_shape=(860, 11))tweet_b = Input(batch_shape=(860, 1))tweetx = model(tweet_a)tweety = offset(tweet_b)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注