使用LSTM解码器去除教师强制 – Tensorflow

我在Tensorflow中尝试构建一个序列到序列的模型,我已经按照几个教程进行了操作,一切都很好。直到我决定在我的模型中去除教师强制。以下是我使用的解码器网络的一个示例:

def decoding_layer_train(encoder_state, dec_cell, dec_embed_input,                      target_sequence_length, max_summary_length,                      output_layer, keep_prob):
    """Create a decoding layer for training
    :param encoder_state: Encoder State
    :param dec_cell: Decoder RNN Cell
    :param dec_embed_input: Decoder embedded input
    :param target_sequence_length: The lengths of each sequence in the target batch
    :param max_summary_length: The length of the longest sequence in the batch
    :param output_layer: Function to apply the output layer
    :param keep_prob: Dropout keep probability
    :return: BasicDecoderOutput containing training logits and sample_id"""
    training_helper = tf.contrib.seq2seq.TrainingHelper(inputs=dec_embed_input,
                                                        sequence_length=target_sequence_length,
                                                        time_major=False)
    training_decoder = tf.contrib.seq2seq.BasicDecoder(dec_cell, training_helper, encoder_state, output_layer)
    training_decoder_output = tf.contrib.seq2seq.dynamic_decode(training_decoder,
                                                                impute_finished=True,
                                                                maximum_iterations=max_summary_length)[0]
    return training_decoder_output

据我所知,TrainingHelper正在执行教师强制。特别是它将真实输出作为其参数的一部分。我尝试在没有训练帮助的情况下使用解码器,但似乎这是必需的。我尝试将真实输出设置为0,但显然TrainingHelper需要输出。我还尝试通过谷歌寻找解决方案,但没有找到任何相关内容。

===================更新=============

我很抱歉之前没有提到这一点,但我还尝试使用了GreedyEmbeddingHelper。模型运行了几次迭代后就开始抛出运行时错误。看起来GreedyEmbeddingHelper开始预测的输出与预期的形状不同。以下是我使用GreedyEmbeddingHelper时的函数:

def decoding_layer_train(encoder_state, dec_cell, dec_embeddings,
                          target_sequence_length, max_summary_length,
                          output_layer, keep_prob):
    """
    Create a decoding layer for training
    :param encoder_state: Encoder State
    :param dec_cell: Decoder RNN Cell
    :param dec_embed_input: Decoder embedded input
    :param target_sequence_length: The lengths of each sequence in the target batch
    :param max_summary_length: The length of the longest sequence in the batch
    :param output_layer: Function to apply the output layer
    :param keep_prob: Dropout keep probability
    :return: BasicDecoderOutput containing training logits and sample_id
    """
    start_tokens = tf.tile(tf.constant([target_vocab_to_int['<GO>']], dtype=tf.int32), [batch_size], name='start_tokens')
    training_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(dec_embeddings,
                                                                start_tokens,
                                                                target_vocab_to_int['<EOS>'])
    training_decoder = tf.contrib.seq2seq.BasicDecoder(dec_cell, training_helper, encoder_state, output_layer)
    training_decoder_output = tf.contrib.seq2seq.dynamic_decode(training_decoder,
                                                                impute_finished=True,
                                                                maximum_iterations=max_summary_length)[0]
    return training_decoder_output

这是训练几次迭代后抛出的错误示例:

    OkEpoch   0 Batch    5/91 - Train Accuracy: 0.4347, Validation Accuracy: 0.3557, Loss: 2.8656++++Epoch   0 Batch    5/91 - Train WER: 1.0000, Validation WER: 1.0000Epoch   0 Batch   10/91 - Train Accuracy: 0.4050, Validation Accuracy: 0.3864, Loss: 2.6347++++Epoch   0 Batch   10/91 - Train WER: 1.0000, Validation WER: 1.0000---------------------------------------------------------------------------InvalidArgumentError                      Traceback (most recent call last)<ipython-input-115-1d2a9495ad42> in <module>()     57                  target_sequence_length: targets_lengths,     58                  source_sequence_length: sources_lengths,---> 59                  keep_prob: keep_probability})     60      61 /Users/alsulaimi/Documents/AI/Tensorflow-make/workspace/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in run(self, fetches, feed_dict, options, run_metadata)    887     try:    888       result = self._run(None, fetches, feed_dict, options_ptr,--> 889                          run_metadata_ptr)    890       if run_metadata:    891         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)/Users/alsulaimi/Documents/AI/Tensorflow-make/workspace/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _run(self, handle, fetches, feed_dict, options, run_metadata)   1116     if final_fetches or final_targets or (handle and feed_dict_tensor):   1117       results = self._do_run(handle, final_targets, final_fetches,-> 1118                              feed_dict_tensor, options, run_metadata)   1119     else:   1120       results = []/Users/alsulaimi/Documents/AI/Tensorflow-make/workspace/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)   1313     if handle is None:   1314       return self._do_call(_run_fn, self._session, feeds, fetches, targets,-> 1315                            options, run_metadata)   1316     else:   1317       return self._do_call(_prun_fn, self._session, handle, feeds, fetches)/Users/alsulaimi/Documents/AI/Tensorflow-make/workspace/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _do_call(self, fn, *args)   1332         except KeyError:   1333           pass-> 1334       raise type(e)(node_def, op, message)   1335    1336   def _extend_graph(self):InvalidArgumentError: logits and labels must have the same first dimension, got logits shape [1100,78] and labels shape [1400]

我不确定,但我猜测GreedyEmbeddingHelper不应该用于训练。我很感激您的帮助和关于如何停止教师强制的想法。

谢谢您。


回答:

Related Posts

L1-L2正则化的不同系数

我想对网络的权重同时应用L1和L2正则化。然而,我找不…

使用scikit-learn的无监督方法将列表分类成不同组别,有没有办法?

我有一系列实例,每个实例都有一份列表,代表它所遵循的不…

f1_score metric in lightgbm

我想使用自定义指标f1_score来训练一个lgb模型…

通过相关系数矩阵进行特征选择

我在测试不同的算法时,如逻辑回归、高斯朴素贝叶斯、随机…

可以将机器学习库用于流式输入和输出吗?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

在TensorFlow中,queue.dequeue_up_to()方法的用途是什么?

我对这个方法感到非常困惑,特别是当我发现这个令人费解的…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注