Skip to content

Commit e430257

Browse files
authored
Merge pull request #4286 from MarcoGorelli/dont-convert-everything
only interchange necessary columns
2 parents da860db + f189576 commit e430257

File tree

3 files changed

+97
-31
lines changed

3 files changed

+97
-31
lines changed

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ This project adheres to [Semantic Versioning](http://semver.org/).
2323
this feature was anonymously sponsored: thank you to our sponsor!
2424
- Add `legend.xref` and `legend.yref` to enable container-referenced positioning of legends [[#6589](https://github.com/plotly/plotly.js/pull/6589)], with thanks to [Gamma Technologies](https://www.gtisoft.com/) for sponsoring the related development.
2525
- Add `colorbar.xref` and `colorbar.yref` to enable container-referenced positioning of colorbars [[#6593](https://github.com/plotly/plotly.js/pull/6593)], with thanks to [Gamma Technologies](https://www.gtisoft.com/) for sponsoring the related development.
26-
- `px` methods now accept data-frame-like objects that support a `to_pandas()` method, such as polars, cudf, vaex etc
26+
- `px` methods now accept data-frame-like objects that support a `to_pandas()` method, such as polars, cudf, vaex etc [[#4244](https://github.com/plotly/plotly.py/pull/4244)], [[#4286](https://github.com/plotly/plotly.py/pull/4286)]
2727

2828
### Fixed
2929
- Fixed another compatibility issue with Pandas 2.0, just affecting `px.*(line_close=True)` [[#4190](https://github.com/plotly/plotly.py/pull/4190)]

packages/python/plotly/plotly/express/_core.py

+56-27
Original file line numberDiff line numberDiff line change
@@ -1018,7 +1018,7 @@ def _get_reserved_col_names(args):
10181018
return reserved_names
10191019

10201020

1021-
def _is_col_list(df_input, arg):
1021+
def _is_col_list(columns, arg):
10221022
"""Returns True if arg looks like it's a list of columns or references to columns
10231023
in df_input, and False otherwise (in which case it's assumed to be a single column
10241024
or reference to a column).
@@ -1033,7 +1033,7 @@ def _is_col_list(df_input, arg):
10331033
return False # not iterable
10341034
for c in arg:
10351035
if isinstance(c, str) or isinstance(c, int):
1036-
if df_input is None or c not in df_input.columns:
1036+
if columns is None or c not in columns:
10371037
return False
10381038
else:
10391039
try:
@@ -1059,8 +1059,8 @@ def _isinstance_listlike(x):
10591059
return True
10601060

10611061

1062-
def _escape_col_name(df_input, col_name, extra):
1063-
while df_input is not None and (col_name in df_input.columns or col_name in extra):
1062+
def _escape_col_name(columns, col_name, extra):
1063+
while columns is not None and (col_name in columns or col_name in extra):
10641064
col_name = "_" + col_name
10651065
return col_name
10661066

@@ -1307,37 +1307,36 @@ def build_dataframe(args, constructor):
13071307

13081308
# Cast data_frame argument to DataFrame (it could be a numpy array, dict etc.)
13091309
df_provided = args["data_frame"] is not None
1310+
needs_interchanging = False
13101311
if df_provided and not isinstance(args["data_frame"], pd.DataFrame):
13111312
if hasattr(args["data_frame"], "__dataframe__") and version.parse(
13121313
pd.__version__
13131314
) >= version.parse("2.0.2"):
13141315
import pandas.api.interchange
13151316

13161317
df_not_pandas = args["data_frame"]
1317-
try:
1318-
df_pandas = pandas.api.interchange.from_dataframe(df_not_pandas)
1319-
except (ImportError, NotImplementedError) as exc:
1320-
# temporary workaround; developers of third-party libraries themselves
1321-
# should try a different implementation, if available. For example:
1322-
# def __dataframe__(self, ...):
1323-
# if not some_condition:
1324-
# self.to_pandas(...)
1325-
if not hasattr(df_not_pandas, "to_pandas"):
1326-
raise exc
1327-
df_pandas = df_not_pandas.to_pandas()
1328-
args["data_frame"] = df_pandas
1318+
args["data_frame"] = df_not_pandas.__dataframe__()
1319+
columns = args["data_frame"].column_names()
1320+
needs_interchanging = True
13291321
elif hasattr(args["data_frame"], "to_pandas"):
13301322
args["data_frame"] = args["data_frame"].to_pandas()
1323+
columns = args["data_frame"].columns
13311324
else:
13321325
args["data_frame"] = pd.DataFrame(args["data_frame"])
1326+
columns = args["data_frame"].columns
1327+
elif df_provided:
1328+
columns = args["data_frame"].columns
1329+
else:
1330+
columns = None
1331+
13331332
df_input = args["data_frame"]
13341333

13351334
# now we handle special cases like wide-mode or x-xor-y specification
13361335
# by rearranging args to tee things up for process_args_into_dataframe to work
13371336
no_x = args.get("x") is None
13381337
no_y = args.get("y") is None
1339-
wide_x = False if no_x else _is_col_list(df_input, args["x"])
1340-
wide_y = False if no_y else _is_col_list(df_input, args["y"])
1338+
wide_x = False if no_x else _is_col_list(columns, args["x"])
1339+
wide_y = False if no_y else _is_col_list(columns, args["y"])
13411340

13421341
wide_mode = False
13431342
var_name = None # will likely be "variable" in wide_mode
@@ -1352,15 +1351,18 @@ def build_dataframe(args, constructor):
13521351
)
13531352
if df_provided and no_x and no_y:
13541353
wide_mode = True
1355-
if isinstance(df_input.columns, pd.MultiIndex):
1354+
if isinstance(columns, pd.MultiIndex):
13561355
raise TypeError(
13571356
"Data frame columns is a pandas MultiIndex. "
13581357
"pandas MultiIndex is not supported by plotly express "
13591358
"at the moment."
13601359
)
1361-
args["wide_variable"] = list(df_input.columns)
1362-
var_name = df_input.columns.name
1363-
if var_name in [None, "value", "index"] or var_name in df_input:
1360+
args["wide_variable"] = list(columns)
1361+
if isinstance(columns, pd.Index):
1362+
var_name = columns.name
1363+
else:
1364+
var_name = None
1365+
if var_name in [None, "value", "index"] or var_name in columns:
13641366
var_name = "variable"
13651367
if constructor == go.Funnel:
13661368
wide_orientation = args.get("orientation") or "h"
@@ -1371,12 +1373,12 @@ def build_dataframe(args, constructor):
13711373
elif wide_x != wide_y:
13721374
wide_mode = True
13731375
args["wide_variable"] = args["y"] if wide_y else args["x"]
1374-
if df_provided and args["wide_variable"] is df_input.columns:
1375-
var_name = df_input.columns.name
1376+
if df_provided and args["wide_variable"] is columns:
1377+
var_name = columns.name
13761378
if isinstance(args["wide_variable"], pd.Index):
13771379
args["wide_variable"] = list(args["wide_variable"])
13781380
if var_name in [None, "value", "index"] or (
1379-
df_provided and var_name in df_input
1381+
df_provided and var_name in columns
13801382
):
13811383
var_name = "variable"
13821384
if hist1d_orientation:
@@ -1389,8 +1391,35 @@ def build_dataframe(args, constructor):
13891391
wide_cross_name = "__x__" if wide_y else "__y__"
13901392

13911393
if wide_mode:
1392-
value_name = _escape_col_name(df_input, "value", [])
1393-
var_name = _escape_col_name(df_input, var_name, [])
1394+
value_name = _escape_col_name(columns, "value", [])
1395+
var_name = _escape_col_name(columns, var_name, [])
1396+
1397+
if needs_interchanging:
1398+
try:
1399+
if wide_mode or not hasattr(args["data_frame"], "select_columns_by_name"):
1400+
args["data_frame"] = pd.api.interchange.from_dataframe(
1401+
args["data_frame"]
1402+
)
1403+
else:
1404+
# Save precious resources by only interchanging columns that are
1405+
# actually going to be plotted.
1406+
columns = [
1407+
i for i in args.values() if isinstance(i, str) and i in columns
1408+
]
1409+
args["data_frame"] = pd.api.interchange.from_dataframe(
1410+
args["data_frame"].select_columns_by_name(columns)
1411+
)
1412+
except (ImportError, NotImplementedError) as exc:
1413+
# temporary workaround; developers of third-party libraries themselves
1414+
# should try a different implementation, if available. For example:
1415+
# def __dataframe__(self, ...):
1416+
# if not some_condition:
1417+
# self.to_pandas(...)
1418+
if not hasattr(df_not_pandas, "to_pandas"):
1419+
raise exc
1420+
args["data_frame"] = df_not_pandas.to_pandas()
1421+
1422+
df_input = args["data_frame"]
13941423

13951424
missing_bar_dim = None
13961425
if (

packages/python/plotly/plotly/tests/test_optional/test_px/test_px_input.py

+40-3
Original file line numberDiff line numberDiff line change
@@ -252,21 +252,58 @@ def test_build_df_with_index():
252252
def test_build_df_using_interchange_protocol_mock(
253253
add_interchange_module_for_old_pandas,
254254
):
255+
class InterchangeDataFrame:
256+
def __init__(self, columns):
257+
self._columns = columns
258+
259+
def column_names(self):
260+
return self._columns
261+
262+
interchange_dataframe = InterchangeDataFrame(
263+
["petal_width", "sepal_length", "sepal_width"]
264+
)
265+
interchange_dataframe_reduced = InterchangeDataFrame(
266+
["petal_width", "sepal_length"]
267+
)
268+
interchange_dataframe.select_columns_by_name = mock.MagicMock(
269+
return_value=interchange_dataframe_reduced
270+
)
271+
interchange_dataframe_reduced.select_columns_by_name = mock.MagicMock(
272+
return_value=interchange_dataframe_reduced
273+
)
274+
255275
class CustomDataFrame:
256276
def __dataframe__(self):
257-
pass
277+
return interchange_dataframe
278+
279+
class CustomDataFrameReduced:
280+
def __dataframe__(self):
281+
return interchange_dataframe_reduced
258282

259283
input_dataframe = CustomDataFrame()
260-
args = dict(data_frame=input_dataframe, x="petal_width", y="sepal_length")
284+
input_dataframe_reduced = CustomDataFrameReduced()
261285

262286
iris_pandas = px.data.iris()
263287

264288
with mock.patch("pandas.__version__", "2.0.2"):
289+
args = dict(data_frame=input_dataframe, x="petal_width", y="sepal_length")
265290
with mock.patch(
266291
"pandas.api.interchange.from_dataframe", return_value=iris_pandas
267292
) as mock_from_dataframe:
268293
build_dataframe(args, go.Scatter)
269-
mock_from_dataframe.assert_called_once_with(input_dataframe)
294+
mock_from_dataframe.assert_called_once_with(interchange_dataframe_reduced)
295+
interchange_dataframe.select_columns_by_name.assert_called_with(
296+
["petal_width", "sepal_length"]
297+
)
298+
299+
args = dict(data_frame=input_dataframe_reduced, color=None)
300+
with mock.patch(
301+
"pandas.api.interchange.from_dataframe",
302+
return_value=iris_pandas[["petal_width", "sepal_length"]],
303+
) as mock_from_dataframe:
304+
build_dataframe(args, go.Scatter)
305+
mock_from_dataframe.assert_called_once_with(interchange_dataframe_reduced)
306+
interchange_dataframe_reduced.select_columns_by_name.assert_not_called()
270307

271308

272309
@pytest.mark.skipif(

0 commit comments

Comments
 (0)