如何从TreeExplainer中获取shap_values的特征名称?

我在进行一个SHAP教程,尝试获取数据集中每个人的SHAP值

from sklearn.model_selection import train_test_splitimport xgboostimport shapimport numpy as npimport pandas as pdimport matplotlib.pylab as plX,y = shap.datasets.adult()X_display,y_display = shap.datasets.adult(display=True)# create a train/test splitX_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=7)d_train = xgboost.DMatrix(X_train, label=y_train)d_test = xgboost.DMatrix(X_test, label=y_test)params = {    "eta": 0.01,    "objective": "binary:logistic",    "subsample": 0.5,    "base_score": np.mean(y_train),    "eval_metric": "logloss"}#model = xgboost.train(params, d_train, 5000, evals = [(d_test, "test")], verbose_eval=100, early_stopping_rounds=20)xg_clf = xgboost.XGBClassifier()xg_clf.fit(X_train, y_train)explainer = shap.TreeExplainer(xg_clf, X_train)#shap_values = explainer(X)shap_values = explainer.shap_values(X)

在Python3解释器中运行,shap_values是一个包含32,561个人的巨大数组,每个人有12个特征的SHAP值。

例如,第一个人的SHAP值如下:

>>> shap_values[0]array([ 0.76437867, -0.11881508,  0.57451954, -0.41974955, -0.20982443,       -0.38079952, -0.00986504,  0.32272505, -3.04392116,  0.00411322,       -0.26587735,  0.02700199])

然而,哪个值对应哪个特征对我来说完全是个谜。

文档中说明:

For models with a single output this returns a matrix of SHAP values        (# samples x # features). Each row sums to the difference between the model output for that        sample and the expected value of the model output (which is stored in the expected_value        attribute of the explainer when it is constant). For models with vector outputs this returns        a list of such matrices, one for each output

当我查看生成shap_valuesexplainer时,我发现我可以获取特征名称:

explainer.data_feature_names['Age', 'Workclass', 'Education-Num', 'Marital Status', 'Occupation', 'Relationship', 'Race', 'Sex', 'Capital Gain', 'Capital Loss', 'Hours per week', 'Country']

但我在Python解释器中看不到如何在shap_values中获取特征名称,如果它们确实存在的话:

>>> shap_values.shap_values.all(           shap_values.compress(      shap_values.dump(          shap_values.max(           shap_values.ravel(         shap_values.sort(          shap_values.tostring(shap_values.any(           shap_values.conj(          shap_values.dumps(         shap_values.mean(          shap_values.real           shap_values.squeeze(       shap_values.trace(shap_values.argmax(        shap_values.conjugate(     shap_values.fill(          shap_values.min(           shap_values.repeat(        shap_values.std(           shap_values.transpose(shap_values.argmin(        shap_values.copy(          shap_values.flags          shap_values.nbytes         shap_values.reshape(       shap_values.strides        shap_values.var(shap_values.argpartition(  shap_values.ctypes         shap_values.flat           shap_values.ndim           shap_values.resize(        shap_values.sum(           shap_values.view(shap_values.argsort(       shap_values.cumprod(       shap_values.flatten(       shap_values.newbyteorder(  shap_values.round(         shap_values.swapaxes(      shap_values.astype(        shap_values.cumsum(        shap_values.getfield(      shap_values.nonzero(       shap_values.searchsorted(  shap_values.T              shap_values.base           shap_values.data           shap_values.imag           shap_values.partition(     shap_values.setfield(      shap_values.take(          shap_values.byteswap(      shap_values.diagonal(      shap_values.item(          shap_values.prod(          shap_values.setflags(      shap_values.tobytes(       shap_values.choose(        shap_values.dot(           shap_values.itemset(       shap_values.ptp(           shap_values.shape          shap_values.tofile(        shap_values.clip(          shap_values.dtype          shap_values.itemsize       shap_values.put(           shap_values.size           shap_values.tolist(    

我的主要问题:如何弄清楚

['Age', 'Workclass', 'Education-Num', 'Marital Status', 'Occupation', 'Relationship', 'Race', 'Sex', 'Capital Gain', 'Capital Loss', 'Hours per week', 'Country']

中的哪个特征对应shap_values的每一行中的哪个数字?

>>> shap_values[0]array([ 0.76437867, -0.11881508,  0.57451954, -0.41974955, -0.20982443,       -0.38079952, -0.00986504,  0.32272505, -3.04392116,  0.00411322,       -0.26587735,  0.02700199])

我假设特征的顺序是相同的,但我没有证据证明这一点。

我的次要问题:如何在shap_values中找到特征名称?


回答:

正如你所假设的,特征的顺序确实是相同的;参见GitHub上的如何提取最重要的特征名称?如何从解释器中获取特征名称问题。

要找到特征名称,你只需访问具有相同索引的名称数组的元素即可

例如:

shap_values = np.array([    0.76437867, -0.11881508,  0.57451954, -0.41974955, -0.20982443,   -0.38079952, -0.00986504,  0.32272505, -3.04392116,  0.00411322,   -0.26587735,  0.02700199])features_names = ['Age', 'Workclass', 'Education-Num', 'Marital Status', 'Occupation',                  'Relationship', 'Race', 'Sex', 'Capital Gain', 'Capital Loss',                  'Hours per week', 'Country']features_names[shap_values.argmin()]  # 索引8 -> Capital Gainfeatures_names[shap_values.argmax()]  # 索引0 -> Age

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注