|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +""" |
| 3 | +Deploying PyTorch and Building a REST API using Flask |
| 4 | +===================================================== |
| 5 | +**Author**: `Avinash Sajjanshetty <https://avi.im>`_ |
| 6 | +
|
| 7 | +
|
| 8 | +In this tutorial, we will deploy a PyTorch model using Flask and expose a |
| 9 | +REST API for model inference. In particular, we will deploy a pretrained |
| 10 | +DenseNet 121 model which detects the image. |
| 11 | +
|
| 12 | +.. tip:: All the code used here is released under MIT license and is available on `Github <https://github.com/avinassh/pytorch-flask-api>`_. |
| 13 | +
|
| 14 | +""" |
| 15 | + |
| 16 | + |
| 17 | +###################################################################### |
| 18 | +# API Definition |
| 19 | +# -------------- |
| 20 | +# |
| 21 | +# We will first define our API endpoints, the request and response types. Our |
| 22 | +# API endpoint will be at ``/predict`` which takes HTTP POST requests with a |
| 23 | +# ``file`` parameter which contains the image. The response will be of JSON |
| 24 | +# response containing the prediction: |
| 25 | +# |
| 26 | +# :: |
| 27 | +# |
| 28 | +# {"class_id": "n02124075", "class_name": "Egyptian_cat"} |
| 29 | +# |
| 30 | +# |
| 31 | + |
| 32 | +###################################################################### |
| 33 | +# Dependencies |
| 34 | +# ------------ |
| 35 | +# |
| 36 | +# Install the required dependenices by running the following command: |
| 37 | +# |
| 38 | +# :: |
| 39 | +# |
| 40 | +# $ pip install Flask==1.0.3 torchvision-0.3.0 |
| 41 | + |
| 42 | + |
| 43 | +###################################################################### |
| 44 | +# Simple Web Server |
| 45 | +# ----------------- |
| 46 | +# |
| 47 | +# Following is a simple webserver, taken from Flask's documentaion |
| 48 | + |
| 49 | + |
| 50 | +from flask import Flask |
| 51 | +app = Flask(__name__) |
| 52 | + |
| 53 | + |
| 54 | +@app.route('/') |
| 55 | +def hello(): |
| 56 | + return 'Hello World!' |
| 57 | + |
| 58 | +############################################################################### |
| 59 | +# Save the above snippet in a file called ``app.py`` and you can now run a |
| 60 | +# Flask development server by typing: |
| 61 | +# |
| 62 | +# :: |
| 63 | +# |
| 64 | +# $ FLASK_ENV=development FLASK_APP=app.py flask run |
| 65 | + |
| 66 | +############################################################################### |
| 67 | +# When you visit ``http://localhost:5000/`` in your web browser, you will be |
| 68 | +# greeted with ``Hello World!`` text |
| 69 | + |
| 70 | +############################################################################### |
| 71 | +# We will make slight changes to the above snippet, so that it suits our API |
| 72 | +# definition. First, we will rename the method to ``predict``. We will update |
| 73 | +# the endpoint path to ``/predict``. Since the image files will be sent via |
| 74 | +# HTTP POST requests, we will update it so that it also accepts only POST |
| 75 | +# requests: |
| 76 | + |
| 77 | + |
| 78 | +@app.route('/predict', methods=['POST']) |
| 79 | +def predict(): |
| 80 | + return 'Hello World!' |
| 81 | + |
| 82 | +############################################################################### |
| 83 | +# We will also change the response type, so that it returns a JSON response |
| 84 | +# containing ImageNet class id and name. The updated ``app.py`` file will |
| 85 | +# be now: |
| 86 | + |
| 87 | +from flask import Flask, jsonify |
| 88 | +app = Flask(__name__) |
| 89 | + |
| 90 | +@app.route('/predict', methods=['POST']) |
| 91 | +def predict(): |
| 92 | + return jsonify({'class_id': 'IMAGE_NET_XXX', 'class_name': 'Cat'}) |
| 93 | + |
| 94 | + |
| 95 | +###################################################################### |
| 96 | +# Inference |
| 97 | +# ----------------- |
| 98 | +# |
| 99 | +# In the next sections we will focus on writing the inference code. This will |
| 100 | +# involve two parts, one where we prepare the image so that it can be fed |
| 101 | +# to DenseNet and next, we will write the code to get the actual prediction |
| 102 | +# from the model. |
| 103 | +# |
| 104 | +# Preparing the image |
| 105 | +# ~~~~~~~~~~~~~~~~~~~ |
| 106 | +# |
| 107 | +# DenseNet model requires the image to be of 3 channel RGB image of size |
| 108 | +# 224 x 224. We will also normalise the image tensor with the required mean |
| 109 | +# and standard deviation values. You can read more about it |
| 110 | +# `here <https://pytorch.org/docs/stable/torchvision/models.html>`_. |
| 111 | +# |
| 112 | +# We will use ``transforms`` from ``torchvision`` library and build a |
| 113 | +# transform pipeline, which transforms our images as required. You |
| 114 | +# can read more about transforms `here <https://pytorch.org/docs/stable/torchvision/transforms.html>`_. |
| 115 | + |
| 116 | +import io |
| 117 | + |
| 118 | +import torchvision.transforms as transforms |
| 119 | +from PIL import Image |
| 120 | + |
| 121 | +def transform_image(image_bytes): |
| 122 | + my_transforms = transforms.Compose([transforms.Resize(255), |
| 123 | + transforms.CenterCrop(224), |
| 124 | + transforms.ToTensor(), |
| 125 | + transforms.Normalize( |
| 126 | + [0.485, 0.456, 0.406], |
| 127 | + [0.229, 0.224, 0.225])]) |
| 128 | + image = Image.open(io.BytesIO(image_bytes)) |
| 129 | + return my_transforms(image).unsqueeze(0) |
| 130 | + |
| 131 | + |
| 132 | +###################################################################### |
| 133 | +# Above method takes image data in bytes, applies the series of transforms |
| 134 | +# and returns a tensor. To test the above method, read an image file in |
| 135 | +# bytes mode and see if you get a tensor back: |
| 136 | + |
| 137 | +with open('sample_file.jpeg', 'rb') as f: |
| 138 | + image_bytes = f.read() |
| 139 | + tensor = transform_image(image_bytes=image_bytes) |
| 140 | + print(tensor) |
| 141 | + |
| 142 | +###################################################################### |
| 143 | +# Prediction |
| 144 | +# ~~~~~~~~~~~~~~~~~~~ |
| 145 | +# |
| 146 | +# Now will use a pretrained DenseNet 121 model to predict the image class. We |
| 147 | +# will use one from ``torchvision`` library, load the model and get an |
| 148 | +# inference. While we'll be using a pretrained model in this example, you can |
| 149 | +# use this same approach for your own models. See more about loading your |
| 150 | +# models in this :doc:`tutorial </beginner/saving_loading_models>`. |
| 151 | + |
| 152 | +from torchvision import models |
| 153 | + |
| 154 | +# Make sure to pass `pretrained` as `True` to use the pretrained weights: |
| 155 | +model = models.densenet121(pretrained=True) |
| 156 | +# Since we are using our model only for inference, switch to `eval` mode: |
| 157 | +model.eval() |
| 158 | + |
| 159 | + |
| 160 | +def get_prediction(image_bytes): |
| 161 | + tensor = transform_image(image_bytes=image_bytes) |
| 162 | + outputs = model.forward(tensor) |
| 163 | + _, y_hat = outputs.max(1) |
| 164 | + return y_hat |
| 165 | + |
| 166 | + |
| 167 | +###################################################################### |
| 168 | +# The tensor ``y_hat`` will contain the index of the predicted class id. |
| 169 | +# However, we need a human readable class name. For that we need a class id |
| 170 | +# to name mapping. Download |
| 171 | +# `this file <https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json>`_ |
| 172 | +# and place it in current directory as file ``imagenet_class_index.json``. |
| 173 | +# This file contains the mapping of ImageNet class id to ImageNet class |
| 174 | +# name. We will load this JSON file and get the class name of the |
| 175 | +# predicted index. |
| 176 | + |
| 177 | +import json |
| 178 | + |
| 179 | +imagenet_class_index = json.load(open('imagenet_class_index.json')) |
| 180 | + |
| 181 | +def get_prediction(image_bytes): |
| 182 | + tensor = transform_image(image_bytes=image_bytes) |
| 183 | + outputs = model.forward(tensor) |
| 184 | + _, y_hat = outputs.max(1) |
| 185 | + predicted_idx = str(y_hat.item()) |
| 186 | + return imagenet_class_index[predicted_idx] |
| 187 | + |
| 188 | + |
| 189 | +###################################################################### |
| 190 | +# Before using ``imagenet_class_index`` dictionary, first we will convert |
| 191 | +# tensor value to a string value, since the keys in the |
| 192 | +# ``imagenet_class_index`` dictionary are strings. |
| 193 | +# We will test our above method: |
| 194 | + |
| 195 | + |
| 196 | +with open('sample_file.jpeg', 'rb') as f: |
| 197 | + image_bytes = f.read() |
| 198 | + print(get_prediction(image_bytes=image_bytes)) |
| 199 | + |
| 200 | +###################################################################### |
| 201 | +# You should get a response like this: |
| 202 | + |
| 203 | +['n02124075', 'Egyptian_cat'] |
| 204 | + |
| 205 | +###################################################################### |
| 206 | +# The first item in array is ImageNet class id and second item is the human |
| 207 | +# readable name. |
| 208 | +# |
| 209 | +# .. Note :: |
| 210 | +# Did you notice that why ``model`` variable is not part of ``get_prediction`` |
| 211 | +# method? Or why is model a global variable? Loading a model can be an |
| 212 | +# expensive operation in terms of memory and compute. If we loaded the model in the |
| 213 | +# ``get_prediction`` method, then it would get unnecessarily loaded every |
| 214 | +# time the method is called. Since, we are building a web server, there |
| 215 | +# could be thousands of requests per second, we should not waste time |
| 216 | +# redundantly loading the model for every inference. So, we keep the model |
| 217 | +# loaded in memory just once. In |
| 218 | +# production systems, it's necessary to be efficient about your use of |
| 219 | +# compute to be able to serve requests at scale, so you should generally |
| 220 | +# load your model before serving requests. |
| 221 | + |
| 222 | +###################################################################### |
| 223 | +# Integrating the model in our API Server |
| 224 | +# --------------------------------------- |
| 225 | +# |
| 226 | +# In this final part we will add our model to our Flask API server. Since |
| 227 | +# our API server is supposed to take an image file, we will update our ``predict`` |
| 228 | +# method to read files from the requests: |
| 229 | + |
| 230 | +from flask import request |
| 231 | + |
| 232 | + |
| 233 | +@app.route('/predict', methods=['POST']) |
| 234 | +def predict(): |
| 235 | + if request.method == 'POST': |
| 236 | + # we will get the file from the request |
| 237 | + file = request.files['file'] |
| 238 | + # convert that to bytes |
| 239 | + img_bytes = file.read() |
| 240 | + class_id, class_name = get_prediction(image_bytes=img_bytes) |
| 241 | + return jsonify({'class_id': class_id, 'class_name': class_name}) |
| 242 | + |
| 243 | +###################################################################### |
| 244 | +# The ``app.py`` file is now complete. Following is the full version: |
| 245 | +# |
| 246 | + |
| 247 | +import io |
| 248 | +import json |
| 249 | + |
| 250 | +from torchvision import models |
| 251 | +import torchvision.transforms as transforms |
| 252 | +from PIL import Image |
| 253 | +from flask import Flask, jsonify, request |
| 254 | + |
| 255 | + |
| 256 | +app = Flask(__name__) |
| 257 | +imagenet_class_index = json.load(open('imagenet_class_index.json')) |
| 258 | +model = models.densenet121(pretrained=True) |
| 259 | +model.eval() |
| 260 | + |
| 261 | + |
| 262 | +def transform_image(image_bytes): |
| 263 | + my_transforms = transforms.Compose([transforms.Resize(255), |
| 264 | + transforms.CenterCrop(224), |
| 265 | + transforms.ToTensor(), |
| 266 | + transforms.Normalize( |
| 267 | + [0.485, 0.456, 0.406], |
| 268 | + [0.229, 0.224, 0.225])]) |
| 269 | + image = Image.open(io.BytesIO(image_bytes)) |
| 270 | + return my_transforms(image).unsqueeze(0) |
| 271 | + |
| 272 | + |
| 273 | +def get_prediction(image_bytes): |
| 274 | + tensor = transform_image(image_bytes=image_bytes) |
| 275 | + outputs = model.forward(tensor) |
| 276 | + _, y_hat = outputs.max(1) |
| 277 | + predicted_idx = str(y_hat.item()) |
| 278 | + return imagenet_class_index[predicted_idx] |
| 279 | + |
| 280 | + |
| 281 | +@app.route('/predict', methods=['POST']) |
| 282 | +def predict(): |
| 283 | + if request.method == 'POST': |
| 284 | + file = request.files['file'] |
| 285 | + img_bytes = file.read() |
| 286 | + class_id, class_name = get_prediction(image_bytes=img_bytes) |
| 287 | + return jsonify({'class_id': class_id, 'class_name': class_name}) |
| 288 | + |
| 289 | + |
| 290 | +if __name__ == '__main__': |
| 291 | + app.run() |
| 292 | + |
| 293 | +###################################################################### |
| 294 | +# Let's test our web server! Run: |
| 295 | +# |
| 296 | +# :: |
| 297 | +# |
| 298 | +# $ FLASK_ENV=development FLASK_APP=app.py flask run |
| 299 | + |
| 300 | +####################################################################### |
| 301 | +# We can use a command line tool like curl or `Postman <https://www.getpostman.com/>`_ to send requests to |
| 302 | +# this webserver: |
| 303 | +# |
| 304 | +# :: |
| 305 | +# |
| 306 | +# $ curl -X POST -F file=@cat_pic.jpeg http://localhost:5000/predict |
| 307 | +# |
| 308 | +# You will get a response in the form: |
| 309 | +# |
| 310 | +# :: |
| 311 | +# |
| 312 | +# {"class_id": "n02124075", "class_name": "Egyptian_cat"} |
| 313 | +# |
| 314 | +# |
| 315 | + |
| 316 | +###################################################################### |
| 317 | +# Next steps |
| 318 | +# -------------- |
| 319 | +# |
| 320 | +# The server we wrote is quite trivial and and may not do everything |
| 321 | +# you need for your production application. So, here are some things you |
| 322 | +# can do to make it better: |
| 323 | +# |
| 324 | +# - The endpoint ``/predict`` assumes that always there will be a image file |
| 325 | +# in the request. This may not hold true for all requests. Our user may |
| 326 | +# send image with a different parameter or send no images at all. |
| 327 | +# |
| 328 | +# - The user may send non-image type files too. Since we are not handling |
| 329 | +# errors, this will break our server. Adding an explicit error handing |
| 330 | +# path that will throw an exception would allow us to better handle |
| 331 | +# the bad inputs |
| 332 | +# |
| 333 | +# - Even though the model can recognize a large number of classes of images, |
| 334 | +# it may not be able to recognize all images. Enhance the implementation |
| 335 | +# to handle cases when the model does not recognize anything in the image. |
| 336 | +# |
| 337 | +# - We run the Flask server in the development mode, which is not suitable for |
| 338 | +# deploying in production. You can check out `this tutorial <https://flask.palletsprojects.com/en/1.1.x/tutorial/deploy/>`_ |
| 339 | +# for deploying a Flask server in production. |
| 340 | +# |
| 341 | +# - You can also add a UI by creating a page with a form which takes the image and |
| 342 | +# displays the prediction. Check out the `demo <https://pytorch-imagenet.herokuapp.com/>`_ |
| 343 | +# of a similar project and its `source code <https://github.com/avinassh/pytorch-flask-api-heroku>`_. |
0 commit comments