如何在PyTorch中实现Conv2d的棋盘式步长?

我正在尝试使用PyTorch创建一个卷积神经网络,用于处理二维矩阵的输入。我使用的是3×5的滤波器,并且希望它有一个自定义的步长,如下所示:在偶数行号上,我希望滤波器从位置0的元素(图片中的红色)开始,在奇数行号上,我希望它从位置1的元素(图片中的蓝色)开始,并且在两种情况下,x方向上的步长为2。这意味着,如果我的输入是一个如图所示的矩阵,我希望滤波器的中心只有0。我知道这在卷积神经网络中非常不寻常,但这实际上是物理学中的一个问题,所以确切的步长非常重要。 3x5 filter on an input


回答:

以下自定义的Conv2d层实现了如原始问题所述的棋盘式步长的卷积操作。这里的难点在于PyTorch并不真正支持这种不一致的步长。尽管如此,我们可以将这个操作分解为两个独立的步进卷积,一个用于偶数行,另一个用于奇数行。之后,我们可以将结果交错组合在一起。下面的代码中有一些细节,确保我们正确地填充(如果需要)。此外,这个层完全支持反向传播。


示例

这个层基本上就像一个普通的Conv2d,但使用了棋盘式步长。

>> x = torch.arange(64).view(1, 1, 8, 8).floattensor([[[[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],          [ 8.,  9., 10., 11., 12., 13., 14., 15.],          [16., 17., 18., 19., 20., 21., 22., 23.],          [24., 25., 26., 27., 28., 29., 30., 31.],          [32., 33., 34., 35., 36., 37., 38., 39.],          [40., 41., 42., 43., 44., 45., 46., 47.],          [48., 49., 50., 51., 52., 53., 54., 55.],          [56., 57., 58., 59., 60., 61., 62., 63.]]]])

>> layer = AMNI_Conv2d(1, 1, (3, 5), bias=False)# set kernels to delta functions to demonstrate kernel centers>>> with torch.no_grad():...     layer.conv.weight.zero_()...     layer.conv.weight[:,:,1,2] = 1>>> result = layer(x)tensor([[[[10., 12.],          [19., 21.],          [26., 28.],          [35., 37.],          [42., 44.],          [51., 53.]]]], grad_fn=)

你也可以通过填充来获取原始图中的每一个“零”

>> layer = AMNI_Conv2d(1, 1, (3, 5), padding=(1, 2), bias=False)# set kernels to delta functions to demonstrate kernel centers>>> with torch.no_grad():...     layer.conv.weight.zero_()...     layer.conv.weight[:,:,1,2] = 1>>> result = layer(x)tensor([[[[ 1.,  3.,  5.,  7.],          [ 8., 10., 12., 14.],          [17., 19., 21., 23.],          [24., 26., 28., 30.],          [33., 35., 37., 39.],          [40., 42., 44., 46.],          [49., 51., 53., 55.],          [56., 58., 60., 62.]]]], grad_fn=)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注