Apache Ignite 更新之前训练过的机器学习模型

我有一个用于训练KNN模型的数据集。后来我想用新的训练数据来更新模型。我发现更新后的模型只使用了新的训练数据,而忽略了之前训练的数据。

        Vectorizer                                     vec             = new DummyVectorizer<Integer>(1, 2).labeled(0);        DatasetTrainer<KNNClassificationModel, Double> trainer         = new KNNClassificationTrainer();        KNNClassificationModel                         model;        KNNClassificationModel                         modelUpdated;        Map<Integer, Vector>                           trainingData    = new HashMap<Integer, Vector>();        Map<Integer, Vector>                           trainingDataNew = new HashMap<Integer, Vector>();        Double[][] data1 = new Double[][] {            {0.136,0.644,0.154},            {0.302,0.634,0.779},            {0.806,0.254,0.211},            {0.241,0.951,0.744},            {0.542,0.893,0.612},            {0.334,0.277,0.486},            {0.616,0.259,0.121},            {0.738,0.585,0.017},            {0.124,0.567,0.358},            {0.934,0.346,0.863}};        Double[][] data2 = new Double[][] {            {0.300,0.236,0.193}};                    Double[] observationData = new Double[] { 0.8, 0.7 };                    // 填充数据集(在缓存中)        for (int i = 0; i < data1.length; i++)            trainingData.put(i, new DenseVector(data1[i]));        // 第一次训练/预测        model = trainer.fit(trainingData, 1, vec);        System.out.println("第一次预测 : " + model.predict(new DenseVector(observationData)));        // 新的训练数据        for (int i = 0; i < data2.length; i++)            trainingDataNew.put(data1.length + i, new DenseVector(data2[i]));        // 第二次训练/预测        modelUpdated = trainer.update(model, trainingDataNew, 1, vec);        System.out.println("第二次预测: " + modelUpdated.predict(new DenseVector(observationData)));

我得到的输出是这样的:

第一次预测 : 0.124第二次预测: 0.3

看起来第二次预测只使用了data2,这导致了0.3的预测结果。

模型更新是如何工作的?如果我必须将data2添加到data1中,然后再次在data1上进行训练,与在所有合并数据上进行全新训练相比,有什么不同?


回答:

模型更新是如何工作的?
对于KNN来说:将data2添加到data1中,然后在合并后的数据上调用modelUpdate。

请参考这个测试作为示例: https://github.com/apache/ignite/blob/635dafb7742673494efa6e8e91e236820156d38f/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java#L167

按照该测试中的说明设置你的训练器:

   KNNClassificationTrainer trainer = new KNNClassificationTrainer()            .withK(3)            .withDistanceMeasure(new EuclideanDistance())            .withWeighted(false);

然后设置你的向量化器:(注意如何创建标记坐标)

        model  = trainer.fit(                trainingData,                parts,                new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.LAST)        );

然后根据需要调用updateModel。

        KNNClassificationModel updatedOnData = trainer.update(            originalMdlOnEmptyDataset,            newData,            parts,            new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.LAST)        );

KNN分类文档: https://ignite.apache.org/docs/latest/machine-learning/binary-classification/knn-classification

KNN分类示例: https://github.com/apache/ignite/blob/master/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注