Skip to content

Commit 7df617d

Browse files
committed
[WIP][AMDGPU] Split isInlinableLiteral16 into three and call the specific version if possible
The current implementation of `isInlinableLiteral16` assumes, a 16-bit inlinable literal is either an i16 or a fp16. This is not always true because of bf16. However, we can't tell fp16 and bf16 apart by just looking at the value. This patch tries to split `isInlinableLiteral16` into three versions, i16, fp16, bf16 respectively, and call the corresponding version. This patch is based on llvm#81282. The current status is, only two uses of original `isInlinableLiteral16` are still there. We need to add an extra argument to indicate the type of the operand the immediate corresponds to. This will also require the change of the function signature of the two callers.
1 parent e9a5322 commit 7df617d

12 files changed

+257
-115
lines changed

llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4106,7 +4106,7 @@ InstructionSelector::ComplexRendererFns
41064106
AMDGPUInstructionSelector::selectWMMAVISrc(MachineOperand &Root) const {
41074107
std::optional<FPValueAndVReg> FPValReg;
41084108
if (mi_match(Root.getReg(), *MRI, m_GFCstOrSplat(FPValReg))) {
4109-
if (TII.isInlineConstant(FPValReg->Value.bitcastToAPInt())) {
4109+
if (TII.isInlineConstant(FPValReg->Value)) {
41104110
return {{[=](MachineInstrBuilder &MIB) {
41114111
MIB.addImm(FPValReg->Value.bitcastToAPInt().getSExtValue());
41124112
}}};

llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1927,8 +1927,12 @@ static bool isInlineableLiteralOp16(int64_t Val, MVT VT, bool HasInv2Pi) {
19271927
return isInlinableIntLiteral(Val);
19281928
}
19291929

1930-
// f16/v2f16 operands work correctly for all values.
1931-
return AMDGPU::isInlinableLiteral16(Val, HasInv2Pi);
1930+
if (VT.getScalarType() == MVT::f16)
1931+
return AMDGPU::isInlinableLiteralFP16(Val, HasInv2Pi);
1932+
1933+
assert(VT.getScalarType() == MVT::bf16);
1934+
1935+
return AMDGPU::isInlinableLiteralBF16(Val, HasInv2Pi);
19321936
}
19331937

19341938
bool AMDGPUOperand::isInlinableImm(MVT type) const {
@@ -2277,15 +2281,26 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo
22772281
return;
22782282

22792283
case AMDGPU::OPERAND_REG_IMM_INT16:
2280-
case AMDGPU::OPERAND_REG_IMM_FP16:
2281-
case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
22822284
case AMDGPU::OPERAND_REG_INLINE_C_INT16:
2283-
case AMDGPU::OPERAND_REG_INLINE_C_FP16:
22842285
case AMDGPU::OPERAND_REG_INLINE_AC_INT16:
2286+
if (isSafeTruncation(Val, 16) &&
2287+
AMDGPU::isInlinableIntLiteral(static_cast<int16_t>(Val))) {
2288+
Inst.addOperand(MCOperand::createImm(Val));
2289+
setImmKindConst();
2290+
return;
2291+
}
2292+
2293+
Inst.addOperand(MCOperand::createImm(Val & 0xffff));
2294+
setImmKindLiteral();
2295+
return;
2296+
2297+
case AMDGPU::OPERAND_REG_INLINE_C_FP16:
2298+
case AMDGPU::OPERAND_REG_IMM_FP16:
2299+
case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
22852300
case AMDGPU::OPERAND_REG_INLINE_AC_FP16:
22862301
if (isSafeTruncation(Val, 16) &&
2287-
AMDGPU::isInlinableLiteral16(static_cast<int16_t>(Val),
2288-
AsmParser->hasInv2PiInlineImm())) {
2302+
AMDGPU::isInlinableLiteralFP16(static_cast<int16_t>(Val),
2303+
AsmParser->hasInv2PiInlineImm())) {
22892304
Inst.addOperand(MCOperand::createImm(Val));
22902305
setImmKindConst();
22912306
return;
@@ -2296,12 +2311,17 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo
22962311
return;
22972312

22982313
case AMDGPU::OPERAND_REG_INLINE_C_V2INT16:
2314+
case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16: {
2315+
assert(isSafeTruncation(Val, 16));
2316+
assert(AMDGPU::isInlinableIntLiteral(static_cast<int16_t>(Val)));
2317+
Inst.addOperand(MCOperand::createImm(Val));
2318+
return;
2319+
}
22992320
case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
2300-
case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
23012321
case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16: {
23022322
assert(isSafeTruncation(Val, 16));
2303-
assert(AMDGPU::isInlinableLiteral16(static_cast<int16_t>(Val),
2304-
AsmParser->hasInv2PiInlineImm()));
2323+
assert(AMDGPU::isInlinableLiteralFP16(static_cast<int16_t>(Val),
2324+
AsmParser->hasInv2PiInlineImm()));
23052325

23062326
Inst.addOperand(MCOperand::createImm(Val));
23072327
return;
@@ -3429,7 +3449,13 @@ bool AMDGPUAsmParser::isInlineConstant(const MCInst &Inst,
34293449
OperandType == AMDGPU::OPERAND_REG_IMM_V2FP16)
34303450
return AMDGPU::isInlinableLiteralV2F16(Val);
34313451

3432-
return AMDGPU::isInlinableLiteral16(Val, hasInv2PiInlineImm());
3452+
if (OperandType == AMDGPU::OPERAND_REG_IMM_FP16 ||
3453+
OperandType == AMDGPU::OPERAND_REG_INLINE_C_FP16 ||
3454+
OperandType == AMDGPU::OPERAND_REG_INLINE_AC_FP16 ||
3455+
OperandType == AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED)
3456+
return AMDGPU::isInlinableLiteralFP16(Val, hasInv2PiInlineImm());
3457+
3458+
llvm_unreachable("invalid operand type");
34333459
}
34343460
default:
34353461
llvm_unreachable("invalid operand size");

llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -462,8 +462,8 @@ void AMDGPUInstPrinter::printImmediateInt16(uint32_t Imm,
462462

463463
// This must accept a 32-bit immediate value to correctly handle packed 16-bit
464464
// operations.
465-
static bool printImmediateFloat16(uint32_t Imm, const MCSubtargetInfo &STI,
466-
raw_ostream &O) {
465+
static bool printImmediateFP16(uint32_t Imm, const MCSubtargetInfo &STI,
466+
raw_ostream &O) {
467467
if (Imm == 0x3C00)
468468
O << "1.0";
469469
else if (Imm == 0xBC00)
@@ -488,7 +488,7 @@ static bool printImmediateFloat16(uint32_t Imm, const MCSubtargetInfo &STI,
488488
return true;
489489
}
490490

491-
void AMDGPUInstPrinter::printImmediate16(uint32_t Imm,
491+
void AMDGPUInstPrinter::printImmediate16(uint32_t Imm, uint8_t OpType,
492492
const MCSubtargetInfo &STI,
493493
raw_ostream &O) {
494494
int16_t SImm = static_cast<int16_t>(Imm);
@@ -498,8 +498,17 @@ void AMDGPUInstPrinter::printImmediate16(uint32_t Imm,
498498
}
499499

500500
uint16_t HImm = static_cast<uint16_t>(Imm);
501-
if (printImmediateFloat16(HImm, STI, O))
502-
return;
501+
switch (OpType) {
502+
case AMDGPU::OPERAND_REG_IMM_FP16:
503+
case AMDGPU::OPERAND_REG_INLINE_C_FP16:
504+
case AMDGPU::OPERAND_REG_INLINE_AC_FP16:
505+
case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
506+
if (printImmediateFP16(HImm, STI, O))
507+
return;
508+
break;
509+
default:
510+
llvm_unreachable("bad operand type");
511+
}
503512

504513
uint64_t Imm16 = static_cast<uint16_t>(Imm);
505514
O << formatHex(Imm16);
@@ -525,7 +534,7 @@ void AMDGPUInstPrinter::printImmediateV216(uint32_t Imm, uint8_t OpType,
525534
case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
526535
case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
527536
if (isUInt<16>(Imm) &&
528-
printImmediateFloat16(static_cast<uint16_t>(Imm), STI, O))
537+
printImmediateFP16(static_cast<uint16_t>(Imm), STI, O))
529538
return;
530539
break;
531540
default:
@@ -797,7 +806,7 @@ void AMDGPUInstPrinter::printRegularOperand(const MCInst *MI, unsigned OpNo,
797806
case AMDGPU::OPERAND_REG_INLINE_AC_FP16:
798807
case AMDGPU::OPERAND_REG_IMM_FP16:
799808
case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
800-
printImmediate16(Op.getImm(), STI, O);
809+
printImmediate16(Op.getImm(), OpTy, STI, O);
801810
break;
802811
case AMDGPU::OPERAND_REG_IMM_V2INT16:
803812
case AMDGPU::OPERAND_REG_IMM_V2FP16:

llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,8 @@ class AMDGPUInstPrinter : public MCInstPrinter {
8686
raw_ostream &O);
8787
void printImmediateInt16(uint32_t Imm, const MCSubtargetInfo &STI,
8888
raw_ostream &O);
89-
void printImmediate16(uint32_t Imm, const MCSubtargetInfo &STI,
90-
raw_ostream &O);
89+
void printImmediate16(uint32_t Imm, uint8_t OpType,
90+
const MCSubtargetInfo &STI, raw_ostream &O);
9191
void printImmediateV216(uint32_t Imm, uint8_t OpType,
9292
const MCSubtargetInfo &STI, raw_ostream &O);
9393
bool printImmediateFloat32(uint32_t Imm, const MCSubtargetInfo &STI,

llvm/lib/Target/AMDGPU/SIISelLowering.cpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12965,10 +12965,8 @@ SDValue SITargetLowering::performFPMed3ImmCombine(SelectionDAG &DAG,
1296512965

1296612966
const SIInstrInfo *TII = getSubtarget()->getInstrInfo();
1296712967

12968-
if ((!K0->hasOneUse() ||
12969-
TII->isInlineConstant(K0->getValueAPF().bitcastToAPInt())) &&
12970-
(!K1->hasOneUse() ||
12971-
TII->isInlineConstant(K1->getValueAPF().bitcastToAPInt()))) {
12968+
if ((!K0->hasOneUse() || TII->isInlineConstant(K0->getValueAPF())) &&
12969+
(!K1->hasOneUse() || TII->isInlineConstant(K1->getValueAPF()))) {
1297212970
return DAG.getNode(AMDGPUISD::FMED3, SL, K0->getValueType(0),
1297312971
Var, SDValue(K0, 0), SDValue(K1, 0));
1297412972
}
@@ -15391,16 +15389,22 @@ bool SITargetLowering::checkAsmConstraintVal(SDValue Op, StringRef Constraint,
1539115389
llvm_unreachable("Invalid asm constraint");
1539215390
}
1539315391

15394-
bool SITargetLowering::checkAsmConstraintValA(SDValue Op,
15395-
uint64_t Val,
15392+
bool SITargetLowering::checkAsmConstraintValA(SDValue Op, uint64_t Val,
1539615393
unsigned MaxSize) const {
1539715394
unsigned Size = std::min<unsigned>(Op.getScalarValueSizeInBits(), MaxSize);
1539815395
bool HasInv2Pi = Subtarget->hasInv2PiInlineImm();
15399-
if ((Size == 16 && AMDGPU::isInlinableLiteral16(Val, HasInv2Pi)) ||
15400-
(Size == 32 && AMDGPU::isInlinableLiteral32(Val, HasInv2Pi)) ||
15401-
(Size == 64 && AMDGPU::isInlinableLiteral64(Val, HasInv2Pi))) {
15402-
return true;
15396+
if (Size == 16) {
15397+
MVT VT = Op.getSimpleValueType();
15398+
if (VT == MVT::i16 && AMDGPU::isInlinableLiteralI16(Val, HasInv2Pi))
15399+
return true;
15400+
if (VT == MVT::f16 && AMDGPU::isInlinableLiteralFP16(Val, HasInv2Pi))
15401+
return true;
15402+
if (VT == MVT::bf16 && AMDGPU::isInlinableLiteralBF16(Val, HasInv2Pi))
15403+
return true;
1540315404
}
15405+
if ((Size == 32 && AMDGPU::isInlinableLiteral32(Val, HasInv2Pi)) ||
15406+
(Size == 64 && AMDGPU::isInlinableLiteral64(Val, HasInv2Pi)))
15407+
return true;
1540415408
return false;
1540515409
}
1540615410

llvm/lib/Target/AMDGPU/SIInstrInfo.cpp

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4121,8 +4121,27 @@ bool SIInstrInfo::isInlineConstant(const APInt &Imm) const {
41214121
ST.hasInv2PiInlineImm());
41224122
case 16:
41234123
return ST.has16BitInsts() &&
4124-
AMDGPU::isInlinableLiteral16(Imm.getSExtValue(),
4125-
ST.hasInv2PiInlineImm());
4124+
AMDGPU::isInlinableLiteralI16(Imm.getSExtValue(),
4125+
ST.hasInv2PiInlineImm());
4126+
default:
4127+
llvm_unreachable("invalid bitwidth");
4128+
}
4129+
}
4130+
4131+
bool SIInstrInfo::isInlineConstant(const APFloat &Imm) const {
4132+
APInt IntImm = Imm.bitcastToAPInt();
4133+
bool HasInv2Pi = ST.hasInv2PiInlineImm();
4134+
switch (IntImm.getBitWidth()) {
4135+
case 32:
4136+
case 64:
4137+
return isInlineConstant(IntImm);
4138+
case 16:
4139+
if (Imm.isIEEE())
4140+
return ST.has16BitInsts() &&
4141+
AMDGPU::isInlinableLiteralFP16(IntImm.getSExtValue(), HasInv2Pi);
4142+
else
4143+
return ST.has16BitInsts() &&
4144+
AMDGPU::isInlinableLiteralBF16(IntImm.getSExtValue(), HasInv2Pi);
41264145
default:
41274146
llvm_unreachable("invalid bitwidth");
41284147
}
@@ -4196,7 +4215,7 @@ bool SIInstrInfo::isInlineConstant(const MachineOperand &MO,
41964215
// constants in these cases
41974216
int16_t Trunc = static_cast<int16_t>(Imm);
41984217
return ST.has16BitInsts() &&
4199-
AMDGPU::isInlinableLiteral16(Trunc, ST.hasInv2PiInlineImm());
4218+
AMDGPU::isInlinableLiteralFP16(Trunc, ST.hasInv2PiInlineImm());
42004219
}
42014220

42024221
return false;

llvm/lib/Target/AMDGPU/SIInstrInfo.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -966,9 +966,7 @@ class SIInstrInfo final : public AMDGPUGenInstrInfo {
966966

967967
bool isInlineConstant(const APInt &Imm) const;
968968

969-
bool isInlineConstant(const APFloat &Imm) const {
970-
return isInlineConstant(Imm.bitcastToAPInt());
971-
}
969+
bool isInlineConstant(const APFloat &Imm) const;
972970

973971
// Returns true if this non-register operand definitely does not need to be
974972
// encoded as a 32-bit literal. Note that this function handles all kinds of

llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2652,13 +2652,28 @@ bool isInlinableLiteral32(int32_t Literal, bool HasInv2Pi) {
26522652
(Val == 0x3e22f983 && HasInv2Pi);
26532653
}
26542654

2655-
bool isInlinableLiteral16(int16_t Literal, bool HasInv2Pi) {
2655+
bool isInlinableLiteralI16(int16_t Literal, bool HasInv2Pi) {
2656+
if (!HasInv2Pi)
2657+
return false;
2658+
if (isInlinableIntLiteral(Literal))
2659+
return true;
2660+
return (Literal == static_cast<int16_t>(llvm::bit_cast<uint32_t>(0.0f))) ||
2661+
(Literal == static_cast<int16_t>(llvm::bit_cast<uint32_t>(1.0f))) ||
2662+
(Literal == static_cast<int16_t>(llvm::bit_cast<uint32_t>(-1.0f))) ||
2663+
(Literal == static_cast<int16_t>(llvm::bit_cast<uint32_t>(0.5f))) ||
2664+
(Literal == static_cast<int16_t>(llvm::bit_cast<uint32_t>(-0.5f))) ||
2665+
(Literal == static_cast<int16_t>(llvm::bit_cast<uint32_t>(2.0f))) ||
2666+
(Literal == static_cast<int16_t>(llvm::bit_cast<uint32_t>(-2.0f))) ||
2667+
(Literal == static_cast<int16_t>(llvm::bit_cast<uint32_t>(4.0f))) ||
2668+
(Literal == static_cast<int16_t>(llvm::bit_cast<uint32_t>(-4.0f))) ||
2669+
(Literal == static_cast<int16_t>(0x3e22f983));
2670+
}
2671+
2672+
bool isInlinableLiteralFP16(int16_t Literal, bool HasInv2Pi) {
26562673
if (!HasInv2Pi)
26572674
return false;
2658-
26592675
if (isInlinableIntLiteral(Literal))
26602676
return true;
2661-
26622677
uint16_t Val = static_cast<uint16_t>(Literal);
26632678
return Val == 0x3C00 || // 1.0
26642679
Val == 0xBC00 || // -1.0
@@ -2671,6 +2686,23 @@ bool isInlinableLiteral16(int16_t Literal, bool HasInv2Pi) {
26712686
Val == 0x3118; // 1/2pi
26722687
}
26732688

2689+
bool isInlinableLiteralBF16(int16_t Literal, bool HasInv2Pi) {
2690+
if (!HasInv2Pi)
2691+
return false;
2692+
if (isInlinableIntLiteral(Literal))
2693+
return true;
2694+
uint16_t Val = static_cast<uint16_t>(Literal);
2695+
return Val == 0x3F00 || // 0.5
2696+
Val == 0xBF00 || // -0.5
2697+
Val == 0x3F80 || // 1.0
2698+
Val == 0xBF80 || // -1.0
2699+
Val == 0x4000 || // 2.0
2700+
Val == 0xC000 || // -2.0
2701+
Val == 0x4080 || // 4.0
2702+
Val == 0xC080 || // -4.0
2703+
Val == 0x3E22; // 1.0 / (2.0 * pi)
2704+
}
2705+
26742706
std::optional<unsigned> getInlineEncodingV216(bool IsFloat, uint32_t Literal) {
26752707
// Unfortunately, the Instruction Set Architecture Reference Guide is
26762708
// misleading about how the inline operands work for (packed) 16-bit

llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1374,7 +1374,13 @@ LLVM_READNONE
13741374
bool isInlinableLiteral32(int32_t Literal, bool HasInv2Pi);
13751375

13761376
LLVM_READNONE
1377-
bool isInlinableLiteral16(int16_t Literal, bool HasInv2Pi);
1377+
bool isInlinableLiteralFP16(int16_t Literal, bool HasInv2Pi);
1378+
1379+
LLVM_READNONE
1380+
bool isInlinableLiteralBF16(int16_t Literal, bool HasInv2Pi);
1381+
1382+
LLVM_READNONE
1383+
bool isInlinableLiteralI16(int16_t Literal, bool HasInv2Pi);
13781384

13791385
LLVM_READNONE
13801386
std::optional<unsigned> getInlineEncodingV2I16(uint32_t Literal);

0 commit comments

Comments
 (0)