@@ -7275,11 +7275,26 @@ static EVT getPackedVectorTypeFromPredicateType(LLVMContext &Ctx, EVT PredVT,
7275
7275
// / Return the EVT of the data associated to a memory operation in \p
7276
7276
// / Root. If such EVT cannot be retrived, it returns an invalid EVT.
7277
7277
static EVT getMemVTFromNode (LLVMContext &Ctx, SDNode *Root) {
7278
- if (isa<MemSDNode>(Root))
7279
- return cast<MemSDNode>(Root)->getMemoryVT ();
7278
+ if (auto *MemIntr = dyn_cast<MemIntrinsicSDNode>(Root))
7279
+ return MemIntr->getMemoryVT ();
7280
+
7281
+ if (isa<MemSDNode>(Root)) {
7282
+ EVT MemVT = cast<MemSDNode>(Root)->getMemoryVT ();
7283
+
7284
+ EVT DataVT;
7285
+ if (auto *Load = dyn_cast<LoadSDNode>(Root))
7286
+ DataVT = Load->getValueType (0 );
7287
+ else if (auto *Load = dyn_cast<MaskedLoadSDNode>(Root))
7288
+ DataVT = Load->getValueType (0 );
7289
+ else if (auto *Store = dyn_cast<StoreSDNode>(Root))
7290
+ DataVT = Store->getValue ().getValueType ();
7291
+ else if (auto *Store = dyn_cast<MaskedStoreSDNode>(Root))
7292
+ DataVT = Store->getValue ().getValueType ();
7293
+ else
7294
+ llvm_unreachable (" Unexpected MemSDNode!" );
7280
7295
7281
- if (isa<MemIntrinsicSDNode>(Root))
7282
- return cast<MemIntrinsicSDNode>(Root)-> getMemoryVT ();
7296
+ return DataVT. changeVectorElementType (MemVT. getVectorElementType ());
7297
+ }
7283
7298
7284
7299
const unsigned Opcode = Root->getOpcode ();
7285
7300
// For custom ISD nodes, we have to look at them individually to extract the
@@ -7380,12 +7395,23 @@ bool AArch64DAGToDAGISel::SelectAddrModeIndexedSVE(SDNode *Root, SDValue N,
7380
7395
return false ;
7381
7396
7382
7397
SDValue VScale = N.getOperand (1 );
7383
- if (VScale.getOpcode () != ISD::VSCALE)
7398
+ int64_t MulImm = std::numeric_limits<int64_t >::max ();
7399
+ if (VScale.getOpcode () == ISD::VSCALE) {
7400
+ MulImm = cast<ConstantSDNode>(VScale.getOperand (0 ))->getSExtValue ();
7401
+ } else if (auto C = dyn_cast<ConstantSDNode>(VScale)) {
7402
+ int64_t ByteOffset = C->getSExtValue ();
7403
+ const auto KnownVScale =
7404
+ Subtarget->getSVEVectorSizeInBits () / AArch64::SVEBitsPerBlock;
7405
+
7406
+ if (!KnownVScale || ByteOffset % KnownVScale != 0 )
7407
+ return false ;
7408
+
7409
+ MulImm = ByteOffset / KnownVScale;
7410
+ } else
7384
7411
return false ;
7385
7412
7386
7413
TypeSize TS = MemVT.getSizeInBits ();
7387
7414
int64_t MemWidthBytes = static_cast <int64_t >(TS.getKnownMinValue ()) / 8 ;
7388
- int64_t MulImm = cast<ConstantSDNode>(VScale.getOperand (0 ))->getSExtValue ();
7389
7415
7390
7416
if ((MulImm % MemWidthBytes) != 0 )
7391
7417
return false ;
0 commit comments