我在使用pytorch的nn.MultiheadAttention
实现变换器编码器中的自注意力部分时,对变换器的填充掩码感到困惑。
下图展示了查询(行)和键(列)的自注意力权重。
如您所见,有一些标记为”<PAD>”的标记,我已经在键中对其进行了掩码处理。因此,这些标记不会计算注意力权重。
仍然有两个问题:
-
在查询部分,除了红色方块部分之外,我也能够掩码这些”<PAD>”吗?这样做合理吗?
-
如何在查询中掩码”<PAD>”?
注意力权重还通过在src_mask
或src_key_padding_mask
参数中提供掩码,沿着行使用softmax
函数。如果我将所有”<PAD>”行设置为-inf
,softmax
将返回nan
,损失也会变成nan
。
回答:
在自注意力过程中没有必要对查询进行掩码,只要在网络的后续部分不使用对应于<PAD>
标记的状态(无论是作为隐藏状态还是键/值),它们就不会影响损失函数或网络中的其他任何部分。
如果您想确保没有因梯度通过<PAD>
标记而导致的错误,您可以在自注意力计算后使用torch.where
明确地将其置零。