使用Tensorflow构建RNN。我该如何正确地预处理数据集以匹配RNN的输入和输出形状?

我正在进行一个关于从音频中检测鼓点开始的项目。我已经预处理了训练数据,并尝试在tensorflow中组装了一个SimpleRNN神经网络,但无法使两者协同工作。

在每个时间步,我的输入是一个形状为(84)的一维张量,输出应该是一个形状为(3)的张量。

我当前的代码如下所示:

train_epochs = 10batch_num = 10learning_Rate = 0.001''' 我也尝试使用tf.dataset,但无法使其工作train_dataset = dataset.batch(batch_num, drop_remainder=True)test_dataset = dataset.take(10000).batch(batch_num, drop_remainder=True)print(train_dataset.element_spec)'''x_data = x_data[:70000]y_data = y_data[:70000]x_data.resize((70000, 84))y_data.resize((70000, 3))print(x_data.shape, y_data.shape) model = keras.Sequential()model.add(keras.Input(shape=(None,84)))model.add(layers.SimpleRNN(200,activation='relu', dropout=0.2))model.add(layers.Dense(3, activation='sigmoid'))model.compile(    optimizer=keras.optimizers.RMSprop(learning_rate=learning_Rate),    loss=keras.losses.BinaryCrossentropy(),    #metrics F measure    metrics=['acc',f1_m,precision_m, recall_m])model.summary()history = model.fit(    x_data,y_data,    epochs=train_epochs,    batch_size=batch_num,    # We pass some validation for    # monitoring validation loss and metrics    # at the end of each epoch    validation_data=(x_data, y_data))print("Evaluate on test data")results = model.evaluate(test_dataset)print("test loss, test acc:", results)

当我执行它时,它给我错误消息:

 ValueError: Input 0 of layer sequential_35 is incompatible with the layer: expected ndim=3, found ndim=2. Full shape received: (10, 84)

如果我将x_data和y_data的形状改为(7000,10, 84)和(7000,10, 3),错误消息变为

 ValueError: logits and labels must have the same shape ((10, 3) vs (10, 10, 3))

我该如何解决这个问题?我对深度学习非常新手,所以任何关于如何进行项目的建议都非常受欢迎。


回答:

SimpleRNN的输入应为3D:

x_data.resize((70000, 84, 1))

Related Posts

如何对SVC进行超参数调优?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

如何在初始训练后向模型添加训练数据?

我想在我的scikit-learn模型已经训练完成后再…

使用Google Cloud Function并行运行带有不同用户参数的相同训练作业

我正在寻找一种方法来并行运行带有不同用户参数的相同训练…

加载Keras模型,TypeError: ‘module’ object is not callable

我已经在StackOverflow上搜索并阅读了文档,…

在计算KNN填补方法中特定列中NaN值的”距离平均值”时

当我从头开始实现KNN填补方法来处理缺失数据时,我遇到…

使用巨大的S3 CSV文件或直接从预处理的关系型或NoSQL数据库获取数据的机器学习训练/测试工作

已关闭。此问题需要更多细节或更清晰的说明。目前不接受回…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注