Skip to content

Commit 5373f42

Browse files
authored
[CIR][ABI][Lowering] Fixes calling convention (#1308)
This PR fixes two run time bugs in the calling convention pass. These bugs were found with `csmith`. Case #1. Return value from a function. Before this PR the returned value were stored in a bit casted memory location. But for the next example it's not safe: the size of a memory slot is less than the size of return value. And the store operation cause a segfault! ``` #pragma pack(push) #pragma pack(1) typedef struct { int f0 : 18; int f1 : 31; int f2 : 5; int f3 : 29; int f4 : 24; } PackedS; #pragma pack(pop) ``` CIR type for this struct is `!ty_PackedS1_ = !cir.struct<struct "PackedS1" {!cir.array<!u8i x 14>}>`, i.e. it occupies 14 bytes. Before this PR the next code ``` PackedS foo(void) { PackedS s; return s; } void check(void) { PackedS y = foo(); } ``` produced the next CIR: ``` %0 = cir.alloca !ty_PackedS1_, !cir.ptr<!ty_PackedS1_>, ["y", init] {alignment = 1 : i64} %1 = cir.call @foo() : () -> !cir.array<!u64i x 2> %2 = cir.cast(bitcast, %0 : !cir.ptr<!ty_PackedS1_>), !cir.ptr<!cir.array<!u64i x 2>> cir.store %1, %2 : !cir.array<!u64i x 2>, !cir.ptr<!cir.array<!u64i x 2>> ``` As one cat see, `%1` is an array of two 64-bit integers and the memory was allocated for 14 bytes only (size of struct). Hence the segfault! This PR fixes such cases and now we have a coercion through memory, which is even with the OG. Case #2. Passing an argument from a pointer deref. Previously for the struct types passed by value we tried to find alloca instruction in order to use it as a source for memcpy operation. But if we have pointer dereference, (in other words if we have a `<!cir.ptr < !cir.ptr ... > >` as alloca result) we don't need to search for the address of the location where this pointer stored - instead we're interested in the pointer itself. And it's a general approach - instead of trying to find an alloca instruction we need to find a first pointer on the way - that will be an address we meed to use for the memcpy source. I combined these two cases into a single PR since there are only few changes actually. But I can split in two if you'd prefer
1 parent fee4bb6 commit 5373f42

File tree

2 files changed

+112
-19
lines changed

2 files changed

+112
-19
lines changed

clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerFunction.cpp

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,17 @@ cir::AllocaOp findAlloca(mlir::Operation *op) {
232232
return {};
233233
}
234234

235+
mlir::Value findAddr(mlir::Value v) {
236+
if (mlir::isa<cir::PointerType>(v.getType()))
237+
return v;
238+
239+
auto op = v.getDefiningOp();
240+
if (!op || !mlir::isa<cir::CastOp, cir::LoadOp, cir::ReturnOp>(op))
241+
return {};
242+
243+
return findAddr(op->getOperand(0));
244+
}
245+
235246
/// Create a store to \param Dst from \param Src where the source and
236247
/// destination may have different types.
237248
///
@@ -338,10 +349,10 @@ mlir::Value createCoercedValue(mlir::Value Src, mlir::Type Ty,
338349
return CGF.buildAggregateBitcast(Src, Ty);
339350
}
340351

341-
if (auto alloca = findAlloca(Src.getDefiningOp())) {
342-
auto tmpAlloca = createTmpAlloca(CGF, alloca.getLoc(), Ty);
343-
createMemCpy(CGF, tmpAlloca, alloca, SrcSize.getFixedValue());
344-
return CGF.getRewriter().create<LoadOp>(alloca.getLoc(),
352+
if (mlir::Value addr = findAddr(Src)) {
353+
auto tmpAlloca = createTmpAlloca(CGF, addr.getLoc(), Ty);
354+
createMemCpy(CGF, tmpAlloca, addr, SrcSize.getFixedValue());
355+
return CGF.getRewriter().create<LoadOp>(addr.getLoc(),
345356
tmpAlloca.getResult());
346357
}
347358

@@ -371,7 +382,6 @@ mlir::Value createCoercedNonPrimitive(mlir::Value src, mlir::Type ty,
371382

372383
auto tySize = LF.LM.getDataLayout().getTypeStoreSize(ty);
373384
createMemCpy(LF, alloca, addr, tySize.getFixedValue());
374-
375385
auto newLoad = bld.create<LoadOp>(src.getLoc(), alloca.getResult());
376386
bld.replaceAllOpUsesWith(load, newLoad);
377387

@@ -1265,6 +1275,14 @@ mlir::Value LowerFunction::rewriteCallOp(const LowerFunctionInfo &CallInfo,
12651275

12661276
// FIXME(cir): Use return value slot here.
12671277
mlir::Value RetVal = callOp.getResult();
1278+
mlir::Value dstPtr;
1279+
for (auto *user : Caller->getUsers()) {
1280+
if (auto storeOp = mlir::dyn_cast<StoreOp>(user)) {
1281+
assert(!dstPtr && "multiple destinations for the return value");
1282+
dstPtr = storeOp.getAddr();
1283+
}
1284+
}
1285+
12681286
// TODO(cir): Check for volatile return values.
12691287
cir_cconv_assert(!cir::MissingFeatures::volatileTypes());
12701288

@@ -1283,16 +1301,11 @@ mlir::Value LowerFunction::rewriteCallOp(const LowerFunctionInfo &CallInfo,
12831301
if (mlir::dyn_cast<StructType>(RetTy) &&
12841302
mlir::cast<StructType>(RetTy).getNumElements() != 0) {
12851303
RetVal = newCallOp.getResult();
1304+
createCoercedStore(RetVal, dstPtr, false, *this);
12861305

1287-
llvm::SmallVector<StoreOp, 8> workList;
12881306
for (auto *user : Caller->getUsers())
12891307
if (auto storeOp = mlir::dyn_cast<StoreOp>(user))
1290-
workList.push_back(storeOp);
1291-
for (StoreOp storeOp : workList) {
1292-
auto destPtr =
1293-
createCoercedBitcast(storeOp.getAddr(), RetVal.getType(), *this);
1294-
rewriter.replaceOpWithNewOp<StoreOp>(storeOp, RetVal, destPtr);
1295-
}
1308+
rewriter.eraseOp(storeOp);
12961309
}
12971310

12981311
// NOTE(cir): No need to convert from a temp to an RValue. This is

clang/test/CIR/CallConvLowering/AArch64/aarch64-cc-structs.c

Lines changed: 87 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -302,19 +302,18 @@ void pass_nested_u(NESTED_U a) {}
302302

303303
// CHECK: cir.func no_proto @call_nested_u()
304304
// CHECK: %[[#V0:]] = cir.alloca !ty_NESTED_U, !cir.ptr<!ty_NESTED_U>
305-
// CHECK: %[[#V1:]] = cir.alloca !u64i, !cir.ptr<!u64i>, ["tmp"] {alignment = 8 : i64}
305+
// CHECK: %[[#V1:]] = cir.alloca !u64i, !cir.ptr<!u64i>, ["tmp"]
306306
// CHECK: %[[#V2:]] = cir.load %[[#V0]] : !cir.ptr<!ty_NESTED_U>, !ty_NESTED_U
307-
// CHECK: %[[#V3:]] = cir.cast(bitcast, %[[#V0]] : !cir.ptr<!ty_NESTED_U>)
308-
// CHECK: %[[#V4:]] = cir.load %[[#V3]]
309-
// CHECK: %[[#V5:]] = cir.cast(bitcast, %[[#V3]]
310-
// CHECK: %[[#V6:]] = cir.load %[[#V5]]
311-
// CHECK: %[[#V7:]] = cir.cast(bitcast, %[[#V0]] : !cir.ptr<!ty_NESTED_U>), !cir.ptr<!void>
307+
// CHECK: %[[#V3:]] = cir.cast(bitcast, %[[#V0]] : !cir.ptr<!ty_NESTED_U>), !cir.ptr<!ty_anon2E0_>
308+
// CHECK: %[[#V4:]] = cir.load %[[#V3]] : !cir.ptr<!ty_anon2E0_>, !ty_anon2E0_
309+
// CHECK: %[[#V5:]] = cir.cast(bitcast, %[[#V3]] : !cir.ptr<!ty_anon2E0_>), !cir.ptr<!ty_anon2E1_>
310+
// CHECK: %[[#V6:]] = cir.load %[[#V5]] : !cir.ptr<!ty_anon2E1_>, !ty_anon2E1_
311+
// CHECK: %[[#V7:]] = cir.cast(bitcast, %[[#V5]] : !cir.ptr<!ty_anon2E1_>), !cir.ptr<!void>
312312
// CHECK: %[[#V8:]] = cir.cast(bitcast, %[[#V1]] : !cir.ptr<!u64i>), !cir.ptr<!void>
313313
// CHECK: %[[#V9:]] = cir.const #cir.int<2> : !u64i
314314
// CHECK: cir.libc.memcpy %[[#V9]] bytes from %[[#V7]] to %[[#V8]] : !u64i, !cir.ptr<!void> -> !cir.ptr<!void>
315315
// CHECK: %[[#V10:]] = cir.load %[[#V1]] : !cir.ptr<!u64i>, !u64i
316316
// CHECK: cir.call @pass_nested_u(%[[#V10]]) : (!u64i) -> ()
317-
// CHECK: cir.return
318317

319318
// LLVM: void @call_nested_u()
320319
// LLVM: %[[#V1:]] = alloca %struct.NESTED_U, i64 1, align 1
@@ -330,3 +329,84 @@ void call_nested_u() {
330329
NESTED_U a;
331330
pass_nested_u(a);
332331
}
332+
333+
334+
#pragma pack(push)
335+
#pragma pack(1)
336+
typedef struct {
337+
int f0 : 18;
338+
int f1 : 31;
339+
int f2 : 5;
340+
int f3 : 29;
341+
int f4 : 24;
342+
} PackedS1;
343+
#pragma pack(pop)
344+
345+
PackedS1 foo(void) {
346+
PackedS1 s;
347+
return s;
348+
}
349+
350+
void bar(void) {
351+
PackedS1 y = foo();
352+
}
353+
354+
// CHECK: cir.func @bar
355+
// CHECK: %[[#V0:]] = cir.alloca !ty_PackedS1_, !cir.ptr<!ty_PackedS1_>, ["y", init]
356+
// CHECK: %[[#V1:]] = cir.alloca !cir.array<!u64i x 2>, !cir.ptr<!cir.array<!u64i x 2>>, ["tmp"]
357+
// CHECK: %[[#V2:]] = cir.call @foo() : () -> !cir.array<!u64i x 2>
358+
// CHECK: cir.store %[[#V2]], %[[#V1]] : !cir.array<!u64i x 2>, !cir.ptr<!cir.array<!u64i x 2>>
359+
// CHECK: %[[#V3:]] = cir.cast(bitcast, %[[#V1]] : !cir.ptr<!cir.array<!u64i x 2>>), !cir.ptr<!void>
360+
// CHECK: %[[#V4:]] = cir.cast(bitcast, %[[#V0]] : !cir.ptr<!ty_PackedS1_>), !cir.ptr<!void>
361+
// CHECK: %[[#V5:]] = cir.const #cir.int<14> : !u64i
362+
// CHECK: cir.libc.memcpy %[[#V5]] bytes from %[[#V3]] to %[[#V4]] : !u64i, !cir.ptr<!void> -> !cir.ptr<!void>
363+
364+
// LLVML: void @bar
365+
// LLVM: %[[#V1:]] = alloca %struct.PackedS1, i64 1, align 1
366+
// LLVM: %[[#V2:]] = alloca [2 x i64], i64 1, align 8
367+
// LLVM: %[[#V3:]] = call [2 x i64] @foo()
368+
// LLVM: store [2 x i64] %[[#V3]], ptr %[[#V2]], align 8
369+
// LLVM: call void @llvm.memcpy.p0.p0.i64(ptr %[[#V1]], ptr %[[#V2]], i64 14, i1 false)
370+
371+
372+
#pragma pack(push)
373+
#pragma pack(1)
374+
typedef struct {
375+
short f0;
376+
int f1;
377+
} PackedS2;
378+
#pragma pack(pop)
379+
380+
PackedS2 g[3] = {{1,2},{3,4},{5,6}};
381+
382+
void baz(PackedS2 a) {
383+
short *x = &g[2].f0;
384+
(*x) = a.f0;
385+
}
386+
387+
void qux(void) {
388+
const PackedS2 *s1 = &g[1];
389+
baz(*s1);
390+
}
391+
392+
// check source of memcpy
393+
// CHECK: cir.func @qux
394+
// CHECK: %[[#V0:]] = cir.alloca !cir.ptr<!ty_PackedS2_>, !cir.ptr<!cir.ptr<!ty_PackedS2_>>, ["s1", init]
395+
// CHECK: %[[#V1:]] = cir.alloca !u64i, !cir.ptr<!u64i>, ["tmp"]
396+
// CHECK: %[[#V2:]] = cir.get_global @g : !cir.ptr<!cir.array<!ty_PackedS2_ x 3>>
397+
// CHECK: %[[#V3:]] = cir.const #cir.int<1> : !s32i
398+
// CHECK: %[[#V4:]] = cir.cast(array_to_ptrdecay, %[[#V2]] : !cir.ptr<!cir.array<!ty_PackedS2_ x 3>>), !cir.ptr<!ty_PackedS2_>
399+
// CHECK: %[[#V5:]] = cir.ptr_stride(%[[#V4]] : !cir.ptr<!ty_PackedS2_>, %[[#V3]] : !s32i), !cir.ptr<!ty_PackedS2_>
400+
// CHECK: cir.store %[[#V5]], %[[#V0]] : !cir.ptr<!ty_PackedS2_>, !cir.ptr<!cir.ptr<!ty_PackedS2_>>
401+
// CHECK: %[[#V6:]] = cir.load deref %[[#V0]] : !cir.ptr<!cir.ptr<!ty_PackedS2_>>, !cir.ptr<!ty_PackedS2_>
402+
// CHECK: %[[#V7:]] = cir.cast(bitcast, %[[#V6]] : !cir.ptr<!ty_PackedS2_>), !cir.ptr<!void>
403+
// CHECK: %[[#V8:]] = cir.const #cir.int<6> : !u64i
404+
// CHECK: cir.libc.memcpy %[[#V8]] bytes from %[[#V7]]
405+
406+
// LLVM: void @qux
407+
// LLVM: %[[#V1:]] = alloca ptr, i64 1, align 8
408+
// LLVM: %[[#V2:]] = alloca i64, i64 1, align 8
409+
// LLVM: store ptr getelementptr (%struct.PackedS2, ptr @g, i64 1), ptr %[[#V1]], align 8
410+
// LLVM: %[[#V3:]] = load ptr, ptr %[[#V1]], align 8
411+
// LLVM: %[[#V4:]] = load %struct.PackedS2, ptr %[[#V3]], align 1
412+
// LLVM: call void @llvm.memcpy.p0.p0.i64(ptr %[[#V2]], ptr %[[#V3]], i64 6, i1 false)

0 commit comments

Comments
 (0)