在封装TensorFlow计算图的类中使用类继承

我喜欢为我的TensorFlow模型构建Python类,使它们更易于携带和使用(至少在我看来是这样)。我采用的方法是编写类似于下面的代码

class MyAwesomeModel(object):    def __init__(self, some_graph_params):        # 一系列定义张量、优化器等的代码...        # 例如 self.mytensor = tf.placeholder(tf.float32, [1])    def Train(self, tfsession, input_val):        # 一些调用tfsession的run()方法的代码等    def other_methods(self):        # 其他方法如测试、绘图等,都由MyAwesomeModel实例维护的状态很好地管理

我有两个非常相似的模型。它们的区别仅在于计算图架构的几个地方 – 我希望创建一个基础类,包含所有通用的功能,然后只需要子类来重写基础类中的几项内容。我心中的工作方式如下

假设我的基础类看起来像这样

然后我加入两个子类,它们只是重新定义了self.tensor2self.tensor1之间的关系:

class ChildClass1(BaseClass):    def __init__(self, multiplier):        BaseClass.__init__(self, multiplier)        self.tensor2 = self.tensor1 + tf.constant(5.0, dtype=tf.float32)class ChildClass2(BaseClass):    def __init__(self, multiplier):        BaseClass.__init__(self, multiplier)        self.tensor2 = self.tensor1 + tf.constant(4.0, dtype=tf.float32)

我的目标是运行以下代码:

cc1 = ChildClass1(2)   # multiplier是2mysession = tf.Session()mysession.run(tf.global_variables_initializer())print(cc1.forward_pass(mysession, 5))

如果一切按我希望的方式进行,那么结果应该是((5*2)+5)*10 = 150。如果对象cc1是类型ChildClass2(2),那么我希望结果是((5*2)+4)*10 = 140。

然而,当我运行上述代码时,结果是100,这与子类从未重写在基础类中首次遇到的self.tensor2的定义是一致的。我以为我需要那条奇怪的代码行self.tensor2 = self.tensor1,因为否则下一行会抱怨self.tensor2不存在。我真正想要的是子类重写self.tensor2的定义,而不影响其他部分。正确的做法是什么?

非常感谢!


回答:

self.tensoroutput从未被重写,因此它的值不依赖于你拥有的基础类。将其变成一个方法,这样它就会工作了。

Related Posts

L1-L2正则化的不同系数

我想对网络的权重同时应用L1和L2正则化。然而,我找不…

使用scikit-learn的无监督方法将列表分类成不同组别,有没有办法?

我有一系列实例,每个实例都有一份列表,代表它所遵循的不…

f1_score metric in lightgbm

我想使用自定义指标f1_score来训练一个lgb模型…

通过相关系数矩阵进行特征选择

我在测试不同的算法时,如逻辑回归、高斯朴素贝叶斯、随机…

可以将机器学习库用于流式输入和输出吗?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

在TensorFlow中,queue.dequeue_up_to()方法的用途是什么?

我对这个方法感到非常困惑,特别是当我发现这个令人费解的…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注