Skip to content

Commit 7351ae0

Browse files
authored
Detect mixed operation group (#3289)
* Detect mixed operation group * Guessing conflicts
1 parent dcc8ef3 commit 7351ae0

File tree

1 file changed

+44
-11
lines changed

1 file changed

+44
-11
lines changed

scripts/multiapi_init_gen.py

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import importlib
2+
import logging
23
import os
34
import pkgutil
45
import re
@@ -25,6 +26,9 @@
2526

2627
_GENERATE_MARKER = "############ Generated from here ############\n"
2728

29+
_LOGGER = logging.getLogger(__name__)
30+
31+
2832
def parse_input(input_parameter):
2933
"""From a syntax like package_name#submodule, build a package name
3034
and complete module name.
@@ -48,26 +52,53 @@ def get_versionned_modules(package_name, module_name, sdk_root=None):
4852
for (_, label, ispkg) in pkgutil.iter_modules(module_to_generate.__path__)
4953
if label.startswith("v20") and ispkg]
5054

55+
def extract_api_version_from_code(function):
56+
"""Will extract from __code__ the API version. Should be use if you use this is an operation group with no constant api_version.
57+
"""
58+
try:
59+
if "api_version" in function.__code__.co_varnames:
60+
return function.__code__.co_consts[1]
61+
except Exception:
62+
pass
63+
5164
def build_operation_meta(versionned_modules):
5265
version_dict = {}
5366
mod_to_api_version = {}
5467
for versionned_label, versionned_mod in versionned_modules:
55-
extracted_api_version = None
68+
extracted_api_versions = set()
5669
client_doc = versionned_mod.__dict__[versionned_mod.__all__[0]].__doc__
5770
operations = list(re.finditer(r':ivar (?P<attr>[a-z_]+): \w+ operations\n\s+:vartype (?P=attr): .*.operations.(?P<clsname>\w+)\n', client_doc))
5871
for operation in operations:
5972
attr, clsname = operation.groups()
6073
version_dict.setdefault(attr, []).append((versionned_label, clsname))
61-
if not extracted_api_version:
62-
# Create a fake operation group to extract easily the real api version
63-
try:
64-
extracted_api_version = versionned_mod.operations.__dict__[clsname](None, None, None, None).api_version
65-
except Exception:
66-
# Should not happen. I guess it mixed operation groups like VMSS Network...
67-
pass
68-
if not extracted_api_version:
69-
sys.exit("Was not able to extract api_version of %s" % versionned_label)
70-
mod_to_api_version[versionned_label] = extracted_api_version
74+
75+
# Create a fake operation group to extract easily the real api version
76+
extracted_api_version = None
77+
try:
78+
extracted_api_version = versionned_mod.operations.__dict__[clsname](None, None, None, None).api_version
79+
except Exception:
80+
# Should not happen. I guess it mixed operation groups like VMSS Network...
81+
for func_name, function in versionned_mod.operations.__dict__[clsname].__dict__.items():
82+
if not func_name.startswith("__"):
83+
extracted_api_version = extract_api_version_from_code(function)
84+
if extracted_api_version:
85+
extracted_api_versions.add(extracted_api_version)
86+
87+
if not extracted_api_versions:
88+
sys.exit("Was not able to extract api_version of {}".format(versionned_label))
89+
if len(extracted_api_versions) >= 2:
90+
# Mixed operation group, try to figure out what we want to use
91+
final_api_version = None
92+
_LOGGER.warning("Found too much API version: {} in label {}".format(extracted_api_versions, versionned_label))
93+
for candidate_api_version in extracted_api_versions:
94+
if "v{}".format(candidate_api_version.replace("-", "_")) == versionned_label:
95+
final_api_version = candidate_api_version
96+
_LOGGER.warning("Guessing you want {} based on label {}".format(final_api_version, versionned_label))
97+
break
98+
else:
99+
sys.exit("Unble to match {} to label {}".format(extracted_api_versions, versionned_label))
100+
extracted_api_versions = {final_api_version}
101+
mod_to_api_version[versionned_label] = extracted_api_versions.pop()
71102

72103
# latest: api_version=mod_to_api_version[versions[-1][0]]
73104

@@ -160,6 +191,8 @@ def _models_dict(cls, api_version):
160191
"""
161192

162193
if __name__ == "__main__":
194+
logging.basicConfig(level=logging.INFO)
195+
163196
package_name, module_name = parse_input(sys.argv[1])
164197
versionned_modules = get_versionned_modules(package_name, module_name)
165198
version_dict, mod_to_api_version = build_operation_meta(versionned_modules)

0 commit comments

Comments
 (0)