Skip to content

Commit f37d81f

Browse files
authored
[PatternMatch] Add a matching helper m_ElementWiseBitCast. NFC. (#80764)
This patch introduces a matching helper `m_ElementWiseBitCast`, which is used for matching element-wise int <-> fp casts. The motivation of this patch is to avoid duplicating checks in #80740 and #80414.
1 parent d109f94 commit f37d81f

File tree

9 files changed

+111
-51
lines changed

9 files changed

+111
-51
lines changed

llvm/include/llvm/IR/PatternMatch.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1711,6 +1711,34 @@ m_BitCast(const OpTy &Op) {
17111711
return CastOperator_match<OpTy, Instruction::BitCast>(Op);
17121712
}
17131713

1714+
template <typename Op_t> struct ElementWiseBitCast_match {
1715+
Op_t Op;
1716+
1717+
ElementWiseBitCast_match(const Op_t &OpMatch) : Op(OpMatch) {}
1718+
1719+
template <typename OpTy> bool match(OpTy *V) {
1720+
BitCastInst *I = dyn_cast<BitCastInst>(V);
1721+
if (!I)
1722+
return false;
1723+
Type *SrcType = I->getSrcTy();
1724+
Type *DstType = I->getType();
1725+
// Make sure the bitcast doesn't change between scalar and vector and
1726+
// doesn't change the number of vector elements.
1727+
if (SrcType->isVectorTy() != DstType->isVectorTy())
1728+
return false;
1729+
if (VectorType *SrcVecTy = dyn_cast<VectorType>(SrcType);
1730+
SrcVecTy && SrcVecTy->getElementCount() !=
1731+
cast<VectorType>(DstType)->getElementCount())
1732+
return false;
1733+
return Op.match(I->getOperand(0));
1734+
}
1735+
};
1736+
1737+
template <typename OpTy>
1738+
inline ElementWiseBitCast_match<OpTy> m_ElementWiseBitCast(const OpTy &Op) {
1739+
return ElementWiseBitCast_match<OpTy>(Op);
1740+
}
1741+
17141742
/// Matches PtrToInt.
17151743
template <typename OpTy>
17161744
inline CastOperator_match<OpTy, Instruction::PtrToInt>

llvm/lib/Analysis/InstructionSimplify.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3034,7 +3034,7 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS,
30343034
// floating-point casts:
30353035
// icmp slt (bitcast (uitofp X)), 0 --> false
30363036
// icmp sgt (bitcast (uitofp X)), -1 --> true
3037-
if (match(LHS, m_BitCast(m_UIToFP(m_Value(X))))) {
3037+
if (match(LHS, m_ElementWiseBitCast(m_UIToFP(m_Value(X))))) {
30383038
if (Pred == ICmpInst::ICMP_SLT && match(RHS, m_Zero()))
30393039
return ConstantInt::getFalse(ITy);
30403040
if (Pred == ICmpInst::ICMP_SGT && match(RHS, m_AllOnes()))

llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2531,14 +2531,12 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
25312531
// Assumes any IEEE-represented type has the sign bit in the high bit.
25322532
// TODO: Unify with APInt matcher. This version allows undef unlike m_APInt
25332533
Value *CastOp;
2534-
if (match(Op0, m_BitCast(m_Value(CastOp))) &&
2534+
if (match(Op0, m_ElementWiseBitCast(m_Value(CastOp))) &&
25352535
match(Op1, m_MaxSignedValue()) &&
25362536
!Builder.GetInsertBlock()->getParent()->hasFnAttribute(
2537-
Attribute::NoImplicitFloat)) {
2537+
Attribute::NoImplicitFloat)) {
25382538
Type *EltTy = CastOp->getType()->getScalarType();
2539-
if (EltTy->isFloatingPointTy() && EltTy->isIEEE() &&
2540-
EltTy->getPrimitiveSizeInBits() ==
2541-
I.getType()->getScalarType()->getPrimitiveSizeInBits()) {
2539+
if (EltTy->isFloatingPointTy() && EltTy->isIEEE()) {
25422540
Value *FAbs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, CastOp);
25432541
return new BitCastInst(FAbs, I.getType());
25442542
}
@@ -3963,13 +3961,12 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
39633961
// This is generous interpretation of noimplicitfloat, this is not a true
39643962
// floating-point operation.
39653963
Value *CastOp;
3966-
if (match(Op0, m_BitCast(m_Value(CastOp))) && match(Op1, m_SignMask()) &&
3964+
if (match(Op0, m_ElementWiseBitCast(m_Value(CastOp))) &&
3965+
match(Op1, m_SignMask()) &&
39673966
!Builder.GetInsertBlock()->getParent()->hasFnAttribute(
39683967
Attribute::NoImplicitFloat)) {
39693968
Type *EltTy = CastOp->getType()->getScalarType();
3970-
if (EltTy->isFloatingPointTy() && EltTy->isIEEE() &&
3971-
EltTy->getPrimitiveSizeInBits() ==
3972-
I.getType()->getScalarType()->getPrimitiveSizeInBits()) {
3969+
if (EltTy->isFloatingPointTy() && EltTy->isIEEE()) {
39733970
Value *FAbs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, CastOp);
39743971
Value *FNegFAbs = Builder.CreateFNeg(FAbs);
39753972
return new BitCastInst(FNegFAbs, I.getType());
@@ -4739,13 +4736,12 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
47394736
// Assumes any IEEE-represented type has the sign bit in the high bit.
47404737
// TODO: Unify with APInt matcher. This version allows undef unlike m_APInt
47414738
Value *CastOp;
4742-
if (match(Op0, m_BitCast(m_Value(CastOp))) && match(Op1, m_SignMask()) &&
4739+
if (match(Op0, m_ElementWiseBitCast(m_Value(CastOp))) &&
4740+
match(Op1, m_SignMask()) &&
47434741
!Builder.GetInsertBlock()->getParent()->hasFnAttribute(
47444742
Attribute::NoImplicitFloat)) {
47454743
Type *EltTy = CastOp->getType()->getScalarType();
4746-
if (EltTy->isFloatingPointTy() && EltTy->isIEEE() &&
4747-
EltTy->getPrimitiveSizeInBits() ==
4748-
I.getType()->getScalarType()->getPrimitiveSizeInBits()) {
4744+
if (EltTy->isFloatingPointTy() && EltTy->isIEEE()) {
47494745
Value *FNeg = Builder.CreateFNeg(CastOp);
47504746
return new BitCastInst(FNeg, I.getType());
47514747
}

llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,9 +182,15 @@ Instruction *InstCombinerImpl::commonCastTransforms(CastInst &CI) {
182182
if (!Cmp || Cmp->getOperand(0)->getType() != Sel->getType() ||
183183
(CI.getOpcode() == Instruction::Trunc &&
184184
shouldChangeType(CI.getSrcTy(), CI.getType()))) {
185-
if (Instruction *NV = FoldOpIntoSelect(CI, Sel)) {
186-
replaceAllDbgUsesWith(*Sel, *NV, CI, DT);
187-
return NV;
185+
186+
// If it's a bitcast involving vectors, make sure it has the same number
187+
// of elements on both sides.
188+
if (CI.getOpcode() != Instruction::BitCast ||
189+
match(&CI, m_ElementWiseBitCast(m_Value()))) {
190+
if (Instruction *NV = FoldOpIntoSelect(CI, Sel)) {
191+
replaceAllDbgUsesWith(*Sel, *NV, CI, DT);
192+
return NV;
193+
}
188194
}
189195
}
190196
}

llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1835,15 +1835,10 @@ Instruction *InstCombinerImpl::foldICmpAndConstConst(ICmpInst &Cmp,
18351835
Value *V;
18361836
if (!Cmp.getParent()->getParent()->hasFnAttribute(
18371837
Attribute::NoImplicitFloat) &&
1838-
Cmp.isEquality() && match(X, m_OneUse(m_BitCast(m_Value(V))))) {
1839-
Type *SrcType = V->getType();
1840-
Type *DstType = X->getType();
1841-
Type *FPType = SrcType->getScalarType();
1842-
// Make sure the bitcast doesn't change between scalar and vector and
1843-
// doesn't change the number of vector elements.
1844-
if (SrcType->isVectorTy() == DstType->isVectorTy() &&
1845-
SrcType->getScalarSizeInBits() == DstType->getScalarSizeInBits() &&
1846-
FPType->isIEEELikeFPTy() && C1 == *C2) {
1838+
Cmp.isEquality() &&
1839+
match(X, m_OneUse(m_ElementWiseBitCast(m_Value(V))))) {
1840+
Type *FPType = V->getType()->getScalarType();
1841+
if (FPType->isIEEELikeFPTy() && C1 == *C2) {
18471842
APInt ExponentMask =
18481843
APFloat::getInf(FPType->getFltSemantics()).bitcastToAPInt();
18491844
if (C1 == ExponentMask) {
@@ -7755,9 +7750,7 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) {
77557750
// Ignore signbit of bitcasted int when comparing equality to FP 0.0:
77567751
// fcmp oeq/une (bitcast X), 0.0 --> (and X, SignMaskC) ==/!= 0
77577752
if (match(Op1, m_PosZeroFP()) &&
7758-
match(Op0, m_OneUse(m_BitCast(m_Value(X)))) &&
7759-
X->getType()->isVectorTy() == OpType->isVectorTy() &&
7760-
X->getType()->getScalarSizeInBits() == OpType->getScalarSizeInBits()) {
7753+
match(Op0, m_OneUse(m_ElementWiseBitCast(m_Value(X))))) {
77617754
ICmpInst::Predicate IntPred = ICmpInst::BAD_ICMP_PREDICATE;
77627755
if (Pred == FCmpInst::FCMP_OEQ)
77637756
IntPred = ICmpInst::ICMP_EQ;

llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2365,9 +2365,6 @@ static Instruction *foldSelectToCopysign(SelectInst &Sel,
23652365
Value *FVal = Sel.getFalseValue();
23662366
Type *SelType = Sel.getType();
23672367

2368-
if (ICmpInst::makeCmpResultType(TVal->getType()) != Cond->getType())
2369-
return nullptr;
2370-
23712368
// Match select ?, TC, FC where the constants are equal but negated.
23722369
// TODO: Generalize to handle a negated variable operand?
23732370
const APFloat *TC, *FC;
@@ -2382,7 +2379,8 @@ static Instruction *foldSelectToCopysign(SelectInst &Sel,
23822379
const APInt *C;
23832380
bool IsTrueIfSignSet;
23842381
ICmpInst::Predicate Pred;
2385-
if (!match(Cond, m_OneUse(m_ICmp(Pred, m_BitCast(m_Value(X)), m_APInt(C)))) ||
2382+
if (!match(Cond, m_OneUse(m_ICmp(Pred, m_ElementWiseBitCast(m_Value(X)),
2383+
m_APInt(C)))) ||
23862384
!InstCombiner::isSignBitCheck(Pred, *C, IsTrueIfSignSet) ||
23872385
X->getType() != SelType)
23882386
return nullptr;
@@ -2770,8 +2768,6 @@ static Instruction *foldSelectWithFCmpToFabs(SelectInst &SI,
27702768

27712769
// Match select with (icmp slt (bitcast X to int), 0)
27722770
// or (icmp sgt (bitcast X to int), -1)
2773-
if (ICmpInst::makeCmpResultType(SI.getType()) != CondVal->getType())
2774-
return ChangedFMF ? &SI : nullptr;
27752771

27762772
for (bool Swap : {false, true}) {
27772773
Value *TrueVal = SI.getTrueValue();
@@ -2783,7 +2779,8 @@ static Instruction *foldSelectWithFCmpToFabs(SelectInst &SI,
27832779
CmpInst::Predicate Pred;
27842780
const APInt *C;
27852781
bool TrueIfSigned;
2786-
if (!match(CondVal, m_ICmp(Pred, m_BitCast(m_Specific(X)), m_APInt(C))) ||
2782+
if (!match(CondVal,
2783+
m_ICmp(Pred, m_ElementWiseBitCast(m_Specific(X)), m_APInt(C))) ||
27872784
!IC.isSignBitCheck(Pred, *C, TrueIfSigned))
27882785
continue;
27892786
if (!match(TrueVal, m_FNeg(m_Specific(X))))

llvm/lib/Transforms/InstCombine/InstructionCombining.cpp

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1474,21 +1474,6 @@ Instruction *InstCombinerImpl::FoldOpIntoSelect(Instruction &Op, SelectInst *SI,
14741474
if (SI->getType()->isIntOrIntVectorTy(1))
14751475
return nullptr;
14761476

1477-
// If it's a bitcast involving vectors, make sure it has the same number of
1478-
// elements on both sides.
1479-
if (auto *BC = dyn_cast<BitCastInst>(&Op)) {
1480-
VectorType *DestTy = dyn_cast<VectorType>(BC->getDestTy());
1481-
VectorType *SrcTy = dyn_cast<VectorType>(BC->getSrcTy());
1482-
1483-
// Verify that either both or neither are vectors.
1484-
if ((SrcTy == nullptr) != (DestTy == nullptr))
1485-
return nullptr;
1486-
1487-
// If vectors, verify that they have the same number of elements.
1488-
if (SrcTy && SrcTy->getElementCount() != DestTy->getElementCount())
1489-
return nullptr;
1490-
}
1491-
14921477
// Test if a FCmpInst instruction is used exclusively by a select as
14931478
// part of a minimum or maximum operation. If so, refrain from doing
14941479
// any other folding. This helps out other analyses which understand

llvm/test/Transforms/InstSimplify/cast-unsigned-icmp-cmp-0.ll

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,19 @@ define <2 x i1> @i32_cast_cmp_sgt_int_m1_uitofp_float_vec(<2 x i32> %i) {
5757
ret <2 x i1> %cmp
5858
}
5959

60+
define i1 @i32_cast_cmp_sgt_int_m1_uitofp_float_vec_mismatch(<2 x i32> %i) {
61+
; CHECK-LABEL: @i32_cast_cmp_sgt_int_m1_uitofp_float_vec_mismatch(
62+
; CHECK-NEXT: [[F:%.*]] = uitofp <2 x i32> [[I:%.*]] to <2 x float>
63+
; CHECK-NEXT: [[B:%.*]] = bitcast <2 x float> [[F]] to i64
64+
; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i64 [[B]], -1
65+
; CHECK-NEXT: ret i1 [[CMP]]
66+
;
67+
%f = uitofp <2 x i32> %i to <2 x float>
68+
%b = bitcast <2 x float> %f to i64
69+
%cmp = icmp sgt i64 %b, -1
70+
ret i1 %cmp
71+
}
72+
6073
define <3 x i1> @i32_cast_cmp_sgt_int_m1_uitofp_float_vec_undef(<3 x i32> %i) {
6174
; CHECK-LABEL: @i32_cast_cmp_sgt_int_m1_uitofp_float_vec_undef(
6275
; CHECK-NEXT: ret <3 x i1> <i1 true, i1 true, i1 true>

llvm/unittests/IR/PatternMatch.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,48 @@ TEST_F(PatternMatchTest, ZExtSExtSelf) {
530530
EXPECT_TRUE(m_ZExtOrSExtOrSelf(m_One()).match(One64S));
531531
}
532532

533+
TEST_F(PatternMatchTest, BitCast) {
534+
Value *OneDouble = ConstantFP::get(IRB.getDoubleTy(), APFloat(1.0));
535+
Value *ScalableDouble = ConstantFP::get(
536+
VectorType::get(IRB.getDoubleTy(), 2, /*Scalable=*/true), APFloat(1.0));
537+
// scalar -> scalar
538+
Value *DoubleToI64 = IRB.CreateBitCast(OneDouble, IRB.getInt64Ty());
539+
// scalar -> vector
540+
Value *DoubleToV2I32 = IRB.CreateBitCast(
541+
OneDouble, VectorType::get(IRB.getInt32Ty(), 2, /*Scalable=*/false));
542+
// vector -> scalar
543+
Value *V2I32ToDouble = IRB.CreateBitCast(DoubleToV2I32, IRB.getDoubleTy());
544+
// vector -> vector (same count)
545+
Value *V2I32ToV2Float = IRB.CreateBitCast(
546+
DoubleToV2I32, VectorType::get(IRB.getFloatTy(), 2, /*Scalable=*/false));
547+
// vector -> vector (different count)
548+
Value *V2I32TOV4I16 = IRB.CreateBitCast(
549+
DoubleToV2I32, VectorType::get(IRB.getInt16Ty(), 4, /*Scalable=*/false));
550+
// scalable vector -> scalable vector (same count)
551+
Value *NXV2DoubleToNXV2I64 = IRB.CreateBitCast(
552+
ScalableDouble, VectorType::get(IRB.getInt64Ty(), 2, /*Scalable=*/true));
553+
// scalable vector -> scalable vector (different count)
554+
Value *NXV2I64ToNXV4I32 = IRB.CreateBitCast(
555+
NXV2DoubleToNXV2I64,
556+
VectorType::get(IRB.getInt32Ty(), 4, /*Scalable=*/true));
557+
558+
EXPECT_TRUE(m_BitCast(m_Value()).match(DoubleToI64));
559+
EXPECT_TRUE(m_BitCast(m_Value()).match(DoubleToV2I32));
560+
EXPECT_TRUE(m_BitCast(m_Value()).match(V2I32ToDouble));
561+
EXPECT_TRUE(m_BitCast(m_Value()).match(V2I32ToV2Float));
562+
EXPECT_TRUE(m_BitCast(m_Value()).match(V2I32TOV4I16));
563+
EXPECT_TRUE(m_BitCast(m_Value()).match(NXV2DoubleToNXV2I64));
564+
EXPECT_TRUE(m_BitCast(m_Value()).match(NXV2I64ToNXV4I32));
565+
566+
EXPECT_TRUE(m_ElementWiseBitCast(m_Value()).match(DoubleToI64));
567+
EXPECT_FALSE(m_ElementWiseBitCast(m_Value()).match(DoubleToV2I32));
568+
EXPECT_FALSE(m_ElementWiseBitCast(m_Value()).match(V2I32ToDouble));
569+
EXPECT_TRUE(m_ElementWiseBitCast(m_Value()).match(V2I32ToV2Float));
570+
EXPECT_FALSE(m_ElementWiseBitCast(m_Value()).match(V2I32TOV4I16));
571+
EXPECT_TRUE(m_ElementWiseBitCast(m_Value()).match(NXV2DoubleToNXV2I64));
572+
EXPECT_FALSE(m_ElementWiseBitCast(m_Value()).match(NXV2I64ToNXV4I32));
573+
}
574+
533575
TEST_F(PatternMatchTest, Power2) {
534576
Value *C128 = IRB.getInt32(128);
535577
Value *CNeg128 = ConstantExpr::getNeg(cast<Constant>(C128));

0 commit comments

Comments
 (0)