Skip to content

Commit 811fa20

Browse files
Added struct2depth model
1 parent 4d4eb85 commit 811fa20

15 files changed

+3864
-0
lines changed

CODEOWNERS

+1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
/research/slim/ @sguada @nathansilberman
4949
/research/steve/ @buckman-google
5050
/research/street/ @theraysmith
51+
/research/struct2depth/ @aneliaangelova
5152
/research/swivel/ @waterson
5253
/research/syntaxnet/ @calberti @andorardo @bogatyy @markomernick
5354
/research/tcn/ @coreylynch @sermanet

research/README.md

+1
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ request.
7474
- [slim](slim): image classification models in TF-Slim.
7575
- [street](street): identify the name of a street (in France) from an image
7676
using a Deep RNN.
77+
- [struct2depth](struct2depth): unsupervised learning of depth and ego-motion.
7778
- [swivel](swivel): the Swivel algorithm for generating word embeddings.
7879
- [syntaxnet](syntaxnet): neural models of natural language syntax.
7980
- [tcn](tcn): Self-supervised representation learning from multi-view video.

research/struct2depth/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
package(default_visibility = ["//visibility:public"])

research/struct2depth/README.md

+147
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# struct2depth
2+
3+
This a method for unsupervised learning of depth and egomotion from monocular video, achieving new state-of-the-art results on both tasks by explicitly modeling 3D object motion, performing on-line refinement and improving quality for moving objects by novel loss formulations. It will appear in the following paper:
4+
5+
**V. Casser, S. Pirk, R. Mahjourian, A. Angelova, Depth Prediction Without the Sensors: Leveraging Structure for Unsupervised Learning from Monocular Videos, AAAI Conference on Artificial Intelligence, 2019**
6+
https://arxiv.org/pdf/1811.06152.pdf
7+
8+
This code is implemented and supported by Vincent Casser (git username: VincentCa) and Anelia Angelova (git username: AneliaAngelova). Please contact [email protected] for questions.
9+
10+
Project website: https://sites.google.com/view/struct2depth.
11+
12+
## Quick start: Running training
13+
14+
Before running training, run gen_data_* script for the respective dataset in order to generate the data in the appropriate format for KITTI or Cityscapes. It is assumed that motion masks are already generated and stored as images.
15+
Models are trained from an Imagenet pretrained model.
16+
17+
```shell
18+
19+
ckpt_dir="your/checkpoint/folder"
20+
data_dir="KITTI_SEQ2_LR/" # Set for KITTI
21+
data_dir="CITYSCAPES_SEQ2_LR/" # Set for Cityscapes
22+
imagenet_ckpt="resnet_pretrained/model.ckpt"
23+
24+
python train.py \
25+
--logtostderr \
26+
--checkpoint_dir $ckpt_dir \
27+
--data_dir $data_dir \
28+
--architecture resnet \
29+
--imagenet_ckpt $imagenet_ckpt \
30+
--imagenet_norm true \
31+
--joint_encoder false
32+
```
33+
34+
35+
36+
## Running depth/egomotion inference on an image folder
37+
38+
KITTI is trained on the raw image data (resized to 416 x 128), but inputs are standardized before feeding them, and Cityscapes images are cropped using the following cropping parameters: (192, 1856, 256, 768). If using a different crop, it is likely that additional training is necessary. Therefore, please follow the inference example shown below when using one of the models. The right choice might depend on a variety of factors. For example, if a checkpoint should be used for odometry, be aware that for improved odometry on motion models, using segmentation masks could be advantageous (setting *use_masks=true* for inference). On the other hand, all models can be used for single-frame depth estimation without any additional information.
39+
40+
41+
```shell
42+
43+
input_dir="your/image/folder"
44+
output_dir="your/output/folder"
45+
model_checkpoint="your/model/checkpoint"
46+
47+
python inference.py \
48+
--logtostderr \
49+
--file_extension png \
50+
--depth \
51+
--egomotion true \
52+
--input_dir $input_dir \
53+
--output_dir $output_dir \
54+
--model_ckpt $model_checkpoint
55+
```
56+
57+
Note that the egomotion prediction expects the files in the input directory to be a consecutive sequence, and that sorting the filenames alphabetically is putting them in the right order.
58+
59+
One can also run inference on KITTI by providing
60+
61+
```shell
62+
--input_list_file ~/kitti-raw-uncompressed/test_files_eigen.txt
63+
```
64+
65+
and on Cityscapes by passing
66+
67+
```shell
68+
--input_list_file CITYSCAPES_FULL/test_files_cityscapes.txt
69+
```
70+
71+
instead of *input_dir*.
72+
Alternatively inference can also be ran on pre-processed images.
73+
74+
75+
76+
## Running on-line refinement
77+
78+
On-line refinement is executed on top of an existing inference folder, so make sure to run regular inference first. Then you can run the on-line fusion procedure as follows:
79+
80+
```shell
81+
82+
prediction_dir="some/prediction/dir"
83+
model_ckpt="checkpoints/checkpoints_baseline/model-199160"
84+
handle_motion="false"
85+
size_constraint_weight="0" # This must be zero when not handling motion.
86+
87+
# If running on KITTI, set as follows:
88+
data_dir="KITTI_SEQ2_LR_EIGEN/"
89+
triplet_list_file="$data_dir/test_files_eigen_triplets.txt"
90+
triplet_list_file_remains="$data_dir/test_files_eigen_triplets_remains.txt"
91+
ft_name="kitti"
92+
93+
# If running on Cityscapes, set as follows:
94+
data_dir="CITYSCAPES_SEQ2_LR_TEST/" # Set for Cityscapes
95+
triplet_list_file="/CITYSCAPES_SEQ2_LR_TEST/test_files_cityscapes_triplets.txt"
96+
triplet_list_file_remains="CITYSCAPES_SEQ2_LR_TEST/test_files_cityscapes_triplets_remains.txt"
97+
ft_name="cityscapes"
98+
99+
python optimize.py \
100+
--logtostderr \
101+
--output_dir $prediction_dir \
102+
--data_dir $data_dir \
103+
--triplet_list_file $triplet_list_file \
104+
--triplet_list_file_remains $triplet_list_file_remains \
105+
--ft_name $ft_name \
106+
--model_ckpt $model_ckpt \
107+
--file_extension png \
108+
--handle_motion $handle_motion \
109+
--size_constraint_weight $size_constraint_weight
110+
```
111+
112+
113+
114+
## Running evaluation
115+
116+
```shell
117+
118+
prediction_dir="some/prediction/dir"
119+
120+
# Use these settings for KITTI:
121+
eval_list_file="KITTI_FULL/kitti-raw-uncompressed/test_files_eigen.txt"
122+
eval_crop="garg"
123+
eval_mode="kitti"
124+
125+
# Use these settings for Cityscapes:
126+
eval_list_file="CITYSCAPES_FULL/test_files_cityscapes.txt"
127+
eval_crop="none"
128+
eval_mode="cityscapes"
129+
130+
python evaluate.py \
131+
--logtostderr \
132+
--prediction_dir $prediction_dir \
133+
--eval_list_file $eval_list_file \
134+
--eval_crop $eval_crop \
135+
--eval_mode $eval_mode
136+
```
137+
138+
139+
140+
## Credits
141+
142+
This code is implemented and supported by Vincent Casser and Anelia Angelova and can be found at
143+
https://sites.google.com/view/struct2depth.
144+
The core implementation is derived from [https://github.com/tensorflow/models/tree/master/research/vid2depth)](https://github.com/tensorflow/models/tree/master/research/vid2depth)
145+
by [Reza Mahjourian]([email protected]), which in turn is based on [SfMLearner
146+
(https://github.com/tinghuiz/SfMLearner)](https://github.com/tinghuiz/SfMLearner)
147+
by [Tinghui Zhou](https://github.com/tinghuiz).

research/struct2depth/alignment.py

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
2+
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
17+
"""Common utilities for data pre-processing, e.g. matching moving object across frames."""
18+
19+
import numpy as np
20+
21+
def compute_overlap(mask1, mask2):
22+
# Use IoU here.
23+
return np.sum(mask1 & mask2)/np.sum(mask1 | mask2)
24+
25+
def align(seg_img1, seg_img2, seg_img3, threshold_same=0.3):
26+
res_img1 = np.zeros_like(seg_img1)
27+
res_img2 = np.zeros_like(seg_img2)
28+
res_img3 = np.zeros_like(seg_img3)
29+
remaining_objects2 = list(np.unique(seg_img2.flatten()))
30+
remaining_objects3 = list(np.unique(seg_img3.flatten()))
31+
for seg_id in np.unique(seg_img1):
32+
# See if we can find correspondences to seg_id in seg_img2.
33+
max_overlap2 = float('-inf')
34+
max_segid2 = -1
35+
for seg_id2 in remaining_objects2:
36+
overlap = compute_overlap(seg_img1==seg_id, seg_img2==seg_id2)
37+
if overlap>max_overlap2:
38+
max_overlap2 = overlap
39+
max_segid2 = seg_id2
40+
if max_overlap2 > threshold_same:
41+
max_overlap3 = float('-inf')
42+
max_segid3 = -1
43+
for seg_id3 in remaining_objects3:
44+
overlap = compute_overlap(seg_img2==max_segid2, seg_img3==seg_id3)
45+
if overlap>max_overlap3:
46+
max_overlap3 = overlap
47+
max_segid3 = seg_id3
48+
if max_overlap3 > threshold_same:
49+
res_img1[seg_img1==seg_id] = seg_id
50+
res_img2[seg_img2==max_segid2] = seg_id
51+
res_img3[seg_img3==max_segid3] = seg_id
52+
remaining_objects2.remove(max_segid2)
53+
remaining_objects3.remove(max_segid3)
54+
return res_img1, res_img2, res_img3
+158
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
2+
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
17+
""" Offline data generation for the Cityscapes dataset."""
18+
19+
import os
20+
from absl import app
21+
from absl import flags
22+
from absl import logging
23+
import numpy as np
24+
import cv2
25+
import os, glob
26+
27+
import alignment
28+
from alignment import compute_overlap
29+
from alignment import align
30+
31+
32+
SKIP = 2
33+
WIDTH = 416
34+
HEIGHT = 128
35+
SUB_FOLDER = 'train'
36+
INPUT_DIR = '/usr/local/google/home/anelia/struct2depth/CITYSCAPES_FULL/'
37+
OUTPUT_DIR = '/usr/local/google/home/anelia/struct2depth/CITYSCAPES_Processed/'
38+
39+
def crop(img, segimg, fx, fy, cx, cy):
40+
# Perform center cropping, preserving 50% vertically.
41+
middle_perc = 0.50
42+
left = 1 - middle_perc
43+
half = left / 2
44+
a = img[int(img.shape[0]*(half)):int(img.shape[0]*(1-half)), :]
45+
aseg = segimg[int(segimg.shape[0]*(half)):int(segimg.shape[0]*(1-half)), :]
46+
cy /= (1 / middle_perc)
47+
48+
# Resize to match target height while preserving aspect ratio.
49+
wdt = int((float(HEIGHT)*a.shape[1]/a.shape[0]))
50+
x_scaling = float(wdt)/a.shape[1]
51+
y_scaling = float(HEIGHT)/a.shape[0]
52+
b = cv2.resize(a, (wdt, HEIGHT))
53+
bseg = cv2.resize(aseg, (wdt, HEIGHT))
54+
55+
# Adjust intrinsics.
56+
fx*=x_scaling
57+
fy*=y_scaling
58+
cx*=x_scaling
59+
cy*=y_scaling
60+
61+
# Perform center cropping horizontally.
62+
remain = b.shape[1] - WIDTH
63+
cx /= (b.shape[1] / WIDTH)
64+
c = b[:, int(remain/2):b.shape[1]-int(remain/2)]
65+
cseg = bseg[:, int(remain/2):b.shape[1]-int(remain/2)]
66+
67+
return c, cseg, fx, fy, cx, cy
68+
69+
70+
def run_all():
71+
dir_name=INPUT_DIR + '/leftImg8bit_sequence/' + SUB_FOLDER + '/*'
72+
print('Processing directory', dir_name)
73+
for location in glob.glob(INPUT_DIR + '/leftImg8bit_sequence/' + SUB_FOLDER + '/*'):
74+
location_name = os.path.basename(location)
75+
print('Processing location', location_name)
76+
files = sorted(glob.glob(location + '/*.png'))
77+
files = [file for file in files if '-seg.png' not in file]
78+
# Break down into sequences
79+
sequences = {}
80+
seq_nr = 0
81+
last_seq = ''
82+
last_imgnr = -1
83+
84+
for i in range(len(files)):
85+
seq = os.path.basename(files[i]).split('_')[1]
86+
nr = int(os.path.basename(files[i]).split('_')[2])
87+
if seq!=last_seq or last_imgnr+1!=nr:
88+
seq_nr+=1
89+
last_imgnr = nr
90+
last_seq = seq
91+
if not seq_nr in sequences:
92+
sequences[seq_nr] = []
93+
sequences[seq_nr].append(files[i])
94+
95+
for (k,v) in sequences.items():
96+
print('Processing sequence', k, 'with', len(v), 'elements...')
97+
output_dir = OUTPUT_DIR + '/' + location_name + '_' + str(k)
98+
if not os.path.isdir(output_dir):
99+
os.mkdir(output_dir)
100+
files = sorted(v)
101+
triplet = []
102+
seg_triplet = []
103+
ct = 1
104+
105+
# Find applicable intrinsics.
106+
for j in range(len(files)):
107+
osegname = os.path.basename(files[j]).split('_')[1]
108+
oimgnr = os.path.basename(files[j]).split('_')[2]
109+
applicable_intrinsics = INPUT_DIR + '/camera/' + SUB_FOLDER + '/' + location_name + '/' + location_name + '_' + osegname + '_' + oimgnr + '_camera.json'
110+
# Get the intrinsics for one of the file of the sequence.
111+
if os.path.isfile(applicable_intrinsics):
112+
f = open(applicable_intrinsics, 'r')
113+
lines = f.readlines()
114+
f.close()
115+
lines = [line.rstrip() for line in lines]
116+
117+
fx = float(lines[11].split(': ')[1].replace(',', ''))
118+
fy = float(lines[12].split(': ')[1].replace(',', ''))
119+
cx = float(lines[13].split(': ')[1].replace(',', ''))
120+
cy = float(lines[14].split(': ')[1].replace(',', ''))
121+
122+
for j in range(0, len(files), SKIP):
123+
img = cv2.imread(files[j])
124+
segimg = cv2.imread(files[j].replace('.png', '-seg.png'))
125+
126+
smallimg, segimg, fx_this, fy_this, cx_this, cy_this = crop(img, segimg, fx, fy, cx, cy)
127+
triplet.append(smallimg)
128+
seg_triplet.append(segimg)
129+
if len(triplet)==3:
130+
cmb = np.hstack(triplet)
131+
align1, align2, align3 = align(seg_triplet[0], seg_triplet[1], seg_triplet[2])
132+
cmb_seg = np.hstack([align1, align2, align3])
133+
cv2.imwrite(os.path.join(output_dir, str(ct).zfill(10) + '.png'), cmb)
134+
cv2.imwrite(os.path.join(output_dir, str(ct).zfill(10) + '-fseg.png'), cmb_seg)
135+
f = open(os.path.join(output_dir, str(ct).zfill(10) + '_cam.txt'), 'w')
136+
f.write(str(fx_this) + ',0.0,' + str(cx_this) + ',0.0,' + str(fy_this) + ',' + str(cy_this) + ',0.0,0.0,1.0')
137+
f.close()
138+
del triplet[0]
139+
del seg_triplet[0]
140+
ct+=1
141+
142+
# Create file list for training. Be careful as it collects and includes all files recursively.
143+
fn = open(OUTPUT_DIR + '/' + SUB_FOLDER + '.txt', 'w')
144+
for f in glob.glob(OUTPUT_DIR + '/*/*.png'):
145+
if '-seg.png' in f or '-fseg.png' in f:
146+
continue
147+
folder_name = f.split('/')[-2]
148+
img_name = f.split('/')[-1].replace('.png', '')
149+
fn.write(folder_name + ' ' + img_name + '\n')
150+
fn.close()
151+
152+
153+
def main(_):
154+
run_all()
155+
156+
157+
if __name__ == '__main__':
158+
app.run(main)

0 commit comments

Comments
 (0)