假设我有一个大小为[batch_size, 5, 10]
的张量,称为my_tensor
。我还有另一个大小为[batch_size, 1]
的张量,保存索引,称为selecter
。
我想根据selecter
过滤my_tensor
,生成一个大小为[batch_size, 10]
的新张量,即只选择selecter
包含的值。基本上,这是在减少中间维度(大小为5)。
我觉得tf.where
可能是合适的选择,但不太确定。我非常希望得到你的帮助!
回答:
解决方案是使用tf.gather_nd
。
tf.gather_nd( my_tensor, tf.stack([tf.range(batch_size), tf.squeeze(selecter)], axis=-1))
如果你从一开始就构造selecter
为一维的,你可以省略squeeze
。