Tensorflow 2: 加载后无法跟踪子类模型属性的问题

以下是我在Tensorflow 2.5中实现的一个子类模型的代码:

from tensorflow.keras import Model, Inputfrom tensorflow.keras.applications import DenseNet201from tensorflow.keras.applications.densenet import preprocess_inputfrom tensorflow.keras.layers import Conv2D, Flatten, Densefrom tensorflow.random import uniformfrom tensorflow.keras.models import load_model class Detector(Model):        def __init__(self, num_classes=3, name="DenseNet201"):        super(Detector, self).__init__(name=name)        self.feature_extractor = DenseNet201(            include_top=False,            weights="imagenet",        )        self.feature_extractor.trainable = False        self.flatten_layer = Flatten()        self.prediction_layer = Dense(num_classes, activation=None)    def call(self, inputs):        x = preprocess_input(inputs)        self.extracted_feature = self.feature_extractor(x, training=False)        x = self.flatten_layer(self.extracted_feature)        x = self.prediction_layer(x)        return x

在测试我的代码时,我发现了一个让我很困惑的问题。

detector = Detector()print(detector.extracted_feature)

这会引发一个错误: AttributeError: ‘Detector’ 对象没有属性 ‘extracted_feature’,这是可以理解的,因为我从未调用过这个模型。在调用模型后,Detector 对象现在有了 extracted_feature 属性。因此,以下代码将不会报错:

image_tensor_1 = uniform(shape=(1, 600, 600, 3))y_hat = detector(image_tensor_1)print(detector.extracted_feature.shape)

然而,在尝试通过运行 detector.save("my_model") 保存模型并将模型加载回一个新变量 new_detector = load_model("my_model") 后,运行以下代码时我得到了一个错误:

image_tensor_2 = uniform(shape=(1, 600, 600, 3))y_hat = new_detector(image_tensor_2)print(new_detector.extracted_feature.shape)

AttributeError: ‘Detector’ 对象没有属性 ‘extracted_feature’.

self.extracted_feature 是用来计算梯度的。我需要继续跟踪它以确保梯度不会是 None。我该怎么做才能访问 extracted_feature 属性呢?


回答:

你可以这样做

    def call(self, inputs):        x = preprocess_input(inputs)        extracted_feature = self.feature_extractor(x, training=False)        x = self.flatten_layer(extracted_feature)        x = self.prediction_layer(x)        return extracted_feature, x

检查

image_tensor_1 = uniform(shape=(1, 32, 32, 3))detector = Detector()ex_feat, y_hat = detector(image_tensor_1)print(ex_feat.shape)(1, 1, 1, 512)

保存并重新加载。

detector.save("my_model")new_detector = load_model("my_model")image_tensor_2 = uniform(shape=(1, 32, 32, 3))ex_feat, y_hat = new_detector(image_tensor_2)print(ex_feat.shape)(1, 1, 1, 512)

仅供参考,如果你想从基础模型中获取中间层的输出,那么你可能需要在 __init__ 方法中以这种方式初始化你的基础模型。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注