为什么我的 Flask API 不返回类别标签?

我有一个 Flask API,它接收图像并应该使用预训练模型和 ImageNet 类别索引输出其类别的预测结果。

我知道我的请求脚本正在调用 API 的 /predict 端点,因为我在 API 端得到了这个输出

127.0.0.1 - - [18/Dec/2020 19:15:08] "←[37mPOST /predict HTTP/1.1←[0m" 200 -

当我像下面这样硬编码时,我可以得到一个预测结果,但我不知道如何将其转换为 API 使用:

imagenet_class_index = json.load(open('./static/imagenet_class_index.json'))def get_prediction(image_bytes):    tensor = transform_image(image_bytes=image_bytes)    outputs = model.forward(tensor)    _, y_hat = outputs.max(1)    predicted_idx = str(y_hat.item())    return imagenet_class_index[predicted_idx]with open("img059.jpg", 'rb') as f:    image_bytes = f.read()    print(get_prediction(image_bytes = image_bytes))

这是我 API 的简化版本

import ioimport jsonfrom torchvision import modelsimport torchvision.transforms as transformsfrom PIL import Imagefrom flask import Flask, jsonify, requestapp = Flask(__name__)imagenet_class_index = json.load(open('./static/imagenet_class_index.json'))model = models.densenet121(pretrained=True)model.eval()def transform_image(image_bytes):    my_transforms = transforms.Compose([transforms.Resize(255),                                        transforms.CenterCrop(244),                                        transforms.ToTensor(),                                        transforms.Normalize(                                            [0.485, 0.456, 0.406],                                            [0.229, 0.224, 0.225])])    image = Image.open(io.BytesIO(image_bytes))    return my_transforms(image).unsqueeze(0)def get_prediction(image_bytes):    tensor = transform_image(image_bytes = image_bytes)    outputs = model.forward(tensor)    _, y_hat = outputs.max(1)    predicted_idx = str(y_hat.item())    return imagenet_class_index[predicted_idx]@app.route('/predict', methods=['POST'])def predict():    if request.method == 'POST':        # 从请求中获取文件        file = request.files['file']        # 将文件转换为字节        img_bytes = file.read()        class_id, class_name = get_prediction(image_bytes = img_bytes)        return jsonify({'class_id' : class_id, 'class_name' : class_name})@app.route('/')def base_route():    return 'Greetings, Traveller!'if __name__ == '__main__':    app.run()

编辑:基本路由日志

127.0.0.1 - - [18/Dec/2020 19:14:59] "←[37mGET / HTTP/1.1←[0m" 200 -

request.py

import requestsresp = requests.post("http://localhost:5000/predict",                     files={"file": open('img059.jpg','rb')})

回答:

我认为您只是没有打印响应。您的客户端脚本应该是

request.py

import requestsresp = requests.post("http://localhost:5000/predict",                     files={"file": open('img059.jpg','rb')})result = resp.json()print(f"Class Id:{result['class_id']}, Class Name: {result['class_name']}")

现在您应该能够看到结果了。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注