我正在尝试将子词嵌入平均起来,形成词级表示。每词对应一个开始和结束索引,指示哪些子词构成该词。
sequence_output
是一个 B * 3 * 2 的张量,其中 3 是最大序列长度,2 是特征数量。
all_token_mapping
是一个 B * 3 * 2 的张量,包含开始和结束索引。
initial_reps
是一个 num_nodes * 2 的张量,num_nodes 是所有样本中词(不是子词)的总数。
sequence_output = torch.arange(2*3*2).float().reshape(2, 3, 2)tensor([[[ 0., 1.], [ 2., 3.], [ 4., 5.]], [[ 6., 7.], [ 8., 9.], [10., 11.]]])
all_token_mapping = torch.tensor([[[0,0],[1,2],[-1,-1]], [[0,2],[-1,-1],[-1,-1]]])tensor([[[ 0, 0], [ 1, 2], [-1, -1]], [[ 0, 2], [-1, -1], [-1, -1]]])
num_nodes = 0for sample in all_token_mapping: for mapping in sample: if mapping[0] != -1: num_nodes += 13
initial_reps = torch.empty((num_nodes, 2), dtype=torch.float32)current_idx = 0for i, feature_tokens_mapping in enumerate(all_token_mapping): for j, token_mapping in enumerate(feature_tokens_mapping): if token_mapping[0] == -1: # 到达该序列的末尾 break initial_reps[current_idx] = torch.mean(sequence_output[i][token_mapping[0]:token_mapping[-1] + 1], 0, keepdim=True) current_idx += 1initial_repstensor([[0., 1.], [3., 4.], [8., 9.]])
在上面的示例中,initial_reps[0] 将是 sequence_output[0][0:1] 的均值,initial_reps[1] 将是 sequence_output[0][1:3] 的均值,initial_reps[2] 将是 sequence_output[1][0:3] 的均值。
我当前的代码会创建一个长度为 num_nodes 的空张量,并使用 for 循环计算每个索引的值,通过检查 token_mapping[0] 和 token_mapping[1] 来确定 sequence_output 需要平均的正确切片。
有没有办法将这段代码向量化呢?
此外,我有一个列表,保存了每个样本的词数。即列表中所有元素的总和 == num_nodes
回答:
在 https://discuss.pytorch.org/t/vectorize-for-loop-need-to-average-slices-of-varying-size/122618/2 某人的帮助下,我找到了一个方法
initial_reps_list = []for i, sample_output in enumerate(sequence_output): token_mapping = all_token_mapping[i] token_mapping = token_mapping[token_mapping != -1] non_padded_outputs = sample_output[:num_bert_tokens[i]] initial_reps_list.append(torch_scatter.segment_coo(non_padded_outputs, token_mapping, reduce="mean"))initial_reps = torch.cat(initial_reps_list)
token_mapping 是一个按升序排列的索引列表,最大长度为序列长度,用 -1 填充。我遍历批次,对于每个样本,我获取 token mapping,只保留非负索引。
num_bert_tokens 是一个列表,保存了每个样本的标记数量(无填充)。我获取未填充的输出,使用 segment_coo 根据 token_mapping 进行归约,并将它们全部添加到一个列表中。
循环结束后,我将列表中的所有张量连接在一起。
segment_coo 方法将 src 张量中的所有值归约到 index 张量指定的索引处,沿 index 的最后一个维度进行。更多详情请查看:https://pytorch-scatter.readthedocs.io/en/latest/functions/segment_coo.html
现在运行速度快多了!