如何在Keras中实现RBF激活函数?

我正在创建一个自定义的激活函数,特别是RBF激活函数:

from keras import backend as Kfrom keras.layers import Lambdal2_norm = lambda a,b:  K.sqrt(K.sum(K.pow((a-b),2), axis=0, keepdims=True))def rbf2(x):X = #这里我需要从前一层接收的输入 Y = #这里我需要应用于这一层的权重l2 = l2_norm(X,Y)res = K.exp(-1 * gamma * K.pow(l2,2))return res

函数rbf2接收前一层作为输入:

#一些Keras层model.add(Dense(84, activation='tanh')) #层1model.add(Dense(10, activation = rbf2)) #层2

为了创建自定义激活函数,我应该如何从layer1获取输入,并从layer2获取权重?

我实际想要做的,是实现LeNet5神经网络的输出层。LeNet-5的输出层有点特殊,不是计算输入和权重向量的点积,而是每个神经元输出其输入向量和权重向量之间的欧几里得距离的平方。

例如,layer1有84个神经元,layer2有10个神经元。一般情况下,为了计算layer2的10个神经元的输出,我们会计算layer1的84个神经元与layer1layer2之间的84个权重的点积。然后我们对其应用softmax激活函数。

但在这里,不是进行点积操作,而是layer2的每个神经元输出其输入向量和权重向量之间的欧几里得距离的平方(我想将此用作我的激活函数)。

关于创建RBF激活函数(计算层接收的输入和权重的欧几里得距离)并在层中使用它的任何帮助也将是有用的。


回答:

你可以简单地定义一个自定义层来实现这个目的:

from keras.layers import Layerfrom keras import backend as Kclass RBFLayer(Layer):    def __init__(self, units, gamma, **kwargs):        super(RBFLayer, self).__init__(**kwargs)        self.units = units        self.gamma = K.cast_to_floatx(gamma)    def build(self, input_shape):        self.mu = self.add_weight(name='mu',                                  shape=(int(input_shape[1]), self.units),                                  initializer='uniform',                                  trainable=True)        super(RBFLayer, self).build(input_shape)    def call(self, inputs):        diff = K.expand_dims(inputs) - self.mu        l2 = K.sum(K.pow(diff,2), axis=1)        res = K.exp(-1 * self.gamma * l2)        return res    def compute_output_shape(self, input_shape):        return (input_shape[0], self.units)

使用示例:

model = Sequential()model.add(Dense(20, input_shape=(100,)))model.add(RBFLayer(10, 0.5))

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注