我正在进行多变量/特征的线性回归。我尝试使用正规方程方法(使用矩阵求逆)来获取theta(系数),以及Numpy的最小二乘numpy.linalg.lstsq工具和np.linalg.solve工具。在我的数据中,我有n = 143个特征和m = 13000个训练样本。
对于带有正则化的正规方程方法,我使用以下公式:
正则化用于解决矩阵可能不可逆的问题(XtX
矩阵可能成为奇异/不可逆矩阵)
数据准备代码:
import pandas as pdimport numpy as nppath = 'DB2.csv' data = pd.read_csv(path, header=None, delimiter=";")data.insert(0, 'Ones', 1)cols = data.shape[1]X = data.iloc[:,0:cols-1] y = data.iloc[:,cols-1:cols] IdentitySize = X.shape[1]IdentityMatrix= np.zeros((IdentitySize, IdentitySize))np.fill_diagonal(IdentityMatrix, 1)
对于最小二乘方法,我使用Numpy的numpy.linalg.lstsq。以下是Python代码:
lamb = 1th = np.linalg.lstsq(X.T.dot(X) + lamb * IdentityMatrix, X.T.dot(y))[0]
我也使用了Numpy的np.linalg.solve工具:
lamb = 1XtX_lamb = X.T.dot(X) + lamb * IdentityMatrixXtY = X.T.dot(y)x = np.linalg.solve(XtX_lamb, XtY);
对于正规方程我使用:
lamb = 1xTx = X.T.dot(X) + lamb * IdentityMatrixXtX = np.linalg.inv(xTx)XtX_xT = XtX.dot(X.T)theta = XtX_xT.dot(y)
在所有方法中我都使用了正则化。以下是结果(theta系数),以便查看这三种方法之间的差异:
正规方程: np.linalg.lstsq np.linalg.solve[-27551.99918303] [-27551.95276154] [-27551.9991855][-940.27518383] [-940.27520138] [-940.27518383][-9332.54653964] [-9332.55448263] [-9332.54654461][-3149.02902071] [-3149.03496582] [-3149.02900965][-1863.25125909] [-1863.2631435] [-1863.25126344][-2779.91105618] [-2779.92175308] [-2779.91105347][-1226.60014026] [-1226.61033117] [-1226.60014192][-920.73334259] [-920.74331432] [-920.73334194][-6278.44238081] [-6278.45496955] [-6278.44237847][-2001.48544938] [-2001.49566981] [-2001.48545349][-715.79204971] [-715.79664124] [-715.79204921][ 4039.38847472] [ 4039.38302499] [ 4039.38847515][-2362.54853195] [-2362.55280478] [-2362.54853139][-12730.8039209] [-12730.80866036] [-12730.80392076][-24872.79868125] [-24872.80203459] [-24872.79867954][-3402.50791863] [-3402.5140501] [-3402.50793382][ 253.47894001] [ 253.47177732] [ 253.47892472][-5998.2045186] [-5998.20513905] [-5998.2045184][ 198.40560401] [ 198.4049081] [ 198.4056042][ 4368.97581411] [ 4368.97175688] [ 4368.97581426][-2885.68026222] [-2885.68154407] [-2885.68026205][ 1218.76602731] [ 1218.76562838] [ 1218.7660275][-1423.73583813] [-1423.7369068] [-1423.73583793][ 173.19125007] [ 173.19086525] [ 173.19125024][-3560.81709538] [-3560.81650156] [-3560.8170952][-142.68135768] [-142.68162508] [-142.6813575][-2010.89489111] [-2010.89601322] [-2010.89489092][-4463.64701238] [-4463.64742877] [-4463.64701219][ 17074.62997704] [ 17074.62974609] [ 17074.62997723][ 7917.75662561] [ 7917.75682048] [ 7917.75662578][-4234.16758492] [-4234.16847544] [-4234.16758474][-5500.10566329] [-5500.106558] [-5500.10566309][-5997.79002683] [-5997.7904842] [-5997.79002634][ 1376.42726683] [ 1376.42629704] [ 1376.42726705][ 6056.87496151] [ 6056.87452659] [ 6056.87496175][ 8149.0123667] [ 8149.01209157] [ 8149.01236827][-7273.3450484] [-7273.34480382] [-7273.34504827][-2010.61773247] [-2010.61839251] [-2010.61773225][-7917.81185096] [-7917.81223606] [-7917.81185084][ 8247.92773739] [ 8247.92774315] [ 8247.92773722][ 1267.25067823] [ 1267.24677734] [ 1267.25067832][ 2557.6208133] [ 2557.62126916] [ 2557.62081337][-5678.53744654] [-5678.53820798] [-5678.53744647][ 3406.41697822] [ 3406.42040997] [ 3406.41697836][-8371.23657044] [-8371.2361594] [-8371.23657035][ 15010.61728285] [ 15010.61598236] [ 15010.61728304][ 11006.21920273] [ 11006.21711213] [ 11006.21920284][-5930.93274062] [-5930.93237071] [-5930.93274048][-5232.84459862] [-5232.84557665] [-5232.84459848][ 3196.89304277] [ 3196.89414431] [ 3196.8930428][ 15298.53309912] [ 15298.53496877] [ 15298.53309919][ 4742.68631183] [ 4742.6862601] [ 4742.68631172][ 4423.14798495] [ 4423.14765013] [ 4423.14798546][-16153.50854089] [-16153.51038489] [-16153.50854123][-22071.50792741] [-22071.49808389] [-22071.50792408][-688.22903323] [-688.2310229] [-688.22904006][-1060.88119863] [-1060.8829114] [-1060.88120546][-101.75750066] [-101.75776411] [-101.75750831][ 4106.77311898] [ 4106.77128502] [ 4106.77311218][ 3482.99764601] [ 3482.99518758] [ 3482.99763924][-1100.42290509] [-1100.42166312] [-1100.4229119][ 20892.42685103] [ 20892.42487476] [ 20892.42684422][-5007.54075789] [-5007.54265501] [-5007.54076473][ 11111.83929421] [ 11111.83734144] [ 11111.83928704][ 9488.57342568] [ 9488.57158677] [ 9488.57341883][-2992.3070786] [-2992.29295891] [-2992.30708529][ 17810.57005982] [ 17810.56651223] [ 17810.57005457][-2154.47389712] [-2154.47504319] [-2154.47390285][-5324.34206726] [-5324.33913623] [-5324.34207293][-14981.89224345] [-14981.8965674] [-14981.89224973][-29440.90545197] [-29440.90465897] [-29440.90545704][-6925.31991443] [-6925.32123144] [-6925.31992383][ 104.98071593] [ 104.97886085] [ 104.98071152][-5184.94477582] [-5184.9447972] [-5184.94477792][ 1555.54536625] [ 1555.54254362] [ 1555.5453638][-402.62443474] [-402.62539068] [-402.62443718][ 17746.15769322] [ 17746.15458093] [ 17746.15769074][-5512.94925026] [-5512.94980649] [-5512.94925267][-2202.8589276] [-2202.86226244] [-2202.85893056][-5549.05250407] [-5549.05416936] [-5549.05250669][-1675.87329493] [-1675.87995809] [-1675.87329255][-5274.27756529] [-5274.28093377] [-5274.2775701][-5424.10246845] [-5424.10658526] [-5424.10247326][-1014.70864363] [-1014.71145066] [-1014.70864845][ 12936.59360437] [ 12936.59168749] [ 12936.59359954][ 2912.71566077] [ 2912.71282628] [ 2912.71565599][ 6489.36648506] [ 6489.36538259] [ 6489.36648021][ 12025.06991281] [ 12025.07040848] [ 12025.06990358][ 17026.57841531] [ 17026.56827742] [ 17026.57841044][ 2220.1852193] [ 2220.18531961] [ 2220.18521579][-2886.39219026] [-2886.39015388] [-2886.39219394][-18393.24573629] [-18393.25888463] [-18393.24573872][-17591.33051471] [-17591.32838012] [-17591.33051834][-3947.18545848] [-3947.17487999] [-3947.18546459][ 7707.05472816] [ 7707.05577227] [ 7707.0547217][ 4280.72039079] [ 4280.72338194] [ 4280.72038435][-3137.48835901] [-3137.48480197] [-3137.48836531][ 6693.47303443] [ 6693.46528167] [ 6693.47302811][-13936.14265517] [-13936.14329336] [-13936.14267094][ 2684.29594641] [ 2684.29859601] [ 2684.29594183][-2193.61036078] [-2193.63086307] [-2193.610366][-10139.10424848] [-10139.11905454] [-10139.10426049][ 4475.11569903] [ 4475.12288711] [ 4475.11569421][-3037.71857269] [-3037.72118246] [-3037.71857265][-5538.71349798] [-5538.71654224] [-5538.71349794][ 8008.38521357] [ 8008.39092739] [ 8008.38521361][-1433.43859633] [-1433.44181824] [-1433.43859629][ 4212.47144667] [ 4212.47368097] [ 4212.47144686][ 19688.24263706] [ 19688.2451694] [ 19688.2426368][ 104.13434091] [ 104.13434349] [ 104.13434091][-654.02451175] [-654.02493111] [-654.02451174][-2522.8642551] [-2522.88694451] [-2522.86424254][-5011.20385919] [-5011.22742915] [-5011.20384655][-13285.64644021] [-13285.66951459] [-13285.64642763][-4254.86406891] [-4254.88695873] [-4254.86405637][-2477.42063206] [-2477.43501057] [-2477.42061727][ 0.] [ 1.23691279e-10] [ 0.][-92.79470071] [-92.79467095] [-92.79470071][ 2383.66211583] [ 2383.66209637] [ 2383.66211583][-10725.22892185] [-10725.22889937] [-10725.22892185][ 234.77560283] [ 234.77560254] [ 234.77560283][ 4739.22119578] [ 4739.22121432] [ 4739.22119578][ 43640.05854156] [ 43640.05848841] [ 43640.05854157][ 2592.3866707] [ 2592.38671547] [ 2592.3866707][-25130.02819215] [-25130.05501178] [-25130.02819515][ 4966.82173096] [ 4966.7946407] [ 4966.82172795][ 14232.97930665] [ 14232.9529959] [ 14232.97930363][-21621.77202422] [-21621.79840459] [-21621.7720272][ 9917.80960029] [ 9917.80960571] [ 9917.80960029][ 1355.79191536] [ 1355.79198092] [ 1355.79191536][-27218.44185748] [-27218.46880642] [-27218.44185719][-27218.04184348] [-27218.06875423] [-27218.04184318][ 23482.80743869] [ 23482.78043029] [ 23482.80743898][ 3401.67707434] [ 3401.65134677] [ 3401.67707463][ 3030.36383274] [ 3030.36384909] [ 3030.36383274][-30590.61847724] [-30590.63933424] [-30590.61847706][-28818.3942685] [-28818.41520495] [-28818.39426833][-25115.73726772] [-25115.7580278] [-25115.73726753][ 77174.61695995] [ 77174.59548773] [ 77174.61696016][-20201.86613672] [-20201.88871113] [-20201.86613657][ 51908.53292209] [ 51908.53446495] [ 51908.53292207][ 7710.71327865] [ 7710.71324194] [ 7710.71327865][-16206.9785119] [-16206.97851993] [-16206.9785119]
如您所见,正规方程、最小二乘和np.linalg.solve工具方法在某种程度上给出了不同的结果。问题是为什么这三种方法会给出明显不同的结果,哪种方法给出的结果更有效和更准确?
假设:正规方程方法的结果与np.linalg.solve的结果非常接近。而np.linalg.lstsq的结果与两者都不同。由于正规方程使用了逆运算,我们不期望它的结果非常准确,因此np.linalg.solve工具的结果也是如此。似乎np.linalg.lstsq给出了更好的结果。
更新:
正如Dave Hensley提到的:
在行 np.fill_diagonal(IdentityMatrix, 1)
之后应添加此代码 IdentityMatrix[0,0] = 0
。
DB2.csv可在DropBox上获取:DB2.csv
完整的Python代码可在DropBox上获取:完整代码
回答:
正如@Matthew Gunn提到的,通过计算系数矩阵的显式逆来解决线性方程组是不好的做法。直接获得解会更快更准确(参见此处)。
您之所以会看到np.linalg.solve
和np.linalg.lstsq
之间的差异,是因为这些函数对您试图解决的系统做出了不同的假设,并使用了不同的数值方法。
-
在内部,
solve
调用了DGESV LAPACK例程,它使用LU分解,然后进行前向和后向代换来找到Ax = b
的精确解。它要求系统是精确确定的,即A
是方阵且满秩。 -
lstsq
则调用DGELSD,它使用A
的奇异值分解来找到一个最小二乘解。这也适用于超定和欠定情况。
如果您的系统是完全确定的,那么您应该使用solve
,因为它需要更少的浮点运算,因此会更快更精确。在您的案例中,XtX_lamb
由于正则化步骤而保证是满秩的。