使用二次函数未知常数项,如何使用梯度下降法找到这些未知常数?

大家好。

我是机器学习的新手,目前正在学习梯度下降法。然而,我遇到了一个大问题。问题如下:

给定数字 [0,0],[1,1],[1,2],[2,1] 和方程 [ f=(a2)*x^2 + (a1)*x + a0 ]

通过手动计算,我得到了答案 [-1,5/2,0],但使用这些给定数据通过梯度下降法编写Python代码来找出解决方案却很困难。

在我的情况下,我尝试以最简单和最快的方式使用梯度下降法编写代码,如下所示:

学习率 = 0.1制作 **一系列的x值初始化给定的1,1,1为a2,a1,a0a2,a1,a0的偏导数 (a2_p:2x, a1_p:x, a0_p:1)梯度下降法:(例如) a2 = a2 - (学习率)( y - [(a2)*x^2 + (a1)*x + a0] )(a2_p)

附注:老实说,我不知道应该将什么放入 ‘x' 和 'y' 或 a2, a1, a0 中。

然而,我每次得到的答案都不正确且结果不同。所以,我想得到正确的方程或代码顺序的提示。

感谢您阅读我这个初级水平的问题。


回答:

您的方程中有一些错误

对于函数 f(x) = a2*x^2+a1*x+a0a2a1a0 的偏导数分别是 x^2x1

假设损失函数(1/2)*(y-f(x))^2

损失函数对ai的偏导数-(y-f(x))* f(x)对ai的偏导数,其中 i 属于 [0,2]

因此,梯度下降方程是:
ai = ai + 学习率*(y-f(x)) * f(x)对ai的偏导数,其中 i 属于 [0,2]

希望这个代码对您有帮助

#训练样本sample = [(0,0),(1,1),(1,2),(2,1)]#我们的函数 => a2*x^2+a1*x+a0class Function():    def __init__(self, a2, a1, a0):        self.a2 = a2        self.a1 = a1        self.a0 = a0        def eval(self, x):        return self.a2*x**2+self.a1*x+self.a0        def partial_a2(self, x):        return x**2        def partial_a1(self, x):        return x        def partial_a0(self, x):        return 1#初始化函数f = Function(1,1,1)#计算样本的损失def loss(sample, f):    return sum([(y-f.eval(x))**2 for x,y in sample])/len(sample)epochs = 100000lr = 0.0005#记录最佳值best_values = (0,0,0)for epoch in range(epochs):    min_loss = 100    for x, y in sample:       #梯度下降       f.a2 = f.a2+lr*(y-f.eval(x))*f.partial_a2(x)       f.a1 = f.a1+lr*(y-f.eval(x))*f.partial_a1(x)       f.a0 = f.a0+lr*(y-f.eval(x))*f.partial_a0(x)        #存储最佳值    epoch_loss = loss(sample, f)    if min_loss > epoch_loss:        min_loss = epoch_loss        best_values = (f.a2, f.a1, f.a0)       print("损失:", min_loss)print("最佳值 (a2,a1,a0):", best_values)

输出

损失: 0.12500004789165717最佳值 (a2,a1,a0): (-1.0001922562970325, 2.5003368582261487, 0.00014521557599919338)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注