TensorFlow: 记住之前轮次的权重

我正在尝试使用TensorFlow。我刚刚在这里发布了一个问题,关于我遇到的问题。不过我还有一个可能更具理论性但具有实际影响的问题。

在训练模型时,我发现准确率可能会有所不同。因此,可能发生的情况是,最后一个轮次的准确率并不是最高的。例如,在第N个轮次我的准确率可能达到85%,而在最后一个轮次准确率只有65%。我想使用第N个轮次的权重来进行预测。

我想知道是否有办法记住具有最佳准确率的轮次的权重值以便后续使用?

最简单和直接的方法可能是:

  1. 运行N个轮次
  2. 记住最佳准确率
  3. 重新开始训练,直到达到与步骤2中存储的准确率相同的轮次
  4. 使用当前权重进行预测

有没有更好的方法?


回答:

是的!你需要创建一个保存器并定期保存你的会话,在你的训练过程中。伪代码实现如下所示:

model = my_model()saver = tf.train.Saver()with tf.Session() as sess:    sess.run(init_op)    for epoch in range(NUM_EPOCHS):        for batch in range(NUM_BATCHES):            # ... 训练你的模型 ...            if batch % VALIDATION_FREQUENCY == 0:                # 定期对验证集进行测试。                error = sess.run(model.error, feed_dict=valid_dict)                if error < min_error:                    min_error = error  # 存储迄今为止的最佳错误                    saver.save(sess, MODEL_PATH)  # 保存迄今为止表现最佳的网络

然后,当你想测试模型在表现最佳的迭代时的表现时:

saver.restore(sess, MODEL_PATH)test_error = sess.run(model.error, feed_dict=test_dict)

也可以查看这个教程,关于保存和加载元图。我发现根据你的用例,加载步骤可能有点棘手。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注