Skip to content

Update webcam-transfer-learning with tf.data.webcam API #267

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 42 additions & 30 deletions webcam-transfer-learning/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@
*/

import * as tf from '@tensorflow/tfjs';
import * as tfd from '@tensorflow/tfjs-data';

import {ControllerDataset} from './controller_dataset';
import * as ui from './ui';
import {Webcam} from './webcam';

// The number of classes we want to predict. In this example, we will be
// predicting 4 classes for up, down, left, and right.
const NUM_CLASSES = 4;

// A webcam class that generates Tensors from the images from the webcam.
const webcam = new Webcam(document.getElementById('webcam'));
// A webcam iterator that generates Tensors from the images from the webcam.
let webcam;

// The dataset object where we will store activations.
const controllerDataset = new ControllerDataset(NUM_CLASSES);
Expand All @@ -48,15 +48,15 @@ async function loadTruncatedMobileNet() {
// When the UI buttons are pressed, read a frame from the webcam and associate
// it with the class label given by the button. up, down, left, right are
// labels 0, 1, 2, 3 respectively.
ui.setExampleHandler(label => {
tf.tidy(() => {
const img = webcam.capture();
controllerDataset.addExample(truncatedMobileNet.predict(img), label);
ui.setExampleHandler(async label => {
let img = await getImage();

// Draw the preview thumbnail.
ui.drawThumb(img, label);
});
});
controllerDataset.addExample(truncatedMobileNet.predict(img), label);

// Draw the preview thumbnail.
ui.drawThumb(img, label);
img.dispose();
})

/**
* Sets up and trains the classifier.
Expand Down Expand Up @@ -129,32 +129,41 @@ let isPredicting = false;
async function predict() {
ui.isPredicting();
while (isPredicting) {
const predictedClass = tf.tidy(() => {
// Capture the frame from the webcam.
const img = webcam.capture();
// Capture the frame from the webcam.
const img = await getImage();

// Make a prediction through mobilenet, getting the internal activation of
// the mobilenet model, i.e., "embeddings" of the input images.
const embeddings = truncatedMobileNet.predict(img);
// Make a prediction through mobilenet, getting the internal activation of
// the mobilenet model, i.e., "embeddings" of the input images.
const embeddings = truncatedMobileNet.predict(img);

// Make a prediction through our newly-trained model using the embeddings
// from mobilenet as input.
const predictions = model.predict(embeddings);

// Returns the index with the maximum probability. This number corresponds
// to the class the model thinks is the most probable given the input.
return predictions.as1D().argMax();
});
// Make a prediction through our newly-trained model using the embeddings
// from mobilenet as input.
const predictions = model.predict(embeddings);

// Returns the index with the maximum probability. This number corresponds
// to the class the model thinks is the most probable given the input.
const predictedClass = predictions.as1D().argMax();
const classId = (await predictedClass.data())[0];
predictedClass.dispose();
img.dispose();

ui.predictClass(classId);
await tf.nextFrame();
}
ui.donePredicting();
}

/**
* Captures a frame from the webcam and normalizes it between -1 and 1.
* Returns a batched image (1-element batch) of shape [1, w, h, c].
*/
async function getImage() {
const img = await webcam.capture();
const processedImg =
tf.tidy(() => img.expandDims(0).toFloat().div(127).sub(1));
img.dispose();
return processedImg;
}

document.getElementById('train').addEventListener('click', async () => {
ui.trainStatus('Training...');
await tf.nextFrame();
Expand All @@ -170,18 +179,21 @@ document.getElementById('predict').addEventListener('click', () => {

async function init() {
try {
await webcam.setup();
webcam = await tfd.webcam(document.getElementById('webcam'));
} catch (e) {
console.log(e);
document.getElementById('no-webcam').style.display = 'block';
}
truncatedMobileNet = await loadTruncatedMobileNet();

ui.init();

// Warm up the model. This uploads weights to the GPU and compiles the WebGL
// programs so the first time we collect data from the webcam it will be
// quick.
tf.tidy(() => truncatedMobileNet.predict(webcam.capture()));

ui.init();
const screenShot = await webcam.capture();
truncatedMobileNet.predict(screenShot.expandDims(0));
screenShot.dispose();
}

// Initialize the application.
Expand Down
2 changes: 1 addition & 1 deletion webcam-transfer-learning/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"node": ">=8.9.0"
},
"dependencies": {
"@tensorflow/tfjs": "^1.0.4",
"@tensorflow/tfjs": "^1.1.0",
"vega-embed": "^3.0.0"
},
"scripts": {
Expand Down
106 changes: 0 additions & 106 deletions webcam-transfer-learning/webcam.js

This file was deleted.

Loading