Skip to content

Commit 3273ae4

Browse files
Merge pull request #100 from boukeversteegh/fix/circular-dependencies
Import bug - Circular Dependencies
2 parents 6fe6664 + 23dcbc2 commit 3273ae4

File tree

5 files changed

+59
-48
lines changed

5 files changed

+59
-48
lines changed

betterproto/compile/importing.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def reference_absolute(imports, py_package, py_type):
8686
string_import = ".".join(py_package)
8787
string_alias = safe_snake_case(string_import)
8888
imports.add(f"import {string_import} as {string_alias}")
89-
return f"{string_alias}.{py_type}"
89+
return f'"{string_alias}.{py_type}"'
9090

9191

9292
def reference_sibling(py_type: str) -> str:
@@ -109,10 +109,10 @@ def reference_descendent(
109109
if string_from:
110110
string_alias = "_".join(importing_descendent)
111111
imports.add(f"from .{string_from} import {string_import} as {string_alias}")
112-
return f"{string_alias}.{py_type}"
112+
return f'"{string_alias}.{py_type}"'
113113
else:
114114
imports.add(f"from . import {string_import}")
115-
return f"{string_import}.{py_type}"
115+
return f'"{string_import}.{py_type}"'
116116

117117

118118
def reference_ancestor(
@@ -130,11 +130,11 @@ def reference_ancestor(
130130
string_alias = f"_{'_' * distance_up}{string_import}__"
131131
string_from = f"..{'.' * distance_up}"
132132
imports.add(f"from {string_from} import {string_import} as {string_alias}")
133-
return f"{string_alias}.{py_type}"
133+
return f'"{string_alias}.{py_type}"'
134134
else:
135135
string_alias = f"{'_' * distance_up}{py_type}__"
136136
imports.add(f"from .{'.' * distance_up} import {py_type} as {string_alias}")
137-
return string_alias
137+
return f'"{string_alias}"'
138138

139139

140140
def reference_cousin(
@@ -157,4 +157,4 @@ def reference_cousin(
157157
+ "__"
158158
)
159159
imports.add(f"from {string_from} import {string_import} as {string_alias}")
160-
return f"{string_alias}.{py_type}"
160+
return f'"{string_alias}.{py_type}"'

betterproto/plugin.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ def generate_code(request, response):
329329
output["imports"],
330330
method.output_type,
331331
unwrap=False,
332-
).strip('"'),
332+
),
333333
"client_streaming": method.client_streaming,
334334
"server_streaming": method.server_streaming,
335335
}

betterproto/templates/template.py.j2

+8-8
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,6 @@ import betterproto
1616
import grpclib
1717
{% endif %}
1818

19-
{% for i in description.imports %}
20-
{{ i }}
21-
{% endfor %}
22-
2319

2420
{% if description.enums %}{% for enum in description.enums %}
2521
class {{ enum.py_name }}(betterproto.Enum):
@@ -102,14 +98,14 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
10298
"{{ method.route }}",
10399
request_iterator,
104100
{{ method.input }},
105-
{{ method.output }},
101+
{{ method.output.strip('"') }},
106102
):
107103
yield response
108104
{% else %}{# i.e. not client streaming #}
109105
async for response in self._unary_stream(
110106
"{{ method.route }}",
111107
request,
112-
{{ method.output }},
108+
{{ method.output.strip('"') }},
113109
):
114110
yield response
115111

@@ -120,16 +116,20 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
120116
"{{ method.route }}",
121117
request_iterator,
122118
{{ method.input }},
123-
{{ method.output }}
119+
{{ method.output.strip('"') }}
124120
)
125121
{% else %}{# i.e. not client streaming #}
126122
return await self._unary_unary(
127123
"{{ method.route }}",
128124
request,
129-
{{ method.output }}
125+
{{ method.output.strip('"') }}
130126
)
131127
{% endif %}{# client streaming #}
132128
{% endif %}
133129

134130
{% endfor %}
135131
{% endfor %}
132+
133+
{% for i in description.imports %}
134+
{{ i }}
135+
{% endfor %}

betterproto/tests/inputs/config.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# Test cases that are expected to fail, e.g. unimplemented features or bug-fixes.
22
# Remove from list when fixed.
33
xfail = {
4-
"import_circular_dependency",
54
"oneof_enum", # 63
65
"namespace_keywords", # 70
76
"namespace_builtin_types", # 53

betterproto/tests/test_get_ref_type.py

+44-32
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,22 @@
88
[
99
(
1010
".google.protobuf.Empty",
11-
"betterproto_lib_google_protobuf.Empty",
11+
'"betterproto_lib_google_protobuf.Empty"',
1212
"import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf",
1313
),
1414
(
1515
".google.protobuf.Struct",
16-
"betterproto_lib_google_protobuf.Struct",
16+
'"betterproto_lib_google_protobuf.Struct"',
1717
"import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf",
1818
),
1919
(
2020
".google.protobuf.ListValue",
21-
"betterproto_lib_google_protobuf.ListValue",
21+
'"betterproto_lib_google_protobuf.ListValue"',
2222
"import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf",
2323
),
2424
(
2525
".google.protobuf.Value",
26-
"betterproto_lib_google_protobuf.Value",
26+
'"betterproto_lib_google_protobuf.Value"',
2727
"import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf",
2828
),
2929
],
@@ -67,15 +67,27 @@ def test_referenceing_google_wrappers_unwraps_them(
6767
@pytest.mark.parametrize(
6868
["google_type", "expected_name"],
6969
[
70-
(".google.protobuf.DoubleValue", "betterproto_lib_google_protobuf.DoubleValue"),
71-
(".google.protobuf.FloatValue", "betterproto_lib_google_protobuf.FloatValue"),
72-
(".google.protobuf.Int32Value", "betterproto_lib_google_protobuf.Int32Value"),
73-
(".google.protobuf.Int64Value", "betterproto_lib_google_protobuf.Int64Value"),
74-
(".google.protobuf.UInt32Value", "betterproto_lib_google_protobuf.UInt32Value"),
75-
(".google.protobuf.UInt64Value", "betterproto_lib_google_protobuf.UInt64Value"),
76-
(".google.protobuf.BoolValue", "betterproto_lib_google_protobuf.BoolValue"),
77-
(".google.protobuf.StringValue", "betterproto_lib_google_protobuf.StringValue"),
78-
(".google.protobuf.BytesValue", "betterproto_lib_google_protobuf.BytesValue"),
70+
(
71+
".google.protobuf.DoubleValue",
72+
'"betterproto_lib_google_protobuf.DoubleValue"',
73+
),
74+
(".google.protobuf.FloatValue", '"betterproto_lib_google_protobuf.FloatValue"'),
75+
(".google.protobuf.Int32Value", '"betterproto_lib_google_protobuf.Int32Value"'),
76+
(".google.protobuf.Int64Value", '"betterproto_lib_google_protobuf.Int64Value"'),
77+
(
78+
".google.protobuf.UInt32Value",
79+
'"betterproto_lib_google_protobuf.UInt32Value"',
80+
),
81+
(
82+
".google.protobuf.UInt64Value",
83+
'"betterproto_lib_google_protobuf.UInt64Value"',
84+
),
85+
(".google.protobuf.BoolValue", '"betterproto_lib_google_protobuf.BoolValue"'),
86+
(
87+
".google.protobuf.StringValue",
88+
'"betterproto_lib_google_protobuf.StringValue"',
89+
),
90+
(".google.protobuf.BytesValue", '"betterproto_lib_google_protobuf.BytesValue"'),
7991
],
8092
)
8193
def test_referenceing_google_wrappers_without_unwrapping(
@@ -95,15 +107,15 @@ def test_reference_child_package_from_package():
95107
)
96108

97109
assert imports == {"from . import child"}
98-
assert name == "child.Message"
110+
assert name == '"child.Message"'
99111

100112

101113
def test_reference_child_package_from_root():
102114
imports = set()
103115
name = get_type_reference(package="", imports=imports, source_type="child.Message")
104116

105117
assert imports == {"from . import child"}
106-
assert name == "child.Message"
118+
assert name == '"child.Message"'
107119

108120

109121
def test_reference_camel_cased():
@@ -113,7 +125,7 @@ def test_reference_camel_cased():
113125
)
114126

115127
assert imports == {"from . import child_package"}
116-
assert name == "child_package.ExampleMessage"
128+
assert name == '"child_package.ExampleMessage"'
117129

118130

119131
def test_reference_nested_child_from_root():
@@ -123,7 +135,7 @@ def test_reference_nested_child_from_root():
123135
)
124136

125137
assert imports == {"from .nested import child as nested_child"}
126-
assert name == "nested_child.Message"
138+
assert name == '"nested_child.Message"'
127139

128140

129141
def test_reference_deeply_nested_child_from_root():
@@ -133,7 +145,7 @@ def test_reference_deeply_nested_child_from_root():
133145
)
134146

135147
assert imports == {"from .deeply.nested import child as deeply_nested_child"}
136-
assert name == "deeply_nested_child.Message"
148+
assert name == '"deeply_nested_child.Message"'
137149

138150

139151
def test_reference_deeply_nested_child_from_package():
@@ -145,7 +157,7 @@ def test_reference_deeply_nested_child_from_package():
145157
)
146158

147159
assert imports == {"from .deeply.nested import child as deeply_nested_child"}
148-
assert name == "deeply_nested_child.Message"
160+
assert name == '"deeply_nested_child.Message"'
149161

150162

151163
def test_reference_root_sibling():
@@ -181,7 +193,7 @@ def test_reference_parent_package_from_child():
181193
)
182194

183195
assert imports == {"from ... import package as __package__"}
184-
assert name == "__package__.Message"
196+
assert name == '"__package__.Message"'
185197

186198

187199
def test_reference_parent_package_from_deeply_nested_child():
@@ -193,7 +205,7 @@ def test_reference_parent_package_from_deeply_nested_child():
193205
)
194206

195207
assert imports == {"from ... import nested as __nested__"}
196-
assert name == "__nested__.Message"
208+
assert name == '"__nested__.Message"'
197209

198210

199211
def test_reference_ancestor_package_from_nested_child():
@@ -205,7 +217,7 @@ def test_reference_ancestor_package_from_nested_child():
205217
)
206218

207219
assert imports == {"from .... import ancestor as ___ancestor__"}
208-
assert name == "___ancestor__.Message"
220+
assert name == '"___ancestor__.Message"'
209221

210222

211223
def test_reference_root_package_from_child():
@@ -215,7 +227,7 @@ def test_reference_root_package_from_child():
215227
)
216228

217229
assert imports == {"from ... import Message as __Message__"}
218-
assert name == "__Message__"
230+
assert name == '"__Message__"'
219231

220232

221233
def test_reference_root_package_from_deeply_nested_child():
@@ -225,23 +237,23 @@ def test_reference_root_package_from_deeply_nested_child():
225237
)
226238

227239
assert imports == {"from ..... import Message as ____Message__"}
228-
assert name == "____Message__"
240+
assert name == '"____Message__"'
229241

230242

231243
def test_reference_unrelated_package():
232244
imports = set()
233245
name = get_type_reference(package="a", imports=imports, source_type="p.Message")
234246

235247
assert imports == {"from .. import p as _p__"}
236-
assert name == "_p__.Message"
248+
assert name == '"_p__.Message"'
237249

238250

239251
def test_reference_unrelated_nested_package():
240252
imports = set()
241253
name = get_type_reference(package="a.b", imports=imports, source_type="p.q.Message")
242254

243255
assert imports == {"from ...p import q as __p_q__"}
244-
assert name == "__p_q__.Message"
256+
assert name == '"__p_q__.Message"'
245257

246258

247259
def test_reference_unrelated_deeply_nested_package():
@@ -251,15 +263,15 @@ def test_reference_unrelated_deeply_nested_package():
251263
)
252264

253265
assert imports == {"from .....p.q.r import s as ____p_q_r_s__"}
254-
assert name == "____p_q_r_s__.Message"
266+
assert name == '"____p_q_r_s__.Message"'
255267

256268

257269
def test_reference_cousin_package():
258270
imports = set()
259271
name = get_type_reference(package="a.x", imports=imports, source_type="a.y.Message")
260272

261273
assert imports == {"from .. import y as _y__"}
262-
assert name == "_y__.Message"
274+
assert name == '"_y__.Message"'
263275

264276

265277
def test_reference_cousin_package_different_name():
@@ -269,7 +281,7 @@ def test_reference_cousin_package_different_name():
269281
)
270282

271283
assert imports == {"from ...cousin import package2 as __cousin_package2__"}
272-
assert name == "__cousin_package2__.Message"
284+
assert name == '"__cousin_package2__.Message"'
273285

274286

275287
def test_reference_cousin_package_same_name():
@@ -279,7 +291,7 @@ def test_reference_cousin_package_same_name():
279291
)
280292

281293
assert imports == {"from ...cousin import package as __cousin_package__"}
282-
assert name == "__cousin_package__.Message"
294+
assert name == '"__cousin_package__.Message"'
283295

284296

285297
def test_reference_far_cousin_package():
@@ -289,7 +301,7 @@ def test_reference_far_cousin_package():
289301
)
290302

291303
assert imports == {"from ...b import c as __b_c__"}
292-
assert name == "__b_c__.Message"
304+
assert name == '"__b_c__.Message"'
293305

294306

295307
def test_reference_far_far_cousin_package():
@@ -299,7 +311,7 @@ def test_reference_far_far_cousin_package():
299311
)
300312

301313
assert imports == {"from ....b.c import d as ___b_c_d__"}
302-
assert name == "___b_c_d__.Message"
314+
assert name == '"___b_c_d__.Message"'
303315

304316

305317
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)