多类图像分类,如何加载掩码

我对Keras完全是新手,只用了几天时间,所以经验还很不足。

我成功训练了一个U-Net网络,它可以处理一个类别,然后输入RGB图像和灰度掩码进行训练,使用的代码如下:

def train_generator():    while True:        for start in range(0, len(ids_train_split), batch_size):            x_batch = []            y_batch = []            end = min(start + batch_size, len(ids_train_split))            ids_train_batch = ids_train_split[start:end]            for id in ids_train_batch.values:                img_name = 'IMG_'+str(id).split('_')[2]                image_path = os.path.join("input", "train", "{}.JPG".format(str(img_name)))                mca_mask_path = os.path.join("input", "train_mask", "{}.png".format(id))                img = cv2.imread(image_path)                img = cv2.resize(img, (input_size, input_size))                mask_mca = cv2.imread(mca_mask_path, cv2.IMREAD_GRAYSCALE)                mask_mca = cv2.resize(mask_mca, (input_size, input_size))                img = randomHueSaturationValue(img,                                               hue_shift_limit=(-50, 50),                                               sat_shift_limit=(-5, 5),                                               val_shift_limit=(-15, 15))                img, mask = randomShiftScaleRotate(img, mask,                                                   shift_limit=(-0.0625, 0.0625),                                                   scale_limit=(-0.1, 0.1),                                                   rotate_limit=(-0, 0))                img, mask = randomHorizontalFlip(img, mask)                mask = np.expand_dims(mask, axis=2)                x_batch.append(img)                y_batch.append(mask)            x_batch = np.array(x_batch, np.float32) / 255            y_batch = np.array(y_batch, np.float32) / 255            yield x_batch, y_batch

这是我的U-Net模型:

def get_unet_1(pretrained_weights=None, input_shape=(1024, 1024, 3), num_classes=1, learning_rate=0.0001):    inputs = Input(shape=input_shape)    # 1024    down0b = Conv2D(8, (3, 3), padding='same')(inputs)    down0b = BatchNormalization()(down0b)    down0b = Activation('relu')(down0b)    down0b = Conv2D(8, (3, 3), padding='same')(down0b)    down0b = BatchNormalization()(down0b)    down0b = Activation('relu')(down0b)    down0b_pool = MaxPooling2D((2, 2), strides=(2, 2))(down0b)    # 512    down0a = Conv2D(16, (3, 3), padding='same')(down0b_pool)    down0a = BatchNormalization()(down0a)    down0a = Activation('relu')(down0a)    down0a = Conv2D(16, (3, 3), padding='same')(down0a)    down0a = BatchNormalization()(down0a)    down0a = Activation('relu')(down0a)    down0a_pool = MaxPooling2D((2, 2), strides=(2, 2))(down0a)    # 256    down0 = Conv2D(32, (3, 3), padding='same')(down0a_pool)    down0 = BatchNormalization()(down0)    down0 = Activation('relu')(down0)    down0 = Conv2D(32, (3, 3), padding='same')(down0)    down0 = BatchNormalization()(down0)    down0 = Activation('relu')(down0)    down0_pool = MaxPooling2D((2, 2), strides=(2, 2))(down0)    # 128    down1 = Conv2D(64, (3, 3), padding='same')(down0_pool)    down1 = BatchNormalization()(down1)    down1 = Activation('relu')(down1)    down1 = Conv2D(64, (3, 3), padding='same')(down1)    down1 = BatchNormalization()(down1)    down1 = Activation('relu')(down1)    down1_pool = MaxPooling2D((2, 2), strides=(2, 2))(down1)    # 64    down2 = Conv2D(128, (3, 3), padding='same')(down1_pool)    down2 = BatchNormalization()(down2)    down2 = Activation('relu')(down2)    down2 = Conv2D(128, (3, 3), padding='same')(down2)    down2 = BatchNormalization()(down2)    down2 = Activation('relu')(down2)    down2_pool = MaxPooling2D((2, 2), strides=(2, 2))(down2)    # 32    down3 = Conv2D(256, (3, 3), padding='same')(down2_pool)    down3 = BatchNormalization()(down3)    down3 = Activation('relu')(down3)    down3 = Conv2D(256, (3, 3), padding='same')(down3)    down3 = BatchNormalization()(down3)    down3 = Activation('relu')(down3)    down3_pool = MaxPooling2D((2, 2), strides=(2, 2))(down3)    # 16    down4 = Conv2D(512, (3, 3), padding='same')(down3_pool)    down4 = BatchNormalization()(down4)    down4 = Activation('relu')(down4)    down4 = Conv2D(512, (3, 3), padding='same')(down4)    down4 = BatchNormalization()(down4)    down4 = Activation('relu')(down4)    down4_pool = MaxPooling2D((2, 2), strides=(2, 2))(down4)    # 8    center = Conv2D(1024, (3, 3), padding='same')(down4_pool)    center = BatchNormalization()(center)    center = Activation('relu')(center)    center = Conv2D(1024, (3, 3), padding='same')(center)    center = BatchNormalization()(center)    center = Activation('relu')(center)    # center    up4 = UpSampling2D((2, 2))(center)    up4 = concatenate([down4, up4], axis=3)    up4 = Conv2D(512, (3, 3), padding='same')(up4)    up4 = BatchNormalization()(up4)    up4 = Activation('relu')(up4)    up4 = Conv2D(512, (3, 3), padding='same')(up4)    up4 = BatchNormalization()(up4)    up4 = Activation('relu')(up4)    up4 = Conv2D(512, (3, 3), padding='same')(up4)    up4 = BatchNormalization()(up4)    up4 = Activation('relu')(up4)    # 16    up3 = UpSampling2D((2, 2))(up4)    up3 = concatenate([down3, up3], axis=3)    up3 = Conv2D(256, (3, 3), padding='same')(up3)    up3 = BatchNormalization()(up3)    up3 = Activation('relu')(up3)    up3 = Conv2D(256, (3, 3), padding='same')(up3)    up3 = BatchNormalization()(up3)    up3 = Activation('relu')(up3)    up3 = Conv2D(256, (3, 3), padding='same')(up3)    up3 = BatchNormalization()(up3)    up3 = Activation('relu')(up3)    # 32    up2 = UpSampling2D((2, 2))(up3)    up2 = concatenate([down2, up2], axis=3)    up2 = Conv2D(128, (3, 3), padding='same')(up2)    up2 = BatchNormalization()(up2)    up2 = Activation('relu')(up2)    up2 = Conv2D(128, (3, 3), padding='same')(up2)    up2 = BatchNormalization()(up2)    up2 = Activation('relu')(up2)    up2 = Conv2D(128, (3, 3), padding='same')(up2)    up2 = BatchNormalization()(up2)    up2 = Activation('relu')(up2)    # 64    up1 = UpSampling2D((2, 2))(up2)    up1 = concatenate([down1, up1], axis=3)    up1 = Conv2D(64, (3, 3), padding='same')(up1)    up1 = BatchNormalization()(up1)    up1 = Activation('relu')(up1)    up1 = Conv2D(64, (3, 3), padding='same')(up1)    up1 = BatchNormalization()(up1)    up1 = Activation('relu')(up1)    up1 = Conv2D(64, (3, 3), padding='same')(up1)    up1 = BatchNormalization()(up1)    up1 = Activation('relu')(up1)    # 128    up0 = UpSampling2D((2, 2))(up1)    up0 = concatenate([down0, up0], axis=3)    up0 = Conv2D(32, (3, 3), padding='same')(up0)    up0 = BatchNormalization()(up0)    up0 = Activation('relu')(up0)    up0 = Conv2D(32, (3, 3), padding='same')(up0)    up0 = BatchNormalization()(up0)    up0 = Activation('relu')(up0)    up0 = Conv2D(32, (3, 3), padding='same')(up0)    up0 = BatchNormalization()(up0)    up0 = Activation('relu')(up0)    # 256    up0a = UpSampling2D((2, 2))(up0)    up0a = concatenate([down0a, up0a], axis=3)    up0a = Conv2D(16, (3, 3), padding='same')(up0a)    up0a = BatchNormalization()(up0a)    up0a = Activation('relu')(up0a)    up0a = Conv2D(16, (3, 3), padding='same')(up0a)    up0a = BatchNormalization()(up0a)    up0a = Activation('relu')(up0a)    up0a = Conv2D(16, (3, 3), padding='same')(up0a)    up0a = BatchNormalization()(up0a)    up0a = Activation('relu')(up0a)    # 512    up0b = UpSampling2D((2, 2))(up0a)    up0b = concatenate([down0b, up0b], axis=3)    up0b = Conv2D(8, (3, 3), padding='same')(up0b)    up0b = BatchNormalization()(up0b)    up0b = Activation('relu')(up0b)    up0b = Conv2D(8, (3, 3), padding='same')(up0b)    up0b = BatchNormalization()(up0b)    up0b = Activation('relu')(up0b)    up0b = Conv2D(8, (3, 3), padding='same')(up0b)    up0b = BatchNormalization()(up0b)    up0b = Activation('relu')(up0b)    # 1024    classify = Conv2D(num_classes, (1, 1), activation='sigmoid')(up0b)    model = Model(inputs=inputs, outputs=classify)    model.compile(optimizer=RMSprop(lr=learning_rate), loss=make_loss('bce_dice'), metrics=[dice_coef, 'accuracy'])    if pretrained_weights:        model.load_weights(pretrained_weights)    return model

现在我需要修改问题,使其成为多类分类器,所以我不再使用一个掩码,而是使用两个。因此,我有两种类型的灰度掩码(Mca_maskNotMca_mask,它们属于同一个训练图像),在这种情况下,标准做法是什么?将两个掩码合并成一个吗?


回答:

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注