我在使用keras处理图像分类问题。我尝试使用子类化API
来完成几乎所有工作。我已经创建了自己的custom
卷积块,代码如下所示:
class ConvBlock(keras.layers.Layer): def __init__(self, in_features, kernel_size=(3, 3)): super(ConvBlock, self).__init__() self.conv = keras.layers.Conv2D(in_features, kernel_size, padding="same") self.bn = keras.layers.BatchNormalization() self.relu = keras.activations.relu def call(self, x, training=False): x = self.conv(x) x = self.bn(x, training=training) return self.relu(x)
之后,我创建了一个简单的Sequential
模型用于测试,代码如下:
seq_model = keras.Sequential([ ConvBlock(64), ConvBlock(128), ConvBlock(64), keras.layers.Flatten(), keras.layers.Dense(64, activation='relu'), keras.layers.Dense(128, activation='relu'), keras.layers.Dense(64, activation='relu'), keras.layers.Dense(5, activation='softmax'),], name="seq_model")seq_model.build((None, 96, 96, 1))seq_model.summary()
到目前为止一切顺利,如果我在这个seq_model
上调用.compile()
、.train()
和.evaluate()
,它都能正常工作。问题出现在我尝试使用自定义的.compile()
、.train()
和.evaluate()
时。以下是我创建它们的代码:
class Model(keras.Model): def __init__(self, model): super().__init__() self.model = model # .compile() def compile(self, loss, optimizer, metrics): super().compile() self.loss = loss self.optimizer = optimizer self.custom_metrics = metrics # .fit() def train_step(self, data): x, y = data with tf.GradientTape() as tape: pred = self.model(x, training=True) loss = self.loss(y, pred) gradients = tape.gradient(loss, self.trainable_variables) optimizer.apply_gradients(zip(gradients, self.trainable_variables)) self.custom_metrics.update_state(y, pred) return {"loss": loss, "accuracy": self.custom_metrics.result()} # .evaluate() def test_step(self, data): x, y = data pred = self.model(x, training=False) loss = self.loss(y, pred) self.custom_metrics.update_state(y, pred) return {"loss": loss, "accuracy": self.custom_metrics.result()}
这是我调用它的方式。
yoga_model = Model(seq_model)yoga_model.compile( loss = keras.losses.CategoricalCrossentropy(from_logits=False), optimizer = keras.optimizers.Adam(lr=0.001), metrics = keras.metrics.CategoricalAccuracy(name="acc"))yoga_model.fit(train_ds, epochs=1, verbose=1)
请帮助我。任何帮助都会被非常感激。
回答:
在您的自定义模型中使用子类化API时,请按以下方式实现call
方法:
from tensorflow import keras class Model(keras.Model): def __init__: self.model = model def train_step: def test_step: def compile: # 实现call方法 def call(self, inputs, *args, **kwargs): return self.model(inputs)