我正在尝试将tensorflow 1中保存的模型转换为tensorflow 2。我正在将代码迁移到tensorflow 2,如tensorflow文档中所强调的。然而,我希望简单地将我的model_weights.ckpt
更新到tensorflow 2。一些权重(如Linear
、Embdedding
)的形状与tensorflow 2语法类似,但我很难将我的GRUCell
的权重进行转换。
如何将GRUCell
权重从compat.v1.nn.rnn_cell.GRUCell
转换到keras.layers.GRUCell
?
GRUCell
有四个权重:
gru_cell/gates/kernel:0
形状为(S + H, 2 x H)
,gru_cell/gates/bias:0
形状为(2 x H, )
,gru_cell/candidate/kernel:0
形状为(S + H, H)
,gru_cell/candidate/bias:0
形状为(H, )
我希望得到与tensorflow 2 API(或PyTorch API)相似的形状的权重,即具有以下权重的GRUCell
:
gru_cell/kernel:0
形状为(S, 3 x H)
gru_cell/recurrent_kernel:0
形状为(H, 3 x H)
gru_cell/bias:0
形状为(2, 3 x H)
为了说明,您可以重现这些结果:
1. 使用tensorflow 1 API的GRUCell
import tensorflow as tfSEQ_LENGTH = 4HIDDEN_SIZE = 512BATCH_SIZE = 1inputs = tf.random.normal([BATCH_SIZE, SEQ_LENGTH])# GRU cellgru = tf.compat.v1.nn.rnn_cell.GRUCell(HIDDEN_SIZE)# Hidden statestate = gru.zero_state(BATCH_SIZE, tf.float32)# Forwardoutput, state = gru(inputs, state)for weight in gru.weights: print(weight.name, weight.shape)
输出:
gru_cell/gates/kernel:0 (516, 1024)gru_cell/gates/bias:0 (1024,)gru_cell/candidate/kernel:0 (516, 512)gru_cell/candidate/bias:0 (512,)
2. 使用tensorflow 2 API的GRUCell
import tensorflow as tfSEQ_LENGTH = 4HIDDEN_SIZE = 512BATCH_SIZE = 1inputs = tf.random.normal([BATCH_SIZE , SEQ_LENGTH])# GRU cellgru = tf.keras.layers.GRUCell(HIDDEN_SIZE)# Hidden statestate = tf.zeros((BATCH_SIZE, HIDDEN_SIZE), dtype=tf.float32)# Forwardoutput, state = gru(inputs, state)# Display the weigthsfor weight in gru.weights: print(weight.name, weight.shape)
输出:
gru_cell/kernel:0 (4, 1536)gru_cell/recurrent_kernel:0 (512, 1536)gru_cell/bias:0 (2, 1536)
注意
- 我尝试使用
_convert_rnn_weights
tensorflow函数来转换所需的权重。它有效,但仅适用于CuDNN
权重,所以在我的情况下无法使用。
回答:
为了社区的利益,这里提供解决方案,尽管它已经在Github上展示了。
简而言之,compat.v1.nn.rnn_cell.GRUCell
和keras.layers.GRUCell
之间的权重彼此不兼容。我们没有用于在它们之间转换的函数,如果你真的想这样做,你需要手动完成。
从数学上讲,如果你有v1权重的numpy值,公式如下:
B = batch_size
H = state_size
- all_kernel = np.concat([gru_cell/gates/kernel, gru_cell/candidate/kernel], axis=1) # 形状 (B+H, 3 * H)
- kernel = all_kernel[:B] # 形状(B, 3 * H)
- recurrent_kernel = all_kernel[B:] # 形状 (H, 3 * H)
- bias = np.concat([gru_cell/gates/bias, gru_cell/candidate/bias], axis=0) # 形状 (B, 3 * H)
- zero_bias = np.zeros([B, 3 * H])
- bias = np.concat([bias, zero_bias], axis=0)