Skip to content

Commit 3c93e21

Browse files
committed
[SelectionDAG] Add STRICT_BF16_TO_FP and STRICT_FP_TO_BF16
This patch adds the support for `STRICT_BF16_TO_FP` and `STRICT_FP_TO_BF16`. Fix #78540.
1 parent 2ba94bf commit 3c93e21

File tree

8 files changed

+99
-26
lines changed

8 files changed

+99
-26
lines changed

llvm/include/llvm/CodeGen/ISDOpcodes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -921,6 +921,8 @@ enum NodeType {
921921
/// has native conversions.
922922
BF16_TO_FP,
923923
FP_TO_BF16,
924+
STRICT_BF16_TO_FP,
925+
STRICT_FP_TO_BF16,
924926

925927
/// Perform various unary floating-point operations inspired by libm. For
926928
/// FPOWI, the result is undefined if the integer operand doesn't fit into

llvm/include/llvm/CodeGen/SelectionDAGNodes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -698,6 +698,8 @@ END_TWO_BYTE_PACK()
698698
return false;
699699
case ISD::STRICT_FP16_TO_FP:
700700
case ISD::STRICT_FP_TO_FP16:
701+
case ISD::STRICT_BF16_TO_FP:
702+
case ISD::STRICT_FP_TO_BF16:
701703
#define DAG_INSTRUCTION(NAME, NARG, ROUND_MODE, INTRINSIC, DAGN) \
702704
case ISD::STRICT_##DAGN:
703705
#include "llvm/IR/ConstrainedOps.def"

llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1033,6 +1033,7 @@ void SelectionDAGLegalize::LegalizeOp(SDNode *Node) {
10331033
Node->getOperand(0).getValueType());
10341034
break;
10351035
case ISD::STRICT_FP_TO_FP16:
1036+
case ISD::STRICT_FP_TO_BF16:
10361037
case ISD::STRICT_SINT_TO_FP:
10371038
case ISD::STRICT_UINT_TO_FP:
10381039
case ISD::STRICT_LRINT:
@@ -3248,12 +3249,17 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
32483249
Results.push_back(Tmp1);
32493250
break;
32503251
}
3252+
case ISD::STRICT_BF16_TO_FP:
3253+
// When it reaches here, the target chooses to do that for strictfp. Since
3254+
// we don't technically have strict variant of the conversion, we falls back
3255+
// to the non-strict one.
3256+
LLVM_FALLTHROUGH;
32513257
case ISD::BF16_TO_FP: {
32523258
// Always expand bf16 to f32 casts, they lower to ext + shift.
32533259
//
32543260
// Note that the operand of this code can be bf16 or an integer type in case
32553261
// bf16 is not supported on the target and was softened.
3256-
SDValue Op = Node->getOperand(0);
3262+
SDValue Op = Node->getOperand(Node->getOpcode() == ISD::BF16_TO_FP ? 0 : 1);
32573263
if (Op.getValueType() == MVT::bf16) {
32583264
Op = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i32,
32593265
DAG.getNode(ISD::BITCAST, dl, MVT::i16, Op));
@@ -3271,10 +3277,17 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
32713277
Results.push_back(Op);
32723278
break;
32733279
}
3280+
case ISD::STRICT_FP_TO_BF16:
3281+
// When it reaches here, the target chooses to do that for strictfp. Since
3282+
// we don't technically have strict variant of the conversion, we falls back
3283+
// to the non-strict one.
3284+
LLVM_FALLTHROUGH;
32743285
case ISD::FP_TO_BF16: {
3275-
SDValue Op = Node->getOperand(0);
3286+
bool IsStrictFP = Node->getOpcode() == ISD::STRICT_FP_TO_BF16;
3287+
SDValue Op = Node->getOperand(IsStrictFP ? 1 : 0);
32763288
if (Op.getValueType() != MVT::f32)
3277-
Op = DAG.getNode(ISD::FP_ROUND, dl, MVT::f32, Op,
3289+
Op = DAG.getNode(IsStrictFP ? ISD::STRICT_FP_ROUND : ISD::FP_ROUND, dl,
3290+
MVT::f32, Op,
32783291
DAG.getIntPtrConstant(0, dl, /*isTarget=*/true));
32793292
Op = DAG.getNode(
32803293
ISD::SRL, dl, MVT::i32, DAG.getNode(ISD::BITCAST, dl, MVT::i32, Op),
@@ -4773,12 +4786,17 @@ void SelectionDAGLegalize::ConvertNodeToLibcall(SDNode *Node) {
47734786
break;
47744787
}
47754788
case ISD::STRICT_FP_EXTEND:
4776-
case ISD::STRICT_FP_TO_FP16: {
4777-
RTLIB::Libcall LC =
4778-
Node->getOpcode() == ISD::STRICT_FP_TO_FP16
4779-
? RTLIB::getFPROUND(Node->getOperand(1).getValueType(), MVT::f16)
4780-
: RTLIB::getFPEXT(Node->getOperand(1).getValueType(),
4781-
Node->getValueType(0));
4789+
case ISD::STRICT_FP_TO_FP16:
4790+
case ISD::STRICT_FP_TO_BF16: {
4791+
RTLIB::Libcall LC = RTLIB::UNKNOWN_LIBCALL;
4792+
if (Node->getOpcode() == ISD::STRICT_FP_TO_FP16)
4793+
LC = RTLIB::getFPROUND(Node->getOperand(1).getValueType(), MVT::f16);
4794+
else if (Node->getOpcode() == ISD::STRICT_FP_TO_BF16)
4795+
LC = RTLIB::getFPROUND(Node->getOperand(1).getValueType(), MVT::bf16);
4796+
else
4797+
LC = RTLIB::getFPEXT(Node->getOperand(1).getValueType(),
4798+
Node->getValueType(0));
4799+
47824800
assert(LC != RTLIB::UNKNOWN_LIBCALL && "Unable to legalize as libcall");
47834801

47844802
TargetLowering::MakeLibCallOptions CallOptions;

llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -918,6 +918,7 @@ bool DAGTypeLegalizer::SoftenFloatOperand(SDNode *N, unsigned OpNo) {
918918
case ISD::STRICT_FP_TO_FP16:
919919
case ISD::FP_TO_FP16: // Same as FP_ROUND for softening purposes
920920
case ISD::FP_TO_BF16:
921+
case ISD::STRICT_FP_TO_BF16:
921922
case ISD::STRICT_FP_ROUND:
922923
case ISD::FP_ROUND: Res = SoftenFloatOp_FP_ROUND(N); break;
923924
case ISD::STRICT_FP_TO_SINT:
@@ -2193,13 +2194,11 @@ static ISD::NodeType GetPromotionOpcodeStrict(EVT OpVT, EVT RetVT) {
21932194
if (RetVT == MVT::f16)
21942195
return ISD::STRICT_FP_TO_FP16;
21952196

2196-
if (OpVT == MVT::bf16) {
2197-
// TODO: return ISD::STRICT_BF16_TO_FP;
2198-
}
2197+
if (OpVT == MVT::bf16)
2198+
return ISD::STRICT_BF16_TO_FP;
21992199

2200-
if (RetVT == MVT::bf16) {
2201-
// TODO: return ISD::STRICT_FP_TO_BF16;
2202-
}
2200+
if (RetVT == MVT::bf16)
2201+
return ISD::STRICT_FP_TO_BF16;
22032202

22042203
report_fatal_error("Attempt at an invalid promotion-related conversion");
22052204
}
@@ -2999,10 +2998,16 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfRes_FP_ROUND(SDNode *N) {
29992998
EVT SVT = N->getOperand(0).getValueType();
30002999

30013000
if (N->isStrictFPOpcode()) {
3002-
assert(RVT == MVT::f16);
3003-
SDValue Res =
3004-
DAG.getNode(ISD::STRICT_FP_TO_FP16, SDLoc(N), {MVT::i16, MVT::Other},
3005-
{N->getOperand(0), N->getOperand(1)});
3001+
// FIXME: assume we only have two f16 variants for now.
3002+
unsigned Opcode;
3003+
if (RVT == MVT::f16)
3004+
Opcode = ISD::STRICT_FP_TO_FP16;
3005+
else if (RVT == MVT::bf16)
3006+
Opcode = ISD::STRICT_FP_TO_BF16;
3007+
else
3008+
llvm_unreachable("unknown half type");
3009+
SDValue Res = DAG.getNode(Opcode, SDLoc(N), {MVT::i16, MVT::Other},
3010+
{N->getOperand(0), N->getOperand(1)});
30063011
ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
30073012
return Res;
30083013
}

llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
165165
case ISD::FP_TO_FP16:
166166
Res = PromoteIntRes_FP_TO_FP16_BF16(N);
167167
break;
168+
case ISD::STRICT_FP_TO_BF16:
168169
case ISD::STRICT_FP_TO_FP16:
169170
Res = PromoteIntRes_STRICT_FP_TO_FP16_BF16(N);
170171
break;

llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,9 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
379379
case ISD::FP_TO_FP16: return "fp_to_fp16";
380380
case ISD::STRICT_FP_TO_FP16: return "strict_fp_to_fp16";
381381
case ISD::BF16_TO_FP: return "bf16_to_fp";
382+
case ISD::STRICT_BF16_TO_FP: return "strict_bf16_to_fp";
382383
case ISD::FP_TO_BF16: return "fp_to_bf16";
384+
case ISD::STRICT_FP_TO_BF16: return "strict_fp_to_bf16";
383385
case ISD::LROUND: return "lround";
384386
case ISD::STRICT_LROUND: return "strict_lround";
385387
case ISD::LLROUND: return "llround";

llvm/lib/Target/AMDGPU/SIISelLowering.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -539,8 +539,10 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
539539
setOperationAction({ISD::FSIN, ISD::FCOS, ISD::FDIV}, MVT::f32, Custom);
540540
setOperationAction(ISD::FDIV, MVT::f64, Custom);
541541

542-
setOperationAction(ISD::BF16_TO_FP, {MVT::i16, MVT::f32, MVT::f64}, Expand);
543-
setOperationAction(ISD::FP_TO_BF16, {MVT::i16, MVT::f32, MVT::f64}, Expand);
542+
setOperationAction({ISD::BF16_TO_FP, ISD::STRICT_BF16_TO_FP},
543+
{MVT::i16, MVT::f32, MVT::f64}, Expand);
544+
setOperationAction({ISD::FP_TO_BF16, ISD::STRICT_FP_TO_BF16},
545+
{MVT::i16, MVT::f32, MVT::f64}, Expand);
544546

545547
// Custom lower these because we can't specify a rule based on an illegal
546548
// source bf16.

llvm/test/CodeGen/AMDGPU/llvm.is.fpclass.bf16.ll

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,11 +1094,52 @@ define <4 x i1> @isnan_v4bf16(<4 x bfloat> %x) nounwind {
10941094
ret <4 x i1> %1
10951095
}
10961096

1097-
; FIXME: Broken for gfx6/7
1098-
; define i1 @isnan_bf16_strictfp(bfloat %x) strictfp nounwind {
1099-
; %1 = call i1 @llvm.is.fpclass.bf16(bfloat %x, i32 3) strictfp ; nan
1100-
; ret i1 %1
1101-
; }
1097+
define i1 @isnan_bf16_strictfp(bfloat %x) strictfp nounwind {
1098+
; GFX7CHECK-LABEL: isnan_bf16_strictfp:
1099+
; GFX7CHECK: ; %bb.0:
1100+
; GFX7CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
1101+
; GFX7CHECK-NEXT: v_bfe_u32 v0, v0, 16, 15
1102+
; GFX7CHECK-NEXT: s_movk_i32 s4, 0x7f80
1103+
; GFX7CHECK-NEXT: v_cmp_lt_i32_e32 vcc, s4, v0
1104+
; GFX7CHECK-NEXT: v_cndmask_b32_e64 v0, 0, 1, vcc
1105+
; GFX7CHECK-NEXT: s_setpc_b64 s[30:31]
1106+
;
1107+
; GFX8CHECK-LABEL: isnan_bf16_strictfp:
1108+
; GFX8CHECK: ; %bb.0:
1109+
; GFX8CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
1110+
; GFX8CHECK-NEXT: v_and_b32_e32 v0, 0x7fff, v0
1111+
; GFX8CHECK-NEXT: s_movk_i32 s4, 0x7f80
1112+
; GFX8CHECK-NEXT: v_cmp_lt_i16_e32 vcc, s4, v0
1113+
; GFX8CHECK-NEXT: v_cndmask_b32_e64 v0, 0, 1, vcc
1114+
; GFX8CHECK-NEXT: s_setpc_b64 s[30:31]
1115+
;
1116+
; GFX9CHECK-LABEL: isnan_bf16_strictfp:
1117+
; GFX9CHECK: ; %bb.0:
1118+
; GFX9CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
1119+
; GFX9CHECK-NEXT: v_and_b32_e32 v0, 0x7fff, v0
1120+
; GFX9CHECK-NEXT: s_movk_i32 s4, 0x7f80
1121+
; GFX9CHECK-NEXT: v_cmp_lt_i16_e32 vcc, s4, v0
1122+
; GFX9CHECK-NEXT: v_cndmask_b32_e64 v0, 0, 1, vcc
1123+
; GFX9CHECK-NEXT: s_setpc_b64 s[30:31]
1124+
;
1125+
; GFX10CHECK-LABEL: isnan_bf16_strictfp:
1126+
; GFX10CHECK: ; %bb.0:
1127+
; GFX10CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
1128+
; GFX10CHECK-NEXT: v_and_b32_e32 v0, 0x7fff, v0
1129+
; GFX10CHECK-NEXT: v_cmp_lt_i16_e32 vcc_lo, 0x7f80, v0
1130+
; GFX10CHECK-NEXT: v_cndmask_b32_e64 v0, 0, 1, vcc_lo
1131+
; GFX10CHECK-NEXT: s_setpc_b64 s[30:31]
1132+
;
1133+
; GFX11CHECK-LABEL: isnan_bf16_strictfp:
1134+
; GFX11CHECK: ; %bb.0:
1135+
; GFX11CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
1136+
; GFX11CHECK-NEXT: v_and_b32_e32 v0, 0x7fff, v0
1137+
; GFX11CHECK-NEXT: v_cmp_lt_i16_e32 vcc_lo, 0x7f80, v0
1138+
; GFX11CHECK-NEXT: v_cndmask_b32_e64 v0, 0, 1, vcc_lo
1139+
; GFX11CHECK-NEXT: s_setpc_b64 s[30:31]
1140+
%1 = call i1 @llvm.is.fpclass.bf16(bfloat %x, i32 3) strictfp ; nan
1141+
ret i1 %1
1142+
}
11021143

11031144
define i1 @isinf_bf16(bfloat %x) nounwind {
11041145
; GFX7CHECK-LABEL: isinf_bf16:

0 commit comments

Comments
 (0)