使用tf.Keras.Sequential API和LSTM在评估模型时激活循环丢弃

我之前训练了一个没有使用丢弃机制的LSTM编码器/解码器,以便使用GPU加速训练。现在我重建了模型并添加了丢弃和循环丢弃,并且希望在推理过程中激活这些机制来估计我的预测的置信区间。我不确定在使用顺序API的模型中如何或在哪里添加training=True参数。我以为可以在评估模型时这样做,但似乎并非如此:

encoder_decoder_dropout.evaluate(val, steps=VALIDATION_STEPS, training=True)evaluate() got an unexpected keyword argument 'training'

我的模型如下。关于如何在推理过程中激活丢弃机制的任何建议将非常受欢迎。

encoder = encoder_decoder_dropout_trained.layers[0]decoder_dropout = tf.keras.Sequential([  tf.keras.layers.RepeatVector(look_back, input_shape=[60]),  tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(100, dropout=0.2,                                                     recurrent_dropout=0.2,                                                     return_sequences=False)),   tf.keras.layers.Dense(look_forward*num_features,                        kernel_initializer=tf.initializers.glorot_normal()),  tf.keras.layers.Reshape([look_forward, num_features])], name = 'decoder')encoder_decoder_dropout = tf.keras.Sequential([encoder, decoder_dropout])encoder_decoder_dropout.set_weights(encoder_decoder_dropout_trained.get_weights()) 

回答:

要在推理时激活丢弃机制,你只需在感兴趣的层中指定training=True(在你的情况下是LSTM层)

使用training=False

inp = Input(shape=(10, 1))x = LSTM(1, recurrent_dropout=0.3)(inp, training=False)m = Model(inp,x)# m.compile(...)# m.fit(...)X = np.random.uniform(0,1, (1,10,1))output = []for i in range(0,100):    output.append(m.predict(X)) # 每次结果相同

使用training=True

inp = Input(shape=(10, 1))x = LSTM(1, recurrent_dropout=0.3)(inp, training=True)m = Model(inp,x)# m.compile(...)# m.fit(...)X = np.random.uniform(0,1, (1,10,1))output = []for i in range(0,100):    output.append(m.predict(X)) # 每次结果不同

你需要使用Keras的函数式API来实现这一点

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注