Skip to content

Commit 134b96a

Browse files
committed
Reformat the code using black
1 parent 90ee58d commit 134b96a

File tree

7 files changed

+135
-105
lines changed

7 files changed

+135
-105
lines changed

kotlang/ast.py

Lines changed: 90 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,20 @@
44
from enum import Enum
55
from itertools import zip_longest
66
from typing import (
7-
Any, cast, Collection, Dict, Iterable, Iterator, List,
8-
Mapping, MutableMapping, Optional,
9-
Tuple, Type as TypingType, TypeVar, Union as TypingUnion,
7+
Any,
8+
cast,
9+
Collection,
10+
Dict,
11+
Iterable,
12+
Iterator,
13+
List,
14+
Mapping,
15+
MutableMapping,
16+
Optional,
17+
Tuple,
18+
Type as TypingType,
19+
TypeVar,
20+
Union as TypingUnion,
1021
)
1122

1223
from llvmlite import ir
@@ -99,15 +110,13 @@ def symbol_name(self, namespace: Namespace) -> str:
99110
return mangle([self.name] + type_values)
100111

101112
def get_type(self, namespace: Namespace) -> ts.FunctionType:
102-
ref = FunctionTypeReference([p.type_ for p in self.parameters], self.return_type, self.parameters.variadic)
113+
ref = FunctionTypeReference(
114+
[p.type_ for p in self.parameters], self.return_type, self.parameters.variadic
115+
)
103116
return ref.codegen(namespace)
104117

105118

106-
def get_or_create_llvm_function(
107-
module: ir.Module,
108-
namespace: Namespace,
109-
function: Function,
110-
) -> ir.Function:
119+
def get_or_create_llvm_function(module: ir.Module, namespace: Namespace, function: Function) -> ir.Function:
111120
symbol_name = function.symbol_name(namespace)
112121
try:
113122
llvm_function = module.globals[symbol_name]
@@ -130,7 +139,9 @@ def get_or_create_llvm_function(
130139
(parameter, parameter_type) = pt
131140
memory = builder.alloca(arg.type, name=parameter.name)
132141
builder.store(arg, memory)
133-
function_namespace.add_value(Variable(parameter.name or f'param{i + 1}', parameter_type, memory))
142+
function_namespace.add_value(
143+
Variable(parameter.name or f'param{i + 1}', parameter_type, memory)
144+
)
134145

135146
function.code_block.codegen(module, builder, function_namespace)
136147
if ft.return_type == ts.void:
@@ -151,12 +162,7 @@ class Module:
151162
includes: List[str]
152163
variables: List[VariableDeclaration]
153164

154-
def codegen(
155-
self,
156-
module: ir.Module,
157-
parent_namespaces: List[Namespace],
158-
module_name: str,
159-
) -> Namespace:
165+
def codegen(self, module: ir.Module, parent_namespaces: List[Namespace], module_name: str) -> Namespace:
160166
module_namespace = Namespace(parents=parent_namespaces)
161167

162168
definitions_types = [(td, td.get_dummy_type()) for td in self.types]
@@ -332,11 +338,7 @@ def codegen(self, module: ir.Module, builder: ir.IRBuilder, namespace: Namespace
332338

333339

334340
def loop_helper(
335-
module: ir.Module,
336-
builder: ir.IRBuilder,
337-
namespace: Namespace,
338-
condition: Expression,
339-
body: Statement,
341+
module: ir.Module, builder: ir.IRBuilder, namespace: Namespace, condition: Expression, body: Statement
340342
) -> None:
341343
assert isinstance(condition.type(namespace), ts.BoolType)
342344
condition_block = builder.append_basic_block('loop.condition')
@@ -402,20 +404,25 @@ def codegen_module_level(self, module: ir.Module, namespace: Namespace, module_n
402404

403405
def variable_type(self, namespace: Namespace) -> ts.Type:
404406
return (
405-
self.type_.codegen(namespace) if self.type_ is not None
407+
self.type_.codegen(namespace)
408+
if self.type_ is not None
406409
else cast(Expression, self.expression).type(namespace)
407410
)
408411

409412

410413
class Expression(Statement):
411-
def codegen(self, module: ir.Module, builder: ir.IRBuilder, namespace: Namespace, name: str = '') -> ir.Value:
414+
def codegen(
415+
self, module: ir.Module, builder: ir.IRBuilder, namespace: Namespace, name: str = ''
416+
) -> ir.Value:
412417
raise NotImplementedError()
413418

414419
def type(self, namespace: Namespace) -> ts.Type:
415420
raise NotImplementedError(f'type() not implemented for {type(self)}')
416421

417422
def get_pointer(self, module: ir.Module, builder: ir.IRBuilder, namespace: Namespace) -> ir.Value:
418-
raise AssertionError(f'{type(self).__name__} cannot be used as a l-value nor can you grab its address')
423+
raise AssertionError(
424+
f'{type(self).__name__} cannot be used as a l-value nor can you grab its address'
425+
)
419426

420427
def get_constant_time_value(self) -> Any:
421428
raise NotImplementedError(f'{0} is not a compile-time constant')
@@ -425,7 +432,9 @@ def get_constant_time_value(self) -> Any:
425432
class NegativeExpression(Expression):
426433
expression: Expression
427434

428-
def codegen(self, module: ir.Module, builder: ir.IRBuilder, namespace: Namespace, name: str = '') -> ir.Value:
435+
def codegen(
436+
self, module: ir.Module, builder: ir.IRBuilder, namespace: Namespace, name: str = ''
437+
) -> ir.Value:
429438
value = self.expression.codegen(module, builder, namespace, name)
430439
value.constant = -value.constant
431440
return value
@@ -438,7 +447,9 @@ def type(self, namespace: Namespace) -> ts.Type:
438447
class BoolNegationExpression(Expression):
439448
expression: Expression
440449

441-
def codegen(self, module: ir.Module, builder: ir.IRBuilder, namespace: Namespace, name: str = '') -> ir.Value:
450+
def codegen(
451+
self, module: ir.Module, builder: ir.IRBuilder, namespace: Namespace, name: str = ''
452+
) -> ir.Value:
442453
assert self.expression.type(namespace).name == 'bool', self.expression
443454

444455
value_to_negate = self.expression.codegen(module, builder, namespace)
@@ -455,7 +466,9 @@ class BinaryExpression(Expression):
455466
right_operand: Expression
456467
name: str = ''
457468

458-
def codegen(self, module: ir.Module, builder: ir.IRBuilder, namespace: Namespace, name: str = '') -> ir.Value:
469+
def codegen(
470+
self, module: ir.Module, builder: ir.IRBuilder, namespace: Namespace, name: str = ''
471+
) -> ir.Value:
459472
left_value = self.left_operand.codegen(module, builder, namespace)
460473
right_value = self.right_operand.codegen(module, builder, namespace)
461474
# TODO stop hardcoding those
@@ -503,7 +516,9 @@ def codegen(self, module: ir.Module, builder: ir.IRBuilder, namespace: Namespace
503516
else:
504517
right_value = builder.sext(right_value, extend_to)
505518
return builder.icmp_signed(self.operator, left_value, right_value, name=self.name)
506-
raise AssertionError(f'Invalid operand, operator, operand triple: ({left_value.type}, {right_value.type}, {self.operator})') # noqa
519+
raise AssertionError(
520+
f'Invalid operand, operator, operand triple: ({left_value.type}, {right_value.type}, {self.operator})'
521+
) # noqa
507522

508523
def type(self, namespace: Namespace) -> ts.Type:
509524
if self.operator in {'<', '>', '<=', '>=', '==', '!='}:
@@ -518,7 +533,9 @@ class FunctionCall(Expression):
518533
name: str
519534
parameters: List[Expression]
520535

521-
def codegen(self, module: ir.Module, builder: ir.IRBuilder, namespace: Namespace, name: str = '') -> ir.Value:
536+
def codegen(
537+
self, module: ir.Module, builder: ir.IRBuilder, namespace: Namespace, name: str = ''
538+
) -> ir.Value:
522539
function: TypingUnion[Function, Variable]
523540
parameter_names: List[str]
524541
try:
@@ -543,9 +560,11 @@ def codegen(self, module: ir.Module, builder: ir.IRBuilder, namespace: Namespace
543560
llvm_function = get_or_create_llvm_function(module, namespace, function)
544561

545562
# TODO: handle not enough parameters here
546-
assert len(self.parameters) == len(parameter_types) or \
547-
ft.variadic and len(self.parameters) > len(parameter_types), \
548-
(ft, self.parameters)
563+
assert (
564+
len(self.parameters) == len(parameter_types)
565+
or ft.variadic
566+
and len(self.parameters) > len(parameter_types)
567+
), (ft, self.parameters)
549568
parameter_values = [
550569
p.codegen(module, builder, namespace, f'{self.name}.{n}')
551570
for (p, n) in zip_longest(self.parameters, parameter_names, fillvalue='arg')
@@ -555,7 +574,7 @@ def codegen(self, module: ir.Module, builder: ir.IRBuilder, namespace: Namespace
555574

556575
provided_parameter_types = [p.type(namespace) for p in self.parameters]
557576
for i, (value, from_type, to_type) in enumerate(
558-
zip(parameter_values, provided_parameter_types, parameter_types),
577+
zip(parameter_values, provided_parameter_types, parameter_types)
559578
):
560579
parameter_values[i] = to_type.adapt(builder, value, from_type)
561580

@@ -572,9 +591,7 @@ def whatever<T>(int a, T b) -> void ...
572591

573592

574593
def namespace_for_specialized_function(
575-
namespace: Namespace,
576-
function: Function,
577-
arguments: Collection[Expression],
594+
namespace: Namespace, function: Function, arguments: Collection[Expression]
578595
) -> Namespace:
579596
mapping: Dict[str, ts.Type] = {}
580597
for parameter, expression in zip(function.parameters, arguments):
@@ -596,7 +613,9 @@ class StructInstantiation(Expression):
596613
name: str
597614
parameters: List[Expression]
598615

599-
def codegen(self, module: ir.Module, builder: ir.IRBuilder, namespace: Namespace, name: str = '') -> ir.Value:
616+
def codegen(
617+
self, module: ir.Module, builder: ir.IRBuilder, namespace: Namespace, name: str = ''
618+
) -> ir.Value:
600619
struct = namespace.get_type(self.name)
601620
assert isinstance(struct, ts.StructType)
602621
assert len(self.parameters) == len(struct.members)
@@ -621,7 +640,9 @@ class StringLiteral(Expression):
621640
def __post_init__(self) -> None:
622641
self.text = evaluate_escape_sequences(self.text)
623642

624-
def codegen(self, module: ir.Module, builder: ir.IRBuilder, namespace: Namespace, name: str = '') -> ir.Value:
643+
def codegen(
644+
self, module: ir.Module, builder: ir.IRBuilder, namespace: Namespace, name: str = ''
645+
) -> ir.Value:
625646
return string_constant(module, builder, self.text[1:-1], namespace)
626647

627648
def type(self, namespace: Namespace) -> ts.Type:
@@ -636,7 +657,9 @@ def evaluate_escape_sequences(text: str) -> str:
636657
class IntegerLiteral(Expression):
637658
text: str
638659

639-
def codegen(self, module: ir.Module, builder: ir.IRBuilder, namespace: Namespace, name: str = '') -> ir.Value:
660+
def codegen(
661+
self, module: ir.Module, builder: ir.IRBuilder, namespace: Namespace, name: str = ''
662+
) -> ir.Value:
640663
value = int(self.text)
641664
return namespace.get_type('i64').get_ir_type()(value)
642665

@@ -651,7 +674,9 @@ def get_constant_time_value(self) -> Any:
651674
class FloatLiteral(Expression):
652675
text: str
653676

654-
def codegen(self, module: ir.Module, builder: ir.IRBuilder, namespace: Namespace, name: str = '') -> ir.Value:
677+
def codegen(
678+
self, module: ir.Module, builder: ir.IRBuilder, namespace: Namespace, name: str = ''
679+
) -> ir.Value:
655680
value = float(self.text)
656681
return namespace.get_type('f64').get_ir_type()(value)
657682

@@ -663,7 +688,9 @@ def type(self, namespace: Namespace) -> ts.Type:
663688
class BoolLiteral(Expression):
664689
value: bool
665690

666-
def codegen(self, module: ir.Module, builder: ir.IRBuilder, namespace: Namespace, name: str = '') -> ir.Value:
691+
def codegen(
692+
self, module: ir.Module, builder: ir.IRBuilder, namespace: Namespace, name: str = ''
693+
) -> ir.Value:
667694
return namespace.get_type('bool').get_ir_type()(self.value)
668695

669696
def type(self, namespace: Namespace) -> ts.Type:
@@ -679,7 +706,9 @@ def get_pointer(self, module: ir.Module, builder: ir.IRBuilder, namespace: Names
679706
class VariableReference(MemoryReference):
680707
name: str
681708

682-
def codegen(self, module: ir.Module, builder: ir.IRBuilder, namespace: Namespace, name: str = '') -> ir.Value:
709+
def codegen(
710+
self, module: ir.Module, builder: ir.IRBuilder, namespace: Namespace, name: str = ''
711+
) -> ir.Value:
683712
type_ = self.type(namespace)
684713
pointer = self.get_pointer(module, builder, namespace)
685714
# The first part of this condition makes sure we keep referring to functions by their pointers.
@@ -714,7 +743,9 @@ def type(self, namespace: Namespace) -> ts.Type:
714743
class AddressOf(MemoryReference):
715744
variable: VariableReference
716745

717-
def codegen(self, module: ir.Module, builder: ir.IRBuilder, namespace: Namespace, name: str = '') -> ir.Value:
746+
def codegen(
747+
self, module: ir.Module, builder: ir.IRBuilder, namespace: Namespace, name: str = ''
748+
) -> ir.Value:
718749
return self.variable.get_pointer(module, builder, namespace)
719750

720751
def type(self, namespace: Namespace) -> ts.Type:
@@ -725,7 +756,9 @@ def type(self, namespace: Namespace) -> ts.Type:
725756
class ValueAt(MemoryReference):
726757
variable: VariableReference
727758

728-
def codegen(self, module: ir.Module, builder: ir.IRBuilder, namespace: Namespace, name: str = '') -> ir.Value:
759+
def codegen(
760+
self, module: ir.Module, builder: ir.IRBuilder, namespace: Namespace, name: str = ''
761+
) -> ir.Value:
729762
pointer = self.variable.get_pointer(module, builder, namespace)
730763
pointer = builder.load(pointer) # self.variable.codegen
731764
pointer = builder.load(pointer)
@@ -743,7 +776,9 @@ class Assignment(Expression):
743776
target: Expression
744777
expression: Expression
745778

746-
def codegen(self, module: ir.Module, builder: ir.IRBuilder, namespace: Namespace, name: str = '') -> ir.Value:
779+
def codegen(
780+
self, module: ir.Module, builder: ir.IRBuilder, namespace: Namespace, name: str = ''
781+
) -> ir.Value:
747782
pointer = self.target.get_pointer(module, builder, namespace)
748783
value = self.expression.codegen(module, builder, namespace)
749784
destination_type = self.target.type(namespace)
@@ -765,13 +800,15 @@ class ArrayLiteral(Expression):
765800
def __post_init__(self) -> None:
766801
assert len(self.initializers) > 0
767802

768-
def codegen(self, module: ir.Module, builder: ir.IRBuilder, namespace: Namespace, name: str = '') -> ir.Value:
803+
def codegen(
804+
self, module: ir.Module, builder: ir.IRBuilder, namespace: Namespace, name: str = ''
805+
) -> ir.Value:
769806
type_ = self.type(namespace)
770807
memory = builder.alloca(type_.get_ir_type(), name=name)
771808
i64 = namespace.get_type('i64').get_ir_type()
772809

773810
for index, initializer in enumerate(self.initializers):
774-
indexed_memory = builder.gep(memory, (i64(0), i64(index),))
811+
indexed_memory = builder.gep(memory, (i64(0), i64(index)))
775812
value = initializer.codegen(module, builder, namespace)
776813
builder.store(value, indexed_memory)
777814
return builder.load(memory)
@@ -787,7 +824,9 @@ class DotAccess(MemoryReference):
787824
left_side: MemoryReference
788825
member: str
789826

790-
def codegen(self, module: ir.Module, builder: ir.IRBuilder, namespace: Namespace, name: str = '') -> ir.Value:
827+
def codegen(
828+
self, module: ir.Module, builder: ir.IRBuilder, namespace: Namespace, name: str = ''
829+
) -> ir.Value:
791830
pointer = self.get_pointer(module, builder, namespace)
792831
return builder.load(pointer)
793832

@@ -808,7 +847,9 @@ class IndexAccess(MemoryReference):
808847
pointer: MemoryReference
809848
index: Expression
810849

811-
def codegen(self, module: ir.Module, builder: ir.IRBuilder, namespace: Namespace, name: str = '') -> ir.Value:
850+
def codegen(
851+
self, module: ir.Module, builder: ir.IRBuilder, namespace: Namespace, name: str = ''
852+
) -> ir.Value:
812853
pointer = self.get_pointer(module, builder, namespace)
813854
return builder.load(pointer)
814855

kotlang/cimport.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,15 @@
3030
]
3131

3232
from kotlang.clang import cindex # noqa
33+
3334
cindex.Config.set_library_file(libclang_file)
3435
clang_index = cindex.Index.create()
3536

3637

3738
def read_header(header: str) -> HeaderContents:
38-
tu = clang_index.parse(find_header(header), options=cindex.TranslationUnit.PARSE_DETAILED_PROCESSING_RECORD)
39+
tu = clang_index.parse(
40+
find_header(header), options=cindex.TranslationUnit.PARSE_DETAILED_PROCESSING_RECORD
41+
)
3942
types: List[ast.TypeDefinition] = []
4043
functions: List[ast.Function] = []
4144
variables: List[ast.VariableDeclaration] = []
@@ -48,15 +51,9 @@ def read_header(header: str) -> HeaderContents:
4851
# We skip macros here, we're only interested in literal defines
4952
if len(macro_tokens) == 1:
5053
defines[c.spelling] = macro_tokens[0]
51-
elif (
52-
(
53-
c.kind is cindex.CursorKind.STRUCT_DECL # type: ignore
54-
and c.is_definition()
55-
)
56-
or (
57-
c.kind is cindex.CursorKind.TYPE_REF # type: ignore
58-
and c.type.kind is cindex.TypeKind.RECORD # type: ignore
59-
)
54+
elif (c.kind is cindex.CursorKind.STRUCT_DECL and c.is_definition()) or ( # type: ignore
55+
c.kind is cindex.CursorKind.TYPE_REF # type: ignore
56+
and c.type.kind is cindex.TypeKind.RECORD # type: ignore
6057
):
6158
types.append(convert_c_record_definition(c))
6259
elif (
@@ -114,12 +111,10 @@ def convert_c_function_declaration(declaration: cindex.Cursor) -> ast.Function:
114111
return_type = convert_c_type_reference(declaration.type.get_result())
115112

116113
parameter_names_types = [
117-
(p.spelling or None, convert_c_type_reference(p.type))
118-
for p in declaration.get_arguments()
114+
(p.spelling or None, convert_c_type_reference(p.type)) for p in declaration.get_arguments()
119115
]
120116
parameters = ast.ParameterList(
121-
[ast.Parameter(n, t) for (n, t) in parameter_names_types],
122-
declaration.type.is_function_variadic(),
117+
[ast.Parameter(n, t) for (n, t) in parameter_names_types], declaration.type.is_function_variadic()
123118
)
124119
return ast.Function(declaration.spelling, parameters, return_type, [], None)
125120

0 commit comments

Comments
 (0)