多输入深度学习中的平均层

我正在使用Keras创建一个多输入卷积神经网络(CNN)模型,用于图像分类,该模型接受两张图片并输出一个结果,即这两张图片的类别。

我有两个数据集:type1和type2,每个数据集包含相同的类别。模型应该从Type1数据集中取一张图片,从Type2数据集中取另一张图片,然后将这些图片分类到一个类别中(ClassA或ClassB或其他)。

我想创建一个模型来预测这两张图片,然后计算预测的平均值,类似于下图所示:

enter image description here

我如何创建这个模型?我如何在fit_generator中创建生成器?


回答:

选项1 – 两侧使用相同的模型,只是使用不同的输入

假设你有一个模型一直到“预测”,称为predModel
创建两个输入张量:

input1 = Input(shape)   input2 = Input(shape)

获取每个输入的输出:

pred1 = predModel(input1)pred2 = predModel(input2)   

平均输出:

output = Average()([pred1,pred2])

创建最终模型:

model = Model([input1,input2], output)

选项2 – 两侧使用相似的模型,但使用不同的权重

基本与上述相同,但为每侧单独创建层。

def createCommonPart(inputTensor):    out = ZeroPadding2D(...)(inputTensor)    out = Conv2D(...)(out)    ...    out = Flatten()(out)    return Dense(...)(out)

创建两个输入:

input1 = Input(shape)   input2 = Input(shape)

获取两个输出:

pred1 = createCommonPart(input1)pred2 = createCommonPart(input2)

平均输出:

output = Average()([pred1,pred2])

创建最终模型:

model = Model([input1,input2], output)

生成器

任何能够产生[xTrain1,xTrain2], y的内容都可以。

你可以这样创建一个:

def generator(files1,files2, batch_size):    while True: #必须是无限的        for i in range(len(files1)//batch_size)):            bStart = i*batch_size            bEnd = bStart+batch_size            x1 = loadImagesSomehow(files1[bStart:bEnd])            x2 = loadImagesSomehow(files2[bStart:bEnd])            y = loadPredictionsSomeHow(forSamples[bStart:bEnd])            yield [x1,x2], y

你也可以以类似的方式实现一个keras.utils.Sequence

class gen(Sequence):    def __init__(self, files1, files2, batchSize):        self.files1 = files1        self.files2 = files2        self.batchSize = batchSize    def __len__(self):        return self.len(files1) // self.batchSize    def __getitem__(self,i):        bStart = i*self.batchSize        bEnd = bStart+self.batchSize         x1 = loadImagesSomehow(files1[bStart:bEnd])        x2 = loadImagesSomehow(files2[bStart:bEnd])        y = loadPredictionsSomeHow(forSamples[bStart:bEnd])        return [x1,x2], y

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注