修改kNN算法中k值 – Java

我已经将kNN算法应用于手写数字的分类。这些数字最初是以8*8的向量格式存储,并被拉伸成1*64的向量。

目前,我的代码仅使用k=1来应用kNN算法。我尝试了几次更改k值,但总是遇到错误。如果有人能指导我正确的方向,我将非常感激。训练数据集可以在这里找到,验证集可以在这里找到。

ImageMatrix.java

import java.util.*;public class ImageMatrix {    private int[] data;    private int classCode;    private int curData;public ImageMatrix(int[] data, int classCode) {    assert data.length == 64; //maximum array length of 64    this.data = data;    this.classCode = classCode;}    public String toString() {        return "Class Code: " + classCode + " Data :" + Arrays.toString(data) + "\n"; //outputs readable    }    public int[] getData() {        return data;    }    public int getClassCode() {        return classCode;    }    public int getCurData() {        return curData;    }}

ImageMatrixDB.java

import java.util.*;import java.io.*;import java.util.ArrayList;public class ImageMatrixDB implements Iterable<ImageMatrix> {    private List<ImageMatrix> list = new ArrayList<ImageMatrix>();    public ImageMatrixDB load(String f) throws IOException {        try (            FileReader fr = new FileReader(f);            BufferedReader br = new BufferedReader(fr)) {            String line = null;            while((line = br.readLine()) != null) {                int lastComma = line.lastIndexOf(',');                int classCode = Integer.parseInt(line.substring(1 + lastComma));                int[] data = Arrays.stream(line.substring(0, lastComma).split(","))                                   .mapToInt(Integer::parseInt)                                   .toArray();                ImageMatrix matrix = new ImageMatrix(data, classCode); // Classcode->100% when 0 -> 0% when 1 - 9..                list.add(matrix);            }        }        return this;    }    public void printResults(){ //output results         for(ImageMatrix matrix: list){            System.out.println(matrix);        }    }    public Iterator<ImageMatrix> iterator() {        return this.list.iterator();    }    /// kNN implementation ///    public static int distance(int[] a, int[] b) {        int sum = 0;        for(int i = 0; i < a.length; i++) {            sum += (a[i] - b[i]) * (a[i] - b[i]);        }        return (int)Math.sqrt(sum);    }    public static int classify(ImageMatrixDB trainingSet, int[] curData) {        int label = 0, bestDistance = Integer.MAX_VALUE;        for(ImageMatrix matrix: trainingSet) {            int dist = distance(matrix.getData(), curData);            if(dist < bestDistance) {                bestDistance = dist;                label = matrix.getClassCode();            }        }        return label;    }    public int size() {        return list.size(); //returns size of the list        }    public static void main(String[] argv) throws IOException {        ImageMatrixDB trainingSet = new ImageMatrixDB();        ImageMatrixDB validationSet = new ImageMatrixDB();        trainingSet.load("cw2DataSet1.csv");        validationSet.load("cw2DataSet2.csv");         int numCorrect = 0;        for(ImageMatrix matrix:validationSet) {            if(classify(trainingSet, matrix.getData()) == matrix.getClassCode()) numCorrect++;        } //285 correct        System.out.println("Accuracy: " + (double)numCorrect / validationSet.size() * 100 + "%");         System.out.println();    }}

回答:

classify方法的for循环中,你试图找到与测试点最接近的训练样本。你需要将其改为寻找与测试数据最接近的K个训练点。然后,你应该对这K个点调用getClassCode方法,并找出其中最常见的类别代码(即出现频率最高的)。classify方法将返回你找到的主要类别代码。

你可以根据需要以任何方式处理平局情况(即有两个或多个最常见的类别代码分配给相同数量的训练数据)。

我对Java的经验非常有限,但通过查阅语言参考,我提出了下面的实现方案。

public static int classify(ImageMatrixDB trainingSet, int[] curData, int k) {    int label = 0, bestDistance = Integer.MAX_VALUE;    int[][] distances = new int[trainingSet.size()][2];    int i=0;    // Place distances in an array to be sorted    for(ImageMatrix matrix: trainingSet) {        distances[i][0] = distance(matrix.getData(), curData);        distances[i][1] = matrix.getClassCode();        i++;    }    Arrays.sort(distances, (int[] lhs, int[] rhs) -> lhs[0]-rhs[0]);    // Find frequencies of each class code    i = 0;    Map<Integer,Integer> majorityMap;    majorityMap = new HashMap<Integer,Integer>();    while(i < k) {        if( majorityMap.containsKey( distances[i][1] ) ) {            int currentValue = majorityMap.get(distances[i][1]);            majorityMap.put(distances[i][1], currentValue + 1);        }        else {            majorityMap.put(distances[i][1], 1);        }        ++i;    }    // Find the class code with the highest frequency    int maxVal = -1;    for (Entry<Integer, Integer> entry: majorityMap.entrySet()) {        int entryVal = entry.getValue();        if(entryVal > maxVal) {            maxVal = entryVal;            label = entry.getKey();        }    }    return label;}

你只需将K作为参数添加进去即可。不过,请注意,上述代码并未特别处理平局情况。

Related Posts

Keras Dense层输入未被展平

这是我的测试代码: from keras import…

无法将分类变量输入随机森林

我有10个分类变量和3个数值变量。我在分割后直接将它们…

如何在Keras中对每个输出应用Sigmoid函数?

这是我代码的一部分。 model = Sequenti…

如何选择类概率的最佳阈值?

我的神经网络输出是一个用于多标签分类的预测类概率表: …

在Keras中使用深度学习得到不同的结果

我按照一个教程使用Keras中的深度神经网络进行文本分…

‘MatMul’操作的输入’b’类型为float32,与参数’a’的类型float64不匹配

我写了一个简单的TensorFlow代码,但不断遇到T…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注