在从TF1迁移到TF2时,带有多个输入的函数模型中的AssertionError

你好,我正在尝试将一个旧模型从TF1转换到TF2,但遇到了几个问题。我使用Google Colab在TF1和TF2之间切换,TF1运行一切正常,但在TF2下却不行。我用下面的简短代码复制了这个问题。

from keras.layers import *from keras import Modelfrom keras.backend import squeezedef create_model():    inputA = Input(shape=(1,))    x = Dense(1)(inputA)    x = Model(inputs=inputA, outputs=x)    print(x.predict([0.1]))        inputB = Input(shape=(1,))    y = Dense(1)(inputB)    y = Model(inputs=inputB, outputs=y)        print(y.predict([0.1]))        combined  = concatenate(inputs = [x.output,y.output])    model = Model(inputs=[x.input, y.input], outputs=combined)        return modelif (__name__  == "__main__") :    model = create_model()    model.compile(loss='mse',optimizer='RMSprop')    model.summary()    print(model.predict([[0.1],[0.1]]))

这是使用TF2时的错误信息:

AssertionError: in user code:    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:1462 predict_function  *        return step_function(self, iterator)    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:1452 step_function  **        outputs = model.distribute_strategy.run(run_step, args=(data,))    /usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:1211 run        return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)    /usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:2585 call_for_each_replica        return self._call_for_each_replica(fn, args, kwargs)    /usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:2945 _call_for_each_replica        return fn(*args, **kwargs)    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:1445 run_step  **        outputs = model.predict_step(data)    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:1418 predict_step        return self(x, training=False)    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py:985 __call__        outputs = call_fn(inputs, *args, **kwargs)    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/functional.py:386 call        inputs, training=training, mask=mask)    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/functional.py:517 _run_internal_graph        assert x_id in tensor_dict, 'Could not compute output ' + str(x)    AssertionError: Could not compute output Tensor("concatenate/concat:0", shape=(None, 2), dtype=float32)

任何帮助都将不胜感激。

谢谢,V_W


回答:

你可以像这样修改你的代码,

from tf.keras.layers import *from tf.keras import Modeldef create_model():    inputA = Input(shape=(1,))    x = Dense(1)(inputA)    modelA = Model(inputs=inputA, outputs=x)    print(modelA.predict([0.1]))        inputB = Input(shape=(1,))    y = Dense(1)(inputB)    modelB = Model(inputs=inputB, outputs=y)        print(modelB.predict([0.1]))        concat = Concatenate()( [ x , y ] )    model = Model(inputs=[ inputA, inputB ], outputs=concat )        return model

Related Posts

L1-L2正则化的不同系数

我想对网络的权重同时应用L1和L2正则化。然而,我找不…

使用scikit-learn的无监督方法将列表分类成不同组别,有没有办法?

我有一系列实例,每个实例都有一份列表,代表它所遵循的不…

f1_score metric in lightgbm

我想使用自定义指标f1_score来训练一个lgb模型…

通过相关系数矩阵进行特征选择

我在测试不同的算法时,如逻辑回归、高斯朴素贝叶斯、随机…

可以将机器学习库用于流式输入和输出吗?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

在TensorFlow中,queue.dequeue_up_to()方法的用途是什么?

我对这个方法感到非常困惑,特别是当我发现这个令人费解的…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注