使用Numpy权重矩阵初始化TensorFlow CNN模型

我正在手动将预训练的matconvnet模型转换为TensorFlow模型。我使用scipy.io从matconvnet模型的mat文件中提取了权重和偏置,并获得了权重和偏置的numpy矩阵。

代码片段,其中data是scipy.io返回的字典:

for i in data['net2']['layers']:    if i.type == 'conv':        model.append({'weights': i.weights[0], 'bias': i.weights[1], 'stride': i.stride, 'padding': i.pad, 'momentum': i.momentum,'lr': i.learningRate,'weight_decay': i.weightDecay})

weights = {    'wc1': tf.Variable(model[0]['weights']),     'wc2': tf.Variable(model[2]['weights']),    'wc3': tf.Variable(model[4]['weights']),    'wc4': tf.Variable(model[6]['weights'])}

其中model[0]['weights']是从matconvnet模型中提取的4x4x60的numpy矩阵,用于某一层。例如,这就是我定义9×9输入的占位符的方式。

X = tf.placeholder(tf.float32, [None, 9, 9]) #也尝试过使用[None, 81]和tf.reshape,以及[None, 9, 9, 1]

当前问题:我无法使秩匹配。我持续收到ValueError:

ValueError: Shape must be rank 4 but is rank 3 for 'Conv2D' (op: 'Conv2D') with input shapes: [?,9,9], [4,4,60]  

总结

  • 是否可以从numpy数组中明确定义tensorflow模型的权重?
  • 为什么我的权重矩阵的秩是4?我的numpy数组是否应该更像[?, 4, 4, 60],我可以这样做吗?

尝试过的但未成功的方法:

  • 旋转numpy矩阵:我知道matlab和python有不同的索引方式(基于0的索引与基于1的索引,以及行优先与列优先)。尽管我相信我已经适当地转换了一切,但我仍然尝试使用像np.rot90()这样的库,将4x4x60数组转换为60x4x4。
  • 使用tf.reshape:我尝试在用tf.Variable包装权重后使用tf.reshape,但我得到的错误是“Variable没有属性’reshape’”。

注意:请注意,我知道有许多脚本可以从matconvnet转换到caffe,再从caffe转换到tensorflow(例如,这里描述的,https://github.com/vlfeat/matconvnet/issues/1021)。我的问题与tensorflow权重初始化选项有关:


回答:

我通过使用tf.reshape(...)(而不是调用weights['wc1'].reshape(...))克服了这个障碍。我仍然不确定性能,或者这是否是一个非常幼稚的尝试。

更新 进一步的测试显示,这种方法至少在功能上是可行的(也就是说,我已经创建了一个TensorFlow CNN模型,它可以运行并产生与MatConvNet模型一致的预测。我对两者之间的准确性不做任何声明)。

我分享了我的代码。在我的情况下,这是一个非常小的网络 – 如果你试图将这个代码用于你自己的matconvnet到tensorflow项目,你可能会需要更多的修改:https://github.com/melissadale/MatConv2TensorFlow

Related Posts

Keras Dense层输入未被展平

这是我的测试代码: from keras import…

无法将分类变量输入随机森林

我有10个分类变量和3个数值变量。我在分割后直接将它们…

如何在Keras中对每个输出应用Sigmoid函数?

这是我代码的一部分。 model = Sequenti…

如何选择类概率的最佳阈值?

我的神经网络输出是一个用于多标签分类的预测类概率表: …

在Keras中使用深度学习得到不同的结果

我按照一个教程使用Keras中的深度神经网络进行文本分…

‘MatMul’操作的输入’b’类型为float32,与参数’a’的类型float64不匹配

我写了一个简单的TensorFlow代码,但不断遇到T…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注