为Spark ML的LabeledPoint添加自定义字段

如何在预测结果中添加一些自定义字段(例如用户ID)?

        List<org.apache.spark.mllib.regression.LabeledPoint> localTesting = ... ;//        // 我想为每个LabeledPoint添加一个标识符        DataFrame localTestDF = jsql.createDataFrame(jsc.parallelize(studyData.localTesting), LabeledPoint.class);        DataFrame predictions = model.transform(localTestDF);        Row[] collect = predictions.select("label", "probability", "prediction").collect();        for (Row r : collect) {            // 并且我想在这里返回标识符。            // 所以我是否要保存到数据库中。            int userNo = Integer.parseInt(r.get(0).toString());            double prob = Double.parseDouble(r.get(1).toString());            int prediction = Integer.parseInt(r.get(2).toString());            log.debug(userNo + "," + prob + ", " + prediction);        }

但是当我使用这个类来代替LabeledPoint进行localTesting时,

class NoLabeledPoint extends LabeledPoint implements Serializable {    private static final long serialVersionUID = -2488661810406135403L;    int userNo;    public NoLabeledPoint(double label, Vector features) {        super(label, features);    }    public int getUserNo() {        return userNo;    }    public void setUserNo(int userNo) {        this.userNo = userNo;    }}        List<NoLabeledPoint> localTesting = ... ;// 为每个用户设置userNo字段        // 我想为每个LabeledPoint添加一个标识符        DataFrame localTestDF = jsql.createDataFrame(jsc.parallelize(studyData.localTesting), LabeledPoint.class);        DataFrame predictions = model.transform(localTestDF);        Row[] collect = predictions.select("userNo", "probability", "prediction").collect();        for (Row r : collect) {            // 并且我想在这里返回标识符。            // 所以我是否要保存到数据库中。            int userNo = Integer.parseInt(r.get(0).toString());            double prob = Double.parseDouble(r.get(1).toString());            int prediction = Integer.parseInt(r.get(2).toString());            log.debug(userNo + "," + prob + ", " + prediction);        }

抛出了异常

org.apache.spark.sql.AnalysisException: cannot resolve 'userNo' given input columns rawPrediction, probability, features, label, prediction;        at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42)        at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1$$anonfun$apply$2.applyOrElse(CheckAnalysis.scala:63)        at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1$$anonfun$apply$2.applyOrElse(CheckAnalysis.scala:52)        at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:286)        at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:286)        at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:51)

我的意思是我不仅想得到预测数据(特征、标签、概率等),还想从结果中得到一些我想要的自定义字段。例如,userNo、user_id等:predictions.select(“…… “)

更新

已解决。需要修改一行代码

            DataFrame localTestDF = jsql.createDataFrame(jsc.parallelize(studyData.localTesting), LabeledPoint.class);

改为

            DataFrame localTestDF = jsql.createDataFrame(jsc.parallelize(studyData.localTesting), NoLabeledPoint.class);

回答:

由于您没有使用低级的MLlib API,因此根本不需要使用LabeledPoint。在创建DataFrame之后,您得到的只是一个带有某些值的Row,重要的是类型和列名要与您的管道中的参数匹配。

在Scala中,您可以使用任何案例类

org.apache.spark.mllib.linalg.Vector; case class case class LabeledPointWithMeta(userNo: String, label: Double, features: Vector)val rdd: RDD[LabeledPointWithMeta] = ???val df = rdd.toDF

为了能够从中使用它,您可能需要添加@BeanInfo注解:

import scala.beans.BeanInfo@BeanInfocase class LabeledPointWithMeta(...)

根据Spark SQL和DataFrame指南,看起来在普通Java中您可以这样做**:

import org.apache.spark.mllib.linalg.Vector;public static class LabeledPointWithMeta implements Serializable {  private int userNo;  private double label;  private Vector vector;  public int getUserNo() {    return userNo;  }  public void setUserNo(int userNo) {    this.userNo = userNo;  }  public double getLabel() {    return label;  }  public void setLabel(double label) {    this.label = label;  }  public Vector getVector() {    return vector;  }  public void seVector(Vector vector) {    this.vector = vector;  }}

然后:

JavaRDD<LabeledPointWithMeta> myPoints = ...;DataFrame df = sqlContext.createDataFrame(myPoints LabeledPointWithMeta.class);

我认为简单地更改您的代码也应该可以工作:

DataFrame localTestDF = jsql.createDataFrame(    jsc.parallelize(studyData.localTesting),    NoLabeledPoint.class); 

如果您想使用MLlib,这不会帮到您,但这部分可以很容易地通过简单的RDD转换如zip来处理。


* 还有一些元数据,但您不会从LabeledPoint中得到这些

** 我没有测试上面的代码,所以它可能包含一些错误。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注