我针对著名的Iris花朵问题运行了这段代码,进行了10折交叉验证,然后使用5种不同的分类方法对它们进行分类。
这应该使分类器在135个实例上进行训练,并在15个实例上进行测试,重复10次,所以我期望错误分类的实例数 + 正确分类的实例数 = 15。
以下是代码和输出。
public class WekaTest { public static void main(String[] args) throws Exception { // Comments are denoted by "//" at the beginning of the line. BufferedReader datafile = readDataFile("C:\\Program Files\\Weka-3-8\\data\\iris.arff"); //BufferedReader datafile = readDataFile("C:\\hwork\\titanic\\train.arff"); Instances data = new Instances(datafile); data.setClassIndex(data.numAttributes() - 1); // Choose a type of validation split Instances[][] split = crossValidationSplit(data, 10); // Separate split into training and testing arrays Instances[] trainingSplits = split[0]; Instances[] testingSplits = split[1]; // Choose a set of classifiers Classifier[] models = { new J48(), new PART(), new DecisionTable(), new OneR(), new DecisionStump() }; // Run for each classifier model double[][][] predictions = new double[100][100][2]; for(int j = 0; j < models.length; j++) { for(int i = 0; i < trainingSplits.length; i++) { Evaluation validation = new Evaluation(trainingSplits[i]); models[j].buildClassifier(trainingSplits[i]); validation.evaluateModel(models[j], testingSplits[i]); predictions[j][i][0] = validation.correct(); predictions[j][i][1] = validation.incorrect(); System.out.println("Classifier: "+models[j].getClass()+" : Correct: "+predictions[j][i][0]+", Wrong: "+predictions[i][j][1]); }//training foreach fold. System.out.println("==================================================================="); }//training foreach classifier.}//main().public static BufferedReader readDataFile(String filename) { BufferedReader inputReader = null; try { inputReader = new BufferedReader(new FileReader(filename)); } catch (FileNotFoundException ex) { System.err.println("File not found: " + filename); } return inputReader;}//readDataFile().public static Evaluation simpleClassify(Classifier model, Instances trainingSet, Instances testingSet) throws Exception { Evaluation validation = new Evaluation(trainingSet); model.buildClassifier(trainingSet); validation.evaluateModel(model, testingSet); return validation;}//simpleClassify().public static double calculateAccuracy(FastVector predictions) { double correct = 0; for (int i = 0; i < predictions.size(); i++) { NominalPrediction np = (NominalPrediction) predictions.elementAt(i); if (np.predicted() == np.actual()) { correct++; } } return 100 * correct / predictions.size();}//calculateAccuracy().public static Instances[][] crossValidationSplit(Instances data, int numberOfFolds) { Instances[][] split = new Instances[2][numberOfFolds]; for (int i = 0; i < numberOfFolds; i++) { split[0][i] = data.trainCV(numberOfFolds, i); split[1][i] = data.testCV(numberOfFolds, i); } return split;}//corssValidationSplit().}//class.
====================
输出结果:
Classifier: class weka.classifiers.trees.J48 : Correct: 15.0, Wrong: 0.0Classifier: class weka.classifiers.trees.J48 : Correct: 15.0, Wrong: 0.0Classifier: class weka.classifiers.trees.J48 : Correct: 14.0, Wrong: 0.0Classifier: class weka.classifiers.trees.J48 : Correct: 15.0, Wrong: 0.0Classifier: class weka.classifiers.trees.J48 : Correct: 14.0, Wrong: 0.0Classifier: class weka.classifiers.trees.J48 : Correct: 13.0, Wrong: 0.0Classifier: class weka.classifiers.trees.J48 : Correct: 15.0, Wrong: 0.0Classifier: class weka.classifiers.trees.J48 : Correct: 13.0, Wrong: 0.0Classifier: class weka.classifiers.trees.J48 : Correct: 12.0, Wrong: 0.0Classifier: class weka.classifiers.trees.J48 : Correct: 15.0, Wrong: 0.0===================================================================Classifier: class weka.classifiers.rules.PART : Correct: 15.0, Wrong: 0.0Classifier: class weka.classifiers.rules.PART : Correct: 15.0, Wrong: 0.0Classifier: class weka.classifiers.rules.PART : Correct: 14.0, Wrong: 0.0Classifier: class weka.classifiers.rules.PART : Correct: 15.0, Wrong: 0.0Classifier: class weka.classifiers.rules.PART : Correct: 14.0, Wrong: 0.0Classifier: class weka.classifiers.rules.PART : Correct: 13.0, Wrong: 0.0Classifier: class weka.classifiers.rules.PART : Correct: 15.0, Wrong: 0.0Classifier: class weka.classifiers.rules.PART : Correct: 13.0, Wrong: 0.0Classifier: class weka.classifiers.rules.PART : Correct: 9.0, Wrong: 0.0Classifier: class weka.classifiers.rules.PART : Correct: 13.0, Wrong: 0.0===================================================================Classifier: class weka.classifiers.rules.DecisionTable : Correct: 15.0, Wrong: 1.0Classifier: class weka.classifiers.rules.DecisionTable : Correct: 15.0, Wrong: 1.0Classifier: class weka.classifiers.rules.DecisionTable : Correct: 15.0, Wrong: 0.0Classifier: class weka.classifiers.rules.DecisionTable : Correct: 15.0, Wrong: 0.0Classifier: class weka.classifiers.rules.DecisionTable : Correct: 13.0, Wrong: 0.0Classifier: class weka.classifiers.rules.DecisionTable : Correct: 13.0, Wrong: 0.0Classifier: class weka.classifiers.rules.DecisionTable : Correct: 15.0, Wrong: 0.0Classifier: class weka.classifiers.rules.DecisionTable : Correct: 13.0, Wrong: 0.0Classifier: class weka.classifiers.rules.DecisionTable : Correct: 12.0, Wrong: 0.0Classifier: class weka.classifiers.rules.DecisionTable : Correct: 14.0, Wrong: 0.0===================================================================Classifier: class weka.classifiers.rules.OneR : Correct: 15.0, Wrong: 0.0Classifier: class weka.classifiers.rules.OneR : Correct: 15.0, Wrong: 0.0Classifier: class weka.classifiers.rules.OneR : Correct: 15.0, Wrong: 0.0Classifier: class weka.classifiers.rules.OneR : Correct: 14.0, Wrong: 1.0Classifier: class weka.classifiers.rules.OneR : Correct: 13.0, Wrong: 0.0Classifier: class weka.classifiers.rules.OneR : Correct: 12.0, Wrong: 0.0Classifier: class weka.classifiers.rules.OneR : Correct: 15.0, Wrong: 0.0Classifier: class weka.classifiers.rules.OneR : Correct: 14.0, Wrong: 0.0Classifier: class weka.classifiers.rules.OneR : Correct: 14.0, Wrong: 0.0Classifier: class weka.classifiers.rules.OneR : Correct: 14.0, Wrong: 0.0===================================================================Classifier: class weka.classifiers.trees.DecisionStump : Correct: 15.0, Wrong: 1.0Classifier: class weka.classifiers.trees.DecisionStump : Correct: 15.0, Wrong: 1.0Classifier: class weka.classifiers.trees.DecisionStump : Correct: 15.0, Wrong: 2.0Classifier: class weka.classifiers.trees.DecisionStump : Correct: 5.0, Wrong: 2.0Classifier: class weka.classifiers.trees.DecisionStump : Correct: 0.0, Wrong: 15.0Classifier: class weka.classifiers.trees.DecisionStump : Correct: 0.0, Wrong: 0.0Classifier: class weka.classifiers.trees.DecisionStump : Correct: 5.0, Wrong: 0.0Classifier: class weka.classifiers.trees.DecisionStump : Correct: 0.0, Wrong: 0.0Classifier: class weka.classifiers.trees.DecisionStump : Correct: 0.0, Wrong: 0.0Classifier: class weka.classifiers.trees.DecisionStump : Correct: 0.0, Wrong: 0.0===================================================================
回答:
在打印行中
System.out.println("Classifier: "+models[j].getClass()+" : Correct: "+predictions[j][i][0]+", Wrong: "+predictions[i][j][1]);
以下部分
Wrong: "+predictions[i][j][1]);
应该改为
Wrong: "+predictions[j][i][1]);
你把 j 和 i 调换了位置。