Skip to content

Commit f023e32

Browse files
committed
remove linear and mix response
1 parent 0167c88 commit f023e32

File tree

4 files changed

+22
-89
lines changed

4 files changed

+22
-89
lines changed

Diff for: pymc/bart/bart.py

-5
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,6 @@ class BART(NoDistribution):
6767
k : float
6868
Scale parameter for the values of the leaf nodes. Defaults to 2. Recomended to be between 1
6969
and 3.
70-
response : str
71-
How the leaf_node values are computed. Available options are ``constant`` (default),
72-
``linear`` or ``mix``.
7370
split_prior : array-like
7471
Each element of split_prior should be in the [0, 1] interval and the elements should sum to
7572
1. Otherwise they will be normalized.
@@ -84,7 +81,6 @@ def __new__(
8481
m=50,
8582
alpha=0.25,
8683
k=2,
87-
response="constant",
8884
split_prior=None,
8985
**kwargs,
9086
):
@@ -103,7 +99,6 @@ def __new__(
10399
m=m,
104100
alpha=alpha,
105101
k=k,
106-
response=response,
107102
split_prior=split_prior,
108103
),
109104
)()

Diff for: pymc/bart/pgbart.py

+9-62
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ def __init__(self, vars=None, num_particles=40, max_stages=100, batch="auto", mo
6868
self.m = self.bart.m
6969
self.alpha = self.bart.alpha
7070
self.k = self.bart.k
71-
self.response = self.bart.response
7271
self.alpha_vec = self.bart.split_prior
7372
if self.alpha_vec is None:
7473
self.alpha_vec = np.ones(self.X.shape[1])
@@ -90,10 +89,8 @@ def __init__(self, vars=None, num_particles=40, max_stages=100, batch="auto", mo
9089
self.a_tree = Tree.init_tree(
9190
leaf_node_value=self.init_mean / self.m,
9291
idx_data_points=np.arange(self.num_observations, dtype="int32"),
93-
m=self.m,
9492
)
9593
self.mean = fast_mean()
96-
self.linear_fit = fast_linear_fit()
9794

9895
self.normal = NormalSampler()
9996
self.prior_prob_leaf_node = compute_prior_probability(self.alpha)
@@ -140,11 +137,9 @@ def astep(self, _):
140137
self.sum_trees,
141138
self.X,
142139
self.mean,
143-
self.linear_fit,
144140
self.m,
145141
self.normal,
146142
self.mu_std,
147-
self.response,
148143
)
149144

150145
# The old tree and the one with new leafs do not grow so we update the weights only once
@@ -162,11 +157,9 @@ def astep(self, _):
162157
self.missing_data,
163158
self.sum_trees,
164159
self.mean,
165-
self.linear_fit,
166160
self.m,
167161
self.normal,
168162
self.mu_std,
169-
self.response,
170163
)
171164
if tree_grew:
172165
self.update_weight(p)
@@ -286,11 +279,9 @@ def sample_tree(
286279
missing_data,
287280
sum_trees,
288281
mean,
289-
linear_fit,
290282
m,
291283
normal,
292284
mu_std,
293-
response,
294285
):
295286
tree_grew = False
296287
if self.expansion_nodes:
@@ -308,11 +299,9 @@ def sample_tree(
308299
missing_data,
309300
sum_trees,
310301
mean,
311-
linear_fit,
312302
m,
313303
normal,
314304
mu_std,
315-
response,
316305
)
317306
if index_selected_predictor is not None:
318307
new_indexes = self.tree.idx_leaf_nodes[-2:]
@@ -322,9 +311,9 @@ def sample_tree(
322311

323312
return tree_grew
324313

325-
def sample_leafs(self, sum_trees, X, mean, linear_fit, m, normal, mu_std, response):
314+
def sample_leafs(self, sum_trees, X, mean, m, normal, mu_std):
326315

327-
sample_leaf_values(self.tree, sum_trees, X, mean, linear_fit, m, normal, mu_std, response)
316+
sample_leaf_values(self.tree, sum_trees, X, mean, m, normal, mu_std)
328317

329318

330319
class SampleSplittingVariable:
@@ -379,11 +368,9 @@ def grow_tree(
379368
missing_data,
380369
sum_trees,
381370
mean,
382-
linear_fit,
383371
m,
384372
normal,
385373
mu_std,
386-
response,
387374
):
388375
current_node = tree.get_node(index_leaf_node)
389376
idx_data_points = current_node.idx_data_points
@@ -409,28 +396,22 @@ def grow_tree(
409396
current_node.get_idx_right_child(),
410397
)
411398

412-
if response == "mix":
413-
response = "linear" if np.random.random() >= 0.5 else "constant"
414-
415399
new_nodes = []
416400
for idx in range(2):
417401
idx_data_point = new_idx_data_points[idx]
418-
node_value, node_linear_params = draw_leaf_value(
402+
node_value = draw_leaf_value(
419403
sum_trees[idx_data_point],
420404
X[idx_data_point, selected_predictor],
421405
mean,
422-
linear_fit,
423406
m,
424407
normal,
425408
mu_std,
426-
response,
427409
)
428410

429411
new_node = LeafNode(
430412
index=current_node_children[idx],
431413
value=node_value,
432414
idx_data_points=idx_data_point,
433-
linear_params=node_linear_params,
434415
)
435416
new_nodes.append(new_node)
436417

@@ -449,26 +430,23 @@ def grow_tree(
449430
return index_selected_predictor
450431

451432

452-
def sample_leaf_values(tree, sum_trees, X, mean, linear_fit, m, normal, mu_std, response):
433+
def sample_leaf_values(tree, sum_trees, X, mean, m, normal, mu_std):
453434

454435
for idx in tree.idx_leaf_nodes:
455436
if idx > 0:
456437
leaf = tree[idx]
457438
idx_data_points = leaf.idx_data_points
458439
parent_node = tree[leaf.get_idx_parent_node()]
459440
selected_predictor = parent_node.idx_split_variable
460-
node_value, node_linear_params = draw_leaf_value(
441+
node_value = draw_leaf_value(
461442
sum_trees[idx_data_points],
462443
X[idx_data_points, selected_predictor],
463444
mean,
464-
linear_fit,
465445
m,
466446
normal,
467447
mu_std,
468-
response,
469448
)
470449
leaf.value = node_value
471-
leaf.linear_params = node_linear_params
472450

473451

474452
def get_new_idx_data_points(split_value, idx_data_points, selected_predictor, X):
@@ -480,24 +458,19 @@ def get_new_idx_data_points(split_value, idx_data_points, selected_predictor, X)
480458
return left_node_idx_data_points, right_node_idx_data_points
481459

482460

483-
def draw_leaf_value(Y_mu_pred, X_mu, mean, linear_fit, m, normal, mu_std, response):
461+
def draw_leaf_value(Y_mu_pred, X_mu, mean, m, normal, mu_std):
484462
"""Draw Gaussian distributed leaf values"""
485-
linear_params = None
486463
if Y_mu_pred.size == 0:
487-
return 0, linear_params
464+
return 0
488465
else:
489466
norm = normal.random() * mu_std
490467
if Y_mu_pred.size == 1:
491468
mu_mean = Y_mu_pred.item() / m
492-
elif response == "constant":
469+
else:
493470
mu_mean = mean(Y_mu_pred) / m
494-
elif response == "linear":
495-
Y_fit, linear_params = linear_fit(X_mu, Y_mu_pred)
496-
mu_mean = Y_fit / m
497-
linear_params[2] = norm
498471

499472
draw = norm + mu_mean
500-
return draw, linear_params
473+
return draw
501474

502475

503476
def fast_mean():
@@ -518,32 +491,6 @@ def mean(a):
518491
return mean
519492

520493

521-
def fast_linear_fit():
522-
"""If available use Numba to speed up the computation of the linear fit"""
523-
524-
def linear_fit(X, Y):
525-
526-
n = len(Y)
527-
xbar = np.sum(X) / n
528-
ybar = np.sum(Y) / n
529-
530-
den = X @ X - n * xbar ** 2
531-
if den > 1e-10:
532-
b = (X @ Y - n * xbar * ybar) / den
533-
else:
534-
b = 0
535-
a = ybar - b * xbar
536-
Y_fit = a + b * X
537-
return Y_fit, [a, b, 0]
538-
539-
try:
540-
from numba import jit
541-
542-
return jit(linear_fit)
543-
except ImportError:
544-
return linear_fit
545-
546-
547494
def discrete_uniform_sampler(upper_value):
548495
"""Draw from the uniform distribution with bounds [0, upper_value).
549496

Diff for: pymc/bart/tree.py

+12-21
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,10 @@ class Tree:
4646
num_observations : int, optional
4747
"""
4848

49-
def __init__(self, num_observations=0, m=0):
49+
def __init__(self, num_observations=0):
5050
self.tree_structure = {}
5151
self.idx_leaf_nodes = []
5252
self.num_observations = num_observations
53-
self.m = m
5453

5554
def __getitem__(self, index):
5655
return self.get_node(index)
@@ -97,16 +96,10 @@ def predict_out_of_sample(self, X):
9796
float
9897
Value of the leaf value where the unobserved point lies.
9998
"""
100-
leaf_node, split_variable = self._traverse_tree(X, node_index=0)
101-
linear_params = leaf_node.linear_params
102-
if linear_params is None:
103-
return leaf_node.value
104-
else:
105-
x = X[split_variable].item()
106-
y_x = (linear_params[0] + linear_params[1] * x) / self.m
107-
return y_x + linear_params[2]
108-
109-
def _traverse_tree(self, x, node_index=0, split_variable=None):
99+
leaf_node = self._traverse_tree(X, node_index=0)
100+
return leaf_node.value
101+
102+
def _traverse_tree(self, x, node_index=0):
110103
"""
111104
Traverse the tree starting from a particular node given an unobserved point.
112105
@@ -121,17 +114,16 @@ def _traverse_tree(self, x, node_index=0, split_variable=None):
121114
"""
122115
current_node = self.get_node(node_index)
123116
if isinstance(current_node, SplitNode):
124-
split_variable = current_node.idx_split_variable
125-
if x[split_variable] <= current_node.split_value:
117+
if x[current_node.idx_split_variable] <= current_node.split_value:
126118
left_child = current_node.get_idx_left_child()
127-
current_node, split_variable = self._traverse_tree(x, left_child, split_variable)
119+
current_node = self._traverse_tree(x, left_child)
128120
else:
129121
right_child = current_node.get_idx_right_child()
130-
current_node, split_variable = self._traverse_tree(x, right_child, split_variable)
131-
return current_node, split_variable
122+
current_node = self._traverse_tree(x, right_child)
123+
return current_node
132124

133125
@staticmethod
134-
def init_tree(leaf_node_value, idx_data_points, m):
126+
def init_tree(leaf_node_value, idx_data_points):
135127
"""
136128
137129
Parameters
@@ -145,7 +137,7 @@ def init_tree(leaf_node_value, idx_data_points, m):
145137
-------
146138
147139
"""
148-
new_tree = Tree(len(idx_data_points), m)
140+
new_tree = Tree(len(idx_data_points))
149141
new_tree[0] = LeafNode(index=0, value=leaf_node_value, idx_data_points=idx_data_points)
150142
return new_tree
151143

@@ -174,8 +166,7 @@ def __init__(self, index, idx_split_variable, split_value):
174166

175167

176168
class LeafNode(BaseNode):
177-
def __init__(self, index, value, idx_data_points, linear_params=None):
169+
def __init__(self, index, value, idx_data_points):
178170
super().__init__(index)
179171
self.value = value
180172
self.idx_data_points = idx_data_points
181-
self.linear_params = linear_params

Diff for: pymc/tests/test_bart.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ class TestUtils:
6767
Y = np.random.normal(0, 1, size=50)
6868

6969
with pm.Model() as model:
70-
mu = pm.BART("mu", X, Y, m=10, response="mix")
70+
mu = pm.BART("mu", X, Y, m=10)
7171
sigma = pm.HalfNormal("sigma", 1)
7272
y = pm.Normal("y", mu, sigma, observed=Y)
7373
idata = pm.sample(random_seed=3415)

0 commit comments

Comments
 (0)