Skip to content

Commit 4c02b54

Browse files
authored
[visualize-convnet] Add new example: visualize-convnet (#201)
This TensorFlow.js example demonstrates some techniques of visualizing the internal workings of a convolutional neural network (convnet), including: - Finding what convolutional layers' filters are sensitive to after training: calculating maximally-activating input image for convolutional filters through gradient ascent in the input space. - Getting the internal activation of a convnet by uisng the functional model API of TensorFlow.js - Finding which part of an input image is most relevant to the classification decision made by a convnet (VGG16 in this case), using the gradient-based class activation map (CAM) approach. Example screenshots: ![image](https://user-images.githubusercontent.com/16824702/50791933-5826b380-1291-11e9-82ec-feb883078dc7.png) ![image](https://user-images.githubusercontent.com/16824702/50791951-61b01b80-1291-11e9-9fe7-6e0534cdf7ba.png) ![image](https://user-images.githubusercontent.com/16824702/50791966-6f65a100-1291-11e9-9941-aaf6fe59902b.png)
1 parent 18d5b2f commit 4c02b54

20 files changed

+8309
-0
lines changed

visualize-convnet/.babelrc

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
{
2+
"presets": [
3+
[
4+
"env",
5+
{
6+
"esmodules": false,
7+
"targets": {
8+
"browsers": [
9+
"> 3%"
10+
]
11+
}
12+
}
13+
]
14+
],
15+
"plugins": [
16+
"transform-runtime"
17+
]
18+
}

visualize-convnet/.gitignore

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
*.png
2+
model.json
3+
group*shard*
4+
*.h5

visualize-convnet/README.md

+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# TensorFlow.js Example: Visualizing Convnet Filters
2+
3+
## Description
4+
5+
This TensorFlow.js example demonstrates some techniques of visualizing
6+
the internal workings of a convolutional neural network (convnet), including:
7+
8+
- Finding what convolutional layers' filters are sensitive to after
9+
training: calculating maximally-activating input image for
10+
convolutional filters through gradient ascent in the input space.
11+
- Getting the internal activation of a convnet by uisng the
12+
functional model API of TensorFlow.js
13+
- Finding which part of an input image is most relevant to the
14+
classification decision made by a convnet (VGG16 in this case),
15+
using the gradient-based class activation map (CAM) approach.
16+
17+
## How to use this demo
18+
19+
Run the command:
20+
```sh
21+
yarn visualize
22+
```
23+
24+
This will automatically
25+
26+
1. install the necessary Python dependencies. If the required
27+
Python package (keras, tensorflow and tensorflowjs) are already installed,
28+
this step will be a no-op. However, to prevent this step from
29+
modifying your global Python environment, you may run this demo from
30+
a [virtualenv](https://virtualenv.pypa.io/en/latest/) or
31+
[pipenv](https://pipenv.readthedocs.io/en/latest/).
32+
2. download and convert the VGG16 model to TensorFlow.js format
33+
3. launch a Node.js script to load the converted model and compute
34+
the maximally-activating input images for the convnet's filters
35+
using gradient ascent in the input space and save them as image
36+
files under the `dist/filters` directory
37+
4. launch a Node.js script to calculate the internal convolutional
38+
layers' activations and th gradient-based class activation
39+
map (CAM) and save them as image files under the
40+
`dist/activation` directory.
41+
5. compile and launch the web view using parcel
42+
43+
Step 3 and 4 (especially step 3) involve relatively heavy computation
44+
and is best done usnig tfjs-node-gpu instead of the default
45+
tfjs-node. This requires that a CUDA-enabled GPU and the necessary
46+
driver and libraries are installed on your system.
47+
48+
Assuming those prerequisites are met, do:
49+
50+
```sh
51+
yarn visualize --gpu
52+
```
53+
54+
You may also increase the number of filters to visualize per convolutional
55+
layer from the default 8 to a larger value, e.g., 32:
56+
57+
```sh
58+
yarn visualize --gpu --filters 32
59+
```
60+
61+
The default image used for the internal-activation and CAM visualization is
62+
"owl.jpg". You can switch to another image by using the "--image" flag, e.g.,
63+
64+
```sh
65+
yarn visualize --image dog.jpg
66+
```

visualize-convnet/cam.js

+129
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
/**
19+
* This script contains a function that performs the following operations:
20+
*
21+
* Get visual interpretation of which parts of the image more most
22+
* responsible for a convnet's classification decision, using the
23+
* gradient-based class activation map (CAM) method.
24+
* See function `gradClassActivationMap()`.
25+
*/
26+
27+
const tf = require('@tensorflow/tfjs');
28+
const utils = require('./utils');
29+
30+
/**
31+
* Calculate class activation map (CAM) and overlay it on input image.
32+
*
33+
* This function automatically finds the last convolutional layer, get its
34+
* output (activation) under the input image, weights its filters by the
35+
* gradient of the class output with respect to them, and then collapses along
36+
* the filter dimension.
37+
*
38+
* @param {tf.Sequential} model A TensorFlow.js sequential model, assumed to
39+
* contain at least one convolutional layer.
40+
* @param {number} classIndex Index to class in the model's final classification
41+
* output.
42+
* @param {tf.Tensor4d} x Input image, assumed to have shape
43+
* `[1, height, width, 3]`.
44+
* @param {number} overlayFactor Optional overlay factor.
45+
* @returns The input image with a heat-map representation of the class
46+
* activation map overlaid on top of it, as float32-type `tf.Tensor4d` of
47+
* shape `[1, height, width, 3]`.
48+
*/
49+
function gradClassActivationMap(model, classIndex, x, overlayFactor = 2.0) {
50+
// Try to locate the last conv layer of the model.
51+
let layerIndex = model.layers.length - 1;
52+
while (layerIndex >= 0) {
53+
if (model.layers[layerIndex].getClassName().startsWith('Conv')) {
54+
break;
55+
}
56+
layerIndex--;
57+
}
58+
tf.util.assert(
59+
layerIndex >= 0, `Failed to find a convolutional layer in model`);
60+
61+
const lastConvLayer = model.layers[layerIndex];
62+
console.log(
63+
`Located last convolutional layer of the model at ` +
64+
`index ${layerIndex}: layer type = ${lastConvLayer.getClassName()}; ` +
65+
`layer name = ${lastConvLayer.name}`);
66+
67+
// Get "sub-model 1", which goes from the original input to the output
68+
// of the last convolutional layer.
69+
const lastConvLayerOutput = lastConvLayer.output;
70+
const subModel1 =
71+
tf.model({inputs: model.inputs, outputs: lastConvLayerOutput});
72+
73+
// Get "sub-model 2", which goes from the output of the last convolutional
74+
// layer to the original output.
75+
const newInput = tf.input({shape: lastConvLayerOutput.shape.slice(1)});
76+
layerIndex++;
77+
let y = newInput;
78+
while (layerIndex < model.layers.length) {
79+
y = model.layers[layerIndex++].apply(y);
80+
}
81+
const subModel2 = tf.model({inputs: newInput, outputs: y});
82+
83+
return tf.tidy(() => {
84+
// This function runs sub-model 2 and extracts the slice of the probability
85+
// output that corresponds to the desired class.
86+
const convOutput2ClassOutput = (input) =>
87+
subModel2.apply(input, {training: true}).gather([classIndex], 1);
88+
// This is the gradient function of the output corresponding to the desired
89+
// class with respect to its input (i.e., the output of the last
90+
// convolutional layer of the original model).
91+
const gradFunction = tf.grad(convOutput2ClassOutput);
92+
93+
// Calculate the values of the last conv layer's output.
94+
const lastConvLayerOutputValues = subModel1.apply(x);
95+
// Calculate the values of gradients of the class output w.r.t. the output
96+
// of the last convolutional layer.
97+
const gradValues = gradFunction(lastConvLayerOutputValues);
98+
99+
// Pool the gradient values within each filter of the last convolutional
100+
// layer, resulting in a tensor of shape [numFilters].
101+
const pooledGradValues = tf.mean(gradValues, [0, 1, 2]);
102+
// Scale the convlutional layer's output by the pooled gradients, using
103+
// broadcasting.
104+
const scaledConvOutputValues =
105+
lastConvLayerOutputValues.mul(pooledGradValues);
106+
107+
// Create heat map by averaging and collapsing over all filters.
108+
let heatMap = scaledConvOutputValues.mean(-1);
109+
110+
// Discard negative values from the heat map and normalize it to the [0, 1]
111+
// interval.
112+
heatMap = heatMap.relu();
113+
heatMap = heatMap.div(heatMap.max()).expandDims(-1);
114+
115+
// Up-sample the heat map to the size of the input image.
116+
heatMap = tf.image.resizeBilinear(heatMap, [x.shape[1], x.shape[2]]);
117+
118+
// Apply an RGB colormap on the heatMap. This step is necessary because
119+
// the heatMap is a 1-channel (grayscale) image. It needs to be converted
120+
// into a color (RGB) one through this function call.
121+
heatMap = utils.applyColorMap(heatMap);
122+
123+
// To form the final output, overlay the color heat map on the input image.
124+
heatMap = heatMap.mul(overlayFactor).add(x.div(255));
125+
return heatMap.div(heatMap.max()).mul(255);
126+
});
127+
}
128+
129+
module.exports = {gradClassActivationMap};

visualize-convnet/cat.jpg

40.1 KB
Loading

visualize-convnet/dog.jpg

75.8 KB
Loading

visualize-convnet/elephants.jpg

186 KB
Loading

visualize-convnet/filters.js

+174
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
/**
19+
* Algorithms for analyzing and visualizing the convolutional filters
20+
* internal to a convnet.
21+
*
22+
* 1. Retrieving internal activations of a convnet.
23+
* See function `writeInternalActivationAndGetOutput()`.
24+
* 2. Calculate maximally-activating input image for convnet filters, using
25+
* gradient ascent in input space.
26+
* See function `inputGradientAscent()`.
27+
**/
28+
29+
const path = require('path');
30+
const tf = require('@tensorflow/tfjs');
31+
const utils = require('./utils');
32+
33+
/**
34+
* Write internal activation of conv layers to file; Get model output.
35+
*
36+
* @param {tf.Model} model The model of interest.
37+
* @param {string[]} layerNames Names of layers of interest.
38+
* @param {tf.Tensor4d} inputImage The input image represented as a 4D tensor
39+
* of shape [1, height, width, 3].
40+
* @param {number} numFilters Number of filters to run for each convolutional
41+
* layer. If it exceeds the number of filters of a convolutional layer, it
42+
* will be cut off.
43+
* @param {string} outputDir Path to the directory to which the image files
44+
* representing the activation will be saved.
45+
* @return modelOutput: final output of the model as a tf.Tensor.
46+
* layerName2FilePaths: an object mapping layer name to the paths to the
47+
* image files saved for the layer's activation.
48+
* layerName2FilePaths: an object mapping layer name to the height
49+
* and width of the layer's filter outputs.
50+
*/
51+
async function writeInternalActivationAndGetOutput(
52+
model, layerNames, inputImage, numFilters, outputDir) {
53+
const layerName2FilePaths = {};
54+
const layerName2ImageDims = {};
55+
const layerOutputs =
56+
layerNames.map(layerName => model.getLayer(layerName).output);
57+
58+
// Construct a model that returns all the desired internal activations,
59+
// in addition to the final output of the original model.
60+
const compositeModel = tf.model(
61+
{inputs: model.input, outputs: layerOutputs.concat(model.outputs[0])});
62+
63+
// `outputs` is an array of `tf.Tensor`s consisting of the internal-activation
64+
// values and the final output value.
65+
const outputs = compositeModel.predict(inputImage);
66+
67+
for (let i = 0; i < outputs.length - 1; ++i) {
68+
const layerName = layerNames[i];
69+
// Split the activation of the convolutional layer by filter.
70+
const activationTensors =
71+
tf.split(outputs[i], outputs[i].shape[outputs[i].shape.length - 1], -1);
72+
const actualNumFilters = numFilters <= activationTensors.length ?
73+
numFilters :
74+
activationTensors.length;
75+
const filePaths = [];
76+
let imageTensorShape;
77+
for (let j = 0; j < actualNumFilters; ++j) {
78+
// Format activation tensors and write them to disk.
79+
const imageTensor = tf.tidy(
80+
() => deprocessImage(tf.tile(activationTensors[j], [1, 1, 1, 3])));
81+
const outputFilePath = path.join(outputDir, `${layerName}_${j + 1}.png`);
82+
filePaths.push(outputFilePath);
83+
await utils.writeImageTensorToFile(imageTensor, outputFilePath);
84+
imageTensorShape = imageTensor.shape;
85+
}
86+
layerName2FilePaths[layerName] = filePaths;
87+
layerName2ImageDims[layerName] = imageTensorShape.slice(1, 3);
88+
tf.dispose(activationTensors);
89+
}
90+
tf.dispose(outputs.slice(0, outputs.length - 1));
91+
return {
92+
modelOutput: outputs[outputs.length - 1],
93+
layerName2FilePaths,
94+
layerName2ImageDims
95+
};
96+
}
97+
98+
99+
/**
100+
* Generate the maximally-activating input image for a conv2d layer filter.
101+
*
102+
* Uses gradient ascent in input space.
103+
*
104+
* @param {tf.Model} model The model that the convolutional layer of interest
105+
* belongs to.
106+
* @param {string} layerName Name of the convolutional layer.
107+
* @param {number} filterIndex Index to the filter of interest. Must be
108+
* < number of filters of the conv2d layer.
109+
* @param {number} iterations Number of gradient-ascent iterations.
110+
* @return {tf.Tensor} The maximally-activating input image as a tensor.
111+
*/
112+
function inputGradientAscent(model, layerName, filterIndex, iterations = 40) {
113+
return tf.tidy(() => {
114+
const imageH = model.inputs[0].shape[1];
115+
const imageW = model.inputs[0].shape[2];
116+
const imageDepth = model.inputs[0].shape[3];
117+
118+
// Create an auxiliary model of which input is the same as the original
119+
// model but the output is the output of the convolutional layer of
120+
// interest.
121+
const layerOutput = model.getLayer(layerName).output;
122+
const auxModel = tf.model({inputs: model.inputs, outputs: layerOutput});
123+
124+
// This function calculates the value of the convolutional layer's
125+
// output at the designated filter index.
126+
const lossFunction = (input) =>
127+
auxModel.apply(input, {training: true}).gather([filterIndex], 3);
128+
129+
// This function (`gradient`) calculates the gradient of the convolutional
130+
// filter's output with respect to the input image.
131+
const gradients = tf.grad(lossFunction);
132+
133+
// Form a random image as the starting point of the gradient ascent.
134+
let image = tf.randomUniform([1, imageH, imageW, imageDepth], 0, 1)
135+
.mul(20)
136+
.add(128);
137+
138+
for (let i = 0; i < iterations; ++i) {
139+
const scaledGrads = tf.tidy(() => {
140+
const grads = gradients(image);
141+
const norm =
142+
tf.sqrt(tf.mean(tf.square(grads))).add(tf.ENV.get('EPSILON'));
143+
// Important trick: scale the gradient with the magnitude (norm)
144+
// of the gradient.
145+
return grads.div(norm);
146+
});
147+
// Perform one step of gradient ascent: Update the image along the
148+
// direction of the gradient.
149+
image = tf.clipByValue(image.add(scaledGrads), 0, 255);
150+
}
151+
return deprocessImage(image);
152+
});
153+
}
154+
155+
/** Center and scale input image so the pixel values fall into [0, 255]. */
156+
function deprocessImage(x) {
157+
return tf.tidy(() => {
158+
const {mean, variance} = tf.moments(x);
159+
x = x.sub(mean);
160+
// Add a small positive number (EPSILON) to the denominator to prevent
161+
// division-by-zero.
162+
x = x.div(tf.sqrt(variance).add(tf.ENV.get('EPSILON')));
163+
// Clip to [0, 1].
164+
x = x.add(0.5);
165+
x = tf.clipByValue(x, 0, 1);
166+
x = x.mul(255);
167+
return tf.clipByValue(x, 0, 255).asType('int32');
168+
});
169+
}
170+
171+
module.exports = {
172+
inputGradientAscent,
173+
writeInternalActivationAndGetOutput
174+
};

0 commit comments

Comments
 (0)