如何创建自定义模型类

如果这是一个简单的问题或我做错了什么,请原谅,因为我对Keras/TensorFlow和Python都还不太熟悉。我正在尝试基于迁移学习测试一些不同的图像分类模型。为此,我希望创建一个函数来构建模型,只需指定一些参数,它就能自动生成所需的模型。我已经编写了以下代码:

class modelMaker(tf.keras.Model):  def __init__(self, img_height, img_width, trained='None'):    super(modelMaker, self).__init__()    self.x = tf.keras.Input(shape=(img_height, img_width, 3),name="input_layer")    if (trained == 'None'):      pass    elif (trained == 'ResNet50'):      self.x = tf.keras.applications.resnet50.preprocess_input(self.x)      IMG_SHAPE = (img_height,img_width) + (3,)      base_model = tf.keras.applications.ResNet50(input_shape=IMG_SHAPE,                                                  include_top=False,                                                  weights='imagenet')      base_model.trainable = False      for layer in base_model.layers:        if isinstance(layer, keras.layers.BatchNormalization):          layer.trainable = True        else:          layer.trainable = False      self.x = base_model(self.x)  def call(self, inputs):    return self.x(inputs)

目前我只实现了ResNet50和一个空选项,但我计划添加更多。我尝试使用 self.x = LAYER(self.x)添加层的理由是因为模型可以根据未来的参数具有不同的层数。

然而,当我尝试使用model.summary()获取模型的摘要时,我得到了以下错误:

ValueError: 该模型尚未构建。首先通过调用build()或使用一些数据调用fit()来构建模型,或者在第一层或多层中指定input_shape参数以自动构建。

这样构建模型是可能的吗?感谢您的帮助


回答:

model.summary()需要一些关于输入形状和模型结构(层)的信息,以便为您打印这些信息。因此,您应该在某个地方将这些信息提供给model对象。

如果您使用顺序模型或函数式API,只需为运行model.summary()指定input_shape参数就足够了。如果您不指定input_shape,那么您可以调用您的模型或使用model.build来提供这些信息。

但是,当您使用子类化(就像您所做的那样)时,除非您调用call()函数(因为您在call函数中定义了层的结构并向其传递输入),否则此类对象没有关于形状和层的任何信息。

有三种方法可以调用call()函数:

  1. model.fit():在训练时调用
    • 可能不符合您的需求,因为您必须首先训练您的模型。
  2. model.build():内部调用
    • 只需像model.build((1,128,128,3))一样传递输入的形状
  3. model():直接调用
    • 您需要至少传递一个样本(张量),如model(tf.random.uniform((1,128,128,3))

修改后的代码应该如下所示:

class modelMaker(tf.keras.Model):    def __init__(self, img_height, img_width, num_classes=1, trained='dense'):        super(modelMaker, self).__init__()        self.trained = trained        self.IMG_SHAPE = (img_height,img_width) + (3,)        # 定义通用层        self.flat = tf.keras.layers.Flatten(name="flatten")        self.classify = tf.keras.layers.Dense(num_classes, name="classify")        # 定义当"trained" != "resnet"时的层        if self.trained == "dense":            self.dense = tf.keras.layers.Dense(128, name="dense128")                 # 当"trained" == "resnet"时的层        else:            self.pre_resnet = tf.keras.applications.resnet50.preprocess_input            self.base_model = tf.keras.applications.ResNet50(input_shape=self.IMG_SHAPE, include_top=False, weights='imagenet')            self.base_model.trainable = False            for layer in self.base_model.layers:                if isinstance(layer, tf.keras.layers.BatchNormalization):                    layer.trainable = True                else:                    layer.trainable = False        def call(self, inputs):        # 定义不含ResNet的模型         if self.trained == "dense":            x = self.flat(inputs)            x = self.dense(x)            x = self.classify(x)            return x        # 定义含ResNet的模型        else:            x = self.pre_resnet(inputs)            x = self.base_model(x)            x = self.flat(x)            x = self.classify(x)            return x            # 添加此函数以获取模型摘要的正确输出    def summary(self):        x = tf.keras.Input(shape=self.IMG_SHAPE, name="input_layer")        model = tf.keras.Model(inputs=[x], outputs=self.call(x))        return model.summary()    model = modelMaker(128, 128, trained="resnet") # 创建对象model.build((10,128,128,3))                    # 构建模型model.summary()                                # 打印摘要

输出为:

Model: "model_9"_________________________________________________________________Layer (type)                 Output Shape              Param #   =================================================================input_layer (InputLayer)           [(None, 128, 128, 3)]     0         _________________________________________________________________tf.__operators__.getitem_6 ( (None, 128, 128, 3)       0         _________________________________________________________________tf.nn.bias_add_6 (TFOpLambda (None, 128, 128, 3)       0         _________________________________________________________________resnet50 (Functional)        (None, 4, 4, 2048)        23587712  _________________________________________________________________flatten (Flatten)            (None, 32768)             0         _________________________________________________________________classify (Dense)             (None, 1)                 32769     =================================================================Total params: 23,620,481Trainable params: 32,769Non-trainable params: 23,587,712_________________________________________________________________

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注