如何初始化现有Keras模型的变量?

我想从头开始重新训练Keras模型Inception_v3。

模型在这里定义:https://github.com/keras-team/keras-applications/blob/master/keras_applications/inception_v3.py

我读了一些帖子,

列出的解决方案有:

  1. 冻结层(这不是我想要的…)

    for layer in model.layers:    layer.trainable = False
  2. https://stackoverflow.com/a/51727616/7748163

通过检查初始化器重置所有层:

def reset_weights(model):    session = K.get_session()    for layer in model.layers:         if hasattr(layer, 'kernel_initializer'):            layer.kernel_initializer.run(session=session)        if hasattr(layer, 'bias_initializer'):            layer.bias_initializer.run(session=session)     
  1. 使用tf.variables_initializer

      model = InceptionV3()  for layer in model.layers:      sess.run(tf.variables_initializer(layer.weights))

    参考:https://stackoverflow.com/a/56634827/7748163

  2. 我认为这是最好的方法,但它引发了错误。

    sess = tf.Session()for layer in model.layers:      for v in layer.__dict__:         v_arg = getattr(layer,v)         if hasattr(v_arg,'initializer'):             initializer_method = getattr(v_arg, 'initializer')             initializer_method.run(session=sess)             print('reinitializing layer {}.{}'.format(layer.name, v))

然而,这些方法对于Inception_v3都不起作用。

错误信息涉及BatchNorm层:

tensorflow.python.framework.errors_impl.FailedPreconditionError: Error while reading resource variable batch_normalization_9/moving_mean from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/batch_normalization_9/moving_mean/N10tensorflow3VarE does not exist.     [[{{node batch_normalization_9_1/AssignMovingAvg/ReadVariableOp}}]]     [[metrics_1/categorical_accuracy/Identity/_469]]

那么,如何重新训练现有的Keras模型,并初始化变量?从Keras应用中重新训练模型的最佳实践是什么?

进一步讨论:

https://github.com/keras-team/keras/issues/341


回答:

为什么不简单地不请求权重呢?

model = Inception_V3(..., weights=None,...)

https://github.com/keras-team/keras-applications/blob/master/keras_applications/inception_v3.py/#L100

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注