在Python中对对数下降曲线进行梯度下降

我想对如下所示的对数下降曲线进行梯度下降:

y = y0 – a * ln(b + x).

在这个例子中我的y0是: 800

我尝试使用关于a和b的偏导数来做这件事,但虽然这显然最小化了平方误差,但它并没有收敛。我知道这不是向量化的,我可能完全采取了错误的方法。我犯了什么简单的错误,还是完全偏离了这个问题?

import numpy as np# constants my gradient descent model should find:a = 4b = 4# function to fit on!def function(x, a, b):    y0 = 800    return y0 - a * np.log(b + x)# Generates datadef gen_data(numpoints):    a = 4    b = 4    x = np.array(range(0, numpoints))    y = function(x, a, b)    return x, yx, y = gen_data(600)def grad_model(x, y, iterations):    converged = False    # length of dataset    m = len(x)    # guess   a ,  b    theta = [0.1, 0.1]    alpha = 0.001    # initial error    e = np.sum((np.square(function(x, theta[0], theta[1])) - y))    for iteration in range(iterations):        hypothesis = function(x, theta[0], theta[1])        loss = hypothesis - y        # compute partial deritaves to find slope to "fall" into        theta0_grad = (np.mean(np.sum(-np.log(x + y)))) / (m)        theta1_grad = (np.mean((((np.log(theta[1] + x)) / theta[0]) - (x*(np.log(theta[1] + x)) / theta[0])))) / (2*m)        theta0 = theta[0] - (alpha * theta0_grad)        theta1 = theta[1] - (alpha * theta1_grad)        theta[1] = theta1        theta[0] = theta0        new_e = np.sum(np.square((function(x, theta[0], theta[1])) - y))        if new_e > e:            print "AHHHH!"            print "Iteration: "+ str(iteration)            break        print theta    return theta[0], theta[1]

回答:

我在你的代码中发现了一些错误。这行

e = np.sum((np.square(function(x, theta[0], theta[1])) - y))

是错误的,应该替换为

e = np.sum((np.square(function(x, theta[0], theta[1]) - y)))

new_e的公式包含了同样的错误。

此外,梯度公式也是错误的。你的损失函数是$L(a,b) = \sum_{i=1}^N y_0 – a \log(b + x_i)$,所以你必须计算$L$关于$a$和$b$的偏导数。(LaTeX在stackoverflow上真的不工作吗?)最后一点是,梯度下降方法有一个步长限制,所以我们的步长不能太大。这里是你的代码的一个工作更好的版本:

import numpy as npimport matplotlib.pyplot as plt# constants my gradient descent model should find:a = 4.0b = 4.0y0 = 800.0# function to fit on!def function(x, a, b):    # y0 = 800    return y0 - a * np.log(b + x)# Generates datadef gen_data(numpoints):    # a = 4    # b = 4    x = np.array(range(0, numpoints))    y = function(x, a, b)    return x, yx, y = gen_data(600)def grad_model(x, y, iterations):    converged = False    # length of dataset    m = len(x)    # guess   a ,  b    theta = [0.1, 0.1]    alpha = 0.00001    # initial error    # e = np.sum((np.square(function(x, theta[0], theta[1])) - y))    #  This was a bug    e = np.sum((np.square(function(x, theta[0], theta[1]) - y)))    costs = np.zeros(iterations)    for iteration in range(iterations):        hypothesis = function(x, theta[0], theta[1])        loss = hypothesis - y        # compute partial deritaves to find slope to "fall" into        # theta0_grad = (np.mean(np.sum(-np.log(x + y)))) / (m)        # theta1_grad = (np.mean((((np.log(theta[1] + x)) / theta[0]) - (x*(np.log(theta[1] + x)) / theta[0])))) / (2*m)        theta0_grad = 2*np.sum((y0 - theta[0]*np.log(theta[1] + x) - y)*(-np.log(theta[1] + x)))        theta1_grad = 2*np.sum((y0 - theta[0]*np.log(theta[1] + x) - y)*(-theta[0]/(b + x)))        theta0 = theta[0] - (alpha * theta0_grad)        theta1 = theta[1] - (alpha * theta1_grad)        theta[1] = theta1        theta[0] = theta0        # new_e = np.sum(np.square((function(x, theta[0], theta[1])) - y)) # This was a bug        new_e = np.sum(np.square((function(x, theta[0], theta[1]) - y)))        costs[iteration] = new_e        if new_e > e:            print "AHHHH!"            print "Iteration: "+ str(iteration)            # break        print theta    return theta[0], theta[1], costs(theta0,theta1,costs) = grad_model(x,y,100000)plt.semilogy(costs)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注