在封装TensorFlow计算图的类中使用类继承

我喜欢为我的TensorFlow模型构建Python类,使它们更易于携带和使用(至少在我看来是这样)。我采用的方法是编写类似于下面的代码

class MyAwesomeModel(object):    def __init__(self, some_graph_params):        # 一系列定义张量、优化器等的代码...        # 例如 self.mytensor = tf.placeholder(tf.float32, [1])    def Train(self, tfsession, input_val):        # 一些调用tfsession的run()方法的代码等    def other_methods(self):        # 其他方法如测试、绘图等,都由MyAwesomeModel实例维护的状态很好地管理

我有两个非常相似的模型。它们的区别仅在于计算图架构的几个地方 – 我希望创建一个基础类,包含所有通用的功能,然后只需要子类来重写基础类中的几项内容。我心中的工作方式如下

假设我的基础类看起来像这样

然后我加入两个子类,它们只是重新定义了self.tensor2self.tensor1之间的关系:

class ChildClass1(BaseClass):    def __init__(self, multiplier):        BaseClass.__init__(self, multiplier)        self.tensor2 = self.tensor1 + tf.constant(5.0, dtype=tf.float32)class ChildClass2(BaseClass):    def __init__(self, multiplier):        BaseClass.__init__(self, multiplier)        self.tensor2 = self.tensor1 + tf.constant(4.0, dtype=tf.float32)

我的目标是运行以下代码:

cc1 = ChildClass1(2)   # multiplier是2mysession = tf.Session()mysession.run(tf.global_variables_initializer())print(cc1.forward_pass(mysession, 5))

如果一切按我希望的方式进行,那么结果应该是((5*2)+5)*10 = 150。如果对象cc1是类型ChildClass2(2),那么我希望结果是((5*2)+4)*10 = 140。

然而,当我运行上述代码时,结果是100,这与子类从未重写在基础类中首次遇到的self.tensor2的定义是一致的。我以为我需要那条奇怪的代码行self.tensor2 = self.tensor1,因为否则下一行会抱怨self.tensor2不存在。我真正想要的是子类重写self.tensor2的定义,而不影响其他部分。正确的做法是什么?

非常感谢!


回答:

self.tensoroutput从未被重写,因此它的值不依赖于你拥有的基础类。将其变成一个方法,这样它就会工作了。

Related Posts

Keras Dense层输入未被展平

这是我的测试代码: from keras import…

无法将分类变量输入随机森林

我有10个分类变量和3个数值变量。我在分割后直接将它们…

如何在Keras中对每个输出应用Sigmoid函数?

这是我代码的一部分。 model = Sequenti…

如何选择类概率的最佳阈值?

我的神经网络输出是一个用于多标签分类的预测类概率表: …

在Keras中使用深度学习得到不同的结果

我按照一个教程使用Keras中的深度神经网络进行文本分…

‘MatMul’操作的输入’b’类型为float32,与参数’a’的类型float64不匹配

我写了一个简单的TensorFlow代码,但不断遇到T…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注