Skip to content

Commit c8f189b

Browse files
authored
[CIR][ABI][AArch64] support for return struct types greater than 128 bits (#1027)
This PR adds a support for the return values of struct types > 128 bits in size. As usually, lot's of copy-pasted lines from the original codegen, except the `AllocaOp` replacement for the return value.
1 parent 2813bac commit c8f189b

File tree

10 files changed

+138
-5
lines changed

10 files changed

+138
-5
lines changed

clang/include/clang/CIR/ABIArgInfo.h

+40
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,9 @@ class ABIArgInfo {
103103
bool InReg : 1; // isDirect() || isExtend() || isIndirect()
104104
bool CanBeFlattened : 1; // isDirect()
105105
bool SignExt : 1; // isExtend()
106+
bool IndirectByVal : 1; // isIndirect()
107+
bool IndirectRealign : 1; // isIndirect()
108+
bool SRetAfterThis : 1; // isIndirect()
106109

107110
bool canHavePaddingType() const {
108111
return isDirect() || isExtend() || isIndirect() || isIndirectAliased() ||
@@ -195,6 +198,43 @@ class ABIArgInfo {
195198

196199
static ABIArgInfo getIgnore() { return ABIArgInfo(Ignore); }
197200

201+
static ABIArgInfo getIndirect(unsigned Alignment, bool ByVal = true,
202+
bool Realign = false,
203+
mlir::Type Padding = nullptr) {
204+
auto AI = ABIArgInfo(Indirect);
205+
AI.setIndirectAlign(Alignment);
206+
AI.setIndirectByVal(ByVal);
207+
AI.setIndirectRealign(Realign);
208+
AI.setSRetAfterThis(false);
209+
AI.setPaddingType(Padding);
210+
return AI;
211+
}
212+
213+
void setIndirectAlign(unsigned align) {
214+
assert((isIndirect() || isIndirectAliased()) && "Invalid kind!");
215+
IndirectAttr.Align = align;
216+
}
217+
218+
void setIndirectByVal(bool IBV) {
219+
assert(isIndirect() && "Invalid kind!");
220+
IndirectByVal = IBV;
221+
}
222+
223+
void setIndirectRealign(bool IR) {
224+
assert((isIndirect() || isIndirectAliased()) && "Invalid kind!");
225+
IndirectRealign = IR;
226+
}
227+
228+
void setSRetAfterThis(bool AfterThis) {
229+
assert(isIndirect() && "Invalid kind!");
230+
SRetAfterThis = AfterThis;
231+
}
232+
233+
bool isSRetAfterThis() const {
234+
assert(isIndirect() && "Invalid kind!");
235+
return SRetAfterThis;
236+
}
237+
198238
Kind getKind() const { return TheKind; }
199239
bool isDirect() const { return TheKind == Direct; }
200240
bool isInAlloca() const { return TheKind == InAlloca; }

clang/include/clang/CIR/MissingFeatures.h

+1
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ struct MissingFeatures {
268268
static bool ABIParameterCoercion() { return false; }
269269
static bool ABIPointerParameterAttrs() { return false; }
270270
static bool ABITransparentUnionHandling() { return false; }
271+
static bool ABIPotentialArgAccess() { return false; }
271272

272273
//-- Missing AST queries
273274

clang/lib/CIR/Dialect/Transforms/TargetLowering/ABIInfo.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -42,5 +42,12 @@ bool ABIInfo::isPromotableIntegerTypeForABI(Type Ty) const {
4242
return false;
4343
}
4444

45+
::cir::ABIArgInfo ABIInfo::getNaturalAlignIndirect(mlir::Type Ty, bool ByVal,
46+
bool Realign,
47+
mlir::Type Padding) const {
48+
return ::cir::ABIArgInfo::getIndirect(getContext().getTypeAlign(Ty), ByVal,
49+
Realign, Padding);
50+
}
51+
4552
} // namespace cir
4653
} // namespace mlir

clang/lib/CIR/Dialect/Transforms/TargetLowering/ABIInfo.h

+4
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ class ABIInfo {
5050
// Implement the Type::IsPromotableIntegerType for ABI specific needs. The
5151
// only difference is that this considers bit-precise integer types as well.
5252
bool isPromotableIntegerTypeForABI(Type Ty) const;
53+
54+
::cir::ABIArgInfo getNaturalAlignIndirect(mlir::Type Ty, bool ByVal = true,
55+
bool Realign = false,
56+
mlir::Type Padding = {}) const;
5357
};
5458

5559
} // namespace cir

clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRToCIRArgMapping.h

+17-2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ namespace cir {
2929
/// LoweringFunctionInfo should be passed to actual CIR function.
3030
class CIRToCIRArgMapping {
3131
static const unsigned InvalidIndex = ~0U;
32+
unsigned SRetArgNo;
3233
unsigned TotalIRArgs;
3334

3435
/// Arguments of CIR function corresponding to single CIR argument.
@@ -51,7 +52,8 @@ class CIRToCIRArgMapping {
5152
public:
5253
CIRToCIRArgMapping(const CIRLowerContext &context,
5354
const LowerFunctionInfo &FI, bool onlyRequiredArgs = false)
54-
: ArgInfo(onlyRequiredArgs ? FI.getNumRequiredArgs() : FI.arg_size()) {
55+
: SRetArgNo(InvalidIndex),
56+
ArgInfo(onlyRequiredArgs ? FI.getNumRequiredArgs() : FI.arg_size()) {
5557
construct(context, FI, onlyRequiredArgs);
5658
};
5759

@@ -69,7 +71,8 @@ class CIRToCIRArgMapping {
6971
const ::cir::ABIArgInfo &RetAI = FI.getReturnInfo();
7072

7173
if (RetAI.getKind() == ::cir::ABIArgInfo::Indirect) {
72-
cir_cconv_unreachable("NYI");
74+
SwapThisWithSRet = RetAI.isSRetAfterThis();
75+
SRetArgNo = SwapThisWithSRet ? 1 : IRArgNo++;
7376
}
7477

7578
unsigned ArgNo = 0;
@@ -100,6 +103,11 @@ class CIRToCIRArgMapping {
100103
}
101104
break;
102105
}
106+
case ::cir::ABIArgInfo::Indirect:
107+
case ::cir::ABIArgInfo::IndirectAliased:
108+
IRArgs.NumberOfArgs = 1;
109+
break;
110+
103111
default:
104112
cir_cconv_unreachable("Missing ABIArgInfo::Kind");
105113
}
@@ -130,6 +138,13 @@ class CIRToCIRArgMapping {
130138
return std::make_pair(ArgInfo[ArgNo].FirstArgIndex,
131139
ArgInfo[ArgNo].NumberOfArgs);
132140
}
141+
142+
bool hasSRetArg() const { return SRetArgNo != InvalidIndex; }
143+
144+
unsigned getSRetArgNo() const {
145+
assert(hasSRetArg());
146+
return SRetArgNo;
147+
}
133148
};
134149

135150
} // namespace cir

clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerCall.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@ void LowerModule::constructAttributeList(StringRef Name,
157157
cir_cconv_assert(!::cir::MissingFeatures::noFPClass());
158158
break;
159159
case ABIArgInfo::Ignore:
160+
case ABIArgInfo::Indirect:
161+
cir_cconv_assert(!::cir::MissingFeatures::ABIPotentialArgAccess());
160162
break;
161163
default:
162164
cir_cconv_unreachable("Missing ABIArgInfo::Kind");

clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerFunction.cpp

+47-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
// are adapted to operate on the CIR dialect, however.
1111
//
1212
//===----------------------------------------------------------------------===//
13-
1413
#include "LowerFunction.h"
1514
#include "CIRToCIRArgMapping.h"
1615
#include "LowerCall.h"
@@ -433,6 +432,23 @@ LowerFunction::buildFunctionProlog(const LowerFunctionInfo &FI, FuncOp Fn,
433432
return success();
434433
}
435434

435+
mlir::cir::AllocaOp findAlloca(Operation *op) {
436+
if (!op)
437+
return {};
438+
439+
if (auto al = dyn_cast<mlir::cir::AllocaOp>(op)) {
440+
return al;
441+
} else if (auto ret = dyn_cast<mlir::cir::ReturnOp>(op)) {
442+
auto vals = ret.getInput();
443+
if (vals.size() == 1)
444+
return findAlloca(vals[0].getDefiningOp());
445+
} else if (auto load = dyn_cast<mlir::cir::LoadOp>(op)) {
446+
return findAlloca(load.getAddr().getDefiningOp());
447+
}
448+
449+
return {};
450+
}
451+
436452
LogicalResult LowerFunction::buildFunctionEpilog(const LowerFunctionInfo &FI) {
437453
// NOTE(cir): no-return, naked, and no result functions should be handled in
438454
// CIRGen.
@@ -446,6 +462,27 @@ LogicalResult LowerFunction::buildFunctionEpilog(const LowerFunctionInfo &FI) {
446462
case ABIArgInfo::Ignore:
447463
break;
448464

465+
case ABIArgInfo::Indirect: {
466+
Value RVAddr = {};
467+
CIRToCIRArgMapping IRFunctionArgs(LM.getContext(), FI, true);
468+
if (IRFunctionArgs.hasSRetArg()) {
469+
auto &entry = NewFn.getBody().front();
470+
RVAddr = entry.getArgument(IRFunctionArgs.getSRetArgNo());
471+
}
472+
473+
if (RVAddr) {
474+
mlir::PatternRewriter::InsertionGuard guard(rewriter);
475+
NewFn->walk([&](ReturnOp ret) {
476+
if (auto al = findAlloca(ret)) {
477+
rewriter.replaceAllUsesWith(al.getResult(), RVAddr);
478+
rewriter.eraseOp(al);
479+
rewriter.replaceOpWithNewOp<ReturnOp>(ret);
480+
}
481+
});
482+
}
483+
break;
484+
}
485+
449486
case ABIArgInfo::Extend:
450487
case ABIArgInfo::Direct:
451488
// FIXME(cir): Should we call ConvertType(RetTy) here?
@@ -517,6 +554,15 @@ LogicalResult LowerFunction::generateCode(FuncOp oldFn, FuncOp newFn,
517554
Block *srcBlock = &oldFn.getBody().front();
518555
Block *dstBlock = &newFn.getBody().front();
519556

557+
// Ensure both blocks have the same number of arguments in order to
558+
// safely merge them.
559+
CIRToCIRArgMapping IRFunctionArgs(LM.getContext(), FnInfo, true);
560+
if (IRFunctionArgs.hasSRetArg()) {
561+
auto dstIndex = IRFunctionArgs.getSRetArgNo();
562+
auto retArg = dstBlock->getArguments()[dstIndex];
563+
srcBlock->insertArgument(dstIndex, retArg.getType(), retArg.getLoc());
564+
}
565+
520566
// Migrate function body to new ABI-aware function.
521567
rewriter.inlineRegionBefore(oldFn.getBody(), newFn.getBody(),
522568
newFn.getBody().end());

clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerTypes.cpp

+6-1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ FuncType LowerTypes::getFunctionType(const LowerFunctionInfo &FI) {
5050
resultType = retAI.getCoerceToType();
5151
break;
5252
case ::cir::ABIArgInfo::Ignore:
53+
case ::cir::ABIArgInfo::Indirect:
5354
resultType = VoidType::get(getMLIRContext());
5455
break;
5556
default:
@@ -60,7 +61,11 @@ FuncType LowerTypes::getFunctionType(const LowerFunctionInfo &FI) {
6061
SmallVector<Type, 8> ArgTypes(IRFunctionArgs.totalIRArgs());
6162

6263
// Add type for sret argument.
63-
cir_cconv_assert(!::cir::MissingFeatures::sretArgs());
64+
if (IRFunctionArgs.hasSRetArg()) {
65+
mlir::Type ret = FI.getReturnType();
66+
ArgTypes[IRFunctionArgs.getSRetArgNo()] =
67+
mlir::cir::PointerType::get(getMLIRContext(), ret);
68+
}
6469

6570
// Add type for inalloca argument.
6671
cir_cconv_assert(!::cir::MissingFeatures::inallocaArgs());

clang/lib/CIR/Dialect/Transforms/TargetLowering/Targets/AArch64.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ ABIArgInfo AArch64ABIInfo::classifyReturnType(Type RetTy,
137137
cir_cconv_unreachable("NYI");
138138
}
139139

140-
cir_cconv_unreachable("NYI");
140+
return getNaturalAlignIndirect(RetTy);
141141
}
142142

143143
ABIArgInfo

clang/test/CIR/CallConvLowering/AArch64/aarch64-cc-structs.c

+13
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@ typedef struct {
2121
int64_t b;
2222
} EQ_128;
2323

24+
typedef struct {
25+
int64_t a;
26+
int64_t b;
27+
int64_t c;
28+
} GT_128;
29+
2430
// CHECK: cir.func {{.*@ret_lt_64}}() -> !u16i
2531
// CHECK: %[[#V0:]] = cir.alloca !ty_LT_64_, !cir.ptr<!ty_LT_64_>, ["__retval"]
2632
// CHECK: %[[#V1:]] = cir.cast(bitcast, %[[#V0]] : !cir.ptr<!ty_LT_64_>), !cir.ptr<!u16i>
@@ -60,3 +66,10 @@ EQ_128 ret_eq_128() {
6066
EQ_128 x;
6167
return x;
6268
}
69+
70+
// CHECK: cir.func {{.*@ret_gt_128}}(%arg0: !cir.ptr<!ty_GT_128_>
71+
// CHECK-NOT: cir.return {{%.*}}
72+
GT_128 ret_gt_128() {
73+
GT_128 x;
74+
return x;
75+
}

0 commit comments

Comments
 (0)