多对多分类使用Keras LSTM

我对Keras中的RNN/LSTM很陌生,需要关于是否以及如何使用它们来解决我的多对多分类问题的一些建议。

我有多个时间序列:大约1500个“运行”,每个运行持续大约100-300个时间步,并且有多个通道。我明白我需要将数据零填充到最大时间步数,因此我的数据看起来像这样:

[nb_samples, timesteps, input_dim]: [1500, 300, 10]

由于即使是人类也无法在不知道过去的情况下为单个时间步获取标签,我可以进行特征工程并训练一个经典的分类算法,但是,我认为LSTMs在这里会是一个很好的选择。这个回答告诉我,在Keras中进行多对多分类时,我需要将return_sequences设置为True。然而,我不太明白接下来该怎么做 – 我应该将返回的序列作为另一个普通层的输入吗?我如何将它连接到我的输出层?

任何帮助、提示或教程链接都将不胜感激 – 我找到了很多关于多对一分类的资料,但没有找到关于多对多的好资料。


回答:

解决这个问题的方法有很多种,我将介绍一些可能适合你的问题的方法。

如果你想堆叠两个LSTM层,那么return-seq可以帮助学习另一个LSTM层,如下面的例子所示。

from keras.layers import Dense, Flatten, LSTM, Activationfrom keras.layers import Dropout, RepeatVector, TimeDistributedfrom keras import Input, Modelseq_length = 15input_dims = 10output_dims = 8 # 类别数量n_hidden = 10model1_inputs = Input(shape=(seq_length,input_dims,))model1_outputs = Input(shape=(output_dims,))net1 = LSTM(n_hidden, return_sequences=True)(model1_inputs)net1 = LSTM(n_hidden, return_sequences=False)(net1)net1 = Dense(output_dims, activation='relu')(net1)model1_outputs = net1model1 = Model(inputs=model1_inputs, outputs = model1_outputs, name='model1')## 拟合模型model1.summary()_________________________________________________________________Layer (type)                 Output Shape              Param #   =================================================================input_1 (InputLayer)        (None, 15, 10)            0         _________________________________________________________________lstm_1 (LSTM)                (None, 15, 10)            840       _________________________________________________________________lstm_2 (LSTM)                (None, 10)                840       _________________________________________________________________dense_3 (Dense)              (None, 8)                 88        _________________________________________________________________
  1. 另一种选择是,你可以使用完整的返回序列作为下一层的特征。在这种情况下,创建一个简单的Dense层,其输入将是[batch, seq_len*lstm_output_dims]

注意:这些特征对于分类任务可能很有用,但大多数情况下,我们使用堆叠的LSTM层,并使用其输出无完整序列作为分类层的特征。

这个回答可能有助于理解用于不同目的的LSTM架构的其他方法。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注