Skip to content

Commit fa9f293

Browse files
committed
[ValueTracking] Compute known FPClass from dominating condition
1 parent 4f984e6 commit fa9f293

File tree

3 files changed

+71
-15
lines changed

3 files changed

+71
-15
lines changed

llvm/lib/Analysis/DomConditionCache.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ static void findAffectedValues(Value *Cond,
5151
// Handle (A + C1) u< C2, which is the canonical form of A > C3 && A < C4.
5252
if (match(A, m_Add(m_Value(X), m_ConstantInt())))
5353
AddAffected(X);
54+
// Handle icmp slt/sgt (bitcast X to int) 0/-1
55+
if (match(A, m_BitCast(m_Value(X))))
56+
Affected.push_back(X);
5457
}
5558
}
5659
}

llvm/lib/Analysis/ValueTracking.cpp

Lines changed: 67 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4213,9 +4213,56 @@ llvm::fcmpImpliesClass(CmpInst::Predicate Pred, const Function &F, Value *LHS,
42134213
return fcmpImpliesClass(Pred, F, LHS, *ConstRHS, LookThroughSrc);
42144214
}
42154215

4216-
static FPClassTest computeKnownFPClassFromAssumes(const Value *V,
4217-
const SimplifyQuery &Q) {
4218-
FPClassTest KnownFromAssume = fcAllFlags;
4216+
static KnownFPClass computeKnownFPClassFromContext(const Value *V,
4217+
const SimplifyQuery &Q) {
4218+
KnownFPClass KnownFromContext;
4219+
4220+
if (!Q.CxtI)
4221+
return KnownFromContext;
4222+
4223+
if (Q.DC && Q.DT) {
4224+
auto computeKnownFPClassFromCmp = [&](CmpInst::Predicate Pred, Value *LHS,
4225+
Value *RHS) {
4226+
if (match(LHS, m_BitCast(m_Specific(V)))) {
4227+
Type *SrcType = V->getType();
4228+
Type *DstType = LHS->getType();
4229+
4230+
// Make sure the bitcast doesn't change between scalar and vector and
4231+
// doesn't change the number of vector elements.
4232+
if (SrcType->isVectorTy() == DstType->isVectorTy() &&
4233+
SrcType->getScalarSizeInBits() == DstType->getScalarSizeInBits()) {
4234+
// TODO: move IsSignBitCheck to ValueTracking
4235+
if ((Pred == ICmpInst::ICMP_SLT && match(RHS, m_Zero())) ||
4236+
(Pred == ICmpInst::ICMP_SLE && match(RHS, m_AllOnes())))
4237+
KnownFromContext.signBitMustBeOne();
4238+
else if (Pred == ICmpInst::ICMP_SGT && match(RHS, m_AllOnes()) ||
4239+
(Pred == ICmpInst::ICMP_SGE && match(RHS, m_Zero())))
4240+
KnownFromContext.signBitMustBeZero();
4241+
}
4242+
}
4243+
};
4244+
4245+
// Handle dominating conditions.
4246+
for (BranchInst *BI : Q.DC->conditionsFor(V)) {
4247+
// TODO: handle fcmps
4248+
auto *Cmp = dyn_cast<ICmpInst>(BI->getCondition());
4249+
if (!Cmp)
4250+
continue;
4251+
4252+
BasicBlockEdge Edge0(BI->getParent(), BI->getSuccessor(0));
4253+
if (Q.DT->dominates(Edge0, Q.CxtI->getParent()))
4254+
computeKnownFPClassFromCmp(Cmp->getPredicate(), Cmp->getOperand(0),
4255+
Cmp->getOperand(1));
4256+
4257+
BasicBlockEdge Edge1(BI->getParent(), BI->getSuccessor(1));
4258+
if (Q.DT->dominates(Edge1, Q.CxtI->getParent()))
4259+
computeKnownFPClassFromCmp(Cmp->getInversePredicate(),
4260+
Cmp->getOperand(0), Cmp->getOperand(1));
4261+
}
4262+
}
4263+
4264+
if (!Q.AC)
4265+
return KnownFromContext;
42194266

42204267
// Try to restrict the floating-point classes based on information from
42214268
// assumptions.
@@ -4242,16 +4289,16 @@ static FPClassTest computeKnownFPClassFromAssumes(const Value *V,
42424289
auto [CmpVal, MaskIfTrue, MaskIfFalse] =
42434290
fcmpImpliesClass(Pred, *F, LHS, *CRHS, LHS != V);
42444291
if (CmpVal == V)
4245-
KnownFromAssume &= MaskIfTrue;
4292+
KnownFromContext.knownNot(~MaskIfTrue);
42464293
}
42474294
} else if (match(I->getArgOperand(0),
42484295
m_Intrinsic<Intrinsic::is_fpclass>(
42494296
m_Value(LHS), m_ConstantInt(ClassVal)))) {
4250-
KnownFromAssume &= static_cast<FPClassTest>(ClassVal);
4297+
KnownFromContext.knownNot(~static_cast<FPClassTest>(ClassVal));
42514298
}
42524299
}
42534300

4254-
return KnownFromAssume;
4301+
return KnownFromContext;
42554302
}
42564303

42574304
void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
@@ -4359,17 +4406,21 @@ void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
43594406
KnownNotFromFlags |= fcInf;
43604407
}
43614408

4362-
if (Q.AC) {
4363-
FPClassTest AssumedClasses = computeKnownFPClassFromAssumes(V, Q);
4364-
KnownNotFromFlags |= ~AssumedClasses;
4365-
}
4409+
KnownFPClass AssumedClasses = computeKnownFPClassFromContext(V, Q);
4410+
KnownNotFromFlags |= ~AssumedClasses.KnownFPClasses;
43664411

43674412
// We no longer need to find out about these bits from inputs if we can
43684413
// assume this from flags/attributes.
43694414
InterestedClasses &= ~KnownNotFromFlags;
43704415

43714416
auto ClearClassesFromFlags = make_scope_exit([=, &Known] {
43724417
Known.knownNot(KnownNotFromFlags);
4418+
if (!Known.SignBit && AssumedClasses.SignBit) {
4419+
if (*AssumedClasses.SignBit)
4420+
Known.signBitMustBeOne();
4421+
else
4422+
Known.signBitMustBeZero();
4423+
}
43734424
});
43744425

43754426
if (!Op)
@@ -5271,7 +5322,8 @@ void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
52715322

52725323
bool First = true;
52735324

5274-
for (Value *IncValue : P->incoming_values()) {
5325+
for (const Use &U : P->operands()) {
5326+
Value *IncValue = U.get();
52755327
// Skip direct self references.
52765328
if (IncValue == P)
52775329
continue;
@@ -5280,8 +5332,10 @@ void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
52805332
// Recurse, but cap the recursion to two levels, because we don't want
52815333
// to waste time spinning around in loops. We need at least depth 2 to
52825334
// detect known sign bits.
5283-
computeKnownFPClass(IncValue, DemandedElts, InterestedClasses, KnownSrc,
5284-
PhiRecursionLimit, Q);
5335+
computeKnownFPClass(
5336+
IncValue, DemandedElts, InterestedClasses, KnownSrc,
5337+
PhiRecursionLimit,
5338+
Q.getWithInstruction(P->getIncomingBlock(U)->getTerminator()));
52855339

52865340
if (First) {
52875341
Known = KnownSrc;

llvm/test/Transforms/InstCombine/fpclass-from-dom-cond.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@ define float @test_signbit_check(float %x, i1 %cond) {
1616
; CHECK-NEXT: br label [[IF_END]]
1717
; CHECK: if.end:
1818
; CHECK-NEXT: [[VALUE:%.*]] = phi float [ [[FNEG]], [[IF_THEN1]] ], [ [[X]], [[IF_THEN2]] ], [ [[X]], [[IF_ELSE]] ]
19-
; CHECK-NEXT: [[RET:%.*]] = call float @llvm.fabs.f32(float [[VALUE]])
20-
; CHECK-NEXT: ret float [[RET]]
19+
; CHECK-NEXT: ret float [[VALUE]]
2120
;
2221
%i32 = bitcast float %x to i32
2322
%cmp = icmp slt i32 %i32, 0

0 commit comments

Comments
 (0)