理解TensorFlow LSTM模型的输入?

我在理解TensorFlow中的LSTM模型时遇到了一些困难。

我使用tflearn作为包装器,因为它可以自动完成所有初始化和其他高级操作。为了简单起见,我们考虑这个示例程序。直到第42行net = tflearn.input_data([None, 200]),发生的事情非常清楚。你将数据集加载到变量中,并使其长度标准化(在这种情况下为200)。输入变量和两个类别都被转换为独热向量。

LSTM是如何接受输入的? 它是基于多少个样本来预测输出的?

net = tflearn.embedding(net, input_dim=20000, output_dim=128)代表什么?

我的目标是复制论文中的活动识别数据集。例如,我想将一个4096维的向量作为LSTM的输入,计划输入16个这样的向量,然后生成分类结果。我认为代码看起来会像这样,但我不知道如何向LSTM提供输入。

from __future__ import division, print_function, absolute_importimport tflearnfrom tflearn.data_utils import to_categorical, pad_sequencesfrom tflearn.datasets import imdbtrain, val = something.load_data()trainX, trainY = train #每个X样本是一个(16,4096)的nd float64 valX, valY = val #每个Y是一个101类的独热向量。net = tflearn.input_data([None, 16,4096])net = tflearn.embedding(net, input_dim=4096, output_dim=256)net = tflearn.lstm(net, 256)net = tflearn.dropout(net, 0.5)net = tflearn.lstm(net, 256)net = tflearn.dropout(net, 0.5)net = tflearn.fully_connected(net, 101, activation='softmax')net = tflearn.regression(net, optimizer='adam',                         loss='categorical_crossentropy')model = tflearn.DNN(net, clip_gradients=0., tensorboard_verbose=3)model.fit(trainX, trainY, validation_set=(testX, testY), show_metric=True,          batch_size=128,n_epoch=2,snapshot_epoch=True)

回答:

基本上,LSTM一次接受一个单元的向量大小:

lstm = rnn_cell.BasicLSTMCell(lstm_size, forget_bias=1.0)

然后,你想输入多少个时间序列?这取决于你输入的向量。X_split中的数组数量决定了时间步数:

X_split = tf.split(0, time_step_size, X)outputs, states = rnn.rnn(lstm, X_split, initial_state=init_state)

在你的例子中,我猜lstm_size是256,因为它是一个单词的向量大小。time_step_size将是训练/测试句子中的最大词数。

请看这个例子:https://github.com/nlintz/TensorFlow-Tutorials/blob/master/07_lstm.py

Related Posts

L1-L2正则化的不同系数

我想对网络的权重同时应用L1和L2正则化。然而,我找不…

使用scikit-learn的无监督方法将列表分类成不同组别,有没有办法?

我有一系列实例,每个实例都有一份列表,代表它所遵循的不…

f1_score metric in lightgbm

我想使用自定义指标f1_score来训练一个lgb模型…

通过相关系数矩阵进行特征选择

我在测试不同的算法时,如逻辑回归、高斯朴素贝叶斯、随机…

可以将机器学习库用于流式输入和输出吗?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

在TensorFlow中,queue.dequeue_up_to()方法的用途是什么?

我对这个方法感到非常困惑,特别是当我发现这个令人费解的…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注