我正在尝试训练DQN来玩井字游戏。我已经训练它来玩X(而O的移动是随机的)。经过12小时的训练,它玩得还可以,但并非完美。现在我想同时训练两个网络——一个用于X的移动,另一个用于O的移动。但是当我尝试在第二个网络上执行model.predict(state)
时,我遇到了如下错误:
ValueError: Cannot feed value of shape (9,) for Tensor 'InputData/X:0', which has shape '(?, 9)'
但我确定网络定义和数据维度是相同的。定义两个DNN时存在一些问题。
这是一个通用的例子:
import tflearnimport randomX = [[random.random(),random.random()] for x in range(1000)]#将值顺序颠倒,例如 [1,0] -> [0,1]Y = [[x[1],x[0]] for x in X]n = tflearn.input_data(shape=[None,2])n = tflearn.fully_connected(n, 2)n = tflearn.regression(n)m = tflearn.DNN(n)m.fit(X, Y, n_epoch = 20)#应打印类似 [0.1,0.9]print(m.predict([[0.9,0.1]]))n2 = tflearn.input_data(shape=[None,2])n2 = tflearn.fully_connected(n2, 2)n2 = tflearn.regression(n2)m2 = tflearn.DNN(n2)#将第二个元素的值设为第一个,例如 [1,0] -> [1,1]Y = [[x[0],x[0]] for x in X]m2.fit(X, Y, n_epoch = 20)#应打印类似 [0.9,0.9]print(m2.predict([[0.9,0.1]]))
错误将类似于:
Traceback (most recent call last): File "2_dnn_test.py", line 25, in <module> m2.fit(X, Y, n_epoch = 20) File "/home/cpro/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tflearn/models/dnn.py", line 157, in fit self.targets) File "/home/cpro/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tflearn/utils.py", line 267, in feed_dict_builder feed_dict[net_inputs[i]] = xIndexError: list index out of range
错误不同,因为在我的井字游戏中,我在第一次fit()
之前就对第二个DNN调用了predict
。如果我在我的例子中注释掉m2.fit(X, Y, n_epoch = 20)
,我会得到相同的错误:
Traceback (most recent call last): File "2_dnn_test.py", line 27, in <module> print(m2.predict([[0.9,0.1]])) File "/home/cpro/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tflearn/models/dnn.py", line 204, in predict return self.predictor.predict(feed_dict) File "/home/cpro/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tflearn/helpers/evaluator.py", line 69, in predict o_pred = self.session.run(output, feed_dict=feed_dict).tolist() File "/home/cpro/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 372, in run run_metadata_ptr) File "/home/cpro/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 625, in _run % (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape())))ValueError: Cannot feed value of shape (2,) for Tensor 'InputData/X:0', which has shape '(?, 2)'
所以两个相同的网络不能同时工作。我该如何让它们都正常工作呢?
顺便说一句,这个例子没有得到预期的预测结果 🙂
回答:
看起来我应该添加
with tf.Graph().as_default(): #在这里定义模型
以防止TFLearn将两个模型都添加到默认图中。有了这个添加,一切正常工作。