Skip to content

API for adding labels: mpf.make_addplot(..., label="myLabel") #605

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Aug 1, 2023
576 changes: 576 additions & 0 deletions examples/addplot_legends.ipynb

Large diffs are not rendered by default.

14 changes: 14 additions & 0 deletions src/mplfinance/_arg_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import matplotlib as mpl
import warnings


def _check_and_prepare_data(data, config):
'''
Check and Prepare the data input:
Expand Down Expand Up @@ -94,6 +95,19 @@ def _check_and_prepare_data(data, config):

return dates, opens, highs, lows, closes, volumes


def _label_validator(label_value):
''' Validates the input of [legend] label for added plots.
label_value may be a str or a sequence of str.
'''
if isinstance(label_value,str):
return True
if isinstance(label_value,(list,tuple,np.ndarray)):
if all([isinstance(v,str) for v in label_value]):
return True
return False


def _get_valid_plot_types(plottype=None):

_alias_types = {
Expand Down
2 changes: 1 addition & 1 deletion src/mplfinance/_version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
version_info = (0, 12, 9, 'beta', 9)
version_info = (0, 12, 10, 'beta', 0)

_specifier_ = {'alpha': 'a','beta': 'b','candidate': 'rc','final': ''}

Expand Down
43 changes: 35 additions & 8 deletions src/mplfinance/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

from mplfinance import _styles

from mplfinance._arg_validators import _check_and_prepare_data, _mav_validator
from mplfinance._arg_validators import _check_and_prepare_data, _mav_validator, _label_validator
from mplfinance._arg_validators import _get_valid_plot_types, _fill_between_validator
from mplfinance._arg_validators import _process_kwargs, _validate_vkwargs_dict
from mplfinance._arg_validators import _kwarg_not_implemented, _bypass_kwarg_validation
Expand Down Expand Up @@ -765,6 +765,8 @@ def plot( data, **kwargs ):

elif not _list_of_dict(addplot):
raise TypeError('addplot must be `dict`, or `list of dict`, NOT '+str(type(addplot)))

contains_legend_label=[] # a list of axes that contains legend labels

for apdict in addplot:

Expand All @@ -788,10 +790,28 @@ def plot( data, **kwargs ):
else:
havedf = False # must be a single series or array
apdata = [apdata,] # make it iterable
if havedf and apdict['label']:
if not isinstance(apdict['label'],(list,tuple,np.ndarray)):
nlabels = 1
else:
nlabels = len(apdict['label'])
ncolumns = len(apdata.columns)
#print('nlabels=',nlabels,'ncolumns=',ncolumns)
if nlabels < ncolumns:
warnings.warn('\n =======================================\n'+
' addplot MISMATCH between data and labels:\n'+
' have '+str(ncolumns)+' columns to plot \n'+
' BUT '+str(nlabels)+' labels for them.\n')
colcount = 0
for column in apdata:
ydata = apdata.loc[:,column] if havedf else column
ax = _addplot_columns(panid,panels,ydata,apdict,xdates,config)
ax = _addplot_columns(panid,panels,ydata,apdict,xdates,config,colcount)
_addplot_apply_supplements(ax,apdict,xdates)
colcount += 1
if apdict['label']: # not supported for aptype == 'ohlc' or 'candle'
contains_legend_label.append(ax)
for ax in set(contains_legend_label): # there might be duplicates
ax.legend()

# fill_between is NOT supported for external_axes_mode
# (caller can easily call ax.fill_between() themselves).
Expand Down Expand Up @@ -1079,7 +1099,7 @@ def _addplot_collections(panid,panels,apdict,xdates,config):
ax.autoscale_view()
return ax

def _addplot_columns(panid,panels,ydata,apdict,xdates,config):
def _addplot_columns(panid,panels,ydata,apdict,xdates,config,colcount):
external_axes_mode = apdict['ax'] is not None
if not external_axes_mode:
secondary_y = False
Expand All @@ -1101,6 +1121,10 @@ def _addplot_columns(panid,panels,ydata,apdict,xdates,config):
ax = apdict['ax']

aptype = apdict['type']
if isinstance(apdict['label'],(list,tuple,np.ndarray)):
label = apdict['label'][colcount]
else: # isinstance(...,str)
label = apdict['label']
if aptype == 'scatter':
size = apdict['markersize']
mark = apdict['marker']
Expand All @@ -1111,27 +1135,27 @@ def _addplot_columns(panid,panels,ydata,apdict,xdates,config):

if isinstance(mark,(list,tuple,np.ndarray)):
_mscatter(xdates, ydata, ax=ax, m=mark, s=size, color=color, alpha=alpha, edgecolors=edgecolors, linewidths=linewidths)
else:
ax.scatter(xdates, ydata, s=size, marker=mark, color=color, alpha=alpha, edgecolors=edgecolors, linewidths=linewidths)
else:
ax.scatter(xdates, ydata, s=size, marker=mark, color=color, alpha=alpha, edgecolors=edgecolors, linewidths=linewidths,label=label)
elif aptype == 'bar':
width = 0.8 if apdict['width'] is None else apdict['width']
bottom = apdict['bottom']
color = apdict['color']
alpha = apdict['alpha']
ax.bar(xdates,ydata,width=width,bottom=bottom,color=color,alpha=alpha)
ax.bar(xdates,ydata,width=width,bottom=bottom,color=color,alpha=alpha,label=label)
elif aptype == 'line':
ls = apdict['linestyle']
color = apdict['color']
width = apdict['width'] if apdict['width'] is not None else 1.6*config['_width_config']['line_width']
alpha = apdict['alpha']
ax.plot(xdates,ydata,linestyle=ls,color=color,linewidth=width,alpha=alpha)
ax.plot(xdates,ydata,linestyle=ls,color=color,linewidth=width,alpha=alpha,label=label)
elif aptype == 'step':
stepwhere = apdict['stepwhere']
ls = apdict['linestyle']
color = apdict['color']
width = apdict['width'] if apdict['width'] is not None else 1.6*config['_width_config']['line_width']
alpha = apdict['alpha']
ax.step(xdates,ydata,where = stepwhere,linestyle=ls,color=color,linewidth=width,alpha=alpha)
ax.step(xdates,ydata,where = stepwhere,linestyle=ls,color=color,linewidth=width,alpha=alpha,label=label)
else:
raise ValueError('addplot type "'+str(aptype)+'" NOT yet supported.')

Expand Down Expand Up @@ -1384,6 +1408,9 @@ def _valid_addplot_kwargs():
'fill_between': { 'Default' : None, # added by Wen
'Description' : " fill region",
'Validator' : _fill_between_validator },
'label' : { 'Default' : None,
'Description' : 'Label for the added plot. One per added plot.',
'Validator' : _label_validator },

}

Expand Down