如何在TensorFlow中使用Keras高层API实现批量归一化

批量归一化(BatchNormalization,BN)在训练和推理时的操作略有不同。在训练过程中,它使用当前小批量的平均值和方差来缩放其输入;这意味着批量归一化的具体结果不仅取决于当前输入,还取决于小批量中的所有其他元素。这在推理模式下显然是不希望的,因为我们希望得到一个确定性的结果。因此,在这种情况下,会使用整个训练集的全局平均值和方差的固定统计数据。

在TensorFlow中,这种行为由一个布尔开关training控制,该开关在调用层时需要指定,详见https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization。当使用Keras高层API时,如何处理这个开关?我是否可以假设它会根据我们是使用model.fit(x, ...)还是model.predict(x, ...)来自动处理?


为了测试这一点,我编写了这个示例。我们从一个随机分布开始,我们希望对输入是正数还是负数进行分类。然而,我们还有一个来自不同分布的测试数据集,其输入偏移了2(因此标签检查x是否大于2)。

运行代码后打印的值为0.5,这意味着有一半的例子被正确标记。这是我期望的结果,如果系统使用训练集的全局统计数据来实现BN的话。

如果我们将BN层改为

x = Input(shape=(1,))a = BatchNormalization()(x, training=True)a = Dense(8, activation='sigmoid')(a)a = BatchNormalization()(a, training=True)y = Dense(1, activation='sigmoid')(a)

然后再次运行代码,我们发现结果为0.87。强制始终处于训练状态,正确预测的百分比发生了变化。这与model.predict(x, ...)现在使用小批量的统计数据来实现BN的想法是一致的,因此能够稍微“纠正”训练数据和测试数据之间源分布的不匹配。

这是正确的吗?


回答:

如果我正确理解了你的问题,那么是的,Keras确实会根据fitpredict/evaluate自动管理训练与推理行为。这个标志被称为learning_phase,它决定了批量归一化、丢弃(dropout)以及其他可能事物的行为。可以通过keras.backend.learning_phase()查看当前的学习阶段,并通过keras.backend.set_learning_phase()设置。

https://keras.io/backend/#learning_phase

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注