TensorFlow文本分类示例中使用的激活层是什么

我正在尝试理解TensorFlow的文本分类示例,位于https://www.tensorflow.org/tutorials/keras/text_classification。他们定义模型如下:

model = tf.keras.Sequential([  layers.Embedding(max_features + 1, embedding_dim),  layers.Dropout(0.2),  layers.GlobalAveragePooling1D(),  layers.Dropout(0.2),  layers.Dense(1)])

据我所知,深度学习模型使用激活函数,我想知道上面这个分类模型内部使用的是什么激活函数。谁能帮我理解一下?


回答:

正如你所看到的,模型定义是这样的:

model = tf.keras.Sequential([  layers.Embedding(max_features + 1, embedding_dim),  layers.Dropout(0.2),  layers.GlobalAveragePooling1D(),  layers.Dropout(0.2),  layers.Dense(1)])

那个教程中使用的数据集是一个二分类的,标记为zeroone。通过不对模型的最后一层定义任何激活函数,原始作者希望得到logits而不是概率。这就是为什么他们使用loss函数如下:

model.compile(loss=losses.BinaryCrossentropy(from_logits=True),              ... 

现在,如果我们将最后一层的激活函数设置为sigmoid(通常用于二分类),那么我们必须设置from_logits=False。所以,这里有两个选项可供选择:

带logit: True

我们从最后一层获取logit,这就是为什么我们设置from_logits=True

model = tf.keras.Sequential([  layers.Embedding(max_features + 1, embedding_dim),  layers.Dropout(0.2),  layers.GlobalAveragePooling1D(),  layers.Dropout(0.2),  layers.Dense(1, activation=None)])model.compile(loss=losses.BinaryCrossentropy(from_logits=True),              optimizer='adam',              metrics=['accuracy'])history = model.fit(    train_ds, verbose=2,    validation_data=val_ds,    epochs=epochs)
7ms/step - loss: 0.6828 - accuracy: 0.5054 - val_loss: 0.6148 - val_accuracy: 0.5452Epoch 2/37ms/step - loss: 0.5797 - accuracy: 0.6153 - val_loss: 0.4976 - val_accuracy: 0.7406Epoch 3/37ms/step - loss: 0.4664 - accuracy: 0.7734 - val_loss: 0.4197 - val_accuracy: 0.8096

不带logit: False

在这里,我们从最后一层获取probability,这就是为什么我们设置from_logits=False

model = tf.keras.Sequential([  layers.Embedding(max_features + 1, embedding_dim),  layers.Dropout(0.2),  layers.GlobalAveragePooling1D(),  layers.Dropout(0.2),  layers.Dense(1, activation='sigmoid')])model.compile(loss=losses.BinaryCrossentropy(from_logits=False),              optimizer='adam',              metrics=['accuracy'])history = model.fit(    train_ds, verbose=2,    validation_data=val_ds,    epochs=epochs)
Epoch 1/38ms/step - loss: 0.6818 - accuracy: 0.6163 - val_loss: 0.6135 - val_accuracy: 0.7736Epoch 2/37ms/step - loss: 0.5787 - accuracy: 0.7871 - val_loss: 0.4973 - val_accuracy: 0.8226Epoch 3/38ms/step - loss: 0.4650 - accuracy: 0.8365 - val_loss: 0.4195 - val_accuracy: 0.8472

现在,你可能会想,为什么这个教程使用logit(或不对最后一层使用激活函数)?简短的回答是,这通常无关紧要,我们可以选择任何选项。关键是,在使用from_logits=False的情况下,可能会有数值不稳定性的风险。查看这个回答以获取更多详情。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注