为什么我在@tf.function中无法使用TensorArray.gather()?

从TensorArray中读取数据:

def __init__(self, size):    self.obs_buf = tf.TensorArray(tf.float32, size=size, clear_after_read=False)    self.obs2_buf = tf.TensorArray(tf.float32, size=size, clear_after_read=False)    self.act_buf = tf.TensorArray(tf.float32, size=size, clear_after_read=False)    self.rew_buf = tf.TensorArray(tf.float32, size=size, clear_after_read=False)    self.done_buf = tf.TensorArray(tf.float32, size=size, clear_after_read=False)def get_sample(self, batch_size):        idxs = tf.random.uniform(shape=[batch_size], maxval=self.size, dtype=tf.int32)        tf.print(idxs)        return self.obs_buf.gather(indices=idxs),     \     # 这里是问题所在               self.act_buf.gather(indices=idxs),     \               self.rew_buf.gather(indices=idxs),     \               self.obs2_buf.gather(indices=idxs),    \               self.done_buf.gather(indices=idxs)

使用方法:

@tf.functiondef train(self, rpm, batch_size, gradient_steps):    for gradient_step in tf.range(1, gradient_steps + 1):        obs, act, rew, next_obs, done = rpm.get_sample(batch_size)        with tf.GradientTape() as tape:        ...

问题:

Traceback (most recent call last):File “.\main.py”, line 130, in rl_training.train()File “C:\Users\user\Documents\Projects\rl-toolkit\rl_training.py”, line 129, in trainself._rpm, self.batch_size, self.gradient_steps, logging_wandb=self.logging_wandbFile “C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\eager\def_function.py”, line 828, in callresult = self._call(*args, **kwds)File “C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\eager\def_function.py”, line 871, in _callself._initialize(args, kwds, add_initializers_to=initializers)File “C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\eager\def_function.py”, line 726, in _initialize*args, **kwds))File “C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\eager\function.py”, line 2969, in _get_concrete_function_internal_garbage_collectedgraph_function, _ = self._maybe_define_function(args, kwargs)File “C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\eager\function.py”, line 3361, in _maybe_define_functiongraph_function = self._create_graph_function(args, kwargs)File “C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\eager\function.py”, line 3206, in _create_graph_functioncapture_by_value=self._capture_by_value),File “C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\framework\func_graph.py”, line 990, in func_graph_from_py_funcfunc_outputs = python_func(*func_args, **func_kwargs)File “C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\eager\def_function.py”, line 634, in wrapped_fnout = weak_wrapped_fn().wrapped(*args, **kwds)File “C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\eager\function.py”, line 3887, in bound_method_wrapperreturn wrapped_fn(*args, **kwargs)File “C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\framework\func_graph.py”, line 977, in wrapperraise e.ag_error_metadata.to_exception(e)tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError: in user code:

C:\Users\user\Documents\Projects\rl-toolkit\policy\sac\sac.py:183 update  *    obs, act, rew, next_obs, done = rpm.get_sample(batch_size)C:\Users\user\Documents\Projects\rl-toolkit\utils\replay_buffer.py:39 __call__  *    return self.obs_buf.gather(indices=idxs),                    self.act_buf.gather(indices=idxs),                    self.rew_buf.gather(indices=idxs),                    self.obs2_buf.gather(indices=idxs),                   self.done_buf.gather(indices=idxs)C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\ops\tensor_array_ops.py:1190 gather  **    return self._implementation.gather(indices, name=name)C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\ops\tensor_array_ops.py:861 gather    return array_ops.stack([self._maybe_zero(i) for i in indices])C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\framework\ops.py:505 __iter__    self._disallow_iteration()C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\framework\ops.py:498 _disallow_iteration    self._disallow_when_autograph_enabled("iterating over `tf.Tensor`")C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\framework\ops.py:476 _disallow_when_autograph_enabled    " indicate you are trying to use an unsupported feature.".format(task))OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.

为什么在这种情况下我无法使用TensorArray?还有什么替代方案?


回答:

问题已在这里解决。必须使用tf.Variable代替tf.TensorArray。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注