我正在开发一个用于年龄和性别识别的Android应用程序。我在GitHub上找到了一个有用的模型。他们基于一篇获得第一名的论文构建了一个Keras模型(使用TensorFlow后端)。他们提供了训练和构建网络的Python模块、已训练的权重文件可供下载和使用,以及一个在网络摄像头上的工作演示。
我想将他们在演示中提供的模型及其权重转换为.pb文件,以便它也可以在Android上运行。
我使用了这段代码,并根据模型进行了少量修改来进行转换:
from keras.models import Sequentialfrom keras.models import model_from_jsonfrom keras import backend as Kimport tensorflow as tffrom tensorflow.python.tools import freeze_graphimport os# 加载现有模型。with open("model.json",'r') as f: modelJSON = f.read()model = model_from_json(modelJSON)model.load_weights("weights.18-4.06.hdf5")print(model.summary())# 所有新操作从现在开始将处于测试模式。K.set_learning_phase(0)# 序列化模型并获取其权重,以便快速重建。config = model.get_config()weights = model.get_weights()# 重新构建一个学习阶段现在硬编码为0的模型。#new_model = model.from_config(config)#new_model.set_weights(weights)temp_dir = "graph"checkpoint_prefix = os.path.join(temp_dir, "saved_checkpoint")checkpoint_state_name = "checkpoint_state"input_graph_name = "input_graph.pb"output_graph_name = "output_graph.pb"# 临时将图形保存到磁盘,不包括权重。saver = tf.train.Saver()checkpoint_path = saver.save(K.get_session(), checkpoint_prefix, global_step=0, latest_filename=checkpoint_state_name)tf.train.write_graph(K.get_session().graph, temp_dir, input_graph_name)input_graph_path = os.path.join(temp_dir, input_graph_name)input_saver_def_path = ""input_binary = Falseoutput_node_names = "dense_1/Softmax,dense_2/Softmax" # 依赖于模型restore_op_name = "save/restore_all"filename_tensor_name = "save/Const:0"output_graph_path = os.path.join(temp_dir, output_graph_name)clear_devices = False# 将权重嵌入图形中并保存到磁盘。freeze_graph.freeze_graph(input_graph_path, input_saver_def_path, input_binary, checkpoint_path, output_node_names, restore_op_name, filename_tensor_name, output_graph_path, clear_devices, "")
我直接从演示中生成了model.json文件。demo.py文件中的主要函数代码和model.json如下所示:
def main(): args = get_args() depth = args.depth k = args.width weight_file = args.weight_file if not weight_file: weight_file = get_file("weights.18-4.06.hdf5", pretrained_model, cache_subdir="pretrained_models", file_hash=modhash, cache_dir=os.path.dirname(os.path.abspath(__file__))) # 用于人脸检测 detector = dlib.get_frontal_face_detector() # 加载模型和权重 img_size = 64 model = WideResNet(img_size, depth=depth, k=k)() model.load_weights(weight_file) print(model.summary()) # 将模型写入JSON model_json = model.to_json() with open("model.json", "w") as json_file: json_file.write(model_json) for img in yield_images(): input_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img_h, img_w, _ = np.shape(input_img) # 使用dlib检测器检测人脸 detected = detector(input_img, 1) faces = np.empty((len(detected), img_size, img_size, 3)) if len(detected) > 0: for i, d in enumerate(detected): x1, y1, x2, y2, w, h = d.left(), d.top(), d.right() + 1, d.bottom() + 1, d.width(), d.height() xw1 = max(int(x1 - 0.4 * w), 0) yw1 = max(int(y1 - 0.4 * h), 0) xw2 = min(int(x2 + 0.4 * w), img_w - 1) yw2 = min(int(y2 + 0.4 * h), img_h - 1) cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 2) # cv2.rectangle(img, (xw1, yw1), (xw2, yw2), (255, 0, 0), 2) faces[i, :, :, :] = cv2.resize(img[yw1:yw2 + 1, xw1:xw2 + 1, :], (img_size, img_size)) # 预测检测到的人脸的年龄和性别 results = model.predict(faces) predicted_genders = results[0] ages = np.arange(0, 101).reshape(101, 1) predicted_ages = results[1].dot(ages).flatten() # 绘制结果 for i, d in enumerate(detected): label = "{}, {}".format(int(predicted_ages[i]), "F" if predicted_genders[i][0] > 0.5 else "M") draw_label(img, (d.left(), d.top()), label) cv2.imshow("result", img) key = cv2.waitKey(30) if key == 27: breakif __name__ == '__main__': main()
代码成功编译并生成了多个检查点文件以及一个.pb文件。
这是模型的图形摘要:
__________________________________________________________________________________________________Layer (type) Output Shape Param # Connected to ==================================================================================================input_1 (InputLayer) (None, 64, 64, 3) 0 __________________________________________________________________________________________________conv2d_1 (Conv2D) (None, 64, 64, 16) 432 input_1[0][0] __________________________________________________________________________________________________batch_normalization_1 (BatchNor (None, 64, 64, 16) 64 conv2d_1[0][0] __________________________________________________________________________________________________activation_1 (Activation) (None, 64, 64, 16) 0 batch_normalization_1[0][0] __________________________________________________________________________________________________conv2d_2 (Conv2D) (None, 64, 64, 128) 18432 activation_1[0][0] __________________________________________________________________________________________________batch_normalization_2 (BatchNor (None, 64, 64, 128) 512 conv2d_2[0][0] __________________________________________________________________________________________________activation_2 (Activation) (None, 64, 64, 128) 0 batch_normalization_2[0][0] __________________________________________________________________________________________________conv2d_3 (Conv2D) (None, 64, 64, 128) 147456 activation_2[0][0] __________________________________________________________________________________________________conv2d_4 (Conv2D) (None, 64, 64, 128) 2048 activation_1[0][0] __________________________________________________________________________________________________add_1 (Add) (None, 64, 64, 128) 0 conv2d_3[0][0] conv2d_4[0][0] __________________________________________________________________________________________________batch_normalization_3 (BatchNor (None, 64, 64, 128) 512 add_1[0][0] __________________________________________________________________________________________________activation_3 (Activation) (None, 64, 64, 128) 0 batch_normalization_3[0][0] __________________________________________________________________________________________________conv2d_5 (Conv2D) (None, 64, 64, 128) 147456 activation_3[0][0] __________________________________________________________________________________________________batch_normalization_4 (BatchNor (None, 64, 64, 128) 512 conv2d_5[0][0] __________________________________________________________________________________________________activation_4 (Activation) (None, 64, 64, 128) 0 batch_normalization_4[0][0] __________________________________________________________________________________________________conv2d_6 (Conv2D) (None, 64, 64, 128) 147456 activation_4[0][0] __________________________________________________________________________________________________add_2 (Add) (None, 64, 64, 128) 0 conv2d_6[0][0] add_1[0][0] __________________________________________________________________________________________________batch_normalization_5 (BatchNor (None, 64, 64, 128) 512 add_2[0][0] __________________________________________________________________________________________________activation_5 (Activation) (None, 64, 64, 128) 0 batch_normalization_5[0][0] __________________________________________________________________________________________________conv2d_7 (Conv2D) (None, 32, 32, 256) 294912 activation_5[0][0] __________________________________________________________________________________________________batch_normalization_6 (BatchNor (None, 32, 32, 256) 1024 conv2d_7[0][0] __________________________________________________________________________________________________activation_6 (Activation) (None, 32, 32, 256) 0 batch_normalization_6[0][0] __________________________________________________________________________________________________conv2d_8 (Conv2D) (None, 32, 32, 256) 589824 activation_6[0][0] __________________________________________________________________________________________________conv2d_9 (Conv2D) (None, 32, 32, 256) 32768 activation_5[0][0] __________________________________________________________________________________________________add_3 (Add) (None, 32, 32, 256) 0 conv2d_8[0][0] conv2d_9[0][0] __________________________________________________________________________________________________batch_normalization_7 (BatchNor (None, 32, 32, 256) 1024 add_3[0][0] __________________________________________________________________________________________________activation_7 (Activation) (None, 32, 32, 256) 0 batch_normalization_7[0][0] __________________________________________________________________________________________________conv2d_10 (Conv2D) (None, 32, 32, 256) 589824 activation_7[0][0] __________________________________________________________________________________________________batch_normalization_8 (BatchNor (None, 32, 32, 256) 1024 conv2d_10[0][0] __________________________________________________________________________________________________activation_8 (Activation) (None, 32, 32, 256) 0 batch_normalization_8[0][0] __________________________________________________________________________________________________conv2d_11 (Conv2D) (None, 32, 32, 256) 589824 activation_8[0][0] __________________________________________________________________________________________________add_4 (Add) (None, 32, 32, 256) 0 conv2d_11[0][0] add_3[0][0] __________________________________________________________________________________________________batch_normalization_9 (BatchNor (None, 32, 32, 256) 1024 add_4[0][0] __________________________________________________________________________________________________activation_9 (Activation) (None, 32, 32, 256) 0 batch_normalization_9[0][0] __________________________________________________________________________________________________conv2d_12 (Conv2D) (None, 16, 16, 512) 1179648 activation_9[0][0] __________________________________________________________________________________________________batch_normalization_10 (BatchNo (None, 16, 16, 512) 2048 conv2d_12[0][0] __________________________________________________________________________________________________activation_10 (Activation) (None, 16, 16, 512) 0 batch_normalization_10[0][0] __________________________________________________________________________________________________conv2d_13 (Conv2D) (None, 16, 16, 512) 2359296 activation_10[0][0] __________________________________________________________________________________________________conv2d_14 (Conv2D) (None, 16, 16, 512) 131072 activation_9[0][0] __________________________________________________________________________________________________add_5 (Add) (None, 16, 16, 512) 0 conv2d_13[0][0] conv2d_14[0][0] __________________________________________________________________________________________________batch_normalization_11 (BatchNo (None, 16, 16, 512) 2048 add_5[0][0] __________________________________________________________________________________________________activation_11 (Activation) (None, 16, 16, 512) 0 batch_normalization_11[0][0] __________________________________________________________________________________________________conv2d_15 (Conv2D) (None, 16, 16, 512) 2359296 activation_11[0][0] __________________________________________________________________________________________________batch_normalization_12 (BatchNo (None, 16, 16, 512) 2048 conv2d_15[0][0] __________________________________________________________________________________________________activation_12 (Activation) (None, 16, 16, 512) 0 batch_normalization_12[0][0] __________________________________________________________________________________________________conv2d_16 (Conv2D) (None, 16, 16, 512) 2359296 activation_12[0][0] __________________________________________________________________________________________________add_6 (Add) (None, 16, 16, 512) 0 conv2d_16[0][0] add_5[0][0] __________________________________________________________________________________________________batch_normalization_13 (BatchNo (None, 16, 16, 512) 2048 add_6[0][0] __________________________________________________________________________________________________activation_13 (Activation) (None, 16, 16, 512) 0 batch_normalization_13[0][0] __________________________________________________________________________________________________average_pooling2d_1 (AveragePoo (None, 16, 16, 512) 0 activation_13[0][0] __________________________________________________________________________________________________flatten_1 (Flatten) (None, 131072) 0 average_pooling2d_1[0][0] __________________________________________________________________________________________________dense_1 (Dense) (None, 2) 262144 flatten_1[0][0] __________________________________________________________________________________________________dense_2 (Dense) (None, 101) 13238272 flatten_1[0][0] ==================================================================================================Total params: 24,463,856Trainable params: 24,456,656Non-trainable params: 7,200__________________________________________________________________________________________________
我使用了输出的模型,并使用以下脚本进行推理优化:
python -m tensorflow.python.tools.optimize_for_inference --input output_graph.pb --output g.pb --input_names=input_1 --output_names=dense_1/Softmax,dense_2/Softmax
在操作过程中,终端显示了很多这样的警告:
FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`. from ._conv import register_converters as _register_convertersWARNING:tensorflow:Incorrect shape for mean, found (0,), expected (16,), for node batch_normalization_1/FusedBatchNormWARNING:tensorflow:Incorrect shape for mean, found (0,), expected (128,), for node batch_normalization_2/FusedBatchNormWARNING:tensorflow:Didn't find expected Conv2D input to 'batch_normalization_3/FusedBatchNorm'WARNING:tensorflow:Incorrect shape for mean, found (0,), expected (128,), for node batch_normalization_4/FusedBatchNormWARNING:tensorflow:Didn't find expected Conv2D input to 'batch_normalization_5/FusedBatchNorm'WARNING:tensorflow:Incorrect shape for mean, found (0,), expected (256,), for node batch_normalization_6/FusedBatchNormWARNING:tensorflow:Didn't find expected Conv2D input to 'batch_normalization_7/FusedBatchNorm'WARNING:tensorflow:Incorrect shape for mean, found (0,), expected (256,), for node batch_normalization_8/FusedBatchNormWARNING:tensorflow:Didn't find expected Conv2D input to 'batch_normalization_9/FusedBatchNorm'WARNING:tensorflow:Incorrect shape for mean, found (0,), expected (512,), for node batch_normalization_10/FusedBatchNormWARNING:tensorflow:Didn't find expected Conv2D input to 'batch_normalization_11/FusedBatchNorm'WARNING:tensorflow:Incorrect shape for mean, found (0,), expected (512,), for node batch_normalization_12/FusedBatchNormWARNING:tensorflow:Didn't find expected Conv2D input to 'batch_normalization_13/FusedBatchNorm'
这些警告看起来很严重!
我尝试在我的Android应用中使用这两个文件。优化后的文件完全无法工作,而未优化的文件虽然可以执行,但产生了无意义的结果“例如:随机猜测”。
我知道这个问题有点长,但这是我一天工作的总结,我不想遗漏任何细节。
我不知道问题出在哪里。是输出节点名称的问题,还是冻结图形、使用权重实例化模型,或者是优化推理脚本的问题。
回答:
经过一番研究,随机猜测的问题终于解决了。
问题不在于将模型转换为.pb
文件,如我最初所预期的,而是在Android上正确地将图像输入模型。
我重新进行了模型转换。以下几点将总结我的工作。
- 首先,我从上述问题中提到的demo.py中获取了模型。我使用以下代码保存了它:
# 将模型保存为.h5文件。model.save('./saved_model/model.h5')
-
其次,我使用生成的.h5文件将其转换为.pb文件。我使用了这个存储库中的代码。如果您无法通过超链接访问,链接是:https://github.com/amir-abdi/keras_to_tensorflow。这个存储库的代码证明了它的可靠性。它一次性将模型转换为
.pb
文件并优化推理。真是太棒了! -
第三,我将生成的
.pb
文件放入Android的assets文件夹,以便在我的应用程序中配置它。 -
第四,我将目标图像转换为像素值,并通过位移提取颜色。我使用以下代码完成了这项任务。请记住,getPixels方法保留了颜色通道。因此,如果您需要反转颜色通道,请按照以下代码操作。我从这个回答中得到了帮助。
Bitmap bitmap = createScaledBitmap(faces[0], INPUT_SIZE , INPUT_SIZE , true); // 获取像素值 bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight()); for (int i = 0; i < intValues.length; ++i) { final int val = intValues[i]; // 使用位移提取颜色。 floatValues[i * 3 + 0] = ((val >> 16) & 0xFF ); floatValues[i * 3 + 1] = ((val >> 8) & 0xFF ); floatValues[i * 3 + 2] = (val & 0xFF ); // 反转颜色顺序。 floatValues[i*3 + 2] = Color.red(val); floatValues[i*3 + 1] = Color.green(val); floatValues[i*3] = Color.blue(val); }
-
最后,我可以使用TensorFlow的推理方法将图像输入模型,进行推理,并输出结果。