如何在Caffe中获取二分类器的两个输出值(针对两个类别)?

我正在尝试使用LeNet网络作为二分类器(是,否)。测试配置文件中的第一层和最后几层如下所示:

    layer {      name: "data"      type: "ImageData"      top: "data"      top: "label"      include {        phase: TEST      }      transform_param {        scale: 0.00390625      }      image_data_param {        source: "examples/my_example/test_images_labels.txt"        batch_size: 1        new_height: 128        new_width: 128      }    }...    layer {      name: "ip2"      type: "InnerProduct"      bottom: "ip1"      top: "ip2"      param {        lr_mult: 1      }      param {        lr_mult: 2      }      inner_product_param {        num_output: 2        weight_filler {          type: "xavier"        }        bias_filler {          type: "constant"        }      }    }    layer {      name: "accuracy"      type: "Accuracy"      bottom: "ip2"      bottom: "label"      top: "accuracy"    }    layer {      name: "loss"      type: "SoftmaxWithLoss"      bottom: "ip2"      bottom: "label"      top: "loss"    }

在测试时,我设置了batch_size=1,因此我使用以下命令运行测试:

./build/tools/caffe test -model examples/my_example/lenet_test.prototxt -weights=examples/my_example/lenet_iter_528.caffemodel -iterations 200

我的目的是能够单独分析每张测试图像的结果。目前,每次迭代我得到以下信息:

I0310 18:30:21.889688  5952 caffe.cpp:264] Batch 41, accuracy = 1I0310 18:30:21.889739  5952 caffe.cpp:264] Batch 41, loss = 0.578524

然而,由于我的网络有两个输出,我希望在测试时看到每个输出的两个独立值:一个用于类别“0”(“否”),另一个用于类别“1”(“是”)。应该类似于以下内容:

Batch 41, class 0 output: 0.755Batch 41, class 1 output: 0.201

我应该如何修改测试配置文件来实现这一点?


回答:

你想看到"Softmax"的概率输出(不仅仅是损失)。
为此,你可以尝试使用"SoftmaxWithLoss"并添加两个"top"(我不确定这个选项是否完全功能/支持):

layer {  name: "loss"  type: "SoftmaxWithLoss"  bottom: "ip2"  bottom: "label"  top: "loss"  top: "prob" # 添加类别概率输出}

或者,如果前面的解决方案不起作用,明确添加一个"Softmax"层:

layer {  name: "prob"  type: "Softmax"  bottom: "ip2"  top: "prob"}

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注