在Keras中声明转换后的Sequence的input_shape?

我在尝试对文本输入运行神经网络。这是一个二分类问题。以下是我目前的工作代码:

df = pd.read_csv(pathname, encoding = "ISO-8859-1")df = df[['content_cleaned', 'meaningful']] #Content cleaned: 文本, meaningful: 标签X = df['content_cleaned']y = df['meaningful']X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=21) tokenizer = Tokenizer(num_words=100)tokenizer.fit_on_texts(X_train)X_train_encoded = tokenizer.texts_to_sequences(X_train)X_test_encoded = tokenizer.texts_to_sequences(X_test)max_len = 100X_train = pad_sequences(X_train_encoded, maxlen=max_len)X_test = pad_sequences(X_test_encoded, maxlen=max_len)batch_size = 100max_words = 100input_dim = X_train.shape[1]  # 特征数量model = Sequential()model.add(layers.Dense(10, activation='relu', input_shape=X_train.shape[1:]))model.add(layers.Dense(1, activation='sigmoid'))model.compile(loss='categorical_crossentropy',          optimizer='adam',          metrics=['accuracy'])history = model.fit(X_train, X_test,                batch_size=batch_size,                epochs=5,                verbose=1,                validation_split=0.1)

我的问题分两部分。第一部分是关于创建层时input_shape的使用。我对声明这个参数的语法感到困惑。当我运行以下命令时:

print(X_train.shape)

我得到了这样的形状:(3609, 100)

据我理解,这告诉我有3609个实例。从查看其他例子来看,我天真地假设应该使用100,因为有100种类型(可能理解错误),这与我初始化的max_words相对应。我认为我在初始化input_shape时可能语法错误了。

第二个问题是运行所有这些代码时出现的错误消息(很可能是由于错误的input_shape)。错误消息突出显示了以下这行代码:

 validation_split=0.1)

错误消息是:

ValueError: Error when checking target: expected dense_2 to have shape (None, 1) but got array with shape (1547, 1

我的方法有问题吗?我对深度学习非常新手。


回答:

input_shape参数指定了一个训练样本的形状。因此,你需要将其设置为X_train.shape[1:](即忽略样本或批次轴):

model.add(layers.Dense(10, activation='relu', input_shape=X_train.shape[1:]))

此外,将X_trainy_train传递给fit_generator(而不是X_train_encodedX_test_encoded)。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注