根据不同类别对训练集的一部分应用不同的数据增强

我正在进行一个图像分类的机器学习项目。我的问题是我的数据集不平衡,在我的5个图像类别中,其中一个类别有大约400张图像,而其他每个类别只有大约20张图像。

我想通过仅对训练集的某些类别应用数据增强来平衡我的训练集。

以下是我创建训练和验证集的代码:

# 导入数据data_dir = pathlib.Path(r"C:\Train set")# 定义训练和验证集(80% - 20%)batch_size = 32img_height = 240img_width = 240train_ds = tf.keras.preprocessing.image_dataset_from_directory(  data_dir,  validation_split=0.2,  subset="training",  seed=123,  image_size=(img_height, img_width),  batch_size=batch_size)val_ds = tf.keras.preprocessing.image_dataset_from_directory(  data_dir,  validation_split=0.2,  subset="validation",  seed=123,  image_size=(img_height, img_width),  batch_size=batch_size)

以下是我应用数据增强的代码,尽管这是针对整个训练集的:

# 应用数据增强data_augmentation = keras.Sequential(  [    layers.experimental.preprocessing.RandomFlip("horizontal",                                                  input_shape=(img_height,                                                               img_width,                                                              3)),    layers.experimental.preprocessing.RandomRotation(0.1),    layers.experimental.preprocessing.RandomZoom(0.1),  ])

有没有办法进入我的训练集,提取那些图像数量较少的类别,并只对它们应用数据增强?

提前感谢!


回答:

我建议不使用 ImageDataGenerator,而是使用自定义的 tf.data.Dataset。在映射操作中,你可以对不同类别进行不同的处理,例如:

def preprocess(filepath):    category = tf.strings.split(filepath, os.sep)[0]    read_file = tf.io.read_file(filepath)    decode = tf.image.decode_jpeg(read_file, channels=3)    resize = tf.image.resize(decode, (200, 200))    image = tf.expand_dims(resize, 0)    if tf.equal(category, 'tf_astronauts'):        image = tf.image.flip_up_down(image)        image = tf.image.flip_left_right(image)    # image = tf.image.convert_image_dtype(image, tf.float32)    # category = tf.cast(tf.equal(category, 'tf_astronauts'), tf.int32)    return image, category

让我来演示一下。让我们为你创建一个包含训练图像的文件夹:

import tensorflow as tfimport matplotlib.pyplot as pltimport cv2from skimage import datafrom glob2 import globimport oscat = data.chelsea()astronaut = data.astronaut()for category, picture in zip(['tf_cats', 'tf_astronauts'], [cat, astronaut]):    os.makedirs(category, exist_ok=True)    for i in range(5):        cv2.imwrite(os.path.join(category, category + f'_{i}.jpg'),                    cv2.cvtColor(picture, cv2.COLOR_RGB2BGR))files = glob('tf_*\\*.jpg')

现在你有了这些文件:

['tf_astronauts\\tf_astronauts_0.jpg', 'tf_astronauts\\tf_astronauts_1.jpg', 'tf_astronauts\\tf_astronauts_2.jpg', 'tf_astronauts\\tf_astronauts_3.jpg', 'tf_astronauts\\tf_astronauts_4.jpg', 'tf_cats\\tf_cats_0.jpg', 'tf_cats\\tf_cats_1.jpg', 'tf_cats\\tf_cats_2.jpg', 'tf_cats\\tf_cats_3.jpg', 'tf_cats\\tf_cats_4.jpg']

让我们只对宇航员类别应用变换。我们使用 tf.image 变换。

def preprocess(filepath):    category = tf.strings.split(filepath, os.sep)[0]    read_file = tf.io.read_file(filepath)    decode = tf.image.decode_jpeg(read_file, channels=3)    resize = tf.image.resize(decode, (200, 200))    image = tf.expand_dims(resize, 0)    if tf.equal(category, 'tf_astronauts'):        image = tf.image.flip_up_down(image)        image = tf.image.flip_left_right(image)    # image = tf.image.convert_image_dtype(image, tf.float32)    # category = tf.cast(tf.equal(category, 'tf_astronauts'), tf.int32)    return image, category

然后,我们创建 tf.data.Dataset

train = tf.data.Dataset.from_tensor_slices(files).\    shuffle(10).take(4).map(preprocess).batch(4)

当你迭代数据集时,你会发现只有宇航员的图像被翻转了:

fig = plt.figure()plt.subplots_adjust(wspace=.1, hspace=.2)images, labels = next(iter(train))for index, (image, label) in enumerate(zip(images, labels)):    ax = plt.subplot(2, 2, index + 1)    ax.set_xticks([])    ax.set_yticks([])    ax.set_title(label.numpy().decode())    ax.imshow(image[0].numpy().astype(int))plt.show()

enter image description here

请注意,为了训练,你需要取消 preprocess 函数中两行的注释,使其返回浮点数数组和整数。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注