如何为自定义Keras层创建可训练的权重变量

我正在实现一个自定义的平均池化层,其中每个神经元计算其输入的平均值,然后将结果乘以一个可学习的系数并加上一个可学习的偏置项,最后应用激活函数

from tensorflow.keras.layers import Layerfrom keras import backend as Kclass Average_Pooling_Layer(Layer):    def __init__(self, output_dimension, **kwargs):        super(Average_Pooling_Layer, self).__init__(**kwargs)        self.output_dimension = output_dimension    def build(self, input_shape):        self.weights = self.add_weight(name='weights2',                              shape=(input_shape[0],                                int(self.output_dimension[0]),                                int(self.output_dimension[1]),                                int(self.output_dimension[2])),                               initializer='uniform',                               trainable=True)    super(Average_Pooling_Layer, self).build(input_shape)   def call(self, inputs):        return K.tanh((inputs * self.weights))    def compute_output_shape(self, input_shape):        return (input_shape)

代码使用

model = tf.keras.Sequential()stride = 1c1 = model.add(Conv2D(6, kernel_size=[5,5], strides=(stride,stride), padding="valid", input_shape=(32,32,1),                   activation = 'tanh'))s2_before_activation = model.add(AveragePooling2D(pool_size=(2, 2), strides=(2, 2)))s2 = model.add(Average_Pooling_Layer(output_dimension = (14, 14, 6)))

我遇到错误:“无法将类型对象转换为Tensor。内容:(Dimension(None), 14, 14, 6)。请考虑将元素转换为支持的类型。” “None”是批量大小,我是从前一层获取的。

如何解决这个问题?


回答:

您的错误是由数据类型引起的。input_shape[0]返回的是<class 'tensorflow.python.framework.tensor_shape.Dimension'>而不是int

您可以用tf.TensorShape(input_shape).as_list()[0]替换input_shape[0]。但是您的维度设置不正确,您需要根据需要进行调整和修改。

编辑

如果您遇到“无法设置属性”的错误,您应该重命名您的权重变量,而不是使用self.weights。例如,改为self.weights_new

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注