我在尝试运行这个简单的程序来计算梯度,但是我得到了None错误:
import tensorflow as tfimport numpy as npbatch_size = 5dim = 3hidden_units = 8sess = tf.Session()with sess.as_default(): x = tf.placeholder(dtype=tf.float32, shape=[None, dim], name="x") y = tf.placeholder(dtype=tf.int32, shape=[None], name="y") w = tf.Variable(initial_value=tf.random_normal(shape=[dim, hidden_units]), name="w") b = tf.Variable(initial_value=tf.zeros(shape=[hidden_units]), name="b") logits = tf.nn.tanh(tf.matmul(x, w) + b) cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, y,name="xentropy") # 定义模型结束 # 开始训练 optimizer = tf.train.GradientDescentOptimizer(1e-5) grads_and_vars = optimizer.compute_gradients(cross_entropy, tf.trainable_variables()) # 生成数据 data = np.random.randn(batch_size, dim) labels = np.random.randint(0, 10, size=batch_size) sess.run(tf.initialize_all_variables()) gradients_and_vars = sess.run(grads_and_vars, feed_dict={x:data, y:labels}) for g, v in gradients_and_vars: if g is not None: print "****************这是变量*************" print "变量的形状:", v.shape print v print "****************这是梯度*************" print "梯度的形状:", g.shape print gsess.close()
错误:
---------------------------------------------------------------------------TypeError Traceback (most recent call last)<ipython-input-14-8096b2e21e06> in <module>() 29 30 sess.run(tf.initialize_all_variables())---> 31 outnet = sess.run(grads_and_vars, feed_dict={x:data, y:labels}) 32 # print(gradients_and_vars) 33 # if g is not None://anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata) 893 try: 894 result = self._run(None, fetches, feed_dict, options_ptr,--> 895 run_metadata_ptr) 896 if run_metadata: 897 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)//anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata) 1107 # Create a fetch handler to take care of the structure of fetches. 1108 fetch_handler = _FetchHandler(-> 1109 self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles) 1110 1111 # Run request and get response.//anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py in __init__(self, graph, fetches, feeds, feed_handles) 411 """ 412 with graph.as_default():--> 413 self._fetch_mapper = _FetchMapper.for_fetch(fetches) 414 self._fetches = [] 415 self._targets = []//anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py in for_fetch(fetch) 231 elif isinstance(fetch, (list, tuple)): 232 # NOTE(touts): This is also the code path for namedtuples.--> 233 return _ListFetchMapper(fetch) 234 elif isinstance(fetch, dict): 235 return _DictFetchMapper(fetch)//anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py in __init__(self, fetches) 338 """ 339 self._fetch_type = type(fetches)--> 340 self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches] 341 self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers) 342 //anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py in <listcomp>(.0) 338 """ 339 self._fetch_type = type(fetches)--> 340 self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches] 341 self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers) 342 //anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py in for_fetch(fetch) 231 elif isinstance(fetch, (list, tuple)): 232 # NOTE(touts): This is also the code path for namedtuples.--> 233 return _ListFetchMapper(fetch) 234 elif isinstance(fetch, dict): 235 return _DictFetchMapper(fetch)//anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py in __init__(self, fetches) 338 """ 339 self._fetch_type = type(fetches)--> 340 self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches] 341 self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers) 342 //anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py in <listcomp>(.0) 338 """ 339 self._fetch_type = type(fetches)--> 340 self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches] 341 self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers) 342 //anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py in for_fetch(fetch) 228 if fetch is None: 229 raise TypeError('Fetch argument %r has invalid type %r' %--> 230 (fetch, type(fetch))) 231 elif isinstance(fetch, (list, tuple)): 232 # NOTE(touts): This is also the code path for namedtuples.TypeError: Fetch参数None的类型无效
为什么会出现这个错误?是版本问题吗?
回答:
Gradients
如果图中它们之间没有明确的连接,则会返回None
。在你的代码中,似乎所有声明的变量都有连接,所以可能是从其他图中加载变量的情况。你可以使用:
print([v.name for v in tf.all_variables()])
并检查只有预期的变量是这个图的一部分。
尝试像这样的代码:
sess.run(tf.initialize_all_variables())gradients_and_vars = sess.run([variable for grad,variable in grads_and_vars], feed_dict={x:data, y:labels})print(gradients_and_vars)