如何解释h2o.predict()结果的概率(p0, p1)

我想了解H2o R包中h2o.predict()函数的返回值(结果)的含义。我发现,在某些情况下,当predict列为1时,p1列的值反而低于p0列。我对p0p1列的理解是它们代表每个事件的概率,因此我期望当predict=1时,p1的概率应该高于相反事件(p0)的概率,但事实并非总是如此,如下面的例子所示:使用前列腺数据集

以下是可执行的示例:

library(h2o)h2o.init(max_mem_size = "12g", nthreads = -1)prostate.hex <- h2o.importFile("https://h2o-public-test-data.s3.amazonaws.com/smalldata/prostate/prostate.csv")prostate.hex$CAPSULE  <- as.factor(prostate.hex$CAPSULE)prostate.hex$RACE     <- as.factor(prostate.hex$RACE)prostate.hex$DCAPS    <- as.factor(prostate.hex$DCAPS)prostate.hex$DPROS    <- as.factor(prostate.hex$DPROS)prostate.hex.split = h2o.splitFrame(data = prostate.hex,  ratios = c(0.70, 0.20, 0.10), seed = 1234)train.hex     <- prostate.hex.split[[1]]validate.hex  <- prostate.hex.split[[2]]test.hex      <- prostate.hex.split[[3]]fit <- h2o.glm(y = "CAPSULE", x = c("AGE", "RACE", "PSA", "DCAPS"),  training_frame = train.hex,  validation_frame = validate.hex,  family = "binomial", nfolds = 0, alpha = 0.5)prostate.predict = h2o.predict(object = fit, newdata = test.hex)result <- as.data.frame(prostate.predict)subset(result, predict == 1 & p1 < 0.4)

我得到了以下subset函数的结果输出:

   predict        p0        p111       1 0.6355974 0.364402617       1 0.6153021 0.384697923       1 0.6289063 0.371093725       1 0.6007919 0.399208131       1 0.6239587 0.3760413

对于test.hex数据集中上述所有观测,预测结果均为1,但p0 > p1

总共有predict=1p1 < p0的观测数为:

>   nrow(subset(result, predict == 1 & p1 < p0))[1] 14

相反,没有predict=0p0 < p1的情况

>   nrow(subset(result, predict == 0 & p0 < p1))[1] 0

以下是predicttable信息表:

> table(result$predict) 0  1 18 23 

我们使用CAPSULE作为决策变量,其值如下:

> levels(as.data.frame(prostate.hex)$CAPSULE)[1] "0" "1"

有什么建议吗?

注意:与此主题类似的提问:如何解释h2o.predict的结果并未解决这一特定问题。


回答:

似乎(也请参见这里)在validation数据集上最大化F1 score的阈值被用作h2o.glm()分类的默认阈值。我们可以观察到以下几点:

  1. 在验证数据集上最大化F1 score的阈值为0.363477
  2. 所有预测的p1概率低于该阈值的数据点都被分类为0类(预测为0类的数据点具有最高的p1概率=0.3602365 < 0.363477)。
  3. 所有预测的p1概率高于该阈值的数据点都被分类为1类(预测为1类的数据点具有最低的p1概率=0.3644026 > 0.363477)。

    min(result[result$predict==1,]$p1)# [1] 0.3644026max(result[result$predict==0,]$p1)# [1] 0.3602365# Thresholds found by maximizing the metrics on the training datasetfit@model$training_metrics@metrics$max_criteria_and_metric_scores #Maximum Metrics: Maximum metrics at their respective thresholds#                        metric threshold    value idx#1                       max f1  0.314699 0.641975 200#2                       max f2  0.215203 0.795148 262#3                 max f0point5  0.451965 0.669856  74#4                 max accuracy  0.451965 0.707581  74#5                max precision  0.998285 1.000000   0#6                   max recall  0.215203 1.000000 262#7              max specificity  0.998285 1.000000   0#8             max absolute_mcc  0.451965 0.395147  74#9   max min_per_class_accuracy  0.360174 0.652542 127#10 max mean_per_class_accuracy  0.391279 0.683269  97# Thresholds found by maximizing the metrics on the validation datasetfit@model$validation_metrics@metrics$max_criteria_and_metric_scores #Maximum Metrics: Maximum metrics at their respective thresholds#                        metric threshold    value idx#1                       max f1  0.363477 0.607143  33#2                       max f2  0.292342 0.785714  51#3                 max f0point5  0.643382 0.725806   9#4                 max accuracy  0.643382 0.774194   9#5                max precision  0.985308 1.000000   0#6                   max recall  0.292342 1.000000  51#7              max specificity  0.985308 1.000000   0#8             max absolute_mcc  0.643382 0.499659   9#9   max min_per_class_accuracy  0.379602 0.650000  28#10 max mean_per_class_accuracy  0.618286 0.702273  11result[order(result$predict),]#   predict          p0        p1#5        0 0.703274569 0.2967254#6        0 0.639763460 0.3602365#13       0 0.689557497 0.3104425#14       0 0.656764541 0.3432355#15       0 0.696248328 0.3037517#16       0 0.707069611 0.2929304#18       0 0.692137408 0.3078626#19       0 0.701482762 0.2985172#20       0 0.705973644 0.2940264#21       0 0.701156961 0.2988430#22       0 0.671778898 0.3282211#24       0 0.646735016 0.3532650#26       0 0.646582708 0.3534173#27       0 0.690402957 0.3095970#32       0 0.649945017 0.3500550#37       0 0.804937468 0.1950625#40       0 0.717706731 0.2822933#41       0 0.642094040 0.3579060#1        1 0.364577068 0.6354229#2        1 0.503432724 0.4965673#3        1 0.406771233 0.5932288#4        1 0.551801718 0.4481983#7        1 0.339600779 0.6603992#8        1 0.002978593 0.9970214#9        1 0.378034417 0.6219656#10       1 0.596298925 0.4037011#11       1 0.635597359 0.3644026#12       1 0.552662241 0.4473378#17       1 0.615302107 0.3846979#23       1 0.628906297 0.3710937#25       1 0.600791894 0.3992081#28       1 0.216571552 0.7834284#29       1 0.559174924 0.4408251#30       1 0.489514642 0.5104854#31       1 0.623958696 0.3760413#33       1 0.504691497 0.4953085#34       1 0.582509462 0.4174905#35       1 0.504136056 0.4958639#36       1 0.463076505 0.5369235#38       1 0.510908093 0.4890919#39       1 0.469376828 0.5306232

Related Posts

如何使用Google Protobuf解析、编辑和生成object_detection/pipeline.config文件

我在一个常见的集成学习范式中训练多个模型,目前我在处理…

我的GridSearchCV不起作用,我不知道为什么

大家好,我在使用GridSearchCV时遇到了问题,…

Keras: 两个同时进行的层,其中一个对前一层的输出进行卷积

我想实现这样的模型连接: 输入图像1 -> 卷积层1 …

如何将行数据转换为列数据而不使用独热编码

我有一个如下所示的数据集。 MonthDate Day…

使用 ML Kit 与 NNAPI

我正在尝试在运行 Android 9 的设备上使用新的…

Vowpal Wabbit 可能的哈希冲突

我在VW中生成了一个模型,并且在相同的数据上生成了两个…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注