Skip to content

Px special inputs #2330

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions packages/python/plotly/plotly/express/__init__.py
Original file line number Diff line number Diff line change
@@ -55,6 +55,11 @@
get_trendline_results,
)

from ._special_inputs import ( # noqa: F401
IdentityMap,
Constant,
)

from . import data, colors # noqa: F401

__all__ = [
@@ -95,4 +100,6 @@
"colors",
"set_mapbox_access_token",
"get_trendline_results",
"IdentityMap",
"Constant",
]
29 changes: 24 additions & 5 deletions packages/python/plotly/plotly/express/_core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import plotly.graph_objs as go
import plotly.io as pio
from collections import namedtuple, OrderedDict
from ._special_inputs import IdentityMap, Constant

from _plotly_utils.basevalidators import ColorscaleValidator
from .colors import qualitative, sequential
@@ -41,6 +42,7 @@ def __init__(self):
defaults = PxDefaults()
del PxDefaults


MAPBOX_TOKEN = None


@@ -137,11 +139,15 @@ def make_mapping(args, variable):
if variable == "dash":
arg_name = "line_dash"
vprefix = "line_dash"
if args[vprefix + "_map"] == "identity":
val_map = IdentityMap()
else:
val_map = args[vprefix + "_map"].copy()
return Mapping(
show_in_trace_name=True,
variable=variable,
grouper=args[arg_name],
val_map=args[vprefix + "_map"].copy(),
val_map=val_map,
sequence=args[vprefix + "_sequence"],
updater=lambda trace, v: trace.update({parent: {variable: v}}),
facet=None,
@@ -919,6 +925,8 @@ def build_dataframe(args, attrables, array_attrables):
else:
df_output[df_input.columns] = df_input[df_input.columns]

constants = dict()

# Loop over possible arguments
for field_name in attrables:
# Massaging variables
@@ -950,8 +958,15 @@ def build_dataframe(args, attrables, array_attrables):
"pandas MultiIndex is not supported by plotly express "
"at the moment." % field
)
# ----------------- argument is a constant ----------------------
if isinstance(argument, Constant):
col_name = _check_name_not_reserved(
str(argument.label) if argument.label is not None else field,
reserved_names,
)
constants[col_name] = argument.value
# ----------------- argument is a col name ----------------------
if isinstance(argument, str) or isinstance(
elif isinstance(argument, str) or isinstance(
argument, int
): # just a column name given as str or int
if not df_provided:
@@ -1032,6 +1047,9 @@ def build_dataframe(args, attrables, array_attrables):
else:
args[field_name][i] = str(col_name)

for col_name in constants:
df_output[col_name] = constants[col_name]

args["data_frame"] = df_output
return args

@@ -1402,9 +1420,10 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
for col, val, m in zip(grouper, group_name, grouped_mappings):
if col != one_group:
key = get_label(args, col)
mapping_labels[key] = str(val)
if m.show_in_trace_name:
trace_name_labels[key] = str(val)
if not isinstance(m.val_map, IdentityMap):
mapping_labels[key] = str(val)
if m.show_in_trace_name:
trace_name_labels[key] = str(val)
if m.variable == "animation_frame":
frame_name = val
trace_name = ", ".join(trace_name_labels.values())
29 changes: 29 additions & 0 deletions packages/python/plotly/plotly/express/_special_inputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
class IdentityMap(object):
"""
`dict`-like object which acts as if the value for any key is the key itself. Objects
of this class can be passed in to arguments like `color_discrete_map` to
use the provided data values as colors, rather than mapping them to colors cycled
from `color_discrete_sequence`. This works for any `_map` argument to Plotly Express
functions, such as `line_dash_map` and `symbol_map`.
"""

def __getitem__(self, key):
return key

def __contains__(self, key):
return True

def copy(self):
return self


class Constant(object):
"""
Objects of this class can be passed to Plotly Express functions that expect column
identifiers or list-like objects to indicate that this attribute should take on a
constant value. An optional label can be provided.
"""

def __init__(self, value, label=None):
self.value = value
self.label = label
Original file line number Diff line number Diff line change
@@ -323,3 +323,61 @@ def test_size_column():
df = px.data.tips()
fig = px.scatter(df, x=df["size"], y=df.tip)
assert fig.data[0].hovertemplate == "size=%{x}<br>tip=%{y}<extra></extra>"


def test_identity_map():
fig = px.scatter(
x=[1, 2],
y=[1, 2],
symbol=["a", "b"],
color=["red", "blue"],
color_discrete_map=px.IdentityMap(),
)
assert fig.data[0].marker.color == "red"
assert fig.data[1].marker.color == "blue"
assert "color=" not in fig.data[0].hovertemplate
assert "symbol=" in fig.data[0].hovertemplate
assert fig.layout.legend.title.text == "symbol"

fig = px.scatter(
x=[1, 2],
y=[1, 2],
symbol=["a", "b"],
color=["red", "blue"],
color_discrete_map="identity",
)
assert fig.data[0].marker.color == "red"
assert fig.data[1].marker.color == "blue"
assert "color=" not in fig.data[0].hovertemplate
assert "symbol=" in fig.data[0].hovertemplate
assert fig.layout.legend.title.text == "symbol"


def test_constants():
fig = px.scatter(x=px.Constant(1), y=[1, 2])
assert fig.data[0].x[0] == 1
assert fig.data[0].x[1] == 1
assert "x=" in fig.data[0].hovertemplate

fig = px.scatter(x=px.Constant(1, label="time"), y=[1, 2])
assert fig.data[0].x[0] == 1
assert fig.data[0].x[1] == 1
assert "x=" not in fig.data[0].hovertemplate
assert "time=" in fig.data[0].hovertemplate

fig = px.scatter(
x=[1, 2],
y=[1, 2],
symbol=["a", "b"],
color=px.Constant("red", label="the_identity_label"),
hover_data=[px.Constant("data", label="the_data")],
color_discrete_map=px.IdentityMap(),
)
assert fig.data[0].marker.color == "red"
assert fig.data[0].customdata[0][0] == "data"
assert fig.data[1].marker.color == "red"
assert "color=" not in fig.data[0].hovertemplate
assert "the_identity_label=" not in fig.data[0].hovertemplate
assert "symbol=" in fig.data[0].hovertemplate
assert "the_data=" in fig.data[0].hovertemplate
assert fig.layout.legend.title.text == "symbol"