问题总结
在下面的例子中,我的NMT模型由于正确预测了target_input
而不是target_output
,导致了高损失。
Targetin : 1 3 3 3 3 6 6 6 9 7 7 7 4 4 4 4 4 9 9 10 10 10 3 3 10 10 3 10 3 3 10 10 3 9 9 4 4 4 4 4 3 10 3 3 9 9 3 6 6 6 6 6 6 10 9 9 10 10 4 4 4 4 4 4 4 4 4 4 4 4 9 9 9 9 3 3 3 6 6 6 6 6 9 9 10 3 4 4 4 4 4 4 4 4 4 4 4 4 9 9 10 3 10 9 9 3 4 4 4 4 4 4 4 4 4 10 10 4 4 4 4 4 4 4 4 4 4 9 9 10 3 6 6 6 6 3 3 3 10 3 3 3 4 4 4 4 4 4 4 4 4 4 4 4 4 9 9 3 3 10 6 6 6 6 6 3 9 9 3 3 3 3 3 3 3 10 10 3 9 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 9 3 6 6 6 6 6 6 3 5 3 3 3 3 10 10 10 3 9 9 5 10 3 3 3 3 9 9 9 5 10 10 10 10 10 4 4 4 4 3 10 6 6 6 6 6 6 3 5 10 10 10 10 3 9 9 6 6 6 6 6 6 6 6 6 9 9 9 3 3 3 6 6 6 6 6 6 6 6 3 9 9 9 3 3 6 6 6 3 3 3 3 3 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0Targetout : 3 3 3 3 6 6 6 9 7 7 7 4 4 4 4 4 9 9 10 10 10 3 3 10 10 3 10 3 3 10 10 3 9 9 4 4 4 4 4 3 10 3 3 9 9 3 6 6 6 6 6 6 10 9 9 10 10 4 4 4 4 4 4 4 4 4 4 4 4 9 9 9 9 3 3 3 6 6 6 6 6 9 9 10 3 4 4 4 4 4 4 4 4 4 4 4 4 9 9 10 3 10 9 9 3 4 4 4 4 4 4 4 4 4 10 10 4 4 4 4 4 4 4 4 4 4 9 9 10 3 6 6 6 6 3 3 3 10 3 3 3 4 4 4 4 4 4 4 4 4 4 4 4 4 9 9 3 3 10 6 6 6 6 6 3 9 9 3 3 3 3 3 3 3 10 10 3 9 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 9 3 6 6 6 6 6 6 3 5 3 3 3 3 10 10 10 3 9 9 5 10 3 3 3 3 9 9 9 5 10 10 10 10 10 4 4 4 4 3 10 6 6 6 6 6 6 3 5 10 10 10 10 3 9 9 6 6 6 6 6 6 6 6 6 9 9 9 3 3 3 6 6 6 6 6 6 6 6 3 9 9 9 3 3 6 6 6 3 3 3 3 3 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0Prediction : 3 3 3 3 3 6 6 6 9 7 7 7 4 4 4 4 4 9 3 3 3 3 3 3 10 3 3 10 3 3 10 3 3 9 3 4 4 4 4 4 3 10 3 3 9 3 3 6 6 6 6 6 6 10 9 3 3 3 4 4 4 4 4 4 4 4 4 4 4 4 9 3 3 3 3 3 3 6 6 6 6 6 9 6 3 3 4 4 4 4 4 4 4 4 4 4 4 4 9 3 3 3 10 9 3 3 4 4 4 4 4 4 4 4 4 3 10 4 4 4 4 4 4 4 4 4 4 9 3 3 3 6 6 6 6 3 3 3 10 3 3 3 4 4 4 4 4 4 4 4 4 4 4 4 4 9 3 3 3 10 6 6 6 6 6 3 9 3 3 3 3 3 3 3 3 3 3 3 9 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 9 3 6 6 6 6 6 6 3 3 3 3 3 3 10 3 3 3 9 3 3 10 3 3 3 3 9 3 9 3 10 3 3 3 3 4 4 4 4 3 10 6 6 6 6 6 6 3 3 10 3 3 3 3 9 3 6 6 6 6 6 6 6 6 6 9 6 9 3 3 3 6 6 6 6 6 6 6 6 3 9 3 9 3 3 6 6 6 3 3 3 3 3 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6Source : 9 16 4 7 22 22 19 1 12 19 12 18 5 18 9 18 5 8 12 19 19 5 5 19 22 7 12 12 6 19 7 3 20 7 9 14 4 11 20 12 7 1 18 7 7 5 22 9 13 22 20 19 7 19 7 13 7 11 19 20 6 22 18 17 17 1 12 17 23 7 20 1 13 7 11 11 22 7 12 1 13 12 5 5 19 22 5 5 20 1 5 4 12 9 7 12 8 14 18 22 18 12 18 17 19 4 19 12 11 18 5 9 9 5 14 7 11 6 4 17 23 6 4 5 12 6 7 14 4 20 6 8 12 25 4 19 6 1 5 1 5 20 4 18 12 12 1 11 12 1 25 13 18 19 7 12 7 3 4 22 9 9 12 4 8 9 19 9 22 22 19 1 19 7 5 19 4 5 18 11 13 9 4 14 12 13 20 11 12 11 7 6 1 11 19 20 7 22 22 12 22 22 9 3 8 12 11 14 16 4 11 7 11 1 8 5 5 7 18 16 22 19 9 20 4 12 18 7 19 7 1 12 18 17 12 19 4 20 9 9 1 12 5 18 14 17 17 7 4 13 16 14 12 22 12 22 18 9 12 11 3 18 6 20 7 4 20 7 9 1 7 25 13 5 25 14 11 5 20 7 23 12 5 16 19 19 25 19 7 -1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
显然,预测几乎100%与target_input
匹配,而不是应该匹配的target_output
(偏差一个)。损失和梯度是使用target_output
计算的,因此预测与target_input
匹配的情况很奇怪。
模型概述
NMT模型使用源语言中的主要单词序列预测目标语言中的单词序列。这是谷歌翻译背后的框架。由于NMT使用耦合的RNN,因此它是监督学习,需要标记的目标输入和输出。
NMT使用source
序列、target_input
序列和target_output
序列。在下面的例子中,编码器RNN(蓝色)使用源输入单词生成一个意义向量,并将其传递给解码器RNN(红色),后者使用该意义向量生成输出。
在进行新的预测(推理)时,解码器RNN使用其之前的输出作为下一个时间步预测的种子。然而,为了提高训练效果,每个新时间步允许它使用正确的之前预测作为种子。这就是为什么target_input
对于训练是必要的。
获取包含source、target_in、target_out的迭代器的代码
def get_batched_iterator(hparams, src_loc, tgt_loc): if not (os.path.exists('primary.csv') and os.path.exists('secondary.csv')): utils.integerize_raw_data() source_dataset = tf.data.TextLineDataset(src_loc) target_dataset = tf.data.TextLineDataset(tgt_loc) dataset = tf.data.Dataset.zip((source_dataset, target_dataset)) dataset = dataset.shuffle(hparams.shuffle_buffer_size, seed=hparams.shuffle_seed) dataset = dataset.map(lambda source, target: (tf.string_to_number(tf.string_split([source], delimiter=',').values, tf.int32), tf.string_to_number(tf.string_split([target], delimiter=',').values, tf.int32))) dataset = dataset.map(lambda source, target: (source, tf.concat(([hparams.sos], target), axis=0), tf.concat((target, [hparams.eos]), axis=0))) dataset = dataset.map(lambda source, target_in, target_out: (source, target_in, target_out, tf.size(source), tf.size(target_in))) # 继续批处理并返回迭代器
NMT模型核心代码
def __init__(self, hparams, iterator, mode): source, target_in, target_out, source_lengths, target_lengths = iterator.get_next() # 查找嵌入 embedding_encoder = tf.get_variable("embedding_encoder", [hparams.src_vsize, hparams.src_emsize]) encoder_emb_inp = tf.nn.embedding_lookup(embedding_encoder, source) embedding_decoder = tf.get_variable("embedding_decoder", [hparams.tgt_vsize, hparams.tgt_emsize]) decoder_emb_inp = tf.nn.embedding_lookup(embedding_decoder, target_in) # 构建并运行编码器LSTM encoder_cell = tf.nn.rnn_cell.BasicLSTMCell(hparams.num_units) encoder_outputs, encoder_state = tf.nn.dynamic_rnn(encoder_cell, encoder_emb_inp, sequence_length=source_lengths, dtype=tf.float32) # 构建并运行解码器LSTM,使用TrainingHelper和输出投影层 decoder_cell = tf.nn.rnn_cell.BasicLSTMCell(hparams.num_units) projection_layer = layers_core.Dense(hparams.tgt_vsize, use_bias=False) helper = tf.contrib.seq2seq.TrainingHelper(decoder_emb_inp, sequence_length=target_lengths) decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper, encoder_state, output_layer=projection_layer) outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder) logits = outputs.rnn_output if mode is 'TRAIN' or mode is 'EVAL': # 然后计算损失 crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target_out, logits=logits) target_weights = tf.sequence_mask(target_lengths, maxlen=tf.shape(target_out)[1], dtype=logits.dtype) self.loss = tf.reduce_sum((crossent * target_weights) / hparams.batch_size) if mode is 'TRAIN': # 然后计算/剪裁梯度,然后优化模型 params = tf.trainable_variables() gradients = tf.gradients(self.loss, params) clipped_gradients, _ = tf.clip_by_global_norm(gradients, hparams.max_gradient_norm) optimizer = tf.train.AdamOptimizer(hparams.l_rate) self.update_step = optimizer.apply_gradients(zip(clipped_gradients, params)) if mode is 'EVAL': # 然后允许访问输入/输出张量以打印输出 self.src = source self.tgt_in = target_in self.tgt_out = target_out self.logits = logits
回答:
用于预测具有重复结构的类语言语法的NMT模型的核心问题是,它倾向于简单地预测之前的预测。由于在每个步骤中通过TrainingHelper
提供正确的之前预测来加速训练,这人为地产生了一个模型无法摆脱的局部最小值。
我找到的最佳解决方案是调整损失函数的权重,使输出序列中非重复的关键点得到更高的权重。这将激励模型正确预测这些点,而不仅仅是重复之前的预测。