相同的Tensorflow模型在Android和Python上给出不同结果

我在Android应用程序中运行一个Tensorflow模型时,发现与在桌面Python上运行时相比,同一训练模型给出了不同的结果(错误的推断)。

该模型是一个简单的顺序CNN,用于识别字符,类似于这个车牌识别网络,但没有窗口处理,因为我的模型已经将字符裁剪到位。

我有以下几点:

  • 模型保存为protobuf (.pb) 文件 – 在Python/Linux + GPU上使用Keras建模和训练
  • 在另一台电脑上使用纯Tensorflow测试了推断,以确保Keras不是问题所在。在这里,结果如预期的那样。
  • Python和Android上使用的是Tensorflow 1.3.0。在Python上通过PIP安装,在Android上通过jcenter安装。
  • Android上的结果与预期结果不符。
  • 输入是一张129*45的RGB图像,因此是一个129*45*3的数组,输出是一个4*36的数组(表示4个字符,从0-9和a-z)。

我使用了这个代码将Keras模型保存为.pb文件。

Python代码,这按预期工作:

test_image = [ndimage.imread("test_image.png", mode="RGB").astype(float)/255]imTensor = np.asarray(test_image)def load_graph(model_file):  graph = tf.Graph()  graph_def = tf.GraphDef()  with open(model_file, "rb") as f:    graph_def.ParseFromString(f.read())  with graph.as_default():    tf.import_graph_def(graph_def)  return graphgraph=load_graph("model.pb")with tf.Session(graph=graph) as sess:    input_operation = graph.get_operation_by_name("import/conv2d_1_input")    output_operation = graph.get_operation_by_name("import/output_node0")    results = sess.run(output_operation.outputs[0],                  {input_operation.outputs[0]: imTensor})

Android代码,基于这个示例;这给出了看似随机的结果:

Bitmap bitmap;try {    InputStream stream = getAssets().open("test_image.png");    bitmap = BitmapFactory.decodeStream(stream);} catch (IOException e) {    e.printStackTrace();}inferenceInterface = new TensorFlowInferenceInterface(context.getAssets(), "model.pb");int[] intValues = new int[129*45];float[] floatValues = new float[129*45*3];String outputName = "output_node0";String[] outputNodes = new String[]{outputName};float[] outputs = new float[4*36];bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());for (int i = 0; i < intValues.length; ++i) {    final int val = intValues[i];    floatValues[i * 3 + 0] = ((val >> 16) & 0xFF) / 255;    floatValues[i * 3 + 1] = ((val >> 8) & 0xFF) / 255;    floatValues[i * 3 + 2] = (val & 0xFF) / 255;}inferenceInterface.feed("conv2d_1_input", floatValues, 1, 45, 129, 3);inferenceInterface.run(outputNodes, false);inferenceInterface.fetch(outputName, outputs);

任何帮助都将不胜感激!


回答:

一个问题出现在以下几行代码中:

    floatValues[i * 3 + 0] = ((val >> 16) & 0xFF) / 255;    floatValues[i * 3 + 1] = ((val >> 8) & 0xFF) / 255;    floatValues[i * 3 + 2] = (val & 0xFF) / 255;

在这里,RGB值被一个整数除,因此每次得到的结果都是整数(即每次都是0)。

此外,即使使用255.0进行除法得到0到1.0之间的浮点数,可能会产生问题,因为这些值在投影空间(0..1)中的分布与自然状态下不同。为了解释这一点:在传感器域中一个值为255(例如R值)意味着测量信号的自然值落在“255”这个范围内,这个范围包含了一系列的能量/强度等。将这个值映射到1.0可能会切掉它的一半范围,因为后续的计算可能会在最大乘数1.0处饱和,而1.0实际上只是+-1/256范围的中点。因此,也许更正确的转换是将值映射到0..1范围内的256个桶的中点:

((val & 0xff) / 256.0) + (0.5/256.0)

但这只是我的猜测。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注