Skip to content

Commit 0917ab4

Browse files
authored
fix bug (#104)
1 parent aedee25 commit 0917ab4

File tree

2 files changed

+16
-23
lines changed

2 files changed

+16
-23
lines changed

pymc_bart/tree.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,11 +296,11 @@ def _traverse_tree(
296296
params[0][nd_dims] + params[1][nd_dims] * X[..., idx_split_variable]
297297
)
298298
else:
299+
idx_split_variable = node.idx_split_variable
299300
left_node_index, right_node_index = get_idx_left_child(
300301
node_index
301302
), get_idx_right_child(node_index)
302-
idx_split_variable = node.idx_split_variable
303-
if excluded is not None and node.idx_split_variable in excluded:
303+
if excluded is not None and idx_split_variable in excluded:
304304
prop_nvalue_left = self.get_node(left_node_index).nvalue / node.nvalue
305305
stack.append((left_node_index, weights * prop_nvalue_left, idx_split_variable))
306306
stack.append(

pymc_bart/utils.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -408,14 +408,14 @@ def plot_pdp(
408408
fig, axes, shape = _get_axes(bartrv, var_idx, grid, sharey, figsize, ax)
409409

410410
count = 0
411+
fake_X = _create_pdp_data(X, xs_interval, xs_values)
411412
for var in range(len(var_idx)):
412413
excluded = indices[:]
413414
excluded.remove(var)
414-
fake_X, new_x = _create_pdp_data(X, xs_interval, var, xs_values, var_discrete)
415415
p_d = _sample_posterior(
416416
all_trees, X=fake_X, rng=rng, size=samples, excluded=excluded, shape=shape
417417
)
418-
418+
new_x = fake_X[:, var]
419419
for s_i in range(shape):
420420
p_di = func(p_d[:, :, s_i])
421421
if var in var_discrete:
@@ -621,10 +621,8 @@ def _prepare_plot_data(
621621
def _create_pdp_data(
622622
X: npt.NDArray[np.float_],
623623
xs_interval: str,
624-
var: int,
625624
xs_values: Optional[Union[int, List[float]]] = None,
626-
var_discrete: Optional[List[int]] = None,
627-
) -> Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_]]:
625+
) -> npt.NDArray[np.float_]:
628626
"""
629627
Create data for partial dependence plot.
630628
@@ -636,28 +634,23 @@ def _create_pdp_data(
636634
Interval for x-axis. Available options are 'insample', 'linear' or 'quantiles'.
637635
xs_values : int or list
638636
Number of points for 'linear' or list of quantiles for 'quantiles'.
639-
var : int
640-
Index of variable of interest
641-
var_discrete : None or list
642-
Indices of discrete variables.
643637
644638
Returns
645639
-------
646-
Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_]]
647-
A tuple containing a 2D array for the fake_X data and 1D array for new_x data.
640+
npt.NDArray[np.float_]
641+
A 2D array for the fake_X data.
648642
"""
649643
if xs_interval == "insample":
650-
return X, X[:, var]
644+
return X
651645
else:
652-
if var_discrete is not None and var in var_discrete:
653-
new_x = np.unique(X[:, var])
654-
else:
655-
if xs_interval == "linear" and isinstance(xs_values, int):
656-
new_x = np.linspace(np.nanmin(X[:, var]), np.nanmax(X[:, var]), xs_values)
657-
elif xs_interval == "quantiles" and isinstance(xs_values, list):
658-
new_x = np.quantile(X[:, var], q=xs_values)
659-
660-
return np.tile(new_x[:, None], X.shape[1]), new_x
646+
if xs_interval == "linear" and isinstance(xs_values, int):
647+
min_vals = np.min(X, axis=0)
648+
max_vals = np.max(X, axis=0)
649+
fake_X = np.linspace(min_vals, max_vals, num=xs_values, axis=0)
650+
elif xs_interval == "quantiles" and isinstance(xs_values, list):
651+
fake_X = np.quantile(X, q=xs_values, axis=0)
652+
653+
return fake_X
661654

662655

663656
def _smooth_mean(

0 commit comments

Comments
 (0)