TensorFlow: 将GRUCell权重从compat.v1转换到tensorflow 2

我正在尝试将tensorflow 1中保存的模型转换为tensorflow 2。我正在将代码迁移到tensorflow 2,如tensorflow文档中所强调的。然而,我希望简单地将我的model_weights.ckpt更新到tensorflow 2。一些权重(如LinearEmbdedding)的形状与tensorflow 2语法类似,但我很难将我的GRUCell的权重进行转换。

如何将GRUCell权重从compat.v1.nn.rnn_cell.GRUCell转换到keras.layers.GRUCell

GRUCell有四个权重:

  • gru_cell/gates/kernel:0形状为(S + H, 2 x H)
  • gru_cell/gates/bias:0形状为(2 x H, )
  • gru_cell/candidate/kernel:0形状为(S + H, H)
  • gru_cell/candidate/bias:0形状为(H, )

我希望得到与tensorflow 2 API(或PyTorch API)相似的形状的权重,即具有以下权重的GRUCell

  • gru_cell/kernel:0形状为(S, 3 x H)
  • gru_cell/recurrent_kernel:0形状为(H, 3 x H)
  • gru_cell/bias:0形状为(2, 3 x H)

为了说明,您可以重现这些结果:

1. 使用tensorflow 1 API的GRUCell

import tensorflow as tfSEQ_LENGTH = 4HIDDEN_SIZE = 512BATCH_SIZE = 1inputs = tf.random.normal([BATCH_SIZE, SEQ_LENGTH])# GRU cellgru = tf.compat.v1.nn.rnn_cell.GRUCell(HIDDEN_SIZE)# Hidden statestate = gru.zero_state(BATCH_SIZE, tf.float32)# Forwardoutput, state = gru(inputs, state)for weight in gru.weights:    print(weight.name, weight.shape)

输出:

gru_cell/gates/kernel:0 (516, 1024)gru_cell/gates/bias:0 (1024,)gru_cell/candidate/kernel:0 (516, 512)gru_cell/candidate/bias:0 (512,)

2. 使用tensorflow 2 API的GRUCell

import tensorflow as tfSEQ_LENGTH = 4HIDDEN_SIZE = 512BATCH_SIZE = 1inputs = tf.random.normal([BATCH_SIZE , SEQ_LENGTH])# GRU cellgru = tf.keras.layers.GRUCell(HIDDEN_SIZE)# Hidden statestate = tf.zeros((BATCH_SIZE, HIDDEN_SIZE), dtype=tf.float32)# Forwardoutput, state = gru(inputs, state)# Display the weigthsfor weight in gru.weights:    print(weight.name, weight.shape)

输出:

gru_cell/kernel:0 (4, 1536)gru_cell/recurrent_kernel:0 (512, 1536)gru_cell/bias:0 (2, 1536)

注意


回答:

为了社区的利益,这里提供解决方案,尽管它已经在Github上展示了。

简而言之,compat.v1.nn.rnn_cell.GRUCellkeras.layers.GRUCell之间的权重彼此不兼容。我们没有用于在它们之间转换的函数,如果你真的想这样做,你需要手动完成。

从数学上讲,如果你有v1权重的numpy值,公式如下:

B = batch_size

H = state_size

  1. all_kernel = np.concat([gru_cell/gates/kernel, gru_cell/candidate/kernel], axis=1) # 形状 (B+H, 3 * H)
  2. kernel = all_kernel[:B] # 形状(B, 3 * H)
  3. recurrent_kernel = all_kernel[B:] # 形状 (H, 3 * H)
  4. bias = np.concat([gru_cell/gates/bias, gru_cell/candidate/bias], axis=0) # 形状 (B, 3 * H)
  5. zero_bias = np.zeros([B, 3 * H])
  6. bias = np.concat([bias, zero_bias], axis=0)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注