我有一个 Flask API,它接收图像并应该使用预训练模型和 ImageNet 类别索引输出其类别的预测结果。
我知道我的请求脚本正在调用 API 的 /predict 端点,因为我在 API 端得到了这个输出
127.0.0.1 - - [18/Dec/2020 19:15:08] "←[37mPOST /predict HTTP/1.1←[0m" 200 -
当我像下面这样硬编码时,我可以得到一个预测结果,但我不知道如何将其转换为 API 使用:
imagenet_class_index = json.load(open('./static/imagenet_class_index.json'))def get_prediction(image_bytes): tensor = transform_image(image_bytes=image_bytes) outputs = model.forward(tensor) _, y_hat = outputs.max(1) predicted_idx = str(y_hat.item()) return imagenet_class_index[predicted_idx]with open("img059.jpg", 'rb') as f: image_bytes = f.read() print(get_prediction(image_bytes = image_bytes))
这是我 API 的简化版本
import ioimport jsonfrom torchvision import modelsimport torchvision.transforms as transformsfrom PIL import Imagefrom flask import Flask, jsonify, requestapp = Flask(__name__)imagenet_class_index = json.load(open('./static/imagenet_class_index.json'))model = models.densenet121(pretrained=True)model.eval()def transform_image(image_bytes): my_transforms = transforms.Compose([transforms.Resize(255), transforms.CenterCrop(244), transforms.ToTensor(), transforms.Normalize( [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) image = Image.open(io.BytesIO(image_bytes)) return my_transforms(image).unsqueeze(0)def get_prediction(image_bytes): tensor = transform_image(image_bytes = image_bytes) outputs = model.forward(tensor) _, y_hat = outputs.max(1) predicted_idx = str(y_hat.item()) return imagenet_class_index[predicted_idx]@app.route('/predict', methods=['POST'])def predict(): if request.method == 'POST': # 从请求中获取文件 file = request.files['file'] # 将文件转换为字节 img_bytes = file.read() class_id, class_name = get_prediction(image_bytes = img_bytes) return jsonify({'class_id' : class_id, 'class_name' : class_name})@app.route('/')def base_route(): return 'Greetings, Traveller!'if __name__ == '__main__': app.run()
编辑:基本路由日志
127.0.0.1 - - [18/Dec/2020 19:14:59] "←[37mGET / HTTP/1.1←[0m" 200 -
request.py
import requestsresp = requests.post("http://localhost:5000/predict", files={"file": open('img059.jpg','rb')})
回答:
我认为您只是没有打印响应。您的客户端脚本应该是
request.py
import requestsresp = requests.post("http://localhost:5000/predict", files={"file": open('img059.jpg','rb')})result = resp.json()print(f"Class Id:{result['class_id']}, Class Name: {result['class_name']}")
现在您应该能够看到结果了。