我正在尝试将我的Tensorflow (.pb)格式的模型转换为Keras (.h5)格式,以便查看事后注意力可视化。我尝试了以下代码。
file_pb = "/test.pb"file_h5 = "/test.h5"loaded_model = tf.keras.models.load_model(file_pb)tf.keras.models.save_keras_model(loaded_model, file_h5)loaded_model_from_h5 = tf.keras.models.load_model(file_h5)
有谁能帮我解决这个问题吗?这是可能的吗?
回答:
在最新的Tensorflow版本(2.2)
中,当我们使用tf.keras.models.save_model
保存模型时,模型不仅会保存为pb文件
,还会保存到一个包含Variables
文件夹和Assets
文件夹的文件夹中,此外还有saved_model.pb
文件,如下图所示:
例如,如果模型
以名称"Model"
保存,我们需要使用文件夹名称”Model”来加载
,而不是saved_model.pb
,如下所示:
loaded_model = tf.keras.models.load_model('Model')
而不是
loaded_model = tf.keras.models.load_model('saved_model.pb')
你可以做的另一个更改是将
tf.keras.models.save_keras_model
替换为
tf.keras.models.save_model
将模型从Tensorflow保存模型格式(pb)
转换为Keras保存模型格式(h5)
的完整工作代码如下所示:
import osimport tensorflow as tffrom tensorflow.keras.preprocessing import imageNew_Model = tf.keras.models.load_model('Dogs_Vs_Cats_Model') # 加载Tensorflow保存的模型(PB)print(New_Model.summary())
New_Model.summary
命令的输出如下:
Layer (type) Output Shape Param # =================================================================conv2d (Conv2D) (None, 148, 148, 32) 896 _________________________________________________________________max_pooling2d (MaxPooling2D) (None, 74, 74, 32) 0 _________________________________________________________________conv2d_1 (Conv2D) (None, 72, 72, 64) 18496 _________________________________________________________________max_pooling2d_1 (MaxPooling2 (None, 36, 36, 64) 0 _________________________________________________________________conv2d_2 (Conv2D) (None, 34, 34, 128) 73856 _________________________________________________________________max_pooling2d_2 (MaxPooling2 (None, 17, 17, 128) 0 _________________________________________________________________conv2d_3 (Conv2D) (None, 15, 15, 128) 147584 _________________________________________________________________max_pooling2d_3 (MaxPooling2 (None, 7, 7, 128) 0 _________________________________________________________________flatten (Flatten) (None, 6272) 0 _________________________________________________________________dense (Dense) (None, 512) 3211776 _________________________________________________________________dense_1 (Dense) (None, 1) 513 =================================================================Total params: 3,453,121Trainable params: 3,453,121Non-trainable params: 0_________________________________________________________________None
继续代码:
# 将模型保存为H5格式并加载它(检查是否与PB格式相同)tf.keras.models.save_model(New_Model, 'New_Model.h5') # 将模型保存为H5格式loaded_model_from_h5 = tf.keras.models.load_model('New_Model.h5') # 加载H5保存的模型print(loaded_model_from_h5.summary())
print(loaded_model_from_h5.summary())
命令的输出如下所示:
Model: "sequential"_________________________________________________________________Layer (type) Output Shape Param # =================================================================conv2d (Conv2D) (None, 148, 148, 32) 896 _________________________________________________________________max_pooling2d (MaxPooling2D) (None, 74, 74, 32) 0 _________________________________________________________________conv2d_1 (Conv2D) (None, 72, 72, 64) 18496 _________________________________________________________________max_pooling2d_1 (MaxPooling2 (None, 36, 36, 64) 0 _________________________________________________________________conv2d_2 (Conv2D) (None, 34, 34, 128) 73856 _________________________________________________________________max_pooling2d_2 (MaxPooling2 (None, 17, 17, 128) 0 _________________________________________________________________conv2d_3 (Conv2D) (None, 15, 15, 128) 147584 _________________________________________________________________max_pooling2d_3 (MaxPooling2 (None, 7, 7, 128) 0 _________________________________________________________________flatten (Flatten) (None, 6272) 0 _________________________________________________________________dense (Dense) (None, 512) 3211776 _________________________________________________________________dense_1 (Dense) (None, 1) 513 =================================================================Total params: 3,453,121Trainable params: 3,453,121Non-trainable params: 0_________________________________________________________________
从上述两个模型
的摘要
可以看出,这两个模型
是相同的。