将Keras模型导出为.pb文件并优化推理在Android上给出随机猜测

我正在开发一个用于年龄和性别识别的Android应用程序。我在GitHub上找到了一个有用的模型。他们基于一篇获得第一名的论文构建了一个Keras模型(使用TensorFlow后端)。他们提供了训练和构建网络的Python模块、已训练的权重文件可供下载和使用,以及一个在网络摄像头上的工作演示。

我想将他们在演示中提供的模型及其权重转换为.pb文件,以便它也可以在Android上运行。

我使用了这段代码,并根据模型进行了少量修改来进行转换:

from keras.models import Sequentialfrom keras.models import model_from_jsonfrom keras import backend as Kimport tensorflow as tffrom tensorflow.python.tools import freeze_graphimport os# 加载现有模型。with open("model.json",'r') as f:    modelJSON = f.read()model = model_from_json(modelJSON)model.load_weights("weights.18-4.06.hdf5")print(model.summary())# 所有新操作从现在开始将处于测试模式。K.set_learning_phase(0)# 序列化模型并获取其权重,以便快速重建。config = model.get_config()weights = model.get_weights()# 重新构建一个学习阶段现在硬编码为0的模型。#new_model = model.from_config(config)#new_model.set_weights(weights)temp_dir = "graph"checkpoint_prefix = os.path.join(temp_dir, "saved_checkpoint")checkpoint_state_name = "checkpoint_state"input_graph_name = "input_graph.pb"output_graph_name = "output_graph.pb"# 临时将图形保存到磁盘,不包括权重。saver = tf.train.Saver()checkpoint_path = saver.save(K.get_session(), checkpoint_prefix, global_step=0, latest_filename=checkpoint_state_name)tf.train.write_graph(K.get_session().graph, temp_dir, input_graph_name)input_graph_path = os.path.join(temp_dir, input_graph_name)input_saver_def_path = ""input_binary = Falseoutput_node_names = "dense_1/Softmax,dense_2/Softmax" # 依赖于模型restore_op_name = "save/restore_all"filename_tensor_name = "save/Const:0"output_graph_path = os.path.join(temp_dir, output_graph_name)clear_devices = False# 将权重嵌入图形中并保存到磁盘。freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,                          input_binary, checkpoint_path,                          output_node_names, restore_op_name,                          filename_tensor_name, output_graph_path,                          clear_devices, "")

我直接从演示中生成了model.json文件。demo.py文件中的主要函数代码和model.json如下所示:

def main():    args = get_args()    depth = args.depth    k = args.width    weight_file = args.weight_file    if not weight_file:        weight_file = get_file("weights.18-4.06.hdf5", pretrained_model, cache_subdir="pretrained_models",                               file_hash=modhash, cache_dir=os.path.dirname(os.path.abspath(__file__)))    # 用于人脸检测    detector = dlib.get_frontal_face_detector()    # 加载模型和权重    img_size = 64    model = WideResNet(img_size, depth=depth, k=k)()    model.load_weights(weight_file)    print(model.summary())    # 将模型写入JSON    model_json = model.to_json()    with open("model.json", "w") as json_file:        json_file.write(model_json)    for img in yield_images():        input_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)        img_h, img_w, _ = np.shape(input_img)        # 使用dlib检测器检测人脸        detected = detector(input_img, 1)        faces = np.empty((len(detected), img_size, img_size, 3))        if len(detected) > 0:            for i, d in enumerate(detected):                x1, y1, x2, y2, w, h = d.left(), d.top(), d.right() + 1, d.bottom() + 1, d.width(), d.height()                xw1 = max(int(x1 - 0.4 * w), 0)                yw1 = max(int(y1 - 0.4 * h), 0)                xw2 = min(int(x2 + 0.4 * w), img_w - 1)                yw2 = min(int(y2 + 0.4 * h), img_h - 1)                cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 2)                # cv2.rectangle(img, (xw1, yw1), (xw2, yw2), (255, 0, 0), 2)                faces[i, :, :, :] = cv2.resize(img[yw1:yw2 + 1, xw1:xw2 + 1, :], (img_size, img_size))            # 预测检测到的人脸的年龄和性别            results = model.predict(faces)            predicted_genders = results[0]            ages = np.arange(0, 101).reshape(101, 1)            predicted_ages = results[1].dot(ages).flatten()            # 绘制结果            for i, d in enumerate(detected):                label = "{}, {}".format(int(predicted_ages[i]),                                        "F" if predicted_genders[i][0] > 0.5 else "M")                draw_label(img, (d.left(), d.top()), label)        cv2.imshow("result", img)        key = cv2.waitKey(30)        if key == 27:            breakif __name__ == '__main__':    main()

代码成功编译并生成了多个检查点文件以及一个.pb文件。

这是模型的图形摘要:

__________________________________________________________________________________________________Layer (type)                    Output Shape         Param #     Connected to                     ==================================================================================================input_1 (InputLayer)            (None, 64, 64, 3)    0                                            __________________________________________________________________________________________________conv2d_1 (Conv2D)               (None, 64, 64, 16)   432         input_1[0][0]                    __________________________________________________________________________________________________batch_normalization_1 (BatchNor (None, 64, 64, 16)   64          conv2d_1[0][0]                   __________________________________________________________________________________________________activation_1 (Activation)       (None, 64, 64, 16)   0           batch_normalization_1[0][0]      __________________________________________________________________________________________________conv2d_2 (Conv2D)               (None, 64, 64, 128)  18432       activation_1[0][0]               __________________________________________________________________________________________________batch_normalization_2 (BatchNor (None, 64, 64, 128)  512         conv2d_2[0][0]                   __________________________________________________________________________________________________activation_2 (Activation)       (None, 64, 64, 128)  0           batch_normalization_2[0][0]      __________________________________________________________________________________________________conv2d_3 (Conv2D)               (None, 64, 64, 128)  147456      activation_2[0][0]               __________________________________________________________________________________________________conv2d_4 (Conv2D)               (None, 64, 64, 128)  2048        activation_1[0][0]               __________________________________________________________________________________________________add_1 (Add)                     (None, 64, 64, 128)  0           conv2d_3[0][0]                                                                                    conv2d_4[0][0]                   __________________________________________________________________________________________________batch_normalization_3 (BatchNor (None, 64, 64, 128)  512         add_1[0][0]                      __________________________________________________________________________________________________activation_3 (Activation)       (None, 64, 64, 128)  0           batch_normalization_3[0][0]      __________________________________________________________________________________________________conv2d_5 (Conv2D)               (None, 64, 64, 128)  147456      activation_3[0][0]               __________________________________________________________________________________________________batch_normalization_4 (BatchNor (None, 64, 64, 128)  512         conv2d_5[0][0]                   __________________________________________________________________________________________________activation_4 (Activation)       (None, 64, 64, 128)  0           batch_normalization_4[0][0]      __________________________________________________________________________________________________conv2d_6 (Conv2D)               (None, 64, 64, 128)  147456      activation_4[0][0]               __________________________________________________________________________________________________add_2 (Add)                     (None, 64, 64, 128)  0           conv2d_6[0][0]                                                                                    add_1[0][0]                      __________________________________________________________________________________________________batch_normalization_5 (BatchNor (None, 64, 64, 128)  512         add_2[0][0]                      __________________________________________________________________________________________________activation_5 (Activation)       (None, 64, 64, 128)  0           batch_normalization_5[0][0]      __________________________________________________________________________________________________conv2d_7 (Conv2D)               (None, 32, 32, 256)  294912      activation_5[0][0]               __________________________________________________________________________________________________batch_normalization_6 (BatchNor (None, 32, 32, 256)  1024        conv2d_7[0][0]                   __________________________________________________________________________________________________activation_6 (Activation)       (None, 32, 32, 256)  0           batch_normalization_6[0][0]      __________________________________________________________________________________________________conv2d_8 (Conv2D)               (None, 32, 32, 256)  589824      activation_6[0][0]               __________________________________________________________________________________________________conv2d_9 (Conv2D)               (None, 32, 32, 256)  32768       activation_5[0][0]               __________________________________________________________________________________________________add_3 (Add)                     (None, 32, 32, 256)  0           conv2d_8[0][0]                                                                                    conv2d_9[0][0]                   __________________________________________________________________________________________________batch_normalization_7 (BatchNor (None, 32, 32, 256)  1024        add_3[0][0]                      __________________________________________________________________________________________________activation_7 (Activation)       (None, 32, 32, 256)  0           batch_normalization_7[0][0]      __________________________________________________________________________________________________conv2d_10 (Conv2D)              (None, 32, 32, 256)  589824      activation_7[0][0]               __________________________________________________________________________________________________batch_normalization_8 (BatchNor (None, 32, 32, 256)  1024        conv2d_10[0][0]                  __________________________________________________________________________________________________activation_8 (Activation)       (None, 32, 32, 256)  0           batch_normalization_8[0][0]      __________________________________________________________________________________________________conv2d_11 (Conv2D)              (None, 32, 32, 256)  589824      activation_8[0][0]               __________________________________________________________________________________________________add_4 (Add)                     (None, 32, 32, 256)  0           conv2d_11[0][0]                                                                                   add_3[0][0]                      __________________________________________________________________________________________________batch_normalization_9 (BatchNor (None, 32, 32, 256)  1024        add_4[0][0]                      __________________________________________________________________________________________________activation_9 (Activation)       (None, 32, 32, 256)  0           batch_normalization_9[0][0]      __________________________________________________________________________________________________conv2d_12 (Conv2D)              (None, 16, 16, 512)  1179648     activation_9[0][0]               __________________________________________________________________________________________________batch_normalization_10 (BatchNo (None, 16, 16, 512)  2048        conv2d_12[0][0]                  __________________________________________________________________________________________________activation_10 (Activation)      (None, 16, 16, 512)  0           batch_normalization_10[0][0]     __________________________________________________________________________________________________conv2d_13 (Conv2D)              (None, 16, 16, 512)  2359296     activation_10[0][0]              __________________________________________________________________________________________________conv2d_14 (Conv2D)              (None, 16, 16, 512)  131072      activation_9[0][0]               __________________________________________________________________________________________________add_5 (Add)                     (None, 16, 16, 512)  0           conv2d_13[0][0]                                                                                   conv2d_14[0][0]                  __________________________________________________________________________________________________batch_normalization_11 (BatchNo (None, 16, 16, 512)  2048        add_5[0][0]                      __________________________________________________________________________________________________activation_11 (Activation)      (None, 16, 16, 512)  0           batch_normalization_11[0][0]     __________________________________________________________________________________________________conv2d_15 (Conv2D)              (None, 16, 16, 512)  2359296     activation_11[0][0]              __________________________________________________________________________________________________batch_normalization_12 (BatchNo (None, 16, 16, 512)  2048        conv2d_15[0][0]                  __________________________________________________________________________________________________activation_12 (Activation)      (None, 16, 16, 512)  0           batch_normalization_12[0][0]     __________________________________________________________________________________________________conv2d_16 (Conv2D)              (None, 16, 16, 512)  2359296     activation_12[0][0]              __________________________________________________________________________________________________add_6 (Add)                     (None, 16, 16, 512)  0           conv2d_16[0][0]                                                                                   add_5[0][0]                      __________________________________________________________________________________________________batch_normalization_13 (BatchNo (None, 16, 16, 512)  2048        add_6[0][0]                      __________________________________________________________________________________________________activation_13 (Activation)      (None, 16, 16, 512)  0           batch_normalization_13[0][0]     __________________________________________________________________________________________________average_pooling2d_1 (AveragePoo (None, 16, 16, 512)  0           activation_13[0][0]              __________________________________________________________________________________________________flatten_1 (Flatten)             (None, 131072)       0           average_pooling2d_1[0][0]        __________________________________________________________________________________________________dense_1 (Dense)                 (None, 2)            262144      flatten_1[0][0]                  __________________________________________________________________________________________________dense_2 (Dense)                 (None, 101)          13238272    flatten_1[0][0]                  ==================================================================================================Total params: 24,463,856Trainable params: 24,456,656Non-trainable params: 7,200__________________________________________________________________________________________________

我使用了输出的模型,并使用以下脚本进行推理优化:

python -m tensorflow.python.tools.optimize_for_inference --input output_graph.pb --output g.pb --input_names=input_1 --output_names=dense_1/Softmax,dense_2/Softmax

在操作过程中,终端显示了很多这样的警告:

 FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.  from ._conv import register_converters as _register_convertersWARNING:tensorflow:Incorrect shape for mean, found (0,), expected (16,), for node batch_normalization_1/FusedBatchNormWARNING:tensorflow:Incorrect shape for mean, found (0,), expected (128,), for node batch_normalization_2/FusedBatchNormWARNING:tensorflow:Didn't find expected Conv2D input to 'batch_normalization_3/FusedBatchNorm'WARNING:tensorflow:Incorrect shape for mean, found (0,), expected (128,), for node batch_normalization_4/FusedBatchNormWARNING:tensorflow:Didn't find expected Conv2D input to 'batch_normalization_5/FusedBatchNorm'WARNING:tensorflow:Incorrect shape for mean, found (0,), expected (256,), for node batch_normalization_6/FusedBatchNormWARNING:tensorflow:Didn't find expected Conv2D input to 'batch_normalization_7/FusedBatchNorm'WARNING:tensorflow:Incorrect shape for mean, found (0,), expected (256,), for node batch_normalization_8/FusedBatchNormWARNING:tensorflow:Didn't find expected Conv2D input to 'batch_normalization_9/FusedBatchNorm'WARNING:tensorflow:Incorrect shape for mean, found (0,), expected (512,), for node batch_normalization_10/FusedBatchNormWARNING:tensorflow:Didn't find expected Conv2D input to 'batch_normalization_11/FusedBatchNorm'WARNING:tensorflow:Incorrect shape for mean, found (0,), expected (512,), for node batch_normalization_12/FusedBatchNormWARNING:tensorflow:Didn't find expected Conv2D input to 'batch_normalization_13/FusedBatchNorm'

这些警告看起来很严重!

我尝试在我的Android应用中使用这两个文件。优化后的文件完全无法工作,而未优化的文件虽然可以执行,但产生了无意义的结果“例如:随机猜测”。

我知道这个问题有点长,但这是我一天工作的总结,我不想遗漏任何细节。

我不知道问题出在哪里。是输出节点名称的问题,还是冻结图形、使用权重实例化模型,或者是优化推理脚本的问题。


回答:

经过一番研究,随机猜测的问题终于解决了。

问题不在于将模型转换为.pb文件,如我最初所预期的,而是在Android上正确地将图像输入模型。

我重新进行了模型转换。以下几点将总结我的工作。

  • 首先,我从上述问题中提到的demo.py中获取了模型。我使用以下代码保存了它:# 将模型保存为.h5文件。model.save('./saved_model/model.h5')
  • 其次,我使用生成的.h5文件将其转换为.pb文件。我使用了这个存储库中的代码。如果您无法通过超链接访问,链接是:https://github.com/amir-abdi/keras_to_tensorflow。这个存储库的代码证明了它的可靠性。它一次性将模型转换为.pb文件并优化推理。真是太棒了!

  • 第三,我将生成的.pb文件放入Android的assets文件夹,以便在我的应用程序中配置它。

  • 第四,我将目标图像转换为像素值,并通过位移提取颜色。我使用以下代码完成了这项任务。请记住,getPixels方法保留了颜色通道。因此,如果您需要反转颜色通道,请按照以下代码操作。我从这个回答中得到了帮助。

        Bitmap bitmap = createScaledBitmap(faces[0], INPUT_SIZE , INPUT_SIZE , true);    // 获取像素值    bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());    for (int i = 0; i < intValues.length; ++i) {        final int val = intValues[i];        // 使用位移提取颜色。        floatValues[i * 3 + 0] = ((val >> 16) & 0xFF );        floatValues[i * 3 + 1] = ((val >> 8) & 0xFF );        floatValues[i * 3 + 2] = (val & 0xFF );        // 反转颜色顺序。        floatValues[i*3 + 2] = Color.red(val);        floatValues[i*3 + 1] = Color.green(val);        floatValues[i*3] = Color.blue(val);    }
  • 最后,我可以使用TensorFlow的推理方法将图像输入模型,进行推理,并输出结果。

Related Posts

L1-L2正则化的不同系数

我想对网络的权重同时应用L1和L2正则化。然而,我找不…

使用scikit-learn的无监督方法将列表分类成不同组别,有没有办法?

我有一系列实例,每个实例都有一份列表,代表它所遵循的不…

f1_score metric in lightgbm

我想使用自定义指标f1_score来训练一个lgb模型…

通过相关系数矩阵进行特征选择

我在测试不同的算法时,如逻辑回归、高斯朴素贝叶斯、随机…

可以将机器学习库用于流式输入和输出吗?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

在TensorFlow中,queue.dequeue_up_to()方法的用途是什么?

我对这个方法感到非常困惑,特别是当我发现这个令人费解的…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注