我在Tensorflow中尝试构建一个序列到序列的模型,我已经按照几个教程进行了操作,一切都很好。直到我决定在我的模型中去除教师强制。以下是我使用的解码器网络的一个示例:
def decoding_layer_train(encoder_state, dec_cell, dec_embed_input, target_sequence_length, max_summary_length, output_layer, keep_prob):
"""Create a decoding layer for training
:param encoder_state: Encoder State
:param dec_cell: Decoder RNN Cell
:param dec_embed_input: Decoder embedded input
:param target_sequence_length: The lengths of each sequence in the target batch
:param max_summary_length: The length of the longest sequence in the batch
:param output_layer: Function to apply the output layer
:param keep_prob: Dropout keep probability
:return: BasicDecoderOutput containing training logits and sample_id"""
training_helper = tf.contrib.seq2seq.TrainingHelper(inputs=dec_embed_input,
sequence_length=target_sequence_length,
time_major=False)
training_decoder = tf.contrib.seq2seq.BasicDecoder(dec_cell, training_helper, encoder_state, output_layer)
training_decoder_output = tf.contrib.seq2seq.dynamic_decode(training_decoder,
impute_finished=True,
maximum_iterations=max_summary_length)[0]
return training_decoder_output
据我所知,TrainingHelper正在执行教师强制。特别是它将真实输出作为其参数的一部分。我尝试在没有训练帮助的情况下使用解码器,但似乎这是必需的。我尝试将真实输出设置为0,但显然TrainingHelper需要输出。我还尝试通过谷歌寻找解决方案,但没有找到任何相关内容。
===================更新=============
我很抱歉之前没有提到这一点,但我还尝试使用了GreedyEmbeddingHelper。模型运行了几次迭代后就开始抛出运行时错误。看起来GreedyEmbeddingHelper开始预测的输出与预期的形状不同。以下是我使用GreedyEmbeddingHelper时的函数:
def decoding_layer_train(encoder_state, dec_cell, dec_embeddings,
target_sequence_length, max_summary_length,
output_layer, keep_prob):
"""
Create a decoding layer for training
:param encoder_state: Encoder State
:param dec_cell: Decoder RNN Cell
:param dec_embed_input: Decoder embedded input
:param target_sequence_length: The lengths of each sequence in the target batch
:param max_summary_length: The length of the longest sequence in the batch
:param output_layer: Function to apply the output layer
:param keep_prob: Dropout keep probability
:return: BasicDecoderOutput containing training logits and sample_id
"""
start_tokens = tf.tile(tf.constant([target_vocab_to_int['<GO>']], dtype=tf.int32), [batch_size], name='start_tokens')
training_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(dec_embeddings,
start_tokens,
target_vocab_to_int['<EOS>'])
training_decoder = tf.contrib.seq2seq.BasicDecoder(dec_cell, training_helper, encoder_state, output_layer)
training_decoder_output = tf.contrib.seq2seq.dynamic_decode(training_decoder,
impute_finished=True,
maximum_iterations=max_summary_length)[0]
return training_decoder_output
这是训练几次迭代后抛出的错误示例:
OkEpoch 0 Batch 5/91 - Train Accuracy: 0.4347, Validation Accuracy: 0.3557, Loss: 2.8656++++Epoch 0 Batch 5/91 - Train WER: 1.0000, Validation WER: 1.0000Epoch 0 Batch 10/91 - Train Accuracy: 0.4050, Validation Accuracy: 0.3864, Loss: 2.6347++++Epoch 0 Batch 10/91 - Train WER: 1.0000, Validation WER: 1.0000---------------------------------------------------------------------------InvalidArgumentError Traceback (most recent call last)<ipython-input-115-1d2a9495ad42> in <module>() 57 target_sequence_length: targets_lengths, 58 source_sequence_length: sources_lengths,---> 59 keep_prob: keep_probability}) 60 61 /Users/alsulaimi/Documents/AI/Tensorflow-make/workspace/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in run(self, fetches, feed_dict, options, run_metadata) 887 try: 888 result = self._run(None, fetches, feed_dict, options_ptr,--> 889 run_metadata_ptr) 890 if run_metadata: 891 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)/Users/alsulaimi/Documents/AI/Tensorflow-make/workspace/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _run(self, handle, fetches, feed_dict, options, run_metadata) 1116 if final_fetches or final_targets or (handle and feed_dict_tensor): 1117 results = self._do_run(handle, final_targets, final_fetches,-> 1118 feed_dict_tensor, options, run_metadata) 1119 else: 1120 results = []/Users/alsulaimi/Documents/AI/Tensorflow-make/workspace/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata) 1313 if handle is None: 1314 return self._do_call(_run_fn, self._session, feeds, fetches, targets,-> 1315 options, run_metadata) 1316 else: 1317 return self._do_call(_prun_fn, self._session, handle, feeds, fetches)/Users/alsulaimi/Documents/AI/Tensorflow-make/workspace/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _do_call(self, fn, *args) 1332 except KeyError: 1333 pass-> 1334 raise type(e)(node_def, op, message) 1335 1336 def _extend_graph(self):InvalidArgumentError: logits and labels must have the same first dimension, got logits shape [1100,78] and labels shape [1400]
我不确定,但我猜测GreedyEmbeddingHelper不应该用于训练。我很感激您的帮助和关于如何停止教师强制的想法。
谢谢您。
回答: