通过带有FLASK的REST API在PYTHON中部署PYTORCH
在本教程中,我们将使用Flask部署PyTorch模型,并公开用于模型推断的REST API。特别是,我们将部署预训练的DenseNet 121模型来检测图像。
TIP
All the code used here is released under MIT license and is available on Github.
API Definition
我们将首先定义API端点,请求和响应类型。我们的API端点将在/predict
该端点接受带有file
包含图像的参数的HTTP POST请求 。响应将是包含预测的JSON响应:
{"class_id": "n02124075", "class_name": "Egyptian_cat"}
Dependencies
通过运行以下命令来安装所需的依赖项:
$ pip install Flask==1.0.3 torchvision-0.3.0
Simple Web Server
以下是一个简单的网络服务器,摘自Flask的文档
from flask import Flask app = Flask(__name__) @app.route('/') def hello(): return 'Hello World!'
将以上代码段保存在一个名为的文件中app.py
,现在可以通过输入以下内容来运行Flask开发服务器:
$ FLASK_ENV=development FLASK_APP=app.py flask run
当大家http://localhost:5000/
在网络浏览器中访问时,会收到文字欢迎Hello
World!
我们将对上面的代码片段进行一些更改,以使其适合我们的API定义。首先,我们将方法重命名为predict
。我们将端点路径更新为/predict
。由于图像文件将通过HTTP POST请求发送,因此我们将对其进行更新,使其也仅接受POST请求
@app.route('/predict', methods=['POST']) def predict(): return 'Hello World!'
我们还将更改响应类型,以使其返回包含ImageNet类ID和名称的JSON响应。更新后的app.py
文件将是:
from flask import Flask, jsonify app = Flask(__name__) @app.route('/predict', methods=['POST']) def predict(): return jsonify({'class_id': 'IMAGE_NET_XXX', 'class_name': 'Cat'})
Inference
在下一部分中,我们将重点介绍编写推理代码。这将涉及两部分,第一部分是准备图像,以便可以将其馈送到DenseNet;第二部分,我们将编写代码以从模型中获取实际的预测。
Preparing the image
DenseNet模型要求图像为尺寸为224 x 224的3通道RGB图像。我们还将使用所需的均值和标准偏差值对图像张量进行归一化。可以在此处了解更多信息 。
我们将使用transforms
来自torchvision
库,并建立一个管道改造的要求,它改变我们的图像。可以在此处了解有关转换的更多信息。
import io import torchvision.transforms as transforms from PIL import Image def transform_image(image_bytes): my_transforms = transforms.Compose([transforms.Resize(255), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) image = Image.open(io.BytesIO(image_bytes)) return my_transforms(image).unsqueeze(0)
上面的方法以字节为单位获取图像数据,应用一系列变换并返回张量。要测试上述方法,请以字节模式读取图像文件(首先将../_static/img/sample_file.jpeg替换为计算机上文件的实际路径),然后查看是否获得了张量:
with open("../_static/img/sample_file.jpeg", 'rb') as f: image_bytes = f.read() tensor = transform_image(image_bytes=image_bytes) print(tensor)
输出:
tensor([[[[ 0.4508, 0.4166, 0.3994, ..., -1.3473, -1.3302, -1.3473], [ 0.5364, 0.4851, 0.4508, ..., -1.2959, -1.3130, -1.3302], [ 0.7077, 0.6392, 0.6049, ..., -1.2959, -1.3302, -1.3644], ..., [ 1.3755, 1.3927, 1.4098, ..., 1.1700, 1.3584, 1.6667], [ 1.8893, 1.7694, 1.4440, ..., 1.2899, 1.4783, 1.5468], [ 1.6324, 1.8379, 1.8379, ..., 1.4783, 1.7352, 1.4612]], [[ 0.5728, 0.5378, 0.5203, ..., -1.3704, -1.3529, -1.3529], [ 0.6604, 0.6078, 0.5728, ..., -1.3004, -1.3179, -1.3354], [ 0.8529, 0.7654, 0.7304, ..., -1.3004, -1.3354, -1.3704], ..., [ 1.4657, 1.4657, 1.4832, ..., 1.3256, 1.5357, 1.8508], [ 2.0084, 1.8683, 1.5182, ..., 1.4657, 1.6583, 1.7283], [ 1.7458, 1.9384, 1.9209, ..., 1.6583, 1.9209, 1.6408]], [[ 0.7228, 0.6879, 0.6531, ..., -1.6476, -1.6302, -1.6476], [ 0.8099, 0.7576, 0.7228, ..., -1.6476, -1.6476, -1.6650], [ 1.0017, 0.9145, 0.8797, ..., -1.6476, -1.6650, -1.6999], ..., [ 1.6291, 1.6291, 1.6465, ..., 1.6291, 1.8208, 2.1346], [ 2.1868, 2.0300, 1.6814, ..., 1.7685, 1.9428, 2.0125], [ 1.9254, 2.0997, 2.0823, ..., 1.9428, 2.2043, 1.9080]]]])
Prediction
现在将使用预训练的DenseNet 121模型来预测图像类别。我们将使用torchvision
库中的一个,加载模型并进行推断。尽管在此示例中将使用预训练的模型,但是可以对自己的模型使用相同的方法。在本教程中了解有关加载模型的更多信息。
from torchvision import models # Make sure to pass `pretrained` as `True` to use the pretrained weights: model = models.densenet121(pretrained=True) # Since we are using our model only for inference, switch to `eval` mode: model.eval() def get_prediction(image_bytes): tensor = transform_image(image_bytes=image_bytes) outputs = model.forward(tensor) _, y_hat = outputs.max(1) return y_hat
张量y_hat
将包含预测的类ID的索引。但是,我们需要一个人类可读的类名。为此,我们需要一个类ID来进行名称映射。将该文件下载 为,imagenet_class_index.json
并记住它的保存位置(或者,如果您按照本教程中的确切步骤操作,请将其保存在 tutorials / _static中)。此文件包含ImageNet类ID到ImageNet类名称的映射。我们将加载此JSON文件并获取预测索引的类名称。
import json imagenet_class_index = json.load(open('../_static/imagenet_class_index.json')) def get_prediction(image_bytes): tensor = transform_image(image_bytes=image_bytes) outputs = model.forward(tensor) _, y_hat = outputs.max(1) predicted_idx = str(y_hat.item()) return imagenet_class_index[predicted_idx]
在使用imagenet_class_index
字典之前,首先我们将张量值转换为字符串值,因为imagenet_class_index
字典中的键 是字符串。我们将测试上述方法:
with open("../_static/img/sample_file.jpeg", 'rb') as f: image_bytes = f.read() print(get_prediction(image_bytes=image_bytes))
输出:
['n02124075', 'Egyptian_cat']
大家应该得到如下响应:
['n02124075', 'Egyptian_cat']
数组中的第一项是ImageNet类ID,第二项是人类可读的名称。
NOTE
大家是否注意到model
变量不是get_prediction
方法的一部分?还是为什么模型是全局变量?就内存和计算而言,加载模型可能是一项昂贵的操作。如果我们在get_prediction
方法中加载模型,则每次调用该方法时都会不必要地加载该模型 。由于我们正在构建一个Web服务器,因此每秒可能有成千上万的请求,因此我们不应该浪费时间为每个推断重复加载模型。因此,我们仅将模型加载到内存中一次。在生产系统中,必须有效使用计算以能够大规模处理请求,因此通常应在处理请求之前加载模型。
Integrating the model in our API Server
在最后一部分中,我们将模型添加到Flask API服务器中。由于我们的API服务器应该获取图像文件,因此我们将更新predict
方法以从请求中读取文件:
from flask import request @app.route('/predict', methods=['POST']) def predict(): if request.method == 'POST': # we will get the file from the request file = request.files['file'] # convert that to bytes img_bytes = file.read() class_id, class_name = get_prediction(image_bytes=img_bytes) return jsonify({'class_id': class_id, 'class_name': class_name})
该app.py
文件现已完成。以下是完整版本;将路径替换为保存文件的路径,它应运行:
import io import json from torchvision import models import torchvision.transforms as transforms from PIL import Image from flask import Flask, jsonify, request app = Flask(__name__) imagenet_class_index = json.load(open('<PATH/TO/.json/FILE>/imagenet_class_index.json')) model = models.densenet121(pretrained=True) model.eval() def transform_image(image_bytes): my_transforms = transforms.Compose([transforms.Resize(255), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) image = Image.open(io.BytesIO(image_bytes)) return my_transforms(image).unsqueeze(0) def get_prediction(image_bytes): tensor = transform_image(image_bytes=image_bytes) outputs = model.forward(tensor) _, y_hat = outputs.max(1) predicted_idx = str(y_hat.item()) return imagenet_class_index[predicted_idx] @app.route('/predict', methods=['POST']) def predict(): if request.method == 'POST': file = request.files['file'] img_bytes = file.read() class_id, class_name = get_prediction(image_bytes=img_bytes) return jsonify({'class_id': class_id, 'class_name': class_name}) if __name__ == '__main__': app.run()
让我们测试一下我们的Web服务器!run:
$ FLASK_ENV=development FLASK_APP=app.py flask run
我们可以使用 请求 库将POST请求发送到我们的应用程序:
import requests resp = requests.post("http://localhost:5000/predict", files={"file": open('<PATH/TO/.jpg/FILE>/cat.jpg','rb')})
打印resp.json()现在将显示以下内容:
{"class_id": "n02124075", "class_name": "Egyptian_cat"}
Next steps
我们编写的服务器非常琐碎,可能无法完成生产应用程序所需的一切。因此,大家可以采取一些措施来改善它:
- 端点
/predict
假定在请求中总会有一个图像文件。并非所有请求都适用。我们的用户可能发送带有其他参数的图像,或者根本不发送任何图像。 - 用户也可以发送非图像类型的文件。由于我们没有处理错误,因此这将破坏我们的服务器。添加显式的错误处理路径将引发异常,这将使我们能够更好地处理错误的输入
- 即使模型可以识别大量类别的图像,也可能无法识别所有图像。增强实现以处理模型无法识别图像中的任何情况的情况。
- 我们在开发模式下运行Flask服务器,该服务器不适合在生产环境中进行部署。大家可以查看本教程 以在生产环境中部署Flask服务器。
- 大家也可以通过创建一个页面来添加用户界面,该页面带有一个用于获取图像并显示预测的表单。查看 类似项目的演示及其源代码。
- 在本教程中,我们仅展示了如何构建可以一次返回单个图像预测的服务。我们可以修改服务以能够一次返回多个图像的预测。此外,服务流媒体 库会自动将对服务的请求排队,并将请求采样到可用于模型的微型批次中。大家可以查看本教程。
接下来,给大家介绍一下租用GPU做实验的方法,我们是在智星云租用的GPU,使用体验很好。具体大家可以参考:智星云官网: http://www.ai-galaxy.cn/,淘宝店:https://shop36573300.taobao.com/公众号: 智星AI