我的网络有多个输入,其中一个输入是一个索引,用于在网络中索引其他张量。
我在使用张量作为索引时遇到了问题。
class MemoryLayer(tf.keras.layers.Layer): def __init__(self, memory_size, k, **kwargs): super().__init__(kwargs) self.memory_size = memory_size self.k = k def build(self,input_shape): # 设置内存变量 # 输入的形状为[(1,3,6), (1,3)] def call(self, input): for i in range(3): statement = input[0][0,i] cluster = input[1][0,i] old_sub_mem = self.memory_var[cluster, :-1] # 这里出错 # 这里应该有很多我移除的内容,因为它们与问题无关 return tf.expand_dims(self.memory_var, axis=0)
我收到了一个TypeError
,提示<tf.Tensor 'memory_layer_19/strided_slice_2:0' shape=() dtype=float32>
不是有效的索引。我尝试对input[1]
调用.numpy()
,但这不起作用,因为张量没有形状。从我输入的数据来看,cluster
应该是一个单一的数字。
回答:
默认情况下,层的输入是tf.float32
。然而,要索引一个张量,你需要整数。你可以将层的输入转换为整数,或者指定该层的输入应该是整数类型。
通过转换
cluster = tf.cast(input[1], dtype=tf.int32)[0,i]old_sub_mem = self.memory_var[cluster, :-1]
指定输入的类型
我在这个例子中使用了函数式API:
inp_statement = tf.keras.Input(shape=(3,6))inp_cluster = tf.keras.Input(shape=(3,), dtype=tf.int32)memory = MemoryLayer(memory_size=10)([inp1,inp2])
注意:我不完全理解你试图实现什么,但这个for循环可能会通过调用tf.gather
或tf.gather_nd
来优化掉。