在Keras中使用函数式API将遮蔽层添加到GRU模型的正确方法?

我正在尝试弄清楚如何在Keras的函数式API中使用遮蔽层

使用非函数式的Keras语法,我可以轻松创建一个GRU模型,像这样遮蔽所有零值:

model = tf.keras.Sequential()model.add( tf.keras.layers.Masking( mask_value = 0.0, input_shape = ( nTimeSteps, nVariables ) ) )model.add( tf.keras.layers.GRU( 32 ) )model.add( tf.keras.layers.Dense( 10, activation = "softmax" ) )opt = tf.keras.optimizers.SGD( learning_rate = 0.001 )model.compile( loss = 'categorical_crossentropy', optimizer = opt, metrics = ['accuracy'] )

我尝试使用函数式API复制这个模型,目前看起来是这样的:

x = tf.keras.layers.Masking( mask_value = 0.0, input_shape = ( nTimeSteps, nVariables ) )x = tf.keras.layers.GRU( 32 )( x )z = tf.keras.layers.Dense( numberOfOutputs, activation = "softmax" )( x )model = tf.keras.Model( inputs = x, outputs = z )opt = tf.keras.optimizers.SGD( learning_rate = 0.001 )model.compile( loss = 'categorical_crossentropy', optimizer = opt, metrics = ['accuracy'] )

然而,这不起作用 – 它会产生以下错误:

AttributeError: 'Masking' object has no attribute 'shape'

在函数式API中使用遮蔽层正确的做法是什么?


回答:

你缺少了函数式API格式中的Input层。这里是一个示例代码:

nsamples = 10nTimeSteps, nVariables = 6, 4numberOfOutputs = 2X = np.random.randint(0,6, (nsamples ,nTimeSteps, nVariables))y = np.random.randint(0,numberOfOutputs, nsamples)inp = tf.keras.Input(shape = ( nTimeSteps, nVariables ))x = tf.keras.layers.Masking( mask_value = 0.0 )(inp)x = tf.keras.layers.GRU( 32 )( x )z = tf.keras.layers.Dense( numberOfOutputs, activation = "softmax" )( x )model = tf.keras.Model( inputs = inp, outputs = z )opt = tf.keras.optimizers.SGD( learning_rate = 0.001 )model.compile( loss = 'sparse_categorical_crossentropy',               optimizer = opt, metrics = ['accuracy'] )model.fit(X,y, epochs=3)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注