@@ -239,6 +239,7 @@ static bool isRegisterVectorType(LLT Ty) {
239
239
EltSize == 128 || EltSize == 256 ;
240
240
}
241
241
242
+ // TODO: replace all uses of isRegisterType with isRegisterClassType
242
243
static bool isRegisterType (LLT Ty) {
243
244
if (!isRegisterSize (Ty.getSizeInBits ()))
244
245
return false ;
@@ -276,6 +277,105 @@ static LegalityPredicate elementTypeIsLegal(unsigned TypeIdx) {
276
277
};
277
278
}
278
279
280
+ static const LLT S1 = LLT::scalar(1 );
281
+ static const LLT S8 = LLT::scalar(8 );
282
+ static const LLT S16 = LLT::scalar(16 );
283
+ static const LLT S32 = LLT::scalar(32 );
284
+ static const LLT S64 = LLT::scalar(64 );
285
+ static const LLT S96 = LLT::scalar(96 );
286
+ static const LLT S128 = LLT::scalar(128 );
287
+ static const LLT S160 = LLT::scalar(160 );
288
+ static const LLT S224 = LLT::scalar(224 );
289
+ static const LLT S256 = LLT::scalar(256 );
290
+ static const LLT S512 = LLT::scalar(512 );
291
+ static const LLT MaxScalar = LLT::scalar(MaxRegisterSize);
292
+
293
+ static const LLT V2S8 = LLT::fixed_vector(2 , 8 );
294
+ static const LLT V2S16 = LLT::fixed_vector(2 , 16 );
295
+ static const LLT V4S16 = LLT::fixed_vector(4 , 16 );
296
+ static const LLT V6S16 = LLT::fixed_vector(6 , 16 );
297
+ static const LLT V8S16 = LLT::fixed_vector(8 , 16 );
298
+ static const LLT V10S16 = LLT::fixed_vector(10 , 16 );
299
+ static const LLT V12S16 = LLT::fixed_vector(12 , 16 );
300
+ static const LLT V16S16 = LLT::fixed_vector(16 , 16 );
301
+
302
+ static const LLT V2S32 = LLT::fixed_vector(2 , 32 );
303
+ static const LLT V3S32 = LLT::fixed_vector(3 , 32 );
304
+ static const LLT V4S32 = LLT::fixed_vector(4 , 32 );
305
+ static const LLT V5S32 = LLT::fixed_vector(5 , 32 );
306
+ static const LLT V6S32 = LLT::fixed_vector(6 , 32 );
307
+ static const LLT V7S32 = LLT::fixed_vector(7 , 32 );
308
+ static const LLT V8S32 = LLT::fixed_vector(8 , 32 );
309
+ static const LLT V9S32 = LLT::fixed_vector(9 , 32 );
310
+ static const LLT V10S32 = LLT::fixed_vector(10 , 32 );
311
+ static const LLT V11S32 = LLT::fixed_vector(11 , 32 );
312
+ static const LLT V12S32 = LLT::fixed_vector(12 , 32 );
313
+ static const LLT V16S32 = LLT::fixed_vector(16 , 32 );
314
+ static const LLT V32S32 = LLT::fixed_vector(32 , 32 );
315
+
316
+ static const LLT V2S64 = LLT::fixed_vector(2 , 64 );
317
+ static const LLT V3S64 = LLT::fixed_vector(3 , 64 );
318
+ static const LLT V4S64 = LLT::fixed_vector(4 , 64 );
319
+ static const LLT V5S64 = LLT::fixed_vector(5 , 64 );
320
+ static const LLT V6S64 = LLT::fixed_vector(6 , 64 );
321
+ static const LLT V7S64 = LLT::fixed_vector(7 , 64 );
322
+ static const LLT V8S64 = LLT::fixed_vector(8 , 64 );
323
+ static const LLT V16S64 = LLT::fixed_vector(16 , 64 );
324
+
325
+ static const LLT V2S128 = LLT::fixed_vector(2 , 128 );
326
+ static const LLT V4S128 = LLT::fixed_vector(4 , 128 );
327
+
328
+ static std::initializer_list<LLT> AllScalarTypes = {S32, S64, S96, S128,
329
+ S160, S224, S256, S512};
330
+
331
+ static std::initializer_list<LLT> AllS16Vectors{
332
+ V2S16, V4S16, V6S16, V8S16, V10S16, V12S16, V16S16, V2S128, V4S128};
333
+
334
+ static std::initializer_list<LLT> AllS32Vectors = {
335
+ V2S32, V3S32, V4S32, V5S32, V6S32, V7S32, V8S32,
336
+ V9S32, V10S32, V11S32, V12S32, V16S32, V32S32};
337
+
338
+ static std::initializer_list<LLT> AllS64Vectors = {V2S64, V3S64, V4S64, V5S64,
339
+ V6S64, V7S64, V8S64, V16S64};
340
+
341
+ static bool typeInSet (LLT Ty, std::initializer_list<LLT> TypesInit) {
342
+ SmallVector<LLT, 4 > Types = TypesInit;
343
+ return llvm::is_contained (Types, Ty);
344
+ }
345
+
346
+ static LLT GetAddrSpacePtr (unsigned AS, const GCNTargetMachine &TM) {
347
+ return LLT::pointer (AS, TM.getPointerSizeInBits (AS));
348
+ }
349
+
350
+ // Checks whether a type is in the list of legal register types.
351
+ static bool isRegisterClassType (LLT Ty, const GCNTargetMachine &TM) {
352
+ const LLT GlobalPtr = GetAddrSpacePtr (AMDGPUAS::GLOBAL_ADDRESS, TM);
353
+ const LLT LocalPtr = GetAddrSpacePtr (AMDGPUAS::LOCAL_ADDRESS, TM);
354
+ const LLT FlatPtr = GetAddrSpacePtr (AMDGPUAS::FLAT_ADDRESS, TM);
355
+
356
+ // TODO: list all possible ptr vectors
357
+ const LLT V2FlatPtr = LLT::fixed_vector (2 , FlatPtr);
358
+ const LLT V3LocalPtr = LLT::fixed_vector (3 , LocalPtr);
359
+ const LLT V5LocalPtr = LLT::fixed_vector (5 , LocalPtr);
360
+ const LLT V16LocalPtr = LLT::fixed_vector (16 , LocalPtr);
361
+ const LLT V2GlobalPtr = LLT::fixed_vector (2 , GlobalPtr);
362
+ const LLT V4GlobalPtr = LLT::fixed_vector (4 , GlobalPtr);
363
+
364
+ std::initializer_list<LLT> AllPtrTypes{V2FlatPtr, V3LocalPtr, V5LocalPtr,
365
+ V16LocalPtr, V2GlobalPtr, V4GlobalPtr};
366
+
367
+ return typeInSet (Ty, AllS32Vectors) || typeInSet (Ty, AllS64Vectors) ||
368
+ typeInSet (Ty, AllScalarTypes) || typeInSet (Ty, AllS16Vectors) ||
369
+ typeInSet (Ty, AllPtrTypes) || Ty.isPointer ();
370
+ }
371
+
372
+ static LegalityPredicate isRegisterClassType (unsigned TypeIdx,
373
+ const GCNTargetMachine &TM) {
374
+ return [TypeIdx, &TM](const LegalityQuery &Query) {
375
+ return isRegisterClassType (Query.Types [TypeIdx], TM);
376
+ };
377
+ }
378
+
279
379
// If we have a truncating store or an extending load with a data size larger
280
380
// than 32-bits, we need to reduce to a 32-bit type.
281
381
static LegalityPredicate isWideScalarExtLoadTruncStore (unsigned TypeIdx) {
@@ -574,67 +674,18 @@ AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_,
574
674
: ST(ST_) {
575
675
using namespace TargetOpcode ;
576
676
577
- auto GetAddrSpacePtr = [&TM](unsigned AS) {
578
- return LLT::pointer (AS, TM.getPointerSizeInBits (AS));
579
- };
580
-
581
- const LLT S1 = LLT::scalar (1 );
582
- const LLT S8 = LLT::scalar (8 );
583
- const LLT S16 = LLT::scalar (16 );
584
- const LLT S32 = LLT::scalar (32 );
585
- const LLT S64 = LLT::scalar (64 );
586
- const LLT S128 = LLT::scalar (128 );
587
- const LLT S256 = LLT::scalar (256 );
588
- const LLT S512 = LLT::scalar (512 );
589
- const LLT MaxScalar = LLT::scalar (MaxRegisterSize);
590
-
591
- const LLT V2S8 = LLT::fixed_vector (2 , 8 );
592
- const LLT V2S16 = LLT::fixed_vector (2 , 16 );
593
- const LLT V4S16 = LLT::fixed_vector (4 , 16 );
594
-
595
- const LLT V2S32 = LLT::fixed_vector (2 , 32 );
596
- const LLT V3S32 = LLT::fixed_vector (3 , 32 );
597
- const LLT V4S32 = LLT::fixed_vector (4 , 32 );
598
- const LLT V5S32 = LLT::fixed_vector (5 , 32 );
599
- const LLT V6S32 = LLT::fixed_vector (6 , 32 );
600
- const LLT V7S32 = LLT::fixed_vector (7 , 32 );
601
- const LLT V8S32 = LLT::fixed_vector (8 , 32 );
602
- const LLT V9S32 = LLT::fixed_vector (9 , 32 );
603
- const LLT V10S32 = LLT::fixed_vector (10 , 32 );
604
- const LLT V11S32 = LLT::fixed_vector (11 , 32 );
605
- const LLT V12S32 = LLT::fixed_vector (12 , 32 );
606
- const LLT V13S32 = LLT::fixed_vector (13 , 32 );
607
- const LLT V14S32 = LLT::fixed_vector (14 , 32 );
608
- const LLT V15S32 = LLT::fixed_vector (15 , 32 );
609
- const LLT V16S32 = LLT::fixed_vector (16 , 32 );
610
- const LLT V32S32 = LLT::fixed_vector (32 , 32 );
611
-
612
- const LLT V2S64 = LLT::fixed_vector (2 , 64 );
613
- const LLT V3S64 = LLT::fixed_vector (3 , 64 );
614
- const LLT V4S64 = LLT::fixed_vector (4 , 64 );
615
- const LLT V5S64 = LLT::fixed_vector (5 , 64 );
616
- const LLT V6S64 = LLT::fixed_vector (6 , 64 );
617
- const LLT V7S64 = LLT::fixed_vector (7 , 64 );
618
- const LLT V8S64 = LLT::fixed_vector (8 , 64 );
619
- const LLT V16S64 = LLT::fixed_vector (16 , 64 );
620
-
621
- std::initializer_list<LLT> AllS32Vectors =
622
- {V2S32, V3S32, V4S32, V5S32, V6S32, V7S32, V8S32,
623
- V9S32, V10S32, V11S32, V12S32, V13S32, V14S32, V15S32, V16S32, V32S32};
624
- std::initializer_list<LLT> AllS64Vectors =
625
- {V2S64, V3S64, V4S64, V5S64, V6S64, V7S64, V8S64, V16S64};
626
-
627
- const LLT GlobalPtr = GetAddrSpacePtr (AMDGPUAS::GLOBAL_ADDRESS);
628
- const LLT ConstantPtr = GetAddrSpacePtr (AMDGPUAS::CONSTANT_ADDRESS);
629
- const LLT Constant32Ptr = GetAddrSpacePtr (AMDGPUAS::CONSTANT_ADDRESS_32BIT);
630
- const LLT LocalPtr = GetAddrSpacePtr (AMDGPUAS::LOCAL_ADDRESS);
631
- const LLT RegionPtr = GetAddrSpacePtr (AMDGPUAS::REGION_ADDRESS);
632
- const LLT FlatPtr = GetAddrSpacePtr (AMDGPUAS::FLAT_ADDRESS);
633
- const LLT PrivatePtr = GetAddrSpacePtr (AMDGPUAS::PRIVATE_ADDRESS);
634
- const LLT BufferFatPtr = GetAddrSpacePtr (AMDGPUAS::BUFFER_FAT_POINTER);
635
- const LLT RsrcPtr = GetAddrSpacePtr (AMDGPUAS::BUFFER_RESOURCE);
677
+ const LLT GlobalPtr = GetAddrSpacePtr (AMDGPUAS::GLOBAL_ADDRESS, TM);
678
+ const LLT ConstantPtr = GetAddrSpacePtr (AMDGPUAS::CONSTANT_ADDRESS, TM);
679
+ const LLT Constant32Ptr =
680
+ GetAddrSpacePtr (AMDGPUAS::CONSTANT_ADDRESS_32BIT, TM);
681
+ const LLT LocalPtr = GetAddrSpacePtr (AMDGPUAS::LOCAL_ADDRESS, TM);
682
+ const LLT RegionPtr = GetAddrSpacePtr (AMDGPUAS::REGION_ADDRESS, TM);
683
+ const LLT FlatPtr = GetAddrSpacePtr (AMDGPUAS::FLAT_ADDRESS, TM);
684
+ const LLT PrivatePtr = GetAddrSpacePtr (AMDGPUAS::PRIVATE_ADDRESS, TM);
685
+ const LLT BufferFatPtr = GetAddrSpacePtr (AMDGPUAS::BUFFER_FAT_POINTER, TM);
686
+ const LLT RsrcPtr = GetAddrSpacePtr (AMDGPUAS::BUFFER_RESOURCE, TM);
636
687
const LLT BufferStridedPtr =
637
- GetAddrSpacePtr (AMDGPUAS::BUFFER_STRIDED_POINTER);
688
+ GetAddrSpacePtr (AMDGPUAS::BUFFER_STRIDED_POINTER, TM );
638
689
639
690
const LLT CodePtr = FlatPtr;
640
691
@@ -836,10 +887,9 @@ AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_,
836
887
.scalarize (0 );
837
888
838
889
getActionDefinitionsBuilder (G_BITCAST)
839
- // Don't worry about the size constraint.
840
- .legalIf (all (isRegisterType (0 ), isRegisterType (1 )))
841
- .lower ();
842
-
890
+ // Don't worry about the size constraint.
891
+ .legalIf (all (isRegisterClassType (0 , TM), isRegisterClassType (1 , TM)))
892
+ .lower ();
843
893
844
894
getActionDefinitionsBuilder (G_CONSTANT)
845
895
.legalFor ({S1, S32, S64, S16, GlobalPtr,
0 commit comments