自定义激活函数带参数

我正在尝试在Keras中创建一个可以接受参数beta的激活函数,如下所示:

from keras import backend as Kfrom keras.utils.generic_utils import get_custom_objectsfrom keras.layers import Activationclass Swish(Activation):    def __init__(self, activation, beta, **kwargs):        super(Swish, self).__init__(activation, **kwargs)        self.__name__ = 'swish'        self.beta = betadef swish(x):    return (K.sigmoid(beta*x) * x)get_custom_objects().update({'swish': Swish(swish, beta=1.)})

如果没有beta参数,代码可以正常运行,但是我如何在激活函数定义中包含这个参数呢?我还希望在使用model.to_json()时,像ELU激活函数一样保存这个值。


更新: 基于@今天的回答,我编写了以下代码:

from keras.layers import Layerfrom keras import backend as Kclass Swish(Layer):    def __init__(self, beta, **kwargs):        super(Swish, self).__init__(**kwargs)        self.beta = K.cast_to_floatx(beta)        self.__name__ = 'swish'    def call(self, inputs):        return K.sigmoid(self.beta * inputs) * inputs    def get_config(self):        config = {'beta': float(self.beta)}        base_config = super(Swish, self).get_config()        return dict(list(base_config.items()) + list(config.items()))    def compute_output_shape(self, input_shape):        return input_shapefrom keras.utils.generic_utils import get_custom_objectsget_custom_objects().update({'swish': Swish(beta=1.)})gnn = keras.models.load_model("Model.h5")arch = gnn.to_json()with open(directory + 'architecture.json', 'w') as arch_file:    arch_file.write(arch)

然而,目前它并不会将beta值保存到.json文件中。我该如何使其保存这个值呢?


回答:

由于您希望在序列化模型时保存激活函数的参数,我认为最好将激活函数定义为层,就像Keras中已经定义的高级激活函数一样。您可以这样做:

from keras.layers import Layerfrom keras import backend as Kclass Swish(Layer):    def __init__(self, beta, **kwargs):        super(Swish, self).__init__(**kwargs)        self.beta = K.cast_to_floatx(beta)    def call(self, inputs):        return K.sigmoid(self.beta * inputs) * inputs    def get_config(self):        config = {'beta': float(self.beta)}        base_config = super(Swish, self).get_config()        return dict(list(base_config.items()) + list(config.items()))    def compute_output_shape(self, input_shape):        return input_shape

然后您可以像使用Keras层一样使用它:

# ...model.add(Swish(beta=0.3))

由于在其定义中实现了get_config()方法,使用to_json()save()等方法时,参数beta将会被保存。

Related Posts

在使用k近邻算法时,有没有办法获取被使用的“邻居”?

我想找到一种方法来确定在我的knn算法中实际使用了哪些…

Theano在Google Colab上无法启用GPU支持

我在尝试使用Theano库训练一个模型。由于我的电脑内…

准确性评分似乎有误

这里是代码: from sklearn.metrics…

Keras Functional API: “错误检查输入时:期望input_1具有4个维度,但得到形状为(X, Y)的数组”

我在尝试使用Keras的fit_generator来训…

如何使用sklearn.datasets.make_classification在指定范围内生成合成数据?

我想为分类问题创建合成数据。我使用了sklearn.d…

如何处理预测时不在训练集中的标签

已关闭。 此问题与编程或软件开发无关。目前不接受回答。…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注