Keras Sequential 到 Functional API 的转换

我刚开始学习深度学习,尝试将 Keras 的 Sequential API 转换为在 CIFAR10 图像数据集上运行的 Functional API,但遇到了一些困难。我已经转换了模型,除了输入层之外看起来是一样的,但 Sequential 模型的平均准确率约为 70%,而我的 Functional 模型的平均准确率约为 10%。我非常希望能得到一些帮助来弄清楚哪里出了问题。以下是我的 Functional 代码:

import tensorflow as tffrom tensorflow import kerasfrom keras import datasets, layers, modelsfrom keras.models import Model, Input, Sequentialimport matplotlib.pyplot as plt

下载并准备数据:

(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()# 将像素值归一化为0到1之间train_images, test_images = train_images / 255.0, test_images / 255.0input_shape = train_images[0,:,:,:].shape

创建模型:

input = layers.Input(shape=input_shape)x = layers.Conv2D(32, (3, 3), activation='relu',padding='valid')(input)x = layers.MaxPooling2D((2,2))(x)x = layers.Conv2D(64, (3, 3), activation='relu')(x)x = layers.MaxPooling2D((2,2))(x)x = layers.Conv2D(64, (3, 3), activation='relu')(x)x = layers.Flatten()(x)x = layers.Dense(64, activation='relu')(x)x = layers.Dense(10)(x)model = Model(input, x, name='Functional')

编译并训练模型:

model.compile(optimizer='adam',              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),              metrics=['accuracy'])history = model.fit(train_images, train_labels, epochs=10,                     validation_data=(test_images, test_labels))

这是原始 Sequential CNN 的链接,这是一个 Google Colab 笔记本。我非常希望能得到任何帮助来理解和修复问题。提前感谢您。


回答:

似乎SparseCategoricalCrossentropy损失函数有些问题。

请查看这个链接: https://github.com/tensorflow/tensorflow/issues/38632

以下模型的准确率较高:

import tensorflow as tffrom tensorflow import kerasfrom keras import datasets, layers, modelsfrom keras.models import Model, Input, Sequentialimport matplotlib.pyplot as plt(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()# 将像素值归一化为0到1之间train_images, test_images = train_images / 255.0, test_images / 255.0train_labels, test_labels = tf.keras.utils.to_categorical(train_labels, 10) , tf.keras.utils.to_categorical(test_labels, 10)input_shape = train_images[0,:,:,:].shapeinput = layers.Input(shape=input_shape)x = layers.Conv2D(32, (3, 3), activation='relu',padding='valid')(input)x = layers.MaxPooling2D((2,2))(x)x = layers.Conv2D(64, (3, 3), activation='relu')(x)x = layers.MaxPooling2D((2,2))(x)x = layers.Conv2D(64, (3, 3), activation='relu')(x)x = layers.Flatten()(x)x = layers.Dense(64, activation='relu')(x)x = layers.Dense(10, activation='softmax')(x)model = Model(input, x, name='Functional')model.summary()model.compile(optimizer='adam',              loss=loss=tf.keras.losses.CategoricalCrossentropy(),              metrics=['accuracy'])history = model.fit(train_images, train_labels, epochs=10,                     validation_data=(test_images, test_labels))
conv2d_16 (Conv2D)           (None, 30, 30, 32)        896       _________________________________________________________________max_pooling2d_11 (MaxPooling (None, 15, 15, 32)        0         _________________________________________________________________conv2d_17 (Conv2D)           (None, 13, 13, 64)        18496     _________________________________________________________________max_pooling2d_12 (MaxPooling (None, 6, 6, 64)          0         _________________________________________________________________conv2d_18 (Conv2D)           (None, 4, 4, 64)          36928     _________________________________________________________________flatten_6 (Flatten)          (None, 1024)              0         _________________________________________________________________dense_11 (Dense)             (None, 64)                65600     _________________________________________________________________dense_12 (Dense)             (None, 10)                650       =================================================================Total params: 122,570Trainable params: 122,570Non-trainable params: 0_________________________________________________________________Train on 50000 samples, validate on 10000 samplesEpoch 1/1050000/50000 [==============================] - 15s 305us/step - loss: 1.4870 - accuracy: 0.4600 - val_loss: 1.2874 - val_accuracy: 0.5488Epoch 2/1050000/50000 [==============================] - 15s 301us/step - loss: 1.1365 - accuracy: 0.5989 - val_loss: 1.0789 - val_accuracy: 0.6191Epoch 3/1050000/50000 [==============================] - 15s 301us/step - loss: 0.9869 - accuracy: 0.6547 - val_loss: 0.9506 - val_accuracy: 0.6700Epoch 4/1050000/50000 [==============================] - 15s 301us/step - loss: 0.8896 - accuracy: 0.6907 - val_loss: 0.9509 - val_accuracy: 0.6695Epoch 5/1050000/50000 [==============================] - 16s 311us/step - loss: 0.8135 - accuracy: 0.7151 - val_loss: 0.8688 - val_accuracy: 0.7046Epoch 6/1050000/50000 [==============================] - 15s 303us/step - loss: 0.7566 - accuracy: 0.7351 - val_loss: 0.8411 - val_accuracy: 0.7141

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注