30
30
"""
31
31
32
32
33
+ import builtins
33
34
import betterproto
34
35
from betterproto import which_one_of
35
36
from betterproto .casing import sanitize_name
@@ -237,6 +238,7 @@ class OutputTemplate:
237
238
imports : Set [str ] = field (default_factory = set )
238
239
datetime_imports : Set [str ] = field (default_factory = set )
239
240
typing_imports : Set [str ] = field (default_factory = set )
241
+ builtins_import : bool = False
240
242
messages : List ["MessageCompiler" ] = field (default_factory = list )
241
243
enums : List ["EnumDefinitionCompiler" ] = field (default_factory = list )
242
244
services : List ["ServiceCompiler" ] = field (default_factory = list )
@@ -268,6 +270,8 @@ def python_module_imports(self) -> Set[str]:
268
270
imports = set ()
269
271
if any (x for x in self .messages if any (x .deprecated_fields )):
270
272
imports .add ("warnings" )
273
+ if self .builtins_import :
274
+ imports .add ("builtins" )
271
275
return imports
272
276
273
277
@@ -283,6 +287,7 @@ class MessageCompiler(ProtoContentBase):
283
287
default_factory = list
284
288
)
285
289
deprecated : bool = field (default = False , init = False )
290
+ builtins_types : Set [str ] = field (default_factory = set )
286
291
287
292
def __post_init__ (self ) -> None :
288
293
# Add message to output file
@@ -376,6 +381,8 @@ def get_field_string(self, indent: int = 4) -> str:
376
381
betterproto_field_type = (
377
382
f"betterproto.{ self .field_type } _field({ self .proto_obj .number } { field_args } )"
378
383
)
384
+ if self .py_name in dir (builtins ):
385
+ self .parent .builtins_types .add (self .py_name )
379
386
return f"{ name } { annotations } = { betterproto_field_type } "
380
387
381
388
@property
@@ -408,9 +415,16 @@ def typing_imports(self) -> Set[str]:
408
415
imports .add ("Dict" )
409
416
return imports
410
417
418
+ @property
419
+ def use_builtins (self ) -> bool :
420
+ return self .py_type in self .parent .builtins_types or (
421
+ self .py_type == self .py_name and self .py_name in dir (builtins )
422
+ )
423
+
411
424
def add_imports_to (self , output_file : OutputTemplate ) -> None :
412
425
output_file .datetime_imports .update (self .datetime_imports )
413
426
output_file .typing_imports .update (self .typing_imports )
427
+ output_file .builtins_import = output_file .builtins_import or self .use_builtins
414
428
415
429
@property
416
430
def field_wraps (self ) -> Optional [str ]:
@@ -504,9 +518,12 @@ def py_type(self) -> str:
504
518
505
519
@property
506
520
def annotation (self ) -> str :
521
+ py_type = self .py_type
522
+ if self .use_builtins :
523
+ py_type = f"builtins.{ py_type } "
507
524
if self .repeated :
508
- return f"List[{ self . py_type } ]"
509
- return self . py_type
525
+ return f"List[{ py_type } ]"
526
+ return py_type
510
527
511
528
512
529
@dataclass
0 commit comments