将Python可调用对象包装在Keras层中

在Keras/TensorFlow中,通常很容易将层直接描述为将输入映射到输出的函数,如下所示:

def resnet_block(x, kernel_size):   ch = x.shape[-1]   out = Conv2D(ch, kernel_size, strides = (1,1), padding='same', activation='relu')(x)   out = Conv2D(ch, kernel_size, strides = (1,1), padding='same', activation='relu')(out)   out = Add()([x,out])   return out

而通过子类化Layer来实现类似于

r = ResNetBlock(kernel_size=(3,3))y = r(x)

则会稍微麻烦一些(对于更复杂的例子甚至会麻烦很多)。

由于Keras似乎很乐意在层首次被调用时构建其底层的权重,我在想是否可以只包装如上所示的函数,让Keras在有输入时自行处理,即我希望它看起来像这样:

r = FunctionWrapperLayer(lambda x:resnet_block(x, kernel_size=(3,3)))y = r(x)

我尝试实现了FunctionWrapperLayer,如下所示:

class FunctionWrapperLayer(Layer):    def __init__(self, fn):        super(FunctionWrapperLayer, self).__init__()        self.fn = fn        def build(self, input_shape):        shape = input_shape[1:]        inputs = Input(shape)        outputs = self.fn(inputs)        self.model = Model(inputs=inputs, outputs=outputs)        self.model.compile()            def call(self, x):        return self.model(x)

这看起来似乎可以工作,但是当我使用激活函数时,遇到了一些奇怪的问题,例如使用

def bad(x):    out = tf.keras.activations.sigmoid(x)    out = Conv2D(1, (1,1), strides=(1,1), padding='same')(out)    return outx = tf.constant(tf.reshape(tf.range(48,dtype=tf.float32),[1,4,-1,1])w = FunctionWrapperLayer(bad)w(x)

我得到了以下错误

FailedPreconditionError:  Error while reading resource variable _AnonymousVar34 from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/_AnonymousVar34/class tensorflow::Var does not exist.     [[node conv2d_6/BiasAdd/ReadVariableOp (defined at <ipython-input-33-fc380d9255c5>:12) ]] [Op:__inference_keras_scratch_graph_353]

这让我觉得在构建方法中像这样初始化模型存在某种固有的问题。也许有人对可能发生的事情有更好的见解,或者有其他方法可以实现我想要的功能。

更新:正如jr15提到的,上述方法在涉及的函数仅使用Keras层时确实有效。然而,以下方法也同样有效,这让我有点困惑:

i = Input(x.shape[1:])o = bad(i)model = Model(inputs=i, outputs=o)model(x)

顺便提一下,model.submodules返回

(<tensorflow.python.keras.engine.input_layer.InputLayer at 0x219d80c77c0>, <tensorflow.python.keras.engine.base_layer.TensorFlowOpLayer at 0x219d7afc820>, <tensorflow.python.keras.layers.convolutional.Conv2D at 0x219d7deafa0>)

这意味着激活函数在以这种方式处理时会自动转换为“TensorFlowOpLayer”。

另一个更新:查看原始错误消息,似乎激活函数并不是唯一的问题。如果我移除卷积并使用包装器,一切都能正常工作,并且再次检查子模块时,我发现了“TensorFlowOpLayer”。


回答:

在TensorFlow 2.4中,现在似乎可以直接工作。子模块现在显示为“TFOpLambda”层。

对任何感兴趣的人,这里有一些稍加改进的包装器代码,也适用于多输入模型:

class FunctionWrapperLayer(Layer):    def __init__(self, fn):        super(FunctionWrapperLayer, self).__init__()        self.fn = fn            def build(self, input_shapes):        super(FunctionWrapperLayer, self).build(input_shapes)        if type(input_shapes) is list:            inputs = [Input(shape[1:]) for shape in input_shapes]        else:            inputs = Input(input_shapes[1:])        outputs = self.fn(inputs)        self.fn_model = Model(inputs=inputs, outputs=outputs)        self.fn_model.compile()            def call(self, x):        return self.fn_model(x)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注