在Android上使用Tensorflow时遇到错误

我正在学习Tensorflow,并按照教程成功创建了一个自定义模型以在Android应用中运行,但遇到了问题。我的代码如下:

    public void testModel(Context ctx) {        String model_file = "file:///android_asset/model_graph.pb";        int[] result = new int[2];        float[] input = new float[]{0.0F, 1.0F, 0.0F, 1.0F, 1.0F, 0.0F, 0.0F, 1.0F, 0.0F, 1.0F, 0.0F, 1.0F, 0.0F, 1.0F, 1.0F, 0.0F, 0.0F, 0.0F, 0.0F, 1.0F, 1.0F, 0.0F, 1.0F, 0.0F, 1.0F, 0.0F, 1.0F, 0.0F, 0.0F, 1.0F, 1.0F, 0.0F, 0.0F, 0.0F, 1.0F, 0.0F, 0.0F, 1.0F, 0.0F, 1.0F, 0.0F, 1.0F, 1.0F, 0.0F, 0.0F, 1.0F, 0.0F, 0.0F, 0.0F, 1.0F, 0.0F, 1.0F, 0.0F, 1.0F, 1.0F, 0.0F, 1.0F, 0.0F, 0.0F, 0.0F, 1.0F, 0.0F, 1.0F, 0.0F, 1.0F, 0.0F, 1.0F, 0.0F};        TensorFlowInferenceInterface inferenceInterface;        inferenceInterface = new TensorFlowInferenceInterface(ctx.getAssets(), model_file);        inferenceInterface.feed("input", input, 68);        inferenceInterface.run(new String[]{"output"});        inferenceInterface.fetch("output", result);        Log.v(TAG, Arrays.toString(result));    }

当应用尝试运行 inferenceInterface.run(new String[]{"output"}) 方法时,我遇到了以下错误:

java.lang.IllegalArgumentException: In[0] is not a matrix[[Node: MatMul = MatMul[T=DT_FLOAT, transpose_a=false, transpose_b=false, _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_input_0_0, W1)]]

我不认为是我创建的模型出了问题,因为我之前在Python代码中使用它时结果是正面的。


回答:

从错误信息(In[0] is not a matrix)来看,似乎您的模型需要输入一个矩阵(即一个二维张量),而您提供的是一个包含68个元素的一维张量(向量)。

特别是,TensorFlowInferenceInterface.feed 方法中的 dims 参数似乎在以下行中设置不正确:

inferenceInterface.feed("input", input, 68);

应该改为类似于:

inferenceInterface.feed("input", input, 68, 1);

如果您的模型期望一个68×1的矩阵(或者如果期望34×2的矩阵则使用34, 2,期望17×4的矩阵则使用17, 4等)

希望这对您有帮助。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注