tensorflow 仅保存初始化值

我试图保存一些变量并查看是否可以稍后恢复它们。以下是我的保存代码:

   import tensorflow as tf;   my_a = tf.Variable(2,name = "my_a");   my_b = tf.Variable(3,name = "my_b");   my_c = tf.Variable(4,name = "my_c");   my_c = tf.add(my_a,my_b);   with tf.Session() as sess:       init = tf.initialize_all_variables();       sess.run(init);       print("my_c =  ",sess.run(my_c));       saver = tf.train.Saver();       saver.save(sess,"test.ckpt");

这会输出:

    my_c =   5

当我恢复它时:

   import tensorflow as tf;   c = tf.Variable(3100,dtype = tf.int32);   with tf.Session() as sess:       sess.run(tf.initialize_all_variables());       saver = tf.train.Saver({"my_c":c});       saver.restore(sess, "test.ckpt");       cc= sess.run(c);       print(cc);

这会给我:

    4

恢复的 my_c 值应该是 5,因为它是 my_a 和 my_b 的和。然而,它给我的是 4,这是 my_c 的初始化值。能有人解释为什么会这样吗,以及如何保存变量的更改?


回答:

在你的原始代码中,你实际上并没有将名为 my_c 的变量(请注意,TensorFlow 的 name)赋值为 my_a + my_b

通过编写 my_c = tf.add(my_a,my_b),Python 变量 my_c 现在与具有 name='my_c'tf.Variable 不同。

当你执行 sess.run() 时,你只是在执行操作,并没有更新那个变量。

如果你想让这段代码正确运行,请使用以下代码 – (查看注释中的更改)

import tensorflow as tfmy_a = tf.Variable(2,name = "my_a")my_b = tf.Variable(3,name = "my_b")my_c = tf.Variable(4,name="my_c")# 使用 assign() 函数设置新值add = my_c.assign(tf.add(my_a,my_b))with tf.Session() as sess:    init = tf.initialize_all_variables()    sess.run(init)    # 执行 add 操作符    sess.run(add)    print("my_c =  ",sess.run(my_c))    saver = tf.train.Saver()    saver.save(sess,"test.ckpt")

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注