使用序列训练神经网络(当前未收敛)

由于其递归特性,我已经能够激活一个仅有一个输入神经元的LSTM,通过一次输入一个项目来处理序列。

然而,当我尝试使用相同技术训练网络时,它从未收敛。训练过程会一直持续下去。

我所做的是,将一个自然语言字符串转换为二进制,然后一次输入一个数字。我转换成二进制的原因是网络只能接受0到1之间的值。

我知道训练是有效的,因为当我用与输入神经元数量相同的值数组进行训练时,在这种情况下是1个值,如:[0],它能够收敛并正常训练。

我想我可以单独传递每个数字,但这样每个数字都会有一个单独的理想输出。当这个数字在另一个训练集中再次出现并有另一个理想输出时,它不会收敛,因为例如0怎么可能既属于类0又属于类1?请告诉我我的这个假设是否有误。

如何训练这个LSTM以序列方式,使得激活时相似的序列能被分类为相似?

这是我的完整训练文件:https://github.com/theirf/synaptic/blob/master/src/trainer.js

这是用于在工作线程上训练网络的代码:

workerTrain: function(set, callback, options) {    var that = this;    var error = 1;    var iterations = bucketSize = 0;    var input, output, target, currentRate;    var length = set.length;    var start = Date.now();    if (options) {        if (options.shuffle) {            function shuffle(o) { //v1.0                for (var j, x, i = o.length; i; j = Math.floor(Math.random() *          i), x = o[--i], o[i] = o[j], o[j] = x);                return o;            };          }          if(options.iterations) this.iterations = options.iterations;          if(options.error) this.error = options.error;          if(options.rate) this.rate = options.rate;          if(options.cost) this.cost = options.cost;          if(options.schedule) this.schedule = options.schedule;          if (options.customLog){            // for backward compatibility with code that used customLog            console.log('Deprecated: use schedule instead of customLog')            this.schedule = options.customLog;          }    }    // dynamic learning rate    currentRate = this.rate;    if(Array.isArray(this.rate)) {        bucketSize = Math.floor(this.iterations / this.rate.length);    }    // create a worker    var worker = this.network.worker();    // activate the network    function activateWorker(input)        {            worker.postMessage({                 action: "activate",                input: input,                memoryBuffer: that.network.optimized.memory            }, [that.network.optimized.memory.buffer]);        }        // backpropagate the network        function propagateWorker(target){            if(bucketSize > 0) {                  var currentBucket = Math.floor(iterations / bucketSize);                  currentRate = this.rate[currentBucket];            }            worker.postMessage({                 action: "propagate",                target: target,                rate: currentRate,                memoryBuffer: that.network.optimized.memory            }, [that.network.optimized.memory.buffer]);        }        // train the worker        worker.onmessage = function(e){            // give control of the memory back to the network            that.network.optimized.ownership(e.data.memoryBuffer);            if(e.data.action == "propagate"){                if(index >= length){                    index = 0;                    iterations++;                    error /= set.length;                    // log                    if(options){                        if(this.schedule && this.schedule.every && iterations % this.schedule.every == 0)                        abort_training = this.schedule.do({                            error: error,                            iterations: iterations                        });                        else if(options.log && iterations % options.log == 0){                            console.log('iterations', iterations, 'error', error);                        };                        if(options.shuffle) shuffle(set);                    }                    if(!abort_training && iterations < that.iterations && error > that.error){                        activateWorker(set[index].input);                    }                    else{                        // callback                        callback({                           error: error,                           iterations: iterations,                           time: Date.now() - start                        })                    }                    error = 0;                }                else{                    activateWorker(set[index].input);               }        }        if(e.data.action == "activate"){            error += that.cost(set[index].output, e.data.output);            propagateWorker(set[index].output);             index++;        }    }

回答:

自然语言字符串不应转换为二进制来进行归一化。应使用独热编码(one-hot encoding):

enter image description here

此外,我建议你考虑使用Neataptic而不是Synaptic。它修复了Synaptic中的许多错误,并且提供了更多的功能供你使用。它在训练过程中有一个特殊选项,称为clear。这个选项告诉网络在每次训练迭代时重置上下文,以便它知道是从头开始的。

Related Posts

L1-L2正则化的不同系数

我想对网络的权重同时应用L1和L2正则化。然而,我找不…

使用scikit-learn的无监督方法将列表分类成不同组别,有没有办法?

我有一系列实例,每个实例都有一份列表,代表它所遵循的不…

f1_score metric in lightgbm

我想使用自定义指标f1_score来训练一个lgb模型…

通过相关系数矩阵进行特征选择

我在测试不同的算法时,如逻辑回归、高斯朴素贝叶斯、随机…

可以将机器学习库用于流式输入和输出吗?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

在TensorFlow中,queue.dequeue_up_to()方法的用途是什么?

我对这个方法感到非常困惑,特别是当我发现这个令人费解的…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注