### Pycaret中的概率与预测标签不匹配

我正在使用pycaret构建一个分类模型,使用的代码如下:

sample = pd.DataFrame(sample)exp_clf = setup(sample, target = 'match',fix_imbalance = True)clf_model = create_model('lightgbm')tuned_clf_model = tune_model(clf_model, optimize = 'Recall')tuned_tuned_clf_model_pred = predict_model(tuned_clf_model, data = sample)

现在问题出现在这里,因为标签1和0的分数有重叠:

enter image description here

这是我使用的数据,可以转换为字典并按上述代码的第一行转换为数据框。

sample =  {'same_add_number': {1521: False,  1756: False,  2456: False,  589: False,  51: False,  668: False,  3030: False,  864: True,  681: False,  372: False,  2768: False,  3519: False,  2212: True,  2424: False,  672: False,  1802: False,  3910: False,  1174: False,  1556: False,  922: False,  3416: False,  719: False,  641: False,  1364: False,  3153: False,  775: False,  967: False,  4054: False,  518: False,  121: False,  1027: False,  4447: True,  257: False,  706: True,  3219: False,  3009: True,  3980: False,  483: False,  3154: False,  4399: True,  2085: False,  373: False,  1469: False,  768: False,  1491: True,  2734: False,  2623: False,  746: True,  1647: False,  3806: False,  4351: False,  925: False,  602: False,  992: False,  2041: False,  1911: False,  615: False,  759: False,  835: False,  2139: False,  56: False,  1980: False,  995: True,  1696: False,  166: False,  114: True,  275: False,  2973: False,  1313: False,  1039: False,  1573: False,  771: False,  3193: False,  2292: False,  2597: False,  1747: False,  1939: False,  2598: False,  1998: False,  3288: False,  528: False,  829: False,  3591: False,  973: False,  4383: False,  1689: False,  1286: False,  4388: False,  491: False,  3920: False,  449: False,  2840: False,  1324: False,  2801: False,  1605: False,  1355: False,  1444: False,  941: False,  4109: False,  1767: False,  839: False,  188: False,  3939: False,  1186: False,  540: False,  1456: False,  3925: True,  1782: False,  1733: False,  64: True,  2710: False,  893: False,  1434: False,  1244: False,  503: False,  3044: False,  1617: False,  2878: False,  913: False,  799: False,  2202: False,  3503: False,  4063: False,  3756: False,  659: False,  1287: False,  3843: False,  2026: True,  1224: False,  705: False,  900: False,  500: False,  614: False,  2766: False,  8: False,  981: False,  1919: False,  2790: False,  1098: False,  1442: False,  2634: False,  3346: False,  652: True,  2324: False,  972: False,  287: False,  2481: False,  2486: False,  4272: False,  4011: False,  4: False,  1645: False,  863: False,  688: False,  2365: False,  3522: False,  13: False,  3251: False,  1410: False,  2306: False,  443: False,  221: False,  632: True,  2549: False,  783: False,  3221: False,  3183: False,  410: False,  1289: False,  1691: False,  2015: False,  1022: True,  455: False,  572: False,  2747: False,  3670: False,  4441: False,  2559: False,  159: False,  91: False,  263: False,  3012: False,  1234: False,  4040: False,  288: False,  89: False,  1029: False,  1180: False,  1083: False,  3970: False,  4201: False,  709: False,  2401: False,  1071: False,  2954: True,  29: True}, 'same_add_name': {1521: False,  1756: False,  2456: False,  589: False,  51: False,  668: False,  3030: False,  864: False,  681: False,  372: False,  2768: False,  3519: False,  2212: False,  2424: False,  672: False,  1802: False,  3910: False,  1174: False,  1556: False,  922: False,  3416: False,  719: False,  641: False,  1364: False,  3153: False,  775: False,  967: False,  4054: False,  518: False,  121: False,  1027: False,  4447: False,  257: False,  706: False,  3219: False,  3009: False,  3980: False,  483: False,  3154: False,  4399: False,  2085: False,  373: False,  1469: False,  768: False,  1491: False,  2734: False,  2623: False,  746: False,  1647: False,  3806: False,  4351: False,  925: False,  602: False,  992: False,  2041: False,  1911: False,  615: True,  759: False,  835: False,  2139: False,  56: False,  1980: False,  995: False,  1696: False,  166: False,  114: False,  275: False,  2973: False,  1313: False,  1039: False,  1573: False,  771: False,  3193: False,  2292: False,  2597: False,  1747: False,  1939: False,  2598: False,  1998: False,  3288: False,  528: False,  829: False,  3591: False,  973: False,  4383: False,  1689: False,  1286: False,  4388: False,  491: False,  3920: False,  449: False,  2840: False,  1324: False,  2801: False,  1605: False,  1355: False,  1444: False,  941: False,  4109: False,  1767: False,  839: False,  188: False,  3939: False,  1186: False,  540: False,  1456: False,  3925: False,  1782: False,  1733: False,  64: False,  2710: False,  893: False,  1434: False,  1244: False,  503: False,  3044: False,  1617: False,  2878: False,  913: False,  799: False,  2202: False,  3503: False,  4063: False,  3756: False,  659: False,  1287: False,  3843: False,  2026: False,  1224: False,  705: False,  900: False,  500: False,  614: False,  2766: False,  8: False,  981: False,  1919: False,  2790: False,  1098: False,  1442: False,  2634: False,  3346: False,  652: False,  2324: False,  972: False,  287: False,  2481: False,  2486: False,  4272: False,  4011: False,  4: False,  1645: False,  863: False,  688: False,  2365: False,  3522: False,  13: False,  3251: False,  1410: False,  2306: False,  443: False,  221: False,  632: False,  2549: False,  783: False,  3221: False,  3183: False,  410: False,  1289: False,  1691: False,  2015: False,  1022: False,  455: False,  572: False,  2747: False,  3670: False,  4441: False,  2559: False,  159: False,  91: False,  263: True,  3012: False,  1234: False,  4040: False,  288: False,  89: False,  1029: False,  1180: False,  1083: False,  3970: False,  4201: False,  709: False,  2401: False,  1071: False,  2954: False,  29: False}, 'name_score_fuzzy': {1521: 78,  1756: 71,  2456: 73,  589: 38,  51: 71,  668: 49,  3030: 75,  864: 47,  681: 75,  372: 72,  2768: 73,  3519: 85,  2212: 100,  2424: 85,  672: 74,  1802: 46,  3910: 73,  1174: 47,  1556: 80,  922: 73,  3416: 71,  719: 55,  641: 71,  1364: 79,  3153: 74,  775: 54,  967: 73,  4054: 100,  518: 72,  121: 49,  1027: 38,  4447: 100,  257: 74,  706: 40,  3219: 71,  3009: 93,  3980: 72,  483: 46,  3154: 68,  4399: 100,  2085: 80,  373: 77,  1469: 23,  768: 50,  1491: 100,  2734: 79,  2623: 79,  746: 88,  1647: 73,  3806: 79,  4351: 72,  925: 65,  602: 83,  992: 46,  2041: 78,  1911: 77,  615: 45,  759: 52,  835: 77,  2139: 77,  56: 81,  1980: 71,  995: 59,  1696: 83,  166: 71,  114: 50,  275: 47,  2973: 80,  1313: 73,  1039: 75,  1573: 70,  771: 53,  3193: 100,  2292: 79,  2597: 71,  1747: 78,  1939: 84,  2598: 71,  1998: 77,  3288: 85,  528: 44,  829: 72,  3591: 80,  973: 47,  4383: 80,  1689: 85,  1286: 41,  4388: 75,  491: 77,  3920: 70,  449: 73,  2840: 79,  1324: 81,  2801: 73,  1605: 47,  1355: 72,  1444: 72,  941: 62,  4109: 79,  1767: 34,  839: 35,  188: 63,  3939: 75,  1186: 49,  540: 44,  1456: 41,  3925: 91,  1782: 43,  1733: 74,  64: 21,  2710: 71,  893: 57,  1434: 75,  1244: 77,  503: 75,  3044: 71,  1617: 73,  2878: 71,  913: 63,  799: 78,  2202: 71,  3503: 77,  4063: 75,  3756: 77,  659: 51,  1287: 76,  3843: 73,  2026: 100,  1224: 71,  705: 81,  900: 65,  500: 42,  614: 81,  2766: 76,  8: 71,  981: 73,  1919: 73,  2790: 71,  1098: 76,  1442: 73,  2634: 73,  3346: 81,  652: 100,  2324: 84,  972: 73,  287: 63,  2481: 76,  2486: 76,  4272: 64,  4011: 73,  4: 74,  1645: 17,  863: 46,  688: 71,  2365: 76,  3522: 73,  13: 52,  3251: 74,  1410: 80,  2306: 71,  443: 71,  221: 73,  632: 65,  2549: 80,  783: 53,  3221: 71,  3183: 75,  410: 53,  1289: 71,  1691: 85,  2015: 71,  1022: 67,  455: 100,  572: 100,  2747: 77,  3670: 74,  4441: 81,  2559: 84,  159: 22,  91: 79,  263: 41,  3012: 76,  1234: 77,  4040: 73,  288: 82,  89: 71,  1029: 82,  1180: 78,  1083: 77,  3970: 75,  4201: 76,  709: 46,  2401: 76,  1071: 83,  2954: 93,  29: 52}, 'name_score_cos': {1521: 0.805341232815891,  1756: 1.0000000156276607,  2456: 0.7146280288550899,  589: 0.4944973860854622,  51: 0.16448994174134138,  668: 0.6680419517655739,  3030: 0.5178230596082453,  864: 0.34284966537760764,  681: 0.8220122172271629,  372: 0.7372570578072887,  2768: 1.0000000748631144,  3519: 0.6544869126589294,  2212: 1.0,  2424: 0.9999999107799844,  672: 0.8006864625973021,  1802: 0.008748746635272902,  3910: 0.6029157847994123,  1174: 0.43891392720221256,  1556: 0.4592255006317409,  922: 0.602017340163112,  3416: 0.7887549792307141,  719: 0.13458379717430374,  641: 0.8221775985370106,  1364: 0.8349841579827227,  3153: 0.6395051509895127,  775: 0.4861694445439952,  967: 0.6240594839420581,  4054: 1.0,  518: 0.8274708074953143,  121: 0.4156175285346006,  1027: 0.4172238782731538,  4447: 1.0,  257: 0.7144798398523643,  706: 0.2914152988288179,  3219: 0.4892006725361837,  3009: 0.8732375138387463,  3980: 0.5371502775293667,  483: 0.6532926383429954,  3154: 0.7500245353516992,  4399: 1.0,  2085: 0.6994934983150074,  373: 0.0,  1469: 0.13834207989466868,  768: 0.0,  1491: 1.0,  2734: 0.5744607478435466,  2623: 0.521054474126365,  746: 0.900627520280279,  1647: 0.46841195036889005,  3806: 0.5245533025793365,  4351: 0.7190153036645236,  925: 0.602017340163112,  602: 0.8180017827481202,  992: 0.6552306767756036,  2041: 0.8416265969822513,  1911: 0.5760342064839252,  615: 0.3142721314062845,  759: 0.29937879126297773,  835: 0.4814135508437952,  2139: 0.8103994874531241,  56: 0.4777649573427413,  1980: 0.4501770315717141,  995: 0.3185447219204094,  1696: 0.9999999289827698,  166: 0.0,  114: 0.0,  275: -0.059108179802214694,  2973: 0.0,  1313: 0.4103695338595878,  1039: 0.4158014949799697,  1573: 0.7687119146546476,  771: -0.038431693364239676,  3193: 1.0,  2292: 0.9999999289827698,  2597: 0.7014107947566588,  1747: 0.613680567239729,  1939: 0.8930406720693059,  2598: 1.0000000156276607,  1998: 0.9999999107799844,  3288: 0.6015149463851227,  528: 0.48037545624105144,  829: 0.3520640350139409,  3591: 0.5123337954949542,  973: 0.29920325457748886,  4383: 0.605345098540998,  1689: 0.699458791765087,  1286: 0.26151465192863704,  4388: 0.5996518099075245,  491: 0.8274708074953143,  3920: 0.5561721737068668,  449: 0.5309349410096579,  2840: 0.6964415538329863,  1324: 0.8352363777690135,  2801: 0.0,  1605: 0.3992469760734788,  1355: 0.5092696449238323,  1444: 0.7013725048779127,  941: 0.0,  4109: 0.7371134488841004,  1767: 0.32686654729234066,  839: 0.28650412696593686,  188: 0.11578000694274473,  3939: 0.5182830082849388,  1186: 0.5399906358163992,  540: 0.23601516039791495,  1456: 0.4462820528772964,  3925: 0.39035408504387764,  1782: 0.17470256029413367,  1733: 0.9999999289827698,  64: 0.47240949440644947,  2710: 0.21737616101123375,  893: 0.3889650515319831,  1434: 0.3144768136655605,  1244: 0.8456850404860974,  503: 0.8274708074953143,  3044: 0.5604645740029809,  1617: 0.8343403856383358,  2878: 0.6624314741881498,  913: 0.3665973835032023,  799: 0.5785308541963937,  2202: 0.584334176199583,  3503: 0.7330193052968511,  4063: 0.633698984756138,  3756: 0.588157437279164,  659: 0.8040106952622528,  1287: 0.6826384100268522,  3843: 0.7287410320020241,  2026: 1.0,  1224: 0.0,  705: 0.7278133754982946,  900: 0.592942126263229,  500: 0.5038847249789867,  614: 0.6417445279680914,  2766: 0.9999999574199627,  8: 0.722455004886235,  981: 0.6168699100990872,  1919: 0.6551439293796956,  2790: 0.0,  1098: 0.5890947178422432,  1442: 0.39311307805458195,  2634: 0.5434702892550847,  3346: 0.5956843029692919,  652: 1.0,  2324: 0.7619312086149606,  972: 0.5067710204705025,  287: 0.6569573257912408,  2481: 0.5829629588847571,  2486: 0.436286219251023,  4272: 0.5408064181796995,  4011: 0.9999999289827698,  4: 0.7647923556190919,  1645: 0.4139532701675873,  863: 0.40369910836161105,  688: 0.0,  2365: 0.7371134488841004,  3522: 0.6205927634025437,  13: 0.6688829431116972,  3251: 0.7114075759658299,  1410: 0.3589092268079449,  2306: 1.0000000396582405,  443: 0.6808489866836555,  221: 0.5811068730506951,  632: 0.5470606107366598,  2549: 0.7123831914993078,  783: 0.46296630135808603,  3221: 0.5883753355908442,  3183: 0.7371134488841004,  410: 0.7604057492722187,  1289: 0.5855230248645426,  1691: 0.727210015672603,  2015: 0.9999999107799844,  1022: 0.0,  455: 1.0,  572: 1.0,  2747: 0.7761666318621021,  3670: 0.5560044398288135,  4441: 0.7697792208927854,  2559: 0.5788817989918374,  159: 0.27027908726745226,  91: 0.5462872872864122,  263: 0.3015316394560223,  3012: 0.6611230100784922,  1234: 0.6639184765411582,  4040: 0.9999999768133089,  288: 0.7681366994965638,  89: 0.7030570621995992,  1029: 0.5322036652128525,  1180: 0.3590668280085605,  1083: 0.7805410171946893,  3970: 0.47446565960369524,  4201: 0.813152589308668,  709: 0.37964467582959255,  2401: 0.6551620258724654,  1071: 0.21475894870778542,  2954: 0.8452728458129916,  29: 0.5138088947304236}, 'match': {1521: 0,  1756: 0,  2456: 0,  589: 0,  51: 0,  668: 0,  3030: 0,  864: 1,  681: 0,  372: 0,  2768: 0,  3519: 0,  2212: 1,  2424: 0,  672: 0,  1802: 0,  3910: 0,  1174: 0,  1556: 0,  922: 0,  3416: 0,  719: 0,  641: 0,  1364: 0,  3153: 0,  775: 0,  967: 0,  4054: 1,  518: 0,  121: 0,  1027: 0,  4447: 1,  257: 0,  706: 0,  3219: 0,  3009: 0,  3980: 0,  483: 0,  3154: 0,  4399: 1,  2085: 0,  373: 0,  1469: 0,  768: 0,  1491: 1,  2734: 0,  2623: 0,  746: 1,  1647: 0,  3806: 0,  4351: 0,  925: 0,  602: 0,  992: 0,  2041: 0,  1911: 0,  615: 0,  759: 0,  835: 0,  2139: 0,  56: 0,  1980: 0,  995: 1,  1696: 0,  166: 0,  114: 1,  275: 0,  2973: 0,  1313: 0,  1039: 0,  1573: 0,  771: 0,  3193: 0,  2292: 0,  2597: 0,  1747: 0,  1939: 0,  2598: 0,  1998: 0,  3288: 0,  528: 0,  829: 0,  3591: 0,  973: 0,  4383: 0,  1689: 0,  1286: 0,  4388: 0,  491: 0,  3920: 0,  449: 0,  2840: 0,  1324: 0,  2801: 0,  1605: 0,  1355: 0,  1444: 0,  941: 0,  4109: 0,  1767: 0,  839: 0,  188: 0,  3939: 0,  1186: 0,  540: 0,  1456: 0,  3925: 1,  1782: 0,  1733: 0,  64: 0,  2710: 0,  893: 0,  1434: 0,  1244: 0,  503: 0,  3044: 0,  1617: 0,  2878: 0,  913: 0,  799: 0,  2202: 0,  3503: 0,  4063: 0,  3756: 0,  659: 0,  1287: 0,  3843: 0,  2026: 1,  1224: 0,  705: 0,  900: 0,  500: 0,  614: 0,  2766: 0,  8: 0,  981: 0,  1919: 0,  2790: 0,  1098: 0,  1442: 0,  2634: 0,  3346: 0,  652: 1,  2324: 0,  972: 0,  287: 0,  2481: 0,  2486: 0,  4272: 0,  4011: 0,  4: 0,  1645: 0,  863: 0,  688: 0,  2365: 0,  3522: 0,  13: 0,  3251: 0,  1410: 0,  2306: 0,  443: 0,  221: 0,  632: 0,  2549: 0,  783: 0,  3221: 0,  3183: 0,  410: 0,  1289: 0,  1691: 0,  2015: 0,  1022: 1,  455: 1,  572: 1,  2747: 0,  3670: 0,  4441: 0,  2559: 0,  159: 0,  91: 0,  263: 0,  3012: 0,  1234: 0,  4040: 0,  288: 0,  89: 0,  1029: 0,  1180: 0,  1083: 0,  3970: 0,  4201: 0,  709: 0,  2401: 0,  1071: 0,  2954: 0,  29: 1}}


回答:

奇怪的是,分数被设置为标签的概率。换句话说,如果模型的原始输出为0.01,数据框将显示为Label = 0 | Score = 0.99。如果模型的原始输出为0.99,数据框将显示为Label = 1 | Score = 0.99。我认为这在进行二元分类以外的分类时可能更有意义。

如果你不满足于仅仅相信我的话(我不会怪你),你可以通过更改预测行来获取原始分数:

tuned_tuned_clf_model_pred = predict_model(tuned_clf_model, raw_score=True, data = sample)

注意raw_score=True。然后你的数据框将有两个分数列(Score_0Score_1)。从那里,你可以通过以下方式获取你想要的直方图:

tuned_tuned_clf_model_pred[tuned_tuned_clf_model_pred["Label"]==0].Score_1.hist()tuned_tuned_clf_model_pred[tuned_tuned_clf_model_pred["Label"]==1].Score_1.hist() 

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注