WEKA生成的模型似乎无法根据属性索引预测类别和分布

概述

我正在使用WEKA API 3.7.10(开发者版本)来使用我预先制作的.model文件。

我制作了25个模型:五种算法的五个结果变量。

  • J48决策树
  • 交替决策树
  • 随机森林
  • LogitBoost
  • 随机子空间

我在使用J48、随机子空间和随机森林时遇到了问题。

必要文件

以下是我创建数据后的ARFF表示:

@relation WekaData@attribute ageDiagNum numeric@attribute raceGroup {Black,Other,Unknown,White}@attribute stage3 {0,I,IIA,IIB,IIIA,IIIB,IIIC,IIINOS,IV,'UNK Stage'}@attribute m3 {M0,M1,MX}@attribute reasonNoCancerSurg {'Not performed, patient died prior to recommended surgery','Not recommended','Not recommended, contraindicated due to other conditions','Recommended but not performed, patient refused','Recommended but not performed, unknown reason','Recommended, unknown if performed','Surgery performed','Unknown; death certificate or autopsy only case'}@attribute ext2 {00,05,10,11,13,14,15,16,17,18,20,21,23,24,25,26,27,28,30,31,33,34,35,36,37,38,40,50,60,70,80,85,99}@attribute time2 {}@attribute time4 {}@attribute time6 {}@attribute time8 {}@attribute time10 {}@data65,White,IIA,MX,'Not recommended, contraindicated due to other conditions',14,?,?,?,?,?

我需要从各自的模型中获取二进制属性time2time10


以下是我用来从所有模型文件中获取预测的代码片段:

private static Map<String, Object> predict(Instances instances,        Classifier classifier, int attributeIndex) {    Map<String, Object> map = new LinkedHashMap<String, Object>();    int instanceIndex = 0; // do not change, equal to row 1    double[] percentage = { 0 };    double outcomeValue = 0;    AbstractOutput abstractOutput = null;    if(classifier.getClass() == RandomForest.class || classifier.getClass() == RandomSubSpace.class) {        // has problems predicting time2 to time10        instances.setClassIndex(5);     } else {        // works as intended in LogitBoost and ADTree        instances.setClassIndex(attributeIndex);        }    try {        outcomeValue = classifier.classifyInstance(instances.instance(0));        percentage = classifier.distributionForInstance(instances                .instance(instanceIndex));    } catch (Exception e) {        e.printStackTrace();    }    map.put("Class", outcomeValue);    if (percentage.length > 0) {        double percentageRaw = 0;        if (outcomeValue == new Double(1)) {            percentageRaw = percentage[1];        } else {            percentageRaw = 1 - percentage[0];        }        map.put("Percentage", percentageRaw);    } else {        // because J48 returns an error if percentage[i] because it's empty        map.put("Percentage", new Double(0));    }    return map;}

以下是我用来预测time2结果的模型,因此我们将使用索引6:

instances.setClassIndex(5); 

问题

  • 如我之前所说,与其他三个相比,LogitBoostADTree在这个直接方法中没有问题,因为我遵循了在Java代码中使用WEKA教程。

  • [已解决]根据我的调整,RandomForestRandomSubSpace在被要求预测time2time10时会返回ArrayOutOfBoundsException

    java.lang.ArrayIndexOutOfBoundsException: 0    at weka.classifiers.meta.Bagging.distributionForInstance(Bagging.java:586)    at weka.classifiers.trees.RandomForest.distributionForInstance(RandomForest.java:602)    at weka.classifiers.AbstractClassifier.classifyInstance(AbstractClassifier.java:70)

    堆栈跟踪指向根错误的行:

    outcomeValue = classifier.classifyInstance(instances.instance(0));

    解决方案:我在创建ARFF文件时,对于二进制变量time2time10FastVector<String>()值分配到FastVector<Attribute>()对象时犯了一些复制粘贴错误。现在我的十个RandomForestRandomSubSpace模型都正常工作了!


我希望有人能帮我解决这个问题。我真的不知道这段代码有什么问题,因为我已经查看了Javadoc和在线示例,但持续的预测问题依然存在。

(我目前正在检查WEKA GUI的主程序,但请在这里帮助我 🙂 )


回答:

我目前只看了随机森林的问题。这是由于Bagging类从数据实例本身而不是从模型中提取不同类别的数量。你在文本中说time2到time10是二进制的,但你在ARFF文件中没有说明,所以Bagging类不知道有多少类别。

所以你只需要在你的ARFF文件中指定time2是二进制的,例如:@attribute time2 {0,1}

这样你就不会再得到任何异常了。

我还没有看J48的问题,因为它可能是ARFF定义相同的问题。

测试代码:

  public static void main(String [] argv) {      try {        Classifier cls = (Classifier) weka.core.SerializationHelper.read("bosom.100k.2.j48.MODEL");        J48 c = (J48)cls;        DataSource source = new DataSource("data.arff");        Instances data = source.getDataSet();        data.setClassIndex(6);                try {            double outcomeValue = c.classifyInstance(data.instance(0));            System.out.println("outcome "+outcomeValue);            double[] p = c.distributionForInstance(data.instance(0));            System.out.println(Arrays.toString(p));        } catch (Exception e) {            e.printStackTrace();        }    } catch (Exception e) {        e.printStackTrace();    }

Related Posts

L1-L2正则化的不同系数

我想对网络的权重同时应用L1和L2正则化。然而,我找不…

使用scikit-learn的无监督方法将列表分类成不同组别,有没有办法?

我有一系列实例,每个实例都有一份列表,代表它所遵循的不…

f1_score metric in lightgbm

我想使用自定义指标f1_score来训练一个lgb模型…

通过相关系数矩阵进行特征选择

我在测试不同的算法时,如逻辑回归、高斯朴素贝叶斯、随机…

可以将机器学习库用于流式输入和输出吗?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

在TensorFlow中,queue.dequeue_up_to()方法的用途是什么?

我对这个方法感到非常困惑,特别是当我发现这个令人费解的…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注