Skip to content

Commit 1e56e59

Browse files
john-brawn-armGeorgeARM
authored andcommitted
[AArch64] Use pattern to select bf16 fpextend (llvm#137212)
Currently bf16 fpextend is lowered to a vector shift. Instead leave it as fpextend and have an instruction selection pattern which selects to a shift later. Doing this means that DAGCombiner patterns for fpextend will be applied, leading to better codegen. It also means that in some situations we use a mov instruction where we previously have a dup instruction, but I don't think this makes any difference.
1 parent 33d2a24 commit 1e56e59

8 files changed

+125
-202
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

+6-32
Original file line numberDiff line numberDiff line change
@@ -766,13 +766,14 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
766766
setOperationAction(Op, MVT::v8bf16, Expand);
767767
}
768768

769-
// For bf16, fpextend is custom lowered to be optionally expanded into shifts.
770-
setOperationAction(ISD::FP_EXTEND, MVT::f32, Custom);
769+
// fpextend from f16 or bf16 to f32 is legal
770+
setOperationAction(ISD::FP_EXTEND, MVT::f32, Legal);
771+
setOperationAction(ISD::FP_EXTEND, MVT::v4f32, Legal);
772+
setOperationAction(ISD::STRICT_FP_EXTEND, MVT::f32, Legal);
773+
setOperationAction(ISD::STRICT_FP_EXTEND, MVT::v4f32, Legal);
774+
// fpextend from bf16 to f64 needs to be split into two fpextends
771775
setOperationAction(ISD::FP_EXTEND, MVT::f64, Custom);
772-
setOperationAction(ISD::FP_EXTEND, MVT::v4f32, Custom);
773-
setOperationAction(ISD::STRICT_FP_EXTEND, MVT::f32, Custom);
774776
setOperationAction(ISD::STRICT_FP_EXTEND, MVT::f64, Custom);
775-
setOperationAction(ISD::STRICT_FP_EXTEND, MVT::v4f32, Custom);
776777

777778
auto LegalizeNarrowFP = [this](MVT ScalarVT) {
778779
for (auto Op : {
@@ -4559,33 +4560,6 @@ SDValue AArch64TargetLowering::LowerFP_EXTEND(SDValue Op,
45594560
return SDValue();
45604561
}
45614562

4562-
if (VT.getScalarType() == MVT::f32) {
4563-
// FP16->FP32 extends are legal for v32 and v4f32.
4564-
if (Op0VT.getScalarType() == MVT::f16)
4565-
return Op;
4566-
if (Op0VT.getScalarType() == MVT::bf16) {
4567-
SDLoc DL(Op);
4568-
EVT IVT = VT.changeTypeToInteger();
4569-
if (!Op0VT.isVector()) {
4570-
Op0 = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v4bf16, Op0);
4571-
IVT = MVT::v4i32;
4572-
}
4573-
4574-
EVT Op0IVT = Op0.getValueType().changeTypeToInteger();
4575-
SDValue Ext =
4576-
DAG.getNode(ISD::ANY_EXTEND, DL, IVT, DAG.getBitcast(Op0IVT, Op0));
4577-
SDValue Shift =
4578-
DAG.getNode(ISD::SHL, DL, IVT, Ext, DAG.getConstant(16, DL, IVT));
4579-
if (!Op0VT.isVector())
4580-
Shift = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i32, Shift,
4581-
DAG.getConstant(0, DL, MVT::i64));
4582-
Shift = DAG.getBitcast(VT, Shift);
4583-
return IsStrict ? DAG.getMergeValues({Shift, Op.getOperand(0)}, DL)
4584-
: Shift;
4585-
}
4586-
return SDValue();
4587-
}
4588-
45894563
assert(Op.getValueType() == MVT::f128 && "Unexpected lowering");
45904564
return SDValue();
45914565
}

llvm/lib/Target/AArch64/AArch64InstrInfo.td

+20
Original file line numberDiff line numberDiff line change
@@ -8513,6 +8513,26 @@ def : InstAlias<"uxtl2 $dst.2d, $src1.4s",
85138513
(USHLLv4i32_shift V128:$dst, V128:$src1, 0)>;
85148514
}
85158515

8516+
// fpextend from bf16 to f32 is just a shift left by 16
8517+
let Predicates = [HasNEON] in {
8518+
def : Pat<(f32 (any_fpextend (bf16 FPR16:$Rn))),
8519+
(f32 (EXTRACT_SUBREG
8520+
(v4i32 (SHLLv4i16 (v4i16 (SUBREG_TO_REG (i64 0), (bf16 FPR16:$Rn), hsub)))),
8521+
ssub))>;
8522+
def : Pat<(v4f32 (any_fpextend (v4bf16 V64:$Rn))),
8523+
(SHLLv4i16 V64:$Rn)>;
8524+
def : Pat<(v4f32 (any_fpextend (extract_high_v8bf16 (v8bf16 V128:$Rn)))),
8525+
(SHLLv8i16 V128:$Rn)>;
8526+
}
8527+
// Fallback pattern for when we don't have NEON
8528+
def : Pat<(f32 (any_fpextend (bf16 FPR16:$Rn))),
8529+
(f32 (COPY_TO_REGCLASS
8530+
(i32 (UBFMWri (COPY_TO_REGCLASS
8531+
(f32 (SUBREG_TO_REG (i32 0), (bf16 FPR16:$Rn), hsub)),
8532+
GPR32),
8533+
(i64 16), (i64 15))),
8534+
FPR32))>;
8535+
85168536
def abs_f16 :
85178537
OutPatFrag<(ops node:$Rn),
85188538
(EXTRACT_SUBREG (f32 (COPY_TO_REGCLASS

llvm/test/CodeGen/AArch64/arm64-fast-isel-conversion-fallback.ll

+2-6
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,7 @@ entry:
155155
define i32 @fptosi_bf(bfloat %a) nounwind ssp {
156156
; CHECK-LABEL: fptosi_bf:
157157
; CHECK: // %bb.0: // %entry
158-
; CHECK-NEXT: fmov s1, s0
159-
; CHECK-NEXT: // implicit-def: $d0
160-
; CHECK-NEXT: fmov s0, s1
158+
; CHECK-NEXT: // kill: def $d0 killed $h0
161159
; CHECK-NEXT: shll v0.4s, v0.4h, #16
162160
; CHECK-NEXT: // kill: def $s0 killed $s0 killed $q0
163161
; CHECK-NEXT: fcvtzs w0, s0
@@ -171,9 +169,7 @@ entry:
171169
define i32 @fptoui_sbf(bfloat %a) nounwind ssp {
172170
; CHECK-LABEL: fptoui_sbf:
173171
; CHECK: // %bb.0: // %entry
174-
; CHECK-NEXT: fmov s1, s0
175-
; CHECK-NEXT: // implicit-def: $d0
176-
; CHECK-NEXT: fmov s0, s1
172+
; CHECK-NEXT: // kill: def $d0 killed $h0
177173
; CHECK-NEXT: shll v0.4s, v0.4h, #16
178174
; CHECK-NEXT: // kill: def $s0 killed $s0 killed $q0
179175
; CHECK-NEXT: fcvtzu w0, s0

llvm/test/CodeGen/AArch64/atomicrmw-fmax.ll

+4-4
Original file line numberDiff line numberDiff line change
@@ -641,15 +641,15 @@ define <2 x bfloat> @test_atomicrmw_fmax_v2bf16_seq_cst_align4(ptr %ptr, <2 x bf
641641
; NOLSE-LABEL: test_atomicrmw_fmax_v2bf16_seq_cst_align4:
642642
; NOLSE: // %bb.0:
643643
; NOLSE-NEXT: // kill: def $d0 killed $d0 def $q0
644-
; NOLSE-NEXT: dup v1.4h, v0.h[1]
644+
; NOLSE-NEXT: mov h1, v0.h[1]
645645
; NOLSE-NEXT: mov w8, #32767 // =0x7fff
646646
; NOLSE-NEXT: shll v0.4s, v0.4h, #16
647647
; NOLSE-NEXT: shll v1.4s, v1.4h, #16
648648
; NOLSE-NEXT: .LBB7_1: // %atomicrmw.start
649649
; NOLSE-NEXT: // =>This Inner Loop Header: Depth=1
650650
; NOLSE-NEXT: ldaxr w9, [x0]
651651
; NOLSE-NEXT: fmov s2, w9
652-
; NOLSE-NEXT: dup v3.4h, v2.h[1]
652+
; NOLSE-NEXT: mov h3, v2.h[1]
653653
; NOLSE-NEXT: shll v2.4s, v2.4h, #16
654654
; NOLSE-NEXT: fmaxnm s2, s2, s0
655655
; NOLSE-NEXT: shll v3.4s, v3.4h, #16
@@ -677,14 +677,14 @@ define <2 x bfloat> @test_atomicrmw_fmax_v2bf16_seq_cst_align4(ptr %ptr, <2 x bf
677677
; LSE-LABEL: test_atomicrmw_fmax_v2bf16_seq_cst_align4:
678678
; LSE: // %bb.0:
679679
; LSE-NEXT: // kill: def $d0 killed $d0 def $q0
680-
; LSE-NEXT: dup v1.4h, v0.h[1]
680+
; LSE-NEXT: mov h1, v0.h[1]
681681
; LSE-NEXT: shll v2.4s, v0.4h, #16
682682
; LSE-NEXT: mov w8, #32767 // =0x7fff
683683
; LSE-NEXT: ldr s0, [x0]
684684
; LSE-NEXT: shll v1.4s, v1.4h, #16
685685
; LSE-NEXT: .LBB7_1: // %atomicrmw.start
686686
; LSE-NEXT: // =>This Inner Loop Header: Depth=1
687-
; LSE-NEXT: dup v3.4h, v0.h[1]
687+
; LSE-NEXT: mov h3, v0.h[1]
688688
; LSE-NEXT: shll v4.4s, v0.4h, #16
689689
; LSE-NEXT: fmaxnm s4, s4, s2
690690
; LSE-NEXT: shll v3.4s, v3.4h, #16

llvm/test/CodeGen/AArch64/atomicrmw-fmin.ll

+4-4
Original file line numberDiff line numberDiff line change
@@ -641,15 +641,15 @@ define <2 x bfloat> @test_atomicrmw_fmin_v2bf16_seq_cst_align4(ptr %ptr, <2 x bf
641641
; NOLSE-LABEL: test_atomicrmw_fmin_v2bf16_seq_cst_align4:
642642
; NOLSE: // %bb.0:
643643
; NOLSE-NEXT: // kill: def $d0 killed $d0 def $q0
644-
; NOLSE-NEXT: dup v1.4h, v0.h[1]
644+
; NOLSE-NEXT: mov h1, v0.h[1]
645645
; NOLSE-NEXT: mov w8, #32767 // =0x7fff
646646
; NOLSE-NEXT: shll v0.4s, v0.4h, #16
647647
; NOLSE-NEXT: shll v1.4s, v1.4h, #16
648648
; NOLSE-NEXT: .LBB7_1: // %atomicrmw.start
649649
; NOLSE-NEXT: // =>This Inner Loop Header: Depth=1
650650
; NOLSE-NEXT: ldaxr w9, [x0]
651651
; NOLSE-NEXT: fmov s2, w9
652-
; NOLSE-NEXT: dup v3.4h, v2.h[1]
652+
; NOLSE-NEXT: mov h3, v2.h[1]
653653
; NOLSE-NEXT: shll v2.4s, v2.4h, #16
654654
; NOLSE-NEXT: fminnm s2, s2, s0
655655
; NOLSE-NEXT: shll v3.4s, v3.4h, #16
@@ -677,14 +677,14 @@ define <2 x bfloat> @test_atomicrmw_fmin_v2bf16_seq_cst_align4(ptr %ptr, <2 x bf
677677
; LSE-LABEL: test_atomicrmw_fmin_v2bf16_seq_cst_align4:
678678
; LSE: // %bb.0:
679679
; LSE-NEXT: // kill: def $d0 killed $d0 def $q0
680-
; LSE-NEXT: dup v1.4h, v0.h[1]
680+
; LSE-NEXT: mov h1, v0.h[1]
681681
; LSE-NEXT: shll v2.4s, v0.4h, #16
682682
; LSE-NEXT: mov w8, #32767 // =0x7fff
683683
; LSE-NEXT: ldr s0, [x0]
684684
; LSE-NEXT: shll v1.4s, v1.4h, #16
685685
; LSE-NEXT: .LBB7_1: // %atomicrmw.start
686686
; LSE-NEXT: // =>This Inner Loop Header: Depth=1
687-
; LSE-NEXT: dup v3.4h, v0.h[1]
687+
; LSE-NEXT: mov h3, v0.h[1]
688688
; LSE-NEXT: shll v4.4s, v0.4h, #16
689689
; LSE-NEXT: fminnm s4, s4, s2
690690
; LSE-NEXT: shll v3.4s, v3.4h, #16

llvm/test/CodeGen/AArch64/bf16-instructions.ll

+9-18
Original file line numberDiff line numberDiff line change
@@ -202,16 +202,13 @@ define bfloat @test_fmadd(bfloat %a, bfloat %b, bfloat %c) #0 {
202202
;
203203
; CHECK-BF16-LABEL: test_fmadd:
204204
; CHECK-BF16: // %bb.0:
205+
; CHECK-BF16-NEXT: // kill: def $h2 killed $h2 def $d2
205206
; CHECK-BF16-NEXT: // kill: def $h1 killed $h1 def $d1
206207
; CHECK-BF16-NEXT: // kill: def $h0 killed $h0 def $d0
207-
; CHECK-BF16-NEXT: // kill: def $h2 killed $h2 def $d2
208208
; CHECK-BF16-NEXT: shll v1.4s, v1.4h, #16
209209
; CHECK-BF16-NEXT: shll v0.4s, v0.4h, #16
210-
; CHECK-BF16-NEXT: fmul s0, s0, s1
211-
; CHECK-BF16-NEXT: shll v1.4s, v2.4h, #16
212-
; CHECK-BF16-NEXT: bfcvt h0, s0
213-
; CHECK-BF16-NEXT: shll v0.4s, v0.4h, #16
214-
; CHECK-BF16-NEXT: fadd s0, s0, s1
210+
; CHECK-BF16-NEXT: shll v2.4s, v2.4h, #16
211+
; CHECK-BF16-NEXT: fmadd s0, s0, s1, s2
215212
; CHECK-BF16-NEXT: bfcvt h0, s0
216213
; CHECK-BF16-NEXT: ret
217214
%mul = fmul fast bfloat %a, %b
@@ -1996,13 +1993,11 @@ define bfloat @test_copysign_f64(bfloat %a, double %b) #0 {
19961993
define float @test_copysign_extended(bfloat %a, bfloat %b) #0 {
19971994
; CHECK-CVT-LABEL: test_copysign_extended:
19981995
; CHECK-CVT: // %bb.0:
1999-
; CHECK-CVT-NEXT: // kill: def $h0 killed $h0 def $d0
2000-
; CHECK-CVT-NEXT: movi v2.4s, #16
20011996
; CHECK-CVT-NEXT: // kill: def $h1 killed $h1 def $d1
2002-
; CHECK-CVT-NEXT: ushll v0.4s, v0.4h, #0
2003-
; CHECK-CVT-NEXT: shll v1.4s, v1.4h, #16
2004-
; CHECK-CVT-NEXT: ushl v0.4s, v0.4s, v2.4s
1997+
; CHECK-CVT-NEXT: // kill: def $h0 killed $h0 def $d0
20051998
; CHECK-CVT-NEXT: mvni v2.4s, #128, lsl #24
1999+
; CHECK-CVT-NEXT: shll v1.4s, v1.4h, #16
2000+
; CHECK-CVT-NEXT: shll v0.4s, v0.4h, #16
20062001
; CHECK-CVT-NEXT: bif v0.16b, v1.16b, v2.16b
20072002
; CHECK-CVT-NEXT: fmov w8, s0
20082003
; CHECK-CVT-NEXT: lsr w8, w8, #16
@@ -2013,16 +2008,12 @@ define float @test_copysign_extended(bfloat %a, bfloat %b) #0 {
20132008
;
20142009
; CHECK-SD-LABEL: test_copysign_extended:
20152010
; CHECK-SD: // %bb.0:
2016-
; CHECK-SD-NEXT: // kill: def $h0 killed $h0 def $d0
2017-
; CHECK-SD-NEXT: movi v2.4s, #16
20182011
; CHECK-SD-NEXT: // kill: def $h1 killed $h1 def $d1
2019-
; CHECK-SD-NEXT: ushll v0.4s, v0.4h, #0
2020-
; CHECK-SD-NEXT: shll v1.4s, v1.4h, #16
2021-
; CHECK-SD-NEXT: ushl v0.4s, v0.4s, v2.4s
2012+
; CHECK-SD-NEXT: // kill: def $h0 killed $h0 def $d0
20222013
; CHECK-SD-NEXT: mvni v2.4s, #128, lsl #24
2023-
; CHECK-SD-NEXT: bif v0.16b, v1.16b, v2.16b
2024-
; CHECK-SD-NEXT: bfcvt h0, s0
2014+
; CHECK-SD-NEXT: shll v1.4s, v1.4h, #16
20252015
; CHECK-SD-NEXT: shll v0.4s, v0.4h, #16
2016+
; CHECK-SD-NEXT: bif v0.16b, v1.16b, v2.16b
20262017
; CHECK-SD-NEXT: // kill: def $s0 killed $s0 killed $q0
20272018
; CHECK-SD-NEXT: ret
20282019
;

0 commit comments

Comments
 (0)