Skip to content

Commit a6428c2

Browse files
committed
feat(lite): add FaceParsingBiSeNet model ORT/MNN C++ (#332)
1 parent a02d1d0 commit a6428c2

12 files changed

+516
-17
lines changed

examples/lite/cv/test_lite_face_parsing_bisenet.cpp

+40
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,52 @@ static void test_mnn()
7777
static void test_ncnn()
7878
{
7979
#ifdef ENABLE_NCNN
80+
std::string proto_path = "../../../hub/ncnn/cv/face_parsing_512x512.opt.param";
81+
std::string bin_path = "../../../hub/ncnn/cv/face_parsing_512x512.opt.bin";
82+
std::string test_img_path = "../../../examples/lite/resources/test_lite_face_parsing.png";
83+
std::string save_img_path = "../../../logs/test_lite_face_parsing_bisenet_ncnn.jpg";
84+
85+
lite::ncnn::cv::segmentation::FaceParsingBiSeNet *face_parsing_bisenet =
86+
new lite::ncnn::cv::segmentation::FaceParsingBiSeNet(
87+
proto_path, bin_path, 4, 512, 512);
88+
89+
lite::types::FaceParsingContent content;
90+
cv::Mat img_bgr = cv::imread(test_img_path);
91+
face_parsing_bisenet->detect(img_bgr, content);
92+
93+
if (content.flag)
94+
{
95+
if (!content.merge.empty()) cv::imwrite(save_img_path, content.merge);
96+
std::cout << "NCNN Version FaceParsingBiSeNet Done!" << std::endl;
97+
}
98+
99+
delete face_parsing_bisenet;
80100
#endif
81101
}
82102

83103
static void test_tnn()
84104
{
85105
#ifdef ENABLE_TNN
106+
std::string proto_path = "../../../hub/ncnn/cv/face_parsing_512x512.opt.tnnproto";
107+
std::string model_path = "../../../hub/ncnn/cv/face_parsing_512x512.opt.tnnmodel";
108+
std::string test_img_path = "../../../examples/lite/resources/test_lite_face_parsing.png";
109+
std::string save_img_path = "../../../logs/test_lite_face_parsing_bisenet_tnn.jpg";
110+
111+
lite::tnn::cv::segmentation::FaceParsingBiSeNet *face_parsing_bisenet =
112+
new lite::tnn::cv::segmentation::FaceParsingBiSeNet(
113+
proto_path, model_path, 4);
114+
115+
lite::types::FaceParsingContent content;
116+
cv::Mat img_bgr = cv::imread(test_img_path);
117+
face_parsing_bisenet->detect(img_bgr, content);
118+
119+
if (content.flag)
120+
{
121+
if (!content.merge.empty()) cv::imwrite(save_img_path, content.merge);
122+
std::cout << "TNN Version FaceParsingBiSeNet Done!" << std::endl;
123+
}
124+
125+
delete face_parsing_bisenet;
86126
#endif
87127
}
88128

examples/lite/cv/test_lite_modnet.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ static void test_ncnn()
143143
{
144144
#ifdef ENABLE_NCNN
145145
std::string proto_path = "../../../hub/ncnn/cv/modnet_photographic_portrait_matting-512x512.opt.param";
146-
std::string model_path = "../../../hub/ncnn/cv/modnet_photographic_portrait_matting-512x512.opt.bin";
146+
std::string bin_path = "../../../hub/ncnn/cv/modnet_photographic_portrait_matting-512x512.opt.bin";
147147
std::string test_img_path = "../../../examples/lite/resources/test_lite_matting_input.jpg";
148148
std::string test_bgr_path = "../../../examples/lite/resources/test_lite_matting_bgr.jpg";
149149
std::string save_fgr_path = "../../../logs/test_lite_modnet_fgr_ncnn.jpg";
@@ -152,7 +152,7 @@ static void test_ncnn()
152152
std::string save_swap_path = "../../../logs/test_lite_modnet_swap_ncnn.jpg";
153153

154154
lite::ncnn::cv::matting::MODNet *modnet =
155-
new lite::ncnn::cv::matting::MODNet(proto_path, model_path, 16, 512, 512); // 16 threads
155+
new lite::ncnn::cv::matting::MODNet(proto_path, bin_path, 16, 512, 512); // 16 threads
156156

157157
lite::types::MattingContent content;
158158
cv::Mat img_bgr = cv::imread(test_img_path);

lite/mnn/cv/mnn_face_parsing_bisenet.cpp

-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
//
44

55
#include "mnn_face_parsing_bisenet.h"
6-
#include "lite/utils.h"
76

87
using mnncv::MNNFaceParsingBiSeNet;
98

lite/models.h

+4
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@
278278
#include "lite/ncnn/cv/ncnn_modnet.h"
279279
#include "lite/ncnn/cv/ncnn_female_photo2cartoon.h"
280280
#include "lite/ncnn/cv/ncnn_yolov6.h"
281+
#include "lite/ncnn/cv/ncnn_face_parsing_bisenet.h"
281282

282283
#endif
283284

@@ -358,6 +359,7 @@
358359
#include "lite/tnn/cv/tnn_head_seg.h"
359360
#include "lite/tnn/cv/tnn_female_photo2cartoon.h"
360361
#include "lite/tnn/cv/tnn_yolov6.h"
362+
#include "lite/tnn/cv/tnn_face_parsing_bisenet.h"
361363

362364
#endif
363365

@@ -1324,6 +1326,7 @@ namespace lite
13241326
{
13251327
typedef ncnncv::NCNNDeepLabV3ResNet101 DeepLabV3ResNet101;
13261328
typedef ncnncv::NCNNFCNResNet101 FCNResNet101;
1329+
typedef ncnncv::NCNNFaceParsingBiSeNet FaceParsingBiSeNet;
13271330
}
13281331
// reid
13291332
namespace reid
@@ -1474,6 +1477,7 @@ namespace lite
14741477
typedef tnncv::TNNDeepLabV3ResNet101 DeepLabV3ResNet101;
14751478
typedef tnncv::TNNFCNResNet101 FCNResNet101;
14761479
typedef tnncv::TNNHeadSeg HeadSeg;
1480+
typedef tnncv::TNNFaceParsingBiSeNet FaceParsingBiSeNet;
14771481
}
14781482
// reid
14791483
namespace reid

lite/ncnn/core/ncnn_core.h

+6-6
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
namespace ncnncv
1414
{
1515
class LITE_EXPORTS NCNNNanoDet; // [0] * reference: https://github.com/RangiLyu/nanodet
16-
class LITE_EXPORTS NCNNNanoDetEfficientNetLite; // [1] * reference: https://github.com/RangiLyu/nanodet
16+
class LITE_EXPORTS NCNNNanoDetEfficientNetLite; // [1] * reference: https://github.com/RangiLyu/nanodet
1717
class LITE_EXPORTS NCNNNanoDetDepreciated; // [2] * reference: https://github.com/RangiLyu/nanodet
18-
class LITE_EXPORTS NCNNNanoDetEfficientNetLiteDepreciated; // [3] * reference: https://github.com/RangiLyu/nanodet
18+
class LITE_EXPORTS NCNNNanoDetEfficientNetLiteDepreciated; // [3] * reference: https://github.com/RangiLyu/nanodet
1919
class LITE_EXPORTS NCNNRobustVideoMatting; // [4] * reference: https://github.com/PeterL1n/RobustVideoMatting
2020
class LITE_EXPORTS NCNNYoloX; // [5] * reference: https://github.com/Megvii-BaseDetection/YOLOX
2121
class LITE_EXPORTS NCNNYOLOP; // [6] * reference: https://github.com/hustvl/YOLOP
@@ -51,11 +51,11 @@ namespace ncnncv
5151
class LITE_EXPORTS NCNNAgeGoogleNet; // [36] * reference: https://github.com/onnx/models/tree/master/vision/body_analysis/age_gender
5252
class LITE_EXPORTS NCNNGenderGoogleNet; // [37] * reference: https://github.com/onnx/models/tree/master/vision/body_analysis/age_gender
5353
class LITE_EXPORTS NCNNEmotionFerPlus; // [38] * reference: https://github.com/onnx/models/blob/master/vision/body_analysis/emotion_ferplus
54-
class LITE_EXPORTS NCNNEfficientEmotion7; // [39] * reference: https://github.com/HSE-asavchenko/face-emotion-recognition
55-
class LITE_EXPORTS NCNNEfficientEmotion8; // [40] * reference: https://github.com/HSE-asavchenko/face-emotion-recognition
54+
class LITE_EXPORTS NCNNEfficientEmotion7; // [39] * reference: https://github.com/HSE-asavchenko/face-emotion-recognition
55+
class LITE_EXPORTS NCNNEfficientEmotion8; // [40] * reference: https://github.com/HSE-asavchenko/face-emotion-recognition
5656
class LITE_EXPORTS NCNNMobileEmotion7; // [41] * reference: https://github.com/HSE-asavchenko/face-emotion-recognition
57-
class LITE_EXPORTS NCNNEfficientNetLite4; // [42] * reference: https://github.com/onnx/models/blob/master/vision/classification/efficientnet-lite4
58-
class LITE_EXPORTS NCNNShuffleNetV2; // [43] * reference: https://github.com/onnx/models/blob/master/vision/classification/shufflenet
57+
class LITE_EXPORTS NCNNEfficientNetLite4; // [42] * reference: https://github.com/onnx/models/blob/master/vision/classification/efficientnet-lite4
58+
class LITE_EXPORTS NCNNShuffleNetV2; // [43] * reference: https://github.com/onnx/models/blob/master/vision/classification/shufflenet
5959
class LITE_EXPORTS NCNNDenseNet; // [44] * reference: https://pytorch.org/hub/pytorch_vision_densenet/
6060
class LITE_EXPORTS NCNNGhostNet; // [45] * reference:https://pytorch.org/hub/pytorch_vision_ghostnet/
6161
class LITE_EXPORTS NCNNHdrDNet; // [46] * reference: https://pytorch.org/hub/pytorch_vision_hardnet/
+182
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
//
2+
// Created by DefTruth on 2022/7/2.
3+
//
4+
5+
#include "ncnn_face_parsing_bisenet.h"
6+
7+
using ncnncv::NCNNFaceParsingBiSeNet;
8+
9+
NCNNFaceParsingBiSeNet::NCNNFaceParsingBiSeNet(const std::string &_param_path,
10+
const std::string &_bin_path,
11+
unsigned int _num_threads,
12+
unsigned int _input_height,
13+
unsigned int _input_width) :
14+
BasicNCNNHandler(_param_path, _bin_path, _num_threads),
15+
input_height(_input_height), input_width(_input_width)
16+
{
17+
}
18+
19+
void NCNNFaceParsingBiSeNet::transform(const cv::Mat &mat, ncnn::Mat &in)
20+
{
21+
cv::Mat mat_rs;
22+
cv::resize(mat, mat_rs, cv::Size(input_width, input_height));
23+
// will do deepcopy inside ncnn
24+
in = ncnn::Mat::from_pixels(mat_rs.data, ncnn::Mat::PIXEL_BGR2RGB, input_width, input_height);
25+
in.substract_mean_normalize(mean_vals, norm_vals);
26+
}
27+
28+
void NCNNFaceParsingBiSeNet::detect(const cv::Mat &mat, types::FaceParsingContent &content,
29+
bool minimum_post_process)
30+
{
31+
if (mat.empty()) return;
32+
33+
// 1. make input tensor
34+
ncnn::Mat input;
35+
this->transform(mat, input);
36+
// 2. inference & extract
37+
auto extractor = net->create_extractor();
38+
extractor.set_light_mode(false); // default
39+
extractor.set_num_threads(num_threads);
40+
extractor.input("input", input);
41+
// 3. generate mask
42+
this->generate_mask(extractor, mat, content, minimum_post_process);
43+
}
44+
45+
static inline uchar __argmax_find(float *mutable_ptr, const unsigned int &step)
46+
{
47+
std::vector<float> logits(19, 0.f);
48+
for (unsigned int i = 0; i < 19; ++i)
49+
logits[i] = *(mutable_ptr + i * step);
50+
uchar label = 0;
51+
float max_logit = logits[0];
52+
for (unsigned int i = 1; i < 19; ++i)
53+
{
54+
if (logits[i] > max_logit)
55+
{
56+
max_logit = logits[i];
57+
label = (uchar) i;
58+
}
59+
}
60+
return label;
61+
}
62+
63+
static const uchar part_colors[20][3] = {
64+
{255, 0, 0},
65+
{255, 85, 0},
66+
{255, 170, 0},
67+
{255, 0, 85},
68+
{255, 0, 170},
69+
{0, 255, 0},
70+
{85, 255, 0},
71+
{170, 255, 0},
72+
{0, 255, 85},
73+
{0, 255, 170},
74+
{0, 0, 255},
75+
{85, 0, 255},
76+
{170, 0, 255},
77+
{0, 85, 255},
78+
{0, 170, 255},
79+
{255, 255, 0},
80+
{255, 255, 85},
81+
{255, 255, 170},
82+
{255, 0, 255},
83+
{255, 85, 255}
84+
};
85+
86+
void NCNNFaceParsingBiSeNet::generate_mask(ncnn::Extractor &extractor, const cv::Mat &mat,
87+
types::FaceParsingContent &content,
88+
bool minimum_post_process)
89+
{
90+
ncnn::Mat output;
91+
extractor.extract("out", output);
92+
#ifdef LITENCNN_DEBUG
93+
BasicNCNNHandler::print_shape(output, "out");
94+
#endif
95+
const unsigned int h = mat.rows;
96+
const unsigned int w = mat.cols;
97+
98+
const unsigned int out_h = input_height;
99+
const unsigned int out_w = input_width;
100+
const unsigned int channel_step = out_h * out_w;
101+
102+
float *output_ptr = (float *) output.data;
103+
std::vector<uchar> elements(channel_step, 0); // allocate
104+
for (unsigned int i = 0; i < channel_step; ++i)
105+
elements[i] = __argmax_find(output_ptr + i, channel_step);
106+
107+
cv::Mat label(out_h, out_w, CV_8UC1, elements.data());
108+
109+
if (!minimum_post_process)
110+
{
111+
const uchar *label_ptr = label.data;
112+
cv::Mat color_mat(out_h, out_w, CV_8UC3, cv::Scalar(255, 255, 255));
113+
for (unsigned int i = 0; i < color_mat.rows; ++i)
114+
{
115+
cv::Vec3b *p = color_mat.ptr<cv::Vec3b>(i);
116+
for (unsigned int j = 0; j < color_mat.cols; ++j)
117+
{
118+
if (label_ptr[i * out_w + j] == 0) continue;
119+
p[j][0] = part_colors[label_ptr[i * out_w + j]][0];
120+
p[j][1] = part_colors[label_ptr[i * out_w + j]][1];
121+
p[j][2] = part_colors[label_ptr[i * out_w + j]][2];
122+
}
123+
}
124+
if (out_h != h || out_w != w)
125+
cv::resize(color_mat, color_mat, cv::Size(w, h));
126+
cv::addWeighted(mat, 0.4, color_mat, 0.6, 0., content.merge);
127+
}
128+
if (out_h != h || out_w != w) cv::resize(label, label, cv::Size(w, h));
129+
130+
content.label = label;
131+
content.flag = true;
132+
}
133+
134+
135+
136+
137+
138+
139+
140+
141+
142+
143+
144+
145+
146+
147+
148+
149+
150+
151+
152+
153+
154+
155+
156+
157+
158+
159+
160+
161+
162+
163+
164+
165+
166+
167+
168+
169+
170+
171+
172+
173+
174+
175+
176+
177+
178+
179+
180+
181+
182+
+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
//
2+
// Created by DefTruth on 2022/7/2.
3+
//
4+
5+
#ifndef LITE_AI_TOOLKIT_NCNN_CV_NCNN_FACE_PARSING_BISENET_H
6+
#define LITE_AI_TOOLKIT_NCNN_CV_NCNN_FACE_PARSING_BISENET_H
7+
8+
#include "lite/ncnn/core/ncnn_core.h"
9+
10+
namespace ncnncv
11+
{
12+
class LITE_EXPORTS NCNNFaceParsingBiSeNet : public BasicNCNNHandler
13+
{
14+
public:
15+
explicit NCNNFaceParsingBiSeNet(const std::string &_param_path,
16+
const std::string &_bin_path,
17+
unsigned int _num_threads = 1,
18+
unsigned int _input_height = 512,
19+
unsigned int _input_width = 512);
20+
21+
~NCNNFaceParsingBiSeNet() override = default;
22+
23+
private:
24+
const int input_height;
25+
const int input_width;
26+
const float mean_vals[3] = {0.485f * 255.f, 0.456f * 255.f, 0.406f * 255.f}; // RGB
27+
const float norm_vals[3] = {1.f / (0.229f * 255.f), 1.f / (0.224f * 255.f), 1.f / (0.225f * 255.f)};
28+
29+
private:
30+
void transform(const cv::Mat &mat, ncnn::Mat &in) override;
31+
32+
void generate_mask(ncnn::Extractor &extractor,
33+
const cv::Mat &mat, types::FaceParsingContent &content,
34+
bool minimum_post_process = false);
35+
36+
public:
37+
void detect(const cv::Mat &mat, types::FaceParsingContent &content,
38+
bool minimum_post_process = false);
39+
40+
};
41+
}
42+
43+
#endif //LITE_AI_TOOLKIT_NCNN_CV_NCNN_FACE_PARSING_BISENET_H

lite/ncnn/cv/ncnn_modnet.h

+1-2
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@ namespace ncnncv
1616
const std::string &_bin_path,
1717
unsigned int _num_threads = 1,
1818
unsigned int _input_height = 512,
19-
unsigned int _input_width = 512
20-
);
19+
unsigned int _input_width = 512);
2120

2221
~NCNNMODNet() override = default;
2322

lite/ort/cv/face_parsing_bisenet.cpp

-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
#include "face_parsing_bisenet.h"
66
#include "lite/ort/core/ort_utils.h"
7-
#include "lite/utils.h"
87

98
using ortcv::FaceParsingBiSeNet;
109

0 commit comments

Comments
 (0)