我正在创建一系列自定义的Tensorflow(版本2.4.1
)层,但在模型摘要中显示可训练参数为零时遇到了问题。下面是一系列示例,展示了在添加最后一个自定义层之前一切正常的情况。
以下是导入和自定义类:
from tensorflow.keras.models import Modelfrom tensorflow.keras.layers import (BatchNormalization, Conv2D, Input, ReLU, Layer)class basic_conv_stack(Layer): def __init__(self, filters, kernel_size, strides): super(basic_conv_stack, self).__init__() self.conv1 = Conv2D(filters, kernel_size, strides, padding='same') self.bn1 = BatchNormalization() self.relu = ReLU() def call(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) return x class basic_residual(Layer): def __init__(self, filters, kernel_size, strides): super(basic_residual, self).__init__() self.bcs1 = basic_conv_stack(filters, kernel_size, strides) self.bcs2 = basic_conv_stack(filters, kernel_size, strides) def call(self, x): x = self.bcs1(x) x = self.bcs2(x) return x class basic_module(Layer): def __init__(self, filters, kernel_size, strides): super(basic_module, self).__init__() self.res = basic_residual self.args = (filters, kernel_size, strides) def call(self, x): for _ in range(4): x = self.res(*self.args)(x) return x
现在,如果我这样做,一切正常,我会得到300个可训练参数:
input_layer = Input((128, 128, 3))conv = basic_conv_stack(10, 3, 1)(input_layer)model = Model(input_layer, conv)print (model.summary())
同样,如果我这样做,我会得到1,230个可训练参数:
input_layer = Input((128, 128, 3))conv = basic_residual(10, 3, 1)(input_layer)model = Model(input_layer, conv)print (model.summary())
然而,如果我尝试使用basic_module类,我会得到零可训练参数:
input_layer = Input((128, 128, 3))conv = basic_module(10, 3, 1)(input_layer)model = Model(input_layer, conv)print (model.summary())
有谁知道这是为什么吗?
编辑补充:
我发现必须在类的初始化中初始化在call中使用的层,才能正常工作。所以如果我将basic_module更改为这样:
class basic_module(Layer): def __init__(self, filters, kernel_size, strides): super(basic_module, self).__init__() self.clayers = [basic_residual(filters, kernel_size, strides) for _ in range(4)] def call(self, x): for idx in range(4): x = self.clayers[idx](x) return x
一切正常。我不知道为什么会这样,所以我会把这个问题留在这里,以防有人能回答这个问题的原因。
回答:
你必须使用filters
、kernel_size
、strides
等所需参数来初始化类实例到预定义的base_mdoule
。另外,请注意这些超参数与可训练权重属性有关。
# >>> a = basic_module# >>> a __main__.basic_module# >>> a = basic_module(10, 3, 1)# >>> a # >>> <__main__.basic_module at 0x7f6123eed510>class basic_module(Layer): def __init__(self, filters, kernel_size, strides): super(basic_module, self).__init__() self.res = basic_residual # < --- self.args = (filters, kernel_size, strides) def call(self, x): for _ in range(4): x = self.res(*self.args)(x) return x