@@ -1996,7 +1996,8 @@ bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
1996
1996
UndefValue::get (Int8Ty), F->getName () + " .ID" );
1997
1997
1998
1998
for (Use *U : ToBeReplacedStateMachineUses)
1999
- U->set (ConstantExpr::getBitCast (ID, U->get ()->getType ()));
1999
+ U->set (ConstantExpr::getPointerBitCastOrAddrSpaceCast (
2000
+ ID, U->get ()->getType ()));
2000
2001
2001
2002
++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
2002
2003
@@ -3183,10 +3184,14 @@ struct AAKernelInfoFunction : AAKernelInfo {
3183
3184
IsWorker->setDebugLoc (DLoc);
3184
3185
BranchInst::Create (StateMachineBeginBB, UserCodeEntryBB, IsWorker, InitBB);
3185
3186
3187
+ Module &M = *Kernel->getParent ();
3188
+
3186
3189
// Create local storage for the work function pointer.
3190
+ const DataLayout &DL = M.getDataLayout ();
3187
3191
Type *VoidPtrTy = Type::getInt8PtrTy (Ctx);
3188
- AllocaInst *WorkFnAI = new AllocaInst (VoidPtrTy, 0 , " worker.work_fn.addr" ,
3189
- &Kernel->getEntryBlock ().front ());
3192
+ Instruction *WorkFnAI =
3193
+ new AllocaInst (VoidPtrTy, DL.getAllocaAddrSpace (), nullptr ,
3194
+ " worker.work_fn.addr" , &Kernel->getEntryBlock ().front ());
3190
3195
WorkFnAI->setDebugLoc (DLoc);
3191
3196
3192
3197
auto &OMPInfoCache = static_cast <OMPInformationCache &>(A.getInfoCache ());
@@ -3199,13 +3204,23 @@ struct AAKernelInfoFunction : AAKernelInfo {
3199
3204
Value *Ident = KernelInitCB->getArgOperand (0 );
3200
3205
Value *GTid = KernelInitCB;
3201
3206
3202
- Module &M = *Kernel->getParent ();
3203
3207
FunctionCallee BarrierFn =
3204
3208
OMPInfoCache.OMPBuilder .getOrCreateRuntimeFunction (
3205
3209
M, OMPRTL___kmpc_barrier_simple_spmd);
3206
3210
CallInst::Create (BarrierFn, {Ident, GTid}, " " , StateMachineBeginBB)
3207
3211
->setDebugLoc (DLoc);
3208
3212
3213
+ if (WorkFnAI->getType ()->getPointerAddressSpace () !=
3214
+ (unsigned int )AddressSpace::Generic) {
3215
+ WorkFnAI = new AddrSpaceCastInst (
3216
+ WorkFnAI,
3217
+ PointerType::getWithSamePointeeType (
3218
+ cast<PointerType>(WorkFnAI->getType ()),
3219
+ (unsigned int )AddressSpace::Generic),
3220
+ WorkFnAI->getName () + " .generic" , StateMachineBeginBB);
3221
+ WorkFnAI->setDebugLoc (DLoc);
3222
+ }
3223
+
3209
3224
FunctionCallee KernelParallelFn =
3210
3225
OMPInfoCache.OMPBuilder .getOrCreateRuntimeFunction (
3211
3226
M, OMPRTL___kmpc_kernel_parallel);
0 commit comments