将训练模型的权重加载到新的tensorflow图中

目标

使用tensorflow,我正在尝试对一个LSTM模型进行训练,数据的每个样本有N个时间步长,然后随着模型的训练逐步增加每个样本的时间步长。

所以可能RNN模型最初查看每个训练样本的4个时间步长。训练一段时间后,性能趋于平稳。我希望现在继续用8个时间步长来训练模型。这基本上是RNN的一种微调形式。

进展

看似最直接的方法是在训练一段时间后保存模型,然后重建一个新图,并定义一个具有更多时间步长的新变量X。

不幸的是,我找不到一种方法来避免将时间步长硬编码到我的模型中。但没关系,因为如果我重新创建模型并填充保存的权重,模型的形状应该是相同的,所以应该可以工作。

所以我第一次运行模型以生成一个保存文件。然后我加载该保存文件,并尝试用旧的(几乎相同)的tensorflow图中的权重填充一个新图。

这让我很头疼,所以任何帮助都非常感激。

代码

这是我到目前为止的代码:

if MODEL_FILE is not None:    # 从保存的模型文件加载    new_saver = tf.train.import_meta_graph(MODEL_FILE + '.meta')weights = {        'out': tf.Variable(tf.random_uniform([LSTM_SIZE, n_outputs_sm]))        }biases = {        'out': tf.Variable(tf.random_uniform([n_outputs_sm]))        }# 设置输入X和输出Y图变量x = tf.placeholder('float', [None, NUM_TIMESTEPS, n_input], name='input_x')y = tf.placeholder('float', [None, n_outputs_sm], name='output_y')# 前馈函数以获取RNN输出。我们使用一种高级的LSTM单元。def TFEncoderRNN(inp, weights, biases):    # current_input_shape: (batch_size, n_steps, n_input    # required shape: 'n_steps' tensors list of shape (batch_size, n_input)    inp = tf.unstack(inp, NUM_TIMESTEPS, 1)    lstm_cell = tf.contrib.rnn.LayerNormBasicLSTMCell(LSTM_SIZE, dropout_keep_prob=DROPOUT)    outputs, states = tf.contrib.rnn.static_rnn(lstm_cell, inp, dtype=tf.float32)    return tf.matmul(outputs[-1], weights['out']) + biases['out']# 我们将能够调用这个来获取模型输出pred = TFEncoderRNN(x, weights, biases)# 定义损失和优化器cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))# 我在这里定义了一些其他内容,为了简洁起见,我会省略init = Noneif new_saver:    new_saver.restore(sess, './' + MODEL_FILE)    init = tf.initialize_variables([global_step])else:    init = tf.global_variables_initializer()sess.run(init)######### 训练和其他操作######print "Optimization finished!"# 保存当前图,你可以再次运行此脚本以# 继续训练if SAVE_MODEL:    print "Saving model"    saver = tf.train.Saver()    saver.save(sess, 'tf_model_001')

关于如何将训练模型的权重移动到新创建的图/模型中,有什么想法吗?


回答:

看似最直接的方法是在训练一段时间后保存模型,然后重建一个新图,并定义一个具有更多时间步长的新变量X。

实际上,这就是tf.nn.dynamic_rnn的用途——同一个模型适用于任何序列长度。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注