Skip to content

Commit a004c59

Browse files
author
Tom Augspurger
committed
Merge pull request #9574 from schmohlio/fix-plot-legends
fixing pandas.DataFrame.plot(): labels do not appear in legend and label kwd
2 parents a01810e + c4b4cb7 commit a004c59

File tree

3 files changed

+21
-11
lines changed

3 files changed

+21
-11
lines changed

Diff for: doc/source/whatsnew/v0.16.1.txt

+2
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ Performance Improvements
5050
Bug Fixes
5151
~~~~~~~~~
5252

53+
- Fixed bug (:issue:`9542`) where labels did not appear properly in legend of ``DataFrame.plot()``. Passing ``label=`` args also now works, and series indices are no longer mutated.
54+
5355

5456

5557

Diff for: pandas/tests/test_graphics.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -1064,12 +1064,6 @@ def test_implicit_label(self):
10641064
ax = df.plot(x='a', y='b')
10651065
self._check_text_labels(ax.xaxis.get_label(), 'a')
10661066

1067-
@slow
1068-
def test_explicit_label(self):
1069-
df = DataFrame(randn(10, 3), columns=['a', 'b', 'c'])
1070-
ax = df.plot(x='a', y='b', label='LABEL')
1071-
self._check_text_labels(ax.xaxis.get_label(), 'LABEL')
1072-
10731067
@slow
10741068
def test_donot_overwrite_index_name(self):
10751069
# GH 8494
@@ -2542,6 +2536,20 @@ def test_df_legend_labels(self):
25422536
ax = df3.plot(kind='scatter', x='g', y='h', label='data3', ax=ax)
25432537
self._check_legend_labels(ax, labels=['data1', 'data3'])
25442538

2539+
# ensure label args pass through and
2540+
# index name does not mutate
2541+
# column names don't mutate
2542+
df5 = df.set_index('a')
2543+
ax = df5.plot(y='b')
2544+
self._check_legend_labels(ax, labels=['b'])
2545+
ax = df5.plot(y='b', label='LABEL_b')
2546+
self._check_legend_labels(ax, labels=['LABEL_b'])
2547+
self._check_text_labels(ax.xaxis.get_label(), 'a')
2548+
ax = df5.plot(y='c', label='LABEL_c', ax=ax)
2549+
self._check_legend_labels(ax, labels=['LABEL_b','LABEL_c'])
2550+
self.assertTrue(df5.columns.tolist() == ['b','c'])
2551+
2552+
25452553
def test_legend_name(self):
25462554
multi = DataFrame(randn(4, 4),
25472555
columns=[np.array(['a', 'a', 'b', 'b']),

Diff for: pandas/tools/plotting.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -886,10 +886,11 @@ def _iter_data(self, data=None, keep_index=False, fillna=None):
886886

887887
from pandas.core.frame import DataFrame
888888
if isinstance(data, (Series, np.ndarray, Index)):
889+
label = self.label if self.label is not None else data.name
889890
if keep_index is True:
890-
yield self.label, data
891+
yield label, data
891892
else:
892-
yield self.label, np.asarray(data)
893+
yield label, np.asarray(data)
893894
elif isinstance(data, DataFrame):
894895
if self.sort_columns:
895896
columns = com._try_sort(data.columns)
@@ -2306,10 +2307,9 @@ def _plot(data, x=None, y=None, subplots=False,
23062307
if y is not None:
23072308
if com.is_integer(y) and not data.columns.holds_integer():
23082309
y = data.columns[y]
2309-
label = x if x is not None else data.index.name
2310-
label = kwds.pop('label', label)
2310+
label = kwds['label'] if 'label' in kwds else y
23112311
series = data[y].copy() # Don't modify
2312-
series.index.name = label
2312+
series.name = label
23132313

23142314
for kw in ['xerr', 'yerr']:
23152315
if (kw in kwds) and \

0 commit comments

Comments
 (0)