我正在尝试在Keras中创建一个可以接受参数beta
的激活函数,如下所示:
from keras import backend as Kfrom keras.utils.generic_utils import get_custom_objectsfrom keras.layers import Activationclass Swish(Activation): def __init__(self, activation, beta, **kwargs): super(Swish, self).__init__(activation, **kwargs) self.__name__ = 'swish' self.beta = betadef swish(x): return (K.sigmoid(beta*x) * x)get_custom_objects().update({'swish': Swish(swish, beta=1.)})
如果没有beta
参数,代码可以正常运行,但是我如何在激活函数定义中包含这个参数呢?我还希望在使用model.to_json()
时,像ELU激活函数一样保存这个值。
更新: 基于@今天的回答,我编写了以下代码:
from keras.layers import Layerfrom keras import backend as Kclass Swish(Layer): def __init__(self, beta, **kwargs): super(Swish, self).__init__(**kwargs) self.beta = K.cast_to_floatx(beta) self.__name__ = 'swish' def call(self, inputs): return K.sigmoid(self.beta * inputs) * inputs def get_config(self): config = {'beta': float(self.beta)} base_config = super(Swish, self).get_config() return dict(list(base_config.items()) + list(config.items())) def compute_output_shape(self, input_shape): return input_shapefrom keras.utils.generic_utils import get_custom_objectsget_custom_objects().update({'swish': Swish(beta=1.)})gnn = keras.models.load_model("Model.h5")arch = gnn.to_json()with open(directory + 'architecture.json', 'w') as arch_file: arch_file.write(arch)
然而,目前它并不会将beta
值保存到.json文件中。我该如何使其保存这个值呢?
回答:
由于您希望在序列化模型时保存激活函数的参数,我认为最好将激活函数定义为层,就像Keras中已经定义的高级激活函数一样。您可以这样做:
from keras.layers import Layerfrom keras import backend as Kclass Swish(Layer): def __init__(self, beta, **kwargs): super(Swish, self).__init__(**kwargs) self.beta = K.cast_to_floatx(beta) def call(self, inputs): return K.sigmoid(self.beta * inputs) * inputs def get_config(self): config = {'beta': float(self.beta)} base_config = super(Swish, self).get_config() return dict(list(base_config.items()) + list(config.items())) def compute_output_shape(self, input_shape): return input_shape
然后您可以像使用Keras层一样使用它:
# ...model.add(Swish(beta=0.3))
由于在其定义中实现了get_config()
方法,使用to_json()
或save()
等方法时,参数beta
将会被保存。