fmin 性能极差。fmin_bfgs 精度丢失。最小化不适合

目前我正在实现逻辑回归。没有什么复杂的,只是使用一个简单的数据库(Andrew Ng 的房屋购买预测)。这是我正在做的事情:

我的成本函数:

def Cost(theta, X, Y):  m = Y.size   h = Sigmoid(X.dot(theta.T))  J = (1.0/m) * ((-Y.T.dot(log(h))) - ((1.0 - Y.T).dot(log(1.0-h))))  return J.sum()

调用 fmin:

initial_theta = zeros(shape = (X.shape[1],1))theta = fmin(Cost2,initial_theta, args = (X,Y))

使用 fmin 时,我得到的最终 theta 值对于预测来说太大了。在预测时,我总是得到大约 0.62 和 0.71 的值,这总是会预测为真。也许通过更多的迭代,我可以得到更好的结果,但我对此并不确定。

使用 fmin_bfgs 时,成本会收敛到 NaN,使其无法使用。

还有一些其他数据:

最终 theta:

[ 0.00126059  0.01033406]

最终成本:

[ 0.62079972]

预测值:

[ 0.63422573  0.6727308   0.62957501  0.66757524  0.64503653  0.622457270.67765315  0.68966732  0.72525886  0.73487524  0.67716454  0.709740590.7142225   0.70415933  0.62892863  0.69232142  0.70645758  0.641526050.62052863  0.69538731]

真实评级(如果为 1,预测值应大于或等于 0.5;如果为 0,预测值应小于 0.5)。这是我应该得到的结果:

[0 0 0 1 0 0 0 0 1 1 1 1 1 1 0 1 1 1 0 1]

有什么改进的建议吗?


回答:

经过一些研究和测试后,我找到了我的代码无法正常工作的原因。

由于 fmin_bfgs 会转换为 NaN,我决定查看原因,并考虑如何解决这个问题。我所做的并不是最佳方法,但解决了问题,现在我的代码可以正常工作了。

基本上,fmin_bfgs 生成的数字太小,导致溢出,结果变成 NaN。我所做的(再次强调,这不是解决问题的最佳方式,但确实有效):

首先:将成本函数分成三部分:

def Cost(theta, X, Y):   m = Y.size  Y = Y.flatten()  for i in range(X.shape[0]):  X[i]=X[i].flatten()  h = Sigmoid(X.dot(theta.T))  a = (-Y.T.dot(log(h)))  b = ((1.0 - Y.T).dot(log(1.0-h)))

如你所见,代码

J = (1.0/m) * ((-Y.T.dot(log(h))) - ((1.0 - Y.T).dot(log(1.0-h))))

被替换为 a、b 和 h,其中 h 是对向量应用的 sigmoid 函数。经过一些测试,我发现问题出在 b 项上。log 计算产生了 -∞,因为 h 被计算为每个项都是 1,导致 log(0.0),对于那些了解一些基本数学的人来说,这是 -∞。所以,我这样解决了这个问题:

if math.isnan(b):  #溢出情况  b = -999999J =  (a-b)/mreturn J/m 

我的想法是:“嗯,我在这里接收到一个 -∞。这是一个非常小的数字,但会导致溢出。所以,让我们用一个不会导致溢出的非常小的数字来替换它!”

再次强调,这可能不是最佳方法,但对我来说确实有效。

在这之后,我的代码运行得非常顺畅,实际上效果也非常好。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注