以下是我在Tensorflow 2.5中实现的一个子类模型的代码:
from tensorflow.keras import Model, Inputfrom tensorflow.keras.applications import DenseNet201from tensorflow.keras.applications.densenet import preprocess_inputfrom tensorflow.keras.layers import Conv2D, Flatten, Densefrom tensorflow.random import uniformfrom tensorflow.keras.models import load_model class Detector(Model): def __init__(self, num_classes=3, name="DenseNet201"): super(Detector, self).__init__(name=name) self.feature_extractor = DenseNet201( include_top=False, weights="imagenet", ) self.feature_extractor.trainable = False self.flatten_layer = Flatten() self.prediction_layer = Dense(num_classes, activation=None) def call(self, inputs): x = preprocess_input(inputs) self.extracted_feature = self.feature_extractor(x, training=False) x = self.flatten_layer(self.extracted_feature) x = self.prediction_layer(x) return x
在测试我的代码时,我发现了一个让我很困惑的问题。
detector = Detector()print(detector.extracted_feature)
这会引发一个错误: AttributeError: ‘Detector’ 对象没有属性 ‘extracted_feature’,这是可以理解的,因为我从未调用过这个模型。在调用模型后,Detector
对象现在有了 extracted_feature
属性。因此,以下代码将不会报错:
image_tensor_1 = uniform(shape=(1, 600, 600, 3))y_hat = detector(image_tensor_1)print(detector.extracted_feature.shape)
然而,在尝试通过运行 detector.save("my_model")
保存模型并将模型加载回一个新变量 new_detector = load_model("my_model")
后,运行以下代码时我得到了一个错误:
image_tensor_2 = uniform(shape=(1, 600, 600, 3))y_hat = new_detector(image_tensor_2)print(new_detector.extracted_feature.shape)
AttributeError: ‘Detector’ 对象没有属性 ‘extracted_feature’.
self.extracted_feature
是用来计算梯度的。我需要继续跟踪它以确保梯度不会是 None
。我该怎么做才能访问 extracted_feature
属性呢?
回答:
你可以这样做
def call(self, inputs): x = preprocess_input(inputs) extracted_feature = self.feature_extractor(x, training=False) x = self.flatten_layer(extracted_feature) x = self.prediction_layer(x) return extracted_feature, x
检查
image_tensor_1 = uniform(shape=(1, 32, 32, 3))detector = Detector()ex_feat, y_hat = detector(image_tensor_1)print(ex_feat.shape)(1, 1, 1, 512)
保存并重新加载。
detector.save("my_model")new_detector = load_model("my_model")image_tensor_2 = uniform(shape=(1, 32, 32, 3))ex_feat, y_hat = new_detector(image_tensor_2)print(ex_feat.shape)(1, 1, 1, 512)
仅供参考,如果你想从基础模型中获取中间层的输出,那么你可能需要在 __init__
方法中以这种方式初始化你的基础模型。