使用队列在Tensorflow中训练模型

我在tensorflow中设计了一个神经网络来解决我的回归问题,我参考并调整了tensorflow的教程。然而,由于我的问题结构(约300,000个数据点和使用了昂贵的FTRLOptimizer),即使在我的32核CPU机器上(我没有GPU),执行时间也过长。

根据这个评论和通过htop的快速确认,似乎我有一些单线程操作,应该是feed_dict导致的。

因此,按照这里的建议,我尝试使用队列来多线程化我的程序。

我编写了一个简单的代码文件,使用队列来训练模型,如下所示:

import numpy as npimport tensorflow as tfimport threading#并行入队数据的函数def enqueue_thread():    sess.run(enqueue_op, feed_dict={x_batch_enqueue: x, y_batch_enqueue: y})#设置用于“训练”模型的(x, y)对的数量BATCH_SIZE = 5#生成数据,其中y=x+1+少量噪声x = np.random.randn(10, 1).astype('float32')y = x+1+np.random.randn(10, 1)/100#创建模型变量 y = x*W+b,然后W和b都应该收敛到1。W = tf.get_variable('W', shape=[1, 1], dtype='float32')b = tf.get_variable('b', shape=[1, 1], dtype='float32')#准备入队的占位符x_batch_enqueue = tf.placeholder(tf.float32, shape=[None, 1])y_batch_enqueue = tf.placeholder(tf.float32, shape=[None, 1])#创建队列q = tf.RandomShuffleQueue(capacity=2**20, min_after_dequeue=BATCH_SIZE, dtypes=[tf.float32, tf.float32], seed=12, shapes=[[1], [1]])#入队操作enqueue_op = q.enqueue_many([x_batch_enqueue, y_batch_enqueue])#出队操作x_batch, y_batch = q.dequeue_many(BATCH_SIZE)#使用线性模型+偏置进行预测y_pred=tf.add(tf.mul(x_batch, W), b)#MAE成本函数cost = tf.reduce_mean(tf.abs(y_batch-y_pred))learning_rate = 1e-3train_op = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)init = tf.initialize_all_variables()sess = tf.Session()sess.run(init)available_threads = 1024#填充队列for i in range(available_threads):    threading.Thread(target=enqueue_thread).start()#训练模型for step in range(1000):    _, cost_step = sess.run([train_op, cost])    print(cost_step)Wf=sess.run(W)bf=sess.run(b)

这个代码不起作用,因为每次我调用x_batch时,也会出队一个y_batch,反之亦然。然后,我无法将特征与相应的“结果”进行比较。

有没有简单的方法来避免这个问题?


回答:

我的错误,一切正常。我被误导了,因为我在算法的每一步都估计了不同批次的性能,而且我的模型对于一个简单的模型来说太复杂了(我应该使用类似y=W*x或y=x+b的模型)。然后,当我在控制台尝试打印时,我多次执行了sess.run在不同的变量上,显然得到了不一致的结果。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注