Skip to content

Commit 997f787

Browse files
committed
[samples]: Samples using the Java API.
1 parent 364f96d commit 997f787

File tree

22 files changed

+2539
-0
lines changed

22 files changed

+2539
-0
lines changed

samples/java/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# TensorFlow for Java: Examples
2+
3+
Examples using the TensorFlow Java API.

samples/java/docker/Dockerfile

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
FROM tensorflow/tensorflow:1.4.0
2+
WORKDIR /
3+
RUN apt-get update
4+
RUN apt-get -y install maven openjdk-8-jdk
5+
RUN mvn dependency:get -Dartifact=org.tensorflow:tensorflow:1.4.0
6+
RUN mvn dependency:get -Dartifact=org.tensorflow:proto:1.4.0
7+
CMD ["/bin/bash", "-l"]

samples/java/docker/README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
Dockerfile for building an image suitable for running the Java examples.
2+
3+
Typical usage:
4+
5+
```
6+
docker build -t java-tensorflow .
7+
docker run -it --rm -v ${PWD}/..:/examples java-tensorflow
8+
```
9+
10+
That second command will pop you into a shell which has all
11+
the dependencies required to execute the scripts and Java
12+
examples.

samples/java/label_image/.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
images
2+
src/main/resources
3+
target

samples/java/label_image/README.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Image Classification Example
2+
3+
1. Download the model:
4+
- If you have [TensorFlow 1.4+ for Python installed](https://www.tensorflow.org/install/),
5+
run `python ./download.py`
6+
- If not, but you have [docker](https://www.docker.com/get-docker) installed,
7+
run `download.sh`.
8+
9+
2. Compile [`LabelImage.java`](src/main/java/LabelImage.java):
10+
11+
```
12+
mvn compile
13+
```
14+
15+
3. Download some sample images:
16+
If you already have some images, great. Otherwise `download_sample_images.sh`
17+
gets a few.
18+
19+
3. Classify!
20+
21+
```
22+
mvn -q exec:java -Dexec.args="<path to image file>"
23+
```

samples/java/label_image/download.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
"""Create an image classification graph.
2+
3+
Script to download a pre-trained image classifier and tweak it so that
4+
the model accepts raw bytes of an encoded image.
5+
6+
Doing so involves some model-specific normalization of an image.
7+
Ideally, this would have been part of the image classifier model,
8+
but the particular model being used didn't include this normalization,
9+
so this script does the necessary tweaking.
10+
"""
11+
12+
from __future__ import absolute_import
13+
from __future__ import division
14+
from __future__ import print_function
15+
16+
from six.moves import urllib
17+
import os
18+
import zipfile
19+
import tensorflow as tf
20+
21+
URL = 'https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip'
22+
LABELS_FILE = 'imagenet_comp_graph_label_strings.txt'
23+
GRAPH_FILE = 'tensorflow_inception_graph.pb'
24+
25+
GRAPH_INPUT_TENSOR = 'input:0'
26+
GRAPH_PROBABILITIES_TENSOR = 'output:0'
27+
28+
IMAGE_HEIGHT = 224
29+
IMAGE_WIDTH = 224
30+
MEAN = 117
31+
SCALE = 1
32+
33+
LOCAL_DIR = 'src/main/resources'
34+
35+
36+
def download():
37+
print('Downloading %s' % URL)
38+
zip_filename, _ = urllib.request.urlretrieve(URL)
39+
with zipfile.ZipFile(zip_filename) as zip:
40+
zip.extract(LABELS_FILE)
41+
zip.extract(GRAPH_FILE)
42+
os.rename(LABELS_FILE, os.path.join(LOCAL_DIR, 'labels.txt'))
43+
os.rename(GRAPH_FILE, os.path.join(LOCAL_DIR, 'graph.pb'))
44+
45+
46+
def create_graph_to_decode_and_normalize_image():
47+
"""See file docstring.
48+
49+
Returns:
50+
input: The placeholder to feed the raw bytes of an encoded image.
51+
y: A Tensor (the decoded, normalized image) to be fed to the graph.
52+
"""
53+
image = tf.placeholder(tf.string, shape=(), name='encoded_image_bytes')
54+
with tf.name_scope("preprocess"):
55+
y = tf.image.decode_image(image, channels=3)
56+
y = tf.cast(y, tf.float32)
57+
y = tf.expand_dims(y, axis=0)
58+
y = tf.image.resize_bilinear(y, (IMAGE_HEIGHT, IMAGE_WIDTH))
59+
y = (y - MEAN) / SCALE
60+
return (image, y)
61+
62+
63+
def patch_graph():
64+
"""Create graph.pb that applies the model in URL to raw image bytes."""
65+
with tf.Graph().as_default() as g:
66+
input_image, image_normalized = create_graph_to_decode_and_normalize_image()
67+
original_graph_def = tf.GraphDef()
68+
with open(os.path.join(LOCAL_DIR, 'graph.pb')) as f:
69+
original_graph_def.ParseFromString(f.read())
70+
softmax = tf.import_graph_def(
71+
original_graph_def,
72+
name='inception',
73+
input_map={GRAPH_INPUT_TENSOR: image_normalized},
74+
return_elements=[GRAPH_PROBABILITIES_TENSOR])
75+
# We're constructing a graph that accepts a single image (as opposed to a
76+
# batch of images), so might as well make the output be a vector of
77+
# probabilities, instead of a batch of vectors with batch size 1.
78+
output_probabilities = tf.squeeze(softmax, name='probabilities')
79+
# Overwrite the graph.
80+
with open(os.path.join(LOCAL_DIR, 'graph.pb'), 'w') as f:
81+
f.write(g.as_graph_def().SerializeToString())
82+
print('------------------------------------------------------------')
83+
print('MODEL GRAPH : graph.pb')
84+
print('LABELS : labels.txt')
85+
print('INPUT TENSOR : %s' % input_image.op.name)
86+
print('OUTPUT TENSOR: %s' % output_probabilities.op.name)
87+
88+
89+
if __name__ == '__main__':
90+
if not os.path.exists(LOCAL_DIR):
91+
os.makedirs(LOCAL_DIR)
92+
download()
93+
patch_graph()

samples/java/label_image/download.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#!/bin/bash
2+
3+
DIR="$(cd "$(dirname "$0")" && pwd -P)"
4+
docker run -it -v ${DIR}:/x -w /x --rm tensorflow/tensorflow:1.4.0 python download.py
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#!/bin/bash
2+
DIR=$(dirname $0)
3+
mkdir -p ${DIR}/images
4+
cd ${DIR}/images
5+
6+
# Some random images
7+
curl -o "porcupine.jpg" -L "https://cdn.pixabay.com/photo/2014/11/06/12/46/porcupines-519145_960_720.jpg"
8+
curl -o "whale.jpg" -L "https://static.pexels.com/photos/417196/pexels-photo-417196.jpeg"
9+
curl -o "terrier1u.jpg" -L "https://upload.wikimedia.org/wikipedia/commons/3/34/Australian_Terrier_Melly_%282%29.JPG"
10+
curl -o "terrier2.jpg" -L "https://cdn.pixabay.com/photo/2014/05/13/07/44/yorkshire-terrier-343198_960_720.jpg"

samples/java/label_image/pom.xml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
<project>
2+
<modelVersion>4.0.0</modelVersion>
3+
<groupId>org.myorg</groupId>
4+
<artifactId>label-image</artifactId>
5+
<version>1.0-SNAPSHOT</version>
6+
<properties>
7+
<exec.mainClass>LabelImage</exec.mainClass>
8+
<!-- The sample code requires at least JDK 1.7. -->
9+
<!-- The maven compiler plugin defaults to a lower version -->
10+
<maven.compiler.source>1.7</maven.compiler.source>
11+
<maven.compiler.target>1.7</maven.compiler.target>
12+
</properties>
13+
<dependencies>
14+
<dependency>
15+
<groupId>org.tensorflow</groupId>
16+
<artifactId>tensorflow</artifactId>
17+
<version>1.4.0</version>
18+
</dependency>
19+
<!-- For ByteStreams.toByteArray: https://google.github.io/guava/releases/23.0/api/docs/com/google/common/io/ByteStreams.html -->
20+
<dependency>
21+
<groupId>com.google.guava</groupId>
22+
<artifactId>guava</artifactId>
23+
<version>23.6-jre</version>
24+
</dependency>
25+
</dependencies>
26+
</project>
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
import com.google.common.io.ByteStreams;
17+
import java.io.BufferedReader;
18+
import java.io.IOException;
19+
import java.io.InputStream;
20+
import java.io.InputStreamReader;
21+
import java.nio.file.Files;
22+
import java.nio.file.Path;
23+
import java.nio.file.Paths;
24+
import java.util.ArrayList;
25+
import java.util.List;
26+
import org.tensorflow.Graph;
27+
import org.tensorflow.Session;
28+
import org.tensorflow.Tensor;
29+
import org.tensorflow.Tensors;
30+
31+
/**
32+
* Simplified version of
33+
* https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java
34+
*/
35+
public class LabelImage {
36+
public static void main(String[] args) throws Exception {
37+
if (args.length < 1) {
38+
System.err.println("USAGE: Provide a list of image filenames");
39+
System.exit(1);
40+
}
41+
final List<String> labels = loadLabels();
42+
try (Graph graph = new Graph();
43+
Session session = new Session(graph)) {
44+
graph.importGraphDef(loadGraphDef());
45+
46+
float[] probabilities = null;
47+
for (String filename : args) {
48+
byte[] bytes = Files.readAllBytes(Paths.get(filename));
49+
try (Tensor<String> input = Tensors.create(bytes);
50+
Tensor<Float> output =
51+
session
52+
.runner()
53+
.feed("encoded_image_bytes", input)
54+
.fetch("probabilities")
55+
.run()
56+
.get(0)
57+
.expect(Float.class)) {
58+
if (probabilities == null) {
59+
probabilities = new float[(int) output.shape()[0]];
60+
}
61+
output.copyTo(probabilities);
62+
int label = argmax(probabilities);
63+
System.out.printf(
64+
"%-30s --> %-15s (%.2f%% likely)\n",
65+
filename, labels.get(label), probabilities[label] * 100.0);
66+
}
67+
}
68+
}
69+
}
70+
71+
private static byte[] loadGraphDef() throws IOException {
72+
try (InputStream is = LabelImage.class.getClassLoader().getResourceAsStream("graph.pb")) {
73+
return ByteStreams.toByteArray(is);
74+
}
75+
}
76+
77+
private static ArrayList<String> loadLabels() throws IOException {
78+
ArrayList<String> labels = new ArrayList<String>();
79+
String line;
80+
final InputStream is = LabelImage.class.getClassLoader().getResourceAsStream("labels.txt");
81+
try (BufferedReader reader = new BufferedReader(new InputStreamReader(is))) {
82+
while ((line = reader.readLine()) != null) {
83+
labels.add(line);
84+
}
85+
}
86+
return labels;
87+
}
88+
89+
private static int argmax(float[] probabilities) {
90+
int best = 0;
91+
for (int i = 1; i < probabilities.length; ++i) {
92+
if (probabilities[i] > probabilities[best]) {
93+
best = i;
94+
}
95+
}
96+
return best;
97+
}
98+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
images
2+
labels
3+
models
4+
src/main/protobuf
5+
target
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Object Detection in Java
2+
3+
Example of using pre-trained models of the [TensorFlow Object Detection
4+
API](https://github.com/tensorflow/models/tree/master/research/object_detection)
5+
in Java.
6+
7+
## Quickstart
8+
9+
1. Download some metadata files:
10+
```
11+
./download.sh
12+
```
13+
14+
2. Download a model from the [object detection API model
15+
zoo](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md).
16+
For example:
17+
```
18+
mkdir -p models
19+
curl -L \
20+
http://download.tensorflow.org/models/object_detection/ssd_inception_v2_coco_2017_11_17.tar.gz \
21+
| tar -xz -C models/
22+
```
23+
24+
3. Locate the corresponding labels file in the `data/` directory.
25+
26+
3. Have some test images handy. For example:
27+
```
28+
mkdir -p images
29+
curl -L -o images/test.jpg \
30+
https://pixnio.com/free-images/people/mother-father-and-children-washing-dog-labrador-retriever-outside-in-the-fresh-air-725x483.jpg
31+
```
32+
33+
4. Compile and run!
34+
```
35+
mvn -q compile exec:java \
36+
-Dexec.args="models/ssd_inception_v2_coco_2017_11_17/saved_model labels/mscoco_label_map.pbtxt images/test.jpg"
37+
```
38+
39+
## Notes
40+
41+
- This example demonstrates the use of the TensorFlow [SavedModel
42+
format](https://www.tensorflow.org/programmers_guide/saved_model). If you have
43+
TensorFlow for Python installed, you could explore the model to get the names
44+
of the tensors using `saved_model_cli` command. For example:
45+
```
46+
saved_model_cli show --dir models/ssd_inception_v2_coco_2017_11_17/saved_model/ --all
47+
```
48+
49+
- The file in `src/main/object_detection/protos/` was generated using:
50+
51+
```
52+
./download.sh
53+
protoc -Isrc/main/protobuf --java_out=src/main/java src/main/protobuf/string_int_label_map.proto
54+
```
55+
56+
Where `protoc` was downloaded from
57+
https://github.com/google/protobuf/releases/tag/v3.5.1
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#!/bin/bash
2+
3+
set -ex
4+
5+
DIR="$(cd "$(dirname "$0")" && pwd -P)"
6+
cd "${DIR}"
7+
8+
# The protobuf file needed for mapping labels to human readable names.
9+
# From:
10+
# https://github.com/tensorflow/models/blob/f87a58c/research/object_detection/protos/string_int_label_map.proto
11+
mkdir -p src/main/protobuf
12+
curl -L -o src/main/protobuf/string_int_label_map.proto "https://raw.githubusercontent.com/tensorflow/models/f87a58cd96d45de73c9a8330a06b2ab56749a7fa/research/object_detection/protos/string_int_label_map.proto"
13+
14+
# Labels from:
15+
# https://github.com/tensorflow/models/tree/865c14c/research/object_detection/data
16+
mkdir -p labels
17+
curl -L -o labels/mscoco_label_map.pbtxt "https://raw.githubusercontent.com/tensorflow/models/865c14c1209cb9ae188b2a1b5f0883c72e050d4c/research/object_detection/data/mscoco_label_map.pbtxt"
18+
curl -L -o labels/oid_bbox_trainable_label_map.pbtxt "https://raw.githubusercontent.com/tensorflow/models/865c14c1209cb9ae188b2a1b5f0883c72e050d4c/research/object_detection/data/oid_bbox_trainable_label_map.pbtxt"

samples/java/object_detection/pom.xml

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
<project>
2+
<modelVersion>4.0.0</modelVersion>
3+
<groupId>org.myorg</groupId>
4+
<artifactId>detect-objects</artifactId>
5+
<version>1.0-SNAPSHOT</version>
6+
<properties>
7+
<exec.mainClass>DetectObjects</exec.mainClass>
8+
<!-- The sample code requires at least JDK 1.7. -->
9+
<!-- The maven compiler plugin defaults to a lower version -->
10+
<maven.compiler.source>1.7</maven.compiler.source>
11+
<maven.compiler.target>1.7</maven.compiler.target>
12+
</properties>
13+
<dependencies>
14+
<dependency>
15+
<groupId>org.tensorflow</groupId>
16+
<artifactId>tensorflow</artifactId>
17+
<version>1.4.0</version>
18+
</dependency>
19+
<dependency>
20+
<groupId>org.tensorflow</groupId>
21+
<artifactId>proto</artifactId>
22+
<version>1.4.0</version>
23+
</dependency>
24+
</dependencies>
25+
</project>

0 commit comments

Comments
 (0)