10
10
// are adapted to operate on the CIR dialect, however.
11
11
//
12
12
// ===----------------------------------------------------------------------===//
13
-
14
13
#include " LowerFunction.h"
15
14
#include " CIRToCIRArgMapping.h"
16
15
#include " LowerCall.h"
@@ -433,6 +432,23 @@ LowerFunction::buildFunctionProlog(const LowerFunctionInfo &FI, FuncOp Fn,
433
432
return success ();
434
433
}
435
434
435
+ mlir::cir::AllocaOp findAlloca (Operation *op) {
436
+ if (!op)
437
+ return {};
438
+
439
+ if (auto al = dyn_cast<mlir::cir::AllocaOp>(op)) {
440
+ return al;
441
+ } else if (auto ret = dyn_cast<mlir::cir::ReturnOp>(op)) {
442
+ auto vals = ret.getInput ();
443
+ if (vals.size () == 1 )
444
+ return findAlloca (vals[0 ].getDefiningOp ());
445
+ } else if (auto load = dyn_cast<mlir::cir::LoadOp>(op)) {
446
+ return findAlloca (load.getAddr ().getDefiningOp ());
447
+ }
448
+
449
+ return {};
450
+ }
451
+
436
452
LogicalResult LowerFunction::buildFunctionEpilog (const LowerFunctionInfo &FI) {
437
453
// NOTE(cir): no-return, naked, and no result functions should be handled in
438
454
// CIRGen.
@@ -446,6 +462,27 @@ LogicalResult LowerFunction::buildFunctionEpilog(const LowerFunctionInfo &FI) {
446
462
case ABIArgInfo::Ignore:
447
463
break ;
448
464
465
+ case ABIArgInfo::Indirect: {
466
+ Value RVAddr = {};
467
+ CIRToCIRArgMapping IRFunctionArgs (LM.getContext (), FI, true );
468
+ if (IRFunctionArgs.hasSRetArg ()) {
469
+ auto &entry = NewFn.getBody ().front ();
470
+ RVAddr = entry.getArgument (IRFunctionArgs.getSRetArgNo ());
471
+ }
472
+
473
+ if (RVAddr) {
474
+ mlir::PatternRewriter::InsertionGuard guard (rewriter);
475
+ NewFn->walk ([&](ReturnOp ret) {
476
+ if (auto al = findAlloca (ret)) {
477
+ rewriter.replaceAllUsesWith (al.getResult (), RVAddr);
478
+ rewriter.eraseOp (al);
479
+ rewriter.replaceOpWithNewOp <ReturnOp>(ret);
480
+ }
481
+ });
482
+ }
483
+ break ;
484
+ }
485
+
449
486
case ABIArgInfo::Extend:
450
487
case ABIArgInfo::Direct:
451
488
// FIXME(cir): Should we call ConvertType(RetTy) here?
@@ -517,6 +554,15 @@ LogicalResult LowerFunction::generateCode(FuncOp oldFn, FuncOp newFn,
517
554
Block *srcBlock = &oldFn.getBody ().front ();
518
555
Block *dstBlock = &newFn.getBody ().front ();
519
556
557
+ // Ensure both blocks have the same number of arguments in order to
558
+ // safely merge them.
559
+ CIRToCIRArgMapping IRFunctionArgs (LM.getContext (), FnInfo, true );
560
+ if (IRFunctionArgs.hasSRetArg ()) {
561
+ auto dstIndex = IRFunctionArgs.getSRetArgNo ();
562
+ auto retArg = dstBlock->getArguments ()[dstIndex];
563
+ srcBlock->insertArgument (dstIndex, retArg.getType (), retArg.getLoc ());
564
+ }
565
+
520
566
// Migrate function body to new ABI-aware function.
521
567
rewriter.inlineRegionBefore (oldFn.getBody (), newFn.getBody (),
522
568
newFn.getBody ().end ());
0 commit comments