我已经训练了一个多层网络,但我在如何对额外的时间步进行预测上遇到了困难。
我尝试通过创建以下方法来遵循字符迭代的示例 –
public float[] sampleFromNetwork(INDArray testingData, int numTimeSteps, DataSetIterator iter){ int inputCount = this.getNumOfInputs(); int outputCount = this.getOutputCount(); float[] samples = new float[numTimeSteps]; //从网络中采样(并将样本反馈到输入中),一次一个值(对于所有样本)
//此处并行进行采样
this.network.rnnClearPreviousState(); INDArray output = this.network.rnnTimeStep(testingData); output = output.tensorAlongDimension(output.size(2)-1,1,0); //获取最后一个时间步的输出 for( int i=0; i<numTimeSteps; ++i ){ //通过从前一个输出中采样来设置下一个输入(单个时间步)
INDArray nextInput = Nd4j.zeros(1,inputCount); //输出是一个概率分布。对于我们想要生成的每个示例,从中采样,并将其添加到新输入中
double[] outputProbDistribution = new double[outputCount]; for( int j=0; j<outputProbDistribution.length; j++ ) { outputProbDistribution[j] = output.getDouble(j); } int nextValue = sampleFromDistribution(outputProbDistribution, new Random()); nextInput.putScalar(new int[]{0,nextValue}, 1.0f); //准备下一个时间步的输入
samples[i] = (nextValue); //将采样的字符添加到StringBuilder中(人类可读的输出)
output = this.network.rnnTimeStep(nextInput); //执行一次前向传递的时间步 } return samples;}
但是sampleFromDistribution()方法不适用,因为我没有使用离散类别。
有什么想法吗?
回答:
我通过调整我的网络使用IDENTITY激活函数并直接使用结果值解决了这个问题。仍然有很多调整要做,但它已经可以运行了。
public float[] sampleFromNetwork(INDArray priori, int numTimeSteps){ int inputCount = this.getNumOfInputs(); float[] samples = new float[numTimeSteps]; if(priori.size(1) != inputCount) { String format = String.format("the priori should have the same number of inputs [%s] as the trained network [%s]", priori.size(1), inputCount); throw new RuntimeException(format); } if(priori.size(2) < inputCount) { String format = String.format("the priori should have enough timesteps [%s] to prime the new inputs [%s]", priori.size(2), inputCount); throw new RuntimeException(format); } this.network.rnnClearPreviousState(); INDArray output = this.network.rnnTimeStep(priori); output = output.ravel(); // 存储输出以用于输入
LinkedList<Float> prevOutput = new LinkedList<>(); for (int i = 0; i < output.length(); i++) { prevOutput.add(output.getFloat(0, i)); } for( int i=0; i<numTimeSteps; ++i ){ samples[i] = (prevOutput.peekLast()); //通过从前一个输出中采样来设置下一个输入(单个时间步)
INDArray nextInput = Nd4j.zeros(1,inputCount); float[] newInputs = new float[inputCount]; newInputs[inputCount-1] = prevOutput.peekLast(); for( int j=0; j<newInputs.length-1; j++ ) { newInputs[j] = prevOutput.get(prevOutput.size()-inputCount-j); } nextInput.assign(Nd4j.create(newInputs)); //准备下一个时间步的输入
output = this.network.rnnTimeStep(nextInput); //执行一次前向传递的时间步
// 将输出添加到前一个输出的队列末尾
prevOutput.addLast(output.ravel().getFloat(0, output.length()-1)); } return samples;}