为什么每个神经网络层都有一个激活函数,而不仅仅是最后一层?

我正在自学机器学习,我有一个类似于这个的问题。

我的理解是否正确:

例如,如果我有一个输入矩阵,其中X1、X2和X3是三个数值特征(例如,假设它们是花瓣长度、茎长度、花长度,我试图标记样本是否属于特定花种):

x1  x2  x3  label5   1   2   yes3   9   8   no1   2   3   yes9   9   9   no  

你会将上述表格的第一行(而不是列)的向量输入到网络中,像这样:

即,将有三个神经元(每个值对应一个神经元),然后随机选择w1、w2和w3,接着计算下一列的第一个神经元时,进行我描述的乘法运算,并加上随机选择的偏置项。这就得到了该节点的值。

对于一组节点都是如此(即,每列实际上会有四个节点(三个加上一个偏置),为了简化,我从第二列中移除了其他三个节点),然后在输出前的最后一个节点上,有一个激活函数将总和转换为一个值(例如,sigmoid函数的0-1),该值告诉你分类是yes还是no。

我为这个问题如此基础而道歉,我真的想理解这个过程,我是通过免费资源学习的。因此,一般来说,你应该选择网络中节点的数量是特征数量的倍数,例如,在这种情况下,编写如下代码是有意义的:

from keras.models import Sequentialfrom keras.models import Densemodel = Sequential()model.add(Dense(6,input_dim=3,activation='relu'))model.add(Dense(6,input_dim=3,activation='relu'))model.add(Dense(3,activation='softmax'))

我不理解的是为什么Keras模型在网络的每一层都有激活函数,而不仅仅是在最后,这就是我怀疑我的理解是否正确/为什么我添加了图片的原因。

编辑1:我注意到在偏置神经元上,我在边上标注了’b=1’,这可能会引起混淆,我知道偏置没有权重,所以这只是提醒自己偏置节点的权重是1。


回答:

除了标题中的问题外,这里还有几个问题,但由于现在不是进行完整教程的时间和场合,我将讨论限制在你的一些观点上,同时考虑到至少已经有一个答案存在。

因此,一般来说,你应该选择网络中节点的数量是特征数量的倍数,

不是的。

特征的数量是在input_dim参数中传递的,仅用于模型的第一层设置;除了第一层之外,每一层的输入数量仅仅是前一层的输出数量。你写的Keras模型是无效的,它会产生错误,因为对于你的第二层,你要求input_dim=3,而前一层显然有6个输出(节点)。

除了input_dim参数之外,数据特征的数量与网络节点的数量之间没有任何其他关系;由于你似乎在考虑iris数据(4个特征),这里是一个简单的可复现示例,展示了如何将Keras模型应用于它们。

在你使用的Keras顺序API中,有一个隐含的输入层,其节点数量就是输入的维度;详情请参见Keras顺序模型输入层中的回答。

所以,你在便笺本上绘制的模型实际上对应于使用顺序API编写的以下Keras模型:

model = Sequential()model.add(Dense(1,input_dim=3,activation='linear'))

在函数式API中,它将被编写为:

inputs = Input(shape=(3,))                outputs = Dense(1, activation='linear')(inputs)     model = Model(inputs, outputs)

仅此而已,即实际上只是线性回归。

我知道偏置没有权重

偏置确实有权重。同样,有用的类比是与线性(或逻辑)回归的常数项:偏置“输入”本身总是1,其对应的系数(权重)是通过拟合过程学习的。

为什么Keras模型在网络的每一层都有激活函数,而不仅仅是在最后

我相信这个问题在另一个答案中已经得到了充分的解释。

我为这个问题如此基础而道歉,我真的想理解这个过程,我是通过免费资源学习的。

我们都这样做过;不过,没有理由不利用Andrew Ng在Coursera上提供的免费且优秀的机器学习MOOC课程。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注