diff --git a/packages/python/plotly/plotly/figure_factory/_dendrogram.py b/packages/python/plotly/plotly/figure_factory/_dendrogram.py index 380487e9a82..4164d0a05b6 100644 --- a/packages/python/plotly/plotly/figure_factory/_dendrogram.py +++ b/packages/python/plotly/plotly/figure_factory/_dendrogram.py @@ -25,14 +25,21 @@ def create_dendrogram( color_threshold=None, ): """ - Function that returns a dendrogram Plotly figure object. + Function that returns a dendrogram Plotly figure object. This is a thin + wrapper around scipy.cluster.hierarchy.dendrogram. See also https://dash.plot.ly/dash-bio/clustergram. :param (ndarray) X: Matrix of observations as array of arrays :param (str) orientation: 'top', 'right', 'bottom', or 'left' :param (list) labels: List of axis category labels(observation labels) - :param (list) colorscale: Optional colorscale for dendrogram tree + :param (list) colorscale: Optional colorscale for the dendrogram tree. + Requires 8 colors to be specified, the 7th of + which is ignored. With scipy>=1.5.0, the 2nd, 3rd + and 6th are used twice as often as the others. + Given a shorter list, the missing values are + replaced with defaults and with a longer list the + extra values are ignored. :param (function) distfun: Function to compute the pairwise distance from the observations :param (function) linkagefun: Function to compute the linkage matrix from @@ -160,8 +167,8 @@ def __init__( if len(self.zero_vals) > len(yvals) + 1: # If the length of zero_vals is larger than the length of yvals, # it means that there are wrong vals because of the identicial samples. - # Three and more identicial samples will make the yvals of spliting center into 0 and it will \ - # accidentally take it as leaves. + # Three and more identicial samples will make the yvals of spliting + # center into 0 and it will accidentally take it as leaves. l_border = int(min(self.zero_vals)) r_border = int(max(self.zero_vals)) correct_leaves_pos = range( @@ -185,6 +192,9 @@ def get_color_dict(self, colorscale): # These are the color codes returned for dendrograms # We're replacing them with nicer colors + # This list is the colors that can be used by dendrogram, which were + # determined as the combination of the default above_threshold_color and + # the default color palette (see scipy/cluster/hierarchy.py) d = { "r": "red", "g": "green", @@ -193,12 +203,14 @@ def get_color_dict(self, colorscale): "m": "magenta", "y": "yellow", "k": "black", + # TODO: 'w' doesn't seem to be in the default color + # palette in scipy/cluster/hierarchy.py "w": "white", } default_colors = OrderedDict(sorted(d.items(), key=lambda t: t[0])) if colorscale is None: - colorscale = [ + rgb_colorscale = [ "rgb(0,116,217)", # blue "rgb(35,205,205)", # cyan "rgb(61,153,112)", # green @@ -206,13 +218,43 @@ def get_color_dict(self, colorscale): "rgb(133,20,75)", # magenta "rgb(255,65,54)", # red "rgb(255,255,255)", # white - "rgb(255,220,0)", - ] # yellow + "rgb(255,220,0)", # yellow + ] + else: + rgb_colorscale = colorscale for i in range(len(default_colors.keys())): k = list(default_colors.keys())[i] # PY3 won't index keys - if i < len(colorscale): - default_colors[k] = colorscale[i] + if i < len(rgb_colorscale): + default_colors[k] = rgb_colorscale[i] + + # add support for cyclic format colors as introduced in scipy===1.5.0 + # before this, the colors were named 'r', 'b', 'y' etc., now they are + # named 'C0', 'C1', etc. To keep the colors consistent regardless of the + # scipy version, we try as much as possible to map the new colors to the + # old colors + # this mapping was found by inpecting scipy/cluster/hierarchy.py (see + # comment above). + new_old_color_map = [ + ("C0", "b"), + ("C1", "g"), + ("C2", "r"), + ("C3", "c"), + ("C4", "m"), + ("C5", "y"), + ("C6", "k"), + ("C7", "g"), + ("C8", "r"), + ("C9", "c"), + ] + for nc, oc in new_old_color_map: + try: + default_colors[nc] = default_colors[oc] + except KeyError: + # it could happen that the old color isn't found (if a custom + # colorscale was specified), in this case we set it to an + # arbitrary default. + default_colors[n] = "rgb(0,116,217)" return default_colors