在封装TensorFlow计算图的类中使用类继承

我喜欢为我的TensorFlow模型构建Python类,使它们更易于携带和使用(至少在我看来是这样)。我采用的方法是编写类似于下面的代码

class MyAwesomeModel(object):    def __init__(self, some_graph_params):        # 一系列定义张量、优化器等的代码...        # 例如 self.mytensor = tf.placeholder(tf.float32, [1])    def Train(self, tfsession, input_val):        # 一些调用tfsession的run()方法的代码等    def other_methods(self):        # 其他方法如测试、绘图等,都由MyAwesomeModel实例维护的状态很好地管理

我有两个非常相似的模型。它们的区别仅在于计算图架构的几个地方 – 我希望创建一个基础类,包含所有通用的功能,然后只需要子类来重写基础类中的几项内容。我心中的工作方式如下

假设我的基础类看起来像这样

然后我加入两个子类,它们只是重新定义了self.tensor2self.tensor1之间的关系:

class ChildClass1(BaseClass):    def __init__(self, multiplier):        BaseClass.__init__(self, multiplier)        self.tensor2 = self.tensor1 + tf.constant(5.0, dtype=tf.float32)class ChildClass2(BaseClass):    def __init__(self, multiplier):        BaseClass.__init__(self, multiplier)        self.tensor2 = self.tensor1 + tf.constant(4.0, dtype=tf.float32)

我的目标是运行以下代码:

cc1 = ChildClass1(2)   # multiplier是2mysession = tf.Session()mysession.run(tf.global_variables_initializer())print(cc1.forward_pass(mysession, 5))

如果一切按我希望的方式进行,那么结果应该是((5*2)+5)*10 = 150。如果对象cc1是类型ChildClass2(2),那么我希望结果是((5*2)+4)*10 = 140。

然而,当我运行上述代码时,结果是100,这与子类从未重写在基础类中首次遇到的self.tensor2的定义是一致的。我以为我需要那条奇怪的代码行self.tensor2 = self.tensor1,因为否则下一行会抱怨self.tensor2不存在。我真正想要的是子类重写self.tensor2的定义,而不影响其他部分。正确的做法是什么?

非常感谢!


回答:

self.tensoroutput从未被重写,因此它的值不依赖于你拥有的基础类。将其变成一个方法,这样它就会工作了。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注