@@ -129,17 +129,15 @@ ModulePass *llvm::createESIMDLowerVecArgPass() {
129
129
// Return ptr to first-class vector type if Value is a simd*, else return
130
130
// nullptr.
131
131
Type *ESIMDLowerVecArgPass::getSimdArgPtrTyOrNull (Value *arg) {
132
- if (auto ArgType = dyn_cast<PointerType>(arg->getType ())) {
133
- auto ContainedType = ArgType->getElementType ();
134
- if (ContainedType->isStructTy ()) {
135
- if (ContainedType->getStructNumElements () == 1 &&
136
- ContainedType->getStructElementType (0 )->isVectorTy ()) {
137
- return PointerType::get (ContainedType->getStructElementType (0 ),
138
- ArgType->getPointerAddressSpace ());
139
- }
140
- }
141
- }
142
- return nullptr ;
132
+ auto ArgType = dyn_cast<PointerType>(arg->getType ());
133
+ if (!ArgType || !ArgType->getElementType ()->isStructTy ())
134
+ return nullptr ;
135
+ auto ContainedType = ArgType->getElementType ();
136
+ if ((ContainedType->getStructNumElements () != 1 ) ||
137
+ !ContainedType->getStructElementType (0 )->isVectorTy ())
138
+ return nullptr ;
139
+ return PointerType::get (ContainedType->getStructElementType (0 ),
140
+ ArgType->getPointerAddressSpace ());
143
141
}
144
142
145
143
// F may have multiple arguments of type simd*. This
@@ -197,33 +195,27 @@ Function *ESIMDLowerVecArgPass::rewriteFunc(Function &F) {
197
195
for (auto &use : F.uses ()) {
198
196
// Use must be a call site
199
197
SmallVector<Value *, 10 > Params;
200
- auto User = use.getUser ();
201
- if (isa<CallInst>(User)) {
202
- auto Call = cast<CallInst>(User);
203
- // Variadic functions not supported
204
- assert (!Call->getFunction ()->isVarArg () &&
205
- " Variadic functions not supported" );
206
- for (unsigned int I = 0 ,
207
- NumOpnds = cast<CallInst>(Call)->getNumArgOperands ();
208
- I != NumOpnds; I++) {
209
- auto SrcOpnd = Call->getOperand (I);
210
- auto NewTy = getSimdArgPtrTyOrNull (SrcOpnd);
211
- if (NewTy) {
212
- auto BitCast = new BitCastInst (SrcOpnd, NewTy, " " , Call);
213
- Params.push_back (BitCast);
214
- } else {
215
- if (SrcOpnd != &F)
216
- Params.push_back (SrcOpnd);
217
- else
218
- Params.push_back (NF);
219
- }
198
+ auto Call = cast<CallInst>(use.getUser ());
199
+ // Variadic functions not supported
200
+ assert (!Call->getFunction ()->isVarArg () &&
201
+ " Variadic functions not supported" );
202
+ for (unsigned int I = 0 ; I < Call->getNumArgOperands (); I++) {
203
+ auto SrcOpnd = Call->getOperand (I);
204
+ auto NewTy = getSimdArgPtrTyOrNull (SrcOpnd);
205
+ if (NewTy) {
206
+ auto BitCast = new BitCastInst (SrcOpnd, NewTy, " " , Call);
207
+ Params.push_back (BitCast);
208
+ } else {
209
+ if (SrcOpnd != &F)
210
+ Params.push_back (SrcOpnd);
211
+ else
212
+ Params.push_back (NF);
220
213
}
221
-
222
- // create new call instruction
223
- auto NewCallInst = CallInst::Create (NFTy, NF, Params, " " );
224
- NewCallInst->setCallingConv (F.getCallingConv ());
225
- OldNewInst.push_back (std::make_pair (Call, NewCallInst));
226
214
}
215
+ // create new call instruction
216
+ auto NewCallInst = CallInst::Create (NFTy, NF, Params, " " );
217
+ NewCallInst->setCallingConv (F.getCallingConv ());
218
+ OldNewInst.push_back (std::make_pair (Call, NewCallInst));
227
219
}
228
220
229
221
for (auto InstPair : OldNewInst) {
0 commit comments