回归在DL4J中的应用 – 预测下一个时间步

我已经训练了一个多层网络,但我在如何对额外的时间步进行预测上遇到了困难。

我尝试通过创建以下方法来遵循字符迭代的示例 –

public float[] sampleFromNetwork(INDArray testingData, int numTimeSteps, DataSetIterator iter){    int inputCount = this.getNumOfInputs();    int outputCount = this.getOutputCount();    float[] samples = new float[numTimeSteps];    //从网络中采样(并将样本反馈到输入中),一次一个值(对于所有样本)
    //此处并行进行采样
    this.network.rnnClearPreviousState();    INDArray output = this.network.rnnTimeStep(testingData);    output = output.tensorAlongDimension(output.size(2)-1,1,0); //获取最后一个时间步的输出    for( int i=0; i<numTimeSteps; ++i ){        //通过从前一个输出中采样来设置下一个输入(单个时间步)
        INDArray nextInput = Nd4j.zeros(1,inputCount);        //输出是一个概率分布。对于我们想要生成的每个示例,从中采样,并将其添加到新输入中
        double[] outputProbDistribution = new double[outputCount];        for( int j=0; j<outputProbDistribution.length; j++ ) {            outputProbDistribution[j] = output.getDouble(j);        }        int nextValue = sampleFromDistribution(outputProbDistribution, new Random());        nextInput.putScalar(new int[]{0,nextValue}, 1.0f);      //准备下一个时间步的输入
        samples[i] = (nextValue);   //将采样的字符添加到StringBuilder中(人类可读的输出)
        output = this.network.rnnTimeStep(nextInput);   //执行一次前向传递的时间步    }    return samples;}

但是sampleFromDistribution()方法不适用,因为我没有使用离散类别。

有什么想法吗?


回答:

我通过调整我的网络使用IDENTITY激活函数并直接使用结果值解决了这个问题。仍然有很多调整要做,但它已经可以运行了。

public float[] sampleFromNetwork(INDArray priori, int numTimeSteps){    int inputCount = this.getNumOfInputs();    float[] samples = new float[numTimeSteps];    if(priori.size(1) != inputCount) {        String format = String.format("the priori should have the same number of inputs [%s] as the trained network [%s]", priori.size(1), inputCount);        throw new RuntimeException(format);    }    if(priori.size(2) < inputCount) {        String format = String.format("the priori should have enough timesteps [%s] to prime the new inputs [%s]", priori.size(2), inputCount);        throw new RuntimeException(format);    }    this.network.rnnClearPreviousState();    INDArray output = this.network.rnnTimeStep(priori);    output = output.ravel();    // 存储输出以用于输入
    LinkedList<Float> prevOutput = new LinkedList<>();    for (int i = 0; i < output.length(); i++) {        prevOutput.add(output.getFloat(0, i));    }    for( int i=0; i<numTimeSteps; ++i ){        samples[i] = (prevOutput.peekLast());        //通过从前一个输出中采样来设置下一个输入(单个时间步)
        INDArray nextInput = Nd4j.zeros(1,inputCount);        float[] newInputs = new float[inputCount];        newInputs[inputCount-1] = prevOutput.peekLast();        for( int j=0; j<newInputs.length-1; j++ ) {            newInputs[j] = prevOutput.get(prevOutput.size()-inputCount-j);        }        nextInput.assign(Nd4j.create(newInputs)); //准备下一个时间步的输入
        output = this.network.rnnTimeStep(nextInput); //执行一次前向传递的时间步
        // 将输出添加到前一个输出的队列末尾
        prevOutput.addLast(output.ravel().getFloat(0, output.length()-1));    }    return samples;}

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注