Skip to content

Commit bbb2613

Browse files
author
hanjr
committed
add tutorials
1 parent a9f77e4 commit bbb2613

File tree

7 files changed

+779
-14
lines changed

7 files changed

+779
-14
lines changed

docs/modules/visualize.rst

+4
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ to visualize the model, activations etc. Here we provide more functions for data
1919
frame
2020
images2d
2121
tsne_embedding
22+
draw_boxes_and_labels_to_image_with_json
2223

2324

2425
Save and read images
@@ -44,6 +45,9 @@ Save image for object detection
4445
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
4546
.. autofunction:: draw_boxes_and_labels_to_image
4647

48+
Save image for object detection with json
49+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
50+
.. autofunction:: draw_boxes_and_labels_to_image_with_json
4751

4852
Save image for pose estimation (MPII)
4953
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
person
2+
bicycle
3+
car
4+
motorbike
5+
aeroplane
6+
bus
7+
train
8+
truck
9+
boat
10+
traffic light
11+
fire hydrant
12+
stop sign
13+
parking meter
14+
bench
15+
bird
16+
cat
17+
dog
18+
horse
19+
sheep
20+
cow
21+
elephant
22+
bear
23+
zebra
24+
giraffe
25+
backpack
26+
umbrella
27+
handbag
28+
tie
29+
suitcase
30+
frisbee
31+
skis
32+
snowboard
33+
sports ball
34+
kite
35+
baseball bat
36+
baseball glove
37+
skateboard
38+
surfboard
39+
tennis racket
40+
bottle
41+
wine glass
42+
cup
43+
fork
44+
knife
45+
spoon
46+
bowl
47+
banana
48+
apple
49+
sandwich
50+
orange
51+
broccoli
52+
carrot
53+
hot dog
54+
pizza
55+
donut
56+
cake
57+
chair
58+
sofa
59+
potted plant
60+
bed
61+
dining table
62+
toilet
63+
tvmonitor
64+
laptop
65+
mouse
66+
remote
67+
keyboard
68+
cell phone
69+
microwave
70+
oven
71+
toaster
72+
sink
73+
refrigerator
74+
book
75+
clock
76+
vase
77+
scissors
78+
teddy bear
79+
hair drier
80+
toothbrush

examples/app_tutorials/model/yolov4_config.txt

+541
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#! /usr/bin/python
2+
# -*- coding: utf-8 -*-
3+
4+
from tensorlayer.app import computer_vision
5+
from tensorlayer import visualize
6+
from tensorlayer.app.computer_vision_object_detection.common import read_class_names
7+
import numpy as np
8+
import cv2
9+
from PIL import Image
10+
INPUT_SIZE = 416
11+
image_path = './data/kite.jpg'
12+
13+
class_names = read_class_names('./model/coco.names')
14+
original_image = cv2.imread(image_path)
15+
image = cv2.cvtColor(np.array(original_image), cv2.COLOR_BGR2RGB)
16+
net = computer_vision.object_detection('yolo4-mscoco')
17+
json_result = net(original_image)
18+
print(type(json_result))
19+
image = visualize.draw_boxes_and_labels_to_image_with_json(image, json_result, class_names)
20+
image = Image.fromarray(image.astype(np.uint8))
21+
image.show()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#! /usr/bin/python
2+
# -*- coding: utf-8 -*-
3+
4+
from tensorlayer.app import computer_vision
5+
from tensorlayer import visualize
6+
from tensorlayer.app.computer_vision_object_detection.common import read_class_names
7+
import cv2
8+
INPUT_SIZE = 416
9+
video_path = './data/road.mp4'
10+
11+
class_names = read_class_names('./model/coco.names')
12+
vid = cv2.VideoCapture(video_path)
13+
'''
14+
vid = cv2.VideoCapture(0) # the serial number of camera on you device
15+
'''
16+
17+
if not vid.isOpened():
18+
raise ValueError("Read Video Failed!")
19+
net = computer_vision.object_detection('yolo4-mscoco')
20+
frame_id = 0
21+
while True:
22+
return_value, frame = vid.read()
23+
if return_value:
24+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
25+
else:
26+
if frame_id == vid.get(cv2.CAP_PROP_FRAME_COUNT):
27+
print("Video processing complete")
28+
break
29+
raise ValueError("No image! Try with another video format")
30+
31+
json_result = net(frame)
32+
image = visualize.draw_boxes_and_labels_to_image_with_json(frame, json_result, class_names)
33+
result = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
34+
35+
cv2.namedWindow("result", cv2.WINDOW_AUTOSIZE)
36+
cv2.imshow("result", result)
37+
if cv2.waitKey(1) & 0xFF == ord('q'): break
38+
frame_id += 1

tensorlayer/app/computer_vision.py

+28-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ def __call__(self, input_data):
2121
if self.model_name == 'yolo4-mscoco':
2222
batch_data = yolo4_input_processing(input_data)
2323
feature_maps = self.model(batch_data, is_train=False)
24-
output = yolo4_output_processing(feature_maps)
24+
pred_bbox = yolo4_output_processing(feature_maps)
25+
output = result_to_json(input_data, pred_bbox)
2526
else:
2627
raise NotImplementedError
2728

@@ -86,3 +87,29 @@ def yolo4_output_processing(feature_maps):
8687
)
8788
output = [boxes.numpy(), scores.numpy(), classes.numpy(), valid_detections.numpy()]
8889
return output
90+
91+
92+
def result_to_json(image, pred_bbox):
93+
image_h, image_w, _ = image.shape
94+
out_boxes, out_scores, out_classes, num_boxes = pred_bbox
95+
class_names = {}
96+
json_result = []
97+
with open('model/coco.names', 'r') as data:
98+
for ID, name in enumerate(data):
99+
class_names[ID] = name.strip('\n')
100+
nums_class = len(class_names)
101+
102+
for i in range(num_boxes[0]):
103+
if int(out_classes[0][i]) < 0 or int(out_classes[0][i]) > nums_class: continue
104+
coor = out_boxes[0][i]
105+
coor[0] = int(coor[0] * image_h)
106+
coor[2] = int(coor[2] * image_h)
107+
coor[1] = int(coor[1] * image_w)
108+
coor[3] = int(coor[3] * image_w)
109+
110+
score = float(out_scores[0][i])
111+
class_ind = int(out_classes[0][i])
112+
bbox = np.array([coor[1], coor[0], coor[3], coor[2]]).tolist() # [x1,y1,x2,y2]
113+
json_result.append({'image': None, 'category_id': class_ind, 'bbox': bbox, 'score': score})
114+
115+
return json_result

tensorlayer/visualize.py

+67-13
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55

66
import imageio
77
import numpy as np
8-
98
import tensorlayer as tl
109
from tensorlayer.lazy_imports import LazyImport
10+
import colorsys, random
1111

1212
cv2 = LazyImport("cv2")
1313

@@ -16,18 +16,9 @@
1616
# matplotlib.use('Agg')
1717

1818
__all__ = [
19-
'read_image',
20-
'read_images',
21-
'save_image',
22-
'save_images',
23-
'draw_boxes_and_labels_to_image',
24-
'draw_mpii_people_to_image',
25-
'frame',
26-
'CNN2d',
27-
'images2d',
28-
'tsne_embedding',
29-
'draw_weights',
30-
'W',
19+
'read_image', 'read_images', 'save_image', 'save_images', 'draw_boxes_and_labels_to_image',
20+
'draw_mpii_people_to_image', 'frame', 'CNN2d', 'images2d', 'tsne_embedding', 'draw_weights', 'W',
21+
'draw_boxes_and_labels_to_image_with_json'
3122
]
3223

3324

@@ -662,3 +653,66 @@ def draw_weights(W=None, second=10, saveable=True, shape=None, name='mnist', fig
662653

663654

664655
W = draw_weights
656+
657+
658+
def draw_boxes_and_labels_to_image_with_json(image, json_result, class_list, save_name=None):
659+
"""Draw bboxes and class labels on image. Return the image with bboxes.
660+
661+
Parameters
662+
-----------
663+
image : numpy.array
664+
The RGB image [height, width, channel].
665+
json_result : list of dict
666+
The object detection result with json format.
667+
classes_list : list of str
668+
For converting ID to string on image.
669+
save_name : None or str
670+
The name of image file (i.e. image.png), if None, not to save image.
671+
672+
Returns
673+
-------
674+
numpy.array
675+
The saved image.
676+
677+
References
678+
-----------
679+
- OpenCV rectangle and putText.
680+
- `scikit-image <http://scikit-image.org/docs/dev/api/skimage.draw.html#skimage.draw.rectangle>`__.
681+
682+
"""
683+
image_h, image_w, _ = image.shape
684+
num_classes = len(class_list)
685+
hsv_tuples = [(1.0 * x / num_classes, 1., 1.) for x in range(num_classes)]
686+
colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
687+
colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), colors))
688+
random.seed(0)
689+
random.shuffle(colors)
690+
random.seed(None)
691+
bbox_thick = int(0.6 * (image_h + image_w) / 600)
692+
fontScale = 0.5
693+
694+
for bbox_info in json_result:
695+
image_name = bbox_info['image']
696+
category_id = bbox_info['category_id']
697+
if category_id < 0 or category_id > num_classes: continue
698+
bbox = bbox_info['bbox'] # the order of coordinates is [x1, y2, x2, y2]
699+
score = bbox_info['score']
700+
701+
bbox_color = colors[category_id]
702+
c1, c2 = (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3]))
703+
cv2.rectangle(image, c1, c2, bbox_color, bbox_thick)
704+
705+
bbox_mess = '%s: %.2f' % (class_list[category_id], score)
706+
t_size = cv2.getTextSize(bbox_mess, 0, fontScale, thickness=bbox_thick // 2)[0]
707+
c3 = (c1[0] + t_size[0], c1[1] - t_size[1] - 3)
708+
cv2.rectangle(image, c1, (np.float32(c3[0]), np.float32(c3[1])), bbox_color, -1)
709+
710+
cv2.putText(
711+
image, bbox_mess, (c1[0], np.float32(c1[1] - 2)), cv2.FONT_HERSHEY_SIMPLEX, fontScale, (0, 0, 0),
712+
bbox_thick // 2, lineType=cv2.LINE_AA
713+
)
714+
715+
if save_name is not None:
716+
save_image(image, save_name)
717+
718+
return image

0 commit comments

Comments
 (0)