使用多层网络预测简单函数

我正在尝试培养对机器学习的直觉。我查看了https://github.com/deeplearning4j/dl4j-0.4-examples上的示例,并想开发自己的示例。基本上,我只是采用了一个简单的函数:a * a + b * b + c * c – a * b * c + a + b + c,并为随机的a、b、c生成了10000个输出,并尝试用90%的输入来训练我的网络。问题是无论我做什么,我的网络都无法预测剩余的示例。

这是我的代码:

public class BasicFunctionNN {    private static Logger log = LoggerFactory.getLogger(MlPredict.class);    public static DataSetIterator generateFunctionDataSet() {        Collection<DataSet> list = new ArrayList<>();        for (int i = 0; i < 100000; i++) {            double a = Math.random();            double b = Math.random();            double c = Math.random();            double output = a * a + b * b + c * c - a * b * c + a + b + c;            INDArray in = Nd4j.create(new double[]{a, b, c});            INDArray out = Nd4j.create(new double[]{output});            list.add(new DataSet(in, out));        }        return new ListDataSetIterator(list, list.size());    }    public static void main(String[] args) throws Exception {        DataSetIterator iterator = generateFunctionDataSet();        Nd4j.MAX_SLICES_TO_PRINT = 10;        Nd4j.MAX_ELEMENTS_PER_SLICE = 10;        final int numInputs = 3;        int outputNum = 1;        int iterations = 100;        log.info("Build model....");        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()                .iterations(iterations).weightInit(WeightInit.XAVIER).updater(Updater.SGD).dropOut(0.5)                .learningRate(.8).regularization(true)                .l1(1e-1).l2(2e-4)                .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT)                .list(3)                .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(8)                        .activation("identity")                        .build())                .layer(1, new DenseLayer.Builder().nIn(8).nOut(8)                        .activation("identity")                        .build())                .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.RMSE_XENT)//LossFunctions.LossFunction.RMSE_XENT)                        .activation("identity")                        .weightInit(WeightInit.XAVIER)                        .nIn(8).nOut(outputNum).build())                .backprop(true).pretrain(false)                .build();        //run the model        MultiLayerNetwork model = new MultiLayerNetwork(conf);        model.init();        model.setListeners(Collections.singletonList((IterationListener) new ScoreIterationListener(iterations)));        //get the dataset using the record reader. The datasetiterator handles vectorization        DataSet next = iterator.next();        SplitTestAndTrain testAndTrain = next.splitTestAndTrain(0.9);        System.out.println(testAndTrain.getTrain());        model.fit(testAndTrain.getTrain());        //evaluate the model        Evaluation eval = new Evaluation(10);        DataSet test = testAndTrain.getTest();        INDArray output = model.output(test.getFeatureMatrix());        eval.eval(test.getLabels(), output);        log.info(">>>>>>>>>>>>>>");        log.info(eval.stats());    }}

我还尝试调整了学习率,并且多次发现分数没有改善:

10:48:51.404 [main] DEBUG o.d.o.solvers.BackTrackLineSearch - Exited line search after maxIterations termination condition; score did not improve (bestScore=0.8522868127536543, scoreAtStart=0.8522868127536543). Resetting parameters

我还尝试了ReLU作为激活函数


回答:

一个明显的问题是,你试图用线性模型来模拟非线性函数。你的神经网络没有激活函数,因此它只能有效地表达形式为W1a + W2b + W3c + W4的函数。不管你创建多少隐藏单元,只要没有使用非线性激活函数,你的网络就会退化为简单的线性模型。

更新

还有许多“小怪事”,包括但不限于:

  • 你使用了非常高的学习率(0.8)
  • 你对一个不需要正则化的复杂问题使用了大量的正则化(同时使用L1和L2正则化器进行回归并不是常见的做法,尤其是在神经网络中)
  • 整流单元可能不是表达平方运算以及你所寻找的乘法的最佳选择。整流单元在分类中非常好,特别是在更深的架构中,但在浅层回归中并非如此。尝试使用类似sigmoid的激活函数(如tanh、sigmoid)
  • 我不完全确定在这个实现中“迭代”意味着什么,但通常这是用于训练的样本/小批量数量。因此,仅使用100次迭代对于梯度下降学习来说可能要小几个数量级

Related Posts

L1-L2正则化的不同系数

我想对网络的权重同时应用L1和L2正则化。然而,我找不…

使用scikit-learn的无监督方法将列表分类成不同组别,有没有办法?

我有一系列实例,每个实例都有一份列表,代表它所遵循的不…

f1_score metric in lightgbm

我想使用自定义指标f1_score来训练一个lgb模型…

通过相关系数矩阵进行特征选择

我在测试不同的算法时,如逻辑回归、高斯朴素贝叶斯、随机…

可以将机器学习库用于流式输入和输出吗?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

在TensorFlow中,queue.dequeue_up_to()方法的用途是什么?

我对这个方法感到非常困惑,特别是当我发现这个令人费解的…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注