我想自定义模型的fit
函数,以便只有当模型在验证数据上的预测有所改善时,才对权重应用梯度下降。我这样做的原因是为了防止过拟合。
根据这个指南,应该可以自定义模型的fit
函数。然而,以下代码运行时会出现错误:
class CustomModel(tf.keras.Model): def train_step(self, data): x, y = data with tf.GradientTape() as tape: y_pred = self(x, training=True) loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses) trainable_vars = self.trainable_variables gradients = tape.gradient(loss, trainable_vars) ### 检查并应用梯度 Y_pred_val = self.predict(X_val) # 这行代码不起作用 acc_val = calculate_accuracy(Y_val, Y_pred_val) if acc_val > last_acc_val: self.optimizer.apply_gradients(zip(gradients, trainable_vars)) ### self.compiled_metrics.update_state(y, y_pred) return_obj = {m.name: m.result() for m in self.metrics} return_obj["acc_val"] = acc_val return return_obj
如何才能在fit
函数内评估模型呢?
回答:
你不需要子类化fit()
来实现这个功能。你可以创建一个自定义的训练循环。看看我是怎么做的:
import osos.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'from tensorflow.keras import Modelimport tensorflow as tffrom tensorflow.keras.layers import Dense, Conv2D, MaxPooling2D, Flatten, Concatenateimport tensorflow_datasets as tfdsfrom tensorflow.keras.regularizers import l1, l2, l1_l2from collections import dequedataset, info = tfds.load('mnist', with_info=True, split='train', as_supervised=False)TAKE = 1_000data = dataset.map(lambda x: (tf.cast(x['image'], tf.float32), x['label'])).shuffle(TAKE).take(TAKE)len_train = int(8e-1*TAKE)train = data.take(len_train).batch(8)test = data.skip(len_train).take(info.splits['train'].num_examples - len_train).batch(8)class CNN(Model): def __init__(self): super(CNN, self).__init__() self.layer1 = Dense(32, activation=tf.nn.relu, kernel_regularizer=l1(1e-2), input_shape=info.features['image'].shape) self.layer2 = Conv2D(filters=16, kernel_size=(3, 3), strides=(1, 1), activation='relu', input_shape=info.features['image'].shape) self.layer3 = MaxPooling2D(pool_size=(2, 2)) self.layer4 = Conv2D(filters=32, kernel_size=(3, 3), strides=(1, 1), activation=tf.nn.elu, kernel_initializer=tf.keras.initializers.glorot_normal) self.layer5 = MaxPooling2D(pool_size=(2, 2)) self.layer6 = Flatten() self.layer7 = Dense(units=64, activation=tf.nn.relu, kernel_regularizer=l2(1e-2)) self.layer8 = Dense(units=64, activation=tf.nn.relu, kernel_regularizer=l1_l2(l1=1e-2, l2=1e-2)) self.layer9 = Concatenate() self.layer10 = Dense(units=info.features['label'].num_classes) def call(self, inputs, training=None, **kwargs): b = self.layer1(inputs) a = self.layer2(inputs) a = self.layer3(a) a = self.layer4(a) a = self.layer5(a) a = self.layer6(a) a = self.layer8(a) b = self.layer7(b) b = self.layer6(b) x = self.layer9([a, b]) x = self.layer10(x) return xcnn = CNN()loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)train_loss = tf.keras.metrics.Mean()test_loss = tf.keras.metrics.Mean()train_acc = tf.keras.metrics.SparseCategoricalAccuracy()test_acc = tf.keras.metrics.SparseCategoricalAccuracy()optimizer = tf.keras.optimizers.Nadam()template = 'Epoch {:3} Train Loss {:7.4f} Test Loss {:7.4f} ' \ 'Train Acc {:6.2%} Test Acc {:6.2%} 'epochs = 5early_stop = epochs//50loss_hist = deque()acc_hist = deque(maxlen=1)acc_hist.append(0)for epoch in range(1, epochs + 1): train_loss.reset_states() test_loss.reset_states() train_acc.reset_states() test_acc.reset_states() for images, labels in train: with tf.GradientTape() as tape: logits = cnn(images, training=True) loss = loss_object(labels, logits) train_loss(loss) train_acc(labels, logits) current_acc = tf.metrics.SparseCategoricalAccuracy()(labels, logits) if tf.greater(current_acc, acc_hist[-1]): print('IMPROVEMENT.') gradients = tape.gradient(loss, cnn.trainable_variables) optimizer.apply_gradients(zip(gradients, cnn.trainable_variables)) acc_hist.append(current_acc) for images, labels in test: logits = cnn(images, training=False) loss = loss_object(labels, logits) test_loss(loss) test_acc(labels, logits) print(template.format(epoch, train_loss.result(), test_loss.result(), train_acc.result(), test_acc.result())) if len(loss_hist) > early_stop and loss_hist.popleft() < min(loss_hist): print('Early stopping. No validation loss decrease in %i epochs.' % early_stop) break