使用Spark 3加载PipelineModel时出现AnalysisException

我在将Spark版本从2.4.5升级到3.0.1时,无法再加载使用”DecisionTreeClassifier”阶段的PipelineModel对象。

在我的代码中,我加载了多个PipelineModel,所有使用阶段[“CountVectorizer_[uid]”, “LinearSVC_[uid]”]的PipelineModel都能正常加载,而使用阶段[“CountVectorizer_[uid]”,”DecisionTreeClassifier_[uid]”]的模型则会抛出以下异常:

AnalysisException: 无法解析’rawCount‘,给定输入列为:[gain, id, impurity, impurityStats, leftChild, prediction, rightChild,split]

这是我使用的代码和完整的堆栈跟踪:

from pyspark.ml.pipeline import PipelineModelPipelineModel.load("/path/to/model")AnalysisException                         Traceback (most recent call last)<command-1278858167154148> in <module>----> 1 RalentModel = PipelineModel.load(MODELES_ATTRIBUTS + "RalentModel_DT")/databricks/spark/python/pyspark/ml/util.py in load(cls, path)    368     def load(cls, path):    369         """Reads an ML instance from the input path, a shortcut of `read().load(path)`."""--> 370         return cls.read().load(path)    371     372 /databricks/spark/python/pyspark/ml/pipeline.py in load(self, path)    289         metadata = DefaultParamsReader.loadMetadata(path, self.sc)    290         if 'language' not in metadata['paramMap'] or metadata['paramMap']['language'] != 'Python':--> 291             return JavaMLReader(self.cls).load(path)    292         else:    293             uid, stages = PipelineSharedReadWrite.load(metadata, self.sc, path)/databricks/spark/python/pyspark/ml/util.py in load(self, path)    318         if not isinstance(path, basestring):    319             raise TypeError("path should be a basestring, got type %s" % type(path))--> 320         java_obj = self._jread.load(path)    321         if not hasattr(self._clazz, "_from_java"):    322             raise NotImplementedError("This Java ML type cannot be loaded into Python currently: %r"/databricks/spark/python/lib/py4j-0.10.9-src.zip/py4j/java_gateway.py in __call__(self, *args)   1303         answer = self.gateway_client.send_command(command)   1304         return_value = get_return_value(-> 1305             answer, self.gateway_client, self.target_id, self.name)   1306    1307         for temp_arg in temp_args:/databricks/spark/python/pyspark/sql/utils.py in deco(*a, **kw)    131                 # Hide where the exception came from that shows a non-Pythonic    132                 # JVM exception message.--> 133                 raise_from(converted)    134             else:    135                 raise/databricks/spark/python/pyspark/sql/utils.py in raise_from(e)AnalysisException: cannot resolve '`rawCount`' given input columns: [gain, id, impurity, impurityStats, leftChild, prediction, rightChild, split];

这些Pipeline模型是使用Spark 2.4.3保存的,我可以使用Spark 2.4.5正常加载它们。

我尝试进一步调查并单独加载每个阶段。使用以下代码加载CountVectorizerModel:

from pyspark.ml.feature import CountVectorizerModelCountVectorizerModel.read().load("/path/to/model/stages/0_CountVectorizer_efce893314a9")

可以得到一个CountVectorizerModel,因此这是可行的,但我的代码在尝试加载DecisionTreeClassificationModel时失败:

DecisionTreeClassificationModel.read().load("/path/to/model/stages/1_DecisionTreeClassifier_4d2a76c565b0")AnalysisException: cannot resolve '`rawCount`' given input columns: [gain, id, impurity, impurityStats, leftChild, prediction, rightChild, split];

这是我的决策树分类器的”data”内容:

spark.read.parquet("/path/to/model/stages/1_DecisionTreeClassifier_4d2a76c565b0/data").show()+---+----------+--------------------+-------------+--------------------+---------+----------+----------------+| id|prediction|            impurity|impurityStats|                gain|leftChild|rightChild|           split|+---+----------+--------------------+-------------+--------------------+---------+----------+----------------+|  0|       0.0|  0.3926234384295062| [90.0, 33.0]| 0.16011830963990054|        1|        16|[190, [0.5], -1]||  1|       0.0|  0.2672722508516028| [90.0, 17.0]| 0.11434106988303855|        2|        15|[512, [0.5], -1]||  2|       0.0|  0.1652892561983472|  [90.0, 9.0]| 0.06959547629404085|        3|        14|[583, [0.5], -1]||  3|       0.0| 0.09972299168975082|  [90.0, 5.0]|0.026984966852376356|        4|        11|[480, [0.5], -1]||  4|       0.0|0.043933846736523306|  [87.0, 2.0]|0.021717299239076976|        5|        10|[555, [1.5], -1]||  5|       0.0|0.022469008264462766|  [87.0, 1.0]|0.011105371900826402|        6|         7|[833, [0.5], -1]||  6|       0.0|                 0.0|  [86.0, 0.0]|                -1.0|       -1|        -1|    [-1, [], -1]||  7|       0.0|                 0.5|   [1.0, 1.0]|                 0.5|        8|         9|  [0, [0.5], -1]||  8|       0.0|                 0.0|   [1.0, 0.0]|                -1.0|       -1|        -1|    [-1, [], -1]||  9|       1.0|                 0.0|   [0.0, 1.0]|                -1.0|       -1|        -1|    [-1, [], -1]|| 10|       1.0|                 0.0|   [0.0, 1.0]|                -1.0|       -1|        -1|    [-1, [], -1]|| 11|       0.0|                 0.5|   [3.0, 3.0]|                 0.5|       12|        13| [14, [1.5], -1]|| 12|       0.0|                 0.0|   [3.0, 0.0]|                -1.0|       -1|        -1|    [-1, [], -1]|| 13|       1.0|                 0.0|   [0.0, 3.0]|                -1.0|       -1|        -1|    [-1, [], -1]|| 14|       1.0|                 0.0|   [0.0, 4.0]|                -1.0|       -1|        -1|    [-1, [], -1]|| 15|       1.0|                 0.0|   [0.0, 8.0]|                -1.0|       -1|        -1|    [-1, [], -1]|| 16|       1.0|                 0.0|  [0.0, 16.0]|                -1.0|       -1|        -1|    [-1, [], -1]|+---+----------+--------------------+-------------+--------------------+---------+----------+----------------+

回答:

这是一个错误,我在这里提交了一个问题:https://issues.apache.org/jira/browse/SPARK-33398,它在这个PR中得到了解决:https://github.com/apache/spark/pull/30889

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注