我喜欢为我的TensorFlow模型构建Python类,使它们更易于携带和使用(至少在我看来是这样)。我采用的方法是编写类似于下面的代码
class MyAwesomeModel(object): def __init__(self, some_graph_params): # 一系列定义张量、优化器等的代码... # 例如 self.mytensor = tf.placeholder(tf.float32, [1]) def Train(self, tfsession, input_val): # 一些调用tfsession的run()方法的代码等 def other_methods(self): # 其他方法如测试、绘图等,都由MyAwesomeModel实例维护的状态很好地管理
我有两个非常相似的模型。它们的区别仅在于计算图架构的几个地方 – 我希望创建一个基础类,包含所有通用的功能,然后只需要子类来重写基础类中的几项内容。我心中的工作方式如下
假设我的基础类看起来像这样
然后我加入两个子类,它们只是重新定义了self.tensor2
和self.tensor1
之间的关系:
class ChildClass1(BaseClass): def __init__(self, multiplier): BaseClass.__init__(self, multiplier) self.tensor2 = self.tensor1 + tf.constant(5.0, dtype=tf.float32)class ChildClass2(BaseClass): def __init__(self, multiplier): BaseClass.__init__(self, multiplier) self.tensor2 = self.tensor1 + tf.constant(4.0, dtype=tf.float32)
我的目标是运行以下代码:
cc1 = ChildClass1(2) # multiplier是2mysession = tf.Session()mysession.run(tf.global_variables_initializer())print(cc1.forward_pass(mysession, 5))
如果一切按我希望的方式进行,那么结果应该是((5*2)+5)*10 = 150。如果对象cc1是类型ChildClass2(2),那么我希望结果是((5*2)+4)*10 = 140。
然而,当我运行上述代码时,结果是100,这与子类从未重写在基础类中首次遇到的self.tensor2的定义是一致的。我以为我需要那条奇怪的代码行self.tensor2 = self.tensor1
,因为否则下一行会抱怨self.tensor2
不存在。我真正想要的是子类重写self.tensor2
的定义,而不影响其他部分。正确的做法是什么?
非常感谢!
回答:
self.tensoroutput从未被重写,因此它的值不依赖于你拥有的基础类。将其变成一个方法,这样它就会工作了。