我想将一个预训练的Lasagne(Theano)模型转换为Keras(Tensorflow)模型,因此所有层都需要具有完全相同的配置。从这两个文档中,我不清楚参数如何对应。假设有一个Lasagne的BatchNormLayer,使用默认设置:
class lasagne.layers.BatchNormLayer(incoming, axes='auto', epsilon=1e-4, alpha=0.1, beta=lasagne.init.Constant(0), gamma=lasagne.init.Constant(1), mean=lasagne.init.Constant(0), inv_std=lasagne.init.Constant(1), **kwargs)
这是Keras BatchNormalization层的API:
keras.layers.BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001, center=True, scale=True, beta_initializer='zeros', gamma_initializer='ones', moving_mean_initializer='zeros', moving_variance_initializer='ones', beta_regularizer=None, gamma_regularizer=None, beta_constraint=None, gamma_constraint=None)
大部分都清楚了,所以我在这里提供相应的参数以供将来参考:
(Lasagne -> Keras)incoming -> (不需要,自动)axes -> axisepsilon -> epsilonalpha -> ?beta -> beta_initializergamma -> gamma_initializermean -> moving_mean_initializerinv_std -> moving_variance_initializer? -> momentum? -> center? -> scale? -> beta_regularizer? -> gamma_regularizer? -> beta_constraint? -> gamma_constraint
我认为Lasagne根本不支持beta_regularizer、gamma_regularizer、beta_constraint和gamma_constraint,所以Keras中默认的None是正确的。我还认为在Lasagne中,center和scale总是开启的,不能关闭。
这就剩下Lasagne的alpha和Keras的momentum。从Lasagne文档中关于alpha的说明:
用于训练期间计算的批次均值和标准差的指数移动平均的系数;越接近1,它将越依赖于最近看到的批次
从Keras文档中关于momentum的说明:
移动均值和移动方差的动量
它们似乎对应——但通过什么公式呢?
回答:
从Lasagne代码中我们看到alpha
的使用如下:
running_mean.default_update = ((1 - self.alpha) * running_mean + self.alpha * input_mean)running_inv_std.default_update = ((1 - self.alpha) * running_inv_std + self.alpha * input_inv_std)
从这个讨论Keras批归一化’momentum’的问题中我们可以看到:
def assign_moving_average(variable, value, decay, zero_debias=True, name=None): """计算变量的移动平均。 'variable'的移动平均用'value'更新为: variable * decay + value * (1 - decay) ...
正如问题所指出的,TensorFlow术语’decay’是Keras中’momentum’的值。
由此看来,Lasagne所称的’alpha’等于1 – ‘momentum’,因为在Keras中,’momentum’是现有变量(现有移动平均)的乘数,而在Lasagne中这个乘数是1 - alpha
。
不得不承认这很 confusing,因为
- TensorFlow操作在Keras下使用术语’decay’,但这是Keras直接命名的’momentum’。
- TensorFlow代码仅将事物命名为’variable’和’value’,这使得很难知道哪个是存储的移动平均,哪个是需要结合的新数据。