Skip to content

REF: Refactor plugin.py to use modular dataclasses in tree-like structure to represent parsed data #121

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Jul 25, 2020
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ help: ## - Show this help.
generate: ## - Generate test cases (do this once before running test)
poetry run python -m tests.generate

test: ## - Run tests
poetry run pytest --cov betterproto
test: ## - Run tests, ingoring collection errors (ex from missing imports)
poetry run pytest --cov betterproto --continue-on-collection-errors
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is obviously not needed for this PR, but I found it very useful, will remove before final review.


types: ## - Check types with mypy
poetry run mypy src/betterproto --ignore-missing-imports
Expand Down
292 changes: 50 additions & 242 deletions src/betterproto/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,13 @@
import itertools
import os.path
import pathlib
import re
import sys
import textwrap
from typing import List, Union

from google.protobuf.compiler.plugin_pb2 import CodeGeneratorRequest

import betterproto
from betterproto.compile.importing import get_type_reference, parse_source_type_name
from betterproto.compile.naming import (
pythonize_class_name,
pythonize_field_name,
pythonize_method_name,
)
from betterproto.lib.google.protobuf import ServiceDescriptorProto
from betterproto.compile.importing import get_type_reference

try:
# betterproto[compiler] specific dependencies
Expand All @@ -28,7 +20,7 @@
EnumDescriptorProto,
FieldDescriptorProto,
)
import google.protobuf.wrappers_pb2 as google_wrappers
from google.protobuf.compiler.plugin_pb2 import CodeGeneratorRequest
import jinja2
except ImportError as err:
missing_import = err.args[0][17:-1]
Expand All @@ -42,6 +34,20 @@
)
raise SystemExit(1)

from .plugin_dataclasses import (
OutputTemplate,
ProtoInputFile,
Message,
Field,
OneOfField,
MapField,
EnumDefinition,
Service,
ServiceMethod,
is_map,
is_oneof
)


def py_type(package: str, imports: set, field: FieldDescriptorProto) -> str:
if field.type in [1, 2]:
Expand Down Expand Up @@ -151,43 +157,27 @@ def generate_code(request, response):

# Initialize Template data for each package
for output_package_name, output_package_content in output_package_files.items():
template_data = {
"input_package": output_package_content["input_package"],
"files": [f.name for f in output_package_content["files"]],
"imports": set(),
"datetime_imports": set(),
"typing_imports": set(),
"messages": [],
"enums": [],
"services": [],
}
template_data = OutputTemplate(input_package=output_package_content["input_package"])
for input_proto_file in output_package_content["files"]:
ProtoInputFile(parent=template_data, proto_obj=input_proto_file)
output_package_content["template_data"] = template_data

# Read Messages and Enums
output_types = []
for output_package_name, output_package_content in output_package_files.items():
for proto_file in output_package_content["files"]:
for item, path in traverse(proto_file):
type_data = read_protobuf_type(
item, path, proto_file, output_package_content
)
output_types.append(type_data)
for proto_file_data in output_package_content["template_data"].input_files:
for item, path in traverse(proto_file_data.proto_obj):
read_protobuf_type(item=item, path=path, proto_file_data=proto_file_data)

# Read Services
for output_package_name, output_package_content in output_package_files.items():
for proto_file in output_package_content["files"]:
for index, service in enumerate(proto_file.service):
read_protobuf_service(
service, index, proto_file, output_package_content, output_types
)
for proto_file_data in output_package_content["template_data"].input_files:
for index, service in enumerate(proto_file_data.proto_obj.service):
read_protobuf_service(service, index, proto_file_data)

# Render files
output_paths = set()
for output_package_name, output_package_content in output_package_files.items():
template_data = output_package_content["template_data"]
template_data["imports"] = sorted(template_data["imports"])
template_data["datetime_imports"] = sorted(template_data["datetime_imports"])
template_data["typing_imports"] = sorted(template_data["typing_imports"])

# Fill response
output_path = pathlib.Path(*output_package_name.split("."), "__init__.py")
Expand Down Expand Up @@ -220,224 +210,42 @@ def generate_code(request, response):
print(f"Writing {output_package_name}", file=sys.stderr)


def read_protobuf_type(item: DescriptorProto, path: List[int], proto_file, content):
input_package_name = content["input_package"]
template_data = content["template_data"]
data = {
"name": item.name,
"py_name": pythonize_class_name(item.name),
"descriptor": item,
"package": input_package_name,
}
def read_protobuf_type(item: DescriptorProto, path: List[int], proto_file_data: ProtoInputFile):
if isinstance(item, DescriptorProto):
# print(item, file=sys.stderr)
if item.options.map_entry:
# Skip generated map entry messages since we just use dicts
return

data.update(
{
"type": "Message",
"comment": get_comment(proto_file, path),
"properties": [],
}
# Process Message
message_data = Message(
parent=proto_file_data,
proto_obj=item,
path=path
)

for i, f in enumerate(item.field):
t = py_type(input_package_name, template_data["imports"], f)
zero = get_py_zero(f.type)

repeated = False
packed = False

field_type = f.Type.Name(f.type).lower()[5:]

field_wraps = ""
match_wrapper = re.match(r"\.google\.protobuf\.(.+)Value", f.type_name)
if match_wrapper:
wrapped_type = "TYPE_" + match_wrapper.group(1).upper()
if hasattr(betterproto, wrapped_type):
field_wraps = f"betterproto.{wrapped_type}"

map_types = None
if f.type == 11:
# This might be a map...
message_type = f.type_name.split(".").pop().lower()
# message_type = py_type(package)
map_entry = f"{f.name.replace('_', '').lower()}entry"

if message_type == map_entry:
for nested in item.nested_type:
if nested.name.replace("_", "").lower() == map_entry:
if nested.options.map_entry:
# print("Found a map!", file=sys.stderr)
k = py_type(
input_package_name,
template_data["imports"],
nested.field[0],
)
v = py_type(
input_package_name,
template_data["imports"],
nested.field[1],
)
t = f"Dict[{k}, {v}]"
field_type = "map"
map_types = (
f.Type.Name(nested.field[0].type),
f.Type.Name(nested.field[1].type),
)
template_data["typing_imports"].add("Dict")

if f.label == 3 and field_type != "map":
# Repeated field
repeated = True
t = f"List[{t}]"
zero = "[]"
template_data["typing_imports"].add("List")

if f.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]:
packed = True

one_of = ""
if f.HasField("oneof_index"):
one_of = item.oneof_decl[f.oneof_index].name

if "Optional[" in t:
template_data["typing_imports"].add("Optional")

if "timedelta" in t:
template_data["datetime_imports"].add("timedelta")
elif "datetime" in t:
template_data["datetime_imports"].add("datetime")

data["properties"].append(
{
"name": f.name,
"py_name": pythonize_field_name(f.name),
"number": f.number,
"comment": get_comment(proto_file, path + [2, i]),
"proto_type": int(f.type),
"field_type": field_type,
"field_wraps": field_wraps,
"map_types": map_types,
"type": t,
"zero": zero,
"repeated": repeated,
"packed": packed,
"one_of": one_of,
}
)
# print(f, file=sys.stderr)

template_data["messages"].append(data)
return data
for index, field in enumerate(item.field):
if is_map(field, item):
MapField(parent=message_data, proto_obj=field, path=path+[2, index])
elif is_oneof(field):
OneOfField(parent=message_data, proto_obj=field, path=path+[2, index])
else:
Field(parent=message_data, proto_obj=field, path=path+[2, index])
elif isinstance(item, EnumDescriptorProto):
# print(item.name, path, file=sys.stderr)
data.update(
{
"type": "Enum",
"comment": get_comment(proto_file, path),
"entries": [
{
"name": v.name,
"value": v.number,
"comment": get_comment(proto_file, path + [2, i]),
}
for i, v in enumerate(item.value)
],
}
)

template_data["enums"].append(data)
return data

# Enum
EnumDefinition(proto_obj=item, parent=proto_file_data, path=path)

def lookup_method_input_type(method, types):
package, name = parse_source_type_name(method.input_type)

for known_type in types:
if known_type["type"] != "Message":
continue

# Nested types are currently flattened without dots.
# Todo: keep a fully quantified name in types, that is comparable with method.input_type
if (
package == known_type["package"]
and name.replace(".", "") == known_type["name"]
):
return known_type


def is_mutable_field_type(field_type: str) -> bool:
return field_type.startswith("List[") or field_type.startswith("Dict[")


def read_protobuf_service(
service: ServiceDescriptorProto, index, proto_file, content, output_types
):
input_package_name = content["input_package"]
template_data = content["template_data"]
# print(service, file=sys.stderr)
data = {
"name": service.name,
"py_name": pythonize_class_name(service.name),
"comment": get_comment(proto_file, [6, index]),
"methods": [],
}
def read_protobuf_service(service: ServiceDescriptorProto, index: int, proto_file_data: ProtoInputFile):
service_data = Service(
parent=proto_file_data,
proto_obj=service,
path=[6, index],
)
for j, method in enumerate(service.method):
method_input_message = lookup_method_input_type(method, output_types)

# This section ensures that method arguments having a default
# value that is initialised as a List/Dict (mutable) is replaced
# with None and initialisation is deferred to the beginning of the
# method definition. This is done so to avoid any side-effects.
# Reference: https://docs.python-guide.org/writing/gotchas/#mutable-default-arguments
mutable_default_args = []

if method_input_message:
for field in method_input_message["properties"]:
if (
not method.client_streaming
and field["zero"] != "None"
and is_mutable_field_type(field["type"])
):
mutable_default_args.append((field["py_name"], field["zero"]))
field["zero"] = "None"

if field["zero"] == "None":
template_data["typing_imports"].add("Optional")

data["methods"].append(
{
"name": method.name,
"py_name": pythonize_method_name(method.name),
"comment": get_comment(proto_file, [6, index, 2, j], indent=8),
"route": f"/{input_package_name}.{service.name}/{method.name}",
"input": get_type_reference(
input_package_name, template_data["imports"], method.input_type
).strip('"'),
"input_message": method_input_message,
"output": get_type_reference(
input_package_name,
template_data["imports"],
method.output_type,
unwrap=False,
),
"client_streaming": method.client_streaming,
"server_streaming": method.server_streaming,
"mutable_default_args": mutable_default_args,
}
ServiceMethod(
parent=service_data,
proto_obj=method,
path=[6, index, 2, j],
)

if method.client_streaming:
template_data["typing_imports"].add("AsyncIterable")
template_data["typing_imports"].add("Iterable")
template_data["typing_imports"].add("Union")
if method.server_streaming:
template_data["typing_imports"].add("AsyncIterator")
template_data["services"].append(data)


def main():
"""The plugin's main entry point."""
Expand Down
Loading