diff options
Diffstat (limited to 'gnu/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp')
| -rw-r--r-- | gnu/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp | 589 |
1 files changed, 356 insertions, 233 deletions
diff --git a/gnu/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp b/gnu/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp index cf98088111b..c5ed6d5c1b8 100644 --- a/gnu/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp +++ b/gnu/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp @@ -43,6 +43,7 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Transforms/Scalar/InductiveRangeCheckElimination.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/None.h" @@ -52,6 +53,7 @@ #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" #include "llvm/Analysis/BranchProbabilityInfo.h" +#include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/ScalarEvolution.h" @@ -179,10 +181,7 @@ public: OS << " Step: "; Step->print(OS); OS << " End: "; - if (End) - End->print(OS); - else - OS << "(null)"; + End->print(OS); OS << "\n CheckUse: "; getCheckUse()->getUser()->print(OS); OS << " Operand: " << getCheckUse()->getOperandNo() << "\n"; @@ -196,7 +195,7 @@ public: Use *getCheckUse() const { return CheckUse; } /// Represents an signed integer range [Range.getBegin(), Range.getEnd()). If - /// R.getEnd() sle R.getBegin(), then R denotes the empty range. + /// R.getEnd() le R.getBegin(), then R denotes the empty range. class Range { const SCEV *Begin; @@ -238,17 +237,31 @@ public: /// checks, and hence don't end up in \p Checks. static void extractRangeChecksFromBranch(BranchInst *BI, Loop *L, ScalarEvolution &SE, - BranchProbabilityInfo &BPI, + BranchProbabilityInfo *BPI, SmallVectorImpl<InductiveRangeCheck> &Checks); }; -class InductiveRangeCheckElimination : public LoopPass { +class InductiveRangeCheckElimination { + ScalarEvolution &SE; + BranchProbabilityInfo *BPI; + DominatorTree &DT; + LoopInfo &LI; + +public: + InductiveRangeCheckElimination(ScalarEvolution &SE, + BranchProbabilityInfo *BPI, DominatorTree &DT, + LoopInfo &LI) + : SE(SE), BPI(BPI), DT(DT), LI(LI) {} + + bool run(Loop *L, function_ref<void(Loop *, bool)> LPMAddNewLoop); +}; + +class IRCELegacyPass : public LoopPass { public: static char ID; - InductiveRangeCheckElimination() : LoopPass(ID) { - initializeInductiveRangeCheckEliminationPass( - *PassRegistry::getPassRegistry()); + IRCELegacyPass() : LoopPass(ID) { + initializeIRCELegacyPassPass(*PassRegistry::getPassRegistry()); } void getAnalysisUsage(AnalysisUsage &AU) const override { @@ -261,14 +274,14 @@ public: } // end anonymous namespace -char InductiveRangeCheckElimination::ID = 0; +char IRCELegacyPass::ID = 0; -INITIALIZE_PASS_BEGIN(InductiveRangeCheckElimination, "irce", +INITIALIZE_PASS_BEGIN(IRCELegacyPass, "irce", "Inductive range check elimination", false, false) INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopPass) -INITIALIZE_PASS_END(InductiveRangeCheckElimination, "irce", - "Inductive range check elimination", false, false) +INITIALIZE_PASS_END(IRCELegacyPass, "irce", "Inductive range check elimination", + false, false) StringRef InductiveRangeCheck::rangeCheckKindToStr( InductiveRangeCheck::RangeCheckKind RCK) { @@ -299,13 +312,8 @@ InductiveRangeCheck::RangeCheckKind InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI, ScalarEvolution &SE, Value *&Index, Value *&Length, bool &IsSigned) { - auto IsNonNegativeAndNotLoopVarying = [&SE, L](Value *V) { - const SCEV *S = SE.getSCEV(V); - if (isa<SCEVCouldNotCompute>(S)) - return false; - - return SE.getLoopDisposition(S, L) == ScalarEvolution::LoopInvariant && - SE.isKnownNonNegative(S); + auto IsLoopInvariant = [&SE, L](Value *V) { + return SE.isLoopInvariant(SE.getSCEV(V), L); }; ICmpInst::Predicate Pred = ICI->getPredicate(); @@ -337,7 +345,7 @@ InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI, return RANGE_CHECK_LOWER; } - if (IsNonNegativeAndNotLoopVarying(LHS)) { + if (IsLoopInvariant(LHS)) { Index = RHS; Length = LHS; return RANGE_CHECK_UPPER; @@ -349,7 +357,7 @@ InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI, LLVM_FALLTHROUGH; case ICmpInst::ICMP_UGT: IsSigned = false; - if (IsNonNegativeAndNotLoopVarying(LHS)) { + if (IsLoopInvariant(LHS)) { Index = RHS; Length = LHS; return RANGE_CHECK_BOTH; @@ -394,8 +402,23 @@ void InductiveRangeCheck::extractRangeChecksFromCond( if (!IsAffineIndex) return; + const SCEV *End = nullptr; + // We strengthen "0 <= I" to "0 <= I < INT_SMAX" and "I < L" to "0 <= I < L". + // We can potentially do much better here. + if (Length) + End = SE.getSCEV(Length); + else { + assert(RCKind == InductiveRangeCheck::RANGE_CHECK_LOWER && "invariant!"); + // So far we can only reach this point for Signed range check. This may + // change in future. In this case we will need to pick Unsigned max for the + // unsigned range check. + unsigned BitWidth = cast<IntegerType>(IndexAddRec->getType())->getBitWidth(); + const SCEV *SIntMax = SE.getConstant(APInt::getSignedMaxValue(BitWidth)); + End = SIntMax; + } + InductiveRangeCheck IRC; - IRC.End = Length ? SE.getSCEV(Length) : nullptr; + IRC.End = End; IRC.Begin = IndexAddRec->getStart(); IRC.Step = IndexAddRec->getStepRecurrence(SE); IRC.CheckUse = &ConditionUse; @@ -405,15 +428,15 @@ void InductiveRangeCheck::extractRangeChecksFromCond( } void InductiveRangeCheck::extractRangeChecksFromBranch( - BranchInst *BI, Loop *L, ScalarEvolution &SE, BranchProbabilityInfo &BPI, + BranchInst *BI, Loop *L, ScalarEvolution &SE, BranchProbabilityInfo *BPI, SmallVectorImpl<InductiveRangeCheck> &Checks) { if (BI->isUnconditional() || BI->getParent() == L->getLoopLatch()) return; BranchProbability LikelyTaken(15, 16); - if (!SkipProfitabilityChecks && - BPI.getEdgeProbability(BI->getParent(), (unsigned)0) < LikelyTaken) + if (!SkipProfitabilityChecks && BPI && + BPI->getEdgeProbability(BI->getParent(), (unsigned)0) < LikelyTaken) return; SmallPtrSet<Value *, 8> Visited; @@ -504,9 +527,8 @@ struct LoopStructure { } static Optional<LoopStructure> parseLoopStructure(ScalarEvolution &, - BranchProbabilityInfo &BPI, - Loop &, - const char *&); + BranchProbabilityInfo *BPI, + Loop &, const char *&); }; /// This class is used to constrain loops to run within a given iteration space. @@ -573,7 +595,7 @@ class LoopConstrainer { // Create the appropriate loop structure needed to describe a cloned copy of // `Original`. The clone is described by `VM`. Loop *createClonedLoopStructure(Loop *Original, Loop *Parent, - ValueToValueMapTy &VM); + ValueToValueMapTy &VM, bool IsSubloop); // Rewrite the iteration space of the loop denoted by (LS, Preheader). The // iteration space of the rewritten loop ends at ExitLoopAt. The start of the @@ -625,8 +647,8 @@ class LoopConstrainer { LLVMContext &Ctx; ScalarEvolution &SE; DominatorTree &DT; - LPPassManager &LPM; LoopInfo &LI; + function_ref<void(Loop *, bool)> LPMAddNewLoop; // Information about the original loop we started out with. Loop &OriginalLoop; @@ -646,12 +668,13 @@ class LoopConstrainer { LoopStructure MainLoopStructure; public: - LoopConstrainer(Loop &L, LoopInfo &LI, LPPassManager &LPM, + LoopConstrainer(Loop &L, LoopInfo &LI, + function_ref<void(Loop *, bool)> LPMAddNewLoop, const LoopStructure &LS, ScalarEvolution &SE, DominatorTree &DT, InductiveRangeCheck::Range R) : F(*L.getHeader()->getParent()), Ctx(L.getHeader()->getContext()), - SE(SE), DT(DT), LPM(LPM), LI(LI), OriginalLoop(L), Range(R), - MainLoopStructure(LS) {} + SE(SE), DT(DT), LI(LI), LPMAddNewLoop(LPMAddNewLoop), OriginalLoop(L), + Range(R), MainLoopStructure(LS) {} // Entry point for the algorithm. Returns true on success. bool run(); @@ -666,56 +689,141 @@ void LoopConstrainer::replacePHIBlock(PHINode *PN, BasicBlock *Block, PN->setIncomingBlock(i, ReplaceBy); } -static bool CanBeMax(ScalarEvolution &SE, const SCEV *S, bool Signed) { - APInt Max = Signed ? - APInt::getSignedMaxValue(cast<IntegerType>(S->getType())->getBitWidth()) : - APInt::getMaxValue(cast<IntegerType>(S->getType())->getBitWidth()); - return SE.getSignedRange(S).contains(Max) && - SE.getUnsignedRange(S).contains(Max); +static bool CannotBeMaxInLoop(const SCEV *BoundSCEV, Loop *L, + ScalarEvolution &SE, bool Signed) { + unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth(); + APInt Max = Signed ? APInt::getSignedMaxValue(BitWidth) : + APInt::getMaxValue(BitWidth); + auto Predicate = Signed ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT; + return SE.isAvailableAtLoopEntry(BoundSCEV, L) && + SE.isLoopEntryGuardedByCond(L, Predicate, BoundSCEV, + SE.getConstant(Max)); +} + +/// Given a loop with an deccreasing induction variable, is it possible to +/// safely calculate the bounds of a new loop using the given Predicate. +static bool isSafeDecreasingBound(const SCEV *Start, + const SCEV *BoundSCEV, const SCEV *Step, + ICmpInst::Predicate Pred, + unsigned LatchBrExitIdx, + Loop *L, ScalarEvolution &SE) { + if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT && + Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT) + return false; + + if (!SE.isAvailableAtLoopEntry(BoundSCEV, L)) + return false; + + assert(SE.isKnownNegative(Step) && "expecting negative step"); + + LLVM_DEBUG(dbgs() << "irce: isSafeDecreasingBound with:\n"); + LLVM_DEBUG(dbgs() << "irce: Start: " << *Start << "\n"); + LLVM_DEBUG(dbgs() << "irce: Step: " << *Step << "\n"); + LLVM_DEBUG(dbgs() << "irce: BoundSCEV: " << *BoundSCEV << "\n"); + LLVM_DEBUG(dbgs() << "irce: Pred: " << ICmpInst::getPredicateName(Pred) + << "\n"); + LLVM_DEBUG(dbgs() << "irce: LatchExitBrIdx: " << LatchBrExitIdx << "\n"); + + bool IsSigned = ICmpInst::isSigned(Pred); + // The predicate that we need to check that the induction variable lies + // within bounds. + ICmpInst::Predicate BoundPred = + IsSigned ? CmpInst::ICMP_SGT : CmpInst::ICMP_UGT; + + if (LatchBrExitIdx == 1) + return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, BoundSCEV); + + assert(LatchBrExitIdx == 0 && + "LatchBrExitIdx should be either 0 or 1"); + + const SCEV *StepPlusOne = SE.getAddExpr(Step, SE.getOne(Step->getType())); + unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth(); + APInt Min = IsSigned ? APInt::getSignedMinValue(BitWidth) : + APInt::getMinValue(BitWidth); + const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Min), StepPlusOne); + + const SCEV *MinusOne = + SE.getMinusSCEV(BoundSCEV, SE.getOne(BoundSCEV->getType())); + + return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, MinusOne) && + SE.isLoopEntryGuardedByCond(L, BoundPred, BoundSCEV, Limit); + +} + +/// Given a loop with an increasing induction variable, is it possible to +/// safely calculate the bounds of a new loop using the given Predicate. +static bool isSafeIncreasingBound(const SCEV *Start, + const SCEV *BoundSCEV, const SCEV *Step, + ICmpInst::Predicate Pred, + unsigned LatchBrExitIdx, + Loop *L, ScalarEvolution &SE) { + if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT && + Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT) + return false; + + if (!SE.isAvailableAtLoopEntry(BoundSCEV, L)) + return false; + + LLVM_DEBUG(dbgs() << "irce: isSafeIncreasingBound with:\n"); + LLVM_DEBUG(dbgs() << "irce: Start: " << *Start << "\n"); + LLVM_DEBUG(dbgs() << "irce: Step: " << *Step << "\n"); + LLVM_DEBUG(dbgs() << "irce: BoundSCEV: " << *BoundSCEV << "\n"); + LLVM_DEBUG(dbgs() << "irce: Pred: " << ICmpInst::getPredicateName(Pred) + << "\n"); + LLVM_DEBUG(dbgs() << "irce: LatchExitBrIdx: " << LatchBrExitIdx << "\n"); + + bool IsSigned = ICmpInst::isSigned(Pred); + // The predicate that we need to check that the induction variable lies + // within bounds. + ICmpInst::Predicate BoundPred = + IsSigned ? CmpInst::ICMP_SLT : CmpInst::ICMP_ULT; + + if (LatchBrExitIdx == 1) + return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, BoundSCEV); + + assert(LatchBrExitIdx == 0 && "LatchBrExitIdx should be 0 or 1"); + + const SCEV *StepMinusOne = + SE.getMinusSCEV(Step, SE.getOne(Step->getType())); + unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth(); + APInt Max = IsSigned ? APInt::getSignedMaxValue(BitWidth) : + APInt::getMaxValue(BitWidth); + const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Max), StepMinusOne); + + return (SE.isLoopEntryGuardedByCond(L, BoundPred, Start, + SE.getAddExpr(BoundSCEV, Step)) && + SE.isLoopEntryGuardedByCond(L, BoundPred, BoundSCEV, Limit)); } -static bool SumCanReachMax(ScalarEvolution &SE, const SCEV *S1, const SCEV *S2, - bool Signed) { - // S1 < INT_MAX - S2 ===> S1 + S2 < INT_MAX. - assert(SE.isKnownNonNegative(S2) && - "We expected the 2nd arg to be non-negative!"); - const SCEV *Max = SE.getConstant( - Signed ? APInt::getSignedMaxValue( - cast<IntegerType>(S1->getType())->getBitWidth()) - : APInt::getMaxValue( - cast<IntegerType>(S1->getType())->getBitWidth())); - const SCEV *CapForS1 = SE.getMinusSCEV(Max, S2); - return !SE.isKnownPredicate(Signed ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT, - S1, CapForS1); +static bool CannotBeMinInLoop(const SCEV *BoundSCEV, Loop *L, + ScalarEvolution &SE, bool Signed) { + unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth(); + APInt Min = Signed ? APInt::getSignedMinValue(BitWidth) : + APInt::getMinValue(BitWidth); + auto Predicate = Signed ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT; + return SE.isAvailableAtLoopEntry(BoundSCEV, L) && + SE.isLoopEntryGuardedByCond(L, Predicate, BoundSCEV, + SE.getConstant(Min)); } -static bool CanBeMin(ScalarEvolution &SE, const SCEV *S, bool Signed) { - APInt Min = Signed ? - APInt::getSignedMinValue(cast<IntegerType>(S->getType())->getBitWidth()) : - APInt::getMinValue(cast<IntegerType>(S->getType())->getBitWidth()); - return SE.getSignedRange(S).contains(Min) && - SE.getUnsignedRange(S).contains(Min); +static bool isKnownNonNegativeInLoop(const SCEV *BoundSCEV, const Loop *L, + ScalarEvolution &SE) { + const SCEV *Zero = SE.getZero(BoundSCEV->getType()); + return SE.isAvailableAtLoopEntry(BoundSCEV, L) && + SE.isLoopEntryGuardedByCond(L, ICmpInst::ICMP_SGE, BoundSCEV, Zero); } -static bool SumCanReachMin(ScalarEvolution &SE, const SCEV *S1, const SCEV *S2, - bool Signed) { - // S1 > INT_MIN - S2 ===> S1 + S2 > INT_MIN. - assert(SE.isKnownNonPositive(S2) && - "We expected the 2nd arg to be non-positive!"); - const SCEV *Max = SE.getConstant( - Signed ? APInt::getSignedMinValue( - cast<IntegerType>(S1->getType())->getBitWidth()) - : APInt::getMinValue( - cast<IntegerType>(S1->getType())->getBitWidth())); - const SCEV *CapForS1 = SE.getMinusSCEV(Max, S2); - return !SE.isKnownPredicate(Signed ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT, - S1, CapForS1); +static bool isKnownNegativeInLoop(const SCEV *BoundSCEV, const Loop *L, + ScalarEvolution &SE) { + const SCEV *Zero = SE.getZero(BoundSCEV->getType()); + return SE.isAvailableAtLoopEntry(BoundSCEV, L) && + SE.isLoopEntryGuardedByCond(L, ICmpInst::ICMP_SLT, BoundSCEV, Zero); } Optional<LoopStructure> LoopStructure::parseLoopStructure(ScalarEvolution &SE, - BranchProbabilityInfo &BPI, - Loop &L, const char *&FailureReason) { + BranchProbabilityInfo *BPI, Loop &L, + const char *&FailureReason) { if (!L.isLoopSimplifyForm()) { FailureReason = "loop not in LoopSimplify form"; return None; @@ -750,7 +858,8 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, unsigned LatchBrExitIdx = LatchBr->getSuccessor(0) == Header ? 1 : 0; BranchProbability ExitProbability = - BPI.getEdgeProbability(LatchBr->getParent(), LatchBrExitIdx); + BPI ? BPI->getEdgeProbability(LatchBr->getParent(), LatchBrExitIdx) + : BranchProbability::getZero(); if (!SkipProfitabilityChecks && ExitProbability > BranchProbability(1, MaxExitProbReciprocal)) { @@ -816,43 +925,29 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, return AR->getNoWrapFlags(SCEV::FlagNSW) != SCEV::FlagAnyWrap; }; - // Here we check whether the suggested AddRec is an induction variable that - // can be handled (i.e. with known constant step), and if yes, calculate its - // step and identify whether it is increasing or decreasing. - auto IsInductionVar = [&](const SCEVAddRecExpr *AR, bool &IsIncreasing, - ConstantInt *&StepCI) { - if (!AR->isAffine()) - return false; - - // Currently we only work with induction variables that have been proved to - // not wrap. This restriction can potentially be lifted in the future. - - if (!HasNoSignedWrap(AR)) - return false; - - if (const SCEVConstant *StepExpr = - dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE))) { - StepCI = StepExpr->getValue(); - assert(!StepCI->isZero() && "Zero step?"); - IsIncreasing = !StepCI->isNegative(); - return true; - } - - return false; - }; - // `ICI` is interpreted as taking the backedge if the *next* value of the // induction variable satisfies some constraint. const SCEVAddRecExpr *IndVarBase = cast<SCEVAddRecExpr>(LeftSCEV); - bool IsIncreasing = false; - bool IsSignedPredicate = true; - ConstantInt *StepCI; - if (!IsInductionVar(IndVarBase, IsIncreasing, StepCI)) { + if (!IndVarBase->isAffine()) { + FailureReason = "LHS in icmp not induction variable"; + return None; + } + const SCEV* StepRec = IndVarBase->getStepRecurrence(SE); + if (!isa<SCEVConstant>(StepRec)) { FailureReason = "LHS in icmp not induction variable"; return None; } + ConstantInt *StepCI = cast<SCEVConstant>(StepRec)->getValue(); + + if (ICI->isEquality() && !HasNoSignedWrap(IndVarBase)) { + FailureReason = "LHS in icmp needs nsw for equality predicates"; + return None; + } + assert(!StepCI->isZero() && "Zero step?"); + bool IsIncreasing = !StepCI->isNegative(); + bool IsSignedPredicate = ICmpInst::isSigned(Pred); const SCEV *StartNext = IndVarBase->getStart(); const SCEV *Addend = SE.getNegativeSCEV(IndVarBase->getStepRecurrence(SE)); const SCEV *IndVarStart = SE.getAddExpr(StartNext, Addend); @@ -870,22 +965,29 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, // If both parts are known non-negative, it is profitable to use // unsigned comparison in increasing loop. This allows us to make the // comparison check against "RightSCEV + 1" more optimistic. - if (SE.isKnownNonNegative(IndVarStart) && - SE.isKnownNonNegative(RightSCEV)) + if (isKnownNonNegativeInLoop(IndVarStart, &L, SE) && + isKnownNonNegativeInLoop(RightSCEV, &L, SE)) Pred = ICmpInst::ICMP_ULT; else Pred = ICmpInst::ICMP_SLT; - else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0 && - !CanBeMin(SE, RightSCEV, /* IsSignedPredicate */ true)) { + else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) { // while (true) { while (true) { // if (++i == len) ---> if (++i > len - 1) // break; break; // ... ... // } } - // TODO: Insert ICMP_UGT if both are non-negative? - Pred = ICmpInst::ICMP_SGT; - RightSCEV = SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType())); - DecreasedRightValueByOne = true; + if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) && + CannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/false)) { + Pred = ICmpInst::ICMP_UGT; + RightSCEV = SE.getMinusSCEV(RightSCEV, + SE.getOne(RightSCEV->getType())); + DecreasedRightValueByOne = true; + } else if (CannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/true)) { + Pred = ICmpInst::ICMP_SGT; + RightSCEV = SE.getMinusSCEV(RightSCEV, + SE.getOne(RightSCEV->getType())); + DecreasedRightValueByOne = true; + } } } @@ -899,36 +1001,18 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, return None; } - IsSignedPredicate = - Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGT; - + IsSignedPredicate = ICmpInst::isSigned(Pred); if (!IsSignedPredicate && !AllowUnsignedLatchCondition) { FailureReason = "unsigned latch conditions are explicitly prohibited"; return None; } - // The predicate that we need to check that the induction variable lies - // within bounds. - ICmpInst::Predicate BoundPred = - IsSignedPredicate ? CmpInst::ICMP_SLT : CmpInst::ICMP_ULT; - + if (!isSafeIncreasingBound(IndVarStart, RightSCEV, Step, Pred, + LatchBrExitIdx, &L, SE)) { + FailureReason = "Unsafe loop bounds"; + return None; + } if (LatchBrExitIdx == 0) { - const SCEV *StepMinusOne = SE.getMinusSCEV(Step, - SE.getOne(Step->getType())); - if (SumCanReachMax(SE, RightSCEV, StepMinusOne, IsSignedPredicate)) { - // TODO: this restriction is easily removable -- we just have to - // remember that the icmp was an slt and not an sle. - FailureReason = "limit may overflow when coercing le to lt"; - return None; - } - - if (!SE.isLoopEntryGuardedByCond( - &L, BoundPred, IndVarStart, - SE.getAddExpr(RightSCEV, Step))) { - FailureReason = "Induction variable start not bounded by upper limit"; - return None; - } - // We need to increase the right value unless we have already decreased // it virtually when we replaced EQ with SGT. if (!DecreasedRightValueByOne) { @@ -936,10 +1020,6 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, RightValue = B.CreateAdd(RightValue, One); } } else { - if (!SE.isLoopEntryGuardedByCond(&L, BoundPred, IndVarStart, RightSCEV)) { - FailureReason = "Induction variable start not bounded by upper limit"; - return None; - } assert(!DecreasedRightValueByOne && "Right value can be decreased only for LatchBrExitIdx == 0!"); } @@ -955,17 +1035,22 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, // that both operands are non-negative, because it will only pessimize // our check against "RightSCEV - 1". Pred = ICmpInst::ICMP_SGT; - else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0 && - !CanBeMax(SE, RightSCEV, /* IsSignedPredicate */ true)) { + else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) { // while (true) { while (true) { // if (--i == len) ---> if (--i < len + 1) // break; break; // ... ... // } } - // TODO: Insert ICMP_ULT if both are non-negative? - Pred = ICmpInst::ICMP_SLT; - RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType())); - IncreasedRightValueByOne = true; + if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) && + CannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ false)) { + Pred = ICmpInst::ICMP_ULT; + RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType())); + IncreasedRightValueByOne = true; + } else if (CannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ true)) { + Pred = ICmpInst::ICMP_SLT; + RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType())); + IncreasedRightValueByOne = true; + } } } @@ -988,27 +1073,13 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, return None; } - // The predicate that we need to check that the induction variable lies - // within bounds. - ICmpInst::Predicate BoundPred = - IsSignedPredicate ? CmpInst::ICMP_SGT : CmpInst::ICMP_UGT; + if (!isSafeDecreasingBound(IndVarStart, RightSCEV, Step, Pred, + LatchBrExitIdx, &L, SE)) { + FailureReason = "Unsafe bounds"; + return None; + } if (LatchBrExitIdx == 0) { - const SCEV *StepPlusOne = SE.getAddExpr(Step, SE.getOne(Step->getType())); - if (SumCanReachMin(SE, RightSCEV, StepPlusOne, IsSignedPredicate)) { - // TODO: this restriction is easily removable -- we just have to - // remember that the icmp was an sgt and not an sge. - FailureReason = "limit may overflow when coercing ge to gt"; - return None; - } - - if (!SE.isLoopEntryGuardedByCond( - &L, BoundPred, IndVarStart, - SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType())))) { - FailureReason = "Induction variable start not bounded by lower limit"; - return None; - } - // We need to decrease the right value unless we have already increased // it virtually when we replaced EQ with SLT. if (!IncreasedRightValueByOne) { @@ -1016,10 +1087,6 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, RightValue = B.CreateSub(RightValue, One); } } else { - if (!SE.isLoopEntryGuardedByCond(&L, BoundPred, IndVarStart, RightSCEV)) { - FailureReason = "Induction variable start not bounded by lower limit"; - return None; - } assert(!IncreasedRightValueByOne && "Right value can be increased only for LatchBrExitIdx == 0!"); } @@ -1381,13 +1448,14 @@ void LoopConstrainer::addToParentLoopIfNeeded(ArrayRef<BasicBlock *> BBs) { } Loop *LoopConstrainer::createClonedLoopStructure(Loop *Original, Loop *Parent, - ValueToValueMapTy &VM) { + ValueToValueMapTy &VM, + bool IsSubloop) { Loop &New = *LI.AllocateLoop(); if (Parent) Parent->addChildLoop(&New); else LI.addTopLevelLoop(&New); - LPM.addLoop(New); + LPMAddNewLoop(&New, IsSubloop); // Add all of the blocks in Original to the new loop. for (auto *BB : Original->blocks()) @@ -1396,7 +1464,7 @@ Loop *LoopConstrainer::createClonedLoopStructure(Loop *Original, Loop *Parent, // Add all of the subloops to the new loop. for (Loop *SubLoop : *Original) - createClonedLoopStructure(SubLoop, &New, VM); + createClonedLoopStructure(SubLoop, &New, VM, /* IsSubloop */ true); return &New; } @@ -1414,7 +1482,7 @@ bool LoopConstrainer::run() { bool IsSignedPredicate = MainLoopStructure.IsSignedPredicate; Optional<SubRanges> MaybeSR = calculateSubRanges(IsSignedPredicate); if (!MaybeSR.hasValue()) { - DEBUG(dbgs() << "irce: could not compute subranges\n"); + LLVM_DEBUG(dbgs() << "irce: could not compute subranges\n"); return false; } @@ -1446,19 +1514,22 @@ bool LoopConstrainer::run() { if (Increasing) ExitPreLoopAtSCEV = *SR.LowLimit; else { - if (CanBeMin(SE, *SR.HighLimit, IsSignedPredicate)) { - DEBUG(dbgs() << "irce: could not prove no-overflow when computing " - << "preloop exit limit. HighLimit = " << *(*SR.HighLimit) - << "\n"); + if (CannotBeMinInLoop(*SR.HighLimit, &OriginalLoop, SE, + IsSignedPredicate)) + ExitPreLoopAtSCEV = SE.getAddExpr(*SR.HighLimit, MinusOneS); + else { + LLVM_DEBUG(dbgs() << "irce: could not prove no-overflow when computing " + << "preloop exit limit. HighLimit = " + << *(*SR.HighLimit) << "\n"); return false; } - ExitPreLoopAtSCEV = SE.getAddExpr(*SR.HighLimit, MinusOneS); } if (!isSafeToExpandAt(ExitPreLoopAtSCEV, InsertPt, SE)) { - DEBUG(dbgs() << "irce: could not prove that it is safe to expand the" - << " preloop exit limit " << *ExitPreLoopAtSCEV - << " at block " << InsertPt->getParent()->getName() << "\n"); + LLVM_DEBUG(dbgs() << "irce: could not prove that it is safe to expand the" + << " preloop exit limit " << *ExitPreLoopAtSCEV + << " at block " << InsertPt->getParent()->getName() + << "\n"); return false; } @@ -1472,19 +1543,22 @@ bool LoopConstrainer::run() { if (Increasing) ExitMainLoopAtSCEV = *SR.HighLimit; else { - if (CanBeMin(SE, *SR.LowLimit, IsSignedPredicate)) { - DEBUG(dbgs() << "irce: could not prove no-overflow when computing " - << "mainloop exit limit. LowLimit = " << *(*SR.LowLimit) - << "\n"); + if (CannotBeMinInLoop(*SR.LowLimit, &OriginalLoop, SE, + IsSignedPredicate)) + ExitMainLoopAtSCEV = SE.getAddExpr(*SR.LowLimit, MinusOneS); + else { + LLVM_DEBUG(dbgs() << "irce: could not prove no-overflow when computing " + << "mainloop exit limit. LowLimit = " + << *(*SR.LowLimit) << "\n"); return false; } - ExitMainLoopAtSCEV = SE.getAddExpr(*SR.LowLimit, MinusOneS); } if (!isSafeToExpandAt(ExitMainLoopAtSCEV, InsertPt, SE)) { - DEBUG(dbgs() << "irce: could not prove that it is safe to expand the" - << " main loop exit limit " << *ExitMainLoopAtSCEV - << " at block " << InsertPt->getParent()->getName() << "\n"); + LLVM_DEBUG(dbgs() << "irce: could not prove that it is safe to expand the" + << " main loop exit limit " << *ExitMainLoopAtSCEV + << " at block " << InsertPt->getParent()->getName() + << "\n"); return false; } @@ -1546,13 +1620,15 @@ bool LoopConstrainer::run() { // LI when LoopSimplifyForm is generated. Loop *PreL = nullptr, *PostL = nullptr; if (!PreLoop.Blocks.empty()) { - PreL = createClonedLoopStructure( - &OriginalLoop, OriginalLoop.getParentLoop(), PreLoop.Map); + PreL = createClonedLoopStructure(&OriginalLoop, + OriginalLoop.getParentLoop(), PreLoop.Map, + /* IsSubLoop */ false); } if (!PostLoop.Blocks.empty()) { - PostL = createClonedLoopStructure( - &OriginalLoop, OriginalLoop.getParentLoop(), PostLoop.Map); + PostL = + createClonedLoopStructure(&OriginalLoop, OriginalLoop.getParentLoop(), + PostLoop.Map, /* IsSubLoop */ false); } // This function canonicalizes the loop into Loop-Simplify and LCSSA forms. @@ -1618,32 +1694,34 @@ InductiveRangeCheck::computeSafeIterationSpace( unsigned BitWidth = cast<IntegerType>(IndVar->getType())->getBitWidth(); const SCEV *SIntMax = SE.getConstant(APInt::getSignedMaxValue(BitWidth)); - // Substract Y from X so that it does not go through border of the IV + // Subtract Y from X so that it does not go through border of the IV // iteration space. Mathematically, it is equivalent to: // - // ClampedSubstract(X, Y) = min(max(X - Y, INT_MIN), INT_MAX). [1] + // ClampedSubtract(X, Y) = min(max(X - Y, INT_MIN), INT_MAX). [1] // - // In [1], 'X - Y' is a mathematical substraction (result is not bounded to + // In [1], 'X - Y' is a mathematical subtraction (result is not bounded to // any width of bit grid). But after we take min/max, the result is // guaranteed to be within [INT_MIN, INT_MAX]. // // In [1], INT_MAX and INT_MIN are respectively signed and unsigned max/min // values, depending on type of latch condition that defines IV iteration // space. - auto ClampedSubstract = [&](const SCEV *X, const SCEV *Y) { - assert(SE.isKnownNonNegative(X) && - "We can only substract from values in [0; SINT_MAX]!"); + auto ClampedSubtract = [&](const SCEV *X, const SCEV *Y) { + // FIXME: The current implementation assumes that X is in [0, SINT_MAX]. + // This is required to ensure that SINT_MAX - X does not overflow signed and + // that X - Y does not overflow unsigned if Y is negative. Can we lift this + // restriction and make it work for negative X either? if (IsLatchSigned) { // X is a number from signed range, Y is interpreted as signed. // Even if Y is SINT_MAX, (X - Y) does not reach SINT_MIN. So the only // thing we should care about is that we didn't cross SINT_MAX. - // So, if Y is positive, we substract Y safely. + // So, if Y is positive, we subtract Y safely. // Rule 1: Y > 0 ---> Y. - // If 0 <= -Y <= (SINT_MAX - X), we substract Y safely. + // If 0 <= -Y <= (SINT_MAX - X), we subtract Y safely. // Rule 2: Y >=s (X - SINT_MAX) ---> Y. - // If 0 <= (SINT_MAX - X) < -Y, we can only substract (X - SINT_MAX). + // If 0 <= (SINT_MAX - X) < -Y, we can only subtract (X - SINT_MAX). // Rule 3: Y <s (X - SINT_MAX) ---> (X - SINT_MAX). - // It gives us smax(Y, X - SINT_MAX) to substract in all cases. + // It gives us smax(Y, X - SINT_MAX) to subtract in all cases. const SCEV *XMinusSIntMax = SE.getMinusSCEV(X, SIntMax); return SE.getMinusSCEV(X, SE.getSMaxExpr(Y, XMinusSIntMax), SCEV::FlagNSW); @@ -1651,29 +1729,45 @@ InductiveRangeCheck::computeSafeIterationSpace( // X is a number from unsigned range, Y is interpreted as signed. // Even if Y is SINT_MIN, (X - Y) does not reach UINT_MAX. So the only // thing we should care about is that we didn't cross zero. - // So, if Y is negative, we substract Y safely. + // So, if Y is negative, we subtract Y safely. // Rule 1: Y <s 0 ---> Y. - // If 0 <= Y <= X, we substract Y safely. + // If 0 <= Y <= X, we subtract Y safely. // Rule 2: Y <=s X ---> Y. - // If 0 <= X < Y, we should stop at 0 and can only substract X. + // If 0 <= X < Y, we should stop at 0 and can only subtract X. // Rule 3: Y >s X ---> X. - // It gives us smin(X, Y) to substract in all cases. + // It gives us smin(X, Y) to subtract in all cases. return SE.getMinusSCEV(X, SE.getSMinExpr(X, Y), SCEV::FlagNUW); }; const SCEV *M = SE.getMinusSCEV(C, A); const SCEV *Zero = SE.getZero(M->getType()); - const SCEV *Begin = ClampedSubstract(Zero, M); - const SCEV *L = nullptr; - // We strengthen "0 <= I" to "0 <= I < INT_SMAX" and "I < L" to "0 <= I < L". - // We can potentially do much better here. - if (const SCEV *EndLimit = getEnd()) - L = EndLimit; - else { - assert(Kind == InductiveRangeCheck::RANGE_CHECK_LOWER && "invariant!"); - L = SIntMax; - } - const SCEV *End = ClampedSubstract(L, M); + // This function returns SCEV equal to 1 if X is non-negative 0 otherwise. + auto SCEVCheckNonNegative = [&](const SCEV *X) { + const Loop *L = IndVar->getLoop(); + const SCEV *One = SE.getOne(X->getType()); + // Can we trivially prove that X is a non-negative or negative value? + if (isKnownNonNegativeInLoop(X, L, SE)) + return One; + else if (isKnownNegativeInLoop(X, L, SE)) + return Zero; + // If not, we will have to figure it out during the execution. + // Function smax(smin(X, 0), -1) + 1 equals to 1 if X >= 0 and 0 if X < 0. + const SCEV *NegOne = SE.getNegativeSCEV(One); + return SE.getAddExpr(SE.getSMaxExpr(SE.getSMinExpr(X, Zero), NegOne), One); + }; + // FIXME: Current implementation of ClampedSubtract implicitly assumes that + // X is non-negative (in sense of a signed value). We need to re-implement + // this function in a way that it will correctly handle negative X as well. + // We use it twice: for X = 0 everything is fine, but for X = getEnd() we can + // end up with a negative X and produce wrong results. So currently we ensure + // that if getEnd() is negative then both ends of the safe range are zero. + // Note that this may pessimize elimination of unsigned range checks against + // negative values. + const SCEV *REnd = getEnd(); + const SCEV *EndIsNonNegative = SCEVCheckNonNegative(REnd); + + const SCEV *Begin = SE.getMulExpr(ClampedSubtract(Zero, M), EndIsNonNegative); + const SCEV *End = SE.getMulExpr(ClampedSubtract(REnd, M), EndIsNonNegative); return InductiveRangeCheck::Range(Begin, End); } @@ -1735,26 +1829,56 @@ IntersectUnsignedRange(ScalarEvolution &SE, return Ret; } -bool InductiveRangeCheckElimination::runOnLoop(Loop *L, LPPassManager &LPM) { +PreservedAnalyses IRCEPass::run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, + LPMUpdater &U) { + Function *F = L.getHeader()->getParent(); + const auto &FAM = + AM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR).getManager(); + auto *BPI = FAM.getCachedResult<BranchProbabilityAnalysis>(*F); + InductiveRangeCheckElimination IRCE(AR.SE, BPI, AR.DT, AR.LI); + auto LPMAddNewLoop = [&U](Loop *NL, bool IsSubloop) { + if (!IsSubloop) + U.addSiblingLoops(NL); + }; + bool Changed = IRCE.run(&L, LPMAddNewLoop); + if (!Changed) + return PreservedAnalyses::all(); + + return getLoopPassPreservedAnalyses(); +} + +bool IRCELegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) { if (skipLoop(L)) return false; + ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + BranchProbabilityInfo &BPI = + getAnalysis<BranchProbabilityInfoWrapperPass>().getBPI(); + auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + InductiveRangeCheckElimination IRCE(SE, &BPI, DT, LI); + auto LPMAddNewLoop = [&LPM](Loop *NL, bool /* IsSubLoop */) { + LPM.addLoop(*NL); + }; + return IRCE.run(L, LPMAddNewLoop); +} + +bool InductiveRangeCheckElimination::run( + Loop *L, function_ref<void(Loop *, bool)> LPMAddNewLoop) { if (L->getBlocks().size() >= LoopSizeCutoff) { - DEBUG(dbgs() << "irce: giving up constraining loop, too large\n";); + LLVM_DEBUG(dbgs() << "irce: giving up constraining loop, too large\n"); return false; } BasicBlock *Preheader = L->getLoopPreheader(); if (!Preheader) { - DEBUG(dbgs() << "irce: loop has no preheader, leaving\n"); + LLVM_DEBUG(dbgs() << "irce: loop has no preheader, leaving\n"); return false; } LLVMContext &Context = Preheader->getContext(); SmallVector<InductiveRangeCheck, 16> RangeChecks; - ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - BranchProbabilityInfo &BPI = - getAnalysis<BranchProbabilityInfoWrapperPass>().getBPI(); for (auto BBI : L->getBlocks()) if (BranchInst *TBI = dyn_cast<BranchInst>(BBI->getTerminator())) @@ -1772,7 +1896,7 @@ bool InductiveRangeCheckElimination::runOnLoop(Loop *L, LPPassManager &LPM) { IRC.print(OS); }; - DEBUG(PrintRecognizedRangeChecks(dbgs())); + LLVM_DEBUG(PrintRecognizedRangeChecks(dbgs())); if (PrintRangeChecks) PrintRecognizedRangeChecks(errs()); @@ -1781,8 +1905,8 @@ bool InductiveRangeCheckElimination::runOnLoop(Loop *L, LPPassManager &LPM) { Optional<LoopStructure> MaybeLoopStructure = LoopStructure::parseLoopStructure(SE, BPI, *L, FailureReason); if (!MaybeLoopStructure.hasValue()) { - DEBUG(dbgs() << "irce: could not parse loop structure: " << FailureReason - << "\n";); + LLVM_DEBUG(dbgs() << "irce: could not parse loop structure: " + << FailureReason << "\n";); return false; } LoopStructure LS = MaybeLoopStructure.getValue(); @@ -1820,9 +1944,8 @@ bool InductiveRangeCheckElimination::runOnLoop(Loop *L, LPPassManager &LPM) { if (!SafeIterRange.hasValue()) return false; - auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - LoopConstrainer LC(*L, getAnalysis<LoopInfoWrapperPass>().getLoopInfo(), LPM, - LS, SE, DT, SafeIterRange.getValue()); + LoopConstrainer LC(*L, LI, LPMAddNewLoop, LS, SE, DT, + SafeIterRange.getValue()); bool Changed = LC.run(); if (Changed) { @@ -1833,7 +1956,7 @@ bool InductiveRangeCheckElimination::runOnLoop(Loop *L, LPPassManager &LPM) { L->print(dbgs()); }; - DEBUG(PrintConstrainedLoopInfo()); + LLVM_DEBUG(PrintConstrainedLoopInfo()); if (PrintChangedLoops) PrintConstrainedLoopInfo(); @@ -1852,5 +1975,5 @@ bool InductiveRangeCheckElimination::runOnLoop(Loop *L, LPPassManager &LPM) { } Pass *llvm::createInductiveRangeCheckEliminationPass() { - return new InductiveRangeCheckElimination; + return new IRCELegacyPass(); } |
