tf.train.Checkpoint是否正在恢复?

我在colab上运行tensorflow 2.4。我尝试使用tf.train.Checkpoint()保存模型,因为它包括了模型子类化,但在恢复后,我发现模型的权重没有被恢复。

以下是一些代码片段:

### 来自tensorflow教程 nmt_with_attentionclass Encoder(tf.keras.Model):  def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):    ...    self.gru = tf.keras.layers.GRU(self.enc_units,                                   return_sequences=True,                                   return_state=True,                                   recurrent_initializer='glorot_uniform')...class NMT_Train(tf.keras.Model):  def __init__(self, inp_vocab_size, tar_vocab_size, max_length_inp, max_length_tar, emb_dims, units, batch_size, source_tokenizer, target_tokenizer):    super(NMT_Train, self).__init__()    self.encoder = Encoder(inp_vocab_size, emb_dims, units, batch_size)    ......model = NMT_Train(INP_VOCAB, TAR_VOCAB, MAXLEN, MAXLEN, EMB_DIMS, UNITS, BATCH_SIZE, english_tokenizer, hindi_tokenizer)model.compile(optimizer = tf.keras.optimizers.Adam(),              loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits= True))model.fit(dataset, epochs=2)checkpoint = tf.train.Checkpoint(model = model)manager = tf.train.CheckpointManager(checkpoint, './ckpts', max_to_keep=1)manager.save()model.encoder.gru.get_weights() ### 获取输出##[array([[-0.0627057 ,  0.05900152,  0.06614069, ...model.optimizer.get_weights() ### 获取输出##[90, array([[ 6.6851695e-05, -4.6736805e-06, -2.3183979e-05, ...

当我后来恢复时,我没有得到任何gru权重:

model = NMT_Train(INP_VOCAB, TAR_VOCAB, MAXLEN, MAXLEN, EMB_DIMS, UNITS, BATCH_SIZE, english_tokenizer, hindi_tokenizer)model.compile(optimizer = tf.keras.optimizers.Adam(),              loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits= True))checkpoint = tf.train.Checkpoint(model = model)manager = tf.train.CheckpointManager(checkpoint, './ckpts', max_to_keep=1)manager.restore_or_initialize()model.encoder.gru.get_weights() ### 空列表## []model.optimizer.get_weights() ### 空列表## []

我也尝试了checkpoint.restore(manager.latest_checkpoint),但没有任何变化。

我做错了什么吗?或者建议其他保存模型的方法,以便我可以进一步训练更多的轮次。


回答:

您正在定义一个keras模型,为什么不使用keras模型检查点呢?

来自Keras文档

model.compile(loss=..., optimizer=...,              metrics=['accuracy'])EPOCHS = 10checkpoint_filepath = '/tmp/checkpoint'model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(    filepath=checkpoint_filepath,    save_weights_only=True,    monitor='val_accuracy',    mode='max',    save_best_only=True)# 模型权重在每个epoch结束时保存,如果这是到目前为止看到的最佳结果# 所以模型权重(被认为是最佳的)将被加载到模型中。model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback])# 模型权重(被认为是最佳的)将被加载到模型中。model.load_weights(checkpoint_filepath)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注