Skip to content

Commit 18cada0

Browse files
authored
Merge pull request #537 from avinassh/rest-api
Add a new tutorial: Building a Flask API server with PyTorch
2 parents c8d46ad + 1401e89 commit 18cada0

File tree

4 files changed

+352
-0
lines changed

4 files changed

+352
-0
lines changed

_static/img/flask.png

173 KB
Loading

index.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,10 @@ Production Usage
225225
:description: :doc:`/beginner/aws_distributed_training_tutorial`
226226
:figure: _static/img/distributed/DistPyTorch.jpg
227227

228+
.. customgalleryitem::
229+
:tooltip: Deploying PyTorch and Building a REST API using Flask
230+
:description: :doc:`/intermediate/flask_rest_api_tutorial`
231+
:figure: _static/img/flask.png
228232

229233
.. raw:: html
230234

@@ -326,6 +330,7 @@ PyTorch in Other Languages
326330
intermediate/model_parallel_tutorial
327331
intermediate/ddp_tutorial
328332
intermediate/dist_tuto
333+
intermediate/flask_rest_api_tutorial
329334
beginner/aws_distributed_training_tutorial
330335
advanced/cpp_export
331336

intermediate_source/README.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,7 @@ Intermediate tutorials
2424
6. spatial_transformer_tutorial
2525
Spatial Transformer Networks Tutorial
2626
https://pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html
27+
28+
7. flask_rest_api_tutorial.py
29+
Deploying PyTorch and Building a REST API using Flask
30+
https://pytorch.org/tutorials/beginner/flask_rest_api_tutorial.html
Lines changed: 343 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,343 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Deploying PyTorch and Building a REST API using Flask
4+
=====================================================
5+
**Author**: `Avinash Sajjanshetty <https://avi.im>`_
6+
7+
8+
In this tutorial, we will deploy a PyTorch model using Flask and expose a
9+
REST API for model inference. In particular, we will deploy a pretrained
10+
DenseNet 121 model which detects the image.
11+
12+
.. tip:: All the code used here is released under MIT license and is available on `Github <https://github.com/avinassh/pytorch-flask-api>`_.
13+
14+
"""
15+
16+
17+
######################################################################
18+
# API Definition
19+
# --------------
20+
#
21+
# We will first define our API endpoints, the request and response types. Our
22+
# API endpoint will be at ``/predict`` which takes HTTP POST requests with a
23+
# ``file`` parameter which contains the image. The response will be of JSON
24+
# response containing the prediction:
25+
#
26+
# ::
27+
#
28+
# {"class_id": "n02124075", "class_name": "Egyptian_cat"}
29+
#
30+
#
31+
32+
######################################################################
33+
# Dependencies
34+
# ------------
35+
#
36+
# Install the required dependenices by running the following command:
37+
#
38+
# ::
39+
#
40+
# $ pip install Flask==1.0.3 torchvision-0.3.0
41+
42+
43+
######################################################################
44+
# Simple Web Server
45+
# -----------------
46+
#
47+
# Following is a simple webserver, taken from Flask's documentaion
48+
49+
50+
from flask import Flask
51+
app = Flask(__name__)
52+
53+
54+
@app.route('/')
55+
def hello():
56+
return 'Hello World!'
57+
58+
###############################################################################
59+
# Save the above snippet in a file called ``app.py`` and you can now run a
60+
# Flask development server by typing:
61+
#
62+
# ::
63+
#
64+
# $ FLASK_ENV=development FLASK_APP=app.py flask run
65+
66+
###############################################################################
67+
# When you visit ``http://localhost:5000/`` in your web browser, you will be
68+
# greeted with ``Hello World!`` text
69+
70+
###############################################################################
71+
# We will make slight changes to the above snippet, so that it suits our API
72+
# definition. First, we will rename the method to ``predict``. We will update
73+
# the endpoint path to ``/predict``. Since the image files will be sent via
74+
# HTTP POST requests, we will update it so that it also accepts only POST
75+
# requests:
76+
77+
78+
@app.route('/predict', methods=['POST'])
79+
def predict():
80+
return 'Hello World!'
81+
82+
###############################################################################
83+
# We will also change the response type, so that it returns a JSON response
84+
# containing ImageNet class id and name. The updated ``app.py`` file will
85+
# be now:
86+
87+
from flask import Flask, jsonify
88+
app = Flask(__name__)
89+
90+
@app.route('/predict', methods=['POST'])
91+
def predict():
92+
return jsonify({'class_id': 'IMAGE_NET_XXX', 'class_name': 'Cat'})
93+
94+
95+
######################################################################
96+
# Inference
97+
# -----------------
98+
#
99+
# In the next sections we will focus on writing the inference code. This will
100+
# involve two parts, one where we prepare the image so that it can be fed
101+
# to DenseNet and next, we will write the code to get the actual prediction
102+
# from the model.
103+
#
104+
# Preparing the image
105+
# ~~~~~~~~~~~~~~~~~~~
106+
#
107+
# DenseNet model requires the image to be of 3 channel RGB image of size
108+
# 224 x 224. We will also normalise the image tensor with the required mean
109+
# and standard deviation values. You can read more about it
110+
# `here <https://pytorch.org/docs/stable/torchvision/models.html>`_.
111+
#
112+
# We will use ``transforms`` from ``torchvision`` library and build a
113+
# transform pipeline, which transforms our images as required. You
114+
# can read more about transforms `here <https://pytorch.org/docs/stable/torchvision/transforms.html>`_.
115+
116+
import io
117+
118+
import torchvision.transforms as transforms
119+
from PIL import Image
120+
121+
def transform_image(image_bytes):
122+
my_transforms = transforms.Compose([transforms.Resize(255),
123+
transforms.CenterCrop(224),
124+
transforms.ToTensor(),
125+
transforms.Normalize(
126+
[0.485, 0.456, 0.406],
127+
[0.229, 0.224, 0.225])])
128+
image = Image.open(io.BytesIO(image_bytes))
129+
return my_transforms(image).unsqueeze(0)
130+
131+
132+
######################################################################
133+
# Above method takes image data in bytes, applies the series of transforms
134+
# and returns a tensor. To test the above method, read an image file in
135+
# bytes mode and see if you get a tensor back:
136+
137+
with open('sample_file.jpeg', 'rb') as f:
138+
image_bytes = f.read()
139+
tensor = transform_image(image_bytes=image_bytes)
140+
print(tensor)
141+
142+
######################################################################
143+
# Prediction
144+
# ~~~~~~~~~~~~~~~~~~~
145+
#
146+
# Now will use a pretrained DenseNet 121 model to predict the image class. We
147+
# will use one from ``torchvision`` library, load the model and get an
148+
# inference. While we'll be using a pretrained model in this example, you can
149+
# use this same approach for your own models. See more about loading your
150+
# models in this :doc:`tutorial </beginner/saving_loading_models>`.
151+
152+
from torchvision import models
153+
154+
# Make sure to pass `pretrained` as `True` to use the pretrained weights:
155+
model = models.densenet121(pretrained=True)
156+
# Since we are using our model only for inference, switch to `eval` mode:
157+
model.eval()
158+
159+
160+
def get_prediction(image_bytes):
161+
tensor = transform_image(image_bytes=image_bytes)
162+
outputs = model.forward(tensor)
163+
_, y_hat = outputs.max(1)
164+
return y_hat
165+
166+
167+
######################################################################
168+
# The tensor ``y_hat`` will contain the index of the predicted class id.
169+
# However, we need a human readable class name. For that we need a class id
170+
# to name mapping. Download
171+
# `this file <https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json>`_
172+
# and place it in current directory as file ``imagenet_class_index.json``.
173+
# This file contains the mapping of ImageNet class id to ImageNet class
174+
# name. We will load this JSON file and get the class name of the
175+
# predicted index.
176+
177+
import json
178+
179+
imagenet_class_index = json.load(open('imagenet_class_index.json'))
180+
181+
def get_prediction(image_bytes):
182+
tensor = transform_image(image_bytes=image_bytes)
183+
outputs = model.forward(tensor)
184+
_, y_hat = outputs.max(1)
185+
predicted_idx = str(y_hat.item())
186+
return imagenet_class_index[predicted_idx]
187+
188+
189+
######################################################################
190+
# Before using ``imagenet_class_index`` dictionary, first we will convert
191+
# tensor value to a string value, since the keys in the
192+
# ``imagenet_class_index`` dictionary are strings.
193+
# We will test our above method:
194+
195+
196+
with open('sample_file.jpeg', 'rb') as f:
197+
image_bytes = f.read()
198+
print(get_prediction(image_bytes=image_bytes))
199+
200+
######################################################################
201+
# You should get a response like this:
202+
203+
['n02124075', 'Egyptian_cat']
204+
205+
######################################################################
206+
# The first item in array is ImageNet class id and second item is the human
207+
# readable name.
208+
#
209+
# .. Note ::
210+
# Did you notice that why ``model`` variable is not part of ``get_prediction``
211+
# method? Or why is model a global variable? Loading a model can be an
212+
# expensive operation in terms of memory and compute. If we loaded the model in the
213+
# ``get_prediction`` method, then it would get unnecessarily loaded every
214+
# time the method is called. Since, we are building a web server, there
215+
# could be thousands of requests per second, we should not waste time
216+
# redundantly loading the model for every inference. So, we keep the model
217+
# loaded in memory just once. In
218+
# production systems, it's necessary to be efficient about your use of
219+
# compute to be able to serve requests at scale, so you should generally
220+
# load your model before serving requests.
221+
222+
######################################################################
223+
# Integrating the model in our API Server
224+
# ---------------------------------------
225+
#
226+
# In this final part we will add our model to our Flask API server. Since
227+
# our API server is supposed to take an image file, we will update our ``predict``
228+
# method to read files from the requests:
229+
230+
from flask import request
231+
232+
233+
@app.route('/predict', methods=['POST'])
234+
def predict():
235+
if request.method == 'POST':
236+
# we will get the file from the request
237+
file = request.files['file']
238+
# convert that to bytes
239+
img_bytes = file.read()
240+
class_id, class_name = get_prediction(image_bytes=img_bytes)
241+
return jsonify({'class_id': class_id, 'class_name': class_name})
242+
243+
######################################################################
244+
# The ``app.py`` file is now complete. Following is the full version:
245+
#
246+
247+
import io
248+
import json
249+
250+
from torchvision import models
251+
import torchvision.transforms as transforms
252+
from PIL import Image
253+
from flask import Flask, jsonify, request
254+
255+
256+
app = Flask(__name__)
257+
imagenet_class_index = json.load(open('imagenet_class_index.json'))
258+
model = models.densenet121(pretrained=True)
259+
model.eval()
260+
261+
262+
def transform_image(image_bytes):
263+
my_transforms = transforms.Compose([transforms.Resize(255),
264+
transforms.CenterCrop(224),
265+
transforms.ToTensor(),
266+
transforms.Normalize(
267+
[0.485, 0.456, 0.406],
268+
[0.229, 0.224, 0.225])])
269+
image = Image.open(io.BytesIO(image_bytes))
270+
return my_transforms(image).unsqueeze(0)
271+
272+
273+
def get_prediction(image_bytes):
274+
tensor = transform_image(image_bytes=image_bytes)
275+
outputs = model.forward(tensor)
276+
_, y_hat = outputs.max(1)
277+
predicted_idx = str(y_hat.item())
278+
return imagenet_class_index[predicted_idx]
279+
280+
281+
@app.route('/predict', methods=['POST'])
282+
def predict():
283+
if request.method == 'POST':
284+
file = request.files['file']
285+
img_bytes = file.read()
286+
class_id, class_name = get_prediction(image_bytes=img_bytes)
287+
return jsonify({'class_id': class_id, 'class_name': class_name})
288+
289+
290+
if __name__ == '__main__':
291+
app.run()
292+
293+
######################################################################
294+
# Let's test our web server! Run:
295+
#
296+
# ::
297+
#
298+
# $ FLASK_ENV=development FLASK_APP=app.py flask run
299+
300+
#######################################################################
301+
# We can use a command line tool like curl or `Postman <https://www.getpostman.com/>`_ to send requests to
302+
# this webserver:
303+
#
304+
# ::
305+
#
306+
# $ curl -X POST -F file=@cat_pic.jpeg http://localhost:5000/predict
307+
#
308+
# You will get a response in the form:
309+
#
310+
# ::
311+
#
312+
# {"class_id": "n02124075", "class_name": "Egyptian_cat"}
313+
#
314+
#
315+
316+
######################################################################
317+
# Next steps
318+
# --------------
319+
#
320+
# The server we wrote is quite trivial and and may not do everything
321+
# you need for your production application. So, here are some things you
322+
# can do to make it better:
323+
#
324+
# - The endpoint ``/predict`` assumes that always there will be a image file
325+
# in the request. This may not hold true for all requests. Our user may
326+
# send image with a different parameter or send no images at all.
327+
#
328+
# - The user may send non-image type files too. Since we are not handling
329+
# errors, this will break our server. Adding an explicit error handing
330+
# path that will throw an exception would allow us to better handle
331+
# the bad inputs
332+
#
333+
# - Even though the model can recognize a large number of classes of images,
334+
# it may not be able to recognize all images. Enhance the implementation
335+
# to handle cases when the model does not recognize anything in the image.
336+
#
337+
# - We run the Flask server in the development mode, which is not suitable for
338+
# deploying in production. You can check out `this tutorial <https://flask.palletsprojects.com/en/1.1.x/tutorial/deploy/>`_
339+
# for deploying a Flask server in production.
340+
#
341+
# - You can also add a UI by creating a page with a form which takes the image and
342+
# displays the prediction. Check out the `demo <https://pytorch-imagenet.herokuapp.com/>`_
343+
# of a similar project and its `source code <https://github.com/avinassh/pytorch-flask-api-heroku>`_.

0 commit comments

Comments
 (0)