我编写了一个Weka Java代码来训练4个分类器。我保存了分类器模型,并希望使用它们来预测新的未见实例(可以将其视为想要测试一条推文是正面还是负面的情况)。
我在训练数据上使用了StringToWordsVector过滤器。为了避免”Src and Dest differ in # of attributes“错误,我使用了以下代码在应用过滤器到新实例之前使用训练数据来训练过滤器,以尝试预测一个新实例是正面还是负面。但我始终无法正确实现。
Classifier cls = (Classifier) weka.core.SerializationHelper.read("models/myModel.model"); //读取一个训练好的分类器
BufferedReader datafile = readDataFile("Tweets/tone1.ARFF"); //读取训练数据
Instances data = new Instances(datafile);
data.setClassIndex(data.numAttributes() - 1);
Filter filter = new StringToWordVector(50);//保留50个单词
filter.setInputFormat(data);
Instances filteredData = Filter.useFilter(data, filter);
// 重新构建分类器
cls.buildClassifier(filteredData);
String testInstance= "Text that I want to use as an unseen instance and predict whether it's positive or negative";
System.out.println(">create test instance");
FastVector attributes = new FastVector(2);
attributes.addElement(new Attribute("text", (FastVector) null));
// 添加类属性。
FastVector classValues = new FastVector(2);
classValues.addElement("Negative");
classValues.addElement("Positive");
attributes.addElement(new Attribute("Tone", classValues));
// 创建数据集,初始容量为100,并设置类索引。
Instances tests = new Instances("test istance", attributes, 100);
tests.setClassIndex(tests.numAttributes() - 1);
Instance test = new Instance(2);
// 设置消息属性的值
Attribute messageAtt = tests.attribute("text");
test.setValue(messageAtt, messageAtt.addStringValue(testInstance));
test.setDataset(tests);
Filter filter2 = new StringToWordVector(50);
filter2.setInputFormat(tests);
Instances filteredTests = Filter.useFilter(tests, filter2);
System.out.println(">train Test filter using training data");
Standardize sfilter = new Standardize(); //匹配源和目标之间的属性数量。
sfilter.setInputFormat(filteredData); // 使用训练集初始化过滤器
filteredTests = Filter.useFilter(filteredData, sfilter);
// 创建新的测试集
ArffSaver saver = new ArffSaver(); //保存测试数据到ARFF文件
saver.setInstances(filteredTests);
File unseenFile = new File ("Tweets/unseen.ARFF");
saver.setFile(unseenFile);
saver.writeBatch();
当我尝试使用过滤后的训练数据标准化输入数据时,我得到了一个新的ARFF文件(unseen.ARFF),但其中包含了2000个实例(与训练数据数量相同),其中大多数值为负数。我不明白为什么会这样,或者如何删除这些实例。
System.out.println(">Evaluation"); //如果没有以下两行,我会得到ArrayIndexOutOfBoundException。
filteredData.setClassIndex(filteredData.numAttributes() - 1);
filteredTests.setClassIndex(filteredTests.numAttributes() - 1);
Evaluation eval = new Evaluation(filteredData);
eval.evaluateModel(cls, filteredTests);
System.out.println(eval.toSummaryString("\nResults\n======\n", false));
打印评估结果时,我希望看到例如这个实例是正面还是负面的百分比,但相反我得到了以下结果。我也希望看到1个实例而不是2000个。任何关于如何做到这一点的帮助都将非常有用。
> Results======
Correlation coefficient 0.0285
Mean absolute error 0.8765
Root mean squared error 1.2185
Relative absolute error 409.4123 %
Root relative squared error 121.8754 %
Total Number of Instances 2000
谢谢
回答:
我已经找到了一个很好的解决方案,在这里与大家分享我的代码。这段代码使用Weka Java代码训练一个分类器,然后用它来预测新的未见实例。有些部分 – 如路径 – 是硬编码的,但您可以轻松修改方法以接受参数。
/*** 此方法执行未见实例的分类。 * 它首先使用一系列分类器训练模型,然后对新的未标记实例进行分类。*/
public static void predict() throws Exception {
//首先提供您的训练和测试ARFF文件的路径,确保两个文件具有相同的结构和头部中的确切类别
//初始化分类器
Classifier classifier = null;
System.out.println("read training arff");
Instances train = new Instances(new BufferedReader(new FileReader("Train.arff")));
train.setClassIndex(0);//在我的情况下,类是第一个属性,因此为零,否则是属性数量 -1
System.out.println("read testing arff");
Instances unlabeled = new Instances(new BufferedReader(new FileReader("Test.arff")));
unlabeled.setClassIndex(0);
// 使用一系列分类器进行训练(朴素贝叶斯,SMO(即SVM),KNN和决策树。)
String[] algorithms = {"nb","smo","knn","j48"};
for(int w=0; w<algorithms.length;w++){
if(algorithms[w].equals("nb"))
classifier = new NaiveBayes();
if(algorithms[w].equals("smo"))
classifier = new SMO();
if(algorithms[w].equals("knn"))
classifier = new IBk();
if(algorithms[w].equals("j48"))
classifier = new J48();
System.out.println("==========================================================================");
System.out.println("training using " + algorithms[w] + " classifier");
Evaluation eval = new Evaluation(train);
//执行10折交叉验证
eval.crossValidateModel(classifier, train, 10, new Random(1));
String output = eval.toSummaryString();
System.out.println(output);
String classDetails = eval.toClassDetailsString();
System.out.println(classDetails);
classifier.buildClassifier(train);
}
Instances labeled = new Instances(unlabeled);
// 标记实例(使用训练好的分类器对新的未见实例进行分类)
for (int i = 0; i < unlabeled.numInstances(); i++) {
double clsLabel = classifier.classifyInstance(unlabeled.instance(i));
labeled.instance(i).setClassValue(clsLabel);
System.out.println(clsLabel + " -> " + unlabeled.classAttribute().value((int) clsLabel));
}
//保存模型以备将来使用
ObjectOutputStream out = new ObjectOutputStream(new FileOutputStream("myModel.dat"));
out.writeObject(classifier);
out.close();
System.out.println("===== Saved model =====");
}