1
+ from open3d import *
2
+ import h5py
3
+ import sys
4
+ import logging
5
+ import json
6
+ import os
7
+ from shutil import copyfile
8
+ import numpy as np
9
+ import torch .optim as optim
10
+ import torch .utils .data
11
+ from torch .autograd import Variable
12
+ from torch .optim .lr_scheduler import ReduceLROnPlateau
13
+ from src .PointNet import PrimitivesEmbeddingDGCNGn
14
+ from matplotlib import pyplot as plt
15
+ from src .utils import visualize_uv_maps , visualize_fitted_surface
16
+ from src .utils import chamfer_distance
17
+ from read_config import Config
18
+ from src .utils import fit_surface_sample_points
19
+ from src .dataset_segments import Dataset
20
+ from torch .utils .data import DataLoader
21
+ from src .utils import chamfer_distance
22
+ from src .segment_loss import EmbeddingLoss
23
+ from src .segment_utils import cluster
24
+ import time
25
+ from src .segment_loss import (
26
+ EmbeddingLoss ,
27
+ primitive_loss ,
28
+ evaluate_miou ,
29
+ )
30
+ from src .segment_utils import to_one_hot , SIOU_matched_segments
31
+ from src .utils import visualize_point_cloud_from_labels , visualize_point_cloud
32
+ from src .dataset import generator_iter
33
+ from src .mean_shift import MeanShift
34
+ from src .segment_utils import SIOU_matched_segments
35
+ from src .residual_utils import Evaluation
36
+ import time
37
+ from src .primitives import SaveParameters
38
+
39
+ # Use only one gpu.
40
+ os .environ ["CUDA_VISIBLE_DEVICES" ] = "0"
41
+ config = Config (sys .argv [1 ])
42
+ if_normals = config .normals
43
+
44
+ userspace = ""
45
+ Loss = EmbeddingLoss (margin = 1.0 )
46
+
47
+ if config .mode == 0 :
48
+ # Just using points for training
49
+ model = PrimitivesEmbeddingDGCNGn (
50
+ embedding = True ,
51
+ emb_size = 128 ,
52
+ primitives = True ,
53
+ num_primitives = 10 ,
54
+ loss_function = Loss .triplet_loss ,
55
+ mode = config .mode ,
56
+ num_channels = 3 ,
57
+ )
58
+ elif config .mode == 5 :
59
+ # Using points and normals for training
60
+ model = PrimitivesEmbeddingDGCNGn (
61
+ embedding = True ,
62
+ emb_size = 128 ,
63
+ primitives = True ,
64
+ num_primitives = 10 ,
65
+ loss_function = Loss .triplet_loss ,
66
+ mode = config .mode ,
67
+ num_channels = 6 ,
68
+ )
69
+
70
+ saveparameters = SaveParameters ()
71
+
72
+ model_bkp = model
73
+ model_bkp .l_permute = np .arange (10000 )
74
+ model = torch .nn .DataParallel (model , device_ids = [0 ])
75
+ model .cuda ()
76
+
77
+ split_dict = {"train" : config .num_train , "val" : config .num_val , "test" : config .num_test }
78
+ ms = MeanShift ()
79
+
80
+ dataset = Dataset (
81
+ config .batch_size ,
82
+ config .num_train ,
83
+ config .num_val ,
84
+ config .num_test ,
85
+ normals = True ,
86
+ primitives = True ,
87
+ if_train_data = False ,
88
+ prefix = userspace
89
+ )
90
+
91
+ get_test_data = dataset .get_test (align_canonical = True , anisotropic = False , if_normal_noise = True )
92
+
93
+ loader = generator_iter (get_test_data , int (1e10 ))
94
+ get_test_data = iter (
95
+ DataLoader (
96
+ loader ,
97
+ batch_size = 1 ,
98
+ shuffle = False ,
99
+ collate_fn = lambda x : x ,
100
+ num_workers = 0 ,
101
+ pin_memory = False ,
102
+ )
103
+ )
104
+
105
+ os .makedirs (userspace + "logs/results/{}/results/" .format (config .pretrain_model_path ), exist_ok = True )
106
+
107
+ evaluation = Evaluation ()
108
+ alt_gpu = 0
109
+ model .eval ()
110
+
111
+ iterations = 50
112
+ quantile = 0.015
113
+
114
+ model .load_state_dict (
115
+ torch .load (userspace + "logs/pretrained_models/" + config .pretrain_model_path )
116
+ )
117
+ test_res = []
118
+ test_s_iou = []
119
+ test_p_iou = []
120
+ test_g_res = []
121
+ test_s_res = []
122
+ PredictedLabels = []
123
+ PredictedPrims = []
124
+
125
+ for val_b_id in range (config .num_test // config .batch_size - 1 ):
126
+ points_ , labels , normals , primitives_ = next (get_test_data )[0 ]
127
+ points = Variable (torch .from_numpy (points_ .astype (np .float32 ))).cuda ()
128
+ normals = torch .from_numpy (normals ).cuda ()
129
+
130
+ # with torch.autograd.detect_anomaly():
131
+ with torch .no_grad ():
132
+ if if_normals :
133
+ input = torch .cat ([points , normals ], 2 )
134
+ embedding , primitives_log_prob , embed_loss = model (
135
+ input .permute (0 , 2 , 1 ), torch .from_numpy (labels ).cuda (), True
136
+ )
137
+ else :
138
+ embedding , primitives_log_prob , embed_loss = model (
139
+ points .permute (0 , 2 , 1 ), torch .from_numpy (labels ).cuda (), True
140
+ )
141
+ pred_primitives = torch .max (primitives_log_prob [0 ], 0 )[1 ].data .cpu ().numpy ()
142
+ embedding = torch .nn .functional .normalize (embedding [0 ].T , p = 2 , dim = 1 )
143
+ _ , _ , cluster_ids = evaluation .guard_mean_shift (
144
+ embedding , quantile , iterations , kernel_type = "gaussian"
145
+ )
146
+ weights = to_one_hot (cluster_ids , np .unique (cluster_ids .data .data .cpu ().numpy ()).shape [
147
+ 0 ])
148
+ cluster_ids = cluster_ids .data .cpu ().numpy ()
149
+
150
+ s_iou , p_iou , _ , _ = SIOU_matched_segments (
151
+ labels [0 ],
152
+ cluster_ids ,
153
+ pred_primitives ,
154
+ primitives_ [0 ],
155
+ weights ,
156
+ )
157
+ # print(s_iou, p_iou)
158
+ PredictedLabels .append (cluster_ids )
159
+ PredictedPrims .append (pred_primitives )
160
+ if val_b_id == 3 :
161
+ break
162
+
163
+ with h5py .File (userspace + "logs/results/{}/results/" .format (config .pretrain_model_path ) + "predictions.h5" , "w" ) as hf :
164
+ hf .create_dataset (name = "seg_id" , data = np .stack (PredictedLabels , 0 ))
165
+ hf .create_dataset (name = "pred_primitives" , data = np .stack (PredictedPrims , 0 ))
0 commit comments