我正在自学一些机器学习知识,并使用MNIST数据库(http://yann.lecun.com/exdb/mnist/)进行学习。该网站的作者在1998年发表了一篇关于各种手写识别技术的论文,可在http://yann.lecun.com/exdb/publis/pdf/lecun-98.pdf找到。
提到的第十种方法是“切线距离分类器”。其思想是,如果将每张图像放置在(NxM)维向量空间中,可以通过计算两个图像形成的超平面之间的距离来计算两张图像之间的距离,其中超平面是通过对图像进行旋转、缩放、平移等操作得到的。
我无法弄清楚足够的细节来填补缺失的部分。我理解这些操作中的大多数确实是线性算子,那么如何利用这一事实来创建超平面呢?一旦我们有了超平面,我们又如何计算它与其他超平面之间的距离呢?
回答:
我会给你一些提示。你需要一些图像处理的基础知识。请参考2、3了解详情。
图像卷积
根据3,你需要做的第一步是平滑图像。下面展示了三种不同平滑操作的结果(查看3的第4节)(左列显示结果图像,右列显示原始图像和卷积算子)。这一步是为了将离散向量映射到连续向量,使其可微分。作者建议使用高斯函数。如果你需要更多关于图像卷积的背景知识,这里有一个例子。
完成这一步后,你已经计算了水平和垂直位移:
计算缩放切线
在这里,我向你展示了在2中实现的一种切线计算 – 缩放切线。从3中,我们知道变换如下所示:
/* scaling */for(k=0;k<height;k++) for(j=0;j<width;j++) { currentTangent[ind] = ((j+offsetW)*x1[ind] + (k+offsetH)*x2[ind])*factor; ind++; }
在2的td.c
实现的开头,我们知道以下定义:
factorW=((double)width*0.5);offsetW=0.5-factorW;factorW=1.0/factorW;factorH=((double)height*0.5);offsetH=0.5-factorH;factorH=1.0/factorH;factor=(factorH<factorW)?factorH:factorW; //min
作者使用的是16×16大小的图像。所以我们知道
factor=factorW=factorH=1/8,
和
offsetH=offsetW = 0.5-8 = -7.5
还要注意我们已经计算了
x1[ind]
=,
x2[ind]
=
因此,我们将这些常数代入:
currentTangent[ind] = ((j-7.5)*x1[ind] + (k-7.5)*x2[ind])/8 = x1 * (j-7.5)/8 + x2 * (k-7.5)/8.
由于j
(也包括k
)是介于0到15之间的整数(图像的宽度和高度为16像素),(j-7.5)/8
只是一个介于-0.9375
到0.9375
之间的分数。
所以我猜(j+offsetW)*factor
是每个像素的位移,它与像素到图像中心的水平距离成比例。同样,你知道垂直位移(k+offsetH)*factor
。
计算旋转切线
旋转切线在3中定义如下:
/* rotation */for(k=0;k<height;k++) for(j=0;j<width;j++) { currentTangent[ind] = ((k+offsetH)*x1[ind] - (j+offsetW)*x2[ind])*factor; ind++; }
使用之前的结论,我们知道(k+offsetH)*factor
对应于y
。同样,- (j+offsetW)*factor
对应于-x
。所以你知道这正是3中使用的公式。
你可以在2中找到3中描述的所有其他切线的实现。我喜欢3中的下图,它清楚地展示了不同变换切线的位移效果。
计算图像之间的切线距离
只需按照tangentDistance
函数中的实现进行操作:
// determine the tangents of the first imagecalculateTangents(imageOne, tangents, numTangents, height, width, choice, background);// find the orthonormal tangent subspace numTangentsRemaining = normalizeTangents(tangents, numTangents, height, width);// determine the distance to the closest point in the subspacedist=calculateDistance(imageOne, imageTwo, (const double **) tangents, numTangentsRemaining, height, width);