|
| 1 | +// Copyright (c) OpenMMLab. All rights reserved. |
| 2 | +#include "yolo_head.h" |
| 3 | + |
| 4 | +#include <math.h> |
| 5 | + |
| 6 | +#include <algorithm> |
| 7 | +#include <numeric> |
| 8 | + |
| 9 | +#include "mmdeploy/core/model.h" |
| 10 | +#include "mmdeploy/core/utils/device_utils.h" |
| 11 | +#include "mmdeploy/core/utils/formatter.h" |
| 12 | +#include "utils.h" |
| 13 | + |
| 14 | +namespace mmdeploy::mmdet { |
| 15 | + |
| 16 | +YOLOHead::YOLOHead(const Value& cfg) : MMDetection(cfg) { |
| 17 | + auto init = [&]() -> Result<void> { |
| 18 | + auto model = cfg["context"]["model"].get<Model>(); |
| 19 | + if (cfg.contains("params")) { |
| 20 | + nms_pre_ = cfg["params"].value("nms_pre", -1); |
| 21 | + score_thr_ = cfg["params"].value("score_thr", 0.02f); |
| 22 | + min_bbox_size_ = cfg["params"].value("min_bbox_size", 0); |
| 23 | + iou_threshold_ = cfg["params"].contains("nms") |
| 24 | + ? cfg["params"]["nms"].value("iou_threshold", 0.45f) |
| 25 | + : 0.45f; |
| 26 | + if (cfg["params"].contains("anchor_generator")) { |
| 27 | + from_value(cfg["params"]["anchor_generator"]["base_sizes"], anchors_); |
| 28 | + from_value(cfg["params"]["anchor_generator"]["strides"], strides_); |
| 29 | + } |
| 30 | + } |
| 31 | + return success(); |
| 32 | + }; |
| 33 | + init().value(); |
| 34 | +} |
| 35 | + |
| 36 | +Result<Value> YOLOHead::operator()(const Value& prep_res, const Value& infer_res) { |
| 37 | + MMDEPLOY_DEBUG("prep_res: {}\ninfer_res: {}", prep_res, infer_res); |
| 38 | + try { |
| 39 | + const Device kHost{0, 0}; |
| 40 | + std::vector<Tensor> pred_maps; |
| 41 | + for (auto iter = infer_res.begin(); iter != infer_res.end(); iter++) { |
| 42 | + auto pred_map = iter->get<Tensor>(); |
| 43 | + OUTCOME_TRY(auto _pred_map, MakeAvailableOnDevice(pred_map, kHost, stream())); |
| 44 | + pred_maps.push_back(_pred_map); |
| 45 | + } |
| 46 | + OUTCOME_TRY(stream().Wait()); |
| 47 | + // reorder pred_maps according to strides and anchors, mainly for rknpu yolov3 |
| 48 | + if ((pred_maps.size() > 1) && |
| 49 | + !((strides_[0] < strides_[1]) ^ (pred_maps[0].shape(3) < pred_maps[1].shape(3)))) { |
| 50 | + std::reverse(pred_maps.begin(), pred_maps.end()); |
| 51 | + } |
| 52 | + OUTCOME_TRY(auto result, GetBBoxes(prep_res["img_metas"], pred_maps)); |
| 53 | + return to_value(result); |
| 54 | + } catch (...) { |
| 55 | + return Status(eFail); |
| 56 | + } |
| 57 | +} |
| 58 | + |
| 59 | +inline static int clamp(float val, int min, int max) { |
| 60 | + return val > min ? (val < max ? val : max) : min; |
| 61 | +} |
| 62 | + |
| 63 | +static float sigmoid(float x) { return 1.0 / (1.0 + expf(-x)); } |
| 64 | + |
| 65 | +static float unsigmoid(float y) { return -1.0 * logf((1.0 / y) - 1.0); } |
| 66 | + |
| 67 | +int YOLOHead::YOLOFeatDecode(const Tensor& feat_map, const std::vector<std::vector<float>>& anchor, |
| 68 | + int grid_h, int grid_w, int height, int width, int stride, |
| 69 | + std::vector<float>& boxes, std::vector<float>& obj_probs, |
| 70 | + std::vector<int>& class_id, float threshold) const { |
| 71 | + auto input = const_cast<float*>(feat_map.data<float>()); |
| 72 | + auto prop_box_size = feat_map.shape(1) / anchor.size(); |
| 73 | + const int kClasses = prop_box_size - 5; |
| 74 | + int valid_count = 0; |
| 75 | + int grid_len = grid_h * grid_w; |
| 76 | + float thres = unsigmoid(threshold); |
| 77 | + for (int a = 0; a < anchor.size(); a++) { |
| 78 | + for (int i = 0; i < grid_h; i++) { |
| 79 | + for (int j = 0; j < grid_w; j++) { |
| 80 | + float box_confidence = input[(prop_box_size * a + 4) * grid_len + i * grid_w + j]; |
| 81 | + if (box_confidence >= thres) { |
| 82 | + int offset = (prop_box_size * a) * grid_len + i * grid_w + j; |
| 83 | + float* in_ptr = input + offset; |
| 84 | + |
| 85 | + float box_x = sigmoid(*in_ptr); |
| 86 | + float box_y = sigmoid(in_ptr[grid_len]); |
| 87 | + float box_w = in_ptr[2 * grid_len]; |
| 88 | + float box_h = in_ptr[3 * grid_len]; |
| 89 | + auto box = yolo_decode(box_x, box_y, box_w, box_h, stride, anchor, j, i, a); |
| 90 | + |
| 91 | + box_x = box[0]; |
| 92 | + box_y = box[1]; |
| 93 | + box_w = box[2]; |
| 94 | + box_h = box[3]; |
| 95 | + |
| 96 | + box_x -= (box_w / 2.0); |
| 97 | + box_y -= (box_h / 2.0); |
| 98 | + boxes.push_back(box_x); |
| 99 | + boxes.push_back(box_y); |
| 100 | + boxes.push_back(box_x + box_w); |
| 101 | + boxes.push_back(box_y + box_h); |
| 102 | + |
| 103 | + float max_class_probs = in_ptr[5 * grid_len]; |
| 104 | + int max_class_id = 0; |
| 105 | + for (int k = 1; k < kClasses; ++k) { |
| 106 | + float prob = in_ptr[(5 + k) * grid_len]; |
| 107 | + if (prob > max_class_probs) { |
| 108 | + max_class_id = k; |
| 109 | + max_class_probs = prob; |
| 110 | + } |
| 111 | + } |
| 112 | + obj_probs.push_back(sigmoid(max_class_probs) * sigmoid(box_confidence)); |
| 113 | + class_id.push_back(max_class_id); |
| 114 | + valid_count++; |
| 115 | + } |
| 116 | + } |
| 117 | + } |
| 118 | + } |
| 119 | + return valid_count; |
| 120 | +} |
| 121 | + |
| 122 | +Result<Detections> YOLOHead::GetBBoxes(const Value& prep_res, |
| 123 | + const std::vector<Tensor>& pred_maps) const { |
| 124 | + std::vector<float> filter_boxes; |
| 125 | + std::vector<float> obj_probs; |
| 126 | + std::vector<int> class_id; |
| 127 | + |
| 128 | + int model_in_h = prep_res["img_shape"][1].get<int>(); |
| 129 | + int model_in_w = prep_res["img_shape"][2].get<int>(); |
| 130 | + |
| 131 | + for (int i = 0; i < pred_maps.size(); i++) { |
| 132 | + int stride = strides_[i]; |
| 133 | + int grid_h = model_in_h / stride; |
| 134 | + int grid_w = model_in_w / stride; |
| 135 | + YOLOFeatDecode(pred_maps[i], anchors_[i], grid_h, grid_w, model_in_h, model_in_w, stride, |
| 136 | + filter_boxes, obj_probs, class_id, score_thr_); |
| 137 | + } |
| 138 | + |
| 139 | + std::vector<int> indexArray; |
| 140 | + for (int i = 0; i < obj_probs.size(); ++i) { |
| 141 | + indexArray.push_back(i); |
| 142 | + } |
| 143 | + Sort(obj_probs, class_id, indexArray); |
| 144 | + |
| 145 | + Tensor dets(TensorDesc{Device{0, 0}, DataType::kFLOAT, |
| 146 | + TensorShape{int(filter_boxes.size() / 4), 4}, "dets"}); |
| 147 | + std::copy(filter_boxes.begin(), filter_boxes.end(), dets.data<float>()); |
| 148 | + NMS(dets, iou_threshold_, indexArray); |
| 149 | + |
| 150 | + Detections objs; |
| 151 | + std::vector<float> scale_factor; |
| 152 | + if (prep_res.contains("scale_factor")) { |
| 153 | + from_value(prep_res["scale_factor"], scale_factor); |
| 154 | + } else { |
| 155 | + scale_factor = {1.f, 1.f, 1.f, 1.f}; |
| 156 | + } |
| 157 | + int ori_width = prep_res["ori_shape"][2].get<int>(); |
| 158 | + int ori_height = prep_res["ori_shape"][1].get<int>(); |
| 159 | + auto det_ptr = dets.data<float>(); |
| 160 | + for (int i = 0; i < indexArray.size(); ++i) { |
| 161 | + if (indexArray[i] == -1) { |
| 162 | + continue; |
| 163 | + } |
| 164 | + int j = indexArray[i]; |
| 165 | + auto x1 = clamp(det_ptr[j * 4 + 0], 0, model_in_w); |
| 166 | + auto y1 = clamp(det_ptr[j * 4 + 1], 0, model_in_h); |
| 167 | + auto x2 = clamp(det_ptr[j * 4 + 2], 0, model_in_w); |
| 168 | + auto y2 = clamp(det_ptr[j * 4 + 3], 0, model_in_h); |
| 169 | + int label_id = class_id[i]; |
| 170 | + float score = obj_probs[i]; |
| 171 | + |
| 172 | + MMDEPLOY_DEBUG("{}-th box: ({}, {}, {}, {}), {}, {}", i, x1, y1, x2, y2, label_id, score); |
| 173 | + |
| 174 | + auto rect = MapToOriginImage(x1, y1, x2, y2, scale_factor.data(), 0, 0, ori_width, ori_height); |
| 175 | + if (rect[2] - rect[0] < min_bbox_size_ || rect[3] - rect[1] < min_bbox_size_) { |
| 176 | + MMDEPLOY_DEBUG("ignore small bbox with width '{}' and height '{}", rect[2] - rect[0], |
| 177 | + rect[3] - rect[1]); |
| 178 | + continue; |
| 179 | + } |
| 180 | + Detection det{}; |
| 181 | + det.index = i; |
| 182 | + det.label_id = label_id; |
| 183 | + det.score = score; |
| 184 | + det.bbox = rect; |
| 185 | + objs.push_back(std::move(det)); |
| 186 | + } |
| 187 | + |
| 188 | + return objs; |
| 189 | +} |
| 190 | + |
| 191 | +Result<Value> YOLOV3Head::operator()(const Value& prep_res, const Value& infer_res) { |
| 192 | + return YOLOHead::operator()(prep_res, infer_res); |
| 193 | +} |
| 194 | + |
| 195 | +std::array<float, 4> YOLOV3Head::yolo_decode(float box_x, float box_y, float box_w, float box_h, |
| 196 | + float stride, |
| 197 | + const std::vector<std::vector<float>>& anchor, int j, |
| 198 | + int i, int a) const { |
| 199 | + box_x = (box_x + j) * stride; |
| 200 | + box_y = (box_y + i) * stride; |
| 201 | + box_w = expf(box_w) * anchor[a][0]; |
| 202 | + box_h = expf(box_h) * anchor[a][1]; |
| 203 | + return std::array<float, 4>{box_x, box_y, box_w, box_h}; |
| 204 | +} |
| 205 | + |
| 206 | +Result<Value> YOLOV5Head::operator()(const Value& prep_res, const Value& infer_res) { |
| 207 | + return YOLOHead::operator()(prep_res, infer_res); |
| 208 | +} |
| 209 | + |
| 210 | +std::array<float, 4> YOLOV5Head::yolo_decode(float box_x, float box_y, float box_w, float box_h, |
| 211 | + float stride, |
| 212 | + const std::vector<std::vector<float>>& anchor, int j, |
| 213 | + int i, int a) const { |
| 214 | + box_x = box_x * 2 - 0.5; |
| 215 | + box_y = box_y * 2 - 0.5; |
| 216 | + box_w = box_w * 2 - 0.5; |
| 217 | + box_h = box_h * 2 - 0.5; |
| 218 | + box_x = (box_x + j) * stride; |
| 219 | + box_y = (box_y + i) * stride; |
| 220 | + box_w = box_w * box_w * anchor[a][0]; |
| 221 | + box_h = box_h * box_h * anchor[a][1]; |
| 222 | + return std::array<float, 4>{box_x, box_y, box_w, box_h}; |
| 223 | +} |
| 224 | + |
| 225 | +MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMDetection, YOLOV3Head); |
| 226 | +MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMDetection, YOLOV5Head); |
| 227 | + |
| 228 | +} // namespace mmdeploy::mmdet |
0 commit comments