神经网络没有隐藏层且使用线性激活函数时,应能近似线性回归?

据我所知,如果神经网络没有隐藏层且使用线性激活函数,那么它应该会产生与线性回归相同的方程式形式。即 y = SUM(w_i * x_i + b_i),其中 i 从0到特征数量。

我尝试通过使用线性回归的权重和偏置,将它们输入到神经网络中,看看结果是否相同。但结果并非如此。

我想知道是我的理解有误,还是我的代码有问题,或者两者都有问题。


from sklearn.linear_model import LinearRegressionimport tensorflow as tffrom tensorflow import kerasimport numpy as nplinearModel = LinearRegression()linearModel.fit(np.array(normTrainFeaturesDf), np.array(trainLabelsDf))# 获取线性模型的权重和截距,以便能传递给神经网络linearWeights = np.array(linearModel.coef_)intercept = np.array([linearModel.intercept_])trialWeights = np.reshape(linearWeights, (len(linearWeights), 1))trialWeights = trialWeights.astype('float32')intercept = intercept.astype('float32')newTrialWeights = [trialWeights, intercept]# 创建一个神经网络并将模型的权重设置为线性模型的权重nnModel = keras.Sequential([keras.layers.Dense(1, activation='linear', input_shape=[len(normTrainFeaturesDf.keys())]),])nnModel.set_weights(newTrialWeights)# 打印两个模型的预测结果(结果差异很大)print(linearModel.predict(np.array(normTestFeaturesDf))print(nnModel.predict(normTestFeaturesDf).flatten())


回答:

是的,一个只有单层且没有激活函数的神经网络等同于线性回归。

定义一些你未包含的变量:

normTrainFeaturesDf = np.random.rand(100, 10)normTestFeaturesDf = np.random.rand(10, 10)trainLabelsDf = np.random.rand(100)

然后输出结果如预期:

>>> linear_model_preds = linearModel.predict(np.array(normTestFeaturesDf))>>> nn_model_preds = nnModel.predict(normTestFeaturesDf).flatten()>>> print(linear_model_preds)>>> print(nn_model_preds)[0.46030349 0.69676376 0.43064266 0.4583325  0.50750268 0.51753189 0.47254946 0.50654825 0.52998559 0.35908762][0.46030346 0.69676375 0.43064266 0.45833248 0.5075026  0.5175319 0.47254944 0.50654817 0.52998555 0.3590876 ]

这些数字是相同的,除了由于浮点精度导致的小差异。

>>> np.allclose(linear_model_preds, nn_model_preds)True

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注