我在进行一个SHAP教程,尝试获取数据集中每个人的SHAP值
from sklearn.model_selection import train_test_splitimport xgboostimport shapimport numpy as npimport pandas as pdimport matplotlib.pylab as plX,y = shap.datasets.adult()X_display,y_display = shap.datasets.adult(display=True)# create a train/test splitX_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=7)d_train = xgboost.DMatrix(X_train, label=y_train)d_test = xgboost.DMatrix(X_test, label=y_test)params = { "eta": 0.01, "objective": "binary:logistic", "subsample": 0.5, "base_score": np.mean(y_train), "eval_metric": "logloss"}#model = xgboost.train(params, d_train, 5000, evals = [(d_test, "test")], verbose_eval=100, early_stopping_rounds=20)xg_clf = xgboost.XGBClassifier()xg_clf.fit(X_train, y_train)explainer = shap.TreeExplainer(xg_clf, X_train)#shap_values = explainer(X)shap_values = explainer.shap_values(X)
在Python3解释器中运行,shap_values
是一个包含32,561个人的巨大数组,每个人有12个特征的SHAP值。
例如,第一个人的SHAP值如下:
>>> shap_values[0]array([ 0.76437867, -0.11881508, 0.57451954, -0.41974955, -0.20982443, -0.38079952, -0.00986504, 0.32272505, -3.04392116, 0.00411322, -0.26587735, 0.02700199])
然而,哪个值对应哪个特征对我来说完全是个谜。
文档中说明:
For models with a single output this returns a matrix of SHAP values (# samples x # features). Each row sums to the difference between the model output for that sample and the expected value of the model output (which is stored in the expected_value attribute of the explainer when it is constant). For models with vector outputs this returns a list of such matrices, one for each output
当我查看生成shap_values
的explainer
时,我发现我可以获取特征名称:
explainer.data_feature_names['Age', 'Workclass', 'Education-Num', 'Marital Status', 'Occupation', 'Relationship', 'Race', 'Sex', 'Capital Gain', 'Capital Loss', 'Hours per week', 'Country']
但我在Python解释器中看不到如何在shap_values
中获取特征名称,如果它们确实存在的话:
>>> shap_values.shap_values.all( shap_values.compress( shap_values.dump( shap_values.max( shap_values.ravel( shap_values.sort( shap_values.tostring(shap_values.any( shap_values.conj( shap_values.dumps( shap_values.mean( shap_values.real shap_values.squeeze( shap_values.trace(shap_values.argmax( shap_values.conjugate( shap_values.fill( shap_values.min( shap_values.repeat( shap_values.std( shap_values.transpose(shap_values.argmin( shap_values.copy( shap_values.flags shap_values.nbytes shap_values.reshape( shap_values.strides shap_values.var(shap_values.argpartition( shap_values.ctypes shap_values.flat shap_values.ndim shap_values.resize( shap_values.sum( shap_values.view(shap_values.argsort( shap_values.cumprod( shap_values.flatten( shap_values.newbyteorder( shap_values.round( shap_values.swapaxes( shap_values.astype( shap_values.cumsum( shap_values.getfield( shap_values.nonzero( shap_values.searchsorted( shap_values.T shap_values.base shap_values.data shap_values.imag shap_values.partition( shap_values.setfield( shap_values.take( shap_values.byteswap( shap_values.diagonal( shap_values.item( shap_values.prod( shap_values.setflags( shap_values.tobytes( shap_values.choose( shap_values.dot( shap_values.itemset( shap_values.ptp( shap_values.shape shap_values.tofile( shap_values.clip( shap_values.dtype shap_values.itemsize shap_values.put( shap_values.size shap_values.tolist(
我的主要问题:如何弄清楚
['Age', 'Workclass', 'Education-Num', 'Marital Status', 'Occupation', 'Relationship', 'Race', 'Sex', 'Capital Gain', 'Capital Loss', 'Hours per week', 'Country']
中的哪个特征对应shap_values
的每一行中的哪个数字?
>>> shap_values[0]array([ 0.76437867, -0.11881508, 0.57451954, -0.41974955, -0.20982443, -0.38079952, -0.00986504, 0.32272505, -3.04392116, 0.00411322, -0.26587735, 0.02700199])
我假设特征的顺序是相同的,但我没有证据证明这一点。
我的次要问题:如何在shap_values
中找到特征名称?
回答:
正如你所假设的,特征的顺序确实是相同的;参见GitHub上的如何提取最重要的特征名称?和如何从解释器中获取特征名称问题。
要找到特征名称,你只需访问具有相同索引的名称数组的元素即可
例如:
shap_values = np.array([ 0.76437867, -0.11881508, 0.57451954, -0.41974955, -0.20982443, -0.38079952, -0.00986504, 0.32272505, -3.04392116, 0.00411322, -0.26587735, 0.02700199])features_names = ['Age', 'Workclass', 'Education-Num', 'Marital Status', 'Occupation', 'Relationship', 'Race', 'Sex', 'Capital Gain', 'Capital Loss', 'Hours per week', 'Country']features_names[shap_values.argmin()] # 索引8 -> Capital Gainfeatures_names[shap_values.argmax()] # 索引0 -> Age