我在尝试创建一个LogisticRegression模型(LogisticRegressionWithSGD),但遇到了以下错误:
org.apache.spark.SparkException: Input validation failed.
如果我提供二进制输入(0,1而不是0,1,2),它就能成功运行。
示例输入:
parsed_data = [LabeledPoint(0.0, [4.6,3.6,1.0,0.2]),LabeledPoint(0.0, [5.7,4.4,1.5,0.4]),LabeledPoint(1.0, [6.7,3.1,4.4,1.4]),LabeledPoint(0.0, [4.8,3.4,1.6,0.2]),LabeledPoint(2.0, [4.4,3.2,1.3,0.2])]
代码: model = LogisticRegressionWithSGD.train(parsed_data)
Spark中的Logistic Regression模型是否仅用于二元分类?
回答:
虽然文档中没有明确说明(你需要深入源代码才能发现),LogisticRegressionWithSGD
仅适用于二进制数据;对于多项回归,你应该使用LogisticRegressionWithLBFGS
:
from pyspark.mllib.classification import LogisticRegressionWithLBFGS, LogisticRegressionModel, LogisticRegressionWithSGD from pyspark.mllib.regression import LabeledPoint parsed_data = [LabeledPoint(0.0, [4.6,3.6,1.0,0.2]), LabeledPoint(0.0, [5.7,4.4,1.5,0.4]), LabeledPoint(1.0, [6.7,3.1,4.4,1.4]), LabeledPoint(0.0, [4.8,3.4,1.6,0.2]), LabeledPoint(2.0, [4.4,3.2,1.3,0.2])] model = LogisticRegressionWithSGD.train(sc.parallelize(parsed_data)) # 会报错: # org.apache.spark.SparkException: Input validation failed. model = LogisticRegressionWithLBFGS.train(sc.parallelize(parsed_data), numClasses=3) # 运行正常