我在Android应用程序中运行一个Tensorflow模型时,发现与在桌面Python上运行时相比,同一训练模型给出了不同的结果(错误的推断)。
该模型是一个简单的顺序CNN,用于识别字符,类似于这个车牌识别网络,但没有窗口处理,因为我的模型已经将字符裁剪到位。
我有以下几点:
- 模型保存为protobuf (.pb) 文件 – 在Python/Linux + GPU上使用Keras建模和训练
- 在另一台电脑上使用纯Tensorflow测试了推断,以确保Keras不是问题所在。在这里,结果如预期的那样。
- Python和Android上使用的是Tensorflow 1.3.0。在Python上通过PIP安装,在Android上通过jcenter安装。
- Android上的结果与预期结果不符。
- 输入是一张129*45的RGB图像,因此是一个129*45*3的数组,输出是一个4*36的数组(表示4个字符,从0-9和a-z)。
我使用了这个代码将Keras模型保存为.pb文件。
Python代码,这按预期工作:
test_image = [ndimage.imread("test_image.png", mode="RGB").astype(float)/255]imTensor = np.asarray(test_image)def load_graph(model_file): graph = tf.Graph() graph_def = tf.GraphDef() with open(model_file, "rb") as f: graph_def.ParseFromString(f.read()) with graph.as_default(): tf.import_graph_def(graph_def) return graphgraph=load_graph("model.pb")with tf.Session(graph=graph) as sess: input_operation = graph.get_operation_by_name("import/conv2d_1_input") output_operation = graph.get_operation_by_name("import/output_node0") results = sess.run(output_operation.outputs[0], {input_operation.outputs[0]: imTensor})
Android代码,基于这个示例;这给出了看似随机的结果:
Bitmap bitmap;try { InputStream stream = getAssets().open("test_image.png"); bitmap = BitmapFactory.decodeStream(stream);} catch (IOException e) { e.printStackTrace();}inferenceInterface = new TensorFlowInferenceInterface(context.getAssets(), "model.pb");int[] intValues = new int[129*45];float[] floatValues = new float[129*45*3];String outputName = "output_node0";String[] outputNodes = new String[]{outputName};float[] outputs = new float[4*36];bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());for (int i = 0; i < intValues.length; ++i) { final int val = intValues[i]; floatValues[i * 3 + 0] = ((val >> 16) & 0xFF) / 255; floatValues[i * 3 + 1] = ((val >> 8) & 0xFF) / 255; floatValues[i * 3 + 2] = (val & 0xFF) / 255;}inferenceInterface.feed("conv2d_1_input", floatValues, 1, 45, 129, 3);inferenceInterface.run(outputNodes, false);inferenceInterface.fetch(outputName, outputs);
任何帮助都将不胜感激!
回答:
一个问题出现在以下几行代码中:
floatValues[i * 3 + 0] = ((val >> 16) & 0xFF) / 255; floatValues[i * 3 + 1] = ((val >> 8) & 0xFF) / 255; floatValues[i * 3 + 2] = (val & 0xFF) / 255;
在这里,RGB值被一个整数除,因此每次得到的结果都是整数(即每次都是0)。
此外,即使使用255.0
进行除法得到0到1.0之间的浮点数,可能会产生问题,因为这些值在投影空间(0..1)中的分布与自然状态下不同。为了解释这一点:在传感器域中一个值为255(例如R值)意味着测量信号的自然值落在“255”这个范围内,这个范围包含了一系列的能量/强度等。将这个值映射到1.0可能会切掉它的一半范围,因为后续的计算可能会在最大乘数1.0处饱和,而1.0实际上只是+-1/256范围的中点。因此,也许更正确的转换是将值映射到0..1范围内的256个桶的中点:
((val & 0xff) / 256.0) + (0.5/256.0)
但这只是我的猜测。