Skip to content

Commit 41ca7ca

Browse files
authored
add onnx infer for ppyoloe_r (#7457)
1 parent 964643e commit 41ca7ca

File tree

4 files changed

+340
-2
lines changed

4 files changed

+340
-2
lines changed

configs/rotate/ppyoloe_r/README.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,21 @@ python deploy/python/infer.py --image_file demo/P0072__1.0__0___0.png --model_di
123123
**注意:**
124124
- 使用Paddle-TRT使用确保**PaddlePaddle版本为develop版本且TensorRT版本大于8.2**.
125125

126+
**使用ONNX Runtime进行部署**,执行以下命令:
127+
```
128+
# 导出模型
129+
python tools/export_model.py -c configs/rotate/ppyoloe_r/ppyoloe_r_crn_l_3x_dota.yml -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_r_crn_l_3x_dota.pdparams export_onnx=True
130+
131+
# 安装paddle2onnx
132+
pip install paddle2onnx
133+
134+
# 转换成onnx模型
135+
paddle2onnx --model_dir output_inference/ppyoloe_r_crn_l_3x_dota --model_filename model.pdmodel --params_filename model.pdiparams --opset_version 11 --save_file ppyoloe_r_crn_l_3x_dota.onnx
136+
137+
# 预测图片
138+
python configs/rotate/tools/onnx_infer.py --infer_cfg output_inference/ppyoloe_r_crn_l_3x_dota/infer_cfg.yml --onnx_file ppyoloe_r_crn_l_3x_dota.onnx --image_file demo/P0072__1.0__0___0.png
139+
140+
```
126141

127142
## 附录
128143

configs/rotate/ppyoloe_r/README_en.md

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ python tools/export_model.py -c configs/rotate/ppyoloe_r/ppyoloe_r_crn_l_3x_dota
114114
python deploy/python/infer.py --image_file demo/P0072__1.0__0___0.png --model_dir=output_inference/ppyoloe_r_crn_l_3x_dota --run_mode=paddle --device=gpu
115115
```
116116

117-
**Using Paddle-TRT** to for deployment, run following command
117+
**Using Paddle-TRT** for deployment, run following command
118118

119119
``` bash
120120
# export inference model
@@ -126,6 +126,22 @@ python deploy/python/infer.py --image_file demo/P0072__1.0__0___0.png --model_di
126126
**Notes:**
127127
- When using Paddle-TRT for speed testing, make sure that **the version of TensorRT is larger than 8.2 and the version of PaddlePaddle is the develop version**
128128

129+
**Using ONNX Runtime** for deployment, run following command
130+
131+
``` bash
132+
# export inference model
133+
python tools/export_model.py -c configs/rotate/ppyoloe_r/ppyoloe_r_crn_l_3x_dota.yml -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_r_crn_l_3x_dota.pdparams export_onnx=True
134+
135+
# install paddle2onnx
136+
pip install paddle2onnx
137+
138+
# convert to onnx model
139+
paddle2onnx --model_dir output_inference/ppyoloe_r_crn_l_3x_dota --model_filename model.pdmodel --params_filename model.pdiparams --opset_version 11 --save_file ppyoloe_r_crn_l_3x_dota.onnx
140+
141+
# inference single image
142+
python configs/rotate/tools/onnx_infer.py --infer_cfg output_inference/ppyoloe_r_crn_l_3x_dota/infer_cfg.yml --onnx_file ppyoloe_r_crn_l_3x_dota.onnx --image_file demo/P0072__1.0__0___0.png
143+
```
144+
129145
## Appendix
130146

131147
Ablation experiments of PP-YOLOE-R

configs/rotate/tools/onnx_infer.py

Lines changed: 302 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import absolute_import
16+
from __future__ import division
17+
from __future__ import print_function
18+
19+
import os
20+
import sys
21+
import six
22+
import glob
23+
import copy
24+
import yaml
25+
import argparse
26+
import cv2
27+
import numpy as np
28+
from shapely.geometry import Polygon
29+
from onnxruntime import InferenceSession
30+
31+
32+
# preprocess ops
33+
def decode_image(img_path):
34+
with open(img_path, 'rb') as f:
35+
im_read = f.read()
36+
data = np.frombuffer(im_read, dtype='uint8')
37+
im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
38+
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
39+
img_info = {
40+
"im_shape": np.array(
41+
im.shape[:2], dtype=np.float32),
42+
"scale_factor": np.array(
43+
[1., 1.], dtype=np.float32)
44+
}
45+
return im, img_info
46+
47+
48+
class Resize(object):
49+
def __init__(self, target_size, keep_ratio=True, interp=cv2.INTER_LINEAR):
50+
if isinstance(target_size, int):
51+
target_size = [target_size, target_size]
52+
self.target_size = target_size
53+
self.keep_ratio = keep_ratio
54+
self.interp = interp
55+
56+
def __call__(self, im, im_info):
57+
assert len(self.target_size) == 2
58+
assert self.target_size[0] > 0 and self.target_size[1] > 0
59+
im_channel = im.shape[2]
60+
im_scale_y, im_scale_x = self.generate_scale(im)
61+
im = cv2.resize(
62+
im,
63+
None,
64+
None,
65+
fx=im_scale_x,
66+
fy=im_scale_y,
67+
interpolation=self.interp)
68+
im_info['im_shape'] = np.array(im.shape[:2]).astype('float32')
69+
im_info['scale_factor'] = np.array(
70+
[im_scale_y, im_scale_x]).astype('float32')
71+
return im, im_info
72+
73+
def generate_scale(self, im):
74+
origin_shape = im.shape[:2]
75+
im_c = im.shape[2]
76+
if self.keep_ratio:
77+
im_size_min = np.min(origin_shape)
78+
im_size_max = np.max(origin_shape)
79+
target_size_min = np.min(self.target_size)
80+
target_size_max = np.max(self.target_size)
81+
im_scale = float(target_size_min) / float(im_size_min)
82+
if np.round(im_scale * im_size_max) > target_size_max:
83+
im_scale = float(target_size_max) / float(im_size_max)
84+
im_scale_x = im_scale
85+
im_scale_y = im_scale
86+
else:
87+
resize_h, resize_w = self.target_size
88+
im_scale_y = resize_h / float(origin_shape[0])
89+
im_scale_x = resize_w / float(origin_shape[1])
90+
return im_scale_y, im_scale_x
91+
92+
93+
class Permute(object):
94+
def __init__(self, ):
95+
super(Permute, self).__init__()
96+
97+
def __call__(self, im, im_info):
98+
im = im.transpose((2, 0, 1))
99+
return im, im_info
100+
101+
102+
class NormalizeImage(object):
103+
def __init__(self, mean, std, is_scale=True, norm_type='mean_std'):
104+
self.mean = mean
105+
self.std = std
106+
self.is_scale = is_scale
107+
self.norm_type = norm_type
108+
109+
def __call__(self, im, im_info):
110+
im = im.astype(np.float32, copy=False)
111+
if self.is_scale:
112+
scale = 1.0 / 255.0
113+
im *= scale
114+
115+
if self.norm_type == 'mean_std':
116+
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
117+
std = np.array(self.std)[np.newaxis, np.newaxis, :]
118+
im -= mean
119+
im /= std
120+
return im, im_info
121+
122+
123+
class PadStride(object):
124+
def __init__(self, stride=0):
125+
self.coarsest_stride = stride
126+
127+
def __call__(self, im, im_info):
128+
coarsest_stride = self.coarsest_stride
129+
if coarsest_stride <= 0:
130+
return im, im_info
131+
im_c, im_h, im_w = im.shape
132+
pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
133+
pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
134+
padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
135+
padding_im[:, :im_h, :im_w] = im
136+
return padding_im, im_info
137+
138+
139+
class Compose:
140+
def __init__(self, transforms):
141+
self.transforms = []
142+
for op_info in transforms:
143+
new_op_info = op_info.copy()
144+
op_type = new_op_info.pop('type')
145+
self.transforms.append(eval(op_type)(**new_op_info))
146+
147+
def __call__(self, img_path):
148+
img, im_info = decode_image(img_path)
149+
for t in self.transforms:
150+
img, im_info = t(img, im_info)
151+
inputs = copy.deepcopy(im_info)
152+
inputs['image'] = img
153+
return inputs
154+
155+
156+
# postprocess
157+
def rbox_iou(g, p):
158+
g = np.array(g)
159+
p = np.array(p)
160+
g = Polygon(g[:8].reshape((4, 2)))
161+
p = Polygon(p[:8].reshape((4, 2)))
162+
g = g.buffer(0)
163+
p = p.buffer(0)
164+
if not g.is_valid or not p.is_valid:
165+
return 0
166+
inter = Polygon(g).intersection(Polygon(p)).area
167+
union = g.area + p.area - inter
168+
if union == 0:
169+
return 0
170+
else:
171+
return inter / union
172+
173+
174+
def multiclass_nms_rotated(pred_bboxes,
175+
pred_scores,
176+
iou_threshlod=0.1,
177+
score_threshold=0.1):
178+
"""
179+
Args:
180+
pred_bboxes (numpy.ndarray): [B, N, 8]
181+
pred_scores (numpy.ndarray): [B, C, N]
182+
183+
Return:
184+
bboxes (numpy.ndarray): [N, 10]
185+
bbox_num (numpy.ndarray): [B]
186+
"""
187+
bbox_num = []
188+
bboxes = []
189+
for bbox_per_img, score_per_img in zip(pred_bboxes, pred_scores):
190+
num_per_img = 0
191+
for cls_id, score_per_cls in enumerate(score_per_img):
192+
keep_mask = score_per_cls > score_threshold
193+
bbox = bbox_per_img[keep_mask]
194+
score = score_per_cls[keep_mask]
195+
196+
idx = score.argsort()[::-1]
197+
bbox = bbox[idx]
198+
score = score[idx]
199+
keep_idx = []
200+
for i, b in enumerate(bbox):
201+
supressed = False
202+
for gi in keep_idx:
203+
g = bbox[gi]
204+
if rbox_iou(b, g) > iou_threshlod:
205+
supressed = True
206+
break
207+
208+
if supressed:
209+
continue
210+
211+
keep_idx.append(i)
212+
213+
keep_box = bbox[keep_idx]
214+
keep_score = score[keep_idx]
215+
keep_cls_ids = np.ones(len(keep_idx)) * cls_id
216+
bboxes.append(
217+
np.concatenate(
218+
[keep_cls_ids[:, None], keep_score[:, None], keep_box],
219+
axis=-1))
220+
num_per_img += len(keep_idx)
221+
222+
bbox_num.append(num_per_img)
223+
224+
return np.concatenate(bboxes, axis=0), np.array(bbox_num)
225+
226+
227+
def get_test_images(infer_dir, infer_img):
228+
"""
229+
Get image path list in TEST mode
230+
"""
231+
assert infer_img is not None or infer_dir is not None, \
232+
"--image_file or --image_dir should be set"
233+
assert infer_img is None or os.path.isfile(infer_img), \
234+
"{} is not a file".format(infer_img)
235+
assert infer_dir is None or os.path.isdir(infer_dir), \
236+
"{} is not a directory".format(infer_dir)
237+
238+
# infer_img has a higher priority
239+
if infer_img and os.path.isfile(infer_img):
240+
return [infer_img]
241+
242+
images = set()
243+
infer_dir = os.path.abspath(infer_dir)
244+
assert os.path.isdir(infer_dir), \
245+
"infer_dir {} is not a directory".format(infer_dir)
246+
exts = ['jpg', 'jpeg', 'png', 'bmp']
247+
exts += [ext.upper() for ext in exts]
248+
for ext in exts:
249+
images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
250+
images = list(images)
251+
252+
assert len(images) > 0, "no image found in {}".format(infer_dir)
253+
print("Found {} inference images in total.".format(len(images)))
254+
255+
return images
256+
257+
258+
def predict_image(infer_config, predictor, img_list):
259+
# load preprocess transforms
260+
transforms = Compose(infer_config['Preprocess'])
261+
# predict image
262+
for img_path in img_list:
263+
inputs = transforms(img_path)
264+
inputs_name = [var.name for var in predictor.get_inputs()]
265+
inputs = {k: inputs[k][None, ] for k in inputs_name}
266+
267+
outputs = predictor.run(output_names=None, input_feed=inputs)
268+
269+
bboxes, bbox_num = multiclass_nms_rotated(
270+
np.array(outputs[0]), np.array(outputs[1]))
271+
print("ONNXRuntime predict: ")
272+
for bbox in bboxes:
273+
if bbox[0] > -1 and bbox[1] > infer_config['draw_threshold']:
274+
print(f"{int(bbox[0])} {bbox[1]} "
275+
f"{bbox[2]} {bbox[3]} {bbox[4]} {bbox[5]}"
276+
f"{bbox[6]} {bbox[7]} {bbox[8]} {bbox[9]}")
277+
278+
279+
def parse_args():
280+
parser = argparse.ArgumentParser(description=__doc__)
281+
parser.add_argument("--infer_cfg", type=str, help="infer_cfg.yml")
282+
parser.add_argument(
283+
'--onnx_file',
284+
type=str,
285+
default="model.onnx",
286+
help="onnx model file path")
287+
parser.add_argument("--image_dir", type=str)
288+
parser.add_argument("--image_file", type=str)
289+
return parser.parse_args()
290+
291+
292+
if __name__ == '__main__':
293+
FLAGS = parse_args()
294+
# load image list
295+
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
296+
# load predictor
297+
predictor = InferenceSession(FLAGS.onnx_file)
298+
# load infer config
299+
with open(FLAGS.infer_cfg) as f:
300+
infer_config = yaml.safe_load(f)
301+
302+
predict_image(infer_config, predictor, img_list)

ppdet/modeling/heads/ppyoloe_r_head.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def forward(self, feat, avg_feat):
4444

4545
@register
4646
class PPYOLOERHead(nn.Layer):
47-
__shared__ = ['num_classes', 'trt']
47+
__shared__ = ['num_classes', 'trt', 'export_onnx']
4848
__inject__ = ['static_assigner', 'assigner', 'nms']
4949

5050
def __init__(self,
@@ -57,6 +57,7 @@ def __init__(self,
5757
use_varifocal_loss=True,
5858
static_assigner_epoch=4,
5959
trt=False,
60+
export_onnx=False,
6061
static_assigner='ATSSAssigner',
6162
assigner='TaskAlignedAssigner',
6263
nms='MultiClassNMS',
@@ -84,6 +85,8 @@ def __init__(self,
8485
self.stem_cls = nn.LayerList()
8586
self.stem_reg = nn.LayerList()
8687
self.stem_angle = nn.LayerList()
88+
trt = False if export_onnx else trt
89+
self.export_onnx = export_onnx
8790
act = get_act_fn(
8891
act, trt=trt) if act is None or isinstance(act,
8992
(str, dict)) else act
@@ -415,5 +418,7 @@ def post_process(self, head_outs, scale_factor):
415418
],
416419
axis=-1).reshape([-1, 1, 8])
417420
pred_bboxes /= scale_factor
421+
if self.export_onnx:
422+
return pred_bboxes, pred_scores
418423
bbox_pred, bbox_num, _ = self.nms(pred_bboxes, pred_scores)
419424
return bbox_pred, bbox_num

0 commit comments

Comments
 (0)