自定义激活函数带参数

我正在尝试在Keras中创建一个可以接受参数beta的激活函数,如下所示:

from keras import backend as Kfrom keras.utils.generic_utils import get_custom_objectsfrom keras.layers import Activationclass Swish(Activation):    def __init__(self, activation, beta, **kwargs):        super(Swish, self).__init__(activation, **kwargs)        self.__name__ = 'swish'        self.beta = betadef swish(x):    return (K.sigmoid(beta*x) * x)get_custom_objects().update({'swish': Swish(swish, beta=1.)})

如果没有beta参数,代码可以正常运行,但是我如何在激活函数定义中包含这个参数呢?我还希望在使用model.to_json()时,像ELU激活函数一样保存这个值。


更新: 基于@今天的回答,我编写了以下代码:

from keras.layers import Layerfrom keras import backend as Kclass Swish(Layer):    def __init__(self, beta, **kwargs):        super(Swish, self).__init__(**kwargs)        self.beta = K.cast_to_floatx(beta)        self.__name__ = 'swish'    def call(self, inputs):        return K.sigmoid(self.beta * inputs) * inputs    def get_config(self):        config = {'beta': float(self.beta)}        base_config = super(Swish, self).get_config()        return dict(list(base_config.items()) + list(config.items()))    def compute_output_shape(self, input_shape):        return input_shapefrom keras.utils.generic_utils import get_custom_objectsget_custom_objects().update({'swish': Swish(beta=1.)})gnn = keras.models.load_model("Model.h5")arch = gnn.to_json()with open(directory + 'architecture.json', 'w') as arch_file:    arch_file.write(arch)

然而,目前它并不会将beta值保存到.json文件中。我该如何使其保存这个值呢?


回答:

由于您希望在序列化模型时保存激活函数的参数,我认为最好将激活函数定义为层,就像Keras中已经定义的高级激活函数一样。您可以这样做:

from keras.layers import Layerfrom keras import backend as Kclass Swish(Layer):    def __init__(self, beta, **kwargs):        super(Swish, self).__init__(**kwargs)        self.beta = K.cast_to_floatx(beta)    def call(self, inputs):        return K.sigmoid(self.beta * inputs) * inputs    def get_config(self):        config = {'beta': float(self.beta)}        base_config = super(Swish, self).get_config()        return dict(list(base_config.items()) + list(config.items()))    def compute_output_shape(self, input_shape):        return input_shape

然后您可以像使用Keras层一样使用它:

# ...model.add(Swish(beta=0.3))

由于在其定义中实现了get_config()方法,使用to_json()save()等方法时,参数beta将会被保存。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注