如何在TFlearn中使用两个网络?

我正在尝试训练DQN来玩井字游戏。我已经训练它来玩X(而O的移动是随机的)。经过12小时的训练,它玩得还可以,但并非完美。现在我想同时训练两个网络——一个用于X的移动,另一个用于O的移动。但是当我尝试在第二个网络上执行model.predict(state)时,我遇到了如下错误:

ValueError: Cannot feed value of shape (9,) for Tensor 'InputData/X:0', which has shape '(?, 9)'

但我确定网络定义和数据维度是相同的。定义两个DNN时存在一些问题。

这是一个通用的例子:

import tflearnimport randomX = [[random.random(),random.random()] for x in range(1000)]#将值顺序颠倒,例如 [1,0] -> [0,1]Y = [[x[1],x[0]] for x in X]n = tflearn.input_data(shape=[None,2])n = tflearn.fully_connected(n, 2)n = tflearn.regression(n)m = tflearn.DNN(n)m.fit(X, Y, n_epoch = 20)#应打印类似 [0.1,0.9]print(m.predict([[0.9,0.1]]))n2 = tflearn.input_data(shape=[None,2])n2 = tflearn.fully_connected(n2, 2)n2 = tflearn.regression(n2)m2 = tflearn.DNN(n2)#将第二个元素的值设为第一个,例如 [1,0] -> [1,1]Y = [[x[0],x[0]] for x in X]m2.fit(X, Y, n_epoch = 20)#应打印类似 [0.9,0.9]print(m2.predict([[0.9,0.1]]))

错误将类似于:

Traceback (most recent call last):  File "2_dnn_test.py", line 25, in <module>    m2.fit(X, Y, n_epoch = 20)  File "/home/cpro/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tflearn/models/dnn.py", line 157, in fit    self.targets)  File "/home/cpro/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tflearn/utils.py", line 267, in feed_dict_builder    feed_dict[net_inputs[i]] = xIndexError: list index out of range

错误不同,因为在我的井字游戏中,我在第一次fit()之前就对第二个DNN调用了predict。如果我在我的例子中注释掉m2.fit(X, Y, n_epoch = 20),我会得到相同的错误:

Traceback (most recent call last):  File "2_dnn_test.py", line 27, in <module>    print(m2.predict([[0.9,0.1]]))  File "/home/cpro/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tflearn/models/dnn.py", line 204, in predict    return self.predictor.predict(feed_dict)  File "/home/cpro/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tflearn/helpers/evaluator.py", line 69, in predict    o_pred = self.session.run(output, feed_dict=feed_dict).tolist()  File "/home/cpro/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 372, in run    run_metadata_ptr)  File "/home/cpro/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 625, in _run    % (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape())))ValueError: Cannot feed value of shape (2,) for Tensor 'InputData/X:0', which has shape '(?, 2)'

所以两个相同的网络不能同时工作。我该如何让它们都正常工作呢?

顺便说一句,这个例子没有得到预期的预测结果 🙂


回答:

看起来我应该添加

with tf.Graph().as_default():    #在这里定义模型

以防止TFLearn将两个模型都添加到默认图中。有了这个添加,一切正常工作。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注