我注意到在 tensorflow.keras.applications
中有一个 preprocess_input
函数,根据你想要使用的模型,这个函数是不同的。
我正在使用 ImageDataGenerator
类来增强我的数据。更具体地说,我使用了一个 CustomDataGenerator
,它继承自 ImageDataGenerator
类并添加了颜色转换功能。
它的样子如下:
class CustomDataGenerator(ImageDataGenerator): def __init__(self, color=False, **kwargs): super().__init__(preprocessing_function=self.augment_color, **kwargs) self.hue = None if color: self.hue = random.random() def augment_color(self, img): if not self.hue or random.random() < 1/3: return img img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) img_hsv[:, :, 0] = self.hue return cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR)
我之前在 ImageDataGenerator
中使用了 rescale=1./255
,但有些模型需要不同的预处理方法。
所以当我尝试
CustomDataGenerator(preprocessing_function=tf.keras.applications.xception.preprocess_input)
我得到了以下错误:
__init__() got multiple values for keyword argument 'preprocessing_function'
回答:
问题在于,你已经在这里传递了一次 preprocessing_function
super().__init__(preprocessing_function=self.augment_color, **kwargs)
然后又从这里传递了一次
CustomDataGenerator(preprocessing_function=tf.keras.applications.xception.preprocess_input)
所以现在看起来像是
super().__init__(preprocessing_function=self.augment_color, preprocessing_function=tf.keras.applications.xception.preprocess_input)
移除其中一个,你就可以继续了。
编辑 1:如果你想保留这两个函数,最好将它们合并成一个预处理方法,并将其作为 preprocessing_function
传递
在 CustomDataGenerator
中添加以下方法
def preprocess(self, image): image = self.augment_color(image) return tf.keras.applications.xception.preprocess_input(image)
使用这个作为预处理函数
super().__init__(preprocessing_function=self.preprocess, **kwargs)