TensorFlow: 保存和恢复模型

我在尝试在训练结束时保存模型,并在每次训练开始时恢复它。我按照这个链接的指示操作。

    saver = tf.train.Saver()    with tf.Session(graph=graph) as session:        # 初始化权重和偏置        tf.global_variables_initializer().run()        new_saver = tf.train.import_meta_graph('model.meta')        new_saver.restore(sess,tf.train.latest_checkpoint('./'))        W1 = session.run(W)        print(W1)        for curr_epoch in range(num_epochs):            train_cost = train_ler = 0            start = time.time()            for batch in range(num_batches_per_epoch):                ...一些训练代码...        W2 = session.run(W)        print(W2)        save_path = saver.save(session, "models/model")

但它出现了以下错误:

--->  new_saver.restore(session, tf.train.latest_checkpoint('./'))SystemError: <built-in function TF_Run> returned a result with an error set

谁能帮帮我吗?非常感谢!


回答:

如果你要使用 ./ 来加载模型,你必须确保你的控制台(你用来启动Python程序的)实际设置在那个目录(models/)。但在这种情况下,它会将你的新数据保存到一个新目录中。所以应该使用 ./models/ 来加载。

(另外,你不需要初始化变量,恢复操作会帮你完成这个步骤。)

Related Posts

关于k折交叉验证的直观问题

我在使用交叉验证检查预测能力时遇到了一些直观问题,我认…

调整numpy数组大小以使用sklearn的train_test_split函数?

我正在尝试使用sklearn中的test_train_…

如何转换二维张量和索引张量以便用于torch.nn.utils.rnn.pack_sequence

我有一组序列,格式如下: sequences = to…

模型预测值的含义是什么?

我在网上找到一个数字识别器的CNN模型并进行了训练,当…

锯齿张量作为LSTM的输入

了解锯齿张量以及如何在TensorFlow中使用它们。…

如何告诉SciKit的LinearRegression模型预测值不能小于零?

我有以下代码,尝试根据非价格基础特征来估值股票。 pr…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注