@@ -85,6 +85,20 @@ class SPIRVRegularizeLLVM : public ModulePass {
85
85
// / @spirv.llvm_memset_* and replace it with @llvm.memset.
86
86
void lowerMemset (MemSetInst *MSI);
87
87
88
+ // / No SPIR-V counterpart for @llvm.fshl.i* intrinsic. It will be lowered
89
+ // / to a newly generated @spirv.llvm_fshl_i* function.
90
+ // / Conceptually, FSHL:
91
+ // / 1. concatenates the ints, the first one being the more significant;
92
+ // / 2. performs a left shift-rotate on the resulting doubled-sized int;
93
+ // / 3. returns the most significant bits of the shift-rotate result,
94
+ // / the number of bits being equal to the size of the original integers.
95
+ // / The actual implementation algorithm will be slightly different to speed
96
+ // / things up.
97
+ void lowerFunnelShiftLeft (IntrinsicInst *FSHLIntrinsic);
98
+ void buildFunnelShiftLeftFunc (Function *FSHLFunc);
99
+
100
+ static std::string lowerLLVMIntrinsicName (IntrinsicInst *II);
101
+
88
102
static char ID;
89
103
90
104
private:
@@ -94,17 +108,22 @@ class SPIRVRegularizeLLVM : public ModulePass {
94
108
95
109
char SPIRVRegularizeLLVM::ID = 0 ;
96
110
97
- void SPIRVRegularizeLLVM::lowerMemset (MemSetInst *MSI) {
98
- if (isa<Constant>(MSI->getValue ()) && isa<ConstantInt>(MSI->getLength ()))
99
- return ; // To be handled in LLVMToSPIRV::transIntrinsicInst
100
- Function *IntrinsicFunc = MSI->getCalledFunction ();
111
+ std::string SPIRVRegularizeLLVM::lowerLLVMIntrinsicName (IntrinsicInst *II) {
112
+ Function *IntrinsicFunc = II->getCalledFunction ();
101
113
assert (IntrinsicFunc && " Missing function" );
102
114
std::string FuncName = IntrinsicFunc->getName ().str ();
103
115
std::replace (FuncName.begin (), FuncName.end (), ' .' , ' _' );
104
116
FuncName = " spirv." + FuncName;
117
+ return FuncName;
118
+ }
119
+
120
+ void SPIRVRegularizeLLVM::lowerMemset (MemSetInst *MSI) {
121
+ if (isa<Constant>(MSI->getValue ()) && isa<ConstantInt>(MSI->getLength ()))
122
+ return ; // To be handled in LLVMToSPIRV::transIntrinsicInst
123
+
124
+ std::string FuncName = lowerLLVMIntrinsicName (MSI);
105
125
if (MSI->isVolatile ())
106
126
FuncName += " .volatile" ;
107
-
108
127
// Redirect @llvm.memset.* call to @spirv.llvm_memset_*
109
128
Function *F = M->getFunction (FuncName);
110
129
if (F) {
@@ -137,6 +156,75 @@ void SPIRVRegularizeLLVM::lowerMemset(MemSetInst *MSI) {
137
156
return ;
138
157
}
139
158
159
+ void SPIRVRegularizeLLVM::buildFunnelShiftLeftFunc (Function *FSHLFunc) {
160
+ if (!FSHLFunc->empty ())
161
+ return ;
162
+
163
+ auto *IntTy = dyn_cast<IntegerType>(FSHLFunc->getReturnType ());
164
+ assert (IntTy && " llvm.fshl: expected an integer return type" );
165
+ assert (FSHLFunc->arg_size () == 3 && " llvm.fshl: expected 3 arguments" );
166
+ for (Argument &Arg : FSHLFunc->args ())
167
+ assert (Arg.getType ()->getTypeID () == IntTy->getTypeID () &&
168
+ " llvm.fshl: mismatched return type and argument types" );
169
+
170
+ // Our function will require 3 basic blocks; the purpose of each will be
171
+ // clarified below.
172
+ auto *CondBB = BasicBlock::Create (M->getContext (), " cond" , FSHLFunc);
173
+ auto *RotateBB =
174
+ BasicBlock::Create (M->getContext (), " rotate" , FSHLFunc); // Main logic
175
+ auto *PhiBB = BasicBlock::Create (M->getContext (), " phi" , FSHLFunc);
176
+
177
+ IRBuilder<> Builder (CondBB);
178
+ // If the number of bits to rotate for is divisible by the bitsize,
179
+ // the shift becomes useless, and we should bypass the main logic in that
180
+ // case.
181
+ unsigned BitWidth = IntTy->getIntegerBitWidth ();
182
+ ConstantInt *BitWidthConstant = Builder.getInt ({BitWidth, BitWidth});
183
+ auto *RotateModVal =
184
+ Builder.CreateURem (/* Rotate*/ FSHLFunc->getArg (2 ), BitWidthConstant);
185
+ ConstantInt *ZeroConstant = Builder.getInt ({BitWidth, 0 });
186
+ auto *CheckRotateModIfZero = Builder.CreateICmpEQ (RotateModVal, ZeroConstant);
187
+ Builder.CreateCondBr (CheckRotateModIfZero, /* True*/ PhiBB,
188
+ /* False*/ RotateBB);
189
+
190
+ // Build the actual funnel shift rotate logic.
191
+ Builder.SetInsertPoint (RotateBB);
192
+ // Shift the more significant number left, the "rotate" number of bits
193
+ // will be 0-filled on the right as a result of this regular shift.
194
+ auto *ShiftLeft = Builder.CreateShl (FSHLFunc->getArg (0 ), RotateModVal);
195
+ // We want the "rotate" number of the second int's MSBs to occupy the
196
+ // rightmost "0 space" left by the previous operation. Therefore,
197
+ // subtract the "rotate" number from the integer bitsize...
198
+ auto *SubRotateVal = Builder.CreateSub (BitWidthConstant, RotateModVal);
199
+ // ...and right-shift the second int by this number, zero-filling the MSBs.
200
+ auto *ShiftRight = Builder.CreateLShr (FSHLFunc->getArg (1 ), SubRotateVal);
201
+ // A simple binary addition of the shifted ints yields the final result.
202
+ auto *FunnelShiftRes = Builder.CreateOr (ShiftLeft, ShiftRight);
203
+ Builder.CreateBr (PhiBB);
204
+
205
+ // PHI basic block. If no actual rotate was required, return the first, more
206
+ // significant int. E.g. for 32-bit integers, it's equivalent to concatenating
207
+ // the 2 ints and taking 32 MSBs.
208
+ Builder.SetInsertPoint (PhiBB);
209
+ PHINode *Phi = Builder.CreatePHI (IntTy, 0 );
210
+ Phi->addIncoming (FunnelShiftRes, RotateBB);
211
+ Phi->addIncoming (FSHLFunc->getArg (0 ), CondBB);
212
+ Builder.CreateRet (Phi);
213
+ }
214
+
215
+ void SPIRVRegularizeLLVM::lowerFunnelShiftLeft (IntrinsicInst *FSHLIntrinsic) {
216
+ // Get a separate function - otherwise, we'd have to rework the CFG of the
217
+ // current one. Then simply replace the intrinsic uses with a call to the new
218
+ // function.
219
+ FunctionType *FSHLFuncTy = FSHLIntrinsic->getFunctionType ();
220
+ Type *FSHLRetTy = FSHLFuncTy->getReturnType ();
221
+ const std::string FuncName = lowerLLVMIntrinsicName (FSHLIntrinsic);
222
+ Function *FSHLFunc =
223
+ getOrCreateFunction (M, FSHLRetTy, FSHLFuncTy->params (), FuncName);
224
+ buildFunnelShiftLeftFunc (FSHLFunc);
225
+ FSHLIntrinsic->setCalledFunction (FSHLFunc);
226
+ }
227
+
140
228
bool SPIRVRegularizeLLVM::runOnModule (Module &Module) {
141
229
M = &Module;
142
230
Ctx = &M->getContext ();
@@ -170,8 +258,11 @@ bool SPIRVRegularizeLLVM::regularize() {
170
258
Function *CF = Call->getCalledFunction ();
171
259
if (CF && CF->isIntrinsic ()) {
172
260
removeFnAttr (Call, Attribute::NoUnwind);
173
- if (auto *MSI = dyn_cast<MemSetInst>(Call))
261
+ auto *II = cast<IntrinsicInst>(Call);
262
+ if (auto *MSI = dyn_cast<MemSetInst>(II))
174
263
lowerMemset (MSI);
264
+ else if (II->getIntrinsicID () == Intrinsic::fshl)
265
+ lowerFunnelShiftLeft (II);
175
266
}
176
267
}
177
268
@@ -254,7 +345,7 @@ bool SPIRVRegularizeLLVM::regularize() {
254
345
}
255
346
}
256
347
for (Instruction *V : ToErase) {
257
- assert (V->user_empty ());
348
+ assert (V->user_empty () && " User non-empty \n " );
258
349
V->eraseFromParent ();
259
350
}
260
351
}
0 commit comments