Skip to content

Commit e255b4c

Browse files
philnik777lanza
authored andcommitted
[CIR][CodeGen] Fix lowering for class types (#378)
1 parent 3943bf8 commit e255b4c

File tree

2 files changed

+103
-9
lines changed

2 files changed

+103
-9
lines changed

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -337,15 +337,15 @@ static void lowerNestedYield(mlir::cir::YieldOpKind targetKind,
337337
[&](mlir::Operation *op) {
338338
if (!isNested(op))
339339
return mlir::WalkResult::advance();
340-
340+
341341
// don't process breaks/continues in nested loops and switches
342342
if (isa<mlir::cir::LoopOp, mlir::cir::SwitchOp>(*op))
343343
return mlir::WalkResult::skip();
344344

345345
auto yield = dyn_cast<mlir::cir::YieldOp>(*op);
346346
if (yield && yield.getKind() == targetKind) {
347347
rewriter.setInsertionPoint(op);
348-
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(op, yield.getArgs(), dst);
348+
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(op, yield.getArgs(), dst);
349349
}
350350

351351
return mlir::WalkResult::advance();
@@ -1386,11 +1386,11 @@ class CIRSwitchOpLowering
13861386
}
13871387

13881388
for (auto& blk : region.getBlocks()) {
1389-
if (blk.getNumSuccessors())
1389+
if (blk.getNumSuccessors())
13901390
continue;
13911391

13921392
// Handle switch-case yields.
1393-
auto *terminator = blk.getTerminator();
1393+
auto *terminator = blk.getTerminator();
13941394
if (auto yieldOp = dyn_cast<mlir::cir::YieldOp>(terminator)) {
13951395
// TODO(cir): Ensure every yield instead of dealing with optional
13961396
// values.
@@ -1414,7 +1414,7 @@ class CIRSwitchOpLowering
14141414
}
14151415
}
14161416

1417-
lowerNestedYield(mlir::cir::YieldOpKind::Break,
1417+
lowerNestedYield(mlir::cir::YieldOpKind::Break,
14181418
rewriter, region, exitBlock);
14191419

14201420
// Extract region contents before erasing the switch op.
@@ -1930,7 +1930,8 @@ class CIRGetMemberOpLowering
19301930
assert(structTy && "expected struct type");
19311931

19321932
switch (structTy.getKind()) {
1933-
case mlir::cir::StructType::Struct: {
1933+
case mlir::cir::StructType::Struct:
1934+
case mlir::cir::StructType::Class: {
19341935
// Since the base address is a pointer to an aggregate, the first offset
19351936
// is always zero. The second offset tell us which member it will access.
19361937
llvm::SmallVector<mlir::LLVM::GEPArg, 2> offset{0, op.getIndex()};
@@ -1945,9 +1946,6 @@ class CIRGetMemberOpLowering
19451946
rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(op, llResTy,
19461947
adaptor.getAddr());
19471948
return mlir::success();
1948-
default:
1949-
return op.emitError()
1950-
<< "struct kind '" << structTy.getKind() << "' is NYI";
19511949
}
19521950
}
19531951
};

clang/test/CIR/Lowering/class.cir

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
// RUN: cir-opt %s -cir-to-llvm -o %t.mlir
2+
// RUN: FileCheck --input-file=%t.mlir %s
3+
4+
!s32i = !cir.int<s, 32>
5+
!u8i = !cir.int<u, 8>
6+
!u32i = !cir.int<u, 32>
7+
!ty_22S22 = !cir.struct<class "S" {!u8i, !s32i}>
8+
!ty_22S2A22 = !cir.struct<class "S2A" {!s32i} #cir.record.decl.ast>
9+
!ty_22S122 = !cir.struct<class "S1" {!s32i, f32, !cir.ptr<!s32i>} #cir.record.decl.ast>
10+
!ty_22S222 = !cir.struct<class "S2" {!ty_22S2A22} #cir.record.decl.ast>
11+
!ty_22S322 = !cir.struct<class "S3" {!s32i} #cir.record.decl.ast>
12+
13+
module {
14+
cir.func @test() {
15+
%1 = cir.alloca !ty_22S22, cir.ptr <!ty_22S22>, ["x"] {alignment = 4 : i64}
16+
// CHECK: %[[#ARRSIZE:]] = llvm.mlir.constant(1 : index) : i64
17+
// CHECK: %[[#CLASS:]] = llvm.alloca %[[#ARRSIZE]] x !llvm.struct<"class.S", (i8, i32)>
18+
%3 = cir.get_member %1[0] {name = "c"} : !cir.ptr<!ty_22S22> -> !cir.ptr<!u8i>
19+
// CHECK: = llvm.getelementptr %[[#CLASS]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"class.S", (i8, i32)>
20+
%5 = cir.get_member %1[1] {name = "i"} : !cir.ptr<!ty_22S22> -> !cir.ptr<!s32i>
21+
// CHECK: = llvm.getelementptr %[[#CLASS]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"class.S", (i8, i32)>
22+
cir.return
23+
}
24+
25+
cir.func @shouldConstInitLocalClassesWithConstStructAttr() {
26+
%0 = cir.alloca !ty_22S2A22, cir.ptr <!ty_22S2A22>, ["s"] {alignment = 4 : i64}
27+
%1 = cir.const(#cir.const_struct<{#cir.int<1> : !s32i}> : !ty_22S2A22) : !ty_22S2A22
28+
cir.store %1, %0 : !ty_22S2A22, cir.ptr <!ty_22S2A22>
29+
cir.return
30+
}
31+
// CHECK: llvm.func @shouldConstInitLocalClassesWithConstStructAttr()
32+
// CHECK: %0 = llvm.mlir.constant(1 : index) : i64
33+
// CHECK: %1 = llvm.alloca %0 x !llvm.struct<"class.S2A", (i32)> {alignment = 4 : i64} : (i64) -> !llvm.ptr
34+
// CHECK: %2 = llvm.mlir.undef : !llvm.struct<"class.S2A", (i32)>
35+
// CHECK: %3 = llvm.mlir.constant(1 : i32) : i32
36+
// CHECK: %4 = llvm.insertvalue %3, %2[0] : !llvm.struct<"class.S2A", (i32)>
37+
// CHECK: llvm.store %4, %1 : !llvm.struct<"class.S2A", (i32)>, !llvm.ptr
38+
// CHECK: llvm.return
39+
// CHECK: }
40+
41+
// Should lower basic #cir.const_struct initializer.
42+
cir.global external @s1 = #cir.const_struct<{#cir.int<1> : !s32i, 1.000000e-01 : f32, #cir.ptr<null> : !cir.ptr<!s32i>}> : !ty_22S122
43+
// CHECK: llvm.mlir.global external @s1() {addr_space = 0 : i32} : !llvm.struct<"class.S1", (i32, f32, ptr)> {
44+
// CHECK: %0 = llvm.mlir.undef : !llvm.struct<"class.S1", (i32, f32, ptr)>
45+
// CHECK: %1 = llvm.mlir.constant(1 : i32) : i32
46+
// CHECK: %2 = llvm.insertvalue %1, %0[0] : !llvm.struct<"class.S1", (i32, f32, ptr)>
47+
// CHECK: %3 = llvm.mlir.constant(1.000000e-01 : f32) : f32
48+
// CHECK: %4 = llvm.insertvalue %3, %2[1] : !llvm.struct<"class.S1", (i32, f32, ptr)>
49+
// CHECK: %5 = llvm.mlir.zero : !llvm.ptr
50+
// CHECK: %6 = llvm.insertvalue %5, %4[2] : !llvm.struct<"class.S1", (i32, f32, ptr)>
51+
// CHECK: llvm.return %6 : !llvm.struct<"class.S1", (i32, f32, ptr)>
52+
// CHECK: }
53+
54+
// Should lower nested #cir.const_struct initializer.
55+
cir.global external @s2 = #cir.const_struct<{#cir.const_struct<{#cir.int<1> : !s32i}> : !ty_22S2A22}> : !ty_22S222
56+
// CHECK: llvm.mlir.global external @s2() {addr_space = 0 : i32} : !llvm.struct<"class.S2", (struct<"class.S2A", (i32)>)> {
57+
// CHECK: %0 = llvm.mlir.undef : !llvm.struct<"class.S2", (struct<"class.S2A", (i32)>)>
58+
// CHECK: %1 = llvm.mlir.undef : !llvm.struct<"class.S2A", (i32)>
59+
// CHECK: %2 = llvm.mlir.constant(1 : i32) : i32
60+
// CHECK: %3 = llvm.insertvalue %2, %1[0] : !llvm.struct<"class.S2A", (i32)>
61+
// CHECK: %4 = llvm.insertvalue %3, %0[0] : !llvm.struct<"class.S2", (struct<"class.S2A", (i32)>)>
62+
// CHECK: llvm.return %4 : !llvm.struct<"class.S2", (struct<"class.S2A", (i32)>)>
63+
// CHECK: }
64+
65+
cir.global external @s3 = #cir.const_array<[#cir.const_struct<{#cir.int<1> : !s32i}> : !ty_22S322, #cir.const_struct<{#cir.int<2> : !s32i}> : !ty_22S322, #cir.const_struct<{#cir.int<3> : !s32i}> : !ty_22S322]> : !cir.array<!ty_22S322 x 3>
66+
// CHECK: llvm.mlir.global external @s3() {addr_space = 0 : i32} : !llvm.array<3 x struct<"class.S3", (i32)>> {
67+
// CHECK: %0 = llvm.mlir.undef : !llvm.array<3 x struct<"class.S3", (i32)>>
68+
// CHECK: %1 = llvm.mlir.undef : !llvm.struct<"class.S3", (i32)>
69+
// CHECK: %2 = llvm.mlir.constant(1 : i32) : i32
70+
// CHECK: %3 = llvm.insertvalue %2, %1[0] : !llvm.struct<"class.S3", (i32)>
71+
// CHECK: %4 = llvm.insertvalue %3, %0[0] : !llvm.array<3 x struct<"class.S3", (i32)>>
72+
// CHECK: %5 = llvm.mlir.undef : !llvm.struct<"class.S3", (i32)>
73+
// CHECK: %6 = llvm.mlir.constant(2 : i32) : i32
74+
// CHECK: %7 = llvm.insertvalue %6, %5[0] : !llvm.struct<"class.S3", (i32)>
75+
// CHECK: %8 = llvm.insertvalue %7, %4[1] : !llvm.array<3 x struct<"class.S3", (i32)>>
76+
// CHECK: %9 = llvm.mlir.undef : !llvm.struct<"class.S3", (i32)>
77+
// CHECK: %10 = llvm.mlir.constant(3 : i32) : i32
78+
// CHECK: %11 = llvm.insertvalue %10, %9[0] : !llvm.struct<"class.S3", (i32)>
79+
// CHECK: %12 = llvm.insertvalue %11, %8[2] : !llvm.array<3 x struct<"class.S3", (i32)>>
80+
// CHECK: llvm.return %12 : !llvm.array<3 x struct<"class.S3", (i32)>>
81+
// CHECK: }
82+
83+
cir.func @shouldLowerClassCopies() {
84+
// CHECK: llvm.func @shouldLowerClassCopies()
85+
%1 = cir.alloca !ty_22S22, cir.ptr <!ty_22S22>, ["a"] {alignment = 4 : i64}
86+
// CHECK: %[[#ONE:]] = llvm.mlir.constant(1 : index) : i64
87+
// CHECK: %[[#SA:]] = llvm.alloca %[[#ONE]] x !llvm.struct<"class.S", (i8, i32)> {alignment = 4 : i64} : (i64) -> !llvm.ptr
88+
%2 = cir.alloca !ty_22S22, cir.ptr <!ty_22S22>, ["b", init] {alignment = 4 : i64}
89+
// CHECK: %[[#ONE:]] = llvm.mlir.constant(1 : index) : i64
90+
// CHECK: %[[#SB:]] = llvm.alloca %[[#ONE]] x !llvm.struct<"class.S", (i8, i32)> {alignment = 4 : i64} : (i64) -> !llvm.ptr
91+
cir.copy %1 to %2 : !cir.ptr<!ty_22S22>
92+
// CHECK: %[[#SIZE:]] = llvm.mlir.constant(8 : i32) : i32
93+
// CHECK: "llvm.intr.memcpy"(%[[#SB]], %[[#SA]], %[[#SIZE]]) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
94+
cir.return
95+
}
96+
}

0 commit comments

Comments
 (0)