Tensorflow 2:如何在Keras函数式API中使用堆叠的全连接层?

我正在构建一个模型,希望能够将no_of_dense_layers作为参数传递,并期望函数在循环中创建全连接层。

在循环中创建全连接层并不是问题,我的困惑在于如何在Keras中向堆叠的全连接层传递值?

假设我想要3个全连接层:

def get_layers(no_of_dense_layers  , dense_size):    return [tf.keras.layers.Dense(dense_size[i], activation = 'elu',                                       kernel_initializer=tf.keras.initializers.glorot_uniform(seed=200)) for i in range(no_of_dense_layers)]

如果我想使用Sequential API,可以这样做:

perceptron = tf.keras.Sequential(get_layers(3,[1000,500,300]))

但如果我想使用函数式API,如何实现相同的功能呢?

类似这样的操作:

input_layer = tf.keras.Input(shape=(1024), dtype='float32', name='embedding_input')## 此层应该传递输入到第一个全连接层 >> 第二个全连接层 >> 第三个全连接层   dense_layers = get_layers(3,[1000,500,300])# 上述层应等同于: # x = tf.keras.layers.Dense(1000)# x = tf.keras.layers.Dense(500)# x = tf.keras.layers.Dense(300)# 然后简单地将所有三个全连接层的输出传递到最后的分类层# 分类层 cls_layer  = tf.keras.layers.Dense(1, activation= 'elu')(dense_layers)

我尝试过的方法:

first_layer = dense_layers[0](input_layer)for k in dense_layers[1:]:    print(k(first_layer))

还有其他方法吗?

谢谢!


回答:

这里有一个可能的解决方案:

def get_layers(inp, no_of_dense_layers, dense_size):        for i in range(no_of_dense_layers):            x = Dense(dense_size[i])(inp)        inp = x            return xinp = Input((1024,))x = get_layers(inp, 3, [1000,500,300])    out = Dense(1)(x)m = Model(inp, out)m.summary()_________________________________________________________________Layer (type)                 Output Shape              Param #   =================================================================input_45 (InputLayer)        [(None, 1024)]            0         _________________________________________________________________dense_88 (Dense)             (None, 1000)              1025000   _________________________________________________________________dense_89 (Dense)             (None, 500)               500500    _________________________________________________________________dense_90 (Dense)             (None, 300)               150300    _________________________________________________________________dense_91 (Dense)             (None, 1)                 301       

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注