Spark中train()、run()和fit()函数的区别

在使用Apache Spark(版本1.5.2)进行Java编程时,可以通过多种方式使用逻辑回归:

spark.ml:

1) LogisticRegression lr = new LogisticRegression();a) lr.train(dataFrame);b) lr.fit(dataFrame);

spark.mllib:

2) LogisticRegressionWithSGD lr = new LogisticRegressionWithSGD();a) lr.train(rdd);b) lr.run(rdd);3) LogisticRegressionWithLBFGS lr = new LogisticRegressionWithLBFGS();a) lr.train(rdd);b) lr.run(rdd);

我想知道a)和b)之间的区别是什么,除了run()函数返回的是GeneralizedLinearAlgorithm而不是其他函数返回的LogisticRegressionModel之外?我在Java或Scala的文档中找不到任何提示。感谢您的帮助。


回答:

Spark确实包含两个可用于机器学习的库:ML和MLLib。请问您使用的是Spark的哪个版本?

MLLib. 这是Spark的第一个机器学习库。它的结构非常简单,并且使用RDD来运行。在MLLib中,这有点混乱,因此您需要查看代码以确定使用哪个方法。我不确定您使用的是哪种语言或版本,但对于Spark 1.6.0的Scala版本,有一个单例对象:

object LogisticRegressionWithSGD {   def train(input: RDD[LabeledPoint], ...) = new LogisticRegressionWithSGD(...).run(input,...)}

这意味着train应该作为LogisticRegressionWithSGD对象上的静态方法调用,但如果您有一个LogisticRegressionWithSGD的实例,则只有run方法:

LogisticRegressionWithSGD.train(rdd, parameters) // ORval lr = new LogisticRegressionWithSGD() lr.run(rdd)

无论如何,如果您使用的是其他版本,您应该优先使用超级版本,即run方法。

ML. 这是最新的库,基于DataFrame的使用,DataFrame基本上是带有模式的RDD[Row]Row只是一个未类型化的对象序列,模式是一个包含列名、类型、元数据等信息的对象)。我强烈建议您使用这个库,因为它可以进行优化!在这种情况下,您应该使用fit方法,这是所有估计器都需要实现的方法。

解释: ML库使用了Pipeline的概念(与sci-kit learn中的类似)。一个pipeline实例基本上是一组阶段(类型为PipelineStage),每个阶段要么是Estimator,要么是Transformer(还有其他类型,例如Evaluator,但我在这里不详细讨论,因为它们较少见)。Transformer只是一个转换数据的算法,因此它的主要方法是transform(DataFrame),它输出另一个DataFrameEstimator是一种产生ModelTransformer的子类型)的算法。基本上,它是任何需要在数据上拟合的块,因此它有一个fit(DataFrame)函数,输出一个Transformer。例如,如果您想将所有数据乘以2,您只需要一个实现了将输入乘以2的transform方法的转换器。如果您需要计算平均值并减去它,您需要一个在数据上拟合以计算平均值并输出一个减去所学平均值的转换器的估计器。因此,每次使用ML时,请使用fittransform方法。这使您可以执行以下操作:

val trainingSet = // 训练DataFrameval testSet = // 测试DataFrameval lr = new LogisticRegession().setInputCol(...).setOutputCol(...) // + setParams()val stage = // 另一个阶段,即实现PipelineStage的东西val stages = Array(lr, stage)val pipeline: Pipeline = new Pipeline().setStages(stages)val model: PipelineModel = pipeline.fit(trainingSet)val result: DataFrame = model.transform(testSet)

现在,如果您真的想知道为什么存在train函数,这是由Predictor继承的函数,而Predictor本身扩展了Estimator。确实有许多可能的Estimators – 您可以计算平均值、IDF等。当您实现像逻辑回归这样的预测器时,您有一个扩展了Estimator的抽象类Predictor,它允许您进行一些快捷操作(例如,它有一个标签列、特征列和预测列)。特别是这段代码已经重写了fit以根据这些标签/特征/预测相应地更改您的DataFrame的模式,您只需实现自己的train方法:

override def fit(dataset: DataFrame): M = {   // 这处理了一些项目,例如模式验证。   // 开发者只需实现train()。   transformSchema(dataset.schema, logging = true)   copyValues(train(dataset).setParent(this))}protected def train(dataset: DataFrame): M

如您所见,train方法应该是受保护的/私有的,因此不应由外部用户使用。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注