我对Keras完全是新手,只用了几天时间,所以经验还很不足。
我成功训练了一个U-Net网络,它可以处理一个类别,然后输入RGB图像和灰度掩码进行训练,使用的代码如下:
def train_generator(): while True: for start in range(0, len(ids_train_split), batch_size): x_batch = [] y_batch = [] end = min(start + batch_size, len(ids_train_split)) ids_train_batch = ids_train_split[start:end] for id in ids_train_batch.values: img_name = 'IMG_'+str(id).split('_')[2] image_path = os.path.join("input", "train", "{}.JPG".format(str(img_name))) mca_mask_path = os.path.join("input", "train_mask", "{}.png".format(id)) img = cv2.imread(image_path) img = cv2.resize(img, (input_size, input_size)) mask_mca = cv2.imread(mca_mask_path, cv2.IMREAD_GRAYSCALE) mask_mca = cv2.resize(mask_mca, (input_size, input_size)) img = randomHueSaturationValue(img, hue_shift_limit=(-50, 50), sat_shift_limit=(-5, 5), val_shift_limit=(-15, 15)) img, mask = randomShiftScaleRotate(img, mask, shift_limit=(-0.0625, 0.0625), scale_limit=(-0.1, 0.1), rotate_limit=(-0, 0)) img, mask = randomHorizontalFlip(img, mask) mask = np.expand_dims(mask, axis=2) x_batch.append(img) y_batch.append(mask) x_batch = np.array(x_batch, np.float32) / 255 y_batch = np.array(y_batch, np.float32) / 255 yield x_batch, y_batch
这是我的U-Net模型:
def get_unet_1(pretrained_weights=None, input_shape=(1024, 1024, 3), num_classes=1, learning_rate=0.0001): inputs = Input(shape=input_shape) # 1024 down0b = Conv2D(8, (3, 3), padding='same')(inputs) down0b = BatchNormalization()(down0b) down0b = Activation('relu')(down0b) down0b = Conv2D(8, (3, 3), padding='same')(down0b) down0b = BatchNormalization()(down0b) down0b = Activation('relu')(down0b) down0b_pool = MaxPooling2D((2, 2), strides=(2, 2))(down0b) # 512 down0a = Conv2D(16, (3, 3), padding='same')(down0b_pool) down0a = BatchNormalization()(down0a) down0a = Activation('relu')(down0a) down0a = Conv2D(16, (3, 3), padding='same')(down0a) down0a = BatchNormalization()(down0a) down0a = Activation('relu')(down0a) down0a_pool = MaxPooling2D((2, 2), strides=(2, 2))(down0a) # 256 down0 = Conv2D(32, (3, 3), padding='same')(down0a_pool) down0 = BatchNormalization()(down0) down0 = Activation('relu')(down0) down0 = Conv2D(32, (3, 3), padding='same')(down0) down0 = BatchNormalization()(down0) down0 = Activation('relu')(down0) down0_pool = MaxPooling2D((2, 2), strides=(2, 2))(down0) # 128 down1 = Conv2D(64, (3, 3), padding='same')(down0_pool) down1 = BatchNormalization()(down1) down1 = Activation('relu')(down1) down1 = Conv2D(64, (3, 3), padding='same')(down1) down1 = BatchNormalization()(down1) down1 = Activation('relu')(down1) down1_pool = MaxPooling2D((2, 2), strides=(2, 2))(down1) # 64 down2 = Conv2D(128, (3, 3), padding='same')(down1_pool) down2 = BatchNormalization()(down2) down2 = Activation('relu')(down2) down2 = Conv2D(128, (3, 3), padding='same')(down2) down2 = BatchNormalization()(down2) down2 = Activation('relu')(down2) down2_pool = MaxPooling2D((2, 2), strides=(2, 2))(down2) # 32 down3 = Conv2D(256, (3, 3), padding='same')(down2_pool) down3 = BatchNormalization()(down3) down3 = Activation('relu')(down3) down3 = Conv2D(256, (3, 3), padding='same')(down3) down3 = BatchNormalization()(down3) down3 = Activation('relu')(down3) down3_pool = MaxPooling2D((2, 2), strides=(2, 2))(down3) # 16 down4 = Conv2D(512, (3, 3), padding='same')(down3_pool) down4 = BatchNormalization()(down4) down4 = Activation('relu')(down4) down4 = Conv2D(512, (3, 3), padding='same')(down4) down4 = BatchNormalization()(down4) down4 = Activation('relu')(down4) down4_pool = MaxPooling2D((2, 2), strides=(2, 2))(down4) # 8 center = Conv2D(1024, (3, 3), padding='same')(down4_pool) center = BatchNormalization()(center) center = Activation('relu')(center) center = Conv2D(1024, (3, 3), padding='same')(center) center = BatchNormalization()(center) center = Activation('relu')(center) # center up4 = UpSampling2D((2, 2))(center) up4 = concatenate([down4, up4], axis=3) up4 = Conv2D(512, (3, 3), padding='same')(up4) up4 = BatchNormalization()(up4) up4 = Activation('relu')(up4) up4 = Conv2D(512, (3, 3), padding='same')(up4) up4 = BatchNormalization()(up4) up4 = Activation('relu')(up4) up4 = Conv2D(512, (3, 3), padding='same')(up4) up4 = BatchNormalization()(up4) up4 = Activation('relu')(up4) # 16 up3 = UpSampling2D((2, 2))(up4) up3 = concatenate([down3, up3], axis=3) up3 = Conv2D(256, (3, 3), padding='same')(up3) up3 = BatchNormalization()(up3) up3 = Activation('relu')(up3) up3 = Conv2D(256, (3, 3), padding='same')(up3) up3 = BatchNormalization()(up3) up3 = Activation('relu')(up3) up3 = Conv2D(256, (3, 3), padding='same')(up3) up3 = BatchNormalization()(up3) up3 = Activation('relu')(up3) # 32 up2 = UpSampling2D((2, 2))(up3) up2 = concatenate([down2, up2], axis=3) up2 = Conv2D(128, (3, 3), padding='same')(up2) up2 = BatchNormalization()(up2) up2 = Activation('relu')(up2) up2 = Conv2D(128, (3, 3), padding='same')(up2) up2 = BatchNormalization()(up2) up2 = Activation('relu')(up2) up2 = Conv2D(128, (3, 3), padding='same')(up2) up2 = BatchNormalization()(up2) up2 = Activation('relu')(up2) # 64 up1 = UpSampling2D((2, 2))(up2) up1 = concatenate([down1, up1], axis=3) up1 = Conv2D(64, (3, 3), padding='same')(up1) up1 = BatchNormalization()(up1) up1 = Activation('relu')(up1) up1 = Conv2D(64, (3, 3), padding='same')(up1) up1 = BatchNormalization()(up1) up1 = Activation('relu')(up1) up1 = Conv2D(64, (3, 3), padding='same')(up1) up1 = BatchNormalization()(up1) up1 = Activation('relu')(up1) # 128 up0 = UpSampling2D((2, 2))(up1) up0 = concatenate([down0, up0], axis=3) up0 = Conv2D(32, (3, 3), padding='same')(up0) up0 = BatchNormalization()(up0) up0 = Activation('relu')(up0) up0 = Conv2D(32, (3, 3), padding='same')(up0) up0 = BatchNormalization()(up0) up0 = Activation('relu')(up0) up0 = Conv2D(32, (3, 3), padding='same')(up0) up0 = BatchNormalization()(up0) up0 = Activation('relu')(up0) # 256 up0a = UpSampling2D((2, 2))(up0) up0a = concatenate([down0a, up0a], axis=3) up0a = Conv2D(16, (3, 3), padding='same')(up0a) up0a = BatchNormalization()(up0a) up0a = Activation('relu')(up0a) up0a = Conv2D(16, (3, 3), padding='same')(up0a) up0a = BatchNormalization()(up0a) up0a = Activation('relu')(up0a) up0a = Conv2D(16, (3, 3), padding='same')(up0a) up0a = BatchNormalization()(up0a) up0a = Activation('relu')(up0a) # 512 up0b = UpSampling2D((2, 2))(up0a) up0b = concatenate([down0b, up0b], axis=3) up0b = Conv2D(8, (3, 3), padding='same')(up0b) up0b = BatchNormalization()(up0b) up0b = Activation('relu')(up0b) up0b = Conv2D(8, (3, 3), padding='same')(up0b) up0b = BatchNormalization()(up0b) up0b = Activation('relu')(up0b) up0b = Conv2D(8, (3, 3), padding='same')(up0b) up0b = BatchNormalization()(up0b) up0b = Activation('relu')(up0b) # 1024 classify = Conv2D(num_classes, (1, 1), activation='sigmoid')(up0b) model = Model(inputs=inputs, outputs=classify) model.compile(optimizer=RMSprop(lr=learning_rate), loss=make_loss('bce_dice'), metrics=[dice_coef, 'accuracy']) if pretrained_weights: model.load_weights(pretrained_weights) return model
现在我需要修改问题,使其成为多类分类器,所以我不再使用一个掩码,而是使用两个。因此,我有两种类型的灰度掩码(Mca_mask
和NotMca_mask
,它们属于同一个训练图像),在这种情况下,标准做法是什么?将两个掩码合并成一个吗?
回答: