diff --git a/betterproto/plugin.bat b/betterproto/plugin.bat new file mode 100644 index 000000000..9b837d7dc --- /dev/null +++ b/betterproto/plugin.bat @@ -0,0 +1,2 @@ +@SET plugin_dir=%~dp0 +@python %plugin_dir%/plugin.py %* \ No newline at end of file diff --git a/betterproto/plugin.py b/betterproto/plugin.py index 597bf1a15..9c1dc340e 100755 --- a/betterproto/plugin.py +++ b/betterproto/plugin.py @@ -1,5 +1,5 @@ #!/usr/bin/env python - +import collections import itertools import json import os.path @@ -30,38 +30,54 @@ from betterproto.casing import safe_snake_case +import google.protobuf.wrappers_pb2 WRAPPER_TYPES = { - "google.protobuf.DoubleValue": "float", - "google.protobuf.FloatValue": "float", - "google.protobuf.Int64Value": "int", - "google.protobuf.UInt64Value": "int", - "google.protobuf.Int32Value": "int", - "google.protobuf.UInt32Value": "int", - "google.protobuf.BoolValue": "bool", - "google.protobuf.StringValue": "str", - "google.protobuf.BytesValue": "bytes", + google.protobuf.wrappers_pb2.DoubleValue: "float", + google.protobuf.wrappers_pb2.FloatValue: "float", + google.protobuf.wrappers_pb2.Int64Value: "int", + google.protobuf.wrappers_pb2.UInt64Value: "int", + google.protobuf.wrappers_pb2.Int32Value: "int", + google.protobuf.wrappers_pb2.UInt32Value: "int", + google.protobuf.wrappers_pb2.BoolValue: "bool", + google.protobuf.wrappers_pb2.StringValue: "str", + google.protobuf.wrappers_pb2.BytesValue: "bytes", } -def get_ref_type(package: str, imports: set, type_name: str) -> str: +def get_wrapper_type(type_name: str) -> (Any, str): + for wrapper, wrapped_type in WRAPPER_TYPES.items(): + if wrapper.DESCRIPTOR.full_name == type_name: + return wrapper, wrapped_type + return None, None + + +def get_ref_type(package: str, imports: set, type_name: str, unwrap: bool = True) -> str: """ Return a Python type name for a proto type reference. Adds the import if - necessary. + necessary. Unwraps well known type if required. """ # If the package name is a blank string, then this should still work # because by convention packages are lowercase and message/enum types are # pascal-cased. May require refactoring in the future. type_name = type_name.lstrip(".") - if type_name in WRAPPER_TYPES: - return f"Optional[{WRAPPER_TYPES[type_name]}]" + # Check if type is wrapper. + wrapper, wrapped_type = get_wrapper_type(type_name) - if type_name == "google.protobuf.Duration": - return "timedelta" + if unwrap: + if wrapper: + return f"Optional[{wrapped_type}]" - if type_name == "google.protobuf.Timestamp": - return "datetime" + if type_name == "google.protobuf.Duration": + return "timedelta" + + if type_name == "google.protobuf.Timestamp": + return "datetime" + else: + if wrapper: + imports.add(f"from {wrapper.__module__} import {wrapper.__name__}") + return f"{wrapper.__name__}" if type_name.startswith(package): parts = type_name.lstrip(package).lstrip(".").split(".") @@ -74,10 +90,32 @@ def get_ref_type(package: str, imports: set, type_name: str) -> str: if "." in type_name: # This is imported from another package. No need # to use a forward ref and we need to add the import. - parts = type_name.split(".") - parts[-1] = stringcase.pascalcase(parts[-1]) - imports.add(f"from .{'.'.join(parts[:-2])} import {parts[-2]}") - type_name = f"{parts[-2]}.{parts[-1]}" + type_package, type_base_name = type_name.rsplit('.', 1) + + import_child_package = type_package.startswith(package + '.') + import_parent_package = package.startswith(type_package + '.') + + if import_child_package: + relative_package = type_package[len(package) + 1:] + relative_type_name = relative_package + '.' + type_base_name + + imports.add(f"from . import {relative_package} # relative import from {package}") + return relative_type_name + elif import_parent_package: + imports.add(f"import {type_package} # absolute parent import from {package}") + return f"'{type_package}.{type_base_name}'" + elif not package: + if type_package.count('.'): + type_packages = type_package.rsplit('.', 1) + sys.stderr.write(f"type_package = {type_packages}\n") + alias = safe_snake_case(type_package) + imports.add(f"from .{type_packages[0]} import {type_packages[1]} as {alias} # relative import from root") + return f"'{alias}.{type_base_name}'" + else: + imports.add(f"from . import {type_package} # import child from root") + return type_name + + return type_name return type_name @@ -174,28 +212,26 @@ def generate_code(request, response): ) template = env.get_template("template.py") - output_map = {} + output_files = collections.defaultdict() + for proto_file in request.proto_file: - out = proto_file.package - if out == "google.protobuf": + package = proto_file.package + if package == "google.protobuf": continue - if not out: - out = os.path.splitext(proto_file.name)[0].replace(os.path.sep, ".") - - if out not in output_map: - output_map[out] = {"package": proto_file.package, "files": []} - output_map[out]["files"].append(proto_file) + if package: + output_file_name = os.path.sep.join(package.split('.') + ['__init__.py']) + else: + output_file_name = '__root__.py' - # TODO: Figure out how to handle gRPC request/response messages and add - # processing below for Service. + output_files.setdefault(output_file_name, {"package": package, "files": []}) + output_files[output_file_name]["files"].append(proto_file) - for filename, options in output_map.items(): - package = options["package"] - # print(package, filename, file=sys.stderr) - output = { + for output_file_name, output_file_data in output_files.items(): + package = output_file_data["package"] + template_data = { "package": package, - "files": [f.name for f in options["files"]], + "files": [f.name for f in output_file_data["files"]], "imports": set(), "datetime_imports": set(), "typing_imports": set(), @@ -203,237 +239,226 @@ def generate_code(request, response): "enums": [], "services": [], } + output_file_data['template_data'] = template_data - type_mapping = {} - - for proto_file in options["files"]: - # print(proto_file.message_type, file=sys.stderr) - # print(proto_file.service, file=sys.stderr) - # print(proto_file.source_code_info, file=sys.stderr) - + # read messages + global_messages = [] + for output_file_name, output_file_data in output_files.items(): + for proto_file in output_file_data["files"]: for item, path in traverse(proto_file): - # print(item, file=sys.stderr) - # print(path, file=sys.stderr) - data = {"name": item.name, "py_name": stringcase.pascalcase(item.name)} - - if isinstance(item, DescriptorProto): - # print(item, file=sys.stderr) - if item.options.map_entry: - # Skip generated map entry messages since we just use dicts - continue - - data.update( - { - "type": "Message", - "comment": get_comment(proto_file, path), - "properties": [], - } - ) - - for i, f in enumerate(item.field): - t = py_type(package, output["imports"], item, f) - zero = get_py_zero(f.type) - - repeated = False - packed = False - - field_type = f.Type.Name(f.type).lower()[5:] - - field_wraps = "" - if f.type_name.startswith( - ".google.protobuf" - ) and f.type_name.endswith("Value"): - w = f.type_name.split(".").pop()[:-5].upper() - field_wraps = f"betterproto.TYPE_{w}" - - map_types = None - if f.type == 11: - # This might be a map... - message_type = f.type_name.split(".").pop().lower() - # message_type = py_type(package) - map_entry = f"{f.name.replace('_', '').lower()}entry" - - if message_type == map_entry: - for nested in item.nested_type: - if ( - nested.name.replace("_", "").lower() - == map_entry - ): - if nested.options.map_entry: - # print("Found a map!", file=sys.stderr) - k = py_type( - package, - output["imports"], - item, - nested.field[0], - ) - v = py_type( - package, - output["imports"], - item, - nested.field[1], - ) - t = f"Dict[{k}, {v}]" - field_type = "map" - map_types = ( - f.Type.Name(nested.field[0].type), - f.Type.Name(nested.field[1].type), - ) - output["typing_imports"].add("Dict") - - if f.label == 3 and field_type != "map": - # Repeated field - repeated = True - t = f"List[{t}]" - zero = "[]" - output["typing_imports"].add("List") - - if f.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]: - packed = True - - one_of = "" - if f.HasField("oneof_index"): - one_of = item.oneof_decl[f.oneof_index].name - - if "Optional[" in t: - output["typing_imports"].add("Optional") - - if "timedelta" in t: - output["datetime_imports"].add("timedelta") - elif "datetime" in t: - output["datetime_imports"].add("datetime") - - data["properties"].append( - { - "name": f.name, - "py_name": safe_snake_case(f.name), - "number": f.number, - "comment": get_comment(proto_file, path + [2, i]), - "proto_type": int(f.type), - "field_type": field_type, - "field_wraps": field_wraps, - "map_types": map_types, - "type": t, - "zero": zero, - "repeated": repeated, - "packed": packed, - "one_of": one_of, - } - ) - # print(f, file=sys.stderr) - - output["messages"].append(data) - elif isinstance(item, EnumDescriptorProto): - # print(item.name, path, file=sys.stderr) - data.update( - { - "type": "Enum", - "comment": get_comment(proto_file, path), - "entries": [ - { - "name": v.name, - "value": v.number, - "comment": get_comment(proto_file, path + [2, i]), - } - for i, v in enumerate(item.value) - ], - } - ) - - output["enums"].append(data) + read_item(item, path, proto_file, output_file_data['package'], output_file_data['template_data'], global_messages) + # read services + for output_file_name, output_file_data in output_files.items(): + for proto_file in output_file_data["files"]: + sys.stderr.write(f'===== service : file {proto_file.name} ====\n') for i, service in enumerate(proto_file.service): - # print(service, file=sys.stderr) + read_service(i, output_file_data['package'], proto_file, service, output_file_data['template_data'], global_messages) - data = { - "name": service.name, - "py_name": stringcase.pascalcase(service.name), - "comment": get_comment(proto_file, [6, i]), - "methods": [], - } + for output_file_name, output_file_data in output_files.items(): + template_data = output_file_data['template_data'] - for j, method in enumerate(service.method): - if method.client_streaming: - raise NotImplementedError("Client streaming not yet supported") - - input_message = None - input_type = get_ref_type( - package, output["imports"], method.input_type - ).strip('"') - for msg in output["messages"]: - if msg["name"] == input_type: - input_message = msg - for field in msg["properties"]: - if field["zero"] == "None": - output["typing_imports"].add("Optional") - break - - data["methods"].append( - { - "name": method.name, - "py_name": stringcase.snakecase(method.name), - "comment": get_comment(proto_file, [6, i, 2, j], indent=8), - "route": f"/{package}.{service.name}/{method.name}", - "input": get_ref_type( - package, output["imports"], method.input_type - ).strip('"'), - "input_message": input_message, - "output": get_ref_type( - package, output["imports"], method.output_type - ).strip('"'), - "client_streaming": method.client_streaming, - "server_streaming": method.server_streaming, - } - ) - - if method.server_streaming: - output["typing_imports"].add("AsyncGenerator") - - output["services"].append(data) - - output["imports"] = sorted(output["imports"]) - output["datetime_imports"] = sorted(output["datetime_imports"]) - output["typing_imports"] = sorted(output["typing_imports"]) + template_data["imports"] = sorted(template_data["imports"]) + template_data["datetime_imports"] = sorted(template_data["datetime_imports"]) + template_data["typing_imports"] = sorted(template_data["typing_imports"]) # Fill response f = response.file.add() - # print(filename, file=sys.stderr) - f.name = filename.replace(".", os.path.sep) + ".py" + f.name = output_file_name # Render and then format the output file. f.content = black.format_str( - template.render(description=output), - mode=black.FileMode(target_versions=set([black.TargetVersion.PY37])), + template.render(description=template_data), + mode=black.FileMode(target_versions={black.TargetVersion.PY37}), ) - inits = set([""]) - for f in response.file: - # Ensure output paths exist - # print(f.name, file=sys.stderr) - dirnames = os.path.dirname(f.name) - if dirnames: - os.makedirs(dirnames, exist_ok=True) - base = "" - for part in dirnames.split(os.path.sep): - base = os.path.join(base, part) - inits.add(base) - - for base in inits: - name = os.path.join(base, "__init__.py") - - if os.path.exists(name): - # Never overwrite inits as they may have custom stuff in them. - continue - - init = response.file.add() - init.name = name - init.content = b"" - filenames = sorted([f.name for f in response.file]) for fname in filenames: print(f"Writing {fname}", file=sys.stderr) +def read_item(item, path, proto_file, package, template_data, global_messages): + item_data = {"name": item.name, "py_name": stringcase.pascalcase(item.name), 'full_name': f'{package}.{item.name}'} + + #sys.stderr.write(f' item {item.name}\n') + + if isinstance(item, DescriptorProto): + # print(item, file=sys.stderr) + if item.options.map_entry: + # Skip generated map entry messages since we just use dicts + return + + item_data.update( + { + "type": "Message", + "comment": get_comment(proto_file, path), + "properties": [], + } + ) + + for i, field in enumerate(item.field): + t = py_type(package, template_data["imports"], item, field) + zero = get_py_zero(field.type) + + repeated = False + packed = False + + field_type = field.Type.Name(field.type).lower()[5:] + + field_wraps = "" + if field.type_name.startswith( + ".google.protobuf" + ) and field.type_name.endswith("Value"): + w = field.type_name.split(".").pop()[:-5].upper() + field_wraps = f"betterproto.TYPE_{w}" + + map_types = None + if field.type == 11: + # This might be a map... + message_type = field.type_name.split(".").pop().lower() + # message_type = py_type(package) + map_entry = f"{field.name.replace('_', '').lower()}entry" + + if message_type == map_entry: + for nested in item.nested_type: + if ( + nested.name.replace("_", "").lower() + == map_entry + ): + if nested.options.map_entry: + # print("Found a map!", file=sys.stderr) + k = py_type( + package, + template_data["imports"], + item, + nested.field[0], + ) + v = py_type( + package, + template_data["imports"], + item, + nested.field[1], + ) + t = f"Dict[{k}, {v}]" + field_type = "map" + map_types = ( + field.Type.Name(nested.field[0].type), + field.Type.Name(nested.field[1].type), + ) + template_data["typing_imports"].add("Dict") + + if field.label == 3 and field_type != "map": + # Repeated field + repeated = True + t = f"List[{t}]" + zero = "[]" + template_data["typing_imports"].add("List") + + if field.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]: + packed = True + + one_of = "" + if field.HasField("oneof_index"): + one_of = item.oneof_decl[field.oneof_index].name + + if "Optional[" in t: + template_data["typing_imports"].add("Optional") + + if "timedelta" in t: + template_data["datetime_imports"].add("timedelta") + elif "datetime" in t: + template_data["datetime_imports"].add("datetime") + + item_data["properties"].append( + { + "name": field.name, + "py_name": safe_snake_case(field.name), + "number": field.number, + "comment": get_comment(proto_file, path + [2, i]), + "proto_type": int(field.type), + "field_type": field_type, + "field_wraps": field_wraps, + "map_types": map_types, + "type": t, + "zero": zero, + "repeated": repeated, + "packed": packed, + "one_of": one_of, + } + ) + + template_data["messages"].append(item_data) + global_messages.append(item_data) + elif isinstance(item, EnumDescriptorProto): + item_data.update( + { + "type": "Enum", + "comment": get_comment(proto_file, path), + "entries": [ + { + "name": v.name, + "value": v.number, + "comment": get_comment(proto_file, path + [2, i]), + } + for i, v in enumerate(item.value) + ], + } + ) + template_data["enums"].append(item_data) + + +def read_service(i, package, proto_file, service, template_data, global_messages): + data = { + "name": service.name, + "py_name": stringcase.pascalcase(service.name), + "comment": get_comment(proto_file, [6, i]), + "methods": [], + } + for j, method in enumerate(service.method): + if method.client_streaming: + raise NotImplementedError("Client streaming not yet supported") + sys.stderr.write(f' add {method.name}\n') + + input_message = None + input_type = get_ref_type( + package, template_data["imports"], method.input_type + ).strip('"') + for msg in global_messages: + if msg["full_name"] == method.input_type.lstrip('.'): + input_message = msg + for field in msg["properties"]: + if field["zero"] == "None": + template_data["typing_imports"].add("Optional") + break + + if not input_message: + sys.stderr.write(f"no input message found for {input_type} / {method}\n") + sys.stderr.write(f" {global_messages}\n") + sys.stderr.write(f"\n") + + method_data = { + "name": method.name, + "py_name": stringcase.snakecase(method.name), + "comment": get_comment(proto_file, [6, i, 2, j], indent=8), + "route": f"/{package}.{service.name}/{method.name}", + "input": get_ref_type( + package, template_data["imports"], method.input_type + ).strip('"'), + "input_message": input_message, + "output": get_ref_type( + package, template_data["imports"], method.output_type, unwrap=False + ).strip('"'), + "client_streaming": method.client_streaming, + "server_streaming": method.server_streaming, + } + # sys.stderr.write(f'METHOD: {method_data}\n') + data["methods"].append(method_data) + + if method.server_streaming: + template_data["typing_imports"].add("AsyncGenerator") + template_data["services"].append(data) + + def main(): """The plugin's main entry point.""" # Read request message from stdin diff --git a/betterproto/tests/generate.py b/betterproto/tests/generate.py index 987f2d9b3..e0fe631e6 100644 --- a/betterproto/tests/generate.py +++ b/betterproto/tests/generate.py @@ -48,13 +48,19 @@ def ensure_ext(filename: str, ext: str) -> str: proto_files = get_files(".proto") json_files = get_files(".json") + if os.name == 'nt': + plugin_path = os.path.join('..', 'plugin.bat') + else: + plugin_path = os.path.join('..', 'plugin.py') + + for filename in proto_files: print(f"Generating code for {os.path.basename(filename)}") subprocess.run( f"protoc --python_out=. {os.path.basename(filename)}", shell=True ) subprocess.run( - f"protoc --plugin=protoc-gen-custom=../plugin.py --custom_out=. {os.path.basename(filename)}", + f"protoc --plugin=protoc-gen-custom={plugin_path} --custom_out=. {os.path.basename(filename)}", shell=True, ) diff --git a/betterproto/tests/googletypes_service.proto b/betterproto/tests/googletypes_service.proto new file mode 100644 index 000000000..4bdca6856 --- /dev/null +++ b/betterproto/tests/googletypes_service.proto @@ -0,0 +1,18 @@ +syntax = "proto3"; + +import "google/protobuf/wrappers.proto"; + +service Test { + rpc GetInt32 (Input) returns (google.protobuf.Int32Value); + rpc GetAnotherInt32 (Input) returns (google.protobuf.Int32Value); + rpc GetInt64 (Input) returns (google.protobuf.Int64Value); + rpc GetOutput (Input) returns (Output); +} + +message Input { + +} + +message Output { + google.protobuf.Int64Value int64 = 1; +} \ No newline at end of file