将TensorFlow的LSTM转换为synapticjs

我正在尝试实现一个接口,将一个已经训练好的TensorFlow基本LSTM模型与可以在浏览器中运行的JavaScript版本连接起来。问题在于,我所阅读的所有文献中,LSTM都被建模为小型网络(仅使用连接、节点和门),而TensorFlow似乎有更多的东西在运作。

我有两个问题:

  1. TensorFlow模型是否可以轻松地转换为更传统的神经网络结构?

  2. 是否有实际的方法将TensorFlow提供的可训练变量映射到这个结构中?

我可以从TensorFlow中提取‘可训练变量’,问题是每个LSTM节点的偏置似乎只有一个值,而我见过的多数模型会为记忆单元、输入和输出包含多个偏置。


回答:

在内部,LSTMCell类为了效率,将LSTM权重存储为一个大矩阵,而不是8个较小的矩阵。将它水平和垂直分割以获得更传统的表示是相当容易的。然而,如果你的库也进行类似的优化,可能会更容易和更有效。

这是BasicLSTMCell的相关代码片段:

concat = linear([inputs, h], 4 * self._num_units, True)# i = input_gate, j = new_input, f = forget_gate, o = output_gatei, j, f, o = array_ops.split(1, 4, concat)

linear函数执行矩阵乘法,将拼接的输入和之前的h状态转换为4个形状为[batch_size, self._num_units]的矩阵。线性变换使用了一个单一的矩阵和偏置变量,这些是你问题中提到的。结果随后被分割成LSTM变换使用的不同门。

如果你想明确地获取每个门的变换,你可以将那个矩阵和偏置分割成4个块。使用4个或8个线性变换从头开始实现它也相当容易。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注