我有一个用于训练KNN模型的数据集。后来我想用新的训练数据来更新模型。我发现更新后的模型只使用了新的训练数据,而忽略了之前训练的数据。
Vectorizer vec = new DummyVectorizer<Integer>(1, 2).labeled(0); DatasetTrainer<KNNClassificationModel, Double> trainer = new KNNClassificationTrainer(); KNNClassificationModel model; KNNClassificationModel modelUpdated; Map<Integer, Vector> trainingData = new HashMap<Integer, Vector>(); Map<Integer, Vector> trainingDataNew = new HashMap<Integer, Vector>(); Double[][] data1 = new Double[][] { {0.136,0.644,0.154}, {0.302,0.634,0.779}, {0.806,0.254,0.211}, {0.241,0.951,0.744}, {0.542,0.893,0.612}, {0.334,0.277,0.486}, {0.616,0.259,0.121}, {0.738,0.585,0.017}, {0.124,0.567,0.358}, {0.934,0.346,0.863}}; Double[][] data2 = new Double[][] { {0.300,0.236,0.193}}; Double[] observationData = new Double[] { 0.8, 0.7 }; // 填充数据集(在缓存中) for (int i = 0; i < data1.length; i++) trainingData.put(i, new DenseVector(data1[i])); // 第一次训练/预测 model = trainer.fit(trainingData, 1, vec); System.out.println("第一次预测 : " + model.predict(new DenseVector(observationData))); // 新的训练数据 for (int i = 0; i < data2.length; i++) trainingDataNew.put(data1.length + i, new DenseVector(data2[i])); // 第二次训练/预测 modelUpdated = trainer.update(model, trainingDataNew, 1, vec); System.out.println("第二次预测: " + modelUpdated.predict(new DenseVector(observationData)));
我得到的输出是这样的:
第一次预测 : 0.124第二次预测: 0.3
看起来第二次预测只使用了data2,这导致了0.3的预测结果。
模型更新是如何工作的?如果我必须将data2添加到data1中,然后再次在data1上进行训练,与在所有合并数据上进行全新训练相比,有什么不同?
回答:
模型更新是如何工作的?
对于KNN来说:将data2添加到data1中,然后在合并后的数据上调用modelUpdate。
按照该测试中的说明设置你的训练器:
KNNClassificationTrainer trainer = new KNNClassificationTrainer() .withK(3) .withDistanceMeasure(new EuclideanDistance()) .withWeighted(false);
然后设置你的向量化器:(注意如何创建标记坐标)
model = trainer.fit( trainingData, parts, new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.LAST) );
然后根据需要调用updateModel。
KNNClassificationModel updatedOnData = trainer.update( originalMdlOnEmptyDataset, newData, parts, new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.LAST) );
KNN分类文档: https://ignite.apache.org/docs/latest/machine-learning/binary-classification/knn-classification