|
| 1 | +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +from __future__ import absolute_import |
| 16 | +from __future__ import division |
| 17 | +from __future__ import print_function |
| 18 | + |
| 19 | +import os |
| 20 | +import sys |
| 21 | +import six |
| 22 | +import glob |
| 23 | +import copy |
| 24 | +import yaml |
| 25 | +import argparse |
| 26 | +import cv2 |
| 27 | +import numpy as np |
| 28 | +from shapely.geometry import Polygon |
| 29 | +from onnxruntime import InferenceSession |
| 30 | + |
| 31 | + |
| 32 | +# preprocess ops |
| 33 | +def decode_image(img_path): |
| 34 | + with open(img_path, 'rb') as f: |
| 35 | + im_read = f.read() |
| 36 | + data = np.frombuffer(im_read, dtype='uint8') |
| 37 | + im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode |
| 38 | + im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) |
| 39 | + img_info = { |
| 40 | + "im_shape": np.array( |
| 41 | + im.shape[:2], dtype=np.float32), |
| 42 | + "scale_factor": np.array( |
| 43 | + [1., 1.], dtype=np.float32) |
| 44 | + } |
| 45 | + return im, img_info |
| 46 | + |
| 47 | + |
| 48 | +class Resize(object): |
| 49 | + def __init__(self, target_size, keep_ratio=True, interp=cv2.INTER_LINEAR): |
| 50 | + if isinstance(target_size, int): |
| 51 | + target_size = [target_size, target_size] |
| 52 | + self.target_size = target_size |
| 53 | + self.keep_ratio = keep_ratio |
| 54 | + self.interp = interp |
| 55 | + |
| 56 | + def __call__(self, im, im_info): |
| 57 | + assert len(self.target_size) == 2 |
| 58 | + assert self.target_size[0] > 0 and self.target_size[1] > 0 |
| 59 | + im_channel = im.shape[2] |
| 60 | + im_scale_y, im_scale_x = self.generate_scale(im) |
| 61 | + im = cv2.resize( |
| 62 | + im, |
| 63 | + None, |
| 64 | + None, |
| 65 | + fx=im_scale_x, |
| 66 | + fy=im_scale_y, |
| 67 | + interpolation=self.interp) |
| 68 | + im_info['im_shape'] = np.array(im.shape[:2]).astype('float32') |
| 69 | + im_info['scale_factor'] = np.array( |
| 70 | + [im_scale_y, im_scale_x]).astype('float32') |
| 71 | + return im, im_info |
| 72 | + |
| 73 | + def generate_scale(self, im): |
| 74 | + origin_shape = im.shape[:2] |
| 75 | + im_c = im.shape[2] |
| 76 | + if self.keep_ratio: |
| 77 | + im_size_min = np.min(origin_shape) |
| 78 | + im_size_max = np.max(origin_shape) |
| 79 | + target_size_min = np.min(self.target_size) |
| 80 | + target_size_max = np.max(self.target_size) |
| 81 | + im_scale = float(target_size_min) / float(im_size_min) |
| 82 | + if np.round(im_scale * im_size_max) > target_size_max: |
| 83 | + im_scale = float(target_size_max) / float(im_size_max) |
| 84 | + im_scale_x = im_scale |
| 85 | + im_scale_y = im_scale |
| 86 | + else: |
| 87 | + resize_h, resize_w = self.target_size |
| 88 | + im_scale_y = resize_h / float(origin_shape[0]) |
| 89 | + im_scale_x = resize_w / float(origin_shape[1]) |
| 90 | + return im_scale_y, im_scale_x |
| 91 | + |
| 92 | + |
| 93 | +class Permute(object): |
| 94 | + def __init__(self, ): |
| 95 | + super(Permute, self).__init__() |
| 96 | + |
| 97 | + def __call__(self, im, im_info): |
| 98 | + im = im.transpose((2, 0, 1)) |
| 99 | + return im, im_info |
| 100 | + |
| 101 | + |
| 102 | +class NormalizeImage(object): |
| 103 | + def __init__(self, mean, std, is_scale=True, norm_type='mean_std'): |
| 104 | + self.mean = mean |
| 105 | + self.std = std |
| 106 | + self.is_scale = is_scale |
| 107 | + self.norm_type = norm_type |
| 108 | + |
| 109 | + def __call__(self, im, im_info): |
| 110 | + im = im.astype(np.float32, copy=False) |
| 111 | + if self.is_scale: |
| 112 | + scale = 1.0 / 255.0 |
| 113 | + im *= scale |
| 114 | + |
| 115 | + if self.norm_type == 'mean_std': |
| 116 | + mean = np.array(self.mean)[np.newaxis, np.newaxis, :] |
| 117 | + std = np.array(self.std)[np.newaxis, np.newaxis, :] |
| 118 | + im -= mean |
| 119 | + im /= std |
| 120 | + return im, im_info |
| 121 | + |
| 122 | + |
| 123 | +class PadStride(object): |
| 124 | + def __init__(self, stride=0): |
| 125 | + self.coarsest_stride = stride |
| 126 | + |
| 127 | + def __call__(self, im, im_info): |
| 128 | + coarsest_stride = self.coarsest_stride |
| 129 | + if coarsest_stride <= 0: |
| 130 | + return im, im_info |
| 131 | + im_c, im_h, im_w = im.shape |
| 132 | + pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride) |
| 133 | + pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride) |
| 134 | + padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32) |
| 135 | + padding_im[:, :im_h, :im_w] = im |
| 136 | + return padding_im, im_info |
| 137 | + |
| 138 | + |
| 139 | +class Compose: |
| 140 | + def __init__(self, transforms): |
| 141 | + self.transforms = [] |
| 142 | + for op_info in transforms: |
| 143 | + new_op_info = op_info.copy() |
| 144 | + op_type = new_op_info.pop('type') |
| 145 | + self.transforms.append(eval(op_type)(**new_op_info)) |
| 146 | + |
| 147 | + def __call__(self, img_path): |
| 148 | + img, im_info = decode_image(img_path) |
| 149 | + for t in self.transforms: |
| 150 | + img, im_info = t(img, im_info) |
| 151 | + inputs = copy.deepcopy(im_info) |
| 152 | + inputs['image'] = img |
| 153 | + return inputs |
| 154 | + |
| 155 | + |
| 156 | +# postprocess |
| 157 | +def rbox_iou(g, p): |
| 158 | + g = np.array(g) |
| 159 | + p = np.array(p) |
| 160 | + g = Polygon(g[:8].reshape((4, 2))) |
| 161 | + p = Polygon(p[:8].reshape((4, 2))) |
| 162 | + g = g.buffer(0) |
| 163 | + p = p.buffer(0) |
| 164 | + if not g.is_valid or not p.is_valid: |
| 165 | + return 0 |
| 166 | + inter = Polygon(g).intersection(Polygon(p)).area |
| 167 | + union = g.area + p.area - inter |
| 168 | + if union == 0: |
| 169 | + return 0 |
| 170 | + else: |
| 171 | + return inter / union |
| 172 | + |
| 173 | + |
| 174 | +def multiclass_nms_rotated(pred_bboxes, |
| 175 | + pred_scores, |
| 176 | + iou_threshlod=0.1, |
| 177 | + score_threshold=0.1): |
| 178 | + """ |
| 179 | + Args: |
| 180 | + pred_bboxes (numpy.ndarray): [B, N, 8] |
| 181 | + pred_scores (numpy.ndarray): [B, C, N] |
| 182 | + |
| 183 | + Return: |
| 184 | + bboxes (numpy.ndarray): [N, 10] |
| 185 | + bbox_num (numpy.ndarray): [B] |
| 186 | + """ |
| 187 | + bbox_num = [] |
| 188 | + bboxes = [] |
| 189 | + for bbox_per_img, score_per_img in zip(pred_bboxes, pred_scores): |
| 190 | + num_per_img = 0 |
| 191 | + for cls_id, score_per_cls in enumerate(score_per_img): |
| 192 | + keep_mask = score_per_cls > score_threshold |
| 193 | + bbox = bbox_per_img[keep_mask] |
| 194 | + score = score_per_cls[keep_mask] |
| 195 | + |
| 196 | + idx = score.argsort()[::-1] |
| 197 | + bbox = bbox[idx] |
| 198 | + score = score[idx] |
| 199 | + keep_idx = [] |
| 200 | + for i, b in enumerate(bbox): |
| 201 | + supressed = False |
| 202 | + for gi in keep_idx: |
| 203 | + g = bbox[gi] |
| 204 | + if rbox_iou(b, g) > iou_threshlod: |
| 205 | + supressed = True |
| 206 | + break |
| 207 | + |
| 208 | + if supressed: |
| 209 | + continue |
| 210 | + |
| 211 | + keep_idx.append(i) |
| 212 | + |
| 213 | + keep_box = bbox[keep_idx] |
| 214 | + keep_score = score[keep_idx] |
| 215 | + keep_cls_ids = np.ones(len(keep_idx)) * cls_id |
| 216 | + bboxes.append( |
| 217 | + np.concatenate( |
| 218 | + [keep_cls_ids[:, None], keep_score[:, None], keep_box], |
| 219 | + axis=-1)) |
| 220 | + num_per_img += len(keep_idx) |
| 221 | + |
| 222 | + bbox_num.append(num_per_img) |
| 223 | + |
| 224 | + return np.concatenate(bboxes, axis=0), np.array(bbox_num) |
| 225 | + |
| 226 | + |
| 227 | +def get_test_images(infer_dir, infer_img): |
| 228 | + """ |
| 229 | + Get image path list in TEST mode |
| 230 | + """ |
| 231 | + assert infer_img is not None or infer_dir is not None, \ |
| 232 | + "--image_file or --image_dir should be set" |
| 233 | + assert infer_img is None or os.path.isfile(infer_img), \ |
| 234 | + "{} is not a file".format(infer_img) |
| 235 | + assert infer_dir is None or os.path.isdir(infer_dir), \ |
| 236 | + "{} is not a directory".format(infer_dir) |
| 237 | + |
| 238 | + # infer_img has a higher priority |
| 239 | + if infer_img and os.path.isfile(infer_img): |
| 240 | + return [infer_img] |
| 241 | + |
| 242 | + images = set() |
| 243 | + infer_dir = os.path.abspath(infer_dir) |
| 244 | + assert os.path.isdir(infer_dir), \ |
| 245 | + "infer_dir {} is not a directory".format(infer_dir) |
| 246 | + exts = ['jpg', 'jpeg', 'png', 'bmp'] |
| 247 | + exts += [ext.upper() for ext in exts] |
| 248 | + for ext in exts: |
| 249 | + images.update(glob.glob('{}/*.{}'.format(infer_dir, ext))) |
| 250 | + images = list(images) |
| 251 | + |
| 252 | + assert len(images) > 0, "no image found in {}".format(infer_dir) |
| 253 | + print("Found {} inference images in total.".format(len(images))) |
| 254 | + |
| 255 | + return images |
| 256 | + |
| 257 | + |
| 258 | +def predict_image(infer_config, predictor, img_list): |
| 259 | + # load preprocess transforms |
| 260 | + transforms = Compose(infer_config['Preprocess']) |
| 261 | + # predict image |
| 262 | + for img_path in img_list: |
| 263 | + inputs = transforms(img_path) |
| 264 | + inputs_name = [var.name for var in predictor.get_inputs()] |
| 265 | + inputs = {k: inputs[k][None, ] for k in inputs_name} |
| 266 | + |
| 267 | + outputs = predictor.run(output_names=None, input_feed=inputs) |
| 268 | + |
| 269 | + bboxes, bbox_num = multiclass_nms_rotated( |
| 270 | + np.array(outputs[0]), np.array(outputs[1])) |
| 271 | + print("ONNXRuntime predict: ") |
| 272 | + for bbox in bboxes: |
| 273 | + if bbox[0] > -1 and bbox[1] > infer_config['draw_threshold']: |
| 274 | + print(f"{int(bbox[0])} {bbox[1]} " |
| 275 | + f"{bbox[2]} {bbox[3]} {bbox[4]} {bbox[5]}" |
| 276 | + f"{bbox[6]} {bbox[7]} {bbox[8]} {bbox[9]}") |
| 277 | + |
| 278 | + |
| 279 | +def parse_args(): |
| 280 | + parser = argparse.ArgumentParser(description=__doc__) |
| 281 | + parser.add_argument("--infer_cfg", type=str, help="infer_cfg.yml") |
| 282 | + parser.add_argument( |
| 283 | + '--onnx_file', |
| 284 | + type=str, |
| 285 | + default="model.onnx", |
| 286 | + help="onnx model file path") |
| 287 | + parser.add_argument("--image_dir", type=str) |
| 288 | + parser.add_argument("--image_file", type=str) |
| 289 | + return parser.parse_args() |
| 290 | + |
| 291 | + |
| 292 | +if __name__ == '__main__': |
| 293 | + FLAGS = parse_args() |
| 294 | + # load image list |
| 295 | + img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file) |
| 296 | + # load predictor |
| 297 | + predictor = InferenceSession(FLAGS.onnx_file) |
| 298 | + # load infer config |
| 299 | + with open(FLAGS.infer_cfg) as f: |
| 300 | + infer_config = yaml.safe_load(f) |
| 301 | + |
| 302 | + predict_image(infer_config, predictor, img_list) |
0 commit comments