@@ -408,14 +408,14 @@ def plot_pdp(
408
408
fig , axes , shape = _get_axes (bartrv , var_idx , grid , sharey , figsize , ax )
409
409
410
410
count = 0
411
+ fake_X = _create_pdp_data (X , xs_interval , xs_values )
411
412
for var in range (len (var_idx )):
412
413
excluded = indices [:]
413
414
excluded .remove (var )
414
- fake_X , new_x = _create_pdp_data (X , xs_interval , var , xs_values , var_discrete )
415
415
p_d = _sample_posterior (
416
416
all_trees , X = fake_X , rng = rng , size = samples , excluded = excluded , shape = shape
417
417
)
418
-
418
+ new_x = fake_X [:, var ]
419
419
for s_i in range (shape ):
420
420
p_di = func (p_d [:, :, s_i ])
421
421
if var in var_discrete :
@@ -621,10 +621,8 @@ def _prepare_plot_data(
621
621
def _create_pdp_data (
622
622
X : npt .NDArray [np .float_ ],
623
623
xs_interval : str ,
624
- var : int ,
625
624
xs_values : Optional [Union [int , List [float ]]] = None ,
626
- var_discrete : Optional [List [int ]] = None ,
627
- ) -> Tuple [npt .NDArray [np .float_ ], npt .NDArray [np .float_ ]]:
625
+ ) -> npt .NDArray [np .float_ ]:
628
626
"""
629
627
Create data for partial dependence plot.
630
628
@@ -636,28 +634,23 @@ def _create_pdp_data(
636
634
Interval for x-axis. Available options are 'insample', 'linear' or 'quantiles'.
637
635
xs_values : int or list
638
636
Number of points for 'linear' or list of quantiles for 'quantiles'.
639
- var : int
640
- Index of variable of interest
641
- var_discrete : None or list
642
- Indices of discrete variables.
643
637
644
638
Returns
645
639
-------
646
- Tuple[ npt.NDArray[np.float_], npt.NDArray[np.float_] ]
647
- A tuple containing a 2D array for the fake_X data and 1D array for new_x data.
640
+ npt.NDArray[np.float_]
641
+ A 2D array for the fake_X data.
648
642
"""
649
643
if xs_interval == "insample" :
650
- return X , X [:, var ]
644
+ return X
651
645
else :
652
- if var_discrete is not None and var in var_discrete :
653
- new_x = np .unique (X [:, var ])
654
- else :
655
- if xs_interval == "linear" and isinstance (xs_values , int ):
656
- new_x = np .linspace (np .nanmin (X [:, var ]), np .nanmax (X [:, var ]), xs_values )
657
- elif xs_interval == "quantiles" and isinstance (xs_values , list ):
658
- new_x = np .quantile (X [:, var ], q = xs_values )
659
-
660
- return np .tile (new_x [:, None ], X .shape [1 ]), new_x
646
+ if xs_interval == "linear" and isinstance (xs_values , int ):
647
+ min_vals = np .min (X , axis = 0 )
648
+ max_vals = np .max (X , axis = 0 )
649
+ fake_X = np .linspace (min_vals , max_vals , num = xs_values , axis = 0 )
650
+ elif xs_interval == "quantiles" and isinstance (xs_values , list ):
651
+ fake_X = np .quantile (X , q = xs_values , axis = 0 )
652
+
653
+ return fake_X
661
654
662
655
663
656
def _smooth_mean (
0 commit comments