神经机器翻译模型预测出现偏差

问题总结

在下面的例子中,我的NMT模型由于正确预测了target_input而不是target_output,导致了高损失。

Targetin   :  1  3  3  3  3  6  6  6  9  7  7  7  4  4  4  4  4  9  9 10 10 10  3  3 10 10  3 10  3  3 10 10  3  9  9  4  4  4  4  4  3 10  3  3  9  9  3  6  6  6  6  6  6 10  9  9 10 10  4  4  4  4  4  4  4  4  4  4  4  4  9  9  9  9  3  3  3  6  6  6  6  6  9  9 10  3  4  4  4  4  4  4  4  4  4  4  4  4  9  9 10  3 10  9  9  3  4  4  4  4  4  4  4  4  4 10 10  4  4  4  4  4  4  4  4  4  4  9  9 10  3  6  6  6  6  3  3  3 10  3  3  3  4  4  4  4  4  4  4  4  4  4  4  4  4  9  9  3  3 10  6  6  6  6  6  3  9  9  3  3  3  3  3  3  3 10 10  3  9  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  9  3  6  6  6  6  6  6  3  5  3  3  3  3 10 10 10  3  9  9  5 10  3  3  3  3  9  9  9  5 10 10 10 10 10  4  4  4  4  3 10  6  6  6  6  6  6  3  5 10 10 10 10  3  9  9  6  6  6  6  6  6  6  6  6  9  9  9  3  3  3  6  6  6  6  6  6  6  6  3  9  9  9  3  3  6  6  6  3  3  3  3  3  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0Targetout  :  3  3  3  3  6  6  6  9  7  7  7  4  4  4  4  4  9  9 10 10 10  3  3 10 10  3 10  3  3 10 10  3  9  9  4  4  4  4  4  3 10  3  3  9  9  3  6  6  6  6  6  6 10  9  9 10 10  4  4  4  4  4  4  4  4  4  4  4  4  9  9  9  9  3  3  3  6  6  6  6  6  9  9 10  3  4  4  4  4  4  4  4  4  4  4  4  4  9  9 10  3 10  9  9  3  4  4  4  4  4  4  4  4  4 10 10  4  4  4  4  4  4  4  4  4  4  9  9 10  3  6  6  6  6  3  3  3 10  3  3  3  4  4  4  4  4  4  4  4  4  4  4  4  4  9  9  3  3 10  6  6  6  6  6  3  9  9  3  3  3  3  3  3  3 10 10  3  9  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  9  3  6  6  6  6  6  6  3  5  3  3  3  3 10 10 10  3  9  9  5 10  3  3  3  3  9  9  9  5 10 10 10 10 10  4  4  4  4  3 10  6  6  6  6  6  6  3  5 10 10 10 10  3  9  9  6  6  6  6  6  6  6  6  6  9  9  9  3  3  3  6  6  6  6  6  6  6  6  3  9  9  9  3  3  6  6  6  3  3  3  3  3  2  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0Prediction :  3  3  3  3  3  6  6  6  9  7  7  7  4  4  4  4  4  9  3  3  3  3  3  3 10  3  3 10  3  3 10  3  3  9  3  4  4  4  4  4  3 10  3  3  9  3  3  6  6  6  6  6  6 10  9  3  3  3  4  4  4  4  4  4  4  4  4  4  4  4  9  3  3  3  3  3  3  6  6  6  6  6  9  6  3  3  4  4  4  4  4  4  4  4  4  4  4  4  9  3  3  3 10  9  3  3  4  4  4  4  4  4  4  4  4  3 10  4  4  4  4  4  4  4  4  4  4  9  3  3  3  6  6  6  6  3  3  3 10  3  3  3  4  4  4  4  4  4  4  4  4  4  4  4  4  9  3  3  3 10  6  6  6  6  6  3  9  3  3  3  3  3  3  3  3  3  3  3  9  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  9  3  6  6  6  6  6  6  3  3  3  3  3  3 10  3  3  3  9  3  3 10  3  3  3  3  9  3  9  3 10  3  3  3  3  4  4  4  4  3 10  6  6  6  6  6  6  3  3 10  3  3  3  3  9  3  6  6  6  6  6  6  6  6  6  9  6  9  3  3  3  6  6  6  6  6  6  6  6  3  9  3  9  3  3  6  6  6  3  3  3  3  3  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6Source     :  9 16  4  7 22 22 19  1 12 19 12 18  5 18  9 18  5  8 12 19 19  5  5 19 22  7 12 12  6 19  7  3 20  7  9 14  4 11 20 12  7  1 18  7  7  5 22  9 13 22 20 19  7 19  7 13  7 11 19 20  6 22 18 17 17  1 12 17 23  7 20  1 13  7 11 11 22  7 12  1 13 12  5  5 19 22  5  5 20  1  5  4 12  9  7 12  8 14 18 22 18 12 18 17 19  4 19 12 11 18  5  9  9  5 14  7 11  6  4 17 23  6  4  5 12  6  7 14  4 20  6  8 12 25  4 19  6  1  5  1  5 20  4 18 12 12  1 11 12  1 25 13 18 19  7 12  7  3  4 22  9  9 12  4  8  9 19  9 22 22 19  1 19  7  5 19  4  5 18 11 13  9  4 14 12 13 20 11 12 11  7  6  1 11 19 20  7 22 22 12 22 22  9  3  8 12 11 14 16  4 11  7 11  1  8  5  5  7 18 16 22 19  9 20  4 12 18  7 19  7  1 12 18 17 12 19  4 20  9  9  1 12  5 18 14 17 17  7  4 13 16 14 12 22 12 22 18  9 12 11  3 18  6 20  7  4 20  7  9  1  7 25 13  5 25 14 11  5 20  7 23 12  5 16 19 19 25 19  7 -1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0

显然,预测几乎100%与target_input匹配,而不是应该匹配的target_output(偏差一个)。损失和梯度是使用target_output计算的,因此预测与target_input匹配的情况很奇怪。

模型概述

NMT模型使用源语言中的主要单词序列预测目标语言中的单词序列。这是谷歌翻译背后的框架。由于NMT使用耦合的RNN,因此它是监督学习,需要标记的目标输入和输出。

NMT使用source序列、target_input序列和target_output序列。在下面的例子中,编码器RNN(蓝色)使用源输入单词生成一个意义向量,并将其传递给解码器RNN(红色),后者使用该意义向量生成输出。

在进行新的预测(推理)时,解码器RNN使用其之前的输出作为下一个时间步预测的种子。然而,为了提高训练效果,每个新时间步允许它使用正确的之前预测作为种子。这就是为什么target_input对于训练是必要的。

enter image description here

获取包含source、target_in、target_out的迭代器的代码

def get_batched_iterator(hparams, src_loc, tgt_loc):    if not (os.path.exists('primary.csv') and os.path.exists('secondary.csv')):        utils.integerize_raw_data()    source_dataset = tf.data.TextLineDataset(src_loc)    target_dataset = tf.data.TextLineDataset(tgt_loc)    dataset = tf.data.Dataset.zip((source_dataset, target_dataset))    dataset = dataset.shuffle(hparams.shuffle_buffer_size, seed=hparams.shuffle_seed)    dataset = dataset.map(lambda source, target: (tf.string_to_number(tf.string_split([source], delimiter=',').values, tf.int32),                                                  tf.string_to_number(tf.string_split([target], delimiter=',').values, tf.int32)))    dataset = dataset.map(lambda source, target: (source, tf.concat(([hparams.sos], target), axis=0), tf.concat((target, [hparams.eos]), axis=0)))    dataset = dataset.map(lambda source, target_in, target_out: (source, target_in, target_out, tf.size(source), tf.size(target_in)))    # 继续批处理并返回迭代器

NMT模型核心代码

def __init__(self, hparams, iterator, mode):        source, target_in, target_out, source_lengths, target_lengths = iterator.get_next()        # 查找嵌入        embedding_encoder = tf.get_variable("embedding_encoder", [hparams.src_vsize, hparams.src_emsize])        encoder_emb_inp = tf.nn.embedding_lookup(embedding_encoder, source)        embedding_decoder = tf.get_variable("embedding_decoder", [hparams.tgt_vsize, hparams.tgt_emsize])        decoder_emb_inp = tf.nn.embedding_lookup(embedding_decoder, target_in)        # 构建并运行编码器LSTM        encoder_cell = tf.nn.rnn_cell.BasicLSTMCell(hparams.num_units)        encoder_outputs, encoder_state = tf.nn.dynamic_rnn(encoder_cell, encoder_emb_inp, sequence_length=source_lengths, dtype=tf.float32)        # 构建并运行解码器LSTM,使用TrainingHelper和输出投影层        decoder_cell = tf.nn.rnn_cell.BasicLSTMCell(hparams.num_units)        projection_layer = layers_core.Dense(hparams.tgt_vsize, use_bias=False)        helper = tf.contrib.seq2seq.TrainingHelper(decoder_emb_inp, sequence_length=target_lengths)        decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper, encoder_state, output_layer=projection_layer)        outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder)        logits = outputs.rnn_output        if mode is 'TRAIN' or mode is 'EVAL':  # 然后计算损失            crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target_out, logits=logits)            target_weights = tf.sequence_mask(target_lengths, maxlen=tf.shape(target_out)[1], dtype=logits.dtype)            self.loss = tf.reduce_sum((crossent * target_weights) / hparams.batch_size)        if mode is 'TRAIN':  # 然后计算/剪裁梯度,然后优化模型            params = tf.trainable_variables()            gradients = tf.gradients(self.loss, params)            clipped_gradients, _ = tf.clip_by_global_norm(gradients, hparams.max_gradient_norm)            optimizer = tf.train.AdamOptimizer(hparams.l_rate)            self.update_step = optimizer.apply_gradients(zip(clipped_gradients, params))        if mode is 'EVAL':  # 然后允许访问输入/输出张量以打印输出            self.src = source            self.tgt_in = target_in            self.tgt_out = target_out            self.logits = logits

回答:

用于预测具有重复结构的类语言语法的NMT模型的核心问题是,它倾向于简单地预测之前的预测。由于在每个步骤中通过TrainingHelper提供正确的之前预测来加速训练,这人为地产生了一个模型无法摆脱的局部最小值。

我找到的最佳解决方案是调整损失函数的权重,使输出序列中非重复的关键点得到更高的权重。这将激励模型正确预测这些点,而不仅仅是重复之前的预测。

Related Posts

L1-L2正则化的不同系数

我想对网络的权重同时应用L1和L2正则化。然而,我找不…

使用scikit-learn的无监督方法将列表分类成不同组别,有没有办法?

我有一系列实例,每个实例都有一份列表,代表它所遵循的不…

f1_score metric in lightgbm

我想使用自定义指标f1_score来训练一个lgb模型…

通过相关系数矩阵进行特征选择

我在测试不同的算法时,如逻辑回归、高斯朴素贝叶斯、随机…

可以将机器学习库用于流式输入和输出吗?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

在TensorFlow中,queue.dequeue_up_to()方法的用途是什么?

我对这个方法感到非常困惑,特别是当我发现这个令人费解的…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注