diff --git a/CHANGELOG.md b/CHANGELOG.md index 9a0f13b6f00..65e02484d7f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,9 @@ All notable changes to this project will be documented in this file. This project adheres to [Semantic Versioning](http://semver.org/). +## UNRELEASED + - Fixed two issues with px.imshow: [[#4330](https://github.com/plotly/plotly.py/issues/4330)] when facet_col is an earlier dimension than animation_frame for xarrays and [[#4329](https://github.com/plotly/plotly.py/issues/4329)] when facet_col has string coordinates in xarrays [[#4331](https://github.com/plotly/plotly.py/pull/4331)] + ## [5.16.1] - 2023-08-16 ### Fixed diff --git a/packages/python/plotly/plotly/express/_imshow.py b/packages/python/plotly/plotly/express/_imshow.py index 6478cc6e274..de0e22284b4 100644 --- a/packages/python/plotly/plotly/express/_imshow.py +++ b/packages/python/plotly/plotly/express/_imshow.py @@ -265,14 +265,18 @@ def imshow( if xarray_imported and isinstance(img, xarray.DataArray): dims = list(img.dims) img_is_xarray = True + pop_indexes = [] if facet_col is not None: facet_slices = img.coords[img.dims[facet_col]].values - _ = dims.pop(facet_col) + pop_indexes.append(facet_col) facet_label = img.dims[facet_col] if animation_frame is not None: animation_slices = img.coords[img.dims[animation_frame]].values - _ = dims.pop(animation_frame) + pop_indexes.append(animation_frame) animation_label = img.dims[animation_frame] + # Remove indices in sorted order. + for index in sorted(pop_indexes, reverse=True): + _ = dims.pop(index) y_label, x_label = dims[0], dims[1] # np.datetime64 is not handled correctly by go.Heatmap for ax in [x_label, y_label]: @@ -541,7 +545,7 @@ def imshow( slice_label = ( "facet_col" if labels.get("facet_col") is None else labels["facet_col"] ) - col_labels = ["%s=%d" % (slice_label, i) for i in facet_slices] + col_labels = [f"{slice_label}={i}" for i in facet_slices] fig = init_figure(args, "xy", [], nrows, ncols, col_labels, []) for attr_name in ["height", "width"]: if args[attr_name]: diff --git a/packages/python/plotly/plotly/tests/test_optional/test_px/test_imshow.py b/packages/python/plotly/plotly/tests/test_optional/test_px/test_imshow.py index 09f1ae1d90f..c2e863c846b 100644 --- a/packages/python/plotly/plotly/tests/test_optional/test_px/test_imshow.py +++ b/packages/python/plotly/plotly/tests/test_optional/test_px/test_imshow.py @@ -194,6 +194,40 @@ def test_imshow_xarray_slicethrough(): assert np.all(np.array(fig.data[0].x) == np.array(da.coords["dim_2"])) +def test_imshow_xarray_facet_col_string(): + img = np.random.random((3, 4, 5)) + da = xr.DataArray( + img, dims=["str_dim", "dim_1", "dim_2"], coords={"str_dim": ["A", "B", "C"]} + ) + fig = px.imshow(da, facet_col="str_dim") + # Dimensions are used for axis labels and coordinates + assert fig.layout.xaxis.title.text == "dim_2" + assert fig.layout.yaxis.title.text == "dim_1" + assert np.all(np.array(fig.data[0].x) == np.array(da.coords["dim_2"])) + + +def test_imshow_xarray_animation_frame_string(): + img = np.random.random((3, 4, 5)) + da = xr.DataArray( + img, dims=["str_dim", "dim_1", "dim_2"], coords={"str_dim": ["A", "B", "C"]} + ) + fig = px.imshow(da, animation_frame="str_dim") + # Dimensions are used for axis labels and coordinates + assert fig.layout.xaxis.title.text == "dim_2" + assert fig.layout.yaxis.title.text == "dim_1" + assert np.all(np.array(fig.data[0].x) == np.array(da.coords["dim_2"])) + + +def test_imshow_xarray_animation_facet_slicethrough(): + img = np.random.random((3, 4, 5, 6)) + da = xr.DataArray(img, dims=["dim_0", "dim_1", "dim_2", "dim_3"]) + fig = px.imshow(da, facet_col="dim_0", animation_frame="dim_1") + # Dimensions are used for axis labels and coordinates + assert fig.layout.xaxis.title.text == "dim_3" + assert fig.layout.yaxis.title.text == "dim_2" + assert np.all(np.array(fig.data[0].x) == np.array(da.coords["dim_3"])) + + def test_imshow_labels_and_ranges(): fig = px.imshow( [[1, 2], [3, 4], [5, 6]],