为什么这个TensorFlow代码会崩溃?

我构建了一个用于图像分类的玩具模型。程序的结构大致参考了cifar10教程。训练开始时一切正常,但最终程序会崩溃。我已经对图进行了最终化处理,以防某些操作被添加到图中,并且在TensorBoard中看起来很好,但它总会不可避免地冻结并强制进行硬重启(或长时间等待最终重启)。退出情况看起来像是GPU内存问题,但模型很小,应该是可以容纳的。如果我分配了全部的GPU内存(增加了4GB),它仍然会崩溃。

数据是存储在tfrecords文件中的256x256x3图像和标签。训练函数的代码如下所示:

def train():    with tf.Graph().as_default():         global_step = tf.contrib.framework.get_or_create_global_step()         train_images_batch, train_labels_batch = distorted_inputs(batch_size=BATCH_SIZE)         train_logits = inference(train_images_batch)         train_batch_loss = loss(train_logits, train_labels_batch)         train_op = training(train_batch_loss, global_step, 0.1)         merged = tf.summary.merge_all()         saver = tf.train.Saver(tf.global_variables())         gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.75)         sess_config=tf.ConfigProto(gpu_options=gpu_options)         sess = tf.Session(config=sess_config)         train_summary_writer = tf.summary.FileWriter(         os.path.join(ROOT, 'logs', 'train'), sess.graph)         init = tf.global_variables_initializer()         sess.run(init)         coord = tf.train.Coordinator()         threads = tf.train.start_queue_runners(sess=sess, coord=coord)         tf.Graph().finalize()         for i in range(5540):             start_time = time.time()             summary, _, batch_loss = sess.run([merged, train_op, train_batch_loss])             duration = time.time() - start_time             train_summary_writer.add_summary(summary, i)             if i % 10 == 0:                 msg = 'batch: {} loss: {:.6f} time: {:.8} sec/batch'.format(                 i, batch_loss, str(time.time() - start_time))                 print(msg)         coord.request_stop()         coord.join(threads)         sess.close() 

损失函数和训练操作分别是交叉熵和Adam优化器:

def loss(logits, labels):    xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits, name='cross_entropy_per_example')    xentropy_mean = tf.reduce_mean(xentropy, name='cross_entropy')    tf.add_to_collection('losses', xentropy_mean)    return xentropy_meandef training(loss, global_step, learning_rate):    optimizer = tf.train.AdamOptimizer(learning_rate)    train_op = optimizer.minimize(loss, global_step=global_step)    return train_op

批次是通过以下方式生成的:

 def distorted_inputs(batch_size):     filename_queue = tf.train.string_input_producer(         ['data/train.tfrecords'], num_epochs=None)    reader = tf.TFRecordReader()    _, serialized_example = reader.read(filename_queue)    features = tf.parse_single_example(serialized_example,        features={'label': tf.FixedLenFeature([], tf.int64),                  'image': tf.FixedLenFeature([], tf.string)})    label = features['label']    label = tf.cast(label, tf.int32)    image = tf.decode_raw(features['image'], tf.uint8)    image = (tf.cast(image, tf.float32) / 255) - 0.5    image = tf.reshape(image, shape=[256, 256, 3])    # 数据增强    image = tf.image.random_flip_up_down(image)    image = tf.image.random_flip_left_right(image)    print('filling the queue with {} images ' \          'before starting to train'.format(MIN_QUEUE_EXAMPLES))    return _generate_batch(image, label, MIN_QUEUE_EXAMPLES, BATCH_SIZE)

以及

def _generate_batch(image, label,                    min_queue_examples=MIN_QUEUE_EXAMPLES,                    batch_size=BATCH_SIZE):    images_batch, labels_batch = tf.train.shuffle_batch(        [image, label], batch_size=batch_size,        num_threads=12, capacity=min_queue_examples + 3 * BATCH_SIZE,        min_after_dequeue=min_queue_examples)    tf.summary.image('images', images_batch)    return images_batch, labels_batch

我遗漏了什么?


回答:

我已经解决了这个问题。以下是解决方案,以防对其他人有用。简而言之:这是一个硬件问题。

具体来说,这是一个PCIe总线错误,与这里票数最多的错误相同。可能这是由消息信号中断与PLX开关不兼容引起的,如这里所建议的。在那个线程中解决问题的方法是设置内核参数pci=nommconf来禁用msi。

在TensorFlow、Torch和Theano之间,只有TensorFlow会触发这个问题。为什么,我不确定。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注