如何在内存容量不足的情况下训练神经网络?

尝试使用大量数据来训练带有循环层的分类器,结果发现所有数据无法装入内存。这导致了以下错误:

Error using zerosRequested 1x2114046976 (15.8GB) array exceeds maximum array size preference. Creation of arrays greater than this limit may take a longtime and cause MATLAB to become unresponsive. See array size limit or preference panel for more information.Error in nnMex.perfsGrad (line 3)TEMP = zeros(1,ceil(hints.tempSizeBG/8)*8);Error in nnCalcLib/perfsGrad (line 294)                lib.calcMode.perfsGrad(calcNet,lib.calcData,lib.calcHints);Error in trainscg>initializeTraining (line 153)[worker.perf,worker.vperf,worker.tperf,worker.gWB,worker.gradient] = calcLib.perfsGrad(calcNet);Error in nnet.train.trainNetwork>trainNetworkInMainThread (line 28)worker = localFcns.initializeTraining(archNet,calcLib,calcNet,tr);Error in nnet.train.trainNetwork (line 16)    [archNet,tr] = trainNetworkInMainThread(archNet,rawData,calcLib,calcNet,tr,feedback,localFcns);Error in trainscg>train_network (line 147)[archNet,tr] = nnet.train.trainNetwork(archNet,rawData,calcLib,calcNet,tr,localfunctions);Error in trainscg (line 59)            [out1,out2] = train_network(varargin{2:end});Error in network/train (line 369)    [net,tr] = feval(trainFcn,'apply',net,data,calcLib,calcNet,tr);

需要注意的是,目前我的训练输入是11×52266,网络由于循环层大约有3千个权重元素。然而,我希望提供15倍于当前的数据量来进行训练。

我该如何应对?有没有什么技术可以将它试图初始化的本地变量映射到我的SSD上,而不是内存?

虽然有“reduction”选项用于训练,但似乎在这方面没有任何效果。无论如何,同样的错误仍然发生。


回答:

一般来说,如果你的数据集太大而无法装入内存,你需要分块处理。对于训练大型网络,通常使用随机梯度下降(只需要一次访问一个数据点)或小批量训练(只需要访问小批量中的数据点)。除了需要较少的内存外,这些方法通常比批量梯度下降(每次权重更新使用整个数据集)收敛得更快。磁盘访问速度慢,因此即使每次更新只需要少数数据点,你也应该尽可能多地加载数据点,然后将它们分成小批量等等。你还可以使用其他技巧来减少磁盘读取次数,比如在加载下一组数据之前进行多次更新。

另一个特定于循环神经网络(RNNs)的要点。当你使用时间反向传播(BPTT)训练RNN时,网络必须在时间上“展开”,并被视为在每个时间步都有一个循环层副本的非常深的前馈网络。这意味着在更多时间步上执行BPTT需要更多的内存(和更多的计算时间)。一个解决方案是使用截断的BPTT,其中梯度只在固定数量的时间步上反向传播。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注