Tensorflow中的BeamSearch执行时间过长

我已经在Tensorflow(Python)中尝试使用Seq2Seq模型几个星期了,我有一个工作模型,它使用双向编码器和基于注意力的解码器,之前运行得很好。今天我添加了Beam Search,但我注意到当Beam宽度为1或更大时,推理过程现在需要很长时间,而之前仅使用双向编码器和注意力解码器时,推理只需要几秒钟。

环境详情:TensorFlow版本:1.3.0,MacOS 10.12.4

以下是我代码的相关部分:

def decoding_layer(dec_input, encoder_state,                   target_sequence_length, max_target_sequence_length,                   rnn_size,                   num_layers, target_vocab_to_int, target_vocab_size,                   batch_size, keep_prob, decoding_embedding_size , encoder_outputs):    """    创建解码层    :param dec_input: 解码器输入    :param encoder_state: 编码器状态    :param target_sequence_length: 目标批次中每个序列的长度    :param max_target_sequence_length: 目标序列的最大长度    :param rnn_size: RNN大小    :param num_layers: 层数    :param target_vocab_to_int: 从目标词到ID的字典    :param target_vocab_size: 目标词汇大小    :param batch_size: 批次大小    :param keep_prob: Dropout保留概率    :param decoding_embedding_size: 解码嵌入大小    :encoder_outputs : 编码器的输出     :return: 包含(训练BasicDecoderOutput,推理BasicDecoderOutput)的元组    """    encoder_outputs_tr =encoder_outputs #tf.transpose(encoder_outputs,[1,0,2])    # 1. 解码器嵌入    dec_embeddings = tf.Variable(tf.random_uniform([target_vocab_size, decoding_embedding_size]))    dec_embed_input = tf.nn.embedding_lookup(dec_embeddings, dec_input)    # 2. 构建解码器单元    def create_cell(rnn_size):        lstm_cell = tf.contrib.rnn.LSTMCell(rnn_size,                                            initializer=tf.random_uniform_initializer(-0.1,0.1,seed=2))        drop = tf.contrib.rnn.DropoutWrapper(lstm_cell, output_keep_prob=keep_prob)        return drop    def create_complete_cell(rnn_size,num_layers,encoder_outputs_tr,batch_size,encoder_state , infer ):        if infer and beam_width >0:             encoder_outputs_tr = tf.contrib.seq2seq.tile_batch(encoder_outputs_tr, multiplier=beam_width)            encoder_state = tf.contrib.seq2seq.tile_batch(encoder_state, multiplier=beam_width)            batch_size = batch_size * beam_width        dec_cell = tf.contrib.rnn.MultiRNNCell([create_cell(rnn_size) for _ in range(num_layers)])        attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(num_units=rnn_size, memory=encoder_outputs_tr)         attn_cell = tf.contrib.seq2seq.AttentionWrapper(dec_cell, attention_mechanism , attention_layer_size=rnn_size , output_attention=False)        attn_zero = attn_cell.zero_state(batch_size , tf.float32 )        attn_zero = attn_zero.clone(cell_state = encoder_state)        return attn_zero ,  attn_cell    intial_train_state , train_cell = create_complete_cell(rnn_size,num_layers,encoder_outputs_tr,batch_size,encoder_state , False )    intial_infer_state , infer_cell = create_complete_cell(rnn_size,num_layers,encoder_outputs_tr,batch_size,encoder_state , True )    output_layer = Dense(target_vocab_size,                         kernel_initializer = tf.truncated_normal_initializer(mean = 0.0, stddev=0.1))    with tf.variable_scope("decode"):        train_decoder_out = decoding_layer_train(intial_train_state, train_cell, dec_embed_input,                          target_sequence_length, max_target_sequence_length, output_layer, keep_prob)    with tf.variable_scope("decode", reuse=True):        if beam_width == 0 :            infer_decoder_out = decoding_layer_infer(intial_infer_state, infer_cell, dec_embeddings,                                  target_vocab_to_int['<GO>'], target_vocab_to_int['<EOS>'], max_target_sequence_length,                                  target_vocab_size, output_layer, batch_size, keep_prob)        else :            infer_decoder_out = decoding_layer_infer_with_Beam(intial_infer_state, infer_cell, dec_embeddings,                                  target_vocab_to_int['<GO>'], target_vocab_to_int['<EOS>'], max_target_sequence_length,                                  target_vocab_size, output_layer, batch_size, keep_prob)            print('beam search')    return (train_decoder_out, infer_decoder_out)"""DON'T MODIFY ANYTHING IN THIS CELL THAT IS BELOW THIS LINE"""#tests.test_decoding_layer(decoding_layer)def decoding_layer_infer_with_Beam(encoder_state, dec_cell, dec_embeddings, start_of_sequence_id,                         end_of_sequence_id, max_target_sequence_length,                         vocab_size, output_layer, batch_size, keep_prob):    """    为推理创建解码层    :param encoder_state: 编码器状态    :param dec_cell: 解码器RNN单元    :param dec_embeddings: 解码器嵌入    :param start_of_sequence_id: GO ID    :param end_of_sequence_id: EOS Id    :param max_target_sequence_length: 目标序列的最大长度    :param vocab_size: 解码器/目标词汇大小    :param decoding_scope: TensorFlow变量作用域用于解码    :param output_layer: 应用输出层的函数    :param batch_size: 批次大小    :param keep_prob: Dropout保留概率    :return: 包含推理logits和sample_id的BasicDecoderOutput    """    start_tokens = tf.tile(tf.constant([start_of_sequence_id], dtype=tf.int32), [batch_size], name='start_tokens')    inference_decoder = tf.contrib.seq2seq.BeamSearchDecoder(              cell=dec_cell,              embedding=dec_embeddings,              start_tokens=start_tokens,              end_token=end_of_sequence_id,              initial_state=encoder_state,              beam_width=beam_width,              output_layer=output_layer)    inference_decoder_output = tf.contrib.seq2seq.dynamic_decode(inference_decoder,                                                            impute_finished=False                                                            )[0]    return inference_decoder_output"""DON'T MODIFY ANYTHING IN THIS CELL THAT IS BELOW THIS LINE"""#tests.test_decoding_layer_infer(decoding_layer_infer)

以下是模型参数:

# 轮数epochs = 200# 批次大小batch_size = 30# RNN大小rnn_size = 512# 层数num_layers = 2# 嵌入大小encoding_embedding_size = 100decoding_embedding_size = 100# 学习率learning_rate = 0.001# Dropout保留概率keep_probability = 0.55display_step = 10beam_width=1

我非常希望得到您的帮助,我不确定具体哪里出了问题。

谢谢您


回答:

好的,我刚刚发现了我做错的地方。

我只需要在动态解码函数中设置最大迭代值如下:

    inference_decoder_output = tf.contrib.seq2seq.dynamic_decode(inference_decoder,                                                        impute_finished=False,                                                        maximum_iterations=max_target_sequence_length)[0]

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注