是否可以使用fit_generator来实现这个功能?
我正在创建一个U-net网络,我希望输入一张高500像素,宽500像素,包含5个通道的图片,输出则为高500像素,宽500像素,1个通道的图片。
如果不可行,我可以自己创建500x500x5的numpy数组,然后我需要一个生成器从硬盘加载这些numpy对象。
我现在的代码(仅适用于RGB图片)
train_generator=datagen.flow_from_directory('/content/data/', target_size=(500,500), color_mode='rgb', batch_size=32, class_mode='categorical', shuffle=False)
回答:
你应该创建自己的生成器。关键是要有一个while True:
结构,并使用yield
来输出你的数据。代码看起来像这样
batch_size=16step_ep=data_size//batch_sizedef generator(): while True: for i in range(step_ep): 处理你的图像 X=images Y=labels yield X,Y
其中X
的形状为(batch_size,height,width,channel)
,Y
的形状为(batch_size,height,width,output_channel)
你应该使用model.fit()
而不是model.fit_generator
,因为后者很快会被弃用。
你也可以为你的验证数据创建一个生成器。
你的model.fit()
看起来像这样:
model.fit(generator(),epochs,steps_per_epoch=step_ep, validation_data=val_generator,validation_steps=val_steps)