dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。
看起来我成功地训练了模型,然而,当我尝试预测测试数据并查看实际预测结果时,遇到了以下错误:

ValueError: 数据必须是一维的

这是我尝试预测数据的方式:

from dask_ml.model_selection import train_test_splitimport daskimport xgboostimport dask_xgboostfrom dask.distributed import Clientimport dask_ml.model_selection as dcv#分割数据x_train,x_test,y_train,y_test=train_test_split(x,y,test_size=0.33,random_state=42)client = Client(n_workers=10, threads_per_worker=1)#尝试进行超参数运行model_xgb = xgb.XGBRegressor(seed=42,verbose=True)params={    'learning_rate':[0.1,0.01,0.05],    'max_depth':[1,5,8],    'gamma':[0,0.5,1],    'scale_pos_weight':[1,3,5]}grid_search = GridSearchCV(model_xgb, params, cv=3, scoring='neg_mean_squared_error')grid_search.fit(x_train, y_train)#使用最佳参数训练数据bst = dask_xgboost.train(client, grid_search.best_params_, x_train, y_train, num_boost_round=10)#预测数据dask_xgboost.predict(client, bst, x_test).persist()

最后一行使用 predict 可以工作,但当我在末尾添加 compute 以查看实际数组时,我得到了维度错误:

dask_xgboost.predict(client, bst, x_test).persist().compute()>>>ValueError: 数据必须是一维的

如何使用 .predict 获取预测结果?


回答:

pip 页面中对 dask-xgboost 的说明:

Dask-XGBoost 已被废弃且不再维护。此项目的功能已直接包含在 XGBoost 中。要一起使用 Dask 和 XGBoost,请使用 xgboost.dask 代替 https://xgboost.readthedocs.io/en/latest/tutorials/dask.html.

您提供的代码中缺少一些赋值和表达式(例如,x 是如何定义的,GridSearchCV 从哪里导入)。以下是一些可能需要更改的地方:

# 注意 .daskmodel_xgb = xgb.dask.DaskXGBRegressor(seed=42, verbose=True)grid_search = GridSearchCV(model_xgb, params, cv=3, scoring='neg_mean_squared_error')grid_search.fit(x_train, y_train)#使用最佳参数训练模型model_xgb.client = clientmodel_xgb.set_params(grid_search.best_params_)model_xgb.fit(X_train, y_train, eval_set=[(X_test, y_test)])

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

### 卷积神经网络 – 1D – 特征分类错误

我正在尝试修改下面的示例,以模拟我的数据集的CNN,并…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注