我正在学习Tensorflow,并按照教程成功创建了一个自定义模型以在Android应用中运行,但遇到了问题。我的代码如下:
public void testModel(Context ctx) { String model_file = "file:///android_asset/model_graph.pb"; int[] result = new int[2]; float[] input = new float[]{0.0F, 1.0F, 0.0F, 1.0F, 1.0F, 0.0F, 0.0F, 1.0F, 0.0F, 1.0F, 0.0F, 1.0F, 0.0F, 1.0F, 1.0F, 0.0F, 0.0F, 0.0F, 0.0F, 1.0F, 1.0F, 0.0F, 1.0F, 0.0F, 1.0F, 0.0F, 1.0F, 0.0F, 0.0F, 1.0F, 1.0F, 0.0F, 0.0F, 0.0F, 1.0F, 0.0F, 0.0F, 1.0F, 0.0F, 1.0F, 0.0F, 1.0F, 1.0F, 0.0F, 0.0F, 1.0F, 0.0F, 0.0F, 0.0F, 1.0F, 0.0F, 1.0F, 0.0F, 1.0F, 1.0F, 0.0F, 1.0F, 0.0F, 0.0F, 0.0F, 1.0F, 0.0F, 1.0F, 0.0F, 1.0F, 0.0F, 1.0F, 0.0F}; TensorFlowInferenceInterface inferenceInterface; inferenceInterface = new TensorFlowInferenceInterface(ctx.getAssets(), model_file); inferenceInterface.feed("input", input, 68); inferenceInterface.run(new String[]{"output"}); inferenceInterface.fetch("output", result); Log.v(TAG, Arrays.toString(result)); }
当应用尝试运行 inferenceInterface.run(new String[]{"output"})
方法时,我遇到了以下错误:
java.lang.IllegalArgumentException: In[0] is not a matrix[[Node: MatMul = MatMul[T=DT_FLOAT, transpose_a=false, transpose_b=false, _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_input_0_0, W1)]]
我不认为是我创建的模型出了问题,因为我之前在Python代码中使用它时结果是正面的。
回答:
从错误信息(In[0] is not a matrix
)来看,似乎您的模型需要输入一个矩阵(即一个二维张量),而您提供的是一个包含68个元素的一维张量(向量)。
特别是,TensorFlowInferenceInterface.feed
方法中的 dims
参数似乎在以下行中设置不正确:
inferenceInterface.feed("input", input, 68);
应该改为类似于:
inferenceInterface.feed("input", input, 68, 1);
如果您的模型期望一个68×1的矩阵(或者如果期望34×2的矩阵则使用34, 2
,期望17×4的矩阵则使用17, 4
等)
希望这对您有帮助。