-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathknnclassifier.js
194 lines (163 loc) · 5.48 KB
/
knnclassifier.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
let video;
// Create a KNN classifier
const knnClassifier = ml5.KNNClassifier();
let poseNet;
let poses = [];
function setup() {
const canvas = createCanvas(1280, 720);
canvas.parent('videoContainer');
video = createCapture(VIDEO);
video.size(width, height);
// Create the UI buttons
createButtons();
// Create a new poseNet method with a single detection
poseNet = ml5.poseNet(video, modelReady);
// This sets up an event that fills the global variable "poses"
// with an array every time new poses are detected
poseNet.on('pose', function(results) {
poses = results;
});
// Hide the video element, and just show the canvas
video.hide();
}
function draw() {
image(video, 0, 0, width, height);
// We can call both functions to draw all keypoints and the skeletons
drawKeypoints();
drawSkeleton();
}
function modelReady(){
select('#status').html('model Loaded')
}
// Add the current frame from the video to the classifier
function addExample(label) {
// Convert poses results to a 2d array [[score0, x0, y0],...,[score16, x16, y16]]
const poseArray = poses[0].pose.keypoints.map(p => [p.score, p.position.x, p.position.y]);
// Add an example with a label to the classifier
knnClassifier.addExample(poseArray, label);
updateCounts();
}
// Predict the current frame.
function classify() {
// Get the total number of labels from knnClassifier
const numLabels = knnClassifier.getNumLabels();
if (numLabels <= 0) {
console.error('There is no examples in any label');
return;
}
// Convert poses results to a 2d array [[score0, x0, y0],...,[score16, x16, y16]]
const poseArray = poses[0].pose.keypoints.map(p => [p.score, p.position.x, p.position.y]);
// Use knnClassifier to classify which label do these features belong to
// You can pass in a callback function `gotResults` to knnClassifier.classify function
knnClassifier.classify(poseArray, gotResults);
}
// A util function to create UI buttons
function createButtons() {
// When the A button is pressed, add the current frame
// from the video with a label of "A" to the classifier
buttonA = select('#addClassA');
buttonA.mousePressed(function() {
addExample('A');
});
// When the B button is pressed, add the current frame
// from the video with a label of "B" to the classifier
buttonB = select('#addClassB');
buttonB.mousePressed(function() {
addExample('B');
});
// Reset buttons
resetBtnA = select('#resetA');
resetBtnA.mousePressed(function() {
clearLabel('A');
});
resetBtnB = select('#resetB');
resetBtnB.mousePressed(function() {
clearLabel('B');
});
// Predict button
buttonPredict = select('#buttonPredict');
buttonPredict.mousePressed(classify);
// Clear all classes button
buttonClearAll = select('#clearAll');
buttonClearAll.mousePressed(clearAllLabels);
// Load saved classifier dataset
buttonSetData = select('#load');
buttonSetData.mousePressed(loadMyKNN);
// Get classifier dataset
buttonGetData = select('#save');
buttonGetData.mousePressed(saveMyKNN);
}
// Show the results
function gotResults(err, result) {
// Display any error
if (err) {
console.error(err);
}
if (result.confidencesByLabel) {
const confidences = result.confidencesByLabel;
// result.label is the label that has the highest confidence
if (result.label) {
select('#result').html(result.label);
select('#confidence').html(`${confidences[result.label] * 100} %`);
}
select('#confidenceA').html(`${confidences['A'] ? confidences['A'] * 100 : 0} %`);
select('#confidenceB').html(`${confidences['B'] ? confidences['B'] * 100 : 0} %`);
}
classify();
}
// Update the example count for each label
function updateCounts() {
const counts = knnClassifier.getCountByLabel();
select('#exampleA').html(counts['A'] || 0);
select('#exampleB').html(counts['B'] || 0);
}
// Clear the examples in one label
function clearLabel(classLabel) {
knnClassifier.clearLabel(classLabel);
updateCounts();
}
// Clear all the examples in all labels
function clearAllLabels() {
knnClassifier.clearAllLabels();
updateCounts();
}
// Save dataset as myKNNDataset.json
function saveMyKNN() {
knnClassifier.save('myKNNDataset');
}
// Load dataset to the classifier
function loadMyKNN() {
knnClassifier.load('http://sugarsnap.co.uk/elements/myKNNDatasetFinal.json', updateCounts);
}
// A function to draw ellipses over the detected keypoints
function drawKeypoints() {
// Loop through all the poses detected
for (let i = 0; i < poses.length; i++) {
// For each pose detected, loop through all the keypoints
let pose = poses[i].pose;
for (let j = 0; j < pose.keypoints.length; j++) {
// A keypoint is an object describing a body part (like rightArm or leftShoulder)
let keypoint = pose.keypoints[j];
// Only draw an ellipse is the pose probability is bigger than 0.2
if (keypoint.score > 0.2) {
fill(255, 0, 0);
noStroke();
ellipse(keypoint.position.x, keypoint.position.y, 10, 10);
}
}
}
}
// A function to draw the skeletons
function drawSkeleton() {
// Loop through all the skeletons detected
for (let i = 0; i < poses.length; i++) {
let skeleton = poses[i].skeleton;
// For every skeleton, loop through all body connections
for (let j = 0; j < skeleton.length; j++) {
let partA = skeleton[j][0];
let partB = skeleton[j][1];
stroke(255, 0, 0);
line(partA.position.x, partA.position.y, partB.position.x, partB.position.y);
}
}
}