Skip to content

type annotations for chainable Figure methods #3708

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
27 changes: 16 additions & 11 deletions packages/python/plotly/codegen/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class {datatype_class}(_{node.name_base_datatype}):\n"""
buffer.write(
f"""
_subplotid_prop_names = {repr(subplot_names)}

import re
_subplotid_prop_re = re.compile(
'^(' + '|'.join(_subplotid_prop_names) + r')(\d+)$')
Expand Down Expand Up @@ -147,7 +147,7 @@ def _subplotid_validators(self):
from plotly.validators.layout import ({validator_csv})

return {subplot_dict_str}

def _subplot_re_match(self, prop):
return self._subplotid_prop_re.match(prop)
"""
Expand Down Expand Up @@ -208,7 +208,7 @@ def _subplot_re_match(self, prop):
# #### Combine to form property docstring ####
if property_description.strip():
property_docstring = f"""{property_description}

{validator_description}"""
else:
property_docstring = f" {validator_description}"
Expand Down Expand Up @@ -342,8 +342,8 @@ def __init__(self"""
arg = _copy.copy(arg)
else:
raise ValueError(\"\"\"\\
The first argument to the {class_name}
constructor must be a dict or
The first argument to the {class_name}
constructor must be a dict or
an instance of :class:`{class_name}`\"\"\")

# Handle skip_invalid
Expand Down Expand Up @@ -389,11 +389,11 @@ def __init__(self"""

buffer.write(
f"""

# Process unknown kwargs
# ----------------------
self._process_kwargs(**dict(arg, **kwargs))

# Reset skip_invalid
# ------------------
self._skip_invalid = False
Expand Down Expand Up @@ -429,7 +429,9 @@ def reindent_validator_description(validator, extra_indent):
return ("\n" + " " * extra_indent).join(validator.description().strip().split("\n"))


def add_constructor_params(buffer, subtype_nodes, prepend_extras=(), append_extras=()):
def add_constructor_params(
buffer, subtype_nodes, prepend_extras=(), append_extras=(), output_type=None
):
"""
Write datatype constructor params to a buffer

Expand Down Expand Up @@ -470,9 +472,12 @@ def add_constructor_params(buffer, subtype_nodes, prepend_extras=(), append_extr
**kwargs"""
)
buffer.write(
f"""
):"""
"""
)"""
)
if output_type:
buffer.write(f"-> '{output_type}'")
buffer.write(":")


def add_docstring(
Expand Down Expand Up @@ -525,7 +530,7 @@ def add_docstring(
f"""
\"\"\"
{header}

{node_description} Parameters
----------"""
)
Expand Down
111 changes: 93 additions & 18 deletions packages/python/plotly/codegen/figure.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,15 @@
from io import StringIO
from os import path as opath

from _plotly_utils.basevalidators import (
BaseDataValidator,
CompoundValidator,
CompoundArrayValidator,
)
from codegen.datatypes import (
reindent_validator_description,
add_constructor_params,
add_docstring,
)
from codegen.utils import PlotlyNode, write_source_py
from codegen.utils import write_source_py

import inflect
from plotly.basedatatypes import BaseFigure


def build_figure_py(
Expand Down Expand Up @@ -118,6 +114,82 @@ def __init__(self, data=None, layout=None,
"""
)

def add_wrapper(wrapped_name, full_params, param_list):
buffer.write(
f"""
def {wrapped_name}(self, {full_params}) -> "{fig_classname}":
'''
{getattr(BaseFigure, wrapped_name).__doc__}
'''
return super({fig_classname}, self).{wrapped_name}({param_list})
"""
)

add_wrapper(
"update",
"dict1=None, overwrite=False, **kwargs",
"dict1, overwrite, **kwargs",
)

add_wrapper(
"update_traces",
"patch=None, selector=None, row=None, col=None, secondary_y=None, overwrite=False, **kwargs",
"patch, selector, row, col, secondary_y, overwrite, **kwargs",
)

add_wrapper(
"update_layout",
"dict1=None, overwrite=False, **kwargs",
"dict1, overwrite, **kwargs",
)

add_wrapper(
"for_each_trace",
"fn, selector=None, row=None, col=None, secondary_y=None",
"fn, selector, row, col, secondary_y",
)

add_wrapper(
"add_trace",
"trace, row=None, col=None, secondary_y=None, exclude_empty_subplots=False",
"trace, row, col, secondary_y, exclude_empty_subplots",
)

add_wrapper(
"add_traces",
"data,rows=None,cols=None,secondary_ys=None,exclude_empty_subplots=False",
"data,rows,cols,secondary_ys,exclude_empty_subplots",
)

add_wrapper(
"add_vline",
'x,row="all",col="all",exclude_empty_subplots=True,annotation=None,**kwargs',
"x,row,col,exclude_empty_subplots,annotation,**kwargs",
)

add_wrapper(
"add_hline",
'y,row="all",col="all",exclude_empty_subplots=True,annotation=None,**kwargs',
"y,row,col,exclude_empty_subplots,annotation,**kwargs",
)

add_wrapper(
"add_vrect",
'x0,x1,row="all",col="all",exclude_empty_subplots=True,annotation=None,**kwargs',
"x0,x1,row,col,exclude_empty_subplots,annotation,**kwargs",
)

add_wrapper(
"add_hrect",
'y0,y1,row="all",col="all",exclude_empty_subplots=True,annotation=None,**kwargs',
"y0,y1,row,col,exclude_empty_subplots,annotation,**kwargs",
)
add_wrapper(
"set_subplots",
"rows=None, cols=None, **make_subplots_args",
"rows, cols, **make_subplots_args",
)

# ### add_trace methods for each trace type ###
for trace_node in trace_nodes:

Expand All @@ -136,7 +208,10 @@ def add_{trace_node.plotly_name}(self"""
if include_secondary_y:
param_extras.append("secondary_y")
add_constructor_params(
buffer, trace_node.child_datatypes, append_extras=param_extras
buffer,
trace_node.child_datatypes,
append_extras=param_extras,
output_type=fig_classname,
)

# #### Docstring ####
Expand Down Expand Up @@ -193,15 +268,15 @@ def add_{trace_node.plotly_name}(self"""
"""
)

for i, subtype_node in enumerate(trace_node.child_datatypes):
for _, subtype_node in enumerate(trace_node.child_datatypes):
subtype_prop_name = subtype_node.name_property
buffer.write(
f"""
{subtype_prop_name}={subtype_prop_name},"""
)

buffer.write(
f"""
"""
**kwargs)"""
)

Expand All @@ -226,7 +301,7 @@ def add_{trace_node.plotly_name}(self"""
if singular_name == "yaxis":
secondary_y_1 = ", secondary_y=None"
secondary_y_2 = ", secondary_y=secondary_y"
secondary_y_docstring = f"""
secondary_y_docstring = """
secondary_y: boolean or None (default None)
* If True, only select yaxis objects associated with the secondary
y-axis of the subplot.
Expand Down Expand Up @@ -283,7 +358,7 @@ def select_{plural_name}(
'{singular_name}', selector, row, col{secondary_y_2})

def for_each_{singular_name}(
self, fn, selector=None, row=None, col=None{secondary_y_1}):
self, fn, selector=None, row=None, col=None{secondary_y_1}) -> '{fig_classname}':
\"\"\"
Apply a function to all {singular_name} objects that satisfy the
specified selection criteria
Expand Down Expand Up @@ -311,7 +386,7 @@ def for_each_{singular_name}(
Returns
-------
self
Returns the Figure object that the method was called on
Returns the {fig_classname} object that the method was called on
\"\"\"
for obj in self.select_{plural_name}(
selector=selector, row=row, col=col{secondary_y_2}):
Expand All @@ -325,7 +400,7 @@ def update_{plural_name}(
selector=None,
overwrite=False,
row=None, col=None{secondary_y_1},
**kwargs):
**kwargs) -> '{fig_classname}':
\"\"\"
Perform a property update operation on all {singular_name} objects
that satisfy the specified selection criteria
Expand Down Expand Up @@ -363,7 +438,7 @@ def update_{plural_name}(
Returns
-------
self
Returns the Figure object that the method was called on
Returns the {fig_classname} object that the method was called on
\"\"\"
for obj in self.select_{plural_name}(
selector=selector, row=row, col=col{secondary_y_2}):
Expand Down Expand Up @@ -477,7 +552,7 @@ def for_each_{method_prefix}{singular_name}(
Returns
-------
self
Returns the Figure object that the method was called on
Returns the {fig_classname} object that the method was called on
\"\"\"
for obj in self._select_annotations_like(
prop='{plural_name}',
Expand All @@ -498,7 +573,7 @@ def update_{method_prefix}{plural_name}(
col=None,
secondary_y=None,
**kwargs
):
) -> '{fig_classname}':
\"\"\"
Perform a property update operation on all {plural_name} that satisfy the
specified selection criteria
Expand Down Expand Up @@ -545,7 +620,7 @@ def update_{method_prefix}{plural_name}(
Returns
-------
self
Returns the Figure object that the method was called on
Returns the {fig_classname} object that the method was called on
\"\"\"
for obj in self._select_annotations_like(
prop='{plural_name}',
Expand Down Expand Up @@ -610,7 +685,7 @@ def add_{method_prefix}{singular_name}(self"""
"""
)

for i, subtype_node in enumerate(node.child_datatypes):
for _, subtype_node in enumerate(node.child_datatypes):
subtype_prop_name = subtype_node.name_property
buffer.write(
f"""
Expand Down
Loading