我刚开始学习机器学习,现在遇到了这个问题。首先,我使用线性回归来拟合训练集,但得到了非常大的均方根误差(RMSE)。然后,我尝试使用多项式回归来减少偏差。
import numpy as npfrom sklearn.linear_model import LinearRegressionfrom sklearn.preprocessing import PolynomialFeaturesfrom sklearn.metrics import mean_squared_errorpoly_features = PolynomialFeatures(degree=2, include_bias=False)X_poly = poly_features.fit_transform(X)poly_reg = LinearRegression()poly_reg.fit(X_poly, y)poly_predict = poly_reg.predict(X_poly)poly_mse = mean_squared_error(X, poly_predict)poly_rmse = np.sqrt(poly_mse)poly_rmse
然后,我得到了比线性回归略好的结果,接着我继续设置阶数为3/4/5,结果不断改善。但随着阶数的增加,可能会有些过拟合。
最佳多项式阶数应该是能在交叉验证集上生成最低RMSE的阶数。但我不知道该如何实现这一点。我应该使用GridSearchCV吗?还是其他方法?
如果您能帮助我解决这个问题,我将非常感激。
回答:
下次您应该提供X/Y的数据,或者一些虚构的数据,这样会更快,并且能为您提供具体的解决方案。现在,我已经创建了一个形式为y = X**4 + X**3 + X + 1
的虚构方程。
有很多方法可以改进这一点,但快速找到最佳阶数的一个迭代方法是简单地在每个阶数上拟合数据,并选择表现最佳(例如,最低RMSE)的阶数。
您还可以尝试决定如何划分训练/测试/验证数据。
import numpy as npimport matplotlib.pyplot as plt from sklearn.linear_model import LinearRegressionfrom sklearn.preprocessing import PolynomialFeaturesfrom sklearn.metrics import mean_squared_errorfrom sklearn.model_selection import train_test_splitX = np.arange(100).reshape(100, 1)y = X**4 + X**3 + X + 1x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.3)rmses = []degrees = np.arange(1, 10)min_rmse, min_deg = 1e10, 0for deg in degrees: # 训练特征 poly_features = PolynomialFeatures(degree=deg, include_bias=False) x_poly_train = poly_features.fit_transform(x_train) # 线性回归 poly_reg = LinearRegression() poly_reg.fit(x_poly_train, y_train) # 与测试数据比较 x_poly_test = poly_features.fit_transform(x_test) poly_predict = poly_reg.predict(x_poly_test) poly_mse = mean_squared_error(y_test, poly_predict) poly_rmse = np.sqrt(poly_mse) rmses.append(poly_rmse) # 阶数的交叉验证 if min_rmse > poly_rmse: min_rmse = poly_rmse min_deg = deg# 绘制和展示结果print('最佳阶数 {},RMSE 为 {}'.format(min_deg, min_rmse)) fig = plt.figure()ax = fig.add_subplot(111)ax.plot(degrees, rmses)ax.set_yscale('log')ax.set_xlabel('阶数')ax.set_ylabel('RMSE')
这将打印:
最佳阶数 4,RMSE 为 1.27689038706e-08
或者,您还可以构建一个新的类来执行多项式拟合,并将其传递给GridSearchCV,并设置一组参数。