在Keras/Tensorflow中实现可训练的通用凸包函数层

我正在尝试编写以下凸包函数的变体,并将其按组件应用:

generalized bump function equation,

其中σ是可训练的;但它不起作用(下方报告了错误)。


我的尝试:

这是我迄今为止编写的代码(如果有帮助)。假设我有两个函数(例如):

  def f_True(x):    # 计算凸包函数    bump_value = 1-tf.math.pow(x,2)    bump_value = -tf.math.pow(bump_value,-1)    bump_value = tf.math.exp(bump_value)    return(bump_value)  def f_False(x):    # 计算凸包函数    x_out = 0*x    return(x_out)class trainable_bump_layer(tf.keras.layers.Layer):    def __init__(self, *args, **kwargs):        super(trainable_bump_layer, self).__init__(*args, **kwargs)    def build(self, input_shape):        self.threshold_level = self.add_weight(name='threshlevel',                                    shape=[1],                                    initializer='GlorotUniform',                                    trainable=True)    def call(self, input):        # 确定阈值逻辑        The_Logic = tf.math.less(input,self.threshold_level)        # 应用逻辑        output_step_3 = tf.cond(The_Logic,                                 lambda: f_True(input),                                lambda: f_False(input))        return output_step_3

错误报告:

    在100个样本上训练Epoch 1/10WARNING:tensorflow:当最小化损失时,变量['reconfiguration_unit_steps_3_3/threshlevel:0']的梯度不存在。WARNING:tensorflow:当最小化损失时,变量['reconfiguration_unit_steps_3_3/threshlevel:0']的梯度不存在。 32/100 [========>.....................] - ETA: 3s

tensorflow:变量的梯度不存在 

此外,似乎它并没有按组件应用(除了不可训练的问题)。问题可能出在哪里?


回答:

我有点惊讶居然没有人提到导致该警告的主要(也是唯一)的原因!似乎那段代码是用来实现凸包函数的通用变体的;然而,再次查看实现的函数:

def f_True(x):    # 计算凸包函数    bump_value = 1-tf.math.pow(x,2)    bump_value = -tf.math.pow(bump_value,-1)    bump_value = tf.math.exp(bump_value)    return(bump_value)def f_False(x):    # 计算凸包函数    x_out = 0*x    return(x_out)

错误很明显:这些函数中没有使用层的可训练权重! 因此,你得到的关于不存在梯度的消息并不令人惊讶:你根本没有使用它,因此没有梯度来更新它!相反,这正是原始的凸包函数(即,没有可训练权重)。

但是,你可能会说:“至少,我在tf.cond的条件中使用了可训练权重,所以应该有一些梯度吧?!”;然而,情况并非如此,让我来澄清一下混淆的地方:

  • 首先,正如你也注意到的,我们对元素级条件感兴趣。所以你需要使用tf.where代替tf.cond

  • 另一个误解是声称由于使用了tf.less作为条件,并且由于它不可微分,即相对于其输入没有定义的梯度(这是真的:对于具有布尔输出相对于其实值输入的函数没有定义的梯度!),那么这会导致给定的警告!

    • 这完全是错误的!这里的导数将是相对于可训练权重的层的输出,而选择条件并不存在于输出中。相反,它只是一个布尔张量,用于确定要选择的输出分支。仅此而已!条件的导数不会被计算,也永远不需要。所以这不是给定警告的原因;原因仅仅是我上面提到的:可训练权重对层的输出没有贡献。(注意:如果你对条件点的看法有点惊讶,那么想想一个简单的例子:ReLU函数,它被定义为relu(x) = 0 if x < 0 else x。如果考虑/需要条件的导数,即x < 0,它不存在,那么我们将无法在我们的模型中使用ReLU,并使用基于梯度的优化方法来训练它们!)

(注意:从这里开始,我将把阈值称为sigma,就像方程中一样)。

好的!我们找到了实现中错误的原因。我们能修复这个问题吗?当然可以!这是更新后的工作实现:

import tensorflow as tffrom tensorflow.keras.initializers import RandomUniformfrom tensorflow.keras.constraints import NonNegclass BumpLayer(tf.keras.layers.Layer):    def __init__(self, *args, **kwargs):        super(BumpLayer, self).__init__(*args, **kwargs)    def build(self, input_shape):        self.sigma = self.add_weight(            name='sigma',            shape=[1],            initializer=RandomUniform(minval=0.0, maxval=0.1),            trainable=True,            constraint=tf.keras.constraints.NonNeg()        )        super().build(input_shape)    def bump_function(self, x):        return tf.math.exp(-self.sigma / (self.sigma - tf.math.pow(x, 2)))    def call(self, inputs):        greater = tf.math.greater(inputs, -self.sigma)        less = tf.math.less(inputs, self.sigma)        condition = tf.logical_and(greater, less)        output = tf.where(            condition,             self.bump_function(inputs),            0.0        )        return output

关于此实现的一些要点:

  • 我们用tf.where替换了tf.cond,以便进行元素级条件判断。

  • 此外,如你所见,与你的实现不同的是,你的实现只检查了不等式的一侧,我们使用tf.math.lesstf.math.greatertf.logical_and来找出输入值的幅度是否小于sigma(或者,我们也可以只使用tf.math.abstf.math.less;没有区别!)。让我们重复一遍:以这种方式使用布尔输出函数不会引起任何问题,也与导数/梯度无关。

  • 我们还在层学习的sigma值上使用了非负约束。为什么?因为sigma值小于零是没有意义的(即,当sigma为负时,范围(-sigma, sigma)是未定义的)。

  • 考虑到前一点,我们注意适当初始化sigma值(即,初始化为一个小的非负值)。

  • 还有,不要做像0.0 * inputs这样的事情!它是多余的(而且有点奇怪),它等同于0.0;两者的梯度都为0.0(相对于inputs)。将零与张量相乘不会增加任何东西或解决任何现有问题,至少在这种情况下不会!

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注