Skip to content

Commit 9f02fb4

Browse files
committed
[ValueTracking] Compute known FPClass from dominating condition
1 parent b0c9bb7 commit 9f02fb4

File tree

3 files changed

+97
-29
lines changed

3 files changed

+97
-29
lines changed

llvm/lib/Analysis/DomConditionCache.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ static void findAffectedValues(Value *Cond,
5151
// Handle (A + C1) u< C2, which is the canonical form of A > C3 && A < C4.
5252
if (match(A, m_Add(m_Value(X), m_ConstantInt())))
5353
AddAffected(X);
54+
// Handle icmp slt/sgt (bitcast X to int) 0/-1
55+
if (match(A, m_BitCast(m_Value(X))))
56+
Affected.push_back(X);
5457
}
5558
}
5659
}

llvm/lib/Analysis/ValueTracking.cpp

Lines changed: 93 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4213,9 +4213,82 @@ llvm::fcmpImpliesClass(CmpInst::Predicate Pred, const Function &F, Value *LHS,
42134213
return fcmpImpliesClass(Pred, F, LHS, *ConstRHS, LookThroughSrc);
42144214
}
42154215

4216-
static FPClassTest computeKnownFPClassFromAssumes(const Value *V,
4217-
const SimplifyQuery &Q) {
4218-
FPClassTest KnownFromAssume = fcAllFlags;
4216+
static void computeKnownFPClassFromCond(const Value *V, Value *Cond,
4217+
bool CondIsTrue,
4218+
const Instruction *CxtI,
4219+
KnownFPClass &KnownFromContext) {
4220+
CmpInst::Predicate Pred;
4221+
Value *LHS;
4222+
Value *RHS;
4223+
uint64_t ClassVal = 0;
4224+
if (match(Cond, m_Cmp(Pred, m_Value(LHS), m_Value(RHS)))) {
4225+
if (CmpInst::isIntPredicate(Pred)) {
4226+
if (!match(LHS, m_BitCast(m_Specific(V))))
4227+
return;
4228+
Type *SrcType = V->getType();
4229+
Type *DstType = LHS->getType();
4230+
4231+
// Make sure the bitcast doesn't change between scalar and vector and
4232+
// doesn't change the number of vector elements.
4233+
if (SrcType->isVectorTy() == DstType->isVectorTy() &&
4234+
SrcType->getScalarSizeInBits() == DstType->getScalarSizeInBits()) {
4235+
// TODO: move IsSignBitCheck to ValueTracking
4236+
bool TrueIfSigned;
4237+
if ((Pred == ICmpInst::ICMP_SLT && match(RHS, m_Zero())) ||
4238+
(Pred == ICmpInst::ICMP_SLE && match(RHS, m_AllOnes())))
4239+
TrueIfSigned = true;
4240+
else if (Pred == ICmpInst::ICMP_SGT && match(RHS, m_AllOnes()) ||
4241+
(Pred == ICmpInst::ICMP_SGE && match(RHS, m_Zero())))
4242+
TrueIfSigned = false;
4243+
else
4244+
return;
4245+
if (TrueIfSigned == CondIsTrue)
4246+
KnownFromContext.signBitMustBeOne();
4247+
else
4248+
KnownFromContext.signBitMustBeZero();
4249+
}
4250+
} else {
4251+
const APFloat *CRHS;
4252+
if (match(RHS, m_APFloat(CRHS))) {
4253+
auto [CmpVal, MaskIfTrue, MaskIfFalse] = fcmpImpliesClass(
4254+
Pred, *CxtI->getParent()->getParent(), LHS, *CRHS, LHS != V);
4255+
if (CmpVal == V)
4256+
KnownFromContext.knownNot(~(CondIsTrue ? MaskIfTrue : MaskIfFalse));
4257+
}
4258+
}
4259+
} else if (match(Cond, m_Intrinsic<Intrinsic::is_fpclass>(
4260+
m_Value(LHS), m_ConstantInt(ClassVal)))) {
4261+
FPClassTest Mask = static_cast<FPClassTest>(ClassVal);
4262+
KnownFromContext.knownNot(CondIsTrue ? ~Mask : Mask);
4263+
}
4264+
}
4265+
4266+
static KnownFPClass computeKnownFPClassFromContext(const Value *V,
4267+
const SimplifyQuery &Q) {
4268+
KnownFPClass KnownFromContext;
4269+
4270+
if (!Q.CxtI)
4271+
return KnownFromContext;
4272+
4273+
if (Q.DC && Q.DT) {
4274+
// Handle dominating conditions.
4275+
for (BranchInst *BI : Q.DC->conditionsFor(V)) {
4276+
Value *Cond = BI->getCondition();
4277+
4278+
BasicBlockEdge Edge0(BI->getParent(), BI->getSuccessor(0));
4279+
if (Q.DT->dominates(Edge0, Q.CxtI->getParent()))
4280+
computeKnownFPClassFromCond(V, Cond, /*CondIsTrue=*/true, Q.CxtI,
4281+
KnownFromContext);
4282+
4283+
BasicBlockEdge Edge1(BI->getParent(), BI->getSuccessor(1));
4284+
if (Q.DT->dominates(Edge1, Q.CxtI->getParent()))
4285+
computeKnownFPClassFromCond(V, Cond, /*CondIsTrue=*/false, Q.CxtI,
4286+
KnownFromContext);
4287+
}
4288+
}
4289+
4290+
if (!Q.AC)
4291+
return KnownFromContext;
42194292

42204293
// Try to restrict the floating-point classes based on information from
42214294
// assumptions.
@@ -4233,25 +4306,11 @@ static FPClassTest computeKnownFPClassFromAssumes(const Value *V,
42334306
if (!isValidAssumeForContext(I, Q.CxtI, Q.DT))
42344307
continue;
42354308

4236-
CmpInst::Predicate Pred;
4237-
Value *LHS, *RHS;
4238-
uint64_t ClassVal = 0;
4239-
if (match(I->getArgOperand(0), m_FCmp(Pred, m_Value(LHS), m_Value(RHS)))) {
4240-
const APFloat *CRHS;
4241-
if (match(RHS, m_APFloat(CRHS))) {
4242-
auto [CmpVal, MaskIfTrue, MaskIfFalse] =
4243-
fcmpImpliesClass(Pred, *F, LHS, *CRHS, LHS != V);
4244-
if (CmpVal == V)
4245-
KnownFromAssume &= MaskIfTrue;
4246-
}
4247-
} else if (match(I->getArgOperand(0),
4248-
m_Intrinsic<Intrinsic::is_fpclass>(
4249-
m_Value(LHS), m_ConstantInt(ClassVal)))) {
4250-
KnownFromAssume &= static_cast<FPClassTest>(ClassVal);
4251-
}
4309+
computeKnownFPClassFromCond(V, I->getArgOperand(0), /*CondIsTrue=*/true,
4310+
Q.CxtI, KnownFromContext);
42524311
}
42534312

4254-
return KnownFromAssume;
4313+
return KnownFromContext;
42554314
}
42564315

42574316
void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
@@ -4359,17 +4418,21 @@ void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
43594418
KnownNotFromFlags |= fcInf;
43604419
}
43614420

4362-
if (Q.AC) {
4363-
FPClassTest AssumedClasses = computeKnownFPClassFromAssumes(V, Q);
4364-
KnownNotFromFlags |= ~AssumedClasses;
4365-
}
4421+
KnownFPClass AssumedClasses = computeKnownFPClassFromContext(V, Q);
4422+
KnownNotFromFlags |= ~AssumedClasses.KnownFPClasses;
43664423

43674424
// We no longer need to find out about these bits from inputs if we can
43684425
// assume this from flags/attributes.
43694426
InterestedClasses &= ~KnownNotFromFlags;
43704427

43714428
auto ClearClassesFromFlags = make_scope_exit([=, &Known] {
43724429
Known.knownNot(KnownNotFromFlags);
4430+
if (!Known.SignBit && AssumedClasses.SignBit) {
4431+
if (*AssumedClasses.SignBit)
4432+
Known.signBitMustBeOne();
4433+
else
4434+
Known.signBitMustBeZero();
4435+
}
43734436
});
43744437

43754438
if (!Op)
@@ -5271,7 +5334,8 @@ void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
52715334

52725335
bool First = true;
52735336

5274-
for (Value *IncValue : P->incoming_values()) {
5337+
for (const Use &U : P->operands()) {
5338+
Value *IncValue = U.get();
52755339
// Skip direct self references.
52765340
if (IncValue == P)
52775341
continue;
@@ -5280,8 +5344,10 @@ void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
52805344
// Recurse, but cap the recursion to two levels, because we don't want
52815345
// to waste time spinning around in loops. We need at least depth 2 to
52825346
// detect known sign bits.
5283-
computeKnownFPClass(IncValue, DemandedElts, InterestedClasses, KnownSrc,
5284-
PhiRecursionLimit, Q);
5347+
computeKnownFPClass(
5348+
IncValue, DemandedElts, InterestedClasses, KnownSrc,
5349+
PhiRecursionLimit,
5350+
Q.getWithInstruction(P->getIncomingBlock(U)->getTerminator()));
52855351

52865352
if (First) {
52875353
Known = KnownSrc;

llvm/test/Transforms/InstCombine/fpclass-from-dom-cond.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@ define float @test_signbit_check(float %x, i1 %cond) {
1616
; CHECK-NEXT: br label [[IF_END]]
1717
; CHECK: if.end:
1818
; CHECK-NEXT: [[VALUE:%.*]] = phi float [ [[FNEG]], [[IF_THEN1]] ], [ [[X]], [[IF_THEN2]] ], [ [[X]], [[IF_ELSE]] ]
19-
; CHECK-NEXT: [[RET:%.*]] = call float @llvm.fabs.f32(float [[VALUE]])
20-
; CHECK-NEXT: ret float [[RET]]
19+
; CHECK-NEXT: ret float [[VALUE]]
2120
;
2221
%i32 = bitcast float %x to i32
2322
%cmp = icmp slt i32 %i32, 0

0 commit comments

Comments
 (0)