Skip to content

Commit 223ed93

Browse files
committed
adds test script
1 parent 934ec2a commit 223ed93

9 files changed

+499
-15
lines changed

download_dataset.sh

+8-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
echo "Downloading dataset"
2-
wget http://neghvar.cs.umass.edu/public_data/parsenet/data.zip
3-
2+
#wget http://neghvar.cs.umass.edu/public_data/parsenet/data.zip
3+
wget http://neghvar.cs.umass.edu/public_data/parsenet/predictions.h5
44
echo "unzipping"
5-
unzip data.zip
5+
#unzip data.zip
6+
mkdir logs
7+
mkdir logs/results
8+
mkdir logs/results/parsenet_with_normals.pth
9+
mkdir logs/results/parsenet_with_normals.pth/results
10+
mv predictions.h5 logs/results/parsenet_with_normals.pth/results/predictions.h5

generate_predictions.py

+165
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
from open3d import *
2+
import h5py
3+
import sys
4+
import logging
5+
import json
6+
import os
7+
from shutil import copyfile
8+
import numpy as np
9+
import torch.optim as optim
10+
import torch.utils.data
11+
from torch.autograd import Variable
12+
from torch.optim.lr_scheduler import ReduceLROnPlateau
13+
from src.PointNet import PrimitivesEmbeddingDGCNGn
14+
from matplotlib import pyplot as plt
15+
from src.utils import visualize_uv_maps, visualize_fitted_surface
16+
from src.utils import chamfer_distance
17+
from read_config import Config
18+
from src.utils import fit_surface_sample_points
19+
from src.dataset_segments import Dataset
20+
from torch.utils.data import DataLoader
21+
from src.utils import chamfer_distance
22+
from src.segment_loss import EmbeddingLoss
23+
from src.segment_utils import cluster
24+
import time
25+
from src.segment_loss import (
26+
EmbeddingLoss,
27+
primitive_loss,
28+
evaluate_miou,
29+
)
30+
from src.segment_utils import to_one_hot, SIOU_matched_segments
31+
from src.utils import visualize_point_cloud_from_labels, visualize_point_cloud
32+
from src.dataset import generator_iter
33+
from src.mean_shift import MeanShift
34+
from src.segment_utils import SIOU_matched_segments
35+
from src.residual_utils import Evaluation
36+
import time
37+
from src.primitives import SaveParameters
38+
39+
# Use only one gpu.
40+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
41+
config = Config(sys.argv[1])
42+
if_normals = config.normals
43+
44+
userspace = ""
45+
Loss = EmbeddingLoss(margin=1.0)
46+
47+
if config.mode == 0:
48+
# Just using points for training
49+
model = PrimitivesEmbeddingDGCNGn(
50+
embedding=True,
51+
emb_size=128,
52+
primitives=True,
53+
num_primitives=10,
54+
loss_function=Loss.triplet_loss,
55+
mode=config.mode,
56+
num_channels=3,
57+
)
58+
elif config.mode == 5:
59+
# Using points and normals for training
60+
model = PrimitivesEmbeddingDGCNGn(
61+
embedding=True,
62+
emb_size=128,
63+
primitives=True,
64+
num_primitives=10,
65+
loss_function=Loss.triplet_loss,
66+
mode=config.mode,
67+
num_channels=6,
68+
)
69+
70+
saveparameters = SaveParameters()
71+
72+
model_bkp = model
73+
model_bkp.l_permute = np.arange(10000)
74+
model = torch.nn.DataParallel(model, device_ids=[0])
75+
model.cuda()
76+
77+
split_dict = {"train": config.num_train, "val": config.num_val, "test": config.num_test}
78+
ms = MeanShift()
79+
80+
dataset = Dataset(
81+
config.batch_size,
82+
config.num_train,
83+
config.num_val,
84+
config.num_test,
85+
normals=True,
86+
primitives=True,
87+
if_train_data=False,
88+
prefix=userspace
89+
)
90+
91+
get_test_data = dataset.get_test(align_canonical=True, anisotropic=False, if_normal_noise=True)
92+
93+
loader = generator_iter(get_test_data, int(1e10))
94+
get_test_data = iter(
95+
DataLoader(
96+
loader,
97+
batch_size=1,
98+
shuffle=False,
99+
collate_fn=lambda x: x,
100+
num_workers=0,
101+
pin_memory=False,
102+
)
103+
)
104+
105+
os.makedirs(userspace + "logs/results/{}/results/".format(config.pretrain_model_path), exist_ok=True)
106+
107+
evaluation = Evaluation()
108+
alt_gpu = 0
109+
model.eval()
110+
111+
iterations = 50
112+
quantile = 0.015
113+
114+
model.load_state_dict(
115+
torch.load(userspace + "logs/pretrained_models/" + config.pretrain_model_path)
116+
)
117+
test_res = []
118+
test_s_iou = []
119+
test_p_iou = []
120+
test_g_res = []
121+
test_s_res = []
122+
PredictedLabels = []
123+
PredictedPrims = []
124+
125+
for val_b_id in range(config.num_test // config.batch_size - 1):
126+
points_, labels, normals, primitives_ = next(get_test_data)[0]
127+
points = Variable(torch.from_numpy(points_.astype(np.float32))).cuda()
128+
normals = torch.from_numpy(normals).cuda()
129+
130+
# with torch.autograd.detect_anomaly():
131+
with torch.no_grad():
132+
if if_normals:
133+
input = torch.cat([points, normals], 2)
134+
embedding, primitives_log_prob, embed_loss = model(
135+
input.permute(0, 2, 1), torch.from_numpy(labels).cuda(), True
136+
)
137+
else:
138+
embedding, primitives_log_prob, embed_loss = model(
139+
points.permute(0, 2, 1), torch.from_numpy(labels).cuda(), True
140+
)
141+
pred_primitives = torch.max(primitives_log_prob[0], 0)[1].data.cpu().numpy()
142+
embedding = torch.nn.functional.normalize(embedding[0].T, p=2, dim=1)
143+
_, _, cluster_ids = evaluation.guard_mean_shift(
144+
embedding, quantile, iterations, kernel_type="gaussian"
145+
)
146+
weights = to_one_hot(cluster_ids, np.unique(cluster_ids.data.data.cpu().numpy()).shape[
147+
0])
148+
cluster_ids = cluster_ids.data.cpu().numpy()
149+
150+
s_iou, p_iou, _, _ = SIOU_matched_segments(
151+
labels[0],
152+
cluster_ids,
153+
pred_primitives,
154+
primitives_[0],
155+
weights,
156+
)
157+
# print(s_iou, p_iou)
158+
PredictedLabels.append(cluster_ids)
159+
PredictedPrims.append(pred_primitives)
160+
if val_b_id == 3:
161+
break
162+
163+
with h5py.File(userspace + "logs/results/{}/results/".format(config.pretrain_model_path) + "predictions.h5", "w") as hf:
164+
hf.create_dataset(name="seg_id", data=np.stack(PredictedLabels, 0))
165+
hf.create_dataset(name="pred_primitives", data=np.stack(PredictedPrims, 0))

readme.md

+5-1
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,16 @@ python train_parsenet.py configs/config_parsenet.yml
8181
python train_parsenet.py configs/config_parsenet_normals.yml
8282
```
8383

84-
* To train ParseNet in an end to end manner (note that you need to first pretrain the above models), then specify the path to the trained model in `configs/config_parsenet_e2e.yml` (with 2 gpus). Further note that, this part of the training requires dynamic amount of gpu memory because a shape can have variable number of segment and corresponding number of fitting module. Training is done using Nvidia m40 (24 Gb gpu). Testing can be done using `test_parsenet.py`.
84+
* To train ParseNet in an end to end manner (note that you need to first pretrain the above models), then specify the path to the trained model in `configs/config_parsenet_e2e.yml` (with 2 gpus). Further note that, this part of the training requires dynamic amount of gpu memory because a shape can have variable number of segment and corresponding number of fitting module. Training is done using Nvidia m40 (24 Gb gpu).
8585

8686
```
8787
python train_parsenet_e2e.py configs/config_parsenet_e2e.yml
8888
```
8989

90+
* Testing can be done using `test.py`
91+
```
92+
python test.py 0 3998
93+
```
9094
------
9195

9296

src/dataset_segments.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ def __init__(self,
1919
test_size=None,
2020
normals=False,
2121
primitives=False,
22-
if_train_data=True):
22+
if_train_data=True,
23+
prefix=""):
2324
"""
2425
Dataset of point cloud from ABC dataset.
2526
:param root_path:
@@ -34,7 +35,7 @@ def __init__(self,
3435
random_scale_point_cloud, rotate_point_cloud]
3536

3637
if if_train_data:
37-
with h5py.File("data/shapes/train_data.h5", "r") as hf:
38+
with h5py.File(prefix + "data/shapes/train_data.h5", "r") as hf:
3839
train_points = np.array(hf.get("points"))
3940
train_labels = np.array(hf.get("labels"))
4041
if normals:
@@ -51,15 +52,15 @@ def __init__(self,
5152
self.train_points = (train_points - means)
5253
self.train_labels = train_labels
5354

54-
with h5py.File("data/shapes/val_data.h5", "r") as hf:
55+
with h5py.File(prefix + "data/shapes/val_data.h5", "r") as hf:
5556
val_points = np.array(hf.get("points"))
5657
val_labels = np.array(hf.get("labels"))
5758
if normals:
5859
val_normals = np.array(hf.get("normals"))
5960
if primitives:
6061
val_primitives = np.array(hf.get("prim"))
6162

62-
with h5py.File("data/shapes/test_data.h5", "r") as hf:
63+
with h5py.File(prefix + "data/shapes/test_data.h5", "r") as hf:
6364
test_points = np.array(hf.get("points"))
6465
test_labels = np.array(hf.get("labels"))
6566
if normals:

src/fitting_utils.py

+110
Original file line numberDiff line numberDiff line change
@@ -708,3 +708,113 @@ def remove_outliers(points, viz=False):
708708
if viz:
709709
display_inlier_outlier(voxel_down_pcd, ind)
710710
return np.array(cl.points)
711+
712+
713+
def visualize_bit_mapping_shape(data_, weights, recon_points, parameters=None, bit_map=True, epsilon=0.05):
714+
# This steps basically gathers trimmed primitives and samples points and normals on trimmed surfaces.
715+
# TODO: better way to do it is to not tesellate but directly find the
716+
# grid point that are occupied.
717+
pred_meshes = []
718+
719+
for index, g in enumerate(data_):
720+
if (recon_points[index] is None):
721+
# corresponds to degenrate cases
722+
continue
723+
if isinstance(recon_points[index], np.ndarray):
724+
if recon_points[index].shape[0] == 0:
725+
continue
726+
727+
points, _, l, _, _, i = g
728+
if not isinstance(points, np.ndarray):
729+
points = points.data.cpu().numpy()
730+
731+
part_points = points
732+
733+
if l in [11]:
734+
# torus
735+
part_points = up_sample_points_torch_memory_efficient(torch.from_numpy(points).cuda(), 2).data.cpu().numpy()
736+
737+
if bit_map:
738+
if epsilon:
739+
e = epsilon
740+
else:
741+
e = 0.03
742+
pred_mesh = bit_mapping_points_torch(part_points, recon_points[index], e, 100, 60)
743+
744+
if l in [0, 9, 6, 7]:
745+
# closed bspline surface
746+
if not isinstance(recon_points[index], np.ndarray):
747+
recon_points_ = recon_points[index].data.cpu().numpy()[0]
748+
else:
749+
recon_points_ = recon_points[index]
750+
part_points = up_sample_points_torch_memory_efficient(torch.from_numpy(points).cuda(), 2).data.cpu().numpy()
751+
try:
752+
pred_mesh = tessalate_points_fast(recon_points_, 31, 30)
753+
except:
754+
import ipdb;
755+
ipdb.set_trace()
756+
757+
if bit_map:
758+
if epsilon:
759+
e = epsilon
760+
else:
761+
e = 0.06
762+
pred_mesh = bit_mapping_points_torch(part_points, recon_points_, e, 31, 30)
763+
764+
elif l in [2, 8]:
765+
# open bspline surface
766+
part_points = up_sample_points_torch_memory_efficient(torch.from_numpy(points).cuda(), 2).data.cpu().numpy()
767+
if not isinstance(recon_points[index], np.ndarray):
768+
recon_points_ = recon_points[index].data.cpu().numpy()[0]
769+
else:
770+
recon_points_ = recon_points[index]
771+
pred_mesh = tessalate_points_fast(recon_points_, 30, 30)
772+
if bit_map:
773+
if epsilon:
774+
e = epsilon
775+
else:
776+
e = 0.06
777+
pred_mesh = bit_mapping_points_torch(part_points, recon_points_, e, 30, 30)
778+
779+
elif l == 1:
780+
# Fit plane
781+
part_points = up_sample_points_torch_memory_efficient(torch.from_numpy(points).cuda(), 3).data.cpu().numpy()
782+
if epsilon:
783+
e = epsilon
784+
else:
785+
e = 0.02
786+
pred_mesh = bit_mapping_points_torch(part_points, recon_points[index], e, 120, 120)
787+
788+
elif l == 3:
789+
# Cone
790+
part_points = up_sample_points_torch_memory_efficient(torch.from_numpy(points).cuda(), 3).data.cpu().numpy()
791+
if epsilon:
792+
e = epsilon
793+
else:
794+
e = 0.03
795+
try:
796+
N = recon_points[index].shape[0] // 51
797+
pred_mesh = bit_mapping_points_torch(part_points, recon_points[index], e, N, 51)
798+
except:
799+
import ipdb;
800+
ipdb.set_trace()
801+
802+
elif l == 4:
803+
# cylinder
804+
part_points = up_sample_points_torch_memory_efficient(torch.from_numpy(points).cuda(), 3).data.cpu().numpy()
805+
806+
if epsilon:
807+
e = epsilon
808+
else:
809+
e = 0.03
810+
pred_mesh = bit_mapping_points_torch(part_points, recon_points[index], e, 200, 60)
811+
812+
elif l == 5:
813+
part_points = up_sample_points_torch_memory_efficient(torch.from_numpy(points).cuda(), 2).data.cpu().numpy()
814+
if epsilon:
815+
e = epsilon
816+
else:
817+
e = 0.03
818+
pred_mesh = bit_mapping_points_torch(part_points, recon_points[index], e, 100, 100)
819+
pred_meshes.append(pred_mesh)
820+
return pred_meshes

src/residual_utils.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
match,
1616
)
1717
from open3d import *
18+
from src.fitting_utils import visualize_bit_mapping_shape
19+
1820

1921
Vector3dVector, Vector3iVector = utility.Vector3dVector, utility.Vector3iVector
2022
from src.mean_shift import MeanShift
@@ -127,7 +129,7 @@ def fitting_loss(
127129
lamb=lamb
128130
)
129131
else:
130-
loss, parameters, pred_mesh, gtpoints, distance, _, _ = self.residual_eval_mode(
132+
loss, parameters, pred_mesh = self.residual_eval_mode(
131133
points[b],
132134
normals[b],
133135
labels[b],
@@ -322,10 +324,11 @@ def residual_eval_mode(
322324
else:
323325
distance = None
324326
if sample_points:
325-
pred_meshes = None
327+
pred_meshes = visualize_bit_mapping_shape(
328+
data_, weights, recon_points, self.fitter.fitting.parameters, epsilon=epsilon)
326329
else:
327330
pred_meshes = None
328-
return Loss, self.fitter.fitting.parameters, pred_meshes, gt_points, distance, rows, cols
331+
return Loss, self.fitter.fitting.parameters, pred_meshes
329332

330333
def separate_losses(self, distance, gt_points, lamb=1.0):
331334
"""

0 commit comments

Comments
 (0)