1
1
#! /usr/bin/python
2
2
# -*- coding: utf-8 -*-
3
3
4
- from tensorlayer .app import YOLOv4 , get_anchors , decode , filter_boxes
4
+ from tensorlayer .app import YOLOv4
5
5
from tensorlayer .app import CGCNN
6
- import numpy as np
7
- import tensorflow as tf
8
6
from tensorlayer import logging
9
- import cv2
7
+ from tensorlayer . app import yolo4_input_processing , yolo4_output_processing , result_to_json
10
8
11
9
12
10
class object_detection (object ):
@@ -42,8 +40,6 @@ def __init__(self, model_name='yolo4-mscoco'):
42
40
self .model_name = model_name
43
41
if self .model_name == 'yolo4-mscoco' :
44
42
self .model = YOLOv4 (NUM_CLASS = 80 , pretrained = True )
45
- elif self .model_name == 'lcn' :
46
- self .model = CGCNN (pretrained = True )
47
43
else :
48
44
raise ("The model does not support." )
49
45
@@ -53,8 +49,6 @@ def __call__(self, input_data):
53
49
feature_maps = self .model (batch_data , is_train = False )
54
50
pred_bbox = yolo4_output_processing (feature_maps )
55
51
output = result_to_json (input_data , pred_bbox )
56
- elif self .model_name == 'lcn' :
57
- output = self .model (input_data )
58
52
else :
59
53
raise NotImplementedError
60
54
@@ -70,78 +64,55 @@ def list(self):
70
64
logging .info ("The model name list: 'yolov4-mscoco', 'lcn'" )
71
65
72
66
73
- def yolo4_input_processing (original_image ):
74
- image_data = cv2 .resize (original_image , (416 , 416 ))
75
- image_data = image_data / 255.
76
- images_data = []
77
- for i in range (1 ):
78
- images_data .append (image_data )
79
- images_data = np .asarray (images_data ).astype (np .float32 )
80
- batch_data = tf .constant (images_data )
81
- return batch_data
82
-
83
-
84
- def yolo4_output_processing (feature_maps ):
85
- STRIDES = [8 , 16 , 32 ]
86
- ANCHORS = get_anchors ([12 , 16 , 19 , 36 , 40 , 28 , 36 , 75 , 76 , 55 , 72 , 146 , 142 , 110 , 192 , 243 , 459 , 401 ])
87
- NUM_CLASS = 80
88
- XYSCALE = [1.2 , 1.1 , 1.05 ]
89
- iou_threshold = 0.45
90
- score_threshold = 0.25
91
-
92
- bbox_tensors = []
93
- prob_tensors = []
94
- score_thres = 0.2
95
- for i , fm in enumerate (feature_maps ):
96
- if i == 0 :
97
- output_tensors = decode (fm , 416 // 8 , NUM_CLASS , STRIDES , ANCHORS , i , XYSCALE )
98
- elif i == 1 :
99
- output_tensors = decode (fm , 416 // 16 , NUM_CLASS , STRIDES , ANCHORS , i , XYSCALE )
67
+ class human_pose_estimation (object ):
68
+ """Model encapsulation.
69
+
70
+ Parameters
71
+ ----------
72
+ model_name : str
73
+ Choose the model to inference.
74
+
75
+ Methods
76
+ ---------
77
+ __init__()
78
+ Initializing the model.
79
+ __call__()
80
+ (1)Formatted input and output. (2)Inference model.
81
+ list()
82
+ Abstract method. Return available a list of model_name.
83
+
84
+ Examples
85
+ ---------
86
+ LCN to estimate 3D human poses from 2D poses, see `tutorial_human_3dpose_estimation_LCN.py
87
+ <https://github.com/tensorlayer/tensorlayer/blob/master/example/app_tutorials/tutorial_human_3dpose_estimation_LCN.py>`__
88
+ With TensorLayer
89
+
90
+ >>> # get the whole model
91
+ >>> net = tl.app.computer_vision.human_pose_estimation('3D-pose')
92
+ >>> # use for inferencing
93
+ >>> output = net(img)
94
+ """
95
+
96
+ def __init__ (self , model_name = '3D-pose' ):
97
+ self .model_name = model_name
98
+ if self .model_name == '3D-pose' :
99
+ self .model = CGCNN (pretrained = True )
100
100
else :
101
- output_tensors = decode (fm , 416 // 32 , NUM_CLASS , STRIDES , ANCHORS , i , XYSCALE )
102
- bbox_tensors .append (output_tensors [0 ])
103
- prob_tensors .append (output_tensors [1 ])
104
- pred_bbox = tf .concat (bbox_tensors , axis = 1 )
105
- pred_prob = tf .concat (prob_tensors , axis = 1 )
106
- boxes , pred_conf = filter_boxes (
107
- pred_bbox , pred_prob , score_threshold = score_thres , input_shape = tf .constant ([416 , 416 ])
108
- )
109
- pred = {'concat' : tf .concat ([boxes , pred_conf ], axis = - 1 )}
110
-
111
- for key , value in pred .items ():
112
- boxes = value [:, :, 0 :4 ]
113
- pred_conf = value [:, :, 4 :]
114
-
115
- boxes , scores , classes , valid_detections = tf .image .combined_non_max_suppression (
116
- boxes = tf .reshape (boxes , (tf .shape (boxes )[0 ], - 1 , 1 , 4 )),
117
- scores = tf .reshape (pred_conf , (tf .shape (pred_conf )[0 ], - 1 , tf .shape (pred_conf )[- 1 ])),
118
- max_output_size_per_class = 50 , max_total_size = 50 , iou_threshold = iou_threshold , score_threshold = score_threshold
119
- )
120
- output = [boxes .numpy (), scores .numpy (), classes .numpy (), valid_detections .numpy ()]
121
- return output
122
-
123
-
124
- def result_to_json (image , pred_bbox ):
125
- image_h , image_w , _ = image .shape
126
- out_boxes , out_scores , out_classes , num_boxes = pred_bbox
127
- class_names = {}
128
- json_result = []
129
- with open ('model/coco.names' , 'r' ) as data :
130
- for ID , name in enumerate (data ):
131
- class_names [ID ] = name .strip ('\n ' )
132
- nums_class = len (class_names )
133
-
134
- for i in range (num_boxes [0 ]):
135
- if int (out_classes [0 ][i ]) < 0 or int (out_classes [0 ][i ]) > nums_class : continue
136
- coor = out_boxes [0 ][i ]
137
- coor [0 ] = int (coor [0 ] * image_h )
138
- coor [2 ] = int (coor [2 ] * image_h )
139
- coor [1 ] = int (coor [1 ] * image_w )
140
- coor [3 ] = int (coor [3 ] * image_w )
141
-
142
- score = float (out_scores [0 ][i ])
143
- class_ind = int (out_classes [0 ][i ])
144
- bbox = np .array ([coor [1 ], coor [0 ], coor [3 ], coor [2 ]]).tolist () # [x1,y1,x2,y2]
145
- json_result .append ({'image' : None , 'category_id' : class_ind , 'bbox' : bbox , 'score' : score })
146
-
147
- return json_result
101
+ raise ("The model does not support." )
102
+
103
+ def __call__ (self , input_data ):
104
+ if self .model_name == '3D-pose' :
105
+ output = self .model (input_data , is_train = False )
106
+ else :
107
+ raise NotImplementedError
108
+
109
+ return output
110
+
111
+ def __repr__ (self ):
112
+ s = ('(model_name={model_name}, model_structure={model}' )
113
+ s += ')'
114
+ return s .format (classname = self .__class__ .__name__ , ** self .__dict__ )
115
+
116
+ @property
117
+ def list (self ):
118
+ logging .info ("The model name list: '3D-pose'" )
0 commit comments