Java LightSIDE – 如何使用LightSIDE对数据进行分类?

我已经设置了LightSIDE插件并且可以正常运行,但是我不知道为什么我无法将数据保存到空文件中?这是我创建的一个简单结构。

enter image description here

  1. 活动是需要分类的数据列表。
  2. 我有三个类别,每个类别都有各自的类型。
  3. 我已经为每个类别定义了特定的单词列表。例如:Food({Sushi, Food, Japan},{Cap Jay, Food, Chinese},{Jog, Sport, Running},…)

这是我如何使用LightSIDE保存我的预测结果。

public void predictSectionType(String[] sections, List<String> activityList) {        LightSideService currentLightsideHelper = new LightSideService();        Recipe newRecipe;        // 初始化SIDEPlugin        currentLightsideHelper.initSIDEPlugin();        try {             // 加载带有提取特征和训练模型的Recipe            ClassLoader myClassLoader = getClass().getClassLoader();            newRecipe = ConverterControl.readFromXML(new InputStreamReader(myClassLoader.getResourceAsStream("static/lightsideTrainingResult/trainingData.xml")));            // 预测结果数据            Recipe recipeToPredict = currentLightsideHelper.loadNewDocumentsFromCSV(sections); // 创建DocumentList和Recipe            currentLightsideHelper.predictLabels(recipeToPredict, newRecipe);        } catch (FileNotFoundException e) {            e.printStackTrace();        } catch (IOException e) {            e.printStackTrace();        }    }

我有一个LightSideService类,作为LightSIDE功能的总结类。

public class LightSideService {    // 提取特征参数    final String featureTableName = "1Grams";    final int featureThreshold = 2;    final String featureAnnotation = "Code";    final Type featureType = Type.NOMINAL;    // 构建模型参数    final String trainingResultName = "Bayes_1Grams";    // 预测标签参数    final String predictionColumnName = featureAnnotation + "_Prediction";    final boolean showMaxScore = false;    final boolean showDists = true;    final boolean overwrite = false;    final boolean useEvaluation = false;    public DocumentListTableModel model = new DocumentListTableModel(null);    public Map<String, Serializable> validationSettings = new TreeMap<String, Serializable>();    public Map<FeaturePlugin, Boolean> featurePlugins = new HashMap<FeaturePlugin, Boolean>();    public Map<LearningPlugin, Boolean> learningPlugins = new HashMap<LearningPlugin, Boolean>();    public Collection<ModelMetricPlugin> modelEvaluationPlugins = new ArrayList<ModelMetricPlugin>();    public Map<WrapperPlugin, Boolean> wrapperPlugins = new HashMap<WrapperPlugin, Boolean>();    // 初始化数据 ==================================================    public void initSIDEPlugin() {                      SIDEPlugin[] featureExtractors = PluginManager.getSIDEPluginArrayByType("feature_hit_extractor");        boolean selected = true;        for (SIDEPlugin fe : featureExtractors) {            featurePlugins.put((FeaturePlugin) fe, selected);            selected = false;        }        SIDEPlugin[] learners = PluginManager.getSIDEPluginArrayByType("model_builder");        for (SIDEPlugin le : learners) {            learningPlugins.put((LearningPlugin) le, true);        }        SIDEPlugin[] tableEvaluations = PluginManager.getSIDEPluginArrayByType("model_evaluation");        for (SIDEPlugin fe : tableEvaluations) {            modelEvaluationPlugins.add((ModelMetricPlugin) fe);        }        SIDEPlugin[] wrappers = PluginManager.getSIDEPluginArrayByType("learning_wrapper");        for (SIDEPlugin wr : wrappers) {            wrapperPlugins.put((WrapperPlugin) wr, false);        }    }    //用于训练模型,根据模型调整参数    public void initValidationSettings(Recipe currentRecipe) {        validationSettings.put("testRecipe", currentRecipe);        validationSettings.put("testSet", currentRecipe.getDocumentList());        validationSettings.put("annotation", "Age");        validationSettings.put("type", "CV");        validationSettings.put("foldMethod", "AUTO");        validationSettings.put("numFolds", 10);        validationSettings.put("source", "RANDOM");        validationSettings.put("test", "true");    }    // 加载CSV文档 ==================================================    public Recipe loadNewDocumentsFromCSV(String filePath) {        DocumentList testDocs;        testDocs = chooseDocumentList(filePath);        if (testDocs != null) {            testDocs.guessTextAndAnnotationColumns();            Recipe currentRecipe = Recipe.fetchRecipe();            currentRecipe.setDocumentList(testDocs);            return currentRecipe;        }        return null;    }    public Recipe loadNewDocumentsFromCSV(String[] rootCauseList) {        DocumentList testDocs;        testDocs = chooseDocumentList(rootCauseList);        if (testDocs != null) {            testDocs.guessTextAndAnnotationColumns();            Recipe currentRecipe = Recipe.fetchRecipe();            currentRecipe.setDocumentList(testDocs);            return currentRecipe;        }        return null;    }    protected DocumentList chooseDocumentList(String filePath) {        TreeSet<String> docNames = new TreeSet<String>();        docNames.add(filePath);        try {            DocumentList testDocs;            Charset encoding = Charset.forName("UTF-8");            {                testDocs = ImportController.makeDocumentList(docNames, encoding);            }            return testDocs;        } catch (FileNotFoundException e) {            e.printStackTrace();        } catch (Exception e) {            e.printStackTrace();        }        return null;    }    protected DocumentList chooseDocumentList(String[] rootCauseList) {        try {            DocumentList testDocs;            testDocs = new DocumentList();            testDocs.setName("TestData.csv");            List<String> codes = new ArrayList();            List<String> roots = new ArrayList();            for (String s : rootCauseList) {                codes.add("");                roots.add((s != null) ? s : "");            }            testDocs.addAnnotation("Code", codes, false);            testDocs.addAnnotation("Root Cause Failure Description", roots, false);            return testDocs;        } catch (Exception e) {            e.printStackTrace();        }        return null;    }    // 保存/加载XML ==================================================    public void saveRecipeToXml(Recipe currentRecipe, String filePath) {        File f = new File(filePath);        try {            ConverterControl.writeToXML(f, currentRecipe);        } catch (Exception e) {            e.printStackTrace();        }    }    public Recipe loadRecipeFromXml(String filePath) throws FileNotFoundException, IOException {        Recipe currentRecipe = ConverterControl.loadRecipe(filePath);        return currentRecipe;    }    // 提取特征 ==================================================    public Recipe prepareBuildFeatureTable(Recipe currentRecipe) {        // 添加特征插件        Collection<FeaturePlugin> plugins = new TreeSet<FeaturePlugin>();        for (FeaturePlugin plugin : featurePlugins.keySet()) {            String pluginString = plugin.toString();            if (pluginString == "Basic Features" || pluginString == "Character N-Grams") {                plugins.add(plugin);            }        }        // 将插件生成到Recipe中        currentRecipe = Recipe.addPluginsToRecipe(currentRecipe, plugins);        // 设置插件配置        OrderedPluginMap currentOrderedPluginMap = currentRecipe.getExtractors();        for (SIDEPlugin plugin : currentOrderedPluginMap.keySet()) {            String pluginString = plugin.toString();            Map<String, String> currentConfigurations = currentOrderedPluginMap.get(plugin);            if (pluginString == "Basic Features") {                for (String s : currentConfigurations.keySet()) {                    if (s == "Unigrams" || s == "Bigrams" || s == "Trigrams" ||                        s == "Count Occurences" || s == "Normalize N-Gram Counts" ||                         s == "Stem N-Grams" || s == "Skip Stopwords in N-Grams") {                        currentConfigurations.put(s, "true");                    } else {                        currentConfigurations.put(s, "false");                    }                }            } else if (pluginString == "Character N-Grams") {                for (String s : currentConfigurations.keySet()) {                    if (s == "Include Punctuation") {                        currentConfigurations.put(s, "true");                    } else if (s == "minGram") {                        currentConfigurations.put(s, "3");                    } else if (s == "maxGram") {                        currentConfigurations.put(s, "4");                    }                }                currentConfigurations.put("Extract Only Within Words", "true");            }        }        // 构建特征表        currentRecipe = buildFeatureTable(currentRecipe, featureTableName, featureThreshold, featureAnnotation, featureType);        return currentRecipe;    }    protected Recipe buildFeatureTable(Recipe currentRecipe, String name,   int threshold, String annotation, Type type) {        FeaturePlugin activeExtractor = null;        try {            Collection<FeatureHit> hits = new HashSet<FeatureHit>();            for (SIDEPlugin plug : currentRecipe.getExtractors().keySet()) {                activeExtractor = (FeaturePlugin) plug;                hits.addAll(activeExtractor.extractFeatureHits(currentRecipe.getDocumentList(), currentRecipe.getExtractors().get(plug)));            }            FeatureTable ft = new FeatureTable(currentRecipe.getDocumentList(), hits, threshold, annotation, type);            ft.setName(name);            currentRecipe.setFeatureTable(ft);        } catch (Exception e) {            System.err.println("特征提取失败");            e.printStackTrace();        }        return currentRecipe;    }    // 构建模型 ==================================================    public Recipe prepareBuildModel(Recipe currentRecipe) {        try {            // 获取学习插件            LearningPlugin learner = null;            for (LearningPlugin plugin : learningPlugins.keySet()) {                /* if (plugin.toString() == "Naive Bayes") */                if (plugin.toString() == "Logistic Regression") {                    learner = plugin;                }            }            if (Boolean.TRUE.toString().equals(validationSettings.get("test"))) {                if (validationSettings.get("type").equals("CV")) {                    validationSettings.put("testSet", currentRecipe.getDocumentList());                }            }            Map<String, String> settings = learner.generateConfigurationSettings();            currentRecipe = Recipe.addLearnerToRecipe(currentRecipe, learner, settings);            currentRecipe.setValidationSettings(new TreeMap<String, Serializable>(validationSettings));            for (WrapperPlugin wrap : wrapperPlugins.keySet()) {                if (wrapperPlugins.get(wrap)) {                    currentRecipe.addWrapper(wrap, wrap.generateConfigurationSettings());                }            }            buildModel(currentRecipe, validationSettings);        } catch (Exception e) {            e.printStackTrace();        }        return currentRecipe;    }    protected void buildModel(Recipe currentRecipe,            Map<String, Serializable> validationSettings) {        try {            FeatureTable currentFeatureTable = currentRecipe.getTrainingTable();            if (currentRecipe != null) {                TrainingResult results = null;                /*                 * if (validationSettings.get("type").equals("SUPPLY")) {                 * DocumentList test = (DocumentList)                 * validationSettings.get("testSet"); FeatureTable                 * extractTestFeatures = prepareTestFeatureTable(currentRecipe,                 * validationSettings, test);                 * validationSettings.put("testFeatureTable",                 * extractTestFeatures);                 *                  * // if we've already trained the exact same model, don't // do                 * it again. Just evaluate. Recipe cached =                 * checkForCachedModel(); if (cached != null) { results =                 * evaluateUsingCachedModel(currentFeatureTable,                 * extractTestFeatures, cached, currentRecipe); } }                 */                if (results == null) {                    results = currentRecipe.getLearner().train(currentFeatureTable, currentRecipe.getLearnerSettings(), validationSettings, currentRecipe.getWrappers());                }                if (results != null) {                    currentRecipe.setTrainingResult(results);                    results.setName(trainingResultName);                    currentRecipe.setLearnerSettings(currentRecipe.getLearner().generateConfigurationSettings());                    currentRecipe.setValidationSettings(new TreeMap<String, Serializable>(validationSettings));                }            }        } catch (Exception e) {            e.printStackTrace();        }    }    protected static FeatureTable prepareTestFeatureTable(Recipe recipe, Map<String, Serializable> validationSettings, DocumentList test) {        prepareDocuments(recipe, validationSettings, test); // 分配类别和注释。        Collection<FeatureHit> hits = new TreeSet<FeatureHit>();        OrderedPluginMap extractors = recipe.getExtractors();        for (SIDEPlugin plug : extractors.keySet()) {            Collection<FeatureHit> extractorHits = ((FeaturePlugin) plug).extractFeatureHits(test, extractors.get(plug));            hits.addAll(extractorHits);        }        FeatureTable originalTable = recipe.getTrainingTable();        FeatureTable ft = new FeatureTable(test, hits, 0, originalTable.getAnnotation(), originalTable.getClassValueType());        for (SIDEPlugin plug : recipe.getFilters().keySet()) {            ft = ((RestructurePlugin) plug).filterTestSet(originalTable, ft, recipe.getFilters().get(plug), recipe.getFilteredTable().getThreshold());        }        ft.reconcileFeatures(originalTable.getFeatureSet());        return ft;    }    protected static Map<String, Serializable> prepareDocuments(Recipe currentRecipe, Map<String, Serializable> validationSettings, DocumentList test) throws IllegalStateException {        DocumentList train = currentRecipe.getDocumentList();        try {            test.setCurrentAnnotation(currentRecipe.getTrainingTable().getAnnotation(), currentRecipe.getTrainingTable().getClassValueType());            test.setTextColumns(new HashSet<String>(train.getTextColumns()));            test.setDifferentiateTextColumns(train.getTextColumnsAreDifferentiated());            Collection<String> trainColumns = train.allAnnotations().keySet();            Collection<String> testColumns = test.allAnnotations().keySet();            if (!testColumns.containsAll(trainColumns)) {                ArrayList<String> missing = new ArrayList<String>(trainColumns);                missing.removeAll(testColumns);                throw new java.lang.IllegalStateException("测试集注释与训练集不匹配。\n缺少的列:" + missing);            }            validationSettings.put("testSet", test);        } catch (Exception e) {            e.printStackTrace();            throw new java.lang.IllegalStateException("无法准备测试集。\n" + e.getMessage(), e);        }        return validationSettings;    }    //预测标签 ==================================================    public void predictLabels(Recipe recipeToPredict, Recipe currentRecipe) {        DocumentList newDocs = null;        DocumentList originalDocs;        if (useEvaluation) {            originalDocs = recipeToPredict.getTrainingResult().getEvaluationTable().getDocumentList();            TrainingResult results = currentRecipe.getTrainingResult();            List<String> predictions = (List<String>) results.getPredictions();            newDocs = addLabelsToDocs(predictionColumnName, showDists, overwrite, originalDocs, results, predictions, currentRecipe.getTrainingTable());        } else {            originalDocs = recipeToPredict.getDocumentList();            Predictor predictor = new Predictor(currentRecipe, predictionColumnName);            newDocs = predictor.predict(originalDocs, predictionColumnName, showDists, overwrite);        }        // 预测标签结果        model.setDocumentList(newDocs);    }    protected DocumentList addLabelsToDocs(final String name, final boolean showDists, final boolean overwrite, DocumentList docs, TrainingResult results, List<String> predictions, FeatureTable currentFeatureTable) {        Map<String, List<Double>> distributions = results.getDistributions();        DocumentList newDocs = docs.clone();        newDocs.addAnnotation(name, predictions, overwrite);        if (distributions != null) {            if (showDists) {                for (String label : currentFeatureTable.getLabelArray()) {                    List<String> dist = new ArrayList<String>();                    for (int i = 0; i < predictions.size(); i++) {                        dist.add(String.format("%.3f", distributions.get(label).get(i)));                    }                    newDocs.addAnnotation(name + "_" + label + "_score", dist, overwrite);                }            }        }        return newDocs;    }    // ==================================================}

回答:

@[隐藏人名],看起来上面的代码复制了edu.cmu.side.recipe包中的很多功能。然而,看起来你的predictSectionType()方法实际上并没有在任何地方输出模型的预测结果。

如果你确实想要使用训练好的模型在新数据上保存预测结果,请查看edu.cmu.side.recipe.Predictor类。它接受训练模型的路径作为输入,由scripts/predict.sh便捷脚本使用,但如果你需要以编程方式调用它,你可以重新使用它的主方法。

希望这对你有帮助!

Related Posts

L1-L2正则化的不同系数

我想对网络的权重同时应用L1和L2正则化。然而,我找不…

使用scikit-learn的无监督方法将列表分类成不同组别,有没有办法?

我有一系列实例,每个实例都有一份列表,代表它所遵循的不…

f1_score metric in lightgbm

我想使用自定义指标f1_score来训练一个lgb模型…

通过相关系数矩阵进行特征选择

我在测试不同的算法时,如逻辑回归、高斯朴素贝叶斯、随机…

可以将机器学习库用于流式输入和输出吗?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

在TensorFlow中,queue.dequeue_up_to()方法的用途是什么?

我对这个方法感到非常困惑,特别是当我发现这个令人费解的…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注