Skip to content

Commit 0f2b688

Browse files
committed
ENH: Allow elementwise coloring in background_gradient with axis=None #15204
1 parent ab6aaf7 commit 0f2b688

File tree

3 files changed

+63
-24
lines changed

3 files changed

+63
-24
lines changed

Diff for: doc/source/whatsnew/v0.24.0.txt

+2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ Other Enhancements
1515
- :func:`to_datetime` now supports the ``%Z`` and ``%z`` directive when passed into ``format`` (:issue:`13486`)
1616
- :func:`Series.mode` and :func:`DataFrame.mode` now support the ``dropna`` parameter which can be used to specify whether NaN/NaT values should be considered (:issue:`17534`)
1717
- :func:`to_csv` now supports ``compression`` keyword when a file handle is passed. (:issue:`21227`)
18+
- Allow elementwise coloring in ``style.background_gradient`` with ``axis=None`` (:issue:`15204`)
19+
-
1820
-
1921

2022
.. _whatsnew_0240.api_breaking:

Diff for: pandas/io/formats/style.py

+33-24
Original file line numberDiff line numberDiff line change
@@ -913,21 +913,22 @@ def background_gradient(self, cmap='PuBu', low=0, high=0, axis=0,
913913
def _background_gradient(s, cmap='PuBu', low=0, high=0,
914914
text_color_threshold=0.408):
915915
"""Color background in a range according to the data."""
916+
if (not isinstance(text_color_threshold, (float, int)) or
917+
not 0 <= text_color_threshold <= 1):
918+
msg = "`text_color_threshold` must be a value from 0 to 1."
919+
raise ValueError(msg)
920+
916921
with _mpl(Styler.background_gradient) as (plt, colors):
917-
rng = s.max() - s.min()
922+
smin = s.values.min()
923+
smax = s.values.max()
924+
rng = smax - smin
918925
# extend lower / upper bounds, compresses color range
919-
norm = colors.Normalize(s.min() - (rng * low),
920-
s.max() + (rng * high))
921-
# matplotlib modifies inplace?
926+
norm = colors.Normalize(smin - (rng * low), smax + (rng * high))
927+
# matplotlib colors.Normalize modifies inplace?
922928
# https://github.com/matplotlib/matplotlib/issues/5427
923-
normed = norm(s.values)
924-
c = [colors.rgb2hex(x) for x in plt.cm.get_cmap(cmap)(normed)]
925-
if (not isinstance(text_color_threshold, (float, int)) or
926-
not 0 <= text_color_threshold <= 1):
927-
msg = "`text_color_threshold` must be a value from 0 to 1."
928-
raise ValueError(msg)
929+
rgbas = plt.cm.get_cmap(cmap)(norm(s.values))
929930

930-
def relative_luminance(color):
931+
def relative_luminance(rgba):
931932
"""
932933
Calculate relative luminance of a color.
933934
@@ -936,25 +937,33 @@ def relative_luminance(color):
936937
937938
Parameters
938939
----------
939-
color : matplotlib color
940-
Hex code, rgb-tuple, or HTML color name.
940+
color : rgb or rgba tuple
941941
942942
Returns
943943
-------
944944
float
945945
The relative luminance as a value from 0 to 1
946946
"""
947-
rgb = colors.colorConverter.to_rgba_array(color)[:, :3]
948-
rgb = np.where(rgb <= .03928, rgb / 12.92,
949-
((rgb + .055) / 1.055) ** 2.4)
950-
lum = rgb.dot([.2126, .7152, .0722])
951-
return lum.item()
952-
953-
text_colors = ['#f1f1f1' if relative_luminance(x) <
954-
text_color_threshold else '#000000' for x in c]
955-
956-
return ['background-color: {color};color: {tc}'.format(
957-
color=color, tc=tc) for color, tc in zip(c, text_colors)]
947+
r, g, b = (
948+
x / 12.92 if x <= 0.03928 else ((x + 0.055) / 1.055 ** 2.4)
949+
for x in rgba[:3]
950+
)
951+
return 0.2126 * r + 0.7152 * g + 0.0722 * b
952+
953+
def css(rgba):
954+
dark = relative_luminance(rgba) < text_color_threshold
955+
text_color = '#f1f1f1' if dark else '#000000'
956+
return 'background-color: {b};color: {c};'.format(
957+
b=colors.rgb2hex(rgba), c=text_color
958+
)
959+
960+
if s.ndim == 1:
961+
return [css(rgba) for rgba in rgbas]
962+
else:
963+
return pd.DataFrame(
964+
[[css(rgba) for rgba in row] for row in rgbas],
965+
index=s.index, columns=s.columns
966+
)
958967

959968
def set_properties(self, subset=None, **kwargs):
960969
"""

Diff for: pandas/tests/io/formats/test_style.py

+28
Original file line numberDiff line numberDiff line change
@@ -1056,6 +1056,34 @@ def test_text_color_threshold_raises(self, text_color_threshold):
10561056
df.style.background_gradient(
10571057
text_color_threshold=text_color_threshold)._compute()
10581058

1059+
@td.skip_if_no_mpl
1060+
def test_background_gradient_axis(self):
1061+
df = pd.DataFrame([[1, 2], [2, 4]], columns=['A', 'B'])
1062+
1063+
low = ['background-color: #f7fbff', 'color: #000000']
1064+
high = ['background-color: #08306b', 'color: #f1f1f1']
1065+
mid = ['background-color: #abd0e6', 'color: #000000']
1066+
result = df.style.background_gradient(cmap='Blues',
1067+
axis=0)._compute().ctx
1068+
assert result[(0, 0)] == low
1069+
assert result[(0, 1)] == low
1070+
assert result[(1, 0)] == high
1071+
assert result[(1, 1)] == high
1072+
1073+
result = df.style.background_gradient(cmap='Blues',
1074+
axis=1)._compute().ctx
1075+
assert result[(0, 0)] == low
1076+
assert result[(0, 1)] == high
1077+
assert result[(1, 0)] == low
1078+
assert result[(1, 1)] == high
1079+
1080+
result = df.style.background_gradient(cmap='Blues',
1081+
axis=None)._compute().ctx
1082+
assert result[(0, 0)] == low
1083+
assert result[(0, 1)] == mid
1084+
assert result[(1, 0)] == mid
1085+
assert result[(1, 1)] == high
1086+
10591087

10601088
def test_block_names():
10611089
# catch accidental removal of a block

0 commit comments

Comments
 (0)