使用expand_dims()在ImageDataGenerator.flow_from_directory()中

我可以在flow_from_directory()管道中为图像添加一个维度,还是需要编写自己的实现来做到这一点?我能否在数据摄入后使用它,或者它会在训练时将图像拉入内存?谢谢


回答:

生成器根据model.fit的需要拉取图像批次,因此它不会尝试将所有图像存储在内存中,以避免OOM(内存溢出)错误。生成器确实有一个预处理函数,我尝试使用它来扩展维度,但它会抛出错误,所以这种方法行不通。猜想你必须创建自己的生成器来实现你的愿望。我认为下面展示的方法会有效

train_dir=r'c:\temp\people\test' # 指向你的训练目录
# 将数据放入ImageDataGenerator并使用flow_from_directory,我的图像是128x128x3的彩色图像
train_gen=ImageDataGenerator(rescale=1.0/255).flow_from_directory( train_dir, target_size=(128, 128),
                                                                    batch_size=10, seed=123,
                                                                    class_mode='categorical', color_mode='rgb',shuffle=True)
def img_gen(input_gen,axis):
      images, labels=next(train_gen) # 获取下一批图像和标签
    images=np.expand_dims(images,axis=axis) # 扩展图像的维度
    yield (images, labels ) # 输出一个批次的元组给model.fit
# 显示图像的维度已被扩展
    images, labels=next(img_gen(train_gen, axis=4))
print (images.shape)

现在图像的形状将变为(10,128,128,3,1),其中10是我使用的批次大小

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注