使用TensorFlow实现批量归一化

我正在尝试在tensor-flow中实现一个批量归一化层。我使用tf.moments来获取均值方差,在训练步骤中运行没有问题。

在测试时,我希望设置一个指数移动平均来跟踪均值和方差。我尝试这样做:

def batch_normalized_linear_layer(state_below, scope_name, n_inputs, n_outputs, stddev, wd, eps=.0001):  with tf.variable_scope(scope_name) as scope:    weight = _variable_with_weight_decay(      "weights", shape=[n_inputs, n_outputs],      stddev=stddev, wd=wd    )    act = tf.matmul(state_below, weight)    # 获取矩    act_mean, act_variance = tf.nn.moments(act, [0])    # 获取均值和方差变量    mean = _variable_on_cpu('bn_mean', [n_outputs], tf.constant_initializer(0.0))    variance = _variable_on_cpu('bn_variance', [n_outputs], tf.constant_initializer(1.0))    # 分配矩    assign_mean = mean.assign(act_mean)    assign_variance = variance.assign(act_variance)    act_bn = tf.mul((act - mean), tf.rsqrt(variance + eps), name=scope.name+"_bn")    beta = _variable_on_cpu("beta", [n_outputs], tf.constant_initializer(0.0))    gamma = _variable_on_cpu("gamma", [n_outputs], tf.constant_initializer(1.0))    bn = tf.add(tf.mul(act_bn, gamma), beta)    output = tf.nn.relu(bn, name=scope.name)    _activation_summary(output)  return output, mean, variance

其中_variable_on_cpu定义如下:

def _variable_on_cpu(name, shape, initializer):  """在CPU内存上创建变量的辅助函数。  Args:    name: 变量名称    shape: 整数列表    initializer: 变量的初始化器  Returns:    Variable Tensor  """  with tf.device('/cpu:0'):    var = tf.get_variable(name, shape, initializer=initializer)  return var

我认为我设置

assign_mean = mean.assign(act_mean)assign_variance = variance.assign(act_variance)

的方式可能不正确,但我并不确定具体问题在哪里。当我使用TensorBoard来跟踪这些均值和方差变量时,它们的值始终保持在初始化的值上。


回答:

@的人的评论指出了问题的核心:你没有运行分配节点。你可以尝试使用我在另一个回答中发布的批量归一化辅助函数 – 如何在TensorFlow中使用批量归一化? – 或者你可以按照他的建议,通过添加with_dependencies来强制执行分配操作。

一般原则是,只有当数据或控制依赖关系“通过”节点时,你才能指望该节点被运行。with_dependencies确保在使用输出操作之前,指定的依赖关系已经完成。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注