Skip to content

[GlobalIsel] combine extract vector element #91922

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,16 @@ class CombinerHelper {
/// Combine insert vector element OOB.
bool matchInsertVectorElementOOB(MachineInstr &MI, BuildFnTy &MatchInfo);

/// Combine extract vector element with a compare on the vector
/// register.
bool matchExtractVectorElementWithICmp(const MachineOperand &MO,
BuildFnTy &MatchInfo);

/// Combine extract vector element with a compare on the vector
/// register.
bool matchExtractVectorElementWithFCmp(const MachineOperand &MO,
BuildFnTy &MatchInfo);

private:
/// Checks for legality of an indexed variant of \p LdSt.
bool isIndexedLoadStoreLegal(GLoadStore &LdSt) const;
Expand Down Expand Up @@ -981,6 +991,18 @@ class CombinerHelper {

// Simplify (cmp cc0 x, y) (&& or ||) (cmp cc1 x, y) -> cmp cc2 x, y.
bool tryFoldLogicOfFCmps(GLogicalBinOp *Logic, BuildFnTy &MatchInfo);

/// Return true if the register \p Src is cheaper to scalarize than it is to
/// leave as a vector operation. If the extract index \p Index is a constant
/// integer then some operations may be cheap to scalarize. The depth \p Depth
/// prevents arbitrary recursion.
bool isCheapToScalarize(Register Src, const std::optional<APInt> &Index,
unsigned Depth = 0);

/// Return true if \p Src is def'd by a operation of type vector that is
/// constant at offset \p Index. \p Depth limits arbitrary recursion into look
/// through vector operations.
bool isConstantAtOffset(Register Src, const APInt &Index, unsigned Depth = 0);
};
} // namespace llvm

Expand Down
16 changes: 16 additions & 0 deletions llvm/include/llvm/Target/GlobalISel/Combine.td
Original file line number Diff line number Diff line change
Expand Up @@ -1591,6 +1591,20 @@ def insert_vector_elt_oob : GICombineRule<
[{ return Helper.matchInsertVectorElementOOB(*${root}, ${matchinfo}); }]),
(apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;

def extract_vector_element_icmp : GICombineRule<
(defs root:$root, build_fn_matchinfo:$matchinfo),
(match (G_ICMP $src, $pred, $lhs, $rhs),
(G_EXTRACT_VECTOR_ELT $root, $src, $idx),
[{ return Helper.matchExtractVectorElementWithICmp(${root}, ${matchinfo}); }]),
(apply [{ Helper.applyBuildFnMO(${root}, ${matchinfo}); }])>;

def extract_vector_element_fcmp : GICombineRule<
(defs root:$root, build_fn_matchinfo:$matchinfo),
(match (G_FCMP $fsrc, $fpred, $flhs, $frhs),
(G_EXTRACT_VECTOR_ELT $root, $fsrc, $fidx),
[{ return Helper.matchExtractVectorElementWithFCmp(${root}, ${matchinfo}); }]),
(apply [{ Helper.applyBuildFnMO(${root}, ${matchinfo}); }])>;

// match_extract_of_element and insert_vector_elt_oob must be the first!
def vector_ops_combines: GICombineGroup<[
match_extract_of_element_undef_vector,
Expand Down Expand Up @@ -1624,6 +1638,8 @@ extract_vector_element_build_vector_trunc7,
extract_vector_element_build_vector_trunc8,
extract_vector_element_freeze,
extract_vector_element_shuffle_vector,
extract_vector_element_icmp,
extract_vector_element_fcmp,
insert_vector_element_extract_vector_element
]>;

Expand Down
159 changes: 159 additions & 0 deletions llvm/lib/CodeGen/GlobalISel/CombinerHelperVectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -453,3 +453,162 @@ bool CombinerHelper::matchInsertVectorElementOOB(MachineInstr &MI,

return false;
}

bool CombinerHelper::isConstantAtOffset(Register Src, const APInt &Index,
unsigned Depth) {
assert(MRI.getType(Src).isVector() && "expected a vector as input");
if (Depth == 2)
return false;

// We use the look through variant for higher hit rate and to increase the
// likelyhood of constant folding. The actual value is ignored. We only test
// *whether* there is a constant.

MachineInstr *SrcMI = getDefIgnoringCopies(Src, MRI);
Comment on lines +463 to +467
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this should be taken care of by copy->copy folding? I think we should mostly be removing all the getDefIgnoringCopies calls


// If Src is def'd by build vector, then we check the constness at the offset.
if (auto *Build = dyn_cast<GBuildVector>(SrcMI))
return getAnyConstantVRegValWithLookThrough(
Build->getSourceReg(Index.getZExtValue()), MRI)
.has_value();

// For concat and shuffle vectors, we could recurse.
// FIXME concat vectors
// FIXME shuffle vectors
// FIXME unary ops
// FIXME insert vector element
// FIXME subvector

return false;
}

bool CombinerHelper::isCheapToScalarize(Register Src,
const std::optional<APInt> &Index,
unsigned Depth) {
assert(MRI.getType(Src).isVector() && "expected a vector as input");

if (Depth >= 2)
return false;

MachineInstr *SrcMI = getDefIgnoringCopies(Src, MRI);

// If Src is def'd by a binary operator,
// then scalarizing the op is cheap when one of its operands is cheap to
// scalarize.
if (auto *BinOp = dyn_cast<GBinOp>(SrcMI))
if (MRI.hasOneNonDBGUse(BinOp->getReg(0)))
if (isCheapToScalarize(BinOp->getLHSReg(), Index, Depth + 1) ||
isCheapToScalarize(BinOp->getRHSReg(), Index, Depth + 1))
return true;

// If Src is def'd by a compare,
// then scalarizing the cmp is cheap when one of its operands is cheap to
// scalarize.
if (auto *Cmp = dyn_cast<GAnyCmp>(SrcMI))
if (MRI.hasOneNonDBGUse(Cmp->getReg(0)))
if (isCheapToScalarize(Cmp->getLHSReg(), Index, Depth + 1) ||
isCheapToScalarize(Cmp->getRHSReg(), Index, Depth + 1))
return true;

// FIXME: unary operator
// FIXME: casts
// FIXME: loads
// FIXME: subvector

if (Index)
// If Index is constant, then Src is cheap to scalarize when it is constant
// at offset Index.
return isConstantAtOffset(Src, *Index, Depth);

return false;
}

bool CombinerHelper::matchExtractVectorElementWithICmp(const MachineOperand &MO,
BuildFnTy &MatchInfo) {
GExtractVectorElement *Extract =
cast<GExtractVectorElement>(MRI.getVRegDef(MO.getReg()));

Register Vector = Extract->getVectorReg();

GICmp *Cmp = cast<GICmp>(MRI.getVRegDef(Vector));

std::optional<ValueAndVReg> MaybeIndex =
getIConstantVRegValWithLookThrough(Extract->getIndexReg(), MRI);
std::optional<APInt> IndexC = std::nullopt;

if (MaybeIndex)
IndexC = MaybeIndex->Value;

if (!isCheapToScalarize(Vector, IndexC))
return false;

if (!MRI.hasOneNonDBGUse(Cmp->getReg(0)))
return false;

Register Dst = Extract->getReg(0);
LLT DstTy = MRI.getType(Dst);
LLT IdxTy = MRI.getType(Extract->getIndexReg());
LLT VectorTy = MRI.getType(Cmp->getLHSReg());
LLT ExtractDstTy = VectorTy.getScalarType();

if (!isLegalOrBeforeLegalizer(
{TargetOpcode::G_ICMP, {DstTy, ExtractDstTy}}) ||
!isLegalOrBeforeLegalizer({TargetOpcode::G_EXTRACT_VECTOR_ELT,
{ExtractDstTy, VectorTy, IdxTy}}))
return false;

MatchInfo = [=](MachineIRBuilder &B) {
auto LHS = B.buildExtractVectorElement(ExtractDstTy, Cmp->getLHSReg(),
Extract->getIndexReg());
auto RHS = B.buildExtractVectorElement(ExtractDstTy, Cmp->getRHSReg(),
Extract->getIndexReg());
B.buildICmp(Cmp->getCond(), Dst, LHS, RHS);
};

return true;
}

bool CombinerHelper::matchExtractVectorElementWithFCmp(const MachineOperand &MO,
BuildFnTy &MatchInfo) {
GExtractVectorElement *Extract =
cast<GExtractVectorElement>(MRI.getVRegDef(MO.getReg()));

Register Vector = Extract->getVectorReg();

GFCmp *Cmp = cast<GFCmp>(MRI.getVRegDef(Vector));

std::optional<ValueAndVReg> MaybeIndex =
getIConstantVRegValWithLookThrough(Extract->getIndexReg(), MRI);
std::optional<APInt> IndexC = std::nullopt;

if (MaybeIndex)
IndexC = MaybeIndex->Value;

if (!isCheapToScalarize(Vector, IndexC))
return false;

if (!MRI.hasOneNonDBGUse(Cmp->getReg(0)))
return false;

Register Dst = Extract->getReg(0);
LLT DstTy = MRI.getType(Dst);
LLT IdxTy = MRI.getType(Extract->getIndexReg());
LLT VectorTy = MRI.getType(Cmp->getLHSReg());
LLT ExtractDstTy = VectorTy.getScalarType();

if (!isLegalOrBeforeLegalizer(
{TargetOpcode::G_FCMP, {DstTy, ExtractDstTy}}) ||
!isLegalOrBeforeLegalizer({TargetOpcode::G_EXTRACT_VECTOR_ELT,
{ExtractDstTy, VectorTy, IdxTy}}))
return false;

MatchInfo = [=](MachineIRBuilder &B) {
auto LHS = B.buildExtractVectorElement(ExtractDstTy, Cmp->getLHSReg(),
Extract->getIndexReg());
auto RHS = B.buildExtractVectorElement(ExtractDstTy, Cmp->getRHSReg(),
Extract->getIndexReg());
B.buildFCmp(Cmp->getCond(), Dst, LHS, RHS, Cmp->getFlags());
};

return true;
}
128 changes: 128 additions & 0 deletions llvm/test/CodeGen/AArch64/extract-vector-elt.ll
Original file line number Diff line number Diff line change
Expand Up @@ -1100,4 +1100,132 @@ ret:
ret i32 %3
}

define i32 @extract_v4float_fcmp_const_no_zext(<4 x float> %a, <4 x float> %b, i32 %c) {
; CHECK-SD-LABEL: extract_v4float_fcmp_const_no_zext:
; CHECK-SD: // %bb.0: // %entry
; CHECK-SD-NEXT: fcmeq v0.4s, v0.4s, v0.4s
; CHECK-SD-NEXT: mvn v0.16b, v0.16b
; CHECK-SD-NEXT: xtn v0.4h, v0.4s
; CHECK-SD-NEXT: umov w8, v0.h[1]
; CHECK-SD-NEXT: and w0, w8, #0x1
; CHECK-SD-NEXT: ret
;
; CHECK-GI-LABEL: extract_v4float_fcmp_const_no_zext:
; CHECK-GI: // %bb.0: // %entry
; CHECK-GI-NEXT: mov s0, v0.s[1]
; CHECK-GI-NEXT: fmov s1, #1.00000000
; CHECK-GI-NEXT: fcmp s0, s1
; CHECK-GI-NEXT: cset w0, vs
; CHECK-GI-NEXT: ret
entry:
%vector = fcmp uno <4 x float> %a, <float 1.0, float 1.0, float 1.0, float 1.0>
%d = extractelement <4 x i1> %vector, i32 1
%z = zext i1 %d to i32
ret i32 %z
}

define i32 @extract_v4i32_icmp_const_no_zext(<4 x i32> %a, <4 x i32> %b, i32 %c) {
; CHECK-SD-LABEL: extract_v4i32_icmp_const_no_zext:
; CHECK-SD: // %bb.0: // %entry
; CHECK-SD-NEXT: adrp x8, .LCPI43_0
; CHECK-SD-NEXT: ldr q1, [x8, :lo12:.LCPI43_0]
; CHECK-SD-NEXT: cmge v0.4s, v1.4s, v0.4s
; CHECK-SD-NEXT: xtn v0.4h, v0.4s
; CHECK-SD-NEXT: umov w8, v0.h[1]
; CHECK-SD-NEXT: and w0, w8, #0x1
; CHECK-SD-NEXT: ret
;
; CHECK-GI-LABEL: extract_v4i32_icmp_const_no_zext:
; CHECK-GI: // %bb.0: // %entry
; CHECK-GI-NEXT: mov s0, v0.s[1]
; CHECK-GI-NEXT: fmov w8, s0
; CHECK-GI-NEXT: cmp w8, #8
; CHECK-GI-NEXT: cset w0, le
; CHECK-GI-NEXT: ret
entry:
%vector = icmp sle <4 x i32> %a, <i32 7, i32 8, i32 7, i32 9>
%d = extractelement <4 x i1> %vector, i32 1
%z = zext i1 %d to i32
ret i32 %z
}

define i32 @extract_v4float_fcmp_const_no_zext_fail(<4 x float> %a, <4 x float> %b, i32 %c) {
; CHECK-SD-LABEL: extract_v4float_fcmp_const_no_zext_fail:
; CHECK-SD: // %bb.0: // %entry
; CHECK-SD-NEXT: sub sp, sp, #16
; CHECK-SD-NEXT: .cfi_def_cfa_offset 16
; CHECK-SD-NEXT: fcmeq v0.4s, v0.4s, v0.4s
; CHECK-SD-NEXT: add x8, sp, #8
; CHECK-SD-NEXT: // kill: def $w0 killed $w0 def $x0
; CHECK-SD-NEXT: bfi x8, x0, #1, #2
; CHECK-SD-NEXT: mvn v0.16b, v0.16b
; CHECK-SD-NEXT: xtn v0.4h, v0.4s
; CHECK-SD-NEXT: str d0, [sp, #8]
; CHECK-SD-NEXT: ldrh w8, [x8]
; CHECK-SD-NEXT: and w0, w8, #0x1
; CHECK-SD-NEXT: add sp, sp, #16
; CHECK-SD-NEXT: ret
;
; CHECK-GI-LABEL: extract_v4float_fcmp_const_no_zext_fail:
; CHECK-GI: // %bb.0: // %entry
; CHECK-GI-NEXT: sub sp, sp, #16
; CHECK-GI-NEXT: .cfi_def_cfa_offset 16
; CHECK-GI-NEXT: fmov v1.4s, #1.00000000
; CHECK-GI-NEXT: mov w8, w0
; CHECK-GI-NEXT: mov x9, sp
; CHECK-GI-NEXT: and x8, x8, #0x3
; CHECK-GI-NEXT: fcmge v2.4s, v0.4s, v1.4s
; CHECK-GI-NEXT: fcmgt v0.4s, v1.4s, v0.4s
; CHECK-GI-NEXT: orr v0.16b, v0.16b, v2.16b
; CHECK-GI-NEXT: mvn v0.16b, v0.16b
; CHECK-GI-NEXT: str q0, [sp]
; CHECK-GI-NEXT: ldr w8, [x9, x8, lsl #2]
; CHECK-GI-NEXT: and w0, w8, #0x1
; CHECK-GI-NEXT: add sp, sp, #16
; CHECK-GI-NEXT: ret
entry:
%vector = fcmp uno <4 x float> %a, <float 1.0, float 1.0, float 1.0, float 1.0>
%d = extractelement <4 x i1> %vector, i32 %c
%z = zext i1 %d to i32
ret i32 %z
}

define i32 @extract_v4i32_icmp_const_no_zext_fail(<4 x i32> %a, <4 x i32> %b, i32 %c) {
; CHECK-SD-LABEL: extract_v4i32_icmp_const_no_zext_fail:
; CHECK-SD: // %bb.0: // %entry
; CHECK-SD-NEXT: sub sp, sp, #16
; CHECK-SD-NEXT: .cfi_def_cfa_offset 16
; CHECK-SD-NEXT: adrp x8, .LCPI45_0
; CHECK-SD-NEXT: // kill: def $w0 killed $w0 def $x0
; CHECK-SD-NEXT: ldr q1, [x8, :lo12:.LCPI45_0]
; CHECK-SD-NEXT: add x8, sp, #8
; CHECK-SD-NEXT: bfi x8, x0, #1, #2
; CHECK-SD-NEXT: cmge v0.4s, v1.4s, v0.4s
; CHECK-SD-NEXT: xtn v0.4h, v0.4s
; CHECK-SD-NEXT: str d0, [sp, #8]
; CHECK-SD-NEXT: ldrh w8, [x8]
; CHECK-SD-NEXT: and w0, w8, #0x1
; CHECK-SD-NEXT: add sp, sp, #16
; CHECK-SD-NEXT: ret
;
; CHECK-GI-LABEL: extract_v4i32_icmp_const_no_zext_fail:
; CHECK-GI: // %bb.0: // %entry
; CHECK-GI-NEXT: sub sp, sp, #16
; CHECK-GI-NEXT: .cfi_def_cfa_offset 16
; CHECK-GI-NEXT: adrp x8, .LCPI45_0
; CHECK-GI-NEXT: mov x9, sp
; CHECK-GI-NEXT: ldr q1, [x8, :lo12:.LCPI45_0]
; CHECK-GI-NEXT: mov w8, w0
; CHECK-GI-NEXT: and x8, x8, #0x3
; CHECK-GI-NEXT: cmge v0.4s, v1.4s, v0.4s
; CHECK-GI-NEXT: str q0, [sp]
; CHECK-GI-NEXT: ldr w8, [x9, x8, lsl #2]
; CHECK-GI-NEXT: and w0, w8, #0x1
; CHECK-GI-NEXT: add sp, sp, #16
; CHECK-GI-NEXT: ret
entry:
%vector = icmp sle <4 x i32> %a, <i32 7, i32 8, i32 7, i32 9>
%d = extractelement <4 x i1> %vector, i32 %c
%z = zext i1 %d to i32
ret i32 %z
}
Loading