@@ -239,6 +239,7 @@ static bool isRegisterVectorType(LLT Ty) {
239
239
EltSize == 128 || EltSize == 256 ;
240
240
}
241
241
242
+ // TODO: replace all uses of isRegisterType with isRegisterClassType
242
243
static bool isRegisterType (LLT Ty) {
243
244
if (!isRegisterSize (Ty.getSizeInBits ()))
244
245
return false ;
@@ -258,6 +259,8 @@ static LegalityPredicate isRegisterType(unsigned TypeIdx) {
258
259
}
259
260
260
261
// RegisterType that doesn't have a corresponding RegClass.
262
+ // TODO: Once `isRegisterType` is replaced with `isRegisterClassType` this
263
+ // should be removed.
261
264
static LegalityPredicate isIllegalRegisterType (unsigned TypeIdx) {
262
265
return [=](const LegalityQuery &Query) {
263
266
LLT Ty = Query.Types [TypeIdx];
@@ -276,6 +279,95 @@ static LegalityPredicate elementTypeIsLegal(unsigned TypeIdx) {
276
279
};
277
280
}
278
281
282
+ static const LLT S1 = LLT::scalar(1 );
283
+ static const LLT S8 = LLT::scalar(8 );
284
+ static const LLT S16 = LLT::scalar(16 );
285
+ static const LLT S32 = LLT::scalar(32 );
286
+ static const LLT S64 = LLT::scalar(64 );
287
+ static const LLT S96 = LLT::scalar(96 );
288
+ static const LLT S128 = LLT::scalar(128 );
289
+ static const LLT S160 = LLT::scalar(160 );
290
+ static const LLT S224 = LLT::scalar(224 );
291
+ static const LLT S256 = LLT::scalar(256 );
292
+ static const LLT S512 = LLT::scalar(512 );
293
+ static const LLT MaxScalar = LLT::scalar(MaxRegisterSize);
294
+
295
+ static const LLT V2S8 = LLT::fixed_vector(2 , 8 );
296
+ static const LLT V2S16 = LLT::fixed_vector(2 , 16 );
297
+ static const LLT V4S16 = LLT::fixed_vector(4 , 16 );
298
+ static const LLT V6S16 = LLT::fixed_vector(6 , 16 );
299
+ static const LLT V8S16 = LLT::fixed_vector(8 , 16 );
300
+ static const LLT V10S16 = LLT::fixed_vector(10 , 16 );
301
+ static const LLT V12S16 = LLT::fixed_vector(12 , 16 );
302
+ static const LLT V16S16 = LLT::fixed_vector(16 , 16 );
303
+
304
+ static const LLT V2S32 = LLT::fixed_vector(2 , 32 );
305
+ static const LLT V3S32 = LLT::fixed_vector(3 , 32 );
306
+ static const LLT V4S32 = LLT::fixed_vector(4 , 32 );
307
+ static const LLT V5S32 = LLT::fixed_vector(5 , 32 );
308
+ static const LLT V6S32 = LLT::fixed_vector(6 , 32 );
309
+ static const LLT V7S32 = LLT::fixed_vector(7 , 32 );
310
+ static const LLT V8S32 = LLT::fixed_vector(8 , 32 );
311
+ static const LLT V9S32 = LLT::fixed_vector(9 , 32 );
312
+ static const LLT V10S32 = LLT::fixed_vector(10 , 32 );
313
+ static const LLT V11S32 = LLT::fixed_vector(11 , 32 );
314
+ static const LLT V12S32 = LLT::fixed_vector(12 , 32 );
315
+ static const LLT V16S32 = LLT::fixed_vector(16 , 32 );
316
+ static const LLT V32S32 = LLT::fixed_vector(32 , 32 );
317
+
318
+ static const LLT V2S64 = LLT::fixed_vector(2 , 64 );
319
+ static const LLT V3S64 = LLT::fixed_vector(3 , 64 );
320
+ static const LLT V4S64 = LLT::fixed_vector(4 , 64 );
321
+ static const LLT V5S64 = LLT::fixed_vector(5 , 64 );
322
+ static const LLT V6S64 = LLT::fixed_vector(6 , 64 );
323
+ static const LLT V7S64 = LLT::fixed_vector(7 , 64 );
324
+ static const LLT V8S64 = LLT::fixed_vector(8 , 64 );
325
+ static const LLT V16S64 = LLT::fixed_vector(16 , 64 );
326
+
327
+ static const LLT V2S128 = LLT::fixed_vector(2 , 128 );
328
+ static const LLT V4S128 = LLT::fixed_vector(4 , 128 );
329
+
330
+ static std::initializer_list<LLT> AllScalarTypes = {S32, S64, S96, S128,
331
+ S160, S224, S256, S512};
332
+
333
+ static std::initializer_list<LLT> AllS16Vectors{
334
+ V2S16, V4S16, V6S16, V8S16, V10S16, V12S16, V16S16, V2S128, V4S128};
335
+
336
+ static std::initializer_list<LLT> AllS32Vectors = {
337
+ V2S32, V3S32, V4S32, V5S32, V6S32, V7S32, V8S32,
338
+ V9S32, V10S32, V11S32, V12S32, V16S32, V32S32};
339
+
340
+ static std::initializer_list<LLT> AllS64Vectors = {V2S64, V3S64, V4S64, V5S64,
341
+ V6S64, V7S64, V8S64, V16S64};
342
+
343
+ static bool typeInSet (LLT Ty, std::initializer_list<LLT> TypesInit) {
344
+ SmallVector<LLT, 4 > Types = TypesInit;
345
+ return llvm::is_contained (Types, Ty);
346
+ }
347
+
348
+ static LLT GetAddrSpacePtr (unsigned AS, const GCNTargetMachine &TM) {
349
+ return LLT::pointer (AS, TM.getPointerSizeInBits (AS));
350
+ }
351
+
352
+ // Checks whether a type is in the list of legal register types.
353
+ static bool isRegisterClassType (LLT Ty) {
354
+ if (Ty.isVector () && Ty.getElementType ().isPointer ())
355
+ Ty = LLT::fixed_vector (Ty.getNumElements (),
356
+ LLT::scalar (Ty.getScalarSizeInBits ()));
357
+ else if (Ty.isPointer ())
358
+ Ty = LLT::scalar (Ty.getScalarSizeInBits ());
359
+
360
+ return typeInSet (Ty, AllS32Vectors) || typeInSet (Ty, AllS64Vectors) ||
361
+ typeInSet (Ty, AllScalarTypes) || typeInSet (Ty, AllS16Vectors) ||
362
+ Ty.isPointer ();
363
+ }
364
+
365
+ static LegalityPredicate isRegisterClassType (unsigned TypeIdx) {
366
+ return [TypeIdx](const LegalityQuery &Query) {
367
+ return isRegisterClassType (Query.Types [TypeIdx]);
368
+ };
369
+ }
370
+
279
371
// If we have a truncating store or an extending load with a data size larger
280
372
// than 32-bits, we need to reduce to a 32-bit type.
281
373
static LegalityPredicate isWideScalarExtLoadTruncStore (unsigned TypeIdx) {
@@ -574,67 +666,18 @@ AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_,
574
666
: ST(ST_) {
575
667
using namespace TargetOpcode ;
576
668
577
- auto GetAddrSpacePtr = [&TM](unsigned AS) {
578
- return LLT::pointer (AS, TM.getPointerSizeInBits (AS));
579
- };
580
-
581
- const LLT S1 = LLT::scalar (1 );
582
- const LLT S8 = LLT::scalar (8 );
583
- const LLT S16 = LLT::scalar (16 );
584
- const LLT S32 = LLT::scalar (32 );
585
- const LLT S64 = LLT::scalar (64 );
586
- const LLT S128 = LLT::scalar (128 );
587
- const LLT S256 = LLT::scalar (256 );
588
- const LLT S512 = LLT::scalar (512 );
589
- const LLT MaxScalar = LLT::scalar (MaxRegisterSize);
590
-
591
- const LLT V2S8 = LLT::fixed_vector (2 , 8 );
592
- const LLT V2S16 = LLT::fixed_vector (2 , 16 );
593
- const LLT V4S16 = LLT::fixed_vector (4 , 16 );
594
-
595
- const LLT V2S32 = LLT::fixed_vector (2 , 32 );
596
- const LLT V3S32 = LLT::fixed_vector (3 , 32 );
597
- const LLT V4S32 = LLT::fixed_vector (4 , 32 );
598
- const LLT V5S32 = LLT::fixed_vector (5 , 32 );
599
- const LLT V6S32 = LLT::fixed_vector (6 , 32 );
600
- const LLT V7S32 = LLT::fixed_vector (7 , 32 );
601
- const LLT V8S32 = LLT::fixed_vector (8 , 32 );
602
- const LLT V9S32 = LLT::fixed_vector (9 , 32 );
603
- const LLT V10S32 = LLT::fixed_vector (10 , 32 );
604
- const LLT V11S32 = LLT::fixed_vector (11 , 32 );
605
- const LLT V12S32 = LLT::fixed_vector (12 , 32 );
606
- const LLT V13S32 = LLT::fixed_vector (13 , 32 );
607
- const LLT V14S32 = LLT::fixed_vector (14 , 32 );
608
- const LLT V15S32 = LLT::fixed_vector (15 , 32 );
609
- const LLT V16S32 = LLT::fixed_vector (16 , 32 );
610
- const LLT V32S32 = LLT::fixed_vector (32 , 32 );
611
-
612
- const LLT V2S64 = LLT::fixed_vector (2 , 64 );
613
- const LLT V3S64 = LLT::fixed_vector (3 , 64 );
614
- const LLT V4S64 = LLT::fixed_vector (4 , 64 );
615
- const LLT V5S64 = LLT::fixed_vector (5 , 64 );
616
- const LLT V6S64 = LLT::fixed_vector (6 , 64 );
617
- const LLT V7S64 = LLT::fixed_vector (7 , 64 );
618
- const LLT V8S64 = LLT::fixed_vector (8 , 64 );
619
- const LLT V16S64 = LLT::fixed_vector (16 , 64 );
620
-
621
- std::initializer_list<LLT> AllS32Vectors =
622
- {V2S32, V3S32, V4S32, V5S32, V6S32, V7S32, V8S32,
623
- V9S32, V10S32, V11S32, V12S32, V13S32, V14S32, V15S32, V16S32, V32S32};
624
- std::initializer_list<LLT> AllS64Vectors =
625
- {V2S64, V3S64, V4S64, V5S64, V6S64, V7S64, V8S64, V16S64};
626
-
627
- const LLT GlobalPtr = GetAddrSpacePtr (AMDGPUAS::GLOBAL_ADDRESS);
628
- const LLT ConstantPtr = GetAddrSpacePtr (AMDGPUAS::CONSTANT_ADDRESS);
629
- const LLT Constant32Ptr = GetAddrSpacePtr (AMDGPUAS::CONSTANT_ADDRESS_32BIT);
630
- const LLT LocalPtr = GetAddrSpacePtr (AMDGPUAS::LOCAL_ADDRESS);
631
- const LLT RegionPtr = GetAddrSpacePtr (AMDGPUAS::REGION_ADDRESS);
632
- const LLT FlatPtr = GetAddrSpacePtr (AMDGPUAS::FLAT_ADDRESS);
633
- const LLT PrivatePtr = GetAddrSpacePtr (AMDGPUAS::PRIVATE_ADDRESS);
634
- const LLT BufferFatPtr = GetAddrSpacePtr (AMDGPUAS::BUFFER_FAT_POINTER);
635
- const LLT RsrcPtr = GetAddrSpacePtr (AMDGPUAS::BUFFER_RESOURCE);
669
+ const LLT GlobalPtr = GetAddrSpacePtr (AMDGPUAS::GLOBAL_ADDRESS, TM);
670
+ const LLT ConstantPtr = GetAddrSpacePtr (AMDGPUAS::CONSTANT_ADDRESS, TM);
671
+ const LLT Constant32Ptr =
672
+ GetAddrSpacePtr (AMDGPUAS::CONSTANT_ADDRESS_32BIT, TM);
673
+ const LLT LocalPtr = GetAddrSpacePtr (AMDGPUAS::LOCAL_ADDRESS, TM);
674
+ const LLT RegionPtr = GetAddrSpacePtr (AMDGPUAS::REGION_ADDRESS, TM);
675
+ const LLT FlatPtr = GetAddrSpacePtr (AMDGPUAS::FLAT_ADDRESS, TM);
676
+ const LLT PrivatePtr = GetAddrSpacePtr (AMDGPUAS::PRIVATE_ADDRESS, TM);
677
+ const LLT BufferFatPtr = GetAddrSpacePtr (AMDGPUAS::BUFFER_FAT_POINTER, TM);
678
+ const LLT RsrcPtr = GetAddrSpacePtr (AMDGPUAS::BUFFER_RESOURCE, TM);
636
679
const LLT BufferStridedPtr =
637
- GetAddrSpacePtr (AMDGPUAS::BUFFER_STRIDED_POINTER);
680
+ GetAddrSpacePtr (AMDGPUAS::BUFFER_STRIDED_POINTER, TM );
638
681
639
682
const LLT CodePtr = FlatPtr;
640
683
@@ -836,10 +879,9 @@ AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_,
836
879
.scalarize (0 );
837
880
838
881
getActionDefinitionsBuilder (G_BITCAST)
839
- // Don't worry about the size constraint.
840
- .legalIf (all (isRegisterType (0 ), isRegisterType (1 )))
841
- .lower ();
842
-
882
+ // Don't worry about the size constraint.
883
+ .legalIf (all (isRegisterClassType (0 ), isRegisterClassType (1 )))
884
+ .lower ();
843
885
844
886
getActionDefinitionsBuilder (G_CONSTANT)
845
887
.legalFor ({S1, S32, S64, S16, GlobalPtr,
0 commit comments