Skip to content

Commit d1ad076

Browse files
authored
[CIR][ABI][Lowering] Fixes calls with union type (#1119)
This PR handles calls with unions passed by value in the calling convention pass. #### Implementation As one may know, data layout for unions in CIR and in LLVM differ one from another. In CIR we track all the union members, while in LLVM IR only the largest one. And here we need to take this difference into account: we need to find a type of the largest member and treat it as the first (and only) union member in order to preserve all the logic from the original codegen. There is a method `StructType::getLargestMember` - but looks like it produces different results (with the one I implemented or better to say copy-pasted). Maybe it's done intentionally, I don't know. The LLVM IR produced has also some difference from the original one. In the original IR `gep` is emitted - and we can not do the same. If we create `getMemberOp` we may fail on type checking for unions - since the first member type may differ from the largest type. This is why we create `bitcast` instead. Relates to the issue #1061
1 parent 16a027a commit d1ad076

File tree

5 files changed

+68
-21
lines changed

5 files changed

+68
-21
lines changed

clang/include/clang/CIR/MissingFeatures.h

-1
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,6 @@ struct MissingFeatures {
261261
static bool X86TypeClassification() { return false; }
262262

263263
static bool ABIClangTypeKind() { return false; }
264-
static bool ABIEnterStructForCoercedAccess() { return false; }
265264
static bool ABIFuncPtr() { return false; }
266265
static bool ABIInRegAttribute() { return false; }
267266
static bool ABINestedRecordLayout() { return false; }

clang/lib/CIR/Dialect/IR/CIRTypes.cpp

+6-7
Original file line numberDiff line numberDiff line change
@@ -519,13 +519,12 @@ void StructType::computeSizeAndAlignment(
519519

520520
// Found a nested union: recurse into it to fetch its largest member.
521521
auto structMember = mlir::dyn_cast<StructType>(ty);
522-
if (structMember && structMember.isUnion()) {
523-
auto candidate = structMember.getLargestMember(dataLayout);
524-
if (dataLayout.getTypeSize(candidate) > largestMemberSize) {
525-
largestMember = candidate;
526-
largestMemberSize = dataLayout.getTypeSize(largestMember);
527-
}
528-
} else if (dataLayout.getTypeSize(ty) > largestMemberSize) {
522+
if (!largestMember ||
523+
dataLayout.getTypeABIAlignment(ty) >
524+
dataLayout.getTypeABIAlignment(largestMember) ||
525+
(dataLayout.getTypeABIAlignment(ty) ==
526+
dataLayout.getTypeABIAlignment(largestMember) &&
527+
dataLayout.getTypeSize(ty) > largestMemberSize)) {
529528
largestMember = ty;
530529
largestMemberSize = dataLayout.getTypeSize(largestMember);
531530
}

clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerFunction.cpp

+30-11
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,12 @@ mlir::Value createCoercedBitcast(mlir::Value Src, mlir::Type DestTy,
5353
CastKind::bitcast, Src);
5454
}
5555

56+
// FIXME(cir): Create a custom rewriter class to abstract this away.
57+
mlir::Value createBitcast(mlir::Value Src, mlir::Type Ty, LowerFunction &LF) {
58+
return LF.getRewriter().create<CastOp>(Src.getLoc(), Ty, CastKind::bitcast,
59+
Src);
60+
}
61+
5662
/// Given a struct pointer that we are accessing some number of bytes out of it,
5763
/// try to gep into the struct to get at its inner goodness. Dive as deep as
5864
/// possible without entering an element with an in-memory size smaller than
@@ -67,6 +73,9 @@ mlir::Value enterStructPointerForCoercedAccess(mlir::Value SrcPtr,
6773

6874
mlir::Type FirstElt = SrcSTy.getMembers()[0];
6975

76+
if (SrcSTy.isUnion())
77+
FirstElt = SrcSTy.getLargestMember(CGF.LM.getDataLayout().layout);
78+
7079
// If the first elt is at least as large as what we're looking for, or if the
7180
// first element is the same size as the whole struct, we can enter it. The
7281
// comparison must be made on the store size and not the alloca size. Using
@@ -76,10 +85,26 @@ mlir::Value enterStructPointerForCoercedAccess(mlir::Value SrcPtr,
7685
FirstEltSize < CGF.LM.getDataLayout().getTypeStoreSize(SrcSTy))
7786
return SrcPtr;
7887

79-
cir_cconv_assert_or_abort(
80-
!cir::MissingFeatures::ABIEnterStructForCoercedAccess(), "NYI");
81-
return SrcPtr; // FIXME: This is a temporary workaround for the assertion
82-
// above.
88+
auto &rw = CGF.getRewriter();
89+
auto *ctxt = rw.getContext();
90+
auto ptrTy = PointerType::get(ctxt, FirstElt);
91+
if (mlir::isa<StructType>(SrcPtr.getType())) {
92+
auto addr = SrcPtr;
93+
if (auto load = mlir::dyn_cast<LoadOp>(SrcPtr.getDefiningOp()))
94+
addr = load.getAddr();
95+
cir_cconv_assert(mlir::isa<PointerType>(addr.getType()));
96+
// we can not use getMemberOp here since we need a pointer to the first
97+
// element. And in the case of unions we pick a type of the largest elt,
98+
// that may or may not be the first one. Thus, getMemberOp verification
99+
// may fail.
100+
auto cast = createBitcast(addr, ptrTy, CGF);
101+
SrcPtr = rw.create<LoadOp>(SrcPtr.getLoc(), cast);
102+
}
103+
104+
if (auto sty = mlir::dyn_cast<StructType>(SrcPtr.getType()))
105+
return enterStructPointerForCoercedAccess(SrcPtr, sty, DstSize, CGF);
106+
107+
return SrcPtr;
83108
}
84109

85110
/// Convert a value Val to the specific Ty where both
@@ -141,12 +166,6 @@ static mlir::Value coerceIntOrPtrToIntOrPtr(mlir::Value val, mlir::Type typ,
141166
return val;
142167
}
143168

144-
// FIXME(cir): Create a custom rewriter class to abstract this away.
145-
mlir::Value createBitcast(mlir::Value Src, mlir::Type Ty, LowerFunction &LF) {
146-
return LF.getRewriter().create<CastOp>(Src.getLoc(), Ty, CastKind::bitcast,
147-
Src);
148-
}
149-
150169
AllocaOp createTmpAlloca(LowerFunction &LF, mlir::Location loc, mlir::Type ty) {
151170
auto &rw = LF.getRewriter();
152171
auto *ctxt = rw.getContext();
@@ -302,7 +321,7 @@ mlir::Value createCoercedValue(mlir::Value Src, mlir::Type Ty,
302321
// extension or truncation to the desired type.
303322
if ((mlir::isa<IntType>(Ty) || mlir::isa<PointerType>(Ty)) &&
304323
(mlir::isa<IntType>(SrcTy) || mlir::isa<PointerType>(SrcTy))) {
305-
cir_cconv_unreachable("NYI");
324+
return coerceIntOrPtrToIntOrPtr(Src, Ty, CGF);
306325
}
307326

308327
// If load is legal, just bitcast the src pointer.

clang/test/CIR/CallConvLowering/AArch64/union.c

+31-1
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,34 @@ void foo(U u) {}
3838
U init() {
3939
U u;
4040
return u;
41-
}
41+
}
42+
43+
typedef union {
44+
45+
struct {
46+
short a;
47+
char b;
48+
char c;
49+
};
50+
51+
int x;
52+
} A;
53+
54+
void passA(A x) {}
55+
56+
// CIR: cir.func {{.*@callA}}()
57+
// CIR: %[[#V0:]] = cir.alloca !ty_A, !cir.ptr<!ty_A>, ["x"] {alignment = 4 : i64}
58+
// CIR: %[[#V1:]] = cir.cast(bitcast, %[[#V0:]] : !cir.ptr<!ty_A>), !cir.ptr<!s32i>
59+
// CIR: %[[#V2:]] = cir.load %[[#V1]] : !cir.ptr<!s32i>, !s32i
60+
// CIR: %[[#V3:]] = cir.cast(integral, %[[#V2]] : !s32i), !u64i
61+
// CIR: cir.call @passA(%[[#V3]]) : (!u64i) -> ()
62+
63+
// LLVM: void @callA()
64+
// LLVM: %[[#V0:]] = alloca %union.A, i64 1, align 4
65+
// LLVM: %[[#V1:]] = load i32, ptr %[[#V0]], align 4
66+
// LLVM: %[[#V2:]] = sext i32 %[[#V1]] to i64
67+
// LLVM: call void @passA(i64 %[[#V2]])
68+
void callA() {
69+
A x;
70+
passA(x);
71+
}

clang/test/CIR/Lowering/unions.cir

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ module {
1616
cir.global external @u2 = #cir.zero : !ty_U2_
1717
cir.global external @u3 = #cir.zero : !ty_U3_
1818
// CHECK: llvm.mlir.global external @u2() {addr_space = 0 : i32} : !llvm.struct<"union.U2", (f64)>
19-
// CHECK: llvm.mlir.global external @u3() {addr_space = 0 : i32} : !llvm.struct<"union.U3", (i32)>
19+
// CHECK: llvm.mlir.global external @u3() {addr_space = 0 : i32} : !llvm.struct<"union.U3", (struct<"union.U1", (i32)>)>
2020

2121
// CHECK: llvm.func @test
2222
cir.func @test(%arg0: !cir.ptr<!ty_U1_>) {

0 commit comments

Comments
 (0)