使用kNN算法加载数据集时遇到问题 – Java

我已经应用了kNN算法来对 handwritten digits 进行分类。这些数字最初是以8*8的向量格式存在,并被拉伸成1*64的向量,每组数据都有一个从0到9的类别代码。

据我所知,我的代码理论上应该是可行的,但这是我第一次尝试使用这个算法。我的问题出现在尝试通过我的算法输入数据集时,我在代码中高亮的行上遇到了错误。训练数据集可以在这里找到,验证集可以在这里找到。我还保留了之前可用的主函数,如果有帮助的话。

ImageMatrix.java

import java.util.*;public class ImageMatrix {    private int[] data;    private int classCode;public ImageMatrix(int[] data, int classCode) {    assert data.length == 64; //maximum array length of 64    this.data = data;    this.classCode = classCode;}    public String toString() {        return "Class Code: " + classCode + " Data :" + Arrays.toString(data) + "\n"; //outputs readable    }    public int[] getData() {        return data;    }    public int getClassCode() {        return classCode;    }}

ImageMatrixDB.java

import java.util.*;import java.io.*;public class ImageMatrixDB implements Iterable<ImageMatrix> {    private List<ImageMatrix> list = new ArrayList<ImageMatrix>();    public ImageMatrixDB load(String f) throws IOException {        try (            FileReader fr = new FileReader(f);            BufferedReader br = new BufferedReader(fr)) {            String line = null;            while((line = br.readLine()) != null) {                int lastComma = line.lastIndexOf(',');                int classCode = Integer.parseInt(line.substring(1 + lastComma));                int[] data = Arrays.stream(line.substring(0, lastComma).split(","))                                   .mapToInt(Integer::parseInt)                                   .toArray();                ImageMatrix matrix = new ImageMatrix(data, classCode);                list.add(matrix);            }        }        return this;    }    public void printResults(){ //output results         for(ImageMatrix matrix: list){            System.out.println(matrix);        }    }    public Iterator<ImageMatrix> iterator() {        return this.list.iterator();    }    /// kNN implementation ///    public static int distance(int[] a, int[] b) {        int sum = 0;        for(int i = 0; i < a.length; i++) {            sum += (a[i] - b[i]) * (a[i] - b[i]);        }        return (int)Math.sqrt(sum); //Euclidean sqrt of the sum     }    public static int classify(List<ImageMatrix> trainingSet, int[] curData) {        int label = 0, bestDistance = Integer.MAX_VALUE;        for(ImageMatrix matrix: trainingSet) {            int dist = distance(matrix.getData(), curData);            if(dist < bestDistance) {                bestDistance = dist;                curData = matrix.getData();            }        }        return label;    }    public static void main(String[] argv) throws IOException {        ImageMatrixDB i = new ImageMatrixDB();        List<ImageMatrix> trainingSet = i.load("cw2DataSet1.csv"); // << ERROR HERE        List<ImageMatrix> validationSet = i.load("cw2DataSet2.csv"); //<< ERROR HERE        int numCorrect = 0;        for(ImageMatrix matrix:validationSet) {            if(classify(trainingSet, matrix.getData()) == matrix.getClassCode()) numCorrect++;        }        System.out.println("Accuracy: " + (double)numCorrect / validationSet.size() * 100 + "%");    }    //////////////////////////////////////////    // Previous working dataset Load // /*   public static void main(String[] args){        ImageMatrixDB i = new ImageMatrixDB();        try{            i.load("cw2DataSet1.csv");             i.printResults();        }        catch(Exception ex){            ex.printStackTrace();        }    } */}

EDIT///

当前错误信息显示:

Exception in thread "main" java.lang.Error: Unresolved compilation problems:     Type mismatch: cannot convert from ImageMatrixDB to List<ImageMatrix>    Type mismatch: cannot convert from ImageMatrixDB to List<ImageMatrix>    at ImageMatrixDB.main(ImageMatrixDB.java:64)

但我在测试时也遇到了其他错误。


回答:

根据您的类设计,您应该这样使用它:

ImageMatrixDB trainingSet = new ImageMatrixDB();ImageMatrixDB validationSet = new ImageMatrixDB();trainingSet.load("cw2DataSet1.csv");validationSet.load("cw2DataSet2.csv");

请注意,这里使用了两个ImageMatrixDB实例,而不是一个,这样可以确保训练数据和验证数据被加载到不同的列表中。

顺便提一下,在计算kNN的距离时,您可以使用平方距离来提高效率(sqrt是一个昂贵的操作)。所以return (int)Math.sqrt(sum);不需要进行平方根运算。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注