我按照Google的机器学习速成课程中的这个Colab练习,用Python为MNIST数据库生成了一个模型。代码如下所示:
import pandas as pdimport tensorflow as tfdef create_model(my_learning_rate): model = tf.keras.models.Sequential() model.add(tf.keras.Input(shape=(28, 28), name='input')) model.add(tf.keras.layers.Flatten(input_shape=(28, 28))) model.add(tf.keras.layers.Dense(units=256, activation='relu')) model.add(tf.keras.layers.Dense(units=128, activation='relu')) model.add(tf.keras.layers.Dropout(rate=0.2)) model.add(tf.keras.layers.Dense(units=10, activation='softmax', name='output')) model.compile(optimizer=tf.keras.optimizers.Adam(lr=my_learning_rate), loss='sparse_categorical_crossentropy', metrics=['accuracy']) return modeldef train_model(model, train_features, train_label, epochs, batch_size=None, validation_split=0.1): history = model.fit(x=train_features, y=train_label, batch_size=batch_size, epochs=epochs, shuffle=True, validation_split=validation_split) epochs = history.epoch hist = pd.DataFrame(history.history) return epochs, histif __name__ == '__main__': (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() x_train_normalized = x_train / 255.0 x_test_normalized = x_test / 255.0 learning_rate = 0.003 epochs = 50 batch_size = 4000 validation_split = 0.2 my_model = create_model(learning_rate) epochs, hist = train_model(my_model, x_train_normalized, y_train, epochs, batch_size, validation_split) my_model.save('my_model')
模型被正确地保存到了“my_model”文件夹中。现在我在Java程序中重新加载它:
public class HelloTensorFlow { public static void main(final String[] args) { final String filePath = Paths.get("my_model").toAbsolutePath().toString(); try (final SavedModelBundle b = SavedModelBundle.load(filePath, "serve")) { final Session sess = b.session(); final Tensor<Float> x = Tensor.create(new float[1][28 * 28], Float.class); final List<Tensor<?>> run = sess.runner() .feed("input", x) .fetch("output") .run(); final float[] y = run.get(0).copyTo(new float[1]); System.out.println(y[0]); } }}
模型加载成功,但运行器无法工作。当我执行程序时,得到“No Operation named [input] in the Graph”的错误,尽管我的输入确实命名为“input”。我做错了什么?我使用的是最新的TensorFlow版本:2.3.0(Python)和1.15.0(Java)。
回答:
我已经解决了这个问题。TensorFlow 2似乎有奇怪的命名方案,但使用MetaGraphDef可以解码。首先,你需要org.tensorflow.proto依赖。然后,你可以从元图中提取信息,如下所示:
final MetaGraphDef metaGraphDef = MetaGraphDef.parseFrom(bundle.metaGraphDef());final SignatureDef signatureDef = metaGraphDef.getSignatureDefMap().get("serving_default");final TensorInfo inputTensorInfo = signatureDef.getInputsMap() .values() .stream() .filter(Objects::nonNull) .findFirst() .orElseThrow(() -> ...);final TensorInfo outputTensorInfo = signatureDef.getOutputsMap() .values() .stream() .filter(Objects::nonNull) .findFirst() .orElseThrow(() -> ...);
现在你可以将创建的张量输入到inputTensorInfo.getName()
返回的名称中,并从outputTensorInfo.getName()
获取结果。