TensorFlow代码阅读:`sess.run()`除了评估还有什么作用?

我正在阅读影响函数(ICML 2017最佳论文)的代码。在第190行有一个令人困惑的代码行,它什么也不返回self.sess.run(self.set_params_op, feed_dict=params_feed_dict)。相关代码片段如下:

params_feed_dict = {}
params_feed_dict[self.W_placeholder] = W
# params_feed_dict[self.b_placeholder] = b
self.sess.run(self.set_params_op, feed_dict=params_feed_dict)
if save_checkpoints:
    self.saver.save(self.sess, self.checkpoint_file, global_step=0)
if verbose:
    print('LBFGS训练花费了%s次迭代。' % model.n_iter_)
    print('LBFGS训练后:')
    self.print_model_eval()

self.set_params_op是在初始化函数中通过以下函数赋值的(无论我是否执行,它都会被执行):

def set_params(self):
    self.W_placeholder = tf.placeholder(
        tf.float32,
        shape=[self.input_dim * self.num_classes],
        name='W_placeholder')
    set_weights = tf.assign(self.weights, self.W_placeholder, validate_shape=True)
    return [set_weights]

然后我尝试注释掉那行代码,但self.print_model_eval()的打印消息发生了变化。

关于print_model_eval()的有用代码片段定义如下:

def print_model_eval():
    ...
    grad_loss_val, loss_no_reg_val, loss_val, train_acc_val = self.sess.run(
        [self.grad_total_loss_op, self.loss_no_reg, self.total_loss, self.accuracy_op],
        feed_dict=self.all_train_feed_dict)
    test_loss_val, test_acc_val = self.sess.run(
        [self.loss_no_reg, self.accuracy_op],
        feed_dict=self.all_test_feed_dict)

这些张量是通过以下方式获得的:

self.total_loss, self.loss_no_reg, self.indiv_loss_no_reg = self.loss(
            self.logits,
            self.labels_placeholder)

以及

def loss(self, logits, labels):
    labels = tf.one_hot(labels, depth=self.num_classes)
    # correct_prob = tf.reduce_sum(tf.multiply(labels, tf.nn.softmax(logits)), reduction_indices=1)
    cross_entropy = - tf.reduce_sum(tf.multiply(labels, tf.nn.log_softmax(logits)), reduction_indices=1)
    indiv_loss_no_reg = cross_entropy
    loss_no_reg = tf.reduce_mean(cross_entropy, name='xentropy_mean')
    tf.add_to_collection('losses', loss_no_reg)
    total_loss = tf.add_n(tf.get_collection('losses'), name='total_loss')
    return total_loss, loss_no_reg, indiv_loss_no_reg

我的问题是:

  1. sess.run()没有返回值是如何影响其他张量的计算的?
  2. 调试TensorFlow代码的最佳实践是什么?

任何建议都将不胜感激。提前谢谢您。


回答:

在TensorFlow的非急切执行模式下,首先你创建一个包含操作的图。然后,在会话中运行这些操作。某些操作,比如变量的初始化,不返回任何东西。它们只是将值赋给图中的变量。

这里,self.set_params_op听起来像是这种类型的操作。它可能是在设置某些参数的值。如果你将其注释掉,我的猜测是这些参数没有被设置为“训练后”的值。因此,当你评估模型时,会得到不同的结果。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注