Skip to content

Commit 10c2d5f

Browse files
authored
[RISCV][GISel] RegBank select and instruction select for vector G_ADD, G_SUB (#74114)
RegisterBank Selection for scalable vector G_ADD and G_SUB by creating new mappings for different types of vector register banks. Then implement Instruction Selection for the same operations by choosing the correct RISC-V vector register class.
1 parent 41be541 commit 10c2d5f

File tree

7 files changed

+3043
-3
lines changed

7 files changed

+3043
-3
lines changed

llvm/lib/CodeGen/GlobalISel/InstructionSelect.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,8 @@ bool InstructionSelect::runOnMachineFunction(MachineFunction &MF) {
281281
}
282282

283283
const LLT Ty = MRI.getType(VReg);
284-
if (Ty.isValid() && Ty.getSizeInBits() > TRI.getRegSizeInBits(*RC)) {
284+
if (Ty.isValid() &&
285+
TypeSize::isKnownGT(Ty.getSizeInBits(), TRI.getRegSizeInBits(*RC))) {
285286
reportGISelFailure(
286287
MF, TPC, MORE, "gisel-select",
287288
"VReg's low-level type and register class have different sizes", *MI);

llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -844,7 +844,20 @@ const TargetRegisterClass *RISCVInstructionSelector::getRegClassForTypeOnBank(
844844
return &RISCV::FPR64RegClass;
845845
}
846846

847-
// TODO: Non-GPR register classes.
847+
if (RB.getID() == RISCV::VRBRegBankID) {
848+
if (Ty.getSizeInBits().getKnownMinValue() <= 64)
849+
return &RISCV::VRRegClass;
850+
851+
if (Ty.getSizeInBits().getKnownMinValue() == 128)
852+
return &RISCV::VRM2RegClass;
853+
854+
if (Ty.getSizeInBits().getKnownMinValue() == 256)
855+
return &RISCV::VRM4RegClass;
856+
857+
if (Ty.getSizeInBits().getKnownMinValue() == 512)
858+
return &RISCV::VRM8RegClass;
859+
}
860+
848861
return nullptr;
849862
}
850863

llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,27 @@ namespace llvm {
2525
namespace RISCV {
2626

2727
const RegisterBankInfo::PartialMapping PartMappings[] = {
28+
// clang-format off
2829
{0, 32, GPRBRegBank},
2930
{0, 64, GPRBRegBank},
3031
{0, 32, FPRBRegBank},
3132
{0, 64, FPRBRegBank},
33+
{0, 64, VRBRegBank},
34+
{0, 128, VRBRegBank},
35+
{0, 256, VRBRegBank},
36+
{0, 512, VRBRegBank},
37+
// clang-format on
3238
};
3339

3440
enum PartialMappingIdx {
3541
PMI_GPRB32 = 0,
3642
PMI_GPRB64 = 1,
3743
PMI_FPRB32 = 2,
3844
PMI_FPRB64 = 3,
45+
PMI_VRB64 = 4,
46+
PMI_VRB128 = 5,
47+
PMI_VRB256 = 6,
48+
PMI_VRB512 = 7,
3949
};
4050

4151
const RegisterBankInfo::ValueMapping ValueMappings[] = {
@@ -57,6 +67,22 @@ const RegisterBankInfo::ValueMapping ValueMappings[] = {
5767
{&PartMappings[PMI_FPRB64], 1},
5868
{&PartMappings[PMI_FPRB64], 1},
5969
{&PartMappings[PMI_FPRB64], 1},
70+
// Maximum 3 VR LMUL={1, MF2, MF4, MF8} operands.
71+
{&PartMappings[PMI_VRB64], 1},
72+
{&PartMappings[PMI_VRB64], 1},
73+
{&PartMappings[PMI_VRB64], 1},
74+
// Maximum 3 VR LMUL=2 operands.
75+
{&PartMappings[PMI_VRB128], 1},
76+
{&PartMappings[PMI_VRB128], 1},
77+
{&PartMappings[PMI_VRB128], 1},
78+
// Maximum 3 VR LMUL=4 operands.
79+
{&PartMappings[PMI_VRB256], 1},
80+
{&PartMappings[PMI_VRB256], 1},
81+
{&PartMappings[PMI_VRB256], 1},
82+
// Maximum 3 VR LMUL=8 operands.
83+
{&PartMappings[PMI_VRB512], 1},
84+
{&PartMappings[PMI_VRB512], 1},
85+
{&PartMappings[PMI_VRB512], 1},
6086
};
6187

6288
enum ValueMappingIdx {
@@ -65,6 +91,10 @@ enum ValueMappingIdx {
6591
GPRB64Idx = 4,
6692
FPRB32Idx = 7,
6793
FPRB64Idx = 10,
94+
VRB64Idx = 13,
95+
VRB128Idx = 16,
96+
VRB256Idx = 19,
97+
VRB512Idx = 22,
6898
};
6999
} // namespace RISCV
70100
} // namespace llvm
@@ -215,6 +245,23 @@ bool RISCVRegisterBankInfo::anyUseOnlyUseFP(
215245
[&](const MachineInstr &UseMI) { return onlyUsesFP(UseMI, MRI, TRI); });
216246
}
217247

248+
static const RegisterBankInfo::ValueMapping *getVRBValueMapping(unsigned Size) {
249+
unsigned Idx;
250+
251+
if (Size <= 64)
252+
Idx = RISCV::VRB64Idx;
253+
else if (Size == 128)
254+
Idx = RISCV::VRB128Idx;
255+
else if (Size == 256)
256+
Idx = RISCV::VRB256Idx;
257+
else if (Size == 512)
258+
Idx = RISCV::VRB512Idx;
259+
else
260+
llvm::report_fatal_error("Invalid Size");
261+
262+
return &RISCV::ValueMappings[Idx];
263+
}
264+
218265
const RegisterBankInfo::InstructionMapping &
219266
RISCVRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
220267
const unsigned Opc = MI.getOpcode();
@@ -242,7 +289,16 @@ RISCVRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
242289

243290
switch (Opc) {
244291
case TargetOpcode::G_ADD:
245-
case TargetOpcode::G_SUB:
292+
case TargetOpcode::G_SUB: {
293+
if (MRI.getType(MI.getOperand(0).getReg()).isVector()) {
294+
LLT Ty = MRI.getType(MI.getOperand(0).getReg());
295+
return getInstructionMapping(
296+
DefaultMappingID, /*Cost=*/1,
297+
getVRBValueMapping(Ty.getSizeInBits().getKnownMinValue()),
298+
NumOperands);
299+
}
300+
}
301+
LLVM_FALLTHROUGH;
246302
case TargetOpcode::G_SHL:
247303
case TargetOpcode::G_ASHR:
248304
case TargetOpcode::G_LSHR:

0 commit comments

Comments
 (0)