基础Tensorflow模型显示随机结果

我正在开发一些新的模型,想回顾一些基础知识。因此,我决定编写一个分类器,将[1, 1]分类为1,其它组合分类为0。

我已经编写了几个不同的版本,但结果总是参差不齐。

from tensorflow.keras import layers, modelsfrom tensorflow import kerasdata = [[1., 1.], [1., 0.], [0., 1.], [0., 0.]]results = [[1.], [0.], [0.], [0.]]def build_model():  model = models.Sequential()  model.add(layers.Dense(len(data[0]), activation='relu'))  model.add(layers.Dense(128, activation='relu'))  model.add(layers.Dense(1))  model.compile(loss=keras.losses.BinaryCrossentropy(), metrics=[keras.metrics.Accuracy()], optimizer='adam')  return modelmodel = build_model()model.fit(data, results, epochs=1000)model.summary()print(model.predict([data[0]]))print(model.predict([data[1]]))print(model.predict([data[2]]))print(model.predict([data[3]]))

有时候输出完全错误:

[[0.]][[0.]][[0.]][[0.]]

模型从未达到准确性。

Epoch 1000/10001/1 [==============================] - 0s 910us/step - loss: 3.8562 - accuracy: 0.7500

有时候显示较低的准确性并产生不好的结果:

Epoch 1000/10001/1 [==============================] - 0s 918us/step - loss: 3.8562 - accuracy: 0.2500
[[-0.1101699]][[-0.13835455]][[-0.03829439]][[0.]]

其他时候它“有点”工作:

Epoch 1000/10001/1 [==============================] - 0s 898us/step - loss: 0.0000e+00 - accuracy: 0.0000e+00

(虽然我期望准确性达到1)

[[1.1292353]][[-0.167045]][[-0.03134967]][[-0.3522459]]

能否有人帮助我理解我的结果的差异。我已经创建了几个版本的这个模型,增加了更多的层,减少了层,使用了不同大小的Dense()层。我尝试了几种损失和度量方法,然而,此时我只是在随机尝试。


回答:

由于您有一个二元分类问题(即二元交叉熵损失和准确性度量),您不应该在最后一层使用线性激活函数,这是默认的,如果您不指定任何东西,就像这里一样;根据文档

激活:要使用的激活函数。如果您不指定任何东西,则不应用激活(即“线性”激活:a(x) = x)。

对于(二元)分类问题,在最后一层使用线性激活是没有意义的;因此,将您的最后一个模型层更改为:

model.add(layers.Dense(1, activation='sigmoid'))

您应该没问题了(另见训练准确性随训练损失降低模型未学习)。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注