File tree 2 files changed +17
-7
lines changed 2 files changed +17
-7
lines changed Original file line number Diff line number Diff line change @@ -74,9 +74,9 @@ class BART(Distribution):
74
74
75
75
Parameters
76
76
----------
77
- X : TensorLike
77
+ X : PyTensor Variable, Pandas/Polars DataFrame or Numpy array
78
78
The covariate matrix.
79
- Y : TensorLike
79
+ Y : PyTensor Variable, Pandas/Polar DataFrame/Series,or Numpy array
80
80
The response vector.
81
81
m : int
82
82
Number of trees.
@@ -204,6 +204,16 @@ def preprocess_xy(
204
204
if isinstance (X , (Series , DataFrame )):
205
205
X = X .to_numpy ()
206
206
207
+ try :
208
+ import polars as pl
209
+
210
+ if isinstance (X , (pl .Series , pl .DataFrame )):
211
+ X = X .to_numpy ()
212
+ if isinstance (Y , (pl .Series , pl .DataFrame )):
213
+ Y = Y .to_numpy ()
214
+ except ImportError :
215
+ pass
216
+
207
217
Y = Y .astype (float )
208
218
X = X .astype (float )
209
219
Original file line number Diff line number Diff line change @@ -546,7 +546,7 @@ def _prepare_plot_data(
546
546
547
547
Parameters
548
548
----------
549
- X : PyTensor Variable, Pandas DataFrame or Numpy array
549
+ X : PyTensor Variable, Pandas DataFrame, Polars DataFrame or Numpy array
550
550
Input data.
551
551
Y : array-like
552
552
Target data.
@@ -585,9 +585,9 @@ def _prepare_plot_data(
585
585
if isinstance (X , Variable ):
586
586
X = X .eval ()
587
587
588
- if hasattr (X , "columns" ) and hasattr (X , "values " ):
588
+ if hasattr (X , "columns" ) and hasattr (X , "to_numpy " ):
589
589
x_names = list (X .columns )
590
- X = X .values
590
+ X = X .to_numpy ()
591
591
else :
592
592
x_names = []
593
593
@@ -750,9 +750,9 @@ def plot_variable_importance( # noqa: PLR0915
750
750
else :
751
751
shape = bartrv .eval ().shape [0 ]
752
752
753
- if hasattr (X , "columns" ) and hasattr (X , "values " ):
753
+ if hasattr (X , "columns" ) and hasattr (X , "to_numpy " ):
754
754
labels = X .columns
755
- X = X .values
755
+ X = X .to_numpy ()
756
756
757
757
n_vars = X .shape [1 ]
758
758
You can’t perform that action at this time.
0 commit comments