标题:Caffe是否支持16位图像?如果不支持,如何实现支持?

背景信息:我需要加载一些16位灰度PNG图像。

Caffe通过ImageDataLayer是否支持加载16位图像?

经过一些搜索,答案似乎是否定的。ImageDataLayer依赖于这个io例程

cv::Mat ReadImageToCVMat(const string& filename,    const int height, const int width, const bool is_color) {  cv::Mat cv_img;  int cv_read_flag = (is_color ? CV_LOAD_IMAGE_COLOR :    CV_LOAD_IMAGE_GRAYSCALE);  cv::Mat cv_img_origin = cv::imread(filename, cv_read_flag);  if (!cv_img_origin.data) {    LOG(ERROR) << "Could not open or find file " << filename;    return cv_img_origin;  }  if (height > 0 && width > 0) {    cv::resize(cv_img_origin, cv_img, cv::Size(width, height));  } else {    cv_img = cv_img_origin;  }  return cv_img;}

它使用了OpenCV的cv::imread函数。除非设置了适当的标志,否则此函数会将输入读取为8位

CV_LOAD_IMAGE_ANYDEPTH – 如果设置了此标志,当输入具有相应的深度时,返回16位/32位图像,否则将其转换为8位。

简单地添加适当的标志并不会起作用,因为在代码的后续部分[io.cpp]中,他们检查了8位深度:

void CVMatToDatum(const cv::Mat& cv_img, Datum* datum) {  CHECK(cv_img.depth() == CV_8U) << "Image data type must be unsigned byte";... }

我可以删除这个检查,但我担心它是有原因的,可能会导致不可预测的结果。谁能对此问题提供一些见解?


回答:

您可以像这样修补ImageDataLayer以读取16位图像:

  1. 如您所述,添加适当的标志(io.cpp):

int cv_read_flag = (is_color ? CV_LOAD_IMAGE_COLOR :    CV_LOAD_IMAGE_GRAYSCALE);

之后添加

cv_read_flag |= CV_LOAD_IMAGE_ANYDEPTH;
  1. 修改您提到的检查(data_transformer.cpp):

CHECK(cv_img.depth() == CV_8U) << "Image data type must be unsigned byte";

改为

CHECK(cv_img.depth() == CV_8U || cv_img.depth() == CV_16U) << "Image data type must be uint8 or uint16";bool is16bit = cv_img.depth() == CV_16U;
  1. 修改DataTransformer读取cv::Mat的方式(在同一函数中):

添加uint16_t类型的指针到:

const uchar* ptr = cv_cropped_img.ptr<uchar>(h);

像这样

const uint16_t* ptr_16 = cv_cropped_img.ptr<uint16_t>(h);

然后使用适当的指针读取:

Dtype pixel = static_cast<Dtype>(ptr[img_index++]);

改为

Dtype pixel;if(is16bit)    pixel = static_cast<Dtype>(ptr_16[img_index++]);else    pixel = static_cast<Dtype>(ptr[img_index++]);

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注