在网络中使用张量作为索引

我的网络有多个输入,其中一个输入是一个索引,用于在网络中索引其他张量。

我在使用张量作为索引时遇到了问题。

class MemoryLayer(tf.keras.layers.Layer):  def __init__(self, memory_size, k, **kwargs):    super().__init__(kwargs)    self.memory_size = memory_size    self.k = k  def build(self,input_shape):    # 设置内存变量  # 输入的形状为[(1,3,6), (1,3)]  def call(self, input):    for i in range(3):      statement = input[0][0,i]      cluster = input[1][0,i]      old_sub_mem = self.memory_var[cluster, :-1] # 这里出错              # 这里应该有很多我移除的内容,因为它们与问题无关    return tf.expand_dims(self.memory_var, axis=0)

我收到了一个TypeError,提示<tf.Tensor 'memory_layer_19/strided_slice_2:0' shape=() dtype=float32>不是有效的索引。我尝试对input[1]调用.numpy(),但这不起作用,因为张量没有形状。从我输入的数据来看,cluster应该是一个单一的数字。


回答:

默认情况下,层的输入是tf.float32。然而,要索引一个张量,你需要整数。你可以将层的输入转换为整数,或者指定该层的输入应该是整数类型。

通过转换

cluster = tf.cast(input[1], dtype=tf.int32)[0,i]old_sub_mem = self.memory_var[cluster, :-1]

指定输入的类型

我在这个例子中使用了函数式API:

inp_statement = tf.keras.Input(shape=(3,6))inp_cluster = tf.keras.Input(shape=(3,), dtype=tf.int32)memory = MemoryLayer(memory_size=10)([inp1,inp2])

注意:我不完全理解你试图实现什么,但这个for循环可能会通过调用tf.gathertf.gather_nd来优化掉。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注