我有两个一维布尔张量,需要使用tf.case函数根据每个输入向量的元素值逐个为输出张量赋值。我的代码如下:
f1 = lambda: tf.constant(1)
f2 = lambda: tf.constant(2)
f3 = lambda: tf.constant(0)
result = tf.case({c1 : f1, c2 : f2}, default=f3)
其中c1和c2是一维张量,输出张量的形状与它们相同。
回答:
解决此问题的有效答案是:
f1 = tf.math.add(tf.zero_like(c1),1)
f2 = tf.math.add(tf.zero_like(c1),2)
f3 = tf.zero_like(c1)
result = tf.where(c1, f1, f3)
result = tf.where(c2, f2, result)