我正在尝试使用Java中的Weka(具体是在Android Studio中)对一个实例进行分类。最初,我在桌面版Weka GUI中保存了一个模型,并尝试将其导入到我的项目目录中。如果我没有记错的话,这是不行的,因为PC和Android上的Weka JDK版本不同。
现在我尝试直接在Android设备上训练模型(因为我看不到其他选择),通过导入训练数据集。这里就是我遇到问题的地方。当我运行“Test.java”时,我得到了一个错误,说我的源文件没有指定,指向第23行,那里我调用了.loadDataset方法。java.io.IOException: No source has been specified 但是,很明显,我已经指定了一个路径。这条路径正确吗?我不确定自己哪里做错了。我查看了其他例子/博客,但没有一个详细说明。
我的最终目标是:在Android/Java中训练模型,并使用Weka开发的模型在Android/Java中对实例进行分类。
我的代码可以在以下链接找到:
ModelGenerator.java
package com.example.owner.introductoryapplication;import android.support.v7.app.AppCompatActivity;import android.widget.ImageView;import android.widget.TextView;import java.util.logging.Level;import java.util.logging.Logger;import weka.classifiers.Classifier;import weka.classifiers.Evaluation;import weka.classifiers.functions.MultilayerPerceptron;import weka.core.Instances;import weka.core.SerializationHelper;import weka.core.converters.ConverterUtils;public class ModelGenerator{ //String trainingSetPath = "JavaTrainingSet.arff"; //com/example/owner/introductoryapplication/JavaTrainingSet.arff //String modelSavedPath = "com/example/owner/introductoryapplication/JavaTrainingSet.csv"; //从ARFF文件加载数据集并将其保存到Instances对象中 public Instances loadDataset(String path) { //声明并初始化空的训练集 Instances dataset = null; //将数据集加载到程序中 try { //读取数据集 dataset = ConverterUtils.DataSource.read(path); if (dataset.classIndex() == -1) { dataset.setClassIndex(dataset.numAttributes() - 1); } } catch (Exception ex) { Logger.getLogger(ModelGenerator.class.getName()).log(Level.SEVERE, null, ex); } return dataset; } //使用MultilayerPerceptron(神经网络)为训练集构建分类器 public Classifier buildClassifier(Instances traindataset) { MultilayerPerceptron m = new MultilayerPerceptron(); try { m.buildClassifier(traindataset); } catch (Exception ex) { Logger.getLogger(ModelGenerator.class.getName()).log(Level.SEVERE, null, ex); } return m; } //使用测试集评估生成模型的准确性 public String evaluateModel(Classifier model, Instances traindataset, Instances testdataset) { Evaluation eval = null; try { //使用测试数据集评估分类器 eval = new Evaluation(traindataset); eval.evaluateModel(model, testdataset); } catch (Exception ex) { Logger.getLogger(ModelGenerator.class.getName()).log(Level.SEVERE, null, ex); } return eval.toSummaryString("", true); } //将生成的模型保存到路径以便将来预测使用 public void saveModel(Classifier model, String modelpath) { try { SerializationHelper.write(modelpath, model); } catch (Exception ex) { Logger.getLogger(ModelGenerator.class.getName()).log(Level.SEVERE, null, ex); } }}
ModelClassifier.java
package com.example.owner.introductoryapplication;import android.support.v7.app.AppCompatActivity;import java.util.ArrayList;import java.util.logging.Level;import java.util.logging.Logger;import weka.classifiers.Classifier;import weka.classifiers.functions.MultilayerPerceptron;import weka.core.Attribute;import weka.core.DenseInstance;import weka.core.Instances;import weka.core.SerializationHelper;public class ModelClassifier{ private Attribute Age; private Attribute Height; private Attribute Weight; private Attribute UPDRS; private Attribute TUAG; private Attribute Speed; private Attribute Gender; private ArrayList attributes; private ArrayList classVal; private Instances dataRaw; public ModelClassifier() { Age = new Attribute("Age"); Height = new Attribute("Height"); Weight = new Attribute("Weight"); UPDRS = new Attribute("UPDRS"); TUAG = new Attribute("TUAG"); Speed = new Attribute("Speed"); Gender = new Attribute("Gender"); attributes = new ArrayList(); classVal = new ArrayList(); classVal.add("PD"); classVal.add("CO"); attributes.add(Age); attributes.add(Height); attributes.add(Weight); attributes.add(UPDRS); attributes.add(TUAG); attributes.add(Speed); attributes.add(Gender); attributes.add(new Attribute("class", classVal)); dataRaw = new Instances("TestInstances", attributes, 0); dataRaw.setClassIndex(dataRaw.numAttributes() - 1); } public Instances createInstance(double Age, double Height, double Weight, double UPDRS, double TUAG, double Speed, double Gender, double result) { dataRaw.clear(); double[] instanceValue1 = new double[]{Age, Height, 0}; dataRaw.add(new DenseInstance(1.0, instanceValue1)); return dataRaw; } public String classifiy(Instances insts, String path) { String result = "Not classified!!"; Classifier cls = null; try { cls = (MultilayerPerceptron) SerializationHelper.read(path); result = (String) classVal.get((int) cls.classifyInstance(insts.firstInstance())); } catch (Exception ex) { Logger.getLogger(ModelClassifier.class.getName()).log(Level.SEVERE, null, ex); } return result; } public Instances getInstance() { return dataRaw; }}
Test.java
package com.example.owner.introductoryapplication;import com.example.owner.introductoryapplication.ModelGenerator;import com.example.owner.introductoryapplication.ModelClassifier;import android.support.v7.app.AppCompatActivity;import weka.classifiers.functions.MultilayerPerceptron;import weka.core.Debug;import weka.core.Instances;import weka.filters.Filter;import weka.filters.unsupervised.attribute.Normalize;public class Test{ String DATASETPATH = "com/example/owner/introductoryapplication/JavaTrainingSet.arff"; String MODElPATH = "com/example/owner/introductoryapplication/model.bin"; public static void main(String[] args) throws Exception { ModelGenerator mg = new ModelGenerator(); Instances dataset = mg.loadDataset("/com/example/owner/introductoryapplication/JavaTrainingSet.arff"); Filter filter = new Normalize(); // 将数据集分为80%的训练数据集和20%的测试数据集 int trainSize = (int) Math.round(dataset.numInstances() * 0.8); int testSize = dataset.numInstances() - trainSize; dataset.randomize(new Debug.Random(1));// 如果注释掉这一行,模型的准确率将从96.6%下降到80% // 标准化数据集 filter.setInputFormat(dataset); Instances datasetnor = Filter.useFilter(dataset, filter); Instances traindataset = new Instances(datasetnor, 0, trainSize); Instances testdataset = new Instances(datasetnor, trainSize, testSize); // 使用训练数据集构建分类器 MultilayerPerceptron ann = (MultilayerPerceptron) mg.buildClassifier(traindataset); // 使用测试数据集评估分类器 String evalsummary = mg.evaluateModel(ann, traindataset, testdataset); System.out.println("评估: " + evalsummary); // 保存模型 mg.saveModel(ann, "/com/example/owner/introductoryapplication/model.bin"); // 分类单个实例 ModelClassifier cls = new ModelClassifier(); String classname = cls.classifiy(Filter.useFilter(cls.createInstance(50, 20, 30, 14, 16, 10.42, 2, 0), filter), "/com/example/owner/introductoryapplication/model.bin"); System.out.println("\n 实例的类别名称为: " + classname); }}
请尽早告知我。
回答:
详细且深入的回答位于此链接:
简而言之,你需要在res目录下创建一个raw文件夹。然后将任何文件保存到那里。你将通过它们的资源ID来访问这些文件。