Skip to content

Commit 7fcb95c

Browse files
Import and initialization optimizations (#2368)
* Lazy imports of graph object hierarchy for Python 3.7+ This involved splitting validators/graph object classes back into separate files * lazy submodule imports for plotly and plotly.io modules * Don't force import pandas and numpy if we're only checking whether an object belongs to these packages If pandas isn't loaded, we don't need to check whether a value is a DataFrame, and this way we don't pay the pandas import time. * Don't auto-import all trace type classes * Don't construct validators up front for every object when object is created. Create them lazily, and cache then for use across graph objects of the same type * Test fixes * Add optional validation using go.validate object as callable of context manager * Delay additional imports * dynamic to static docstrings for Figure io methods Saves 200ms of startup time! * Update packages/python/plotly/_plotly_utils/importers.py Co-Authored-By: Emmanuelle Gouillart <[email protected]> * Use module-level __dir__ function to restore IPython tab completion with Python 3.7+ * Don't fail hierarchy test when ipywidgets not installed * cut commented print statements * Test fix, FigureWidget is not expected to be importable when ipywidgets is not installed Co-authored-by: Emmanuelle Gouillart <[email protected]>
1 parent 9465de3 commit 7fcb95c

File tree

10,122 files changed

+623059
-599698
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

10,122 files changed

+623059
-599698
lines changed

Diff for: packages/python/plotly/_plotly_utils/basevalidators.py

+28-24
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ def to_scalar_or_list(v):
3535
# Python native scalar type ('float' in the example above).
3636
# We explicitly check if is has the 'item' method, which conventionally
3737
# converts these types to native scalars.
38-
np = get_module("numpy")
39-
pd = get_module("pandas")
38+
np = get_module("numpy", should_load=False)
39+
pd = get_module("pandas", should_load=False)
4040
if np and np.isscalar(v) and hasattr(v, "item"):
4141
return v.item()
4242
if isinstance(v, (list, tuple)):
@@ -74,7 +74,9 @@ def copy_to_readonly_numpy_array(v, kind=None, force_numeric=False):
7474
Numpy array with the 'WRITEABLE' flag set to False
7575
"""
7676
np = get_module("numpy")
77-
pd = get_module("pandas")
77+
78+
# Don't force pandas to be loaded, we only want to know if it's already loaded
79+
pd = get_module("pandas", should_load=False)
7880
assert np is not None
7981

8082
# ### Process kind ###
@@ -166,8 +168,8 @@ def is_homogeneous_array(v):
166168
"""
167169
Return whether a value is considered to be a homogeneous array
168170
"""
169-
np = get_module("numpy")
170-
pd = get_module("pandas")
171+
np = get_module("numpy", should_load=False)
172+
pd = get_module("pandas", should_load=False)
171173
if (
172174
np
173175
and isinstance(v, np.ndarray)
@@ -2455,6 +2457,10 @@ def validate_coerce(self, v, skip_invalid=False):
24552457
v._plotly_name = self.plotly_name
24562458
return v
24572459

2460+
def present(self, v):
2461+
# Return compound object as-is
2462+
return v
2463+
24582464

24592465
class TitleValidator(CompoundValidator):
24602466
"""
@@ -2549,6 +2555,10 @@ def validate_coerce(self, v, skip_invalid=False):
25492555

25502556
return v
25512557

2558+
def present(self, v):
2559+
# Return compound object as tuple
2560+
return tuple(v)
2561+
25522562

25532563
class BaseDataValidator(BaseValidator):
25542564
def __init__(
@@ -2559,7 +2569,7 @@ def __init__(
25592569
)
25602570

25612571
self.class_strs_map = class_strs_map
2562-
self._class_map = None
2572+
self._class_map = {}
25632573
self.set_uid = set_uid
25642574

25652575
def description(self):
@@ -2595,21 +2605,17 @@ def description(self):
25952605

25962606
return desc
25972607

2598-
@property
2599-
def class_map(self):
2600-
if self._class_map is None:
2601-
2602-
# Initialize class map
2603-
self._class_map = {}
2604-
2605-
# Import trace classes
2608+
def get_trace_class(self, trace_name):
2609+
# Import trace classes
2610+
if trace_name not in self._class_map:
26062611
trace_module = import_module("plotly.graph_objs")
2607-
for k, class_str in self.class_strs_map.items():
2608-
self._class_map[k] = getattr(trace_module, class_str)
2612+
trace_class_name = self.class_strs_map[trace_name]
2613+
self._class_map[trace_name] = getattr(trace_module, trace_class_name)
26092614

2610-
return self._class_map
2615+
return self._class_map[trace_name]
26112616

26122617
def validate_coerce(self, v, skip_invalid=False):
2618+
from plotly.basedatatypes import BaseTraceType
26132619

26142620
# Import Histogram2dcontour, this is the deprecated name of the
26152621
# Histogram2dContour trace.
@@ -2621,13 +2627,11 @@ def validate_coerce(self, v, skip_invalid=False):
26212627
if not isinstance(v, (list, tuple)):
26222628
v = [v]
26232629

2624-
trace_classes = tuple(self.class_map.values())
2625-
26262630
res = []
26272631
invalid_els = []
26282632
for v_el in v:
26292633

2630-
if isinstance(v_el, trace_classes):
2634+
if isinstance(v_el, BaseTraceType):
26312635
# Clone input traces
26322636
v_el = v_el.to_plotly_json()
26332637

@@ -2641,25 +2645,25 @@ def validate_coerce(self, v, skip_invalid=False):
26412645
else:
26422646
trace_type = "scatter"
26432647

2644-
if trace_type not in self.class_map:
2648+
if trace_type not in self.class_strs_map:
26452649
if skip_invalid:
26462650
# Treat as scatter trace
2647-
trace = self.class_map["scatter"](
2651+
trace = self.get_trace_class("scatter")(
26482652
skip_invalid=skip_invalid, **v_copy
26492653
)
26502654
res.append(trace)
26512655
else:
26522656
res.append(None)
26532657
invalid_els.append(v_el)
26542658
else:
2655-
trace = self.class_map[trace_type](
2659+
trace = self.get_trace_class(trace_type)(
26562660
skip_invalid=skip_invalid, **v_copy
26572661
)
26582662
res.append(trace)
26592663
else:
26602664
if skip_invalid:
26612665
# Add empty scatter trace
2662-
trace = self.class_map["scatter"]()
2666+
trace = self.get_trace_class("scatter")()
26632667
res.append(trace)
26642668
else:
26652669
res.append(None)

Diff for: packages/python/plotly/_plotly_utils/importers.py

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import importlib
2+
3+
4+
def relative_import(parent_name, rel_modules=(), rel_classes=()):
5+
"""
6+
Helper function to import submodules lazily in Python 3.7+
7+
8+
Parameters
9+
----------
10+
rel_modules: list of str
11+
list of submodules to import, of the form .submodule
12+
rel_classes: list of str
13+
list of submodule classes/variables to import, of the form ._submodule.Foo
14+
15+
Returns
16+
-------
17+
tuple
18+
Tuple that should be assigned to __all__, __getattr__ in the caller
19+
"""
20+
module_names = {rel_module.split(".")[-1]: rel_module for rel_module in rel_modules}
21+
class_names = {rel_path.split(".")[-1]: rel_path for rel_path in rel_classes}
22+
23+
def __getattr__(import_name):
24+
# In Python 3.7+, lazy import submodules
25+
26+
# Check for submodule
27+
if import_name in module_names:
28+
rel_import = module_names[import_name]
29+
return importlib.import_module(rel_import, parent_name)
30+
31+
# Check for submodule class
32+
if import_name in class_names:
33+
rel_path_parts = class_names[import_name].split(".")
34+
rel_module = ".".join(rel_path_parts[:-1])
35+
class_name = import_name
36+
class_module = importlib.import_module(rel_module, parent_name)
37+
return getattr(class_module, class_name)
38+
39+
raise AttributeError(
40+
"module {__name__!r} has no attribute {name!r}".format(
41+
name=import_name, __name__=parent_name
42+
)
43+
)
44+
45+
__all__ = list(module_names) + list(class_names)
46+
47+
def __dir__():
48+
return __all__
49+
50+
return __all__, __getattr__, __dir__

Diff for: packages/python/plotly/_plotly_utils/optional_imports.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
_not_importable = set()
1313

1414

15-
def get_module(name):
15+
def get_module(name, should_load=True):
1616
"""
1717
Return module or None. Absolute import is required.
1818
@@ -23,6 +23,8 @@ def get_module(name):
2323
"""
2424
if name in sys.modules:
2525
return sys.modules[name]
26+
if not should_load:
27+
return None
2628
if name not in _not_importable:
2729
try:
2830
return import_module(name)

Diff for: packages/python/plotly/_plotly_utils/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def encode_as_sage(obj):
147147
@staticmethod
148148
def encode_as_pandas(obj):
149149
"""Attempt to convert pandas.NaT"""
150-
pandas = get_module("pandas")
150+
pandas = get_module("pandas", should_load=False)
151151
if not pandas:
152152
raise NotEncodable
153153

@@ -159,7 +159,7 @@ def encode_as_pandas(obj):
159159
@staticmethod
160160
def encode_as_numpy(obj):
161161
"""Attempt to convert numpy.ma.core.masked"""
162-
numpy = get_module("numpy")
162+
numpy = get_module("numpy", should_load=False)
163163
if not numpy:
164164
raise NotEncodable
165165

0 commit comments

Comments
 (0)