使用 Keras 中 ImageDataGenerator 的 preprocessing_function

我注意到在 tensorflow.keras.applications 中有一个 preprocess_input 函数,根据你想要使用的模型,这个函数是不同的。

我正在使用 ImageDataGenerator 类来增强我的数据。更具体地说,我使用了一个 CustomDataGenerator,它继承自 ImageDataGenerator 类并添加了颜色转换功能。

它的样子如下:

class CustomDataGenerator(ImageDataGenerator):    def __init__(self, color=False, **kwargs):        super().__init__(preprocessing_function=self.augment_color, **kwargs)        self.hue = None        if color:            self.hue = random.random()    def augment_color(self, img):        if not self.hue or random.random() < 1/3:            return img        img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)        img_hsv[:, :, 0] = self.hue        return cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR)

我之前在 ImageDataGenerator 中使用了 rescale=1./255,但有些模型需要不同的预处理方法。

所以当我尝试

CustomDataGenerator(preprocessing_function=tf.keras.applications.xception.preprocess_input)

我得到了以下错误:

__init__() got multiple values for keyword argument 'preprocessing_function'

回答:

问题在于,你已经在这里传递了一次 preprocessing_function

super().__init__(preprocessing_function=self.augment_color, **kwargs)

然后又从这里传递了一次

CustomDataGenerator(preprocessing_function=tf.keras.applications.xception.preprocess_input)

所以现在看起来像是

super().__init__(preprocessing_function=self.augment_color, preprocessing_function=tf.keras.applications.xception.preprocess_input)

移除其中一个,你就可以继续了。

编辑 1:如果你想保留这两个函数,最好将它们合并成一个预处理方法,并将其作为 preprocessing_function 传递

CustomDataGenerator 中添加以下方法

    def preprocess(self, image):        image = self.augment_color(image)        return tf.keras.applications.xception.preprocess_input(image)

使用这个作为预处理函数

super().__init__(preprocessing_function=self.preprocess, **kwargs)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注