在Keras中实现多任务学习

我正在尝试在Keras中实现共享层。我注意到Keras有keras.layers.concatenate,但从文档中我对它的使用还不太确定。我可以用它来创建多个共享层吗?使用Keras实现如图所示的简单共享神经网络的最佳方法是什么? 共享神经网络


编辑1:请注意,所有三个神经网络的输入、输出和共享层的形状都是相同的。这三个神经网络中有多个共享层(和非共享层)。彩色层是每个神经网络独有的,且形状相同。

基本上,该图表示了三个具有多个共享隐藏层的相同神经网络,其后是多个非共享的隐藏层。

我不确定如何像Twitter示例中那样共享多个层,因为API文档中的示例只有一个共享层。


编辑2:根据geompalik的有用评论,这是我最初想出的方案:

sharedLSTM1 = LSTM((data.shape[1]), return_sequences=True)sharedLSTM2 = LSTM(data.shape[1])def createModel(dropoutRate=0.0, numNeurons=40, optimizer='adam'):    inputLayer = Input(shape=(timesteps, data.shape[1]))    sharedLSTM1Instance = sharedLSTM1(inputLayer)    sharedLSTM2Instance =  sharedLSTM2(sharedLSTM1Instance)    dropoutLayer = Dropout(dropoutRate)(sharedLSTM2Instance)    denseLayer1 = Dense(numNeurons)(dropoutLayer)    denseLayer2 = Dense(numNeurons)(denseLayer1)    outputLayer = Dense(1, activation='sigmoid')(denseLayer2)    return (inputLayer, outputLayer)inputLayer1, outputLayer1 = createModel()inputLayer2, outputLayer2 = createModel()model = Model(inputs=[inputLayer1, inputLayer2], outputs=[outputLayer1, outputLayer2])model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

在上面的代码中,我期望两个模型中的LSTM层是共享的,而dropout和两个dense层不是共享的。这是正确的吗?

如果是这样的话,在这个例子中我不需要keras.layers.concatenate,对吗?

如果我尝试使用plot_model来可视化网络,我会得到以下图像(这正是我所期望的):

模型图


回答:

使用Keras的函数式API实现所示的架构非常简单。有关更多信息,请查看此页面

在你的情况下,输入层和第一个隐藏层是共享的,然后每个主题有一个层。现在设计你的模型取决于你的数据看起来如何:例如,如果对于给定的输入,你对每个主题有不同的输出,你应该定义一个模型,如下所示:

model = Model(inputs=[you_main_input], outputs=[subject1_output, subject2_output, subject3_output])

如果不是这种情况,并且你有对应于每个主题的训练数据,你可以定义三个神经网络,并在它们之间共享前两层。请查看上述引用的文档中的“共享层”部分。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注