Keras Graph断开连接,多输入嵌入

我正在尝试构建一个模型:一个具有4个输入的模型,这些输入会被嵌入并用于生成分数输出

import numpy as npimport tensorflow as tffrom tensorflow import kerasfrom tensorflow.keras import layersH = keras.Input(shape=(1,), name="H") R = keras.Input(shape=(1,), name="R") T = keras.Input(shape=(1,), name="T") N = keras.Input(shape=(1,), name="N") embedding = keras.layers.Embedding(10000, 100)embedding_r = keras.layers.Embedding(1000, 100)H = embedding(H)R = embedding_r(R)T = embedding(T)N = embedding(N)H = keras.layers.Flatten()(H)R = keras.layers.Flatten()(R)T = keras.layers.Flatten()(T)N = keras.layers.Flatten()(N)H_plus_R = keras.layers.Concatenate()([H, R])T_plus_N = keras.layers.Concatenate()([N, T])H_plus_R = keras.layers.Dense(100, activation='relu')(H_plus_R)T_plus_N = keras.layers.Dense(100, activation='relu')(T_plus_N)score = keras.layers.Concatenate()([T_plus_N,H_plus_R])score = keras.layers.Dense(1, activation='relu')(score)model = tf.keras.Model(    inputs=[H,R,T,N],    outputs=score,)model.summary()

我得到了这个错误,这意味着输入和输出没有连接,但实际上它们是连接的:

ValueError                                Traceback (most recent call last)<ipython-input-8-90804bccaf4f> in <module>()     32 model = tf.keras.Model(     33     inputs=[H,R,T,N],---> 34     outputs=score,     35 )     36 5 frames/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/functional.py in _map_graph_network(inputs, outputs)    929                              'The following previous layers '    930                              'were accessed without issue: ' +--> 931                              str(layers_with_complete_input))    932         for x in nest.flatten(node.outputs):    933           computable_tensors.add(id(x))ValueError: Graph disconnected: cannot obtain value for tensor Tensor("R_7:0", shape=(None, 1), dtype=float32) at layer "embedding_13". The following previous layers were accessed without issue: []

我该如何修复这个问题?


回答:

在嵌入你的输入后,你覆盖了HRTN,尝试使用其他变量名

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注