@@ -796,16 +796,17 @@ llvm::Function *CodeGenFunction::GenerateOpenMPCapturedStmtFunction(
796
796
CapturedStmtInfo &&
797
797
"CapturedStmtInfo should be set when generating the captured function");
798
798
const CapturedDecl *CD = S.getCapturedDecl();
799
+
799
800
// Build the argument list.
800
- // AMDGCN does not generate wrapper kernels properly, fails to launch kernel.
801
- bool NeedWrapperFunction = !CGM.getTriple().isAMDGCN() &&
802
- (getDebugInfo() && CGM.getCodeGenOpts().hasReducedDebugInfo());
803
- FunctionArgList Args;
804
- llvm::MapVector<const Decl *, std::pair<const VarDecl *, Address>> LocalAddrs;
805
- llvm::DenseMap<const Decl *, std::pair<const Expr *, llvm::Value *>> VLASizes;
801
+ FunctionArgList Args, WrapperArgs;
802
+ llvm::MapVector<const Decl *, std::pair<const VarDecl *, Address>> LocalAddrs,
803
+ WrapperLocalAddrs;
804
+ llvm::DenseMap<const Decl *, std::pair<const Expr *, llvm::Value *>> VLASizes,
805
+ WrapperVLASizes;
806
806
SmallString<256> Buffer;
807
807
llvm::raw_svector_ostream Out(Buffer);
808
808
Out << CapturedStmtInfo->getHelperName();
809
+
809
810
bool isKernel = (Out.str().find("__omp_offloading_") != std::string::npos);
810
811
811
812
// For host codegen, we need to determine now whether Xteam reduction is used
@@ -834,22 +835,40 @@ llvm::Function *CodeGenFunction::GenerateOpenMPCapturedStmtFunction(
834
835
}
835
836
}
836
837
837
- if (NeedWrapperFunction)
838
+ // AMDGCN does not generate wrapper kernels properly, fails to launch kernel.
839
+ // Xteam reduction does not use wrapper kernels.
840
+ bool NeedWrapperFunction =
841
+ !CGM.getTriple().isAMDGCN() && !isXteamKernel &&
842
+ (getDebugInfo() && CGM.getCodeGenOpts().hasReducedDebugInfo());
843
+
844
+ CodeGenFunction WrapperCGF(CGM, /*suppressNewContext=*/true);
845
+ llvm::Function *WrapperF = nullptr;
846
+ if (NeedWrapperFunction) {
847
+ // Emit the final kernel early to allow attributes to be added by the
848
+ // OpenMPI-IR-Builder.
849
+ FunctionOptions WrapperFO(&S, /*UIntPtrCastRequired=*/true,
850
+ /*RegisterCastedArgsOnly=*/true,
851
+ CapturedStmtInfo->getHelperName(), Loc);
852
+ WrapperCGF.CapturedStmtInfo = CapturedStmtInfo;
853
+ WrapperF = emitOutlinedFunctionPrologue(WrapperCGF, D, Args, LocalAddrs,
854
+ VLASizes, WrapperCGF.CXXThisValue,
855
+ WrapperFO, isKernel, isXteamKernel);
838
856
Out << "_debug__";
857
+ }
839
858
FunctionOptions FO(&S, !NeedWrapperFunction, /*RegisterCastedArgsOnly=*/false,
840
859
Out.str(), Loc);
841
- llvm::Function *F =
842
- emitOutlinedFunctionPrologue( *this, D, Args, LocalAddrs, VLASizes ,
843
- CXXThisValue, FO, isKernel, isXteamKernel);
860
+ llvm::Function *F = emitOutlinedFunctionPrologue(
861
+ *this, D, WrapperArgs, WrapperLocalAddrs, WrapperVLASizes, CXXThisValue ,
862
+ FO, isKernel, isXteamKernel);
844
863
CodeGenFunction::OMPPrivateScope LocalScope(*this);
845
- for (const auto &LocalAddrPair : LocalAddrs ) {
864
+ for (const auto &LocalAddrPair : WrapperLocalAddrs ) {
846
865
if (LocalAddrPair.second.first) {
847
866
LocalScope.addPrivate(LocalAddrPair.second.first,
848
867
LocalAddrPair.second.second);
849
868
}
850
869
}
851
870
(void)LocalScope.Privatize();
852
- for (const auto &VLASizePair : VLASizes )
871
+ for (const auto &VLASizePair : WrapperVLASizes )
853
872
VLASizeMap[VLASizePair.second.first] = VLASizePair.second.second;
854
873
PGO.assignRegionCounters(GlobalDecl(CD), F);
855
874
@@ -861,16 +880,16 @@ llvm::Function *CodeGenFunction::GenerateOpenMPCapturedStmtFunction(
861
880
EmitOptKernel(
862
881
D, FStmt,
863
882
llvm::omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP, Loc,
864
- /*Args =*/nullptr);
883
+ /*WrapperArgs =*/nullptr);
865
884
else
866
885
EmitOptKernel(
867
886
D, FStmt,
868
887
llvm::omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_SPMD_BIG_JUMP_LOOP,
869
- Loc, /*Args =*/nullptr);
888
+ Loc, /*WrapperArgs =*/nullptr);
870
889
} else if (CGM.getLangOpts().OpenMPIsTargetDevice && isXteamKernel) {
871
890
EmitOptKernel(D, FStmt,
872
891
llvm::omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_XTEAM_RED,
873
- Loc, &Args );
892
+ Loc, &WrapperArgs );
874
893
} else {
875
894
CapturedStmtInfo->EmitBody(*this, CD->getBody());
876
895
}
@@ -880,22 +899,9 @@ llvm::Function *CodeGenFunction::GenerateOpenMPCapturedStmtFunction(
880
899
if (!NeedWrapperFunction)
881
900
return F;
882
901
883
- FunctionOptions WrapperFO(&S, /*UIntPtrCastRequired=*/true,
884
- /*RegisterCastedArgsOnly=*/true,
885
- CapturedStmtInfo->getHelperName(), Loc);
886
- CodeGenFunction WrapperCGF(CGM, /*suppressNewContext=*/true);
887
- WrapperCGF.CapturedStmtInfo = CapturedStmtInfo;
888
- Args.clear();
889
- LocalAddrs.clear();
890
- VLASizes.clear();
891
- SmallString<256> Buffer2;
892
- llvm::raw_svector_ostream Out2(Buffer2);
893
- Out2 << CapturedStmtInfo->getHelperName();
894
- isKernel = (Out2.str().find("__omp_offloading_") != std::string::npos);
895
-
896
- llvm::Function *WrapperF = emitOutlinedFunctionPrologue(
897
- WrapperCGF, D, Args, LocalAddrs, VLASizes, WrapperCGF.CXXThisValue,
898
- WrapperFO, isKernel, isXteamKernel);
902
+ // Reverse the order.
903
+ WrapperF->removeFromParent();
904
+ F->getParent()->getFunctionList().insertAfter(F->getIterator(), WrapperF);
899
905
900
906
llvm::SmallVector<llvm::Value *, 4> CallArgs;
901
907
auto *PI = F->arg_begin();
0 commit comments