@@ -61,7 +61,7 @@ class PGBART(ArrayStepShared):
61
61
generates_stats = True
62
62
stats_dtypes = [{"variable_inclusion" : np .ndarray }]
63
63
64
- def __init__ (self , vars = None , num_particles = 10 , max_stages = 5000 , chunk = "auto" , model = None ):
64
+ def __init__ (self , vars = None , num_particles = 10 , max_stages = 100 , chunk = "auto" , model = None ):
65
65
_log .warning ("BART is experimental. Use with caution." )
66
66
model = modelcontext (model )
67
67
initial_values = model .initial_point
@@ -78,7 +78,8 @@ def __init__(self, vars=None, num_particles=10, max_stages=5000, chunk="auto", m
78
78
79
79
self .init_mean = self .Y .mean ()
80
80
# if data is binary
81
- if np .all (np .unique (self .Y ) == [0 , 1 ]):
81
+ Y_unique = np .unique (self .Y )
82
+ if Y_unique .size == 2 and np .all (Y_unique == [0 , 1 ]):
82
83
self .mu_std = 6 / (self .k * self .m ** 0.5 )
83
84
# maybe we need to check for count data
84
85
else :
@@ -97,6 +98,7 @@ def __init__(self, vars=None, num_particles=10, max_stages=5000, chunk="auto", m
97
98
self .mean = fast_mean ()
98
99
self .normal = NormalSampler ()
99
100
self .prior_prob_leaf_node = compute_prior_probability (self .alpha )
101
+ self .ssv = SampleSplittingVariable (self .split_prior )
100
102
101
103
self .tune = True
102
104
self .idx = 0
@@ -120,7 +122,6 @@ def __init__(self, vars=None, num_particles=10, max_stages=5000, chunk="auto", m
120
122
self .a_tree .tree_id = i
121
123
p = ParticleTree (
122
124
self .a_tree ,
123
- self .prior_prob_leaf_node ,
124
125
self .init_log_weight ,
125
126
self .init_likelihood ,
126
127
)
@@ -132,29 +133,31 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
132
133
sum_trees_output = q .data
133
134
134
135
variable_inclusion = np .zeros (self .num_variates , dtype = "int" )
135
- self .ssv = SampleSplittingVariable (self .split_prior )
136
136
137
137
if self .idx == self .m :
138
138
self .idx = 0
139
139
140
140
for idx in range (self .idx , self .idx + self .chunk ):
141
141
if idx >= self .m :
142
142
break
143
- self .idx += 1
144
143
tree = self .all_particles [idx ].tree
145
144
sum_trees_output_noi = sum_trees_output - tree .predict_output ()
145
+ self .idx += 1
146
146
# Generate an initial set of SMC particles
147
147
# at the end of the algorithm we return one of these particles as the new tree
148
148
particles = self .init_particles (tree .tree_id )
149
149
150
- for t in range (1 , self .max_stages ):
150
+ for t in range (self .max_stages ):
151
151
# Get old particle at stage t
152
- particles [0 ] = self .get_old_tree_particle (tree .tree_id , t )
152
+ if t > 0 :
153
+ particles [0 ] = self .get_old_tree_particle (tree .tree_id , t )
153
154
# sample each particle (try to grow each tree)
154
- for c in range (1 , self .num_particles ):
155
- particles [c ].sample_tree_sequential (
155
+ compute_logp = [True ]
156
+ for p in particles [1 :]:
157
+ clp = p .sample_tree_sequential (
156
158
self .ssv ,
157
159
self .available_predictors ,
160
+ self .prior_prob_leaf_node ,
158
161
self .X ,
159
162
self .missing_data ,
160
163
sum_trees_output ,
@@ -163,34 +166,34 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
163
166
self .normal ,
164
167
self .mu_std ,
165
168
)
169
+ compute_logp .append (clp )
166
170
# Update weights. Since the prior is used as the proposal,the weights
167
171
# are updated additively as the ratio of the new and old log_likelihoods
168
- for p in particles :
169
- new_likelihood = self . likelihood_logp (
170
- sum_trees_output_noi + p . tree . predict_output ()
171
- )
172
- p . log_weight += new_likelihood - p . old_likelihood_logp
173
- p . old_likelihood_logp = new_likelihood
174
-
172
+ for clp , p in zip ( compute_logp , particles ) :
173
+ if clp : # Compute the likelihood when p has changed from the previous iteration
174
+ new_likelihood = self . likelihood_logp (
175
+ sum_trees_output_noi + p . tree . predict_output ( )
176
+ )
177
+ p . log_weight + = new_likelihood - p . old_likelihood_logp
178
+ p . old_likelihood_logp = new_likelihood
175
179
# Normalize weights
176
180
W_t , normalized_weights = self .normalize (particles )
177
181
178
- # Set the new weights
179
- for p in particles :
180
- p .log_weight = W_t
181
-
182
182
# Resample all but first particle
183
183
re_n_w = normalized_weights [1 :] / normalized_weights [1 :].sum ()
184
-
185
184
new_indices = np .random .choice (self .indices , size = len (self .indices ), p = re_n_w )
186
185
particles [1 :] = particles [new_indices ]
187
186
187
+ # Set the new weights
188
+ for p in particles :
189
+ p .log_weight = W_t
190
+
188
191
# Check if particles can keep growing, otherwise stop iterating
189
- non_available_nodes_for_expansion = np . ones ( self . num_particles - 1 )
190
- for c in range ( 1 , self . num_particles ) :
191
- if len ( particles [ c ] .expansion_nodes ) != 0 :
192
- non_available_nodes_for_expansion [ c - 1 ] = 0
193
- if np . all (non_available_nodes_for_expansion ):
192
+ non_available_nodes_for_expansion = []
193
+ for p in particles [ 1 :] :
194
+ if p .expansion_nodes :
195
+ non_available_nodes_for_expansion . append ( 0 )
196
+ if all (non_available_nodes_for_expansion ):
194
197
break
195
198
196
199
# Get the new tree and update
@@ -203,6 +206,7 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
203
206
if self .tune :
204
207
for index in new_particle .used_variates :
205
208
self .split_prior [index ] += 1
209
+ self .ssv = SampleSplittingVariable (self .split_prior )
206
210
else :
207
211
self .iter += 1
208
212
self .sum_trees .append (new_tree )
@@ -253,14 +257,16 @@ def init_particles(self, tree_id):
253
257
"""
254
258
Initialize particles
255
259
"""
256
- particles = [self .get_old_tree_particle (tree_id , 0 )]
260
+ p = self .get_old_tree_particle (tree_id , 0 )
261
+ p .log_weight = self .init_log_weight
262
+ p .old_likelihood_logp = self .init_likelihood
263
+ particles = [p ]
257
264
258
- for _ in range ( 1 , self .num_particles ) :
265
+ for _ in self .indices :
259
266
self .a_tree .tree_id = tree_id
260
267
particles .append (
261
268
ParticleTree (
262
269
self .a_tree ,
263
- self .prior_prob_leaf_node ,
264
270
self .init_log_weight ,
265
271
self .init_likelihood ,
266
272
)
@@ -274,20 +280,20 @@ class ParticleTree:
274
280
Particle tree
275
281
"""
276
282
277
- def __init__ (self , tree , prior_prob_leaf_node , log_weight = 0 , likelihood = 0 ):
283
+ def __init__ (self , tree , log_weight , likelihood ):
278
284
self .tree = tree .copy () # keeps the tree that we care at the moment
279
285
self .expansion_nodes = [0 ]
280
286
self .tree_history = [self .tree ]
281
287
self .expansion_nodes_history = [self .expansion_nodes ]
282
288
self .log_weight = log_weight
283
- self .prior_prob_leaf_node = prior_prob_leaf_node
284
289
self .old_likelihood_logp = likelihood
285
290
self .used_variates = []
286
291
287
292
def sample_tree_sequential (
288
293
self ,
289
294
ssv ,
290
295
available_predictors ,
296
+ prior_prob_leaf_node ,
291
297
X ,
292
298
missing_data ,
293
299
sum_trees_output ,
@@ -296,13 +302,14 @@ def sample_tree_sequential(
296
302
normal ,
297
303
mu_std ,
298
304
):
305
+ clp = False
299
306
if self .expansion_nodes :
300
307
index_leaf_node = self .expansion_nodes .pop (0 )
301
308
# Probability that this node will remain a leaf node
302
- prob_leaf = self . prior_prob_leaf_node [self .tree [index_leaf_node ].depth ]
309
+ prob_leaf = prior_prob_leaf_node [self .tree [index_leaf_node ].depth ]
303
310
304
311
if prob_leaf < np .random .random ():
305
- index_selected_predictor = grow_tree (
312
+ clp , index_selected_predictor = grow_tree (
306
313
self .tree ,
307
314
index_leaf_node ,
308
315
ssv ,
@@ -315,21 +322,20 @@ def sample_tree_sequential(
315
322
normal ,
316
323
mu_std ,
317
324
)
318
- if index_selected_predictor is not None :
325
+ if clp :
319
326
new_indexes = self .tree .idx_leaf_nodes [- 2 :]
320
327
self .expansion_nodes .extend (new_indexes )
321
328
self .used_variates .append (index_selected_predictor )
322
329
323
330
self .tree_history .append (self .tree )
324
331
self .expansion_nodes_history .append (self .expansion_nodes )
332
+ return clp
325
333
326
334
def set_particle_to_step (self , t ):
327
335
if len (self .tree_history ) <= t :
328
- self .tree = self .tree_history [- 1 ]
329
- self .expansion_nodes = self .expansion_nodes_history [- 1 ]
330
- else :
331
- self .tree = self .tree_history [t ]
332
- self .expansion_nodes = self .expansion_nodes_history [t ]
336
+ t = - 1
337
+ self .tree = self .tree_history [t ]
338
+ self .expansion_nodes = self .expansion_nodes_history [t ]
333
339
334
340
335
341
def preprocess_XY (X , Y ):
@@ -410,7 +416,7 @@ def grow_tree(
410
416
]
411
417
412
418
if available_splitting_values .size == 0 :
413
- return None
419
+ return False , None
414
420
415
421
idx_selected_splitting_values = discrete_uniform_sampler (len (available_splitting_values ))
416
422
selected_splitting_rule = available_splitting_values [idx_selected_splitting_values ]
@@ -443,7 +449,7 @@ def grow_tree(
443
449
)
444
450
tree .grow_tree (index_leaf_node , new_split_node , new_left_node , new_right_node )
445
451
446
- return index_selected_predictor
452
+ return True , index_selected_predictor
447
453
448
454
449
455
def get_new_idx_data_points (current_split_node , idx_data_points , X ):
0 commit comments