Skip to content

Commit 2f33a36

Browse files
auto-generate aggregation classes
1 parent bacfc74 commit 2f33a36

File tree

1 file changed

+90
-7
lines changed

1 file changed

+90
-7
lines changed

Diff for: utils/generator.py

+90-7
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
lstrip_blocks=True,
3333
)
3434
query_py = jinja_env.get_template("query.py.tpl")
35+
aggs_py = jinja_env.get_template("aggs.py.tpl")
3536
types_py = jinja_env.get_template("types.py.tpl")
3637

3738
# map with name replacements for Elasticsearch attributes
@@ -43,6 +44,22 @@
4344
"_types.query_dsl:DistanceFeatureQuery": "_types.query_dsl:DistanceFeatureQueryBase",
4445
}
4546

47+
# some aggregation types are complicated to determine from the schema, so they
48+
# have their correct type here
49+
AGG_TYPES = {
50+
"bucket_count_ks_test": "Pipeline",
51+
"bucket_correlation": "Pipeline",
52+
"bucket_sort": "Bucket",
53+
"categorize_text": "Bucket",
54+
"filter": "Bucket",
55+
"moving_avg": "Pipeline",
56+
"variable_width_histogram": "Bucket",
57+
}
58+
59+
60+
def property_to_class_name(name):
61+
return "".join([w.title() if w != "ip" else "IP" for w in name.split("_")])
62+
4663

4764
def wrapped_doc(text, width=70, initial_indent="", subsequent_indent=""):
4865
"""Formats a docstring as a list of lines of up to the request width."""
@@ -101,6 +118,18 @@ def find_type(self, name, namespace=None):
101118
):
102119
return t
103120

121+
def inherits_from(self, type_, name, namespace=None):
122+
while "inherits" in type_:
123+
type_ = self.find_type(
124+
type_["inherits"]["type"]["name"],
125+
type_["inherits"]["type"]["namespace"],
126+
)
127+
if type_["name"]["name"] == name and (
128+
namespace is None or type_["name"]["namespace"] == namespace
129+
):
130+
return True
131+
return False
132+
104133
def get_python_type(self, schema_type):
105134
"""Obtain Python typing details for a given schema type
106135
@@ -156,7 +185,9 @@ def get_python_type(self, schema_type):
156185
# for dicts we use Mapping[key_type, value_type]
157186
key_type, key_param = self.get_python_type(schema_type["key"])
158187
value_type, value_param = self.get_python_type(schema_type["value"])
159-
return f"Mapping[{key_type}, {value_type}]", None
188+
return f"Mapping[{key_type}, {value_type}]", (
189+
{**value_param, "hash": True} if value_param else None
190+
)
160191

161192
elif schema_type["kind"] == "union_of":
162193
if (
@@ -334,17 +365,38 @@ def property_to_python_class(self, p):
334365
"""
335366
k = {
336367
"property_name": p["name"],
337-
"name": "".join([w.title() for w in p["name"].split("_")]),
368+
"name": property_to_class_name(p["name"]),
338369
}
339370
k["docstring"] = wrapped_doc(p.get("description") or "")
371+
other_classes = []
340372
kind = p["type"]["kind"]
341373
if kind == "instance_of":
342374
namespace = p["type"]["type"]["namespace"]
343375
name = p["type"]["type"]["name"]
344376
if f"{namespace}:{name}" in TYPE_REPLACEMENTS:
345377
namespace, name = TYPE_REPLACEMENTS[f"{namespace}:{name}"].split(":")
346-
type_ = schema.find_type(name, namespace)
378+
if name == "QueryContainer" and namespace == "_types.query_dsl":
379+
type_ = {
380+
"kind": "interface",
381+
"properties": [p],
382+
}
383+
else:
384+
type_ = schema.find_type(name, namespace)
385+
if p["name"] in AGG_TYPES:
386+
k["parent"] = AGG_TYPES[p["name"]]
387+
347388
if type_["kind"] == "interface":
389+
# set the correct parent for bucket and pipeline aggregations
390+
if self.inherits_from(
391+
type_, "PipelineAggregationBase", "_types.aggregations"
392+
):
393+
k["parent"] = "Pipeline"
394+
elif self.inherits_from(
395+
type_, "BucketAggregationBase", "_types.aggregations"
396+
):
397+
k["parent"] = "Bucket"
398+
399+
# generate class attributes
348400
k["args"] = []
349401
k["params"] = []
350402
if "behaviors" in type_:
@@ -397,6 +449,21 @@ def property_to_python_class(self, p):
397449
)
398450
else:
399451
break
452+
453+
elif type_["kind"] == "type_alias":
454+
if type_["type"]["kind"] == "union_of":
455+
# for unions we create sub-classes
456+
for other in type_["type"]["items"]:
457+
other_class = self.interface_to_python_class(
458+
other["type"]["name"], self.interfaces, for_types_py=False
459+
)
460+
other_class["parent"] = k["name"]
461+
other_classes.append(other_class)
462+
else:
463+
raise RuntimeError(
464+
"Cannot generate code for instances of type_alias instances that are not unions."
465+
)
466+
400467
else:
401468
raise RuntimeError(
402469
f"Cannot generate code for instances of kind '{type_['kind']}'"
@@ -444,9 +511,9 @@ def property_to_python_class(self, p):
444511

445512
else:
446513
raise RuntimeError(f"Cannot generate code for type {p['type']}")
447-
return k
514+
return [k] + other_classes
448515

449-
def interface_to_python_class(self, interface, interfaces):
516+
def interface_to_python_class(self, interface, interfaces, for_types_py=True):
450517
"""Return a dictionary with template data necessary to render an
451518
interface a Python class.
452519
@@ -477,7 +544,7 @@ def interface_to_python_class(self, interface, interfaces):
477544
k = {"name": interface, "args": []}
478545
while True:
479546
for arg in type_["properties"]:
480-
schema.add_attribute(k, arg, for_types_py=True)
547+
schema.add_attribute(k, arg, for_types_py=for_types_py)
481548

482549
if "inherits" not in type_ or "type" not in type_["inherits"]:
483550
break
@@ -500,13 +567,28 @@ def generate_query_py(schema, filename):
500567
classes = []
501568
query_container = schema.find_type("QueryContainer", "_types.query_dsl")
502569
for p in query_container["properties"]:
503-
classes.append(schema.property_to_python_class(p))
570+
classes += schema.property_to_python_class(p)
504571

505572
with open(filename, "wt") as f:
506573
f.write(query_py.render(classes=classes, parent="Query"))
507574
print(f"Generated {filename}.")
508575

509576

577+
def generate_aggs_py(schema, filename):
578+
"""Generate aggs.py with all the properties of `AggregationContainer` as
579+
Python classes.
580+
"""
581+
classes = []
582+
aggs_container = schema.find_type("AggregationContainer", "_types.aggregations")
583+
for p in aggs_container["properties"]:
584+
if "containerProperty" not in p or not p["containerProperty"]:
585+
classes += schema.property_to_python_class(p)
586+
587+
with open(filename, "wt") as f:
588+
f.write(aggs_py.render(classes=classes, parent="Agg"))
589+
print(f"Generated {filename}.")
590+
591+
510592
def generate_types_py(schema, filename):
511593
"""Generate types.py"""
512594
classes = {}
@@ -542,4 +624,5 @@ def generate_types_py(schema, filename):
542624
if __name__ == "__main__":
543625
schema = ElasticsearchSchema()
544626
generate_query_py(schema, "elasticsearch_dsl/query.py")
627+
generate_aggs_py(schema, "elasticsearch_dsl/aggs.py")
545628
generate_types_py(schema, "elasticsearch_dsl/types.py")

0 commit comments

Comments
 (0)