向已训练的TensorFlow Keras模型添加重缩放层(或任何其他层)

我有一个使用TensorFlow 2.3训练的TensorFlow Keras模型。该模型的输入是图像,但由于模型在训练时使用了缩放的输入,因此我们必须在将图像输入模型之前将其缩放到255。

由于我们在各种平台上使用这个模型,我尝试通过修改模型来简化这一过程,即在Keras模型的开始处插入一个重缩放层(即紧跟在输入之后)。这样,未来使用这个模型时,用户只需传递图像,而无需进行缩放处理。

我在这方面遇到了很多困难。我明白我需要使用以下函数来创建一个重缩放层;

tf.keras.layers.experimental.preprocessing.Rescaling(255, 0.0, "rescaling")

但我不确定如何将它插入到模型的开始部分。

提前感谢您


回答:

您可以在已训练模型的顶部插入此层。下面是一个示例,首先我们手动缩放输入来训练一个模型,然后我们使用相同的已训练模型,但在顶部添加一个Rescaling

from tensorflow.keras.layers.experimental.preprocessing import Rescaling# 生成虚拟数据input_dim = (28,28,3)n_sample = 10X = np.random.randint(0,255, (n_sample,)+input_dim)y = np.random.uniform(0,1, (n_sample,))# 创建基础模型inp = Input(input_dim)x = Conv2D(8, (3,3))(inp)x = Flatten()(x)out = Dense(1)(x)# 使用手动缩放拟合基础模型model = Model(inp, out)model.compile('adam', 'mse')model.fit(X/255, y, epochs=3)# 创建带有预训练权重和顶部重缩放的新模型inp = Input(input_dim)scaled_input = Rescaling(1/255, 0.0, "rescaling")(inp)out = model(scaled_input)scaled_model = Model(inp, out)# 比较手动缩放与层缩放的预测结果pred = model.predict(X/255)pred_scaled = scaled_model.predict(X)(pred.round(5) == pred_scaled.round(5)).all() # True

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注