TensorFlow: 在同一会话中训练和测试

我想在同一个TensorFlow会话中训练和测试我的模型。我使用两个不同的tf.FIFOQueue来通过多线程加载训练和测试数据(因为feed_dict会导致性能不佳)。我尝试了两种方法:

  1. 我尝试创建两个模型(用于训练和测试)并共享参数。但我使用了tf.contrib.layers.batch_norm,它不允许共享批量归一化的参数。

  2. 我尝试根据is_training布尔占位符来条件化网络的输入tf.FIFOQueue,使用tf.cond,但显然tf.cond会执行两个tf.FIFOQueue的出队函数,无论is_training的值是什么。

我想知道在同一会话中训练和测试的常规设置是什么,且不使用feed_dict


回答:

显然,如果在全局tf.variable_scope中定义,tf.contrib.layers.batch_norm确实允许共享批量归一化参数。

示例代码:摘自这里

def model(data, is_training=False, reuse=None, scope='my_model'):  # 定义一个变量范围来包含模型的所有变量  with tf.variable_scope(scope, 'model', data, reuse=reuse):    ....    net = tf.contrib.layers.batch_norm(net, is_training)   return nettrain_outputs = model(train_data, is_training=True)eval_outputs = model(eval_data, is_training=False, reuse=True)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注