我有一个在强化学习环境中使用的小模型。
我可以输入一个状态的二维张量,并得到一个动作权重的二维张量。
假设我输入两个状态,并得到以下动作权重:
[[0.1, 0.2], [0.3, 0.4]]
现在我有另一个二维张量,其中包含我想获取权重的动作编号:
[[1], [0]]
我如何使用这个张量来获取动作的权重?
在这个例子中,我希望得到:
[[0.2], [0.3]]
回答:
类似于TensorFlow tf.gather with axis parameter,这里的索引处理略有不同:
a = tf.constant( [[0.1, 0.2], [0.3, 0.4]])
indices = tf.constant([[1],[0]])
# 转换为完整索引
full_indices = tf.stack([tf.range(indices.shape[0])[...,tf.newaxis], indices], axis=2)
# 收集
result = tf.gather_nd(a,full_indices)
with tf.Session() as sess:
print(sess.run(result))
#[[0.2]
#[0.3]]