我需要为10个类别创建一个模型。这应该是一个单层线性分类器,带有softmax激活函数。我从各种教程中编写了一些代码,但似乎没有得到我需要的结果。
这是我编写的函数:
def build_classifier(): model = models.Sequential([ layers.Input(shape=(2,)), layers.Dense(1, activation='softmax'), ]) return model
它通过以下方式被调用:
newModel = build_classifier()newModel.summary()
数据和标签定义如下:
labels = ['a','b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j']data = keras.datasets.cifar10.load_data()
我一直得到以下结果:
Model: "sequential_3"_________________________________________________________________Layer (type) Output Shape Param # =================================================================dense_4 (Dense) (None, 1) 3 =================================================================Total params: 3Trainable params: 3Non-trainable params: 0
但我需要得到:
Model: "linear_classifier"_________________________________________________________________Layer (type) Output Shape Param # =================================================================flatten (Flatten) (None, 3072) 0 _________________________________________________________________dense (Dense) (None, 10) 30730 =================================================================Total params: 30,730Trainable params: 30,730Non-trainable params: 0
我特别不确定如何从sequential_3转换到linear_classifier,因为我似乎只能找到models.Sequential
,而没有找到线性分类器的版本。
回答:
你是否在寻找类似这样的内容
import tensorflow as tfdef build_classifier(): model = tf.keras.Sequential([ tf.keras.layers.Input(shape=(32,32,3)), tf.keras.layers.Flatten(name='flatten'), tf.keras.layers.Dense(10, activation='linear', name='dense'), ], name='linear_classifier') return modelmodel = build_classifier()model.summary()Model: "linear_classifier"_________________________________________________________________Layer (type) Output Shape Param # =================================================================flatten (Flatten) (None, 3072) 0 _________________________________________________________________dense (Dense) (None, 10) 30730 =================================================================Total params: 30,730Trainable params: 30,730Non-trainable params: 0_________________________________________________________________