Skip to content

Commit ddd7433

Browse files
committed
[ADMGPU] Replace isInlinableLiteral16 with specific version
The current implementation of `isInlinableLiteral16` assumes, a 16-bit inlinable literal is either an `i16` or a `fp16`. This is not always true because of `bf16`. However, we can't tell `fp16` and `bf16` apart by just looking at the value. This patch splits `isInlinableLiteral16` into three versions, `i16`, `fp16`, `bf16` respectively, and call the corresponding version.
1 parent c54e052 commit ddd7433

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+1180
-1044
lines changed

llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3327,35 +3327,41 @@ bool AMDGPUDAGToDAGISel::SelectWMMAVISrc(SDValue In, SDValue &Src) const {
33273327

33283328
// 16 bit splat
33293329
SDValue SplatSrc32 = stripBitcast(In);
3330-
if (auto *SplatSrc32BV = dyn_cast<BuildVectorSDNode>(SplatSrc32)) {
3330+
if (auto *SplatSrc32BV = dyn_cast<BuildVectorSDNode>(SplatSrc32))
33313331
if (SDValue Splat32 = SplatSrc32BV->getSplatValue()) {
33323332
SDValue SplatSrc16 = stripBitcast(Splat32);
3333-
if (auto *SplatSrc16BV = dyn_cast<BuildVectorSDNode>(SplatSrc16)) {
3333+
if (auto *SplatSrc16BV = dyn_cast<BuildVectorSDNode>(SplatSrc16))
33343334
if (SDValue Splat = SplatSrc16BV->getSplatValue()) {
3335-
3336-
// f16
3337-
if (isInlineImmediate(Splat.getNode())) {
3338-
const ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(Splat);
3339-
int64_t Imm = C->getValueAPF().bitcastToAPInt().getSExtValue();
3340-
Src = CurDAG->getTargetConstant(Imm, SDLoc(In), MVT::i16);
3341-
return true;
3342-
}
3343-
3344-
// bf16
3345-
if (const ConstantSDNode *C = dyn_cast<ConstantSDNode>(Splat)) {
3346-
const SIInstrInfo *TII = Subtarget->getInstrInfo();
3347-
APInt BF16Value = C->getAPIntValue();
3348-
APInt F32Value = BF16Value.zext(32).shl(16);
3349-
if (TII->isInlineConstant(F32Value)) {
3350-
int64_t Imm = F32Value.getSExtValue();
3351-
Src = CurDAG->getTargetConstant(Imm, SDLoc(In), MVT::i32);
3352-
return true;
3353-
}
3335+
const SIInstrInfo *TII = Subtarget->getInstrInfo();
3336+
std::optional<APInt> RawValue;
3337+
if (const ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(Splat))
3338+
RawValue = C->getValueAPF().bitcastToAPInt();
3339+
else if (const ConstantSDNode *C = dyn_cast<ConstantSDNode>(Splat))
3340+
RawValue = C->getAPIntValue();
3341+
3342+
if (RawValue.has_value()) {
3343+
EVT VT = In.getValueType().getScalarType();
3344+
if (VT.getSimpleVT() == MVT::f16 || VT.getSimpleVT() == MVT::bf16) {
3345+
APFloat FloatVal(VT.getSimpleVT() == MVT::f16
3346+
? APFloatBase::IEEEhalf()
3347+
: APFloatBase::BFloat(),
3348+
RawValue.value());
3349+
if (TII->isInlineConstant(FloatVal)) {
3350+
Src = CurDAG->getTargetConstant(RawValue.value(), SDLoc(In),
3351+
MVT::i16);
3352+
return true;
3353+
}
3354+
} else if (VT.getSimpleVT() == MVT::i16) {
3355+
if (TII->isInlineConstant(RawValue.value())) {
3356+
Src = CurDAG->getTargetConstant(RawValue.value(), SDLoc(In),
3357+
MVT::i16);
3358+
return true;
3359+
}
3360+
} else
3361+
llvm_unreachable("unknown 16-bit type");
33543362
}
33553363
}
3356-
}
33573364
}
3358-
}
33593365

33603366
return false;
33613367
}

llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp

Lines changed: 75 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1926,6 +1926,11 @@ static const fltSemantics *getFltSemantics(MVT VT) {
19261926

19271927
static const fltSemantics *getOpFltSemantics(uint8_t OperandType) {
19281928
switch (OperandType) {
1929+
// When floating-point immediate is used as operand of type i16, the 32-bit
1930+
// representation of the constant truncated to the 16 LSBs should be used.
1931+
case AMDGPU::OPERAND_REG_IMM_INT16:
1932+
case AMDGPU::OPERAND_REG_INLINE_C_INT16:
1933+
case AMDGPU::OPERAND_REG_INLINE_AC_INT16:
19291934
case AMDGPU::OPERAND_REG_IMM_INT32:
19301935
case AMDGPU::OPERAND_REG_IMM_FP32:
19311936
case AMDGPU::OPERAND_REG_IMM_FP32_DEFERRED:
@@ -1949,13 +1954,10 @@ static const fltSemantics *getOpFltSemantics(uint8_t OperandType) {
19491954
case AMDGPU::OPERAND_REG_INLINE_C_FP64:
19501955
case AMDGPU::OPERAND_REG_INLINE_AC_FP64:
19511956
return &APFloat::IEEEdouble();
1952-
case AMDGPU::OPERAND_REG_IMM_INT16:
19531957
case AMDGPU::OPERAND_REG_IMM_FP16:
19541958
case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
1955-
case AMDGPU::OPERAND_REG_INLINE_C_INT16:
19561959
case AMDGPU::OPERAND_REG_INLINE_C_FP16:
19571960
case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
1958-
case AMDGPU::OPERAND_REG_INLINE_AC_INT16:
19591961
case AMDGPU::OPERAND_REG_INLINE_AC_FP16:
19601962
case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
19611963
case AMDGPU::OPERAND_REG_IMM_V2FP16:
@@ -2001,13 +2003,15 @@ static bool isSafeTruncation(int64_t Val, unsigned Size) {
20012003
}
20022004

20032005
static bool isInlineableLiteralOp16(int64_t Val, MVT VT, bool HasInv2Pi) {
2004-
if (VT.getScalarType() == MVT::i16) {
2005-
// FP immediate values are broken.
2006-
return isInlinableIntLiteral(Val);
2007-
}
2006+
if (VT.getScalarType() == MVT::i16)
2007+
return isInlinableLiteral32(Val, HasInv2Pi);
2008+
2009+
if (VT.getScalarType() == MVT::f16)
2010+
return AMDGPU::isInlinableLiteralFP16(Val, HasInv2Pi);
20082011

2009-
// f16/v2f16 operands work correctly for all values.
2010-
return AMDGPU::isInlinableLiteral16(Val, HasInv2Pi);
2012+
assert(VT.getScalarType() == MVT::bf16);
2013+
2014+
return AMDGPU::isInlinableLiteralBF16(Val, HasInv2Pi);
20112015
}
20122016

20132017
bool AMDGPUOperand::isInlinableImm(MVT type) const {
@@ -2041,9 +2045,30 @@ bool AMDGPUOperand::isInlinableImm(MVT type) const {
20412045
return false;
20422046

20432047
if (type.getScalarSizeInBits() == 16) {
2044-
return isInlineableLiteralOp16(
2045-
static_cast<int16_t>(FPLiteral.bitcastToAPInt().getZExtValue()),
2046-
type, AsmParser->hasInv2PiInlineImm());
2048+
bool Lost = false;
2049+
switch (type.getScalarType().SimpleTy) {
2050+
default:
2051+
llvm_unreachable("unknown 16-bit type");
2052+
case MVT::bf16:
2053+
FPLiteral.convert(APFloatBase::BFloat(), APFloat::rmNearestTiesToEven,
2054+
&Lost);
2055+
break;
2056+
case MVT::f16:
2057+
FPLiteral.convert(APFloatBase::IEEEhalf(), APFloat::rmNearestTiesToEven,
2058+
&Lost);
2059+
break;
2060+
case MVT::i16:
2061+
FPLiteral.convert(APFloatBase::IEEEsingle(),
2062+
APFloat::rmNearestTiesToEven, &Lost);
2063+
break;
2064+
}
2065+
// We need to use 32-bit representation here because when a floating-point
2066+
// inline constant is used as an i16 operand, its 32-bit representation
2067+
// representation will be used. We will need the 32-bit value to check if
2068+
// it is FP inline constant.
2069+
uint32_t ImmVal = FPLiteral.bitcastToAPInt().getZExtValue();
2070+
return isInlineableLiteralOp16(ImmVal, type,
2071+
AsmParser->hasInv2PiInlineImm());
20472072
}
20482073

20492074
// Check if single precision literal is inlinable
@@ -2375,15 +2400,26 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo
23752400
return;
23762401

23772402
case AMDGPU::OPERAND_REG_IMM_INT16:
2378-
case AMDGPU::OPERAND_REG_IMM_FP16:
2379-
case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
23802403
case AMDGPU::OPERAND_REG_INLINE_C_INT16:
2381-
case AMDGPU::OPERAND_REG_INLINE_C_FP16:
23822404
case AMDGPU::OPERAND_REG_INLINE_AC_INT16:
2405+
if (isSafeTruncation(Val, 16) &&
2406+
AMDGPU::isInlinableIntLiteral(static_cast<int16_t>(Val))) {
2407+
Inst.addOperand(MCOperand::createImm(Val & 0xffffffff));
2408+
setImmKindConst();
2409+
return;
2410+
}
2411+
2412+
Inst.addOperand(MCOperand::createImm(Val & 0xffff));
2413+
setImmKindLiteral();
2414+
return;
2415+
2416+
case AMDGPU::OPERAND_REG_INLINE_C_FP16:
2417+
case AMDGPU::OPERAND_REG_IMM_FP16:
2418+
case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
23832419
case AMDGPU::OPERAND_REG_INLINE_AC_FP16:
23842420
if (isSafeTruncation(Val, 16) &&
2385-
AMDGPU::isInlinableLiteral16(static_cast<int16_t>(Val),
2386-
AsmParser->hasInv2PiInlineImm())) {
2421+
AMDGPU::isInlinableLiteralFP16(static_cast<int16_t>(Val),
2422+
AsmParser->hasInv2PiInlineImm())) {
23872423
Inst.addOperand(MCOperand::createImm(Val));
23882424
setImmKindConst();
23892425
return;
@@ -2410,12 +2446,17 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo
24102446
return;
24112447

24122448
case AMDGPU::OPERAND_REG_INLINE_C_V2INT16:
2449+
case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16: {
2450+
assert(isSafeTruncation(Val, 16));
2451+
assert(AMDGPU::isInlinableIntLiteral(static_cast<int16_t>(Val)));
2452+
Inst.addOperand(MCOperand::createImm(Val));
2453+
return;
2454+
}
24132455
case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
2414-
case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
24152456
case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16: {
24162457
assert(isSafeTruncation(Val, 16));
2417-
assert(AMDGPU::isInlinableLiteral16(static_cast<int16_t>(Val),
2418-
AsmParser->hasInv2PiInlineImm()));
2458+
assert(AMDGPU::isInlinableLiteralFP16(static_cast<int16_t>(Val),
2459+
AsmParser->hasInv2PiInlineImm()));
24192460

24202461
Inst.addOperand(MCOperand::createImm(Val));
24212462
return;
@@ -3542,7 +3583,7 @@ bool AMDGPUAsmParser::isInlineConstant(const MCInst &Inst,
35423583
if (OperandType == AMDGPU::OPERAND_REG_IMM_INT16 ||
35433584
OperandType == AMDGPU::OPERAND_REG_INLINE_C_INT16 ||
35443585
OperandType == AMDGPU::OPERAND_REG_INLINE_AC_INT16)
3545-
return AMDGPU::isInlinableIntLiteral(Val);
3586+
return AMDGPU::isInlinableLiteralI16(Val, hasInv2PiInlineImm());
35463587

35473588
if (OperandType == AMDGPU::OPERAND_REG_INLINE_C_V2INT16 ||
35483589
OperandType == AMDGPU::OPERAND_REG_INLINE_AC_V2INT16 ||
@@ -3559,7 +3600,19 @@ bool AMDGPUAsmParser::isInlineConstant(const MCInst &Inst,
35593600
OperandType == AMDGPU::OPERAND_REG_IMM_V2BF16)
35603601
return AMDGPU::isInlinableLiteralV2BF16(Val);
35613602

3562-
return AMDGPU::isInlinableLiteral16(Val, hasInv2PiInlineImm());
3603+
if (OperandType == AMDGPU::OPERAND_REG_IMM_FP16 ||
3604+
OperandType == AMDGPU::OPERAND_REG_INLINE_C_FP16 ||
3605+
OperandType == AMDGPU::OPERAND_REG_INLINE_AC_FP16 ||
3606+
OperandType == AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED)
3607+
return AMDGPU::isInlinableLiteralFP16(Val, hasInv2PiInlineImm());
3608+
3609+
if (OperandType == AMDGPU::OPERAND_REG_IMM_BF16 ||
3610+
OperandType == AMDGPU::OPERAND_REG_INLINE_C_BF16 ||
3611+
OperandType == AMDGPU::OPERAND_REG_INLINE_AC_BF16 ||
3612+
OperandType == AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED)
3613+
return AMDGPU::isInlinableLiteralBF16(Val, hasInv2PiInlineImm());
3614+
3615+
llvm_unreachable("invalid operand type");
35633616
}
35643617
default:
35653618
llvm_unreachable("invalid operand size");

llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -451,19 +451,20 @@ void AMDGPUInstPrinter::printVINTRPDst(const MCInst *MI, unsigned OpNo,
451451
void AMDGPUInstPrinter::printImmediateInt16(uint32_t Imm,
452452
const MCSubtargetInfo &STI,
453453
raw_ostream &O) {
454-
int16_t SImm = static_cast<int16_t>(Imm);
454+
int32_t SImm = static_cast<int32_t>(Imm);
455455
if (isInlinableIntLiteral(SImm)) {
456456
O << SImm;
457-
} else {
458-
uint64_t Imm16 = static_cast<uint16_t>(Imm);
459-
O << formatHex(Imm16);
457+
return;
460458
}
459+
460+
if (printImmediateFloat32(Imm, STI, O))
461+
return;
462+
463+
O << formatHex(static_cast<uint64_t>(Imm & 0xffff));
461464
}
462465

463-
// This must accept a 32-bit immediate value to correctly handle packed 16-bit
464-
// operations.
465-
static bool printImmediateFloat16(uint32_t Imm, const MCSubtargetInfo &STI,
466-
raw_ostream &O) {
466+
static bool printImmediateFP16(uint32_t Imm, const MCSubtargetInfo &STI,
467+
raw_ostream &O) {
467468
if (Imm == 0x3C00)
468469
O << "1.0";
469470
else if (Imm == 0xBC00)
@@ -529,17 +530,17 @@ void AMDGPUInstPrinter::printImmediateBF16(uint32_t Imm,
529530
O << formatHex(static_cast<uint64_t>(Imm));
530531
}
531532

532-
void AMDGPUInstPrinter::printImmediate16(uint32_t Imm,
533-
const MCSubtargetInfo &STI,
534-
raw_ostream &O) {
533+
void AMDGPUInstPrinter::printImmediateF16(uint32_t Imm,
534+
const MCSubtargetInfo &STI,
535+
raw_ostream &O) {
535536
int16_t SImm = static_cast<int16_t>(Imm);
536537
if (isInlinableIntLiteral(SImm)) {
537538
O << SImm;
538539
return;
539540
}
540541

541542
uint16_t HImm = static_cast<uint16_t>(Imm);
542-
if (printImmediateFloat16(HImm, STI, O))
543+
if (printImmediateFP16(HImm, STI, O))
543544
return;
544545

545546
uint64_t Imm16 = static_cast<uint16_t>(Imm);
@@ -566,7 +567,7 @@ void AMDGPUInstPrinter::printImmediateV216(uint32_t Imm, uint8_t OpType,
566567
case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
567568
case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
568569
if (isUInt<16>(Imm) &&
569-
printImmediateFloat16(static_cast<uint16_t>(Imm), STI, O))
570+
printImmediateFP16(static_cast<uint16_t>(Imm), STI, O))
570571
return;
571572
break;
572573
case AMDGPU::OPERAND_REG_IMM_V2BF16:
@@ -845,7 +846,7 @@ void AMDGPUInstPrinter::printRegularOperand(const MCInst *MI, unsigned OpNo,
845846
case AMDGPU::OPERAND_REG_INLINE_AC_FP16:
846847
case AMDGPU::OPERAND_REG_IMM_FP16:
847848
case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
848-
printImmediate16(Op.getImm(), STI, O);
849+
printImmediateF16(Op.getImm(), STI, O);
849850
break;
850851
case AMDGPU::OPERAND_REG_INLINE_C_BF16:
851852
case AMDGPU::OPERAND_REG_INLINE_AC_BF16:

llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,10 @@ class AMDGPUInstPrinter : public MCInstPrinter {
8686
raw_ostream &O);
8787
void printImmediateInt16(uint32_t Imm, const MCSubtargetInfo &STI,
8888
raw_ostream &O);
89-
void printImmediate16(uint32_t Imm, const MCSubtargetInfo &STI,
90-
raw_ostream &O);
9189
void printImmediateBF16(uint32_t Imm, const MCSubtargetInfo &STI,
9290
raw_ostream &O);
91+
void printImmediateF16(uint32_t Imm, const MCSubtargetInfo &STI,
92+
raw_ostream &O);
9393
void printImmediateV216(uint32_t Imm, uint8_t OpType,
9494
const MCSubtargetInfo &STI, raw_ostream &O);
9595
bool printImmediateFloat32(uint32_t Imm, const MCSubtargetInfo &STI,

llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,6 @@ static uint32_t getIntInlineImmEncoding(IntTy Imm) {
116116
return 0;
117117
}
118118

119-
static uint32_t getLit16IntEncoding(uint16_t Val, const MCSubtargetInfo &STI) {
120-
uint16_t IntImm = getIntInlineImmEncoding(static_cast<int16_t>(Val));
121-
return IntImm == 0 ? 255 : IntImm;
122-
}
123-
124119
static uint32_t getLit16Encoding(uint16_t Val, const MCSubtargetInfo &STI) {
125120
uint16_t IntImm = getIntInlineImmEncoding(static_cast<int16_t>(Val));
126121
if (IntImm != 0)
@@ -214,6 +209,10 @@ static uint32_t getLit32Encoding(uint32_t Val, const MCSubtargetInfo &STI) {
214209
return 255;
215210
}
216211

212+
static uint32_t getLit16IntEncoding(uint32_t Val, const MCSubtargetInfo &STI) {
213+
return getLit32Encoding(Val, STI);
214+
}
215+
217216
static uint32_t getLit64Encoding(uint64_t Val, const MCSubtargetInfo &STI) {
218217
uint32_t IntImm = getIntInlineImmEncoding(static_cast<int64_t>(Val));
219218
if (IntImm != 0)
@@ -296,7 +295,7 @@ AMDGPUMCCodeEmitter::getLitEncoding(const MCOperand &MO,
296295
case AMDGPU::OPERAND_REG_IMM_INT16:
297296
case AMDGPU::OPERAND_REG_INLINE_C_INT16:
298297
case AMDGPU::OPERAND_REG_INLINE_AC_INT16:
299-
return getLit16IntEncoding(static_cast<uint16_t>(Imm), STI);
298+
return getLit16IntEncoding(static_cast<uint32_t>(Imm), STI);
300299

301300
case AMDGPU::OPERAND_REG_IMM_FP16:
302301
case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:

llvm/lib/Target/AMDGPU/SIISelLowering.cpp

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15495,16 +15495,32 @@ bool SITargetLowering::checkAsmConstraintVal(SDValue Op, StringRef Constraint,
1549515495
llvm_unreachable("Invalid asm constraint");
1549615496
}
1549715497

15498-
bool SITargetLowering::checkAsmConstraintValA(SDValue Op,
15499-
uint64_t Val,
15498+
bool SITargetLowering::checkAsmConstraintValA(SDValue Op, uint64_t Val,
1550015499
unsigned MaxSize) const {
1550115500
unsigned Size = std::min<unsigned>(Op.getScalarValueSizeInBits(), MaxSize);
1550215501
bool HasInv2Pi = Subtarget->hasInv2PiInlineImm();
15503-
if ((Size == 16 && AMDGPU::isInlinableLiteral16(Val, HasInv2Pi)) ||
15504-
(Size == 32 && AMDGPU::isInlinableLiteral32(Val, HasInv2Pi)) ||
15505-
(Size == 64 && AMDGPU::isInlinableLiteral64(Val, HasInv2Pi))) {
15506-
return true;
15502+
if (Size == 16) {
15503+
MVT VT = Op.getSimpleValueType();
15504+
switch (VT.SimpleTy) {
15505+
default:
15506+
return false;
15507+
case MVT::i16:
15508+
return AMDGPU::isInlinableLiteralI16(Val, HasInv2Pi);
15509+
case MVT::f16:
15510+
return AMDGPU::isInlinableLiteralFP16(Val, HasInv2Pi);
15511+
case MVT::bf16:
15512+
return AMDGPU::isInlinableLiteralBF16(Val, HasInv2Pi);
15513+
case MVT::v2i16:
15514+
return AMDGPU::getInlineEncodingV2I16(Val).has_value();
15515+
case MVT::v2f16:
15516+
return AMDGPU::getInlineEncodingV2F16(Val).has_value();
15517+
case MVT::v2bf16:
15518+
return AMDGPU::getInlineEncodingV2BF16(Val).has_value();
15519+
}
1550715520
}
15521+
if ((Size == 32 && AMDGPU::isInlinableLiteral32(Val, HasInv2Pi)) ||
15522+
(Size == 64 && AMDGPU::isInlinableLiteral64(Val, HasInv2Pi)))
15523+
return true;
1550815524
return false;
1550915525
}
1551015526

0 commit comments

Comments
 (0)