我有一个名为 main_decoder 的 3D 张量,形状为 (None,9,256)
我想提取 9 个形状为 (None,256) 的张量
我尝试使用 Keras 的 gather 函数,以下是我的代码片段:
for i in range(0,9): sub_decoder_input = Lambda(lambda main_decoder:gather(main_decoder,(i)), name='lambda'+str(i))(main_decoder)
结果是 9 个形状为 (9,256) 的 lambda 层
如何修改代码,以便我能获取或提取 9 个形状为 (None,256) 的张量?
谢谢。
回答:
你可以将 3D 张量切片成 9 个 2D 张量,并从 Lambda
层返回一个张量列表。
main_decoder = Input(shape=(9, 256))sub_decoder_input = Lambda(lambda x: [x[:, i, :] for i in range(9)])(main_decoder)print(sub_decoder_input)[<tf.Tensor 'lambda_1/strided_slice:0' shape=(?, 256) dtype=float32>, <tf.Tensor 'lambda_1/strided_slice_1:0' shape=(?, 256) dtype=float32>, <tf.Tensor 'lambda_1/strided_slice_2:0' shape=(?, 256) dtype=float32>, <tf.Tensor 'lambda_1/strided_slice_3:0' shape=(?, 256) dtype=float32>, <tf.Tensor 'lambda_1/strided_slice_4:0' shape=(?, 256) dtype=float32>, <tf.Tensor 'lambda_1/strided_slice_5:0' shape=(?, 256) dtype=float32>, <tf.Tensor 'lambda_1/strided_slice_6:0' shape=(?, 256) dtype=float32>, <tf.Tensor 'lambda_1/strided_slice_7:0' shape=(?, 256) dtype=float32>, <tf.Tensor 'lambda_1/strided_slice_8:0' shape=(?, 256) dtype=float32>]