Java中出现”No Operation named [input] in the Graph”错误

我按照Google的机器学习速成课程中的这个Colab练习,用Python为MNIST数据库生成了一个模型。代码如下所示:

import pandas as pdimport tensorflow as tfdef create_model(my_learning_rate):    model = tf.keras.models.Sequential()    model.add(tf.keras.Input(shape=(28, 28), name='input'))    model.add(tf.keras.layers.Flatten(input_shape=(28, 28)))    model.add(tf.keras.layers.Dense(units=256, activation='relu'))    model.add(tf.keras.layers.Dense(units=128, activation='relu'))    model.add(tf.keras.layers.Dropout(rate=0.2))    model.add(tf.keras.layers.Dense(units=10, activation='softmax', name='output'))    model.compile(optimizer=tf.keras.optimizers.Adam(lr=my_learning_rate),                  loss='sparse_categorical_crossentropy',                  metrics=['accuracy'])    return modeldef train_model(model, train_features, train_label, epochs,                batch_size=None, validation_split=0.1):    history = model.fit(x=train_features, y=train_label, batch_size=batch_size,                        epochs=epochs, shuffle=True,                        validation_split=validation_split)    epochs = history.epoch    hist = pd.DataFrame(history.history)    return epochs, histif __name__ == '__main__':    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()    x_train_normalized = x_train / 255.0    x_test_normalized = x_test / 255.0    learning_rate = 0.003    epochs = 50    batch_size = 4000    validation_split = 0.2    my_model = create_model(learning_rate)    epochs, hist = train_model(my_model, x_train_normalized, y_train,                               epochs, batch_size, validation_split)    my_model.save('my_model')

模型被正确地保存到了“my_model”文件夹中。现在我在Java程序中重新加载它:

public class HelloTensorFlow {    public static void main(final String[] args) {        final String filePath = Paths.get("my_model").toAbsolutePath().toString();        try (final SavedModelBundle b = SavedModelBundle.load(filePath, "serve")) {            final Session sess = b.session();            final Tensor<Float> x = Tensor.create(new float[1][28 * 28], Float.class);            final List<Tensor<?>> run = sess.runner()                    .feed("input", x)                    .fetch("output")                    .run();            final float[] y = run.get(0).copyTo(new float[1]);            System.out.println(y[0]);        }    }}

模型加载成功,但运行器无法工作。当我执行程序时,得到“No Operation named [input] in the Graph”的错误,尽管我的输入确实命名为“input”。我做错了什么?我使用的是最新的TensorFlow版本:2.3.0(Python)和1.15.0(Java)。


回答:

我已经解决了这个问题。TensorFlow 2似乎有奇怪的命名方案,但使用MetaGraphDef可以解码。首先,你需要org.tensorflow.proto依赖。然后,你可以从元图中提取信息,如下所示:

final MetaGraphDef metaGraphDef = MetaGraphDef.parseFrom(bundle.metaGraphDef());final SignatureDef signatureDef = metaGraphDef.getSignatureDefMap().get("serving_default");final TensorInfo inputTensorInfo = signatureDef.getInputsMap()    .values()    .stream()    .filter(Objects::nonNull)    .findFirst()    .orElseThrow(() -> ...);final TensorInfo outputTensorInfo = signatureDef.getOutputsMap()    .values()    .stream()    .filter(Objects::nonNull)    .findFirst()    .orElseThrow(() -> ...);

现在你可以将创建的张量输入到inputTensorInfo.getName()返回的名称中,并从outputTensorInfo.getName()获取结果。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注