WEKA生成的模型似乎无法根据属性索引预测类别和分布

概述

我正在使用WEKA API 3.7.10(开发者版本)来使用我预先制作的.model文件。

我制作了25个模型:五种算法的五个结果变量。

  • J48决策树
  • 交替决策树
  • 随机森林
  • LogitBoost
  • 随机子空间

我在使用J48、随机子空间和随机森林时遇到了问题。

必要文件

以下是我创建数据后的ARFF表示:

@relation WekaData@attribute ageDiagNum numeric@attribute raceGroup {Black,Other,Unknown,White}@attribute stage3 {0,I,IIA,IIB,IIIA,IIIB,IIIC,IIINOS,IV,'UNK Stage'}@attribute m3 {M0,M1,MX}@attribute reasonNoCancerSurg {'Not performed, patient died prior to recommended surgery','Not recommended','Not recommended, contraindicated due to other conditions','Recommended but not performed, patient refused','Recommended but not performed, unknown reason','Recommended, unknown if performed','Surgery performed','Unknown; death certificate or autopsy only case'}@attribute ext2 {00,05,10,11,13,14,15,16,17,18,20,21,23,24,25,26,27,28,30,31,33,34,35,36,37,38,40,50,60,70,80,85,99}@attribute time2 {}@attribute time4 {}@attribute time6 {}@attribute time8 {}@attribute time10 {}@data65,White,IIA,MX,'Not recommended, contraindicated due to other conditions',14,?,?,?,?,?

我需要从各自的模型中获取二进制属性time2time10


以下是我用来从所有模型文件中获取预测的代码片段:

private static Map<String, Object> predict(Instances instances,        Classifier classifier, int attributeIndex) {    Map<String, Object> map = new LinkedHashMap<String, Object>();    int instanceIndex = 0; // do not change, equal to row 1    double[] percentage = { 0 };    double outcomeValue = 0;    AbstractOutput abstractOutput = null;    if(classifier.getClass() == RandomForest.class || classifier.getClass() == RandomSubSpace.class) {        // has problems predicting time2 to time10        instances.setClassIndex(5);     } else {        // works as intended in LogitBoost and ADTree        instances.setClassIndex(attributeIndex);        }    try {        outcomeValue = classifier.classifyInstance(instances.instance(0));        percentage = classifier.distributionForInstance(instances                .instance(instanceIndex));    } catch (Exception e) {        e.printStackTrace();    }    map.put("Class", outcomeValue);    if (percentage.length > 0) {        double percentageRaw = 0;        if (outcomeValue == new Double(1)) {            percentageRaw = percentage[1];        } else {            percentageRaw = 1 - percentage[0];        }        map.put("Percentage", percentageRaw);    } else {        // because J48 returns an error if percentage[i] because it's empty        map.put("Percentage", new Double(0));    }    return map;}

以下是我用来预测time2结果的模型,因此我们将使用索引6:

instances.setClassIndex(5); 

问题

  • 如我之前所说,与其他三个相比,LogitBoostADTree在这个直接方法中没有问题,因为我遵循了在Java代码中使用WEKA教程。

  • [已解决]根据我的调整,RandomForestRandomSubSpace在被要求预测time2time10时会返回ArrayOutOfBoundsException

    java.lang.ArrayIndexOutOfBoundsException: 0    at weka.classifiers.meta.Bagging.distributionForInstance(Bagging.java:586)    at weka.classifiers.trees.RandomForest.distributionForInstance(RandomForest.java:602)    at weka.classifiers.AbstractClassifier.classifyInstance(AbstractClassifier.java:70)

    堆栈跟踪指向根错误的行:

    outcomeValue = classifier.classifyInstance(instances.instance(0));

    解决方案:我在创建ARFF文件时,对于二进制变量time2time10FastVector<String>()值分配到FastVector<Attribute>()对象时犯了一些复制粘贴错误。现在我的十个RandomForestRandomSubSpace模型都正常工作了!


我希望有人能帮我解决这个问题。我真的不知道这段代码有什么问题,因为我已经查看了Javadoc和在线示例,但持续的预测问题依然存在。

(我目前正在检查WEKA GUI的主程序,但请在这里帮助我 🙂 )


回答:

我目前只看了随机森林的问题。这是由于Bagging类从数据实例本身而不是从模型中提取不同类别的数量。你在文本中说time2到time10是二进制的,但你在ARFF文件中没有说明,所以Bagging类不知道有多少类别。

所以你只需要在你的ARFF文件中指定time2是二进制的,例如:@attribute time2 {0,1}

这样你就不会再得到任何异常了。

我还没有看J48的问题,因为它可能是ARFF定义相同的问题。

测试代码:

  public static void main(String [] argv) {      try {        Classifier cls = (Classifier) weka.core.SerializationHelper.read("bosom.100k.2.j48.MODEL");        J48 c = (J48)cls;        DataSource source = new DataSource("data.arff");        Instances data = source.getDataSet();        data.setClassIndex(6);                try {            double outcomeValue = c.classifyInstance(data.instance(0));            System.out.println("outcome "+outcomeValue);            double[] p = c.distributionForInstance(data.instance(0));            System.out.println(Arrays.toString(p));        } catch (Exception e) {            e.printStackTrace();        }    } catch (Exception e) {        e.printStackTrace();    }

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注