有人能解释一下tensorflow中cifar10教程的cifar10_train.py中的train函数吗?

我正在按照https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10上的cifar10教程进行学习。在这个项目中,有6个类别。通过在网上搜索,我理解了cifar10.py和cifar10_input.py中的类别。但是我无法理解cifar10_train.py中的train函数。以下是cifar10_train.py类中的train函数。

def train():with tf.Graph().as_default():    global_step = tf.contrib.framework.get_or_create_global_step()    # get images and labels for cifar 10    # Force input pipeline to CPU:0 to avoid operations sometime ending on    # GPU and resulting in a slow down    with tf.device('/cpu:0'):        images, labels = cifar10.distorted_inputs()    logits = cifar10.inference(images)    loss = cifar10.loss(logits, labels)    train_op = cifar10.train(loss, global_step)    class _LoggerHook(tf.train.SessionRunHook):        def begin(self):            self._step = -1            self._start_time = time.time()        def before_run(self, run_context):            self._step += 1            return tf.train.SessionRunArgs(loss)        def after_run(self, run_context, run_values):            if self._step % FLAGS.log_frequency == 0:                current_time = time.time()                duration = current_time - self._start_time                self._start_time = current_time                loss_value = run_values.results                examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration                sec_per_batch = float(duration / FLAGS.log_frequency)                format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '                              'sec/batch)')                print(format_str % (datetime.now(), self._step, loss_value,                                    examples_per_sec, sec_per_batch))    with tf.train.MonitoredTrainingSession(            checkpoint_dir=FLAGS.train_dir,            hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),                   tf.train.NanTensorHook(loss),                   _LoggerHook()],            config=tf.ConfigProto(                log_device_placement=FLAGS.log_device_placement)) as mon_sess:        while not mon_sess.should_stop():            mon_sess.run(train_op)

有人能解释一下_LoggerHook类中发生了什么吗?


回答:

它使用MonitoredSessionSessionRunHook来记录训练时的损失。

_LoggerHookSessionRunHook的一个实现,按照以下顺序运行:

  call hooks.begin()  sess = tf.Session()  call hooks.after_create_session()  while not stop is requested:    call hooks.before_run()    try:      results = sess.run(merged_fetches, feed_dict=merged_feeds)    except (errors.OutOfRangeError, StopIteration):      break    call hooks.after_run()  call hooks.end()  sess.close()

这是从这里获得的。

它在session.run之前收集loss数据,然后以预定义的格式输出loss

教程:https://www.tensorflow.org/tutorials/layers

希望这对你有帮助。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注