在学习Coursera的机器学习课程时,我想在另一个数据集上测试所学内容,并为不同的算法绘制学习曲线。
我(相当随机地)选择了在线新闻流行度数据集,并尝试对其应用线性回归。
注意:我知道这可能不是一个好的选择,但我希望从线性回归开始,以便后来看到其他模型如何更适合。
我训练了一个线性回归模型,并绘制了以下学习曲线:
这个结果对我来说特别令人惊讶,所以我有一些问题:
- 这种曲线是否有可能出现,还是我的代码一定有问题?
- 如果这是正确的,当增加新的训练样本时,训练误差如何能如此迅速地增长?交叉验证误差如何能低于训练误差?
- 如果不是这样,有没有提示我哪里犯了错误?
以防万一,这里是我的代码(Octave / Matlab):
绘图:
lambda = 0;startPoint = 5000;stepSize = 500;[error_train, error_val] = ... learningCurve([ones(mTrain, 1) X_train], y_train, ... [ones(size(X_val, 1), 1) X_val], y_val, ... lambda, startPoint, stepSize);plot(error_train(:,1),error_train(:,2),error_val(:,1),error_val(:,2))title('Learning curve for linear regression')legend('Train', 'Cross Validation')xlabel('Number of training examples')ylabel('Error')
学习曲线:
S = ['Reg with '];for i = startPoint:stepSize:m temp_X = X(1:i,:); temp_y = y(1:i); % Initialize Theta initial_theta = zeros(size(X, 2), 1); % Create "short hand" for the cost function to be minimized costFunction = @(t) linearRegCostFunction(X, y, t, lambda); % Now, costFunction is a function that takes in only one argument options = optimset('MaxIter', 50, 'GradObj', 'on'); % Minimize using fmincg theta = fmincg(costFunction, initial_theta, options); [J, grad] = linearRegCostFunction(temp_X, temp_y, theta, 0); error_train = [error_train; [i J]]; [J, grad] = linearRegCostFunction(Xval, yval, theta, 0); error_val = [error_val; [i J]]; fprintf('%s %6i examples \r', S, i); fflush(stdout);end
编辑:如果我在分割训练/验证集并绘制学习曲线之前对整个数据集进行洗牌,我会得到非常不同的结果,如下面的三个例子:
注意:训练集的大小始终约为24k个样本,验证集约为8k个样本。
回答:
这种曲线是否有可能出现,还是我的代码一定有问题?
这是可能的,但可能性不大。你可能总是选择难以预测的实例作为训练集,而容易预测的实例作为测试集。确保你对数据进行洗牌,并使用10折交叉验证。
即使你做了所有这些,仍然有可能发生这种情况,这并不一定表明方法或实现有问题。
如果这是正确的,当增加新的训练样本时,训练误差如何能如此迅速地增长?交叉验证误差如何能低于训练误差?
假设你的数据只能通过三次多项式来正确拟合,而你使用的是线性回归。这意味着你添加的数据越多,你的模型不充分的表现就越明显(训练误差更高)。现在,如果你为测试集选择了少量实例,误差会更小,因为对于这个特定问题,线性与三次多项式在太少的测试实例上可能不会显示出很大的差异。
例如,如果你对2D点进行回归,并且你总是为测试集选择2个点,你的线性回归总是会有0误差。一个极端的例子,但你应该明白这个意思。
你的测试集有多大?
另外,确保在绘制学习曲线时,测试集保持不变。只有训练集应该增加。
如果不是这样,有没有提示我哪里犯了错误?
你的测试集可能不够大,或者你的训练集和测试集可能没有正确随机化。你应该对数据进行洗牌,并使用10折交叉验证。
你可能还想查找关于该数据集的其他研究。其他人得到了什么结果?
关于更新
我认为这更有意义。现在测试误差通常更高。然而,这些误差对我来说看起来很大。这给你提供的最重要信息可能是线性回归在拟合这个数据方面非常差。
再次建议你对学习曲线进行10折交叉验证。可以把它看作是将你当前的所有图表平均成一个图表。另外,在运行过程之前对数据进行洗牌。