我在colab上运行tensorflow 2.4。我尝试使用tf.train.Checkpoint()
保存模型,因为它包括了模型子类化,但在恢复后,我发现模型的权重没有被恢复。
以下是一些代码片段:
### 来自tensorflow教程 nmt_with_attentionclass Encoder(tf.keras.Model): def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz): ... self.gru = tf.keras.layers.GRU(self.enc_units, return_sequences=True, return_state=True, recurrent_initializer='glorot_uniform')...class NMT_Train(tf.keras.Model): def __init__(self, inp_vocab_size, tar_vocab_size, max_length_inp, max_length_tar, emb_dims, units, batch_size, source_tokenizer, target_tokenizer): super(NMT_Train, self).__init__() self.encoder = Encoder(inp_vocab_size, emb_dims, units, batch_size) ......model = NMT_Train(INP_VOCAB, TAR_VOCAB, MAXLEN, MAXLEN, EMB_DIMS, UNITS, BATCH_SIZE, english_tokenizer, hindi_tokenizer)model.compile(optimizer = tf.keras.optimizers.Adam(), loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits= True))model.fit(dataset, epochs=2)checkpoint = tf.train.Checkpoint(model = model)manager = tf.train.CheckpointManager(checkpoint, './ckpts', max_to_keep=1)manager.save()model.encoder.gru.get_weights() ### 获取输出##[array([[-0.0627057 , 0.05900152, 0.06614069, ...model.optimizer.get_weights() ### 获取输出##[90, array([[ 6.6851695e-05, -4.6736805e-06, -2.3183979e-05, ...
当我后来恢复时,我没有得到任何gru权重:
model = NMT_Train(INP_VOCAB, TAR_VOCAB, MAXLEN, MAXLEN, EMB_DIMS, UNITS, BATCH_SIZE, english_tokenizer, hindi_tokenizer)model.compile(optimizer = tf.keras.optimizers.Adam(), loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits= True))checkpoint = tf.train.Checkpoint(model = model)manager = tf.train.CheckpointManager(checkpoint, './ckpts', max_to_keep=1)manager.restore_or_initialize()model.encoder.gru.get_weights() ### 空列表## []model.optimizer.get_weights() ### 空列表## []
我也尝试了checkpoint.restore(manager.latest_checkpoint)
,但没有任何变化。
我做错了什么吗?或者建议其他保存模型的方法,以便我可以进一步训练更多的轮次。
回答:
您正在定义一个keras模型,为什么不使用keras模型检查点呢?
来自Keras文档:
model.compile(loss=..., optimizer=..., metrics=['accuracy'])EPOCHS = 10checkpoint_filepath = '/tmp/checkpoint'model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( filepath=checkpoint_filepath, save_weights_only=True, monitor='val_accuracy', mode='max', save_best_only=True)# 模型权重在每个epoch结束时保存,如果这是到目前为止看到的最佳结果# 所以模型权重(被认为是最佳的)将被加载到模型中。model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback])# 模型权重(被认为是最佳的)将被加载到模型中。model.load_weights(checkpoint_filepath)