diff options
| author | 2017-01-24 08:32:59 +0000 | |
|---|---|---|
| committer | 2017-01-24 08:32:59 +0000 | |
| commit | 53d771aafdbe5b919f264f53cba3788e2c4cffd2 (patch) | |
| tree | 7eca39498be0ff1e3a6daf583cd9ca5886bb2636 /gnu/llvm/lib/Analysis/InstructionSimplify.cpp | |
| parent | In preparation of compiling our kernels with -ffreestanding, explicitly map (diff) | |
| download | wireguard-openbsd-53d771aafdbe5b919f264f53cba3788e2c4cffd2.tar.xz wireguard-openbsd-53d771aafdbe5b919f264f53cba3788e2c4cffd2.zip | |
Import LLVM 4.0.0 rc1 including clang and lld to help the current
development effort on OpenBSD/arm64.
Diffstat (limited to 'gnu/llvm/lib/Analysis/InstructionSimplify.cpp')
| -rw-r--r-- | gnu/llvm/lib/Analysis/InstructionSimplify.cpp | 1666 |
1 files changed, 988 insertions, 678 deletions
diff --git a/gnu/llvm/lib/Analysis/InstructionSimplify.cpp b/gnu/llvm/lib/Analysis/InstructionSimplify.cpp index aeaf9388579..796e6e44498 100644 --- a/gnu/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/gnu/llvm/lib/Analysis/InstructionSimplify.cpp @@ -67,9 +67,12 @@ static Value *SimplifyFPBinOp(unsigned, Value *, Value *, const FastMathFlags &, const Query &, unsigned); static Value *SimplifyCmpInst(unsigned, Value *, Value *, const Query &, unsigned); +static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, + const Query &Q, unsigned MaxRecurse); static Value *SimplifyOrInst(Value *, Value *, const Query &, unsigned); static Value *SimplifyXorInst(Value *, Value *, const Query &, unsigned); -static Value *SimplifyTruncInst(Value *, Type *, const Query &, unsigned); +static Value *SimplifyCastInst(unsigned, Value *, Type *, + const Query &, unsigned); /// For a boolean type, or a vector of boolean type, return false, or /// a vector with every element false, as appropriate for the type. @@ -679,9 +682,26 @@ static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, if (Op0 == Op1) return Constant::getNullValue(Op0->getType()); - // 0 - X -> 0 if the sub is NUW. - if (isNUW && match(Op0, m_Zero())) - return Op0; + // Is this a negation? + if (match(Op0, m_Zero())) { + // 0 - X -> 0 if the sub is NUW. + if (isNUW) + return Op0; + + unsigned BitWidth = Op1->getType()->getScalarSizeInBits(); + APInt KnownZero(BitWidth, 0); + APInt KnownOne(BitWidth, 0); + computeKnownBits(Op1, KnownZero, KnownOne, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (KnownZero == ~APInt::getSignBit(BitWidth)) { + // Op1 is either 0 or the minimum signed value. If the sub is NSW, then + // Op1 must be 0 because negating the minimum signed value is undefined. + if (isNSW) + return Op0; + + // 0 - X -> X if X is 0 or the minimum signed value. + return Op1; + } + } // (X + Y) - Z -> X + (Y - Z) or Y + (X - Z) if everything simplifies. // For example, (X + Y) - Y -> X; (Y + X) - Y -> X @@ -747,7 +767,8 @@ static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, // See if "V === X - Y" simplifies. if (Value *V = SimplifyBinOp(Instruction::Sub, X, Y, Q, MaxRecurse-1)) // It does! Now see if "trunc V" simplifies. - if (Value *W = SimplifyTruncInst(V, Op0->getType(), Q, MaxRecurse-1)) + if (Value *W = SimplifyCastInst(Instruction::Trunc, V, Op0->getType(), + Q, MaxRecurse - 1)) // It does, return the simplified "trunc V". return W; @@ -1085,6 +1106,16 @@ static Value *SimplifyUDivInst(Value *Op0, Value *Op1, const Query &Q, if (Value *V = SimplifyDiv(Instruction::UDiv, Op0, Op1, Q, MaxRecurse)) return V; + // udiv %V, C -> 0 if %V < C + if (MaxRecurse) { + if (Constant *C = dyn_cast_or_null<Constant>(SimplifyICmpInst( + ICmpInst::ICMP_ULT, Op0, Op1, Q, MaxRecurse - 1))) { + if (C->isAllOnesValue()) { + return Constant::getNullValue(Op0->getType()); + } + } + } + return nullptr; } @@ -1106,6 +1137,10 @@ static Value *SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF, if (match(Op1, m_Undef())) return Op1; + // X / 1.0 -> X + if (match(Op1, m_FPOne())) + return Op0; + // 0 / X -> 0 // Requires that NaNs are off (X could be zero) and signed zeroes are // ignored (X could be positive or negative, so the output sign is unknown). @@ -1222,6 +1257,16 @@ static Value *SimplifyURemInst(Value *Op0, Value *Op1, const Query &Q, if (Value *V = SimplifyRem(Instruction::URem, Op0, Op1, Q, MaxRecurse)) return V; + // urem %V, C -> %V if %V < C + if (MaxRecurse) { + if (Constant *C = dyn_cast_or_null<Constant>(SimplifyICmpInst( + ICmpInst::ICMP_ULT, Op0, Op1, Q, MaxRecurse - 1))) { + if (C->isAllOnesValue()) { + return Op0; + } + } + } + return nullptr; } @@ -1497,17 +1542,45 @@ static Value *simplifyUnsignedRangeCheck(ICmpInst *ZeroICmp, return nullptr; } -static Value *SimplifyAndOfICmps(ICmpInst *Op0, ICmpInst *Op1) { - Type *ITy = Op0->getType(); +/// Commuted variants are assumed to be handled by calling this function again +/// with the parameters swapped. +static Value *simplifyAndOfICmpsWithSameOperands(ICmpInst *Op0, ICmpInst *Op1) { ICmpInst::Predicate Pred0, Pred1; - ConstantInt *CI1, *CI2; - Value *V; + Value *A ,*B; + if (!match(Op0, m_ICmp(Pred0, m_Value(A), m_Value(B))) || + !match(Op1, m_ICmp(Pred1, m_Specific(A), m_Specific(B)))) + return nullptr; + // We have (icmp Pred0, A, B) & (icmp Pred1, A, B). + // If Op1 is always implied true by Op0, then Op0 is a subset of Op1, and we + // can eliminate Op1 from this 'and'. + if (ICmpInst::isImpliedTrueByMatchingCmp(Pred0, Pred1)) + return Op0; + + // Check for any combination of predicates that are guaranteed to be disjoint. + if ((Pred0 == ICmpInst::getInversePredicate(Pred1)) || + (Pred0 == ICmpInst::ICMP_EQ && ICmpInst::isFalseWhenEqual(Pred1)) || + (Pred0 == ICmpInst::ICMP_SLT && Pred1 == ICmpInst::ICMP_SGT) || + (Pred0 == ICmpInst::ICMP_ULT && Pred1 == ICmpInst::ICMP_UGT)) + return getFalse(Op0->getType()); + + return nullptr; +} + +/// Commuted variants are assumed to be handled by calling this function again +/// with the parameters swapped. +static Value *SimplifyAndOfICmps(ICmpInst *Op0, ICmpInst *Op1) { if (Value *X = simplifyUnsignedRangeCheck(Op0, Op1, /*IsAnd=*/true)) return X; + if (Value *X = simplifyAndOfICmpsWithSameOperands(Op0, Op1)) + return X; + // Look for this pattern: (icmp V, C0) & (icmp V, C1)). + Type *ITy = Op0->getType(); + ICmpInst::Predicate Pred0, Pred1; const APInt *C0, *C1; + Value *V; if (match(Op0, m_ICmp(Pred0, m_Value(V), m_APInt(C0))) && match(Op1, m_ICmp(Pred1, m_Specific(V), m_APInt(C1)))) { // Make a constant range that's the intersection of the two icmp ranges. @@ -1518,21 +1591,22 @@ static Value *SimplifyAndOfICmps(ICmpInst *Op0, ICmpInst *Op1) { return getFalse(ITy); } - if (!match(Op0, m_ICmp(Pred0, m_Add(m_Value(V), m_ConstantInt(CI1)), - m_ConstantInt(CI2)))) + // (icmp (add V, C0), C1) & (icmp V, C0) + if (!match(Op0, m_ICmp(Pred0, m_Add(m_Value(V), m_APInt(C0)), m_APInt(C1)))) return nullptr; - if (!match(Op1, m_ICmp(Pred1, m_Specific(V), m_Specific(CI1)))) + if (!match(Op1, m_ICmp(Pred1, m_Specific(V), m_Value()))) return nullptr; auto *AddInst = cast<BinaryOperator>(Op0->getOperand(0)); + if (AddInst->getOperand(1) != Op1->getOperand(1)) + return nullptr; + bool isNSW = AddInst->hasNoSignedWrap(); bool isNUW = AddInst->hasNoUnsignedWrap(); - const APInt &CI1V = CI1->getValue(); - const APInt &CI2V = CI2->getValue(); - const APInt Delta = CI2V - CI1V; - if (CI1V.isStrictlyPositive()) { + const APInt Delta = *C1 - *C0; + if (C0->isStrictlyPositive()) { if (Delta == 2) { if (Pred0 == ICmpInst::ICMP_ULT && Pred1 == ICmpInst::ICMP_SGT) return getFalse(ITy); @@ -1546,7 +1620,7 @@ static Value *SimplifyAndOfICmps(ICmpInst *Op0, ICmpInst *Op1) { return getFalse(ITy); } } - if (CI1V.getBoolValue() && isNUW) { + if (C0->getBoolValue() && isNUW) { if (Delta == 2) if (Pred0 == ICmpInst::ICMP_ULT && Pred1 == ICmpInst::ICMP_UGT) return getFalse(ITy); @@ -1680,33 +1754,61 @@ Value *llvm::SimplifyAndInst(Value *Op0, Value *Op1, const DataLayout &DL, RecursionLimit); } -/// Simplify (or (icmp ...) (icmp ...)) to true when we can tell that the union -/// contains all possible values. -static Value *SimplifyOrOfICmps(ICmpInst *Op0, ICmpInst *Op1) { +/// Commuted variants are assumed to be handled by calling this function again +/// with the parameters swapped. +static Value *simplifyOrOfICmpsWithSameOperands(ICmpInst *Op0, ICmpInst *Op1) { ICmpInst::Predicate Pred0, Pred1; - ConstantInt *CI1, *CI2; - Value *V; + Value *A ,*B; + if (!match(Op0, m_ICmp(Pred0, m_Value(A), m_Value(B))) || + !match(Op1, m_ICmp(Pred1, m_Specific(A), m_Specific(B)))) + return nullptr; + + // We have (icmp Pred0, A, B) | (icmp Pred1, A, B). + // If Op1 is always implied true by Op0, then Op0 is a subset of Op1, and we + // can eliminate Op0 from this 'or'. + if (ICmpInst::isImpliedTrueByMatchingCmp(Pred0, Pred1)) + return Op1; + + // Check for any combination of predicates that cover the entire range of + // possibilities. + if ((Pred0 == ICmpInst::getInversePredicate(Pred1)) || + (Pred0 == ICmpInst::ICMP_NE && ICmpInst::isTrueWhenEqual(Pred1)) || + (Pred0 == ICmpInst::ICMP_SLE && Pred1 == ICmpInst::ICMP_SGE) || + (Pred0 == ICmpInst::ICMP_ULE && Pred1 == ICmpInst::ICMP_UGE)) + return getTrue(Op0->getType()); + + return nullptr; +} +/// Commuted variants are assumed to be handled by calling this function again +/// with the parameters swapped. +static Value *SimplifyOrOfICmps(ICmpInst *Op0, ICmpInst *Op1) { if (Value *X = simplifyUnsignedRangeCheck(Op0, Op1, /*IsAnd=*/false)) return X; - if (!match(Op0, m_ICmp(Pred0, m_Add(m_Value(V), m_ConstantInt(CI1)), - m_ConstantInt(CI2)))) - return nullptr; + if (Value *X = simplifyOrOfICmpsWithSameOperands(Op0, Op1)) + return X; - if (!match(Op1, m_ICmp(Pred1, m_Specific(V), m_Specific(CI1)))) + // (icmp (add V, C0), C1) | (icmp V, C0) + ICmpInst::Predicate Pred0, Pred1; + const APInt *C0, *C1; + Value *V; + if (!match(Op0, m_ICmp(Pred0, m_Add(m_Value(V), m_APInt(C0)), m_APInt(C1)))) return nullptr; - Type *ITy = Op0->getType(); + if (!match(Op1, m_ICmp(Pred1, m_Specific(V), m_Value()))) + return nullptr; auto *AddInst = cast<BinaryOperator>(Op0->getOperand(0)); + if (AddInst->getOperand(1) != Op1->getOperand(1)) + return nullptr; + + Type *ITy = Op0->getType(); bool isNSW = AddInst->hasNoSignedWrap(); bool isNUW = AddInst->hasNoUnsignedWrap(); - const APInt &CI1V = CI1->getValue(); - const APInt &CI2V = CI2->getValue(); - const APInt Delta = CI2V - CI1V; - if (CI1V.isStrictlyPositive()) { + const APInt Delta = *C1 - *C0; + if (C0->isStrictlyPositive()) { if (Delta == 2) { if (Pred0 == ICmpInst::ICMP_UGE && Pred1 == ICmpInst::ICMP_SLE) return getTrue(ITy); @@ -1720,7 +1822,7 @@ static Value *SimplifyOrOfICmps(ICmpInst *Op0, ICmpInst *Op1) { return getTrue(ITy); } } - if (CI1V.getBoolValue() && isNUW) { + if (C0->getBoolValue() && isNUW) { if (Delta == 2) if (Pred0 == ICmpInst::ICMP_UGE && Pred1 == ICmpInst::ICMP_ULE) return getTrue(ITy); @@ -2102,8 +2204,8 @@ computePointerICmp(const DataLayout &DL, const TargetLibraryInfo *TLI, GetUnderlyingObjects(RHS, RHSUObjs, DL); // Is the set of underlying objects all noalias calls? - auto IsNAC = [](SmallVectorImpl<Value *> &Objects) { - return std::all_of(Objects.begin(), Objects.end(), isNoAliasCall); + auto IsNAC = [](ArrayRef<Value *> Objects) { + return all_of(Objects, isNoAliasCall); }; // Is the set of underlying objects all things which must be disjoint from @@ -2112,8 +2214,8 @@ computePointerICmp(const DataLayout &DL, const TargetLibraryInfo *TLI, // live with the compared-to allocation). For globals, we exclude symbols // that might be resolve lazily to symbols in another dynamically-loaded // library (and, thus, could be malloc'ed by the implementation). - auto IsAllocDisjoint = [](SmallVectorImpl<Value *> &Objects) { - return std::all_of(Objects.begin(), Objects.end(), [](Value *V) { + auto IsAllocDisjoint = [](ArrayRef<Value *> Objects) { + return all_of(Objects, [](Value *V) { if (const AllocaInst *AI = dyn_cast<AllocaInst>(V)) return AI->getParent() && AI->getFunction() && AI->isStaticAlloca(); if (const GlobalValue *GV = dyn_cast<GlobalValue>(V)) @@ -2150,470 +2252,275 @@ computePointerICmp(const DataLayout &DL, const TargetLibraryInfo *TLI, return nullptr; } -/// Given operands for an ICmpInst, see if we can fold the result. -/// If not, this returns null. -static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, - const Query &Q, unsigned MaxRecurse) { - CmpInst::Predicate Pred = (CmpInst::Predicate)Predicate; - assert(CmpInst::isIntPredicate(Pred) && "Not an integer compare!"); - - if (Constant *CLHS = dyn_cast<Constant>(LHS)) { - if (Constant *CRHS = dyn_cast<Constant>(RHS)) - return ConstantFoldCompareInstOperands(Pred, CLHS, CRHS, Q.DL, Q.TLI); - - // If we have a constant, make sure it is on the RHS. - std::swap(LHS, RHS); - Pred = CmpInst::getSwappedPredicate(Pred); - } - +/// Fold an icmp when its operands have i1 scalar type. +static Value *simplifyICmpOfBools(CmpInst::Predicate Pred, Value *LHS, + Value *RHS, const Query &Q) { Type *ITy = GetCompareTy(LHS); // The return type. Type *OpTy = LHS->getType(); // The operand type. + if (!OpTy->getScalarType()->isIntegerTy(1)) + return nullptr; - // icmp X, X -> true/false - // X icmp undef -> true/false. For example, icmp ugt %X, undef -> false - // because X could be 0. - if (LHS == RHS || isa<UndefValue>(RHS)) - return ConstantInt::get(ITy, CmpInst::isTrueWhenEqual(Pred)); - - // Special case logic when the operands have i1 type. - if (OpTy->getScalarType()->isIntegerTy(1)) { - switch (Pred) { - default: break; - case ICmpInst::ICMP_EQ: - // X == 1 -> X - if (match(RHS, m_One())) - return LHS; - break; - case ICmpInst::ICMP_NE: - // X != 0 -> X - if (match(RHS, m_Zero())) - return LHS; - break; - case ICmpInst::ICMP_UGT: - // X >u 0 -> X - if (match(RHS, m_Zero())) - return LHS; - break; - case ICmpInst::ICMP_UGE: { - // X >=u 1 -> X - if (match(RHS, m_One())) - return LHS; - if (isImpliedCondition(RHS, LHS, Q.DL).getValueOr(false)) - return getTrue(ITy); - break; - } - case ICmpInst::ICMP_SGE: { - /// For signed comparison, the values for an i1 are 0 and -1 - /// respectively. This maps into a truth table of: - /// LHS | RHS | LHS >=s RHS | LHS implies RHS - /// 0 | 0 | 1 (0 >= 0) | 1 - /// 0 | 1 | 1 (0 >= -1) | 1 - /// 1 | 0 | 0 (-1 >= 0) | 0 - /// 1 | 1 | 1 (-1 >= -1) | 1 - if (isImpliedCondition(LHS, RHS, Q.DL).getValueOr(false)) - return getTrue(ITy); - break; - } - case ICmpInst::ICMP_SLT: - // X <s 0 -> X - if (match(RHS, m_Zero())) - return LHS; - break; - case ICmpInst::ICMP_SLE: - // X <=s -1 -> X - if (match(RHS, m_One())) - return LHS; - break; - case ICmpInst::ICMP_ULE: { - if (isImpliedCondition(LHS, RHS, Q.DL).getValueOr(false)) - return getTrue(ITy); - break; - } - } - } - - // If we are comparing with zero then try hard since this is a common case. - if (match(RHS, m_Zero())) { - bool LHSKnownNonNegative, LHSKnownNegative; - switch (Pred) { - default: llvm_unreachable("Unknown ICmp predicate!"); - case ICmpInst::ICMP_ULT: - return getFalse(ITy); - case ICmpInst::ICMP_UGE: + switch (Pred) { + default: + break; + case ICmpInst::ICMP_EQ: + // X == 1 -> X + if (match(RHS, m_One())) + return LHS; + break; + case ICmpInst::ICMP_NE: + // X != 0 -> X + if (match(RHS, m_Zero())) + return LHS; + break; + case ICmpInst::ICMP_UGT: + // X >u 0 -> X + if (match(RHS, m_Zero())) + return LHS; + break; + case ICmpInst::ICMP_UGE: + // X >=u 1 -> X + if (match(RHS, m_One())) + return LHS; + if (isImpliedCondition(RHS, LHS, Q.DL).getValueOr(false)) return getTrue(ITy); - case ICmpInst::ICMP_EQ: - case ICmpInst::ICMP_ULE: - if (isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) - return getFalse(ITy); - break; - case ICmpInst::ICMP_NE: - case ICmpInst::ICMP_UGT: - if (isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) - return getTrue(ITy); - break; - case ICmpInst::ICMP_SLT: - ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL, 0, Q.AC, - Q.CxtI, Q.DT); - if (LHSKnownNegative) - return getTrue(ITy); - if (LHSKnownNonNegative) - return getFalse(ITy); - break; - case ICmpInst::ICMP_SLE: - ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL, 0, Q.AC, - Q.CxtI, Q.DT); - if (LHSKnownNegative) - return getTrue(ITy); - if (LHSKnownNonNegative && - isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) - return getFalse(ITy); - break; - case ICmpInst::ICMP_SGE: - ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL, 0, Q.AC, - Q.CxtI, Q.DT); - if (LHSKnownNegative) - return getFalse(ITy); - if (LHSKnownNonNegative) - return getTrue(ITy); - break; - case ICmpInst::ICMP_SGT: - ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL, 0, Q.AC, - Q.CxtI, Q.DT); - if (LHSKnownNegative) - return getFalse(ITy); - if (LHSKnownNonNegative && - isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) - return getTrue(ITy); - break; - } + break; + case ICmpInst::ICMP_SGE: + /// For signed comparison, the values for an i1 are 0 and -1 + /// respectively. This maps into a truth table of: + /// LHS | RHS | LHS >=s RHS | LHS implies RHS + /// 0 | 0 | 1 (0 >= 0) | 1 + /// 0 | 1 | 1 (0 >= -1) | 1 + /// 1 | 0 | 0 (-1 >= 0) | 0 + /// 1 | 1 | 1 (-1 >= -1) | 1 + if (isImpliedCondition(LHS, RHS, Q.DL).getValueOr(false)) + return getTrue(ITy); + break; + case ICmpInst::ICMP_SLT: + // X <s 0 -> X + if (match(RHS, m_Zero())) + return LHS; + break; + case ICmpInst::ICMP_SLE: + // X <=s -1 -> X + if (match(RHS, m_One())) + return LHS; + break; + case ICmpInst::ICMP_ULE: + if (isImpliedCondition(LHS, RHS, Q.DL).getValueOr(false)) + return getTrue(ITy); + break; } - // See if we are doing a comparison with a constant integer. - if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) { - // Rule out tautological comparisons (eg., ult 0 or uge 0). - ConstantRange RHS_CR = ICmpInst::makeConstantRange(Pred, CI->getValue()); - if (RHS_CR.isEmptySet()) - return ConstantInt::getFalse(CI->getContext()); - if (RHS_CR.isFullSet()) - return ConstantInt::getTrue(CI->getContext()); - - // Many binary operators with constant RHS have easy to compute constant - // range. Use them to check whether the comparison is a tautology. - unsigned Width = CI->getBitWidth(); - APInt Lower = APInt(Width, 0); - APInt Upper = APInt(Width, 0); - ConstantInt *CI2; - if (match(LHS, m_URem(m_Value(), m_ConstantInt(CI2)))) { - // 'urem x, CI2' produces [0, CI2). - Upper = CI2->getValue(); - } else if (match(LHS, m_SRem(m_Value(), m_ConstantInt(CI2)))) { - // 'srem x, CI2' produces (-|CI2|, |CI2|). - Upper = CI2->getValue().abs(); - Lower = (-Upper) + 1; - } else if (match(LHS, m_UDiv(m_ConstantInt(CI2), m_Value()))) { - // 'udiv CI2, x' produces [0, CI2]. - Upper = CI2->getValue() + 1; - } else if (match(LHS, m_UDiv(m_Value(), m_ConstantInt(CI2)))) { - // 'udiv x, CI2' produces [0, UINT_MAX / CI2]. - APInt NegOne = APInt::getAllOnesValue(Width); - if (!CI2->isZero()) - Upper = NegOne.udiv(CI2->getValue()) + 1; - } else if (match(LHS, m_SDiv(m_ConstantInt(CI2), m_Value()))) { - if (CI2->isMinSignedValue()) { - // 'sdiv INT_MIN, x' produces [INT_MIN, INT_MIN / -2]. - Lower = CI2->getValue(); - Upper = Lower.lshr(1) + 1; - } else { - // 'sdiv CI2, x' produces [-|CI2|, |CI2|]. - Upper = CI2->getValue().abs() + 1; - Lower = (-Upper) + 1; - } - } else if (match(LHS, m_SDiv(m_Value(), m_ConstantInt(CI2)))) { - APInt IntMin = APInt::getSignedMinValue(Width); - APInt IntMax = APInt::getSignedMaxValue(Width); - const APInt &Val = CI2->getValue(); - if (Val.isAllOnesValue()) { - // 'sdiv x, -1' produces [INT_MIN + 1, INT_MAX] - // where CI2 != -1 and CI2 != 0 and CI2 != 1 - Lower = IntMin + 1; - Upper = IntMax + 1; - } else if (Val.countLeadingZeros() < Width - 1) { - // 'sdiv x, CI2' produces [INT_MIN / CI2, INT_MAX / CI2] - // where CI2 != -1 and CI2 != 0 and CI2 != 1 - Lower = IntMin.sdiv(Val); - Upper = IntMax.sdiv(Val); - if (Lower.sgt(Upper)) - std::swap(Lower, Upper); - Upper = Upper + 1; - assert(Upper != Lower && "Upper part of range has wrapped!"); - } - } else if (match(LHS, m_NUWShl(m_ConstantInt(CI2), m_Value()))) { - // 'shl nuw CI2, x' produces [CI2, CI2 << CLZ(CI2)] - Lower = CI2->getValue(); - Upper = Lower.shl(Lower.countLeadingZeros()) + 1; - } else if (match(LHS, m_NSWShl(m_ConstantInt(CI2), m_Value()))) { - if (CI2->isNegative()) { - // 'shl nsw CI2, x' produces [CI2 << CLO(CI2)-1, CI2] - unsigned ShiftAmount = CI2->getValue().countLeadingOnes() - 1; - Lower = CI2->getValue().shl(ShiftAmount); - Upper = CI2->getValue() + 1; - } else { - // 'shl nsw CI2, x' produces [CI2, CI2 << CLZ(CI2)-1] - unsigned ShiftAmount = CI2->getValue().countLeadingZeros() - 1; - Lower = CI2->getValue(); - Upper = CI2->getValue().shl(ShiftAmount) + 1; - } - } else if (match(LHS, m_LShr(m_Value(), m_ConstantInt(CI2)))) { - // 'lshr x, CI2' produces [0, UINT_MAX >> CI2]. - APInt NegOne = APInt::getAllOnesValue(Width); - if (CI2->getValue().ult(Width)) - Upper = NegOne.lshr(CI2->getValue()) + 1; - } else if (match(LHS, m_LShr(m_ConstantInt(CI2), m_Value()))) { - // 'lshr CI2, x' produces [CI2 >> (Width-1), CI2]. - unsigned ShiftAmount = Width - 1; - if (!CI2->isZero() && cast<BinaryOperator>(LHS)->isExact()) - ShiftAmount = CI2->getValue().countTrailingZeros(); - Lower = CI2->getValue().lshr(ShiftAmount); - Upper = CI2->getValue() + 1; - } else if (match(LHS, m_AShr(m_Value(), m_ConstantInt(CI2)))) { - // 'ashr x, CI2' produces [INT_MIN >> CI2, INT_MAX >> CI2]. - APInt IntMin = APInt::getSignedMinValue(Width); - APInt IntMax = APInt::getSignedMaxValue(Width); - if (CI2->getValue().ult(Width)) { - Lower = IntMin.ashr(CI2->getValue()); - Upper = IntMax.ashr(CI2->getValue()) + 1; - } - } else if (match(LHS, m_AShr(m_ConstantInt(CI2), m_Value()))) { - unsigned ShiftAmount = Width - 1; - if (!CI2->isZero() && cast<BinaryOperator>(LHS)->isExact()) - ShiftAmount = CI2->getValue().countTrailingZeros(); - if (CI2->isNegative()) { - // 'ashr CI2, x' produces [CI2, CI2 >> (Width-1)] - Lower = CI2->getValue(); - Upper = CI2->getValue().ashr(ShiftAmount) + 1; - } else { - // 'ashr CI2, x' produces [CI2 >> (Width-1), CI2] - Lower = CI2->getValue().ashr(ShiftAmount); - Upper = CI2->getValue() + 1; - } - } else if (match(LHS, m_Or(m_Value(), m_ConstantInt(CI2)))) { - // 'or x, CI2' produces [CI2, UINT_MAX]. - Lower = CI2->getValue(); - } else if (match(LHS, m_And(m_Value(), m_ConstantInt(CI2)))) { - // 'and x, CI2' produces [0, CI2]. - Upper = CI2->getValue() + 1; - } else if (match(LHS, m_NUWAdd(m_Value(), m_ConstantInt(CI2)))) { - // 'add nuw x, CI2' produces [CI2, UINT_MAX]. - Lower = CI2->getValue(); - } - - ConstantRange LHS_CR = Lower != Upper ? ConstantRange(Lower, Upper) - : ConstantRange(Width, true); + return nullptr; +} - if (auto *I = dyn_cast<Instruction>(LHS)) - if (auto *Ranges = I->getMetadata(LLVMContext::MD_range)) - LHS_CR = LHS_CR.intersectWith(getConstantRangeFromMetadata(*Ranges)); +/// Try hard to fold icmp with zero RHS because this is a common case. +static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS, + Value *RHS, const Query &Q) { + if (!match(RHS, m_Zero())) + return nullptr; - if (!LHS_CR.isFullSet()) { - if (RHS_CR.contains(LHS_CR)) - return ConstantInt::getTrue(RHS->getContext()); - if (RHS_CR.inverse().contains(LHS_CR)) - return ConstantInt::getFalse(RHS->getContext()); - } + Type *ITy = GetCompareTy(LHS); // The return type. + bool LHSKnownNonNegative, LHSKnownNegative; + switch (Pred) { + default: + llvm_unreachable("Unknown ICmp predicate!"); + case ICmpInst::ICMP_ULT: + return getFalse(ITy); + case ICmpInst::ICMP_UGE: + return getTrue(ITy); + case ICmpInst::ICMP_EQ: + case ICmpInst::ICMP_ULE: + if (isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) + return getFalse(ITy); + break; + case ICmpInst::ICMP_NE: + case ICmpInst::ICMP_UGT: + if (isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) + return getTrue(ITy); + break; + case ICmpInst::ICMP_SLT: + ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL, 0, Q.AC, + Q.CxtI, Q.DT); + if (LHSKnownNegative) + return getTrue(ITy); + if (LHSKnownNonNegative) + return getFalse(ITy); + break; + case ICmpInst::ICMP_SLE: + ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL, 0, Q.AC, + Q.CxtI, Q.DT); + if (LHSKnownNegative) + return getTrue(ITy); + if (LHSKnownNonNegative && isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) + return getFalse(ITy); + break; + case ICmpInst::ICMP_SGE: + ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL, 0, Q.AC, + Q.CxtI, Q.DT); + if (LHSKnownNegative) + return getFalse(ITy); + if (LHSKnownNonNegative) + return getTrue(ITy); + break; + case ICmpInst::ICMP_SGT: + ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL, 0, Q.AC, + Q.CxtI, Q.DT); + if (LHSKnownNegative) + return getFalse(ITy); + if (LHSKnownNonNegative && isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) + return getTrue(ITy); + break; } - // If both operands have range metadata, use the metadata - // to simplify the comparison. - if (isa<Instruction>(RHS) && isa<Instruction>(LHS)) { - auto RHS_Instr = dyn_cast<Instruction>(RHS); - auto LHS_Instr = dyn_cast<Instruction>(LHS); - - if (RHS_Instr->getMetadata(LLVMContext::MD_range) && - LHS_Instr->getMetadata(LLVMContext::MD_range)) { - auto RHS_CR = getConstantRangeFromMetadata( - *RHS_Instr->getMetadata(LLVMContext::MD_range)); - auto LHS_CR = getConstantRangeFromMetadata( - *LHS_Instr->getMetadata(LLVMContext::MD_range)); + return nullptr; +} - auto Satisfied_CR = ConstantRange::makeSatisfyingICmpRegion(Pred, RHS_CR); - if (Satisfied_CR.contains(LHS_CR)) - return ConstantInt::getTrue(RHS->getContext()); +static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS, + Value *RHS) { + const APInt *C; + if (!match(RHS, m_APInt(C))) + return nullptr; - auto InversedSatisfied_CR = ConstantRange::makeSatisfyingICmpRegion( - CmpInst::getInversePredicate(Pred), RHS_CR); - if (InversedSatisfied_CR.contains(LHS_CR)) - return ConstantInt::getFalse(RHS->getContext()); + // Rule out tautological comparisons (eg., ult 0 or uge 0). + ConstantRange RHS_CR = ConstantRange::makeExactICmpRegion(Pred, *C); + if (RHS_CR.isEmptySet()) + return ConstantInt::getFalse(GetCompareTy(RHS)); + if (RHS_CR.isFullSet()) + return ConstantInt::getTrue(GetCompareTy(RHS)); + + // Many binary operators with constant RHS have easy to compute constant + // range. Use them to check whether the comparison is a tautology. + unsigned Width = C->getBitWidth(); + APInt Lower = APInt(Width, 0); + APInt Upper = APInt(Width, 0); + const APInt *C2; + if (match(LHS, m_URem(m_Value(), m_APInt(C2)))) { + // 'urem x, C2' produces [0, C2). + Upper = *C2; + } else if (match(LHS, m_SRem(m_Value(), m_APInt(C2)))) { + // 'srem x, C2' produces (-|C2|, |C2|). + Upper = C2->abs(); + Lower = (-Upper) + 1; + } else if (match(LHS, m_UDiv(m_APInt(C2), m_Value()))) { + // 'udiv C2, x' produces [0, C2]. + Upper = *C2 + 1; + } else if (match(LHS, m_UDiv(m_Value(), m_APInt(C2)))) { + // 'udiv x, C2' produces [0, UINT_MAX / C2]. + APInt NegOne = APInt::getAllOnesValue(Width); + if (*C2 != 0) + Upper = NegOne.udiv(*C2) + 1; + } else if (match(LHS, m_SDiv(m_APInt(C2), m_Value()))) { + if (C2->isMinSignedValue()) { + // 'sdiv INT_MIN, x' produces [INT_MIN, INT_MIN / -2]. + Lower = *C2; + Upper = Lower.lshr(1) + 1; + } else { + // 'sdiv C2, x' produces [-|C2|, |C2|]. + Upper = C2->abs() + 1; + Lower = (-Upper) + 1; } - } - - // Compare of cast, for example (zext X) != 0 -> X != 0 - if (isa<CastInst>(LHS) && (isa<Constant>(RHS) || isa<CastInst>(RHS))) { - Instruction *LI = cast<CastInst>(LHS); - Value *SrcOp = LI->getOperand(0); - Type *SrcTy = SrcOp->getType(); - Type *DstTy = LI->getType(); - - // Turn icmp (ptrtoint x), (ptrtoint/constant) into a compare of the input - // if the integer type is the same size as the pointer type. - if (MaxRecurse && isa<PtrToIntInst>(LI) && - Q.DL.getTypeSizeInBits(SrcTy) == DstTy->getPrimitiveSizeInBits()) { - if (Constant *RHSC = dyn_cast<Constant>(RHS)) { - // Transfer the cast to the constant. - if (Value *V = SimplifyICmpInst(Pred, SrcOp, - ConstantExpr::getIntToPtr(RHSC, SrcTy), - Q, MaxRecurse-1)) - return V; - } else if (PtrToIntInst *RI = dyn_cast<PtrToIntInst>(RHS)) { - if (RI->getOperand(0)->getType() == SrcTy) - // Compare without the cast. - if (Value *V = SimplifyICmpInst(Pred, SrcOp, RI->getOperand(0), - Q, MaxRecurse-1)) - return V; - } + } else if (match(LHS, m_SDiv(m_Value(), m_APInt(C2)))) { + APInt IntMin = APInt::getSignedMinValue(Width); + APInt IntMax = APInt::getSignedMaxValue(Width); + if (C2->isAllOnesValue()) { + // 'sdiv x, -1' produces [INT_MIN + 1, INT_MAX] + // where C2 != -1 and C2 != 0 and C2 != 1 + Lower = IntMin + 1; + Upper = IntMax + 1; + } else if (C2->countLeadingZeros() < Width - 1) { + // 'sdiv x, C2' produces [INT_MIN / C2, INT_MAX / C2] + // where C2 != -1 and C2 != 0 and C2 != 1 + Lower = IntMin.sdiv(*C2); + Upper = IntMax.sdiv(*C2); + if (Lower.sgt(Upper)) + std::swap(Lower, Upper); + Upper = Upper + 1; + assert(Upper != Lower && "Upper part of range has wrapped!"); } - - if (isa<ZExtInst>(LHS)) { - // Turn icmp (zext X), (zext Y) into a compare of X and Y if they have the - // same type. - if (ZExtInst *RI = dyn_cast<ZExtInst>(RHS)) { - if (MaxRecurse && SrcTy == RI->getOperand(0)->getType()) - // Compare X and Y. Note that signed predicates become unsigned. - if (Value *V = SimplifyICmpInst(ICmpInst::getUnsignedPredicate(Pred), - SrcOp, RI->getOperand(0), Q, - MaxRecurse-1)) - return V; - } - // Turn icmp (zext X), Cst into a compare of X and Cst if Cst is extended - // too. If not, then try to deduce the result of the comparison. - else if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) { - // Compute the constant that would happen if we truncated to SrcTy then - // reextended to DstTy. - Constant *Trunc = ConstantExpr::getTrunc(CI, SrcTy); - Constant *RExt = ConstantExpr::getCast(CastInst::ZExt, Trunc, DstTy); - - // If the re-extended constant didn't change then this is effectively - // also a case of comparing two zero-extended values. - if (RExt == CI && MaxRecurse) - if (Value *V = SimplifyICmpInst(ICmpInst::getUnsignedPredicate(Pred), - SrcOp, Trunc, Q, MaxRecurse-1)) - return V; - - // Otherwise the upper bits of LHS are zero while RHS has a non-zero bit - // there. Use this to work out the result of the comparison. - if (RExt != CI) { - switch (Pred) { - default: llvm_unreachable("Unknown ICmp predicate!"); - // LHS <u RHS. - case ICmpInst::ICMP_EQ: - case ICmpInst::ICMP_UGT: - case ICmpInst::ICMP_UGE: - return ConstantInt::getFalse(CI->getContext()); - - case ICmpInst::ICMP_NE: - case ICmpInst::ICMP_ULT: - case ICmpInst::ICMP_ULE: - return ConstantInt::getTrue(CI->getContext()); - - // LHS is non-negative. If RHS is negative then LHS >s LHS. If RHS - // is non-negative then LHS <s RHS. - case ICmpInst::ICMP_SGT: - case ICmpInst::ICMP_SGE: - return CI->getValue().isNegative() ? - ConstantInt::getTrue(CI->getContext()) : - ConstantInt::getFalse(CI->getContext()); - - case ICmpInst::ICMP_SLT: - case ICmpInst::ICMP_SLE: - return CI->getValue().isNegative() ? - ConstantInt::getFalse(CI->getContext()) : - ConstantInt::getTrue(CI->getContext()); - } - } - } + } else if (match(LHS, m_NUWShl(m_APInt(C2), m_Value()))) { + // 'shl nuw C2, x' produces [C2, C2 << CLZ(C2)] + Lower = *C2; + Upper = Lower.shl(Lower.countLeadingZeros()) + 1; + } else if (match(LHS, m_NSWShl(m_APInt(C2), m_Value()))) { + if (C2->isNegative()) { + // 'shl nsw C2, x' produces [C2 << CLO(C2)-1, C2] + unsigned ShiftAmount = C2->countLeadingOnes() - 1; + Lower = C2->shl(ShiftAmount); + Upper = *C2 + 1; + } else { + // 'shl nsw C2, x' produces [C2, C2 << CLZ(C2)-1] + unsigned ShiftAmount = C2->countLeadingZeros() - 1; + Lower = *C2; + Upper = C2->shl(ShiftAmount) + 1; } + } else if (match(LHS, m_LShr(m_Value(), m_APInt(C2)))) { + // 'lshr x, C2' produces [0, UINT_MAX >> C2]. + APInt NegOne = APInt::getAllOnesValue(Width); + if (C2->ult(Width)) + Upper = NegOne.lshr(*C2) + 1; + } else if (match(LHS, m_LShr(m_APInt(C2), m_Value()))) { + // 'lshr C2, x' produces [C2 >> (Width-1), C2]. + unsigned ShiftAmount = Width - 1; + if (*C2 != 0 && cast<BinaryOperator>(LHS)->isExact()) + ShiftAmount = C2->countTrailingZeros(); + Lower = C2->lshr(ShiftAmount); + Upper = *C2 + 1; + } else if (match(LHS, m_AShr(m_Value(), m_APInt(C2)))) { + // 'ashr x, C2' produces [INT_MIN >> C2, INT_MAX >> C2]. + APInt IntMin = APInt::getSignedMinValue(Width); + APInt IntMax = APInt::getSignedMaxValue(Width); + if (C2->ult(Width)) { + Lower = IntMin.ashr(*C2); + Upper = IntMax.ashr(*C2) + 1; + } + } else if (match(LHS, m_AShr(m_APInt(C2), m_Value()))) { + unsigned ShiftAmount = Width - 1; + if (*C2 != 0 && cast<BinaryOperator>(LHS)->isExact()) + ShiftAmount = C2->countTrailingZeros(); + if (C2->isNegative()) { + // 'ashr C2, x' produces [C2, C2 >> (Width-1)] + Lower = *C2; + Upper = C2->ashr(ShiftAmount) + 1; + } else { + // 'ashr C2, x' produces [C2 >> (Width-1), C2] + Lower = C2->ashr(ShiftAmount); + Upper = *C2 + 1; + } + } else if (match(LHS, m_Or(m_Value(), m_APInt(C2)))) { + // 'or x, C2' produces [C2, UINT_MAX]. + Lower = *C2; + } else if (match(LHS, m_And(m_Value(), m_APInt(C2)))) { + // 'and x, C2' produces [0, C2]. + Upper = *C2 + 1; + } else if (match(LHS, m_NUWAdd(m_Value(), m_APInt(C2)))) { + // 'add nuw x, C2' produces [C2, UINT_MAX]. + Lower = *C2; + } - if (isa<SExtInst>(LHS)) { - // Turn icmp (sext X), (sext Y) into a compare of X and Y if they have the - // same type. - if (SExtInst *RI = dyn_cast<SExtInst>(RHS)) { - if (MaxRecurse && SrcTy == RI->getOperand(0)->getType()) - // Compare X and Y. Note that the predicate does not change. - if (Value *V = SimplifyICmpInst(Pred, SrcOp, RI->getOperand(0), - Q, MaxRecurse-1)) - return V; - } - // Turn icmp (sext X), Cst into a compare of X and Cst if Cst is extended - // too. If not, then try to deduce the result of the comparison. - else if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) { - // Compute the constant that would happen if we truncated to SrcTy then - // reextended to DstTy. - Constant *Trunc = ConstantExpr::getTrunc(CI, SrcTy); - Constant *RExt = ConstantExpr::getCast(CastInst::SExt, Trunc, DstTy); - - // If the re-extended constant didn't change then this is effectively - // also a case of comparing two sign-extended values. - if (RExt == CI && MaxRecurse) - if (Value *V = SimplifyICmpInst(Pred, SrcOp, Trunc, Q, MaxRecurse-1)) - return V; - - // Otherwise the upper bits of LHS are all equal, while RHS has varying - // bits there. Use this to work out the result of the comparison. - if (RExt != CI) { - switch (Pred) { - default: llvm_unreachable("Unknown ICmp predicate!"); - case ICmpInst::ICMP_EQ: - return ConstantInt::getFalse(CI->getContext()); - case ICmpInst::ICMP_NE: - return ConstantInt::getTrue(CI->getContext()); + ConstantRange LHS_CR = + Lower != Upper ? ConstantRange(Lower, Upper) : ConstantRange(Width, true); - // If RHS is non-negative then LHS <s RHS. If RHS is negative then - // LHS >s RHS. - case ICmpInst::ICMP_SGT: - case ICmpInst::ICMP_SGE: - return CI->getValue().isNegative() ? - ConstantInt::getTrue(CI->getContext()) : - ConstantInt::getFalse(CI->getContext()); - case ICmpInst::ICMP_SLT: - case ICmpInst::ICMP_SLE: - return CI->getValue().isNegative() ? - ConstantInt::getFalse(CI->getContext()) : - ConstantInt::getTrue(CI->getContext()); + if (auto *I = dyn_cast<Instruction>(LHS)) + if (auto *Ranges = I->getMetadata(LLVMContext::MD_range)) + LHS_CR = LHS_CR.intersectWith(getConstantRangeFromMetadata(*Ranges)); - // If LHS is non-negative then LHS <u RHS. If LHS is negative then - // LHS >u RHS. - case ICmpInst::ICMP_UGT: - case ICmpInst::ICMP_UGE: - // Comparison is true iff the LHS <s 0. - if (MaxRecurse) - if (Value *V = SimplifyICmpInst(ICmpInst::ICMP_SLT, SrcOp, - Constant::getNullValue(SrcTy), - Q, MaxRecurse-1)) - return V; - break; - case ICmpInst::ICMP_ULT: - case ICmpInst::ICMP_ULE: - // Comparison is true iff the LHS >=s 0. - if (MaxRecurse) - if (Value *V = SimplifyICmpInst(ICmpInst::ICMP_SGE, SrcOp, - Constant::getNullValue(SrcTy), - Q, MaxRecurse-1)) - return V; - break; - } - } - } - } + if (!LHS_CR.isFullSet()) { + if (RHS_CR.contains(LHS_CR)) + return ConstantInt::getTrue(GetCompareTy(RHS)); + if (RHS_CR.inverse().contains(LHS_CR)) + return ConstantInt::getFalse(GetCompareTy(RHS)); } - // icmp eq|ne X, Y -> false|true if X != Y - if ((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) && - isKnownNonEqual(LHS, RHS, Q.DL, Q.AC, Q.CxtI, Q.DT)) { - LLVMContext &Ctx = LHS->getType()->getContext(); - return Pred == ICmpInst::ICMP_NE ? - ConstantInt::getTrue(Ctx) : ConstantInt::getFalse(Ctx); - } + return nullptr; +} + +static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, + Value *RHS, const Query &Q, + unsigned MaxRecurse) { + Type *ITy = GetCompareTy(LHS); // The return type. - // Special logic for binary operators. BinaryOperator *LBO = dyn_cast<BinaryOperator>(LHS); BinaryOperator *RBO = dyn_cast<BinaryOperator>(RHS); if (MaxRecurse && (LBO || RBO)) { @@ -2622,35 +2529,39 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, // LHS = A + B (or A and B are null); RHS = C + D (or C and D are null). bool NoLHSWrapProblem = false, NoRHSWrapProblem = false; if (LBO && LBO->getOpcode() == Instruction::Add) { - A = LBO->getOperand(0); B = LBO->getOperand(1); - NoLHSWrapProblem = ICmpInst::isEquality(Pred) || - (CmpInst::isUnsigned(Pred) && LBO->hasNoUnsignedWrap()) || - (CmpInst::isSigned(Pred) && LBO->hasNoSignedWrap()); + A = LBO->getOperand(0); + B = LBO->getOperand(1); + NoLHSWrapProblem = + ICmpInst::isEquality(Pred) || + (CmpInst::isUnsigned(Pred) && LBO->hasNoUnsignedWrap()) || + (CmpInst::isSigned(Pred) && LBO->hasNoSignedWrap()); } if (RBO && RBO->getOpcode() == Instruction::Add) { - C = RBO->getOperand(0); D = RBO->getOperand(1); - NoRHSWrapProblem = ICmpInst::isEquality(Pred) || - (CmpInst::isUnsigned(Pred) && RBO->hasNoUnsignedWrap()) || - (CmpInst::isSigned(Pred) && RBO->hasNoSignedWrap()); + C = RBO->getOperand(0); + D = RBO->getOperand(1); + NoRHSWrapProblem = + ICmpInst::isEquality(Pred) || + (CmpInst::isUnsigned(Pred) && RBO->hasNoUnsignedWrap()) || + (CmpInst::isSigned(Pred) && RBO->hasNoSignedWrap()); } // icmp (X+Y), X -> icmp Y, 0 for equalities or if there is no overflow. if ((A == RHS || B == RHS) && NoLHSWrapProblem) if (Value *V = SimplifyICmpInst(Pred, A == RHS ? B : A, - Constant::getNullValue(RHS->getType()), - Q, MaxRecurse-1)) + Constant::getNullValue(RHS->getType()), Q, + MaxRecurse - 1)) return V; // icmp X, (X+Y) -> icmp 0, Y for equalities or if there is no overflow. if ((C == LHS || D == LHS) && NoRHSWrapProblem) - if (Value *V = SimplifyICmpInst(Pred, - Constant::getNullValue(LHS->getType()), - C == LHS ? D : C, Q, MaxRecurse-1)) + if (Value *V = + SimplifyICmpInst(Pred, Constant::getNullValue(LHS->getType()), + C == LHS ? D : C, Q, MaxRecurse - 1)) return V; // icmp (X+Y), (X+Z) -> icmp Y,Z for equalities or if there is no overflow. - if (A && C && (A == C || A == D || B == C || B == D) && - NoLHSWrapProblem && NoRHSWrapProblem) { + if (A && C && (A == C || A == D || B == C || B == D) && NoLHSWrapProblem && + NoRHSWrapProblem) { // Determine Y and Z in the form icmp (X+Y), (X+Z). Value *Y, *Z; if (A == C) { @@ -2671,7 +2582,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, Y = A; Z = C; } - if (Value *V = SimplifyICmpInst(Pred, Y, Z, Q, MaxRecurse-1)) + if (Value *V = SimplifyICmpInst(Pred, Y, Z, Q, MaxRecurse - 1)) return V; } } @@ -2771,7 +2682,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, Q.CxtI, Q.DT); if (!KnownNonNegative) break; - // fall-through + LLVM_FALLTHROUGH; case ICmpInst::ICMP_EQ: case ICmpInst::ICMP_UGT: case ICmpInst::ICMP_UGE: @@ -2782,7 +2693,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, Q.CxtI, Q.DT); if (!KnownNonNegative) break; - // fall-through + LLVM_FALLTHROUGH; case ICmpInst::ICMP_NE: case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_ULE: @@ -2802,7 +2713,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, Q.CxtI, Q.DT); if (!KnownNonNegative) break; - // fall-through + LLVM_FALLTHROUGH; case ICmpInst::ICMP_NE: case ICmpInst::ICMP_UGT: case ICmpInst::ICMP_UGE: @@ -2813,7 +2724,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, Q.CxtI, Q.DT); if (!KnownNonNegative) break; - // fall-through + LLVM_FALLTHROUGH; case ICmpInst::ICMP_EQ: case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_ULE: @@ -2832,6 +2743,17 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, return getTrue(ITy); } + // x >=u x >> y + // x >=u x udiv y. + if (RBO && (match(RBO, m_LShr(m_Specific(LHS), m_Value())) || + match(RBO, m_UDiv(m_Specific(LHS), m_Value())))) { + // icmp pred X, (X op Y) + if (Pred == ICmpInst::ICMP_ULT) + return getFalse(ITy); + if (Pred == ICmpInst::ICMP_UGE) + return getTrue(ITy); + } + // handle: // CI2 << X == CI // CI2 << X != CI @@ -2870,18 +2792,19 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, if (MaxRecurse && LBO && RBO && LBO->getOpcode() == RBO->getOpcode() && LBO->getOperand(1) == RBO->getOperand(1)) { switch (LBO->getOpcode()) { - default: break; + default: + break; case Instruction::UDiv: case Instruction::LShr: if (ICmpInst::isSigned(Pred)) break; - // fall-through + LLVM_FALLTHROUGH; case Instruction::SDiv: case Instruction::AShr: if (!LBO->isExact() || !RBO->isExact()) break; if (Value *V = SimplifyICmpInst(Pred, LBO->getOperand(0), - RBO->getOperand(0), Q, MaxRecurse-1)) + RBO->getOperand(0), Q, MaxRecurse - 1)) return V; break; case Instruction::Shl: { @@ -2892,40 +2815,51 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, if (!NSW && ICmpInst::isSigned(Pred)) break; if (Value *V = SimplifyICmpInst(Pred, LBO->getOperand(0), - RBO->getOperand(0), Q, MaxRecurse-1)) + RBO->getOperand(0), Q, MaxRecurse - 1)) return V; break; } } } + return nullptr; +} - // Simplify comparisons involving max/min. +/// Simplify integer comparisons where at least one operand of the compare +/// matches an integer min/max idiom. +static Value *simplifyICmpWithMinMax(CmpInst::Predicate Pred, Value *LHS, + Value *RHS, const Query &Q, + unsigned MaxRecurse) { + Type *ITy = GetCompareTy(LHS); // The return type. Value *A, *B; CmpInst::Predicate P = CmpInst::BAD_ICMP_PREDICATE; CmpInst::Predicate EqP; // Chosen so that "A == max/min(A,B)" iff "A EqP B". // Signed variants on "max(a,b)>=a -> true". if (match(LHS, m_SMax(m_Value(A), m_Value(B))) && (A == RHS || B == RHS)) { - if (A != RHS) std::swap(A, B); // smax(A, B) pred A. + if (A != RHS) + std::swap(A, B); // smax(A, B) pred A. EqP = CmpInst::ICMP_SGE; // "A == smax(A, B)" iff "A sge B". // We analyze this as smax(A, B) pred A. P = Pred; } else if (match(RHS, m_SMax(m_Value(A), m_Value(B))) && (A == LHS || B == LHS)) { - if (A != LHS) std::swap(A, B); // A pred smax(A, B). + if (A != LHS) + std::swap(A, B); // A pred smax(A, B). EqP = CmpInst::ICMP_SGE; // "A == smax(A, B)" iff "A sge B". // We analyze this as smax(A, B) swapped-pred A. P = CmpInst::getSwappedPredicate(Pred); } else if (match(LHS, m_SMin(m_Value(A), m_Value(B))) && (A == RHS || B == RHS)) { - if (A != RHS) std::swap(A, B); // smin(A, B) pred A. + if (A != RHS) + std::swap(A, B); // smin(A, B) pred A. EqP = CmpInst::ICMP_SLE; // "A == smin(A, B)" iff "A sle B". // We analyze this as smax(-A, -B) swapped-pred -A. // Note that we do not need to actually form -A or -B thanks to EqP. P = CmpInst::getSwappedPredicate(Pred); } else if (match(RHS, m_SMin(m_Value(A), m_Value(B))) && (A == LHS || B == LHS)) { - if (A != LHS) std::swap(A, B); // A pred smin(A, B). + if (A != LHS) + std::swap(A, B); // A pred smin(A, B). EqP = CmpInst::ICMP_SLE; // "A == smin(A, B)" iff "A sle B". // We analyze this as smax(-A, -B) pred -A. // Note that we do not need to actually form -A or -B thanks to EqP. @@ -2946,7 +2880,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, return V; // Otherwise, see if "A EqP B" simplifies. if (MaxRecurse) - if (Value *V = SimplifyICmpInst(EqP, A, B, Q, MaxRecurse-1)) + if (Value *V = SimplifyICmpInst(EqP, A, B, Q, MaxRecurse - 1)) return V; break; case CmpInst::ICMP_NE: @@ -2960,7 +2894,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, return V; // Otherwise, see if "A InvEqP B" simplifies. if (MaxRecurse) - if (Value *V = SimplifyICmpInst(InvEqP, A, B, Q, MaxRecurse-1)) + if (Value *V = SimplifyICmpInst(InvEqP, A, B, Q, MaxRecurse - 1)) return V; break; } @@ -2976,26 +2910,30 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, // Unsigned variants on "max(a,b)>=a -> true". P = CmpInst::BAD_ICMP_PREDICATE; if (match(LHS, m_UMax(m_Value(A), m_Value(B))) && (A == RHS || B == RHS)) { - if (A != RHS) std::swap(A, B); // umax(A, B) pred A. + if (A != RHS) + std::swap(A, B); // umax(A, B) pred A. EqP = CmpInst::ICMP_UGE; // "A == umax(A, B)" iff "A uge B". // We analyze this as umax(A, B) pred A. P = Pred; } else if (match(RHS, m_UMax(m_Value(A), m_Value(B))) && (A == LHS || B == LHS)) { - if (A != LHS) std::swap(A, B); // A pred umax(A, B). + if (A != LHS) + std::swap(A, B); // A pred umax(A, B). EqP = CmpInst::ICMP_UGE; // "A == umax(A, B)" iff "A uge B". // We analyze this as umax(A, B) swapped-pred A. P = CmpInst::getSwappedPredicate(Pred); } else if (match(LHS, m_UMin(m_Value(A), m_Value(B))) && (A == RHS || B == RHS)) { - if (A != RHS) std::swap(A, B); // umin(A, B) pred A. + if (A != RHS) + std::swap(A, B); // umin(A, B) pred A. EqP = CmpInst::ICMP_ULE; // "A == umin(A, B)" iff "A ule B". // We analyze this as umax(-A, -B) swapped-pred -A. // Note that we do not need to actually form -A or -B thanks to EqP. P = CmpInst::getSwappedPredicate(Pred); } else if (match(RHS, m_UMin(m_Value(A), m_Value(B))) && (A == LHS || B == LHS)) { - if (A != LHS) std::swap(A, B); // A pred umin(A, B). + if (A != LHS) + std::swap(A, B); // A pred umin(A, B). EqP = CmpInst::ICMP_ULE; // "A == umin(A, B)" iff "A ule B". // We analyze this as umax(-A, -B) pred -A. // Note that we do not need to actually form -A or -B thanks to EqP. @@ -3016,7 +2954,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, return V; // Otherwise, see if "A EqP B" simplifies. if (MaxRecurse) - if (Value *V = SimplifyICmpInst(EqP, A, B, Q, MaxRecurse-1)) + if (Value *V = SimplifyICmpInst(EqP, A, B, Q, MaxRecurse - 1)) return V; break; case CmpInst::ICMP_NE: @@ -3030,7 +2968,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, return V; // Otherwise, see if "A InvEqP B" simplifies. if (MaxRecurse) - if (Value *V = SimplifyICmpInst(InvEqP, A, B, Q, MaxRecurse-1)) + if (Value *V = SimplifyICmpInst(InvEqP, A, B, Q, MaxRecurse - 1)) return V; break; } @@ -3087,11 +3025,254 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, return getFalse(ITy); } + return nullptr; +} + +/// Given operands for an ICmpInst, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, + const Query &Q, unsigned MaxRecurse) { + CmpInst::Predicate Pred = (CmpInst::Predicate)Predicate; + assert(CmpInst::isIntPredicate(Pred) && "Not an integer compare!"); + + if (Constant *CLHS = dyn_cast<Constant>(LHS)) { + if (Constant *CRHS = dyn_cast<Constant>(RHS)) + return ConstantFoldCompareInstOperands(Pred, CLHS, CRHS, Q.DL, Q.TLI); + + // If we have a constant, make sure it is on the RHS. + std::swap(LHS, RHS); + Pred = CmpInst::getSwappedPredicate(Pred); + } + + Type *ITy = GetCompareTy(LHS); // The return type. + + // icmp X, X -> true/false + // X icmp undef -> true/false. For example, icmp ugt %X, undef -> false + // because X could be 0. + if (LHS == RHS || isa<UndefValue>(RHS)) + return ConstantInt::get(ITy, CmpInst::isTrueWhenEqual(Pred)); + + if (Value *V = simplifyICmpOfBools(Pred, LHS, RHS, Q)) + return V; + + if (Value *V = simplifyICmpWithZero(Pred, LHS, RHS, Q)) + return V; + + if (Value *V = simplifyICmpWithConstant(Pred, LHS, RHS)) + return V; + + // If both operands have range metadata, use the metadata + // to simplify the comparison. + if (isa<Instruction>(RHS) && isa<Instruction>(LHS)) { + auto RHS_Instr = dyn_cast<Instruction>(RHS); + auto LHS_Instr = dyn_cast<Instruction>(LHS); + + if (RHS_Instr->getMetadata(LLVMContext::MD_range) && + LHS_Instr->getMetadata(LLVMContext::MD_range)) { + auto RHS_CR = getConstantRangeFromMetadata( + *RHS_Instr->getMetadata(LLVMContext::MD_range)); + auto LHS_CR = getConstantRangeFromMetadata( + *LHS_Instr->getMetadata(LLVMContext::MD_range)); + + auto Satisfied_CR = ConstantRange::makeSatisfyingICmpRegion(Pred, RHS_CR); + if (Satisfied_CR.contains(LHS_CR)) + return ConstantInt::getTrue(RHS->getContext()); + + auto InversedSatisfied_CR = ConstantRange::makeSatisfyingICmpRegion( + CmpInst::getInversePredicate(Pred), RHS_CR); + if (InversedSatisfied_CR.contains(LHS_CR)) + return ConstantInt::getFalse(RHS->getContext()); + } + } + + // Compare of cast, for example (zext X) != 0 -> X != 0 + if (isa<CastInst>(LHS) && (isa<Constant>(RHS) || isa<CastInst>(RHS))) { + Instruction *LI = cast<CastInst>(LHS); + Value *SrcOp = LI->getOperand(0); + Type *SrcTy = SrcOp->getType(); + Type *DstTy = LI->getType(); + + // Turn icmp (ptrtoint x), (ptrtoint/constant) into a compare of the input + // if the integer type is the same size as the pointer type. + if (MaxRecurse && isa<PtrToIntInst>(LI) && + Q.DL.getTypeSizeInBits(SrcTy) == DstTy->getPrimitiveSizeInBits()) { + if (Constant *RHSC = dyn_cast<Constant>(RHS)) { + // Transfer the cast to the constant. + if (Value *V = SimplifyICmpInst(Pred, SrcOp, + ConstantExpr::getIntToPtr(RHSC, SrcTy), + Q, MaxRecurse-1)) + return V; + } else if (PtrToIntInst *RI = dyn_cast<PtrToIntInst>(RHS)) { + if (RI->getOperand(0)->getType() == SrcTy) + // Compare without the cast. + if (Value *V = SimplifyICmpInst(Pred, SrcOp, RI->getOperand(0), + Q, MaxRecurse-1)) + return V; + } + } + + if (isa<ZExtInst>(LHS)) { + // Turn icmp (zext X), (zext Y) into a compare of X and Y if they have the + // same type. + if (ZExtInst *RI = dyn_cast<ZExtInst>(RHS)) { + if (MaxRecurse && SrcTy == RI->getOperand(0)->getType()) + // Compare X and Y. Note that signed predicates become unsigned. + if (Value *V = SimplifyICmpInst(ICmpInst::getUnsignedPredicate(Pred), + SrcOp, RI->getOperand(0), Q, + MaxRecurse-1)) + return V; + } + // Turn icmp (zext X), Cst into a compare of X and Cst if Cst is extended + // too. If not, then try to deduce the result of the comparison. + else if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) { + // Compute the constant that would happen if we truncated to SrcTy then + // reextended to DstTy. + Constant *Trunc = ConstantExpr::getTrunc(CI, SrcTy); + Constant *RExt = ConstantExpr::getCast(CastInst::ZExt, Trunc, DstTy); + + // If the re-extended constant didn't change then this is effectively + // also a case of comparing two zero-extended values. + if (RExt == CI && MaxRecurse) + if (Value *V = SimplifyICmpInst(ICmpInst::getUnsignedPredicate(Pred), + SrcOp, Trunc, Q, MaxRecurse-1)) + return V; + + // Otherwise the upper bits of LHS are zero while RHS has a non-zero bit + // there. Use this to work out the result of the comparison. + if (RExt != CI) { + switch (Pred) { + default: llvm_unreachable("Unknown ICmp predicate!"); + // LHS <u RHS. + case ICmpInst::ICMP_EQ: + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_UGE: + return ConstantInt::getFalse(CI->getContext()); + + case ICmpInst::ICMP_NE: + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_ULE: + return ConstantInt::getTrue(CI->getContext()); + + // LHS is non-negative. If RHS is negative then LHS >s LHS. If RHS + // is non-negative then LHS <s RHS. + case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_SGE: + return CI->getValue().isNegative() ? + ConstantInt::getTrue(CI->getContext()) : + ConstantInt::getFalse(CI->getContext()); + + case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_SLE: + return CI->getValue().isNegative() ? + ConstantInt::getFalse(CI->getContext()) : + ConstantInt::getTrue(CI->getContext()); + } + } + } + } + + if (isa<SExtInst>(LHS)) { + // Turn icmp (sext X), (sext Y) into a compare of X and Y if they have the + // same type. + if (SExtInst *RI = dyn_cast<SExtInst>(RHS)) { + if (MaxRecurse && SrcTy == RI->getOperand(0)->getType()) + // Compare X and Y. Note that the predicate does not change. + if (Value *V = SimplifyICmpInst(Pred, SrcOp, RI->getOperand(0), + Q, MaxRecurse-1)) + return V; + } + // Turn icmp (sext X), Cst into a compare of X and Cst if Cst is extended + // too. If not, then try to deduce the result of the comparison. + else if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) { + // Compute the constant that would happen if we truncated to SrcTy then + // reextended to DstTy. + Constant *Trunc = ConstantExpr::getTrunc(CI, SrcTy); + Constant *RExt = ConstantExpr::getCast(CastInst::SExt, Trunc, DstTy); + + // If the re-extended constant didn't change then this is effectively + // also a case of comparing two sign-extended values. + if (RExt == CI && MaxRecurse) + if (Value *V = SimplifyICmpInst(Pred, SrcOp, Trunc, Q, MaxRecurse-1)) + return V; + + // Otherwise the upper bits of LHS are all equal, while RHS has varying + // bits there. Use this to work out the result of the comparison. + if (RExt != CI) { + switch (Pred) { + default: llvm_unreachable("Unknown ICmp predicate!"); + case ICmpInst::ICMP_EQ: + return ConstantInt::getFalse(CI->getContext()); + case ICmpInst::ICMP_NE: + return ConstantInt::getTrue(CI->getContext()); + + // If RHS is non-negative then LHS <s RHS. If RHS is negative then + // LHS >s RHS. + case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_SGE: + return CI->getValue().isNegative() ? + ConstantInt::getTrue(CI->getContext()) : + ConstantInt::getFalse(CI->getContext()); + case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_SLE: + return CI->getValue().isNegative() ? + ConstantInt::getFalse(CI->getContext()) : + ConstantInt::getTrue(CI->getContext()); + + // If LHS is non-negative then LHS <u RHS. If LHS is negative then + // LHS >u RHS. + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_UGE: + // Comparison is true iff the LHS <s 0. + if (MaxRecurse) + if (Value *V = SimplifyICmpInst(ICmpInst::ICMP_SLT, SrcOp, + Constant::getNullValue(SrcTy), + Q, MaxRecurse-1)) + return V; + break; + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_ULE: + // Comparison is true iff the LHS >=s 0. + if (MaxRecurse) + if (Value *V = SimplifyICmpInst(ICmpInst::ICMP_SGE, SrcOp, + Constant::getNullValue(SrcTy), + Q, MaxRecurse-1)) + return V; + break; + } + } + } + } + } + + // icmp eq|ne X, Y -> false|true if X != Y + if ((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) && + isKnownNonEqual(LHS, RHS, Q.DL, Q.AC, Q.CxtI, Q.DT)) { + LLVMContext &Ctx = LHS->getType()->getContext(); + return Pred == ICmpInst::ICMP_NE ? + ConstantInt::getTrue(Ctx) : ConstantInt::getFalse(Ctx); + } + + if (Value *V = simplifyICmpWithBinOp(Pred, LHS, RHS, Q, MaxRecurse)) + return V; + + if (Value *V = simplifyICmpWithMinMax(Pred, LHS, RHS, Q, MaxRecurse)) + return V; + // Simplify comparisons of related pointers using a powerful, recursive // GEP-walk when we have target data available.. if (LHS->getType()->isPointerTy()) if (auto *C = computePointerICmp(Q.DL, Q.TLI, Q.DT, Pred, Q.CxtI, LHS, RHS)) return C; + if (auto *CLHS = dyn_cast<PtrToIntOperator>(LHS)) + if (auto *CRHS = dyn_cast<PtrToIntOperator>(RHS)) + if (Q.DL.getTypeSizeInBits(CLHS->getPointerOperandType()) == + Q.DL.getTypeSizeInBits(CLHS->getType()) && + Q.DL.getTypeSizeInBits(CRHS->getPointerOperandType()) == + Q.DL.getTypeSizeInBits(CRHS->getType())) + if (auto *C = computePointerICmp(Q.DL, Q.TLI, Q.DT, Pred, Q.CxtI, + CLHS->getPointerOperand(), + CRHS->getPointerOperand())) + return C; if (GetElementPtrInst *GLHS = dyn_cast<GetElementPtrInst>(LHS)) { if (GEPOperator *GRHS = dyn_cast<GEPOperator>(RHS)) { @@ -3119,17 +3300,16 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, // If a bit is known to be zero for A and known to be one for B, // then A and B cannot be equal. if (ICmpInst::isEquality(Pred)) { - if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) { - uint32_t BitWidth = CI->getBitWidth(); + const APInt *RHSVal; + if (match(RHS, m_APInt(RHSVal))) { + unsigned BitWidth = RHSVal->getBitWidth(); APInt LHSKnownZero(BitWidth, 0); APInt LHSKnownOne(BitWidth, 0); computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT); - const APInt &RHSVal = CI->getValue(); - if (((LHSKnownZero & RHSVal) != 0) || ((LHSKnownOne & ~RHSVal) != 0)) - return Pred == ICmpInst::ICMP_EQ - ? ConstantInt::getFalse(CI->getContext()) - : ConstantInt::getTrue(CI->getContext()); + if (((LHSKnownZero & *RHSVal) != 0) || ((LHSKnownOne & ~(*RHSVal)) != 0)) + return Pred == ICmpInst::ICMP_EQ ? ConstantInt::getFalse(ITy) + : ConstantInt::getTrue(ITy); } } @@ -3175,17 +3355,18 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, } // Fold trivial predicates. + Type *RetTy = GetCompareTy(LHS); if (Pred == FCmpInst::FCMP_FALSE) - return ConstantInt::get(GetCompareTy(LHS), 0); + return getFalse(RetTy); if (Pred == FCmpInst::FCMP_TRUE) - return ConstantInt::get(GetCompareTy(LHS), 1); + return getTrue(RetTy); // UNO/ORD predicates can be trivially folded if NaNs are ignored. if (FMF.noNaNs()) { if (Pred == FCmpInst::FCMP_UNO) - return ConstantInt::get(GetCompareTy(LHS), 0); + return getFalse(RetTy); if (Pred == FCmpInst::FCMP_ORD) - return ConstantInt::get(GetCompareTy(LHS), 1); + return getTrue(RetTy); } // fcmp pred x, undef and fcmp pred undef, x @@ -3193,15 +3374,15 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, if (isa<UndefValue>(LHS) || isa<UndefValue>(RHS)) { // Choosing NaN for the undef will always make unordered comparison succeed // and ordered comparison fail. - return ConstantInt::get(GetCompareTy(LHS), CmpInst::isUnordered(Pred)); + return ConstantInt::get(RetTy, CmpInst::isUnordered(Pred)); } // fcmp x,x -> true/false. Not all compares are foldable. if (LHS == RHS) { if (CmpInst::isTrueWhenEqual(Pred)) - return ConstantInt::get(GetCompareTy(LHS), 1); + return getTrue(RetTy); if (CmpInst::isFalseWhenEqual(Pred)) - return ConstantInt::get(GetCompareTy(LHS), 0); + return getFalse(RetTy); } // Handle fcmp with constant RHS @@ -3216,11 +3397,11 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, // If the constant is a nan, see if we can fold the comparison based on it. if (CFP->getValueAPF().isNaN()) { if (FCmpInst::isOrdered(Pred)) // True "if ordered and foo" - return ConstantInt::getFalse(CFP->getContext()); + return getFalse(RetTy); assert(FCmpInst::isUnordered(Pred) && "Comparison must be either ordered or unordered!"); // True if unordered. - return ConstantInt::get(GetCompareTy(LHS), 1); + return getTrue(RetTy); } // Check whether the constant is an infinity. if (CFP->getValueAPF().isInfinity()) { @@ -3228,10 +3409,10 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, switch (Pred) { case FCmpInst::FCMP_OLT: // No value is ordered and less than negative infinity. - return ConstantInt::get(GetCompareTy(LHS), 0); + return getFalse(RetTy); case FCmpInst::FCMP_UGE: // All values are unordered with or at least negative infinity. - return ConstantInt::get(GetCompareTy(LHS), 1); + return getTrue(RetTy); default: break; } @@ -3239,10 +3420,10 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, switch (Pred) { case FCmpInst::FCMP_OGT: // No value is ordered and greater than infinity. - return ConstantInt::get(GetCompareTy(LHS), 0); + return getFalse(RetTy); case FCmpInst::FCMP_ULE: // All values are unordered with and at most infinity. - return ConstantInt::get(GetCompareTy(LHS), 1); + return getTrue(RetTy); default: break; } @@ -3252,12 +3433,12 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, switch (Pred) { case FCmpInst::FCMP_UGE: if (CannotBeOrderedLessThanZero(LHS, Q.TLI)) - return ConstantInt::get(GetCompareTy(LHS), 1); + return getTrue(RetTy); break; case FCmpInst::FCMP_OLT: // X < 0 if (CannotBeOrderedLessThanZero(LHS, Q.TLI)) - return ConstantInt::get(GetCompareTy(LHS), 0); + return getFalse(RetTy); break; default: break; @@ -3371,6 +3552,150 @@ static const Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, return nullptr; } +/// Try to simplify a select instruction when its condition operand is an +/// integer comparison where one operand of the compare is a constant. +static Value *simplifySelectBitTest(Value *TrueVal, Value *FalseVal, Value *X, + const APInt *Y, bool TrueWhenUnset) { + const APInt *C; + + // (X & Y) == 0 ? X & ~Y : X --> X + // (X & Y) != 0 ? X & ~Y : X --> X & ~Y + if (FalseVal == X && match(TrueVal, m_And(m_Specific(X), m_APInt(C))) && + *Y == ~*C) + return TrueWhenUnset ? FalseVal : TrueVal; + + // (X & Y) == 0 ? X : X & ~Y --> X & ~Y + // (X & Y) != 0 ? X : X & ~Y --> X + if (TrueVal == X && match(FalseVal, m_And(m_Specific(X), m_APInt(C))) && + *Y == ~*C) + return TrueWhenUnset ? FalseVal : TrueVal; + + if (Y->isPowerOf2()) { + // (X & Y) == 0 ? X | Y : X --> X | Y + // (X & Y) != 0 ? X | Y : X --> X + if (FalseVal == X && match(TrueVal, m_Or(m_Specific(X), m_APInt(C))) && + *Y == *C) + return TrueWhenUnset ? TrueVal : FalseVal; + + // (X & Y) == 0 ? X : X | Y --> X + // (X & Y) != 0 ? X : X | Y --> X | Y + if (TrueVal == X && match(FalseVal, m_Or(m_Specific(X), m_APInt(C))) && + *Y == *C) + return TrueWhenUnset ? TrueVal : FalseVal; + } + + return nullptr; +} + +/// An alternative way to test if a bit is set or not uses sgt/slt instead of +/// eq/ne. +static Value *simplifySelectWithFakeICmpEq(Value *CmpLHS, Value *TrueVal, + Value *FalseVal, + bool TrueWhenUnset) { + unsigned BitWidth = TrueVal->getType()->getScalarSizeInBits(); + if (!BitWidth) + return nullptr; + + APInt MinSignedValue; + Value *X; + if (match(CmpLHS, m_Trunc(m_Value(X))) && (X == TrueVal || X == FalseVal)) { + // icmp slt (trunc X), 0 <--> icmp ne (and X, C), 0 + // icmp sgt (trunc X), -1 <--> icmp eq (and X, C), 0 + unsigned DestSize = CmpLHS->getType()->getScalarSizeInBits(); + MinSignedValue = APInt::getSignedMinValue(DestSize).zext(BitWidth); + } else { + // icmp slt X, 0 <--> icmp ne (and X, C), 0 + // icmp sgt X, -1 <--> icmp eq (and X, C), 0 + X = CmpLHS; + MinSignedValue = APInt::getSignedMinValue(BitWidth); + } + + if (Value *V = simplifySelectBitTest(TrueVal, FalseVal, X, &MinSignedValue, + TrueWhenUnset)) + return V; + + return nullptr; +} + +/// Try to simplify a select instruction when its condition operand is an +/// integer comparison. +static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal, + Value *FalseVal, const Query &Q, + unsigned MaxRecurse) { + ICmpInst::Predicate Pred; + Value *CmpLHS, *CmpRHS; + if (!match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) + return nullptr; + + // FIXME: This code is nearly duplicated in InstCombine. Using/refactoring + // decomposeBitTestICmp() might help. + if (ICmpInst::isEquality(Pred) && match(CmpRHS, m_Zero())) { + Value *X; + const APInt *Y; + if (match(CmpLHS, m_And(m_Value(X), m_APInt(Y)))) + if (Value *V = simplifySelectBitTest(TrueVal, FalseVal, X, Y, + Pred == ICmpInst::ICMP_EQ)) + return V; + } else if (Pred == ICmpInst::ICMP_SLT && match(CmpRHS, m_Zero())) { + // Comparing signed-less-than 0 checks if the sign bit is set. + if (Value *V = simplifySelectWithFakeICmpEq(CmpLHS, TrueVal, FalseVal, + false)) + return V; + } else if (Pred == ICmpInst::ICMP_SGT && match(CmpRHS, m_AllOnes())) { + // Comparing signed-greater-than -1 checks if the sign bit is not set. + if (Value *V = simplifySelectWithFakeICmpEq(CmpLHS, TrueVal, FalseVal, + true)) + return V; + } + + if (CondVal->hasOneUse()) { + const APInt *C; + if (match(CmpRHS, m_APInt(C))) { + // X < MIN ? T : F --> F + if (Pred == ICmpInst::ICMP_SLT && C->isMinSignedValue()) + return FalseVal; + // X < MIN ? T : F --> F + if (Pred == ICmpInst::ICMP_ULT && C->isMinValue()) + return FalseVal; + // X > MAX ? T : F --> F + if (Pred == ICmpInst::ICMP_SGT && C->isMaxSignedValue()) + return FalseVal; + // X > MAX ? T : F --> F + if (Pred == ICmpInst::ICMP_UGT && C->isMaxValue()) + return FalseVal; + } + } + + // If we have an equality comparison, then we know the value in one of the + // arms of the select. See if substituting this value into the arm and + // simplifying the result yields the same value as the other arm. + if (Pred == ICmpInst::ICMP_EQ) { + if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q, MaxRecurse) == + TrueVal || + SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q, MaxRecurse) == + TrueVal) + return FalseVal; + if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q, MaxRecurse) == + FalseVal || + SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q, MaxRecurse) == + FalseVal) + return FalseVal; + } else if (Pred == ICmpInst::ICMP_NE) { + if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q, MaxRecurse) == + FalseVal || + SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q, MaxRecurse) == + FalseVal) + return TrueVal; + if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q, MaxRecurse) == + TrueVal || + SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q, MaxRecurse) == + TrueVal) + return TrueVal; + } + + return nullptr; +} + /// Given operands for a SelectInst, see if we can fold the result. /// If not, this returns null. static Value *SimplifySelectInst(Value *CondVal, Value *TrueVal, @@ -3399,106 +3724,9 @@ static Value *SimplifySelectInst(Value *CondVal, Value *TrueVal, if (isa<UndefValue>(FalseVal)) // select C, X, undef -> X return TrueVal; - if (const auto *ICI = dyn_cast<ICmpInst>(CondVal)) { - // FIXME: This code is nearly duplicated in InstCombine. Using/refactoring - // decomposeBitTestICmp() might help. - unsigned BitWidth = - Q.DL.getTypeSizeInBits(TrueVal->getType()->getScalarType()); - ICmpInst::Predicate Pred = ICI->getPredicate(); - Value *CmpLHS = ICI->getOperand(0); - Value *CmpRHS = ICI->getOperand(1); - APInt MinSignedValue = APInt::getSignBit(BitWidth); - Value *X; - const APInt *Y; - bool TrueWhenUnset; - bool IsBitTest = false; - if (ICmpInst::isEquality(Pred) && - match(CmpLHS, m_And(m_Value(X), m_APInt(Y))) && - match(CmpRHS, m_Zero())) { - IsBitTest = true; - TrueWhenUnset = Pred == ICmpInst::ICMP_EQ; - } else if (Pred == ICmpInst::ICMP_SLT && match(CmpRHS, m_Zero())) { - X = CmpLHS; - Y = &MinSignedValue; - IsBitTest = true; - TrueWhenUnset = false; - } else if (Pred == ICmpInst::ICMP_SGT && match(CmpRHS, m_AllOnes())) { - X = CmpLHS; - Y = &MinSignedValue; - IsBitTest = true; - TrueWhenUnset = true; - } - if (IsBitTest) { - const APInt *C; - // (X & Y) == 0 ? X & ~Y : X --> X - // (X & Y) != 0 ? X & ~Y : X --> X & ~Y - if (FalseVal == X && match(TrueVal, m_And(m_Specific(X), m_APInt(C))) && - *Y == ~*C) - return TrueWhenUnset ? FalseVal : TrueVal; - // (X & Y) == 0 ? X : X & ~Y --> X & ~Y - // (X & Y) != 0 ? X : X & ~Y --> X - if (TrueVal == X && match(FalseVal, m_And(m_Specific(X), m_APInt(C))) && - *Y == ~*C) - return TrueWhenUnset ? FalseVal : TrueVal; - - if (Y->isPowerOf2()) { - // (X & Y) == 0 ? X | Y : X --> X | Y - // (X & Y) != 0 ? X | Y : X --> X - if (FalseVal == X && match(TrueVal, m_Or(m_Specific(X), m_APInt(C))) && - *Y == *C) - return TrueWhenUnset ? TrueVal : FalseVal; - // (X & Y) == 0 ? X : X | Y --> X - // (X & Y) != 0 ? X : X | Y --> X | Y - if (TrueVal == X && match(FalseVal, m_Or(m_Specific(X), m_APInt(C))) && - *Y == *C) - return TrueWhenUnset ? TrueVal : FalseVal; - } - } - if (ICI->hasOneUse()) { - const APInt *C; - if (match(CmpRHS, m_APInt(C))) { - // X < MIN ? T : F --> F - if (Pred == ICmpInst::ICMP_SLT && C->isMinSignedValue()) - return FalseVal; - // X < MIN ? T : F --> F - if (Pred == ICmpInst::ICMP_ULT && C->isMinValue()) - return FalseVal; - // X > MAX ? T : F --> F - if (Pred == ICmpInst::ICMP_SGT && C->isMaxSignedValue()) - return FalseVal; - // X > MAX ? T : F --> F - if (Pred == ICmpInst::ICMP_UGT && C->isMaxValue()) - return FalseVal; - } - } - - // If we have an equality comparison then we know the value in one of the - // arms of the select. See if substituting this value into the arm and - // simplifying the result yields the same value as the other arm. - if (Pred == ICmpInst::ICMP_EQ) { - if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q, MaxRecurse) == - TrueVal || - SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q, MaxRecurse) == - TrueVal) - return FalseVal; - if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q, MaxRecurse) == - FalseVal || - SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q, MaxRecurse) == - FalseVal) - return FalseVal; - } else if (Pred == ICmpInst::ICMP_NE) { - if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q, MaxRecurse) == - FalseVal || - SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q, MaxRecurse) == - FalseVal) - return TrueVal; - if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q, MaxRecurse) == - TrueVal || - SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q, MaxRecurse) == - TrueVal) - return TrueVal; - } - } + if (Value *V = + simplifySelectWithICmpCond(CondVal, TrueVal, FalseVal, Q, MaxRecurse)) + return V; return nullptr; } @@ -3587,6 +3815,32 @@ static Value *SimplifyGEPInst(Type *SrcTy, ArrayRef<Value *> Ops, } } + if (Q.DL.getTypeAllocSize(LastType) == 1 && + all_of(Ops.slice(1).drop_back(1), + [](Value *Idx) { return match(Idx, m_Zero()); })) { + unsigned PtrWidth = + Q.DL.getPointerSizeInBits(Ops[0]->getType()->getPointerAddressSpace()); + if (Q.DL.getTypeSizeInBits(Ops.back()->getType()) == PtrWidth) { + APInt BasePtrOffset(PtrWidth, 0); + Value *StrippedBasePtr = + Ops[0]->stripAndAccumulateInBoundsConstantOffsets(Q.DL, + BasePtrOffset); + + // gep (gep V, C), (sub 0, V) -> C + if (match(Ops.back(), + m_Sub(m_Zero(), m_PtrToInt(m_Specific(StrippedBasePtr))))) { + auto *CI = ConstantInt::get(GEPTy->getContext(), BasePtrOffset); + return ConstantExpr::getIntToPtr(CI, GEPTy); + } + // gep (gep V, C), (xor V, -1) -> C-1 + if (match(Ops.back(), + m_Xor(m_PtrToInt(m_Specific(StrippedBasePtr)), m_AllOnes()))) { + auto *CI = ConstantInt::get(GEPTy->getContext(), BasePtrOffset - 1); + return ConstantExpr::getIntToPtr(CI, GEPTy); + } + } + } + // Check to see if this is constant foldable. for (unsigned i = 0, e = Ops.size(); i != e; ++i) if (!isa<Constant>(Ops[i])) @@ -3742,19 +3996,47 @@ static Value *SimplifyPHINode(PHINode *PN, const Query &Q) { return CommonValue; } -static Value *SimplifyTruncInst(Value *Op, Type *Ty, const Query &Q, unsigned) { - if (Constant *C = dyn_cast<Constant>(Op)) - return ConstantFoldCastOperand(Instruction::Trunc, C, Ty, Q.DL); +static Value *SimplifyCastInst(unsigned CastOpc, Value *Op, + Type *Ty, const Query &Q, unsigned MaxRecurse) { + if (auto *C = dyn_cast<Constant>(Op)) + return ConstantFoldCastOperand(CastOpc, C, Ty, Q.DL); + + if (auto *CI = dyn_cast<CastInst>(Op)) { + auto *Src = CI->getOperand(0); + Type *SrcTy = Src->getType(); + Type *MidTy = CI->getType(); + Type *DstTy = Ty; + if (Src->getType() == Ty) { + auto FirstOp = static_cast<Instruction::CastOps>(CI->getOpcode()); + auto SecondOp = static_cast<Instruction::CastOps>(CastOpc); + Type *SrcIntPtrTy = + SrcTy->isPtrOrPtrVectorTy() ? Q.DL.getIntPtrType(SrcTy) : nullptr; + Type *MidIntPtrTy = + MidTy->isPtrOrPtrVectorTy() ? Q.DL.getIntPtrType(MidTy) : nullptr; + Type *DstIntPtrTy = + DstTy->isPtrOrPtrVectorTy() ? Q.DL.getIntPtrType(DstTy) : nullptr; + if (CastInst::isEliminableCastPair(FirstOp, SecondOp, SrcTy, MidTy, DstTy, + SrcIntPtrTy, MidIntPtrTy, + DstIntPtrTy) == Instruction::BitCast) + return Src; + } + } + + // bitcast x -> x + if (CastOpc == Instruction::BitCast) + if (Op->getType() == Ty) + return Op; return nullptr; } -Value *llvm::SimplifyTruncInst(Value *Op, Type *Ty, const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyTruncInst(Op, Ty, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); +Value *llvm::SimplifyCastInst(unsigned CastOpc, Value *Op, Type *Ty, + const DataLayout &DL, + const TargetLibraryInfo *TLI, + const DominatorTree *DT, AssumptionCache *AC, + const Instruction *CxtI) { + return ::SimplifyCastInst(CastOpc, Op, Ty, Query(DL, TLI, DT, AC, CxtI), + RecursionLimit); } //=== Helper functions for higher up the class hierarchy. @@ -3837,6 +4119,8 @@ static Value *SimplifyFPBinOp(unsigned Opcode, Value *LHS, Value *RHS, return SimplifyFSubInst(LHS, RHS, FMF, Q, MaxRecurse); case Instruction::FMul: return SimplifyFMulInst(LHS, RHS, FMF, Q, MaxRecurse); + case Instruction::FDiv: + return SimplifyFDivInst(LHS, RHS, FMF, Q, MaxRecurse); default: return SimplifyBinOp(Opcode, LHS, RHS, Q, MaxRecurse); } @@ -3968,14 +4252,36 @@ static Value *SimplifyIntrinsic(Function *F, IterTy ArgBegin, IterTy ArgEnd, const Query &Q, unsigned MaxRecurse) { Intrinsic::ID IID = F->getIntrinsicID(); unsigned NumOperands = std::distance(ArgBegin, ArgEnd); - Type *ReturnType = F->getReturnType(); + + // Unary Ops + if (NumOperands == 1) { + // Perform idempotent optimizations + if (IsIdempotent(IID)) { + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(*ArgBegin)) { + if (II->getIntrinsicID() == IID) + return II; + } + } + + switch (IID) { + case Intrinsic::fabs: { + if (SignBitMustBeZero(*ArgBegin, Q.TLI)) + return *ArgBegin; + } + default: + return nullptr; + } + } // Binary Ops if (NumOperands == 2) { Value *LHS = *ArgBegin; Value *RHS = *(ArgBegin + 1); - if (IID == Intrinsic::usub_with_overflow || - IID == Intrinsic::ssub_with_overflow) { + Type *ReturnType = F->getReturnType(); + + switch (IID) { + case Intrinsic::usub_with_overflow: + case Intrinsic::ssub_with_overflow: { // X - X -> { 0, false } if (LHS == RHS) return Constant::getNullValue(ReturnType); @@ -3984,17 +4290,19 @@ static Value *SimplifyIntrinsic(Function *F, IterTy ArgBegin, IterTy ArgEnd, // undef - X -> undef if (isa<UndefValue>(LHS) || isa<UndefValue>(RHS)) return UndefValue::get(ReturnType); - } - if (IID == Intrinsic::uadd_with_overflow || - IID == Intrinsic::sadd_with_overflow) { + return nullptr; + } + case Intrinsic::uadd_with_overflow: + case Intrinsic::sadd_with_overflow: { // X + undef -> undef if (isa<UndefValue>(RHS)) return UndefValue::get(ReturnType); - } - if (IID == Intrinsic::umul_with_overflow || - IID == Intrinsic::smul_with_overflow) { + return nullptr; + } + case Intrinsic::umul_with_overflow: + case Intrinsic::smul_with_overflow: { // X * 0 -> { 0, false } if (match(RHS, m_Zero())) return Constant::getNullValue(ReturnType); @@ -4002,34 +4310,34 @@ static Value *SimplifyIntrinsic(Function *F, IterTy ArgBegin, IterTy ArgEnd, // X * undef -> { 0, false } if (match(RHS, m_Undef())) return Constant::getNullValue(ReturnType); - } - if (IID == Intrinsic::load_relative && isa<Constant>(LHS) && - isa<Constant>(RHS)) - return SimplifyRelativeLoad(cast<Constant>(LHS), cast<Constant>(RHS), - Q.DL); + return nullptr; + } + case Intrinsic::load_relative: { + Constant *C0 = dyn_cast<Constant>(LHS); + Constant *C1 = dyn_cast<Constant>(RHS); + if (C0 && C1) + return SimplifyRelativeLoad(C0, C1, Q.DL); + return nullptr; + } + default: + return nullptr; + } } // Simplify calls to llvm.masked.load.* - if (IID == Intrinsic::masked_load) { + switch (IID) { + case Intrinsic::masked_load: { Value *MaskArg = ArgBegin[2]; Value *PassthruArg = ArgBegin[3]; // If the mask is all zeros or undef, the "passthru" argument is the result. if (maskIsAllZeroOrUndef(MaskArg)) return PassthruArg; + return nullptr; } - - // Perform idempotent optimizations - if (!IsIdempotent(IID)) + default: return nullptr; - - // Unary Ops - if (NumOperands == 1) - if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(*ArgBegin)) - if (II->getIntrinsicID() == IID) - return II; - - return nullptr; + } } template <typename IterTy> @@ -4223,21 +4531,23 @@ Value *llvm::SimplifyInstruction(Instruction *I, const DataLayout &DL, TLI, DT, AC, I); break; } - case Instruction::Trunc: - Result = - SimplifyTruncInst(I->getOperand(0), I->getType(), DL, TLI, DT, AC, I); +#define HANDLE_CAST_INST(num, opc, clas) case Instruction::opc: +#include "llvm/IR/Instruction.def" +#undef HANDLE_CAST_INST + Result = SimplifyCastInst(I->getOpcode(), I->getOperand(0), I->getType(), + DL, TLI, DT, AC, I); break; } // In general, it is possible for computeKnownBits to determine all bits in a // value even when the operands are not all constants. - if (!Result && I->getType()->isIntegerTy()) { + if (!Result && I->getType()->isIntOrIntVectorTy()) { unsigned BitWidth = I->getType()->getScalarSizeInBits(); APInt KnownZero(BitWidth, 0); APInt KnownOne(BitWidth, 0); computeKnownBits(I, KnownZero, KnownOne, DL, /*Depth*/0, AC, I, DT); if ((KnownZero | KnownOne).isAllOnesValue()) - Result = ConstantInt::get(I->getContext(), KnownOne); + Result = ConstantInt::get(I->getType(), KnownOne); } /// If called on unreachable code, the above logic may report that the |
