Skip to content

Commit 4dd4d48

Browse files
authored
Add rv1126 yolov3 support to sdk (#1280)
* add yolov3 head to SDK * add yolov5 head to SDK * fix export-info and lint, add reverse check * fix lint * fix export info for yolo heads * add output_names to partition_config * fix typo * config * normalize config * fix * refactor config * fix lint and doc * c++ form * resolve comments * fix CI * fix CI * fix CI * float strides anchors * refine pipeline of rknn-int8 * config * rename func * refactor * rknn wrapper dict and fix typo * rknn wrapper output update, mmcls use end2end type * fix typo
1 parent 522fcc0 commit 4dd4d48

35 files changed

+579
-147
lines changed

configs/_base_/backends/rknn.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
backend_config = dict(
22
type='rknn',
33
common_config=dict(
4-
mean_values=None, # [[103.53, 116.28, 123.675]],
5-
std_values=None, # [[57.375, 57.12, 58.395]],
64
target_platform='rv1126', # 'rk3588'
75
optimization_level=1),
8-
quantization_config=dict(do_quantization=False, dataset=None))
6+
quantization_config=dict(do_quantization=True, dataset=None))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
_base_ = ['./classification_static.py', '../_base_/backends/rknn.py']
2+
3+
onnx_config = dict(input_shape=[224, 224])
4+
codebase_config = dict(model_type='end2end')
5+
backend_config = dict(
6+
input_size_list=[[3, 224, 224]],
7+
quantization_config=dict(do_quantization=False))
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
_base_ = ['./classification_static.py', '../_base_/backends/rknn.py']
22

33
onnx_config = dict(input_shape=[224, 224])
4-
codebase_config = dict(model_type='rknn')
4+
codebase_config = dict(model_type='end2end')
55
backend_config = dict(input_size_list=[[3, 224, 224]])
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
_base_ = ['../_base_/base_static.py', '../../_base_/backends/rknn.py']
2+
3+
onnx_config = dict(input_shape=[320, 320])
4+
5+
codebase_config = dict(model_type='rknn')
6+
7+
backend_config = dict(
8+
input_size_list=[[3, 320, 320]],
9+
quantization_config=dict(do_quantization=False))
10+
11+
# # yolov3, yolox for rknn-toolkit and rknn-toolkit2
12+
# partition_config = dict(
13+
# type='rknn', # the partition policy name
14+
# apply_marks=True, # should always be set to True
15+
# partition_cfg=[
16+
# dict(
17+
# save_file='model.onnx', # name to save the partitioned onnx
18+
# start=['detector_forward:input'], # [mark_name:input, ...]
19+
# end=['yolo_head:input'], # [mark_name:output, ...]
20+
# output_names=[f'pred_maps.{i}' for i in range(3)]) # out names
21+
# ])
22+
23+
# # retinanet, ssd, fsaf for rknn-toolkit2
24+
# partition_config = dict(
25+
# type='rknn', # the partition policy name
26+
# apply_marks=True,
27+
# partition_cfg=[
28+
# dict(
29+
# save_file='model.onnx',
30+
# start='detector_forward:input',
31+
# end=['BaseDenseHead:output'],
32+
# output_names=[f'BaseDenseHead.cls.{i}' for i in range(5)] +
33+
# [f'BaseDenseHead.loc.{i}' for i in range(5)])
34+
# ])

configs/mmdet/detection/detection_rknn_static-320x320.py renamed to configs/mmdet/detection/detection_rknn-int8_static-320x320.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,27 @@
66

77
backend_config = dict(input_size_list=[[3, 320, 320]])
88

9-
# # yolov3, yolox
9+
# # yolov3, yolox for rknn-toolkit and rknn-toolkit2
1010
# partition_config = dict(
1111
# type='rknn', # the partition policy name
1212
# apply_marks=True, # should always be set to True
1313
# partition_cfg=[
1414
# dict(
1515
# save_file='model.onnx', # name to save the partitioned onnx
1616
# start=['detector_forward:input'], # [mark_name:input, ...]
17-
# end=['yolo_head:input']) # [mark_name:output, ...]
17+
# end=['yolo_head:input'], # [mark_name:output, ...]
18+
# output_names=[f'pred_maps.{i}' for i in range(3)]) # out names
1819
# ])
1920

20-
# # retinanet, ssd, fsaf
21+
# # retinanet, ssd, fsaf for rknn-toolkit2
2122
# partition_config = dict(
2223
# type='rknn', # the partition policy name
2324
# apply_marks=True,
2425
# partition_cfg=[
2526
# dict(
2627
# save_file='model.onnx',
2728
# start='detector_forward:input',
28-
# end=['BaseDenseHead:output'])
29+
# end=['BaseDenseHead:output'],
30+
# output_names=[f'BaseDenseHead.cls.{i}' for i in range(5)] +
31+
# [f'BaseDenseHead.loc.{i}' for i in range(5)])
2932
# ])

configs/mmdet/detection/yolov3_partition_onnxruntime_static.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,6 @@
88
dict(
99
save_file='yolov3.onnx',
1010
start=['detector_forward:input'],
11-
end=['yolo_head:input'])
11+
end=['yolo_head:input'],
12+
output_names=[f'pred_maps.{i}' for i in range(3)])
1213
])
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
_base_ = ['./segmentation_static.py', '../_base_/backends/rknn.py']
2+
3+
onnx_config = dict(input_shape=[320, 320])
4+
5+
codebase_config = dict(model_type='rknn')
6+
7+
backend_config = dict(
8+
input_size_list=[[3, 320, 320]],
9+
quantization_config=dict(do_quantization=False))
+228
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
// Copyright (c) OpenMMLab. All rights reserved.
2+
#include "yolo_head.h"
3+
4+
#include <math.h>
5+
6+
#include <algorithm>
7+
#include <numeric>
8+
9+
#include "mmdeploy/core/model.h"
10+
#include "mmdeploy/core/utils/device_utils.h"
11+
#include "mmdeploy/core/utils/formatter.h"
12+
#include "utils.h"
13+
14+
namespace mmdeploy::mmdet {
15+
16+
YOLOHead::YOLOHead(const Value& cfg) : MMDetection(cfg) {
17+
auto init = [&]() -> Result<void> {
18+
auto model = cfg["context"]["model"].get<Model>();
19+
if (cfg.contains("params")) {
20+
nms_pre_ = cfg["params"].value("nms_pre", -1);
21+
score_thr_ = cfg["params"].value("score_thr", 0.02f);
22+
min_bbox_size_ = cfg["params"].value("min_bbox_size", 0);
23+
iou_threshold_ = cfg["params"].contains("nms")
24+
? cfg["params"]["nms"].value("iou_threshold", 0.45f)
25+
: 0.45f;
26+
if (cfg["params"].contains("anchor_generator")) {
27+
from_value(cfg["params"]["anchor_generator"]["base_sizes"], anchors_);
28+
from_value(cfg["params"]["anchor_generator"]["strides"], strides_);
29+
}
30+
}
31+
return success();
32+
};
33+
init().value();
34+
}
35+
36+
Result<Value> YOLOHead::operator()(const Value& prep_res, const Value& infer_res) {
37+
MMDEPLOY_DEBUG("prep_res: {}\ninfer_res: {}", prep_res, infer_res);
38+
try {
39+
const Device kHost{0, 0};
40+
std::vector<Tensor> pred_maps;
41+
for (auto iter = infer_res.begin(); iter != infer_res.end(); iter++) {
42+
auto pred_map = iter->get<Tensor>();
43+
OUTCOME_TRY(auto _pred_map, MakeAvailableOnDevice(pred_map, kHost, stream()));
44+
pred_maps.push_back(_pred_map);
45+
}
46+
OUTCOME_TRY(stream().Wait());
47+
// reorder pred_maps according to strides and anchors, mainly for rknpu yolov3
48+
if ((pred_maps.size() > 1) &&
49+
!((strides_[0] < strides_[1]) ^ (pred_maps[0].shape(3) < pred_maps[1].shape(3)))) {
50+
std::reverse(pred_maps.begin(), pred_maps.end());
51+
}
52+
OUTCOME_TRY(auto result, GetBBoxes(prep_res["img_metas"], pred_maps));
53+
return to_value(result);
54+
} catch (...) {
55+
return Status(eFail);
56+
}
57+
}
58+
59+
inline static int clamp(float val, int min, int max) {
60+
return val > min ? (val < max ? val : max) : min;
61+
}
62+
63+
static float sigmoid(float x) { return 1.0 / (1.0 + expf(-x)); }
64+
65+
static float unsigmoid(float y) { return -1.0 * logf((1.0 / y) - 1.0); }
66+
67+
int YOLOHead::YOLOFeatDecode(const Tensor& feat_map, const std::vector<std::vector<float>>& anchor,
68+
int grid_h, int grid_w, int height, int width, int stride,
69+
std::vector<float>& boxes, std::vector<float>& obj_probs,
70+
std::vector<int>& class_id, float threshold) const {
71+
auto input = const_cast<float*>(feat_map.data<float>());
72+
auto prop_box_size = feat_map.shape(1) / anchor.size();
73+
const int kClasses = prop_box_size - 5;
74+
int valid_count = 0;
75+
int grid_len = grid_h * grid_w;
76+
float thres = unsigmoid(threshold);
77+
for (int a = 0; a < anchor.size(); a++) {
78+
for (int i = 0; i < grid_h; i++) {
79+
for (int j = 0; j < grid_w; j++) {
80+
float box_confidence = input[(prop_box_size * a + 4) * grid_len + i * grid_w + j];
81+
if (box_confidence >= thres) {
82+
int offset = (prop_box_size * a) * grid_len + i * grid_w + j;
83+
float* in_ptr = input + offset;
84+
85+
float box_x = sigmoid(*in_ptr);
86+
float box_y = sigmoid(in_ptr[grid_len]);
87+
float box_w = in_ptr[2 * grid_len];
88+
float box_h = in_ptr[3 * grid_len];
89+
auto box = yolo_decode(box_x, box_y, box_w, box_h, stride, anchor, j, i, a);
90+
91+
box_x = box[0];
92+
box_y = box[1];
93+
box_w = box[2];
94+
box_h = box[3];
95+
96+
box_x -= (box_w / 2.0);
97+
box_y -= (box_h / 2.0);
98+
boxes.push_back(box_x);
99+
boxes.push_back(box_y);
100+
boxes.push_back(box_x + box_w);
101+
boxes.push_back(box_y + box_h);
102+
103+
float max_class_probs = in_ptr[5 * grid_len];
104+
int max_class_id = 0;
105+
for (int k = 1; k < kClasses; ++k) {
106+
float prob = in_ptr[(5 + k) * grid_len];
107+
if (prob > max_class_probs) {
108+
max_class_id = k;
109+
max_class_probs = prob;
110+
}
111+
}
112+
obj_probs.push_back(sigmoid(max_class_probs) * sigmoid(box_confidence));
113+
class_id.push_back(max_class_id);
114+
valid_count++;
115+
}
116+
}
117+
}
118+
}
119+
return valid_count;
120+
}
121+
122+
Result<Detections> YOLOHead::GetBBoxes(const Value& prep_res,
123+
const std::vector<Tensor>& pred_maps) const {
124+
std::vector<float> filter_boxes;
125+
std::vector<float> obj_probs;
126+
std::vector<int> class_id;
127+
128+
int model_in_h = prep_res["img_shape"][1].get<int>();
129+
int model_in_w = prep_res["img_shape"][2].get<int>();
130+
131+
for (int i = 0; i < pred_maps.size(); i++) {
132+
int stride = strides_[i];
133+
int grid_h = model_in_h / stride;
134+
int grid_w = model_in_w / stride;
135+
YOLOFeatDecode(pred_maps[i], anchors_[i], grid_h, grid_w, model_in_h, model_in_w, stride,
136+
filter_boxes, obj_probs, class_id, score_thr_);
137+
}
138+
139+
std::vector<int> indexArray;
140+
for (int i = 0; i < obj_probs.size(); ++i) {
141+
indexArray.push_back(i);
142+
}
143+
Sort(obj_probs, class_id, indexArray);
144+
145+
Tensor dets(TensorDesc{Device{0, 0}, DataType::kFLOAT,
146+
TensorShape{int(filter_boxes.size() / 4), 4}, "dets"});
147+
std::copy(filter_boxes.begin(), filter_boxes.end(), dets.data<float>());
148+
NMS(dets, iou_threshold_, indexArray);
149+
150+
Detections objs;
151+
std::vector<float> scale_factor;
152+
if (prep_res.contains("scale_factor")) {
153+
from_value(prep_res["scale_factor"], scale_factor);
154+
} else {
155+
scale_factor = {1.f, 1.f, 1.f, 1.f};
156+
}
157+
int ori_width = prep_res["ori_shape"][2].get<int>();
158+
int ori_height = prep_res["ori_shape"][1].get<int>();
159+
auto det_ptr = dets.data<float>();
160+
for (int i = 0; i < indexArray.size(); ++i) {
161+
if (indexArray[i] == -1) {
162+
continue;
163+
}
164+
int j = indexArray[i];
165+
auto x1 = clamp(det_ptr[j * 4 + 0], 0, model_in_w);
166+
auto y1 = clamp(det_ptr[j * 4 + 1], 0, model_in_h);
167+
auto x2 = clamp(det_ptr[j * 4 + 2], 0, model_in_w);
168+
auto y2 = clamp(det_ptr[j * 4 + 3], 0, model_in_h);
169+
int label_id = class_id[i];
170+
float score = obj_probs[i];
171+
172+
MMDEPLOY_DEBUG("{}-th box: ({}, {}, {}, {}), {}, {}", i, x1, y1, x2, y2, label_id, score);
173+
174+
auto rect = MapToOriginImage(x1, y1, x2, y2, scale_factor.data(), 0, 0, ori_width, ori_height);
175+
if (rect[2] - rect[0] < min_bbox_size_ || rect[3] - rect[1] < min_bbox_size_) {
176+
MMDEPLOY_DEBUG("ignore small bbox with width '{}' and height '{}", rect[2] - rect[0],
177+
rect[3] - rect[1]);
178+
continue;
179+
}
180+
Detection det{};
181+
det.index = i;
182+
det.label_id = label_id;
183+
det.score = score;
184+
det.bbox = rect;
185+
objs.push_back(std::move(det));
186+
}
187+
188+
return objs;
189+
}
190+
191+
Result<Value> YOLOV3Head::operator()(const Value& prep_res, const Value& infer_res) {
192+
return YOLOHead::operator()(prep_res, infer_res);
193+
}
194+
195+
std::array<float, 4> YOLOV3Head::yolo_decode(float box_x, float box_y, float box_w, float box_h,
196+
float stride,
197+
const std::vector<std::vector<float>>& anchor, int j,
198+
int i, int a) const {
199+
box_x = (box_x + j) * stride;
200+
box_y = (box_y + i) * stride;
201+
box_w = expf(box_w) * anchor[a][0];
202+
box_h = expf(box_h) * anchor[a][1];
203+
return std::array<float, 4>{box_x, box_y, box_w, box_h};
204+
}
205+
206+
Result<Value> YOLOV5Head::operator()(const Value& prep_res, const Value& infer_res) {
207+
return YOLOHead::operator()(prep_res, infer_res);
208+
}
209+
210+
std::array<float, 4> YOLOV5Head::yolo_decode(float box_x, float box_y, float box_w, float box_h,
211+
float stride,
212+
const std::vector<std::vector<float>>& anchor, int j,
213+
int i, int a) const {
214+
box_x = box_x * 2 - 0.5;
215+
box_y = box_y * 2 - 0.5;
216+
box_w = box_w * 2 - 0.5;
217+
box_h = box_h * 2 - 0.5;
218+
box_x = (box_x + j) * stride;
219+
box_y = (box_y + i) * stride;
220+
box_w = box_w * box_w * anchor[a][0];
221+
box_h = box_h * box_h * anchor[a][1];
222+
return std::array<float, 4>{box_x, box_y, box_w, box_h};
223+
}
224+
225+
MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMDetection, YOLOV3Head);
226+
MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMDetection, YOLOV5Head);
227+
228+
} // namespace mmdeploy::mmdet
+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// Copyright (c) OpenMMLab. All rights reserved.
2+
#ifndef MMDEPLOY_CODEBASE_MMDET_YOLO_HEAD_H_
3+
#define MMDEPLOY_CODEBASE_MMDET_YOLO_HEAD_H_
4+
5+
#include "mmdeploy/codebase/mmdet/mmdet.h"
6+
#include "mmdeploy/core/tensor.h"
7+
8+
namespace mmdeploy::mmdet {
9+
10+
class YOLOHead : public MMDetection {
11+
public:
12+
explicit YOLOHead(const Value& cfg);
13+
Result<Value> operator()(const Value& prep_res, const Value& infer_res);
14+
int YOLOFeatDecode(const Tensor& feat_map, const std::vector<std::vector<float>>& anchor,
15+
int grid_h, int grid_w, int height, int width, int stride,
16+
std::vector<float>& boxes, std::vector<float>& obj_probs,
17+
std::vector<int>& class_id, float threshold) const;
18+
Result<Detections> GetBBoxes(const Value& prep_res, const std::vector<Tensor>& pred_maps) const;
19+
virtual std::array<float, 4> yolo_decode(float box_x, float box_y, float box_w, float box_h,
20+
float stride,
21+
const std::vector<std::vector<float>>& anchor, int j,
22+
int i, int a) const = 0;
23+
24+
private:
25+
float score_thr_{0.4f};
26+
int nms_pre_{1000};
27+
float iou_threshold_{0.45f};
28+
int min_bbox_size_{0};
29+
std::vector<std::vector<std::vector<float>>> anchors_;
30+
std::vector<float> strides_;
31+
};
32+
33+
class YOLOV3Head : public YOLOHead {
34+
public:
35+
using YOLOHead::YOLOHead;
36+
Result<Value> operator()(const Value& prep_res, const Value& infer_res);
37+
std::array<float, 4> yolo_decode(float box_x, float box_y, float box_w, float box_h, float stride,
38+
const std::vector<std::vector<float>>& anchor, int j, int i,
39+
int a) const override;
40+
};
41+
42+
class YOLOV5Head : public YOLOHead {
43+
public:
44+
using YOLOHead::YOLOHead;
45+
Result<Value> operator()(const Value& prep_res, const Value& infer_res);
46+
std::array<float, 4> yolo_decode(float box_x, float box_y, float box_w, float box_h, float stride,
47+
const std::vector<std::vector<float>>& anchor, int j, int i,
48+
int a) const override;
49+
};
50+
51+
} // namespace mmdeploy::mmdet
52+
53+
#endif // MMDEPLOY_CODEBASE_MMDET_YOLO_HEAD_H_

0 commit comments

Comments
 (0)