为什么我喂给TensorFlow的Placeholder值会被计算值替换?

我定义了一些计算值和相关的摘要,使用的是

keep_prob_val = tf.sub(1.0, tf.train.exponential_decay(1 - FLAGS.keep_prob, global_step, FLAGS.decay_steps,                                                        FLAGS.dropout_decay_rate, staircase=False))keep_prob_summary = tf.scalar_summary('params/keep_prob', keep_prob_val)

并设置了一些合并的摘要,使用的是

train_writer = tf.train.SummaryWriter(FLAGS.train_dir.format(log_id), tf.get_default_graph())test_writer = tf.train.SummaryWriter(test_dir)merged = tf.merge_all_summaries()

但当我随后

sess.run(train_step, feed_dict={x: xs, y_: ys,                                keep_prob: sess.run(keep_prob_val)})# ...test_writer.add_summary(sess.run(merged, feed_dict={ ..., keep_prob: some_val}),                         global_step=gs)

其中keep_prob之前已被定义为

keep_prob = tf.placeholder(tf.float32, name='keep_prob')

我得到的摘要信息是计算得到的keep_prob_val,而不是我喂入的some_val。即使我没有明确喂入keep_prob,也会如预期般报错,这种情况仍然会发生。

据我所知,这种情况只发生在涉及keep_prob的摘要上。例如,当我

sess.run(accuracy, feed_dict={ ..., keep_prob: some_val})

我似乎得到了使用some_val计算的accuracy,这与相关的摘要相符。

为什么我在摘要中喂入的keep_prob值被忽略了?


回答:

摘要并不依赖于你的keep_prob占位符,而是依赖于FLAGS.keep_prob值:

keep_prob_val = tf.sub(1.0, tf.train.exponential_decay(    1 - FLAGS.keep_prob, global_step, FLAGS.decay_steps,                    FLAGS.dropout_decay_rate, staircase=False))keep_prob_summary = tf.scalar_summary('params/keep_prob', keep_prob_val)

因此,你的摘要将始终包含由--keep_prob标志设置的(可能已经衰减的)值(或其默认值)。

为了使摘要反映喂入的值,你应该简单地重新定义keep_prob_valkeep_prob_summary,使其依赖于占位符:

keep_prob = tf.placeholder(tf.float32, name='keep_prob')# ...keep_prob_val = tf.sub(1.0, tf.train.exponential_decay(    1 - keep_prob,  # 注意:关键更改在这里!    global_step, FLAGS.decay_steps, FLAGS.dropout_decay_rate, staircase=False))keep_prob_summary = tf.scalar_summary('params/keep_prob', keep_prob_val)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注