我有一个使用 TensorFlow 2.1.0 在 Python 3.7 中编写的模型,试图在 Java 应用程序(使用 TensorFlow 1.4)中使用它,但模型无法接受输入。我猜测这是一个兼容性问题,但模型在 Java 中成功加载。我尝试使用 keras.Sequential
和 keras.Model
,但似乎没有效果。我看到在 TF v1 中使用了 tf.placeholder
,但了解到 v2 的替代品是 tf.keras.Input
。
Python:
#方法1model = tf.keras.Sequential([ tf.keras.Input(name='input', shape=(60,), dtype=tf.dtypes.float32), tf.keras.layers.Flatten(), tf.keras.layers.Dense(30, activation='relu'), tf.keras.layers.Dense(10, activation='relu'), tf.keras.layers.Dense(3, activation='softmax', name='output')])
#方法2inputs = tf.keras.Input(name='input', shape=(60,), dtype=tf.dtypes.float32)outputs = tf.keras.layers.Dense(3, activation='softmax')(inputs)model = tf.keras.Model(inputs, outputs)
Java:
Session.Runner runner = session.runner();runner.feed("input", Tensor.create(testData)); List<Tensor<?>> tensors = runner.fetch("output").run();System.out.println("答案是: " + tensors.get(0).floatValue());
异常:
2020-05-07 01:32:23.596732: I tensorflow/cc/saved_model/loader.cc:311] 加载 SavedModel 用于标签 { serve }; 状态: 成功。耗时 50986 微秒。线程 "main" 中的异常 java.lang.IllegalArgumentException: 图形中没有名为 [input] 的操作 at org.tensorflow.Session$Runner.operationByName(Session.java:380) at org.tensorflow.Session$Runner.parseOutput(Session.java:389) at org.tensorflow.Session$Runner.feed(Session.java:131) at com.treyyoder.smurge.ml.TensorFlowTest.main(TensorFlowTest.java:40)
!!!!!!!!!!!!!!!!!!!!!!! 更新 !!!!!!!!!!!!!!!!!!!!!!!
根据 @*** 的建议,我包含了 org.tensorflow:proto
以便能够检查 MetaGraphDef
MetaGraphDef 大约有 15k 行,以下是关键部分:
node { name: "StatefulPartitionedCall" op: "StatefulPartitionedCall" input: "serving_default_input" input: "dense/kernel" input: "dense/bias" input: "dense_1/kernel" input: "dense_1/bias" input: "output/kernel" input: "output/bias" attr { key: "_gradient_op_type" value { s: "PartitionedCallUnused" } } attr { key: "f" value { func { name: "__inference_signature_wrapper_9526" } } } attr { key: "Tout" value { list { type: DT_FLOAT } } } attr { key: "config_proto" value { s: "\n\a\n\003CPU\020\001\n\a\n\003GPU\020\0012\005*\0010J\0008\001" } } attr { key: "_output_shapes" value { list { shape { dim { size: -1 } dim { size: 3 } } } } } attr { key: "Tin" value { list { type: DT_FLOAT type: DT_RESOURCE type: DT_RESOURCE type: DT_RESOURCE type: DT_RESOURCE type: DT_RESOURCE type: DT_RESOURCE } } } }...node { name: "serving_default_input" op: "Placeholder" attr { key: "shape" value { shape { dim { size: -1 } dim { size: 60 } } } } attr { key: "dtype" value { type: DT_FLOAT } } attr { key: "_output_shapes" value { list { shape { dim { size: -1 } dim { size: 60 } } } } } }...signature_def { key: "serving_default" value { inputs { key: "input" value { name: "serving_default_input:0" dtype: DT_FLOAT tensor_shape { dim { size: -1 } dim { size: 60 } } } } outputs { key: "output" value { name: "StatefulPartitionedCall:0" dtype: DT_FLOAT tensor_shape { dim { size: -1 } dim { size: 3 } } } } method_name: "tensorflow/serving/predict" }}
我发现了正确的输入 serving_default_input
和输出 StatefulPartitionedCall
更新后的 Java 代码:
float[] fa = //传递给模型的数据List<Tensor<?>> tensor = runner.feed("serving_default_input", Tensor.create(fa)) .fetch("StatefulPartitionedCall").run();Tensor<Float> t1 = tensor.get(0).expect(Float.class);float[][] vector = t1.copyTo(new float[1][3]);for (float[] f : vector) { for (float ff : f) { System.out.println("结果: " + ff); }}
回答:
最佳方案是从模型签名中动态获取这些名称,并将它们提供给模型进行推理。
要在 Java 中查看保存模型的输入/输出列表,您可以从 SavedModelBundle
中检索 MetaGraphDef
,如这里所述:Tensorflow 2.0 & Java API。(您也可以使用 [saved_model_cli][1]
命令行工具进行双重检查)。
但请注意,TF2.x 在处理功能模型时存在一个错误,TF 在编码输入/输出签名时会进行一些未记录的名称混淆,如这里所描述。
此外,您可能想查看下一版 TF Java,它原生支持 TF2.x 版本,但目前仅以快照形式提供。