Skip to content

Commit d89e0b6

Browse files
authored
Support Polars (#179)
1 parent e8f258a commit d89e0b6

File tree

2 files changed

+17
-7
lines changed

2 files changed

+17
-7
lines changed

pymc_bart/bart.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,9 @@ class BART(Distribution):
7474
7575
Parameters
7676
----------
77-
X : TensorLike
77+
X : PyTensor Variable, Pandas/Polars DataFrame or Numpy array
7878
The covariate matrix.
79-
Y : TensorLike
79+
Y : PyTensor Variable, Pandas/Polar DataFrame/Series,or Numpy array
8080
The response vector.
8181
m : int
8282
Number of trees.
@@ -204,6 +204,16 @@ def preprocess_xy(
204204
if isinstance(X, (Series, DataFrame)):
205205
X = X.to_numpy()
206206

207+
try:
208+
import polars as pl
209+
210+
if isinstance(X, (pl.Series, pl.DataFrame)):
211+
X = X.to_numpy()
212+
if isinstance(Y, (pl.Series, pl.DataFrame)):
213+
Y = Y.to_numpy()
214+
except ImportError:
215+
pass
216+
207217
Y = Y.astype(float)
208218
X = X.astype(float)
209219

pymc_bart/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -546,7 +546,7 @@ def _prepare_plot_data(
546546
547547
Parameters
548548
----------
549-
X : PyTensor Variable, Pandas DataFrame or Numpy array
549+
X : PyTensor Variable, Pandas DataFrame, Polars DataFrame or Numpy array
550550
Input data.
551551
Y : array-like
552552
Target data.
@@ -585,9 +585,9 @@ def _prepare_plot_data(
585585
if isinstance(X, Variable):
586586
X = X.eval()
587587

588-
if hasattr(X, "columns") and hasattr(X, "values"):
588+
if hasattr(X, "columns") and hasattr(X, "to_numpy"):
589589
x_names = list(X.columns)
590-
X = X.values
590+
X = X.to_numpy()
591591
else:
592592
x_names = []
593593

@@ -750,9 +750,9 @@ def plot_variable_importance( # noqa: PLR0915
750750
else:
751751
shape = bartrv.eval().shape[0]
752752

753-
if hasattr(X, "columns") and hasattr(X, "values"):
753+
if hasattr(X, "columns") and hasattr(X, "to_numpy"):
754754
labels = X.columns
755-
X = X.values
755+
X = X.to_numpy()
756756

757757
n_vars = X.shape[1]
758758

0 commit comments

Comments
 (0)