如何在预测结果中添加一些自定义字段(例如用户ID)?
List<org.apache.spark.mllib.regression.LabeledPoint> localTesting = ... ;// // 我想为每个LabeledPoint添加一个标识符 DataFrame localTestDF = jsql.createDataFrame(jsc.parallelize(studyData.localTesting), LabeledPoint.class); DataFrame predictions = model.transform(localTestDF); Row[] collect = predictions.select("label", "probability", "prediction").collect(); for (Row r : collect) { // 并且我想在这里返回标识符。 // 所以我是否要保存到数据库中。 int userNo = Integer.parseInt(r.get(0).toString()); double prob = Double.parseDouble(r.get(1).toString()); int prediction = Integer.parseInt(r.get(2).toString()); log.debug(userNo + "," + prob + ", " + prediction); }
但是当我使用这个类来代替LabeledPoint进行localTesting时,
class NoLabeledPoint extends LabeledPoint implements Serializable { private static final long serialVersionUID = -2488661810406135403L; int userNo; public NoLabeledPoint(double label, Vector features) { super(label, features); } public int getUserNo() { return userNo; } public void setUserNo(int userNo) { this.userNo = userNo; }} List<NoLabeledPoint> localTesting = ... ;// 为每个用户设置userNo字段 // 我想为每个LabeledPoint添加一个标识符 DataFrame localTestDF = jsql.createDataFrame(jsc.parallelize(studyData.localTesting), LabeledPoint.class); DataFrame predictions = model.transform(localTestDF); Row[] collect = predictions.select("userNo", "probability", "prediction").collect(); for (Row r : collect) { // 并且我想在这里返回标识符。 // 所以我是否要保存到数据库中。 int userNo = Integer.parseInt(r.get(0).toString()); double prob = Double.parseDouble(r.get(1).toString()); int prediction = Integer.parseInt(r.get(2).toString()); log.debug(userNo + "," + prob + ", " + prediction); }
抛出了异常
org.apache.spark.sql.AnalysisException: cannot resolve 'userNo' given input columns rawPrediction, probability, features, label, prediction; at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42) at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1$$anonfun$apply$2.applyOrElse(CheckAnalysis.scala:63) at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1$$anonfun$apply$2.applyOrElse(CheckAnalysis.scala:52) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:286) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:286) at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:51)
我的意思是我不仅想得到预测数据(特征、标签、概率等),还想从结果中得到一些我想要的自定义字段。例如,userNo、user_id等:predictions.select(“…… “)
更新
已解决。需要修改一行代码
从
DataFrame localTestDF = jsql.createDataFrame(jsc.parallelize(studyData.localTesting), LabeledPoint.class);
改为
DataFrame localTestDF = jsql.createDataFrame(jsc.parallelize(studyData.localTesting), NoLabeledPoint.class);
回答:
由于您没有使用低级的MLlib API,因此根本不需要使用LabeledPoint
。在创建DataFrame
之后,您得到的只是一个带有某些值的Row
,重要的是类型和列名要与您的管道中的参数匹配。
在Scala中,您可以使用任何案例类
org.apache.spark.mllib.linalg.Vector; case class case class LabeledPointWithMeta(userNo: String, label: Double, features: Vector)val rdd: RDD[LabeledPointWithMeta] = ???val df = rdd.toDF
为了能够从中使用它,您可能需要添加@BeanInfo
注解:
import scala.beans.BeanInfo@BeanInfocase class LabeledPointWithMeta(...)
根据Spark SQL和DataFrame指南,看起来在普通Java中您可以这样做**:
import org.apache.spark.mllib.linalg.Vector;public static class LabeledPointWithMeta implements Serializable { private int userNo; private double label; private Vector vector; public int getUserNo() { return userNo; } public void setUserNo(int userNo) { this.userNo = userNo; } public double getLabel() { return label; } public void setLabel(double label) { this.label = label; } public Vector getVector() { return vector; } public void seVector(Vector vector) { this.vector = vector; }}
然后:
JavaRDD<LabeledPointWithMeta> myPoints = ...;DataFrame df = sqlContext.createDataFrame(myPoints LabeledPointWithMeta.class);
我认为简单地更改您的代码也应该可以工作:
DataFrame localTestDF = jsql.createDataFrame( jsc.parallelize(studyData.localTesting), NoLabeledPoint.class);
如果您想使用MLlib,这不会帮到您,但这部分可以很容易地通过简单的RDD
转换如zip
来处理。
* 还有一些元数据,但您不会从LabeledPoint
中得到这些
** 我没有测试上面的代码,所以它可能包含一些错误。