Skip to content

Commit 6cebc10

Browse files
committed
clean code, refactor and small speed-up
1 parent 9ea259a commit 6cebc10

File tree

1 file changed

+47
-41
lines changed

1 file changed

+47
-41
lines changed

Diff for: pymc3/step_methods/pgbart.py

+47-41
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class PGBART(ArrayStepShared):
6161
generates_stats = True
6262
stats_dtypes = [{"variable_inclusion": np.ndarray}]
6363

64-
def __init__(self, vars=None, num_particles=10, max_stages=5000, chunk="auto", model=None):
64+
def __init__(self, vars=None, num_particles=10, max_stages=100, chunk="auto", model=None):
6565
_log.warning("BART is experimental. Use with caution.")
6666
model = modelcontext(model)
6767
initial_values = model.initial_point
@@ -78,7 +78,8 @@ def __init__(self, vars=None, num_particles=10, max_stages=5000, chunk="auto", m
7878

7979
self.init_mean = self.Y.mean()
8080
# if data is binary
81-
if np.all(np.unique(self.Y) == [0, 1]):
81+
Y_unique = np.unique(self.Y)
82+
if Y_unique.size == 2 and np.all(Y_unique == [0, 1]):
8283
self.mu_std = 6 / (self.k * self.m ** 0.5)
8384
# maybe we need to check for count data
8485
else:
@@ -97,6 +98,7 @@ def __init__(self, vars=None, num_particles=10, max_stages=5000, chunk="auto", m
9798
self.mean = fast_mean()
9899
self.normal = NormalSampler()
99100
self.prior_prob_leaf_node = compute_prior_probability(self.alpha)
101+
self.ssv = SampleSplittingVariable(self.split_prior)
100102

101103
self.tune = True
102104
self.idx = 0
@@ -120,7 +122,6 @@ def __init__(self, vars=None, num_particles=10, max_stages=5000, chunk="auto", m
120122
self.a_tree.tree_id = i
121123
p = ParticleTree(
122124
self.a_tree,
123-
self.prior_prob_leaf_node,
124125
self.init_log_weight,
125126
self.init_likelihood,
126127
)
@@ -132,29 +133,31 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
132133
sum_trees_output = q.data
133134

134135
variable_inclusion = np.zeros(self.num_variates, dtype="int")
135-
self.ssv = SampleSplittingVariable(self.split_prior)
136136

137137
if self.idx == self.m:
138138
self.idx = 0
139139

140140
for idx in range(self.idx, self.idx + self.chunk):
141141
if idx >= self.m:
142142
break
143-
self.idx += 1
144143
tree = self.all_particles[idx].tree
145144
sum_trees_output_noi = sum_trees_output - tree.predict_output()
145+
self.idx += 1
146146
# Generate an initial set of SMC particles
147147
# at the end of the algorithm we return one of these particles as the new tree
148148
particles = self.init_particles(tree.tree_id)
149149

150-
for t in range(1, self.max_stages):
150+
for t in range(self.max_stages):
151151
# Get old particle at stage t
152-
particles[0] = self.get_old_tree_particle(tree.tree_id, t)
152+
if t > 0:
153+
particles[0] = self.get_old_tree_particle(tree.tree_id, t)
153154
# sample each particle (try to grow each tree)
154-
for c in range(1, self.num_particles):
155-
particles[c].sample_tree_sequential(
155+
compute_logp = [True]
156+
for p in particles[1:]:
157+
clp = p.sample_tree_sequential(
156158
self.ssv,
157159
self.available_predictors,
160+
self.prior_prob_leaf_node,
158161
self.X,
159162
self.missing_data,
160163
sum_trees_output,
@@ -163,34 +166,34 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
163166
self.normal,
164167
self.mu_std,
165168
)
169+
compute_logp.append(clp)
166170
# Update weights. Since the prior is used as the proposal,the weights
167171
# are updated additively as the ratio of the new and old log_likelihoods
168-
for p in particles:
169-
new_likelihood = self.likelihood_logp(
170-
sum_trees_output_noi + p.tree.predict_output()
171-
)
172-
p.log_weight += new_likelihood - p.old_likelihood_logp
173-
p.old_likelihood_logp = new_likelihood
174-
172+
for clp, p in zip(compute_logp, particles):
173+
if clp: # Compute the likelihood when p has changed from the previous iteration
174+
new_likelihood = self.likelihood_logp(
175+
sum_trees_output_noi + p.tree.predict_output()
176+
)
177+
p.log_weight += new_likelihood - p.old_likelihood_logp
178+
p.old_likelihood_logp = new_likelihood
175179
# Normalize weights
176180
W_t, normalized_weights = self.normalize(particles)
177181

178-
# Set the new weights
179-
for p in particles:
180-
p.log_weight = W_t
181-
182182
# Resample all but first particle
183183
re_n_w = normalized_weights[1:] / normalized_weights[1:].sum()
184-
185184
new_indices = np.random.choice(self.indices, size=len(self.indices), p=re_n_w)
186185
particles[1:] = particles[new_indices]
187186

187+
# Set the new weights
188+
for p in particles:
189+
p.log_weight = W_t
190+
188191
# Check if particles can keep growing, otherwise stop iterating
189-
non_available_nodes_for_expansion = np.ones(self.num_particles - 1)
190-
for c in range(1, self.num_particles):
191-
if len(particles[c].expansion_nodes) != 0:
192-
non_available_nodes_for_expansion[c - 1] = 0
193-
if np.all(non_available_nodes_for_expansion):
192+
non_available_nodes_for_expansion = []
193+
for p in particles[1:]:
194+
if p.expansion_nodes:
195+
non_available_nodes_for_expansion.append(0)
196+
if all(non_available_nodes_for_expansion):
194197
break
195198

196199
# Get the new tree and update
@@ -203,6 +206,7 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
203206
if self.tune:
204207
for index in new_particle.used_variates:
205208
self.split_prior[index] += 1
209+
self.ssv = SampleSplittingVariable(self.split_prior)
206210
else:
207211
self.iter += 1
208212
self.sum_trees.append(new_tree)
@@ -253,14 +257,16 @@ def init_particles(self, tree_id):
253257
"""
254258
Initialize particles
255259
"""
256-
particles = [self.get_old_tree_particle(tree_id, 0)]
260+
p = self.get_old_tree_particle(tree_id, 0)
261+
p.log_weight = self.init_log_weight
262+
p.old_likelihood_logp = self.init_likelihood
263+
particles = [p]
257264

258-
for _ in range(1, self.num_particles):
265+
for _ in self.indices:
259266
self.a_tree.tree_id = tree_id
260267
particles.append(
261268
ParticleTree(
262269
self.a_tree,
263-
self.prior_prob_leaf_node,
264270
self.init_log_weight,
265271
self.init_likelihood,
266272
)
@@ -274,20 +280,20 @@ class ParticleTree:
274280
Particle tree
275281
"""
276282

277-
def __init__(self, tree, prior_prob_leaf_node, log_weight=0, likelihood=0):
283+
def __init__(self, tree, log_weight, likelihood):
278284
self.tree = tree.copy() # keeps the tree that we care at the moment
279285
self.expansion_nodes = [0]
280286
self.tree_history = [self.tree]
281287
self.expansion_nodes_history = [self.expansion_nodes]
282288
self.log_weight = log_weight
283-
self.prior_prob_leaf_node = prior_prob_leaf_node
284289
self.old_likelihood_logp = likelihood
285290
self.used_variates = []
286291

287292
def sample_tree_sequential(
288293
self,
289294
ssv,
290295
available_predictors,
296+
prior_prob_leaf_node,
291297
X,
292298
missing_data,
293299
sum_trees_output,
@@ -296,13 +302,14 @@ def sample_tree_sequential(
296302
normal,
297303
mu_std,
298304
):
305+
clp = False
299306
if self.expansion_nodes:
300307
index_leaf_node = self.expansion_nodes.pop(0)
301308
# Probability that this node will remain a leaf node
302-
prob_leaf = self.prior_prob_leaf_node[self.tree[index_leaf_node].depth]
309+
prob_leaf = prior_prob_leaf_node[self.tree[index_leaf_node].depth]
303310

304311
if prob_leaf < np.random.random():
305-
index_selected_predictor = grow_tree(
312+
clp, index_selected_predictor = grow_tree(
306313
self.tree,
307314
index_leaf_node,
308315
ssv,
@@ -315,21 +322,20 @@ def sample_tree_sequential(
315322
normal,
316323
mu_std,
317324
)
318-
if index_selected_predictor is not None:
325+
if clp:
319326
new_indexes = self.tree.idx_leaf_nodes[-2:]
320327
self.expansion_nodes.extend(new_indexes)
321328
self.used_variates.append(index_selected_predictor)
322329

323330
self.tree_history.append(self.tree)
324331
self.expansion_nodes_history.append(self.expansion_nodes)
332+
return clp
325333

326334
def set_particle_to_step(self, t):
327335
if len(self.tree_history) <= t:
328-
self.tree = self.tree_history[-1]
329-
self.expansion_nodes = self.expansion_nodes_history[-1]
330-
else:
331-
self.tree = self.tree_history[t]
332-
self.expansion_nodes = self.expansion_nodes_history[t]
336+
t = -1
337+
self.tree = self.tree_history[t]
338+
self.expansion_nodes = self.expansion_nodes_history[t]
333339

334340

335341
def preprocess_XY(X, Y):
@@ -410,7 +416,7 @@ def grow_tree(
410416
]
411417

412418
if available_splitting_values.size == 0:
413-
return None
419+
return False, None
414420

415421
idx_selected_splitting_values = discrete_uniform_sampler(len(available_splitting_values))
416422
selected_splitting_rule = available_splitting_values[idx_selected_splitting_values]
@@ -443,7 +449,7 @@ def grow_tree(
443449
)
444450
tree.grow_tree(index_leaf_node, new_split_node, new_left_node, new_right_node)
445451

446-
return index_selected_predictor
452+
return True, index_selected_predictor
447453

448454

449455
def get_new_idx_data_points(current_split_node, idx_data_points, X):

0 commit comments

Comments
 (0)