Skip to content

[GlobalIsel] combine extract vector element #91922

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

tschuett
Copy link

scalarize compares

extelt (cmp X, Y), Index --> cmp (extelt X, Index),
(extelt Y, Index)

scalarize compares

extelt (cmp X, Y), Index --> cmp (extelt X, Index),
                                 (extelt Y, Index)
@llvmbot
Copy link
Member

llvmbot commented May 13, 2024

@llvm/pr-subscribers-backend-aarch64

Author: Thorsten Schütt (tschuett)

Changes

scalarize compares

extelt (cmp X, Y), Index --> cmp (extelt X, Index),
(extelt Y, Index)


Full diff: https://github.com/llvm/llvm-project/pull/91922.diff

4 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h (+23)
  • (modified) llvm/include/llvm/Target/GlobalISel/Combine.td (+18-1)
  • (modified) llvm/lib/CodeGen/GlobalISel/CombinerHelperVectorOps.cpp (+159)
  • (modified) llvm/test/CodeGen/AArch64/extract-vector-elt.ll (+128)
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
index ecaece8b68342..6edb3f9cd2e89 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
@@ -1,3 +1,4 @@
+
 //===-- llvm/CodeGen/GlobalISel/CombinerHelper.h --------------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
@@ -866,6 +867,16 @@ class CombinerHelper {
   /// Combine insert vector element OOB.
   bool matchInsertVectorElementOOB(MachineInstr &MI, BuildFnTy &MatchInfo);
 
+  /// Combine extract vector element with a compare on the vector
+  /// register.
+  bool matchExtractVectorElementWithICmp(const MachineOperand &MO,
+                                         BuildFnTy &MatchInfo);
+
+  /// Combine extract vector element with a compare on the vector
+  /// register.
+  bool matchExtractVectorElementWithFCmp(const MachineOperand &MO,
+                                         BuildFnTy &MatchInfo);
+
 private:
   /// Checks for legality of an indexed variant of \p LdSt.
   bool isIndexedLoadStoreLegal(GLoadStore &LdSt) const;
@@ -981,6 +992,18 @@ class CombinerHelper {
 
   // Simplify (cmp cc0 x, y) (&& or ||) (cmp cc1 x, y) -> cmp cc2 x, y.
   bool tryFoldLogicOfFCmps(GLogicalBinOp *Logic, BuildFnTy &MatchInfo);
+
+  /// Return true if the register \p Src is cheaper to scalarize than it is to
+  /// leave as a vector operation. If the extract index \p Index is a constant
+  /// integer then some operations may be cheap to scalarize. The depth \p Depth
+  /// prevents arbitrary recursion.
+  bool isCheapToScalarize(Register Src, const std::optional<APInt> &Index,
+                          unsigned Depth = 0);
+
+  /// Return true if \p Src is def'd by a operation of type vector that is
+  /// constant at offset \p Index. \p Depth limits arbitrary recursion into look
+  /// through vector operations.
+  bool isConstantAtOffset(Register Src, const APInt &Index, unsigned Depth = 0);
 };
 } // namespace llvm
 
diff --git a/llvm/include/llvm/Target/GlobalISel/Combine.td b/llvm/include/llvm/Target/GlobalISel/Combine.td
index 98d266c8c0b4f..3c71c2a25b2d9 100644
--- a/llvm/include/llvm/Target/GlobalISel/Combine.td
+++ b/llvm/include/llvm/Target/GlobalISel/Combine.td
@@ -1591,6 +1591,20 @@ def insert_vector_elt_oob : GICombineRule<
          [{ return Helper.matchInsertVectorElementOOB(*${root}, ${matchinfo}); }]),
   (apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;
 
+def extract_vector_element_icmp : GICombineRule<
+  (defs root:$root, build_fn_matchinfo:$matchinfo),
+   (match (G_ICMP $src, $pred, $lhs, $rhs),
+          (G_EXTRACT_VECTOR_ELT $root, $src, $idx),
+   [{ return Helper.matchExtractVectorElementWithICmp(${root}, ${matchinfo}); }]),
+   (apply [{ Helper.applyBuildFnMO(${root}, ${matchinfo}); }])>;
+
+def extract_vector_element_fcmp : GICombineRule<
+  (defs root:$root, build_fn_matchinfo:$matchinfo),
+  (match (G_FCMP $fsrc, $fpred, $flhs, $frhs),
+         (G_EXTRACT_VECTOR_ELT $root, $fsrc, $fidx),
+  [{ return Helper.matchExtractVectorElementWithFCmp(${root}, ${matchinfo}); }]),
+  (apply [{ Helper.applyBuildFnMO(${root}, ${matchinfo}); }])>;
+
 // match_extract_of_element and insert_vector_elt_oob must be the first!
 def vector_ops_combines: GICombineGroup<[
 match_extract_of_element_undef_vector,
@@ -1624,6 +1638,8 @@ extract_vector_element_build_vector_trunc7,
 extract_vector_element_build_vector_trunc8,
 extract_vector_element_freeze,
 extract_vector_element_shuffle_vector,
+extract_vector_element_icmp,
+extract_vector_element_fcmp,
 insert_vector_element_extract_vector_element
 ]>;
 
@@ -1706,7 +1722,8 @@ def all_combines : GICombineGroup<[trivial_combines, vector_ops_combines,
     sub_add_reg, select_to_minmax, redundant_binop_in_equality,
     fsub_to_fneg, commute_constant_to_rhs, match_ands, match_ors,
     combine_concat_vector, double_icmp_zero_and_or_combine, match_addos,
-    sext_trunc, zext_trunc, combine_shuffle_concat]>;
+    sext_trunc, zext_trunc, combine_shuffle_concat
+]>;
 
 // A combine group used to for prelegalizer combiners at -O0. The combines in
 // this group have been selected based on experiments to balance code size and
diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelperVectorOps.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelperVectorOps.cpp
index 21b1eb2628174..64b39e3f82e65 100644
--- a/llvm/lib/CodeGen/GlobalISel/CombinerHelperVectorOps.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelperVectorOps.cpp
@@ -453,3 +453,162 @@ bool CombinerHelper::matchInsertVectorElementOOB(MachineInstr &MI,
 
   return false;
 }
+
+bool CombinerHelper::isConstantAtOffset(Register Src, const APInt &Index,
+                                        unsigned Depth) {
+  assert(MRI.getType(Src).isVector() && "expected a vector as input");
+  if (Depth == 2)
+    return false;
+
+  // We use the look through variant for higher hit rate and to increase the
+  // likelyhood of constant folding. The actual value is ignored. We only test
+  // *whether* there is a constant.
+
+  MachineInstr *SrcMI = getDefIgnoringCopies(Src, MRI);
+
+  // If Src is def'd by build vector, then we check the constness at the offset.
+  if (auto *Build = dyn_cast<GBuildVector>(SrcMI))
+    return getAnyConstantVRegValWithLookThrough(
+               Build->getSourceReg(Index.getZExtValue()), MRI)
+        .has_value();
+
+  // For concat and shuffle vectors, we could recurse.
+  // FIXME concat vectors
+  // FIXME shuffle vectors
+  // FIXME unary ops
+  // FIXME insert vector element
+  // FIXME subvector
+
+  return false;
+}
+
+bool CombinerHelper::isCheapToScalarize(Register Src,
+                                        const std::optional<APInt> &Index,
+                                        unsigned Depth) {
+  assert(MRI.getType(Src).isVector() && "expected a vector as input");
+
+  if (Depth >= 2)
+    return false;
+
+  MachineInstr *SrcMI = getDefIgnoringCopies(Src, MRI);
+
+  // If Src is def'd by a binary operator,
+  // then scalarizing the op is cheap when one of its operands is cheap to
+  // scalarize.
+  if (auto *BinOp = dyn_cast<GBinOp>(SrcMI))
+    if (MRI.hasOneNonDBGUse(BinOp->getReg(0)))
+      if (isCheapToScalarize(BinOp->getLHSReg(), Index, Depth + 1) ||
+          isCheapToScalarize(BinOp->getRHSReg(), Index, Depth + 1))
+        return true;
+
+  // If Src is def'd by a compare,
+  // then scalarizing the cmp is cheap when one of its operands is cheap to
+  // scalarize.
+  if (auto *Cmp = dyn_cast<GAnyCmp>(SrcMI))
+    if (MRI.hasOneNonDBGUse(Cmp->getReg(0)))
+      if (isCheapToScalarize(Cmp->getLHSReg(), Index, Depth + 1) ||
+          isCheapToScalarize(Cmp->getRHSReg(), Index, Depth + 1))
+        return true;
+
+  // FIXME: unary operator
+  // FIXME: casts
+  // FIXME: loads
+  // FIXME: subvector
+
+  if (Index)
+    // If Index is constant, then Src is cheap to scalarize when it is constant
+    // at offset Index.
+    return isConstantAtOffset(Src, *Index, Depth);
+
+  return false;
+}
+
+bool CombinerHelper::matchExtractVectorElementWithICmp(const MachineOperand &MO,
+                                                       BuildFnTy &MatchInfo) {
+  GExtractVectorElement *Extract =
+      cast<GExtractVectorElement>(MRI.getVRegDef(MO.getReg()));
+
+  Register Vector = Extract->getVectorReg();
+
+  GICmp *Cmp = cast<GICmp>(MRI.getVRegDef(Vector));
+
+  std::optional<ValueAndVReg> MaybeIndex =
+      getIConstantVRegValWithLookThrough(Extract->getIndexReg(), MRI);
+  std::optional<APInt> IndexC = std::nullopt;
+
+  if (MaybeIndex)
+    IndexC = MaybeIndex->Value;
+
+  if (!isCheapToScalarize(Vector, IndexC))
+    return false;
+
+  if (!MRI.hasOneNonDBGUse(Cmp->getReg(0)))
+    return false;
+
+  Register Dst = Extract->getReg(0);
+  LLT DstTy = MRI.getType(Dst);
+  LLT IdxTy = MRI.getType(Extract->getIndexReg());
+  LLT VectorTy = MRI.getType(Cmp->getLHSReg());
+  LLT ExtractDstTy = VectorTy.getScalarType();
+
+  if (!isLegalOrBeforeLegalizer(
+          {TargetOpcode::G_ICMP, {DstTy, ExtractDstTy}}) ||
+      !isLegalOrBeforeLegalizer({TargetOpcode::G_EXTRACT_VECTOR_ELT,
+                                 {ExtractDstTy, VectorTy, IdxTy}}))
+    return false;
+
+  MatchInfo = [=](MachineIRBuilder &B) {
+    auto LHS = B.buildExtractVectorElement(ExtractDstTy, Cmp->getLHSReg(),
+                                           Extract->getIndexReg());
+    auto RHS = B.buildExtractVectorElement(ExtractDstTy, Cmp->getRHSReg(),
+                                           Extract->getIndexReg());
+    B.buildICmp(Cmp->getCond(), Dst, LHS, RHS);
+  };
+
+  return true;
+}
+
+bool CombinerHelper::matchExtractVectorElementWithFCmp(const MachineOperand &MO,
+                                                       BuildFnTy &MatchInfo) {
+  GExtractVectorElement *Extract =
+      cast<GExtractVectorElement>(MRI.getVRegDef(MO.getReg()));
+
+  Register Vector = Extract->getVectorReg();
+
+  GFCmp *Cmp = cast<GFCmp>(MRI.getVRegDef(Vector));
+
+  std::optional<ValueAndVReg> MaybeIndex =
+      getIConstantVRegValWithLookThrough(Extract->getIndexReg(), MRI);
+  std::optional<APInt> IndexC = std::nullopt;
+
+  if (MaybeIndex)
+    IndexC = MaybeIndex->Value;
+
+  if (!isCheapToScalarize(Vector, IndexC))
+    return false;
+
+  if (!MRI.hasOneNonDBGUse(Cmp->getReg(0)))
+    return false;
+
+  Register Dst = Extract->getReg(0);
+  LLT DstTy = MRI.getType(Dst);
+  LLT IdxTy = MRI.getType(Extract->getIndexReg());
+  LLT VectorTy = MRI.getType(Cmp->getLHSReg());
+  LLT ExtractDstTy = VectorTy.getScalarType();
+
+  if (!isLegalOrBeforeLegalizer(
+          {TargetOpcode::G_FCMP, {DstTy, ExtractDstTy}}) ||
+      !isLegalOrBeforeLegalizer({TargetOpcode::G_EXTRACT_VECTOR_ELT,
+                                 {ExtractDstTy, VectorTy, IdxTy}}))
+    return false;
+
+  MatchInfo = [=](MachineIRBuilder &B) {
+    auto LHS = B.buildExtractVectorElement(ExtractDstTy, Cmp->getLHSReg(),
+                                           Extract->getIndexReg());
+    auto RHS = B.buildExtractVectorElement(ExtractDstTy, Cmp->getRHSReg(),
+                                           Extract->getIndexReg());
+    B.buildFCmp(Cmp->getCond(), Dst, LHS, RHS, Cmp->getFlags());
+  };
+
+  return true;
+}
diff --git a/llvm/test/CodeGen/AArch64/extract-vector-elt.ll b/llvm/test/CodeGen/AArch64/extract-vector-elt.ll
index 0481d997d24fa..42fe5e82cb7de 100644
--- a/llvm/test/CodeGen/AArch64/extract-vector-elt.ll
+++ b/llvm/test/CodeGen/AArch64/extract-vector-elt.ll
@@ -1100,4 +1100,132 @@ ret:
   ret i32 %3
 }
 
+define i32 @extract_v4float_fcmp_const_no_zext(<4 x float> %a, <4 x float> %b, i32 %c) {
+; CHECK-SD-LABEL: extract_v4float_fcmp_const_no_zext:
+; CHECK-SD:       // %bb.0: // %entry
+; CHECK-SD-NEXT:    fcmeq v0.4s, v0.4s, v0.4s
+; CHECK-SD-NEXT:    mvn v0.16b, v0.16b
+; CHECK-SD-NEXT:    xtn v0.4h, v0.4s
+; CHECK-SD-NEXT:    umov w8, v0.h[1]
+; CHECK-SD-NEXT:    and w0, w8, #0x1
+; CHECK-SD-NEXT:    ret
+;
+; CHECK-GI-LABEL: extract_v4float_fcmp_const_no_zext:
+; CHECK-GI:       // %bb.0: // %entry
+; CHECK-GI-NEXT:    mov s0, v0.s[1]
+; CHECK-GI-NEXT:    fmov s1, #1.00000000
+; CHECK-GI-NEXT:    fcmp s0, s1
+; CHECK-GI-NEXT:    cset w0, vs
+; CHECK-GI-NEXT:    ret
+entry:
+  %vector = fcmp uno <4 x float> %a, <float 1.0, float 1.0, float 1.0, float 1.0>
+  %d = extractelement <4 x i1> %vector, i32 1
+  %z = zext i1 %d to i32
+  ret i32 %z
+}
 
+define i32 @extract_v4i32_icmp_const_no_zext(<4 x i32> %a, <4 x i32> %b, i32 %c) {
+; CHECK-SD-LABEL: extract_v4i32_icmp_const_no_zext:
+; CHECK-SD:       // %bb.0: // %entry
+; CHECK-SD-NEXT:    adrp x8, .LCPI43_0
+; CHECK-SD-NEXT:    ldr q1, [x8, :lo12:.LCPI43_0]
+; CHECK-SD-NEXT:    cmge v0.4s, v1.4s, v0.4s
+; CHECK-SD-NEXT:    xtn v0.4h, v0.4s
+; CHECK-SD-NEXT:    umov w8, v0.h[1]
+; CHECK-SD-NEXT:    and w0, w8, #0x1
+; CHECK-SD-NEXT:    ret
+;
+; CHECK-GI-LABEL: extract_v4i32_icmp_const_no_zext:
+; CHECK-GI:       // %bb.0: // %entry
+; CHECK-GI-NEXT:    mov s0, v0.s[1]
+; CHECK-GI-NEXT:    fmov w8, s0
+; CHECK-GI-NEXT:    cmp w8, #8
+; CHECK-GI-NEXT:    cset w0, le
+; CHECK-GI-NEXT:    ret
+entry:
+  %vector = icmp sle <4 x i32> %a, <i32 7, i32 8, i32 7, i32 9>
+  %d = extractelement <4 x i1> %vector, i32 1
+  %z = zext i1 %d to i32
+  ret i32 %z
+}
+
+define i32 @extract_v4float_fcmp_const_no_zext_fail(<4 x float> %a, <4 x float> %b, i32 %c) {
+; CHECK-SD-LABEL: extract_v4float_fcmp_const_no_zext_fail:
+; CHECK-SD:       // %bb.0: // %entry
+; CHECK-SD-NEXT:    sub sp, sp, #16
+; CHECK-SD-NEXT:    .cfi_def_cfa_offset 16
+; CHECK-SD-NEXT:    fcmeq v0.4s, v0.4s, v0.4s
+; CHECK-SD-NEXT:    add x8, sp, #8
+; CHECK-SD-NEXT:    // kill: def $w0 killed $w0 def $x0
+; CHECK-SD-NEXT:    bfi x8, x0, #1, #2
+; CHECK-SD-NEXT:    mvn v0.16b, v0.16b
+; CHECK-SD-NEXT:    xtn v0.4h, v0.4s
+; CHECK-SD-NEXT:    str d0, [sp, #8]
+; CHECK-SD-NEXT:    ldrh w8, [x8]
+; CHECK-SD-NEXT:    and w0, w8, #0x1
+; CHECK-SD-NEXT:    add sp, sp, #16
+; CHECK-SD-NEXT:    ret
+;
+; CHECK-GI-LABEL: extract_v4float_fcmp_const_no_zext_fail:
+; CHECK-GI:       // %bb.0: // %entry
+; CHECK-GI-NEXT:    sub sp, sp, #16
+; CHECK-GI-NEXT:    .cfi_def_cfa_offset 16
+; CHECK-GI-NEXT:    fmov v1.4s, #1.00000000
+; CHECK-GI-NEXT:    mov w8, w0
+; CHECK-GI-NEXT:    mov x9, sp
+; CHECK-GI-NEXT:    and x8, x8, #0x3
+; CHECK-GI-NEXT:    fcmge v2.4s, v0.4s, v1.4s
+; CHECK-GI-NEXT:    fcmgt v0.4s, v1.4s, v0.4s
+; CHECK-GI-NEXT:    orr v0.16b, v0.16b, v2.16b
+; CHECK-GI-NEXT:    mvn v0.16b, v0.16b
+; CHECK-GI-NEXT:    str q0, [sp]
+; CHECK-GI-NEXT:    ldr w8, [x9, x8, lsl #2]
+; CHECK-GI-NEXT:    and w0, w8, #0x1
+; CHECK-GI-NEXT:    add sp, sp, #16
+; CHECK-GI-NEXT:    ret
+entry:
+  %vector = fcmp uno <4 x float> %a, <float 1.0, float 1.0, float 1.0, float 1.0>
+  %d = extractelement <4 x i1> %vector, i32 %c
+  %z = zext i1 %d to i32
+  ret i32 %z
+}
+
+define i32 @extract_v4i32_icmp_const_no_zext_fail(<4 x i32> %a, <4 x i32> %b, i32 %c) {
+; CHECK-SD-LABEL: extract_v4i32_icmp_const_no_zext_fail:
+; CHECK-SD:       // %bb.0: // %entry
+; CHECK-SD-NEXT:    sub sp, sp, #16
+; CHECK-SD-NEXT:    .cfi_def_cfa_offset 16
+; CHECK-SD-NEXT:    adrp x8, .LCPI45_0
+; CHECK-SD-NEXT:    // kill: def $w0 killed $w0 def $x0
+; CHECK-SD-NEXT:    ldr q1, [x8, :lo12:.LCPI45_0]
+; CHECK-SD-NEXT:    add x8, sp, #8
+; CHECK-SD-NEXT:    bfi x8, x0, #1, #2
+; CHECK-SD-NEXT:    cmge v0.4s, v1.4s, v0.4s
+; CHECK-SD-NEXT:    xtn v0.4h, v0.4s
+; CHECK-SD-NEXT:    str d0, [sp, #8]
+; CHECK-SD-NEXT:    ldrh w8, [x8]
+; CHECK-SD-NEXT:    and w0, w8, #0x1
+; CHECK-SD-NEXT:    add sp, sp, #16
+; CHECK-SD-NEXT:    ret
+;
+; CHECK-GI-LABEL: extract_v4i32_icmp_const_no_zext_fail:
+; CHECK-GI:       // %bb.0: // %entry
+; CHECK-GI-NEXT:    sub sp, sp, #16
+; CHECK-GI-NEXT:    .cfi_def_cfa_offset 16
+; CHECK-GI-NEXT:    adrp x8, .LCPI45_0
+; CHECK-GI-NEXT:    mov x9, sp
+; CHECK-GI-NEXT:    ldr q1, [x8, :lo12:.LCPI45_0]
+; CHECK-GI-NEXT:    mov w8, w0
+; CHECK-GI-NEXT:    and x8, x8, #0x3
+; CHECK-GI-NEXT:    cmge v0.4s, v1.4s, v0.4s
+; CHECK-GI-NEXT:    str q0, [sp]
+; CHECK-GI-NEXT:    ldr w8, [x9, x8, lsl #2]
+; CHECK-GI-NEXT:    and w0, w8, #0x1
+; CHECK-GI-NEXT:    add sp, sp, #16
+; CHECK-GI-NEXT:    ret
+entry:
+  %vector = icmp sle <4 x i32> %a, <i32 7, i32 8, i32 7, i32 9>
+  %d = extractelement <4 x i1> %vector, i32 %c
+  %z = zext i1 %d to i32
+  ret i32 %z
+}

@llvmbot
Copy link
Member

llvmbot commented May 13, 2024

@llvm/pr-subscribers-llvm-globalisel

Author: Thorsten Schütt (tschuett)

Changes

scalarize compares

extelt (cmp X, Y), Index --> cmp (extelt X, Index),
(extelt Y, Index)


Full diff: https://github.com/llvm/llvm-project/pull/91922.diff

4 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h (+23)
  • (modified) llvm/include/llvm/Target/GlobalISel/Combine.td (+18-1)
  • (modified) llvm/lib/CodeGen/GlobalISel/CombinerHelperVectorOps.cpp (+159)
  • (modified) llvm/test/CodeGen/AArch64/extract-vector-elt.ll (+128)
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
index ecaece8b68342..6edb3f9cd2e89 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
@@ -1,3 +1,4 @@
+
 //===-- llvm/CodeGen/GlobalISel/CombinerHelper.h --------------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
@@ -866,6 +867,16 @@ class CombinerHelper {
   /// Combine insert vector element OOB.
   bool matchInsertVectorElementOOB(MachineInstr &MI, BuildFnTy &MatchInfo);
 
+  /// Combine extract vector element with a compare on the vector
+  /// register.
+  bool matchExtractVectorElementWithICmp(const MachineOperand &MO,
+                                         BuildFnTy &MatchInfo);
+
+  /// Combine extract vector element with a compare on the vector
+  /// register.
+  bool matchExtractVectorElementWithFCmp(const MachineOperand &MO,
+                                         BuildFnTy &MatchInfo);
+
 private:
   /// Checks for legality of an indexed variant of \p LdSt.
   bool isIndexedLoadStoreLegal(GLoadStore &LdSt) const;
@@ -981,6 +992,18 @@ class CombinerHelper {
 
   // Simplify (cmp cc0 x, y) (&& or ||) (cmp cc1 x, y) -> cmp cc2 x, y.
   bool tryFoldLogicOfFCmps(GLogicalBinOp *Logic, BuildFnTy &MatchInfo);
+
+  /// Return true if the register \p Src is cheaper to scalarize than it is to
+  /// leave as a vector operation. If the extract index \p Index is a constant
+  /// integer then some operations may be cheap to scalarize. The depth \p Depth
+  /// prevents arbitrary recursion.
+  bool isCheapToScalarize(Register Src, const std::optional<APInt> &Index,
+                          unsigned Depth = 0);
+
+  /// Return true if \p Src is def'd by a operation of type vector that is
+  /// constant at offset \p Index. \p Depth limits arbitrary recursion into look
+  /// through vector operations.
+  bool isConstantAtOffset(Register Src, const APInt &Index, unsigned Depth = 0);
 };
 } // namespace llvm
 
diff --git a/llvm/include/llvm/Target/GlobalISel/Combine.td b/llvm/include/llvm/Target/GlobalISel/Combine.td
index 98d266c8c0b4f..3c71c2a25b2d9 100644
--- a/llvm/include/llvm/Target/GlobalISel/Combine.td
+++ b/llvm/include/llvm/Target/GlobalISel/Combine.td
@@ -1591,6 +1591,20 @@ def insert_vector_elt_oob : GICombineRule<
          [{ return Helper.matchInsertVectorElementOOB(*${root}, ${matchinfo}); }]),
   (apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;
 
+def extract_vector_element_icmp : GICombineRule<
+  (defs root:$root, build_fn_matchinfo:$matchinfo),
+   (match (G_ICMP $src, $pred, $lhs, $rhs),
+          (G_EXTRACT_VECTOR_ELT $root, $src, $idx),
+   [{ return Helper.matchExtractVectorElementWithICmp(${root}, ${matchinfo}); }]),
+   (apply [{ Helper.applyBuildFnMO(${root}, ${matchinfo}); }])>;
+
+def extract_vector_element_fcmp : GICombineRule<
+  (defs root:$root, build_fn_matchinfo:$matchinfo),
+  (match (G_FCMP $fsrc, $fpred, $flhs, $frhs),
+         (G_EXTRACT_VECTOR_ELT $root, $fsrc, $fidx),
+  [{ return Helper.matchExtractVectorElementWithFCmp(${root}, ${matchinfo}); }]),
+  (apply [{ Helper.applyBuildFnMO(${root}, ${matchinfo}); }])>;
+
 // match_extract_of_element and insert_vector_elt_oob must be the first!
 def vector_ops_combines: GICombineGroup<[
 match_extract_of_element_undef_vector,
@@ -1624,6 +1638,8 @@ extract_vector_element_build_vector_trunc7,
 extract_vector_element_build_vector_trunc8,
 extract_vector_element_freeze,
 extract_vector_element_shuffle_vector,
+extract_vector_element_icmp,
+extract_vector_element_fcmp,
 insert_vector_element_extract_vector_element
 ]>;
 
@@ -1706,7 +1722,8 @@ def all_combines : GICombineGroup<[trivial_combines, vector_ops_combines,
     sub_add_reg, select_to_minmax, redundant_binop_in_equality,
     fsub_to_fneg, commute_constant_to_rhs, match_ands, match_ors,
     combine_concat_vector, double_icmp_zero_and_or_combine, match_addos,
-    sext_trunc, zext_trunc, combine_shuffle_concat]>;
+    sext_trunc, zext_trunc, combine_shuffle_concat
+]>;
 
 // A combine group used to for prelegalizer combiners at -O0. The combines in
 // this group have been selected based on experiments to balance code size and
diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelperVectorOps.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelperVectorOps.cpp
index 21b1eb2628174..64b39e3f82e65 100644
--- a/llvm/lib/CodeGen/GlobalISel/CombinerHelperVectorOps.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelperVectorOps.cpp
@@ -453,3 +453,162 @@ bool CombinerHelper::matchInsertVectorElementOOB(MachineInstr &MI,
 
   return false;
 }
+
+bool CombinerHelper::isConstantAtOffset(Register Src, const APInt &Index,
+                                        unsigned Depth) {
+  assert(MRI.getType(Src).isVector() && "expected a vector as input");
+  if (Depth == 2)
+    return false;
+
+  // We use the look through variant for higher hit rate and to increase the
+  // likelyhood of constant folding. The actual value is ignored. We only test
+  // *whether* there is a constant.
+
+  MachineInstr *SrcMI = getDefIgnoringCopies(Src, MRI);
+
+  // If Src is def'd by build vector, then we check the constness at the offset.
+  if (auto *Build = dyn_cast<GBuildVector>(SrcMI))
+    return getAnyConstantVRegValWithLookThrough(
+               Build->getSourceReg(Index.getZExtValue()), MRI)
+        .has_value();
+
+  // For concat and shuffle vectors, we could recurse.
+  // FIXME concat vectors
+  // FIXME shuffle vectors
+  // FIXME unary ops
+  // FIXME insert vector element
+  // FIXME subvector
+
+  return false;
+}
+
+bool CombinerHelper::isCheapToScalarize(Register Src,
+                                        const std::optional<APInt> &Index,
+                                        unsigned Depth) {
+  assert(MRI.getType(Src).isVector() && "expected a vector as input");
+
+  if (Depth >= 2)
+    return false;
+
+  MachineInstr *SrcMI = getDefIgnoringCopies(Src, MRI);
+
+  // If Src is def'd by a binary operator,
+  // then scalarizing the op is cheap when one of its operands is cheap to
+  // scalarize.
+  if (auto *BinOp = dyn_cast<GBinOp>(SrcMI))
+    if (MRI.hasOneNonDBGUse(BinOp->getReg(0)))
+      if (isCheapToScalarize(BinOp->getLHSReg(), Index, Depth + 1) ||
+          isCheapToScalarize(BinOp->getRHSReg(), Index, Depth + 1))
+        return true;
+
+  // If Src is def'd by a compare,
+  // then scalarizing the cmp is cheap when one of its operands is cheap to
+  // scalarize.
+  if (auto *Cmp = dyn_cast<GAnyCmp>(SrcMI))
+    if (MRI.hasOneNonDBGUse(Cmp->getReg(0)))
+      if (isCheapToScalarize(Cmp->getLHSReg(), Index, Depth + 1) ||
+          isCheapToScalarize(Cmp->getRHSReg(), Index, Depth + 1))
+        return true;
+
+  // FIXME: unary operator
+  // FIXME: casts
+  // FIXME: loads
+  // FIXME: subvector
+
+  if (Index)
+    // If Index is constant, then Src is cheap to scalarize when it is constant
+    // at offset Index.
+    return isConstantAtOffset(Src, *Index, Depth);
+
+  return false;
+}
+
+bool CombinerHelper::matchExtractVectorElementWithICmp(const MachineOperand &MO,
+                                                       BuildFnTy &MatchInfo) {
+  GExtractVectorElement *Extract =
+      cast<GExtractVectorElement>(MRI.getVRegDef(MO.getReg()));
+
+  Register Vector = Extract->getVectorReg();
+
+  GICmp *Cmp = cast<GICmp>(MRI.getVRegDef(Vector));
+
+  std::optional<ValueAndVReg> MaybeIndex =
+      getIConstantVRegValWithLookThrough(Extract->getIndexReg(), MRI);
+  std::optional<APInt> IndexC = std::nullopt;
+
+  if (MaybeIndex)
+    IndexC = MaybeIndex->Value;
+
+  if (!isCheapToScalarize(Vector, IndexC))
+    return false;
+
+  if (!MRI.hasOneNonDBGUse(Cmp->getReg(0)))
+    return false;
+
+  Register Dst = Extract->getReg(0);
+  LLT DstTy = MRI.getType(Dst);
+  LLT IdxTy = MRI.getType(Extract->getIndexReg());
+  LLT VectorTy = MRI.getType(Cmp->getLHSReg());
+  LLT ExtractDstTy = VectorTy.getScalarType();
+
+  if (!isLegalOrBeforeLegalizer(
+          {TargetOpcode::G_ICMP, {DstTy, ExtractDstTy}}) ||
+      !isLegalOrBeforeLegalizer({TargetOpcode::G_EXTRACT_VECTOR_ELT,
+                                 {ExtractDstTy, VectorTy, IdxTy}}))
+    return false;
+
+  MatchInfo = [=](MachineIRBuilder &B) {
+    auto LHS = B.buildExtractVectorElement(ExtractDstTy, Cmp->getLHSReg(),
+                                           Extract->getIndexReg());
+    auto RHS = B.buildExtractVectorElement(ExtractDstTy, Cmp->getRHSReg(),
+                                           Extract->getIndexReg());
+    B.buildICmp(Cmp->getCond(), Dst, LHS, RHS);
+  };
+
+  return true;
+}
+
+bool CombinerHelper::matchExtractVectorElementWithFCmp(const MachineOperand &MO,
+                                                       BuildFnTy &MatchInfo) {
+  GExtractVectorElement *Extract =
+      cast<GExtractVectorElement>(MRI.getVRegDef(MO.getReg()));
+
+  Register Vector = Extract->getVectorReg();
+
+  GFCmp *Cmp = cast<GFCmp>(MRI.getVRegDef(Vector));
+
+  std::optional<ValueAndVReg> MaybeIndex =
+      getIConstantVRegValWithLookThrough(Extract->getIndexReg(), MRI);
+  std::optional<APInt> IndexC = std::nullopt;
+
+  if (MaybeIndex)
+    IndexC = MaybeIndex->Value;
+
+  if (!isCheapToScalarize(Vector, IndexC))
+    return false;
+
+  if (!MRI.hasOneNonDBGUse(Cmp->getReg(0)))
+    return false;
+
+  Register Dst = Extract->getReg(0);
+  LLT DstTy = MRI.getType(Dst);
+  LLT IdxTy = MRI.getType(Extract->getIndexReg());
+  LLT VectorTy = MRI.getType(Cmp->getLHSReg());
+  LLT ExtractDstTy = VectorTy.getScalarType();
+
+  if (!isLegalOrBeforeLegalizer(
+          {TargetOpcode::G_FCMP, {DstTy, ExtractDstTy}}) ||
+      !isLegalOrBeforeLegalizer({TargetOpcode::G_EXTRACT_VECTOR_ELT,
+                                 {ExtractDstTy, VectorTy, IdxTy}}))
+    return false;
+
+  MatchInfo = [=](MachineIRBuilder &B) {
+    auto LHS = B.buildExtractVectorElement(ExtractDstTy, Cmp->getLHSReg(),
+                                           Extract->getIndexReg());
+    auto RHS = B.buildExtractVectorElement(ExtractDstTy, Cmp->getRHSReg(),
+                                           Extract->getIndexReg());
+    B.buildFCmp(Cmp->getCond(), Dst, LHS, RHS, Cmp->getFlags());
+  };
+
+  return true;
+}
diff --git a/llvm/test/CodeGen/AArch64/extract-vector-elt.ll b/llvm/test/CodeGen/AArch64/extract-vector-elt.ll
index 0481d997d24fa..42fe5e82cb7de 100644
--- a/llvm/test/CodeGen/AArch64/extract-vector-elt.ll
+++ b/llvm/test/CodeGen/AArch64/extract-vector-elt.ll
@@ -1100,4 +1100,132 @@ ret:
   ret i32 %3
 }
 
+define i32 @extract_v4float_fcmp_const_no_zext(<4 x float> %a, <4 x float> %b, i32 %c) {
+; CHECK-SD-LABEL: extract_v4float_fcmp_const_no_zext:
+; CHECK-SD:       // %bb.0: // %entry
+; CHECK-SD-NEXT:    fcmeq v0.4s, v0.4s, v0.4s
+; CHECK-SD-NEXT:    mvn v0.16b, v0.16b
+; CHECK-SD-NEXT:    xtn v0.4h, v0.4s
+; CHECK-SD-NEXT:    umov w8, v0.h[1]
+; CHECK-SD-NEXT:    and w0, w8, #0x1
+; CHECK-SD-NEXT:    ret
+;
+; CHECK-GI-LABEL: extract_v4float_fcmp_const_no_zext:
+; CHECK-GI:       // %bb.0: // %entry
+; CHECK-GI-NEXT:    mov s0, v0.s[1]
+; CHECK-GI-NEXT:    fmov s1, #1.00000000
+; CHECK-GI-NEXT:    fcmp s0, s1
+; CHECK-GI-NEXT:    cset w0, vs
+; CHECK-GI-NEXT:    ret
+entry:
+  %vector = fcmp uno <4 x float> %a, <float 1.0, float 1.0, float 1.0, float 1.0>
+  %d = extractelement <4 x i1> %vector, i32 1
+  %z = zext i1 %d to i32
+  ret i32 %z
+}
 
+define i32 @extract_v4i32_icmp_const_no_zext(<4 x i32> %a, <4 x i32> %b, i32 %c) {
+; CHECK-SD-LABEL: extract_v4i32_icmp_const_no_zext:
+; CHECK-SD:       // %bb.0: // %entry
+; CHECK-SD-NEXT:    adrp x8, .LCPI43_0
+; CHECK-SD-NEXT:    ldr q1, [x8, :lo12:.LCPI43_0]
+; CHECK-SD-NEXT:    cmge v0.4s, v1.4s, v0.4s
+; CHECK-SD-NEXT:    xtn v0.4h, v0.4s
+; CHECK-SD-NEXT:    umov w8, v0.h[1]
+; CHECK-SD-NEXT:    and w0, w8, #0x1
+; CHECK-SD-NEXT:    ret
+;
+; CHECK-GI-LABEL: extract_v4i32_icmp_const_no_zext:
+; CHECK-GI:       // %bb.0: // %entry
+; CHECK-GI-NEXT:    mov s0, v0.s[1]
+; CHECK-GI-NEXT:    fmov w8, s0
+; CHECK-GI-NEXT:    cmp w8, #8
+; CHECK-GI-NEXT:    cset w0, le
+; CHECK-GI-NEXT:    ret
+entry:
+  %vector = icmp sle <4 x i32> %a, <i32 7, i32 8, i32 7, i32 9>
+  %d = extractelement <4 x i1> %vector, i32 1
+  %z = zext i1 %d to i32
+  ret i32 %z
+}
+
+define i32 @extract_v4float_fcmp_const_no_zext_fail(<4 x float> %a, <4 x float> %b, i32 %c) {
+; CHECK-SD-LABEL: extract_v4float_fcmp_const_no_zext_fail:
+; CHECK-SD:       // %bb.0: // %entry
+; CHECK-SD-NEXT:    sub sp, sp, #16
+; CHECK-SD-NEXT:    .cfi_def_cfa_offset 16
+; CHECK-SD-NEXT:    fcmeq v0.4s, v0.4s, v0.4s
+; CHECK-SD-NEXT:    add x8, sp, #8
+; CHECK-SD-NEXT:    // kill: def $w0 killed $w0 def $x0
+; CHECK-SD-NEXT:    bfi x8, x0, #1, #2
+; CHECK-SD-NEXT:    mvn v0.16b, v0.16b
+; CHECK-SD-NEXT:    xtn v0.4h, v0.4s
+; CHECK-SD-NEXT:    str d0, [sp, #8]
+; CHECK-SD-NEXT:    ldrh w8, [x8]
+; CHECK-SD-NEXT:    and w0, w8, #0x1
+; CHECK-SD-NEXT:    add sp, sp, #16
+; CHECK-SD-NEXT:    ret
+;
+; CHECK-GI-LABEL: extract_v4float_fcmp_const_no_zext_fail:
+; CHECK-GI:       // %bb.0: // %entry
+; CHECK-GI-NEXT:    sub sp, sp, #16
+; CHECK-GI-NEXT:    .cfi_def_cfa_offset 16
+; CHECK-GI-NEXT:    fmov v1.4s, #1.00000000
+; CHECK-GI-NEXT:    mov w8, w0
+; CHECK-GI-NEXT:    mov x9, sp
+; CHECK-GI-NEXT:    and x8, x8, #0x3
+; CHECK-GI-NEXT:    fcmge v2.4s, v0.4s, v1.4s
+; CHECK-GI-NEXT:    fcmgt v0.4s, v1.4s, v0.4s
+; CHECK-GI-NEXT:    orr v0.16b, v0.16b, v2.16b
+; CHECK-GI-NEXT:    mvn v0.16b, v0.16b
+; CHECK-GI-NEXT:    str q0, [sp]
+; CHECK-GI-NEXT:    ldr w8, [x9, x8, lsl #2]
+; CHECK-GI-NEXT:    and w0, w8, #0x1
+; CHECK-GI-NEXT:    add sp, sp, #16
+; CHECK-GI-NEXT:    ret
+entry:
+  %vector = fcmp uno <4 x float> %a, <float 1.0, float 1.0, float 1.0, float 1.0>
+  %d = extractelement <4 x i1> %vector, i32 %c
+  %z = zext i1 %d to i32
+  ret i32 %z
+}
+
+define i32 @extract_v4i32_icmp_const_no_zext_fail(<4 x i32> %a, <4 x i32> %b, i32 %c) {
+; CHECK-SD-LABEL: extract_v4i32_icmp_const_no_zext_fail:
+; CHECK-SD:       // %bb.0: // %entry
+; CHECK-SD-NEXT:    sub sp, sp, #16
+; CHECK-SD-NEXT:    .cfi_def_cfa_offset 16
+; CHECK-SD-NEXT:    adrp x8, .LCPI45_0
+; CHECK-SD-NEXT:    // kill: def $w0 killed $w0 def $x0
+; CHECK-SD-NEXT:    ldr q1, [x8, :lo12:.LCPI45_0]
+; CHECK-SD-NEXT:    add x8, sp, #8
+; CHECK-SD-NEXT:    bfi x8, x0, #1, #2
+; CHECK-SD-NEXT:    cmge v0.4s, v1.4s, v0.4s
+; CHECK-SD-NEXT:    xtn v0.4h, v0.4s
+; CHECK-SD-NEXT:    str d0, [sp, #8]
+; CHECK-SD-NEXT:    ldrh w8, [x8]
+; CHECK-SD-NEXT:    and w0, w8, #0x1
+; CHECK-SD-NEXT:    add sp, sp, #16
+; CHECK-SD-NEXT:    ret
+;
+; CHECK-GI-LABEL: extract_v4i32_icmp_const_no_zext_fail:
+; CHECK-GI:       // %bb.0: // %entry
+; CHECK-GI-NEXT:    sub sp, sp, #16
+; CHECK-GI-NEXT:    .cfi_def_cfa_offset 16
+; CHECK-GI-NEXT:    adrp x8, .LCPI45_0
+; CHECK-GI-NEXT:    mov x9, sp
+; CHECK-GI-NEXT:    ldr q1, [x8, :lo12:.LCPI45_0]
+; CHECK-GI-NEXT:    mov w8, w0
+; CHECK-GI-NEXT:    and x8, x8, #0x3
+; CHECK-GI-NEXT:    cmge v0.4s, v1.4s, v0.4s
+; CHECK-GI-NEXT:    str q0, [sp]
+; CHECK-GI-NEXT:    ldr w8, [x9, x8, lsl #2]
+; CHECK-GI-NEXT:    and w0, w8, #0x1
+; CHECK-GI-NEXT:    add sp, sp, #16
+; CHECK-GI-NEXT:    ret
+entry:
+  %vector = icmp sle <4 x i32> %a, <i32 7, i32 8, i32 7, i32 9>
+  %d = extractelement <4 x i1> %vector, i32 %c
+  %z = zext i1 %d to i32
+  ret i32 %z
+}

@tschuett tschuett requested review from aemerson and arsenm May 13, 2024 07:13
Copy link
Contributor

@arsenm arsenm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SelectionDAG seems to not do this combine, and InstCombine does. Is this pattern coming from somewhere in legalization?

Comment on lines +463 to +467
// We use the look through variant for higher hit rate and to increase the
// likelyhood of constant folding. The actual value is ignored. We only test
// *whether* there is a constant.

MachineInstr *SrcMI = getDefIgnoringCopies(Src, MRI);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this should be taken care of by copy->copy folding? I think we should mostly be removing all the getDefIgnoringCopies calls

@tschuett
Copy link
Author

It comes out of the middle-end. It is a combine to free the vector register by an instruction.

@tschuett
Copy link
Author

%vector = icmp sle <4 x i32> %a, <i32 42, i32 11, i32 17, i32 6>

The tricky case with an ext in between.

@tschuett
Copy link
Author

We should add a Combiner.rst:

  • Each combine is a precise pattern in Combine.td, for compile-time and targets can pick their choice of combines
  • There are no visit functions
  • Document the differences of the Gisel Combiner and InstCombine, i.e., the kind of combines and aggressiveness
  • The Gisel combiner is a pure legalization artefact combiner?

The DAGCombiner seems to do more than combining legalization artefacts.

@tschuett
Copy link
Author

From the Gisel documentation:

Combiner

Replaces patterns of instructions with a better alternative. Typically, this means improving run time performance by replacing instructions with faster alternatives but Combiners can also focus on code size or other metrics.

@aemerson
Copy link
Contributor

I think the key here is that combines have cost, in both compile time (every failed match is a time wasted) and in maintenance cost.

Using clang as the archetypal user of LLVM and assuming that other front-ends build similar pass manager pipelines, we should see if a piece of IR after optimizations still has room for improvement during codegen. Sometimes this is due to the fact that some combines only become visible to the optimizer after certain transformations, so they're phase-ordering sensitive. That's fine IMO. Likewise if certain transforms need more target specific information that we gain during codegen (perhaps because legalization makes some costs clearer or expands MIR and thus exposes more opportunity).

But I think a combine that never fires because in an exceedingly vast majority of cases the pattern is always optimized away at the IR level doesn't justify its existence. That said, these should only be guidelines IMO and we don't have to be dogmatic about it.

I agree that documentation would be useful here.

@tschuett
Copy link
Author

That is fine by me. As long as we document the combiner and what kind of combines it has.

@aemerson
Copy link
Contributor

Filed #92309 to remind us to do this

@tschuett tschuett closed this May 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants