Skip to content

Commit c838fec

Browse files
authored
Merge pull request #4331 from Karl-Krauth/imshow-xarray-bugfixes
Fix facet_col and animation_frame in px.imshow for xarrays.
2 parents 83e5cfa + d833fd5 commit c838fec

File tree

3 files changed

+44
-3
lines changed

3 files changed

+44
-3
lines changed

Diff for: CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
All notable changes to this project will be documented in this file.
33
This project adheres to [Semantic Versioning](http://semver.org/).
44

5+
## UNRELEASED
6+
- Fixed two issues with px.imshow: [[#4330](https://github.com/plotly/plotly.py/issues/4330)] when facet_col is an earlier dimension than animation_frame for xarrays and [[#4329](https://github.com/plotly/plotly.py/issues/4329)] when facet_col has string coordinates in xarrays [[#4331](https://github.com/plotly/plotly.py/pull/4331)]
7+
58
## [5.16.1] - 2023-08-16
69

710
### Fixed

Diff for: packages/python/plotly/plotly/express/_imshow.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -265,14 +265,18 @@ def imshow(
265265
if xarray_imported and isinstance(img, xarray.DataArray):
266266
dims = list(img.dims)
267267
img_is_xarray = True
268+
pop_indexes = []
268269
if facet_col is not None:
269270
facet_slices = img.coords[img.dims[facet_col]].values
270-
_ = dims.pop(facet_col)
271+
pop_indexes.append(facet_col)
271272
facet_label = img.dims[facet_col]
272273
if animation_frame is not None:
273274
animation_slices = img.coords[img.dims[animation_frame]].values
274-
_ = dims.pop(animation_frame)
275+
pop_indexes.append(animation_frame)
275276
animation_label = img.dims[animation_frame]
277+
# Remove indices in sorted order.
278+
for index in sorted(pop_indexes, reverse=True):
279+
_ = dims.pop(index)
276280
y_label, x_label = dims[0], dims[1]
277281
# np.datetime64 is not handled correctly by go.Heatmap
278282
for ax in [x_label, y_label]:
@@ -541,7 +545,7 @@ def imshow(
541545
slice_label = (
542546
"facet_col" if labels.get("facet_col") is None else labels["facet_col"]
543547
)
544-
col_labels = ["%s=%d" % (slice_label, i) for i in facet_slices]
548+
col_labels = [f"{slice_label}={i}" for i in facet_slices]
545549
fig = init_figure(args, "xy", [], nrows, ncols, col_labels, [])
546550
for attr_name in ["height", "width"]:
547551
if args[attr_name]:

Diff for: packages/python/plotly/plotly/tests/test_optional/test_px/test_imshow.py

+34
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,40 @@ def test_imshow_xarray_slicethrough():
194194
assert np.all(np.array(fig.data[0].x) == np.array(da.coords["dim_2"]))
195195

196196

197+
def test_imshow_xarray_facet_col_string():
198+
img = np.random.random((3, 4, 5))
199+
da = xr.DataArray(
200+
img, dims=["str_dim", "dim_1", "dim_2"], coords={"str_dim": ["A", "B", "C"]}
201+
)
202+
fig = px.imshow(da, facet_col="str_dim")
203+
# Dimensions are used for axis labels and coordinates
204+
assert fig.layout.xaxis.title.text == "dim_2"
205+
assert fig.layout.yaxis.title.text == "dim_1"
206+
assert np.all(np.array(fig.data[0].x) == np.array(da.coords["dim_2"]))
207+
208+
209+
def test_imshow_xarray_animation_frame_string():
210+
img = np.random.random((3, 4, 5))
211+
da = xr.DataArray(
212+
img, dims=["str_dim", "dim_1", "dim_2"], coords={"str_dim": ["A", "B", "C"]}
213+
)
214+
fig = px.imshow(da, animation_frame="str_dim")
215+
# Dimensions are used for axis labels and coordinates
216+
assert fig.layout.xaxis.title.text == "dim_2"
217+
assert fig.layout.yaxis.title.text == "dim_1"
218+
assert np.all(np.array(fig.data[0].x) == np.array(da.coords["dim_2"]))
219+
220+
221+
def test_imshow_xarray_animation_facet_slicethrough():
222+
img = np.random.random((3, 4, 5, 6))
223+
da = xr.DataArray(img, dims=["dim_0", "dim_1", "dim_2", "dim_3"])
224+
fig = px.imshow(da, facet_col="dim_0", animation_frame="dim_1")
225+
# Dimensions are used for axis labels and coordinates
226+
assert fig.layout.xaxis.title.text == "dim_3"
227+
assert fig.layout.yaxis.title.text == "dim_2"
228+
assert np.all(np.array(fig.data[0].x) == np.array(da.coords["dim_3"]))
229+
230+
197231
def test_imshow_labels_and_ranges():
198232
fig = px.imshow(
199233
[[1, 2], [3, 4], [5, 6]],

0 commit comments

Comments
 (0)