Skip to content

Commit f47034c

Browse files
changpengarsenm
andauthored
AMDGPU: Add round-to-odd rounding during f64 to bf16 conversion (#133995)
f64 -> bf16 conversion can be lowered to f64 -> f32 followed by f32 -> bf16: v_cvt_f32_f64_e32 v0, v[0:1] v_cvt_pk_bf16_f32 v0, v0, s0 Both conversion instructions will do round-to-even rounding, and thus we will have double rounding issue which may generate incorrect result in some data range. We need to add round-to-odd rounding during f64 -> f32 to avoid double rounding,. NOTE: we are having the same issue with f64 -> f16 conversion. Will add round-to-odd rounding for it in a separate patch, which fixes SWDEV-523856 --------- Co-authored-by: Matt Arsenault <[email protected]>
1 parent a07b374 commit f47034c

File tree

3 files changed

+89
-24
lines changed

3 files changed

+89
-24
lines changed

Diff for: llvm/lib/Target/AMDGPU/SIISelLowering.cpp

+23-13
Original file line numberDiff line numberDiff line change
@@ -911,8 +911,7 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
911911
setOperationAction(ISD::MUL, MVT::i1, Promote);
912912

913913
if (Subtarget->hasBF16ConversionInsts()) {
914-
setOperationAction(ISD::FP_ROUND, MVT::v2bf16, Legal);
915-
setOperationAction(ISD::FP_ROUND, MVT::bf16, Legal);
914+
setOperationAction(ISD::FP_ROUND, {MVT::bf16, MVT::v2bf16}, Custom);
916915
setOperationAction(ISD::BUILD_VECTOR, MVT::v2bf16, Legal);
917916
}
918917

@@ -6888,23 +6887,34 @@ SDValue SITargetLowering::getFPExtOrFPRound(SelectionDAG &DAG, SDValue Op,
68886887
}
68896888

68906889
SDValue SITargetLowering::lowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const {
6891-
assert(Op.getValueType() == MVT::f16 &&
6892-
"Do not know how to custom lower FP_ROUND for non-f16 type");
6893-
68946890
SDValue Src = Op.getOperand(0);
68956891
EVT SrcVT = Src.getValueType();
6896-
if (SrcVT != MVT::f64)
6897-
return Op;
6898-
6899-
// TODO: Handle strictfp
6900-
if (Op.getOpcode() != ISD::FP_ROUND)
6892+
if (SrcVT.getScalarType() != MVT::f64)
69016893
return Op;
69026894

6895+
EVT DstVT = Op.getValueType();
69036896
SDLoc DL(Op);
6897+
if (DstVT == MVT::f16) {
6898+
// TODO: Handle strictfp
6899+
if (Op.getOpcode() != ISD::FP_ROUND)
6900+
return Op;
6901+
6902+
SDValue FpToFp16 = DAG.getNode(ISD::FP_TO_FP16, DL, MVT::i32, Src);
6903+
SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FpToFp16);
6904+
return DAG.getNode(ISD::BITCAST, DL, MVT::f16, Trunc);
6905+
}
6906+
6907+
assert(DstVT.getScalarType() == MVT::bf16 &&
6908+
"custom lower FP_ROUND for f16 or bf16");
6909+
assert(Subtarget->hasBF16ConversionInsts() && "f32 -> bf16 is legal");
69046910

6905-
SDValue FpToFp16 = DAG.getNode(ISD::FP_TO_FP16, DL, MVT::i32, Src);
6906-
SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FpToFp16);
6907-
return DAG.getNode(ISD::BITCAST, DL, MVT::f16, Trunc);
6911+
// Round-inexact-to-odd f64 to f32, then do the final rounding using the
6912+
// hardware f32 -> bf16 instruction.
6913+
EVT F32VT = SrcVT.isVector() ? SrcVT.changeVectorElementType(MVT::f32) :
6914+
MVT::f32;
6915+
SDValue Rod = expandRoundInexactToOdd(F32VT, Src, DL, DAG);
6916+
return DAG.getNode(ISD::FP_ROUND, DL, DstVT, Rod,
6917+
DAG.getTargetConstant(0, DL, MVT::i32));
69086918
}
69096919

69106920
SDValue SITargetLowering::lowerFMINNUM_FMAXNUM(SDValue Op,

Diff for: llvm/lib/Target/AMDGPU/VOP3Instructions.td

-5
Original file line numberDiff line numberDiff line change
@@ -1443,16 +1443,11 @@ let SubtargetPredicate = HasBF16ConversionInsts in {
14431443
}
14441444
def : GCNPat<(v2bf16 (bf16_fpround v2f32:$src)),
14451445
(V_CVT_PK_BF16_F32_e64 0, (EXTRACT_SUBREG VReg_64:$src, sub0), 0, (EXTRACT_SUBREG VReg_64:$src, sub1))>;
1446-
def : GCNPat<(v2bf16 (bf16_fpround v2f64:$src)),
1447-
(V_CVT_PK_BF16_F32_e64 0, (V_CVT_F32_F64_e64 0, (EXTRACT_SUBREG VReg_128:$src, sub0_sub1)),
1448-
0, (V_CVT_F32_F64_e64 0, (EXTRACT_SUBREG VReg_128:$src, sub2_sub3)))>;
14491446
def : GCNPat<(v2bf16 (build_vector (bf16 (bf16_fpround (f32 (VOP3Mods f32:$src0, i32:$src0_modifiers)))),
14501447
(bf16 (bf16_fpround (f32 (VOP3Mods f32:$src1, i32:$src1_modifiers)))))),
14511448
(V_CVT_PK_BF16_F32_e64 $src0_modifiers, $src0, $src1_modifiers, $src1)>;
14521449
def : GCNPat<(bf16 (bf16_fpround (f32 (VOP3Mods f32:$src0, i32:$src0_modifiers)))),
14531450
(V_CVT_PK_BF16_F32_e64 $src0_modifiers, $src0, 0, (f32 (IMPLICIT_DEF)))>;
1454-
def : GCNPat<(bf16 (bf16_fpround (f64 (VOP3Mods f64:$src0, i32:$src0_modifiers)))),
1455-
(V_CVT_PK_BF16_F32_e64 0, (f32 (V_CVT_F32_F64_e64 $src0_modifiers, $src0)), 0, (f32 (IMPLICIT_DEF)))>;
14561451
}
14571452

14581453
class Cvt_Scale_Sr_F32ToBF16F16_Pat<SDPatternOperator node, VOP3_Pseudo inst, ValueType DstTy> : GCNPat<

Diff for: llvm/test/CodeGen/AMDGPU/bf16-conversions.ll

+66-6
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,34 @@ define amdgpu_ps float @v_test_cvt_v2f64_v2bf16_v(<2 x double> %src) {
153153
;
154154
; GFX-950-LABEL: v_test_cvt_v2f64_v2bf16_v:
155155
; GFX-950: ; %bb.0:
156-
; GFX-950-NEXT: v_cvt_f32_f64_e32 v2, v[2:3]
157-
; GFX-950-NEXT: v_cvt_f32_f64_e32 v0, v[0:1]
158-
; GFX-950-NEXT: v_cvt_pk_bf16_f32 v0, v0, v2
156+
; GFX-950-NEXT: v_mov_b32_e32 v4, v3
157+
; GFX-950-NEXT: v_and_b32_e32 v3, 0x7fffffff, v4
158+
; GFX-950-NEXT: v_mov_b32_e32 v5, v1
159+
; GFX-950-NEXT: v_cvt_f32_f64_e32 v1, v[2:3]
160+
; GFX-950-NEXT: v_cvt_f64_f32_e32 v[6:7], v1
161+
; GFX-950-NEXT: v_and_b32_e32 v8, 1, v1
162+
; GFX-950-NEXT: v_cmp_gt_f64_e64 s[2:3], v[2:3], v[6:7]
163+
; GFX-950-NEXT: v_cmp_nlg_f64_e32 vcc, v[2:3], v[6:7]
164+
; GFX-950-NEXT: v_cmp_eq_u32_e64 s[0:1], 1, v8
165+
; GFX-950-NEXT: v_cndmask_b32_e64 v2, -1, 1, s[2:3]
166+
; GFX-950-NEXT: v_add_u32_e32 v2, v1, v2
167+
; GFX-950-NEXT: s_or_b64 vcc, vcc, s[0:1]
168+
; GFX-950-NEXT: v_cndmask_b32_e32 v1, v2, v1, vcc
169+
; GFX-950-NEXT: s_brev_b32 s4, 1
170+
; GFX-950-NEXT: v_and_or_b32 v4, v4, s4, v1
171+
; GFX-950-NEXT: v_and_b32_e32 v1, 0x7fffffff, v5
172+
; GFX-950-NEXT: v_cvt_f32_f64_e32 v6, v[0:1]
173+
; GFX-950-NEXT: v_cvt_f64_f32_e32 v[2:3], v6
174+
; GFX-950-NEXT: v_and_b32_e32 v7, 1, v6
175+
; GFX-950-NEXT: v_cmp_gt_f64_e64 s[2:3], v[0:1], v[2:3]
176+
; GFX-950-NEXT: v_cmp_nlg_f64_e32 vcc, v[0:1], v[2:3]
177+
; GFX-950-NEXT: v_cmp_eq_u32_e64 s[0:1], 1, v7
178+
; GFX-950-NEXT: v_cndmask_b32_e64 v0, -1, 1, s[2:3]
179+
; GFX-950-NEXT: v_add_u32_e32 v0, v6, v0
180+
; GFX-950-NEXT: s_or_b64 vcc, vcc, s[0:1]
181+
; GFX-950-NEXT: v_cndmask_b32_e32 v0, v0, v6, vcc
182+
; GFX-950-NEXT: v_and_or_b32 v0, v5, s4, v0
183+
; GFX-950-NEXT: v_cvt_pk_bf16_f32 v0, v0, v4
159184
; GFX-950-NEXT: ; return to shader part epilog
160185
%res = fptrunc <2 x double> %src to <2 x bfloat>
161186
%cast = bitcast <2 x bfloat> %res to float
@@ -347,7 +372,18 @@ define amdgpu_ps void @fptrunc_f64_to_bf16(double %a, ptr %out) {
347372
;
348373
; GFX-950-LABEL: fptrunc_f64_to_bf16:
349374
; GFX-950: ; %bb.0: ; %entry
350-
; GFX-950-NEXT: v_cvt_f32_f64_e32 v0, v[0:1]
375+
; GFX-950-NEXT: v_cvt_f32_f64_e64 v6, |v[0:1]|
376+
; GFX-950-NEXT: v_cvt_f64_f32_e32 v[4:5], v6
377+
; GFX-950-NEXT: v_and_b32_e32 v7, 1, v6
378+
; GFX-950-NEXT: v_cmp_gt_f64_e64 s[2:3], |v[0:1]|, v[4:5]
379+
; GFX-950-NEXT: v_cmp_nlg_f64_e64 s[0:1], |v[0:1]|, v[4:5]
380+
; GFX-950-NEXT: v_cmp_eq_u32_e32 vcc, 1, v7
381+
; GFX-950-NEXT: v_cndmask_b32_e64 v0, -1, 1, s[2:3]
382+
; GFX-950-NEXT: v_add_u32_e32 v0, v6, v0
383+
; GFX-950-NEXT: s_or_b64 vcc, s[0:1], vcc
384+
; GFX-950-NEXT: v_cndmask_b32_e32 v0, v0, v6, vcc
385+
; GFX-950-NEXT: s_brev_b32 s0, 1
386+
; GFX-950-NEXT: v_and_or_b32 v0, v1, s0, v0
351387
; GFX-950-NEXT: v_cvt_pk_bf16_f32 v0, v0, s0
352388
; GFX-950-NEXT: flat_store_short v[2:3], v0
353389
; GFX-950-NEXT: s_endpgm
@@ -385,7 +421,19 @@ define amdgpu_ps void @fptrunc_f64_to_bf16_neg(double %a, ptr %out) {
385421
;
386422
; GFX-950-LABEL: fptrunc_f64_to_bf16_neg:
387423
; GFX-950: ; %bb.0: ; %entry
388-
; GFX-950-NEXT: v_cvt_f32_f64_e64 v0, -v[0:1]
424+
; GFX-950-NEXT: v_cvt_f32_f64_e64 v7, |v[0:1]|
425+
; GFX-950-NEXT: v_cvt_f64_f32_e32 v[4:5], v7
426+
; GFX-950-NEXT: v_and_b32_e32 v8, 1, v7
427+
; GFX-950-NEXT: v_cmp_gt_f64_e64 s[2:3], |v[0:1]|, v[4:5]
428+
; GFX-950-NEXT: v_cmp_nlg_f64_e64 s[0:1], |v[0:1]|, v[4:5]
429+
; GFX-950-NEXT: v_cmp_eq_u32_e32 vcc, 1, v8
430+
; GFX-950-NEXT: v_cndmask_b32_e64 v0, -1, 1, s[2:3]
431+
; GFX-950-NEXT: v_add_u32_e32 v0, v7, v0
432+
; GFX-950-NEXT: s_or_b64 vcc, s[0:1], vcc
433+
; GFX-950-NEXT: s_brev_b32 s4, 1
434+
; GFX-950-NEXT: v_xor_b32_e32 v6, 0x80000000, v1
435+
; GFX-950-NEXT: v_cndmask_b32_e32 v0, v0, v7, vcc
436+
; GFX-950-NEXT: v_and_or_b32 v0, v6, s4, v0
389437
; GFX-950-NEXT: v_cvt_pk_bf16_f32 v0, v0, s0
390438
; GFX-950-NEXT: flat_store_short v[2:3], v0
391439
; GFX-950-NEXT: s_endpgm
@@ -424,7 +472,19 @@ define amdgpu_ps void @fptrunc_f64_to_bf16_abs(double %a, ptr %out) {
424472
;
425473
; GFX-950-LABEL: fptrunc_f64_to_bf16_abs:
426474
; GFX-950: ; %bb.0: ; %entry
427-
; GFX-950-NEXT: v_cvt_f32_f64_e64 v0, |v[0:1]|
475+
; GFX-950-NEXT: v_cvt_f32_f64_e64 v7, |v[0:1]|
476+
; GFX-950-NEXT: v_cvt_f64_f32_e32 v[4:5], v7
477+
; GFX-950-NEXT: v_and_b32_e32 v8, 1, v7
478+
; GFX-950-NEXT: v_cmp_gt_f64_e64 s[2:3], |v[0:1]|, v[4:5]
479+
; GFX-950-NEXT: v_cmp_nlg_f64_e64 s[0:1], |v[0:1]|, v[4:5]
480+
; GFX-950-NEXT: v_cmp_eq_u32_e32 vcc, 1, v8
481+
; GFX-950-NEXT: v_cndmask_b32_e64 v0, -1, 1, s[2:3]
482+
; GFX-950-NEXT: v_add_u32_e32 v0, v7, v0
483+
; GFX-950-NEXT: s_or_b64 vcc, s[0:1], vcc
484+
; GFX-950-NEXT: v_and_b32_e32 v6, 0x7fffffff, v1
485+
; GFX-950-NEXT: v_cndmask_b32_e32 v0, v0, v7, vcc
486+
; GFX-950-NEXT: s_brev_b32 s0, 1
487+
; GFX-950-NEXT: v_and_or_b32 v0, v6, s0, v0
428488
; GFX-950-NEXT: v_cvt_pk_bf16_f32 v0, v0, s0
429489
; GFX-950-NEXT: flat_store_short v[2:3], v0
430490
; GFX-950-NEXT: s_endpgm

0 commit comments

Comments
 (0)