Skip to content

REF: Refactor plugin.py to use modular dataclasses in tree-like structure to represent parsed data #121

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Jul 25, 2020
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
321 changes: 61 additions & 260 deletions src/betterproto/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,13 @@
import itertools
import os.path
import pathlib
import re
import sys
import textwrap
from typing import List, Union

from google.protobuf.compiler.plugin_pb2 import CodeGeneratorRequest

import betterproto
from betterproto.compile.importing import get_type_reference, parse_source_type_name
from betterproto.compile.naming import (
pythonize_class_name,
pythonize_field_name,
pythonize_method_name,
)
from betterproto.lib.google.protobuf import ServiceDescriptorProto
from betterproto.compile.importing import get_type_reference

try:
# betterproto[compiler] specific dependencies
Expand All @@ -28,7 +20,7 @@
EnumDescriptorProto,
FieldDescriptorProto,
)
import google.protobuf.wrappers_pb2 as google_wrappers
from google.protobuf.compiler.plugin_pb2 import CodeGeneratorRequest
import jinja2
except ImportError as err:
missing_import = err.args[0][17:-1]
Expand All @@ -42,6 +34,20 @@
)
raise SystemExit(1)

from .plugin_dataclasses import (
Request,
OutputTemplate,
Message,
Field,
OneOfField,
MapField,
EnumDefinition,
Service,
ServiceMethod,
is_map,
is_oneof,
)


def py_type(package: str, imports: set, field: FieldDescriptorProto) -> str:
if field.type in [1, 2]:
Expand Down Expand Up @@ -133,70 +139,52 @@ def generate_code(request, response):
loader=jinja2.FileSystemLoader("%s/templates/" % os.path.dirname(__file__)),
)
template = env.get_template("template.py.j2")

request_data = Request(plugin_request_obj=request)
# Gather output packages
output_package_files = collections.defaultdict()
for proto_file in request.proto_file:
if (
proto_file.package == "google.protobuf"
and "INCLUDE_GOOGLE" not in plugin_options
):
# If not INCLUDE_GOOGLE,
# skip re-compiling Google's well-known types
continue

output_package = proto_file.package
output_package_files.setdefault(
output_package, {"input_package": proto_file.package, "files": []}
)
output_package_files[output_package]["files"].append(proto_file)

# Initialize Template data for each package
for output_package_name, output_package_content in output_package_files.items():
template_data = {
"input_package": output_package_content["input_package"],
"files": [f.name for f in output_package_content["files"]],
"imports": set(),
"datetime_imports": set(),
"typing_imports": set(),
"messages": [],
"enums": [],
"services": [],
}
output_package_content["template_data"] = template_data
output_package_name = proto_file.package
if output_package_name not in request_data.output_packages:
# Create a new output if there is no output for this package
request_data.output_packages[output_package_name] = OutputTemplate(
parent_request=request_data, package_proto_obj=proto_file
)
# Add this input file to the output corresponding to this package
request_data.output_packages[output_package_name].input_files.append(proto_file)

# Read Messages and Enums
output_types = []
for output_package_name, output_package_content in output_package_files.items():
for proto_file in output_package_content["files"]:
for item, path in traverse(proto_file):
type_data = read_protobuf_type(
item, path, proto_file, output_package_content
)
output_types.append(type_data)
# We need to read Messages before Services in so that we can
# get the references to input/output messages for each service
for output_package_name, output_package in request_data.output_packages.items():
for proto_input_file in output_package.input_files:
for item, path in traverse(proto_input_file):
read_protobuf_type(item=item, path=path, output_package=output_package)

# Read Services
for output_package_name, output_package_content in output_package_files.items():
for proto_file in output_package_content["files"]:
for index, service in enumerate(proto_file.service):
read_protobuf_service(
service, index, proto_file, output_package_content, output_types
)

# Render files
for output_package_name, output_package in request_data.output_packages.items():
for proto_input_file in output_package.input_files:
for index, service in enumerate(proto_input_file.service):
read_protobuf_service(service, index, output_package)

# Generate output files
output_paths = set()
for output_package_name, output_package_content in output_package_files.items():
template_data = output_package_content["template_data"]
template_data["imports"] = sorted(template_data["imports"])
template_data["datetime_imports"] = sorted(template_data["datetime_imports"])
template_data["typing_imports"] = sorted(template_data["typing_imports"])
for output_package_name, template_data in request_data.output_packages.items():

# Fill response
# Add files to the response object
output_path = pathlib.Path(*output_package_name.split("."), "__init__.py")
output_paths.add(output_path)

f = response.file.add()
f.name = str(output_path)

# Render and then format the output file.
# Render and then format the output file
f.content = black.format_str(
template.render(description=template_data),
mode=black.FileMode(target_versions={black.TargetVersion.PY37}),
Expand All @@ -220,226 +208,39 @@ def generate_code(request, response):
print(f"Writing {output_package_name}", file=sys.stderr)


def read_protobuf_type(item: DescriptorProto, path: List[int], proto_file, content):
input_package_name = content["input_package"]
template_data = content["template_data"]
data = {
"name": item.name,
"py_name": pythonize_class_name(item.name),
"descriptor": item,
"package": input_package_name,
}
def read_protobuf_type(
item: DescriptorProto, path: List[int], output_package: OutputTemplate
):
if isinstance(item, DescriptorProto):
# print(item, file=sys.stderr)
if item.options.map_entry:
# Skip generated map entry messages since we just use dicts
return

data.update(
{
"type": "Message",
"comment": get_comment(proto_file, path),
"properties": [],
}
)

for i, f in enumerate(item.field):
t = py_type(input_package_name, template_data["imports"], f)
zero = get_py_zero(f.type)

repeated = False
packed = False

field_type = f.Type.Name(f.type).lower()[5:]

field_wraps = ""
match_wrapper = re.match(r"\.google\.protobuf\.(.+)Value", f.type_name)
if match_wrapper:
wrapped_type = "TYPE_" + match_wrapper.group(1).upper()
if hasattr(betterproto, wrapped_type):
field_wraps = f"betterproto.{wrapped_type}"

map_types = None
if f.type == 11:
# This might be a map...
message_type = f.type_name.split(".").pop().lower()
# message_type = py_type(package)
map_entry = f"{f.name.replace('_', '').lower()}entry"

if message_type == map_entry:
for nested in item.nested_type:
if nested.name.replace("_", "").lower() == map_entry:
if nested.options.map_entry:
# print("Found a map!", file=sys.stderr)
k = py_type(
input_package_name,
template_data["imports"],
nested.field[0],
)
v = py_type(
input_package_name,
template_data["imports"],
nested.field[1],
)
t = f"Dict[{k}, {v}]"
field_type = "map"
map_types = (
f.Type.Name(nested.field[0].type),
f.Type.Name(nested.field[1].type),
)
template_data["typing_imports"].add("Dict")

if f.label == 3 and field_type != "map":
# Repeated field
repeated = True
t = f"List[{t}]"
zero = "[]"
template_data["typing_imports"].add("List")

if f.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]:
packed = True

one_of = ""
if f.HasField("oneof_index"):
one_of = item.oneof_decl[f.oneof_index].name

if "Optional[" in t:
template_data["typing_imports"].add("Optional")

if "timedelta" in t:
template_data["datetime_imports"].add("timedelta")
elif "datetime" in t:
template_data["datetime_imports"].add("datetime")

data["properties"].append(
{
"name": f.name,
"py_name": pythonize_field_name(f.name),
"number": f.number,
"comment": get_comment(proto_file, path + [2, i]),
"proto_type": int(f.type),
"field_type": field_type,
"field_wraps": field_wraps,
"map_types": map_types,
"type": t,
"zero": zero,
"repeated": repeated,
"packed": packed,
"one_of": one_of,
}
)
# print(f, file=sys.stderr)

template_data["messages"].append(data)
return data
# Process Message
message_data = Message(parent=output_package, proto_obj=item, path=path)
for index, field in enumerate(item.field):
if is_map(field, item):
MapField(parent=message_data, proto_obj=field, path=path + [2, index])
elif is_oneof(field):
OneOfField(parent=message_data, proto_obj=field, path=path + [2, index])
else:
Field(parent=message_data, proto_obj=field, path=path + [2, index])
elif isinstance(item, EnumDescriptorProto):
# print(item.name, path, file=sys.stderr)
data.update(
{
"type": "Enum",
"comment": get_comment(proto_file, path),
"entries": [
{
"name": v.name,
"value": v.number,
"comment": get_comment(proto_file, path + [2, i]),
}
for i, v in enumerate(item.value)
],
}
)

template_data["enums"].append(data)
return data


def lookup_method_input_type(method, types):
package, name = parse_source_type_name(method.input_type)

for known_type in types:
if known_type["type"] != "Message":
continue

# Nested types are currently flattened without dots.
# Todo: keep a fully quantified name in types, that is comparable with method.input_type
if (
package == known_type["package"]
and name.replace(".", "") == known_type["name"]
):
return known_type


def is_mutable_field_type(field_type: str) -> bool:
return field_type.startswith("List[") or field_type.startswith("Dict[")
# Enum
EnumDefinition(parent=output_package, proto_obj=item, path=path)


def read_protobuf_service(
service: ServiceDescriptorProto, index, proto_file, content, output_types
service: ServiceDescriptorProto, index: int, output_package: OutputTemplate
):
input_package_name = content["input_package"]
template_data = content["template_data"]
# print(service, file=sys.stderr)
data = {
"name": service.name,
"py_name": pythonize_class_name(service.name),
"comment": get_comment(proto_file, [6, index]),
"methods": [],
}
service_data = Service(parent=output_package, proto_obj=service, path=[6, index],)
for j, method in enumerate(service.method):
method_input_message = lookup_method_input_type(method, output_types)

# This section ensures that method arguments having a default
# value that is initialised as a List/Dict (mutable) is replaced
# with None and initialisation is deferred to the beginning of the
# method definition. This is done so to avoid any side-effects.
# Reference: https://docs.python-guide.org/writing/gotchas/#mutable-default-arguments
mutable_default_args = []

if method_input_message:
for field in method_input_message["properties"]:
if (
not method.client_streaming
and field["zero"] != "None"
and is_mutable_field_type(field["type"])
):
mutable_default_args.append((field["py_name"], field["zero"]))
field["zero"] = "None"

if field["zero"] == "None":
template_data["typing_imports"].add("Optional")

data["methods"].append(
{
"name": method.name,
"py_name": pythonize_method_name(method.name),
"comment": get_comment(proto_file, [6, index, 2, j], indent=8),
"route": f"/{input_package_name}.{service.name}/{method.name}",
"input": get_type_reference(
input_package_name, template_data["imports"], method.input_type
).strip('"'),
"input_message": method_input_message,
"output": get_type_reference(
input_package_name,
template_data["imports"],
method.output_type,
unwrap=False,
),
"client_streaming": method.client_streaming,
"server_streaming": method.server_streaming,
"mutable_default_args": mutable_default_args,
}
ServiceMethod(
parent=service_data, proto_obj=method, path=[6, index, 2, j],
)

if method.client_streaming:
template_data["typing_imports"].add("AsyncIterable")
template_data["typing_imports"].add("Iterable")
template_data["typing_imports"].add("Union")
if method.server_streaming:
template_data["typing_imports"].add("AsyncIterator")
template_data["services"].append(data)


def main():

"""The plugin's main entry point."""
# Read request message from stdin
data = sys.stdin.buffer.read()
Expand Down
Loading