自定义激活函数带参数

我正在尝试在Keras中创建一个可以接受参数beta的激活函数,如下所示:

from keras import backend as Kfrom keras.utils.generic_utils import get_custom_objectsfrom keras.layers import Activationclass Swish(Activation):    def __init__(self, activation, beta, **kwargs):        super(Swish, self).__init__(activation, **kwargs)        self.__name__ = 'swish'        self.beta = betadef swish(x):    return (K.sigmoid(beta*x) * x)get_custom_objects().update({'swish': Swish(swish, beta=1.)})

如果没有beta参数,代码可以正常运行,但是我如何在激活函数定义中包含这个参数呢?我还希望在使用model.to_json()时,像ELU激活函数一样保存这个值。


更新: 基于@今天的回答,我编写了以下代码:

from keras.layers import Layerfrom keras import backend as Kclass Swish(Layer):    def __init__(self, beta, **kwargs):        super(Swish, self).__init__(**kwargs)        self.beta = K.cast_to_floatx(beta)        self.__name__ = 'swish'    def call(self, inputs):        return K.sigmoid(self.beta * inputs) * inputs    def get_config(self):        config = {'beta': float(self.beta)}        base_config = super(Swish, self).get_config()        return dict(list(base_config.items()) + list(config.items()))    def compute_output_shape(self, input_shape):        return input_shapefrom keras.utils.generic_utils import get_custom_objectsget_custom_objects().update({'swish': Swish(beta=1.)})gnn = keras.models.load_model("Model.h5")arch = gnn.to_json()with open(directory + 'architecture.json', 'w') as arch_file:    arch_file.write(arch)

然而,目前它并不会将beta值保存到.json文件中。我该如何使其保存这个值呢?


回答:

由于您希望在序列化模型时保存激活函数的参数,我认为最好将激活函数定义为层,就像Keras中已经定义的高级激活函数一样。您可以这样做:

from keras.layers import Layerfrom keras import backend as Kclass Swish(Layer):    def __init__(self, beta, **kwargs):        super(Swish, self).__init__(**kwargs)        self.beta = K.cast_to_floatx(beta)    def call(self, inputs):        return K.sigmoid(self.beta * inputs) * inputs    def get_config(self):        config = {'beta': float(self.beta)}        base_config = super(Swish, self).get_config()        return dict(list(base_config.items()) + list(config.items()))    def compute_output_shape(self, input_shape):        return input_shape

然后您可以像使用Keras层一样使用它:

# ...model.add(Swish(beta=0.3))

由于在其定义中实现了get_config()方法,使用to_json()save()等方法时,参数beta将会被保存。

Related Posts

神经网络中长特征向量的大小

我正在设计一个神经网络,希望特征向量的大小与输入向量的…

如何使用cross_val_score来拟合我的测试数据?

我正在尝试理解cross_val_score()的使用…

knn.fit() 错误:ValueError:发现输入变量的样本数量不一致

我在 DataCamp 上学习监督学习课程,并尝试在 …

我在人脸验证中应该使用0类吗?

我正在进行人脸验证的实现工作,我有很多人的大量照片,在…

PMML GBDTLRClassifier中的分类特征设置错误

我尝试按照这里的说明设置我的GBDTLRClassif…

Tensorflow – model.fit中的值错误 – 如何修复

我正在尝试使用MNIST数据集训练一个深度神经网络。 …

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注