如何计算LSTM网络的参数数量?

是否有办法计算LSTM网络中的总参数数量?

我找到一个例子,但我不确定这个是否正确,或者我是否正确理解了它。

例如,考虑以下例子:

from keras.models import Sequentialfrom keras.layers import Dense, Dropout, Activationfrom keras.layers import Embeddingfrom keras.layers import LSTMmodel = Sequential()model.add(LSTM(256, input_dim=4096, input_length=16))model.summary()

输出

____________________________________________________________________________________________________Layer (type)                       Output Shape        Param #     Connected to                     ====================================================================================================lstm_1 (LSTM)                      (None, 256)         4457472     lstm_input_1[0][0]               ====================================================================================================Total params: 4457472____________________________________________________________________________________________________

根据我的理解,n是输入向量的长度,m是时间步数。在这个例子中,他们认为隐藏层的数量为1。

因此,根据帖子中的公式,我的例子中4(nm+n^2)m=16n=4096num_of_units=256

4*((4096*16)+(4096*4096))*256 = 17246978048

为什么会有这么大的差异?我是否误解了这个例子,还是公式错了?


回答:

不,Keras中LSTM层的参数数量等于:

params = 4 * ((size_of_input + 1) * size_of_output + size_of_output^2)

额外的1来自偏置项。因此,n是输入的大小(增加了偏置项),m是LSTM层的输出大小。

所以最终是:

4 * (4097 * 256 + 256^2) = 4457472

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注