我尝试使用线性判别分析来自scikit-learn库,以对我的数据进行降维处理,我的数据有超过200个特征。但是我在LDA类中找不到inverse_transform
函数。
我想问的是,如何从LDA域中的一个点重建原始数据?
根据@bogatron和@kazemakase的回答进行编辑:
我想“原始数据”这个术语是错误的,我应该使用“原始坐标”或“原始空间”。我知道没有所有主成分分析(PCA)我们无法重建原始数据,但当我们构建形状空间时,我们借助PCA将数据投影到更低的维度。PCA试图用仅2或3个成分来解释数据,这些成分能够捕捉数据的大部分方差,如果我们基于这些成分重建数据,应该能显示出导致这种分离的形状部分。
我再次检查了scikit-learn LDA的源代码,我注意到特征向量存储在scalings_
变量中。当我们使用svd
求解器时,无法对特征向量(scalings_
)矩阵进行逆运算,但当我尝试矩阵的伪逆时,我能够重建形状。
这里有两张图片,分别是从[4.28, 0.52]和[0, 0]点重建的:
我认为如果有人能深入解释LDA逆变换的数学限制,那将是很棒的。
回答:
LDA的逆变换并不一定有意义,因为它会丢失大量信息。
作为对比,考虑PCA。我们得到一个用于变换数据的系数矩阵。我们可以通过从矩阵中删除行来进行降维。为了获得逆变换,我们首先对完整的矩阵进行逆运算,然后删除与删除的行对应的列。
LDA不会给我们一个完整的矩阵。我们只能得到一个无法直接逆运算的降维矩阵。可以取伪逆,但这远不如拥有完整矩阵时高效。
考虑一个简单的例子:
C = np.ones((3, 3)) + np.eye(3) # 完整变换矩阵U = C[:2, :] # 降维矩阵V1 = np.linalg.inv(C)[:, :2] # PCA风格的重建矩阵print(V1)#array([[ 0.75, -0.25],# [-0.25, 0.75],# [-0.25, -0.25]])V2 = np.linalg.pinv(U) # LDA风格的重建矩阵print(V2)#array([[ 0.63636364, -0.36363636],# [-0.36363636, 0.63636364],# [ 0.09090909, 0.09090909]])
如果我们有完整的矩阵,我们得到的逆变换(V1
)与简单逆变换(V2
)不同。这是因为在第二种情况下,我们丢失了所有关于被丢弃成分的信息。
已警告。如果你仍然想进行LDA逆变换,这里有一个函数:
import matplotlib.pyplot as pltfrom sklearn import datasetsfrom sklearn.decomposition import PCAfrom sklearn.discriminant_analysis import LinearDiscriminantAnalysisfrom sklearn.utils.validation import check_is_fittedfrom sklearn.utils import check_array, check_X_yimport numpy as npdef inverse_transform(lda, x): if lda.solver == 'lsqr': raise NotImplementedError("(inverse) transform not implemented for 'lsqr' " "solver (use 'svd' or 'eigen').") check_is_fitted(lda, ['xbar_', 'scalings_'], all_or_any=any) inv = np.linalg.pinv(lda.scalings_) x = check_array(x) if lda.solver == 'svd': x_back = np.dot(x, inv) + lda.xbar_ elif lda.solver == 'eigen': x_back = np.dot(x, inv) return x_backiris = datasets.load_iris()X = iris.datay = iris.targettarget_names = iris.target_nameslda = LinearDiscriminantAnalysis()Z = lda.fit(X, y).transform(X)Xr = inverse_transform(lda, Z)# 绘制原始数据和重建数据的前两个维度plt.plot(X[:, 0], X[:, 1], '.', label='原始数据')plt.plot(Xr[:, 0], Xr[:, 1], '.', label='重建数据')plt.legend()
你会看到,逆变换的结果与原始数据没有太大关系(当然,可以猜测投影的方向)。相当一部分变化已经永远消失了。