-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[PatternMatch] Add a matching helper m_ElementWiseBitCast
. NFC.
#80764
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-llvm-analysis @llvm/pr-subscribers-llvm-transforms Author: Yingwei Zheng (dtcxzyw) ChangesThis patch introduces a matching helper Full diff: https://github.com/llvm/llvm-project/pull/80764.diff 5 Files Affected:
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 878079c4fe4e8e..7ce42e30639092 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -1711,6 +1711,30 @@ m_BitCast(const OpTy &Op) {
return CastOperator_match<OpTy, Instruction::BitCast>(Op);
}
+template <typename Op_t> struct ElementWiseBitCast_match {
+ Op_t Op;
+
+ ElementWiseBitCast_match(const Op_t &OpMatch) : Op(OpMatch) {}
+
+ template <typename OpTy> bool match(OpTy *V) {
+ if (auto *I = dyn_cast<BitCastInst>(V)) {
+ Type *SrcType = I->getSrcTy();
+ Type *DstType = I->getType();
+ // Make sure the bitcast doesn't change between scalar and vector and
+ // doesn't change the number of vector elements.
+ if (SrcType->isVectorTy() == DstType->isVectorTy() &&
+ SrcType->getScalarSizeInBits() == DstType->getScalarSizeInBits())
+ return Op.match(I->getOperand(0));
+ }
+ return false;
+ }
+};
+
+template <typename OpTy>
+inline ElementWiseBitCast_match<OpTy> m_ElementWiseBitCast(const OpTy &Op) {
+ return ElementWiseBitCast_match<OpTy>(Op);
+}
+
/// Matches PtrToInt.
template <typename OpTy>
inline CastOperator_match<OpTy, Instruction::PtrToInt>
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 2793b798f35f36..01b017142cfcb0 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -3034,7 +3034,7 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS,
// floating-point casts:
// icmp slt (bitcast (uitofp X)), 0 --> false
// icmp sgt (bitcast (uitofp X)), -1 --> true
- if (match(LHS, m_BitCast(m_UIToFP(m_Value(X))))) {
+ if (match(LHS, m_ElementWiseBitCast(m_UIToFP(m_Value(X))))) {
if (Pred == ICmpInst::ICMP_SLT && match(RHS, m_Zero()))
return ConstantInt::getFalse(ITy);
if (Pred == ICmpInst::ICMP_SGT && match(RHS, m_AllOnes()))
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 6ca4d6d673068e..0409f33445f0af 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -2493,14 +2493,12 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
// Assumes any IEEE-represented type has the sign bit in the high bit.
// TODO: Unify with APInt matcher. This version allows undef unlike m_APInt
Value *CastOp;
- if (match(Op0, m_BitCast(m_Value(CastOp))) &&
+ if (match(Op0, m_ElementWiseBitCast(m_Value(CastOp))) &&
match(Op1, m_MaxSignedValue()) &&
!Builder.GetInsertBlock()->getParent()->hasFnAttribute(
- Attribute::NoImplicitFloat)) {
+ Attribute::NoImplicitFloat)) {
Type *EltTy = CastOp->getType()->getScalarType();
- if (EltTy->isFloatingPointTy() && EltTy->isIEEE() &&
- EltTy->getPrimitiveSizeInBits() ==
- I.getType()->getScalarType()->getPrimitiveSizeInBits()) {
+ if (EltTy->isFloatingPointTy() && EltTy->isIEEE()) {
Value *FAbs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, CastOp);
return new BitCastInst(FAbs, I.getType());
}
@@ -3925,13 +3923,12 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
// This is generous interpretation of noimplicitfloat, this is not a true
// floating-point operation.
Value *CastOp;
- if (match(Op0, m_BitCast(m_Value(CastOp))) && match(Op1, m_SignMask()) &&
+ if (match(Op0, m_ElementWiseBitCast(m_Value(CastOp))) &&
+ match(Op1, m_SignMask()) &&
!Builder.GetInsertBlock()->getParent()->hasFnAttribute(
Attribute::NoImplicitFloat)) {
Type *EltTy = CastOp->getType()->getScalarType();
- if (EltTy->isFloatingPointTy() && EltTy->isIEEE() &&
- EltTy->getPrimitiveSizeInBits() ==
- I.getType()->getScalarType()->getPrimitiveSizeInBits()) {
+ if (EltTy->isFloatingPointTy() && EltTy->isIEEE()) {
Value *FAbs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, CastOp);
Value *FNegFAbs = Builder.CreateFNeg(FAbs);
return new BitCastInst(FNegFAbs, I.getType());
@@ -4701,13 +4698,12 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
// Assumes any IEEE-represented type has the sign bit in the high bit.
// TODO: Unify with APInt matcher. This version allows undef unlike m_APInt
Value *CastOp;
- if (match(Op0, m_BitCast(m_Value(CastOp))) && match(Op1, m_SignMask()) &&
+ if (match(Op0, m_ElementWiseBitCast(m_Value(CastOp))) &&
+ match(Op1, m_SignMask()) &&
!Builder.GetInsertBlock()->getParent()->hasFnAttribute(
Attribute::NoImplicitFloat)) {
Type *EltTy = CastOp->getType()->getScalarType();
- if (EltTy->isFloatingPointTy() && EltTy->isIEEE() &&
- EltTy->getPrimitiveSizeInBits() ==
- I.getType()->getScalarType()->getPrimitiveSizeInBits()) {
+ if (EltTy->isFloatingPointTy() && EltTy->isIEEE()) {
Value *FNeg = Builder.CreateFNeg(CastOp);
return new BitCastInst(FNeg, I.getType());
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 380cb3504209d3..cda1061fb35a8d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -1834,15 +1834,10 @@ Instruction *InstCombinerImpl::foldICmpAndConstConst(ICmpInst &Cmp,
Value *V;
if (!Cmp.getParent()->getParent()->hasFnAttribute(
Attribute::NoImplicitFloat) &&
- Cmp.isEquality() && match(X, m_OneUse(m_BitCast(m_Value(V))))) {
- Type *SrcType = V->getType();
- Type *DstType = X->getType();
- Type *FPType = SrcType->getScalarType();
- // Make sure the bitcast doesn't change between scalar and vector and
- // doesn't change the number of vector elements.
- if (SrcType->isVectorTy() == DstType->isVectorTy() &&
- SrcType->getScalarSizeInBits() == DstType->getScalarSizeInBits() &&
- FPType->isIEEELikeFPTy() && C1 == *C2) {
+ Cmp.isEquality() &&
+ match(X, m_OneUse(m_ElementWiseBitCast(m_Value(V))))) {
+ Type *FPType = V->getType()->getScalarType();
+ if (FPType->isIEEELikeFPTy() && C1 == *C2) {
APInt ExponentMask =
APFloat::getInf(FPType->getFltSemantics()).bitcastToAPInt();
if (C1 == ExponentMask) {
@@ -7754,9 +7749,7 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) {
// Ignore signbit of bitcasted int when comparing equality to FP 0.0:
// fcmp oeq/une (bitcast X), 0.0 --> (and X, SignMaskC) ==/!= 0
if (match(Op1, m_PosZeroFP()) &&
- match(Op0, m_OneUse(m_BitCast(m_Value(X)))) &&
- X->getType()->isVectorTy() == OpType->isVectorTy() &&
- X->getType()->getScalarSizeInBits() == OpType->getScalarSizeInBits()) {
+ match(Op0, m_OneUse(m_ElementWiseBitCast(m_Value(X))))) {
ICmpInst::Predicate IntPred = ICmpInst::BAD_ICMP_PREDICATE;
if (Pred == FCmpInst::FCMP_OEQ)
IntPred = ICmpInst::ICMP_EQ;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 2756f81ed9e620..d6f447125269c9 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -2382,7 +2382,8 @@ static Instruction *foldSelectToCopysign(SelectInst &Sel,
const APInt *C;
bool IsTrueIfSignSet;
ICmpInst::Predicate Pred;
- if (!match(Cond, m_OneUse(m_ICmp(Pred, m_BitCast(m_Value(X)), m_APInt(C)))) ||
+ if (!match(Cond, m_OneUse(m_ICmp(Pred, m_ElementWiseBitCast(m_Value(X)),
+ m_APInt(C)))) ||
!InstCombiner::isSignBitCheck(Pred, *C, IsTrueIfSignSet) ||
X->getType() != SelType)
return nullptr;
@@ -2783,7 +2784,8 @@ static Instruction *foldSelectWithFCmpToFabs(SelectInst &SI,
CmpInst::Predicate Pred;
const APInt *C;
bool TrueIfSigned;
- if (!match(CondVal, m_ICmp(Pred, m_BitCast(m_Specific(X)), m_APInt(C))) ||
+ if (!match(CondVal,
+ m_ICmp(Pred, m_ElementWiseBitCast(m_Specific(X)), m_APInt(C))) ||
!IC.isSignBitCheck(Pred, *C, TrueIfSigned))
continue;
if (!match(TrueVal, m_FNeg(m_Specific(X))))
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a unit test in unittests/IR/PatternMatch.cpp
47f5744
to
0bcdd86
Compare
0bcdd86
to
62894a2
Compare
62894a2
to
8ac6782
Compare
// If it's a bitcast involving vectors, make sure it has the same number | ||
// of elements on both sides. | ||
if (CI.getOpcode() != Instruction::BitCast || | ||
match(&CI, m_ElementWiseBitCast(m_Value()))) { | ||
if (Instruction *NV = FoldOpIntoSelect(CI, Sel)) { | ||
replaceAllDbgUsesWith(*Sel, *NV, CI, DT); | ||
return NV; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this another untested bug fix?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed the check in InstructionCombing.cpp
:
// If it's a bitcast involving vectors, make sure it has the same number of
// elements on both sides.
if (auto *BC = dyn_cast<BitCastInst>(&Op)) {
VectorType *DestTy = dyn_cast<VectorType>(BC->getDestTy());
VectorType *SrcTy = dyn_cast<VectorType>(BC->getSrcTy());
// Verify that either both or neither are vectors.
if ((SrcTy == nullptr) != (DestTy == nullptr))
return nullptr;
// If vectors, verify that they have the same number of elements.
if (SrcTy && SrcTy->getElementCount() != DestTy->getElementCount())
return nullptr;
}
This patch introduces a matching helper
m_ElementWiseBitCast
, which is used for matching element-wise int<-> fp casts.The motivation of this patch is to avoid duplicating checks in #80740 and #80414.