在预测另一个列表的对应列表时,y 数据应采用何种形状?

背景

有两个字符列表。一个列表包含钢琴音符,另一个列表包含弦乐音符。想法是训练模型根据钢琴音符来预测弦乐音符。这样它可以生成与钢琴相匹配的弦乐旋律。为了使其更加流畅,它不仅应考虑当前的钢琴音符,还应考虑之前的音符。

数据

我已经创建了一个包含超过100首歌曲的数据集(并且还在选择新的歌曲)。目前,钢琴和弦乐列表的总音符数为48523。钢琴音符的词汇量为447,弦乐音符的词汇量为261

len(set(piano_notes)) #447len(set(string_notes)) #261

所有音符都是独热编码,序列长度为100。两个列表的形状如下:

print(x.shape) #(48523, 100, 447)print(y.shape) #(48523, 100, 261)

问题

我不清楚y数据应该采用什么形状?网络结构如下:

def create_network(x, n_vocab_string_notes):""" create the structure of the neural network """model = Sequential()model.add(LSTM(    512,    input_shape=(x.shape[1], x.shape[2]),    return_sequences=True))model.add(Dropout(0.3))model.add(LSTM(512, return_sequences=True))model.add(Dropout(0.3))model.add(LSTM(512))model.add(Dense(256))model.add(Dropout(0.3))model.add(Dense(n_vocab_string_notes))model.add(Activation('softmax'))model.compile(loss='categorical_crossentropy', optimizer='rmsprop')return model

训练过程如下:

def train(model, x, y):""" train the neural network """file_path = "weights-improved.hdf5"checkpoint = ModelCheckpoint(    file_path,    monitor='loss',    verbose=0,    save_best_only=True,    mode='min')callbacks_list = [checkpoint]model.fit(x, y, epochs=200, batch_size=64, callbacks=callbacks_list)

现在它返回错误,因为y的形状不是它应该有的。它说得到了形状为(48523, 100, 261)的数组,而期望的是2维数组。

目标

目标是根据钢琴音符预测弦乐音符。也就是说,对于长度为100的钢琴音符,预测相应长度的弦乐音符。因此,从单一的钢琴音符列表中,我可以预测弦乐音符列表。这意味着对于每首钢琴旋律,都可以添加相应的弦乐旋律。


回答:

您目前通过在第二个LSTM(100)中不返回序列来压缩时间维度。您也需要返回一个序列并处理上层。类似这样的操作:

# second lstmmodel.add(LSTM(512, return_sequences=True))model.add(TimeDistributed(Dense(256)))model.add(Dropout(0.3))model.add(TimeDistributed(Dense(n_vocab_string_notes, activation='softmax')))

现在输出也将是序列。

Related Posts

L1-L2正则化的不同系数

我想对网络的权重同时应用L1和L2正则化。然而,我找不…

使用scikit-learn的无监督方法将列表分类成不同组别,有没有办法?

我有一系列实例,每个实例都有一份列表,代表它所遵循的不…

f1_score metric in lightgbm

我想使用自定义指标f1_score来训练一个lgb模型…

通过相关系数矩阵进行特征选择

我在测试不同的算法时,如逻辑回归、高斯朴素贝叶斯、随机…

可以将机器学习库用于流式输入和输出吗?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

在TensorFlow中,queue.dequeue_up_to()方法的用途是什么?

我对这个方法感到非常困惑,特别是当我发现这个令人费解的…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注