为Spark ML的LabeledPoint添加自定义字段

如何在预测结果中添加一些自定义字段(例如用户ID)?

        List<org.apache.spark.mllib.regression.LabeledPoint> localTesting = ... ;//        // 我想为每个LabeledPoint添加一个标识符        DataFrame localTestDF = jsql.createDataFrame(jsc.parallelize(studyData.localTesting), LabeledPoint.class);        DataFrame predictions = model.transform(localTestDF);        Row[] collect = predictions.select("label", "probability", "prediction").collect();        for (Row r : collect) {            // 并且我想在这里返回标识符。            // 所以我是否要保存到数据库中。            int userNo = Integer.parseInt(r.get(0).toString());            double prob = Double.parseDouble(r.get(1).toString());            int prediction = Integer.parseInt(r.get(2).toString());            log.debug(userNo + "," + prob + ", " + prediction);        }

但是当我使用这个类来代替LabeledPoint进行localTesting时,

class NoLabeledPoint extends LabeledPoint implements Serializable {    private static final long serialVersionUID = -2488661810406135403L;    int userNo;    public NoLabeledPoint(double label, Vector features) {        super(label, features);    }    public int getUserNo() {        return userNo;    }    public void setUserNo(int userNo) {        this.userNo = userNo;    }}        List<NoLabeledPoint> localTesting = ... ;// 为每个用户设置userNo字段        // 我想为每个LabeledPoint添加一个标识符        DataFrame localTestDF = jsql.createDataFrame(jsc.parallelize(studyData.localTesting), LabeledPoint.class);        DataFrame predictions = model.transform(localTestDF);        Row[] collect = predictions.select("userNo", "probability", "prediction").collect();        for (Row r : collect) {            // 并且我想在这里返回标识符。            // 所以我是否要保存到数据库中。            int userNo = Integer.parseInt(r.get(0).toString());            double prob = Double.parseDouble(r.get(1).toString());            int prediction = Integer.parseInt(r.get(2).toString());            log.debug(userNo + "," + prob + ", " + prediction);        }

抛出了异常

org.apache.spark.sql.AnalysisException: cannot resolve 'userNo' given input columns rawPrediction, probability, features, label, prediction;        at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42)        at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1$$anonfun$apply$2.applyOrElse(CheckAnalysis.scala:63)        at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1$$anonfun$apply$2.applyOrElse(CheckAnalysis.scala:52)        at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:286)        at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:286)        at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:51)

我的意思是我不仅想得到预测数据(特征、标签、概率等),还想从结果中得到一些我想要的自定义字段。例如,userNo、user_id等:predictions.select(“…… “)

更新

已解决。需要修改一行代码

            DataFrame localTestDF = jsql.createDataFrame(jsc.parallelize(studyData.localTesting), LabeledPoint.class);

改为

            DataFrame localTestDF = jsql.createDataFrame(jsc.parallelize(studyData.localTesting), NoLabeledPoint.class);

回答:

由于您没有使用低级的MLlib API,因此根本不需要使用LabeledPoint。在创建DataFrame之后,您得到的只是一个带有某些值的Row,重要的是类型和列名要与您的管道中的参数匹配。

在Scala中,您可以使用任何案例类

org.apache.spark.mllib.linalg.Vector; case class case class LabeledPointWithMeta(userNo: String, label: Double, features: Vector)val rdd: RDD[LabeledPointWithMeta] = ???val df = rdd.toDF

为了能够从中使用它,您可能需要添加@BeanInfo注解:

import scala.beans.BeanInfo@BeanInfocase class LabeledPointWithMeta(...)

根据Spark SQL和DataFrame指南,看起来在普通Java中您可以这样做**:

import org.apache.spark.mllib.linalg.Vector;public static class LabeledPointWithMeta implements Serializable {  private int userNo;  private double label;  private Vector vector;  public int getUserNo() {    return userNo;  }  public void setUserNo(int userNo) {    this.userNo = userNo;  }  public double getLabel() {    return label;  }  public void setLabel(double label) {    this.label = label;  }  public Vector getVector() {    return vector;  }  public void seVector(Vector vector) {    this.vector = vector;  }}

然后:

JavaRDD<LabeledPointWithMeta> myPoints = ...;DataFrame df = sqlContext.createDataFrame(myPoints LabeledPointWithMeta.class);

我认为简单地更改您的代码也应该可以工作:

DataFrame localTestDF = jsql.createDataFrame(    jsc.parallelize(studyData.localTesting),    NoLabeledPoint.class); 

如果您想使用MLlib,这不会帮到您,但这部分可以很容易地通过简单的RDD转换如zip来处理。


* 还有一些元数据,但您不会从LabeledPoint中得到这些

** 我没有测试上面的代码,所以它可能包含一些错误。

Related Posts

L1-L2正则化的不同系数

我想对网络的权重同时应用L1和L2正则化。然而,我找不…

使用scikit-learn的无监督方法将列表分类成不同组别,有没有办法?

我有一系列实例,每个实例都有一份列表,代表它所遵循的不…

f1_score metric in lightgbm

我想使用自定义指标f1_score来训练一个lgb模型…

通过相关系数矩阵进行特征选择

我在测试不同的算法时,如逻辑回归、高斯朴素贝叶斯、随机…

可以将机器学习库用于流式输入和输出吗?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

在TensorFlow中,queue.dequeue_up_to()方法的用途是什么?

我对这个方法感到非常困惑,特别是当我发现这个令人费解的…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注