在Android上使用Tensorflow时遇到错误

我正在学习Tensorflow,并按照教程成功创建了一个自定义模型以在Android应用中运行,但遇到了问题。我的代码如下:

    public void testModel(Context ctx) {        String model_file = "file:///android_asset/model_graph.pb";        int[] result = new int[2];        float[] input = new float[]{0.0F, 1.0F, 0.0F, 1.0F, 1.0F, 0.0F, 0.0F, 1.0F, 0.0F, 1.0F, 0.0F, 1.0F, 0.0F, 1.0F, 1.0F, 0.0F, 0.0F, 0.0F, 0.0F, 1.0F, 1.0F, 0.0F, 1.0F, 0.0F, 1.0F, 0.0F, 1.0F, 0.0F, 0.0F, 1.0F, 1.0F, 0.0F, 0.0F, 0.0F, 1.0F, 0.0F, 0.0F, 1.0F, 0.0F, 1.0F, 0.0F, 1.0F, 1.0F, 0.0F, 0.0F, 1.0F, 0.0F, 0.0F, 0.0F, 1.0F, 0.0F, 1.0F, 0.0F, 1.0F, 1.0F, 0.0F, 1.0F, 0.0F, 0.0F, 0.0F, 1.0F, 0.0F, 1.0F, 0.0F, 1.0F, 0.0F, 1.0F, 0.0F};        TensorFlowInferenceInterface inferenceInterface;        inferenceInterface = new TensorFlowInferenceInterface(ctx.getAssets(), model_file);        inferenceInterface.feed("input", input, 68);        inferenceInterface.run(new String[]{"output"});        inferenceInterface.fetch("output", result);        Log.v(TAG, Arrays.toString(result));    }

当应用尝试运行 inferenceInterface.run(new String[]{"output"}) 方法时,我遇到了以下错误:

java.lang.IllegalArgumentException: In[0] is not a matrix[[Node: MatMul = MatMul[T=DT_FLOAT, transpose_a=false, transpose_b=false, _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_input_0_0, W1)]]

我不认为是我创建的模型出了问题,因为我之前在Python代码中使用它时结果是正面的。


回答:

从错误信息(In[0] is not a matrix)来看,似乎您的模型需要输入一个矩阵(即一个二维张量),而您提供的是一个包含68个元素的一维张量(向量)。

特别是,TensorFlowInferenceInterface.feed 方法中的 dims 参数似乎在以下行中设置不正确:

inferenceInterface.feed("input", input, 68);

应该改为类似于:

inferenceInterface.feed("input", input, 68, 1);

如果您的模型期望一个68×1的矩阵(或者如果期望34×2的矩阵则使用34, 2,期望17×4的矩阵则使用17, 4等)

希望这对您有帮助。

Related Posts

L1-L2正则化的不同系数

我想对网络的权重同时应用L1和L2正则化。然而,我找不…

使用scikit-learn的无监督方法将列表分类成不同组别,有没有办法?

我有一系列实例,每个实例都有一份列表,代表它所遵循的不…

f1_score metric in lightgbm

我想使用自定义指标f1_score来训练一个lgb模型…

通过相关系数矩阵进行特征选择

我在测试不同的算法时,如逻辑回归、高斯朴素贝叶斯、随机…

可以将机器学习库用于流式输入和输出吗?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

在TensorFlow中,queue.dequeue_up_to()方法的用途是什么?

我对这个方法感到非常困惑,特别是当我发现这个令人费解的…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注