Skip to content

[CIR][CodeGen] Fix lowering for class types #378

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -337,15 +337,15 @@ static void lowerNestedYield(mlir::cir::YieldOpKind targetKind,
[&](mlir::Operation *op) {
if (!isNested(op))
return mlir::WalkResult::advance();

// don't process breaks/continues in nested loops and switches
if (isa<mlir::cir::LoopOp, mlir::cir::SwitchOp>(*op))
return mlir::WalkResult::skip();

auto yield = dyn_cast<mlir::cir::YieldOp>(*op);
if (yield && yield.getKind() == targetKind) {
rewriter.setInsertionPoint(op);
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(op, yield.getArgs(), dst);
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(op, yield.getArgs(), dst);
}

return mlir::WalkResult::advance();
Expand Down Expand Up @@ -1364,11 +1364,11 @@ class CIRSwitchOpLowering
}

for (auto& blk : region.getBlocks()) {
if (blk.getNumSuccessors())
if (blk.getNumSuccessors())
continue;

// Handle switch-case yields.
auto *terminator = blk.getTerminator();
auto *terminator = blk.getTerminator();
if (auto yieldOp = dyn_cast<mlir::cir::YieldOp>(terminator)) {
// TODO(cir): Ensure every yield instead of dealing with optional
// values.
Expand All @@ -1392,7 +1392,7 @@ class CIRSwitchOpLowering
}
}

lowerNestedYield(mlir::cir::YieldOpKind::Break,
lowerNestedYield(mlir::cir::YieldOpKind::Break,
rewriter, region, exitBlock);

// Extract region contents before erasing the switch op.
Expand Down Expand Up @@ -1908,7 +1908,8 @@ class CIRGetMemberOpLowering
assert(structTy && "expected struct type");

switch (structTy.getKind()) {
case mlir::cir::StructType::Struct: {
case mlir::cir::StructType::Struct:
case mlir::cir::StructType::Class: {
// Since the base address is a pointer to an aggregate, the first offset
// is always zero. The second offset tell us which member it will access.
llvm::SmallVector<mlir::LLVM::GEPArg, 2> offset{0, op.getIndex()};
Expand All @@ -1923,9 +1924,6 @@ class CIRGetMemberOpLowering
rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(op, llResTy,
adaptor.getAddr());
return mlir::success();
default:
return op.emitError()
<< "struct kind '" << structTy.getKind() << "' is NYI";
}
}
};
Expand Down
96 changes: 96 additions & 0 deletions clang/test/CIR/Lowering/class.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
// RUN: cir-opt %s -cir-to-llvm -o %t.mlir
// RUN: FileCheck --input-file=%t.mlir %s

!s32i = !cir.int<s, 32>
!u8i = !cir.int<u, 8>
!u32i = !cir.int<u, 32>
!ty_22S22 = !cir.struct<class "S" {!u8i, !s32i}>
!ty_22S2A22 = !cir.struct<class "S2A" {!s32i} #cir.record.decl.ast>
!ty_22S122 = !cir.struct<class "S1" {!s32i, f32, !cir.ptr<!s32i>} #cir.record.decl.ast>
!ty_22S222 = !cir.struct<class "S2" {!ty_22S2A22} #cir.record.decl.ast>
!ty_22S322 = !cir.struct<class "S3" {!s32i} #cir.record.decl.ast>

module {
cir.func @test() {
%1 = cir.alloca !ty_22S22, cir.ptr <!ty_22S22>, ["x"] {alignment = 4 : i64}
// CHECK: %[[#ARRSIZE:]] = llvm.mlir.constant(1 : index) : i64
// CHECK: %[[#CLASS:]] = llvm.alloca %[[#ARRSIZE]] x !llvm.struct<"class.S", (i8, i32)>
%3 = cir.get_member %1[0] {name = "c"} : !cir.ptr<!ty_22S22> -> !cir.ptr<!u8i>
// CHECK: = llvm.getelementptr %[[#CLASS]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"class.S", (i8, i32)>
%5 = cir.get_member %1[1] {name = "i"} : !cir.ptr<!ty_22S22> -> !cir.ptr<!s32i>
// CHECK: = llvm.getelementptr %[[#CLASS]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"class.S", (i8, i32)>
cir.return
}

cir.func @shouldConstInitLocalClassesWithConstStructAttr() {
%0 = cir.alloca !ty_22S2A22, cir.ptr <!ty_22S2A22>, ["s"] {alignment = 4 : i64}
%1 = cir.const(#cir.const_struct<{#cir.int<1> : !s32i}> : !ty_22S2A22) : !ty_22S2A22
cir.store %1, %0 : !ty_22S2A22, cir.ptr <!ty_22S2A22>
cir.return
}
// CHECK: llvm.func @shouldConstInitLocalClassesWithConstStructAttr()
// CHECK: %0 = llvm.mlir.constant(1 : index) : i64
// CHECK: %1 = llvm.alloca %0 x !llvm.struct<"class.S2A", (i32)> {alignment = 4 : i64} : (i64) -> !llvm.ptr
// CHECK: %2 = llvm.mlir.undef : !llvm.struct<"class.S2A", (i32)>
// CHECK: %3 = llvm.mlir.constant(1 : i32) : i32
// CHECK: %4 = llvm.insertvalue %3, %2[0] : !llvm.struct<"class.S2A", (i32)>
// CHECK: llvm.store %4, %1 : !llvm.struct<"class.S2A", (i32)>, !llvm.ptr
// CHECK: llvm.return
// CHECK: }

// Should lower basic #cir.const_struct initializer.
cir.global external @s1 = #cir.const_struct<{#cir.int<1> : !s32i, 1.000000e-01 : f32, #cir.ptr<null> : !cir.ptr<!s32i>}> : !ty_22S122
// CHECK: llvm.mlir.global external @s1() {addr_space = 0 : i32} : !llvm.struct<"class.S1", (i32, f32, ptr)> {
// CHECK: %0 = llvm.mlir.undef : !llvm.struct<"class.S1", (i32, f32, ptr)>
// CHECK: %1 = llvm.mlir.constant(1 : i32) : i32
// CHECK: %2 = llvm.insertvalue %1, %0[0] : !llvm.struct<"class.S1", (i32, f32, ptr)>
// CHECK: %3 = llvm.mlir.constant(1.000000e-01 : f32) : f32
// CHECK: %4 = llvm.insertvalue %3, %2[1] : !llvm.struct<"class.S1", (i32, f32, ptr)>
// CHECK: %5 = llvm.mlir.zero : !llvm.ptr
// CHECK: %6 = llvm.insertvalue %5, %4[2] : !llvm.struct<"class.S1", (i32, f32, ptr)>
// CHECK: llvm.return %6 : !llvm.struct<"class.S1", (i32, f32, ptr)>
// CHECK: }

// Should lower nested #cir.const_struct initializer.
cir.global external @s2 = #cir.const_struct<{#cir.const_struct<{#cir.int<1> : !s32i}> : !ty_22S2A22}> : !ty_22S222
// CHECK: llvm.mlir.global external @s2() {addr_space = 0 : i32} : !llvm.struct<"class.S2", (struct<"class.S2A", (i32)>)> {
// CHECK: %0 = llvm.mlir.undef : !llvm.struct<"class.S2", (struct<"class.S2A", (i32)>)>
// CHECK: %1 = llvm.mlir.undef : !llvm.struct<"class.S2A", (i32)>
// CHECK: %2 = llvm.mlir.constant(1 : i32) : i32
// CHECK: %3 = llvm.insertvalue %2, %1[0] : !llvm.struct<"class.S2A", (i32)>
// CHECK: %4 = llvm.insertvalue %3, %0[0] : !llvm.struct<"class.S2", (struct<"class.S2A", (i32)>)>
// CHECK: llvm.return %4 : !llvm.struct<"class.S2", (struct<"class.S2A", (i32)>)>
// CHECK: }

cir.global external @s3 = #cir.const_array<[#cir.const_struct<{#cir.int<1> : !s32i}> : !ty_22S322, #cir.const_struct<{#cir.int<2> : !s32i}> : !ty_22S322, #cir.const_struct<{#cir.int<3> : !s32i}> : !ty_22S322]> : !cir.array<!ty_22S322 x 3>
// CHECK: llvm.mlir.global external @s3() {addr_space = 0 : i32} : !llvm.array<3 x struct<"class.S3", (i32)>> {
// CHECK: %0 = llvm.mlir.undef : !llvm.array<3 x struct<"class.S3", (i32)>>
// CHECK: %1 = llvm.mlir.undef : !llvm.struct<"class.S3", (i32)>
// CHECK: %2 = llvm.mlir.constant(1 : i32) : i32
// CHECK: %3 = llvm.insertvalue %2, %1[0] : !llvm.struct<"class.S3", (i32)>
// CHECK: %4 = llvm.insertvalue %3, %0[0] : !llvm.array<3 x struct<"class.S3", (i32)>>
// CHECK: %5 = llvm.mlir.undef : !llvm.struct<"class.S3", (i32)>
// CHECK: %6 = llvm.mlir.constant(2 : i32) : i32
// CHECK: %7 = llvm.insertvalue %6, %5[0] : !llvm.struct<"class.S3", (i32)>
// CHECK: %8 = llvm.insertvalue %7, %4[1] : !llvm.array<3 x struct<"class.S3", (i32)>>
// CHECK: %9 = llvm.mlir.undef : !llvm.struct<"class.S3", (i32)>
// CHECK: %10 = llvm.mlir.constant(3 : i32) : i32
// CHECK: %11 = llvm.insertvalue %10, %9[0] : !llvm.struct<"class.S3", (i32)>
// CHECK: %12 = llvm.insertvalue %11, %8[2] : !llvm.array<3 x struct<"class.S3", (i32)>>
// CHECK: llvm.return %12 : !llvm.array<3 x struct<"class.S3", (i32)>>
// CHECK: }

cir.func @shouldLowerClassCopies() {
// CHECK: llvm.func @shouldLowerClassCopies()
%1 = cir.alloca !ty_22S22, cir.ptr <!ty_22S22>, ["a"] {alignment = 4 : i64}
// CHECK: %[[#ONE:]] = llvm.mlir.constant(1 : index) : i64
// CHECK: %[[#SA:]] = llvm.alloca %[[#ONE]] x !llvm.struct<"class.S", (i8, i32)> {alignment = 4 : i64} : (i64) -> !llvm.ptr
%2 = cir.alloca !ty_22S22, cir.ptr <!ty_22S22>, ["b", init] {alignment = 4 : i64}
// CHECK: %[[#ONE:]] = llvm.mlir.constant(1 : index) : i64
// CHECK: %[[#SB:]] = llvm.alloca %[[#ONE]] x !llvm.struct<"class.S", (i8, i32)> {alignment = 4 : i64} : (i64) -> !llvm.ptr
cir.copy %1 to %2 : !cir.ptr<!ty_22S22>
// CHECK: %[[#SIZE:]] = llvm.mlir.constant(8 : i32) : i32
// CHECK: "llvm.intr.memcpy"(%[[#SB]], %[[#SA]], %[[#SIZE]]) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
cir.return
}
}