@@ -72,17 +72,14 @@ void SPIRVGlobalRegistry::assignSPIRVTypeToVReg(SPIRVType *SpirvType,
72
72
VRegToTypeMap[&MF][VReg] = SpirvType;
73
73
}
74
74
75
- static Register createTypeVReg (MachineIRBuilder &MIRBuilder) {
76
- auto &MRI = MIRBuilder.getMF ().getRegInfo ();
77
- auto Res = MRI.createGenericVirtualRegister (LLT::scalar (32 ));
75
+ static Register createTypeVReg (MachineRegisterInfo &MRI) {
76
+ auto Res = MRI.createGenericVirtualRegister (LLT::scalar (64 ));
78
77
MRI.setRegClass (Res, &SPIRV::TYPERegClass);
79
78
return Res;
80
79
}
81
80
82
- static Register createTypeVReg (MachineRegisterInfo &MRI) {
83
- auto Res = MRI.createGenericVirtualRegister (LLT::scalar (32 ));
84
- MRI.setRegClass (Res, &SPIRV::TYPERegClass);
85
- return Res;
81
+ inline Register createTypeVReg (MachineIRBuilder &MIRBuilder) {
82
+ return createTypeVReg (MIRBuilder.getMF ().getRegInfo ());
86
83
}
87
84
88
85
SPIRVType *SPIRVGlobalRegistry::getOpTypeBool (MachineIRBuilder &MIRBuilder) {
@@ -157,26 +154,24 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems,
157
154
return MIB;
158
155
}
159
156
160
- std::tuple<Register, ConstantInt *, bool >
157
+ std::tuple<Register, ConstantInt *, bool , unsigned >
161
158
SPIRVGlobalRegistry::getOrCreateConstIntReg (uint64_t Val, SPIRVType *SpvType,
162
159
MachineIRBuilder *MIRBuilder,
163
160
MachineInstr *I,
164
161
const SPIRVInstrInfo *TII) {
165
- const IntegerType *LLVMIntTy;
166
- if (SpvType)
167
- LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType (SpvType));
168
- else
169
- LLVMIntTy = IntegerType::getInt32Ty (CurMF->getFunction ().getContext ());
162
+ assert (SpvType);
163
+ const IntegerType *LLVMIntTy =
164
+ cast<IntegerType>(getTypeForSPIRVType (SpvType));
165
+ unsigned BitWidth = getScalarOrVectorBitWidth (SpvType);
170
166
bool NewInstr = false ;
171
167
// Find a constant in DT or build a new one.
172
168
ConstantInt *CI = ConstantInt::get (const_cast <IntegerType *>(LLVMIntTy), Val);
173
169
Register Res = DT.find (CI, CurMF);
174
170
if (!Res.isValid ()) {
175
- unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth (SpvType) : 32 ;
176
171
// TODO: handle cases where the type is not 32bit wide
177
172
// TODO: https://github.com/llvm/llvm-project/issues/88129
178
- LLT LLTy = LLT::scalar ( 32 );
179
- Res = CurMF->getRegInfo ().createGenericVirtualRegister (LLTy );
173
+ Res =
174
+ CurMF->getRegInfo ().createGenericVirtualRegister (LLT::scalar (BitWidth) );
180
175
CurMF->getRegInfo ().setRegClass (Res, &SPIRV::iIDRegClass);
181
176
if (MIRBuilder)
182
177
assignTypeToVReg (LLVMIntTy, Res, *MIRBuilder);
@@ -185,35 +180,27 @@ SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType,
185
180
DT.add (CI, CurMF, Res);
186
181
NewInstr = true ;
187
182
}
188
- return std::make_tuple (Res, CI, NewInstr);
183
+ return std::make_tuple (Res, CI, NewInstr, BitWidth );
189
184
}
190
185
191
186
std::tuple<Register, ConstantFP *, bool , unsigned >
192
187
SPIRVGlobalRegistry::getOrCreateConstFloatReg (APFloat Val, SPIRVType *SpvType,
193
188
MachineIRBuilder *MIRBuilder,
194
189
MachineInstr *I,
195
190
const SPIRVInstrInfo *TII) {
196
- const Type *LLVMFloatTy ;
191
+ assert (SpvType) ;
197
192
LLVMContext &Ctx = CurMF->getFunction ().getContext ();
198
- unsigned BitWidth = 32 ;
199
- if (SpvType)
200
- LLVMFloatTy = getTypeForSPIRVType (SpvType);
201
- else {
202
- LLVMFloatTy = Type::getFloatTy (Ctx);
203
- if (MIRBuilder)
204
- SpvType = getOrCreateSPIRVType (LLVMFloatTy, *MIRBuilder);
205
- }
193
+ const Type *LLVMFloatTy = getTypeForSPIRVType (SpvType);
194
+ unsigned BitWidth = getScalarOrVectorBitWidth (SpvType);
206
195
bool NewInstr = false ;
207
196
// Find a constant in DT or build a new one.
208
197
auto *const CI = ConstantFP::get (Ctx, Val);
209
198
Register Res = DT.find (CI, CurMF);
210
199
if (!Res.isValid ()) {
211
- if (SpvType)
212
- BitWidth = getScalarOrVectorBitWidth (SpvType);
213
200
// TODO: handle cases where the type is not 32bit wide
214
201
// TODO: https://github.com/llvm/llvm-project/issues/88129
215
- LLT LLTy = LLT::scalar ( 32 );
216
- Res = CurMF->getRegInfo ().createGenericVirtualRegister (LLTy );
202
+ Res =
203
+ CurMF->getRegInfo ().createGenericVirtualRegister (LLT::scalar (BitWidth) );
217
204
CurMF->getRegInfo ().setRegClass (Res, &SPIRV::fIDRegClass );
218
205
if (MIRBuilder)
219
206
assignTypeToVReg (LLVMFloatTy, Res, *MIRBuilder);
@@ -269,7 +256,8 @@ Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
269
256
ConstantInt *CI;
270
257
Register Res;
271
258
bool New;
272
- std::tie (Res, CI, New) =
259
+ unsigned BitWidth;
260
+ std::tie (Res, CI, New, BitWidth) =
273
261
getOrCreateConstIntReg (Val, SpvType, nullptr , &I, &TII);
274
262
// If we have found Res register which is defined by the passed G_CONSTANT
275
263
// machine instruction, a new constant instruction should be created.
@@ -281,7 +269,7 @@ Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
281
269
MIB = BuildMI (BB, I, I.getDebugLoc (), TII.get (SPIRV::OpConstantI))
282
270
.addDef (Res)
283
271
.addUse (getSPIRVTypeID (SpvType));
284
- addNumImm (APInt (getScalarOrVectorBitWidth (SpvType) , Val), MIB);
272
+ addNumImm (APInt (BitWidth , Val), MIB);
285
273
} else {
286
274
MIB = BuildMI (BB, I, I.getDebugLoc (), TII.get (SPIRV::OpConstantNull))
287
275
.addDef (Res)
@@ -297,19 +285,17 @@ Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val,
297
285
MachineIRBuilder &MIRBuilder,
298
286
SPIRVType *SpvType,
299
287
bool EmitIR) {
288
+ assert (SpvType);
300
289
auto &MF = MIRBuilder.getMF ();
301
- const IntegerType *LLVMIntTy;
302
- if (SpvType)
303
- LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType (SpvType));
304
- else
305
- LLVMIntTy = IntegerType::getInt32Ty (MF.getFunction ().getContext ());
290
+ const IntegerType *LLVMIntTy =
291
+ cast<IntegerType>(getTypeForSPIRVType (SpvType));
306
292
// Find a constant in DT or build a new one.
307
293
const auto ConstInt =
308
294
ConstantInt::get (const_cast <IntegerType *>(LLVMIntTy), Val);
309
295
Register Res = DT.find (ConstInt, &MF);
310
296
if (!Res.isValid ()) {
311
- unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth (SpvType) : 32 ;
312
- LLT LLTy = LLT::scalar (EmitIR ? BitWidth : 32 );
297
+ unsigned BitWidth = getScalarOrVectorBitWidth (SpvType);
298
+ LLT LLTy = LLT::scalar (BitWidth);
313
299
Res = MF.getRegInfo ().createGenericVirtualRegister (LLTy);
314
300
MF.getRegInfo ().setRegClass (Res, &SPIRV::iIDRegClass);
315
301
assignTypeToVReg (LLVMIntTy, Res, MIRBuilder,
@@ -318,18 +304,17 @@ Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val,
318
304
if (EmitIR) {
319
305
MIRBuilder.buildConstant (Res, *ConstInt);
320
306
} else {
321
- if (!SpvType)
322
- SpvType = getOrCreateSPIRVIntegerType (BitWidth, MIRBuilder);
307
+ Register SpvTypeReg = getSPIRVTypeID (SpvType);
323
308
MachineInstrBuilder MIB;
324
309
if (Val) {
325
310
MIB = MIRBuilder.buildInstr (SPIRV::OpConstantI)
326
311
.addDef (Res)
327
- .addUse (getSPIRVTypeID (SpvType) );
312
+ .addUse (SpvTypeReg );
328
313
addNumImm (APInt (BitWidth, Val), MIB);
329
314
} else {
330
315
MIB = MIRBuilder.buildInstr (SPIRV::OpConstantNull)
331
316
.addDef (Res)
332
- .addUse (getSPIRVTypeID (SpvType) );
317
+ .addUse (SpvTypeReg );
333
318
}
334
319
const auto &Subtarget = CurMF->getSubtarget ();
335
320
constrainSelectedInstRegOperands (*MIB, *Subtarget.getInstrInfo (),
@@ -353,7 +338,8 @@ Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val,
353
338
const auto ConstFP = ConstantFP::get (Ctx, Val);
354
339
Register Res = DT.find (ConstFP, &MF);
355
340
if (!Res.isValid ()) {
356
- Res = MF.getRegInfo ().createGenericVirtualRegister (LLT::scalar (32 ));
341
+ Res = MF.getRegInfo ().createGenericVirtualRegister (
342
+ LLT::scalar (getScalarOrVectorBitWidth (SpvType)));
357
343
MF.getRegInfo ().setRegClass (Res, &SPIRV::fIDRegClass );
358
344
assignSPIRVTypeToVReg (SpvType, Res, MF);
359
345
DT.add (ConstFP, &MF, Res);
@@ -407,7 +393,7 @@ Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull(
407
393
408
394
// TODO: handle cases where the type is not 32bit wide
409
395
// TODO: https://github.com/llvm/llvm-project/issues/88129
410
- LLT LLTy = LLT::scalar (32 );
396
+ LLT LLTy = LLT::scalar (64 );
411
397
Register SpvVecConst =
412
398
CurMF->getRegInfo ().createGenericVirtualRegister (LLTy);
413
399
CurMF->getRegInfo ().setRegClass (SpvVecConst, &SPIRV::iIDRegClass);
@@ -509,7 +495,7 @@ Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
509
495
getOrCreateSPIRVIntegerType (BitWidth, MIRBuilder);
510
496
SpvScalConst = buildConstantInt (Val, MIRBuilder, SpvBaseType, EmitIR);
511
497
}
512
- LLT LLTy = EmitIR ? LLT::fixed_vector (ElemCnt, BitWidth) : LLT::scalar (32 );
498
+ LLT LLTy = EmitIR ? LLT::fixed_vector (ElemCnt, BitWidth) : LLT::scalar (64 );
513
499
Register SpvVecConst =
514
500
CurMF->getRegInfo ().createGenericVirtualRegister (LLTy);
515
501
CurMF->getRegInfo ().setRegClass (SpvVecConst, &SPIRV::iIDRegClass);
@@ -650,7 +636,6 @@ Register SPIRVGlobalRegistry::buildGlobalVariable(
650
636
651
637
// Set to Reg the same type as ResVReg has.
652
638
auto MRI = MIRBuilder.getMRI ();
653
- assert (MRI->getType (ResVReg).isPointer () && " Pointer type is expected" );
654
639
if (Reg != ResVReg) {
655
640
LLT RegLLTy =
656
641
LLT::pointer (MRI->getType (ResVReg).getAddressSpace (), getPointerSize ());
@@ -706,8 +691,9 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems,
706
691
bool EmitIR) {
707
692
assert ((ElemType->getOpcode () != SPIRV::OpTypeVoid) &&
708
693
" Invalid array element type" );
694
+ SPIRVType *SpvTypeInt32 = getOrCreateSPIRVIntegerType (32 , MIRBuilder);
709
695
Register NumElementsVReg =
710
- buildConstantInt (NumElems, MIRBuilder, nullptr , EmitIR);
696
+ buildConstantInt (NumElems, MIRBuilder, SpvTypeInt32 , EmitIR);
711
697
auto MIB = MIRBuilder.buildInstr (SPIRV::OpTypeArray)
712
698
.addDef (createTypeVReg (MIRBuilder))
713
699
.addUse (getSPIRVTypeID (ElemType))
@@ -1188,14 +1174,15 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeCoopMatr(
1188
1174
if (ResVReg.isValid ())
1189
1175
return MIRBuilder.getMF ().getRegInfo ().getUniqueVRegDef (ResVReg);
1190
1176
ResVReg = createTypeVReg (MIRBuilder);
1177
+ SPIRVType *SpvTypeInt32 = getOrCreateSPIRVIntegerType (32 , MIRBuilder);
1191
1178
SPIRVType *SpirvTy =
1192
1179
MIRBuilder.buildInstr (SPIRV::OpTypeCooperativeMatrixKHR)
1193
1180
.addDef (ResVReg)
1194
1181
.addUse (getSPIRVTypeID (ElemType))
1195
- .addUse (buildConstantInt (Scope, MIRBuilder, nullptr , true ))
1196
- .addUse (buildConstantInt (Rows, MIRBuilder, nullptr , true ))
1197
- .addUse (buildConstantInt (Columns, MIRBuilder, nullptr , true ))
1198
- .addUse (buildConstantInt (Use, MIRBuilder, nullptr , true ));
1182
+ .addUse (buildConstantInt (Scope, MIRBuilder, SpvTypeInt32 , true ))
1183
+ .addUse (buildConstantInt (Rows, MIRBuilder, SpvTypeInt32 , true ))
1184
+ .addUse (buildConstantInt (Columns, MIRBuilder, SpvTypeInt32 , true ))
1185
+ .addUse (buildConstantInt (Use, MIRBuilder, SpvTypeInt32 , true ));
1199
1186
DT.add (ExtensionType, &MIRBuilder.getMF (), ResVReg);
1200
1187
return SpirvTy;
1201
1188
}
@@ -1386,8 +1373,8 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVArrayType(
1386
1373
if (Reg.isValid ())
1387
1374
return getSPIRVTypeForVReg (Reg);
1388
1375
MachineBasicBlock &BB = *I.getParent ();
1389
- SPIRVType *SpirvType = getOrCreateSPIRVIntegerType (32 , I, TII);
1390
- Register Len = getOrCreateConstInt (NumElements, I, SpirvType , TII);
1376
+ SPIRVType *SpvTypeInt32 = getOrCreateSPIRVIntegerType (32 , I, TII);
1377
+ Register Len = getOrCreateConstInt (NumElements, I, SpvTypeInt32 , TII);
1391
1378
auto MIB = BuildMI (BB, I, I.getDebugLoc (), TII.get (SPIRV::OpTypeArray))
1392
1379
.addDef (createTypeVReg (CurMF->getRegInfo ()))
1393
1380
.addUse (getSPIRVTypeID (BaseType))
@@ -1436,7 +1423,7 @@ Register SPIRVGlobalRegistry::getOrCreateUndef(MachineInstr &I,
1436
1423
Register Res = DT.find (UV, CurMF);
1437
1424
if (Res.isValid ())
1438
1425
return Res;
1439
- LLT LLTy = LLT::scalar (32 );
1426
+ LLT LLTy = LLT::scalar (64 );
1440
1427
Res = CurMF->getRegInfo ().createGenericVirtualRegister (LLTy);
1441
1428
CurMF->getRegInfo ().setRegClass (Res, &SPIRV::iIDRegClass);
1442
1429
assignSPIRVTypeToVReg (SpvType, Res, *CurMF);
@@ -1451,3 +1438,61 @@ Register SPIRVGlobalRegistry::getOrCreateUndef(MachineInstr &I,
1451
1438
*ST.getRegisterInfo (), *ST.getRegBankInfo ());
1452
1439
return Res;
1453
1440
}
1441
+
1442
+ const TargetRegisterClass *
1443
+ SPIRVGlobalRegistry::getRegClass (SPIRVType *SpvType) const {
1444
+ unsigned Opcode = SpvType->getOpcode ();
1445
+ switch (Opcode) {
1446
+ case SPIRV::OpTypeFloat:
1447
+ return &SPIRV::fIDRegClass ;
1448
+ case SPIRV::OpTypePointer:
1449
+ return &SPIRV::pIDRegClass;
1450
+ case SPIRV::OpTypeVector: {
1451
+ SPIRVType *ElemType = getSPIRVTypeForVReg (SpvType->getOperand (1 ).getReg ());
1452
+ unsigned ElemOpcode = ElemType ? ElemType->getOpcode () : 0 ;
1453
+ if (ElemOpcode == SPIRV::OpTypeFloat)
1454
+ return &SPIRV::vfIDRegClass;
1455
+ if (ElemOpcode == SPIRV::OpTypePointer)
1456
+ return &SPIRV::vpIDRegClass;
1457
+ return &SPIRV::vIDRegClass;
1458
+ }
1459
+ }
1460
+ return &SPIRV::iIDRegClass;
1461
+ }
1462
+
1463
+ inline unsigned getAS (SPIRVType *SpvType) {
1464
+ return storageClassToAddressSpace (
1465
+ static_cast <SPIRV::StorageClass::StorageClass>(
1466
+ SpvType->getOperand (1 ).getImm ()));
1467
+ }
1468
+
1469
+ LLT SPIRVGlobalRegistry::getRegType (SPIRVType *SpvType) const {
1470
+ unsigned Opcode = SpvType ? SpvType->getOpcode () : 0 ;
1471
+ switch (Opcode) {
1472
+ case SPIRV::OpTypeInt:
1473
+ case SPIRV::OpTypeFloat:
1474
+ case SPIRV::OpTypeBool:
1475
+ return LLT::scalar (getScalarOrVectorBitWidth (SpvType));
1476
+ case SPIRV::OpTypePointer:
1477
+ return LLT::pointer (getAS (SpvType), getPointerSize ());
1478
+ case SPIRV::OpTypeVector: {
1479
+ SPIRVType *ElemType = getSPIRVTypeForVReg (SpvType->getOperand (1 ).getReg ());
1480
+ LLT ET;
1481
+ switch (ElemType ? ElemType->getOpcode () : 0 ) {
1482
+ case SPIRV::OpTypePointer:
1483
+ ET = LLT::pointer (getAS (ElemType), getPointerSize ());
1484
+ break ;
1485
+ case SPIRV::OpTypeInt:
1486
+ case SPIRV::OpTypeFloat:
1487
+ case SPIRV::OpTypeBool:
1488
+ ET = LLT::scalar (getScalarOrVectorBitWidth (ElemType));
1489
+ break ;
1490
+ default :
1491
+ ET = LLT::scalar (64 );
1492
+ }
1493
+ return LLT::fixed_vector (
1494
+ static_cast <unsigned >(SpvType->getOperand (2 ).getImm ()), ET);
1495
+ }
1496
+ }
1497
+ return LLT::scalar (64 );
1498
+ }
0 commit comments