在Theano中添加额外功能(CNN)

我正在使用Theano进行分类(卷积神经网络)

之前,我一直使用(扁平化)图像的像素值作为神经网络的特征。现在,我想添加额外的特征。
我被告知可以将额外特征的向量与扁平化的图像特征连接起来,然后将其作为全连接层的输入使用,但我在这方面遇到了麻烦。

首先,这种方法正确吗?

以下是一些代码片段和我的错误:
类似于他们网站上提供的示例,但做了一些修改

(来自构建模型的类)

 # 为数据分配符号变量 self.x = T.matrix('x')   # 数据以光栅化图像的形式呈现 self.y = T.ivector('y')  # 标签以[int]标签的一维向量形式呈现 self.f = T.matrix('f') # 额外特征

下面,变量 vrng 之前已定义。重要的是 layer2_input

layer2_input = self.layer1.output.flatten(2)layer2_input = T.concatenate([layer2_input, self.f.flatten(2)])self.layer2 = HiddenLayer(rng, input=layer2_input, n_in=v, n_out=200, activation=T.tanh)

(来自训练的类)

train_model = theano.function([index], cost, updates=updates,          givens={             model.x: train_set_x[index * batch_size: (index + 1) * batch_size],             model.y: train_set_y[index * batch_size: (index + 1) * batch_size],             model.f: train_set_f[index * batch_size: (index + 1) * batch_size]          })

然而,当调用train_model时,我得到了一个错误:

ValueError: GpuJoin: Wrong inputs for input 1 related to inputs 0.!Apply node that caused the error: GpuJoin(TensorConstant{0}, GpuElemwise{tanh,no_inplace}.0, GpuFlatten{2}.0)Inputs shapes: [(), (5, 11776), (5, 2)]Inputs strides: [(), (11776, 1), (2, 1)]Inputs types: [TensorType(int8, scalar), CudaNdarrayType(float32, matrix), CudaNdarrayType(float32, matrix)]

输入形状分别代表 xyf 的形状吗?

如果是的话,第三个似乎是正确的(批量大小=5,2个额外特征),但为什么第一个是标量而第二个是矩阵呢?

更多细节:

train_set_x.shape = (61, 19200) [61个扁平化图像(160x120),19200个像素]train_set_y.shape = (61,) [61个整数标签]train_set_f.shape = (61,2) [每个图像2个额外特征]batch_size = 5

我的想法正确吗,还是有更好的方法来实现这个目标?你对为什么我会得到错误有什么见解吗?


回答:

问题是我在错误的轴上进行了连接。

layer2_input = T.concatenate([layer2_input, self.f.flatten(2)])

应该改为

layer2_input = T.concatenate([layer2_input, self.f.flatten(2)], axis=1)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注