4
4
from enum import Enum
5
5
from itertools import zip_longest
6
6
from typing import (
7
- Any , cast , Collection , Dict , Iterable , Iterator , List ,
8
- Mapping , MutableMapping , Optional ,
9
- Tuple , Type as TypingType , TypeVar , Union as TypingUnion ,
7
+ Any ,
8
+ cast ,
9
+ Collection ,
10
+ Dict ,
11
+ Iterable ,
12
+ Iterator ,
13
+ List ,
14
+ Mapping ,
15
+ MutableMapping ,
16
+ Optional ,
17
+ Tuple ,
18
+ Type as TypingType ,
19
+ TypeVar ,
20
+ Union as TypingUnion ,
10
21
)
11
22
12
23
from llvmlite import ir
@@ -99,15 +110,13 @@ def symbol_name(self, namespace: Namespace) -> str:
99
110
return mangle ([self .name ] + type_values )
100
111
101
112
def get_type (self , namespace : Namespace ) -> ts .FunctionType :
102
- ref = FunctionTypeReference ([p .type_ for p in self .parameters ], self .return_type , self .parameters .variadic )
113
+ ref = FunctionTypeReference (
114
+ [p .type_ for p in self .parameters ], self .return_type , self .parameters .variadic
115
+ )
103
116
return ref .codegen (namespace )
104
117
105
118
106
- def get_or_create_llvm_function (
107
- module : ir .Module ,
108
- namespace : Namespace ,
109
- function : Function ,
110
- ) -> ir .Function :
119
+ def get_or_create_llvm_function (module : ir .Module , namespace : Namespace , function : Function ) -> ir .Function :
111
120
symbol_name = function .symbol_name (namespace )
112
121
try :
113
122
llvm_function = module .globals [symbol_name ]
@@ -130,7 +139,9 @@ def get_or_create_llvm_function(
130
139
(parameter , parameter_type ) = pt
131
140
memory = builder .alloca (arg .type , name = parameter .name )
132
141
builder .store (arg , memory )
133
- function_namespace .add_value (Variable (parameter .name or f'param{ i + 1 } ' , parameter_type , memory ))
142
+ function_namespace .add_value (
143
+ Variable (parameter .name or f'param{ i + 1 } ' , parameter_type , memory )
144
+ )
134
145
135
146
function .code_block .codegen (module , builder , function_namespace )
136
147
if ft .return_type == ts .void :
@@ -151,12 +162,7 @@ class Module:
151
162
includes : List [str ]
152
163
variables : List [VariableDeclaration ]
153
164
154
- def codegen (
155
- self ,
156
- module : ir .Module ,
157
- parent_namespaces : List [Namespace ],
158
- module_name : str ,
159
- ) -> Namespace :
165
+ def codegen (self , module : ir .Module , parent_namespaces : List [Namespace ], module_name : str ) -> Namespace :
160
166
module_namespace = Namespace (parents = parent_namespaces )
161
167
162
168
definitions_types = [(td , td .get_dummy_type ()) for td in self .types ]
@@ -332,11 +338,7 @@ def codegen(self, module: ir.Module, builder: ir.IRBuilder, namespace: Namespace
332
338
333
339
334
340
def loop_helper (
335
- module : ir .Module ,
336
- builder : ir .IRBuilder ,
337
- namespace : Namespace ,
338
- condition : Expression ,
339
- body : Statement ,
341
+ module : ir .Module , builder : ir .IRBuilder , namespace : Namespace , condition : Expression , body : Statement
340
342
) -> None :
341
343
assert isinstance (condition .type (namespace ), ts .BoolType )
342
344
condition_block = builder .append_basic_block ('loop.condition' )
@@ -402,20 +404,25 @@ def codegen_module_level(self, module: ir.Module, namespace: Namespace, module_n
402
404
403
405
def variable_type (self , namespace : Namespace ) -> ts .Type :
404
406
return (
405
- self .type_ .codegen (namespace ) if self .type_ is not None
407
+ self .type_ .codegen (namespace )
408
+ if self .type_ is not None
406
409
else cast (Expression , self .expression ).type (namespace )
407
410
)
408
411
409
412
410
413
class Expression (Statement ):
411
- def codegen (self , module : ir .Module , builder : ir .IRBuilder , namespace : Namespace , name : str = '' ) -> ir .Value :
414
+ def codegen (
415
+ self , module : ir .Module , builder : ir .IRBuilder , namespace : Namespace , name : str = ''
416
+ ) -> ir .Value :
412
417
raise NotImplementedError ()
413
418
414
419
def type (self , namespace : Namespace ) -> ts .Type :
415
420
raise NotImplementedError (f'type() not implemented for { type (self )} ' )
416
421
417
422
def get_pointer (self , module : ir .Module , builder : ir .IRBuilder , namespace : Namespace ) -> ir .Value :
418
- raise AssertionError (f'{ type (self ).__name__ } cannot be used as a l-value nor can you grab its address' )
423
+ raise AssertionError (
424
+ f'{ type (self ).__name__ } cannot be used as a l-value nor can you grab its address'
425
+ )
419
426
420
427
def get_constant_time_value (self ) -> Any :
421
428
raise NotImplementedError (f'{ 0 } is not a compile-time constant' )
@@ -425,7 +432,9 @@ def get_constant_time_value(self) -> Any:
425
432
class NegativeExpression (Expression ):
426
433
expression : Expression
427
434
428
- def codegen (self , module : ir .Module , builder : ir .IRBuilder , namespace : Namespace , name : str = '' ) -> ir .Value :
435
+ def codegen (
436
+ self , module : ir .Module , builder : ir .IRBuilder , namespace : Namespace , name : str = ''
437
+ ) -> ir .Value :
429
438
value = self .expression .codegen (module , builder , namespace , name )
430
439
value .constant = - value .constant
431
440
return value
@@ -438,7 +447,9 @@ def type(self, namespace: Namespace) -> ts.Type:
438
447
class BoolNegationExpression (Expression ):
439
448
expression : Expression
440
449
441
- def codegen (self , module : ir .Module , builder : ir .IRBuilder , namespace : Namespace , name : str = '' ) -> ir .Value :
450
+ def codegen (
451
+ self , module : ir .Module , builder : ir .IRBuilder , namespace : Namespace , name : str = ''
452
+ ) -> ir .Value :
442
453
assert self .expression .type (namespace ).name == 'bool' , self .expression
443
454
444
455
value_to_negate = self .expression .codegen (module , builder , namespace )
@@ -455,7 +466,9 @@ class BinaryExpression(Expression):
455
466
right_operand : Expression
456
467
name : str = ''
457
468
458
- def codegen (self , module : ir .Module , builder : ir .IRBuilder , namespace : Namespace , name : str = '' ) -> ir .Value :
469
+ def codegen (
470
+ self , module : ir .Module , builder : ir .IRBuilder , namespace : Namespace , name : str = ''
471
+ ) -> ir .Value :
459
472
left_value = self .left_operand .codegen (module , builder , namespace )
460
473
right_value = self .right_operand .codegen (module , builder , namespace )
461
474
# TODO stop hardcoding those
@@ -503,7 +516,9 @@ def codegen(self, module: ir.Module, builder: ir.IRBuilder, namespace: Namespace
503
516
else :
504
517
right_value = builder .sext (right_value , extend_to )
505
518
return builder .icmp_signed (self .operator , left_value , right_value , name = self .name )
506
- raise AssertionError (f'Invalid operand, operator, operand triple: ({ left_value .type } , { right_value .type } , { self .operator } )' ) # noqa
519
+ raise AssertionError (
520
+ f'Invalid operand, operator, operand triple: ({ left_value .type } , { right_value .type } , { self .operator } )'
521
+ ) # noqa
507
522
508
523
def type (self , namespace : Namespace ) -> ts .Type :
509
524
if self .operator in {'<' , '>' , '<=' , '>=' , '==' , '!=' }:
@@ -518,7 +533,9 @@ class FunctionCall(Expression):
518
533
name : str
519
534
parameters : List [Expression ]
520
535
521
- def codegen (self , module : ir .Module , builder : ir .IRBuilder , namespace : Namespace , name : str = '' ) -> ir .Value :
536
+ def codegen (
537
+ self , module : ir .Module , builder : ir .IRBuilder , namespace : Namespace , name : str = ''
538
+ ) -> ir .Value :
522
539
function : TypingUnion [Function , Variable ]
523
540
parameter_names : List [str ]
524
541
try :
@@ -543,9 +560,11 @@ def codegen(self, module: ir.Module, builder: ir.IRBuilder, namespace: Namespace
543
560
llvm_function = get_or_create_llvm_function (module , namespace , function )
544
561
545
562
# TODO: handle not enough parameters here
546
- assert len (self .parameters ) == len (parameter_types ) or \
547
- ft .variadic and len (self .parameters ) > len (parameter_types ), \
548
- (ft , self .parameters )
563
+ assert (
564
+ len (self .parameters ) == len (parameter_types )
565
+ or ft .variadic
566
+ and len (self .parameters ) > len (parameter_types )
567
+ ), (ft , self .parameters )
549
568
parameter_values = [
550
569
p .codegen (module , builder , namespace , f'{ self .name } .{ n } ' )
551
570
for (p , n ) in zip_longest (self .parameters , parameter_names , fillvalue = 'arg' )
@@ -555,7 +574,7 @@ def codegen(self, module: ir.Module, builder: ir.IRBuilder, namespace: Namespace
555
574
556
575
provided_parameter_types = [p .type (namespace ) for p in self .parameters ]
557
576
for i , (value , from_type , to_type ) in enumerate (
558
- zip (parameter_values , provided_parameter_types , parameter_types ),
577
+ zip (parameter_values , provided_parameter_types , parameter_types )
559
578
):
560
579
parameter_values [i ] = to_type .adapt (builder , value , from_type )
561
580
@@ -572,9 +591,7 @@ def whatever<T>(int a, T b) -> void ...
572
591
573
592
574
593
def namespace_for_specialized_function (
575
- namespace : Namespace ,
576
- function : Function ,
577
- arguments : Collection [Expression ],
594
+ namespace : Namespace , function : Function , arguments : Collection [Expression ]
578
595
) -> Namespace :
579
596
mapping : Dict [str , ts .Type ] = {}
580
597
for parameter , expression in zip (function .parameters , arguments ):
@@ -596,7 +613,9 @@ class StructInstantiation(Expression):
596
613
name : str
597
614
parameters : List [Expression ]
598
615
599
- def codegen (self , module : ir .Module , builder : ir .IRBuilder , namespace : Namespace , name : str = '' ) -> ir .Value :
616
+ def codegen (
617
+ self , module : ir .Module , builder : ir .IRBuilder , namespace : Namespace , name : str = ''
618
+ ) -> ir .Value :
600
619
struct = namespace .get_type (self .name )
601
620
assert isinstance (struct , ts .StructType )
602
621
assert len (self .parameters ) == len (struct .members )
@@ -621,7 +640,9 @@ class StringLiteral(Expression):
621
640
def __post_init__ (self ) -> None :
622
641
self .text = evaluate_escape_sequences (self .text )
623
642
624
- def codegen (self , module : ir .Module , builder : ir .IRBuilder , namespace : Namespace , name : str = '' ) -> ir .Value :
643
+ def codegen (
644
+ self , module : ir .Module , builder : ir .IRBuilder , namespace : Namespace , name : str = ''
645
+ ) -> ir .Value :
625
646
return string_constant (module , builder , self .text [1 :- 1 ], namespace )
626
647
627
648
def type (self , namespace : Namespace ) -> ts .Type :
@@ -636,7 +657,9 @@ def evaluate_escape_sequences(text: str) -> str:
636
657
class IntegerLiteral (Expression ):
637
658
text : str
638
659
639
- def codegen (self , module : ir .Module , builder : ir .IRBuilder , namespace : Namespace , name : str = '' ) -> ir .Value :
660
+ def codegen (
661
+ self , module : ir .Module , builder : ir .IRBuilder , namespace : Namespace , name : str = ''
662
+ ) -> ir .Value :
640
663
value = int (self .text )
641
664
return namespace .get_type ('i64' ).get_ir_type ()(value )
642
665
@@ -651,7 +674,9 @@ def get_constant_time_value(self) -> Any:
651
674
class FloatLiteral (Expression ):
652
675
text : str
653
676
654
- def codegen (self , module : ir .Module , builder : ir .IRBuilder , namespace : Namespace , name : str = '' ) -> ir .Value :
677
+ def codegen (
678
+ self , module : ir .Module , builder : ir .IRBuilder , namespace : Namespace , name : str = ''
679
+ ) -> ir .Value :
655
680
value = float (self .text )
656
681
return namespace .get_type ('f64' ).get_ir_type ()(value )
657
682
@@ -663,7 +688,9 @@ def type(self, namespace: Namespace) -> ts.Type:
663
688
class BoolLiteral (Expression ):
664
689
value : bool
665
690
666
- def codegen (self , module : ir .Module , builder : ir .IRBuilder , namespace : Namespace , name : str = '' ) -> ir .Value :
691
+ def codegen (
692
+ self , module : ir .Module , builder : ir .IRBuilder , namespace : Namespace , name : str = ''
693
+ ) -> ir .Value :
667
694
return namespace .get_type ('bool' ).get_ir_type ()(self .value )
668
695
669
696
def type (self , namespace : Namespace ) -> ts .Type :
@@ -679,7 +706,9 @@ def get_pointer(self, module: ir.Module, builder: ir.IRBuilder, namespace: Names
679
706
class VariableReference (MemoryReference ):
680
707
name : str
681
708
682
- def codegen (self , module : ir .Module , builder : ir .IRBuilder , namespace : Namespace , name : str = '' ) -> ir .Value :
709
+ def codegen (
710
+ self , module : ir .Module , builder : ir .IRBuilder , namespace : Namespace , name : str = ''
711
+ ) -> ir .Value :
683
712
type_ = self .type (namespace )
684
713
pointer = self .get_pointer (module , builder , namespace )
685
714
# The first part of this condition makes sure we keep referring to functions by their pointers.
@@ -714,7 +743,9 @@ def type(self, namespace: Namespace) -> ts.Type:
714
743
class AddressOf (MemoryReference ):
715
744
variable : VariableReference
716
745
717
- def codegen (self , module : ir .Module , builder : ir .IRBuilder , namespace : Namespace , name : str = '' ) -> ir .Value :
746
+ def codegen (
747
+ self , module : ir .Module , builder : ir .IRBuilder , namespace : Namespace , name : str = ''
748
+ ) -> ir .Value :
718
749
return self .variable .get_pointer (module , builder , namespace )
719
750
720
751
def type (self , namespace : Namespace ) -> ts .Type :
@@ -725,7 +756,9 @@ def type(self, namespace: Namespace) -> ts.Type:
725
756
class ValueAt (MemoryReference ):
726
757
variable : VariableReference
727
758
728
- def codegen (self , module : ir .Module , builder : ir .IRBuilder , namespace : Namespace , name : str = '' ) -> ir .Value :
759
+ def codegen (
760
+ self , module : ir .Module , builder : ir .IRBuilder , namespace : Namespace , name : str = ''
761
+ ) -> ir .Value :
729
762
pointer = self .variable .get_pointer (module , builder , namespace )
730
763
pointer = builder .load (pointer ) # self.variable.codegen
731
764
pointer = builder .load (pointer )
@@ -743,7 +776,9 @@ class Assignment(Expression):
743
776
target : Expression
744
777
expression : Expression
745
778
746
- def codegen (self , module : ir .Module , builder : ir .IRBuilder , namespace : Namespace , name : str = '' ) -> ir .Value :
779
+ def codegen (
780
+ self , module : ir .Module , builder : ir .IRBuilder , namespace : Namespace , name : str = ''
781
+ ) -> ir .Value :
747
782
pointer = self .target .get_pointer (module , builder , namespace )
748
783
value = self .expression .codegen (module , builder , namespace )
749
784
destination_type = self .target .type (namespace )
@@ -765,13 +800,15 @@ class ArrayLiteral(Expression):
765
800
def __post_init__ (self ) -> None :
766
801
assert len (self .initializers ) > 0
767
802
768
- def codegen (self , module : ir .Module , builder : ir .IRBuilder , namespace : Namespace , name : str = '' ) -> ir .Value :
803
+ def codegen (
804
+ self , module : ir .Module , builder : ir .IRBuilder , namespace : Namespace , name : str = ''
805
+ ) -> ir .Value :
769
806
type_ = self .type (namespace )
770
807
memory = builder .alloca (type_ .get_ir_type (), name = name )
771
808
i64 = namespace .get_type ('i64' ).get_ir_type ()
772
809
773
810
for index , initializer in enumerate (self .initializers ):
774
- indexed_memory = builder .gep (memory , (i64 (0 ), i64 (index ), ))
811
+ indexed_memory = builder .gep (memory , (i64 (0 ), i64 (index )))
775
812
value = initializer .codegen (module , builder , namespace )
776
813
builder .store (value , indexed_memory )
777
814
return builder .load (memory )
@@ -787,7 +824,9 @@ class DotAccess(MemoryReference):
787
824
left_side : MemoryReference
788
825
member : str
789
826
790
- def codegen (self , module : ir .Module , builder : ir .IRBuilder , namespace : Namespace , name : str = '' ) -> ir .Value :
827
+ def codegen (
828
+ self , module : ir .Module , builder : ir .IRBuilder , namespace : Namespace , name : str = ''
829
+ ) -> ir .Value :
791
830
pointer = self .get_pointer (module , builder , namespace )
792
831
return builder .load (pointer )
793
832
@@ -808,7 +847,9 @@ class IndexAccess(MemoryReference):
808
847
pointer : MemoryReference
809
848
index : Expression
810
849
811
- def codegen (self , module : ir .Module , builder : ir .IRBuilder , namespace : Namespace , name : str = '' ) -> ir .Value :
850
+ def codegen (
851
+ self , module : ir .Module , builder : ir .IRBuilder , namespace : Namespace , name : str = ''
852
+ ) -> ir .Value :
812
853
pointer = self .get_pointer (module , builder , namespace )
813
854
return builder .load (pointer )
814
855
0 commit comments