我正在使用这个库 https://github.com/cazala/synaptic
我试图预测以下序列中的下一个值(值 X):
0 0 0 1 0 0 0 1 0 0 0 X
这个值应该是 1
。
这是我的代码:
const options = { peepholes: Layer.connectionType.ALL_TO_ALL, hiddenToHidden: false, outputToHidden: false, outputToGates: false, inputToOutput: true,};// 1个输入,3个隐藏层(每层4个节点),1个输出const lstm = new Architect.LSTM(1, 4, 4, 4, 1, options);const trainingArray = [ { input: [0], output: [0], }, { input: [0], output: [0], }, { input: [0], output: [1], }, { input: [1], output: [0], }, { input: [0], output: [0], }, { input: [0], output: [0], }, { input: [0], output: [1], }, { input: [1], output: [0], },];const trainingOptions = { rate: 0.1, iterations: 100000, error: 0.05, cost: null, crossValidate: null,};let results = lstm.trainer.train(trainingArray, trainingOptions);console.log(results);array = [ 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0,];results = lstm.activate(array);console.log(results);
控制台的输出是:
{ error: 0.049765018466871494, iterations: 673, time: 392 }[ 0.05010961302724895 ]
我期望激活的结果更接近1而不是0(更接近)。我不知道这是库的问题还是我对LSTM了解不足。有人能指导我正确的方向吗?
回答:
我阅读了源代码并找到了解决方法。
const synaptic = require('synaptic');const Architect = synaptic.Architect;const Layer = synaptic.Layer;const lstmOptions = { peepholes: Layer.connectionType.ALL_TO_ALL, hiddenToHidden: false, outputToHidden: false, outputToGates: false, inputToOutput: true,};const lstm = new Architect.LSTM(1, 4, 4, 4, 1, lstmOptions);const trainSet = [ { input: [0], output: [0.1] }, { input: [1], output: [0.2] }, { input: [0], output: [0.3] }, { input: [1], output: [0.4] }, { input: [0], output: [0.5] },];const trainOptions = { rate: 0.2, iterations: 10000, error: 0.005, cost: null, crossValidate: null,};const trainResults = lstm.trainer.train(trainSet, trainOptions);console.log(trainResults);const testResults = [];testResults[0] = lstm.activate([0]);testResults[1] = lstm.activate([1]);testResults[2] = lstm.activate([0]);testResults[3] = lstm.activate([1]);testResults[4] = lstm.activate([0]);console.log(testResults);
结果是:
{ error: 0.004982436660844655, iterations: 2010, time: 384 }[ [ 0.18288280009908592 ], [ 0.2948083898027347 ], [ 0.35061782593064206 ], [ 0.3900799575806566 ], [ 0.49454852760556606 ] ]
这是准确的。