根据文档,创建一个transformer模型的方式如下:
transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)src = torch.rand((10, 32, 512))tgt = torch.rand((20, 32, 512)) # tgt是什么?out = transformer_model(src, tgt)
tgt的含义是什么?tgt应该和src相同吗?
回答:
变换器结构由两个部分组成,编码器和解码器。src是输入到编码器的数据,而tgt是输入到解码器的数据。
例如,在执行将英语句子翻译成法语的机器翻译任务时,src是英语序列的ID,而tgt是法语序列的ID。