Skip to content

Commit 52a126d

Browse files
authored
BART: add partial dependence plots and individual conditional expectation plots (#5091)
* add utils for prediction and interpretability * remove file * add tests * test mixed response * fix tests * update release notes * remove unused import
1 parent 80bf823 commit 52a126d

File tree

8 files changed

+340
-70
lines changed

8 files changed

+340
-70
lines changed

RELEASE-NOTES.md

+4
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
- `pm.DensityDist` no longer accepts the `logp` as its first position argument. It is now an optional keyword argument. If you pass a callable as the first positional argument, a `TypeError` will be raised (see [5026](https://github.com/pymc-devs/pymc3/pull/5026)).
1616
- `pm.DensityDist` now accepts distribution parameters as positional arguments. Passing them as a dictionary in the `observed` keyword argument is no longer supported and will raise an error (see [5026](https://github.com/pymc-devs/pymc3/pull/5026)).
1717
- The signature of the `logp` and `random` functions that can be passed into a `pm.DensityDist` has been changed (see [5026](https://github.com/pymc-devs/pymc3/pull/5026)).
18+
- Generalize BART. A BART variable can be combined with other random variables. The `inv_link` argument has been removed (see [4914](https://github.com/pymc-devs/pymc3/pull/4914)).
19+
- Move BART to its own module (see [5058](https://github.com/pymc-devs/pymc3/pull/5058)).
1820
- ...
1921

2022
### New Features
@@ -32,6 +34,8 @@
3234
- New experimental mass matrix tuning method jitter+adapt_diag_grad. [#5004](https://github.com/pymc-devs/pymc/pull/5004)
3335
- `pm.DensityDist` can now accept an optional `logcdf` keyword argument to pass in a function to compute the cummulative density function of the distribution (see [5026](https://github.com/pymc-devs/pymc3/pull/5026)).
3436
- `pm.DensityDist` can now accept an optional `get_moment` keyword argument to pass in a function to compute the moment of the distribution (see [5026](https://github.com/pymc-devs/pymc3/pull/5026)).
37+
- BART: add linear response, increase number of trees fitted per step [5044](https://github.com/pymc-devs/pymc3/pull/5044).
38+
- BART: add partial dependence plots and individual conditional expectation plots [5091](https://github.com/pymc-devs/pymc3/pull/5091).
3539
- ...
3640

3741
### Maintenance

pymc/bart/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,6 @@
1515

1616
from pymc.bart.bart import BART
1717
from pymc.bart.pgbart import PGBART
18+
from pymc.bart.utils import plot_dependence, predict
1819

1920
__all__ = ["BART", "PGBART"]

pymc/bart/bart.py

+1-31
Original file line numberDiff line numberDiff line change
@@ -41,35 +41,7 @@ def _shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None):
4141

4242
@classmethod
4343
def rng_fn(cls, rng=np.random.default_rng(), *args, **kwargs):
44-
size = kwargs.pop("size", None)
45-
X_new = kwargs.pop("X_new", None)
46-
all_trees = cls.all_trees
47-
if all_trees:
48-
49-
if size is None:
50-
size = ()
51-
elif isinstance(size, int):
52-
size = [size]
53-
54-
flatten_size = 1
55-
for s in size:
56-
flatten_size *= s
57-
58-
idx = rng.randint(len(all_trees), size=flatten_size)
59-
60-
if X_new is None:
61-
pred = np.zeros((flatten_size, all_trees[0][0].num_observations))
62-
for ind, p in enumerate(pred):
63-
for tree in all_trees[idx[ind]]:
64-
p += tree.predict_output()
65-
else:
66-
pred = np.zeros((flatten_size, X_new.shape[0]))
67-
for ind, p in enumerate(pred):
68-
for tree in all_trees[idx[ind]]:
69-
p += np.array([tree.predict_out_of_sample(x, cls.m) for x in X_new])
70-
return pred.reshape((*size, -1))
71-
else:
72-
return np.full_like(cls.Y, cls.Y.mean())
44+
return np.full_like(cls.Y, cls.Y.mean())
7345

7446

7547
bart = BARTRV()
@@ -117,15 +89,13 @@ def __new__(
11789
**kwargs,
11890
):
11991

120-
cls.all_trees = []
12192
X, Y = preprocess_XY(X, Y)
12293

12394
bart_op = type(
12495
f"BART_{name}",
12596
(BARTRV,),
12697
dict(
12798
name="BART",
128-
all_trees=cls.all_trees,
12999
inplace=False,
130100
initval=Y.mean(),
131101
X=X,

pymc/bart/pgbart.py

+11-17
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import logging
1616

17+
from copy import copy
1718
from typing import Any, Dict, List, Tuple
1819

1920
import aesara
@@ -121,7 +122,7 @@ class PGBART(ArrayStepShared):
121122
name = "bartsampler"
122123
default_blocked = False
123124
generates_stats = True
124-
stats_dtypes = [{"variable_inclusion": np.ndarray}]
125+
stats_dtypes = [{"variable_inclusion": np.ndarray, "bart_trees": np.ndarray}]
125126

126127
def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", model=None):
127128
_log.warning("BART is experimental. Use with caution.")
@@ -159,6 +160,7 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo
159160
tree_id=0,
160161
leaf_node_value=self.init_mean / self.m,
161162
idx_data_points=np.arange(self.num_observations, dtype="int32"),
163+
m=self.m,
162164
)
163165
self.mean = fast_mean()
164166
self.linear_fit = fast_linear_fit()
@@ -169,8 +171,6 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo
169171

170172
self.tune = True
171173
self.idx = 0
172-
self.iter = 0
173-
self.sum_trees = []
174174
self.batch = batch
175175

176176
if self.batch == "auto":
@@ -193,12 +193,12 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo
193193
self.init_likelihood,
194194
)
195195
self.all_particles.append(p)
196+
self.all_trees = np.array([p.tree for p in self.all_particles])
196197
super().__init__(vars, shared)
197198

198199
def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
199200
point_map_info = q.point_map_info
200201
sum_trees_output = q.data
201-
202202
variable_inclusion = np.zeros(self.num_variates, dtype="int")
203203

204204
if self.idx == self.m:
@@ -212,7 +212,6 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
212212
particles = self.init_particles(tree_id)
213213
# Compute the sum of trees without the tree we are attempting to replace
214214
self.sum_trees_output_noi = sum_trees_output - particles[0].tree.predict_output()
215-
self.idx += 1
216215

217216
# The old tree is not growing so we update the weights only once.
218217
self.update_weight(particles[0])
@@ -258,6 +257,7 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
258257
# Get the new tree and update
259258
new_particle = np.random.choice(particles, p=normalized_weights)
260259
new_tree = new_particle.tree
260+
self.all_trees[self.idx] = new_tree
261261
new_particle.log_weight = new_particle.old_likelihood_logp - self.log_num_particles
262262
self.all_particles[tree_id] = new_particle
263263
sum_trees_output = self.sum_trees_output_noi + new_tree.predict_output()
@@ -268,17 +268,11 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
268268
self.ssv = SampleSplittingVariable(self.split_prior)
269269
else:
270270
self.batch = max(1, int(self.m * 0.2))
271-
self.iter += 1
272-
self.sum_trees.append(new_tree)
273-
if not self.iter % self.m:
274-
# XXX update the all_trees variable in BARTRV to be used in the rng_fn method
275-
# this fails for chains > 1 as the variable is not shared between proccesses
276-
self.bart.all_trees.append(self.sum_trees)
277-
self.sum_trees = []
278271
for index in new_particle.used_variates:
279272
variable_inclusion[index] += 1
273+
self.idx += 1
280274

281-
stats = {"variable_inclusion": variable_inclusion}
275+
stats = {"variable_inclusion": variable_inclusion, "bart_trees": copy(self.all_trees)}
282276
sum_trees_output = RaveledVars(sum_trees_output, point_map_info)
283277
return sum_trees_output, [stats]
284278

@@ -526,11 +520,11 @@ def linear_fit(X, Y):
526520
xbar = np.sum(X) / n
527521
ybar = np.sum(Y) / n
528522

529-
if np.all(X == xbar):
530-
b = 0
523+
den = X @ X - n * xbar ** 2
524+
if den > 1e-10:
525+
b = (X @ Y - n * xbar * ybar) / den
531526
else:
532-
b = (X @ Y - n * xbar * ybar) / (X @ X - n * xbar ** 2)
533-
527+
b = 0
534528
a = ybar - b * xbar
535529
Y_fit = a + b * X
536530
return Y_fit, [a, b, 0]

pymc/bart/tree.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -45,21 +45,23 @@ class Tree:
4545
Identifier used to get the previous tree in the ParticleGibbs algorithm used in BART.
4646
num_observations : int
4747
Number of observations used to fit BART.
48-
48+
m : int
49+
Number of trees
4950
5051
Parameters
5152
----------
5253
tree_id : int, optional
5354
num_observations : int, optional
5455
"""
5556

56-
def __init__(self, tree_id=0, num_observations=0):
57+
def __init__(self, tree_id=0, num_observations=0, m=0):
5758
self.tree_structure = {}
5859
self.num_nodes = 0
5960
self.idx_leaf_nodes = []
6061
self.idx_prunable_split_nodes = []
6162
self.tree_id = tree_id
6263
self.num_observations = num_observations
64+
self.m = m
6365

6466
def __getitem__(self, index):
6567
return self.get_node(index)
@@ -94,16 +96,14 @@ def predict_output(self):
9496

9597
return output.astype(aesara.config.floatX)
9698

97-
def predict_out_of_sample(self, X, m):
99+
def predict_out_of_sample(self, X):
98100
"""
99101
Predict output of tree for an unobserved point x.
100102
101103
Parameters
102104
----------
103105
X : numpy array
104106
Unobserved point
105-
m : int
106-
Number of trees
107107
108108
Returns
109109
-------
@@ -116,7 +116,7 @@ def predict_out_of_sample(self, X, m):
116116
return leaf_node.value
117117
else:
118118
x = X[split_variable].item()
119-
y_x = (linear_params[0] + linear_params[1] * x) / m
119+
y_x = (linear_params[0] + linear_params[1] * x) / self.m
120120
return y_x + linear_params[2]
121121

122122
def _traverse_tree(self, x, node_index=0, split_variable=None):
@@ -170,20 +170,22 @@ def grow_tree(self, index_leaf_node, new_split_node, new_left_node, new_right_no
170170
self.idx_prunable_split_nodes.remove(parent_index)
171171

172172
@staticmethod
173-
def init_tree(tree_id, leaf_node_value, idx_data_points):
173+
def init_tree(tree_id, leaf_node_value, idx_data_points, m):
174174
"""
175175
176176
Parameters
177177
----------
178178
tree_id
179179
leaf_node_value
180180
idx_data_points
181+
m : int
182+
number of trees in BART
181183
182184
Returns
183185
-------
184186
185187
"""
186-
new_tree = Tree(tree_id, len(idx_data_points))
188+
new_tree = Tree(tree_id, len(idx_data_points), m)
187189
new_tree[0] = LeafNode(index=0, value=leaf_node_value, idx_data_points=idx_data_points)
188190
return new_tree
189191

0 commit comments

Comments
 (0)