我开始编写一个用于对一系列文档中的段落进行分类的机器学习模型。我编写了模型,结果看起来很好!然而,当我尝试输入一个不包含labelCol(即标记列,我试图预测的列)的CSV文件时,它抛出了一个错误!“Field tagIndexed does not exist.”
这很奇怪。我要预测的是“tag”列,所以为什么在调用model.transform(df)
(在Predict.scala中)时会期望一个“tagIndexed”列?我在机器学习方面经验不多,但所有DecisionTreeClassifier似乎在测试数据中都不存在labelCol。我在这里错过了什么?
我创建了模型,用测试数据验证它,并将其保存到磁盘上。然后,在另一个Scala对象中,我加载模型并将我的CSV文件传入其中。
//Train.scala package com.secret.classifierimport org.apache.spark.ml.Pipelineimport org.apache.spark.ml.classification.DecisionTreeClassifierimport org.apache.spark.ml.evaluation.RegressionEvaluatorimport org.apache.spark.sql.Columnimport org.apache.spark.ml.feature.{HashingTF, IDF, StringIndexer, Tokenizer, VectorAssembler, Word2Vec}import org.apache.spark.ml.regression.LinearRegressionimport org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit}import org.apache.spark.sql.functions.udfimport org.apache.spark.sql.typesimport org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}...val colSeq = Seq("font", "tag")val indexSeq = colSeq.map(col => new StringIndexer().setInputCol(col).setOutputCol(col+"Indexed").fit(dfNoNan))val tokenizer = new Tokenizer().setInputCol("soup").setOutputCol("words")//val wordsData = tokenizer.transform(dfNoNan)val hashingTF = new HashingTF().setInputCol(tokenizer.getOutputCol).setOutputCol("rawFeatures").setNumFeatures(20)val featuresCol = "features"val assembler = new VectorAssembler().setInputCols((numericCols ++ colSeq.map(_+"Indexed")).toArray).setOutputCol(featuresCol)val labelCol = "tagIndexed"val decisionTree = new DecisionTreeClassifier().setLabelCol(labelCol).setFeaturesCol(featuresCol)val pipeline = new Pipeline().setStages((indexSeq :+ tokenizer :+ hashingTF :+ assembler :+ decisionTree).toArray)val Array(training, test) = dfNoNan.randomSplit(Array(0.8, 0.2), seed=420420)val model = pipeline.fit(training)model.write.overwrite().save("tmp/spark-model")//Predict.scalapackage com.secret.classifierimport org.apache.spark.sql.functions._import org.apache.spark.ml.{Pipeline, PipelineModel}import org.apache.spark.ml.classification.DecisionTreeClassifierimport org.apache.spark.ml.evaluation.RegressionEvaluatorimport org.apache.spark.sql.Columnimport org.apache.spark.ml.feature.{HashingTF, IDF, StringIndexer, Tokenizer, VectorAssembler, Word2Vec}import org.apache.spark.ml.regression.LinearRegressionimport org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit}import org.apache.spark.sql.typesimport org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}... val dfImport = spark.read .format("csv") .option("header", "true") //.option("mode", "DROPMALFORMED") .schema(customSchema) .load(csvLocation)val df = dfImport.drop("_c0", "doc_name")df.show(20)val model = PipelineModel.load("tmp/spark-model")val predictions = model.transform(df)predictions.show(20)//pom.xml -> Spark/Scala specific dependencies<properties> <maven.compiler.source>1.8</maven.compiler.source> <maven.compiler.target>1.8</maven.compiler.target> <encoding>UTF-8</encoding> <scala.version>2.11.12</scala.version> <scala.compat.version>2.11</scala.compat.version> <spec2.version>4.2.0</spec2.version></properties> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-core_2.11</artifactId> <version>2.3.1</version> </dependency> <!-- https://mvnrepository.com/artifact/com.databricks/spark-csv --> <dependency> <groupId>com.databricks</groupId> <artifactId>spark-csv_2.11</artifactId> <version>1.5.0</version> </dependency> <!-- https://mvnrepository.com/artifact/org.apache.spark/spark-sql --> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-sql_2.11</artifactId> <version>2.3.1</version> </dependency> <!-- https://mvnrepository.com/artifact/org.apache.spark/spark-core --> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-core_2.11</artifactId> <version>2.3.1</version> </dependency> <dependency> <groupId>com.univocity</groupId> <artifactId>univocity-parsers</artifactId> <version>2.8.0</version> </dependency> <!-- https://mvnrepository.com/artifact/org.apache.spark/spark-mllib --> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-mllib_2.11</artifactId> <version>2.3.1</version> </dependency></dependencies>
预期的结果是预测模型不会抛出错误。相反,它抛出了错误“Field “tagIndexed” does not exist.”
回答:
看起来你已经将标签字段也包括在特征中,因为它在colSeq列输出中。在这一步中,你只需要包括特征列:
.setInputCols((numericCols ++ colSeq.map(_+"Indexed")).toArray)
我发现使用.filterNot()函数很有帮助。