使用Tensorflow构建RNN。我该如何正确地预处理数据集以匹配RNN的输入和输出形状?

我正在进行一个关于从音频中检测鼓点开始的项目。我已经预处理了训练数据,并尝试在tensorflow中组装了一个SimpleRNN神经网络,但无法使两者协同工作。

在每个时间步,我的输入是一个形状为(84)的一维张量,输出应该是一个形状为(3)的张量。

我当前的代码如下所示:

train_epochs = 10batch_num = 10learning_Rate = 0.001''' 我也尝试使用tf.dataset,但无法使其工作train_dataset = dataset.batch(batch_num, drop_remainder=True)test_dataset = dataset.take(10000).batch(batch_num, drop_remainder=True)print(train_dataset.element_spec)'''x_data = x_data[:70000]y_data = y_data[:70000]x_data.resize((70000, 84))y_data.resize((70000, 3))print(x_data.shape, y_data.shape) model = keras.Sequential()model.add(keras.Input(shape=(None,84)))model.add(layers.SimpleRNN(200,activation='relu', dropout=0.2))model.add(layers.Dense(3, activation='sigmoid'))model.compile(    optimizer=keras.optimizers.RMSprop(learning_rate=learning_Rate),    loss=keras.losses.BinaryCrossentropy(),    #metrics F measure    metrics=['acc',f1_m,precision_m, recall_m])model.summary()history = model.fit(    x_data,y_data,    epochs=train_epochs,    batch_size=batch_num,    # We pass some validation for    # monitoring validation loss and metrics    # at the end of each epoch    validation_data=(x_data, y_data))print("Evaluate on test data")results = model.evaluate(test_dataset)print("test loss, test acc:", results)

当我执行它时,它给我错误消息:

 ValueError: Input 0 of layer sequential_35 is incompatible with the layer: expected ndim=3, found ndim=2. Full shape received: (10, 84)

如果我将x_data和y_data的形状改为(7000,10, 84)和(7000,10, 3),错误消息变为

 ValueError: logits and labels must have the same shape ((10, 3) vs (10, 10, 3))

我该如何解决这个问题?我对深度学习非常新手,所以任何关于如何进行项目的建议都非常受欢迎。


回答:

SimpleRNN的输入应为3D:

x_data.resize((70000, 84, 1))

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注