我在使用 seq2seq 的以下代码时,遇到了下面的错误:
cell = tf.nn.rnn_cell.BasicLSTMCell(size)a, b = tf.nn.dynamic_rnn(cell, seq_input, dtype=tf.float32)cell_a = tf.contrib.rnn.OutputProjectionWrapper(cell, frame_dim)dec_output= tf.contrib.legacy_seq2seq.rnn_decoder(seq_input, b, cell_a)
但是我得到了以下错误:
TypeError: 'Tensor' 对象不可迭代。
我检查后发现错误来自 seq2seq 那一行。
回答:
看起来 seq_input
是一个张量,而不是张量列表。单个张量对 tf.nn.dynamic_rnn
来说是没问题的,但 rnn_decoder
需要将序列拆分成张量列表:
decoder_inputs
:一个二维张量的列表[batch_size x input_size]
。
在 源代码 中,你可以看到实现只是在 for
循环中迭代 decoder_inputs
。