### 逻辑回归的泛化问题

根据Andrew Ng在Coursera上的逻辑回归讲座,以下成本函数可以通过下面的更新表达式来最小化:

enter image description here

在约150个样本上运行该更新函数数百次后,我得到了以下模式,尽管每次迭代后的成本似乎都在按预期减少:

enter image description here

圆圈是我训练的样本,其中输入特征是每个点的(x, y)坐标,颜色是目标标签。红黄背景是模型预测的(x, y)输入的分类结果(红色=0,黄色=1)。

问题

  • 该更新例程是否不是对应成本函数J的正确偏导数?
  • 这种输出模式可能表明什么?

训练方法

// A single pass/epochconst lr = 0.003;let params = [0.5, 0.5, 0.5];const scores = samples.map(sample => sig(sum(sample, params));const errors = scores.map((score, i) => score - labels[i][0]);params = params.map((param, col) => {  return param - lr * errors.reduce((acc, error, row) => {    return acc + error * samples[row][col];  }, 0);});

样本训练数据

const samples = [  [1, 142, 78],  [1, 108, 182],  [1, 396, 47],  [1, 66,  102],  [1, 165, 116],  [1, 8,   106],  [1, 245, 119],  [1, 302, 17],  [1, 96,  38],  [1, 201, 132],];const labels = [  [0],  [1],  [0],  [0],  [1],  [1],  [1],  [0],  [1],];

编辑

这是这个的JSBin链接:https://jsbin.com/jinole/edit?html,js,output


回答:

你的问题完全是数值上的,因为你直接实现了逻辑损失,你的函数J需要对一个点取指数。与此同时,你的数据量很大,你的x/y坐标达到了数百。exp(400)在JS中会导致NaN,所以你的整个代码无法收敛。你只需要将你的点放在[0,2] x [0,4]的矩形中,而不是[0,200] x [0, 400],它就会正常工作。

enter image description here

例如:

function sum(x, w) {  return x.reduce((acc, _x, i) => acc + _x * w[i], 0);}function sig(z) {  return 1 / (1 + Math.exp(-z));}function cost(scores, labels) {  return -(1 / scores.length) * scores.reduce((acc, score, i) => {    var y = labels[i][0];    return y * Math.log(score) + (1 - y) * Math.log(1 - score);  }, 0);}function clear(ctx) {  ctx.clearRect(0, 0, 400, 200);}function render(ctx, points) {  points.forEach(point => {    if (point[2] > 0) {      ctx.fillStyle = '#3c5cff';    } else {      ctx.fillStyle = '#f956ff';    }    ctx.fillRect(Math.max(0, point[0] * 100 - 2), Math.max(0, point[1] * 100 - 2), 4, 4);    //      ctx.fillRect(point[0], point[1], 1, 1);  })}function renderEach(ctx, params) {  for (let y = 0; y < 200; y++) {    for (let x = 0; x < 400; x++) {      if (sig(sum([1, x / 100, y / 100], params)) < 0.5) {        ctx.fillStyle = '#b22438';      } else {        ctx.fillStyle = '#fff9b6';      }      ctx.fillRect(x, y, 1, 1);    }  }}function doEpoch(samples, params, learningRate, lastCost, cycle, maxCycles) {  var scores = samples.map(sample => sig(sum(sample, params)));  var errors = scores.map((score, i) => score - labels[i][0]);  var p = document.getElementById('log');  if (!p) {    p = document.createElement('p');    p.setAttribute('id', 'log');    document.body.appendChild(p);  }  params = params.map((param, col) => {    return param - learningRate * errors.reduce((acc, error, row) => (acc + error * samples[row][col]), 0);  });  var J = cost(scores, labels);  if (lastCost === null) {    lastCost = J;  }  if (cycle % 100 === 0) {    p.textContent = `Epoch = ${cycle}, Cost = ${J} (${J - lastCost}), Params = ${JSON.stringify(params, null, 2)}`;    clear(ctx);    renderEach(ctx, params);    render(ctx, points);  }  if (cycle < maxCycles) {    setTimeout(function() {      doEpoch(samples, params, learningRate, J, cycle + 1, maxCycles);    }, 10);  }}var canvas = document.createElement('canvas');canvas.width = 400;canvas.height = 200;document.body.appendChild(canvas);var ctx = canvas.getContext('2d');var lineY = 150;var points = [];for (let i = 0; i < 500; i++) {  var point = [parseInt(Math.random() * canvas.width, 10) / 100, parseInt(Math.random() * canvas.height, 10) / 100];  point.push(Number(point[1] <= lineY / 100));  points.push(point);}render(ctx, points);var samples = points.map(point => [point[0], point[1]]);var labels = points.map(point => [point[2]]);console.log('Samples', JSON.stringify(samples.slice(0, 10)));console.log('Labels', JSON.stringify(labels.slice(0, 10)));var params = [1].concat(samples[0].map(() => Math.random()));var withBias = samples.map(sample => [1].concat(sample));var epochs = 100000;var learningRate = 0.01;var lastCost = null;doEpoch(withBias, params, learningRate, lastCost, 0, epochs);
body {  background: #eee;  padding: 0;  margin: 0;  font-family: monospace;}canvas {  background: #fff;  width: 100%;  image-rendering: pixelated;}
<div id="plot-app"></div>

Related Posts

L1-L2正则化的不同系数

我想对网络的权重同时应用L1和L2正则化。然而,我找不…

使用scikit-learn的无监督方法将列表分类成不同组别,有没有办法?

我有一系列实例,每个实例都有一份列表,代表它所遵循的不…

f1_score metric in lightgbm

我想使用自定义指标f1_score来训练一个lgb模型…

通过相关系数矩阵进行特征选择

我在测试不同的算法时,如逻辑回归、高斯朴素贝叶斯、随机…

可以将机器学习库用于流式输入和输出吗?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

在TensorFlow中,queue.dequeue_up_to()方法的用途是什么?

我对这个方法感到非常困惑,特别是当我发现这个令人费解的…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注