在Tensorflow中处理eval()时发生内存溢出

我在使用Tensorflow处理一个简单的矩阵分解算法。所有的步骤都进行得很顺利,但在最后一步,当我想使用eval()来存储一个张量时,程序无法正常工作,并且占用的内存越来越多。所以我的代码是否有问题?我是Tensorflow的新手,不知道问题出在哪里。下面是我的代码。

    class model(object):    def __init__(self, D, Q, stepsize = 6e-7, max_iter = 200, inner_maxiter = 50, dim = 200, verbose = 5):        self.D = tf.constant(D, dtype = tf.float32)        self.Q = tf.constant(Q, dtype = tf.float32)        self.rank = dim        self.stepsize = stepsize        self.max_iter = max_iter        self.inner_maxiter = inner_maxiter        self.verbose = verbose        self.W = tf.Variable((np.random.rand(self.rank, sample_num)), dtype = tf.float32, name = 'W')        self.C = tf.Variable((np.random.rand(context_num, self.rank)), dtype = tf.float32, name = 'C')    def _run(self, sess):        Q = self.Q        D = self.D        W = self.W        C = self.C        for i in xrange(self.max_iter):            if (i + 1) % 2 == 1:                for j in xrange(self.inner_maxiter):                    ED = tf.transpose(Q) * (1.0 / (1.0 + tf.exp(- tf.matmul(C, W))))                    recons = D - ED                    W_grad = tf.matmul(tf.transpose(C), recons)                    W = W + self.stepsize * W_grad            else:                for j in xrange(self.inner_maxiter):                    ED = tf.transpose(Q) * (1.0 / (1.0 + tf.exp(- tf.matmul(C, W))))                    recons = D - ED                    C_grad = tf.matmul(recons, tf.transpose(W))                    C = C + self.stepsize * C_grad            print 'epoch: %d' % i        print W.eval()        print C.eval()train_epoch = model(D, Q, args.step_size, \        args.max_iter, args.inner_maxiter, args.dim, args.verbose)with tf.Session(config = config) as sess:    tf.initialize_all_variables().run()    train_epoch._run(sess)

程序在_run()的最后两行停止并占用内存,这两行包含W.eval()C.eval()。那么我应该如何修复这个问题?有人能帮我吗?


回答:

已解决。在Tensorflow中,你不能添加符号操作的迭代。相反,你应该首先构建数据流,这意味着你应该在初始化步骤中定义你的操作。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注