Tensorflow ValueError: logits和labels必须具有相同的形状((None, 42) vs (None, 1))

当我运行model.fit()时出现Tensorflow错误。这是我的代码。

train_data = pd.read_csv('train.csv')train_data = shuffle(train_data).reset_index(drop=True)split_data = np.array_split(train_data, 50)train_image = []for i in tqdm(range(split_data[0].shape[0])):    path = 'train/train/'+str(train_data['category'][i]).zfill(2)+'/'+train_data['filename'][i]    img = image.load_img(path,target_size=(400,400,3))    img = image.img_to_array(img)    img = img/255    train_image.append(img)X = np.array(train_image)  # X.shape (2108, 400, 400, 3)y = np.array(split_data[0]['category'])   # y.shape (2108,)X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42, test_size=0.1)

这是我的CNN模型。

model = Sequential()model.add(Conv2D(filters=16, kernel_size=(5, 5), activation="relu", input_shape=(400,400,3)))model.add(MaxPooling2D(pool_size=(2, 2)))model.add(Dropout(0.25))...model.add(Flatten())model.add(Dense(128, activation='relu'))model.add(Dropout(0.5))model.add(Dense(42, activation='sigmoid'))model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])model.fit(X_train, y_train, epochs=10, validation_data=(X_test, y_test), batch_size=64)

运行model.fit()时出现错误

ValueError: logits和labels必须具有相同的形状((None, 42) vs (None, 1))

X_train的值

array([[[[0.99607843, 0.99607843, 0.99607843],         [0.99607843, 0.99607843, 0.99607843],         [0.99607843, 0.99607843, 0.99607843],         ...,         [1.        , 1.        , 1.        ],         [1.        , 1.        , 1.        ],         [1.        , 1.        , 1.        ]],         ...,       ]]], dtype=float32)

y_train的值

array([ 5, 41, 24, ..., 41, 19, 40], dtype=int64)

回答:

您正在进行多分类问题,您的标签也是整数编码的

使用softmax作为最后一层的激活函数:Dense(42, activation='softmax')

并使用sparse_categorical_crossentropy作为损失函数

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注