Skip to content

Commit e886bb3

Browse files
committed
[RISCV] Adjust select shuffle cost to reflect mask creation cost
This is inspired by llvm#77342 (review), and is split off of same with some differences in style. A select is a vmerge.vv with the additional cost of materializing the bitmask vector in a vreg. All masks fit within a single vector register (e8 + m8 is the worst case), and thus our worst case cost should be roughly 3 (2 scalar to produce the address, one vector load op). Given most shuffles are small, and the mask will be instead produced by LUI/ADDI + vmv.s.x or ADDI + vmv.s.x, using 2 as the default seems quite reasonable. At worst, we're not going to be off by much. The prior lowering scaled the cost of the bitmask with LMUL, which I don't understand. At m1 it did use the same base cost of 2.
1 parent 5ce067d commit e886bb3

File tree

3 files changed

+48
-47
lines changed

3 files changed

+48
-47
lines changed

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,6 @@ RISCVTTIImpl::getRISCVInstructionCost(ArrayRef<unsigned> OpCodes, MVT VT,
8484
Cost += VL;
8585
break;
8686
}
87-
case RISCV::VMV_S_X:
88-
// FIXME: VMV_S_X doesn't use LMUL, the cost should be 1
8987
default:
9088
Cost += LMULCost;
9189
}
@@ -443,10 +441,13 @@ InstructionCost RISCVTTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
443441
// vsetivli zero, 8, e8, mf2, ta, ma (ignored)
444442
// vmv.s.x v0, a0
445443
// vmerge.vvm v8, v9, v8, v0
444+
// We use 2 for the cost of the mask materialization as this is the true
445+
// cost for small masks and most shuffles are small. At worst, this cost
446+
// should be a very small constant for the constant pool load. As such,
447+
// we may bias towards large selects slightly more than truely warranted.
446448
return LT.first *
447-
(TLI->getLMULCost(LT.second) + // FIXME: should be 1 for li
448-
getRISCVInstructionCost({RISCV::VMV_S_X, RISCV::VMERGE_VVM},
449-
LT.second, CostKind));
449+
(2 + getRISCVInstructionCost({RISCV::VMERGE_VVM},
450+
LT.second, CostKind));
450451
}
451452
case TTI::SK_Broadcast: {
452453
bool HasScalar = (Args.size() > 0) && (Operator::getOpcode(Args[0]) ==

0 commit comments

Comments
 (0)