我最近开始使用TensorFlow(TF),遇到一个需要帮助的问题。基本上,我已经恢复了一个预训练模型,并且在我重新测试其准确性之前,需要修改其一层中的权重和偏置。现在,我的具体问题是:如何使用TF中的assign
方法来更改权重和偏置?在TF中修改恢复模型的权重是否可行?
这是我的代码:
import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data # 导入MINST数据集# 数据集:# ---------mnist = input_data.read_data_sets("/home/frr/MNIST_data", one_hot=True)# 存储数据的对象ImVecDim = 784# 图像向量中的元素数量(将28x28的2D图像展平)NumOfClasses = 10g = tf.get_default_graph()with tf.Session() as sess: LoadMod = tf.train.import_meta_graph('simple_mnist.ckpt.meta') # 此对象加载模型 LoadMod.restore(sess, tf.train.latest_checkpoint('./'))# 将权重、偏置和其他内容加载到模型中 # (在这里我想修改第一层的权重和偏置,例如将它们设置为1,然后再进行准确性测试) # # 测试模型的准确性: X = g.get_tensor_by_name('ImageIn:0') Y = g.get_tensor_by_name('LabelIn:0') KP = g.get_tensor_by_name('KeepProb:0') Accuracy = g.get_tensor_by_name('NetAccuracy:0') feed_dict = { X: mnist.test.images[:256], Y: mnist.test.labels[:256], KP: 1.0 } print( '模型准确性 = ' ) print( sess.run( Accuracy, feed_dict ) )
回答:
除了现有的回答外,还可以通过tf.assign
函数来执行张量更新。
v1 = sess.graph.get_tensor_by_name('v1:0')print(sess.run(v1)) # 1.0sess.run(tf.assign(v1, v1 + 1))print(sess.run(v1)) # 2.0