为TensorFlow Lite C++编写read_jpeg和decode_jpeg函数

TensorFlow Lite在其仓库中有一个很好的C++图像分类示例,在这里。然而,我正在处理.jpeg文件,而这个示例仅限于使用bitmap_helpers.cc解码.bmp图像。

我正在尝试创建自己的jpeg解码器,但我在图像处理方面并不精通,所以需要一些帮助。我正在重用这个jpeg解码器作为第三方辅助库。在示例的bmp解码中,我不太明白计算row_sizes和接受头部之后的字节数组的用意是什么。谁能解释一下这对jpeg解码器的应用?或者,更好的是,是否已经有一个我尚未找到的C++ decode_jpeg函数隐藏在某处?

最终实现必须在TensorFlow Lite的C++中完成。

非常感谢!

编辑:

以下是我目前所做的。当我使用相同的输入图像和tflite模型时,与Python示例的图像分类器相比,我没有得到相同的信心值,这清楚地表明有些地方出了问题。我基本上是复制粘贴了read_bmp中的row_size计算,而没有理解它,所以我怀疑这可能是问题所在。row_size意味着什么?

std::vector<uint8_t> decode_jpeg(const uint8_t* input, int row_size, int width, int height) {    // 通道数总是3。现在硬编码它。    int channels = 3;    // 输出将包含TensorFlow处理的数据。    std::vector<uint8_t> output(height * width * channels);    // 遍历图像的每个像素。    for(int i = 0; i < height; i++) {            int src_pos;            int dst_pos;            for(int j = 0; j < width; j++) {                    src_pos = i * row_size + j * channels;                    dst_pos = (i * width + j) * channels;                    // 将RGB通道数据放入输出数组中。                    output[dst_pos] = input[src_pos + 2];                    output[dst_pos + 1] = input[src_pos + 1];                    output[dst_pos + 2] = input[src_pos];            }    }    return output;}std::vector<uint8_t> read_jpeg(const std::string& input_jpeg_name, int* width, int* height, Settings* s) {    // 大小和缓冲区。    size_t size;    unsigned char *buf;    // 打开输入文件。    FILE *f;    f = fopen(input_jpeg_name.c_str(), "rb");    if (!f) {            if (s->verbose) LOG(INFO) << "Error opening the input file\n";            exit(-1);    }    // 读取文件。    fseek(f, 0, SEEK_END);    // 获取文件大小。    size = ftell(f);    // 将文件数据读入缓冲区。    buf = (unsigned char*)malloc(size);    fseek(f, 0, SEEK_SET);    size_t read = fread(buf, 1, size, f);        // 关闭文件。    fclose(f);    // 解码文件。    Decoder decoder(buf, size);    if (decoder.GetResult() != Decoder::OK)    {            if (s->verbose) LOG(INFO) << "Error decoding the input file\n";            exit(-1);    }    // 从解码文件中获取图像。    unsigned char* img = decoder.GetImage();    // 获取图像宽度和高度。    *width = decoder.GetWidth();    *height = decoder.GetHeight();    // TODO: 理解这个row size的含义。不要只是复制粘贴。    const int row_size = (8 * *channels * *width + 31) / 32 * 4;    // 解码JPEG。    return decode_jpeg(img, row_size, *width, *height);}

回答:

您使用的库已经为您处理了解码,decoder.getImage()包含原始的rgb数据。您不需要计算任何大小。

像row_size这样的东西是BMP文件格式特有的。BMP文件可能包含一些填充字节,除了像素颜色数据外,代码处理了这些东西。

此外,BMP文件以BGR顺序存储像素值,这就是为什么您在原始代码中有反向排序的原因:

// 将RGB通道数据放入输出数组中。output[dst_pos] = input[src_pos + 2];output[dst_pos + 1] = input[src_pos + 1];output[dst_pos + 2] = input[src_pos];

下面的代码应该对您有用(注意decode_jpeg函数不执行任何解码):

std::vector<uint8_t> decode_jpeg(const uint8_t* input, int width, int height) {    // 通道数总是3。现在硬编码它。    int channels = 3;    // 输出将包含TensorFlow处理的数据。    std::vector<uint8_t> output(height * width * channels);    // 复制像素数据到输出    for (size_t i = 0; i < height*width*channels; ++i)    {        output[i] = input[i];    }        return output;}std::vector<uint8_t> read_jpeg(const std::string& input_jpeg_name, int* width, int* height, Settings* s) {    // 大小和缓冲区。    size_t size;    unsigned char *buf;    // 打开输入文件。    FILE *f;    f = fopen(input_jpeg_name.c_str(), "rb");    if (!f) {            if (s->verbose) LOG(INFO) << "Error opening the input file\n";            exit(-1);    }    // 读取文件。    fseek(f, 0, SEEK_END);    // 获取文件大小。    size = ftell(f);    // 将文件数据读入缓冲区。    buf = (unsigned char*)malloc(size);    fseek(f, 0, SEEK_SET);    size_t read = fread(buf, 1, size, f);        // 关闭文件。    fclose(f);    // 解码文件。    Decoder decoder(buf, size);    if (decoder.GetResult() != Decoder::OK)    {            if (s->verbose) LOG(INFO) << "Error decoding the input file\n";            exit(-1);    }    // 从解码文件中获取图像。    unsigned char* img = decoder.GetImage();    // 获取图像宽度和高度。    *width = decoder.GetWidth();    *height = decoder.GetHeight();    // 解码JPEG。    return decode_jpeg(img, *width, *height);}

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注