@@ -68,7 +68,6 @@ def __init__(self, vars=None, num_particles=40, max_stages=100, batch="auto", mo
68
68
self .m = self .bart .m
69
69
self .alpha = self .bart .alpha
70
70
self .k = self .bart .k
71
- self .response = self .bart .response
72
71
self .alpha_vec = self .bart .split_prior
73
72
if self .alpha_vec is None :
74
73
self .alpha_vec = np .ones (self .X .shape [1 ])
@@ -90,10 +89,8 @@ def __init__(self, vars=None, num_particles=40, max_stages=100, batch="auto", mo
90
89
self .a_tree = Tree .init_tree (
91
90
leaf_node_value = self .init_mean / self .m ,
92
91
idx_data_points = np .arange (self .num_observations , dtype = "int32" ),
93
- m = self .m ,
94
92
)
95
93
self .mean = fast_mean ()
96
- self .linear_fit = fast_linear_fit ()
97
94
98
95
self .normal = NormalSampler ()
99
96
self .prior_prob_leaf_node = compute_prior_probability (self .alpha )
@@ -140,11 +137,9 @@ def astep(self, _):
140
137
self .sum_trees ,
141
138
self .X ,
142
139
self .mean ,
143
- self .linear_fit ,
144
140
self .m ,
145
141
self .normal ,
146
142
self .mu_std ,
147
- self .response ,
148
143
)
149
144
150
145
# The old tree and the one with new leafs do not grow so we update the weights only once
@@ -162,11 +157,9 @@ def astep(self, _):
162
157
self .missing_data ,
163
158
self .sum_trees ,
164
159
self .mean ,
165
- self .linear_fit ,
166
160
self .m ,
167
161
self .normal ,
168
162
self .mu_std ,
169
- self .response ,
170
163
)
171
164
if tree_grew :
172
165
self .update_weight (p )
@@ -286,11 +279,9 @@ def sample_tree(
286
279
missing_data ,
287
280
sum_trees ,
288
281
mean ,
289
- linear_fit ,
290
282
m ,
291
283
normal ,
292
284
mu_std ,
293
- response ,
294
285
):
295
286
tree_grew = False
296
287
if self .expansion_nodes :
@@ -308,11 +299,9 @@ def sample_tree(
308
299
missing_data ,
309
300
sum_trees ,
310
301
mean ,
311
- linear_fit ,
312
302
m ,
313
303
normal ,
314
304
mu_std ,
315
- response ,
316
305
)
317
306
if index_selected_predictor is not None :
318
307
new_indexes = self .tree .idx_leaf_nodes [- 2 :]
@@ -322,9 +311,9 @@ def sample_tree(
322
311
323
312
return tree_grew
324
313
325
- def sample_leafs (self , sum_trees , X , mean , linear_fit , m , normal , mu_std , response ):
314
+ def sample_leafs (self , sum_trees , X , mean , m , normal , mu_std ):
326
315
327
- sample_leaf_values (self .tree , sum_trees , X , mean , linear_fit , m , normal , mu_std , response )
316
+ sample_leaf_values (self .tree , sum_trees , X , mean , m , normal , mu_std )
328
317
329
318
330
319
class SampleSplittingVariable :
@@ -379,11 +368,9 @@ def grow_tree(
379
368
missing_data ,
380
369
sum_trees ,
381
370
mean ,
382
- linear_fit ,
383
371
m ,
384
372
normal ,
385
373
mu_std ,
386
- response ,
387
374
):
388
375
current_node = tree .get_node (index_leaf_node )
389
376
idx_data_points = current_node .idx_data_points
@@ -409,28 +396,22 @@ def grow_tree(
409
396
current_node .get_idx_right_child (),
410
397
)
411
398
412
- if response == "mix" :
413
- response = "linear" if np .random .random () >= 0.5 else "constant"
414
-
415
399
new_nodes = []
416
400
for idx in range (2 ):
417
401
idx_data_point = new_idx_data_points [idx ]
418
- node_value , node_linear_params = draw_leaf_value (
402
+ node_value = draw_leaf_value (
419
403
sum_trees [idx_data_point ],
420
404
X [idx_data_point , selected_predictor ],
421
405
mean ,
422
- linear_fit ,
423
406
m ,
424
407
normal ,
425
408
mu_std ,
426
- response ,
427
409
)
428
410
429
411
new_node = LeafNode (
430
412
index = current_node_children [idx ],
431
413
value = node_value ,
432
414
idx_data_points = idx_data_point ,
433
- linear_params = node_linear_params ,
434
415
)
435
416
new_nodes .append (new_node )
436
417
@@ -449,26 +430,23 @@ def grow_tree(
449
430
return index_selected_predictor
450
431
451
432
452
- def sample_leaf_values (tree , sum_trees , X , mean , linear_fit , m , normal , mu_std , response ):
433
+ def sample_leaf_values (tree , sum_trees , X , mean , m , normal , mu_std ):
453
434
454
435
for idx in tree .idx_leaf_nodes :
455
436
if idx > 0 :
456
437
leaf = tree [idx ]
457
438
idx_data_points = leaf .idx_data_points
458
439
parent_node = tree [leaf .get_idx_parent_node ()]
459
440
selected_predictor = parent_node .idx_split_variable
460
- node_value , node_linear_params = draw_leaf_value (
441
+ node_value = draw_leaf_value (
461
442
sum_trees [idx_data_points ],
462
443
X [idx_data_points , selected_predictor ],
463
444
mean ,
464
- linear_fit ,
465
445
m ,
466
446
normal ,
467
447
mu_std ,
468
- response ,
469
448
)
470
449
leaf .value = node_value
471
- leaf .linear_params = node_linear_params
472
450
473
451
474
452
def get_new_idx_data_points (split_value , idx_data_points , selected_predictor , X ):
@@ -480,24 +458,19 @@ def get_new_idx_data_points(split_value, idx_data_points, selected_predictor, X)
480
458
return left_node_idx_data_points , right_node_idx_data_points
481
459
482
460
483
- def draw_leaf_value (Y_mu_pred , X_mu , mean , linear_fit , m , normal , mu_std , response ):
461
+ def draw_leaf_value (Y_mu_pred , X_mu , mean , m , normal , mu_std ):
484
462
"""Draw Gaussian distributed leaf values"""
485
- linear_params = None
486
463
if Y_mu_pred .size == 0 :
487
- return 0 , linear_params
464
+ return 0
488
465
else :
489
466
norm = normal .random () * mu_std
490
467
if Y_mu_pred .size == 1 :
491
468
mu_mean = Y_mu_pred .item () / m
492
- elif response == "constant" :
469
+ else :
493
470
mu_mean = mean (Y_mu_pred ) / m
494
- elif response == "linear" :
495
- Y_fit , linear_params = linear_fit (X_mu , Y_mu_pred )
496
- mu_mean = Y_fit / m
497
- linear_params [2 ] = norm
498
471
499
472
draw = norm + mu_mean
500
- return draw , linear_params
473
+ return draw
501
474
502
475
503
476
def fast_mean ():
@@ -518,32 +491,6 @@ def mean(a):
518
491
return mean
519
492
520
493
521
- def fast_linear_fit ():
522
- """If available use Numba to speed up the computation of the linear fit"""
523
-
524
- def linear_fit (X , Y ):
525
-
526
- n = len (Y )
527
- xbar = np .sum (X ) / n
528
- ybar = np .sum (Y ) / n
529
-
530
- den = X @ X - n * xbar ** 2
531
- if den > 1e-10 :
532
- b = (X @ Y - n * xbar * ybar ) / den
533
- else :
534
- b = 0
535
- a = ybar - b * xbar
536
- Y_fit = a + b * X
537
- return Y_fit , [a , b , 0 ]
538
-
539
- try :
540
- from numba import jit
541
-
542
- return jit (linear_fit )
543
- except ImportError :
544
- return linear_fit
545
-
546
-
547
494
def discrete_uniform_sampler (upper_value ):
548
495
"""Draw from the uniform distribution with bounds [0, upper_value).
549
496
0 commit comments