使用多层网络预测简单函数

我正在尝试培养对机器学习的直觉。我查看了https://github.com/deeplearning4j/dl4j-0.4-examples上的示例,并想开发自己的示例。基本上,我只是采用了一个简单的函数:a * a + b * b + c * c – a * b * c + a + b + c,并为随机的a、b、c生成了10000个输出,并尝试用90%的输入来训练我的网络。问题是无论我做什么,我的网络都无法预测剩余的示例。

这是我的代码:

public class BasicFunctionNN {    private static Logger log = LoggerFactory.getLogger(MlPredict.class);    public static DataSetIterator generateFunctionDataSet() {        Collection<DataSet> list = new ArrayList<>();        for (int i = 0; i < 100000; i++) {            double a = Math.random();            double b = Math.random();            double c = Math.random();            double output = a * a + b * b + c * c - a * b * c + a + b + c;            INDArray in = Nd4j.create(new double[]{a, b, c});            INDArray out = Nd4j.create(new double[]{output});            list.add(new DataSet(in, out));        }        return new ListDataSetIterator(list, list.size());    }    public static void main(String[] args) throws Exception {        DataSetIterator iterator = generateFunctionDataSet();        Nd4j.MAX_SLICES_TO_PRINT = 10;        Nd4j.MAX_ELEMENTS_PER_SLICE = 10;        final int numInputs = 3;        int outputNum = 1;        int iterations = 100;        log.info("Build model....");        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()                .iterations(iterations).weightInit(WeightInit.XAVIER).updater(Updater.SGD).dropOut(0.5)                .learningRate(.8).regularization(true)                .l1(1e-1).l2(2e-4)                .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT)                .list(3)                .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(8)                        .activation("identity")                        .build())                .layer(1, new DenseLayer.Builder().nIn(8).nOut(8)                        .activation("identity")                        .build())                .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.RMSE_XENT)//LossFunctions.LossFunction.RMSE_XENT)                        .activation("identity")                        .weightInit(WeightInit.XAVIER)                        .nIn(8).nOut(outputNum).build())                .backprop(true).pretrain(false)                .build();        //run the model        MultiLayerNetwork model = new MultiLayerNetwork(conf);        model.init();        model.setListeners(Collections.singletonList((IterationListener) new ScoreIterationListener(iterations)));        //get the dataset using the record reader. The datasetiterator handles vectorization        DataSet next = iterator.next();        SplitTestAndTrain testAndTrain = next.splitTestAndTrain(0.9);        System.out.println(testAndTrain.getTrain());        model.fit(testAndTrain.getTrain());        //evaluate the model        Evaluation eval = new Evaluation(10);        DataSet test = testAndTrain.getTest();        INDArray output = model.output(test.getFeatureMatrix());        eval.eval(test.getLabels(), output);        log.info(">>>>>>>>>>>>>>");        log.info(eval.stats());    }}

我还尝试调整了学习率,并且多次发现分数没有改善:

10:48:51.404 [main] DEBUG o.d.o.solvers.BackTrackLineSearch - Exited line search after maxIterations termination condition; score did not improve (bestScore=0.8522868127536543, scoreAtStart=0.8522868127536543). Resetting parameters

我还尝试了ReLU作为激活函数


回答:

一个明显的问题是,你试图用线性模型来模拟非线性函数。你的神经网络没有激活函数,因此它只能有效地表达形式为W1a + W2b + W3c + W4的函数。不管你创建多少隐藏单元,只要没有使用非线性激活函数,你的网络就会退化为简单的线性模型。

更新

还有许多“小怪事”,包括但不限于:

  • 你使用了非常高的学习率(0.8)
  • 你对一个不需要正则化的复杂问题使用了大量的正则化(同时使用L1和L2正则化器进行回归并不是常见的做法,尤其是在神经网络中)
  • 整流单元可能不是表达平方运算以及你所寻找的乘法的最佳选择。整流单元在分类中非常好,特别是在更深的架构中,但在浅层回归中并非如此。尝试使用类似sigmoid的激活函数(如tanh、sigmoid)
  • 我不完全确定在这个实现中“迭代”意味着什么,但通常这是用于训练的样本/小批量数量。因此,仅使用100次迭代对于梯度下降学习来说可能要小几个数量级

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注