diff options
Diffstat (limited to 'gnu/llvm/lib/Target/X86/X86ISelLowering.cpp')
| -rw-r--r-- | gnu/llvm/lib/Target/X86/X86ISelLowering.cpp | 7278 |
1 files changed, 4668 insertions, 2610 deletions
diff --git a/gnu/llvm/lib/Target/X86/X86ISelLowering.cpp b/gnu/llvm/lib/Target/X86/X86ISelLowering.cpp index 607bc4530ab..10e19f92b4a 100644 --- a/gnu/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/gnu/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -35,6 +35,7 @@ #include "llvm/CodeGen/MachineJumpTableInfo.h" #include "llvm/CodeGen/MachineModuleInfo.h" #include "llvm/CodeGen/MachineRegisterInfo.h" +#include "llvm/CodeGen/TargetLowering.h" #include "llvm/CodeGen/WinEHFuncInfo.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/CallingConv.h" @@ -55,7 +56,6 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/KnownBits.h" #include "llvm/Support/MathExtras.h" -#include "llvm/Target/TargetLowering.h" #include "llvm/Target/TargetOptions.h" #include <algorithm> #include <bitset> @@ -94,7 +94,7 @@ static void errorUnsupported(SelectionDAG &DAG, const SDLoc &dl, const char *Msg) { MachineFunction &MF = DAG.getMachineFunction(); DAG.getContext()->diagnose( - DiagnosticInfoUnsupported(*MF.getFunction(), Msg, dl.getDebugLoc())); + DiagnosticInfoUnsupported(MF.getFunction(), Msg, dl.getDebugLoc())); } X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, @@ -188,6 +188,14 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setCondCodeAction(ISD::SETUNE, MVT::f64, Expand); setCondCodeAction(ISD::SETUNE, MVT::f80, Expand); + // Integer absolute. + if (Subtarget.hasCMov()) { + setOperationAction(ISD::ABS , MVT::i16 , Custom); + setOperationAction(ISD::ABS , MVT::i32 , Custom); + if (Subtarget.is64Bit()) + setOperationAction(ISD::ABS , MVT::i64 , Custom); + } + // Promote all UINT_TO_FP to larger SINT_TO_FP's, as X86 doesn't have this // operation. setOperationAction(ISD::UINT_TO_FP , MVT::i1 , Promote); @@ -372,8 +380,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, // Special handling for half-precision floating point conversions. // If we don't have F16C support, then lower half float conversions // into library calls. - if (Subtarget.useSoftFloat() || - (!Subtarget.hasF16C() && !Subtarget.hasAVX512())) { + if (Subtarget.useSoftFloat() || !Subtarget.hasF16C()) { setOperationAction(ISD::FP16_TO_FP, MVT::f32, Expand); setOperationAction(ISD::FP_TO_FP16, MVT::f32, Expand); } @@ -392,7 +399,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setTruncStoreAction(MVT::f80, MVT::f16, Expand); if (Subtarget.hasPOPCNT()) { - setOperationAction(ISD::CTPOP , MVT::i8 , Promote); + setOperationPromotedToType(ISD::CTPOP, MVT::i8, MVT::i32); } else { setOperationAction(ISD::CTPOP , MVT::i8 , Expand); setOperationAction(ISD::CTPOP , MVT::i16 , Expand); @@ -425,12 +432,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::SELECT_CC, MVT::x86mmx, Expand); setOperationAction(ISD::EH_RETURN , MVT::Other, Custom); - // NOTE: EH_SJLJ_SETJMP/_LONGJMP supported here is NOT intended to support - // SjLj exception handling but a light-weight setjmp/longjmp replacement to - // support continuation, user-level threading, and etc.. As a result, no - // other SjLj exception interfaces are implemented and please don't build - // your own exception handling based on them. - // LLVM/Clang supports zero-cost DWARF exception handling. + // NOTE: EH_SJLJ_SETJMP/_LONGJMP are not recommended, since + // LLVM/Clang supports zero-cost DWARF and SEH exception handling. setOperationAction(ISD::EH_SJLJ_SETJMP, MVT::i32, Custom); setOperationAction(ISD::EH_SJLJ_LONGJMP, MVT::Other, Custom); setOperationAction(ISD::EH_SJLJ_SETUP_DISPATCH, MVT::Other, Custom); @@ -458,7 +461,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::SRL_PARTS, VT, Custom); } - if (Subtarget.hasSSE1()) + if (Subtarget.hasSSEPrefetch() || Subtarget.has3DNow()) setOperationAction(ISD::PREFETCH , MVT::Other, Legal); setOperationAction(ISD::ATOMIC_FENCE , MVT::Other, Custom); @@ -545,8 +548,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, } else if (UseX87 && X86ScalarSSEf32) { // Use SSE for f32, x87 for f64. // Set up the FP register classes. - addRegisterClass(MVT::f32, Subtarget.hasAVX512() ? &X86::FR32XRegClass - : &X86::FR32RegClass); + addRegisterClass(MVT::f32, &X86::FR32RegClass); addRegisterClass(MVT::f64, &X86::RFP64RegClass); // Use ANDPS to simulate FABS. @@ -573,11 +575,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, addLegalFPImmediate(APFloat(-0.0)); // FLD0/FCHS addLegalFPImmediate(APFloat(-1.0)); // FLD1/FCHS - if (!TM.Options.UnsafeFPMath) { - setOperationAction(ISD::FSIN , MVT::f64, Expand); - setOperationAction(ISD::FCOS , MVT::f64, Expand); - setOperationAction(ISD::FSINCOS, MVT::f64, Expand); - } + // Always expand sin/cos functions even though x87 has an instruction. + setOperationAction(ISD::FSIN , MVT::f64, Expand); + setOperationAction(ISD::FCOS , MVT::f64, Expand); + setOperationAction(ISD::FSINCOS, MVT::f64, Expand); } else if (UseX87) { // f32 and f64 in x87. // Set up the FP register classes. @@ -588,11 +589,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::UNDEF, VT, Expand); setOperationAction(ISD::FCOPYSIGN, VT, Expand); - if (!TM.Options.UnsafeFPMath) { - setOperationAction(ISD::FSIN , VT, Expand); - setOperationAction(ISD::FCOS , VT, Expand); - setOperationAction(ISD::FSINCOS, VT, Expand); - } + // Always expand sin/cos functions even though x87 has an instruction. + setOperationAction(ISD::FSIN , VT, Expand); + setOperationAction(ISD::FCOS , VT, Expand); + setOperationAction(ISD::FSINCOS, VT, Expand); } addLegalFPImmediate(APFloat(+0.0)); // FLD0 addLegalFPImmediate(APFloat(+1.0)); // FLD1 @@ -636,11 +636,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, addLegalFPImmediate(TmpFlt2); // FLD1/FCHS } - if (!TM.Options.UnsafeFPMath) { - setOperationAction(ISD::FSIN , MVT::f80, Expand); - setOperationAction(ISD::FCOS , MVT::f80, Expand); - setOperationAction(ISD::FSINCOS, MVT::f80, Expand); - } + // Always expand sin/cos functions even though x87 has an instruction. + setOperationAction(ISD::FSIN , MVT::f80, Expand); + setOperationAction(ISD::FCOS , MVT::f80, Expand); + setOperationAction(ISD::FSINCOS, MVT::f80, Expand); setOperationAction(ISD::FFLOOR, MVT::f80, Expand); setOperationAction(ISD::FCEIL, MVT::f80, Expand); @@ -861,8 +860,6 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::SINT_TO_FP, MVT::v4i32, Legal); setOperationAction(ISD::SINT_TO_FP, MVT::v2i32, Custom); - setOperationAction(ISD::UINT_TO_FP, MVT::v4i8, Custom); - setOperationAction(ISD::UINT_TO_FP, MVT::v4i16, Custom); setOperationAction(ISD::UINT_TO_FP, MVT::v2i32, Custom); // Fast v2f32 UINT_TO_FP( v2i32 ) custom conversion. @@ -944,6 +941,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, for (auto LoadExtOp : { ISD::SEXTLOAD, ISD::ZEXTLOAD }) { setLoadExtAction(LoadExtOp, MVT::v8i16, MVT::v8i8, Legal); setLoadExtAction(LoadExtOp, MVT::v4i32, MVT::v4i8, Legal); + setLoadExtAction(LoadExtOp, MVT::v2i32, MVT::v2i8, Legal); setLoadExtAction(LoadExtOp, MVT::v2i64, MVT::v2i8, Legal); setLoadExtAction(LoadExtOp, MVT::v4i32, MVT::v4i16, Legal); setLoadExtAction(LoadExtOp, MVT::v2i64, MVT::v2i16, Legal); @@ -998,17 +996,13 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, // (fp_to_int:v8i16 (v8f32 ..)) requires the result type to be promoted // even though v8i16 is a legal type. - setOperationAction(ISD::FP_TO_SINT, MVT::v8i16, Promote); - setOperationAction(ISD::FP_TO_UINT, MVT::v8i16, Promote); + setOperationPromotedToType(ISD::FP_TO_SINT, MVT::v8i16, MVT::v8i32); + setOperationPromotedToType(ISD::FP_TO_UINT, MVT::v8i16, MVT::v8i32); setOperationAction(ISD::FP_TO_SINT, MVT::v8i32, Legal); - setOperationAction(ISD::SINT_TO_FP, MVT::v8i16, Promote); setOperationAction(ISD::SINT_TO_FP, MVT::v8i32, Legal); setOperationAction(ISD::FP_ROUND, MVT::v4f32, Legal); - setOperationAction(ISD::UINT_TO_FP, MVT::v8i8, Custom); - setOperationAction(ISD::UINT_TO_FP, MVT::v8i16, Custom); - for (MVT VT : MVT::fp_vector_valuetypes()) setLoadExtAction(ISD::EXTLOAD, VT, MVT::v4f32, Legal); @@ -1104,7 +1098,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, // (result) is 128-bit but the source is 256-bit wide. for (auto VT : { MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v2i64, MVT::v4f32, MVT::v2f64 }) { - setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom); + setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Legal); } // Custom lower several nodes for 256-bit types. @@ -1131,6 +1125,16 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationPromotedToType(ISD::LOAD, VT, MVT::v4i64); setOperationPromotedToType(ISD::SELECT, VT, MVT::v4i64); } + + if (HasInt256) { + // Custom legalize 2x32 to get a little better code. + setOperationAction(ISD::MGATHER, MVT::v2f32, Custom); + setOperationAction(ISD::MGATHER, MVT::v2i32, Custom); + + for (auto VT : { MVT::v4i32, MVT::v8i32, MVT::v2i64, MVT::v4i64, + MVT::v4f32, MVT::v8f32, MVT::v2f64, MVT::v4f64 }) + setOperationAction(ISD::MGATHER, VT, Custom); + } } if (!Subtarget.useSoftFloat() && Subtarget.hasAVX512()) { @@ -1143,13 +1147,65 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, addRegisterClass(MVT::v8i1, &X86::VK8RegClass); addRegisterClass(MVT::v16i1, &X86::VK16RegClass); + setOperationAction(ISD::SELECT, MVT::v1i1, Custom); + setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v1i1, Custom); + setOperationAction(ISD::BUILD_VECTOR, MVT::v1i1, Custom); + + setOperationPromotedToType(ISD::SINT_TO_FP, MVT::v16i1, MVT::v16i32); + setOperationPromotedToType(ISD::UINT_TO_FP, MVT::v16i1, MVT::v16i32); + setOperationPromotedToType(ISD::SINT_TO_FP, MVT::v8i1, MVT::v8i32); + setOperationPromotedToType(ISD::UINT_TO_FP, MVT::v8i1, MVT::v8i32); + setOperationPromotedToType(ISD::SINT_TO_FP, MVT::v4i1, MVT::v4i32); + setOperationPromotedToType(ISD::UINT_TO_FP, MVT::v4i1, MVT::v4i32); + setOperationAction(ISD::SINT_TO_FP, MVT::v2i1, Custom); + setOperationAction(ISD::UINT_TO_FP, MVT::v2i1, Custom); + + setOperationPromotedToType(ISD::FP_TO_SINT, MVT::v16i1, MVT::v16i32); + setOperationPromotedToType(ISD::FP_TO_UINT, MVT::v16i1, MVT::v16i32); + setOperationPromotedToType(ISD::FP_TO_SINT, MVT::v8i1, MVT::v8i32); + setOperationPromotedToType(ISD::FP_TO_UINT, MVT::v8i1, MVT::v8i32); + setOperationPromotedToType(ISD::FP_TO_SINT, MVT::v4i1, MVT::v4i32); + setOperationPromotedToType(ISD::FP_TO_UINT, MVT::v4i1, MVT::v4i32); + if (Subtarget.hasVLX()) { + setOperationAction(ISD::FP_TO_SINT, MVT::v2i1, Custom); + setOperationAction(ISD::FP_TO_UINT, MVT::v2i1, Custom); + } + + // Extends of v16i1/v8i1 to 128-bit vectors. + setOperationAction(ISD::SIGN_EXTEND, MVT::v16i8, Custom); + setOperationAction(ISD::ZERO_EXTEND, MVT::v16i8, Custom); + setOperationAction(ISD::ANY_EXTEND, MVT::v16i8, Custom); + setOperationAction(ISD::SIGN_EXTEND, MVT::v8i16, Custom); + setOperationAction(ISD::ZERO_EXTEND, MVT::v8i16, Custom); + setOperationAction(ISD::ANY_EXTEND, MVT::v8i16, Custom); + + for (auto VT : { MVT::v8i1, MVT::v16i1 }) { + setOperationAction(ISD::ADD, VT, Custom); + setOperationAction(ISD::SUB, VT, Custom); + setOperationAction(ISD::MUL, VT, Custom); + setOperationAction(ISD::SETCC, VT, Custom); + setOperationAction(ISD::SELECT, VT, Custom); + setOperationAction(ISD::TRUNCATE, VT, Custom); + + setOperationAction(ISD::BUILD_VECTOR, VT, Custom); + setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom); + setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom); + setOperationAction(ISD::VECTOR_SHUFFLE, VT, Custom); + setOperationAction(ISD::VSELECT, VT, Expand); + } + + setOperationAction(ISD::CONCAT_VECTORS, MVT::v16i1, Custom); + setOperationAction(ISD::INSERT_SUBVECTOR, MVT::v8i1, Custom); + setOperationAction(ISD::INSERT_SUBVECTOR, MVT::v16i1, Custom); + for (auto VT : { MVT::v1i1, MVT::v8i1 }) + setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom); + for (MVT VT : MVT::fp_vector_valuetypes()) setLoadExtAction(ISD::EXTLOAD, VT, MVT::v8f32, Legal); - for (auto ExtType : {ISD::ZEXTLOAD, ISD::SEXTLOAD, ISD::EXTLOAD}) { + for (auto ExtType : {ISD::ZEXTLOAD, ISD::SEXTLOAD}) { setLoadExtAction(ExtType, MVT::v16i32, MVT::v16i8, Legal); setLoadExtAction(ExtType, MVT::v16i32, MVT::v16i16, Legal); - setLoadExtAction(ExtType, MVT::v32i16, MVT::v32i8, Legal); setLoadExtAction(ExtType, MVT::v8i64, MVT::v8i8, Legal); setLoadExtAction(ExtType, MVT::v8i64, MVT::v8i16, Legal); setLoadExtAction(ExtType, MVT::v8i64, MVT::v8i32, Legal); @@ -1173,98 +1229,32 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, } setOperationAction(ISD::FP_TO_SINT, MVT::v16i32, Legal); + setOperationPromotedToType(ISD::FP_TO_SINT, MVT::v16i16, MVT::v16i32); + setOperationPromotedToType(ISD::FP_TO_SINT, MVT::v16i8, MVT::v16i32); setOperationAction(ISD::FP_TO_UINT, MVT::v16i32, Legal); - setOperationAction(ISD::FP_TO_UINT, MVT::v8i32, Legal); - setOperationAction(ISD::FP_TO_UINT, MVT::v4i32, Legal); - setOperationAction(ISD::FP_TO_UINT, MVT::v2i32, Custom); + setOperationPromotedToType(ISD::FP_TO_UINT, MVT::v16i8, MVT::v16i32); + setOperationPromotedToType(ISD::FP_TO_UINT, MVT::v16i16, MVT::v16i32); setOperationAction(ISD::SINT_TO_FP, MVT::v16i32, Legal); - setOperationAction(ISD::SINT_TO_FP, MVT::v8i1, Custom); - setOperationAction(ISD::SINT_TO_FP, MVT::v16i1, Custom); - setOperationAction(ISD::SINT_TO_FP, MVT::v16i8, Promote); - setOperationAction(ISD::SINT_TO_FP, MVT::v16i16, Promote); setOperationAction(ISD::UINT_TO_FP, MVT::v16i32, Legal); - setOperationAction(ISD::UINT_TO_FP, MVT::v8i32, Legal); - setOperationAction(ISD::UINT_TO_FP, MVT::v4i32, Legal); - setOperationAction(ISD::UINT_TO_FP, MVT::v16i8, Custom); - setOperationAction(ISD::UINT_TO_FP, MVT::v16i16, Custom); - setOperationAction(ISD::SINT_TO_FP, MVT::v16i1, Custom); - setOperationAction(ISD::UINT_TO_FP, MVT::v16i1, Custom); - setOperationAction(ISD::SINT_TO_FP, MVT::v8i1, Custom); - setOperationAction(ISD::UINT_TO_FP, MVT::v8i1, Custom); - setOperationAction(ISD::SINT_TO_FP, MVT::v4i1, Custom); - setOperationAction(ISD::UINT_TO_FP, MVT::v4i1, Custom); - setOperationAction(ISD::SINT_TO_FP, MVT::v2i1, Custom); - setOperationAction(ISD::UINT_TO_FP, MVT::v2i1, Custom); - setOperationAction(ISD::FP_ROUND, MVT::v8f32, Legal); - setOperationAction(ISD::FP_EXTEND, MVT::v8f32, Legal); setTruncStoreAction(MVT::v8i64, MVT::v8i8, Legal); setTruncStoreAction(MVT::v8i64, MVT::v8i16, Legal); setTruncStoreAction(MVT::v8i64, MVT::v8i32, Legal); setTruncStoreAction(MVT::v16i32, MVT::v16i8, Legal); setTruncStoreAction(MVT::v16i32, MVT::v16i16, Legal); - if (Subtarget.hasVLX()){ - setTruncStoreAction(MVT::v4i64, MVT::v4i8, Legal); - setTruncStoreAction(MVT::v4i64, MVT::v4i16, Legal); - setTruncStoreAction(MVT::v4i64, MVT::v4i32, Legal); - setTruncStoreAction(MVT::v8i32, MVT::v8i8, Legal); - setTruncStoreAction(MVT::v8i32, MVT::v8i16, Legal); - - setTruncStoreAction(MVT::v2i64, MVT::v2i8, Legal); - setTruncStoreAction(MVT::v2i64, MVT::v2i16, Legal); - setTruncStoreAction(MVT::v2i64, MVT::v2i32, Legal); - setTruncStoreAction(MVT::v4i32, MVT::v4i8, Legal); - setTruncStoreAction(MVT::v4i32, MVT::v4i16, Legal); - } else { + + if (!Subtarget.hasVLX()) { + // With 512-bit vectors and no VLX, we prefer to widen MLOAD/MSTORE + // to 512-bit rather than use the AVX2 instructions so that we can use + // k-masks. for (auto VT : {MVT::v4i32, MVT::v8i32, MVT::v2i64, MVT::v4i64, MVT::v4f32, MVT::v8f32, MVT::v2f64, MVT::v4f64}) { setOperationAction(ISD::MLOAD, VT, Custom); setOperationAction(ISD::MSTORE, VT, Custom); } } - setOperationAction(ISD::TRUNCATE, MVT::v16i8, Custom); - setOperationAction(ISD::TRUNCATE, MVT::v8i32, Custom); - - if (Subtarget.hasDQI()) { - for (auto VT : { MVT::v2i64, MVT::v4i64, MVT::v8i64 }) { - setOperationAction(ISD::SINT_TO_FP, VT, Legal); - setOperationAction(ISD::UINT_TO_FP, VT, Legal); - setOperationAction(ISD::FP_TO_SINT, VT, Legal); - setOperationAction(ISD::FP_TO_UINT, VT, Legal); - } - if (Subtarget.hasVLX()) { - // Fast v2f32 SINT_TO_FP( v2i32 ) custom conversion. - setOperationAction(ISD::SINT_TO_FP, MVT::v2f32, Custom); - setOperationAction(ISD::FP_TO_SINT, MVT::v2f32, Custom); - setOperationAction(ISD::FP_TO_UINT, MVT::v2f32, Custom); - } - } - if (Subtarget.hasVLX()) { - setOperationAction(ISD::SINT_TO_FP, MVT::v8i32, Legal); - setOperationAction(ISD::UINT_TO_FP, MVT::v8i32, Legal); - setOperationAction(ISD::FP_TO_SINT, MVT::v8i32, Legal); - setOperationAction(ISD::FP_TO_UINT, MVT::v8i32, Legal); - setOperationAction(ISD::SINT_TO_FP, MVT::v4i32, Legal); - setOperationAction(ISD::FP_TO_SINT, MVT::v4i32, Legal); - setOperationAction(ISD::FP_TO_UINT, MVT::v4i32, Legal); - setOperationAction(ISD::ZERO_EXTEND, MVT::v4i32, Custom); - setOperationAction(ISD::ZERO_EXTEND, MVT::v2i64, Custom); - setOperationAction(ISD::SIGN_EXTEND, MVT::v4i32, Custom); - setOperationAction(ISD::SIGN_EXTEND, MVT::v2i64, Custom); - - // FIXME. This commands are available on SSE/AVX2, add relevant patterns. - setLoadExtAction(ISD::EXTLOAD, MVT::v8i32, MVT::v8i8, Legal); - setLoadExtAction(ISD::EXTLOAD, MVT::v8i32, MVT::v8i16, Legal); - setLoadExtAction(ISD::EXTLOAD, MVT::v4i32, MVT::v4i8, Legal); - setLoadExtAction(ISD::EXTLOAD, MVT::v4i32, MVT::v4i16, Legal); - setLoadExtAction(ISD::EXTLOAD, MVT::v4i64, MVT::v4i8, Legal); - setLoadExtAction(ISD::EXTLOAD, MVT::v4i64, MVT::v4i16, Legal); - setLoadExtAction(ISD::EXTLOAD, MVT::v4i64, MVT::v4i32, Legal); - setLoadExtAction(ISD::EXTLOAD, MVT::v2i64, MVT::v2i8, Legal); - setLoadExtAction(ISD::EXTLOAD, MVT::v2i64, MVT::v2i16, Legal); - setLoadExtAction(ISD::EXTLOAD, MVT::v2i64, MVT::v2i32, Legal); - } + setOperationAction(ISD::TRUNCATE, MVT::v8i32, Custom); setOperationAction(ISD::TRUNCATE, MVT::v16i16, Custom); setOperationAction(ISD::ZERO_EXTEND, MVT::v16i32, Custom); setOperationAction(ISD::ZERO_EXTEND, MVT::v8i64, Custom); @@ -1272,9 +1262,6 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::ANY_EXTEND, MVT::v8i64, Custom); setOperationAction(ISD::SIGN_EXTEND, MVT::v16i32, Custom); setOperationAction(ISD::SIGN_EXTEND, MVT::v8i64, Custom); - setOperationAction(ISD::SIGN_EXTEND, MVT::v16i8, Custom); - setOperationAction(ISD::SIGN_EXTEND, MVT::v8i16, Custom); - setOperationAction(ISD::SIGN_EXTEND, MVT::v16i16, Custom); for (auto VT : { MVT::v16f32, MVT::v8f64 }) { setOperationAction(ISD::FFLOOR, VT, Legal); @@ -1295,38 +1282,17 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::CONCAT_VECTORS, MVT::v8i64, Custom); setOperationAction(ISD::CONCAT_VECTORS, MVT::v16f32, Custom); setOperationAction(ISD::CONCAT_VECTORS, MVT::v16i32, Custom); - setOperationAction(ISD::CONCAT_VECTORS, MVT::v16i1, Custom); setOperationAction(ISD::MUL, MVT::v8i64, Custom); + setOperationAction(ISD::MUL, MVT::v16i32, Legal); + + setOperationAction(ISD::UMUL_LOHI, MVT::v16i32, Custom); + setOperationAction(ISD::SMUL_LOHI, MVT::v16i32, Custom); - setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v1i1, Custom); - setOperationAction(ISD::INSERT_SUBVECTOR, MVT::v16i1, Custom); - setOperationAction(ISD::BUILD_VECTOR, MVT::v1i1, Custom); setOperationAction(ISD::SELECT, MVT::v8f64, Custom); setOperationAction(ISD::SELECT, MVT::v8i64, Custom); setOperationAction(ISD::SELECT, MVT::v16f32, Custom); - setOperationAction(ISD::MUL, MVT::v16i32, Legal); - - // NonVLX sub-targets extend 128/256 vectors to use the 512 version. - setOperationAction(ISD::ABS, MVT::v4i64, Legal); - setOperationAction(ISD::ABS, MVT::v2i64, Legal); - - for (auto VT : { MVT::v8i1, MVT::v16i1 }) { - setOperationAction(ISD::ADD, VT, Custom); - setOperationAction(ISD::SUB, VT, Custom); - setOperationAction(ISD::MUL, VT, Custom); - setOperationAction(ISD::SETCC, VT, Custom); - setOperationAction(ISD::SELECT, VT, Custom); - setOperationAction(ISD::TRUNCATE, VT, Custom); - - setOperationAction(ISD::BUILD_VECTOR, VT, Custom); - setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom); - setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom); - setOperationAction(ISD::VECTOR_SHUFFLE, VT, Custom); - setOperationAction(ISD::VSELECT, VT, Expand); - } - for (auto VT : { MVT::v16i32, MVT::v8i64 }) { setOperationAction(ISD::SMAX, VT, Legal); setOperationAction(ISD::UMAX, VT, Legal); @@ -1338,11 +1304,6 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::SRA, VT, Custom); setOperationAction(ISD::CTPOP, VT, Custom); setOperationAction(ISD::CTTZ, VT, Custom); - } - - // NonVLX sub-targets extend 128/256 vectors to use the 512 version. - for (auto VT : {MVT::v4i32, MVT::v8i32, MVT::v16i32, MVT::v2i64, MVT::v4i64, - MVT::v8i64}) { setOperationAction(ISD::ROTL, VT, Custom); setOperationAction(ISD::ROTR, VT, Custom); } @@ -1354,44 +1315,31 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationPromotedToType(ISD::OR, MVT::v16i32, MVT::v8i64); setOperationPromotedToType(ISD::XOR, MVT::v16i32, MVT::v8i64); + if (Subtarget.hasDQI()) { + setOperationAction(ISD::SINT_TO_FP, MVT::v8i64, Legal); + setOperationAction(ISD::UINT_TO_FP, MVT::v8i64, Legal); + setOperationAction(ISD::FP_TO_SINT, MVT::v8i64, Legal); + setOperationAction(ISD::FP_TO_UINT, MVT::v8i64, Legal); + } + if (Subtarget.hasCDI()) { // NonVLX sub-targets extend 128/256 vectors to use the 512 version. - for (auto VT : {MVT::v4i32, MVT::v8i32, MVT::v16i32, MVT::v2i64, - MVT::v4i64, MVT::v8i64}) { + for (auto VT : { MVT::v16i32, MVT::v8i64} ) { setOperationAction(ISD::CTLZ, VT, Legal); setOperationAction(ISD::CTTZ_ZERO_UNDEF, VT, Custom); } } // Subtarget.hasCDI() - if (Subtarget.hasDQI()) { - // NonVLX sub-targets extend 128/256 vectors to use the 512 version. - setOperationAction(ISD::MUL, MVT::v2i64, Legal); - setOperationAction(ISD::MUL, MVT::v4i64, Legal); - setOperationAction(ISD::MUL, MVT::v8i64, Legal); - } - if (Subtarget.hasVPOPCNTDQ()) { - // VPOPCNTDQ sub-targets extend 128/256 vectors to use the avx512 - // version of popcntd/q. - for (auto VT : {MVT::v16i32, MVT::v8i64, MVT::v8i32, MVT::v4i64, - MVT::v4i32, MVT::v2i64}) + for (auto VT : { MVT::v16i32, MVT::v8i64 }) setOperationAction(ISD::CTPOP, VT, Legal); } - // Custom lower several nodes. - for (auto VT : { MVT::v4i32, MVT::v8i32, MVT::v2i64, MVT::v4i64, - MVT::v4f32, MVT::v8f32, MVT::v2f64, MVT::v4f64 }) { - setOperationAction(ISD::MGATHER, VT, Custom); - setOperationAction(ISD::MSCATTER, VT, Custom); - } // Extract subvector is special because the value type // (result) is 256-bit but the source is 512-bit wide. - // 128-bit was made Custom under AVX1. + // 128-bit was made Legal under AVX1. for (auto VT : { MVT::v32i8, MVT::v16i16, MVT::v8i32, MVT::v4i64, - MVT::v8f32, MVT::v4f64, MVT::v1i1 }) - setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom); - for (auto VT : { MVT::v2i1, MVT::v4i1, MVT::v8i1, - MVT::v16i1, MVT::v32i1, MVT::v64i1 }) + MVT::v8f32, MVT::v4f64 }) setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Legal); for (auto VT : { MVT::v16i32, MVT::v8i64, MVT::v16f32, MVT::v8f64 }) { @@ -1404,7 +1352,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::INSERT_SUBVECTOR, VT, Legal); setOperationAction(ISD::MLOAD, VT, Legal); setOperationAction(ISD::MSTORE, VT, Legal); - setOperationAction(ISD::MGATHER, VT, Legal); + setOperationAction(ISD::MGATHER, VT, Custom); setOperationAction(ISD::MSCATTER, VT, Custom); } for (auto VT : { MVT::v64i8, MVT::v32i16, MVT::v16i32 }) { @@ -1413,6 +1361,57 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, } }// has AVX-512 + if (!Subtarget.useSoftFloat() && + (Subtarget.hasAVX512() || Subtarget.hasVLX())) { + // These operations are handled on non-VLX by artificially widening in + // isel patterns. + // TODO: Custom widen in lowering on non-VLX and drop the isel patterns? + + setOperationAction(ISD::FP_TO_UINT, MVT::v8i32, Legal); + setOperationAction(ISD::FP_TO_UINT, MVT::v4i32, Legal); + setOperationAction(ISD::FP_TO_UINT, MVT::v2i32, Custom); + setOperationAction(ISD::UINT_TO_FP, MVT::v8i32, Legal); + setOperationAction(ISD::UINT_TO_FP, MVT::v4i32, Legal); + + for (auto VT : { MVT::v2i64, MVT::v4i64 }) { + setOperationAction(ISD::SMAX, VT, Legal); + setOperationAction(ISD::UMAX, VT, Legal); + setOperationAction(ISD::SMIN, VT, Legal); + setOperationAction(ISD::UMIN, VT, Legal); + setOperationAction(ISD::ABS, VT, Legal); + } + + for (auto VT : { MVT::v4i32, MVT::v8i32, MVT::v2i64, MVT::v4i64 }) { + setOperationAction(ISD::ROTL, VT, Custom); + setOperationAction(ISD::ROTR, VT, Custom); + } + + for (auto VT : { MVT::v4i32, MVT::v8i32, MVT::v2i64, MVT::v4i64, + MVT::v4f32, MVT::v8f32, MVT::v2f64, MVT::v4f64 }) + setOperationAction(ISD::MSCATTER, VT, Custom); + + if (Subtarget.hasDQI()) { + for (auto VT : { MVT::v2i64, MVT::v4i64 }) { + setOperationAction(ISD::SINT_TO_FP, VT, Legal); + setOperationAction(ISD::UINT_TO_FP, VT, Legal); + setOperationAction(ISD::FP_TO_SINT, VT, Legal); + setOperationAction(ISD::FP_TO_UINT, VT, Legal); + } + } + + if (Subtarget.hasCDI()) { + for (auto VT : { MVT::v4i32, MVT::v8i32, MVT::v2i64, MVT::v4i64 }) { + setOperationAction(ISD::CTLZ, VT, Legal); + setOperationAction(ISD::CTTZ_ZERO_UNDEF, VT, Custom); + } + } // Subtarget.hasCDI() + + if (Subtarget.hasVPOPCNTDQ()) { + for (auto VT : { MVT::v4i32, MVT::v8i32, MVT::v2i64, MVT::v4i64 }) + setOperationAction(ISD::CTPOP, VT, Legal); + } + } + if (!Subtarget.useSoftFloat() && Subtarget.hasBWI()) { addRegisterClass(MVT::v32i16, &X86::VR512RegClass); addRegisterClass(MVT::v64i8, &X86::VR512RegClass); @@ -1420,77 +1419,64 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, addRegisterClass(MVT::v32i1, &X86::VK32RegClass); addRegisterClass(MVT::v64i1, &X86::VK64RegClass); - setOperationAction(ISD::ADD, MVT::v32i1, Custom); - setOperationAction(ISD::ADD, MVT::v64i1, Custom); - setOperationAction(ISD::SUB, MVT::v32i1, Custom); - setOperationAction(ISD::SUB, MVT::v64i1, Custom); - setOperationAction(ISD::MUL, MVT::v32i1, Custom); - setOperationAction(ISD::MUL, MVT::v64i1, Custom); + for (auto VT : { MVT::v32i1, MVT::v64i1 }) { + setOperationAction(ISD::ADD, VT, Custom); + setOperationAction(ISD::SUB, VT, Custom); + setOperationAction(ISD::MUL, VT, Custom); + setOperationAction(ISD::VSELECT, VT, Expand); + + setOperationAction(ISD::TRUNCATE, VT, Custom); + setOperationAction(ISD::SETCC, VT, Custom); + setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom); + setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom); + setOperationAction(ISD::SELECT, VT, Custom); + setOperationAction(ISD::BUILD_VECTOR, VT, Custom); + setOperationAction(ISD::VECTOR_SHUFFLE, VT, Custom); + } + + setOperationAction(ISD::CONCAT_VECTORS, MVT::v32i1, Custom); + setOperationAction(ISD::CONCAT_VECTORS, MVT::v64i1, Custom); + setOperationAction(ISD::INSERT_SUBVECTOR, MVT::v32i1, Custom); + setOperationAction(ISD::INSERT_SUBVECTOR, MVT::v64i1, Custom); + for (auto VT : { MVT::v16i1, MVT::v32i1 }) + setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom); + + // Extends from v32i1 masks to 256-bit vectors. + setOperationAction(ISD::SIGN_EXTEND, MVT::v32i8, Custom); + setOperationAction(ISD::ZERO_EXTEND, MVT::v32i8, Custom); + setOperationAction(ISD::ANY_EXTEND, MVT::v32i8, Custom); + // Extends from v64i1 masks to 512-bit vectors. + setOperationAction(ISD::SIGN_EXTEND, MVT::v64i8, Custom); + setOperationAction(ISD::ZERO_EXTEND, MVT::v64i8, Custom); + setOperationAction(ISD::ANY_EXTEND, MVT::v64i8, Custom); - setOperationAction(ISD::SETCC, MVT::v32i1, Custom); - setOperationAction(ISD::SETCC, MVT::v64i1, Custom); setOperationAction(ISD::MUL, MVT::v32i16, Legal); setOperationAction(ISD::MUL, MVT::v64i8, Custom); setOperationAction(ISD::MULHS, MVT::v32i16, Legal); setOperationAction(ISD::MULHU, MVT::v32i16, Legal); - setOperationAction(ISD::CONCAT_VECTORS, MVT::v32i1, Custom); - setOperationAction(ISD::CONCAT_VECTORS, MVT::v64i1, Custom); + setOperationAction(ISD::MULHS, MVT::v64i8, Custom); + setOperationAction(ISD::MULHU, MVT::v64i8, Custom); setOperationAction(ISD::CONCAT_VECTORS, MVT::v32i16, Custom); setOperationAction(ISD::CONCAT_VECTORS, MVT::v64i8, Custom); - setOperationAction(ISD::INSERT_SUBVECTOR, MVT::v32i1, Custom); - setOperationAction(ISD::INSERT_SUBVECTOR, MVT::v64i1, Custom); setOperationAction(ISD::INSERT_SUBVECTOR, MVT::v32i16, Legal); setOperationAction(ISD::INSERT_SUBVECTOR, MVT::v64i8, Legal); setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v32i16, Custom); setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v64i8, Custom); - setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v32i1, Custom); - setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v64i1, Custom); setOperationAction(ISD::SCALAR_TO_VECTOR, MVT::v32i16, Custom); setOperationAction(ISD::SCALAR_TO_VECTOR, MVT::v64i8, Custom); - setOperationAction(ISD::SELECT, MVT::v32i1, Custom); - setOperationAction(ISD::SELECT, MVT::v64i1, Custom); - setOperationAction(ISD::SIGN_EXTEND, MVT::v32i8, Custom); - setOperationAction(ISD::ZERO_EXTEND, MVT::v32i8, Custom); setOperationAction(ISD::SIGN_EXTEND, MVT::v32i16, Custom); setOperationAction(ISD::ZERO_EXTEND, MVT::v32i16, Custom); setOperationAction(ISD::ANY_EXTEND, MVT::v32i16, Custom); setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v32i16, Custom); setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v64i8, Custom); - setOperationAction(ISD::SIGN_EXTEND, MVT::v64i8, Custom); - setOperationAction(ISD::ZERO_EXTEND, MVT::v64i8, Custom); - setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v32i1, Custom); - setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v64i1, Custom); setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v32i16, Custom); setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v64i8, Custom); - setOperationAction(ISD::TRUNCATE, MVT::v32i1, Custom); - setOperationAction(ISD::TRUNCATE, MVT::v64i1, Custom); setOperationAction(ISD::TRUNCATE, MVT::v32i8, Custom); - setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v32i1, Custom); - setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v64i1, Custom); - setOperationAction(ISD::BUILD_VECTOR, MVT::v32i1, Custom); - setOperationAction(ISD::BUILD_VECTOR, MVT::v64i1, Custom); - setOperationAction(ISD::VSELECT, MVT::v32i1, Expand); - setOperationAction(ISD::VSELECT, MVT::v64i1, Expand); setOperationAction(ISD::BITREVERSE, MVT::v64i8, Custom); setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, MVT::v32i16, Custom); setTruncStoreAction(MVT::v32i16, MVT::v32i8, Legal); - if (Subtarget.hasVLX()) { - setTruncStoreAction(MVT::v16i16, MVT::v16i8, Legal); - setTruncStoreAction(MVT::v8i16, MVT::v8i8, Legal); - } - - LegalizeAction Action = Subtarget.hasVLX() ? Legal : Custom; - for (auto VT : { MVT::v32i8, MVT::v16i8, MVT::v16i16, MVT::v8i16 }) { - setOperationAction(ISD::MLOAD, VT, Action); - setOperationAction(ISD::MSTORE, VT, Action); - } - - if (Subtarget.hasCDI()) { - setOperationAction(ISD::CTLZ, MVT::v32i16, Custom); - setOperationAction(ISD::CTLZ, MVT::v64i8, Custom); - } for (auto VT : { MVT::v64i8, MVT::v32i16 }) { setOperationAction(ISD::BUILD_VECTOR, VT, Custom); @@ -1503,6 +1489,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::MSTORE, VT, Legal); setOperationAction(ISD::CTPOP, VT, Custom); setOperationAction(ISD::CTTZ, VT, Custom); + setOperationAction(ISD::CTLZ, VT, Custom); setOperationAction(ISD::SMAX, VT, Legal); setOperationAction(ISD::UMAX, VT, Legal); setOperationAction(ISD::SMIN, VT, Legal); @@ -1513,13 +1500,30 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationPromotedToType(ISD::XOR, VT, MVT::v8i64); } - for (auto ExtType : {ISD::ZEXTLOAD, ISD::SEXTLOAD, ISD::EXTLOAD}) { + for (auto ExtType : {ISD::ZEXTLOAD, ISD::SEXTLOAD}) { setLoadExtAction(ExtType, MVT::v32i16, MVT::v32i8, Legal); - if (Subtarget.hasVLX()) { - // FIXME. This commands are available on SSE/AVX2, add relevant patterns. - setLoadExtAction(ExtType, MVT::v16i16, MVT::v16i8, Legal); - setLoadExtAction(ExtType, MVT::v8i16, MVT::v8i8, Legal); - } + } + + if (Subtarget.hasBITALG()) { + for (auto VT : { MVT::v64i8, MVT::v32i16 }) + setOperationAction(ISD::CTPOP, VT, Legal); + } + } + + if (!Subtarget.useSoftFloat() && Subtarget.hasBWI() && + (Subtarget.hasAVX512() || Subtarget.hasVLX())) { + for (auto VT : { MVT::v32i8, MVT::v16i8, MVT::v16i16, MVT::v8i16 }) { + setOperationAction(ISD::MLOAD, VT, Subtarget.hasVLX() ? Legal : Custom); + setOperationAction(ISD::MSTORE, VT, Subtarget.hasVLX() ? Legal : Custom); + } + + // These operations are handled on non-VLX by artificially widening in + // isel patterns. + // TODO: Custom widen in lowering on non-VLX and drop the isel patterns? + + if (Subtarget.hasBITALG()) { + for (auto VT : { MVT::v16i8, MVT::v32i8, MVT::v8i16, MVT::v16i16 }) + setOperationAction(ISD::CTPOP, VT, Legal); } } @@ -1542,16 +1546,49 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(ISD::VECTOR_SHUFFLE, VT, Custom); } + // TODO: v8i1 concat should be legal without VLX to support concats of + // v1i1, but we won't legalize it correctly currently without introducing + // a v4i1 concat in the middle. setOperationAction(ISD::CONCAT_VECTORS, MVT::v8i1, Custom); setOperationAction(ISD::CONCAT_VECTORS, MVT::v4i1, Custom); - setOperationAction(ISD::INSERT_SUBVECTOR, MVT::v8i1, Custom); setOperationAction(ISD::INSERT_SUBVECTOR, MVT::v4i1, Custom); + for (auto VT : { MVT::v2i1, MVT::v4i1 }) + setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom); - for (auto VT : { MVT::v2i64, MVT::v4i64 }) { - setOperationAction(ISD::SMAX, VT, Legal); - setOperationAction(ISD::UMAX, VT, Legal); - setOperationAction(ISD::SMIN, VT, Legal); - setOperationAction(ISD::UMIN, VT, Legal); + // Extends from v2i1/v4i1 masks to 128-bit vectors. + setOperationAction(ISD::ZERO_EXTEND, MVT::v4i32, Custom); + setOperationAction(ISD::ZERO_EXTEND, MVT::v2i64, Custom); + setOperationAction(ISD::SIGN_EXTEND, MVT::v4i32, Custom); + setOperationAction(ISD::SIGN_EXTEND, MVT::v2i64, Custom); + setOperationAction(ISD::ANY_EXTEND, MVT::v4i32, Custom); + setOperationAction(ISD::ANY_EXTEND, MVT::v2i64, Custom); + + setTruncStoreAction(MVT::v4i64, MVT::v4i8, Legal); + setTruncStoreAction(MVT::v4i64, MVT::v4i16, Legal); + setTruncStoreAction(MVT::v4i64, MVT::v4i32, Legal); + setTruncStoreAction(MVT::v8i32, MVT::v8i8, Legal); + setTruncStoreAction(MVT::v8i32, MVT::v8i16, Legal); + + setTruncStoreAction(MVT::v2i64, MVT::v2i8, Legal); + setTruncStoreAction(MVT::v2i64, MVT::v2i16, Legal); + setTruncStoreAction(MVT::v2i64, MVT::v2i32, Legal); + setTruncStoreAction(MVT::v4i32, MVT::v4i8, Legal); + setTruncStoreAction(MVT::v4i32, MVT::v4i16, Legal); + + if (Subtarget.hasDQI()) { + // Fast v2f32 SINT_TO_FP( v2i64 ) custom conversion. + // v2f32 UINT_TO_FP is already custom under SSE2. + setOperationAction(ISD::SINT_TO_FP, MVT::v2f32, Custom); + assert(isOperationCustom(ISD::UINT_TO_FP, MVT::v2f32) && + "Unexpected operation action!"); + // v2i64 FP_TO_S/UINT(v2f32) custom conversion. + setOperationAction(ISD::FP_TO_SINT, MVT::v2f32, Custom); + setOperationAction(ISD::FP_TO_UINT, MVT::v2f32, Custom); + } + + if (Subtarget.hasBWI()) { + setTruncStoreAction(MVT::v16i16, MVT::v16i8, Legal); + setTruncStoreAction(MVT::v8i16, MVT::v8i8, Legal); } } @@ -1592,18 +1629,14 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setLibcallName(RTLIB::SHL_I128, nullptr); setLibcallName(RTLIB::SRL_I128, nullptr); setLibcallName(RTLIB::SRA_I128, nullptr); + setLibcallName(RTLIB::MUL_I128, nullptr); } - // Combine sin / cos into one node or libcall if possible. - if (Subtarget.hasSinCos()) { - setLibcallName(RTLIB::SINCOS_F32, "sincosf"); - setLibcallName(RTLIB::SINCOS_F64, "sincos"); - if (Subtarget.isTargetDarwin()) { - // For MacOSX, we don't want the normal expansion of a libcall to sincos. - // We want to issue a libcall to __sincos_stret to avoid memory traffic. - setOperationAction(ISD::FSINCOS, MVT::f64, Custom); - setOperationAction(ISD::FSINCOS, MVT::f32, Custom); - } + // Combine sin / cos into _sincos_stret if it is available. + if (getLibcallName(RTLIB::SINCOS_STRET_F32) != nullptr && + getLibcallName(RTLIB::SINCOS_STRET_F64) != nullptr) { + setOperationAction(ISD::FSINCOS, MVT::f64, Custom); + setOperationAction(ISD::FSINCOS, MVT::f32, Custom); } if (Subtarget.isTargetWin64()) { @@ -1631,6 +1664,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setTargetDAGCombine(ISD::VECTOR_SHUFFLE); setTargetDAGCombine(ISD::EXTRACT_VECTOR_ELT); setTargetDAGCombine(ISD::INSERT_SUBVECTOR); + setTargetDAGCombine(ISD::EXTRACT_SUBVECTOR); setTargetDAGCombine(ISD::BITCAST); setTargetDAGCombine(ISD::VSELECT); setTargetDAGCombine(ISD::SELECT); @@ -1698,6 +1732,19 @@ bool X86TargetLowering::useLoadStackGuardNode() const { return Subtarget.isTargetMachO() && Subtarget.is64Bit(); } +bool X86TargetLowering::useStackGuardXorFP() const { + // Currently only MSVC CRTs XOR the frame pointer into the stack guard value. + return Subtarget.getTargetTriple().isOSMSVCRT(); +} + +SDValue X86TargetLowering::emitStackGuardXorFP(SelectionDAG &DAG, SDValue Val, + const SDLoc &DL) const { + EVT PtrTy = getPointerTy(DAG.getDataLayout()); + unsigned XorOp = Subtarget.is64Bit() ? X86::XOR64_FP : X86::XOR32_FP; + MachineSDNode *Node = DAG.getMachineNode(XorOp, DL, PtrTy, Val); + return SDValue(Node, 0); +} + TargetLoweringBase::LegalizeTypeAction X86TargetLowering::getPreferredVectorAction(EVT VT) const { if (ExperimentalVectorWideningLegalization && @@ -1714,40 +1761,26 @@ EVT X86TargetLowering::getSetCCResultType(const DataLayout &DL, if (!VT.isVector()) return MVT::i8; - if (VT.isSimple()) { - MVT VVT = VT.getSimpleVT(); - const unsigned NumElts = VVT.getVectorNumElements(); - MVT EltVT = VVT.getVectorElementType(); - if (VVT.is512BitVector()) { - if (Subtarget.hasAVX512()) - if (EltVT == MVT::i32 || EltVT == MVT::i64 || - EltVT == MVT::f32 || EltVT == MVT::f64) - switch(NumElts) { - case 8: return MVT::v8i1; - case 16: return MVT::v16i1; - } - if (Subtarget.hasBWI()) - if (EltVT == MVT::i8 || EltVT == MVT::i16) - switch(NumElts) { - case 32: return MVT::v32i1; - case 64: return MVT::v64i1; - } - } + if (Subtarget.hasAVX512()) { + const unsigned NumElts = VT.getVectorNumElements(); - if (Subtarget.hasBWI() && Subtarget.hasVLX()) - return MVT::getVectorVT(MVT::i1, NumElts); + // Figure out what this type will be legalized to. + EVT LegalVT = VT; + while (getTypeAction(Context, LegalVT) != TypeLegal) + LegalVT = getTypeToTransformTo(Context, LegalVT); - if (!isTypeLegal(VT) && getTypeAction(Context, VT) == TypePromoteInteger) { - EVT LegalVT = getTypeToTransformTo(Context, VT); - EltVT = LegalVT.getVectorElementType().getSimpleVT(); - } + // If we got a 512-bit vector then we'll definitely have a vXi1 compare. + if (LegalVT.getSimpleVT().is512BitVector()) + return EVT::getVectorVT(Context, MVT::i1, NumElts); - if (Subtarget.hasVLX() && EltVT.getSizeInBits() >= 32) - switch(NumElts) { - case 2: return MVT::v2i1; - case 4: return MVT::v4i1; - case 8: return MVT::v8i1; - } + if (LegalVT.getSimpleVT().isVector() && Subtarget.hasVLX()) { + // If we legalized to less than a 512-bit vector, then we will use a vXi1 + // compare for vXi32/vXi64 for sure. If we have BWI we will also support + // vXi16/vXi8. + MVT EltVT = LegalVT.getSimpleVT().getVectorElementType(); + if (Subtarget.hasBWI() || EltVT.getSizeInBits() >= 32) + return EVT::getVectorVT(Context, MVT::i1, NumElts); + } } return VT.changeVectorElementTypeToInteger(); @@ -1815,8 +1848,8 @@ X86TargetLowering::getOptimalMemOpType(uint64_t Size, bool IsMemset, bool ZeroMemset, bool MemcpyStrSrc, MachineFunction &MF) const { - const Function *F = MF.getFunction(); - if (!F->hasFnAttribute(Attribute::NoImplicitFloat)) { + const Function &F = MF.getFunction(); + if (!F.hasFnAttribute(Attribute::NoImplicitFloat)) { if (Size >= 16 && (!Subtarget.isUnalignedMem16Slow() || ((DstAlign == 0 || DstAlign >= 16) && @@ -1912,7 +1945,7 @@ void X86TargetLowering::markLibCallAttributes(MachineFunction *MF, unsigned CC, if (CC != CallingConv::C && CC != CallingConv::X86_StdCall) return; unsigned ParamRegs = 0; - if (auto *M = MF->getFunction()->getParent()) + if (auto *M = MF->getFunction().getParent()) ParamRegs = M->getNumberRegisterParameters(); // Mark the first N int arguments as having reg @@ -2017,7 +2050,7 @@ Value *X86TargetLowering::getIRStackGuard(IRBuilder<> &IRB) const { // sysdeps/{i386,x86_64}/nptl/tls.h) if (hasStackGuardSlotTLS(Subtarget.getTargetTriple())) { if (Subtarget.isTargetFuchsia()) { - // <magenta/tls.h> defines MX_TLS_STACK_GUARD_OFFSET with this value. + // <zircon/tls.h> defines ZX_TLS_STACK_GUARD_OFFSET with this value. return SegmentOffset(IRB, 0x10, getAddressSpace()); } else { // %fs:0x28, unless we're using a Kernel code model, in which case @@ -2082,7 +2115,7 @@ Value *X86TargetLowering::getSafeStackPointerLocation(IRBuilder<> &IRB) const { // Fuchsia is similar. if (Subtarget.isTargetFuchsia()) { - // <magenta/tls.h> defines MX_TLS_UNSAFE_SP_OFFSET with this value. + // <zircon/tls.h> defines ZX_TLS_UNSAFE_SP_OFFSET with this value. return SegmentOffset(IRB, 0x18, getAddressSpace()); } @@ -2121,6 +2154,10 @@ static SDValue lowerMasksToReg(const SDValue &ValArg, const EVT &ValLoc, const SDLoc &Dl, SelectionDAG &DAG) { EVT ValVT = ValArg.getValueType(); + if (ValVT == MVT::v1i1) + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, Dl, ValLoc, ValArg, + DAG.getIntPtrConstant(0, Dl)); + if ((ValVT == MVT::v8i1 && (ValLoc == MVT::i8 || ValLoc == MVT::i32)) || (ValVT == MVT::v16i1 && (ValLoc == MVT::i16 || ValLoc == MVT::i32))) { // Two stage lowering might be required @@ -2145,8 +2182,7 @@ static void Passv64i1ArgInRegs( const SDLoc &Dl, SelectionDAG &DAG, SDValue Chain, SDValue &Arg, SmallVector<std::pair<unsigned, SDValue>, 8> &RegsToPass, CCValAssign &VA, CCValAssign &NextVA, const X86Subtarget &Subtarget) { - assert((Subtarget.hasBWI() || Subtarget.hasBMI()) && - "Expected AVX512BW or AVX512BMI target!"); + assert(Subtarget.hasBWI() && "Expected AVX512BW target!"); assert(Subtarget.is32Bit() && "Expecting 32 bit target"); assert(Arg.getValueType() == MVT::i64 && "Expecting 64 bit value"); assert(VA.isRegLoc() && NextVA.isRegLoc() && @@ -2180,7 +2216,7 @@ X86TargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv, // For example, when they are used for argument passing. bool ShouldDisableCalleeSavedRegister = CallConv == CallingConv::X86_RegCall || - MF.getFunction()->hasFnAttribute("no_caller_saved_registers"); + MF.getFunction().hasFnAttribute("no_caller_saved_registers"); if (CallConv == CallingConv::X86_INTR && !Outs.empty()) report_fatal_error("X86 interrupts may not return any value"); @@ -2862,8 +2898,8 @@ static ArrayRef<MCPhysReg> get64BitArgumentXMMs(MachineFunction &MF, return None; } - const Function *Fn = MF.getFunction(); - bool NoImplicitFloatOps = Fn->hasFnAttribute(Attribute::NoImplicitFloat); + const Function &F = MF.getFunction(); + bool NoImplicitFloatOps = F.hasFnAttribute(Attribute::NoImplicitFloat); bool isSoftFloat = Subtarget.useSoftFloat(); assert(!(isSoftFloat && NoImplicitFloatOps) && "SSE register cannot be used when SSE is disabled!"); @@ -2896,10 +2932,9 @@ SDValue X86TargetLowering::LowerFormalArguments( X86MachineFunctionInfo *FuncInfo = MF.getInfo<X86MachineFunctionInfo>(); const TargetFrameLowering &TFI = *Subtarget.getFrameLowering(); - const Function *Fn = MF.getFunction(); - if (Fn->hasExternalLinkage() && - Subtarget.isTargetCygMing() && - Fn->getName() == "main") + const Function &F = MF.getFunction(); + if (F.hasExternalLinkage() && Subtarget.isTargetCygMing() && + F.getName() == "main") FuncInfo->setForceFramePointer(true); MachineFrameInfo &MFI = MF.getFrameInfo(); @@ -3074,7 +3109,7 @@ SDValue X86TargetLowering::LowerFormalArguments( // Figure out if XMM registers are in use. assert(!(Subtarget.useSoftFloat() && - Fn->hasFnAttribute(Attribute::NoImplicitFloat)) && + F.hasFnAttribute(Attribute::NoImplicitFloat)) && "SSE register cannot be used when SSE is disabled!"); // 64-bit calling conventions support varargs and register parameters, so we @@ -3231,7 +3266,7 @@ SDValue X86TargetLowering::LowerFormalArguments( FuncInfo->setArgumentStackSize(StackSize); if (WinEHFuncInfo *EHInfo = MF.getWinEHFuncInfo()) { - EHPersonality Personality = classifyEHPersonality(Fn->getPersonalityFn()); + EHPersonality Personality = classifyEHPersonality(F.getPersonalityFn()); if (Personality == EHPersonality::CoreCLR) { assert(Is64Bit); // TODO: Add a mechanism to frame lowering that will allow us to indicate @@ -3248,10 +3283,10 @@ SDValue X86TargetLowering::LowerFormalArguments( } if (CallConv == CallingConv::X86_RegCall || - Fn->hasFnAttribute("no_caller_saved_registers")) { - const MachineRegisterInfo &MRI = MF.getRegInfo(); - for (const auto &Pair : make_range(MRI.livein_begin(), MRI.livein_end())) - MF.getRegInfo().disableCalleeSavedRegister(Pair.first); + F.hasFnAttribute("no_caller_saved_registers")) { + MachineRegisterInfo &MRI = MF.getRegInfo(); + for (std::pair<unsigned, unsigned> Pair : MRI.liveins()) + MRI.disableCalleeSavedRegister(Pair.first); } return Chain; @@ -3339,9 +3374,8 @@ X86TargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, StructReturnType SR = callIsStructReturn(Outs, Subtarget.isTargetMCU()); bool IsSibcall = false; X86MachineFunctionInfo *X86Info = MF.getInfo<X86MachineFunctionInfo>(); - auto Attr = MF.getFunction()->getFnAttribute("disable-tail-calls"); - const CallInst *CI = - CLI.CS ? dyn_cast<CallInst>(CLI.CS->getInstruction()) : nullptr; + auto Attr = MF.getFunction().getFnAttribute("disable-tail-calls"); + const auto *CI = dyn_cast_or_null<CallInst>(CLI.CS.getInstruction()); const Function *Fn = CI ? CI->getCalledFunction() : nullptr; bool HasNCSR = (CI && CI->hasFnAttr("no_caller_saved_registers")) || (Fn && Fn->hasFnAttribute("no_caller_saved_registers")); @@ -3365,7 +3399,7 @@ X86TargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, isTailCall = false; } - bool IsMustTail = CLI.CS && CLI.CS->isMustTailCall(); + bool IsMustTail = CLI.CS && CLI.CS.isMustTailCall(); if (IsMustTail) { // Force this to be a tail call. The verifier rules are enough to ensure // that we can lower this successfully without moving the return address @@ -3375,7 +3409,7 @@ X86TargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, // Check if it's really possible to do a tail call. isTailCall = IsEligibleForTailCallOptimization(Callee, CallConv, isVarArg, SR != NotStructReturn, - MF.getFunction()->hasStructRetAttr(), CLI.RetTy, + MF.getFunction().hasStructRetAttr(), CLI.RetTy, Outs, OutVals, Ins, DAG); // Sibcalls are automatically detected tailcalls which do not require @@ -3721,7 +3755,7 @@ X86TargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, } } } else if (ExternalSymbolSDNode *S = dyn_cast<ExternalSymbolSDNode>(Callee)) { - const Module *Mod = DAG.getMachineFunction().getFunction()->getParent(); + const Module *Mod = DAG.getMachineFunction().getFunction().getParent(); unsigned char OpFlags = Subtarget.classifyGlobalFunctionReference(nullptr, *Mod); @@ -3769,11 +3803,11 @@ X86TargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, // is thrown, the runtime will not restore CSRs. // FIXME: Model this more precisely so that we can register allocate across // the normal edge and spill and fill across the exceptional edge. - if (!Is64Bit && CLI.CS && CLI.CS->isInvoke()) { - const Function *CallerFn = MF.getFunction(); + if (!Is64Bit && CLI.CS && CLI.CS.isInvoke()) { + const Function &CallerFn = MF.getFunction(); EHPersonality Pers = - CallerFn->hasPersonalityFn() - ? classifyEHPersonality(CallerFn->getPersonalityFn()) + CallerFn.hasPersonalityFn() + ? classifyEHPersonality(CallerFn.getPersonalityFn()) : EHPersonality::Unknown; if (isFuncletEHPersonality(Pers)) Mask = RegInfo->getNoPreservedMask(); @@ -4021,15 +4055,15 @@ bool X86TargetLowering::IsEligibleForTailCallOptimization( // If -tailcallopt is specified, make fastcc functions tail-callable. MachineFunction &MF = DAG.getMachineFunction(); - const Function *CallerF = MF.getFunction(); + const Function &CallerF = MF.getFunction(); // If the function return type is x86_fp80 and the callee return type is not, // then the FP_EXTEND of the call result is not a nop. It's not safe to // perform a tailcall optimization here. - if (CallerF->getReturnType()->isX86_FP80Ty() && !RetTy->isX86_FP80Ty()) + if (CallerF.getReturnType()->isX86_FP80Ty() && !RetTy->isX86_FP80Ty()) return false; - CallingConv::ID CallerCC = CallerF->getCallingConv(); + CallingConv::ID CallerCC = CallerF.getCallingConv(); bool CCMatch = CallerCC == CalleeCC; bool IsCalleeWin64 = Subtarget.isCallingConvWin64(CalleeCC); bool IsCallerWin64 = Subtarget.isCallingConvWin64(CallerCC); @@ -4243,7 +4277,6 @@ static bool isTargetShuffle(unsigned Opcode) { case X86ISD::VSHLDQ: case X86ISD::VSRLDQ: case X86ISD::MOVLHPS: - case X86ISD::MOVLHPD: case X86ISD::MOVHLPS: case X86ISD::MOVLPS: case X86ISD::MOVLPD: @@ -4491,6 +4524,7 @@ static bool hasFPCMov(unsigned X86CC) { bool X86TargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info, const CallInst &I, + MachineFunction &MF, unsigned Intrinsic) const { const IntrinsicData* IntrData = getIntrinsicWithChain(Intrinsic); @@ -4498,9 +4532,7 @@ bool X86TargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info, return false; Info.opc = ISD::INTRINSIC_W_CHAIN; - Info.readMem = false; - Info.writeMem = false; - Info.vol = false; + Info.flags = MachineMemOperand::MONone; Info.offset = 0; switch (IntrData->Type) { @@ -4508,14 +4540,14 @@ bool X86TargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info, Info.ptrVal = I.getArgOperand(0); Info.memVT = MVT::getVT(I.getType()); Info.align = 1; - Info.readMem = true; + Info.flags |= MachineMemOperand::MOLoad; break; } case COMPRESS_TO_MEM: { Info.ptrVal = I.getArgOperand(0); Info.memVT = MVT::getVT(I.getArgOperand(1)->getType()); Info.align = 1; - Info.writeMem = true; + Info.flags |= MachineMemOperand::MOStore; break; } case TRUNCATE_TO_MEM_VI8: @@ -4533,7 +4565,7 @@ bool X86TargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info, Info.memVT = MVT::getVectorVT(ScalarVT, VT.getVectorNumElements()); Info.align = 1; - Info.writeMem = true; + Info.flags |= MachineMemOperand::MOStore; break; } default: @@ -4578,12 +4610,27 @@ bool X86TargetLowering::shouldConvertConstantLoadToIntImm(const APInt &Imm, return true; } -bool X86TargetLowering::isExtractSubvectorCheap(EVT ResVT, +bool X86TargetLowering::convertSelectOfConstantsToMath(EVT VT) const { + // TODO: It might be a win to ease or lift this restriction, but the generic + // folds in DAGCombiner conflict with vector folds for an AVX512 target. + if (VT.isVector() && Subtarget.hasAVX512()) + return false; + + return true; +} + +bool X86TargetLowering::isExtractSubvectorCheap(EVT ResVT, EVT SrcVT, unsigned Index) const { if (!isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, ResVT)) return false; - return (Index == 0 || Index == ResVT.getVectorNumElements()); + // Mask vectors support all subregister combinations and operations that + // extract half of vector. + if (ResVT.getVectorElementType() == MVT::i1) + return Index == 0 || ((ResVT.getSizeInBits() == SrcVT.getSizeInBits()*2) && + (Index == ResVT.getVectorNumElements())); + + return (Index % ResVT.getVectorNumElements()) == 0; } bool X86TargetLowering::isCheapToSpeculateCttz() const { @@ -4596,6 +4643,28 @@ bool X86TargetLowering::isCheapToSpeculateCtlz() const { return Subtarget.hasLZCNT(); } +bool X86TargetLowering::isLoadBitCastBeneficial(EVT LoadVT, + EVT BitcastVT) const { + if (!Subtarget.hasDQI() && BitcastVT == MVT::v8i1) + return false; + + return TargetLowering::isLoadBitCastBeneficial(LoadVT, BitcastVT); +} + +bool X86TargetLowering::canMergeStoresTo(unsigned AddressSpace, EVT MemVT, + const SelectionDAG &DAG) const { + // Do not merge to float value size (128 bytes) if no implicit + // float attribute is set. + bool NoFloat = DAG.getMachineFunction().getFunction().hasFnAttribute( + Attribute::NoImplicitFloat); + + if (NoFloat) { + unsigned MaxIntSize = Subtarget.is64Bit() ? 64 : 32; + return (MemVT.getSizeInBits() <= MaxIntSize); + } + return true; +} + bool X86TargetLowering::isCtlzFast() const { return Subtarget.hasFastLZCNT(); } @@ -4778,123 +4847,6 @@ static bool canWidenShuffleElements(ArrayRef<int> Mask, return true; } -/// Helper function to scale a shuffle or target shuffle mask, replacing each -/// mask index with the scaled sequential indices for an equivalent narrowed -/// mask. This is the reverse process to canWidenShuffleElements, but can always -/// succeed. -static void scaleShuffleMask(int Scale, ArrayRef<int> Mask, - SmallVectorImpl<int> &ScaledMask) { - assert(0 < Scale && "Unexpected scaling factor"); - int NumElts = Mask.size(); - ScaledMask.assign(static_cast<size_t>(NumElts * Scale), -1); - - for (int i = 0; i != NumElts; ++i) { - int M = Mask[i]; - - // Repeat sentinel values in every mask element. - if (M < 0) { - for (int s = 0; s != Scale; ++s) - ScaledMask[(Scale * i) + s] = M; - continue; - } - - // Scale mask element and increment across each mask element. - for (int s = 0; s != Scale; ++s) - ScaledMask[(Scale * i) + s] = (Scale * M) + s; - } -} - -/// Return true if the specified EXTRACT_SUBVECTOR operand specifies a vector -/// extract that is suitable for instruction that extract 128 or 256 bit vectors -static bool isVEXTRACTIndex(SDNode *N, unsigned vecWidth) { - assert((vecWidth == 128 || vecWidth == 256) && "Unexpected vector width"); - if (!isa<ConstantSDNode>(N->getOperand(1).getNode())) - return false; - - // The index should be aligned on a vecWidth-bit boundary. - uint64_t Index = N->getConstantOperandVal(1); - MVT VT = N->getSimpleValueType(0); - unsigned ElSize = VT.getScalarSizeInBits(); - return (Index * ElSize) % vecWidth == 0; -} - -/// Return true if the specified INSERT_SUBVECTOR -/// operand specifies a subvector insert that is suitable for input to -/// insertion of 128 or 256-bit subvectors -static bool isVINSERTIndex(SDNode *N, unsigned vecWidth) { - assert((vecWidth == 128 || vecWidth == 256) && "Unexpected vector width"); - if (!isa<ConstantSDNode>(N->getOperand(2).getNode())) - return false; - - // The index should be aligned on a vecWidth-bit boundary. - uint64_t Index = N->getConstantOperandVal(2); - MVT VT = N->getSimpleValueType(0); - unsigned ElSize = VT.getScalarSizeInBits(); - return (Index * ElSize) % vecWidth == 0; -} - -bool X86::isVINSERT128Index(SDNode *N) { - return isVINSERTIndex(N, 128); -} - -bool X86::isVINSERT256Index(SDNode *N) { - return isVINSERTIndex(N, 256); -} - -bool X86::isVEXTRACT128Index(SDNode *N) { - return isVEXTRACTIndex(N, 128); -} - -bool X86::isVEXTRACT256Index(SDNode *N) { - return isVEXTRACTIndex(N, 256); -} - -static unsigned getExtractVEXTRACTImmediate(SDNode *N, unsigned vecWidth) { - assert((vecWidth == 128 || vecWidth == 256) && "Unsupported vector width"); - assert(isa<ConstantSDNode>(N->getOperand(1).getNode()) && - "Illegal extract subvector for VEXTRACT"); - - uint64_t Index = N->getConstantOperandVal(1); - MVT VecVT = N->getOperand(0).getSimpleValueType(); - unsigned NumElemsPerChunk = vecWidth / VecVT.getScalarSizeInBits(); - return Index / NumElemsPerChunk; -} - -static unsigned getInsertVINSERTImmediate(SDNode *N, unsigned vecWidth) { - assert((vecWidth == 128 || vecWidth == 256) && "Unsupported vector width"); - assert(isa<ConstantSDNode>(N->getOperand(2).getNode()) && - "Illegal insert subvector for VINSERT"); - - uint64_t Index = N->getConstantOperandVal(2); - MVT VecVT = N->getSimpleValueType(0); - unsigned NumElemsPerChunk = vecWidth / VecVT.getScalarSizeInBits(); - return Index / NumElemsPerChunk; -} - -/// Return the appropriate immediate to extract the specified -/// EXTRACT_SUBVECTOR index with VEXTRACTF128 and VINSERTI128 instructions. -unsigned X86::getExtractVEXTRACT128Immediate(SDNode *N) { - return getExtractVEXTRACTImmediate(N, 128); -} - -/// Return the appropriate immediate to extract the specified -/// EXTRACT_SUBVECTOR index with VEXTRACTF64x4 and VINSERTI64x4 instructions. -unsigned X86::getExtractVEXTRACT256Immediate(SDNode *N) { - return getExtractVEXTRACTImmediate(N, 256); -} - -/// Return the appropriate immediate to insert at the specified -/// INSERT_SUBVECTOR index with VINSERTF128 and VINSERTI128 instructions. -unsigned X86::getInsertVINSERT128Immediate(SDNode *N) { - return getInsertVINSERTImmediate(N, 128); -} - -/// Return the appropriate immediate to insert at the specified -/// INSERT_SUBVECTOR index with VINSERTF46x4 and VINSERTI64x4 instructions. -unsigned X86::getInsertVINSERT256Immediate(SDNode *N) { - return getInsertVINSERTImmediate(N, 256); -} - /// Returns true if Elt is a constant zero or a floating point constant +0.0. bool X86::isZeroNode(SDValue Elt) { return isNullConstant(Elt) || isNullFPConstant(Elt); @@ -5018,8 +4970,8 @@ static SDValue extractSubVector(SDValue Vec, unsigned IdxVal, SelectionDAG &DAG, // If the input is a buildvector just emit a smaller one. if (Vec.getOpcode() == ISD::BUILD_VECTOR) - return DAG.getBuildVector( - ResultVT, dl, makeArrayRef(Vec->op_begin() + IdxVal, ElemsPerChunk)); + return DAG.getBuildVector(ResultVT, dl, + Vec->ops().slice(IdxVal, ElemsPerChunk)); SDValue VecIdx = DAG.getIntPtrConstant(IdxVal, dl); return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, ResultVT, Vec, VecIdx); @@ -5093,10 +5045,13 @@ static bool isMaskedZeroUpperBitsvXi1(unsigned int Opcode) { switch (Opcode) { default: return false; + case X86ISD::TESTM: + case X86ISD::TESTNM: case X86ISD::PCMPEQM: case X86ISD::PCMPGTM: case X86ISD::CMPM: case X86ISD::CMPMU: + case X86ISD::CMPM_RND: return true; } } @@ -5113,113 +5068,128 @@ static SDValue insert1BitVector(SDValue Op, SelectionDAG &DAG, if (!isa<ConstantSDNode>(Idx)) return SDValue(); + // Inserting undef is a nop. We can just return the original vector. + if (SubVec.isUndef()) + return Vec; + unsigned IdxVal = cast<ConstantSDNode>(Idx)->getZExtValue(); - if (IdxVal == 0 && Vec.isUndef()) // the operation is legal + if (IdxVal == 0 && Vec.isUndef()) // the operation is legal return Op; MVT OpVT = Op.getSimpleValueType(); - MVT SubVecVT = SubVec.getSimpleValueType(); unsigned NumElems = OpVT.getVectorNumElements(); + + SDValue ZeroIdx = DAG.getIntPtrConstant(0, dl); + + // Extend to natively supported kshift. + MVT WideOpVT = OpVT; + if ((!Subtarget.hasDQI() && NumElems == 8) || NumElems < 8) + WideOpVT = Subtarget.hasDQI() ? MVT::v8i1 : MVT::v16i1; + + // Inserting into the lsbs of a zero vector is legal. ISel will insert shifts + // if necessary. + if (IdxVal == 0 && ISD::isBuildVectorAllZeros(Vec.getNode())) { + // May need to promote to a legal type. + Op = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, WideOpVT, + getZeroVector(WideOpVT, Subtarget, DAG, dl), + SubVec, Idx); + return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, OpVT, Op, ZeroIdx); + } + + MVT SubVecVT = SubVec.getSimpleValueType(); unsigned SubVecNumElems = SubVecVT.getVectorNumElements(); assert(IdxVal + SubVecNumElems <= NumElems && IdxVal % SubVecVT.getSizeInBits() == 0 && "Unexpected index value in INSERT_SUBVECTOR"); - // There are 3 possible cases: - // 1. Subvector should be inserted in the lower part (IdxVal == 0) - // 2. Subvector should be inserted in the upper part - // (IdxVal + SubVecNumElems == NumElems) - // 3. Subvector should be inserted in the middle (for example v2i1 - // to v16i1, index 2) - - // If this node widens - by concatenating zeroes - the type of the result - // of a node with instruction that zeroes all upper (irrelevant) bits of the - // output register, mark this node as legal to enable replacing them with - // the v8i1 version of the previous instruction during instruction selection. - // For example, VPCMPEQDZ128rr instruction stores its v4i1 result in a k-reg, - // while zeroing all the upper remaining 60 bits of the register. if the - // result of such instruction is inserted into an allZeroVector, then we can - // safely remove insert_vector (in instruction selection) as the cmp instr - // already zeroed the rest of the register. - if (ISD::isBuildVectorAllZeros(Vec.getNode()) && IdxVal == 0 && - (isMaskedZeroUpperBitsvXi1(SubVec.getOpcode()) || - (SubVec.getOpcode() == ISD::AND && - (isMaskedZeroUpperBitsvXi1(SubVec.getOperand(0).getOpcode()) || - isMaskedZeroUpperBitsvXi1(SubVec.getOperand(1).getOpcode()))))) - return Op; - - // extend to natively supported kshift - MVT MinVT = Subtarget.hasDQI() ? MVT::v8i1 : MVT::v16i1; - MVT WideOpVT = OpVT; - if (OpVT.getSizeInBits() < MinVT.getStoreSizeInBits()) - WideOpVT = MinVT; - - SDValue ZeroIdx = DAG.getIntPtrConstant(0, dl); SDValue Undef = DAG.getUNDEF(WideOpVT); - SDValue WideSubVec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, WideOpVT, - Undef, SubVec, ZeroIdx); - // Extract sub-vector if require. - auto ExtractSubVec = [&](SDValue V) { - return (WideOpVT == OpVT) ? V : DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, - OpVT, V, ZeroIdx); - }; + if (IdxVal == 0) { + // Zero lower bits of the Vec + SDValue ShiftBits = DAG.getConstant(SubVecNumElems, dl, MVT::i8); + Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, WideOpVT, Undef, Vec, + ZeroIdx); + Vec = DAG.getNode(X86ISD::KSHIFTR, dl, WideOpVT, Vec, ShiftBits); + Vec = DAG.getNode(X86ISD::KSHIFTL, dl, WideOpVT, Vec, ShiftBits); + // Merge them together, SubVec should be zero extended. + SubVec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, WideOpVT, + getZeroVector(WideOpVT, Subtarget, DAG, dl), + SubVec, ZeroIdx); + Op = DAG.getNode(ISD::OR, dl, WideOpVT, Vec, SubVec); + return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, OpVT, Op, ZeroIdx); + } + + SubVec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, WideOpVT, + Undef, SubVec, ZeroIdx); if (Vec.isUndef()) { - if (IdxVal != 0) { - SDValue ShiftBits = DAG.getConstant(IdxVal, dl, MVT::i8); - WideSubVec = DAG.getNode(X86ISD::KSHIFTL, dl, WideOpVT, WideSubVec, - ShiftBits); - } - return ExtractSubVec(WideSubVec); + assert(IdxVal != 0 && "Unexpected index"); + SubVec = DAG.getNode(X86ISD::KSHIFTL, dl, WideOpVT, SubVec, + DAG.getConstant(IdxVal, dl, MVT::i8)); + return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, OpVT, SubVec, ZeroIdx); } if (ISD::isBuildVectorAllZeros(Vec.getNode())) { + assert(IdxVal != 0 && "Unexpected index"); NumElems = WideOpVT.getVectorNumElements(); unsigned ShiftLeft = NumElems - SubVecNumElems; unsigned ShiftRight = NumElems - SubVecNumElems - IdxVal; - Vec = DAG.getNode(X86ISD::KSHIFTL, dl, WideOpVT, WideSubVec, - DAG.getConstant(ShiftLeft, dl, MVT::i8)); - Vec = ShiftRight ? DAG.getNode(X86ISD::KSHIFTR, dl, WideOpVT, Vec, - DAG.getConstant(ShiftRight, dl, MVT::i8)) : Vec; - return ExtractSubVec(Vec); - } - - if (IdxVal == 0) { - // Zero lower bits of the Vec - SDValue ShiftBits = DAG.getConstant(SubVecNumElems, dl, MVT::i8); - Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, WideOpVT, Undef, Vec, ZeroIdx); - Vec = DAG.getNode(X86ISD::KSHIFTR, dl, WideOpVT, Vec, ShiftBits); - Vec = DAG.getNode(X86ISD::KSHIFTL, dl, WideOpVT, Vec, ShiftBits); - // Merge them together, SubVec should be zero extended. - WideSubVec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, WideOpVT, - getZeroVector(WideOpVT, Subtarget, DAG, dl), - SubVec, ZeroIdx); - Vec = DAG.getNode(ISD::OR, dl, WideOpVT, Vec, WideSubVec); - return ExtractSubVec(Vec); + SubVec = DAG.getNode(X86ISD::KSHIFTL, dl, WideOpVT, SubVec, + DAG.getConstant(ShiftLeft, dl, MVT::i8)); + if (ShiftRight != 0) + SubVec = DAG.getNode(X86ISD::KSHIFTR, dl, WideOpVT, SubVec, + DAG.getConstant(ShiftRight, dl, MVT::i8)); + return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, OpVT, SubVec, ZeroIdx); } // Simple case when we put subvector in the upper part if (IdxVal + SubVecNumElems == NumElems) { - // Zero upper bits of the Vec - WideSubVec = DAG.getNode(X86ISD::KSHIFTL, dl, WideOpVT, WideSubVec, - DAG.getConstant(IdxVal, dl, MVT::i8)); - SDValue ShiftBits = DAG.getConstant(SubVecNumElems, dl, MVT::i8); - Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, WideOpVT, Undef, Vec, ZeroIdx); - Vec = DAG.getNode(X86ISD::KSHIFTL, dl, WideOpVT, Vec, ShiftBits); - Vec = DAG.getNode(X86ISD::KSHIFTR, dl, WideOpVT, Vec, ShiftBits); - Vec = DAG.getNode(ISD::OR, dl, WideOpVT, Vec, WideSubVec); - return ExtractSubVec(Vec); - } - // Subvector should be inserted in the middle - use shuffle - WideSubVec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, OpVT, Undef, - SubVec, ZeroIdx); - SmallVector<int, 64> Mask; - for (unsigned i = 0; i < NumElems; ++i) - Mask.push_back(i >= IdxVal && i < IdxVal + SubVecNumElems ? - i : i + NumElems); - return DAG.getVectorShuffle(OpVT, dl, WideSubVec, Vec, Mask); + SubVec = DAG.getNode(X86ISD::KSHIFTL, dl, WideOpVT, SubVec, + DAG.getConstant(IdxVal, dl, MVT::i8)); + if (SubVecNumElems * 2 == NumElems) { + // Special case, use legal zero extending insert_subvector. This allows + // isel to opimitize when bits are known zero. + Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, SubVecVT, Vec, ZeroIdx); + Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, WideOpVT, + getZeroVector(WideOpVT, Subtarget, DAG, dl), + Vec, ZeroIdx); + } else { + // Otherwise use explicit shifts to zero the bits. + Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, WideOpVT, + Undef, Vec, ZeroIdx); + NumElems = WideOpVT.getVectorNumElements(); + SDValue ShiftBits = DAG.getConstant(NumElems - IdxVal, dl, MVT::i8); + Vec = DAG.getNode(X86ISD::KSHIFTL, dl, WideOpVT, Vec, ShiftBits); + Vec = DAG.getNode(X86ISD::KSHIFTR, dl, WideOpVT, Vec, ShiftBits); + } + Op = DAG.getNode(ISD::OR, dl, WideOpVT, Vec, SubVec); + return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, OpVT, Op, ZeroIdx); + } + + // Inserting into the middle is more complicated. + + NumElems = WideOpVT.getVectorNumElements(); + + // Widen the vector if needed. + Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, WideOpVT, Undef, Vec, ZeroIdx); + // Move the current value of the bit to be replace to the lsbs. + Op = DAG.getNode(X86ISD::KSHIFTR, dl, WideOpVT, Vec, + DAG.getConstant(IdxVal, dl, MVT::i8)); + // Xor with the new bit. + Op = DAG.getNode(ISD::XOR, dl, WideOpVT, Op, SubVec); + // Shift to MSB, filling bottom bits with 0. + unsigned ShiftLeft = NumElems - SubVecNumElems; + Op = DAG.getNode(X86ISD::KSHIFTL, dl, WideOpVT, Op, + DAG.getConstant(ShiftLeft, dl, MVT::i8)); + // Shift to the final position, filling upper bits with 0. + unsigned ShiftRight = NumElems - SubVecNumElems - IdxVal; + Op = DAG.getNode(X86ISD::KSHIFTR, dl, WideOpVT, Op, + DAG.getConstant(ShiftRight, dl, MVT::i8)); + // Xor with original vector leaving the new value. + Op = DAG.getNode(ISD::XOR, dl, WideOpVT, Vec, Op); + // Reduce to original width if needed. + return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, OpVT, Op, ZeroIdx); } /// Concat two 128-bit vectors into a 256 bit vector using VINSERTF128 @@ -5273,22 +5243,6 @@ static SDValue getExtendInVec(unsigned Opc, const SDLoc &DL, EVT VT, SDValue In, return DAG.getNode(Opc, DL, VT, In); } -/// Generate unpacklo/unpackhi shuffle mask. -static void createUnpackShuffleMask(MVT VT, SmallVectorImpl<int> &Mask, bool Lo, - bool Unary) { - assert(Mask.empty() && "Expected an empty shuffle mask vector"); - int NumElts = VT.getVectorNumElements(); - int NumEltsInLane = 128 / VT.getScalarSizeInBits(); - - for (int i = 0; i < NumElts; ++i) { - unsigned LaneStart = (i / NumEltsInLane) * NumEltsInLane; - int Pos = (i % NumEltsInLane) / 2 + LaneStart; - Pos += (Unary ? 0 : NumElts * (i % 2)); - Pos += (Lo ? 0 : NumEltsInLane / 2); - Mask.push_back(Pos); - } -} - /// Returns a vector_shuffle node for an unpackl operation. static SDValue getUnpackl(SelectionDAG &DAG, const SDLoc &dl, MVT VT, SDValue V1, SDValue V2) { @@ -5448,6 +5402,20 @@ static bool getTargetConstantBitsFromNode(SDValue Op, unsigned EltSizeInBits, return false; }; + // Handle UNDEFs. + if (Op.isUndef()) { + APInt UndefSrcElts = APInt::getAllOnesValue(NumElts); + SmallVector<APInt, 64> SrcEltBits(NumElts, APInt(EltSizeInBits, 0)); + return CastBitData(UndefSrcElts, SrcEltBits); + } + + // Extract scalar constant bits. + if (auto *Cst = dyn_cast<ConstantSDNode>(Op)) { + APInt UndefSrcElts = APInt::getNullValue(1); + SmallVector<APInt, 64> SrcEltBits(1, Cst->getAPIntValue()); + return CastBitData(UndefSrcElts, SrcEltBits); + } + // Extract constant bits from build vector. if (ISD::isBuildVectorOfConstantSDNodes(Op.getNode())) { unsigned SrcEltSizeInBits = VT.getScalarSizeInBits(); @@ -5542,6 +5510,24 @@ static bool getTargetShuffleMaskIndices(SDValue MaskNode, return true; } +/// Create a shuffle mask that matches the PACKSS/PACKUS truncation. +/// Note: This ignores saturation, so inputs must be checked first. +static void createPackShuffleMask(MVT VT, SmallVectorImpl<int> &Mask, + bool Unary) { + assert(Mask.empty() && "Expected an empty shuffle mask vector"); + unsigned NumElts = VT.getVectorNumElements(); + unsigned NumLanes = VT.getSizeInBits() / 128; + unsigned NumEltsPerLane = 128 / VT.getScalarSizeInBits(); + unsigned Offset = Unary ? 0 : NumElts; + + for (unsigned Lane = 0; Lane != NumLanes; ++Lane) { + for (unsigned Elt = 0; Elt != NumEltsPerLane; Elt += 2) + Mask.push_back(Elt + (Lane * NumEltsPerLane)); + for (unsigned Elt = 0; Elt != NumEltsPerLane; Elt += 2) + Mask.push_back(Elt + (Lane * NumEltsPerLane) + Offset); + } +} + /// Calculates the shuffle mask corresponding to the target-specific opcode. /// If the mask could be calculated, returns it in \p Mask, returns the shuffle /// operands in \p Ops, and returns true. @@ -5562,21 +5548,28 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, bool IsFakeUnary = false; switch(N->getOpcode()) { case X86ISD::BLENDI: + assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); + assert(N->getOperand(1).getValueType() == VT && "Unexpected value type"); ImmN = N->getOperand(N->getNumOperands()-1); DecodeBLENDMask(VT, cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask); IsUnary = IsFakeUnary = N->getOperand(0) == N->getOperand(1); break; case X86ISD::SHUFP: + assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); + assert(N->getOperand(1).getValueType() == VT && "Unexpected value type"); ImmN = N->getOperand(N->getNumOperands()-1); DecodeSHUFPMask(VT, cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask); IsUnary = IsFakeUnary = N->getOperand(0) == N->getOperand(1); break; case X86ISD::INSERTPS: + assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); + assert(N->getOperand(1).getValueType() == VT && "Unexpected value type"); ImmN = N->getOperand(N->getNumOperands()-1); DecodeINSERTPSMask(cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask); IsUnary = IsFakeUnary = N->getOperand(0) == N->getOperand(1); break; case X86ISD::EXTRQI: + assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); if (isa<ConstantSDNode>(N->getOperand(1)) && isa<ConstantSDNode>(N->getOperand(2))) { int BitLen = N->getConstantOperandVal(1); @@ -5586,6 +5579,8 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, } break; case X86ISD::INSERTQI: + assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); + assert(N->getOperand(1).getValueType() == VT && "Unexpected value type"); if (isa<ConstantSDNode>(N->getOperand(2)) && isa<ConstantSDNode>(N->getOperand(3))) { int BitLen = N->getConstantOperandVal(2); @@ -5595,23 +5590,33 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, } break; case X86ISD::UNPCKH: + assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); + assert(N->getOperand(1).getValueType() == VT && "Unexpected value type"); DecodeUNPCKHMask(VT, Mask); IsUnary = IsFakeUnary = N->getOperand(0) == N->getOperand(1); break; case X86ISD::UNPCKL: + assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); + assert(N->getOperand(1).getValueType() == VT && "Unexpected value type"); DecodeUNPCKLMask(VT, Mask); IsUnary = IsFakeUnary = N->getOperand(0) == N->getOperand(1); break; case X86ISD::MOVHLPS: + assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); + assert(N->getOperand(1).getValueType() == VT && "Unexpected value type"); DecodeMOVHLPSMask(NumElems, Mask); IsUnary = IsFakeUnary = N->getOperand(0) == N->getOperand(1); break; case X86ISD::MOVLHPS: + assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); + assert(N->getOperand(1).getValueType() == VT && "Unexpected value type"); DecodeMOVLHPSMask(NumElems, Mask); IsUnary = IsFakeUnary = N->getOperand(0) == N->getOperand(1); break; case X86ISD::PALIGNR: assert(VT.getScalarType() == MVT::i8 && "Byte vector expected"); + assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); + assert(N->getOperand(1).getValueType() == VT && "Unexpected value type"); ImmN = N->getOperand(N->getNumOperands()-1); DecodePALIGNRMask(VT, cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask); IsUnary = IsFakeUnary = N->getOperand(0) == N->getOperand(1); @@ -5620,33 +5625,39 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, break; case X86ISD::VSHLDQ: assert(VT.getScalarType() == MVT::i8 && "Byte vector expected"); + assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); ImmN = N->getOperand(N->getNumOperands() - 1); DecodePSLLDQMask(VT, cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask); IsUnary = true; break; case X86ISD::VSRLDQ: assert(VT.getScalarType() == MVT::i8 && "Byte vector expected"); + assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); ImmN = N->getOperand(N->getNumOperands() - 1); DecodePSRLDQMask(VT, cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask); IsUnary = true; break; case X86ISD::PSHUFD: case X86ISD::VPERMILPI: + assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); ImmN = N->getOperand(N->getNumOperands()-1); DecodePSHUFMask(VT, cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask); IsUnary = true; break; case X86ISD::PSHUFHW: + assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); ImmN = N->getOperand(N->getNumOperands()-1); DecodePSHUFHWMask(VT, cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask); IsUnary = true; break; case X86ISD::PSHUFLW: + assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); ImmN = N->getOperand(N->getNumOperands()-1); DecodePSHUFLWMask(VT, cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask); IsUnary = true; break; case X86ISD::VZEXT_MOVL: + assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); DecodeZeroMoveLowMask(VT, Mask); IsUnary = true; break; @@ -5670,6 +5681,7 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, return false; } case X86ISD::VPERMILPV: { + assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); IsUnary = true; SDValue MaskNode = N->getOperand(1); unsigned MaskEltSize = VT.getScalarSizeInBits(); @@ -5685,6 +5697,9 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, return false; } case X86ISD::PSHUFB: { + assert(VT.getScalarType() == MVT::i8 && "Byte vector expected"); + assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); + assert(N->getOperand(1).getValueType() == VT && "Unexpected value type"); IsUnary = true; SDValue MaskNode = N->getOperand(1); SmallVector<uint64_t, 32> RawMask; @@ -5699,37 +5714,46 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, return false; } case X86ISD::VPERMI: + assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); ImmN = N->getOperand(N->getNumOperands()-1); DecodeVPERMMask(VT, cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask); IsUnary = true; break; case X86ISD::MOVSS: case X86ISD::MOVSD: + assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); + assert(N->getOperand(1).getValueType() == VT && "Unexpected value type"); DecodeScalarMoveMask(VT, /* IsLoad */ false, Mask); break; case X86ISD::VPERM2X128: + assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); + assert(N->getOperand(1).getValueType() == VT && "Unexpected value type"); ImmN = N->getOperand(N->getNumOperands()-1); DecodeVPERM2X128Mask(VT, cast<ConstantSDNode>(ImmN)->getZExtValue(), Mask); IsUnary = IsFakeUnary = N->getOperand(0) == N->getOperand(1); break; case X86ISD::MOVSLDUP: + assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); DecodeMOVSLDUPMask(VT, Mask); IsUnary = true; break; case X86ISD::MOVSHDUP: + assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); DecodeMOVSHDUPMask(VT, Mask); IsUnary = true; break; case X86ISD::MOVDDUP: + assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); DecodeMOVDDUPMask(VT, Mask); IsUnary = true; break; - case X86ISD::MOVLHPD: case X86ISD::MOVLPD: case X86ISD::MOVLPS: // Not yet implemented return false; case X86ISD::VPERMIL2: { + assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); + assert(N->getOperand(1).getValueType() == VT && "Unexpected value type"); IsUnary = IsFakeUnary = N->getOperand(0) == N->getOperand(1); unsigned MaskEltSize = VT.getScalarSizeInBits(); SDValue MaskNode = N->getOperand(2); @@ -5749,6 +5773,8 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, return false; } case X86ISD::VPPERM: { + assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); + assert(N->getOperand(1).getValueType() == VT && "Unexpected value type"); IsUnary = IsFakeUnary = N->getOperand(0) == N->getOperand(1); SDValue MaskNode = N->getOperand(2); SmallVector<uint64_t, 32> RawMask; @@ -5763,6 +5789,7 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, return false; } case X86ISD::VPERMV: { + assert(N->getOperand(1).getValueType() == VT && "Unexpected value type"); IsUnary = true; // Unlike most shuffle nodes, VPERMV's mask operand is operand 0. Ops.push_back(N->getOperand(1)); @@ -5780,6 +5807,8 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, return false; } case X86ISD::VPERMV3: { + assert(N->getOperand(0).getValueType() == VT && "Unexpected value type"); + assert(N->getOperand(2).getValueType() == VT && "Unexpected value type"); IsUnary = IsFakeUnary = N->getOperand(0) == N->getOperand(2); // Unlike most shuffle nodes, VPERMV3's mask operand is the middle one. Ops.push_back(N->getOperand(0)); @@ -5793,6 +5822,8 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, return false; } case X86ISD::VPERMIV3: { + assert(N->getOperand(1).getValueType() == VT && "Unexpected value type"); + assert(N->getOperand(2).getValueType() == VT && "Unexpected value type"); IsUnary = IsFakeUnary = N->getOperand(1) == N->getOperand(2); // Unlike most shuffle nodes, VPERMIV3's mask operand is the first one. Ops.push_back(N->getOperand(1)); @@ -5965,19 +5996,13 @@ static bool getFauxShuffleMask(SDValue N, SmallVectorImpl<int> &Mask, SDValue N0 = N.getOperand(0); SDValue SrcExtract; - if (N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT && - N0.getOperand(0).getValueType() == VT) { + if ((N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT && + N0.getOperand(0).getValueType() == VT) || + (N0.getOpcode() == X86ISD::PEXTRW && + N0.getOperand(0).getValueType() == MVT::v8i16) || + (N0.getOpcode() == X86ISD::PEXTRB && + N0.getOperand(0).getValueType() == MVT::v16i8)) { SrcExtract = N0; - } else if (N0.getOpcode() == ISD::AssertZext && - N0.getOperand(0).getOpcode() == X86ISD::PEXTRW && - cast<VTSDNode>(N0.getOperand(1))->getVT() == MVT::i16) { - SrcExtract = N0.getOperand(0); - assert(SrcExtract.getOperand(0).getValueType() == MVT::v8i16); - } else if (N0.getOpcode() == ISD::AssertZext && - N0.getOperand(0).getOpcode() == X86ISD::PEXTRB && - cast<VTSDNode>(N0.getOperand(1))->getVT() == MVT::i8) { - SrcExtract = N0.getOperand(0); - assert(SrcExtract.getOperand(0).getValueType() == MVT::v16i8); } if (!SrcExtract || !isa<ConstantSDNode>(SrcExtract.getOperand(1))) @@ -6013,16 +6038,15 @@ static bool getFauxShuffleMask(SDValue N, SmallVectorImpl<int> &Mask, return true; } - // Attempt to recognise a PINSR*(ASSERTZEXT(PEXTR*)) shuffle pattern. + // Attempt to recognise a PINSR*(PEXTR*) shuffle pattern. // TODO: Expand this to support INSERT_VECTOR_ELT/etc. unsigned ExOp = (X86ISD::PINSRB == Opcode ? X86ISD::PEXTRB : X86ISD::PEXTRW); - if (InScl.getOpcode() != ISD::AssertZext || - InScl.getOperand(0).getOpcode() != ExOp) + if (InScl.getOpcode() != ExOp) return false; - SDValue ExVec = InScl.getOperand(0).getOperand(0); - uint64_t ExIdx = InScl.getOperand(0).getConstantOperandVal(1); + SDValue ExVec = InScl.getOperand(0); + uint64_t ExIdx = InScl.getConstantOperandVal(1); assert(ExIdx < NumElts && "Illegal extraction index"); Ops.push_back(InVec); Ops.push_back(ExVec); @@ -6030,17 +6054,34 @@ static bool getFauxShuffleMask(SDValue N, SmallVectorImpl<int> &Mask, Mask.push_back(i == InIdx ? NumElts + ExIdx : i); return true; } - case X86ISD::PACKSS: { + case X86ISD::PACKSS: + case X86ISD::PACKUS: { + SDValue N0 = N.getOperand(0); + SDValue N1 = N.getOperand(1); + assert(N0.getValueType().getVectorNumElements() == (NumElts / 2) && + N1.getValueType().getVectorNumElements() == (NumElts / 2) && + "Unexpected input value type"); + // If we know input saturation won't happen we can treat this // as a truncation shuffle. - if (DAG.ComputeNumSignBits(N.getOperand(0)) <= NumBitsPerElt || - DAG.ComputeNumSignBits(N.getOperand(1)) <= NumBitsPerElt) - return false; + if (Opcode == X86ISD::PACKSS) { + if ((!N0.isUndef() && DAG.ComputeNumSignBits(N0) <= NumBitsPerElt) || + (!N1.isUndef() && DAG.ComputeNumSignBits(N1) <= NumBitsPerElt)) + return false; + } else { + APInt ZeroMask = APInt::getHighBitsSet(2 * NumBitsPerElt, NumBitsPerElt); + if ((!N0.isUndef() && !DAG.MaskedValueIsZero(N0, ZeroMask)) || + (!N1.isUndef() && !DAG.MaskedValueIsZero(N1, ZeroMask))) + return false; + } - Ops.push_back(N.getOperand(0)); - Ops.push_back(N.getOperand(1)); - for (unsigned i = 0; i != NumElts; ++i) - Mask.push_back(i * 2); + bool IsUnary = (N0 == N1); + + Ops.push_back(N0); + if (!IsUnary) + Ops.push_back(N1); + + createPackShuffleMask(VT, Mask, IsUnary); return true; } case X86ISD::VSHLI: @@ -6099,6 +6140,14 @@ static void resolveTargetShuffleInputsAndMask(SmallVectorImpl<SDValue> &Inputs, for (int i = 0, e = Inputs.size(); i < e; ++i) { int lo = UsedInputs.size() * MaskWidth; int hi = lo + MaskWidth; + + // Strip UNDEF input usage. + if (Inputs[i].isUndef()) + for (int &M : Mask) + if ((lo <= M) && (M < hi)) + M = SM_SentinelUndef; + + // Check for unused inputs. if (any_of(Mask, [lo, hi](int i) { return (lo <= i) && (i < hi); })) { UsedInputs.push_back(Inputs[i]); continue; @@ -6196,6 +6245,49 @@ static SDValue getShuffleScalarElt(SDNode *N, unsigned Index, SelectionDAG &DAG, return SDValue(); } +// Use PINSRB/PINSRW/PINSRD to create a build vector. +static SDValue LowerBuildVectorAsInsert(SDValue Op, unsigned NonZeros, + unsigned NumNonZero, unsigned NumZero, + SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + MVT VT = Op.getSimpleValueType(); + unsigned NumElts = VT.getVectorNumElements(); + assert(((VT == MVT::v8i16 && Subtarget.hasSSE2()) || + ((VT == MVT::v16i8 || VT == MVT::v4i32) && Subtarget.hasSSE41())) && + "Illegal vector insertion"); + + SDLoc dl(Op); + SDValue V; + bool First = true; + + for (unsigned i = 0; i < NumElts; ++i) { + bool IsNonZero = (NonZeros & (1 << i)) != 0; + if (!IsNonZero) + continue; + + // If the build vector contains zeros or our first insertion is not the + // first index then insert into zero vector to break any register + // dependency else use SCALAR_TO_VECTOR/VZEXT_MOVL. + if (First) { + First = false; + if (NumZero || 0 != i) + V = getZeroVector(VT, Subtarget, DAG, dl); + else { + assert(0 == i && "Expected insertion into zero-index"); + V = DAG.getAnyExtOrTrunc(Op.getOperand(i), dl, MVT::i32); + V = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, MVT::v4i32, V); + V = DAG.getNode(X86ISD::VZEXT_MOVL, dl, MVT::v4i32, V); + V = DAG.getBitcast(VT, V); + continue; + } + } + V = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, VT, V, Op.getOperand(i), + DAG.getIntPtrConstant(i, dl)); + } + + return V; +} + /// Custom lower build_vector of v16i8. static SDValue LowerBuildVectorv16i8(SDValue Op, unsigned NonZeros, unsigned NumNonZero, unsigned NumZero, @@ -6204,39 +6296,15 @@ static SDValue LowerBuildVectorv16i8(SDValue Op, unsigned NonZeros, if (NumNonZero > 8 && !Subtarget.hasSSE41()) return SDValue(); + // SSE4.1 - use PINSRB to insert each byte directly. + if (Subtarget.hasSSE41()) + return LowerBuildVectorAsInsert(Op, NonZeros, NumNonZero, NumZero, DAG, + Subtarget); + SDLoc dl(Op); SDValue V; bool First = true; - // SSE4.1 - use PINSRB to insert each byte directly. - if (Subtarget.hasSSE41()) { - for (unsigned i = 0; i < 16; ++i) { - bool IsNonZero = (NonZeros & (1 << i)) != 0; - if (IsNonZero) { - // If the build vector contains zeros or our first insertion is not the - // first index then insert into zero vector to break any register - // dependency else use SCALAR_TO_VECTOR/VZEXT_MOVL. - if (First) { - First = false; - if (NumZero || 0 != i) - V = getZeroVector(MVT::v16i8, Subtarget, DAG, dl); - else { - assert(0 == i && "Expected insertion into zero-index"); - V = DAG.getAnyExtOrTrunc(Op.getOperand(i), dl, MVT::i32); - V = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, MVT::v4i32, V); - V = DAG.getNode(X86ISD::VZEXT_MOVL, dl, MVT::v4i32, V); - V = DAG.getBitcast(MVT::v16i8, V); - continue; - } - } - V = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, MVT::v16i8, V, - Op.getOperand(i), DAG.getIntPtrConstant(i, dl)); - } - } - - return V; - } - // Pre-SSE4.1 - merge byte pairs and insert with PINSRW. for (unsigned i = 0; i < 16; ++i) { bool ThisIsNonZero = (NonZeros & (1 << i)) != 0; @@ -6292,34 +6360,9 @@ static SDValue LowerBuildVectorv8i16(SDValue Op, unsigned NonZeros, if (NumNonZero > 4 && !Subtarget.hasSSE41()) return SDValue(); - SDLoc dl(Op); - SDValue V; - bool First = true; - for (unsigned i = 0; i < 8; ++i) { - bool IsNonZero = (NonZeros & (1 << i)) != 0; - if (IsNonZero) { - // If the build vector contains zeros or our first insertion is not the - // first index then insert into zero vector to break any register - // dependency else use SCALAR_TO_VECTOR/VZEXT_MOVL. - if (First) { - First = false; - if (NumZero || 0 != i) - V = getZeroVector(MVT::v8i16, Subtarget, DAG, dl); - else { - assert(0 == i && "Expected insertion into zero-index"); - V = DAG.getAnyExtOrTrunc(Op.getOperand(i), dl, MVT::i32); - V = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, MVT::v4i32, V); - V = DAG.getNode(X86ISD::VZEXT_MOVL, dl, MVT::v4i32, V); - V = DAG.getBitcast(MVT::v8i16, V); - continue; - } - } - V = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, MVT::v8i16, V, - Op.getOperand(i), DAG.getIntPtrConstant(i, dl)); - } - } - - return V; + // Use PINSRW to insert each byte directly. + return LowerBuildVectorAsInsert(Op, NonZeros, NumNonZero, NumZero, DAG, + Subtarget); } /// Custom lower build_vector of v4i32 or v4f32. @@ -6589,14 +6632,20 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts, } } - auto CreateLoad = [&DAG, &DL](EVT VT, LoadSDNode *LDBase) { + SmallVector<LoadSDNode *, 8> Loads; + for (int i = FirstLoadedElt; i <= LastLoadedElt; ++i) + if (LoadMask[i]) + Loads.push_back(cast<LoadSDNode>(peekThroughBitcasts(Elts[i]))); + + auto CreateLoad = [&DAG, &DL, &Loads](EVT VT, LoadSDNode *LDBase) { auto MMOFlags = LDBase->getMemOperand()->getFlags(); assert(!(MMOFlags & MachineMemOperand::MOVolatile) && "Cannot merge volatile loads."); SDValue NewLd = DAG.getLoad(VT, DL, LDBase->getChain(), LDBase->getBasePtr(), LDBase->getPointerInfo(), LDBase->getAlignment(), MMOFlags); - DAG.makeEquivalentMemoryOrdering(LDBase, NewLd); + for (auto *LD : Loads) + DAG.makeEquivalentMemoryOrdering(LD, NewLd); return NewLd; }; @@ -6659,9 +6708,9 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts, DAG.getMemIntrinsicNode(X86ISD::VZEXT_LOAD, DL, Tys, Ops, VecSVT, LDBase->getPointerInfo(), LDBase->getAlignment(), - false/*isVolatile*/, true/*ReadMem*/, - false/*WriteMem*/); - DAG.makeEquivalentMemoryOrdering(LDBase, ResNode); + MachineMemOperand::MOLoad); + for (auto *LD : Loads) + DAG.makeEquivalentMemoryOrdering(LD, ResNode); return DAG.getBitcast(VT, ResNode); } } @@ -6702,6 +6751,43 @@ static bool isUseOfShuffle(SDNode *N) { return false; } +// Check if the current node of build vector is a zero extended vector. +// // If so, return the value extended. +// // For example: (0,0,0,a,0,0,0,a,0,0,0,a,0,0,0,a) returns a. +// // NumElt - return the number of zero extended identical values. +// // EltType - return the type of the value include the zero extend. +static SDValue isSplatZeroExtended(const BuildVectorSDNode *Op, + unsigned &NumElt, MVT &EltType) { + SDValue ExtValue = Op->getOperand(0); + unsigned NumElts = Op->getNumOperands(); + unsigned Delta = NumElts; + + for (unsigned i = 1; i < NumElts; i++) { + if (Op->getOperand(i) == ExtValue) { + Delta = i; + break; + } + if (!(Op->getOperand(i).isUndef() || isNullConstant(Op->getOperand(i)))) + return SDValue(); + } + if (!isPowerOf2_32(Delta) || Delta == 1) + return SDValue(); + + for (unsigned i = Delta; i < NumElts; i++) { + if (i % Delta == 0) { + if (Op->getOperand(i) != ExtValue) + return SDValue(); + } else if (!(isNullConstant(Op->getOperand(i)) || + Op->getOperand(i).isUndef())) + return SDValue(); + } + unsigned EltSize = Op->getSimpleValueType(0).getScalarSizeInBits(); + unsigned ExtVTSize = EltSize * Delta; + EltType = MVT::getIntegerVT(ExtVTSize); + NumElt = NumElts / Delta; + return ExtValue; +} + /// Attempt to use the vbroadcast instruction to generate a splat value /// from a splat BUILD_VECTOR which uses: /// a. A single scalar load, or a constant. @@ -6727,6 +6813,39 @@ static SDValue lowerBuildVectorAsBroadcast(BuildVectorSDNode *BVOp, BitVector UndefElements; SDValue Ld = BVOp->getSplatValue(&UndefElements); + // Attempt to use VBROADCASTM + // From this paterrn: + // a. t0 = (zext_i64 (bitcast_i8 v2i1 X)) + // b. t1 = (build_vector t0 t0) + // + // Create (VBROADCASTM v2i1 X) + if (Subtarget.hasCDI() && (VT.is512BitVector() || Subtarget.hasVLX())) { + MVT EltType = VT.getScalarType(); + unsigned NumElts = VT.getVectorNumElements(); + SDValue BOperand; + SDValue ZeroExtended = isSplatZeroExtended(BVOp, NumElts, EltType); + if ((ZeroExtended && ZeroExtended.getOpcode() == ISD::BITCAST) || + (Ld && Ld.getOpcode() == ISD::ZERO_EXTEND && + Ld.getOperand(0).getOpcode() == ISD::BITCAST)) { + if (ZeroExtended) + BOperand = ZeroExtended.getOperand(0); + else + BOperand = Ld.getOperand(0).getOperand(0); + if (BOperand.getValueType().isVector() && + BOperand.getSimpleValueType().getVectorElementType() == MVT::i1) { + if ((EltType == MVT::i64 && (VT.getVectorElementType() == MVT::i8 || + NumElts == 8)) || // for broadcastmb2q + (EltType == MVT::i32 && (VT.getVectorElementType() == MVT::i16 || + NumElts == 16))) { // for broadcastmw2d + SDValue Brdcst = + DAG.getNode(X86ISD::VBROADCASTM, dl, + MVT::getVectorVT(EltType, NumElts), BOperand); + return DAG.getBitcast(VT, Brdcst); + } + } + } + } + // We need a splat of a single value to use broadcast, and it doesn't // make any sense if the value is only in one element of the vector. if (!Ld || (VT.getVectorNumElements() - UndefElements.count()) <= 1) { @@ -6824,7 +6943,7 @@ static SDValue lowerBuildVectorAsBroadcast(BuildVectorSDNode *BVOp, // TODO: If multiple splats are generated to load the same constant, // it may be detrimental to overall size. There needs to be a way to detect // that condition to know if this is truly a size win. - bool OptForSize = DAG.getMachineFunction().getFunction()->optForSize(); + bool OptForSize = DAG.getMachineFunction().getFunction().optForSize(); // Handle broadcasting a single constant scalar from the constant pool // into a vector. @@ -6902,10 +7021,10 @@ static int getUnderlyingExtractedFromVec(SDValue &ExtractedFromVec, // For 256-bit vectors, LowerEXTRACT_VECTOR_ELT_SSE4 may have already // lowered this: - // (extract_vector_elt (v8f32 %vreg1), Constant<6>) + // (extract_vector_elt (v8f32 %1), Constant<6>) // to: // (extract_vector_elt (vector_shuffle<2,u,u,u> - // (extract_subvector (v8f32 %vreg0), Constant<4>), + // (extract_subvector (v8f32 %0), Constant<4>), // undef) // Constant<0>) // In this case the vector is the extract_subvector expression and the index @@ -7011,8 +7130,8 @@ static SDValue ConvertI1VectorToInteger(SDValue Op, SelectionDAG &DAG) { return DAG.getConstant(Immediate, dl, VT); } // Lower BUILD_VECTOR operation for v8i1 and v16i1 types. -SDValue -X86TargetLowering::LowerBUILD_VECTORvXi1(SDValue Op, SelectionDAG &DAG) const { +static SDValue LowerBUILD_VECTORvXi1(SDValue Op, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { MVT VT = Op.getSimpleValueType(); assert((VT.getVectorElementType() == MVT::i1) && @@ -7020,10 +7139,10 @@ X86TargetLowering::LowerBUILD_VECTORvXi1(SDValue Op, SelectionDAG &DAG) const { SDLoc dl(Op); if (ISD::isBuildVectorAllZeros(Op.getNode())) - return DAG.getTargetConstant(0, dl, VT); + return Op; if (ISD::isBuildVectorAllOnes(Op.getNode())) - return DAG.getTargetConstant(1, dl, VT); + return Op; if (ISD::isBuildVectorOfConstantSDNodes(Op.getNode())) { if (VT == MVT::v64i1 && !Subtarget.is64Bit()) { @@ -7034,8 +7153,8 @@ X86TargetLowering::LowerBUILD_VECTORvXi1(SDValue Op, SelectionDAG &DAG) const { DAG.getBuildVector(MVT::v32i1, dl, Op.getNode()->ops().slice(32, 32)); // We have to manually lower both halves so getNode doesn't try to // reassemble the build_vector. - Lower = LowerBUILD_VECTORvXi1(Lower, DAG); - Upper = LowerBUILD_VECTORvXi1(Upper, DAG); + Lower = LowerBUILD_VECTORvXi1(Lower, DAG, Subtarget); + Upper = LowerBUILD_VECTORvXi1(Upper, DAG, Subtarget); return DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v64i1, Lower, Upper); } SDValue Imm = ConvertI1VectorToInteger(Op, DAG); @@ -7272,7 +7391,8 @@ static SDValue ExpandHorizontalBinOp(const SDValue &V0, const SDValue &V1, /// are written to the parameters \p Opnd0 and \p Opnd1. static bool isAddSub(const BuildVectorSDNode *BV, const X86Subtarget &Subtarget, SelectionDAG &DAG, - SDValue &Opnd0, SDValue &Opnd1) { + SDValue &Opnd0, SDValue &Opnd1, + unsigned &NumExtracts) { MVT VT = BV->getSimpleValueType(0); if ((!Subtarget.hasSSE3() || (VT != MVT::v4f32 && VT != MVT::v2f64)) && @@ -7284,6 +7404,8 @@ static bool isAddSub(const BuildVectorSDNode *BV, SDValue InVec0 = DAG.getUNDEF(VT); SDValue InVec1 = DAG.getUNDEF(VT); + NumExtracts = 0; + // Odd-numbered elements in the input build vector are obtained from // adding two integer/float elements. // Even-numbered elements in the input build vector are obtained from @@ -7360,6 +7482,9 @@ static bool isAddSub(const BuildVectorSDNode *BV, // Update the pair of expected opcodes. std::swap(ExpectedOpcode, NextExpectedOpcode); + + // Increment the number of extractions done. + ++NumExtracts; } // Don't try to fold this build_vector into an ADDSUB if the inputs are undef. @@ -7372,9 +7497,9 @@ static bool isAddSub(const BuildVectorSDNode *BV, } /// Returns true if is possible to fold MUL and an idiom that has already been -/// recognized as ADDSUB(\p Opnd0, \p Opnd1) into FMADDSUB(x, y, \p Opnd1). -/// If (and only if) true is returned, the operands of FMADDSUB are written to -/// parameters \p Opnd0, \p Opnd1, \p Opnd2. +/// recognized as ADDSUB/SUBADD(\p Opnd0, \p Opnd1) into +/// FMADDSUB/FMSUBADD(x, y, \p Opnd1). If (and only if) true is returned, the +/// operands of FMADDSUB/FMSUBADD are written to parameters \p Opnd0, \p Opnd1, \p Opnd2. /// /// Prior to calling this function it should be known that there is some /// SDNode that potentially can be replaced with an X86ISD::ADDSUB operation @@ -7397,10 +7522,12 @@ static bool isAddSub(const BuildVectorSDNode *BV, /// recognized ADDSUB idiom with ADDSUB operation is that such replacement /// is illegal sometimes. E.g. 512-bit ADDSUB is not available, while 512-bit /// FMADDSUB is. -static bool isFMAddSub(const X86Subtarget &Subtarget, SelectionDAG &DAG, - SDValue &Opnd0, SDValue &Opnd1, SDValue &Opnd2) { - if (Opnd0.getOpcode() != ISD::FMUL || Opnd0->use_size() != 2 || - !Subtarget.hasAnyFMA()) +static bool isFMAddSubOrFMSubAdd(const X86Subtarget &Subtarget, + SelectionDAG &DAG, + SDValue &Opnd0, SDValue &Opnd1, SDValue &Opnd2, + unsigned ExpectedUses) { + if (Opnd0.getOpcode() != ISD::FMUL || + !Opnd0->hasNUsesOfValue(ExpectedUses, 0) || !Subtarget.hasAnyFMA()) return false; // FIXME: These checks must match the similar ones in @@ -7426,7 +7553,8 @@ static SDValue lowerToAddSubOrFMAddSub(const BuildVectorSDNode *BV, const X86Subtarget &Subtarget, SelectionDAG &DAG) { SDValue Opnd0, Opnd1; - if (!isAddSub(BV, Subtarget, DAG, Opnd0, Opnd1)) + unsigned NumExtracts; + if (!isAddSub(BV, Subtarget, DAG, Opnd0, Opnd1, NumExtracts)) return SDValue(); MVT VT = BV->getSimpleValueType(0); @@ -7434,7 +7562,9 @@ static SDValue lowerToAddSubOrFMAddSub(const BuildVectorSDNode *BV, // Try to generate X86ISD::FMADDSUB node here. SDValue Opnd2; - if (isFMAddSub(Subtarget, DAG, Opnd0, Opnd1, Opnd2)) + // TODO: According to coverage reports, the FMADDSUB transform is not + // triggered by any tests. + if (isFMAddSubOrFMSubAdd(Subtarget, DAG, Opnd0, Opnd1, Opnd2, NumExtracts)) return DAG.getNode(X86ISD::FMADDSUB, DL, VT, Opnd0, Opnd1, Opnd2); // Do not generate X86ISD::ADDSUB node for 512-bit types even though @@ -7604,6 +7734,10 @@ static SDValue lowerBuildVectorToBitOp(BuildVectorSDNode *Op, case ISD::AND: case ISD::XOR: case ISD::OR: + // Don't do this if the buildvector is a splat - we'd replace one + // constant with an entire vector. + if (Op->getSplatValue()) + return SDValue(); if (!TLI.isOperationLegalOrPromote(Opcode, VT)) return SDValue(); break; @@ -7658,6 +7792,117 @@ static SDValue materializeVectorConstant(SDValue Op, SelectionDAG &DAG, return SDValue(); } +// Tries to lower a BUILD_VECTOR composed of extract-extract chains that can be +// reasoned to be a permutation of a vector by indices in a non-constant vector. +// (build_vector (extract_elt V, (extract_elt I, 0)), +// (extract_elt V, (extract_elt I, 1)), +// ... +// -> +// (vpermv I, V) +// +// TODO: Handle undefs +// TODO: Utilize pshufb and zero mask blending to support more efficient +// construction of vectors with constant-0 elements. +// TODO: Use smaller-element vectors of same width, and "interpolate" the indices, +// when no native operation available. +static SDValue +LowerBUILD_VECTORAsVariablePermute(SDValue V, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + // Look for VPERMV and PSHUFB opportunities. + MVT VT = V.getSimpleValueType(); + switch (VT.SimpleTy) { + default: + return SDValue(); + case MVT::v16i8: + if (!Subtarget.hasSSE3()) + return SDValue(); + break; + case MVT::v8f32: + case MVT::v8i32: + if (!Subtarget.hasAVX2()) + return SDValue(); + break; + case MVT::v4i64: + case MVT::v4f64: + if (!Subtarget.hasVLX()) + return SDValue(); + break; + case MVT::v16f32: + case MVT::v8f64: + case MVT::v16i32: + case MVT::v8i64: + if (!Subtarget.hasAVX512()) + return SDValue(); + break; + case MVT::v32i16: + if (!Subtarget.hasBWI()) + return SDValue(); + break; + case MVT::v8i16: + case MVT::v16i16: + if (!Subtarget.hasVLX() || !Subtarget.hasBWI()) + return SDValue(); + break; + case MVT::v64i8: + if (!Subtarget.hasVBMI()) + return SDValue(); + break; + case MVT::v32i8: + if (!Subtarget.hasVLX() || !Subtarget.hasVBMI()) + return SDValue(); + break; + } + SDValue SrcVec, IndicesVec; + // Check for a match of the permute source vector and permute index elements. + // This is done by checking that the i-th build_vector operand is of the form: + // (extract_elt SrcVec, (extract_elt IndicesVec, i)). + for (unsigned Idx = 0, E = V.getNumOperands(); Idx != E; ++Idx) { + SDValue Op = V.getOperand(Idx); + if (Op.getOpcode() != ISD::EXTRACT_VECTOR_ELT) + return SDValue(); + + // If this is the first extract encountered in V, set the source vector, + // otherwise verify the extract is from the previously defined source + // vector. + if (!SrcVec) + SrcVec = Op.getOperand(0); + else if (SrcVec != Op.getOperand(0)) + return SDValue(); + SDValue ExtractedIndex = Op->getOperand(1); + // Peek through extends. + if (ExtractedIndex.getOpcode() == ISD::ZERO_EXTEND || + ExtractedIndex.getOpcode() == ISD::SIGN_EXTEND) + ExtractedIndex = ExtractedIndex.getOperand(0); + if (ExtractedIndex.getOpcode() != ISD::EXTRACT_VECTOR_ELT) + return SDValue(); + + // If this is the first extract from the index vector candidate, set the + // indices vector, otherwise verify the extract is from the previously + // defined indices vector. + if (!IndicesVec) + IndicesVec = ExtractedIndex.getOperand(0); + else if (IndicesVec != ExtractedIndex.getOperand(0)) + return SDValue(); + + auto *PermIdx = dyn_cast<ConstantSDNode>(ExtractedIndex.getOperand(1)); + if (!PermIdx || PermIdx->getZExtValue() != Idx) + return SDValue(); + } + MVT IndicesVT = VT; + if (VT.isFloatingPoint()) + IndicesVT = MVT::getVectorVT(MVT::getIntegerVT(VT.getScalarSizeInBits()), + VT.getVectorNumElements()); + IndicesVec = DAG.getZExtOrTrunc(IndicesVec, SDLoc(IndicesVec), IndicesVT); + if (SrcVec.getValueSizeInBits() < IndicesVT.getSizeInBits()) { + SrcVec = + DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(SrcVec), VT, DAG.getUNDEF(VT), + SrcVec, DAG.getIntPtrConstant(0, SDLoc(SrcVec))); + } + if (VT == MVT::v16i8) + return DAG.getNode(X86ISD::PSHUFB, SDLoc(V), VT, SrcVec, IndicesVec); + return DAG.getNode(X86ISD::VPERMV, SDLoc(V), VT, IndicesVec, SrcVec); +} + SDValue X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const { SDLoc dl(Op); @@ -7668,12 +7913,14 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const { // Generate vectors for predicate vectors. if (VT.getVectorElementType() == MVT::i1 && Subtarget.hasAVX512()) - return LowerBUILD_VECTORvXi1(Op, DAG); + return LowerBUILD_VECTORvXi1(Op, DAG, Subtarget); if (SDValue VectorConstant = materializeVectorConstant(Op, DAG, Subtarget)) return VectorConstant; BuildVectorSDNode *BV = cast<BuildVectorSDNode>(Op.getNode()); + // TODO: Support FMSUBADD here if we ever get tests for the FMADDSUB + // transform here. if (SDValue AddSub = lowerToAddSubOrFMAddSub(BV, Subtarget, DAG)) return AddSub; if (SDValue HorizontalOp = LowerToHorizontalOp(BV, Subtarget, DAG)) @@ -7690,14 +7937,16 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const { uint64_t NonZeros = 0; bool IsAllConstants = true; SmallSet<SDValue, 8> Values; + unsigned NumConstants = NumElems; for (unsigned i = 0; i < NumElems; ++i) { SDValue Elt = Op.getOperand(i); if (Elt.isUndef()) continue; Values.insert(Elt); - if (Elt.getOpcode() != ISD::Constant && - Elt.getOpcode() != ISD::ConstantFP) + if (!isa<ConstantSDNode>(Elt) && !isa<ConstantFPSDNode>(Elt)) { IsAllConstants = false; + NumConstants--; + } if (X86::isZeroNode(Elt)) NumZero++; else { @@ -7711,6 +7960,52 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const { if (NumNonZero == 0) return DAG.getUNDEF(VT); + // If we are inserting one variable into a vector of non-zero constants, try + // to avoid loading each constant element as a scalar. Load the constants as a + // vector and then insert the variable scalar element. If insertion is not + // supported, we assume that we will fall back to a shuffle to get the scalar + // blended with the constants. Insertion into a zero vector is handled as a + // special-case somewhere below here. + LLVMContext &Context = *DAG.getContext(); + if (NumConstants == NumElems - 1 && NumNonZero != 1 && + (isOperationLegalOrCustom(ISD::INSERT_VECTOR_ELT, VT) || + isOperationLegalOrCustom(ISD::VECTOR_SHUFFLE, VT))) { + // Create an all-constant vector. The variable element in the old + // build vector is replaced by undef in the constant vector. Save the + // variable scalar element and its index for use in the insertelement. + Type *EltType = Op.getValueType().getScalarType().getTypeForEVT(Context); + SmallVector<Constant *, 16> ConstVecOps(NumElems, UndefValue::get(EltType)); + SDValue VarElt; + SDValue InsIndex; + for (unsigned i = 0; i != NumElems; ++i) { + SDValue Elt = Op.getOperand(i); + if (auto *C = dyn_cast<ConstantSDNode>(Elt)) + ConstVecOps[i] = ConstantInt::get(Context, C->getAPIntValue()); + else if (auto *C = dyn_cast<ConstantFPSDNode>(Elt)) + ConstVecOps[i] = ConstantFP::get(Context, C->getValueAPF()); + else if (!Elt.isUndef()) { + assert(!VarElt.getNode() && !InsIndex.getNode() && + "Expected one variable element in this vector"); + VarElt = Elt; + InsIndex = DAG.getConstant(i, dl, getVectorIdxTy(DAG.getDataLayout())); + } + } + Constant *CV = ConstantVector::get(ConstVecOps); + SDValue DAGConstVec = DAG.getConstantPool(CV, VT); + + // The constants we just created may not be legal (eg, floating point). We + // must lower the vector right here because we can not guarantee that we'll + // legalize it before loading it. This is also why we could not just create + // a new build vector here. If the build vector contains illegal constants, + // it could get split back up into a series of insert elements. + // TODO: Improve this by using shorter loads with broadcast/VZEXT_LOAD. + SDValue LegalDAGConstVec = LowerConstantPool(DAGConstVec, DAG); + MachineFunction &MF = DAG.getMachineFunction(); + MachinePointerInfo MPI = MachinePointerInfo::getConstantPool(MF); + SDValue Ld = DAG.getLoad(VT, dl, DAG.getEntryNode(), LegalDAGConstVec, MPI); + return DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, VT, Ld, VarElt, InsIndex); + } + // Special case for single non-zero, non-undef, element. if (NumNonZero == 1) { unsigned Idx = countTrailingZeros(NonZeros); @@ -7825,6 +8120,9 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const { if (IsAllConstants) return SDValue(); + if (SDValue V = LowerBUILD_VECTORAsVariablePermute(Op, DAG, Subtarget)) + return V; + // See if we can use a vector load to get all of the elements. if (VT.is128BitVector() || VT.is256BitVector() || VT.is512BitVector()) { SmallVector<SDValue, 64> Ops(Op->op_begin(), Op->op_begin() + NumElems); @@ -7836,15 +8134,13 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const { // For AVX-length vectors, build the individual 128-bit pieces and use // shuffles to put them in place. if (VT.is256BitVector() || VT.is512BitVector()) { - SmallVector<SDValue, 64> Ops(Op->op_begin(), Op->op_begin() + NumElems); - - EVT HVT = EVT::getVectorVT(*DAG.getContext(), ExtVT, NumElems/2); + EVT HVT = EVT::getVectorVT(Context, ExtVT, NumElems/2); // Build both the lower and upper subvector. SDValue Lower = - DAG.getBuildVector(HVT, dl, makeArrayRef(&Ops[0], NumElems / 2)); + DAG.getBuildVector(HVT, dl, Op->ops().slice(0, NumElems / 2)); SDValue Upper = DAG.getBuildVector( - HVT, dl, makeArrayRef(&Ops[NumElems / 2], NumElems / 2)); + HVT, dl, Op->ops().slice(NumElems / 2, NumElems /2)); // Recreate the wider vector with the lower and upper part. if (VT.is256BitVector()) @@ -7892,8 +8188,8 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const { } for (unsigned i = 0; i < 2; ++i) { - switch ((NonZeros & (0x3 << i*2)) >> (i*2)) { - default: break; + switch ((NonZeros >> (i*2)) & 0x3) { + default: llvm_unreachable("Unexpected NonZero count"); case 0: Ops[i] = Ops[i*2]; // Must be a zero vector. break; @@ -7920,57 +8216,56 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const { return DAG.getVectorShuffle(VT, dl, Ops[0], Ops[1], MaskVec); } - if (Values.size() > 1 && VT.is128BitVector()) { - // Check for a build vector from mostly shuffle plus few inserting. - if (SDValue Sh = buildFromShuffleMostly(Op, DAG)) - return Sh; + assert(Values.size() > 1 && "Expected non-undef and non-splat vector"); - // For SSE 4.1, use insertps to put the high elements into the low element. - if (Subtarget.hasSSE41()) { - SDValue Result; - if (!Op.getOperand(0).isUndef()) - Result = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, VT, Op.getOperand(0)); - else - Result = DAG.getUNDEF(VT); + // Check for a build vector from mostly shuffle plus few inserting. + if (SDValue Sh = buildFromShuffleMostly(Op, DAG)) + return Sh; - for (unsigned i = 1; i < NumElems; ++i) { - if (Op.getOperand(i).isUndef()) continue; - Result = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, VT, Result, - Op.getOperand(i), DAG.getIntPtrConstant(i, dl)); - } - return Result; - } + // For SSE 4.1, use insertps to put the high elements into the low element. + if (Subtarget.hasSSE41()) { + SDValue Result; + if (!Op.getOperand(0).isUndef()) + Result = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, VT, Op.getOperand(0)); + else + Result = DAG.getUNDEF(VT); - // Otherwise, expand into a number of unpckl*, start by extending each of - // our (non-undef) elements to the full vector width with the element in the - // bottom slot of the vector (which generates no code for SSE). - SmallVector<SDValue, 8> Ops(NumElems); - for (unsigned i = 0; i < NumElems; ++i) { - if (!Op.getOperand(i).isUndef()) - Ops[i] = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, VT, Op.getOperand(i)); - else - Ops[i] = DAG.getUNDEF(VT); + for (unsigned i = 1; i < NumElems; ++i) { + if (Op.getOperand(i).isUndef()) continue; + Result = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, VT, Result, + Op.getOperand(i), DAG.getIntPtrConstant(i, dl)); } + return Result; + } + + // Otherwise, expand into a number of unpckl*, start by extending each of + // our (non-undef) elements to the full vector width with the element in the + // bottom slot of the vector (which generates no code for SSE). + SmallVector<SDValue, 8> Ops(NumElems); + for (unsigned i = 0; i < NumElems; ++i) { + if (!Op.getOperand(i).isUndef()) + Ops[i] = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, VT, Op.getOperand(i)); + else + Ops[i] = DAG.getUNDEF(VT); + } - // Next, we iteratively mix elements, e.g. for v4f32: - // Step 1: unpcklps 0, 1 ==> X: <?, ?, 1, 0> - // : unpcklps 2, 3 ==> Y: <?, ?, 3, 2> - // Step 2: unpcklpd X, Y ==> <3, 2, 1, 0> - for (unsigned Scale = 1; Scale < NumElems; Scale *= 2) { - // Generate scaled UNPCKL shuffle mask. - SmallVector<int, 16> Mask; - for(unsigned i = 0; i != Scale; ++i) - Mask.push_back(i); - for (unsigned i = 0; i != Scale; ++i) - Mask.push_back(NumElems+i); - Mask.append(NumElems - Mask.size(), SM_SentinelUndef); + // Next, we iteratively mix elements, e.g. for v4f32: + // Step 1: unpcklps 0, 1 ==> X: <?, ?, 1, 0> + // : unpcklps 2, 3 ==> Y: <?, ?, 3, 2> + // Step 2: unpcklpd X, Y ==> <3, 2, 1, 0> + for (unsigned Scale = 1; Scale < NumElems; Scale *= 2) { + // Generate scaled UNPCKL shuffle mask. + SmallVector<int, 16> Mask; + for(unsigned i = 0; i != Scale; ++i) + Mask.push_back(i); + for (unsigned i = 0; i != Scale; ++i) + Mask.push_back(NumElems+i); + Mask.append(NumElems - Mask.size(), SM_SentinelUndef); - for (unsigned i = 0, e = NumElems / (2 * Scale); i != e; ++i) - Ops[i] = DAG.getVectorShuffle(VT, dl, Ops[2*i], Ops[(2*i)+1], Mask); - } - return Ops[0]; + for (unsigned i = 0, e = NumElems / (2 * Scale); i != e; ++i) + Ops[i] = DAG.getVectorShuffle(VT, dl, Ops[2*i], Ops[(2*i)+1], Mask); } - return SDValue(); + return Ops[0]; } // 256-bit AVX can use the vinsertf128 instruction @@ -8060,87 +8355,74 @@ static SDValue LowerCONCAT_VECTORSvXi1(SDValue Op, SelectionDAG & DAG) { SDLoc dl(Op); MVT ResVT = Op.getSimpleValueType(); - unsigned NumOfOperands = Op.getNumOperands(); + unsigned NumOperands = Op.getNumOperands(); - assert(isPowerOf2_32(NumOfOperands) && + assert(NumOperands > 1 && isPowerOf2_32(NumOperands) && "Unexpected number of operands in CONCAT_VECTORS"); // If this node promotes - by concatenating zeroes - the type of the result // of a node with instruction that zeroes all upper (irrelevant) bits of the // output register, mark it as legal and catch the pattern in instruction - // selection to avoid emitting extra insturctions (for zeroing upper bits). + // selection to avoid emitting extra instructions (for zeroing upper bits). if (SDValue Promoted = isTypePromotionOfi1ZeroUpBits(Op)) { - SDValue ZeroC = DAG.getConstant(0, dl, MVT::i64); - SDValue AllZeros = DAG.getSplatBuildVector(ResVT, dl, ZeroC); + SDValue ZeroC = DAG.getIntPtrConstant(0, dl); + SDValue AllZeros = getZeroVector(ResVT, Subtarget, DAG, dl); return DAG.getNode(ISD::INSERT_SUBVECTOR, dl, ResVT, AllZeros, Promoted, ZeroC); } - SDValue Undef = DAG.getUNDEF(ResVT); - if (NumOfOperands > 2) { - // Specialize the cases when all, or all but one, of the operands are undef. - unsigned NumOfDefinedOps = 0; - unsigned OpIdx = 0; - for (unsigned i = 0; i < NumOfOperands; i++) - if (!Op.getOperand(i).isUndef()) { - NumOfDefinedOps++; - OpIdx = i; - } - if (NumOfDefinedOps == 0) - return Undef; - if (NumOfDefinedOps == 1) { - unsigned SubVecNumElts = - Op.getOperand(OpIdx).getValueType().getVectorNumElements(); - SDValue IdxVal = DAG.getIntPtrConstant(SubVecNumElts * OpIdx, dl); - return DAG.getNode(ISD::INSERT_SUBVECTOR, dl, ResVT, Undef, - Op.getOperand(OpIdx), IdxVal); + unsigned NumZero = 0; + unsigned NumNonZero = 0; + uint64_t NonZeros = 0; + for (unsigned i = 0; i != NumOperands; ++i) { + SDValue SubVec = Op.getOperand(i); + if (SubVec.isUndef()) + continue; + if (ISD::isBuildVectorAllZeros(SubVec.getNode())) + ++NumZero; + else { + assert(i < sizeof(NonZeros) * CHAR_BIT); // Ensure the shift is in range. + NonZeros |= (uint64_t)1 << i; + ++NumNonZero; } + } + + // If there are zero or one non-zeros we can handle this very simply. + if (NumNonZero <= 1) { + SDValue Vec = NumZero ? getZeroVector(ResVT, Subtarget, DAG, dl) + : DAG.getUNDEF(ResVT); + if (!NumNonZero) + return Vec; + unsigned Idx = countTrailingZeros(NonZeros); + SDValue SubVec = Op.getOperand(Idx); + unsigned SubVecNumElts = SubVec.getSimpleValueType().getVectorNumElements(); + return DAG.getNode(ISD::INSERT_SUBVECTOR, dl, ResVT, Vec, SubVec, + DAG.getIntPtrConstant(Idx * SubVecNumElts, dl)); + } + + if (NumOperands > 2) { MVT HalfVT = MVT::getVectorVT(ResVT.getVectorElementType(), ResVT.getVectorNumElements()/2); - SmallVector<SDValue, 2> Ops; - for (unsigned i = 0; i < NumOfOperands/2; i++) - Ops.push_back(Op.getOperand(i)); - SDValue Lo = DAG.getNode(ISD::CONCAT_VECTORS, dl, HalfVT, Ops); - Ops.clear(); - for (unsigned i = NumOfOperands/2; i < NumOfOperands; i++) - Ops.push_back(Op.getOperand(i)); - SDValue Hi = DAG.getNode(ISD::CONCAT_VECTORS, dl, HalfVT, Ops); + ArrayRef<SDUse> Ops = Op->ops(); + SDValue Lo = DAG.getNode(ISD::CONCAT_VECTORS, dl, HalfVT, + Ops.slice(0, NumOperands/2)); + SDValue Hi = DAG.getNode(ISD::CONCAT_VECTORS, dl, HalfVT, + Ops.slice(NumOperands/2)); return DAG.getNode(ISD::CONCAT_VECTORS, dl, ResVT, Lo, Hi); } - // 2 operands - SDValue V1 = Op.getOperand(0); - SDValue V2 = Op.getOperand(1); - unsigned NumElems = ResVT.getVectorNumElements(); - assert(V1.getValueType() == V2.getValueType() && - V1.getValueType().getVectorNumElements() == NumElems/2 && - "Unexpected operands in CONCAT_VECTORS"); + assert(NumNonZero == 2 && "Simple cases not handled?"); - if (ResVT.getSizeInBits() >= 16) + if (ResVT.getVectorNumElements() >= 16) return Op; // The operation is legal with KUNPCK - bool IsZeroV1 = ISD::isBuildVectorAllZeros(V1.getNode()); - bool IsZeroV2 = ISD::isBuildVectorAllZeros(V2.getNode()); - SDValue ZeroVec = getZeroVector(ResVT, Subtarget, DAG, dl); - if (IsZeroV1 && IsZeroV2) - return ZeroVec; - - SDValue ZeroIdx = DAG.getIntPtrConstant(0, dl); - if (V2.isUndef()) - return DAG.getNode(ISD::INSERT_SUBVECTOR, dl, ResVT, Undef, V1, ZeroIdx); - if (IsZeroV2) - return DAG.getNode(ISD::INSERT_SUBVECTOR, dl, ResVT, ZeroVec, V1, ZeroIdx); - - SDValue IdxVal = DAG.getIntPtrConstant(NumElems/2, dl); - if (V1.isUndef()) - return DAG.getNode(ISD::INSERT_SUBVECTOR, dl, ResVT, Undef, V2, IdxVal); - - if (IsZeroV1) - return DAG.getNode(ISD::INSERT_SUBVECTOR, dl, ResVT, ZeroVec, V2, IdxVal); - - V1 = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, ResVT, Undef, V1, ZeroIdx); - return DAG.getNode(ISD::INSERT_SUBVECTOR, dl, ResVT, V1, V2, IdxVal); + SDValue Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, ResVT, + DAG.getUNDEF(ResVT), Op.getOperand(0), + DAG.getIntPtrConstant(0, dl)); + unsigned NumElems = ResVT.getVectorNumElements(); + return DAG.getNode(ISD::INSERT_SUBVECTOR, dl, ResVT, Vec, Op.getOperand(1), + DAG.getIntPtrConstant(NumElems/2, dl)); } static SDValue LowerCONCAT_VECTORS(SDValue Op, @@ -8723,6 +9005,76 @@ static SDValue lowerVectorShuffleWithUNPCK(const SDLoc &DL, MVT VT, return SDValue(); } +// X86 has dedicated pack instructions that can handle specific truncation +// operations: PACKSS and PACKUS. +static bool matchVectorShuffleWithPACK(MVT VT, MVT &SrcVT, SDValue &V1, + SDValue &V2, unsigned &PackOpcode, + ArrayRef<int> TargetMask, + SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + unsigned NumElts = VT.getVectorNumElements(); + unsigned BitSize = VT.getScalarSizeInBits(); + MVT PackSVT = MVT::getIntegerVT(BitSize * 2); + MVT PackVT = MVT::getVectorVT(PackSVT, NumElts / 2); + + auto MatchPACK = [&](SDValue N1, SDValue N2) { + SDValue VV1 = DAG.getBitcast(PackVT, N1); + SDValue VV2 = DAG.getBitcast(PackVT, N2); + if ((N1.isUndef() || DAG.ComputeNumSignBits(VV1) > BitSize) && + (N2.isUndef() || DAG.ComputeNumSignBits(VV2) > BitSize)) { + V1 = VV1; + V2 = VV2; + SrcVT = PackVT; + PackOpcode = X86ISD::PACKSS; + return true; + } + + if (Subtarget.hasSSE41() || PackSVT == MVT::i16) { + APInt ZeroMask = APInt::getHighBitsSet(BitSize * 2, BitSize); + if ((N1.isUndef() || DAG.MaskedValueIsZero(VV1, ZeroMask)) && + (N2.isUndef() || DAG.MaskedValueIsZero(VV2, ZeroMask))) { + V1 = VV1; + V2 = VV2; + SrcVT = PackVT; + PackOpcode = X86ISD::PACKUS; + return true; + } + } + + return false; + }; + + // Try binary shuffle. + SmallVector<int, 32> BinaryMask; + createPackShuffleMask(VT, BinaryMask, false); + if (isTargetShuffleEquivalent(TargetMask, BinaryMask)) + if (MatchPACK(V1, V2)) + return true; + + // Try unary shuffle. + SmallVector<int, 32> UnaryMask; + createPackShuffleMask(VT, UnaryMask, true); + if (isTargetShuffleEquivalent(TargetMask, UnaryMask)) + if (MatchPACK(V1, V1)) + return true; + + return false; +} + +static SDValue lowerVectorShuffleWithPACK(const SDLoc &DL, MVT VT, + ArrayRef<int> Mask, SDValue V1, + SDValue V2, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + MVT PackVT; + unsigned PackOpcode; + if (matchVectorShuffleWithPACK(VT, PackVT, V1, V2, PackOpcode, Mask, DAG, + Subtarget)) + return DAG.getNode(PackOpcode, DL, VT, DAG.getBitcast(PackVT, V1), + DAG.getBitcast(PackVT, V2)); + + return SDValue(); +} + /// \brief Try to emit a bitmask instruction for a shuffle. /// /// This handles cases where we can model a blend exactly as a bitmask due to @@ -8834,7 +9186,8 @@ static bool matchVectorShuffleAsBlend(SDValue V1, SDValue V2, return true; } -uint64_t scaleVectorShuffleBlendMask(uint64_t BlendMask, int Size, int Scale) { +static uint64_t scaleVectorShuffleBlendMask(uint64_t BlendMask, int Size, + int Scale) { uint64_t ScaledMask = 0; for (int i = 0; i != Size; ++i) if (BlendMask & (1ull << i)) @@ -9869,7 +10222,7 @@ static SDValue lowerVectorShuffleAsElementInsertion( return SDValue(); // Zero-extend directly to i32. - ExtVT = MVT::v4i32; + ExtVT = MVT::getVectorVT(MVT::i32, ExtVT.getSizeInBits() / 32); V2S = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, V2S); } V2 = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, ExtVT, V2S); @@ -9891,10 +10244,7 @@ static SDValue lowerVectorShuffleAsElementInsertion( V1Mask[V2Index] = -1; if (!isNoopShuffleMask(V1Mask)) return SDValue(); - // This is essentially a special case blend operation, but if we have - // general purpose blend operations, they are always faster. Bail and let - // the rest of the lowering handle these as blends. - if (Subtarget.hasSSE41()) + if (!VT.is128BitVector()) return SDValue(); // Otherwise, use MOVSD or MOVSS. @@ -10005,7 +10355,9 @@ static SDValue lowerVectorShuffleAsBroadcast(const SDLoc &DL, MVT VT, // With MOVDDUP (v2f64) we can broadcast from a register or a load, otherwise // we can only broadcast from a register with AVX2. unsigned NumElts = Mask.size(); - unsigned Opcode = VT == MVT::v2f64 ? X86ISD::MOVDDUP : X86ISD::VBROADCAST; + unsigned Opcode = (VT == MVT::v2f64 && !Subtarget.hasAVX2()) + ? X86ISD::MOVDDUP + : X86ISD::VBROADCAST; bool BroadcastFromReg = (Opcode == X86ISD::MOVDDUP) || Subtarget.hasAVX2(); // Check that the mask is a broadcast. @@ -10030,9 +10382,16 @@ static SDValue lowerVectorShuffleAsBroadcast(const SDLoc &DL, MVT VT, for (;;) { switch (V.getOpcode()) { case ISD::BITCAST: { + // Peek through bitcasts as long as BroadcastIdx can be adjusted. SDValue VSrc = V.getOperand(0); - MVT SrcVT = VSrc.getSimpleValueType(); - if (VT.getScalarSizeInBits() != SrcVT.getScalarSizeInBits()) + unsigned NumEltBits = V.getScalarValueSizeInBits(); + unsigned NumSrcBits = VSrc.getScalarValueSizeInBits(); + if ((NumEltBits % NumSrcBits) == 0) + BroadcastIdx *= (NumEltBits / NumSrcBits); + else if ((NumSrcBits % NumEltBits) == 0 && + (BroadcastIdx % (NumSrcBits / NumEltBits)) == 0) + BroadcastIdx /= (NumSrcBits / NumEltBits); + else break; V = VSrc; continue; @@ -10064,6 +10423,23 @@ static SDValue lowerVectorShuffleAsBroadcast(const SDLoc &DL, MVT VT, break; } + // Ensure the source vector and BroadcastIdx are for a suitable type. + if (VT.getScalarSizeInBits() != V.getScalarValueSizeInBits()) { + unsigned NumEltBits = VT.getScalarSizeInBits(); + unsigned NumSrcBits = V.getScalarValueSizeInBits(); + if ((NumSrcBits % NumEltBits) == 0) + BroadcastIdx *= (NumSrcBits / NumEltBits); + else if ((NumEltBits % NumSrcBits) == 0 && + (BroadcastIdx % (NumEltBits / NumSrcBits)) == 0) + BroadcastIdx /= (NumEltBits / NumSrcBits); + else + return SDValue(); + + unsigned NumSrcElts = V.getValueSizeInBits() / NumEltBits; + MVT SrcVT = MVT::getVectorVT(VT.getScalarType(), NumSrcElts); + V = DAG.getBitcast(SrcVT, V); + } + // Check if this is a broadcast of a scalar. We special case lowering // for scalars so that we can more effectively fold with loads. // First, look through bitcast: if the original value has a larger element @@ -10091,7 +10467,9 @@ static SDValue lowerVectorShuffleAsBroadcast(const SDLoc &DL, MVT VT, // 32-bit targets need to load i64 as a f64 and then bitcast the result. if (!Subtarget.is64Bit() && VT.getScalarType() == MVT::i64) { BroadcastVT = MVT::getVectorVT(MVT::f64, VT.getVectorNumElements()); - Opcode = (BroadcastVT.is128BitVector() ? X86ISD::MOVDDUP : Opcode); + Opcode = (BroadcastVT.is128BitVector() && !Subtarget.hasAVX2()) + ? X86ISD::MOVDDUP + : Opcode; } // If we are broadcasting a load that is only used by the shuffle @@ -10127,15 +10505,11 @@ static SDValue lowerVectorShuffleAsBroadcast(const SDLoc &DL, MVT VT, // The shuffle input might have been a bitcast we looked through; look at // the original input vector. Emit an EXTRACT_SUBVECTOR of that type; we'll // later bitcast it to BroadcastVT. - MVT SrcVT = V.getSimpleValueType(); - assert(SrcVT.getScalarSizeInBits() == BroadcastVT.getScalarSizeInBits() && + assert(V.getScalarValueSizeInBits() == BroadcastVT.getScalarSizeInBits() && "Unexpected vector element size"); - assert((SrcVT.is256BitVector() || SrcVT.is512BitVector()) && + assert((V.getValueSizeInBits() == 256 || V.getValueSizeInBits() == 512) && "Unexpected vector size"); - - MVT ExtVT = MVT::getVectorVT(SrcVT.getScalarType(), 128 / EltSize); - V = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ExtVT, V, - DAG.getIntPtrConstant(BroadcastIdx, DL)); + V = extract128BitVector(V, BroadcastIdx, DAG, DL); } if (Opcode == X86ISD::MOVDDUP && !V.getValueType().isVector()) @@ -10165,9 +10539,13 @@ static SDValue lowerVectorShuffleAsBroadcast(const SDLoc &DL, MVT VT, // We only support broadcasting from 128-bit vectors to minimize the // number of patterns we need to deal with in isel. So extract down to - // 128-bits. - if (SrcVT.getSizeInBits() > 128) - V = extract128BitVector(V, 0, DAG, DL); + // 128-bits, removing as many bitcasts as possible. + if (SrcVT.getSizeInBits() > 128) { + MVT ExtVT = MVT::getVectorVT(SrcVT.getScalarType(), + 128 / SrcVT.getScalarSizeInBits()); + V = extract128BitVector(peekThroughBitcasts(V), 0, DAG, DL); + V = DAG.getBitcast(ExtVT, V); + } return DAG.getBitcast(VT, DAG.getNode(Opcode, DL, BroadcastVT, V)); } @@ -10517,26 +10895,6 @@ static SDValue lowerV2I64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, assert(Mask[0] < 2 && "We sort V1 to be the first input."); assert(Mask[1] >= 2 && "We sort V2 to be the second input."); - // If we have a blend of two same-type PACKUS operations and the blend aligns - // with the low and high halves, we can just merge the PACKUS operations. - // This is particularly important as it lets us merge shuffles that this - // routine itself creates. - auto GetPackNode = [](SDValue V) { - V = peekThroughBitcasts(V); - return V.getOpcode() == X86ISD::PACKUS ? V : SDValue(); - }; - if (SDValue V1Pack = GetPackNode(V1)) - if (SDValue V2Pack = GetPackNode(V2)) { - EVT PackVT = V1Pack.getValueType(); - if (PackVT == V2Pack.getValueType()) - return DAG.getBitcast(MVT::v2i64, - DAG.getNode(X86ISD::PACKUS, DL, PackVT, - Mask[0] == 0 ? V1Pack.getOperand(0) - : V1Pack.getOperand(1), - Mask[1] == 2 ? V2Pack.getOperand(0) - : V2Pack.getOperand(1))); - } - // Try to use shift instructions. if (SDValue Shift = lowerVectorShuffleAsShift(DL, MVT::v2i64, V1, V2, Mask, Zeroable, Subtarget, DAG)) @@ -10569,10 +10927,16 @@ static SDValue lowerV2I64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // Try to use byte rotation instructions. // Its more profitable for pre-SSSE3 to use shuffles/unpacks. - if (Subtarget.hasSSSE3()) + if (Subtarget.hasSSSE3()) { + if (Subtarget.hasVLX()) + if (SDValue Rotate = lowerVectorShuffleAsRotate(DL, MVT::v2i64, V1, V2, + Mask, Subtarget, DAG)) + return Rotate; + if (SDValue Rotate = lowerVectorShuffleAsByteRotate( DL, MVT::v2i64, V1, V2, Mask, Subtarget, DAG)) return Rotate; + } // If we have direct support for blends, we should lower by decomposing into // a permute. That will be faster than the domain cross. @@ -10736,6 +11100,15 @@ static SDValue lowerV4F32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, getV4X86ShuffleImm8ForMask(Mask, DL, DAG)); } + // Use MOVLHPS/MOVHLPS to simulate unary shuffles. These are only valid + // in SSE1 because otherwise they are widened to v2f64 and never get here. + if (!Subtarget.hasSSE2()) { + if (isShuffleEquivalent(V1, V2, Mask, {0, 1, 0, 1})) + return DAG.getNode(X86ISD::MOVLHPS, DL, MVT::v4f32, V1, V1); + if (isShuffleEquivalent(V1, V2, Mask, {2, 3, 2, 3})) + return DAG.getNode(X86ISD::MOVHLPS, DL, MVT::v4f32, V1, V1); + } + // Otherwise, use a straight shuffle of a single input vector. We pass the // input vector to both operands to simulate this with a SHUFPS. return DAG.getNode(X86ISD::SHUFP, DL, MVT::v4f32, V1, V1, @@ -10768,11 +11141,14 @@ static SDValue lowerV4F32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, return BlendPerm; } - // Use low/high mov instructions. - if (isShuffleEquivalent(V1, V2, Mask, {0, 1, 4, 5})) - return DAG.getNode(X86ISD::MOVLHPS, DL, MVT::v4f32, V1, V2); - if (isShuffleEquivalent(V1, V2, Mask, {2, 3, 6, 7})) - return DAG.getNode(X86ISD::MOVHLPS, DL, MVT::v4f32, V2, V1); + // Use low/high mov instructions. These are only valid in SSE1 because + // otherwise they are widened to v2f64 and never get here. + if (!Subtarget.hasSSE2()) { + if (isShuffleEquivalent(V1, V2, Mask, {0, 1, 4, 5})) + return DAG.getNode(X86ISD::MOVLHPS, DL, MVT::v4f32, V1, V2); + if (isShuffleEquivalent(V1, V2, Mask, {2, 3, 6, 7})) + return DAG.getNode(X86ISD::MOVHLPS, DL, MVT::v4f32, V2, V1); + } // Use dedicated unpack instructions for masks that match their pattern. if (SDValue V = @@ -10857,10 +11233,16 @@ static SDValue lowerV4I32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // Try to use byte rotation instructions. // Its more profitable for pre-SSSE3 to use shuffles/unpacks. - if (Subtarget.hasSSSE3()) + if (Subtarget.hasSSSE3()) { + if (Subtarget.hasVLX()) + if (SDValue Rotate = lowerVectorShuffleAsRotate(DL, MVT::v4i32, V1, V2, + Mask, Subtarget, DAG)) + return Rotate; + if (SDValue Rotate = lowerVectorShuffleAsByteRotate( DL, MVT::v4i32, V1, V2, Mask, Subtarget, DAG)) return Rotate; + } // Assume that a single SHUFPS is faster than an alternative sequence of // multiple instructions (even if the CPU has a domain penalty). @@ -10915,6 +11297,20 @@ static SDValue lowerV8I16GeneralSingleInputVectorShuffle( MutableArrayRef<int> LoMask = Mask.slice(0, 4); MutableArrayRef<int> HiMask = Mask.slice(4, 4); + // Attempt to directly match PSHUFLW or PSHUFHW. + if (isUndefOrInRange(LoMask, 0, 4) && + isSequentialOrUndefInRange(HiMask, 0, 4, 4)) { + return DAG.getNode(X86ISD::PSHUFLW, DL, VT, V, + getV4X86ShuffleImm8ForMask(LoMask, DL, DAG)); + } + if (isUndefOrInRange(HiMask, 4, 8) && + isSequentialOrUndefInRange(LoMask, 0, 4, 0)) { + for (int i = 0; i != 4; ++i) + HiMask[i] = (HiMask[i] < 0 ? HiMask[i] : (HiMask[i] - 4)); + return DAG.getNode(X86ISD::PSHUFHW, DL, VT, V, + getV4X86ShuffleImm8ForMask(HiMask, DL, DAG)); + } + SmallVector<int, 4> LoInputs; copy_if(LoMask, std::back_inserter(LoInputs), [](int M) { return M >= 0; }); std::sort(LoInputs.begin(), LoInputs.end()); @@ -10934,13 +11330,11 @@ static SDValue lowerV8I16GeneralSingleInputVectorShuffle( MutableArrayRef<int> HToLInputs(LoInputs.data() + NumLToL, NumHToL); MutableArrayRef<int> HToHInputs(HiInputs.data() + NumLToH, NumHToH); - // If we are splatting two values from one half - one to each half, then - // we can shuffle that half so each is splatted to a dword, then splat those - // to their respective halves. - auto SplatHalfs = [&](int LoInput, int HiInput, unsigned ShufWOp, - int DOffset) { - int PSHUFHalfMask[] = {LoInput % 4, LoInput % 4, HiInput % 4, HiInput % 4}; - int PSHUFDMask[] = {DOffset + 0, DOffset + 0, DOffset + 1, DOffset + 1}; + // If we are shuffling values from one half - check how many different DWORD + // pairs we need to create. If only 1 or 2 then we can perform this as a + // PSHUFLW/PSHUFHW + PSHUFD instead of the PSHUFD+PSHUFLW+PSHUFHW chain below. + auto ShuffleDWordPairs = [&](ArrayRef<int> PSHUFHalfMask, + ArrayRef<int> PSHUFDMask, unsigned ShufWOp) { V = DAG.getNode(ShufWOp, DL, VT, V, getV4X86ShuffleImm8ForMask(PSHUFHalfMask, DL, DAG)); V = DAG.getBitcast(PSHUFDVT, V); @@ -10949,10 +11343,48 @@ static SDValue lowerV8I16GeneralSingleInputVectorShuffle( return DAG.getBitcast(VT, V); }; - if (NumLToL == 1 && NumLToH == 1 && (NumHToL + NumHToH) == 0) - return SplatHalfs(LToLInputs[0], LToHInputs[0], X86ISD::PSHUFLW, 0); - if (NumHToL == 1 && NumHToH == 1 && (NumLToL + NumLToH) == 0) - return SplatHalfs(HToLInputs[0], HToHInputs[0], X86ISD::PSHUFHW, 2); + if ((NumHToL + NumHToH) == 0 || (NumLToL + NumLToH) == 0) { + int PSHUFDMask[4] = { -1, -1, -1, -1 }; + SmallVector<std::pair<int, int>, 4> DWordPairs; + int DOffset = ((NumHToL + NumHToH) == 0 ? 0 : 2); + + // Collect the different DWORD pairs. + for (int DWord = 0; DWord != 4; ++DWord) { + int M0 = Mask[2 * DWord + 0]; + int M1 = Mask[2 * DWord + 1]; + M0 = (M0 >= 0 ? M0 % 4 : M0); + M1 = (M1 >= 0 ? M1 % 4 : M1); + if (M0 < 0 && M1 < 0) + continue; + + bool Match = false; + for (int j = 0, e = DWordPairs.size(); j < e; ++j) { + auto &DWordPair = DWordPairs[j]; + if ((M0 < 0 || isUndefOrEqual(DWordPair.first, M0)) && + (M1 < 0 || isUndefOrEqual(DWordPair.second, M1))) { + DWordPair.first = (M0 >= 0 ? M0 : DWordPair.first); + DWordPair.second = (M1 >= 0 ? M1 : DWordPair.second); + PSHUFDMask[DWord] = DOffset + j; + Match = true; + break; + } + } + if (!Match) { + PSHUFDMask[DWord] = DOffset + DWordPairs.size(); + DWordPairs.push_back(std::make_pair(M0, M1)); + } + } + + if (DWordPairs.size() <= 2) { + DWordPairs.resize(2, std::make_pair(-1, -1)); + int PSHUFHalfMask[4] = {DWordPairs[0].first, DWordPairs[0].second, + DWordPairs[1].first, DWordPairs[1].second}; + if ((NumHToL + NumHToH) == 0) + return ShuffleDWordPairs(PSHUFHalfMask, PSHUFDMask, X86ISD::PSHUFLW); + if ((NumLToL + NumLToH) == 0) + return ShuffleDWordPairs(PSHUFHalfMask, PSHUFDMask, X86ISD::PSHUFHW); + } + } // Simplify the 1-into-3 and 3-into-1 cases with a single pshufd. For all // such inputs we can swap two of the dwords across the half mark and end up @@ -11449,6 +11881,11 @@ static SDValue lowerV8I16VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, lowerVectorShuffleWithUNPCK(DL, MVT::v8i16, Mask, V1, V2, DAG)) return V; + // Use dedicated pack instructions for masks that match their pattern. + if (SDValue V = lowerVectorShuffleWithPACK(DL, MVT::v8i16, Mask, V1, V2, + DAG, Subtarget)) + return V; + // Try to use byte rotation instructions. if (SDValue Rotate = lowerVectorShuffleAsByteRotate(DL, MVT::v8i16, V1, V1, Mask, Subtarget, DAG)) @@ -11499,6 +11936,11 @@ static SDValue lowerV8I16VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, lowerVectorShuffleWithUNPCK(DL, MVT::v8i16, Mask, V1, V2, DAG)) return V; + // Use dedicated pack instructions for masks that match their pattern. + if (SDValue V = lowerVectorShuffleWithPACK(DL, MVT::v8i16, Mask, V1, V2, DAG, + Subtarget)) + return V; + // Try to use byte rotation instructions. if (SDValue Rotate = lowerVectorShuffleAsByteRotate( DL, MVT::v8i16, V1, V2, Mask, Subtarget, DAG)) @@ -11593,6 +12035,19 @@ static int canLowerByDroppingEvenElements(ArrayRef<int> Mask, return 0; } +static SDValue lowerVectorShuffleWithPERMV(const SDLoc &DL, MVT VT, + ArrayRef<int> Mask, SDValue V1, + SDValue V2, SelectionDAG &DAG) { + MVT MaskEltVT = MVT::getIntegerVT(VT.getScalarSizeInBits()); + MVT MaskVecVT = MVT::getVectorVT(MaskEltVT, VT.getVectorNumElements()); + + SDValue MaskNode = getConstVector(Mask, MaskVecVT, DAG, DL, true); + if (V2.isUndef()) + return DAG.getNode(X86ISD::VPERMV, DL, VT, MaskNode, V1); + + return DAG.getNode(X86ISD::VPERMV3, DL, VT, V1, MaskNode, V2); +} + /// \brief Generic lowering of v16i8 shuffles. /// /// This is a hybrid strategy to lower v16i8 vectors. It first attempts to @@ -11619,6 +12074,11 @@ static SDValue lowerV16I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, DL, MVT::v16i8, V1, V2, Mask, Subtarget, DAG)) return Rotate; + // Use dedicated pack instructions for masks that match their pattern. + if (SDValue V = lowerVectorShuffleWithPACK(DL, MVT::v16i8, Mask, V1, V2, DAG, + Subtarget)) + return V; + // Try to use a zext lowering. if (SDValue ZExt = lowerVectorShuffleAsZeroOrAnyExtend( DL, MVT::v16i8, V1, V2, Mask, Zeroable, Subtarget, DAG)) @@ -11778,6 +12238,10 @@ static SDValue lowerV16I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, if (SDValue Unpack = lowerVectorShuffleAsPermuteAndUnpack( DL, MVT::v16i8, V1, V2, Mask, DAG)) return Unpack; + + // If we have VBMI we can use one VPERM instead of multiple PSHUFBs. + if (Subtarget.hasVBMI() && Subtarget.hasVLX()) + return lowerVectorShuffleWithPERMV(DL, MVT::v16i8, Mask, V1, V2, DAG); } return PSHUFB; @@ -12105,7 +12569,8 @@ static SDValue lowerVectorShuffleAsSplitOrBlend(const SDLoc &DL, MVT VT, static SDValue lowerVectorShuffleAsLanePermuteAndBlend(const SDLoc &DL, MVT VT, SDValue V1, SDValue V2, ArrayRef<int> Mask, - SelectionDAG &DAG) { + SelectionDAG &DAG, + const X86Subtarget &Subtarget) { // FIXME: This should probably be generalized for 512-bit vectors as well. assert(VT.is256BitVector() && "Only for 256-bit vector shuffles!"); int Size = Mask.size(); @@ -12114,12 +12579,21 @@ static SDValue lowerVectorShuffleAsLanePermuteAndBlend(const SDLoc &DL, MVT VT, // If there are only inputs from one 128-bit lane, splitting will in fact be // less expensive. The flags track whether the given lane contains an element // that crosses to another lane. - bool LaneCrossing[2] = {false, false}; - for (int i = 0; i < Size; ++i) - if (Mask[i] >= 0 && (Mask[i] % Size) / LaneSize != i / LaneSize) - LaneCrossing[(Mask[i] % Size) / LaneSize] = true; - if (!LaneCrossing[0] || !LaneCrossing[1]) - return splitAndLowerVectorShuffle(DL, VT, V1, V2, Mask, DAG); + if (!Subtarget.hasAVX2()) { + bool LaneCrossing[2] = {false, false}; + for (int i = 0; i < Size; ++i) + if (Mask[i] >= 0 && (Mask[i] % Size) / LaneSize != i / LaneSize) + LaneCrossing[(Mask[i] % Size) / LaneSize] = true; + if (!LaneCrossing[0] || !LaneCrossing[1]) + return splitAndLowerVectorShuffle(DL, VT, V1, V2, Mask, DAG); + } else { + bool LaneUsed[2] = {false, false}; + for (int i = 0; i < Size; ++i) + if (Mask[i] >= 0) + LaneUsed[(Mask[i] / LaneSize)] = true; + if (!LaneUsed[0] || !LaneUsed[1]) + return splitAndLowerVectorShuffle(DL, VT, V1, V2, Mask, DAG); + } assert(V2.isUndef() && "This last part of this routine only works on single input shuffles"); @@ -12132,14 +12606,12 @@ static SDValue lowerVectorShuffleAsLanePermuteAndBlend(const SDLoc &DL, MVT VT, : Mask[i] % LaneSize + (i / LaneSize) * LaneSize + Size); - // Flip the vector, and blend the results which should now be in-lane. The - // VPERM2X128 mask uses the low 2 bits for the low source and bits 4 and - // 5 for the high source. The value 3 selects the high half of source 2 and - // the value 2 selects the low half of source 2. We only use source 2 to - // allow folding it into a memory operand. - unsigned PERMMask = 3 | 2 << 4; - SDValue Flipped = DAG.getNode(X86ISD::VPERM2X128, DL, VT, DAG.getUNDEF(VT), - V1, DAG.getConstant(PERMMask, DL, MVT::i8)); + // Flip the vector, and blend the results which should now be in-lane. + MVT PVT = VT.isFloatingPoint() ? MVT::v4f64 : MVT::v4i64; + SDValue Flipped = DAG.getBitcast(PVT, V1); + Flipped = DAG.getVectorShuffle(PVT, DL, Flipped, DAG.getUNDEF(PVT), + { 2, 3, 0, 1 }); + Flipped = DAG.getBitcast(VT, Flipped); return DAG.getVectorShuffle(VT, DL, V1, Flipped, FlippedBlendMask); } @@ -12149,6 +12621,10 @@ static SDValue lowerV2X128VectorShuffle(const SDLoc &DL, MVT VT, SDValue V1, const APInt &Zeroable, const X86Subtarget &Subtarget, SelectionDAG &DAG) { + // With AVX2, use VPERMQ/VPERMPD for unary shuffles to allow memory folding. + if (Subtarget.hasAVX2() && V2.isUndef()) + return SDValue(); + SmallVector<int, 4> WidenedMask; if (!canWidenShuffleElements(Mask, WidenedMask)) return SDValue(); @@ -12162,19 +12638,16 @@ static SDValue lowerV2X128VectorShuffle(const SDLoc &DL, MVT VT, SDValue V1, Zeroable, Subtarget, DAG)) return Blend; - bool IsV1Zero = ISD::isBuildVectorAllZeros(V1.getNode()); - bool IsV2Zero = ISD::isBuildVectorAllZeros(V2.getNode()); + bool IsLowZero = (Zeroable & 0x3) == 0x3; + bool IsHighZero = (Zeroable & 0xc) == 0xc; // If either input operand is a zero vector, use VPERM2X128 because its mask // allows us to replace the zero input with an implicit zero. - if (!IsV1Zero && !IsV2Zero) { + if (!IsLowZero && !IsHighZero) { // Check for patterns which can be matched with a single insert of a 128-bit // subvector. bool OnlyUsesV1 = isShuffleEquivalent(V1, V2, Mask, {0, 1, 0, 1}); if (OnlyUsesV1 || isShuffleEquivalent(V1, V2, Mask, {0, 1, 4, 5})) { - // With AVX2, use VPERMQ/VPERMPD to allow memory folding. - if (Subtarget.hasAVX2() && V2.isUndef()) - return SDValue(); // With AVX1, use vperm2f128 (below) to allow load folding. Otherwise, // this will likely become vinsertf128 which can't fold a 256-bit memop. @@ -12189,6 +12662,16 @@ static SDValue lowerV2X128VectorShuffle(const SDLoc &DL, MVT VT, SDValue V1, return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, LoV, HiV); } } + + // Try to use SHUF128 if possible. + if (Subtarget.hasVLX()) { + if (WidenedMask[0] < 2 && WidenedMask[1] >= 2) { + unsigned PermMask = ((WidenedMask[0] % 2) << 0) | + ((WidenedMask[1] % 2) << 1); + return DAG.getNode(X86ISD::SHUF128, DL, VT, V1, V2, + DAG.getConstant(PermMask, DL, MVT::i8)); + } + } } // Otherwise form a 128-bit permutation. After accounting for undefs, @@ -12204,30 +12687,17 @@ static SDValue lowerV2X128VectorShuffle(const SDLoc &DL, MVT VT, SDValue V1, // [6] - ignore // [7] - zero high half of destination - int MaskLO = WidenedMask[0] < 0 ? 0 : WidenedMask[0]; - int MaskHI = WidenedMask[1] < 0 ? 0 : WidenedMask[1]; + assert(WidenedMask[0] >= 0 && WidenedMask[1] >= 0 && "Undef half?"); - unsigned PermMask = MaskLO | (MaskHI << 4); + unsigned PermMask = 0; + PermMask |= IsLowZero ? 0x08 : (WidenedMask[0] << 0); + PermMask |= IsHighZero ? 0x80 : (WidenedMask[1] << 4); - // If either input is a zero vector, replace it with an undef input. - // Shuffle mask values < 4 are selecting elements of V1. - // Shuffle mask values >= 4 are selecting elements of V2. - // Adjust each half of the permute mask by clearing the half that was - // selecting the zero vector and setting the zero mask bit. - if (IsV1Zero) { + // Check the immediate mask and replace unused sources with undef. + if ((PermMask & 0x0a) != 0x00 && (PermMask & 0xa0) != 0x00) V1 = DAG.getUNDEF(VT); - if (MaskLO < 2) - PermMask = (PermMask & 0xf0) | 0x08; - if (MaskHI < 2) - PermMask = (PermMask & 0x0f) | 0x80; - } - if (IsV2Zero) { + if ((PermMask & 0x0a) != 0x02 && (PermMask & 0xa0) != 0x20) V2 = DAG.getUNDEF(VT); - if (MaskLO >= 2) - PermMask = (PermMask & 0xf0) | 0x08; - if (MaskHI >= 2) - PermMask = (PermMask & 0x0f) | 0x80; - } return DAG.getNode(X86ISD::VPERM2X128, DL, VT, V1, V2, DAG.getConstant(PermMask, DL, MVT::i8)); @@ -12311,7 +12781,7 @@ static SDValue lowerVectorShuffleByMerging128BitLanes( return DAG.getVectorShuffle(VT, DL, LaneShuffle, DAG.getUNDEF(VT), NewMask); } -/// Lower shuffles where an entire half of a 256-bit vector is UNDEF. +/// Lower shuffles where an entire half of a 256 or 512-bit vector is UNDEF. /// This allows for fast cases such as subvector extraction/insertion /// or shuffling smaller vector types which can lower more efficiently. static SDValue lowerVectorShuffleWithUndefHalf(const SDLoc &DL, MVT VT, @@ -12319,7 +12789,8 @@ static SDValue lowerVectorShuffleWithUndefHalf(const SDLoc &DL, MVT VT, ArrayRef<int> Mask, const X86Subtarget &Subtarget, SelectionDAG &DAG) { - assert(VT.is256BitVector() && "Expected 256-bit vector"); + assert((VT.is256BitVector() || VT.is512BitVector()) && + "Expected 256-bit or 512-bit vector"); unsigned NumElts = VT.getVectorNumElements(); unsigned HalfNumElts = NumElts / 2; @@ -12415,6 +12886,10 @@ static SDValue lowerVectorShuffleWithUndefHalf(const SDLoc &DL, MVT VT, } } + // AVX512 - XXXXuuuu - always extract lowers. + if (VT.is512BitVector() && !(UndefUpper && NumUpperHalves == 0)) + return SDValue(); + auto GetHalfVector = [&](int HalfIdx) { if (HalfIdx < 0) return DAG.getUNDEF(HalfVT); @@ -12667,19 +13142,6 @@ static SDValue lowerVectorShuffleWithSHUFPD(const SDLoc &DL, MVT VT, DAG.getConstant(Immediate, DL, MVT::i8)); } -static SDValue lowerVectorShuffleWithPERMV(const SDLoc &DL, MVT VT, - ArrayRef<int> Mask, SDValue V1, - SDValue V2, SelectionDAG &DAG) { - MVT MaskEltVT = MVT::getIntegerVT(VT.getScalarSizeInBits()); - MVT MaskVecVT = MVT::getVectorVT(MaskEltVT, VT.getVectorNumElements()); - - SDValue MaskNode = getConstVector(Mask, MaskVecVT, DAG, DL, true); - if (V2.isUndef()) - return DAG.getNode(X86ISD::VPERMV, DL, VT, MaskNode, V1); - - return DAG.getNode(X86ISD::VPERMV3, DL, VT, V1, MaskNode, V2); -} - /// \brief Handle lowering of 4-lane 64-bit floating point shuffles. /// /// Also ends up handling lowering of 4-lane 64-bit integer shuffles when AVX2 @@ -12729,7 +13191,7 @@ static SDValue lowerV4F64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // Otherwise, fall back. return lowerVectorShuffleAsLanePermuteAndBlend(DL, MVT::v4f64, V1, V2, Mask, - DAG); + DAG, Subtarget); } // Use dedicated unpack instructions for masks that match their pattern. @@ -12810,7 +13272,7 @@ static SDValue lowerV4I64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, SmallVector<int, 2> RepeatedMask; if (is128BitLaneRepeatedShuffleMask(MVT::v4i64, Mask, RepeatedMask)) { SmallVector<int, 4> PSHUFDMask; - scaleShuffleMask(2, RepeatedMask, PSHUFDMask); + scaleShuffleMask<int>(2, RepeatedMask, PSHUFDMask); return DAG.getBitcast( MVT::v4i64, DAG.getNode(X86ISD::PSHUFD, DL, MVT::v8i32, @@ -12932,7 +13394,7 @@ static SDValue lowerV8F32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // Otherwise, fall back. return lowerVectorShuffleAsLanePermuteAndBlend(DL, MVT::v8f32, V1, V2, Mask, - DAG); + DAG, Subtarget); } // Try to simplify this by merging 128-bit lanes to enable a lane-based @@ -13112,6 +13574,11 @@ static SDValue lowerV16I16VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, lowerVectorShuffleWithUNPCK(DL, MVT::v16i16, Mask, V1, V2, DAG)) return V; + // Use dedicated pack instructions for masks that match their pattern. + if (SDValue V = lowerVectorShuffleWithPACK(DL, MVT::v16i16, Mask, V1, V2, DAG, + Subtarget)) + return V; + // Try to use shift instructions. if (SDValue Shift = lowerVectorShuffleAsShift(DL, MVT::v16i16, V1, V2, Mask, Zeroable, Subtarget, DAG)) @@ -13133,7 +13600,7 @@ static SDValue lowerV16I16VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // element types. if (is128BitLaneCrossingShuffleMask(MVT::v16i16, Mask)) return lowerVectorShuffleAsLanePermuteAndBlend(DL, MVT::v16i16, V1, V2, - Mask, DAG); + Mask, DAG, Subtarget); SmallVector<int, 8> RepeatedMask; if (is128BitLaneRepeatedShuffleMask(MVT::v16i16, Mask, RepeatedMask)) { @@ -13198,6 +13665,11 @@ static SDValue lowerV32I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, lowerVectorShuffleWithUNPCK(DL, MVT::v32i8, Mask, V1, V2, DAG)) return V; + // Use dedicated pack instructions for masks that match their pattern. + if (SDValue V = lowerVectorShuffleWithPACK(DL, MVT::v32i8, Mask, V1, V2, DAG, + Subtarget)) + return V; + // Try to use shift instructions. if (SDValue Shift = lowerVectorShuffleAsShift(DL, MVT::v32i8, V1, V2, Mask, Zeroable, Subtarget, DAG)) @@ -13218,12 +13690,16 @@ static SDValue lowerV32I8VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // element types. if (V2.isUndef() && is128BitLaneCrossingShuffleMask(MVT::v32i8, Mask)) return lowerVectorShuffleAsLanePermuteAndBlend(DL, MVT::v32i8, V1, V2, Mask, - DAG); + DAG, Subtarget); if (SDValue PSHUFB = lowerVectorShuffleWithPSHUFB( DL, MVT::v32i8, Mask, V1, V2, Zeroable, Subtarget, DAG)) return PSHUFB; + // AVX512VBMIVL can lower to VPERMB. + if (Subtarget.hasVBMI() && Subtarget.hasVLX()) + return lowerVectorShuffleWithPERMV(DL, MVT::v32i8, Mask, V1, V2, DAG); + // Try to simplify this by merging 128-bit lanes to enable a lane-based // shuffle. if (SDValue Result = lowerVectorShuffleByMerging128BitLanes( @@ -13485,6 +13961,15 @@ static SDValue lowerV16F32VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, // Otherwise, fall back to a SHUFPS sequence. return lowerVectorShuffleWithSHUFPS(DL, MVT::v16f32, RepeatedMask, V1, V2, DAG); } + + // If we have a single input shuffle with different shuffle patterns in the + // 128-bit lanes and don't lane cross, use variable mask VPERMILPS. + if (V2.isUndef() && + !is128BitLaneCrossingShuffleMask(MVT::v16f32, Mask)) { + SDValue VPermMask = getConstVector(Mask, MVT::v16i32, DAG, DL, true); + return DAG.getNode(X86ISD::VPERMILPV, DL, MVT::v16f32, V1, VPermMask); + } + // If we have AVX512F support, we can use VEXPAND. if (SDValue V = lowerVectorShuffleToEXPAND(DL, MVT::v16f32, Zeroable, Mask, V1, V2, DAG, Subtarget)) @@ -13503,10 +13988,6 @@ static SDValue lowerV8I64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, assert(V2.getSimpleValueType() == MVT::v8i64 && "Bad operand type!"); assert(Mask.size() == 8 && "Unexpected mask size for v8 shuffle!"); - if (SDValue Shuf128 = - lowerV4X128VectorShuffle(DL, MVT::v8i64, Mask, V1, V2, DAG)) - return Shuf128; - if (V2.isUndef()) { // When the shuffle is mirrored between the 128-bit lanes of the unit, we // can use lower latency instructions that will operate on all four @@ -13514,7 +13995,7 @@ static SDValue lowerV8I64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, SmallVector<int, 2> Repeated128Mask; if (is128BitLaneRepeatedShuffleMask(MVT::v8i64, Mask, Repeated128Mask)) { SmallVector<int, 4> PSHUFDMask; - scaleShuffleMask(2, Repeated128Mask, PSHUFDMask); + scaleShuffleMask<int>(2, Repeated128Mask, PSHUFDMask); return DAG.getBitcast( MVT::v8i64, DAG.getNode(X86ISD::PSHUFD, DL, MVT::v16i32, @@ -13528,6 +14009,10 @@ static SDValue lowerV8I64VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, getV4X86ShuffleImm8ForMask(Repeated256Mask, DL, DAG)); } + if (SDValue Shuf128 = + lowerV4X128VectorShuffle(DL, MVT::v8i64, Mask, V1, V2, DAG)) + return Shuf128; + // Try to use shift instructions. if (SDValue Shift = lowerVectorShuffleAsShift(DL, MVT::v8i64, V1, V2, Mask, Zeroable, Subtarget, DAG)) @@ -13677,6 +14162,10 @@ static SDValue lowerV32I16VectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, Zeroable, Subtarget, DAG)) return Blend; + if (SDValue PSHUFB = lowerVectorShuffleWithPSHUFB( + DL, MVT::v32i16, Mask, V1, V2, Zeroable, Subtarget, DAG)) + return PSHUFB; + return lowerVectorShuffleWithPERMV(DL, MVT::v32i16, Mask, V1, V2, DAG); } @@ -13758,6 +14247,11 @@ static SDValue lower512BitVectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, DL, VT, V1, V2, Mask, Zeroable, Subtarget, DAG)) return Insertion; + // Handle special cases where the lower or upper half is UNDEF. + if (SDValue V = + lowerVectorShuffleWithUndefHalf(DL, VT, V1, V2, Mask, Subtarget, DAG)) + return V; + // Check for being able to broadcast a single element. if (SDValue Broadcast = lowerVectorShuffleAsBroadcast(DL, VT, V1, V2, Mask, Subtarget, DAG)) @@ -13807,7 +14301,9 @@ static SDValue lower1BitVectorShuffle(const SDLoc &DL, ArrayRef<int> Mask, ExtVT = MVT::v4i32; break; case MVT::v8i1: - ExtVT = MVT::v8i64; // Take 512-bit type, more shuffles on KNL + // Take 512-bit type, more shuffles on KNL. If we have VLX use a 256-bit + // shuffle. + ExtVT = Subtarget.hasVLX() ? MVT::v8i32 : MVT::v8i64; break; case MVT::v16i1: ExtVT = MVT::v16i32; @@ -14046,16 +14542,16 @@ SDValue X86TargetLowering::LowerVSELECT(SDValue Op, SelectionDAG &DAG) const { ISD::isBuildVectorOfConstantSDNodes(Op.getOperand(2).getNode())) return SDValue(); - // If this VSELECT has a vector if i1 as a mask, it will be directly matched - // with patterns on the mask registers on AVX-512. - if (Op->getOperand(0).getValueType().getScalarSizeInBits() == 1) - return Op; - // Try to lower this to a blend-style vector shuffle. This can handle all // constant condition cases. if (SDValue BlendOp = lowerVSELECTtoVectorShuffle(Op, Subtarget, DAG)) return BlendOp; + // If this VSELECT has a vector if i1 as a mask, it will be directly matched + // with patterns on the mask registers on AVX-512. + if (Op->getOperand(0).getValueType().getScalarSizeInBits() == 1) + return Op; + // Variable blends are only legal from SSE4.1 onward. if (!Subtarget.hasSSE41()) return SDValue(); @@ -14097,10 +14593,6 @@ SDValue X86TargetLowering::LowerVSELECT(SDValue Op, SelectionDAG &DAG) const { case MVT::v8i16: case MVT::v16i16: - // AVX-512 BWI and VLX features support VSELECT with i16 elements. - if (Subtarget.hasBWI() && Subtarget.hasVLX()) - return Op; - // FIXME: We should custom lower this by fixing the condition and using i8 // blends. return SDValue(); @@ -14117,9 +14609,7 @@ static SDValue LowerEXTRACT_VECTOR_ELT_SSE4(SDValue Op, SelectionDAG &DAG) { if (VT.getSizeInBits() == 8) { SDValue Extract = DAG.getNode(X86ISD::PEXTRB, dl, MVT::i32, Op.getOperand(0), Op.getOperand(1)); - SDValue Assert = DAG.getNode(ISD::AssertZext, dl, MVT::i32, Extract, - DAG.getValueType(VT)); - return DAG.getNode(ISD::TRUNCATE, dl, VT, Assert); + return DAG.getNode(ISD::TRUNCATE, dl, VT, Extract); } if (VT == MVT::f32) { @@ -14153,8 +14643,8 @@ static SDValue LowerEXTRACT_VECTOR_ELT_SSE4(SDValue Op, SelectionDAG &DAG) { /// Extract one bit from mask vector, like v16i1 or v8i1. /// AVX-512 feature. -SDValue -X86TargetLowering::ExtractBitFromMaskVector(SDValue Op, SelectionDAG &DAG) const { +static SDValue ExtractBitFromMaskVector(SDValue Op, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { SDValue Vec = Op.getOperand(0); SDLoc dl(Vec); MVT VecVT = Vec.getSimpleValueType(); @@ -14170,31 +14660,42 @@ X86TargetLowering::ExtractBitFromMaskVector(SDValue Op, SelectionDAG &DAG) const unsigned NumElts = VecVT.getVectorNumElements(); // Extending v8i1/v16i1 to 512-bit get better performance on KNL // than extending to 128/256bit. - unsigned VecSize = (NumElts <= 4 ? 128 : 512); - MVT ExtVT = MVT::getVectorVT(MVT::getIntegerVT(VecSize/NumElts), NumElts); - SDValue Ext = DAG.getNode(ISD::SIGN_EXTEND, dl, ExtVT, Vec); - SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, - ExtVT.getVectorElementType(), Ext, Idx); + MVT ExtEltVT = (NumElts <= 8) ? MVT::getIntegerVT(128 / NumElts) : MVT::i8; + MVT ExtVecVT = MVT::getVectorVT(ExtEltVT, NumElts); + SDValue Ext = DAG.getNode(ISD::SIGN_EXTEND, dl, ExtVecVT, Vec); + SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, ExtEltVT, Ext, Idx); return DAG.getNode(ISD::TRUNCATE, dl, EltVT, Elt); } + // Canonicalize result type to MVT::i32. + if (EltVT != MVT::i32) { + SDValue Extract = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i32, + Vec, Idx); + return DAG.getAnyExtOrTrunc(Extract, dl, EltVT); + } + unsigned IdxVal = cast<ConstantSDNode>(Idx)->getZExtValue(); + + // Extracts from element 0 are always allowed. + if (IdxVal == 0) + return Op; + + // If the kshift instructions of the correct width aren't natively supported + // then we need to promote the vector to the native size to get the correct + // zeroing behavior. if ((!Subtarget.hasDQI() && (VecVT.getVectorNumElements() == 8)) || (VecVT.getVectorNumElements() < 8)) { - // Use kshiftlw/rw instruction. VecVT = MVT::v16i1; Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, VecVT, DAG.getUNDEF(VecVT), Vec, DAG.getIntPtrConstant(0, dl)); } - unsigned MaxSift = VecVT.getVectorNumElements() - 1; - if (MaxSift - IdxVal) - Vec = DAG.getNode(X86ISD::KSHIFTL, dl, VecVT, Vec, - DAG.getConstant(MaxSift - IdxVal, dl, MVT::i8)); + + // Use kshiftr instruction to move to the lower element. Vec = DAG.getNode(X86ISD::KSHIFTR, dl, VecVT, Vec, - DAG.getConstant(MaxSift, dl, MVT::i8)); - return DAG.getNode(X86ISD::VEXTRACT, dl, Op.getSimpleValueType(), Vec, + DAG.getConstant(IdxVal, dl, MVT::i8)); + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i32, Vec, DAG.getIntPtrConstant(0, dl)); } @@ -14207,7 +14708,7 @@ X86TargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op, SDValue Idx = Op.getOperand(1); if (VecVT.getVectorElementType() == MVT::i1) - return ExtractBitFromMaskVector(Op, DAG); + return ExtractBitFromMaskVector(Op, DAG, Subtarget); if (!isa<ConstantSDNode>(Idx)) { // Its more profitable to go through memory (1 cycles throughput) @@ -14278,9 +14779,7 @@ X86TargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op, // Transform it so it match pextrw which produces a 32-bit result. SDValue Extract = DAG.getNode(X86ISD::PEXTRW, dl, MVT::i32, Op.getOperand(0), Op.getOperand(1)); - SDValue Assert = DAG.getNode(ISD::AssertZext, dl, MVT::i32, Extract, - DAG.getValueType(VT)); - return DAG.getNode(ISD::TRUNCATE, dl, VT, Assert); + return DAG.getNode(ISD::TRUNCATE, dl, VT, Extract); } if (Subtarget.hasSSE41()) @@ -14347,8 +14846,8 @@ X86TargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op, /// Insert one bit to mask vector, like v16i1 or v8i1. /// AVX-512 feature. -SDValue -X86TargetLowering::InsertBitToMaskVector(SDValue Op, SelectionDAG &DAG) const { +static SDValue InsertBitToMaskVector(SDValue Op, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { SDLoc dl(Op); SDValue Vec = Op.getOperand(0); SDValue Elt = Op.getOperand(1); @@ -14358,19 +14857,34 @@ X86TargetLowering::InsertBitToMaskVector(SDValue Op, SelectionDAG &DAG) const { if (!isa<ConstantSDNode>(Idx)) { // Non constant index. Extend source and destination, // insert element and then truncate the result. - MVT ExtVecVT = (VecVT == MVT::v8i1 ? MVT::v8i64 : MVT::v16i32); - MVT ExtEltVT = (VecVT == MVT::v8i1 ? MVT::i64 : MVT::i32); + unsigned NumElts = VecVT.getVectorNumElements(); + MVT ExtEltVT = (NumElts <= 8) ? MVT::getIntegerVT(128 / NumElts) : MVT::i8; + MVT ExtVecVT = MVT::getVectorVT(ExtEltVT, NumElts); SDValue ExtOp = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, ExtVecVT, - DAG.getNode(ISD::ZERO_EXTEND, dl, ExtVecVT, Vec), - DAG.getNode(ISD::ZERO_EXTEND, dl, ExtEltVT, Elt), Idx); + DAG.getNode(ISD::SIGN_EXTEND, dl, ExtVecVT, Vec), + DAG.getNode(ISD::SIGN_EXTEND, dl, ExtEltVT, Elt), Idx); return DAG.getNode(ISD::TRUNCATE, dl, VecVT, ExtOp); } unsigned IdxVal = cast<ConstantSDNode>(Idx)->getZExtValue(); - SDValue EltInVec = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, VecVT, Elt); unsigned NumElems = VecVT.getVectorNumElements(); - if(Vec.isUndef()) { + // If the kshift instructions of the correct width aren't natively supported + // then we need to promote the vector to the native size to get the correct + // zeroing behavior. + if ((!Subtarget.hasDQI() && NumElems == 8) || (NumElems < 8)) { + // Need to promote to v16i1, do the insert, then extract back. + Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, MVT::v16i1, + DAG.getUNDEF(MVT::v16i1), Vec, + DAG.getIntPtrConstant(0, dl)); + Op = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, MVT::v16i1, Vec, Elt, Idx); + return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VecVT, Op, + DAG.getIntPtrConstant(0, dl)); + } + + SDValue EltInVec = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, VecVT, Elt); + + if (Vec.isUndef()) { if (IdxVal) EltInVec = DAG.getNode(X86ISD::KSHIFTL, dl, VecVT, EltInVec, DAG.getConstant(IdxVal, dl, MVT::i8)); @@ -14393,25 +14907,33 @@ X86TargetLowering::InsertBitToMaskVector(SDValue Op, SelectionDAG &DAG) const { return DAG.getNode(ISD::OR, dl, VecVT, Vec, EltInVec); } // Insertion of one bit into last position - if (IdxVal == NumElems -1) { + if (IdxVal == NumElems - 1) { // Move the bit to the last position inside the vector. EltInVec = DAG.getNode(X86ISD::KSHIFTL, dl, VecVT, EltInVec, DAG.getConstant(IdxVal, dl, MVT::i8)); // Clean the last bit in the source vector. Vec = DAG.getNode(X86ISD::KSHIFTL, dl, VecVT, Vec, - DAG.getConstant(1, dl, MVT::i8)); + DAG.getConstant(1, dl, MVT::i8)); Vec = DAG.getNode(X86ISD::KSHIFTR, dl, VecVT, Vec, - DAG.getConstant(1 , dl, MVT::i8)); + DAG.getConstant(1 , dl, MVT::i8)); return DAG.getNode(ISD::OR, dl, VecVT, Vec, EltInVec); } - // Use shuffle to insert element. - SmallVector<int, 64> MaskVec(NumElems); - for (unsigned i = 0; i != NumElems; ++i) - MaskVec[i] = (i == IdxVal) ? NumElems : i; - - return DAG.getVectorShuffle(VecVT, dl, Vec, EltInVec, MaskVec); + // Move the current value of the bit to be replace to bit 0. + SDValue Merged = DAG.getNode(X86ISD::KSHIFTR, dl, VecVT, Vec, + DAG.getConstant(IdxVal, dl, MVT::i8)); + // Xor with the new bit. + Merged = DAG.getNode(ISD::XOR, dl, VecVT, Merged, EltInVec); + // Shift to MSB, filling bottom bits with 0. + Merged = DAG.getNode(X86ISD::KSHIFTL, dl, VecVT, Merged, + DAG.getConstant(NumElems - 1, dl, MVT::i8)); + // Shift to the final position, filling upper bits with 0. + Merged = DAG.getNode(X86ISD::KSHIFTR, dl, VecVT, Merged, + DAG.getConstant(NumElems - 1 - IdxVal, dl, MVT::i8)); + // Xor with original vector to cancel out the original bit value that's still + // present. + return DAG.getNode(ISD::XOR, dl, VecVT, Merged, Vec); } SDValue X86TargetLowering::LowerINSERT_VECTOR_ELT(SDValue Op, @@ -14421,7 +14943,7 @@ SDValue X86TargetLowering::LowerINSERT_VECTOR_ELT(SDValue Op, unsigned NumElts = VT.getVectorNumElements(); if (EltVT == MVT::i1) - return InsertBitToMaskVector(Op, DAG); + return InsertBitToMaskVector(Op, DAG, Subtarget); SDLoc dl(Op); SDValue N0 = Op.getOperand(0); @@ -14444,7 +14966,7 @@ SDValue X86TargetLowering::LowerINSERT_VECTOR_ELT(SDValue Op, for (unsigned i = 0; i != NumElts; ++i) BlendMask.push_back(i == IdxVal ? i + NumElts : i); SDValue CstVector = IsZeroElt ? getZeroVector(VT, Subtarget, DAG, dl) - : DAG.getConstant(-1, dl, VT); + : getOnesVector(VT, DAG, dl); return DAG.getVectorShuffle(VT, dl, N0, CstVector, BlendMask); } @@ -14513,7 +15035,7 @@ SDValue X86TargetLowering::LowerINSERT_VECTOR_ELT(SDValue Op, // Bits [3:0] of the constant are the zero mask. The DAG Combiner may // combine either bitwise AND or insert of float 0.0 to set these bits. - bool MinSize = DAG.getMachineFunction().getFunction()->optForMinSize(); + bool MinSize = DAG.getMachineFunction().getFunction().optForMinSize(); if (IdxVal == 0 && (!MinSize || !MayFoldLoad(N1))) { // If this is an insertion of 32-bits into the low 32-bits of // a vector, we prefer to generate a blend with immediate rather @@ -14574,56 +15096,50 @@ static SDValue LowerSCALAR_TO_VECTOR(SDValue Op, const X86Subtarget &Subtarget, OpVT, DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, MVT::v4i32, AnyExt)); } -// Lower a node with an EXTRACT_SUBVECTOR opcode. This may result in -// a simple subregister reference or explicit instructions to grab -// upper bits of a vector. +// Lower a node with an INSERT_SUBVECTOR opcode. This may result in a +// simple superregister reference or explicit instructions to insert +// the upper bits of a vector. +static SDValue LowerINSERT_SUBVECTOR(SDValue Op, const X86Subtarget &Subtarget, + SelectionDAG &DAG) { + assert(Op.getSimpleValueType().getVectorElementType() == MVT::i1); + + return insert1BitVector(Op, DAG, Subtarget); +} + static SDValue LowerEXTRACT_SUBVECTOR(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG) { - assert(Subtarget.hasAVX() && "EXTRACT_SUBVECTOR requires AVX"); + assert(Op.getSimpleValueType().getVectorElementType() == MVT::i1 && + "Only vXi1 extract_subvectors need custom lowering"); SDLoc dl(Op); - SDValue In = Op.getOperand(0); + SDValue Vec = Op.getOperand(0); SDValue Idx = Op.getOperand(1); - unsigned IdxVal = cast<ConstantSDNode>(Idx)->getZExtValue(); - MVT ResVT = Op.getSimpleValueType(); - // When v1i1 is legal a scalarization of a vselect with a vXi1 Cond - // would result with: v1i1 = extract_subvector(vXi1, idx). - // Lower these into extract_vector_elt which is already selectable. - if (ResVT == MVT::v1i1) { - assert(Subtarget.hasAVX512() && - "Boolean EXTRACT_SUBVECTOR requires AVX512"); - - MVT EltVT = ResVT.getVectorElementType(); - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - MVT LegalVT = - (TLI.getTypeToTransformTo(*DAG.getContext(), EltVT)).getSimpleVT(); - SDValue Res = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, LegalVT, In, Idx); - return DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, ResVT, Res); - } + if (!isa<ConstantSDNode>(Idx)) + return SDValue(); - assert((In.getSimpleValueType().is256BitVector() || - In.getSimpleValueType().is512BitVector()) && - "Can only extract from 256-bit or 512-bit vectors"); + unsigned IdxVal = cast<ConstantSDNode>(Idx)->getZExtValue(); + if (IdxVal == 0) // the operation is legal + return Op; - // If the input is a buildvector just emit a smaller one. - unsigned ElemsPerChunk = ResVT.getVectorNumElements(); - if (In.getOpcode() == ISD::BUILD_VECTOR) - return DAG.getBuildVector( - ResVT, dl, makeArrayRef(In->op_begin() + IdxVal, ElemsPerChunk)); + MVT VecVT = Vec.getSimpleValueType(); + unsigned NumElems = VecVT.getVectorNumElements(); - // Everything else is legal. - return Op; -} + // Extend to natively supported kshift. + MVT WideVecVT = VecVT; + if ((!Subtarget.hasDQI() && NumElems == 8) || NumElems < 8) { + WideVecVT = Subtarget.hasDQI() ? MVT::v8i1 : MVT::v16i1; + Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, WideVecVT, + DAG.getUNDEF(WideVecVT), Vec, + DAG.getIntPtrConstant(0, dl)); + } -// Lower a node with an INSERT_SUBVECTOR opcode. This may result in a -// simple superregister reference or explicit instructions to insert -// the upper bits of a vector. -static SDValue LowerINSERT_SUBVECTOR(SDValue Op, const X86Subtarget &Subtarget, - SelectionDAG &DAG) { - assert(Op.getSimpleValueType().getVectorElementType() == MVT::i1); + // Shift to the LSB. + Vec = DAG.getNode(X86ISD::KSHIFTR, dl, WideVecVT, Vec, + DAG.getConstant(IdxVal, dl, MVT::i8)); - return insert1BitVector(Op, DAG, Subtarget); + return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, Op.getValueType(), Vec, + DAG.getIntPtrConstant(0, dl)); } // Returns the appropriate wrapper opcode for a global reference. @@ -14696,7 +15212,7 @@ X86TargetLowering::LowerExternalSymbol(SDValue Op, SelectionDAG &DAG) const { // In PIC mode (unless we're in RIPRel PIC mode) we add an offset to the // global base reg. - const Module *Mod = DAG.getMachineFunction().getFunction()->getParent(); + const Module *Mod = DAG.getMachineFunction().getFunction().getParent(); unsigned char OpFlag = Subtarget.classifyGlobalReference(nullptr, *Mod); auto PtrVT = getPointerTy(DAG.getDataLayout()); @@ -15145,20 +15661,19 @@ SDValue X86TargetLowering::LowerSINT_TO_FP(SDValue Op, MVT VT = Op.getSimpleValueType(); SDLoc dl(Op); - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); if (SrcVT.isVector()) { if (SrcVT == MVT::v2i32 && VT == MVT::v2f64) { return DAG.getNode(X86ISD::CVTSI2P, dl, VT, DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4i32, Src, DAG.getUNDEF(SrcVT))); } - if (SrcVT.getVectorElementType() == MVT::i1) { - if (SrcVT == MVT::v2i1 && TLI.isTypeLegal(SrcVT)) - return DAG.getNode(ISD::SINT_TO_FP, dl, Op.getValueType(), - DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v2i64, Src)); - MVT IntegerVT = MVT::getVectorVT(MVT::i32, SrcVT.getVectorNumElements()); - return DAG.getNode(ISD::SINT_TO_FP, dl, Op.getValueType(), - DAG.getNode(ISD::SIGN_EXTEND, dl, IntegerVT, Src)); + if (SrcVT == MVT::v2i1) { + // For v2i1, we need to widen to v4i1 first. + assert(VT == MVT::v2f64 && "Unexpected type"); + Src = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4i1, Src, + DAG.getUNDEF(MVT::v2i1)); + return DAG.getNode(X86ISD::CVTSI2P, dl, Op.getValueType(), + DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v4i32, Src)); } return SDValue(); } @@ -15255,8 +15770,8 @@ SDValue X86TargetLowering::BuildFILD(SDValue Op, EVT SrcVT, SDValue Chain, } /// 64-bit unsigned integer to double expansion. -SDValue X86TargetLowering::LowerUINT_TO_FP_i64(SDValue Op, - SelectionDAG &DAG) const { +static SDValue LowerUINT_TO_FP_i64(SDValue Op, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { // This algorithm is not obvious. Here it is what we're trying to output: /* movq %rax, %xmm0 @@ -15276,7 +15791,7 @@ SDValue X86TargetLowering::LowerUINT_TO_FP_i64(SDValue Op, // Build some magic constants. static const uint32_t CV0[] = { 0x43300000, 0x45300000, 0, 0 }; Constant *C0 = ConstantDataVector::get(*Context, CV0); - auto PtrVT = getPointerTy(DAG.getDataLayout()); + auto PtrVT = DAG.getTargetLoweringInfo().getPointerTy(DAG.getDataLayout()); SDValue CPIdx0 = DAG.getConstantPool(C0, PtrVT, 16); SmallVector<Constant*,2> CV1; @@ -15323,8 +15838,8 @@ SDValue X86TargetLowering::LowerUINT_TO_FP_i64(SDValue Op, } /// 32-bit unsigned integer to float expansion. -SDValue X86TargetLowering::LowerUINT_TO_FP_i32(SDValue Op, - SelectionDAG &DAG) const { +static SDValue LowerUINT_TO_FP_i32(SDValue Op, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { SDLoc dl(Op); // FP constant to bias correct the final result. SDValue Bias = DAG.getConstantFP(BitsToDouble(0x4330000000000000ULL), dl, @@ -15357,16 +15872,7 @@ SDValue X86TargetLowering::LowerUINT_TO_FP_i32(SDValue Op, SDValue Sub = DAG.getNode(ISD::FSUB, dl, MVT::f64, Or, Bias); // Handle final rounding. - MVT DestVT = Op.getSimpleValueType(); - - if (DestVT.bitsLT(MVT::f64)) - return DAG.getNode(ISD::FP_ROUND, dl, DestVT, Sub, - DAG.getIntPtrConstant(0, dl)); - if (DestVT.bitsGT(MVT::f64)) - return DAG.getNode(ISD::FP_EXTEND, dl, DestVT, Sub); - - // Handle final rounding. - return Sub; + return DAG.getFPExtendOrRound(Sub, dl, Op.getSimpleValueType()); } static SDValue lowerUINT_TO_FP_v2i32(SDValue Op, SelectionDAG &DAG, @@ -15498,42 +16004,30 @@ static SDValue lowerUINT_TO_FP_vXi32(SDValue Op, SelectionDAG &DAG, return DAG.getNode(ISD::FADD, DL, VecFloatVT, LowBitcast, FHigh); } -SDValue X86TargetLowering::lowerUINT_TO_FP_vec(SDValue Op, - SelectionDAG &DAG) const { +static SDValue lowerUINT_TO_FP_vec(SDValue Op, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { SDValue N0 = Op.getOperand(0); MVT SrcVT = N0.getSimpleValueType(); SDLoc dl(Op); - if (SrcVT.getVectorElementType() == MVT::i1) { - if (SrcVT == MVT::v2i1) - return DAG.getNode(ISD::UINT_TO_FP, dl, Op.getValueType(), - DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::v2i64, N0)); - MVT IntegerVT = MVT::getVectorVT(MVT::i32, SrcVT.getVectorNumElements()); - return DAG.getNode(ISD::UINT_TO_FP, dl, Op.getValueType(), - DAG.getNode(ISD::ZERO_EXTEND, dl, IntegerVT, N0)); + if (SrcVT == MVT::v2i1) { + // For v2i1, we need to widen to v4i1 first. + assert(Op.getValueType() == MVT::v2f64 && "Unexpected type"); + N0 = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4i1, N0, + DAG.getUNDEF(MVT::v2i1)); + return DAG.getNode(X86ISD::CVTUI2P, dl, MVT::v2f64, + DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::v4i32, N0)); } switch (SrcVT.SimpleTy) { default: llvm_unreachable("Custom UINT_TO_FP is not supported!"); - case MVT::v4i8: - case MVT::v4i16: - case MVT::v8i8: - case MVT::v8i16: { - MVT NVT = MVT::getVectorVT(MVT::i32, SrcVT.getVectorNumElements()); - return DAG.getNode(ISD::SINT_TO_FP, dl, Op.getValueType(), - DAG.getNode(ISD::ZERO_EXTEND, dl, NVT, N0)); - } case MVT::v2i32: return lowerUINT_TO_FP_v2i32(Op, DAG, Subtarget, dl); case MVT::v4i32: case MVT::v8i32: + assert(!Subtarget.hasAVX512()); return lowerUINT_TO_FP_vXi32(Op, DAG, Subtarget); - case MVT::v16i8: - case MVT::v16i16: - assert(Subtarget.hasAVX512()); - return DAG.getNode(ISD::UINT_TO_FP, dl, Op.getValueType(), - DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::v16i32, N0)); } } @@ -15543,14 +16037,8 @@ SDValue X86TargetLowering::LowerUINT_TO_FP(SDValue Op, SDLoc dl(Op); auto PtrVT = getPointerTy(DAG.getDataLayout()); - // Since UINT_TO_FP is legal (it's marked custom), dag combiner won't - // optimize it to a SINT_TO_FP when the sign bit is known zero. Perform - // the optimization here. - if (DAG.SignBitIsZero(N0)) - return DAG.getNode(ISD::SINT_TO_FP, dl, Op.getValueType(), N0); - if (Op.getSimpleValueType().isVector()) - return lowerUINT_TO_FP_vec(Op, DAG); + return lowerUINT_TO_FP_vec(Op, DAG, Subtarget); MVT SrcVT = N0.getSimpleValueType(); MVT DstVT = Op.getSimpleValueType(); @@ -15563,9 +16051,9 @@ SDValue X86TargetLowering::LowerUINT_TO_FP(SDValue Op, } if (SrcVT == MVT::i64 && DstVT == MVT::f64 && X86ScalarSSEf64) - return LowerUINT_TO_FP_i64(Op, DAG); + return LowerUINT_TO_FP_i64(Op, DAG, Subtarget); if (SrcVT == MVT::i32 && X86ScalarSSEf64) - return LowerUINT_TO_FP_i32(Op, DAG); + return LowerUINT_TO_FP_i32(Op, DAG, Subtarget); if (Subtarget.is64Bit() && SrcVT == MVT::i64 && DstVT == MVT::f32) return SDValue(); @@ -15827,8 +16315,18 @@ static SDValue LowerAVXExtend(SDValue Op, SelectionDAG &DAG, MVT InVT = In.getSimpleValueType(); SDLoc dl(Op); - if (VT.is512BitVector() || InVT.getVectorElementType() == MVT::i1) - return DAG.getNode(ISD::ZERO_EXTEND, dl, VT, In); + if ((VT != MVT::v4i64 || InVT != MVT::v4i32) && + (VT != MVT::v8i32 || InVT != MVT::v8i16) && + (VT != MVT::v16i16 || InVT != MVT::v16i8) && + (VT != MVT::v8i64 || InVT != MVT::v8i32) && + (VT != MVT::v8i64 || InVT != MVT::v8i16) && + (VT != MVT::v16i32 || InVT != MVT::v16i16) && + (VT != MVT::v16i32 || InVT != MVT::v16i8) && + (VT != MVT::v32i16 || InVT != MVT::v32i8)) + return SDValue(); + + if (Subtarget.hasInt256()) + return DAG.getNode(X86ISD::VZEXT, dl, VT, In); // Optimize vectors in AVX mode: // @@ -15843,14 +16341,6 @@ static SDValue LowerAVXExtend(SDValue Op, SelectionDAG &DAG, // Concat upper and lower parts. // - if (((VT != MVT::v16i16) || (InVT != MVT::v16i8)) && - ((VT != MVT::v8i32) || (InVT != MVT::v8i16)) && - ((VT != MVT::v4i64) || (InVT != MVT::v4i32))) - return SDValue(); - - if (Subtarget.hasInt256()) - return DAG.getNode(X86ISD::VZEXT, dl, VT, In); - SDValue ZeroVec = getZeroVector(InVT, Subtarget, DAG, dl); SDValue Undef = DAG.getUNDEF(InVT); bool NeedZero = Op.getOpcode() == ISD::ZERO_EXTEND; @@ -15866,74 +16356,81 @@ static SDValue LowerAVXExtend(SDValue Op, SelectionDAG &DAG, return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, OpLo, OpHi); } -static SDValue LowerZERO_EXTEND_AVX512(SDValue Op, - const X86Subtarget &Subtarget, SelectionDAG &DAG) { +static SDValue LowerZERO_EXTEND_Mask(SDValue Op, + const X86Subtarget &Subtarget, + SelectionDAG &DAG) { MVT VT = Op->getSimpleValueType(0); SDValue In = Op->getOperand(0); MVT InVT = In.getSimpleValueType(); + assert(InVT.getVectorElementType() == MVT::i1 && "Unexpected input type!"); SDLoc DL(Op); unsigned NumElts = VT.getVectorNumElements(); - if (VT.is512BitVector() && InVT.getVectorElementType() != MVT::i1 && - (NumElts == 8 || NumElts == 16 || Subtarget.hasBWI())) - return DAG.getNode(X86ISD::VZEXT, DL, VT, In); + // Extend VT if the scalar type is v8/v16 and BWI is not supported. + MVT ExtVT = VT; + if (!Subtarget.hasBWI() && + (VT.getVectorElementType().getSizeInBits() <= 16)) + ExtVT = MVT::getVectorVT(MVT::i32, NumElts); - if (InVT.getVectorElementType() != MVT::i1) - return SDValue(); + // Widen to 512-bits if VLX is not supported. + MVT WideVT = ExtVT; + if (!ExtVT.is512BitVector() && !Subtarget.hasVLX()) { + NumElts *= 512 / ExtVT.getSizeInBits(); + InVT = MVT::getVectorVT(MVT::i1, NumElts); + In = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, InVT, DAG.getUNDEF(InVT), + In, DAG.getIntPtrConstant(0, DL)); + WideVT = MVT::getVectorVT(ExtVT.getVectorElementType(), + NumElts); + } - // Extend VT if the target is 256 or 128bit vector and VLX is not supported. - MVT ExtVT = VT; - if (!VT.is512BitVector() && !Subtarget.hasVLX()) - ExtVT = MVT::getVectorVT(MVT::getIntegerVT(512/NumElts), NumElts); + SDValue One = DAG.getConstant(1, DL, WideVT); + SDValue Zero = getZeroVector(WideVT, Subtarget, DAG, DL); - SDValue One = - DAG.getConstant(APInt(ExtVT.getScalarSizeInBits(), 1), DL, ExtVT); - SDValue Zero = - DAG.getConstant(APInt::getNullValue(ExtVT.getScalarSizeInBits()), DL, ExtVT); + SDValue SelectedVal = DAG.getSelect(DL, WideVT, In, One, Zero); - SDValue SelectedVal = DAG.getSelect(DL, ExtVT, In, One, Zero); - if (VT == ExtVT) - return SelectedVal; - return DAG.getNode(X86ISD::VTRUNC, DL, VT, SelectedVal); -} + // Truncate if we had to extend i16/i8 above. + if (VT != ExtVT) { + WideVT = MVT::getVectorVT(VT.getVectorElementType(), NumElts); + SelectedVal = DAG.getNode(ISD::TRUNCATE, DL, WideVT, SelectedVal); + } -static SDValue LowerANY_EXTEND(SDValue Op, const X86Subtarget &Subtarget, - SelectionDAG &DAG) { - if (Subtarget.hasFp256()) - if (SDValue Res = LowerAVXExtend(Op, DAG, Subtarget)) - return Res; + // Extract back to 128/256-bit if we widened. + if (WideVT != VT) + SelectedVal = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, SelectedVal, + DAG.getIntPtrConstant(0, DL)); - return SDValue(); + return SelectedVal; } static SDValue LowerZERO_EXTEND(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG) { - SDLoc DL(Op); - MVT VT = Op.getSimpleValueType(); SDValue In = Op.getOperand(0); MVT SVT = In.getSimpleValueType(); - if (VT.is512BitVector() || SVT.getVectorElementType() == MVT::i1) - return LowerZERO_EXTEND_AVX512(Op, Subtarget, DAG); + if (SVT.getVectorElementType() == MVT::i1) + return LowerZERO_EXTEND_Mask(Op, Subtarget, DAG); if (Subtarget.hasFp256()) if (SDValue Res = LowerAVXExtend(Op, DAG, Subtarget)) return Res; - assert(!VT.is256BitVector() || !SVT.is128BitVector() || - VT.getVectorNumElements() != SVT.getVectorNumElements()); + assert(!Op.getSimpleValueType().is256BitVector() || !SVT.is128BitVector() || + Op.getSimpleValueType().getVectorNumElements() != + SVT.getVectorNumElements()); return SDValue(); } -/// Helper to recursively truncate vector elements in half with PACKSS. -/// It makes use of the fact that vector comparison results will be all-zeros -/// or all-ones to use (vXi8 PACKSS(vYi16, vYi16)) instead of matching types. -/// AVX2 (Int256) sub-targets require extra shuffling as the PACKSS operates +/// Helper to recursively truncate vector elements in half with PACKSS/PACKUS. +/// It makes use of the fact that vectors with enough leading sign/zero bits +/// prevent the PACKSS/PACKUS from saturating the results. +/// AVX2 (Int256) sub-targets require extra shuffling as the PACK*S operates /// within each 128-bit lane. -static SDValue truncateVectorCompareWithPACKSS(EVT DstVT, SDValue In, - const SDLoc &DL, - SelectionDAG &DAG, - const X86Subtarget &Subtarget) { +static SDValue truncateVectorWithPACK(unsigned Opcode, EVT DstVT, SDValue In, + const SDLoc &DL, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + assert((Opcode == X86ISD::PACKSS || Opcode == X86ISD::PACKUS) && + "Unexpected PACK opcode"); + // Requires SSE2 but AVX512 has fast truncate. if (!Subtarget.hasSSE2() || Subtarget.hasAVX512()) return SDValue(); @@ -15946,40 +16443,52 @@ static SDValue truncateVectorCompareWithPACKSS(EVT DstVT, SDValue In, // We only support vector truncation to 128bits or greater from a // 256bits or greater source. - if ((DstVT.getSizeInBits() % 128) != 0) - return SDValue(); - if ((SrcVT.getSizeInBits() % 256) != 0) + unsigned DstSizeInBits = DstVT.getSizeInBits(); + unsigned SrcSizeInBits = SrcVT.getSizeInBits(); + if ((DstSizeInBits % 128) != 0 || (SrcSizeInBits % 256) != 0) return SDValue(); + LLVMContext &Ctx = *DAG.getContext(); unsigned NumElems = SrcVT.getVectorNumElements(); assert(DstVT.getVectorNumElements() == NumElems && "Illegal truncation"); - assert(SrcVT.getSizeInBits() > DstVT.getSizeInBits() && "Illegal truncation"); + assert(SrcSizeInBits > DstSizeInBits && "Illegal truncation"); - EVT PackedSVT = - EVT::getIntegerVT(*DAG.getContext(), SrcVT.getScalarSizeInBits() / 2); + EVT PackedSVT = EVT::getIntegerVT(Ctx, SrcVT.getScalarSizeInBits() / 2); // Extract lower/upper subvectors. unsigned NumSubElts = NumElems / 2; - unsigned SrcSizeInBits = SrcVT.getSizeInBits(); SDValue Lo = extractSubVector(In, 0 * NumSubElts, DAG, DL, SrcSizeInBits / 2); SDValue Hi = extractSubVector(In, 1 * NumSubElts, DAG, DL, SrcSizeInBits / 2); - // 256bit -> 128bit truncate - PACKSS lower/upper 128-bit subvectors. + // Pack to the largest type possible: + // vXi64/vXi32 -> PACK*SDW and vXi16 -> PACK*SWB. + EVT InVT = MVT::i16, OutVT = MVT::i8; + if (DstVT.getScalarSizeInBits() > 8 && + (Opcode == X86ISD::PACKSS || Subtarget.hasSSE41())) { + InVT = MVT::i32; + OutVT = MVT::i16; + } + + unsigned SubSizeInBits = SrcSizeInBits / 2; + InVT = EVT::getVectorVT(Ctx, InVT, SubSizeInBits / InVT.getSizeInBits()); + OutVT = EVT::getVectorVT(Ctx, OutVT, SubSizeInBits / OutVT.getSizeInBits()); + + // 256bit -> 128bit truncate - PACK lower/upper 128-bit subvectors. if (SrcVT.is256BitVector()) { - Lo = DAG.getBitcast(MVT::v8i16, Lo); - Hi = DAG.getBitcast(MVT::v8i16, Hi); - SDValue Res = DAG.getNode(X86ISD::PACKSS, DL, MVT::v16i8, Lo, Hi); + Lo = DAG.getBitcast(InVT, Lo); + Hi = DAG.getBitcast(InVT, Hi); + SDValue Res = DAG.getNode(Opcode, DL, OutVT, Lo, Hi); return DAG.getBitcast(DstVT, Res); } - // AVX2: 512bit -> 256bit truncate - PACKSS lower/upper 256-bit subvectors. - // AVX2: 512bit -> 128bit truncate - PACKSS(PACKSS, PACKSS). + // AVX2: 512bit -> 256bit truncate - PACK lower/upper 256-bit subvectors. + // AVX2: 512bit -> 128bit truncate - PACK(PACK, PACK). if (SrcVT.is512BitVector() && Subtarget.hasInt256()) { - Lo = DAG.getBitcast(MVT::v16i16, Lo); - Hi = DAG.getBitcast(MVT::v16i16, Hi); - SDValue Res = DAG.getNode(X86ISD::PACKSS, DL, MVT::v32i8, Lo, Hi); + Lo = DAG.getBitcast(InVT, Lo); + Hi = DAG.getBitcast(InVT, Hi); + SDValue Res = DAG.getNode(Opcode, DL, OutVT, Lo, Hi); - // 256-bit PACKSS(ARG0, ARG1) leaves us with ((LO0,LO1),(HI0,HI1)), + // 256-bit PACK(ARG0, ARG1) leaves us with ((LO0,LO1),(HI0,HI1)), // so we need to shuffle to get ((LO0,HI0),(LO1,HI1)). Res = DAG.getBitcast(MVT::v4i64, Res); Res = DAG.getVectorShuffle(MVT::v4i64, DL, Res, Res, {0, 2, 1, 3}); @@ -15988,20 +16497,20 @@ static SDValue truncateVectorCompareWithPACKSS(EVT DstVT, SDValue In, return DAG.getBitcast(DstVT, Res); // If 512bit -> 128bit truncate another stage. - EVT PackedVT = EVT::getVectorVT(*DAG.getContext(), PackedSVT, NumElems); + EVT PackedVT = EVT::getVectorVT(Ctx, PackedSVT, NumElems); Res = DAG.getBitcast(PackedVT, Res); - return truncateVectorCompareWithPACKSS(DstVT, Res, DL, DAG, Subtarget); + return truncateVectorWithPACK(Opcode, DstVT, Res, DL, DAG, Subtarget); } // Recursively pack lower/upper subvectors, concat result and pack again. - assert(SrcVT.getSizeInBits() >= 512 && "Expected 512-bit vector or greater"); - EVT PackedVT = EVT::getVectorVT(*DAG.getContext(), PackedSVT, NumElems / 2); - Lo = truncateVectorCompareWithPACKSS(PackedVT, Lo, DL, DAG, Subtarget); - Hi = truncateVectorCompareWithPACKSS(PackedVT, Hi, DL, DAG, Subtarget); + assert(SrcSizeInBits >= 512 && "Expected 512-bit vector or greater"); + EVT PackedVT = EVT::getVectorVT(Ctx, PackedSVT, NumSubElts); + Lo = truncateVectorWithPACK(Opcode, PackedVT, Lo, DL, DAG, Subtarget); + Hi = truncateVectorWithPACK(Opcode, PackedVT, Hi, DL, DAG, Subtarget); - PackedVT = EVT::getVectorVT(*DAG.getContext(), PackedSVT, NumElems); + PackedVT = EVT::getVectorVT(Ctx, PackedSVT, NumElems); SDValue Res = DAG.getNode(ISD::CONCAT_VECTORS, DL, PackedVT, Lo, Hi); - return truncateVectorCompareWithPACKSS(DstVT, Res, DL, DAG, Subtarget); + return truncateVectorWithPACK(Opcode, DstVT, Res, DL, DAG, Subtarget); } static SDValue LowerTruncateVecI1(SDValue Op, SelectionDAG &DAG, @@ -16019,27 +16528,34 @@ static SDValue LowerTruncateVecI1(SDValue Op, SelectionDAG &DAG, if (InVT.getScalarSizeInBits() <= 16) { if (Subtarget.hasBWI()) { // legal, will go to VPMOVB2M, VPMOVW2M - // Shift packed bytes not supported natively, bitcast to word - MVT ExtVT = MVT::getVectorVT(MVT::i16, InVT.getSizeInBits()/16); - SDValue ShiftNode = DAG.getNode(ISD::SHL, DL, ExtVT, - DAG.getBitcast(ExtVT, In), - DAG.getConstant(ShiftInx, DL, ExtVT)); - ShiftNode = DAG.getBitcast(InVT, ShiftNode); - return DAG.getNode(X86ISD::CVT2MASK, DL, VT, ShiftNode); + if (DAG.ComputeNumSignBits(In) < InVT.getScalarSizeInBits()) { + // We need to shift to get the lsb into sign position. + // Shift packed bytes not supported natively, bitcast to word + MVT ExtVT = MVT::getVectorVT(MVT::i16, InVT.getSizeInBits()/16); + In = DAG.getNode(ISD::SHL, DL, ExtVT, + DAG.getBitcast(ExtVT, In), + DAG.getConstant(ShiftInx, DL, ExtVT)); + In = DAG.getBitcast(InVT, In); + } + return DAG.getNode(X86ISD::CVT2MASK, DL, VT, In); } // Use TESTD/Q, extended vector to packed dword/qword. assert((InVT.is256BitVector() || InVT.is128BitVector()) && "Unexpected vector type."); unsigned NumElts = InVT.getVectorNumElements(); - MVT ExtVT = MVT::getVectorVT(MVT::getIntegerVT(512/NumElts), NumElts); + MVT EltVT = Subtarget.hasVLX() ? MVT::i32 : MVT::getIntegerVT(512/NumElts); + MVT ExtVT = MVT::getVectorVT(EltVT, NumElts); In = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtVT, In); InVT = ExtVT; ShiftInx = InVT.getScalarSizeInBits() - 1; } - SDValue ShiftNode = DAG.getNode(ISD::SHL, DL, InVT, In, - DAG.getConstant(ShiftInx, DL, InVT)); - return DAG.getNode(X86ISD::TESTM, DL, VT, ShiftNode, ShiftNode); + if (DAG.ComputeNumSignBits(In) < InVT.getScalarSizeInBits()) { + // We need to shift to get the lsb into sign position. + In = DAG.getNode(ISD::SHL, DL, InVT, In, + DAG.getConstant(ShiftInx, DL, InVT)); + } + return DAG.getNode(X86ISD::TESTM, DL, VT, In, In); } SDValue X86TargetLowering::LowerTRUNCATE(SDValue Op, SelectionDAG &DAG) const { @@ -16047,15 +16563,8 @@ SDValue X86TargetLowering::LowerTRUNCATE(SDValue Op, SelectionDAG &DAG) const { MVT VT = Op.getSimpleValueType(); SDValue In = Op.getOperand(0); MVT InVT = In.getSimpleValueType(); + unsigned InNumEltBits = InVT.getScalarSizeInBits(); - if (VT == MVT::i1) { - assert((InVT.isInteger() && (InVT.getSizeInBits() <= 64)) && - "Invalid scalar TRUNCATE operation"); - if (InVT.getSizeInBits() >= 32) - return SDValue(); - In = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, In); - return DAG.getNode(ISD::TRUNCATE, DL, VT, In); - } assert(VT.getVectorNumElements() == InVT.getVectorNumElements() && "Invalid TRUNCATE operation"); @@ -16071,9 +16580,23 @@ SDValue X86TargetLowering::LowerTRUNCATE(SDValue Op, SelectionDAG &DAG) const { return DAG.getNode(X86ISD::VTRUNC, DL, VT, In); } - // Truncate with PACKSS if we are truncating a vector zero/all-bits result. - if (InVT.getScalarSizeInBits() == DAG.ComputeNumSignBits(In)) - if (SDValue V = truncateVectorCompareWithPACKSS(VT, In, DL, DAG, Subtarget)) + // Truncate with PACKSS if we are truncating a vector with sign-bits that + // extend all the way to the packed/truncated value. + unsigned NumPackedBits = std::min<unsigned>(VT.getScalarSizeInBits(), 16); + if ((InNumEltBits - NumPackedBits) < DAG.ComputeNumSignBits(In)) + if (SDValue V = + truncateVectorWithPACK(X86ISD::PACKSS, VT, In, DL, DAG, Subtarget)) + return V; + + // Truncate with PACKUS if we are truncating a vector with leading zero bits + // that extend all the way to the packed/truncated value. + // Pre-SSE41 we can only use PACKUSWB. + KnownBits Known; + DAG.computeKnownBits(In, Known); + NumPackedBits = Subtarget.hasSSE41() ? NumPackedBits : 8; + if ((InNumEltBits - NumPackedBits) <= Known.countMinLeadingZeros()) + if (SDValue V = + truncateVectorWithPACK(X86ISD::PACKUS, VT, In, DL, DAG, Subtarget)) return V; if ((VT == MVT::v4i32) && (InVT == MVT::v4i64)) { @@ -16165,9 +16688,29 @@ SDValue X86TargetLowering::LowerFP_TO_INT(SDValue Op, SelectionDAG &DAG) const { MVT VT = Op.getSimpleValueType(); if (VT.isVector()) { - assert(Subtarget.hasDQI() && Subtarget.hasVLX() && "Requires AVX512DQVL!"); SDValue Src = Op.getOperand(0); SDLoc dl(Op); + + if (VT == MVT::v2i1 && Src.getSimpleValueType() == MVT::v2f64) { + MVT ResVT = MVT::v4i32; + MVT TruncVT = MVT::v4i1; + unsigned Opc = IsSigned ? X86ISD::CVTTP2SI : X86ISD::CVTTP2UI; + if (!IsSigned && !Subtarget.hasVLX()) { + // Widen to 512-bits. + ResVT = MVT::v8i32; + TruncVT = MVT::v8i1; + Opc = ISD::FP_TO_UINT; + Src = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, MVT::v8f64, + DAG.getUNDEF(MVT::v8f64), + Src, DAG.getIntPtrConstant(0, dl)); + } + SDValue Res = DAG.getNode(Opc, dl, ResVT, Src); + Res = DAG.getNode(ISD::TRUNCATE, dl, TruncVT, Res); + return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v2i1, Res, + DAG.getIntPtrConstant(0, dl)); + } + + assert(Subtarget.hasDQI() && Subtarget.hasVLX() && "Requires AVX512DQVL!"); if (VT == MVT::v2i64 && Src.getSimpleValueType() == MVT::v2f32) { return DAG.getNode(IsSigned ? X86ISD::CVTTP2SI : X86ISD::CVTTP2UI, dl, VT, DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4f32, Src, @@ -16474,24 +17017,6 @@ static bool hasNonFlagsUse(SDValue Op) { return false; } -// Emit KTEST instruction for bit vectors on AVX-512 -static SDValue EmitKTEST(SDValue Op, SelectionDAG &DAG, - const X86Subtarget &Subtarget) { - if (Op.getOpcode() == ISD::BITCAST) { - auto hasKTEST = [&](MVT VT) { - unsigned SizeInBits = VT.getSizeInBits(); - return (Subtarget.hasDQI() && (SizeInBits == 8 || SizeInBits == 16)) || - (Subtarget.hasBWI() && (SizeInBits == 32 || SizeInBits == 64)); - }; - SDValue Op0 = Op.getOperand(0); - MVT Op0VT = Op0.getValueType().getSimpleVT(); - if (Op0VT.isVector() && Op0VT.getVectorElementType() == MVT::i1 && - hasKTEST(Op0VT)) - return DAG.getNode(X86ISD::KTEST, SDLoc(Op), Op0VT, Op0, Op0); - } - return SDValue(); -} - /// Emit nodes that will be selected as "test Op0,Op0", or something /// equivalent. SDValue X86TargetLowering::EmitTest(SDValue Op, unsigned X86CC, const SDLoc &dl, @@ -16536,9 +17061,6 @@ SDValue X86TargetLowering::EmitTest(SDValue Op, unsigned X86CC, const SDLoc &dl, // doing a separate TEST. TEST always sets OF and CF to 0, so unless // we prove that the arithmetic won't overflow, we can't use OF or CF. if (Op.getResNo() != 0 || NeedOF || NeedCF) { - // Emit KTEST for bit vectors - if (auto Node = EmitKTEST(Op, DAG, Subtarget)) - return Node; // Emit a CMP with 0, which is the TEST pattern. return DAG.getNode(X86ISD::CMP, dl, MVT::i32, Op, DAG.getConstant(0, dl, Op.getValueType())); @@ -16579,16 +17101,11 @@ SDValue X86TargetLowering::EmitTest(SDValue Op, unsigned X86CC, const SDLoc &dl, // non-casted variable when we check for possible users. switch (ArithOp.getOpcode()) { case ISD::ADD: - // Due to an isel shortcoming, be conservative if this add is likely to be - // selected as part of a load-modify-store instruction. When the root node - // in a match is a store, isel doesn't know how to remap non-chain non-flag - // uses of other nodes in the match, such as the ADD in this case. This - // leads to the ADD being left around and reselected, with the result being - // two adds in the output. Alas, even if none our users are stores, that - // doesn't prove we're O.K. Ergo, if we have any parents that aren't - // CopyToReg or SETCC, eschew INC/DEC. A better fix seems to require - // climbing the DAG back to the root, and it doesn't seem to be worth the - // effort. + // We only want to rewrite this as a target-specific node with attached + // flags if there is a reasonable chance of either using that to do custom + // instructions selection that can fold some of the memory operands, or if + // only the flags are used. If there are other uses, leave the node alone + // and emit a test instruction. for (SDNode::use_iterator UI = Op.getNode()->use_begin(), UE = Op.getNode()->use_end(); UI != UE; ++UI) if (UI->getOpcode() != ISD::CopyToReg && @@ -16596,17 +17113,20 @@ SDValue X86TargetLowering::EmitTest(SDValue Op, unsigned X86CC, const SDLoc &dl, UI->getOpcode() != ISD::STORE) goto default_case; - if (ConstantSDNode *C = - dyn_cast<ConstantSDNode>(ArithOp.getOperand(1))) { + if (auto *C = dyn_cast<ConstantSDNode>(ArithOp.getOperand(1))) { // An add of one will be selected as an INC. - if (C->isOne() && !Subtarget.slowIncDec()) { + if (C->isOne() && + (!Subtarget.slowIncDec() || + DAG.getMachineFunction().getFunction().optForSize())) { Opcode = X86ISD::INC; NumOperands = 1; break; } // An add of negative one (subtract of one) will be selected as a DEC. - if (C->isAllOnesValue() && !Subtarget.slowIncDec()) { + if (C->isAllOnesValue() && + (!Subtarget.slowIncDec() || + DAG.getMachineFunction().getFunction().optForSize())) { Opcode = X86ISD::DEC; NumOperands = 1; break; @@ -16699,11 +17219,13 @@ SDValue X86TargetLowering::EmitTest(SDValue Op, unsigned X86CC, const SDLoc &dl, case ISD::SUB: case ISD::OR: case ISD::XOR: - // Due to the ISEL shortcoming noted above, be conservative if this op is - // likely to be selected as part of a load-modify-store instruction. + // Similar to ISD::ADD above, check if the uses will preclude useful + // lowering of the target-specific node. for (SDNode::use_iterator UI = Op.getNode()->use_begin(), UE = Op.getNode()->use_end(); UI != UE; ++UI) - if (UI->getOpcode() == ISD::STORE) + if (UI->getOpcode() != ISD::CopyToReg && + UI->getOpcode() != ISD::SETCC && + UI->getOpcode() != ISD::STORE) goto default_case; // Otherwise use a regular EFLAGS-setting instruction. @@ -16767,10 +17289,6 @@ SDValue X86TargetLowering::EmitTest(SDValue Op, unsigned X86CC, const SDLoc &dl, } if (Opcode == 0) { - // Emit KTEST for bit vectors - if (auto Node = EmitKTEST(Op, DAG, Subtarget)) - return Node; - // Emit a CMP with 0, which is the TEST pattern. return DAG.getNode(X86ISD::CMP, dl, MVT::i32, Op, DAG.getConstant(0, dl, Op.getValueType())); @@ -16799,7 +17317,7 @@ SDValue X86TargetLowering::EmitCmp(SDValue Op0, SDValue Op1, unsigned X86CC, // with an immediate. 16 bit immediates are to be avoided. if ((Op0.getValueType() == MVT::i16 && (isa<ConstantSDNode>(Op0) || isa<ConstantSDNode>(Op1))) && - !DAG.getMachineFunction().getFunction()->optForMinSize() && + !DAG.getMachineFunction().getFunction().optForMinSize() && !Subtarget.isAtom()) { unsigned ExtendOp = isX86CCUnsigned(X86CC) ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND; @@ -16808,8 +17326,7 @@ SDValue X86TargetLowering::EmitCmp(SDValue Op0, SDValue Op1, unsigned X86CC, } // Use SUB instead of CMP to enable CSE between SUB and CMP. SDVTList VTs = DAG.getVTList(Op0.getValueType(), MVT::i32); - SDValue Sub = DAG.getNode(X86ISD::SUB, dl, VTs, - Op0, Op1); + SDValue Sub = DAG.getNode(X86ISD::SUB, dl, VTs, Op0, Op1); return SDValue(Sub.getNode(), 1); } return DAG.getNode(X86ISD::CMP, dl, MVT::i32, Op0, Op1); @@ -16871,8 +17388,11 @@ SDValue X86TargetLowering::getSqrtEstimate(SDValue Op, // instructions: convert to single, rsqrtss, convert back to double, refine // (3 steps = at least 13 insts). If an 'rsqrtsd' variant was added to the ISA // along with FMA, this could be a throughput win. + // TODO: SQRT requires SSE2 to prevent the introduction of an illegal v4i32 + // after legalize types. if ((VT == MVT::f32 && Subtarget.hasSSE1()) || - (VT == MVT::v4f32 && Subtarget.hasSSE1()) || + (VT == MVT::v4f32 && Subtarget.hasSSE1() && Reciprocal) || + (VT == MVT::v4f32 && Subtarget.hasSSE2() && !Reciprocal) || (VT == MVT::v8f32 && Subtarget.hasAVX())) { if (RefinementSteps == ReciprocalEstimate::Unspecified) RefinementSteps = 1; @@ -16965,6 +17485,7 @@ static SDValue getBitTestCondition(SDValue Src, SDValue BitNo, ISD::CondCode CC, /// Result of 'and' is compared against zero. Change to a BT node if possible. static SDValue LowerAndToBT(SDValue And, ISD::CondCode CC, const SDLoc &dl, SelectionDAG &DAG) { + assert(And.getOpcode() == ISD::AND && "Expected AND node!"); SDValue Op0 = And.getOperand(0); SDValue Op1 = And.getOperand(1); if (Op0.getOpcode() == ISD::TRUNCATE) @@ -17013,36 +17534,10 @@ static SDValue LowerAndToBT(SDValue And, ISD::CondCode CC, return SDValue(); } -// Convert (truncate (srl X, N) to i1) to (bt X, N) -static SDValue LowerTruncateToBT(SDValue Op, ISD::CondCode CC, - const SDLoc &dl, SelectionDAG &DAG) { - - assert(Op.getOpcode() == ISD::TRUNCATE && Op.getValueType() == MVT::i1 && - "Expected TRUNCATE to i1 node"); - - if (Op.getOperand(0).getOpcode() != ISD::SRL) - return SDValue(); - - SDValue ShiftRight = Op.getOperand(0); - return getBitTestCondition(ShiftRight.getOperand(0), ShiftRight.getOperand(1), - CC, dl, DAG); -} - -/// Result of 'and' or 'trunc to i1' is compared against zero. -/// Change to a BT node if possible. -SDValue X86TargetLowering::LowerToBT(SDValue Op, ISD::CondCode CC, - const SDLoc &dl, SelectionDAG &DAG) const { - if (Op.getOpcode() == ISD::AND) - return LowerAndToBT(Op, CC, dl, DAG); - if (Op.getOpcode() == ISD::TRUNCATE && Op.getValueType() == MVT::i1) - return LowerTruncateToBT(Op, CC, dl, DAG); - return SDValue(); -} - /// Turns an ISD::CondCode into a value suitable for SSE floating-point mask /// CMPs. -static int translateX86FSETCC(ISD::CondCode SetCCOpcode, SDValue &Op0, - SDValue &Op1) { +static unsigned translateX86FSETCC(ISD::CondCode SetCCOpcode, SDValue &Op0, + SDValue &Op1) { unsigned SSECC; bool Swap = false; @@ -17075,8 +17570,8 @@ static int translateX86FSETCC(ISD::CondCode SetCCOpcode, SDValue &Op0, case ISD::SETULT: Swap = true; LLVM_FALLTHROUGH; case ISD::SETUGT: SSECC = 6; break; case ISD::SETO: SSECC = 7; break; - case ISD::SETUEQ: - case ISD::SETONE: SSECC = 8; break; + case ISD::SETUEQ: SSECC = 8; break; + case ISD::SETONE: SSECC = 12; break; } if (Swap) std::swap(Op0, Op1); @@ -17189,6 +17684,20 @@ static SDValue LowerIntVSETCC_AVX512(SDValue Op, SelectionDAG &DAG) { if (Swap) std::swap(Op0, Op1); + + // See if it is the case of CMP(EQ|NEQ,AND(A,B),ZERO) and change it to TESTM|NM. + if ((!Opc && SSECC == 4) || Opc == X86ISD::PCMPEQM) { + SDValue A = peekThroughBitcasts(Op0); + if ((A.getOpcode() == ISD::AND || A.getOpcode() == X86ISD::FAND) && + ISD::isBuildVectorAllZeros(Op1.getNode())) { + MVT VT0 = Op0.getSimpleValueType(); + SDValue RHS = DAG.getBitcast(VT0, A.getOperand(0)); + SDValue LHS = DAG.getBitcast(VT0, A.getOperand(1)); + return DAG.getNode(Opc == X86ISD::PCMPEQM ? X86ISD::TESTNM : X86ISD::TESTM, + dl, VT, RHS, LHS); + } + } + if (Opc) return DAG.getNode(Opc, dl, VT, Op0, Op1); Opc = Unsigned ? X86ISD::CMPMU: X86ISD::CMPM; @@ -17256,25 +17765,21 @@ static SDValue LowerVSETCC(SDValue Op, const X86Subtarget &Subtarget, // In the two cases not handled by SSE compare predicates (SETUEQ/SETONE), // emit two comparisons and a logic op to tie them together. - // TODO: This can be avoided if Intel (and only Intel as of 2016) AVX is - // available. SDValue Cmp; unsigned SSECC = translateX86FSETCC(Cond, Op0, Op1); - if (SSECC == 8) { + if (SSECC >= 8 && !Subtarget.hasAVX()) { // LLVM predicate is SETUEQ or SETONE. unsigned CC0, CC1; unsigned CombineOpc; if (Cond == ISD::SETUEQ) { CC0 = 3; // UNORD CC1 = 0; // EQ - CombineOpc = Opc == X86ISD::CMPP ? static_cast<unsigned>(X86ISD::FOR) : - static_cast<unsigned>(ISD::OR); + CombineOpc = X86ISD::FOR; } else { assert(Cond == ISD::SETONE); CC0 = 7; // ORD CC1 = 4; // NEQ - CombineOpc = Opc == X86ISD::CMPP ? static_cast<unsigned>(X86ISD::FAND) : - static_cast<unsigned>(ISD::AND); + CombineOpc = X86ISD::FAND; } SDValue Cmp0 = DAG.getNode(Opc, dl, VT, Op0, Op1, @@ -17379,6 +17884,24 @@ static SDValue LowerVSETCC(SDValue Op, const X86Subtarget &Subtarget, DAG.getConstant(CmpMode, dl, MVT::i8)); } + // (X & Y) != 0 --> (X & Y) == Y iff Y is power-of-2. + // Revert part of the simplifySetCCWithAnd combine, to avoid an invert. + if (Cond == ISD::SETNE && ISD::isBuildVectorAllZeros(Op1.getNode())) { + SDValue BC0 = peekThroughBitcasts(Op0); + if (BC0.getOpcode() == ISD::AND) { + APInt UndefElts; + SmallVector<APInt, 64> EltBits; + if (getTargetConstantBitsFromNode(BC0.getOperand(1), + VT.getScalarSizeInBits(), UndefElts, + EltBits, false, false)) { + if (llvm::all_of(EltBits, [](APInt &V) { return V.isPowerOf2(); })) { + Cond = ISD::SETEQ; + Op1 = DAG.getBitcast(VT, BC0.getOperand(1)); + } + } + } + } + // We are handling one of the integer comparisons here. Since SSE only has // GT and EQ comparisons for integer, swapping operands and multiple // operations may be required for some comparisons. @@ -17399,7 +17922,8 @@ static SDValue LowerVSETCC(SDValue Op, const X86Subtarget &Subtarget, // Special case: Use min/max operations for SETULE/SETUGE MVT VET = VT.getVectorElementType(); bool HasMinMax = - (Subtarget.hasSSE41() && (VET >= MVT::i8 && VET <= MVT::i32)) || + (Subtarget.hasAVX512() && VET == MVT::i64) || + (Subtarget.hasSSE41() && (VET == MVT::i16 || VET == MVT::i32)) || (Subtarget.hasSSE2() && (VET == MVT::i8)); bool MinMax = false; if (HasMinMax) { @@ -17544,6 +18068,34 @@ static SDValue LowerVSETCC(SDValue Op, const X86Subtarget &Subtarget, return Result; } +// Try to select this as a KTEST+SETCC if possible. +static SDValue EmitKTEST(SDValue Op0, SDValue Op1, ISD::CondCode CC, + const SDLoc &dl, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + // Only support equality comparisons. + if (CC != ISD::SETEQ && CC != ISD::SETNE) + return SDValue(); + + // Must be a bitcast from vXi1. + if (Op0.getOpcode() != ISD::BITCAST) + return SDValue(); + + Op0 = Op0.getOperand(0); + MVT VT = Op0.getSimpleValueType(); + if (!(Subtarget.hasDQI() && (VT == MVT::v8i1 || VT == MVT::v16i1)) && + !(Subtarget.hasBWI() && (VT == MVT::v32i1 || VT == MVT::v64i1))) + return SDValue(); + + X86::CondCode X86CC; + if (isNullConstant(Op1)) { + X86CC = CC == ISD::SETEQ ? X86::COND_E : X86::COND_NE; + } else + return SDValue(); + + SDValue KTEST = DAG.getNode(X86ISD::KTEST, dl, MVT::i32, Op0, Op0); + return getSETCC(X86CC, KTEST, dl, DAG); +} + SDValue X86TargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const { MVT VT = Op.getSimpleValueType(); @@ -17560,16 +18112,16 @@ SDValue X86TargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const { // Lower (X & (1 << N)) == 0 to BT(X, N). // Lower ((X >>u N) & 1) != 0 to BT(X, N). // Lower ((X >>s N) & 1) != 0 to BT(X, N). - // Lower (trunc (X >> N) to i1) to BT(X, N). - if (Op0.hasOneUse() && isNullConstant(Op1) && + if (Op0.getOpcode() == ISD::AND && Op0.hasOneUse() && isNullConstant(Op1) && (CC == ISD::SETEQ || CC == ISD::SETNE)) { - if (SDValue NewSetCC = LowerToBT(Op0, CC, dl, DAG)) { - if (VT == MVT::i1) - return DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, NewSetCC); + if (SDValue NewSetCC = LowerAndToBT(Op0, CC, dl, DAG)) return NewSetCC; - } } + // Try to lower using KTEST. + if (SDValue NewSetCC = EmitKTEST(Op0, Op1, CC, dl, DAG, Subtarget)) + return NewSetCC; + // Look for X == 0, X == 1, X != 0, or X != 1. We can simplify some forms of // these. if ((isOneConstant(Op1) || isNullConstant(Op1)) && @@ -17584,20 +18136,7 @@ SDValue X86TargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const { return Op0; CCode = X86::GetOppositeBranchCondition(CCode); - SDValue SetCC = getSETCC(CCode, Op0.getOperand(1), dl, DAG); - if (VT == MVT::i1) - return DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, SetCC); - return SetCC; - } - } - if (Op0.getValueType() == MVT::i1 && (CC == ISD::SETEQ || CC == ISD::SETNE)) { - if (isOneConstant(Op1)) { - ISD::CondCode NewCC = ISD::getSetCCInverse(CC, true); - return DAG.getSetCC(dl, VT, Op0, DAG.getConstant(0, dl, MVT::i1), NewCC); - } - if (!isNullConstant(Op1)) { - SDValue Xor = DAG.getNode(ISD::XOR, dl, MVT::i1, Op0, Op1); - return DAG.getSetCC(dl, VT, Xor, DAG.getConstant(0, dl, MVT::i1), CC); + return getSETCC(CCode, Op0.getOperand(1), dl, DAG); } } @@ -17608,10 +18147,7 @@ SDValue X86TargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const { SDValue EFLAGS = EmitCmp(Op0, Op1, X86CC, dl, DAG); EFLAGS = ConvertCmpIfNecessary(EFLAGS, DAG); - SDValue SetCC = getSETCC(X86CC, EFLAGS, dl, DAG); - if (VT == MVT::i1) - return DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, SetCC); - return SetCC; + return getSETCC(X86CC, EFLAGS, dl, DAG); } SDValue X86TargetLowering::LowerSETCCCARRY(SDValue Op, SelectionDAG &DAG) const { @@ -17632,10 +18168,7 @@ SDValue X86TargetLowering::LowerSETCCCARRY(SDValue Op, SelectionDAG &DAG) const SDVTList VTs = DAG.getVTList(LHS.getValueType(), MVT::i32); SDValue Cmp = DAG.getNode(X86ISD::SBB, DL, VTs, LHS, RHS, Carry.getValue(1)); - SDValue SetCC = getSETCC(CC, Cmp.getValue(1), DL, DAG); - if (Op.getSimpleValueType() == MVT::i1) - return DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, SetCC); - return SetCC; + return getSETCC(CC, Cmp.getValue(1), DL, DAG); } /// Return true if opcode is a X86 logical comparison. @@ -17646,7 +18179,7 @@ static bool isX86LogicalCmp(SDValue Op) { return true; if (Op.getResNo() == 1 && (Opc == X86ISD::ADD || Opc == X86ISD::SUB || Opc == X86ISD::ADC || - Opc == X86ISD::SBB || Opc == X86ISD::SMUL || Opc == X86ISD::UMUL || + Opc == X86ISD::SBB || Opc == X86ISD::SMUL || Opc == X86ISD::INC || Opc == X86ISD::DEC || Opc == X86ISD::OR || Opc == X86ISD::XOR || Opc == X86ISD::AND)) return true; @@ -17684,17 +18217,17 @@ SDValue X86TargetLowering::LowerSELECT(SDValue Op, SelectionDAG &DAG) const { (Subtarget.hasSSE1() && VT == MVT::f32)) && VT == Cond.getOperand(0).getSimpleValueType() && Cond->hasOneUse()) { SDValue CondOp0 = Cond.getOperand(0), CondOp1 = Cond.getOperand(1); - int SSECC = translateX86FSETCC( + unsigned SSECC = translateX86FSETCC( cast<CondCodeSDNode>(Cond.getOperand(2))->get(), CondOp0, CondOp1); - if (SSECC != 8) { - if (Subtarget.hasAVX512()) { - SDValue Cmp = DAG.getNode(X86ISD::FSETCCM, DL, MVT::v1i1, CondOp0, - CondOp1, DAG.getConstant(SSECC, DL, MVT::i8)); - return DAG.getNode(VT.isVector() ? X86ISD::SELECT : X86ISD::SELECTS, - DL, VT, Cmp, Op1, Op2); - } + if (Subtarget.hasAVX512()) { + SDValue Cmp = DAG.getNode(X86ISD::FSETCCM, DL, MVT::v1i1, CondOp0, + CondOp1, DAG.getConstant(SSECC, DL, MVT::i8)); + assert(!VT.isVector() && "Not a scalar type?"); + return DAG.getNode(X86ISD::SELECTS, DL, VT, Cmp, Op1, Op2); + } + if (SSECC < 8 || Subtarget.hasAVX()) { SDValue Cmp = DAG.getNode(X86ISD::FSETCC, DL, VT, CondOp0, CondOp1, DAG.getConstant(SSECC, DL, MVT::i8)); @@ -17742,6 +18275,18 @@ SDValue X86TargetLowering::LowerSELECT(SDValue Op, SelectionDAG &DAG) const { return DAG.getNode(X86ISD::SELECTS, DL, VT, Cmp, Op1, Op2); } + // For v64i1 without 64-bit support we need to split and rejoin. + if (VT == MVT::v64i1 && !Subtarget.is64Bit()) { + assert(Subtarget.hasBWI() && "Expected BWI to be legal"); + SDValue Op1Lo = extractSubVector(Op1, 0, DAG, DL, 32); + SDValue Op2Lo = extractSubVector(Op2, 0, DAG, DL, 32); + SDValue Op1Hi = extractSubVector(Op1, 32, DAG, DL, 32); + SDValue Op2Hi = extractSubVector(Op2, 32, DAG, DL, 32); + SDValue Lo = DAG.getSelect(DL, MVT::v32i1, Cond, Op1Lo, Op2Lo); + SDValue Hi = DAG.getSelect(DL, MVT::v32i1, Cond, Op1Hi, Op2Hi); + return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Lo, Hi); + } + if (VT.isVector() && VT.getVectorElementType() == MVT::i1) { SDValue Op1Scalar; if (ISD::isBuildVectorOfConstantSDNodes(Op1.getNode())) @@ -17941,7 +18486,7 @@ SDValue X86TargetLowering::LowerSELECT(SDValue Op, SelectionDAG &DAG) const { // We know the result of AND is compared against zero. Try to match // it to BT. if (Cond.getOpcode() == ISD::AND && Cond.hasOneUse()) { - if (SDValue NewSetCC = LowerToBT(Cond, ISD::SETNE, DL, DAG)) { + if (SDValue NewSetCC = LowerAndToBT(Cond, ISD::SETNE, DL, DAG)) { CC = NewSetCC.getOperand(0); Cond = NewSetCC.getOperand(1); AddTest = false; @@ -17983,66 +18528,83 @@ SDValue X86TargetLowering::LowerSELECT(SDValue Op, SelectionDAG &DAG) const { if (T1.getValueType() == T2.getValueType() && // Blacklist CopyFromReg to avoid partial register stalls. T1.getOpcode() != ISD::CopyFromReg && T2.getOpcode()!=ISD::CopyFromReg){ - SDVTList VTs = DAG.getVTList(T1.getValueType(), MVT::Glue); - SDValue Cmov = DAG.getNode(X86ISD::CMOV, DL, VTs, T2, T1, CC, Cond); + SDValue Cmov = DAG.getNode(X86ISD::CMOV, DL, T1.getValueType(), T2, T1, + CC, Cond); return DAG.getNode(ISD::TRUNCATE, DL, Op.getValueType(), Cmov); } } // X86ISD::CMOV means set the result (which is operand 1) to the RHS if // condition is true. - SDVTList VTs = DAG.getVTList(Op.getValueType(), MVT::Glue); SDValue Ops[] = { Op2, Op1, CC, Cond }; - return DAG.getNode(X86ISD::CMOV, DL, VTs, Ops); + return DAG.getNode(X86ISD::CMOV, DL, Op.getValueType(), Ops); } -static SDValue LowerSIGN_EXTEND_AVX512(SDValue Op, - const X86Subtarget &Subtarget, - SelectionDAG &DAG) { +static SDValue LowerSIGN_EXTEND_Mask(SDValue Op, + const X86Subtarget &Subtarget, + SelectionDAG &DAG) { MVT VT = Op->getSimpleValueType(0); SDValue In = Op->getOperand(0); MVT InVT = In.getSimpleValueType(); + assert(InVT.getVectorElementType() == MVT::i1 && "Unexpected input type!"); MVT VTElt = VT.getVectorElementType(); - MVT InVTElt = InVT.getVectorElementType(); SDLoc dl(Op); - // SKX processor - if ((InVTElt == MVT::i1) && - (((Subtarget.hasBWI() && VTElt.getSizeInBits() <= 16)) || + unsigned NumElts = VT.getVectorNumElements(); - ((Subtarget.hasDQI() && VTElt.getSizeInBits() >= 32)))) + // Extend VT if the scalar type is v8/v16 and BWI is not supported. + MVT ExtVT = VT; + if (!Subtarget.hasBWI() && VTElt.getSizeInBits() <= 16) + ExtVT = MVT::getVectorVT(MVT::i32, NumElts); - return DAG.getNode(X86ISD::VSEXT, dl, VT, In); + // Widen to 512-bits if VLX is not supported. + MVT WideVT = ExtVT; + if (!ExtVT.is512BitVector() && !Subtarget.hasVLX()) { + NumElts *= 512 / ExtVT.getSizeInBits(); + InVT = MVT::getVectorVT(MVT::i1, NumElts); + In = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, InVT, DAG.getUNDEF(InVT), + In, DAG.getIntPtrConstant(0, dl)); + WideVT = MVT::getVectorVT(ExtVT.getVectorElementType(), NumElts); + } - unsigned NumElts = VT.getVectorNumElements(); + SDValue V; + MVT WideEltVT = WideVT.getVectorElementType(); + if ((Subtarget.hasDQI() && WideEltVT.getSizeInBits() >= 32) || + (Subtarget.hasBWI() && WideEltVT.getSizeInBits() <= 16)) { + V = getExtendInVec(X86ISD::VSEXT, dl, WideVT, In, DAG); + } else { + SDValue NegOne = getOnesVector(WideVT, DAG, dl); + SDValue Zero = getZeroVector(WideVT, Subtarget, DAG, dl); + V = DAG.getSelect(dl, WideVT, In, NegOne, Zero); + } - if (VT.is512BitVector() && InVTElt != MVT::i1 && - (NumElts == 8 || NumElts == 16 || Subtarget.hasBWI())) { - if (In.getOpcode() == X86ISD::VSEXT || In.getOpcode() == X86ISD::VZEXT) - return getExtendInVec(In.getOpcode(), dl, VT, In.getOperand(0), DAG); - return getExtendInVec(X86ISD::VSEXT, dl, VT, In, DAG); + // Truncate if we had to extend i16/i8 above. + if (VT != ExtVT) { + WideVT = MVT::getVectorVT(VTElt, NumElts); + V = DAG.getNode(ISD::TRUNCATE, dl, WideVT, V); } - if (InVTElt != MVT::i1) - return SDValue(); + // Extract back to 128/256-bit if we widened. + if (WideVT != VT) + V = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT, V, + DAG.getIntPtrConstant(0, dl)); - MVT ExtVT = VT; - if (!VT.is512BitVector() && !Subtarget.hasVLX()) - ExtVT = MVT::getVectorVT(MVT::getIntegerVT(512/NumElts), NumElts); + return V; +} - SDValue V; - if (Subtarget.hasDQI()) { - V = getExtendInVec(X86ISD::VSEXT, dl, ExtVT, In, DAG); - assert(!VT.is512BitVector() && "Unexpected vector type"); - } else { - SDValue NegOne = getOnesVector(ExtVT, DAG, dl); - SDValue Zero = getZeroVector(ExtVT, Subtarget, DAG, dl); - V = DAG.getSelect(dl, ExtVT, In, NegOne, Zero); - if (ExtVT == VT) - return V; - } +static SDValue LowerANY_EXTEND(SDValue Op, const X86Subtarget &Subtarget, + SelectionDAG &DAG) { + SDValue In = Op->getOperand(0); + MVT InVT = In.getSimpleValueType(); + + if (InVT.getVectorElementType() == MVT::i1) + return LowerSIGN_EXTEND_Mask(Op, Subtarget, DAG); + + if (Subtarget.hasFp256()) + if (SDValue Res = LowerAVXExtend(Op, DAG, Subtarget)) + return Res; - return DAG.getNode(X86ISD::VTRUNC, dl, VT, V); + return SDValue(); } // Lowering for SIGN_EXTEND_VECTOR_INREG and ZERO_EXTEND_VECTOR_INREG. @@ -18139,12 +18701,17 @@ static SDValue LowerSIGN_EXTEND(SDValue Op, const X86Subtarget &Subtarget, MVT InVT = In.getSimpleValueType(); SDLoc dl(Op); - if (VT.is512BitVector() || InVT.getVectorElementType() == MVT::i1) - return LowerSIGN_EXTEND_AVX512(Op, Subtarget, DAG); + if (InVT.getVectorElementType() == MVT::i1) + return LowerSIGN_EXTEND_Mask(Op, Subtarget, DAG); - if ((VT != MVT::v4i64 || InVT != MVT::v4i32) && - (VT != MVT::v8i32 || InVT != MVT::v8i16) && - (VT != MVT::v16i16 || InVT != MVT::v16i8)) + if ((VT != MVT::v4i64 || InVT != MVT::v4i32) && + (VT != MVT::v8i32 || InVT != MVT::v8i16) && + (VT != MVT::v16i16 || InVT != MVT::v16i8) && + (VT != MVT::v8i64 || InVT != MVT::v8i32) && + (VT != MVT::v8i64 || InVT != MVT::v8i16) && + (VT != MVT::v16i32 || InVT != MVT::v16i16) && + (VT != MVT::v16i32 || InVT != MVT::v16i8) && + (VT != MVT::v32i16 || InVT != MVT::v32i8)) return SDValue(); if (Subtarget.hasInt256()) @@ -18217,6 +18784,7 @@ static SDValue LowerTruncatingStore(SDValue StOp, const X86Subtarget &Subtarget, DAG.getUNDEF(ExtVT), Op, DAG.getIntPtrConstant(0, dl)); } Op = DAG.getNode(ISD::TRUNCATE, dl, MVT::v8i1, Op); + Op = DAG.getBitcast(MVT::i8, Op); return DAG.getStore(St->getChain(), dl, Op, St->getBasePtr(), St->getMemOperand()); } @@ -18233,12 +18801,12 @@ static SDValue LowerTruncatingStore(SDValue StOp, const X86Subtarget &Subtarget, DAG.getIntPtrConstant(16, dl)); Hi = DAG.getNode(ISD::TRUNCATE, dl, MVT::v16i1, Hi); - SDValue BasePtrHi = - DAG.getNode(ISD::ADD, dl, BasePtr.getValueType(), BasePtr, - DAG.getConstant(2, dl, BasePtr.getValueType())); + SDValue BasePtrHi = DAG.getMemBasePlusOffset(BasePtr, 2, dl); SDValue StHi = DAG.getStore(St->getChain(), dl, Hi, - BasePtrHi, St->getMemOperand()); + BasePtrHi, St->getPointerInfo().getWithOffset(2), + MinAlign(St->getAlignment(), 2U), + St->getMemOperand()->getFlags()); return DAG.getNode(ISD::TokenFactor, dl, MVT::Other, StLo, StHi); } @@ -18267,6 +18835,14 @@ static SDValue LowerExtended1BitVectorLoad(SDValue Op, // Replace chain users with the new chain. assert(Load->getNumValues() == 2 && "Loads must carry a chain!"); DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), Load.getValue(1)); + if (Subtarget.hasVLX()) { + // Extract to v4i1/v2i1. + SDValue Extract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MemVT, Load, + DAG.getIntPtrConstant(0, dl)); + // Finally, do a normal sign-extend to the desired register. + return DAG.getNode(ExtOpcode, dl, Op.getValueType(), Extract); + } + MVT ExtVT = MVT::getVectorVT(VT.getScalarType(), 8); SDValue ExtVec = DAG.getNode(ExtOpcode, dl, ExtVT, Load); @@ -18286,22 +18862,25 @@ static SDValue LowerExtended1BitVectorLoad(SDValue Op, if (NumElts <= 8) { // A subset, assume that we have only AVX-512F - unsigned NumBitsToLoad = 8; - MVT TypeToLoad = MVT::getIntegerVT(NumBitsToLoad); - SDValue Load = DAG.getLoad(TypeToLoad, dl, Ld->getChain(), + SDValue Load = DAG.getLoad(MVT::i8, dl, Ld->getChain(), Ld->getBasePtr(), Ld->getMemOperand()); // Replace chain users with the new chain. assert(Load->getNumValues() == 2 && "Loads must carry a chain!"); DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), Load.getValue(1)); - MVT MaskVT = MVT::getVectorVT(MVT::i1, NumBitsToLoad); - SDValue BitVec = DAG.getBitcast(MaskVT, Load); + SDValue BitVec = DAG.getBitcast(MVT::v8i1, Load); if (NumElts == 8) return DAG.getNode(ExtOpcode, dl, VT, BitVec); - // we should take care to v4i1 and v2i1 + if (Subtarget.hasVLX()) { + // Extract to v4i1/v2i1. + SDValue Extract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MemVT, BitVec, + DAG.getIntPtrConstant(0, dl)); + // Finally, do a normal sign-extend to the desired register. + return DAG.getNode(ExtOpcode, dl, Op.getValueType(), Extract); + } MVT ExtVT = MVT::getVectorVT(VT.getScalarType(), 8); SDValue ExtVec = DAG.getNode(ExtOpcode, dl, ExtVT, BitVec); @@ -18311,23 +18890,20 @@ static SDValue LowerExtended1BitVectorLoad(SDValue Op, assert(VT == MVT::v32i8 && "Unexpected extload type"); - SmallVector<SDValue, 2> Chains; - SDValue BasePtr = Ld->getBasePtr(); SDValue LoadLo = DAG.getLoad(MVT::v16i1, dl, Ld->getChain(), Ld->getBasePtr(), Ld->getMemOperand()); - Chains.push_back(LoadLo.getValue(1)); - SDValue BasePtrHi = - DAG.getNode(ISD::ADD, dl, BasePtr.getValueType(), BasePtr, - DAG.getConstant(2, dl, BasePtr.getValueType())); + SDValue BasePtrHi = DAG.getMemBasePlusOffset(BasePtr, 2, dl); - SDValue LoadHi = DAG.getLoad(MVT::v16i1, dl, Ld->getChain(), - BasePtrHi, - Ld->getMemOperand()); - Chains.push_back(LoadHi.getValue(1)); - SDValue NewChain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, Chains); + SDValue LoadHi = DAG.getLoad(MVT::v16i1, dl, Ld->getChain(), BasePtrHi, + Ld->getPointerInfo().getWithOffset(2), + MinAlign(Ld->getAlignment(), 2U), + Ld->getMemOperand()->getFlags()); + + SDValue NewChain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, + LoadLo.getValue(1), LoadHi.getValue(1)); DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), NewChain); SDValue Lo = DAG.getNode(ExtOpcode, dl, MVT::v16i8, LoadLo); @@ -18443,6 +19019,12 @@ static SDValue LowerExtendedLoad(SDValue Op, const X86Subtarget &Subtarget, if (Ext == ISD::SEXTLOAD && RegSz >= 256) loadRegZize = 128; + // If we don't have BWI we won't be able to create the shuffle needed for + // v8i8->v8i64. + if (Ext == ISD::EXTLOAD && !Subtarget.hasBWI() && RegVT == MVT::v8i64 && + MemVT == MVT::v8i8) + loadRegZize = 128; + // Represent our vector as a sequence of elements which are the // largest scalar that we can load. EVT LoadUnitVecVT = EVT::getVectorVT( @@ -18509,6 +19091,13 @@ static SDValue LowerExtendedLoad(SDValue Op, const X86Subtarget &Subtarget, return Shuff; } + if (Ext == ISD::EXTLOAD && !Subtarget.hasBWI() && RegVT == MVT::v8i64 && + MemVT == MVT::v8i8) { + SDValue Sext = getExtendInVec(X86ISD::VZEXT, dl, RegVT, SlicedVec, DAG); + DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), TF); + return Sext; + } + // Redistribute the loaded elements into the different locations. SmallVector<int, 16> ShuffleVec(NumElems * SizeRatio, -1); for (unsigned i = 0; i != NumElems; ++i) @@ -18796,9 +19385,10 @@ SDValue X86TargetLowering::LowerBRCOND(SDValue Op, SelectionDAG &DAG) const { if (isTruncWithZeroHighBitsInput(Cond, DAG)) Cond = Cond.getOperand(0); - // We know the result is compared against zero. Try to match it to BT. - if (Cond.hasOneUse()) { - if (SDValue NewSetCC = LowerToBT(Cond, ISD::SETNE, dl, DAG)) { + // We know the result of AND is compared against zero. Try to match + // it to BT. + if (Cond.getOpcode() == ISD::AND && Cond.hasOneUse()) { + if (SDValue NewSetCC = LowerAndToBT(Cond, ISD::SETNE, dl, DAG)) { CC = NewSetCC.getOperand(0); Cond = NewSetCC.getOperand(1); addTest = false; @@ -18867,8 +19457,8 @@ X86TargetLowering::LowerDYNAMIC_STACKALLOC(SDValue Op, if (Is64Bit) { // The 64 bit implementation of segmented stacks needs to clobber both r10 // r11. This makes it impossible to use it along with nested parameters. - const Function *F = MF.getFunction(); - for (const auto &A : F->args()) { + const Function &F = MF.getFunction(); + for (const auto &A : F.args()) { if (A.hasNestAttr()) report_fatal_error("Cannot use segmented stacks with functions that " "have nested arguments."); @@ -18915,7 +19505,7 @@ SDValue X86TargetLowering::LowerVASTART(SDValue Op, SelectionDAG &DAG) const { SDLoc DL(Op); if (!Subtarget.is64Bit() || - Subtarget.isCallingConvWin64(MF.getFunction()->getCallingConv())) { + Subtarget.isCallingConvWin64(MF.getFunction().getCallingConv())) { // vastart just stores the address of the VarArgsFrameIndex slot into the // memory location argument. SDValue FR = DAG.getFrameIndex(FuncInfo->getVarArgsFrameIndex(), PtrVT); @@ -18969,7 +19559,7 @@ SDValue X86TargetLowering::LowerVAARG(SDValue Op, SelectionDAG &DAG) const { assert(Op.getNumOperands() == 4); MachineFunction &MF = DAG.getMachineFunction(); - if (Subtarget.isCallingConvWin64(MF.getFunction()->getCallingConv())) + if (Subtarget.isCallingConvWin64(MF.getFunction().getCallingConv())) // The Win64 ABI uses char* instead of a structure. return DAG.expandVAArg(Op.getNode()); @@ -19000,7 +19590,7 @@ SDValue X86TargetLowering::LowerVAARG(SDValue Op, SelectionDAG &DAG) const { if (ArgMode == 2) { // Sanity Check: Make sure using fp_offset makes sense. assert(!Subtarget.useSoftFloat() && - !(MF.getFunction()->hasFnAttribute(Attribute::NoImplicitFloat)) && + !(MF.getFunction().hasFnAttribute(Attribute::NoImplicitFloat)) && Subtarget.hasSSE1()); } @@ -19010,13 +19600,12 @@ SDValue X86TargetLowering::LowerVAARG(SDValue Op, SelectionDAG &DAG) const { DAG.getConstant(ArgMode, dl, MVT::i8), DAG.getConstant(Align, dl, MVT::i32)}; SDVTList VTs = DAG.getVTList(getPointerTy(DAG.getDataLayout()), MVT::Other); - SDValue VAARG = DAG.getMemIntrinsicNode(X86ISD::VAARG_64, dl, - VTs, InstOps, MVT::i64, - MachinePointerInfo(SV), - /*Align=*/0, - /*Volatile=*/false, - /*ReadMem=*/true, - /*WriteMem=*/true); + SDValue VAARG = DAG.getMemIntrinsicNode( + X86ISD::VAARG_64, dl, + VTs, InstOps, MVT::i64, + MachinePointerInfo(SV), + /*Align=*/0, + MachineMemOperand::MOLoad | MachineMemOperand::MOStore); Chain = VAARG.getValue(1); // Load the next argument and return it @@ -19029,7 +19618,7 @@ static SDValue LowerVACOPY(SDValue Op, const X86Subtarget &Subtarget, // where a va_list is still an i8*. assert(Subtarget.is64Bit() && "This code only handles 64-bit va_copy!"); if (Subtarget.isCallingConvWin64( - DAG.getMachineFunction().getFunction()->getCallingConv())) + DAG.getMachineFunction().getFunction().getCallingConv())) // Probably a Win64 va_copy. return DAG.expandVACopy(Op.getNode()); @@ -19172,8 +19761,8 @@ static SDValue getTargetVShiftNode(unsigned Opc, const SDLoc &dl, MVT VT, ShAmt = DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(ShAmt), MVT::v4i32, ShAmt); ShAmt = DAG.getZeroExtendVectorInReg(ShAmt, SDLoc(ShAmt), MVT::v2i64); } else { - SmallVector<SDValue, 4> ShOps = {ShAmt, DAG.getConstant(0, dl, SVT), - DAG.getUNDEF(SVT), DAG.getUNDEF(SVT)}; + SDValue ShOps[4] = {ShAmt, DAG.getConstant(0, dl, SVT), + DAG.getUNDEF(SVT), DAG.getUNDEF(SVT)}; ShAmt = DAG.getBuildVector(MVT::v4i32, dl, ShOps); } @@ -19193,9 +19782,9 @@ static SDValue getMaskNode(SDValue Mask, MVT MaskVT, const SDLoc &dl) { if (isAllOnesConstant(Mask)) - return DAG.getTargetConstant(1, dl, MaskVT); + return DAG.getConstant(1, dl, MaskVT); if (X86::isZeroNode(Mask)) - return DAG.getTargetConstant(0, dl, MaskVT); + return DAG.getConstant(0, dl, MaskVT); if (MaskVT.bitsGT(Mask.getSimpleValueType())) { // Mask should be extended @@ -19255,13 +19844,12 @@ static SDValue getVectorMaskingNode(SDValue Op, SDValue Mask, switch (Op.getOpcode()) { default: break; - case X86ISD::PCMPEQM: - case X86ISD::PCMPGTM: case X86ISD::CMPM: + case X86ISD::CMPM_RND: case X86ISD::CMPMU: + case X86ISD::VPSHUFBITQMB: return DAG.getNode(ISD::AND, dl, VT, Op, VMask); case X86ISD::VFPCLASS: - case X86ISD::VFPCLASSS: return DAG.getNode(ISD::OR, dl, VT, Op, VMask); case X86ISD::VTRUNC: case X86ISD::VTRUNCS: @@ -19370,8 +19958,8 @@ static SDValue recoverFramePointer(SelectionDAG &DAG, const Function *Fn, return DAG.getNode(ISD::SUB, dl, PtrVT, RegNodeBase, ParentFrameOffset); } -static SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, const X86Subtarget &Subtarget, - SelectionDAG &DAG) { +SDValue X86TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, + SelectionDAG &DAG) const { // Helper to detect if the operand is CUR_DIRECTION rounding mode. auto isRoundModeCurDirection = [](SDValue Rnd) { if (!isa<ConstantSDNode>(Rnd)) @@ -19442,14 +20030,36 @@ static SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, const X86Subtarget &Subtarget SDValue passThru = Op.getOperand(3); SDValue Mask = Op.getOperand(4); unsigned IntrWithRoundingModeOpcode = IntrData->Opc1; - if (IntrWithRoundingModeOpcode != 0) { - SDValue Rnd = Op.getOperand(5); - if (!isRoundModeCurDirection(Rnd)) + // There are 2 kinds of intrinsics in this group: + // (1) With suppress-all-exceptions (sae) or rounding mode- 6 operands + // (2) With rounding mode and sae - 7 operands. + bool HasRounding = IntrWithRoundingModeOpcode != 0; + if (Op.getNumOperands() == (5U + HasRounding)) { + if (HasRounding) { + SDValue Rnd = Op.getOperand(5); + if (!isRoundModeCurDirection(Rnd)) + return getScalarMaskingNode(DAG.getNode(IntrWithRoundingModeOpcode, + dl, VT, Src1, Src2, Rnd), + Mask, passThru, Subtarget, DAG); + } + return getScalarMaskingNode(DAG.getNode(IntrData->Opc0, dl, VT, Src1, + Src2), + Mask, passThru, Subtarget, DAG); + } + + assert(Op.getNumOperands() == (6U + HasRounding) && + "Unexpected intrinsic form"); + SDValue RoundingMode = Op.getOperand(5); + if (HasRounding) { + SDValue Sae = Op.getOperand(6); + if (!isRoundModeCurDirection(Sae)) return getScalarMaskingNode(DAG.getNode(IntrWithRoundingModeOpcode, - dl, VT, Src1, Src2, Rnd), + dl, VT, Src1, Src2, + RoundingMode, Sae), Mask, passThru, Subtarget, DAG); } - return getScalarMaskingNode(DAG.getNode(IntrData->Opc0, dl, VT, Src1, Src2), + return getScalarMaskingNode(DAG.getNode(IntrData->Opc0, dl, VT, Src1, + Src2, RoundingMode), Mask, passThru, Subtarget, DAG); } case INTR_TYPE_SCALAR_MASK_RM: { @@ -19518,16 +20128,23 @@ static SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, const X86Subtarget &Subtarget Src1, Src2, Rnd), Mask, PassThru, Subtarget, DAG); } - case INTR_TYPE_3OP_SCALAR_MASK_RM: { + case INTR_TYPE_3OP_SCALAR_MASK: { SDValue Src1 = Op.getOperand(1); SDValue Src2 = Op.getOperand(2); SDValue Src3 = Op.getOperand(3); SDValue PassThru = Op.getOperand(4); SDValue Mask = Op.getOperand(5); - SDValue Sae = Op.getOperand(6); + unsigned IntrWithRoundingModeOpcode = IntrData->Opc1; + if (IntrWithRoundingModeOpcode != 0) { + SDValue Rnd = Op.getOperand(6); + if (!isRoundModeCurDirection(Rnd)) + return getScalarMaskingNode(DAG.getNode(IntrWithRoundingModeOpcode, + dl, VT, Src1, Src2, Src3, Rnd), + Mask, PassThru, Subtarget, DAG); + } return getScalarMaskingNode(DAG.getNode(IntrData->Opc0, dl, VT, Src1, - Src2, Src3, Sae), + Src2, Src3), Mask, PassThru, Subtarget, DAG); } case INTR_TYPE_3OP_MASK_RM: { @@ -19664,10 +20281,39 @@ static SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, const X86Subtarget &Subtarget else PassThru = Src1; - SDValue Rnd = Op.getOperand(5); + unsigned IntrWithRoundingModeOpcode = IntrData->Opc1; + if (IntrWithRoundingModeOpcode != 0) { + SDValue Rnd = Op.getOperand(5); + if (!isRoundModeCurDirection(Rnd)) + return getScalarMaskingNode(DAG.getNode(IntrWithRoundingModeOpcode, dl, + Op.getValueType(), Src1, Src2, + Src3, Rnd), + Mask, PassThru, Subtarget, DAG); + } + return getScalarMaskingNode(DAG.getNode(IntrData->Opc0, dl, Op.getValueType(), Src1, Src2, - Src3, Rnd), + Src3), + Mask, PassThru, Subtarget, DAG); + } + case IFMA_OP_MASKZ: + case IFMA_OP_MASK: { + SDValue Src1 = Op.getOperand(1); + SDValue Src2 = Op.getOperand(2); + SDValue Src3 = Op.getOperand(3); + SDValue Mask = Op.getOperand(4); + MVT VT = Op.getSimpleValueType(); + SDValue PassThru = Src1; + + // set PassThru element + if (IntrData->Type == IFMA_OP_MASKZ) + PassThru = getZeroVector(VT, Subtarget, DAG, dl); + + // Node we need to swizzle the operands to pass the multiply operands + // first. + return getVectorMaskingNode(DAG.getNode(IntrData->Opc0, + dl, Op.getValueType(), + Src2, Src3, Src1), Mask, PassThru, Subtarget, DAG); } case TERLOG_OP_MASK: @@ -19726,9 +20372,8 @@ static SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, const X86Subtarget &Subtarget MVT BitcastVT = MVT::getVectorVT(MVT::i1, Mask.getSimpleValueType().getSizeInBits()); SDValue FPclass = DAG.getNode(IntrData->Opc0, dl, MaskVT, Src1, Imm); - SDValue FPclassMask = getVectorMaskingNode(FPclass, Mask, - DAG.getTargetConstant(0, dl, MaskVT), - Subtarget, DAG); + SDValue FPclassMask = getVectorMaskingNode(FPclass, Mask, SDValue(), + Subtarget, DAG); SDValue Res = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, BitcastVT, DAG.getUNDEF(BitcastVT), FPclassMask, DAG.getIntPtrConstant(0, dl)); @@ -19739,9 +20384,9 @@ static SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, const X86Subtarget &Subtarget SDValue Imm = Op.getOperand(2); SDValue Mask = Op.getOperand(3); SDValue FPclass = DAG.getNode(IntrData->Opc0, dl, MVT::v1i1, Src1, Imm); - SDValue FPclassMask = getScalarMaskingNode(FPclass, Mask, - DAG.getTargetConstant(0, dl, MVT::i1), Subtarget, DAG); - return DAG.getNode(X86ISD::VEXTRACT, dl, MVT::i8, FPclassMask, + SDValue FPclassMask = getScalarMaskingNode(FPclass, Mask, SDValue(), + Subtarget, DAG); + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i8, FPclassMask, DAG.getIntPtrConstant(0, dl)); } case CMP_MASK: @@ -19783,9 +20428,7 @@ static SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, const X86Subtarget &Subtarget Cmp = DAG.getNode(IntrData->Opc0, dl, MaskVT, Op.getOperand(1), Op.getOperand(2)); } - SDValue CmpMask = getVectorMaskingNode(Cmp, Mask, - DAG.getTargetConstant(0, dl, - MaskVT), + SDValue CmpMask = getVectorMaskingNode(Cmp, Mask, SDValue(), Subtarget, DAG); SDValue Res = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, BitcastVT, DAG.getUNDEF(BitcastVT), CmpMask, @@ -19808,11 +20451,9 @@ static SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, const X86Subtarget &Subtarget if(!Cmp.getNode()) Cmp = DAG.getNode(IntrData->Opc0, dl, MVT::v1i1, Src1, Src2, CC); - SDValue CmpMask = getScalarMaskingNode(Cmp, Mask, - DAG.getTargetConstant(0, dl, - MVT::i1), + SDValue CmpMask = getScalarMaskingNode(Cmp, Mask, SDValue(), Subtarget, DAG); - return DAG.getNode(X86ISD::VEXTRACT, dl, MVT::i8, CmpMask, + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i8, CmpMask, DAG.getIntPtrConstant(0, dl)); } case COMI: { // Comparison intrinsics @@ -19866,7 +20507,7 @@ static SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, const X86Subtarget &Subtarget else FCmp = DAG.getNode(X86ISD::FSETCCM_RND, dl, MVT::v1i1, LHS, RHS, DAG.getConstant(CondVal, dl, MVT::i8), Sae); - return DAG.getNode(X86ISD::VEXTRACT, dl, MVT::i32, FCmp, + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i32, FCmp, DAG.getIntPtrConstant(0, dl)); } case VSHIFT: @@ -19953,37 +20594,25 @@ static SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, const X86Subtarget &Subtarget DAG.getIntPtrConstant(0, dl)); return DAG.getBitcast(Op.getValueType(), Res); } - case BRCST_SUBVEC_TO_VEC: { - SDValue Src = Op.getOperand(1); - SDValue Passthru = Op.getOperand(2); - SDValue Mask = Op.getOperand(3); - EVT resVT = Passthru.getValueType(); - SDValue subVec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, resVT, - DAG.getUNDEF(resVT), Src, - DAG.getIntPtrConstant(0, dl)); - SDValue immVal; - if (Src.getSimpleValueType().is256BitVector() && resVT.is512BitVector()) - immVal = DAG.getConstant(0x44, dl, MVT::i8); - else - immVal = DAG.getConstant(0, dl, MVT::i8); - return getVectorMaskingNode(DAG.getNode(IntrData->Opc0, dl, VT, - subVec, subVec, immVal), - Mask, Passthru, Subtarget, DAG); - } - case BRCST32x2_TO_VEC: { - SDValue Src = Op.getOperand(1); - SDValue PassThru = Op.getOperand(2); - SDValue Mask = Op.getOperand(3); - - assert((VT.getScalarType() == MVT::i32 || - VT.getScalarType() == MVT::f32) && "Unexpected type!"); - //bitcast Src to packed 64 - MVT ScalarVT = VT.getScalarType() == MVT::i32 ? MVT::i64 : MVT::f64; - MVT BitcastVT = MVT::getVectorVT(ScalarVT, Src.getValueSizeInBits()/64); - Src = DAG.getBitcast(BitcastVT, Src); - - return getVectorMaskingNode(DAG.getNode(IntrData->Opc0, dl, VT, Src), - Mask, PassThru, Subtarget, DAG); + case ROUNDP: { + assert(IntrData->Opc0 == X86ISD::VRNDSCALE && "Unexpected opcode"); + // Clear the upper bits of the rounding immediate so that the legacy + // intrinsic can't trigger the scaling behavior of VRNDSCALE. + SDValue RoundingMode = DAG.getNode(ISD::AND, dl, MVT::i32, + Op.getOperand(2), + DAG.getConstant(0xf, dl, MVT::i32)); + return DAG.getNode(IntrData->Opc0, dl, Op.getValueType(), + Op.getOperand(1), RoundingMode); + } + case ROUNDS: { + assert(IntrData->Opc0 == X86ISD::VRNDSCALES && "Unexpected opcode"); + // Clear the upper bits of the rounding immediate so that the legacy + // intrinsic can't trigger the scaling behavior of VRNDSCALE. + SDValue RoundingMode = DAG.getNode(ISD::AND, dl, MVT::i32, + Op.getOperand(3), + DAG.getConstant(0xf, dl, MVT::i32)); + return DAG.getNode(IntrData->Opc0, dl, Op.getValueType(), + Op.getOperand(1), Op.getOperand(2), RoundingMode); } default: break; @@ -20187,7 +20816,8 @@ static SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, const X86Subtarget &Subtarget auto &Context = MF.getMMI().getContext(); MCSymbol *S = Context.getOrCreateSymbol(Twine("GCC_except_table") + Twine(MF.getFunctionNumber())); - return DAG.getNode(X86ISD::Wrapper, dl, VT, DAG.getMCSymbol(S, PtrVT)); + return DAG.getNode(getGlobalWrapperKind(), dl, VT, + DAG.getMCSymbol(S, PtrVT)); } case Intrinsic::x86_seh_lsda: { @@ -20589,18 +21219,16 @@ static SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, const X86Subtarget &Subtarget, case RDSEED: case RDRAND: { // Emit the node with the right value type. - SDVTList VTs = DAG.getVTList(Op->getValueType(0), MVT::Glue, MVT::Other); + SDVTList VTs = DAG.getVTList(Op->getValueType(0), MVT::i32, MVT::Other); SDValue Result = DAG.getNode(IntrData->Opc0, dl, VTs, Op.getOperand(0)); // If the value returned by RDRAND/RDSEED was valid (CF=1), return 1. // Otherwise return the value from Rand, which is always 0, casted to i32. SDValue Ops[] = { DAG.getZExtOrTrunc(Result, dl, Op->getValueType(1)), DAG.getConstant(1, dl, Op->getValueType(1)), - DAG.getConstant(X86::COND_B, dl, MVT::i32), + DAG.getConstant(X86::COND_B, dl, MVT::i8), SDValue(Result.getNode(), 1) }; - SDValue isValid = DAG.getNode(X86ISD::CMOV, dl, - DAG.getVTList(Op->getValueType(1), MVT::Glue), - Ops); + SDValue isValid = DAG.getNode(X86ISD::CMOV, dl, Op->getValueType(1), Ops); // Return { result, isValid, chain }. return DAG.getNode(ISD::MERGE_VALUES, dl, Op->getVTList(), Result, isValid, @@ -20684,7 +21312,7 @@ static SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, const X86Subtarget &Subtarget, // ADC/ADCX/SBB case ADX: { SDVTList CFVTs = DAG.getVTList(Op->getValueType(0), MVT::i32); - SDVTList VTs = DAG.getVTList(Op.getOperand(3)->getValueType(0), MVT::i32); + SDVTList VTs = DAG.getVTList(Op.getOperand(3).getValueType(), MVT::i32); SDValue GenCF = DAG.getNode(X86ISD::ADD, dl, CFVTs, Op.getOperand(2), DAG.getConstant(-1, dl, MVT::i8)); SDValue Res = DAG.getNode(IntrData->Opc0, dl, VTs, Op.getOperand(3), @@ -21292,7 +21920,14 @@ static SDValue LowerVectorCTLZInRegLUT(SDValue Op, const SDLoc &DL, SDValue NibbleShift = DAG.getConstant(0x4, DL, CurrVT); SDValue Lo = DAG.getNode(ISD::AND, DL, CurrVT, Op0, NibbleMask); SDValue Hi = DAG.getNode(ISD::SRL, DL, CurrVT, Op0, NibbleShift); - SDValue HiZ = DAG.getSetCC(DL, CurrVT, Hi, Zero, ISD::SETEQ); + SDValue HiZ; + if (CurrVT.is512BitVector()) { + MVT MaskVT = MVT::getVectorVT(MVT::i1, CurrVT.getVectorNumElements()); + HiZ = DAG.getSetCC(DL, MaskVT, Hi, Zero, ISD::SETEQ); + HiZ = DAG.getNode(ISD::SIGN_EXTEND, DL, CurrVT, HiZ); + } else { + HiZ = DAG.getSetCC(DL, CurrVT, Hi, Zero, ISD::SETEQ); + } Lo = DAG.getNode(X86ISD::PSHUFB, DL, CurrVT, InRegLUT, Lo); Hi = DAG.getNode(X86ISD::PSHUFB, DL, CurrVT, InRegLUT, Hi); @@ -21312,8 +21947,15 @@ static SDValue LowerVectorCTLZInRegLUT(SDValue Op, const SDLoc &DL, SDValue Shift = DAG.getConstant(CurrScalarSizeInBits, DL, NextVT); // Check if the upper half of the input element is zero. - SDValue HiZ = DAG.getSetCC(DL, CurrVT, DAG.getBitcast(CurrVT, Op0), - DAG.getBitcast(CurrVT, Zero), ISD::SETEQ); + if (CurrVT.is512BitVector()) { + MVT MaskVT = MVT::getVectorVT(MVT::i1, CurrVT.getVectorNumElements()); + HiZ = DAG.getSetCC(DL, MaskVT, DAG.getBitcast(CurrVT, Op0), + DAG.getBitcast(CurrVT, Zero), ISD::SETEQ); + HiZ = DAG.getNode(ISD::SIGN_EXTEND, DL, CurrVT, HiZ); + } else { + HiZ = DAG.getSetCC(DL, CurrVT, DAG.getBitcast(CurrVT, Op0), + DAG.getBitcast(CurrVT, Zero), ISD::SETEQ); + } HiZ = DAG.getBitcast(NextVT, HiZ); // Move the upper/lower halves to the lower bits as we'll be extending to @@ -21505,6 +22147,19 @@ static SDValue LowerADD_SUB(SDValue Op, SelectionDAG &DAG) { } static SDValue LowerABS(SDValue Op, SelectionDAG &DAG) { + MVT VT = Op.getSimpleValueType(); + if (VT == MVT::i16 || VT == MVT::i32 || VT == MVT::i64) { + // Since X86 does not have CMOV for 8-bit integer, we don't convert + // 8-bit integer abs to NEG and CMOV. + SDLoc DL(Op); + SDValue N0 = Op.getOperand(0); + SDValue Neg = DAG.getNode(X86ISD::SUB, DL, DAG.getVTList(VT, MVT::i32), + DAG.getConstant(0, DL, VT), N0); + SDValue Ops[] = {N0, Neg, DAG.getConstant(X86::COND_GE, DL, MVT::i8), + SDValue(Neg.getNode(), 1)}; + return DAG.getNode(X86ISD::CMOV, DL, VT, Ops); + } + assert(Op.getSimpleValueType().is256BitVector() && Op.getSimpleValueType().isInteger() && "Only handle AVX 256-bit vector integer operation"); @@ -21607,7 +22262,14 @@ static SDValue LowerMUL(SDValue Op, const X86Subtarget &Subtarget, // Lower v4i32 mul as 2x shuffle, 2x pmuludq, 2x shuffle. if (VT == MVT::v4i32) { assert(Subtarget.hasSSE2() && !Subtarget.hasSSE41() && - "Should not custom lower when pmuldq is available!"); + "Should not custom lower when pmulld is available!"); + + // If the upper 17 bits of each element are zero then we can use PMADD. + APInt Mask17 = APInt::getHighBitsSet(32, 17); + if (DAG.MaskedValueIsZero(A, Mask17) && DAG.MaskedValueIsZero(B, Mask17)) + return DAG.getNode(X86ISD::VPMADDWD, dl, VT, + DAG.getBitcast(MVT::v8i16, A), + DAG.getBitcast(MVT::v8i16, B)); // Extract the odd parts. static const int UnpackMask[] = { 1, -1, 3, -1 }; @@ -21659,6 +22321,11 @@ static SDValue LowerMUL(SDValue Op, const X86Subtarget &Subtarget, bool AHiIsZero = DAG.MaskedValueIsZero(A, UpperBitsMask); bool BHiIsZero = DAG.MaskedValueIsZero(B, UpperBitsMask); + // If DQI is supported we can use MULLQ, but MULUDQ is still better if the + // the high bits are known to be zero. + if (Subtarget.hasDQI() && (!AHiIsZero || !BHiIsZero)) + return Op; + // Bit cast to 32-bit vectors for MULUDQ. SDValue Alo = DAG.getBitcast(MulVT, A); SDValue Blo = DAG.getBitcast(MulVT, B); @@ -21700,7 +22367,8 @@ static SDValue LowerMULH(SDValue Op, const X86Subtarget &Subtarget, return Lower256IntArith(Op, DAG); // Only i8 vectors should need custom lowering after this. - assert((VT == MVT::v16i8 || (VT == MVT::v32i8 && Subtarget.hasInt256())) && + assert((VT == MVT::v16i8 || (VT == MVT::v32i8 && Subtarget.hasInt256()) || + (VT == MVT::v64i8 && Subtarget.hasBWI())) && "Unsupported vector type"); // Lower v16i8/v32i8 as extension to v8i16/v16i16 vector pairs, multiply, @@ -21712,22 +22380,36 @@ static SDValue LowerMULH(SDValue Op, const X86Subtarget &Subtarget, // and then ashr/lshr the upper bits down to the lower bits before multiply. unsigned Opcode = Op.getOpcode(); unsigned ExShift = (ISD::MULHU == Opcode ? ISD::SRL : ISD::SRA); - unsigned ExSSE41 = (ISD::MULHU == Opcode ? X86ISD::VZEXT : X86ISD::VSEXT); + unsigned ExAVX = (ISD::MULHU == Opcode ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND); + + // For 512-bit vectors, split into 256-bit vectors to allow the + // sign-extension to occur. + if (VT == MVT::v64i8) + return Lower512IntArith(Op, DAG); // AVX2 implementations - extend xmm subvectors to ymm. if (Subtarget.hasInt256()) { + unsigned NumElems = VT.getVectorNumElements(); SDValue Lo = DAG.getIntPtrConstant(0, dl); - SDValue Hi = DAG.getIntPtrConstant(VT.getVectorNumElements() / 2, dl); + SDValue Hi = DAG.getIntPtrConstant(NumElems / 2, dl); if (VT == MVT::v32i8) { - SDValue ALo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v16i8, A, Lo); - SDValue BLo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v16i8, B, Lo); - SDValue AHi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v16i8, A, Hi); - SDValue BHi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v16i8, B, Hi); - ALo = DAG.getNode(ExSSE41, dl, MVT::v16i16, ALo); - BLo = DAG.getNode(ExSSE41, dl, MVT::v16i16, BLo); - AHi = DAG.getNode(ExSSE41, dl, MVT::v16i16, AHi); - BHi = DAG.getNode(ExSSE41, dl, MVT::v16i16, BHi); + if (Subtarget.hasBWI()) { + SDValue ExA = DAG.getNode(ExAVX, dl, MVT::v32i16, A); + SDValue ExB = DAG.getNode(ExAVX, dl, MVT::v32i16, B); + SDValue Mul = DAG.getNode(ISD::MUL, dl, MVT::v32i16, ExA, ExB); + Mul = DAG.getNode(ISD::SRL, dl, MVT::v32i16, Mul, + DAG.getConstant(8, dl, MVT::v32i16)); + return DAG.getNode(ISD::TRUNCATE, dl, VT, Mul); + } + SDValue ALo = extract128BitVector(A, 0, DAG, dl); + SDValue BLo = extract128BitVector(B, 0, DAG, dl); + SDValue AHi = extract128BitVector(A, NumElems / 2, DAG, dl); + SDValue BHi = extract128BitVector(B, NumElems / 2, DAG, dl); + ALo = DAG.getNode(ExAVX, dl, MVT::v16i16, ALo); + BLo = DAG.getNode(ExAVX, dl, MVT::v16i16, BLo); + AHi = DAG.getNode(ExAVX, dl, MVT::v16i16, AHi); + BHi = DAG.getNode(ExAVX, dl, MVT::v16i16, BHi); Lo = DAG.getNode(ISD::SRL, dl, MVT::v16i16, DAG.getNode(ISD::MUL, dl, MVT::v16i16, ALo, BLo), DAG.getConstant(8, dl, MVT::v16i16)); @@ -21745,19 +22427,25 @@ static SDValue LowerMULH(SDValue Op, const X86Subtarget &Subtarget, DAG.getVectorShuffle(MVT::v16i16, dl, Lo, Hi, HiMask)); } - SDValue ExA = getExtendInVec(ExSSE41, dl, MVT::v16i16, A, DAG); - SDValue ExB = getExtendInVec(ExSSE41, dl, MVT::v16i16, B, DAG); + assert(VT == MVT::v16i8 && "Unexpected VT"); + + SDValue ExA = DAG.getNode(ExAVX, dl, MVT::v16i16, A); + SDValue ExB = DAG.getNode(ExAVX, dl, MVT::v16i16, B); SDValue Mul = DAG.getNode(ISD::MUL, dl, MVT::v16i16, ExA, ExB); - SDValue MulH = DAG.getNode(ISD::SRL, dl, MVT::v16i16, Mul, - DAG.getConstant(8, dl, MVT::v16i16)); - Lo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v8i16, MulH, Lo); - Hi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v8i16, MulH, Hi); + Mul = DAG.getNode(ISD::SRL, dl, MVT::v16i16, Mul, + DAG.getConstant(8, dl, MVT::v16i16)); + // If we have BWI we can use truncate instruction. + if (Subtarget.hasBWI()) + return DAG.getNode(ISD::TRUNCATE, dl, VT, Mul); + Lo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v8i16, Mul, Lo); + Hi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v8i16, Mul, Hi); return DAG.getNode(X86ISD::PACKUS, dl, VT, Lo, Hi); } assert(VT == MVT::v16i8 && "Pre-AVX2 support only supports v16i8 multiplication"); MVT ExVT = MVT::v8i16; + unsigned ExSSE41 = (ISD::MULHU == Opcode ? X86ISD::VZEXT : X86ISD::VSEXT); // Extract the lo parts and zero/sign extend to i16. SDValue ALo, BLo; @@ -21885,7 +22573,10 @@ static SDValue LowerMUL_LOHI(SDValue Op, const X86Subtarget &Subtarget, } assert((VT == MVT::v4i32 && Subtarget.hasSSE2()) || - (VT == MVT::v8i32 && Subtarget.hasInt256())); + (VT == MVT::v8i32 && Subtarget.hasInt256()) || + (VT == MVT::v16i32 && Subtarget.hasAVX512())); + + int NumElts = VT.getVectorNumElements(); // PMULxD operations multiply each even value (starting at 0) of LHS with // the related value of RHS and produce a widen result. @@ -21899,17 +22590,17 @@ static SDValue LowerMUL_LOHI(SDValue Op, const X86Subtarget &Subtarget, // // Place the odd value at an even position (basically, shift all values 1 // step to the left): - const int Mask[] = {1, -1, 3, -1, 5, -1, 7, -1}; + const int Mask[] = {1, -1, 3, -1, 5, -1, 7, -1, 9, -1, 11, -1, 13, -1, 15, -1}; // <a|b|c|d> => <b|undef|d|undef> SDValue Odd0 = DAG.getVectorShuffle(VT, dl, Op0, Op0, - makeArrayRef(&Mask[0], VT.getVectorNumElements())); + makeArrayRef(&Mask[0], NumElts)); // <e|f|g|h> => <f|undef|h|undef> SDValue Odd1 = DAG.getVectorShuffle(VT, dl, Op1, Op1, - makeArrayRef(&Mask[0], VT.getVectorNumElements())); + makeArrayRef(&Mask[0], NumElts)); // Emit two multiplies, one for the lower 2 ints and one for the higher 2 // ints. - MVT MulVT = VT == MVT::v4i32 ? MVT::v2i64 : MVT::v4i64; + MVT MulVT = MVT::getVectorVT(MVT::i64, NumElts / 2); bool IsSigned = Op->getOpcode() == ISD::SMUL_LOHI; unsigned Opcode = (!IsSigned || !Subtarget.hasSSE41()) ? X86ISD::PMULUDQ : X86ISD::PMULDQ; @@ -21921,19 +22612,16 @@ static SDValue LowerMUL_LOHI(SDValue Op, const X86Subtarget &Subtarget, SDValue Mul2 = DAG.getBitcast(VT, DAG.getNode(Opcode, dl, MulVT, Odd0, Odd1)); // Shuffle it back into the right order. - SDValue Highs, Lows; - if (VT == MVT::v8i32) { - const int HighMask[] = {1, 9, 3, 11, 5, 13, 7, 15}; - Highs = DAG.getVectorShuffle(VT, dl, Mul1, Mul2, HighMask); - const int LowMask[] = {0, 8, 2, 10, 4, 12, 6, 14}; - Lows = DAG.getVectorShuffle(VT, dl, Mul1, Mul2, LowMask); - } else { - const int HighMask[] = {1, 5, 3, 7}; - Highs = DAG.getVectorShuffle(VT, dl, Mul1, Mul2, HighMask); - const int LowMask[] = {0, 4, 2, 6}; - Lows = DAG.getVectorShuffle(VT, dl, Mul1, Mul2, LowMask); + SmallVector<int, 16> HighMask(NumElts); + SmallVector<int, 16> LowMask(NumElts); + for (int i = 0; i != NumElts; ++i) { + HighMask[i] = (i / 2) * 2 + ((i % 2) * NumElts) + 1; + LowMask[i] = (i / 2) * 2 + ((i % 2) * NumElts); } + SDValue Highs = DAG.getVectorShuffle(VT, dl, Mul1, Mul2, HighMask); + SDValue Lows = DAG.getVectorShuffle(VT, dl, Mul1, Mul2, LowMask); + // If we have a signed multiply but no PMULDQ fix up the high parts of a // unsigned multiply. if (IsSigned && !Subtarget.hasSSE41()) { @@ -22123,9 +22811,9 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG, } } - // Special case in 32-bit mode, where i64 is expanded into high and low parts. + // Check cases (mainly 32-bit) where i64 is expanded into high and low parts. // TODO: Replace constant extraction with getTargetConstantBitsFromNode. - if (!Subtarget.is64Bit() && !Subtarget.hasXOP() && + if (!Subtarget.hasXOP() && (VT == MVT::v2i64 || (Subtarget.hasInt256() && VT == MVT::v4i64) || (Subtarget.hasAVX512() && VT == MVT::v8i64))) { @@ -22252,9 +22940,8 @@ static SDValue LowerScalarVariableShift(SDValue Op, SelectionDAG &DAG, } } - // Special case in 32-bit mode, where i64 is expanded into high and low parts. - if (!Subtarget.is64Bit() && VT == MVT::v2i64 && - Amt.getOpcode() == ISD::BITCAST && + // Check cases (mainly 32-bit) where i64 is expanded into high and low parts. + if (VT == MVT::v2i64 && Amt.getOpcode() == ISD::BITCAST && Amt.getOperand(0).getOpcode() == ISD::BUILD_VECTOR) { Amt = Amt.getOperand(0); unsigned Ratio = Amt.getSimpleValueType().getVectorNumElements() / @@ -22389,7 +23076,7 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget, // the vector shift into four scalar shifts plus four pairs of vector // insert/extract. if (ConstantAmt && (VT == MVT::v8i16 || VT == MVT::v4i32)) { - unsigned TargetOpcode = X86ISD::MOVSS; + bool UseMOVSD = false; bool CanBeSimplified; // The splat value for the first packed shift (the 'X' from the example). SDValue Amt1 = Amt->getOperand(0); @@ -22406,7 +23093,7 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget, // Otherwise, check if we can still simplify this node using a MOVSD. CanBeSimplified = Amt1 == Amt->getOperand(1) && Amt->getOperand(2) == Amt->getOperand(3); - TargetOpcode = X86ISD::MOVSD; + UseMOVSD = true; Amt2 = Amt->getOperand(2); } } else { @@ -22417,7 +23104,7 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget, CanBeSimplified = Amt2 == Amt->getOperand(i); if (!CanBeSimplified) { - TargetOpcode = X86ISD::MOVSD; + UseMOVSD = true; CanBeSimplified = true; Amt2 = Amt->getOperand(4); for (unsigned i=0; i != 4 && CanBeSimplified; ++i) @@ -22430,19 +23117,18 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget, if (CanBeSimplified && isa<ConstantSDNode>(Amt1) && isa<ConstantSDNode>(Amt2)) { // Replace this node with two shifts followed by a MOVSS/MOVSD/PBLEND. - MVT CastVT = MVT::v4i32; SDValue Splat1 = DAG.getConstant(cast<ConstantSDNode>(Amt1)->getAPIntValue(), dl, VT); SDValue Shift1 = DAG.getNode(Op->getOpcode(), dl, VT, R, Splat1); SDValue Splat2 = DAG.getConstant(cast<ConstantSDNode>(Amt2)->getAPIntValue(), dl, VT); SDValue Shift2 = DAG.getNode(Op->getOpcode(), dl, VT, R, Splat2); - SDValue BitCast1 = DAG.getBitcast(CastVT, Shift1); - SDValue BitCast2 = DAG.getBitcast(CastVT, Shift2); - if (TargetOpcode == X86ISD::MOVSD) - return DAG.getBitcast(VT, DAG.getVectorShuffle(CastVT, dl, BitCast1, + SDValue BitCast1 = DAG.getBitcast(MVT::v4i32, Shift1); + SDValue BitCast2 = DAG.getBitcast(MVT::v4i32, Shift2); + if (UseMOVSD) + return DAG.getBitcast(VT, DAG.getVectorShuffle(MVT::v4i32, dl, BitCast1, BitCast2, {0, 1, 6, 7})); - return DAG.getBitcast(VT, DAG.getVectorShuffle(CastVT, dl, BitCast1, + return DAG.getBitcast(VT, DAG.getVectorShuffle(MVT::v4i32, dl, BitCast1, BitCast2, {0, 5, 6, 7})); } } @@ -22501,12 +23187,14 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget, (Subtarget.hasAVX512() && VT == MVT::v16i16) || (Subtarget.hasAVX512() && VT == MVT::v16i8) || (Subtarget.hasBWI() && VT == MVT::v32i8)) { - MVT EvtSVT = (VT == MVT::v32i8 ? MVT::i16 : MVT::i32); + assert((!Subtarget.hasBWI() || VT == MVT::v32i8 || VT == MVT::v16i8) && + "Unexpected vector type"); + MVT EvtSVT = Subtarget.hasBWI() ? MVT::i16 : MVT::i32; MVT ExtVT = MVT::getVectorVT(EvtSVT, VT.getVectorNumElements()); unsigned ExtOpc = Op.getOpcode() == ISD::SRA ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND; R = DAG.getNode(ExtOpc, dl, ExtVT, R); - Amt = DAG.getNode(ISD::ANY_EXTEND, dl, ExtVT, Amt); + Amt = DAG.getNode(ISD::ZERO_EXTEND, dl, ExtVT, Amt); return DAG.getNode(ISD::TRUNCATE, dl, VT, DAG.getNode(Op.getOpcode(), dl, ExtVT, R, Amt)); } @@ -22752,7 +23440,7 @@ static SDValue LowerRotate(SDValue Op, const X86Subtarget &Subtarget, assert((Opcode == ISD::ROTL) && "Only ROTL supported"); // XOP has 128-bit vector variable + immediate rotates. - // +ve/-ve Amt = rotate left/right. + // +ve/-ve Amt = rotate left/right - just need to handle ISD::ROTL. // Split 256-bit integers. if (VT.is256BitVector()) @@ -22765,13 +23453,13 @@ static SDValue LowerRotate(SDValue Op, const X86Subtarget &Subtarget, if (auto *RotateConst = BVAmt->getConstantSplatNode()) { uint64_t RotateAmt = RotateConst->getAPIntValue().getZExtValue(); assert(RotateAmt < EltSizeInBits && "Rotation out of range"); - return DAG.getNode(X86ISD::VPROTI, DL, VT, R, + return DAG.getNode(X86ISD::VROTLI, DL, VT, R, DAG.getConstant(RotateAmt, DL, MVT::i8)); } } // Use general rotate by variable (per-element). - return DAG.getNode(X86ISD::VPROT, DL, VT, R, Amt); + return Op; } static SDValue LowerXALUO(SDValue Op, SelectionDAG &DAG) { @@ -23319,15 +24007,14 @@ static SDValue LowerVectorCTPOP(SDValue Op, const X86Subtarget &Subtarget, // TRUNC(CTPOP(ZEXT(X))) to make use of vXi32/vXi64 VPOPCNT instructions. if (Subtarget.hasVPOPCNTDQ()) { - if (VT == MVT::v8i16) { - Op = DAG.getNode(X86ISD::VZEXT, DL, MVT::v8i64, Op0); - Op = DAG.getNode(ISD::CTPOP, DL, MVT::v8i64, Op); - return DAG.getNode(X86ISD::VTRUNC, DL, VT, Op); - } - if (VT == MVT::v16i8 || VT == MVT::v16i16) { - Op = DAG.getNode(X86ISD::VZEXT, DL, MVT::v16i32, Op0); - Op = DAG.getNode(ISD::CTPOP, DL, MVT::v16i32, Op); - return DAG.getNode(X86ISD::VTRUNC, DL, VT, Op); + unsigned NumElems = VT.getVectorNumElements(); + assert((VT.getVectorElementType() == MVT::i8 || + VT.getVectorElementType() == MVT::i16) && "Unexpected type"); + if (NumElems <= 16) { + MVT NewVT = MVT::getVectorVT(MVT::i32, NumElems); + Op = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, Op0); + Op = DAG.getNode(ISD::CTPOP, DL, NewVT, Op); + return DAG.getNode(ISD::TRUNCATE, DL, VT, Op); } } @@ -23402,12 +24089,13 @@ static SDValue LowerBITREVERSE_XOP(SDValue Op, SelectionDAG &DAG) { static SDValue LowerBITREVERSE(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG) { - if (Subtarget.hasXOP()) + MVT VT = Op.getSimpleValueType(); + + if (Subtarget.hasXOP() && !VT.is512BitVector()) return LowerBITREVERSE_XOP(Op, DAG); assert(Subtarget.hasSSSE3() && "SSSE3 required for BITREVERSE"); - MVT VT = Op.getSimpleValueType(); SDValue In = Op.getOperand(0); SDLoc DL(Op); @@ -23450,7 +24138,9 @@ static SDValue LowerBITREVERSE(SDValue Op, const X86Subtarget &Subtarget, return DAG.getNode(ISD::OR, DL, VT, Lo, Hi); } -static SDValue lowerAtomicArithWithLOCK(SDValue N, SelectionDAG &DAG) { +static SDValue lowerAtomicArithWithLOCK(SDValue N, SelectionDAG &DAG, + const X86Subtarget &Subtarget, + bool AllowIncDec = true) { unsigned NewOpc = 0; switch (N->getOpcode()) { case ISD::ATOMIC_LOAD_ADD: @@ -23473,6 +24163,26 @@ static SDValue lowerAtomicArithWithLOCK(SDValue N, SelectionDAG &DAG) { } MachineMemOperand *MMO = cast<MemSDNode>(N)->getMemOperand(); + + if (auto *C = dyn_cast<ConstantSDNode>(N->getOperand(2))) { + // Convert to inc/dec if they aren't slow or we are optimizing for size. + if (AllowIncDec && (!Subtarget.slowIncDec() || + DAG.getMachineFunction().getFunction().optForSize())) { + if ((NewOpc == X86ISD::LADD && C->isOne()) || + (NewOpc == X86ISD::LSUB && C->isAllOnesValue())) + return DAG.getMemIntrinsicNode(X86ISD::LINC, SDLoc(N), + DAG.getVTList(MVT::i32, MVT::Other), + {N->getOperand(0), N->getOperand(1)}, + /*MemVT=*/N->getSimpleValueType(0), MMO); + if ((NewOpc == X86ISD::LSUB && C->isOne()) || + (NewOpc == X86ISD::LADD && C->isAllOnesValue())) + return DAG.getMemIntrinsicNode(X86ISD::LDEC, SDLoc(N), + DAG.getVTList(MVT::i32, MVT::Other), + {N->getOperand(0), N->getOperand(1)}, + /*MemVT=*/N->getSimpleValueType(0), MMO); + } + } + return DAG.getMemIntrinsicNode( NewOpc, SDLoc(N), DAG.getVTList(MVT::i32, MVT::Other), {N->getOperand(0), N->getOperand(1), N->getOperand(2)}, @@ -23506,7 +24216,7 @@ static SDValue lowerAtomicArith(SDValue N, SelectionDAG &DAG, return N; } - SDValue LockOp = lowerAtomicArithWithLOCK(N, DAG); + SDValue LockOp = lowerAtomicArithWithLOCK(N, DAG, Subtarget); // RAUW the chain, but don't worry about the result, as it's unused. assert(!N->hasAnyUseOfValue(0)); DAG.ReplaceAllUsesOfValueWith(N.getValue(1), LockOp.getValue(1)); @@ -23591,8 +24301,9 @@ static SDValue LowerFSINCOS(SDValue Op, const X86Subtarget &Subtarget, // Only optimize x86_64 for now. i386 is a bit messy. For f32, // the small struct {f32, f32} is returned in (eax, edx). For f64, // the results are returned via SRet in memory. - const char *LibcallName = isF64 ? "__sincos_stret" : "__sincosf_stret"; const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + RTLIB::Libcall LC = isF64 ? RTLIB::SINCOS_STRET_F64 : RTLIB::SINCOS_STRET_F32; + const char *LibcallName = TLI.getLibcallName(LC); SDValue Callee = DAG.getExternalSymbol(LibcallName, TLI.getPointerTy(DAG.getDataLayout())); @@ -23675,19 +24386,12 @@ static SDValue LowerMSCATTER(SDValue Op, const X86Subtarget &Subtarget, assert(Subtarget.hasAVX512() && "MGATHER/MSCATTER are supported on AVX-512 arch only"); - // X86 scatter kills mask register, so its type should be added to - // the list of return values. - // If the "scatter" has 2 return values, it is already handled. - if (Op.getNode()->getNumValues() == 2) - return Op; - MaskedScatterSDNode *N = cast<MaskedScatterSDNode>(Op.getNode()); SDValue Src = N->getValue(); MVT VT = Src.getSimpleValueType(); assert(VT.getScalarSizeInBits() >= 32 && "Unsupported scatter op"); SDLoc dl(Op); - SDValue NewScatter; SDValue Index = N->getIndex(); SDValue Mask = N->getMask(); SDValue Chain = N->getChain(); @@ -23758,8 +24462,8 @@ static SDValue LowerMSCATTER(SDValue Op, const X86Subtarget &Subtarget, // The mask is killed by scatter, add it to the values SDVTList VTs = DAG.getVTList(BitMaskVT, MVT::Other); SDValue Ops[] = {Chain, Src, Mask, BasePtr, Index}; - NewScatter = DAG.getMaskedScatter(VTs, N->getMemoryVT(), dl, Ops, - N->getMemOperand()); + SDValue NewScatter = DAG.getTargetMemSDNode<X86MaskedScatterSDNode>( + VTs, Ops, dl, N->getMemoryVT(), N->getMemOperand()); DAG.ReplaceAllUsesWith(Op, SDValue(NewScatter.getNode(), 1)); return SDValue(NewScatter.getNode(), 1); } @@ -23874,8 +24578,8 @@ static SDValue LowerMSTORE(SDValue Op, const X86Subtarget &Subtarget, static SDValue LowerMGATHER(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG) { - assert(Subtarget.hasAVX512() && - "MGATHER/MSCATTER are supported on AVX-512 arch only"); + assert(Subtarget.hasAVX2() && + "MGATHER/MSCATTER are supported on AVX-512/AVX-2 arch only"); MaskedGatherSDNode *N = cast<MaskedGatherSDNode>(Op.getNode()); SDLoc dl(Op); @@ -23889,17 +24593,22 @@ static SDValue LowerMGATHER(SDValue Op, const X86Subtarget &Subtarget, unsigned NumElts = VT.getVectorNumElements(); assert(VT.getScalarSizeInBits() >= 32 && "Unsupported gather op"); - if (!Subtarget.hasVLX() && !VT.is512BitVector() && + // If the index is v2i32, we're being called by type legalization. + if (IndexVT == MVT::v2i32) + return SDValue(); + + if (Subtarget.hasAVX512() && !Subtarget.hasVLX() && !VT.is512BitVector() && !Index.getSimpleValueType().is512BitVector()) { // AVX512F supports only 512-bit vectors. Or data or index should // be 512 bit wide. If now the both index and data are 256-bit, but // the vector contains 8 elements, we just sign-extend the index if (NumElts == 8) { Index = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v8i64, Index); - SDValue Ops[] = { N->getOperand(0), N->getOperand(1), N->getOperand(2), - N->getOperand(3), Index }; - DAG.UpdateNodeOperands(N, Ops); - return Op; + SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index }; + SDValue NewGather = DAG.getTargetMemSDNode<X86MaskedGatherSDNode>( + DAG.getVTList(VT, MaskVT, MVT::Other), Ops, dl, N->getMemoryVT(), + N->getMemOperand()); + return DAG.getMergeValues({NewGather, NewGather.getValue(2)}, dl); } // Minimal number of elements in Gather @@ -23923,67 +24632,21 @@ static SDValue LowerMGATHER(SDValue Op, const X86Subtarget &Subtarget, Src0 = ExtendToType(Src0, NewVT, DAG); SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index }; - SDValue NewGather = DAG.getMaskedGather(DAG.getVTList(NewVT, MVT::Other), - N->getMemoryVT(), dl, Ops, - N->getMemOperand()); - SDValue Exract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT, - NewGather.getValue(0), - DAG.getIntPtrConstant(0, dl)); - SDValue RetOps[] = {Exract, NewGather.getValue(1)}; - return DAG.getMergeValues(RetOps, dl); - } - if (N->getMemoryVT() == MVT::v2i32 && Subtarget.hasVLX()) { - // There is a special case when the return type is v2i32 is illegal and - // the type legaizer extended it to v2i64. Without this conversion we end up - // with VPGATHERQQ (reading q-words from the memory) instead of VPGATHERQD. - // In order to avoid this situation, we'll build an X86 specific Gather node - // with index v2i64 and value type v4i32. - assert(VT == MVT::v2i64 && Src0.getValueType() == MVT::v2i64 && - "Unexpected type in masked gather"); - Src0 = DAG.getVectorShuffle(MVT::v4i32, dl, - DAG.getBitcast(MVT::v4i32, Src0), - DAG.getUNDEF(MVT::v4i32), { 0, 2, -1, -1 }); - // The mask should match the destination type. Extending mask with zeroes - // is not necessary since instruction itself reads only two values from - // memory. - Mask = ExtendToType(Mask, MVT::v4i1, DAG, false); - SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index }; SDValue NewGather = DAG.getTargetMemSDNode<X86MaskedGatherSDNode>( - DAG.getVTList(MVT::v4i32, MVT::Other), Ops, dl, N->getMemoryVT(), - N->getMemOperand()); - - SDValue Sext = getExtendInVec(X86ISD::VSEXT, dl, MVT::v2i64, - NewGather.getValue(0), DAG); - SDValue RetOps[] = { Sext, NewGather.getValue(1) }; + DAG.getVTList(NewVT, MaskBitVT, MVT::Other), Ops, dl, N->getMemoryVT(), + N->getMemOperand()); + SDValue Extract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT, + NewGather.getValue(0), + DAG.getIntPtrConstant(0, dl)); + SDValue RetOps[] = {Extract, NewGather.getValue(2)}; return DAG.getMergeValues(RetOps, dl); } - if (N->getMemoryVT() == MVT::v2f32 && Subtarget.hasVLX()) { - // This transformation is for optimization only. - // The type legalizer extended mask and index to 4 elements vector - // in order to match requirements of the common gather node - same - // vector width of index and value. X86 Gather node allows mismatch - // of vector width in order to select more optimal instruction at the - // end. - assert(VT == MVT::v4f32 && Src0.getValueType() == MVT::v4f32 && - "Unexpected type in masked gather"); - if (Mask.getOpcode() == ISD::CONCAT_VECTORS && - ISD::isBuildVectorAllZeros(Mask.getOperand(1).getNode()) && - Index.getOpcode() == ISD::CONCAT_VECTORS && - Index.getOperand(1).isUndef()) { - Mask = ExtendToType(Mask.getOperand(0), MVT::v4i1, DAG, false); - Index = Index.getOperand(0); - } else - return Op; - SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index }; - SDValue NewGather = DAG.getTargetMemSDNode<X86MaskedGatherSDNode>( - DAG.getVTList(MVT::v4f32, MVT::Other), Ops, dl, N->getMemoryVT(), - N->getMemOperand()); - SDValue RetOps[] = { NewGather.getValue(0), NewGather.getValue(1) }; - return DAG.getMergeValues(RetOps, dl); - - } - return Op; + SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index }; + SDValue NewGather = DAG.getTargetMemSDNode<X86MaskedGatherSDNode>( + DAG.getVTList(VT, MaskVT, MVT::Other), Ops, dl, N->getMemoryVT(), + N->getMemOperand()); + return DAG.getMergeValues({NewGather, NewGather.getValue(2)}, dl); } SDValue X86TargetLowering::LowerGC_TRANSITION_START(SDValue Op, @@ -24049,8 +24712,8 @@ SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { case ISD::VSELECT: return LowerVSELECT(Op, DAG); case ISD::EXTRACT_VECTOR_ELT: return LowerEXTRACT_VECTOR_ELT(Op, DAG); case ISD::INSERT_VECTOR_ELT: return LowerINSERT_VECTOR_ELT(Op, DAG); - case ISD::EXTRACT_SUBVECTOR: return LowerEXTRACT_SUBVECTOR(Op,Subtarget,DAG); case ISD::INSERT_SUBVECTOR: return LowerINSERT_SUBVECTOR(Op, Subtarget,DAG); + case ISD::EXTRACT_SUBVECTOR: return LowerEXTRACT_SUBVECTOR(Op,Subtarget,DAG); case ISD::SCALAR_TO_VECTOR: return LowerSCALAR_TO_VECTOR(Op, Subtarget,DAG); case ISD::ConstantPool: return LowerConstantPool(Op, DAG); case ISD::GlobalAddress: return LowerGlobalAddress(Op, DAG); @@ -24085,7 +24748,7 @@ SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { case ISD::VASTART: return LowerVASTART(Op, DAG); case ISD::VAARG: return LowerVAARG(Op, DAG); case ISD::VACOPY: return LowerVACOPY(Op, Subtarget, DAG); - case ISD::INTRINSIC_WO_CHAIN: return LowerINTRINSIC_WO_CHAIN(Op, Subtarget, DAG); + case ISD::INTRINSIC_WO_CHAIN: return LowerINTRINSIC_WO_CHAIN(Op, DAG); case ISD::INTRINSIC_VOID: case ISD::INTRINSIC_W_CHAIN: return LowerINTRINSIC_W_CHAIN(Op, Subtarget, DAG); case ISD::RETURNADDR: return LowerRETURNADDR(Op, DAG); @@ -24203,8 +24866,10 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N, SDValue InVec1 = DAG.getNode(ISD::CONCAT_VECTORS, dl, RegVT, Ops); SDValue Res = DAG.getNode(X86ISD::AVG, dl, RegVT, InVec0, InVec1); - Results.push_back(DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, InVT, Res, - DAG.getIntPtrConstant(0, dl))); + if (!ExperimentalVectorWideningLegalization) + Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, InVT, Res, + DAG.getIntPtrConstant(0, dl)); + Results.push_back(Res); return; } // We might have generated v2f32 FMIN/FMAX operations. Widen them to v4f32. @@ -24242,11 +24907,21 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N, assert(Subtarget.hasSSE2() && "Requires at least SSE2!"); SDValue Src = N->getOperand(0); if (Src.getValueType() == MVT::v2f64) { - SDValue Idx = DAG.getIntPtrConstant(0, dl); - SDValue Res = DAG.getNode(IsSigned ? X86ISD::CVTTP2SI - : X86ISD::CVTTP2UI, - dl, MVT::v4i32, Src); - Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v2i32, Res, Idx); + MVT ResVT = MVT::v4i32; + unsigned Opc = IsSigned ? X86ISD::CVTTP2SI : X86ISD::CVTTP2UI; + if (!IsSigned && !Subtarget.hasVLX()) { + // Widen to 512-bits. + ResVT = MVT::v8i32; + Opc = ISD::FP_TO_UINT; + Src = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, MVT::v8f64, + DAG.getUNDEF(MVT::v8f64), + Src, DAG.getIntPtrConstant(0, dl)); + } + SDValue Res = DAG.getNode(Opc, dl, ResVT, Src); + ResVT = ExperimentalVectorWideningLegalization ? MVT::v4i32 + : MVT::v2i32; + Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, ResVT, Res, + DAG.getIntPtrConstant(0, dl)); Results.push_back(Res); return; } @@ -24256,7 +24931,8 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N, DAG.getUNDEF(MVT::v2f32)); Res = DAG.getNode(IsSigned ? ISD::FP_TO_SINT : ISD::FP_TO_UINT, dl, MVT::v4i32, Res); - Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v2i32, Res, Idx); + if (!ExperimentalVectorWideningLegalization) + Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v2i32, Res, Idx); Results.push_back(Res); return; } @@ -24345,7 +25021,7 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N, } } case ISD::INTRINSIC_WO_CHAIN: { - if (SDValue V = LowerINTRINSIC_WO_CHAIN(SDValue(N, 0), Subtarget, DAG)) + if (SDValue V = LowerINTRINSIC_WO_CHAIN(SDValue(N, 0), DAG)) Results.push_back(V); return; } @@ -24454,7 +25130,7 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N, case ISD::BITCAST: { assert(Subtarget.hasSSE2() && "Requires at least SSE2!"); EVT DstVT = N->getValueType(0); - EVT SrcVT = N->getOperand(0)->getValueType(0); + EVT SrcVT = N->getOperand(0).getValueType(); if (SrcVT != MVT::f64 || (DstVT != MVT::v2i32 && DstVT != MVT::v4i16 && DstVT != MVT::v8i8)) @@ -24480,6 +25156,89 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N, ToVecInt, DAG.getIntPtrConstant(i, dl))); Results.push_back(DAG.getBuildVector(DstVT, dl, Elts)); + return; + } + case ISD::MGATHER: { + EVT VT = N->getValueType(0); + if (VT == MVT::v2f32 && (Subtarget.hasVLX() || !Subtarget.hasAVX512())) { + auto *Gather = cast<MaskedGatherSDNode>(N); + SDValue Index = Gather->getIndex(); + if (Index.getValueType() != MVT::v2i64) + return; + SDValue Mask = Gather->getMask(); + assert(Mask.getValueType() == MVT::v2i1 && "Unexpected mask type"); + SDValue Src0 = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4f32, + Gather->getValue(), + DAG.getUNDEF(MVT::v2f32)); + if (!Subtarget.hasVLX()) { + // We need to widen the mask, but the instruction will only use 2 + // of its elements. So we can use undef. + Mask = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4i1, Mask, + DAG.getUNDEF(MVT::v2i1)); + Mask = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v4i32, Mask); + } + SDValue Ops[] = { Gather->getChain(), Src0, Mask, Gather->getBasePtr(), + Index }; + SDValue Res = DAG.getTargetMemSDNode<X86MaskedGatherSDNode>( + DAG.getVTList(MVT::v4f32, Mask.getValueType(), MVT::Other), Ops, dl, + Gather->getMemoryVT(), Gather->getMemOperand()); + Results.push_back(Res); + Results.push_back(Res.getValue(2)); + return; + } + if (VT == MVT::v2i32) { + auto *Gather = cast<MaskedGatherSDNode>(N); + SDValue Index = Gather->getIndex(); + SDValue Mask = Gather->getMask(); + assert(Mask.getValueType() == MVT::v2i1 && "Unexpected mask type"); + SDValue Src0 = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4i32, + Gather->getValue(), + DAG.getUNDEF(MVT::v2i32)); + // If the index is v2i64 we can use it directly. + if (Index.getValueType() == MVT::v2i64 && + (Subtarget.hasVLX() || !Subtarget.hasAVX512())) { + if (!Subtarget.hasVLX()) { + // We need to widen the mask, but the instruction will only use 2 + // of its elements. So we can use undef. + Mask = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4i1, Mask, + DAG.getUNDEF(MVT::v2i1)); + Mask = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v4i32, Mask); + } + SDValue Ops[] = { Gather->getChain(), Src0, Mask, Gather->getBasePtr(), + Index }; + SDValue Res = DAG.getTargetMemSDNode<X86MaskedGatherSDNode>( + DAG.getVTList(MVT::v4i32, Mask.getValueType(), MVT::Other), Ops, dl, + Gather->getMemoryVT(), Gather->getMemOperand()); + SDValue Chain = Res.getValue(2); + if (!ExperimentalVectorWideningLegalization) + Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v2i32, Res, + DAG.getIntPtrConstant(0, dl)); + Results.push_back(Res); + Results.push_back(Chain); + return; + } + EVT IndexVT = Index.getValueType(); + EVT NewIndexVT = EVT::getVectorVT(*DAG.getContext(), + IndexVT.getScalarType(), 4); + // Otherwise we need to custom widen everything to avoid promotion. + Index = DAG.getNode(ISD::CONCAT_VECTORS, dl, NewIndexVT, Index, + DAG.getUNDEF(IndexVT)); + Mask = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4i1, Mask, + DAG.getConstant(0, dl, MVT::v2i1)); + SDValue Ops[] = { Gather->getChain(), Src0, Mask, Gather->getBasePtr(), + Index }; + SDValue Res = DAG.getMaskedGather(DAG.getVTList(MVT::v4i32, MVT::Other), + Gather->getMemoryVT(), dl, Ops, + Gather->getMemOperand()); + SDValue Chain = Res.getValue(1); + if (!ExperimentalVectorWideningLegalization) + Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v2i32, Res, + DAG.getIntPtrConstant(0, dl)); + Results.push_back(Res); + Results.push_back(Chain); + return; + } + break; } } } @@ -24557,9 +25316,7 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const { case X86ISD::FMAXC: return "X86ISD::FMAXC"; case X86ISD::FMINC: return "X86ISD::FMINC"; case X86ISD::FRSQRT: return "X86ISD::FRSQRT"; - case X86ISD::FRSQRTS: return "X86ISD::FRSQRTS"; case X86ISD::FRCP: return "X86ISD::FRCP"; - case X86ISD::FRCPS: return "X86ISD::FRCPS"; case X86ISD::EXTRQI: return "X86ISD::EXTRQI"; case X86ISD::INSERTQI: return "X86ISD::INSERTQI"; case X86ISD::TLSADDR: return "X86ISD::TLSADDR"; @@ -24585,6 +25342,8 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const { case X86ISD::LOR: return "X86ISD::LOR"; case X86ISD::LXOR: return "X86ISD::LXOR"; case X86ISD::LAND: return "X86ISD::LAND"; + case X86ISD::LINC: return "X86ISD::LINC"; + case X86ISD::LDEC: return "X86ISD::LDEC"; case X86ISD::VZEXT_MOVL: return "X86ISD::VZEXT_MOVL"; case X86ISD::VZEXT_LOAD: return "X86ISD::VZEXT_LOAD"; case X86ISD::VZEXT: return "X86ISD::VZEXT"; @@ -24620,6 +25379,7 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const { case X86ISD::PCMPGT: return "X86ISD::PCMPGT"; case X86ISD::PCMPEQM: return "X86ISD::PCMPEQM"; case X86ISD::PCMPGTM: return "X86ISD::PCMPGTM"; + case X86ISD::PHMINPOS: return "X86ISD::PHMINPOS"; case X86ISD::ADD: return "X86ISD::ADD"; case X86ISD::SUB: return "X86ISD::SUB"; case X86ISD::ADC: return "X86ISD::ADC"; @@ -24635,7 +25395,6 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const { case X86ISD::OR: return "X86ISD::OR"; case X86ISD::XOR: return "X86ISD::XOR"; case X86ISD::AND: return "X86ISD::AND"; - case X86ISD::BEXTR: return "X86ISD::BEXTR"; case X86ISD::MUL_IMM: return "X86ISD::MUL_IMM"; case X86ISD::MOVMSK: return "X86ISD::MOVMSK"; case X86ISD::PTEST: return "X86ISD::PTEST"; @@ -24650,13 +25409,16 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const { case X86ISD::PACKUS: return "X86ISD::PACKUS"; case X86ISD::PALIGNR: return "X86ISD::PALIGNR"; case X86ISD::VALIGN: return "X86ISD::VALIGN"; + case X86ISD::VSHLD: return "X86ISD::VSHLD"; + case X86ISD::VSHRD: return "X86ISD::VSHRD"; + case X86ISD::VSHLDV: return "X86ISD::VSHLDV"; + case X86ISD::VSHRDV: return "X86ISD::VSHRDV"; case X86ISD::PSHUFD: return "X86ISD::PSHUFD"; case X86ISD::PSHUFHW: return "X86ISD::PSHUFHW"; case X86ISD::PSHUFLW: return "X86ISD::PSHUFLW"; case X86ISD::SHUFP: return "X86ISD::SHUFP"; case X86ISD::SHUF128: return "X86ISD::SHUF128"; case X86ISD::MOVLHPS: return "X86ISD::MOVLHPS"; - case X86ISD::MOVLHPD: return "X86ISD::MOVLHPD"; case X86ISD::MOVHLPS: return "X86ISD::MOVHLPS"; case X86ISD::MOVLPS: return "X86ISD::MOVLPS"; case X86ISD::MOVLPD: return "X86ISD::MOVLPD"; @@ -24670,7 +25432,6 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const { case X86ISD::VBROADCAST: return "X86ISD::VBROADCAST"; case X86ISD::VBROADCASTM: return "X86ISD::VBROADCASTM"; case X86ISD::SUBV_BROADCAST: return "X86ISD::SUBV_BROADCAST"; - case X86ISD::VEXTRACT: return "X86ISD::VEXTRACT"; case X86ISD::VPERMILPV: return "X86ISD::VPERMILPV"; case X86ISD::VPERMILPI: return "X86ISD::VPERMILPI"; case X86ISD::VPERM2X128: return "X86ISD::VPERM2X128"; @@ -24680,8 +25441,11 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const { case X86ISD::VPERMI: return "X86ISD::VPERMI"; case X86ISD::VPTERNLOG: return "X86ISD::VPTERNLOG"; case X86ISD::VFIXUPIMM: return "X86ISD::VFIXUPIMM"; - case X86ISD::VFIXUPIMMS: return "X86ISD::VFIXUPIMMS"; + case X86ISD::VFIXUPIMMS: return "X86ISD::VFIXUPIMMS"; case X86ISD::VRANGE: return "X86ISD::VRANGE"; + case X86ISD::VRANGE_RND: return "X86ISD::VRANGE_RND"; + case X86ISD::VRANGES: return "X86ISD::VRANGES"; + case X86ISD::VRANGES_RND: return "X86ISD::VRANGES_RND"; case X86ISD::PMULUDQ: return "X86ISD::PMULUDQ"; case X86ISD::PMULDQ: return "X86ISD::PMULDQ"; case X86ISD::PSADBW: return "X86ISD::PSADBW"; @@ -24697,14 +25461,11 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const { case X86ISD::RDSEED: return "X86ISD::RDSEED"; case X86ISD::VPMADDUBSW: return "X86ISD::VPMADDUBSW"; case X86ISD::VPMADDWD: return "X86ISD::VPMADDWD"; - case X86ISD::VPROT: return "X86ISD::VPROT"; - case X86ISD::VPROTI: return "X86ISD::VPROTI"; case X86ISD::VPSHA: return "X86ISD::VPSHA"; case X86ISD::VPSHL: return "X86ISD::VPSHL"; case X86ISD::VPCOM: return "X86ISD::VPCOM"; case X86ISD::VPCOMU: return "X86ISD::VPCOMU"; case X86ISD::VPERMIL2: return "X86ISD::VPERMIL2"; - case X86ISD::FMADD: return "X86ISD::FMADD"; case X86ISD::FMSUB: return "X86ISD::FMSUB"; case X86ISD::FNMADD: return "X86ISD::FNMADD"; case X86ISD::FNMSUB: return "X86ISD::FNMSUB"; @@ -24716,22 +25477,40 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const { case X86ISD::FNMSUB_RND: return "X86ISD::FNMSUB_RND"; case X86ISD::FMADDSUB_RND: return "X86ISD::FMADDSUB_RND"; case X86ISD::FMSUBADD_RND: return "X86ISD::FMSUBADD_RND"; + case X86ISD::FMADDS1: return "X86ISD::FMADDS1"; + case X86ISD::FNMADDS1: return "X86ISD::FNMADDS1"; + case X86ISD::FMSUBS1: return "X86ISD::FMSUBS1"; + case X86ISD::FNMSUBS1: return "X86ISD::FNMSUBS1"; case X86ISD::FMADDS1_RND: return "X86ISD::FMADDS1_RND"; case X86ISD::FNMADDS1_RND: return "X86ISD::FNMADDS1_RND"; case X86ISD::FMSUBS1_RND: return "X86ISD::FMSUBS1_RND"; case X86ISD::FNMSUBS1_RND: return "X86ISD::FNMSUBS1_RND"; + case X86ISD::FMADDS3: return "X86ISD::FMADDS3"; + case X86ISD::FNMADDS3: return "X86ISD::FNMADDS3"; + case X86ISD::FMSUBS3: return "X86ISD::FMSUBS3"; + case X86ISD::FNMSUBS3: return "X86ISD::FNMSUBS3"; case X86ISD::FMADDS3_RND: return "X86ISD::FMADDS3_RND"; case X86ISD::FNMADDS3_RND: return "X86ISD::FNMADDS3_RND"; case X86ISD::FMSUBS3_RND: return "X86ISD::FMSUBS3_RND"; case X86ISD::FNMSUBS3_RND: return "X86ISD::FNMSUBS3_RND"; + case X86ISD::FMADD4S: return "X86ISD::FMADD4S"; + case X86ISD::FNMADD4S: return "X86ISD::FNMADD4S"; + case X86ISD::FMSUB4S: return "X86ISD::FMSUB4S"; + case X86ISD::FNMSUB4S: return "X86ISD::FNMSUB4S"; case X86ISD::VPMADD52H: return "X86ISD::VPMADD52H"; case X86ISD::VPMADD52L: return "X86ISD::VPMADD52L"; case X86ISD::VRNDSCALE: return "X86ISD::VRNDSCALE"; + case X86ISD::VRNDSCALE_RND: return "X86ISD::VRNDSCALE_RND"; case X86ISD::VRNDSCALES: return "X86ISD::VRNDSCALES"; + case X86ISD::VRNDSCALES_RND: return "X86ISD::VRNDSCALES_RND"; case X86ISD::VREDUCE: return "X86ISD::VREDUCE"; + case X86ISD::VREDUCE_RND: return "X86ISD::VREDUCE_RND"; case X86ISD::VREDUCES: return "X86ISD::VREDUCES"; + case X86ISD::VREDUCES_RND: return "X86ISD::VREDUCES_RND"; case X86ISD::VGETMANT: return "X86ISD::VGETMANT"; + case X86ISD::VGETMANT_RND: return "X86ISD::VGETMANT_RND"; case X86ISD::VGETMANTS: return "X86ISD::VGETMANTS"; + case X86ISD::VGETMANTS_RND: return "X86ISD::VGETMANTS_RND"; case X86ISD::PCMPESTRI: return "X86ISD::PCMPESTRI"; case X86ISD::PCMPISTRI: return "X86ISD::PCMPISTRI"; case X86ISD::XTEST: return "X86ISD::XTEST"; @@ -24740,9 +25519,13 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const { case X86ISD::SELECT: return "X86ISD::SELECT"; case X86ISD::SELECTS: return "X86ISD::SELECTS"; case X86ISD::ADDSUB: return "X86ISD::ADDSUB"; + case X86ISD::RCP14: return "X86ISD::RCP14"; + case X86ISD::RCP14S: return "X86ISD::RCP14S"; case X86ISD::RCP28: return "X86ISD::RCP28"; case X86ISD::RCP28S: return "X86ISD::RCP28S"; case X86ISD::EXP2: return "X86ISD::EXP2"; + case X86ISD::RSQRT14: return "X86ISD::RSQRT14"; + case X86ISD::RSQRT14S: return "X86ISD::RSQRT14S"; case X86ISD::RSQRT28: return "X86ISD::RSQRT28"; case X86ISD::RSQRT28S: return "X86ISD::RSQRT28S"; case X86ISD::FADD_RND: return "X86ISD::FADD_RND"; @@ -24780,6 +25563,7 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const { case X86ISD::SCALAR_UINT_TO_FP_RND: return "X86ISD::SCALAR_UINT_TO_FP_RND"; case X86ISD::CVTPS2PH: return "X86ISD::CVTPS2PH"; case X86ISD::CVTPH2PS: return "X86ISD::CVTPH2PS"; + case X86ISD::CVTPH2PS_RND: return "X86ISD::CVTPH2PS_RND"; case X86ISD::CVTP2SI: return "X86ISD::CVTP2SI"; case X86ISD::CVTP2UI: return "X86ISD::CVTP2UI"; case X86ISD::CVTP2SI_RND: return "X86ISD::CVTP2SI_RND"; @@ -24788,6 +25572,15 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const { case X86ISD::CVTS2UI_RND: return "X86ISD::CVTS2UI_RND"; case X86ISD::LWPINS: return "X86ISD::LWPINS"; case X86ISD::MGATHER: return "X86ISD::MGATHER"; + case X86ISD::MSCATTER: return "X86ISD::MSCATTER"; + case X86ISD::VPDPBUSD: return "X86ISD::VPDPBUSD"; + case X86ISD::VPDPBUSDS: return "X86ISD::VPDPBUSDS"; + case X86ISD::VPDPWSSD: return "X86ISD::VPDPWSSD"; + case X86ISD::VPDPWSSDS: return "X86ISD::VPDPWSSDS"; + case X86ISD::VPSHUFBITQMB: return "X86ISD::VPSHUFBITQMB"; + case X86ISD::GF2P8MULB: return "X86ISD::GF2P8MULB"; + case X86ISD::GF2P8AFFINEQB: return "X86ISD::GF2P8AFFINEQB"; + case X86ISD::GF2P8AFFINEINVQB: return "X86ISD::GF2P8AFFINEINVQB"; } return nullptr; } @@ -24796,7 +25589,8 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const { /// target, for a load/store of the specified type. bool X86TargetLowering::isLegalAddressingMode(const DataLayout &DL, const AddrMode &AM, Type *Ty, - unsigned AS) const { + unsigned AS, + Instruction *I) const { // X86 supports extremely general addressing modes. CodeModel::Model M = getTargetMachine().getCodeModel(); @@ -24853,9 +25647,9 @@ bool X86TargetLowering::isVectorShiftByScalarCheap(Type *Ty) const { if (Bits == 8) return false; - // On AVX2 there are new vpsllv[dq] instructions (and other shifts), that make - // variable shifts just as cheap as scalar ones. - if (Subtarget.hasInt256() && (Bits == 32 || Bits == 64)) + // AVX2 has vpsllv[dq] instructions (and other shifts) that make variable + // shifts just as cheap as scalar ones. + if (Subtarget.hasAVX2() && (Bits == 32 || Bits == 64)) return false; // Otherwise, it's significantly cheaper to shift by a scalar amount than by a @@ -24968,9 +25762,7 @@ bool X86TargetLowering::isNarrowingProfitable(EVT VT1, EVT VT2) const { /// VECTOR_SHUFFLE operations, those with specific masks. /// By default, if a target supports the VECTOR_SHUFFLE node, all mask values /// are assumed to be legal. -bool -X86TargetLowering::isShuffleMaskLegal(const SmallVectorImpl<int> &M, - EVT VT) const { +bool X86TargetLowering::isShuffleMaskLegal(ArrayRef<int> M, EVT VT) const { if (!VT.isSimple()) return false; @@ -24994,6 +25786,15 @@ X86TargetLowering::isVectorClearMaskLegal(const SmallVectorImpl<int> &Mask, return isShuffleMaskLegal(Mask, VT); } +bool X86TargetLowering::areJTsAllowed(const Function *Fn) const { + // If the subtarget is using retpolines, we need to not generate jump tables. + if (Subtarget.useRetpoline()) + return false; + + // Otherwise, fallback on the generic logic. + return TargetLowering::areJTsAllowed(Fn); +} + //===----------------------------------------------------------------------===// // X86 Scheduler Hooks //===----------------------------------------------------------------------===// @@ -25522,7 +26323,7 @@ MachineBasicBlock *X86TargetLowering::EmitVAStartSaveXMMRegsWithCustomInserter( int64_t RegSaveFrameIndex = MI.getOperand(1).getImm(); int64_t VarArgsFPOffset = MI.getOperand(2).getImm(); - if (!Subtarget.isCallingConvWin64(F->getFunction()->getCallingConv())) { + if (!Subtarget.isCallingConvWin64(F->getFunction().getCallingConv())) { // If %al is 0, branch around the XMM save block. BuildMI(MBB, DL, TII->get(X86::TEST8rr)).addReg(CountReg).addReg(CountReg); BuildMI(MBB, DL, TII->get(X86::JE_1)).addMBB(EndMBB); @@ -25627,65 +26428,76 @@ static bool isCMOVPseudo(MachineInstr &MI) { } } -MachineBasicBlock * -X86TargetLowering::EmitLoweredSelect(MachineInstr &MI, - MachineBasicBlock *BB) const { - const TargetInstrInfo *TII = Subtarget.getInstrInfo(); - DebugLoc DL = MI.getDebugLoc(); +// Helper function, which inserts PHI functions into SinkMBB: +// %Result(i) = phi [ %FalseValue(i), FalseMBB ], [ %TrueValue(i), TrueMBB ], +// where %FalseValue(i) and %TrueValue(i) are taken from the consequent CMOVs +// in [MIItBegin, MIItEnd) range. It returns the last MachineInstrBuilder for +// the last PHI function inserted. +static MachineInstrBuilder createPHIsForCMOVsInSinkBB( + MachineBasicBlock::iterator MIItBegin, MachineBasicBlock::iterator MIItEnd, + MachineBasicBlock *TrueMBB, MachineBasicBlock *FalseMBB, + MachineBasicBlock *SinkMBB) { + MachineFunction *MF = TrueMBB->getParent(); + const TargetInstrInfo *TII = MF->getSubtarget().getInstrInfo(); + DebugLoc DL = MIItBegin->getDebugLoc(); - // To "insert" a SELECT_CC instruction, we actually have to insert the - // diamond control-flow pattern. The incoming instruction knows the - // destination vreg to set, the condition code register to branch on, the - // true/false values to select between, and a branch opcode to use. - const BasicBlock *LLVM_BB = BB->getBasicBlock(); - MachineFunction::iterator It = ++BB->getIterator(); + X86::CondCode CC = X86::CondCode(MIItBegin->getOperand(3).getImm()); + X86::CondCode OppCC = X86::GetOppositeBranchCondition(CC); - // thisMBB: - // ... - // TrueVal = ... - // cmpTY ccX, r1, r2 - // bCC copy1MBB - // fallthrough --> copy0MBB - MachineBasicBlock *thisMBB = BB; - MachineFunction *F = BB->getParent(); + MachineBasicBlock::iterator SinkInsertionPoint = SinkMBB->begin(); - // This code lowers all pseudo-CMOV instructions. Generally it lowers these - // as described above, by inserting a BB, and then making a PHI at the join - // point to select the true and false operands of the CMOV in the PHI. - // - // The code also handles two different cases of multiple CMOV opcodes - // in a row. - // - // Case 1: - // In this case, there are multiple CMOVs in a row, all which are based on - // the same condition setting (or the exact opposite condition setting). - // In this case we can lower all the CMOVs using a single inserted BB, and - // then make a number of PHIs at the join point to model the CMOVs. The only - // trickiness here, is that in a case like: - // - // t2 = CMOV cond1 t1, f1 - // t3 = CMOV cond1 t2, f2 - // - // when rewriting this into PHIs, we have to perform some renaming on the - // temps since you cannot have a PHI operand refer to a PHI result earlier - // in the same block. The "simple" but wrong lowering would be: - // - // t2 = PHI t1(BB1), f1(BB2) - // t3 = PHI t2(BB1), f2(BB2) - // - // but clearly t2 is not defined in BB1, so that is incorrect. The proper - // renaming is to note that on the path through BB1, t2 is really just a - // copy of t1, and do that renaming, properly generating: - // - // t2 = PHI t1(BB1), f1(BB2) - // t3 = PHI t1(BB1), f2(BB2) - // - // Case 2, we lower cascaded CMOVs such as + // As we are creating the PHIs, we have to be careful if there is more than + // one. Later CMOVs may reference the results of earlier CMOVs, but later + // PHIs have to reference the individual true/false inputs from earlier PHIs. + // That also means that PHI construction must work forward from earlier to + // later, and that the code must maintain a mapping from earlier PHI's + // destination registers, and the registers that went into the PHI. + DenseMap<unsigned, std::pair<unsigned, unsigned>> RegRewriteTable; + MachineInstrBuilder MIB; + + for (MachineBasicBlock::iterator MIIt = MIItBegin; MIIt != MIItEnd; ++MIIt) { + unsigned DestReg = MIIt->getOperand(0).getReg(); + unsigned Op1Reg = MIIt->getOperand(1).getReg(); + unsigned Op2Reg = MIIt->getOperand(2).getReg(); + + // If this CMOV we are generating is the opposite condition from + // the jump we generated, then we have to swap the operands for the + // PHI that is going to be generated. + if (MIIt->getOperand(3).getImm() == OppCC) + std::swap(Op1Reg, Op2Reg); + + if (RegRewriteTable.find(Op1Reg) != RegRewriteTable.end()) + Op1Reg = RegRewriteTable[Op1Reg].first; + + if (RegRewriteTable.find(Op2Reg) != RegRewriteTable.end()) + Op2Reg = RegRewriteTable[Op2Reg].second; + + MIB = BuildMI(*SinkMBB, SinkInsertionPoint, DL, TII->get(X86::PHI), DestReg) + .addReg(Op1Reg) + .addMBB(FalseMBB) + .addReg(Op2Reg) + .addMBB(TrueMBB); + + // Add this PHI to the rewrite table. + RegRewriteTable[DestReg] = std::make_pair(Op1Reg, Op2Reg); + } + + return MIB; +} + +// Lower cascaded selects in form of (SecondCmov (FirstCMOV F, T, cc1), T, cc2). +MachineBasicBlock * +X86TargetLowering::EmitLoweredCascadedSelect(MachineInstr &FirstCMOV, + MachineInstr &SecondCascadedCMOV, + MachineBasicBlock *ThisMBB) const { + const TargetInstrInfo *TII = Subtarget.getInstrInfo(); + DebugLoc DL = FirstCMOV.getDebugLoc(); + + // We lower cascaded CMOVs such as // - // (CMOV (CMOV F, T, cc1), T, cc2) + // (SecondCascadedCMOV (FirstCMOV F, T, cc1), T, cc2) // - // to two successive branches. For that, we look for another CMOV as the - // following instruction. + // to two successive branches. // // Without this, we would add a PHI between the two jumps, which ends up // creating a few copies all around. For instance, for @@ -25749,10 +26561,145 @@ X86TargetLowering::EmitLoweredSelect(MachineInstr &MI, // .LBB5_4: // retq // - MachineInstr *CascadedCMOV = nullptr; - MachineInstr *LastCMOV = &MI; + + // We lower cascaded CMOV into two successive branches to the same block. + // EFLAGS is used by both, so mark it as live in the second. + const BasicBlock *LLVM_BB = ThisMBB->getBasicBlock(); + MachineFunction *F = ThisMBB->getParent(); + MachineBasicBlock *FirstInsertedMBB = F->CreateMachineBasicBlock(LLVM_BB); + MachineBasicBlock *SecondInsertedMBB = F->CreateMachineBasicBlock(LLVM_BB); + MachineBasicBlock *SinkMBB = F->CreateMachineBasicBlock(LLVM_BB); + + MachineFunction::iterator It = ++ThisMBB->getIterator(); + F->insert(It, FirstInsertedMBB); + F->insert(It, SecondInsertedMBB); + F->insert(It, SinkMBB); + + // For a cascaded CMOV, we lower it to two successive branches to + // the same block (SinkMBB). EFLAGS is used by both, so mark it as live in + // the FirstInsertedMBB. + FirstInsertedMBB->addLiveIn(X86::EFLAGS); + + // If the EFLAGS register isn't dead in the terminator, then claim that it's + // live into the sink and copy blocks. + const TargetRegisterInfo *TRI = Subtarget.getRegisterInfo(); + if (!SecondCascadedCMOV.killsRegister(X86::EFLAGS) && + !checkAndUpdateEFLAGSKill(SecondCascadedCMOV, ThisMBB, TRI)) { + SecondInsertedMBB->addLiveIn(X86::EFLAGS); + SinkMBB->addLiveIn(X86::EFLAGS); + } + + // Transfer the remainder of ThisMBB and its successor edges to SinkMBB. + SinkMBB->splice(SinkMBB->begin(), ThisMBB, + std::next(MachineBasicBlock::iterator(FirstCMOV)), + ThisMBB->end()); + SinkMBB->transferSuccessorsAndUpdatePHIs(ThisMBB); + + // Fallthrough block for ThisMBB. + ThisMBB->addSuccessor(FirstInsertedMBB); + // The true block target of the first branch is always SinkMBB. + ThisMBB->addSuccessor(SinkMBB); + // Fallthrough block for FirstInsertedMBB. + FirstInsertedMBB->addSuccessor(SecondInsertedMBB); + // The true block for the branch of FirstInsertedMBB. + FirstInsertedMBB->addSuccessor(SinkMBB); + // This is fallthrough. + SecondInsertedMBB->addSuccessor(SinkMBB); + + // Create the conditional branch instructions. + X86::CondCode FirstCC = X86::CondCode(FirstCMOV.getOperand(3).getImm()); + unsigned Opc = X86::GetCondBranchFromCond(FirstCC); + BuildMI(ThisMBB, DL, TII->get(Opc)).addMBB(SinkMBB); + + X86::CondCode SecondCC = + X86::CondCode(SecondCascadedCMOV.getOperand(3).getImm()); + unsigned Opc2 = X86::GetCondBranchFromCond(SecondCC); + BuildMI(FirstInsertedMBB, DL, TII->get(Opc2)).addMBB(SinkMBB); + + // SinkMBB: + // %Result = phi [ %FalseValue, SecondInsertedMBB ], [ %TrueValue, ThisMBB ] + unsigned DestReg = FirstCMOV.getOperand(0).getReg(); + unsigned Op1Reg = FirstCMOV.getOperand(1).getReg(); + unsigned Op2Reg = FirstCMOV.getOperand(2).getReg(); + MachineInstrBuilder MIB = + BuildMI(*SinkMBB, SinkMBB->begin(), DL, TII->get(X86::PHI), DestReg) + .addReg(Op1Reg) + .addMBB(SecondInsertedMBB) + .addReg(Op2Reg) + .addMBB(ThisMBB); + + // The second SecondInsertedMBB provides the same incoming value as the + // FirstInsertedMBB (the True operand of the SELECT_CC/CMOV nodes). + MIB.addReg(FirstCMOV.getOperand(2).getReg()).addMBB(FirstInsertedMBB); + // Copy the PHI result to the register defined by the second CMOV. + BuildMI(*SinkMBB, std::next(MachineBasicBlock::iterator(MIB.getInstr())), DL, + TII->get(TargetOpcode::COPY), + SecondCascadedCMOV.getOperand(0).getReg()) + .addReg(FirstCMOV.getOperand(0).getReg()); + + // Now remove the CMOVs. + FirstCMOV.eraseFromParent(); + SecondCascadedCMOV.eraseFromParent(); + + return SinkMBB; +} + +MachineBasicBlock * +X86TargetLowering::EmitLoweredSelect(MachineInstr &MI, + MachineBasicBlock *ThisMBB) const { + const TargetInstrInfo *TII = Subtarget.getInstrInfo(); + DebugLoc DL = MI.getDebugLoc(); + + // To "insert" a SELECT_CC instruction, we actually have to insert the + // diamond control-flow pattern. The incoming instruction knows the + // destination vreg to set, the condition code register to branch on, the + // true/false values to select between and a branch opcode to use. + + // ThisMBB: + // ... + // TrueVal = ... + // cmpTY ccX, r1, r2 + // bCC copy1MBB + // fallthrough --> FalseMBB + + // This code lowers all pseudo-CMOV instructions. Generally it lowers these + // as described above, by inserting a BB, and then making a PHI at the join + // point to select the true and false operands of the CMOV in the PHI. + // + // The code also handles two different cases of multiple CMOV opcodes + // in a row. + // + // Case 1: + // In this case, there are multiple CMOVs in a row, all which are based on + // the same condition setting (or the exact opposite condition setting). + // In this case we can lower all the CMOVs using a single inserted BB, and + // then make a number of PHIs at the join point to model the CMOVs. The only + // trickiness here, is that in a case like: + // + // t2 = CMOV cond1 t1, f1 + // t3 = CMOV cond1 t2, f2 + // + // when rewriting this into PHIs, we have to perform some renaming on the + // temps since you cannot have a PHI operand refer to a PHI result earlier + // in the same block. The "simple" but wrong lowering would be: + // + // t2 = PHI t1(BB1), f1(BB2) + // t3 = PHI t2(BB1), f2(BB2) + // + // but clearly t2 is not defined in BB1, so that is incorrect. The proper + // renaming is to note that on the path through BB1, t2 is really just a + // copy of t1, and do that renaming, properly generating: + // + // t2 = PHI t1(BB1), f1(BB2) + // t3 = PHI t1(BB1), f2(BB2) + // + // Case 2: + // CMOV ((CMOV F, T, cc1), T, cc2) is checked here and handled by a separate + // function - EmitLoweredCascadedSelect. + X86::CondCode CC = X86::CondCode(MI.getOperand(3).getImm()); X86::CondCode OppCC = X86::GetOppositeBranchCondition(CC); + MachineInstr *LastCMOV = &MI; MachineBasicBlock::iterator NextMIIt = std::next(MachineBasicBlock::iterator(MI)); @@ -25762,7 +26709,7 @@ X86TargetLowering::EmitLoweredSelect(MachineInstr &MI, if (isCMOVPseudo(MI)) { // See if we have a string of CMOVS with the same condition. - while (NextMIIt != BB->end() && isCMOVPseudo(*NextMIIt) && + while (NextMIIt != ThisMBB->end() && isCMOVPseudo(*NextMIIt) && (NextMIIt->getOperand(3).getImm() == CC || NextMIIt->getOperand(3).getImm() == OppCC)) { LastCMOV = &*NextMIIt; @@ -25772,136 +26719,61 @@ X86TargetLowering::EmitLoweredSelect(MachineInstr &MI, // This checks for case 2, but only do this if we didn't already find // case 1, as indicated by LastCMOV == MI. - if (LastCMOV == &MI && NextMIIt != BB->end() && + if (LastCMOV == &MI && NextMIIt != ThisMBB->end() && NextMIIt->getOpcode() == MI.getOpcode() && NextMIIt->getOperand(2).getReg() == MI.getOperand(2).getReg() && NextMIIt->getOperand(1).getReg() == MI.getOperand(0).getReg() && NextMIIt->getOperand(1).isKill()) { - CascadedCMOV = &*NextMIIt; + return EmitLoweredCascadedSelect(MI, *NextMIIt, ThisMBB); } - MachineBasicBlock *jcc1MBB = nullptr; - - // If we have a cascaded CMOV, we lower it to two successive branches to - // the same block. EFLAGS is used by both, so mark it as live in the second. - if (CascadedCMOV) { - jcc1MBB = F->CreateMachineBasicBlock(LLVM_BB); - F->insert(It, jcc1MBB); - jcc1MBB->addLiveIn(X86::EFLAGS); - } + const BasicBlock *LLVM_BB = ThisMBB->getBasicBlock(); + MachineFunction *F = ThisMBB->getParent(); + MachineBasicBlock *FalseMBB = F->CreateMachineBasicBlock(LLVM_BB); + MachineBasicBlock *SinkMBB = F->CreateMachineBasicBlock(LLVM_BB); - MachineBasicBlock *copy0MBB = F->CreateMachineBasicBlock(LLVM_BB); - MachineBasicBlock *sinkMBB = F->CreateMachineBasicBlock(LLVM_BB); - F->insert(It, copy0MBB); - F->insert(It, sinkMBB); + MachineFunction::iterator It = ++ThisMBB->getIterator(); + F->insert(It, FalseMBB); + F->insert(It, SinkMBB); // If the EFLAGS register isn't dead in the terminator, then claim that it's // live into the sink and copy blocks. const TargetRegisterInfo *TRI = Subtarget.getRegisterInfo(); - - MachineInstr *LastEFLAGSUser = CascadedCMOV ? CascadedCMOV : LastCMOV; - if (!LastEFLAGSUser->killsRegister(X86::EFLAGS) && - !checkAndUpdateEFLAGSKill(LastEFLAGSUser, BB, TRI)) { - copy0MBB->addLiveIn(X86::EFLAGS); - sinkMBB->addLiveIn(X86::EFLAGS); + if (!LastCMOV->killsRegister(X86::EFLAGS) && + !checkAndUpdateEFLAGSKill(LastCMOV, ThisMBB, TRI)) { + FalseMBB->addLiveIn(X86::EFLAGS); + SinkMBB->addLiveIn(X86::EFLAGS); } - // Transfer the remainder of BB and its successor edges to sinkMBB. - sinkMBB->splice(sinkMBB->begin(), BB, - std::next(MachineBasicBlock::iterator(LastCMOV)), BB->end()); - sinkMBB->transferSuccessorsAndUpdatePHIs(BB); - - // Add the true and fallthrough blocks as its successors. - if (CascadedCMOV) { - // The fallthrough block may be jcc1MBB, if we have a cascaded CMOV. - BB->addSuccessor(jcc1MBB); - - // In that case, jcc1MBB will itself fallthrough the copy0MBB, and - // jump to the sinkMBB. - jcc1MBB->addSuccessor(copy0MBB); - jcc1MBB->addSuccessor(sinkMBB); - } else { - BB->addSuccessor(copy0MBB); - } + // Transfer the remainder of ThisMBB and its successor edges to SinkMBB. + SinkMBB->splice(SinkMBB->begin(), ThisMBB, + std::next(MachineBasicBlock::iterator(LastCMOV)), + ThisMBB->end()); + SinkMBB->transferSuccessorsAndUpdatePHIs(ThisMBB); - // The true block target of the first (or only) branch is always sinkMBB. - BB->addSuccessor(sinkMBB); + // Fallthrough block for ThisMBB. + ThisMBB->addSuccessor(FalseMBB); + // The true block target of the first (or only) branch is always a SinkMBB. + ThisMBB->addSuccessor(SinkMBB); + // Fallthrough block for FalseMBB. + FalseMBB->addSuccessor(SinkMBB); // Create the conditional branch instruction. unsigned Opc = X86::GetCondBranchFromCond(CC); - BuildMI(BB, DL, TII->get(Opc)).addMBB(sinkMBB); - - if (CascadedCMOV) { - unsigned Opc2 = X86::GetCondBranchFromCond( - (X86::CondCode)CascadedCMOV->getOperand(3).getImm()); - BuildMI(jcc1MBB, DL, TII->get(Opc2)).addMBB(sinkMBB); - } - - // copy0MBB: - // %FalseValue = ... - // # fallthrough to sinkMBB - copy0MBB->addSuccessor(sinkMBB); + BuildMI(ThisMBB, DL, TII->get(Opc)).addMBB(SinkMBB); - // sinkMBB: - // %Result = phi [ %FalseValue, copy0MBB ], [ %TrueValue, thisMBB ] + // SinkMBB: + // %Result = phi [ %FalseValue, FalseMBB ], [ %TrueValue, ThisMBB ] // ... MachineBasicBlock::iterator MIItBegin = MachineBasicBlock::iterator(MI); MachineBasicBlock::iterator MIItEnd = - std::next(MachineBasicBlock::iterator(LastCMOV)); - MachineBasicBlock::iterator SinkInsertionPoint = sinkMBB->begin(); - DenseMap<unsigned, std::pair<unsigned, unsigned>> RegRewriteTable; - MachineInstrBuilder MIB; - - // As we are creating the PHIs, we have to be careful if there is more than - // one. Later CMOVs may reference the results of earlier CMOVs, but later - // PHIs have to reference the individual true/false inputs from earlier PHIs. - // That also means that PHI construction must work forward from earlier to - // later, and that the code must maintain a mapping from earlier PHI's - // destination registers, and the registers that went into the PHI. - - for (MachineBasicBlock::iterator MIIt = MIItBegin; MIIt != MIItEnd; ++MIIt) { - unsigned DestReg = MIIt->getOperand(0).getReg(); - unsigned Op1Reg = MIIt->getOperand(1).getReg(); - unsigned Op2Reg = MIIt->getOperand(2).getReg(); - - // If this CMOV we are generating is the opposite condition from - // the jump we generated, then we have to swap the operands for the - // PHI that is going to be generated. - if (MIIt->getOperand(3).getImm() == OppCC) - std::swap(Op1Reg, Op2Reg); - - if (RegRewriteTable.find(Op1Reg) != RegRewriteTable.end()) - Op1Reg = RegRewriteTable[Op1Reg].first; - - if (RegRewriteTable.find(Op2Reg) != RegRewriteTable.end()) - Op2Reg = RegRewriteTable[Op2Reg].second; - - MIB = BuildMI(*sinkMBB, SinkInsertionPoint, DL, - TII->get(X86::PHI), DestReg) - .addReg(Op1Reg).addMBB(copy0MBB) - .addReg(Op2Reg).addMBB(thisMBB); - - // Add this PHI to the rewrite table. - RegRewriteTable[DestReg] = std::make_pair(Op1Reg, Op2Reg); - } - - // If we have a cascaded CMOV, the second Jcc provides the same incoming - // value as the first Jcc (the True operand of the SELECT_CC/CMOV nodes). - if (CascadedCMOV) { - MIB.addReg(MI.getOperand(2).getReg()).addMBB(jcc1MBB); - // Copy the PHI result to the register defined by the second CMOV. - BuildMI(*sinkMBB, std::next(MachineBasicBlock::iterator(MIB.getInstr())), - DL, TII->get(TargetOpcode::COPY), - CascadedCMOV->getOperand(0).getReg()) - .addReg(MI.getOperand(0).getReg()); - CascadedCMOV->eraseFromParent(); - } + std::next(MachineBasicBlock::iterator(LastCMOV)); + createPHIsForCMOVsInSinkBB(MIItBegin, MIItEnd, ThisMBB, FalseMBB, SinkMBB); // Now remove the CMOV(s). - for (MachineBasicBlock::iterator MIIt = MIItBegin; MIIt != MIItEnd; ) - (MIIt++)->eraseFromParent(); + ThisMBB->erase(MIItBegin, MIItEnd); - return sinkMBB; + return SinkMBB; } MachineBasicBlock * @@ -26094,7 +26966,7 @@ X86TargetLowering::EmitLoweredCatchRet(MachineInstr &MI, DebugLoc DL = MI.getDebugLoc(); assert(!isAsynchronousEHPersonality( - classifyEHPersonality(MF->getFunction()->getPersonalityFn())) && + classifyEHPersonality(MF->getFunction().getPersonalityFn())) && "SEH does not use catchret!"); // Only 32-bit EH needs to worry about manually restoring stack pointers. @@ -26121,7 +26993,7 @@ MachineBasicBlock * X86TargetLowering::EmitLoweredCatchPad(MachineInstr &MI, MachineBasicBlock *BB) const { MachineFunction *MF = BB->getParent(); - const Constant *PerFn = MF->getFunction()->getPersonalityFn(); + const Constant *PerFn = MF->getFunction().getPersonalityFn(); bool IsSEH = isAsynchronousEHPersonality(classifyEHPersonality(PerFn)); // Only 32-bit SEH requires special handling for catchpad. if (IsSEH && Subtarget.is32Bit()) { @@ -26225,6 +27097,130 @@ X86TargetLowering::EmitLoweredTLSCall(MachineInstr &MI, return BB; } +static unsigned getOpcodeForRetpoline(unsigned RPOpc) { + switch (RPOpc) { + case X86::RETPOLINE_CALL32: + return X86::CALLpcrel32; + case X86::RETPOLINE_CALL64: + return X86::CALL64pcrel32; + case X86::RETPOLINE_TCRETURN32: + return X86::TCRETURNdi; + case X86::RETPOLINE_TCRETURN64: + return X86::TCRETURNdi64; + } + llvm_unreachable("not retpoline opcode"); +} + +static const char *getRetpolineSymbol(const X86Subtarget &Subtarget, + unsigned Reg) { + if (Subtarget.useRetpolineExternalThunk()) { + // When using an external thunk for retpolines, we pick names that match the + // names GCC happens to use as well. This helps simplify the implementation + // of the thunks for kernels where they have no easy ability to create + // aliases and are doing non-trivial configuration of the thunk's body. For + // example, the Linux kernel will do boot-time hot patching of the thunk + // bodies and cannot easily export aliases of these to loaded modules. + // + // Note that at any point in the future, we may need to change the semantics + // of how we implement retpolines and at that time will likely change the + // name of the called thunk. Essentially, there is no hard guarantee that + // LLVM will generate calls to specific thunks, we merely make a best-effort + // attempt to help out kernels and other systems where duplicating the + // thunks is costly. + switch (Reg) { + case X86::EAX: + assert(!Subtarget.is64Bit() && "Should not be using a 32-bit thunk!"); + return "__x86_indirect_thunk_eax"; + case X86::ECX: + assert(!Subtarget.is64Bit() && "Should not be using a 32-bit thunk!"); + return "__x86_indirect_thunk_ecx"; + case X86::EDX: + assert(!Subtarget.is64Bit() && "Should not be using a 32-bit thunk!"); + return "__x86_indirect_thunk_edx"; + case X86::EDI: + assert(!Subtarget.is64Bit() && "Should not be using a 32-bit thunk!"); + return "__x86_indirect_thunk_edi"; + case X86::R11: + assert(Subtarget.is64Bit() && "Should not be using a 64-bit thunk!"); + return "__x86_indirect_thunk_r11"; + } + llvm_unreachable("unexpected reg for retpoline"); + } + + // When targeting an internal COMDAT thunk use an LLVM-specific name. + switch (Reg) { + case X86::EAX: + assert(!Subtarget.is64Bit() && "Should not be using a 32-bit thunk!"); + return "__llvm_retpoline_eax"; + case X86::ECX: + assert(!Subtarget.is64Bit() && "Should not be using a 32-bit thunk!"); + return "__llvm_retpoline_ecx"; + case X86::EDX: + assert(!Subtarget.is64Bit() && "Should not be using a 32-bit thunk!"); + return "__llvm_retpoline_edx"; + case X86::EDI: + assert(!Subtarget.is64Bit() && "Should not be using a 32-bit thunk!"); + return "__llvm_retpoline_edi"; + case X86::R11: + assert(Subtarget.is64Bit() && "Should not be using a 64-bit thunk!"); + return "__llvm_retpoline_r11"; + } + llvm_unreachable("unexpected reg for retpoline"); +} + +MachineBasicBlock * +X86TargetLowering::EmitLoweredRetpoline(MachineInstr &MI, + MachineBasicBlock *BB) const { + // Copy the virtual register into the R11 physical register and + // call the retpoline thunk. + DebugLoc DL = MI.getDebugLoc(); + const X86InstrInfo *TII = Subtarget.getInstrInfo(); + unsigned CalleeVReg = MI.getOperand(0).getReg(); + unsigned Opc = getOpcodeForRetpoline(MI.getOpcode()); + + // Find an available scratch register to hold the callee. On 64-bit, we can + // just use R11, but we scan for uses anyway to ensure we don't generate + // incorrect code. On 32-bit, we use one of EAX, ECX, or EDX that isn't + // already a register use operand to the call to hold the callee. If none + // are available, use EDI instead. EDI is chosen because EBX is the PIC base + // register and ESI is the base pointer to realigned stack frames with VLAs. + SmallVector<unsigned, 3> AvailableRegs; + if (Subtarget.is64Bit()) + AvailableRegs.push_back(X86::R11); + else + AvailableRegs.append({X86::EAX, X86::ECX, X86::EDX, X86::EDI}); + + // Zero out any registers that are already used. + for (const auto &MO : MI.operands()) { + if (MO.isReg() && MO.isUse()) + for (unsigned &Reg : AvailableRegs) + if (Reg == MO.getReg()) + Reg = 0; + } + + // Choose the first remaining non-zero available register. + unsigned AvailableReg = 0; + for (unsigned MaybeReg : AvailableRegs) { + if (MaybeReg) { + AvailableReg = MaybeReg; + break; + } + } + if (!AvailableReg) + report_fatal_error("calling convention incompatible with retpoline, no " + "available registers"); + + const char *Symbol = getRetpolineSymbol(Subtarget, AvailableReg); + + BuildMI(*BB, MI, DL, TII->get(TargetOpcode::COPY), AvailableReg) + .addReg(CalleeVReg); + MI.getOperand(0).ChangeToES(Symbol); + MI.setDesc(TII->get(Opc)); + MachineInstrBuilder(*BB->getParent(), &MI) + .addReg(AvailableReg, RegState::Implicit | RegState::Kill); + return BB; +} + MachineBasicBlock * X86TargetLowering::emitEHSjLjSetJmp(MachineInstr &MI, MachineBasicBlock *MBB) const { @@ -26480,7 +27476,7 @@ void X86TargetLowering::SetupEntryBlockForSjLj(MachineInstr &MI, } MachineInstrBuilder MIB = BuildMI(*MBB, MI, DL, TII->get(Op)); - addFrameReference(MIB, FI, 36); + addFrameReference(MIB, FI, Subtarget.is64Bit() ? 56 : 36); if (UseImmLabel) MIB.addMBB(DispatchBB); else @@ -26562,8 +27558,8 @@ X86TargetLowering::EmitSjLjDispatchBlock(MachineInstr &MI, SetupEntryBlockForSjLj(MI, BB, DispatchBB, FI); // Create the jump table and associated information - MachineJumpTableInfo *JTI = - MF->getOrCreateJumpTableInfo(getJumpTableEncoding()); + unsigned JTE = getJumpTableEncoding(); + MachineJumpTableInfo *JTI = MF->getOrCreateJumpTableInfo(JTE); unsigned MJTI = JTI->createJumpTableIndex(LPadList); const X86RegisterInfo &RI = TII->getRegisterInfo(); @@ -26586,25 +27582,76 @@ X86TargetLowering::EmitSjLjDispatchBlock(MachineInstr &MI, .addRegMask(RI.getNoPreservedMask()); } - unsigned IReg = MRI->createVirtualRegister(&X86::GR32RegClass); + // IReg is used as an index in a memory operand and therefore can't be SP + unsigned IReg = MRI->createVirtualRegister(&X86::GR32_NOSPRegClass); addFrameReference(BuildMI(DispatchBB, DL, TII->get(X86::MOV32rm), IReg), FI, - 4); + Subtarget.is64Bit() ? 8 : 4); BuildMI(DispatchBB, DL, TII->get(X86::CMP32ri)) .addReg(IReg) .addImm(LPadList.size()); - BuildMI(DispatchBB, DL, TII->get(X86::JA_1)).addMBB(TrapBB); + BuildMI(DispatchBB, DL, TII->get(X86::JAE_1)).addMBB(TrapBB); - unsigned JReg = MRI->createVirtualRegister(&X86::GR32RegClass); - BuildMI(DispContBB, DL, TII->get(X86::SUB32ri), JReg) - .addReg(IReg) - .addImm(1); - BuildMI(DispContBB, DL, - TII->get(Subtarget.is64Bit() ? X86::JMP64m : X86::JMP32m)) - .addReg(0) - .addImm(Subtarget.is64Bit() ? 8 : 4) - .addReg(JReg) - .addJumpTableIndex(MJTI) - .addReg(0); + if (Subtarget.is64Bit()) { + unsigned BReg = MRI->createVirtualRegister(&X86::GR64RegClass); + unsigned IReg64 = MRI->createVirtualRegister(&X86::GR64_NOSPRegClass); + + // leaq .LJTI0_0(%rip), BReg + BuildMI(DispContBB, DL, TII->get(X86::LEA64r), BReg) + .addReg(X86::RIP) + .addImm(1) + .addReg(0) + .addJumpTableIndex(MJTI) + .addReg(0); + // movzx IReg64, IReg + BuildMI(DispContBB, DL, TII->get(TargetOpcode::SUBREG_TO_REG), IReg64) + .addImm(0) + .addReg(IReg) + .addImm(X86::sub_32bit); + + switch (JTE) { + case MachineJumpTableInfo::EK_BlockAddress: + // jmpq *(BReg,IReg64,8) + BuildMI(DispContBB, DL, TII->get(X86::JMP64m)) + .addReg(BReg) + .addImm(8) + .addReg(IReg64) + .addImm(0) + .addReg(0); + break; + case MachineJumpTableInfo::EK_LabelDifference32: { + unsigned OReg = MRI->createVirtualRegister(&X86::GR32RegClass); + unsigned OReg64 = MRI->createVirtualRegister(&X86::GR64RegClass); + unsigned TReg = MRI->createVirtualRegister(&X86::GR64RegClass); + + // movl (BReg,IReg64,4), OReg + BuildMI(DispContBB, DL, TII->get(X86::MOV32rm), OReg) + .addReg(BReg) + .addImm(4) + .addReg(IReg64) + .addImm(0) + .addReg(0); + // movsx OReg64, OReg + BuildMI(DispContBB, DL, TII->get(X86::MOVSX64rr32), OReg64).addReg(OReg); + // addq BReg, OReg64, TReg + BuildMI(DispContBB, DL, TII->get(X86::ADD64rr), TReg) + .addReg(OReg64) + .addReg(BReg); + // jmpq *TReg + BuildMI(DispContBB, DL, TII->get(X86::JMP64r)).addReg(TReg); + break; + } + default: + llvm_unreachable("Unexpected jump table encoding"); + } + } else { + // jmpl *.LJTI0_0(,IReg,4) + BuildMI(DispContBB, DL, TII->get(X86::JMP32m)) + .addReg(0) + .addImm(4) + .addReg(IReg) + .addJumpTableIndex(MJTI) + .addReg(0); + } // Add the jump table entries as successors to the MBB. SmallPtrSet<MachineBasicBlock *, 8> SeenMBBs; @@ -26689,6 +27736,11 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI, case X86::TLS_base_addr32: case X86::TLS_base_addr64: return EmitLoweredTLSAddr(MI, BB); + case X86::RETPOLINE_CALL32: + case X86::RETPOLINE_CALL64: + case X86::RETPOLINE_TCRETURN32: + case X86::RETPOLINE_TCRETURN64: + return EmitLoweredRetpoline(MI, BB); case X86::CATCHRET: return EmitLoweredCatchRet(MI, BB); case X86::CATCHPAD: @@ -26975,21 +28027,6 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op, Known.resetAll(); switch (Opc) { default: break; - case X86ISD::ADD: - case X86ISD::SUB: - case X86ISD::ADC: - case X86ISD::SBB: - case X86ISD::SMUL: - case X86ISD::UMUL: - case X86ISD::INC: - case X86ISD::DEC: - case X86ISD::OR: - case X86ISD::XOR: - case X86ISD::AND: - // These nodes' second result is a boolean. - if (Op.getResNo() == 0) - break; - LLVM_FALLTHROUGH; case X86ISD::SETCC: Known.Zero.setBitsFrom(1); break; @@ -26998,6 +28035,17 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op, Known.Zero.setBitsFrom(NumLoBits); break; } + case X86ISD::PEXTRB: + case X86ISD::PEXTRW: { + SDValue Src = Op.getOperand(0); + EVT SrcVT = Src.getValueType(); + APInt DemandedElt = APInt::getOneBitSet(SrcVT.getVectorNumElements(), + Op.getConstantOperandVal(1)); + DAG.computeKnownBits(Src, Known, DemandedElt, Depth + 1); + Known = Known.zextOrTrunc(BitWidth); + Known.Zero.setBitsFrom(SrcVT.getScalarSizeInBits()); + break; + } case X86ISD::VSHLI: case X86ISD::VSRLI: { if (auto *ShiftImm = dyn_cast<ConstantSDNode>(Op.getOperand(1))) { @@ -27006,7 +28054,7 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op, break; } - DAG.computeKnownBits(Op.getOperand(0), Known, Depth + 1); + DAG.computeKnownBits(Op.getOperand(0), Known, DemandedElts, Depth + 1); unsigned ShAmt = ShiftImm->getZExtValue(); if (Opc == X86ISD::VSHLI) { Known.Zero <<= ShAmt; @@ -27023,6 +28071,7 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op, break; } case X86ISD::VZEXT: { + // TODO: Add DemandedElts support. SDValue N0 = Op.getOperand(0); unsigned NumElts = VT.getVectorNumElements(); @@ -27038,6 +28087,26 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op, Known.Zero.setBitsFrom(InBitWidth); break; } + case X86ISD::CMOV: { + DAG.computeKnownBits(Op.getOperand(1), Known, Depth+1); + // If we don't know any bits, early out. + if (Known.isUnknown()) + break; + KnownBits Known2; + DAG.computeKnownBits(Op.getOperand(0), Known2, Depth+1); + + // Only known if known in both the LHS and RHS. + Known.One &= Known2.One; + Known.Zero &= Known2.Zero; + break; + } + case X86ISD::UDIVREM8_ZEXT_HREG: + // TODO: Support more than just the zero extended bits? + if (Op.getResNo() != 1) + break; + // The remainder is zero extended. + Known.Zero.setBitsFrom(8); + break; } } @@ -27052,18 +28121,42 @@ unsigned X86TargetLowering::ComputeNumSignBitsForTargetNode( return VTBits; case X86ISD::VSEXT: { + // TODO: Add DemandedElts support. SDValue Src = Op.getOperand(0); unsigned Tmp = DAG.ComputeNumSignBits(Src, Depth + 1); Tmp += VTBits - Src.getScalarValueSizeInBits(); return Tmp; } - case X86ISD::VSHLI: { + case X86ISD::VTRUNC: { + // TODO: Add DemandedElts support. SDValue Src = Op.getOperand(0); + unsigned NumSrcBits = Src.getScalarValueSizeInBits(); + assert(VTBits < NumSrcBits && "Illegal truncation input type"); unsigned Tmp = DAG.ComputeNumSignBits(Src, Depth + 1); + if (Tmp > (NumSrcBits - VTBits)) + return Tmp - (NumSrcBits - VTBits); + return 1; + } + + case X86ISD::PACKSS: { + // PACKSS is just a truncation if the sign bits extend to the packed size. + // TODO: Add DemandedElts support. + unsigned SrcBits = Op.getOperand(0).getScalarValueSizeInBits(); + unsigned Tmp0 = DAG.ComputeNumSignBits(Op.getOperand(0), Depth + 1); + unsigned Tmp1 = DAG.ComputeNumSignBits(Op.getOperand(1), Depth + 1); + unsigned Tmp = std::min(Tmp0, Tmp1); + if (Tmp > (SrcBits - VTBits)) + return Tmp - (SrcBits - VTBits); + return 1; + } + + case X86ISD::VSHLI: { + SDValue Src = Op.getOperand(0); APInt ShiftVal = cast<ConstantSDNode>(Op.getOperand(1))->getAPIntValue(); if (ShiftVal.uge(VTBits)) return VTBits; // Shifted all bits out --> zero. + unsigned Tmp = DAG.ComputeNumSignBits(Src, DemandedElts, Depth + 1); if (ShiftVal.uge(Tmp)) return 1; // Shifted all sign bits out --> unknown. return Tmp - ShiftVal.getZExtValue(); @@ -27071,8 +28164,10 @@ unsigned X86TargetLowering::ComputeNumSignBitsForTargetNode( case X86ISD::VSRAI: { SDValue Src = Op.getOperand(0); - unsigned Tmp = DAG.ComputeNumSignBits(Src, Depth + 1); APInt ShiftVal = cast<ConstantSDNode>(Op.getOperand(1))->getAPIntValue(); + if (ShiftVal.uge(VTBits - 1)) + return VTBits; // Sign splat. + unsigned Tmp = DAG.ComputeNumSignBits(Src, DemandedElts, Depth + 1); ShiftVal += Tmp; return ShiftVal.uge(VTBits) ? VTBits : ShiftVal.getZExtValue(); } @@ -27084,12 +28179,31 @@ unsigned X86TargetLowering::ComputeNumSignBitsForTargetNode( case X86ISD::VPCOMU: // Vector compares return zero/all-bits result values. return VTBits; + + case X86ISD::CMOV: { + unsigned Tmp0 = DAG.ComputeNumSignBits(Op.getOperand(0), Depth+1); + if (Tmp0 == 1) return 1; // Early out. + unsigned Tmp1 = DAG.ComputeNumSignBits(Op.getOperand(1), Depth+1); + return std::min(Tmp0, Tmp1); + } + case X86ISD::SDIVREM8_SEXT_HREG: + // TODO: Support more than just the sign extended bits? + if (Op.getResNo() != 1) + break; + // The remainder is sign extended. + return VTBits - 7; } // Fallback case. return 1; } +SDValue X86TargetLowering::unwrapAddress(SDValue N) const { + if (N->getOpcode() == X86ISD::Wrapper || N->getOpcode() == X86ISD::WrapperRIP) + return N->getOperand(0); + return N; +} + /// Returns true (and the GlobalValue and the offset) if the node is a /// GlobalAddress + offset. bool X86TargetLowering::isGAPlusOffset(SDNode *N, @@ -27130,13 +28244,18 @@ static bool matchUnaryVectorShuffle(MVT MaskVT, ArrayRef<int> Mask, } if (Match) { unsigned SrcSize = std::max(128u, NumDstElts * MaskEltSize); - SrcVT = MVT::getVectorVT(MaskVT.getScalarType(), SrcSize / MaskEltSize); - if (SrcVT != MaskVT) + MVT ScalarTy = MaskVT.isInteger() ? MaskVT.getScalarType() : + MVT::getIntegerVT(MaskEltSize); + SrcVT = MVT::getVectorVT(ScalarTy, SrcSize / MaskEltSize); + + if (SrcVT.getSizeInBits() != MaskVT.getSizeInBits()) { V1 = extractSubVector(V1, 0, DAG, DL, SrcSize); + Shuffle = unsigned(X86ISD::VZEXT); + } else + Shuffle = unsigned(ISD::ZERO_EXTEND_VECTOR_INREG); + DstVT = MVT::getIntegerVT(Scale * MaskEltSize); DstVT = MVT::getVectorVT(DstVT, NumDstElts); - Shuffle = SrcVT != MaskVT ? unsigned(X86ISD::VZEXT) - : unsigned(ISD::ZERO_EXTEND_VECTOR_INREG); return true; } } @@ -27155,7 +28274,7 @@ static bool matchUnaryVectorShuffle(MVT MaskVT, ArrayRef<int> Mask, // instructions are no slower than UNPCKLPD but has the option to // fold the input operand into even an unaligned memory load. if (MaskVT.is128BitVector() && Subtarget.hasSSE3() && AllowFloatDomain) { - if (isTargetShuffleEquivalent(Mask, {0, 0})) { + if (!Subtarget.hasAVX2() && isTargetShuffleEquivalent(Mask, {0, 0})) { Shuffle = X86ISD::MOVDDUP; SrcVT = DstVT = MVT::v2f64; return true; @@ -27290,7 +28409,7 @@ static bool matchUnaryPermuteVectorShuffle(MVT MaskVT, ArrayRef<int> Mask, // Narrow the repeated mask to create 32-bit element permutes. SmallVector<int, 4> WordMask = RepeatedMask; if (MaskScalarSizeInBits == 64) - scaleShuffleMask(2, RepeatedMask, WordMask); + scaleShuffleMask<int>(2, RepeatedMask, WordMask); Shuffle = (AllowIntDomain ? X86ISD::PSHUFD : X86ISD::VPERMILPI); ShuffleVT = (AllowIntDomain ? MVT::i32 : MVT::f32); @@ -27356,7 +28475,7 @@ static bool matchBinaryVectorShuffle(MVT MaskVT, ArrayRef<int> Mask, SDValue &V1, SDValue &V2, SDLoc &DL, SelectionDAG &DAG, const X86Subtarget &Subtarget, - unsigned &Shuffle, MVT &ShuffleVT, + unsigned &Shuffle, MVT &SrcVT, MVT &DstVT, bool IsUnary) { unsigned EltSizeInBits = MaskVT.getScalarSizeInBits(); @@ -27364,26 +28483,36 @@ static bool matchBinaryVectorShuffle(MVT MaskVT, ArrayRef<int> Mask, if (isTargetShuffleEquivalent(Mask, {0, 0}) && AllowFloatDomain) { V2 = V1; Shuffle = X86ISD::MOVLHPS; - ShuffleVT = MVT::v4f32; + SrcVT = DstVT = MVT::v4f32; return true; } if (isTargetShuffleEquivalent(Mask, {1, 1}) && AllowFloatDomain) { V2 = V1; Shuffle = X86ISD::MOVHLPS; - ShuffleVT = MVT::v4f32; + SrcVT = DstVT = MVT::v4f32; return true; } if (isTargetShuffleEquivalent(Mask, {0, 3}) && Subtarget.hasSSE2() && (AllowFloatDomain || !Subtarget.hasSSE41())) { std::swap(V1, V2); Shuffle = X86ISD::MOVSD; - ShuffleVT = MaskVT; + SrcVT = DstVT = MaskVT; return true; } if (isTargetShuffleEquivalent(Mask, {4, 1, 2, 3}) && (AllowFloatDomain || !Subtarget.hasSSE41())) { Shuffle = X86ISD::MOVSS; - ShuffleVT = MaskVT; + SrcVT = DstVT = MaskVT; + return true; + } + } + + // Attempt to match against either a unary or binary PACKSS/PACKUS shuffle. + // TODO add support for 256/512-bit types. + if ((MaskVT == MVT::v8i16 || MaskVT == MVT::v16i8) && Subtarget.hasSSE2()) { + if (matchVectorShuffleWithPACK(MaskVT, SrcVT, V1, V2, Shuffle, Mask, DAG, + Subtarget)) { + DstVT = MaskVT; return true; } } @@ -27396,9 +28525,9 @@ static bool matchBinaryVectorShuffle(MVT MaskVT, ArrayRef<int> Mask, (MaskVT.is512BitVector() && Subtarget.hasAVX512())) { if (matchVectorShuffleWithUNPCK(MaskVT, V1, V2, Shuffle, IsUnary, Mask, DL, DAG, Subtarget)) { - ShuffleVT = MaskVT; - if (ShuffleVT.is256BitVector() && !Subtarget.hasAVX2()) - ShuffleVT = (32 == EltSizeInBits ? MVT::v8f32 : MVT::v4f64); + SrcVT = DstVT = MaskVT; + if (MaskVT.is256BitVector() && !Subtarget.hasAVX2()) + SrcVT = DstVT = (32 == EltSizeInBits ? MVT::v8f32 : MVT::v4f64); return true; } } @@ -27572,11 +28701,11 @@ static bool matchBinaryPermuteVectorShuffle(MVT MaskVT, ArrayRef<int> Mask, /// into either a single instruction if there is a special purpose instruction /// for this operation, or into a PSHUFB instruction which is a fully general /// instruction but should only be used to replace chains over a certain depth. -static bool combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, - ArrayRef<int> BaseMask, int Depth, - bool HasVariableMask, SelectionDAG &DAG, - TargetLowering::DAGCombinerInfo &DCI, - const X86Subtarget &Subtarget) { +static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, + ArrayRef<int> BaseMask, int Depth, + bool HasVariableMask, SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI, + const X86Subtarget &Subtarget) { assert(!BaseMask.empty() && "Cannot combine an empty shuffle mask!"); assert((Inputs.size() == 1 || Inputs.size() == 2) && "Unexpected number of shuffle inputs!"); @@ -27601,9 +28730,7 @@ static bool combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, unsigned NumBaseMaskElts = BaseMask.size(); if (NumBaseMaskElts == 1) { assert(BaseMask[0] == 0 && "Invalid shuffle index found!"); - DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, V1), - /*AddTo*/ true); - return true; + return DAG.getBitcast(RootVT, V1); } unsigned RootSizeInBits = RootVT.getSizeInBits(); @@ -27620,17 +28747,18 @@ static bool combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, // TODO - attempt to narrow Mask back to writemask size. bool IsEVEXShuffle = RootSizeInBits == 512 || (Subtarget.hasVLX() && RootSizeInBits >= 128); - if (IsEVEXShuffle && (RootVT.getScalarSizeInBits() != BaseMaskEltSizeInBits)) - return false; // TODO - handle 128/256-bit lane shuffles of 512-bit vectors. // Handle 128-bit lane shuffles of 256-bit vectors. + // If we have AVX2, prefer to use VPERMQ/VPERMPD for unary shuffles unless + // we need to use the zeroing feature. // TODO - this should support binary shuffles. if (UnaryShuffle && RootVT.is256BitVector() && NumBaseMaskElts == 2 && + !(Subtarget.hasAVX2() && BaseMask[0] >= -1 && BaseMask[1] >= -1) && !isSequentialOrUndefOrZeroInRange(BaseMask, 0, 2, 0)) { if (Depth == 1 && Root.getOpcode() == X86ISD::VPERM2X128) - return false; // Nothing to do! + return SDValue(); // Nothing to do! MVT ShuffleVT = (FloatDomain ? MVT::v4f64 : MVT::v4i64); unsigned PermMask = 0; PermMask |= ((BaseMask[0] < 0 ? 0x8 : (BaseMask[0] & 1)) << 0); @@ -27642,9 +28770,7 @@ static bool combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, DAG.getUNDEF(ShuffleVT), DAG.getConstant(PermMask, DL, MVT::i8)); DCI.AddToWorklist(Res.getNode()); - DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, Res), - /*AddTo*/ true); - return true; + return DAG.getBitcast(RootVT, Res); } // For masks that have been widened to 128-bit elements or more, @@ -27653,7 +28779,7 @@ static bool combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, if (BaseMaskEltSizeInBits > 64) { assert((BaseMaskEltSizeInBits % 64) == 0 && "Illegal mask size"); int MaskScale = BaseMaskEltSizeInBits / 64; - scaleShuffleMask(MaskScale, BaseMask, Mask); + scaleShuffleMask<int>(MaskScale, BaseMask, Mask); } else { Mask = SmallVector<int, 64>(BaseMask.begin(), BaseMask.end()); } @@ -27669,7 +28795,7 @@ static bool combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, // Only allow legal mask types. if (!DAG.getTargetLoweringInfo().isTypeLegal(MaskVT)) - return false; + return SDValue(); // Attempt to match the mask against known shuffle patterns. MVT ShuffleSrcVT, ShuffleVT; @@ -27678,7 +28804,7 @@ static bool combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, // Which shuffle domains are permitted? // Permit domain crossing at higher combine depths. bool AllowFloatDomain = FloatDomain || (Depth > 3); - bool AllowIntDomain = (!FloatDomain || (Depth > 3)) && + bool AllowIntDomain = (!FloatDomain || (Depth > 3)) && Subtarget.hasSSE2() && (!MaskVT.is256BitVector() || Subtarget.hasAVX2()); // Determine zeroable mask elements. @@ -27697,82 +28823,72 @@ static bool combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, ArrayRef<int> HiMask(Mask.data() + Scale, NumMaskElts - Scale); if (isSequentialOrUndefInRange(Mask, 0, Scale, 0) && isUndefOrZeroOrInRange(HiMask, Scale, NumMaskElts)) { - DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, V1), - /*AddTo*/ true); - return true; + return DAG.getBitcast(RootVT, V1); } } + SDValue NewV1 = V1; // Save operand in case early exit happens. if (matchUnaryVectorShuffle(MaskVT, Mask, AllowFloatDomain, AllowIntDomain, - V1, DL, DAG, Subtarget, Shuffle, ShuffleSrcVT, - ShuffleVT)) { + NewV1, DL, DAG, Subtarget, Shuffle, + ShuffleSrcVT, ShuffleVT) && + (!IsEVEXShuffle || (NumRootElts == ShuffleVT.getVectorNumElements()))) { if (Depth == 1 && Root.getOpcode() == Shuffle) - return false; // Nothing to do! - if (IsEVEXShuffle && (NumRootElts != ShuffleVT.getVectorNumElements())) - return false; // AVX512 Writemask clash. - Res = DAG.getBitcast(ShuffleSrcVT, V1); + return SDValue(); // Nothing to do! + Res = DAG.getBitcast(ShuffleSrcVT, NewV1); DCI.AddToWorklist(Res.getNode()); Res = DAG.getNode(Shuffle, DL, ShuffleVT, Res); DCI.AddToWorklist(Res.getNode()); - DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, Res), - /*AddTo*/ true); - return true; + return DAG.getBitcast(RootVT, Res); } if (matchUnaryPermuteVectorShuffle(MaskVT, Mask, Zeroable, AllowFloatDomain, AllowIntDomain, Subtarget, Shuffle, - ShuffleVT, PermuteImm)) { + ShuffleVT, PermuteImm) && + (!IsEVEXShuffle || (NumRootElts == ShuffleVT.getVectorNumElements()))) { if (Depth == 1 && Root.getOpcode() == Shuffle) - return false; // Nothing to do! - if (IsEVEXShuffle && (NumRootElts != ShuffleVT.getVectorNumElements())) - return false; // AVX512 Writemask clash. + return SDValue(); // Nothing to do! Res = DAG.getBitcast(ShuffleVT, V1); DCI.AddToWorklist(Res.getNode()); Res = DAG.getNode(Shuffle, DL, ShuffleVT, Res, DAG.getConstant(PermuteImm, DL, MVT::i8)); DCI.AddToWorklist(Res.getNode()); - DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, Res), - /*AddTo*/ true); - return true; + return DAG.getBitcast(RootVT, Res); } } + SDValue NewV1 = V1; // Save operands in case early exit happens. + SDValue NewV2 = V2; if (matchBinaryVectorShuffle(MaskVT, Mask, AllowFloatDomain, AllowIntDomain, - V1, V2, DL, DAG, Subtarget, Shuffle, ShuffleVT, - UnaryShuffle)) { + NewV1, NewV2, DL, DAG, Subtarget, Shuffle, + ShuffleSrcVT, ShuffleVT, UnaryShuffle) && + (!IsEVEXShuffle || (NumRootElts == ShuffleVT.getVectorNumElements()))) { if (Depth == 1 && Root.getOpcode() == Shuffle) - return false; // Nothing to do! - if (IsEVEXShuffle && (NumRootElts != ShuffleVT.getVectorNumElements())) - return false; // AVX512 Writemask clash. - V1 = DAG.getBitcast(ShuffleVT, V1); - DCI.AddToWorklist(V1.getNode()); - V2 = DAG.getBitcast(ShuffleVT, V2); - DCI.AddToWorklist(V2.getNode()); - Res = DAG.getNode(Shuffle, DL, ShuffleVT, V1, V2); + return SDValue(); // Nothing to do! + NewV1 = DAG.getBitcast(ShuffleSrcVT, NewV1); + DCI.AddToWorklist(NewV1.getNode()); + NewV2 = DAG.getBitcast(ShuffleSrcVT, NewV2); + DCI.AddToWorklist(NewV2.getNode()); + Res = DAG.getNode(Shuffle, DL, ShuffleVT, NewV1, NewV2); DCI.AddToWorklist(Res.getNode()); - DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, Res), - /*AddTo*/ true); - return true; + return DAG.getBitcast(RootVT, Res); } - if (matchBinaryPermuteVectorShuffle(MaskVT, Mask, Zeroable, AllowFloatDomain, - AllowIntDomain, V1, V2, DL, DAG, - Subtarget, Shuffle, ShuffleVT, - PermuteImm)) { + NewV1 = V1; // Save operands in case early exit happens. + NewV2 = V2; + if (matchBinaryPermuteVectorShuffle( + MaskVT, Mask, Zeroable, AllowFloatDomain, AllowIntDomain, NewV1, + NewV2, DL, DAG, Subtarget, Shuffle, ShuffleVT, PermuteImm) && + (!IsEVEXShuffle || (NumRootElts == ShuffleVT.getVectorNumElements()))) { if (Depth == 1 && Root.getOpcode() == Shuffle) - return false; // Nothing to do! - if (IsEVEXShuffle && (NumRootElts != ShuffleVT.getVectorNumElements())) - return false; // AVX512 Writemask clash. - V1 = DAG.getBitcast(ShuffleVT, V1); - DCI.AddToWorklist(V1.getNode()); - V2 = DAG.getBitcast(ShuffleVT, V2); - DCI.AddToWorklist(V2.getNode()); - Res = DAG.getNode(Shuffle, DL, ShuffleVT, V1, V2, + return SDValue(); // Nothing to do! + NewV1 = DAG.getBitcast(ShuffleVT, NewV1); + DCI.AddToWorklist(NewV1.getNode()); + NewV2 = DAG.getBitcast(ShuffleVT, NewV2); + DCI.AddToWorklist(NewV2.getNode()); + Res = DAG.getNode(Shuffle, DL, ShuffleVT, NewV1, NewV2, DAG.getConstant(PermuteImm, DL, MVT::i8)); DCI.AddToWorklist(Res.getNode()); - DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, Res), - /*AddTo*/ true); - return true; + return DAG.getBitcast(RootVT, Res); } // Typically from here on, we need an integer version of MaskVT. @@ -27785,21 +28901,19 @@ static bool combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, if (matchVectorShuffleAsEXTRQ(IntMaskVT, V1, V2, Mask, BitLen, BitIdx, Zeroable)) { if (Depth == 1 && Root.getOpcode() == X86ISD::EXTRQI) - return false; // Nothing to do! + return SDValue(); // Nothing to do! V1 = DAG.getBitcast(IntMaskVT, V1); DCI.AddToWorklist(V1.getNode()); Res = DAG.getNode(X86ISD::EXTRQI, DL, IntMaskVT, V1, DAG.getConstant(BitLen, DL, MVT::i8), DAG.getConstant(BitIdx, DL, MVT::i8)); DCI.AddToWorklist(Res.getNode()); - DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, Res), - /*AddTo*/ true); - return true; + return DAG.getBitcast(RootVT, Res); } if (matchVectorShuffleAsINSERTQ(IntMaskVT, V1, V2, Mask, BitLen, BitIdx)) { if (Depth == 1 && Root.getOpcode() == X86ISD::INSERTQI) - return false; // Nothing to do! + return SDValue(); // Nothing to do! V1 = DAG.getBitcast(IntMaskVT, V1); DCI.AddToWorklist(V1.getNode()); V2 = DAG.getBitcast(IntMaskVT, V2); @@ -27808,23 +28922,25 @@ static bool combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, DAG.getConstant(BitLen, DL, MVT::i8), DAG.getConstant(BitIdx, DL, MVT::i8)); DCI.AddToWorklist(Res.getNode()); - DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, Res), - /*AddTo*/ true); - return true; + return DAG.getBitcast(RootVT, Res); } } // Don't try to re-form single instruction chains under any circumstances now // that we've done encoding canonicalization for them. if (Depth < 2) - return false; + return SDValue(); + + // Depth threshold above which we can efficiently use variable mask shuffles. + int VariableShuffleDepth = Subtarget.hasFastVariableShuffle() ? 2 : 3; + bool AllowVariableMask = (Depth >= VariableShuffleDepth) || HasVariableMask; bool MaskContainsZeros = any_of(Mask, [](int M) { return M == SM_SentinelZero; }); if (is128BitLaneCrossingShuffleMask(MaskVT, Mask)) { // If we have a single input lane-crossing shuffle then lower to VPERMV. - if (UnaryShuffle && (Depth >= 3 || HasVariableMask) && !MaskContainsZeros && + if (UnaryShuffle && AllowVariableMask && !MaskContainsZeros && ((Subtarget.hasAVX2() && (MaskVT == MVT::v8f32 || MaskVT == MVT::v8i32)) || (Subtarget.hasAVX512() && @@ -27840,14 +28956,12 @@ static bool combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, DCI.AddToWorklist(Res.getNode()); Res = DAG.getNode(X86ISD::VPERMV, DL, MaskVT, VPermMask, Res); DCI.AddToWorklist(Res.getNode()); - DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, Res), - /*AddTo*/ true); - return true; + return DAG.getBitcast(RootVT, Res); } // Lower a unary+zero lane-crossing shuffle as VPERMV3 with a zero // vector as the second source. - if (UnaryShuffle && (Depth >= 3 || HasVariableMask) && + if (UnaryShuffle && AllowVariableMask && ((Subtarget.hasAVX512() && (MaskVT == MVT::v8f64 || MaskVT == MVT::v8i64 || MaskVT == MVT::v16f32 || MaskVT == MVT::v16i32)) || @@ -27871,13 +28985,11 @@ static bool combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, DCI.AddToWorklist(Zero.getNode()); Res = DAG.getNode(X86ISD::VPERMV3, DL, MaskVT, Res, VPermMask, Zero); DCI.AddToWorklist(Res.getNode()); - DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, Res), - /*AddTo*/ true); - return true; + return DAG.getBitcast(RootVT, Res); } // If we have a dual input lane-crossing shuffle then lower to VPERMV3. - if ((Depth >= 3 || HasVariableMask) && !MaskContainsZeros && + if (AllowVariableMask && !MaskContainsZeros && ((Subtarget.hasAVX512() && (MaskVT == MVT::v8f64 || MaskVT == MVT::v8i64 || MaskVT == MVT::v16f32 || MaskVT == MVT::v16i32)) || @@ -27896,16 +29008,14 @@ static bool combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, DCI.AddToWorklist(V2.getNode()); Res = DAG.getNode(X86ISD::VPERMV3, DL, MaskVT, V1, VPermMask, V2); DCI.AddToWorklist(Res.getNode()); - DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, Res), - /*AddTo*/ true); - return true; + return DAG.getBitcast(RootVT, Res); } - return false; + return SDValue(); } // See if we can combine a single input shuffle with zeros to a bit-mask, // which is much simpler than any shuffle. - if (UnaryShuffle && MaskContainsZeros && (Depth >= 3 || HasVariableMask) && + if (UnaryShuffle && MaskContainsZeros && AllowVariableMask && isSequentialOrUndefOrZeroInRange(Mask, 0, NumMaskElts, 0) && DAG.getTargetLoweringInfo().isTypeLegal(MaskVT)) { APInt Zero = APInt::getNullValue(MaskEltSizeInBits); @@ -27930,15 +29040,13 @@ static bool combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, FloatDomain ? unsigned(X86ISD::FAND) : unsigned(ISD::AND); Res = DAG.getNode(AndOpcode, DL, MaskVT, Res, BitMask); DCI.AddToWorklist(Res.getNode()); - DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, Res), - /*AddTo*/ true); - return true; + return DAG.getBitcast(RootVT, Res); } // If we have a single input shuffle with different shuffle patterns in the // the 128-bit lanes use the variable mask to VPERMILPS. // TODO Combine other mask types at higher depths. - if (UnaryShuffle && HasVariableMask && !MaskContainsZeros && + if (UnaryShuffle && AllowVariableMask && !MaskContainsZeros && ((MaskVT == MVT::v8f32 && Subtarget.hasAVX()) || (MaskVT == MVT::v16f32 && Subtarget.hasAVX512()))) { SmallVector<SDValue, 16> VPermIdx; @@ -27953,14 +29061,12 @@ static bool combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, DCI.AddToWorklist(Res.getNode()); Res = DAG.getNode(X86ISD::VPERMILPV, DL, MaskVT, Res, VPermMask); DCI.AddToWorklist(Res.getNode()); - DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, Res), - /*AddTo*/ true); - return true; + return DAG.getBitcast(RootVT, Res); } // With XOP, binary shuffles of 128/256-bit floating point vectors can combine // to VPERMIL2PD/VPERMIL2PS. - if ((Depth >= 3 || HasVariableMask) && Subtarget.hasXOP() && + if (AllowVariableMask && Subtarget.hasXOP() && (MaskVT == MVT::v2f64 || MaskVT == MVT::v4f64 || MaskVT == MVT::v4f32 || MaskVT == MVT::v8f32)) { // VPERMIL2 Operation. @@ -27994,9 +29100,7 @@ static bool combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, Res = DAG.getNode(X86ISD::VPERMIL2, DL, MaskVT, V1, V2, VPerm2MaskOp, DAG.getConstant(M2ZImm, DL, MVT::i8)); DCI.AddToWorklist(Res.getNode()); - DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, Res), - /*AddTo*/ true); - return true; + return DAG.getBitcast(RootVT, Res); } // If we have 3 or more shuffle instructions or a chain involving a variable @@ -28004,7 +29108,7 @@ static bool combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, // Intel's manuals suggest only using PSHUFB if doing so replacing 5 // instructions, but in practice PSHUFB tends to be *very* fast so we're // more aggressive. - if (UnaryShuffle && (Depth >= 3 || HasVariableMask) && + if (UnaryShuffle && AllowVariableMask && ((RootVT.is128BitVector() && Subtarget.hasSSSE3()) || (RootVT.is256BitVector() && Subtarget.hasAVX2()) || (RootVT.is512BitVector() && Subtarget.hasBWI()))) { @@ -28022,7 +29126,7 @@ static bool combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, continue; } M = Ratio * M + i % Ratio; - assert ((M / 16) == (i / 16) && "Lane crossing detected"); + assert((M / 16) == (i / 16) && "Lane crossing detected"); PSHUFBMask.push_back(DAG.getConstant(M, DL, MVT::i8)); } MVT ByteVT = MVT::getVectorVT(MVT::i8, NumBytes); @@ -28032,16 +29136,13 @@ static bool combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, DCI.AddToWorklist(PSHUFBMaskOp.getNode()); Res = DAG.getNode(X86ISD::PSHUFB, DL, ByteVT, Res, PSHUFBMaskOp); DCI.AddToWorklist(Res.getNode()); - DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, Res), - /*AddTo*/ true); - return true; + return DAG.getBitcast(RootVT, Res); } // With XOP, if we have a 128-bit binary input shuffle we can always combine // to VPPERM. We match the depth requirement of PSHUFB - VPPERM is never // slower than PSHUFB on targets that support both. - if ((Depth >= 3 || HasVariableMask) && RootVT.is128BitVector() && - Subtarget.hasXOP()) { + if (AllowVariableMask && RootVT.is128BitVector() && Subtarget.hasXOP()) { // VPPERM Mask Operation // Bits[4:0] - Byte Index (0 - 31) // Bits[7:5] - Permute Operation (0 - Source byte, 4 - ZERO) @@ -28070,23 +29171,22 @@ static bool combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root, DCI.AddToWorklist(VPPERMMaskOp.getNode()); Res = DAG.getNode(X86ISD::VPPERM, DL, ByteVT, V1, V2, VPPERMMaskOp); DCI.AddToWorklist(Res.getNode()); - DCI.CombineTo(Root.getNode(), DAG.getBitcast(RootVT, Res), - /*AddTo*/ true); - return true; + return DAG.getBitcast(RootVT, Res); } // Failed to find any combines. - return false; + return SDValue(); } // Attempt to constant fold all of the constant source ops. // Returns true if the entire shuffle is folded to a constant. // TODO: Extend this to merge multiple constant Ops and update the mask. -static bool combineX86ShufflesConstants(const SmallVectorImpl<SDValue> &Ops, - ArrayRef<int> Mask, SDValue Root, - bool HasVariableMask, SelectionDAG &DAG, - TargetLowering::DAGCombinerInfo &DCI, - const X86Subtarget &Subtarget) { +static SDValue combineX86ShufflesConstants(const SmallVectorImpl<SDValue> &Ops, + ArrayRef<int> Mask, SDValue Root, + bool HasVariableMask, + SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI, + const X86Subtarget &Subtarget) { MVT VT = Root.getSimpleValueType(); unsigned SizeInBits = VT.getSizeInBits(); @@ -28103,14 +29203,14 @@ static bool combineX86ShufflesConstants(const SmallVectorImpl<SDValue> &Ops, OneUseConstantOp |= SrcOp.hasOneUse(); if (!getTargetConstantBitsFromNode(SrcOp, MaskSizeInBits, UndefEltsOps[i], RawBitsOps[i])) - return false; + return SDValue(); } // Only fold if at least one of the constants is only used once or // the combined shuffle has included a variable mask shuffle, this // is to avoid constant pool bloat. if (!OneUseConstantOp && !HasVariableMask) - return false; + return SDValue(); // Shuffle the constant bits according to the mask. APInt UndefElts(NumMaskElts, 0); @@ -28162,8 +29262,7 @@ static bool combineX86ShufflesConstants(const SmallVectorImpl<SDValue> &Ops, SDLoc DL(Root); SDValue CstOp = getConstVector(ConstantBitData, UndefElts, MaskVT, DAG, DL); DCI.AddToWorklist(CstOp.getNode()); - DCI.CombineTo(Root.getNode(), DAG.getBitcast(VT, CstOp)); - return true; + return DAG.getBitcast(VT, CstOp); } /// \brief Fully generic combining of x86 shuffle instructions. @@ -28195,18 +29294,15 @@ static bool combineX86ShufflesConstants(const SmallVectorImpl<SDValue> &Ops, /// would simplify under the threshold for PSHUFB formation because of /// combine-ordering. To fix this, we should do the redundant instruction /// combining in this recursive walk. -static bool combineX86ShufflesRecursively(ArrayRef<SDValue> SrcOps, - int SrcOpIndex, SDValue Root, - ArrayRef<int> RootMask, - ArrayRef<const SDNode*> SrcNodes, - int Depth, bool HasVariableMask, - SelectionDAG &DAG, - TargetLowering::DAGCombinerInfo &DCI, - const X86Subtarget &Subtarget) { +static SDValue combineX86ShufflesRecursively( + ArrayRef<SDValue> SrcOps, int SrcOpIndex, SDValue Root, + ArrayRef<int> RootMask, ArrayRef<const SDNode *> SrcNodes, int Depth, + bool HasVariableMask, SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { // Bound the depth of our recursive combine because this is ultimately // quadratic in nature. if (Depth > 8) - return false; + return SDValue(); // Directly rip through bitcasts to find the underlying operand. SDValue Op = SrcOps[SrcOpIndex]; @@ -28214,7 +29310,7 @@ static bool combineX86ShufflesRecursively(ArrayRef<SDValue> SrcOps, MVT VT = Op.getSimpleValueType(); if (!VT.isVector()) - return false; // Bail if we hit a non-vector. + return SDValue(); // Bail if we hit a non-vector. assert(Root.getSimpleValueType().isVector() && "Shuffles operate on vector types!"); @@ -28225,7 +29321,7 @@ static bool combineX86ShufflesRecursively(ArrayRef<SDValue> SrcOps, SmallVector<int, 64> OpMask; SmallVector<SDValue, 2> OpInputs; if (!resolveTargetShuffleInputs(Op, OpInputs, OpMask, DAG)) - return false; + return SDValue(); assert(OpInputs.size() <= 2 && "Too many shuffle inputs"); SDValue Input0 = (OpInputs.size() > 0 ? OpInputs[0] : SDValue()); @@ -28334,18 +29430,15 @@ static bool combineX86ShufflesRecursively(ArrayRef<SDValue> SrcOps, } // Handle the all undef/zero cases early. - if (all_of(Mask, [](int Idx) { return Idx == SM_SentinelUndef; })) { - DCI.CombineTo(Root.getNode(), DAG.getUNDEF(Root.getValueType())); - return true; - } - if (all_of(Mask, [](int Idx) { return Idx < 0; })) { - // TODO - should we handle the mixed zero/undef case as well? Just returning - // a zero mask will lose information on undef elements possibly reducing - // future combine possibilities. - DCI.CombineTo(Root.getNode(), getZeroVector(Root.getSimpleValueType(), - Subtarget, DAG, SDLoc(Root))); - return true; - } + if (all_of(Mask, [](int Idx) { return Idx == SM_SentinelUndef; })) + return DAG.getUNDEF(Root.getValueType()); + + // TODO - should we handle the mixed zero/undef case as well? Just returning + // a zero mask will lose information on undef elements possibly reducing + // future combine possibilities. + if (all_of(Mask, [](int Idx) { return Idx < 0; })) + return getZeroVector(Root.getSimpleValueType(), Subtarget, DAG, + SDLoc(Root)); // Remove unused shuffle source ops. resolveTargetShuffleInputsAndMask(Ops, Mask); @@ -28364,19 +29457,19 @@ static bool combineX86ShufflesRecursively(ArrayRef<SDValue> SrcOps, for (int i = 0, e = Ops.size(); i < e; ++i) if (Ops[i].getNode()->hasOneUse() || SDNode::areOnlyUsersOf(CombinedNodes, Ops[i].getNode())) - if (combineX86ShufflesRecursively(Ops, i, Root, Mask, CombinedNodes, - Depth + 1, HasVariableMask, DAG, DCI, - Subtarget)) - return true; + if (SDValue Res = combineX86ShufflesRecursively( + Ops, i, Root, Mask, CombinedNodes, Depth + 1, HasVariableMask, + DAG, DCI, Subtarget)) + return Res; // Attempt to constant fold all of the constant source ops. - if (combineX86ShufflesConstants(Ops, Mask, Root, HasVariableMask, DAG, DCI, - Subtarget)) - return true; + if (SDValue Cst = combineX86ShufflesConstants( + Ops, Mask, Root, HasVariableMask, DAG, DCI, Subtarget)) + return Cst; // We can only combine unary and binary shuffle mask cases. if (Ops.size() > 2) - return false; + return SDValue(); // Minor canonicalization of the accumulated shuffle mask to make it easier // to match below. All this does is detect masks with sequential pairs of @@ -28395,6 +29488,7 @@ static bool combineX86ShufflesRecursively(ArrayRef<SDValue> SrcOps, std::swap(Ops[0], Ops[1]); } + // Finally, try to combine into a single shuffle instruction. return combineX86ShuffleChain(Ops, Root, Mask, Depth, HasVariableMask, DAG, DCI, Subtarget); } @@ -28650,8 +29744,37 @@ static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG, SDLoc DL(N); MVT VT = N.getSimpleValueType(); SmallVector<int, 4> Mask; - unsigned Opcode = N.getOpcode(); + + // Combine binary shuffle of 2 similar 'Horizontal' instructions into a + // single instruction. + if (VT.getScalarSizeInBits() == 64 && + (Opcode == X86ISD::MOVSD || Opcode == X86ISD::UNPCKH || + Opcode == X86ISD::UNPCKL)) { + auto BC0 = peekThroughBitcasts(N.getOperand(0)); + auto BC1 = peekThroughBitcasts(N.getOperand(1)); + EVT VT0 = BC0.getValueType(); + EVT VT1 = BC1.getValueType(); + unsigned Opcode0 = BC0.getOpcode(); + unsigned Opcode1 = BC1.getOpcode(); + if (Opcode0 == Opcode1 && VT0 == VT1 && + (Opcode0 == X86ISD::FHADD || Opcode0 == X86ISD::HADD || + Opcode0 == X86ISD::FHSUB || Opcode0 == X86ISD::HSUB || + Opcode0 == X86ISD::PACKSS || Opcode0 == X86ISD::PACKUS)) { + SDValue Lo, Hi; + if (Opcode == X86ISD::MOVSD) { + Lo = BC1.getOperand(0); + Hi = BC0.getOperand(1); + } else { + Lo = BC0.getOperand(Opcode == X86ISD::UNPCKH ? 1 : 0); + Hi = BC1.getOperand(Opcode == X86ISD::UNPCKH ? 1 : 0); + } + SDValue Horiz = DAG.getNode(Opcode0, DL, VT0, Lo, Hi); + DCI.AddToWorklist(Horiz.getNode()); + return DAG.getBitcast(VT, Horiz); + } + } + switch (Opcode) { case X86ISD::PSHUFD: case X86ISD::PSHUFLW: @@ -28660,17 +29783,6 @@ static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG, assert(Mask.size() == 4); break; case X86ISD::UNPCKL: { - auto Op0 = N.getOperand(0); - auto Op1 = N.getOperand(1); - unsigned Opcode0 = Op0.getOpcode(); - unsigned Opcode1 = Op1.getOpcode(); - - // Combine X86ISD::UNPCKL with 2 X86ISD::FHADD inputs into a single - // X86ISD::FHADD. This is generated by UINT_TO_FP v2f64 scalarization. - // TODO: Add other horizontal operations as required. - if (VT == MVT::v2f64 && Opcode0 == Opcode1 && Opcode0 == X86ISD::FHADD) - return DAG.getNode(Opcode0, DL, VT, Op0.getOperand(0), Op1.getOperand(0)); - // Combine X86ISD::UNPCKL and ISD::VECTOR_SHUFFLE into X86ISD::UNPCKH, in // which X86ISD::UNPCKL has a ISD::UNDEF operand, and ISD::VECTOR_SHUFFLE // moves upper half elements into the lower half part. For example: @@ -28688,7 +29800,9 @@ static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG, if (!VT.is128BitVector()) return SDValue(); - if (Op0.isUndef() && Opcode1 == ISD::VECTOR_SHUFFLE) { + auto Op0 = N.getOperand(0); + auto Op1 = N.getOperand(1); + if (Op0.isUndef() && Op1.getOpcode() == ISD::VECTOR_SHUFFLE) { ArrayRef<int> Mask = cast<ShuffleVectorSDNode>(Op1.getNode())->getMask(); unsigned NumElts = VT.getVectorNumElements(); @@ -28922,17 +30036,18 @@ static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG, return SDValue(); } -/// Returns true iff the shuffle node \p N can be replaced with ADDSUB -/// operation. If true is returned then the operands of ADDSUB operation +/// Returns true iff the shuffle node \p N can be replaced with ADDSUB(SUBADD) +/// operation. If true is returned then the operands of ADDSUB(SUBADD) operation /// are written to the parameters \p Opnd0 and \p Opnd1. /// -/// We combine shuffle to ADDSUB directly on the abstract vector shuffle nodes +/// We combine shuffle to ADDSUB(SUBADD) directly on the abstract vector shuffle nodes /// so it is easier to generically match. We also insert dummy vector shuffle /// nodes for the operands which explicitly discard the lanes which are unused /// by this operation to try to flow through the rest of the combiner /// the fact that they're unused. -static bool isAddSub(SDNode *N, const X86Subtarget &Subtarget, - SDValue &Opnd0, SDValue &Opnd1) { +static bool isAddSubOrSubAdd(SDNode *N, const X86Subtarget &Subtarget, + SDValue &Opnd0, SDValue &Opnd1, + bool matchSubAdd = false) { EVT VT = N->getValueType(0); if ((!Subtarget.hasSSE3() || (VT != MVT::v4f32 && VT != MVT::v2f64)) && @@ -28952,12 +30067,15 @@ static bool isAddSub(SDNode *N, const X86Subtarget &Subtarget, SDValue V1 = N->getOperand(0); SDValue V2 = N->getOperand(1); - // We require the first shuffle operand to be the FSUB node, and the second to - // be the FADD node. - if (V1.getOpcode() == ISD::FADD && V2.getOpcode() == ISD::FSUB) { + unsigned ExpectedOpcode = matchSubAdd ? ISD::FADD : ISD::FSUB; + unsigned NextExpectedOpcode = matchSubAdd ? ISD::FSUB : ISD::FADD; + + // We require the first shuffle operand to be the ExpectedOpcode node, + // and the second to be the NextExpectedOpcode node. + if (V1.getOpcode() == NextExpectedOpcode && V2.getOpcode() == ExpectedOpcode) { ShuffleVectorSDNode::commuteMask(Mask); std::swap(V1, V2); - } else if (V1.getOpcode() != ISD::FSUB || V2.getOpcode() != ISD::FADD) + } else if (V1.getOpcode() != ExpectedOpcode || V2.getOpcode() != NextExpectedOpcode) return false; // If there are other uses of these operations we can't fold them. @@ -28991,7 +30109,7 @@ static SDValue combineShuffleToAddSubOrFMAddSub(SDNode *N, const X86Subtarget &Subtarget, SelectionDAG &DAG) { SDValue Opnd0, Opnd1; - if (!isAddSub(N, Subtarget, Opnd0, Opnd1)) + if (!isAddSubOrSubAdd(N, Subtarget, Opnd0, Opnd1)) return SDValue(); EVT VT = N->getValueType(0); @@ -28999,7 +30117,7 @@ static SDValue combineShuffleToAddSubOrFMAddSub(SDNode *N, // Try to generate X86ISD::FMADDSUB node here. SDValue Opnd2; - if (isFMAddSub(Subtarget, DAG, Opnd0, Opnd1, Opnd2)) + if (isFMAddSubOrFMSubAdd(Subtarget, DAG, Opnd0, Opnd1, Opnd2, 2)) return DAG.getNode(X86ISD::FMADDSUB, DL, VT, Opnd0, Opnd1, Opnd2); // Do not generate X86ISD::ADDSUB node for 512-bit types even though @@ -29011,6 +30129,26 @@ static SDValue combineShuffleToAddSubOrFMAddSub(SDNode *N, return DAG.getNode(X86ISD::ADDSUB, DL, VT, Opnd0, Opnd1); } +/// \brief Try to combine a shuffle into a target-specific +/// mul-sub-add node. +static SDValue combineShuffleToFMSubAdd(SDNode *N, + const X86Subtarget &Subtarget, + SelectionDAG &DAG) { + SDValue Opnd0, Opnd1; + if (!isAddSubOrSubAdd(N, Subtarget, Opnd0, Opnd1, true)) + return SDValue(); + + EVT VT = N->getValueType(0); + SDLoc DL(N); + + // Try to generate X86ISD::FMSUBADD node here. + SDValue Opnd2; + if (isFMAddSubOrFMSubAdd(Subtarget, DAG, Opnd0, Opnd1, Opnd2, 2)) + return DAG.getNode(X86ISD::FMSUBADD, DL, VT, Opnd0, Opnd1, Opnd2); + + return SDValue(); +} + // We are looking for a shuffle where both sources are concatenated with undef // and have a width that is half of the output's width. AVX2 has VPERMD/Q, so // if we can express this as a single-source shuffle, that's preferable. @@ -29056,6 +30194,40 @@ static SDValue combineShuffleOfConcatUndef(SDNode *N, SelectionDAG &DAG, return DAG.getVectorShuffle(VT, DL, Concat, DAG.getUNDEF(VT), Mask); } +/// Eliminate a redundant shuffle of a horizontal math op. +static SDValue foldShuffleOfHorizOp(SDNode *N) { + if (N->getOpcode() != ISD::VECTOR_SHUFFLE || !N->getOperand(1).isUndef()) + return SDValue(); + + SDValue HOp = N->getOperand(0); + if (HOp.getOpcode() != X86ISD::HADD && HOp.getOpcode() != X86ISD::FHADD && + HOp.getOpcode() != X86ISD::HSUB && HOp.getOpcode() != X86ISD::FHSUB) + return SDValue(); + + // 128-bit horizontal math instructions are defined to operate on adjacent + // lanes of each operand as: + // v4X32: A[0] + A[1] , A[2] + A[3] , B[0] + B[1] , B[2] + B[3] + // ...similarly for v2f64 and v8i16. + // TODO: 256-bit is not the same because...x86. + if (HOp.getOperand(0) != HOp.getOperand(1) || HOp.getValueSizeInBits() != 128) + return SDValue(); + + // When the operands of a horizontal math op are identical, the low half of + // the result is the same as the high half. If the shuffle is also replicating + // low and high halves, we don't need the shuffle. + // shuffle (hadd X, X), undef, [low half...high half] --> hadd X, X + ArrayRef<int> Mask = cast<ShuffleVectorSDNode>(N)->getMask(); + // TODO: Other mask possibilities like {1,1} and {1,0} could be added here, + // but this should be tied to whatever horizontal op matching and shuffle + // canonicalization are producing. + if (isTargetShuffleEquivalent(Mask, { 0, 0 }) || + isTargetShuffleEquivalent(Mask, { 0, 1, 0, 1 }) || + isTargetShuffleEquivalent(Mask, { 0, 1, 2, 3, 0, 1, 2, 3 })) + return HOp; + + return SDValue(); +} + static SDValue combineShuffle(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { @@ -29063,11 +30235,18 @@ static SDValue combineShuffle(SDNode *N, SelectionDAG &DAG, EVT VT = N->getValueType(0); const TargetLowering &TLI = DAG.getTargetLoweringInfo(); // If we have legalized the vector types, look for blends of FADD and FSUB - // nodes that we can fuse into an ADDSUB node. - if (TLI.isTypeLegal(VT)) + // nodes that we can fuse into an ADDSUB, FMADDSUB, or FMSUBADD node. + if (TLI.isTypeLegal(VT)) { if (SDValue AddSub = combineShuffleToAddSubOrFMAddSub(N, Subtarget, DAG)) return AddSub; + if (SDValue FMSubAdd = combineShuffleToFMSubAdd(N, Subtarget, DAG)) + return FMSubAdd; + + if (SDValue HAddSub = foldShuffleOfHorizOp(N)) + return HAddSub; + } + // During Type Legalization, when promoting illegal vector types, // the backend might introduce new shuffle dag nodes and bitcasts. // @@ -29165,12 +30344,12 @@ static SDValue combineShuffle(SDNode *N, SelectionDAG &DAG, // specific PSHUF instruction sequences into their minimal form so that we // can evaluate how many specialized shuffle instructions are involved in // a particular chain. - SmallVector<int, 1> NonceMask; // Just a placeholder. - NonceMask.push_back(0); - if (combineX86ShufflesRecursively({Op}, 0, Op, NonceMask, {}, - /*Depth*/ 1, /*HasVarMask*/ false, DAG, - DCI, Subtarget)) - return SDValue(); // This routine will use CombineTo to replace N. + if (SDValue Res = combineX86ShufflesRecursively( + {Op}, 0, Op, {0}, {}, /*Depth*/ 1, + /*HasVarMask*/ false, DAG, DCI, Subtarget)) { + DCI.CombineTo(N, Res); + return SDValue(); + } } return SDValue(); @@ -29300,7 +30479,7 @@ static SDValue combineBitcastvxi1(SelectionDAG &DAG, SDValue BitCast, // v8i16 and v16i16. // For these two cases, we can shuffle the upper element bytes to a // consecutive sequence at the start of the vector and treat the results as - // v16i8 or v32i8, and for v61i8 this is the preferable solution. However, + // v16i8 or v32i8, and for v16i8 this is the preferable solution. However, // for v16i16 this is not the case, because the shuffle is expensive, so we // avoid sign-extending to this type entirely. // For example, t0 := (v8i16 sext(v8i1 x)) needs to be shuffled as: @@ -29319,9 +30498,8 @@ static SDValue combineBitcastvxi1(SelectionDAG &DAG, SDValue BitCast, FPCastVT = MVT::v4f32; // For cases such as (i4 bitcast (v4i1 setcc v4i64 v1, v2)) // sign-extend to a 256-bit operation to avoid truncation. - if (N0->getOpcode() == ISD::SETCC && - N0->getOperand(0)->getValueType(0).is256BitVector() && - Subtarget.hasInt256()) { + if (N0->getOpcode() == ISD::SETCC && Subtarget.hasAVX() && + N0->getOperand(0).getValueType().is256BitVector()) { SExtVT = MVT::v4i64; FPCastVT = MVT::v4f64; } @@ -29333,9 +30511,9 @@ static SDValue combineBitcastvxi1(SelectionDAG &DAG, SDValue BitCast, // If the setcc operand is 128-bit, prefer sign-extending to 128-bit over // 256-bit because the shuffle is cheaper than sign extending the result of // the compare. - if (N0->getOpcode() == ISD::SETCC && - N0->getOperand(0)->getValueType(0).is256BitVector() && - Subtarget.hasInt256()) { + if (N0->getOpcode() == ISD::SETCC && Subtarget.hasAVX() && + (N0->getOperand(0).getValueType().is256BitVector() || + N0->getOperand(0).getValueType().is512BitVector())) { SExtVT = MVT::v8i32; FPCastVT = MVT::v8f32; } @@ -29348,23 +30526,34 @@ static SDValue combineBitcastvxi1(SelectionDAG &DAG, SDValue BitCast, // truncating the result of the compare to 128-bits. break; case MVT::v32i1: - // TODO: Handle pre-AVX2 cases by splitting to two v16i1's. - if (!Subtarget.hasInt256()) - return SDValue(); SExtVT = MVT::v32i8; break; }; SDLoc DL(BitCast); SDValue V = DAG.getSExtOrTrunc(N0, DL, SExtVT); + + if (SExtVT == MVT::v32i8 && !Subtarget.hasInt256()) { + // Handle pre-AVX2 cases by splitting to two v16i1's. + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + MVT ShiftTy = TLI.getScalarShiftAmountTy(DAG.getDataLayout(), MVT::i32); + SDValue Lo = extract128BitVector(V, 0, DAG, DL); + SDValue Hi = extract128BitVector(V, 16, DAG, DL); + Lo = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Lo); + Hi = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Hi); + Hi = DAG.getNode(ISD::SHL, DL, MVT::i32, Hi, + DAG.getConstant(16, DL, ShiftTy)); + V = DAG.getNode(ISD::OR, DL, MVT::i32, Lo, Hi); + return DAG.getZExtOrTrunc(V, DL, VT); + } + if (SExtVT == MVT::v8i16) { - V = DAG.getBitcast(MVT::v16i8, V); - V = DAG.getVectorShuffle( - MVT::v16i8, DL, V, DAG.getUNDEF(MVT::v16i8), - {0, 2, 4, 6, 8, 10, 12, 14, -1, -1, -1, -1, -1, -1, -1, -1}); + assert(16 == DAG.ComputeNumSignBits(V) && "Expected all/none bit vector"); + V = DAG.getNode(X86ISD::PACKSS, DL, MVT::v16i8, V, + DAG.getUNDEF(MVT::v8i16)); } else assert(SExtVT.getScalarType() != MVT::i16 && - "Vectors of i16 must be shuffled"); + "Vectors of i16 must be packed"); if (FPCastVT != MVT::INVALID_SIMPLE_VALUE_TYPE) V = DAG.getBitcast(FPCastVT, V); V = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, V); @@ -29384,9 +30573,35 @@ static SDValue combineBitcast(SDNode *N, SelectionDAG &DAG, // (i16 movmsk (16i8 sext (v16i1 x))) // before the setcc result is scalarized on subtargets that don't have legal // vxi1 types. - if (DCI.isBeforeLegalize()) + if (DCI.isBeforeLegalize()) { if (SDValue V = combineBitcastvxi1(DAG, SDValue(N, 0), Subtarget)) return V; + + // If this is a bitcast between a MVT::v4i1/v2i1 and an illegal integer + // type, widen both sides to avoid a trip through memory. + if ((VT == MVT::v4i1 || VT == MVT::v2i1) && SrcVT.isScalarInteger() && + Subtarget.hasVLX()) { + SDLoc dl(N); + N0 = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i8, N0); + N0 = DAG.getBitcast(MVT::v8i1, N0); + return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT, N0, + DAG.getIntPtrConstant(0, dl)); + } + + // If this is a bitcast between a MVT::v4i1/v2i1 and an illegal integer + // type, widen both sides to avoid a trip through memory. + if ((SrcVT == MVT::v4i1 || SrcVT == MVT::v2i1) && VT.isScalarInteger() && + Subtarget.hasVLX()) { + SDLoc dl(N); + unsigned NumConcats = 8 / SrcVT.getVectorNumElements(); + SmallVector<SDValue, 4> Ops(NumConcats, DAG.getUNDEF(SrcVT)); + Ops[0] = N0; + N0 = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v8i1, Ops); + N0 = DAG.getBitcast(MVT::i8, N0); + return DAG.getNode(ISD::TRUNCATE, dl, VT, N0); + } + } + // Since MMX types are special and don't usually play with other vector types, // it's better to handle them early to be sure we emit efficient code by // avoiding store-load conversions. @@ -29463,16 +30678,22 @@ static SDValue combineBitcast(SDNode *N, SelectionDAG &DAG, // the elements of a vector. // Returns the vector that is being reduced on, or SDValue() if a reduction // was not matched. -static SDValue matchBinOpReduction(SDNode *Extract, ISD::NodeType BinOp) { +static SDValue matchBinOpReduction(SDNode *Extract, unsigned &BinOp, + ArrayRef<ISD::NodeType> CandidateBinOps) { // The pattern must end in an extract from index 0. if ((Extract->getOpcode() != ISD::EXTRACT_VECTOR_ELT) || !isNullConstant(Extract->getOperand(1))) return SDValue(); - unsigned Stages = - Log2_32(Extract->getOperand(0).getValueType().getVectorNumElements()); - SDValue Op = Extract->getOperand(0); + unsigned Stages = Log2_32(Op.getValueType().getVectorNumElements()); + + // Match against one of the candidate binary ops. + if (llvm::none_of(CandidateBinOps, [Op](ISD::NodeType BinOp) { + return Op.getOpcode() == unsigned(BinOp); + })) + return SDValue(); + // At each stage, we're looking for something that looks like: // %s = shufflevector <8 x i32> %op, <8 x i32> undef, // <8 x i32> <i32 2, i32 3, i32 undef, i32 undef, @@ -29483,8 +30704,9 @@ static SDValue matchBinOpReduction(SDNode *Extract, ISD::NodeType BinOp) { // <4,5,6,7,u,u,u,u> // <2,3,u,u,u,u,u,u> // <1,u,u,u,u,u,u,u> + unsigned CandidateBinOp = Op.getOpcode(); for (unsigned i = 0; i < Stages; ++i) { - if (Op.getOpcode() != BinOp) + if (Op.getOpcode() != CandidateBinOp) return SDValue(); ShuffleVectorSDNode *Shuffle = @@ -29497,8 +30719,8 @@ static SDValue matchBinOpReduction(SDNode *Extract, ISD::NodeType BinOp) { } // The first operand of the shuffle should be the same as the other operand - // of the add. - if (!Shuffle || (Shuffle->getOperand(0) != Op)) + // of the binop. + if (!Shuffle || Shuffle->getOperand(0) != Op) return SDValue(); // Verify the shuffle has the expected (at this stage of the pyramid) mask. @@ -29507,6 +30729,7 @@ static SDValue matchBinOpReduction(SDNode *Extract, ISD::NodeType BinOp) { return SDValue(); } + BinOp = CandidateBinOp; return Op; } @@ -29552,8 +30775,7 @@ static bool detectZextAbsDiff(const SDValue &Select, SDValue &Op0, // In SetLT case, The second operand of the comparison can be either 1 or 0. APInt SplatVal; if ((CC == ISD::SETLT) && - !((ISD::isConstantSplatVector(SetCC.getOperand(1).getNode(), SplatVal, - /*AllowShrink*/false) && + !((ISD::isConstantSplatVector(SetCC.getOperand(1).getNode(), SplatVal) && SplatVal.isOneValue()) || (ISD::isBuildVectorAllZeros(SetCC.getOperand(1).getNode())))) return false; @@ -29606,6 +30828,84 @@ static SDValue createPSADBW(SelectionDAG &DAG, const SDValue &Zext0, return DAG.getNode(X86ISD::PSADBW, DL, SadVT, SadOp0, SadOp1); } +// Attempt to replace an min/max v8i16/v16i8 horizontal reduction with +// PHMINPOSUW. +static SDValue combineHorizontalMinMaxResult(SDNode *Extract, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + // Bail without SSE41. + if (!Subtarget.hasSSE41()) + return SDValue(); + + EVT ExtractVT = Extract->getValueType(0); + if (ExtractVT != MVT::i16 && ExtractVT != MVT::i8) + return SDValue(); + + // Check for SMAX/SMIN/UMAX/UMIN horizontal reduction patterns. + unsigned BinOp; + SDValue Src = matchBinOpReduction( + Extract, BinOp, {ISD::SMAX, ISD::SMIN, ISD::UMAX, ISD::UMIN}); + if (!Src) + return SDValue(); + + EVT SrcVT = Src.getValueType(); + EVT SrcSVT = SrcVT.getScalarType(); + if (SrcSVT != ExtractVT || (SrcVT.getSizeInBits() % 128) != 0) + return SDValue(); + + SDLoc DL(Extract); + SDValue MinPos = Src; + + // First, reduce the source down to 128-bit, applying BinOp to lo/hi. + while (SrcVT.getSizeInBits() > 128) { + unsigned NumElts = SrcVT.getVectorNumElements(); + unsigned NumSubElts = NumElts / 2; + SrcVT = EVT::getVectorVT(*DAG.getContext(), SrcSVT, NumSubElts); + unsigned SubSizeInBits = SrcVT.getSizeInBits(); + SDValue Lo = extractSubVector(MinPos, 0, DAG, DL, SubSizeInBits); + SDValue Hi = extractSubVector(MinPos, NumSubElts, DAG, DL, SubSizeInBits); + MinPos = DAG.getNode(BinOp, DL, SrcVT, Lo, Hi); + } + assert(((SrcVT == MVT::v8i16 && ExtractVT == MVT::i16) || + (SrcVT == MVT::v16i8 && ExtractVT == MVT::i8)) && + "Unexpected value type"); + + // PHMINPOSUW applies to UMIN(v8i16), for SMIN/SMAX/UMAX we must apply a mask + // to flip the value accordingly. + SDValue Mask; + unsigned MaskEltsBits = ExtractVT.getSizeInBits(); + if (BinOp == ISD::SMAX) + Mask = DAG.getConstant(APInt::getSignedMaxValue(MaskEltsBits), DL, SrcVT); + else if (BinOp == ISD::SMIN) + Mask = DAG.getConstant(APInt::getSignedMinValue(MaskEltsBits), DL, SrcVT); + else if (BinOp == ISD::UMAX) + Mask = DAG.getConstant(APInt::getAllOnesValue(MaskEltsBits), DL, SrcVT); + + if (Mask) + MinPos = DAG.getNode(ISD::XOR, DL, SrcVT, Mask, MinPos); + + // For v16i8 cases we need to perform UMIN on pairs of byte elements, + // shuffling each upper element down and insert zeros. This means that the + // v16i8 UMIN will leave the upper element as zero, performing zero-extension + // ready for the PHMINPOS. + if (ExtractVT == MVT::i8) { + SDValue Upper = DAG.getVectorShuffle( + SrcVT, DL, MinPos, getZeroVector(MVT::v16i8, Subtarget, DAG, DL), + {1, 16, 3, 16, 5, 16, 7, 16, 9, 16, 11, 16, 13, 16, 15, 16}); + MinPos = DAG.getNode(ISD::UMIN, DL, SrcVT, MinPos, Upper); + } + + // Perform the PHMINPOS on a v8i16 vector, + MinPos = DAG.getBitcast(MVT::v8i16, MinPos); + MinPos = DAG.getNode(X86ISD::PHMINPOS, DL, MVT::v8i16, MinPos); + MinPos = DAG.getBitcast(SrcVT, MinPos); + + if (Mask) + MinPos = DAG.getNode(ISD::XOR, DL, SrcVT, Mask, MinPos); + + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ExtractVT, MinPos, + DAG.getIntPtrConstant(0, DL)); +} + // Attempt to replace an all_of/any_of style horizontal reduction with a MOVMSK. static SDValue combineHorizontalPredicateResult(SDNode *Extract, SelectionDAG &DAG, @@ -29621,66 +30921,63 @@ static SDValue combineHorizontalPredicateResult(SDNode *Extract, return SDValue(); // Check for OR(any_of) and AND(all_of) horizontal reduction patterns. - for (ISD::NodeType Op : {ISD::OR, ISD::AND}) { - SDValue Match = matchBinOpReduction(Extract, Op); - if (!Match) - continue; - - // EXTRACT_VECTOR_ELT can require implicit extension of the vector element - // which we can't support here for now. - if (Match.getScalarValueSizeInBits() != BitWidth) - continue; + unsigned BinOp = 0; + SDValue Match = matchBinOpReduction(Extract, BinOp, {ISD::OR, ISD::AND}); + if (!Match) + return SDValue(); - // We require AVX2 for PMOVMSKB for v16i16/v32i8; - unsigned MatchSizeInBits = Match.getValueSizeInBits(); - if (!(MatchSizeInBits == 128 || - (MatchSizeInBits == 256 && - ((Subtarget.hasAVX() && BitWidth >= 32) || Subtarget.hasAVX2())))) - return SDValue(); + // EXTRACT_VECTOR_ELT can require implicit extension of the vector element + // which we can't support here for now. + if (Match.getScalarValueSizeInBits() != BitWidth) + return SDValue(); - // Don't bother performing this for 2-element vectors. - if (Match.getValueType().getVectorNumElements() <= 2) - return SDValue(); + // We require AVX2 for PMOVMSKB for v16i16/v32i8; + unsigned MatchSizeInBits = Match.getValueSizeInBits(); + if (!(MatchSizeInBits == 128 || + (MatchSizeInBits == 256 && + ((Subtarget.hasAVX() && BitWidth >= 32) || Subtarget.hasAVX2())))) + return SDValue(); - // Check that we are extracting a reduction of all sign bits. - if (DAG.ComputeNumSignBits(Match) != BitWidth) - return SDValue(); + // Don't bother performing this for 2-element vectors. + if (Match.getValueType().getVectorNumElements() <= 2) + return SDValue(); - // For 32/64 bit comparisons use MOVMSKPS/MOVMSKPD, else PMOVMSKB. - MVT MaskVT; - if (64 == BitWidth || 32 == BitWidth) - MaskVT = MVT::getVectorVT(MVT::getFloatingPointVT(BitWidth), - MatchSizeInBits / BitWidth); - else - MaskVT = MVT::getVectorVT(MVT::i8, MatchSizeInBits / 8); - - APInt CompareBits; - ISD::CondCode CondCode; - if (Op == ISD::OR) { - // any_of -> MOVMSK != 0 - CompareBits = APInt::getNullValue(32); - CondCode = ISD::CondCode::SETNE; - } else { - // all_of -> MOVMSK == ((1 << NumElts) - 1) - CompareBits = APInt::getLowBitsSet(32, MaskVT.getVectorNumElements()); - CondCode = ISD::CondCode::SETEQ; - } + // Check that we are extracting a reduction of all sign bits. + if (DAG.ComputeNumSignBits(Match) != BitWidth) + return SDValue(); - // Perform the select as i32/i64 and then truncate to avoid partial register - // stalls. - unsigned ResWidth = std::max(BitWidth, 32u); - EVT ResVT = EVT::getIntegerVT(*DAG.getContext(), ResWidth); - SDLoc DL(Extract); - SDValue Zero = DAG.getConstant(0, DL, ResVT); - SDValue Ones = DAG.getAllOnesConstant(DL, ResVT); - SDValue Res = DAG.getBitcast(MaskVT, Match); - Res = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Res); - Res = DAG.getSelectCC(DL, Res, DAG.getConstant(CompareBits, DL, MVT::i32), - Ones, Zero, CondCode); - return DAG.getSExtOrTrunc(Res, DL, ExtractVT); + // For 32/64 bit comparisons use MOVMSKPS/MOVMSKPD, else PMOVMSKB. + MVT MaskVT; + if (64 == BitWidth || 32 == BitWidth) + MaskVT = MVT::getVectorVT(MVT::getFloatingPointVT(BitWidth), + MatchSizeInBits / BitWidth); + else + MaskVT = MVT::getVectorVT(MVT::i8, MatchSizeInBits / 8); + + APInt CompareBits; + ISD::CondCode CondCode; + if (BinOp == ISD::OR) { + // any_of -> MOVMSK != 0 + CompareBits = APInt::getNullValue(32); + CondCode = ISD::CondCode::SETNE; + } else { + // all_of -> MOVMSK == ((1 << NumElts) - 1) + CompareBits = APInt::getLowBitsSet(32, MaskVT.getVectorNumElements()); + CondCode = ISD::CondCode::SETEQ; } - return SDValue(); + // Perform the select as i32/i64 and then truncate to avoid partial register + // stalls. + unsigned ResWidth = std::max(BitWidth, 32u); + EVT ResVT = EVT::getIntegerVT(*DAG.getContext(), ResWidth); + SDLoc DL(Extract); + SDValue Zero = DAG.getConstant(0, DL, ResVT); + SDValue Ones = DAG.getAllOnesConstant(DL, ResVT); + SDValue Res = DAG.getBitcast(MaskVT, Match); + Res = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Res); + Res = DAG.getSelectCC(DL, Res, DAG.getConstant(CompareBits, DL, MVT::i32), + Ones, Zero, CondCode); + return DAG.getSExtOrTrunc(Res, DL, ExtractVT); } static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG, @@ -29707,7 +31004,8 @@ static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG, return SDValue(); // Match shuffle + add pyramid. - SDValue Root = matchBinOpReduction(Extract, ISD::ADD); + unsigned BinOp = 0; + SDValue Root = matchBinOpReduction(Extract, BinOp, {ISD::ADD}); // The operand is expected to be zero extended from i8 // (verified in detectZextAbsDiff). @@ -29758,7 +31056,7 @@ static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG, unsigned TypeSizeInBits = Type.getSizeInBits(); // Return the lowest TypeSizeInBits bits. MVT ResVT = MVT::getVectorVT(Type, SadVT.getSizeInBits() / TypeSizeInBits); - SAD = DAG.getNode(ISD::BITCAST, DL, ResVT, SAD); + SAD = DAG.getBitcast(ResVT, SAD); return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, Type, SAD, Extract->getOperand(1)); } @@ -29783,6 +31081,11 @@ static SDValue combineExtractWithShuffle(SDNode *N, SelectionDAG &DAG, if (SrcSVT == MVT::i1 || !isa<ConstantSDNode>(Idx)) return SDValue(); + // Handle extract(broadcast(scalar_value)), it doesn't matter what index is. + if (X86ISD::VBROADCAST == Src.getOpcode() && + Src.getOperand(0).getValueType() == VT) + return Src.getOperand(0); + // Resolve the target shuffle inputs and mask. SmallVector<int, 16> Mask; SmallVector<SDValue, 2> Ops; @@ -29794,7 +31097,7 @@ static SDValue combineExtractWithShuffle(SDNode *N, SelectionDAG &DAG, if ((NumSrcElts % Mask.size()) == 0) { SmallVector<int, 16> ScaledMask; int Scale = NumSrcElts / Mask.size(); - scaleShuffleMask(Scale, Mask, ScaledMask); + scaleShuffleMask<int>(Scale, Mask, ScaledMask); Mask = std::move(ScaledMask); } else if ((Mask.size() % NumSrcElts) == 0) { SmallVector<int, 16> WidenedMask; @@ -29843,9 +31146,7 @@ static SDValue combineExtractWithShuffle(SDNode *N, SelectionDAG &DAG, unsigned OpCode = (SrcVT == MVT::v8i16 ? X86ISD::PEXTRW : X86ISD::PEXTRB); SDValue ExtOp = DAG.getNode(OpCode, dl, MVT::i32, SrcOp, DAG.getIntPtrConstant(SrcIdx, dl)); - SDValue Assert = DAG.getNode(ISD::AssertZext, dl, MVT::i32, ExtOp, - DAG.getValueType(SrcSVT)); - return DAG.getZExtOrTrunc(Assert, dl, VT); + return DAG.getZExtOrTrunc(ExtOp, dl, VT); } return SDValue(); @@ -29858,10 +31159,17 @@ static SDValue combineExtractWithShuffle(SDNode *N, SelectionDAG &DAG, static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { - if (SDValue NewOp = XFormVExtractWithShuffleIntoLoad(N, DAG, DCI)) + if (SDValue NewOp = combineExtractWithShuffle(N, DAG, DCI, Subtarget)) return NewOp; - if (SDValue NewOp = combineExtractWithShuffle(N, DAG, DCI, Subtarget)) + // TODO - Remove this once we can handle the implicit zero-extension of + // X86ISD::PEXTRW/X86ISD::PEXTRB in: + // XFormVExtractWithShuffleIntoLoad, combineHorizontalPredicateResult and + // combineBasicSADPattern. + if (N->getOpcode() != ISD::EXTRACT_VECTOR_ELT) + return SDValue(); + + if (SDValue NewOp = XFormVExtractWithShuffleIntoLoad(N, DAG, DCI)) return NewOp; SDValue InputVector = N->getOperand(0); @@ -29910,6 +31218,10 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG, if (SDValue Cmp = combineHorizontalPredicateResult(N, DAG, Subtarget)) return Cmp; + // Attempt to replace min/max v8i16/v16i8 reductions with PHMINPOSUW. + if (SDValue MinMax = combineHorizontalMinMaxResult(N, DAG, Subtarget)) + return MinMax; + // Only operate on vectors of 4 elements, where the alternative shuffling // gets to be more expensive. if (SrcVT != MVT::v4i32) @@ -30004,22 +31316,13 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG, DAG.ReplaceAllUsesOfValueWith(SDValue(Extract, 0), Vals[IdxVal]); } - // The replacement was made in place; don't return anything. - return SDValue(); -} - -// TODO - merge with combineExtractVectorElt once it can handle the implicit -// zero-extension of X86ISD::PINSRW/X86ISD::PINSRB in: -// XFormVExtractWithShuffleIntoLoad, combineHorizontalPredicateResult and -// combineBasicSADPattern. -static SDValue combineExtractVectorElt_SSE(SDNode *N, SelectionDAG &DAG, - TargetLowering::DAGCombinerInfo &DCI, - const X86Subtarget &Subtarget) { - return combineExtractWithShuffle(N, DAG, DCI, Subtarget); + // The replacement was made in place; return N so it won't be revisited. + return SDValue(N, 0); } /// If a vector select has an operand that is -1 or 0, try to simplify the /// select to a bitwise logic operation. +/// TODO: Move to DAGCombiner, possibly using TargetLowering::hasAndNot()? static SDValue combineVSelectWithAllOnesOrZeros(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, @@ -30037,10 +31340,10 @@ combineVSelectWithAllOnesOrZeros(SDNode *N, SelectionDAG &DAG, assert(CondVT.isVector() && "Vector select expects a vector selector!"); - bool FValIsAllZeros = ISD::isBuildVectorAllZeros(LHS.getNode()); + bool TValIsAllZeros = ISD::isBuildVectorAllZeros(LHS.getNode()); // Check if the first operand is all zeros and Cond type is vXi1. // This situation only applies to avx512. - if (FValIsAllZeros && Subtarget.hasAVX512() && Cond.hasOneUse() && + if (TValIsAllZeros && Subtarget.hasAVX512() && Cond.hasOneUse() && CondVT.getVectorElementType() == MVT::i1) { // Invert the cond to not(cond) : xor(op,allones)=not(op) SDValue CondNew = DAG.getNode(ISD::XOR, DL, CondVT, Cond, @@ -30058,7 +31361,7 @@ combineVSelectWithAllOnesOrZeros(SDNode *N, SelectionDAG &DAG, return SDValue(); bool TValIsAllOnes = ISD::isBuildVectorAllOnes(LHS.getNode()); - FValIsAllZeros = ISD::isBuildVectorAllZeros(RHS.getNode()); + bool FValIsAllZeros = ISD::isBuildVectorAllZeros(RHS.getNode()); // Try to invert the condition if true value is not all 1s and false value is // not all 0s. @@ -30068,7 +31371,6 @@ combineVSelectWithAllOnesOrZeros(SDNode *N, SelectionDAG &DAG, // Check if SETCC has already been promoted. TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT) == CondVT) { - bool TValIsAllZeros = ISD::isBuildVectorAllZeros(LHS.getNode()); bool FValIsAllOnes = ISD::isBuildVectorAllOnes(RHS.getNode()); if (TValIsAllZeros || FValIsAllOnes) { @@ -30084,6 +31386,10 @@ combineVSelectWithAllOnesOrZeros(SDNode *N, SelectionDAG &DAG, } } + // Cond value must be 'sign splat' to be converted to a logical op. + if (DAG.ComputeNumSignBits(Cond) != CondVT.getScalarSizeInBits()) + return SDValue(); + // vselect Cond, 111..., 000... -> Cond if (TValIsAllOnes && FValIsAllZeros) return DAG.getBitcast(VT, Cond); @@ -30105,6 +31411,15 @@ combineVSelectWithAllOnesOrZeros(SDNode *N, SelectionDAG &DAG, return DAG.getBitcast(VT, And); } + // vselect Cond, 000..., X -> andn Cond, X + if (TValIsAllZeros) { + MVT AndNVT = MVT::getVectorVT(MVT::i64, CondVT.getSizeInBits() / 64); + SDValue CastCond = DAG.getBitcast(AndNVT, Cond); + SDValue CastRHS = DAG.getBitcast(AndNVT, RHS); + SDValue AndN = DAG.getNode(X86ISD::ANDNP, DL, AndNVT, CastCond, CastRHS); + return DAG.getBitcast(VT, AndN); + } + return SDValue(); } @@ -30120,78 +31435,52 @@ static SDValue combineSelectOfTwoConstants(SDNode *N, SelectionDAG &DAG) { return SDValue(); // Don't do this for crazy integer types. - if (!DAG.getTargetLoweringInfo().isTypeLegal(LHS.getValueType())) + EVT VT = N->getValueType(0); + if (!DAG.getTargetLoweringInfo().isTypeLegal(VT)) return SDValue(); - // If this is efficiently invertible, canonicalize the LHSC/RHSC values - // so that TrueC (the true value) is larger than FalseC. - bool NeedsCondInvert = false; - if (TrueC->getAPIntValue().ult(FalseC->getAPIntValue()) && - // Efficiently invertible. - (Cond.getOpcode() == ISD::SETCC || // setcc -> invertible. - (Cond.getOpcode() == ISD::XOR && // xor(X, C) -> invertible. - isa<ConstantSDNode>(Cond.getOperand(1))))) { - NeedsCondInvert = true; - std::swap(TrueC, FalseC); - } - - // Optimize C ? 8 : 0 -> zext(C) << 3. Likewise for any pow2/0. - if (FalseC->getAPIntValue() == 0 && TrueC->getAPIntValue().isPowerOf2()) { - if (NeedsCondInvert) // Invert the condition if needed. - Cond = DAG.getNode(ISD::XOR, DL, Cond.getValueType(), Cond, - DAG.getConstant(1, DL, Cond.getValueType())); - - // Zero extend the condition if needed. - Cond = DAG.getNode(ISD::ZERO_EXTEND, DL, LHS.getValueType(), Cond); + // We're going to use the condition bit in math or logic ops. We could allow + // this with a wider condition value (post-legalization it becomes an i8), + // but if nothing is creating selects that late, it doesn't matter. + if (Cond.getValueType() != MVT::i1) + return SDValue(); - unsigned ShAmt = TrueC->getAPIntValue().logBase2(); - return DAG.getNode(ISD::SHL, DL, LHS.getValueType(), Cond, - DAG.getConstant(ShAmt, DL, MVT::i8)); - } + // A power-of-2 multiply is just a shift. LEA also cheaply handles multiply by + // 3, 5, or 9 with i32/i64, so those get transformed too. + // TODO: For constants that overflow or do not differ by power-of-2 or small + // multiplier, convert to 'and' + 'add'. + const APInt &TrueVal = TrueC->getAPIntValue(); + const APInt &FalseVal = FalseC->getAPIntValue(); + bool OV; + APInt Diff = TrueVal.ssub_ov(FalseVal, OV); + if (OV) + return SDValue(); - // Optimize cases that will turn into an LEA instruction. This requires - // an i32 or i64 and an efficient multiplier (1, 2, 3, 4, 5, 8, 9). - if (N->getValueType(0) == MVT::i32 || N->getValueType(0) == MVT::i64) { - uint64_t Diff = TrueC->getZExtValue() - FalseC->getZExtValue(); - if (N->getValueType(0) == MVT::i32) - Diff = (unsigned)Diff; + APInt AbsDiff = Diff.abs(); + if (AbsDiff.isPowerOf2() || + ((VT == MVT::i32 || VT == MVT::i64) && + (AbsDiff == 3 || AbsDiff == 5 || AbsDiff == 9))) { - bool IsFastMultiplier = false; - if (Diff < 10) { - switch ((unsigned char)Diff) { - default: - break; - case 1: // result = add base, cond - case 2: // result = lea base( , cond*2) - case 3: // result = lea base(cond, cond*2) - case 4: // result = lea base( , cond*4) - case 5: // result = lea base(cond, cond*4) - case 8: // result = lea base( , cond*8) - case 9: // result = lea base(cond, cond*8) - IsFastMultiplier = true; - break; - } + // We need a positive multiplier constant for shift/LEA codegen. The 'not' + // of the condition can usually be folded into a compare predicate, but even + // without that, the sequence should be cheaper than a CMOV alternative. + if (TrueVal.slt(FalseVal)) { + Cond = DAG.getNOT(DL, Cond, MVT::i1); + std::swap(TrueC, FalseC); } - if (IsFastMultiplier) { - APInt Diff = TrueC->getAPIntValue() - FalseC->getAPIntValue(); - if (NeedsCondInvert) // Invert the condition if needed. - Cond = DAG.getNode(ISD::XOR, DL, Cond.getValueType(), Cond, - DAG.getConstant(1, DL, Cond.getValueType())); + // select Cond, TC, FC --> (zext(Cond) * (TC - FC)) + FC + SDValue R = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Cond); - // Zero extend the condition if needed. - Cond = DAG.getNode(ISD::ZERO_EXTEND, DL, FalseC->getValueType(0), Cond); - // Scale the condition by the difference. - if (Diff != 1) - Cond = DAG.getNode(ISD::MUL, DL, Cond.getValueType(), Cond, - DAG.getConstant(Diff, DL, Cond.getValueType())); + // Multiply condition by the difference if non-one. + if (!AbsDiff.isOneValue()) + R = DAG.getNode(ISD::MUL, DL, VT, R, DAG.getConstant(AbsDiff, DL, VT)); - // Add the base if non-zero. - if (FalseC->getAPIntValue() != 0) - Cond = DAG.getNode(ISD::ADD, DL, Cond.getValueType(), Cond, - SDValue(FalseC, 0)); - return Cond; - } + // Add the base if non-zero. + if (!FalseC->isNullValue()) + R = DAG.getNode(ISD::ADD, DL, VT, R, SDValue(FalseC, 0)); + + return R; } return SDValue(); @@ -30231,26 +31520,6 @@ static bool combineBitcastForMaskedOp(SDValue OrigOp, SelectionDAG &DAG, unsigned Opcode = Op.getOpcode(); switch (Opcode) { - case X86ISD::PALIGNR: - // PALIGNR can be converted to VALIGND/Q for 128-bit vectors. - if (!VT.is128BitVector()) - return false; - Opcode = X86ISD::VALIGN; - LLVM_FALLTHROUGH; - case X86ISD::VALIGN: { - if (EltVT != MVT::i32 && EltVT != MVT::i64) - return false; - uint64_t Imm = cast<ConstantSDNode>(Op.getOperand(2))->getZExtValue(); - MVT OpEltVT = Op.getSimpleValueType().getVectorElementType(); - unsigned ShiftAmt = Imm * OpEltVT.getSizeInBits(); - unsigned EltSize = EltVT.getSizeInBits(); - // Make sure we can represent the same shift with the new VT. - if ((ShiftAmt % EltSize) != 0) - return false; - Imm = ShiftAmt / EltSize; - return BitcastAndCombineShuffle(Opcode, Op.getOperand(0), Op.getOperand(1), - DAG.getConstant(Imm, DL, MVT::i8)); - } case X86ISD::SHUF128: { if (EltVT.getSizeInBits() != 32 && EltVT.getSizeInBits() != 64) return false; @@ -30260,50 +31529,6 @@ static bool combineBitcastForMaskedOp(SDValue OrigOp, SelectionDAG &DAG, return BitcastAndCombineShuffle(Opcode, Op.getOperand(0), Op.getOperand(1), Op.getOperand(2)); } - case ISD::INSERT_SUBVECTOR: { - unsigned EltSize = EltVT.getSizeInBits(); - if (EltSize != 32 && EltSize != 64) - return false; - MVT OpEltVT = Op.getSimpleValueType().getVectorElementType(); - // Only change element size, not type. - if (EltVT.isInteger() != OpEltVT.isInteger()) - return false; - uint64_t Imm = cast<ConstantSDNode>(Op.getOperand(2))->getZExtValue(); - Imm = (Imm * OpEltVT.getSizeInBits()) / EltSize; - SDValue Op0 = DAG.getBitcast(VT, Op.getOperand(0)); - DCI.AddToWorklist(Op0.getNode()); - // Op1 needs to be bitcasted to a smaller vector with the same element type. - SDValue Op1 = Op.getOperand(1); - MVT Op1VT = MVT::getVectorVT(EltVT, - Op1.getSimpleValueType().getSizeInBits() / EltSize); - Op1 = DAG.getBitcast(Op1VT, Op1); - DCI.AddToWorklist(Op1.getNode()); - DCI.CombineTo(OrigOp.getNode(), - DAG.getNode(Opcode, DL, VT, Op0, Op1, - DAG.getIntPtrConstant(Imm, DL))); - return true; - } - case ISD::EXTRACT_SUBVECTOR: { - unsigned EltSize = EltVT.getSizeInBits(); - if (EltSize != 32 && EltSize != 64) - return false; - MVT OpEltVT = Op.getSimpleValueType().getVectorElementType(); - // Only change element size, not type. - if (EltVT.isInteger() != OpEltVT.isInteger()) - return false; - uint64_t Imm = cast<ConstantSDNode>(Op.getOperand(1))->getZExtValue(); - Imm = (Imm * OpEltVT.getSizeInBits()) / EltSize; - // Op0 needs to be bitcasted to a larger vector with the same element type. - SDValue Op0 = Op.getOperand(0); - MVT Op0VT = MVT::getVectorVT(EltVT, - Op0.getSimpleValueType().getSizeInBits() / EltSize); - Op0 = DAG.getBitcast(Op0VT, Op0); - DCI.AddToWorklist(Op0.getNode()); - DCI.CombineTo(OrigOp.getNode(), - DAG.getNode(Opcode, DL, VT, Op0, - DAG.getIntPtrConstant(Imm, DL))); - return true; - } case X86ISD::SUBV_BROADCAST: { unsigned EltSize = EltVT.getSizeInBits(); if (EltSize != 32 && EltSize != 64) @@ -30661,9 +31886,10 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG, // Check all uses of the condition operand to check whether it will be // consumed by non-BLEND instructions. Those may require that all bits // are set properly. - for (SDNode *U : Cond->uses()) { + for (SDNode::use_iterator UI = Cond->use_begin(), UE = Cond->use_end(); + UI != UE; ++UI) { // TODO: Add other opcodes eventually lowered into BLEND. - if (U->getOpcode() != ISD::VSELECT) + if (UI->getOpcode() != ISD::VSELECT || UI.getOperandNo() != 0) return SDValue(); } @@ -30717,7 +31943,8 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG, /// i.e., reusing the EFLAGS produced by the LOCKed instruction. /// Note that this is only legal for some op/cc combinations. static SDValue combineSetCCAtomicArith(SDValue Cmp, X86::CondCode &CC, - SelectionDAG &DAG) { + SelectionDAG &DAG, + const X86Subtarget &Subtarget) { // This combine only operates on CMP-like nodes. if (!(Cmp.getOpcode() == X86ISD::CMP || (Cmp.getOpcode() == X86ISD::SUB && !Cmp->hasAnyUseOfValue(0)))) @@ -30747,12 +31974,7 @@ static SDValue combineSetCCAtomicArith(SDValue Cmp, X86::CondCode &CC, if (!CmpLHS.hasOneUse()) return SDValue(); - auto *CmpRHSC = dyn_cast<ConstantSDNode>(CmpRHS); - if (!CmpRHSC || CmpRHSC->getZExtValue() != 0) - return SDValue(); - - const unsigned Opc = CmpLHS.getOpcode(); - + unsigned Opc = CmpLHS.getOpcode(); if (Opc != ISD::ATOMIC_LOAD_ADD && Opc != ISD::ATOMIC_LOAD_SUB) return SDValue(); @@ -30765,6 +31987,44 @@ static SDValue combineSetCCAtomicArith(SDValue Cmp, X86::CondCode &CC, if (Opc == ISD::ATOMIC_LOAD_SUB) Addend = -Addend; + auto *CmpRHSC = dyn_cast<ConstantSDNode>(CmpRHS); + if (!CmpRHSC) + return SDValue(); + + APInt Comparison = CmpRHSC->getAPIntValue(); + + // If the addend is the negation of the comparison value, then we can do + // a full comparison by emitting the atomic arithmetic as a locked sub. + if (Comparison == -Addend) { + // The CC is fine, but we need to rewrite the LHS of the comparison as an + // atomic sub. + auto *AN = cast<AtomicSDNode>(CmpLHS.getNode()); + auto AtomicSub = DAG.getAtomic( + ISD::ATOMIC_LOAD_SUB, SDLoc(CmpLHS), CmpLHS.getValueType(), + /*Chain*/ CmpLHS.getOperand(0), /*LHS*/ CmpLHS.getOperand(1), + /*RHS*/ DAG.getConstant(-Addend, SDLoc(CmpRHS), CmpRHS.getValueType()), + AN->getMemOperand()); + // If the comparision uses the CF flag we can't use INC/DEC instructions. + bool NeedCF = false; + switch (CC) { + default: break; + case X86::COND_A: case X86::COND_AE: + case X86::COND_B: case X86::COND_BE: + NeedCF = true; + break; + } + auto LockOp = lowerAtomicArithWithLOCK(AtomicSub, DAG, Subtarget, !NeedCF); + DAG.ReplaceAllUsesOfValueWith(CmpLHS.getValue(0), + DAG.getUNDEF(CmpLHS.getValueType())); + DAG.ReplaceAllUsesOfValueWith(CmpLHS.getValue(1), LockOp.getValue(1)); + return LockOp; + } + + // We can handle comparisons with zero in a number of cases by manipulating + // the CC used. + if (!Comparison.isNullValue()) + return SDValue(); + if (CC == X86::COND_S && Addend == 1) CC = X86::COND_LE; else if (CC == X86::COND_NS && Addend == 1) @@ -30776,7 +32036,7 @@ static SDValue combineSetCCAtomicArith(SDValue Cmp, X86::CondCode &CC, else return SDValue(); - SDValue LockOp = lowerAtomicArithWithLOCK(CmpLHS, DAG); + SDValue LockOp = lowerAtomicArithWithLOCK(CmpLHS, DAG, Subtarget); DAG.ReplaceAllUsesOfValueWith(CmpLHS.getValue(0), DAG.getUNDEF(CmpLHS.getValueType())); DAG.ReplaceAllUsesOfValueWith(CmpLHS.getValue(1), LockOp.getValue(1)); @@ -30983,14 +32243,15 @@ static SDValue combineCarryThroughADD(SDValue EFLAGS) { /// into a simpler EFLAGS value, potentially returning a new \p CC and replacing /// uses of chain values. static SDValue combineSetCCEFLAGS(SDValue EFLAGS, X86::CondCode &CC, - SelectionDAG &DAG) { + SelectionDAG &DAG, + const X86Subtarget &Subtarget) { if (CC == X86::COND_B) if (SDValue Flags = combineCarryThroughADD(EFLAGS)) return Flags; if (SDValue R = checkBoolTestSetCCCombine(EFLAGS, CC)) return R; - return combineSetCCAtomicArith(EFLAGS, CC, DAG); + return combineSetCCAtomicArith(EFLAGS, CC, DAG, Subtarget); } /// Optimize X86ISD::CMOV [LHS, RHS, CONDCODE (e.g. X86::COND_NE), CONDVAL] @@ -30999,10 +32260,6 @@ static SDValue combineCMov(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { SDLoc DL(N); - // If the flag operand isn't dead, don't touch this CMOV. - if (N->getNumValues() == 2 && !SDValue(N, 1).use_empty()) - return SDValue(); - SDValue FalseOp = N->getOperand(0); SDValue TrueOp = N->getOperand(1); X86::CondCode CC = (X86::CondCode)N->getConstantOperandVal(2); @@ -31021,11 +32278,11 @@ static SDValue combineCMov(SDNode *N, SelectionDAG &DAG, // Try to simplify the EFLAGS and condition code operands. // We can't always do this as FCMOV only supports a subset of X86 cond. - if (SDValue Flags = combineSetCCEFLAGS(Cond, CC, DAG)) { + if (SDValue Flags = combineSetCCEFLAGS(Cond, CC, DAG, Subtarget)) { if (FalseOp.getValueType() != MVT::f80 || hasFPCMov(CC)) { SDValue Ops[] = {FalseOp, TrueOp, DAG.getConstant(CC, DL, MVT::i8), Flags}; - return DAG.getNode(X86ISD::CMOV, DL, N->getVTList(), Ops); + return DAG.getNode(X86ISD::CMOV, DL, N->getValueType(0), Ops); } } @@ -31054,8 +32311,6 @@ static SDValue combineCMov(SDNode *N, SelectionDAG &DAG, unsigned ShAmt = TrueC->getAPIntValue().logBase2(); Cond = DAG.getNode(ISD::SHL, DL, Cond.getValueType(), Cond, DAG.getConstant(ShAmt, DL, MVT::i8)); - if (N->getNumValues() == 2) // Dead flag value? - return DCI.CombineTo(N, Cond, SDValue()); return Cond; } @@ -31069,9 +32324,6 @@ static SDValue combineCMov(SDNode *N, SelectionDAG &DAG, FalseC->getValueType(0), Cond); Cond = DAG.getNode(ISD::ADD, DL, Cond.getValueType(), Cond, SDValue(FalseC, 0)); - - if (N->getNumValues() == 2) // Dead flag value? - return DCI.CombineTo(N, Cond, SDValue()); return Cond; } @@ -31112,8 +32364,6 @@ static SDValue combineCMov(SDNode *N, SelectionDAG &DAG, if (FalseC->getAPIntValue() != 0) Cond = DAG.getNode(ISD::ADD, DL, Cond.getValueType(), Cond, SDValue(FalseC, 0)); - if (N->getNumValues() == 2) // Dead flag value? - return DCI.CombineTo(N, Cond, SDValue()); return Cond; } } @@ -31153,7 +32403,7 @@ static SDValue combineCMov(SDNode *N, SelectionDAG &DAG, CmpAgainst == dyn_cast<ConstantSDNode>(TrueOp)) { SDValue Ops[] = { FalseOp, Cond.getOperand(0), DAG.getConstant(CC, DL, MVT::i8), Cond }; - return DAG.getNode(X86ISD::CMOV, DL, N->getVTList (), Ops); + return DAG.getNode(X86ISD::CMOV, DL, N->getValueType(0), Ops); } } } @@ -31188,10 +32438,9 @@ static SDValue combineCMov(SDNode *N, SelectionDAG &DAG, SDValue LOps[] = {FalseOp, TrueOp, DAG.getConstant(CC0, DL, MVT::i8), Flags}; - SDValue LCMOV = DAG.getNode(X86ISD::CMOV, DL, N->getVTList(), LOps); + SDValue LCMOV = DAG.getNode(X86ISD::CMOV, DL, N->getValueType(0), LOps); SDValue Ops[] = {LCMOV, TrueOp, DAG.getConstant(CC1, DL, MVT::i8), Flags}; - SDValue CMOV = DAG.getNode(X86ISD::CMOV, DL, N->getVTList(), Ops); - DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), SDValue(CMOV.getNode(), 1)); + SDValue CMOV = DAG.getNode(X86ISD::CMOV, DL, N->getValueType(0), Ops); return CMOV; } } @@ -31307,7 +32556,7 @@ static SDValue reduceVMULWidth(SDNode *N, SelectionDAG &DAG, // pmulld is supported since SSE41. It is better to use pmulld // instead of pmullw+pmulhw, except for subtargets where pmulld is slower than // the expansion. - bool OptForMinSize = DAG.getMachineFunction().getFunction()->optForMinSize(); + bool OptForMinSize = DAG.getMachineFunction().getFunction().optForMinSize(); if (Subtarget.hasSSE41() && (OptForMinSize || !Subtarget.isPMULLDSlow())) return SDValue(); @@ -31319,15 +32568,26 @@ static SDValue reduceVMULWidth(SDNode *N, SelectionDAG &DAG, SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); EVT VT = N->getOperand(0).getValueType(); + unsigned NumElts = VT.getVectorNumElements(); + if ((NumElts % 2) != 0) + return SDValue(); + + // If the upper 17 bits of each element are zero then we can use PMADD. + APInt Mask17 = APInt::getHighBitsSet(32, 17); + if (VT == MVT::v4i32 && DAG.MaskedValueIsZero(N0, Mask17) && + DAG.MaskedValueIsZero(N1, Mask17)) + return DAG.getNode(X86ISD::VPMADDWD, DL, VT, DAG.getBitcast(MVT::v8i16, N0), + DAG.getBitcast(MVT::v8i16, N1)); + unsigned RegSize = 128; MVT OpsVT = MVT::getVectorVT(MVT::i16, RegSize / 16); - EVT ReducedVT = - EVT::getVectorVT(*DAG.getContext(), MVT::i16, VT.getVectorNumElements()); + EVT ReducedVT = EVT::getVectorVT(*DAG.getContext(), MVT::i16, NumElts); + // Shrink the operands of mul. SDValue NewN0 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, N0); SDValue NewN1 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, N1); - if (VT.getVectorNumElements() >= OpsVT.getVectorNumElements()) { + if (NumElts >= OpsVT.getVectorNumElements()) { // Generate the lower part of mul: pmullw. For MULU8/MULS8, only the // lower part is needed. SDValue MulLo = DAG.getNode(ISD::MUL, DL, ReducedVT, NewN0, NewN1); @@ -31335,7 +32595,7 @@ static SDValue reduceVMULWidth(SDNode *N, SelectionDAG &DAG, return DAG.getNode((Mode == MULU8) ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND, DL, VT, MulLo); } else { - MVT ResVT = MVT::getVectorVT(MVT::i32, VT.getVectorNumElements() / 2); + MVT ResVT = MVT::getVectorVT(MVT::i32, NumElts / 2); // Generate the higher part of mul: pmulhw/pmulhuw. For MULU16/MULS16, // the higher part is also needed. SDValue MulHi = DAG.getNode(Mode == MULS16 ? ISD::MULHS : ISD::MULHU, DL, @@ -31344,22 +32604,22 @@ static SDValue reduceVMULWidth(SDNode *N, SelectionDAG &DAG, // Repack the lower part and higher part result of mul into a wider // result. // Generate shuffle functioning as punpcklwd. - SmallVector<int, 16> ShuffleMask(VT.getVectorNumElements()); - for (unsigned i = 0; i < VT.getVectorNumElements() / 2; i++) { + SmallVector<int, 16> ShuffleMask(NumElts); + for (unsigned i = 0, e = NumElts / 2; i < e; i++) { ShuffleMask[2 * i] = i; - ShuffleMask[2 * i + 1] = i + VT.getVectorNumElements(); + ShuffleMask[2 * i + 1] = i + NumElts; } SDValue ResLo = DAG.getVectorShuffle(ReducedVT, DL, MulLo, MulHi, ShuffleMask); - ResLo = DAG.getNode(ISD::BITCAST, DL, ResVT, ResLo); + ResLo = DAG.getBitcast(ResVT, ResLo); // Generate shuffle functioning as punpckhwd. - for (unsigned i = 0; i < VT.getVectorNumElements() / 2; i++) { - ShuffleMask[2 * i] = i + VT.getVectorNumElements() / 2; - ShuffleMask[2 * i + 1] = i + VT.getVectorNumElements() * 3 / 2; + for (unsigned i = 0, e = NumElts / 2; i < e; i++) { + ShuffleMask[2 * i] = i + NumElts / 2; + ShuffleMask[2 * i + 1] = i + NumElts * 3 / 2; } SDValue ResHi = DAG.getVectorShuffle(ReducedVT, DL, MulLo, MulHi, ShuffleMask); - ResHi = DAG.getNode(ISD::BITCAST, DL, ResVT, ResHi); + ResHi = DAG.getBitcast(ResVT, ResHi); return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ResLo, ResHi); } } else { @@ -31405,8 +32665,8 @@ static SDValue reduceVMULWidth(SDNode *N, SelectionDAG &DAG, // Repack the lower part and higher part result of mul into a wider // result. Make sure the type of mul result is VT. MVT ResVT = MVT::getVectorVT(MVT::i32, RegSize / 32); - SDValue Res = DAG.getNode(X86ISD::UNPCKL, DL, OpsVT, MulLo, MulHi); - Res = DAG.getNode(ISD::BITCAST, DL, ResVT, Res); + SDValue Res = getUnpackl(DAG, DL, OpsVT, MulLo, MulHi); + Res = DAG.getBitcast(ResVT, Res); return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Res, DAG.getIntPtrConstant(0, DL)); } @@ -31496,7 +32756,7 @@ static SDValue combineMul(SDNode *N, SelectionDAG &DAG, if (!MulConstantOptimization) return SDValue(); // An imul is usually smaller than the alternative sequence. - if (DAG.getMachineFunction().getFunction()->optForMinSize()) + if (DAG.getMachineFunction().getFunction().optForMinSize()) return SDValue(); if (DCI.isBeforeLegalize() || DCI.isCalledByLegalizer()) @@ -31653,7 +32913,7 @@ static SDValue combineShiftLeft(SDNode *N, SelectionDAG &DAG) { return SDValue(); } -static SDValue combineShiftRightAlgebraic(SDNode *N, SelectionDAG &DAG) { +static SDValue combineShiftRightArithmetic(SDNode *N, SelectionDAG &DAG) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); EVT VT = N0.getValueType(); @@ -31670,7 +32930,7 @@ static SDValue combineShiftRightAlgebraic(SDNode *N, SelectionDAG &DAG) { // 1. MOVs can write to a register that differs from source // 2. MOVs accept memory operands - if (!VT.isInteger() || VT.isVector() || N1.getOpcode() != ISD::Constant || + if (VT.isVector() || N1.getOpcode() != ISD::Constant || N0.getOpcode() != ISD::SHL || !N0.hasOneUse() || N0.getOperand(1).getOpcode() != ISD::Constant) return SDValue(); @@ -31684,11 +32944,11 @@ static SDValue combineShiftRightAlgebraic(SDNode *N, SelectionDAG &DAG) { if (SarConst.isNegative()) return SDValue(); - for (MVT SVT : MVT::integer_valuetypes()) { + for (MVT SVT : { MVT::i8, MVT::i16, MVT::i32 }) { unsigned ShiftSize = SVT.getSizeInBits(); // skipping types without corresponding sext/zext and // ShlConst that is not one of [56,48,32,24,16] - if (ShiftSize < 8 || ShiftSize > 64 || ShlConst != Size - ShiftSize) + if (ShiftSize >= Size || ShlConst != Size - ShiftSize) continue; SDLoc DL(N); SDValue NN = @@ -31706,34 +32966,38 @@ static SDValue combineShiftRightAlgebraic(SDNode *N, SelectionDAG &DAG) { return SDValue(); } -/// \brief Returns a vector of 0s if the node in input is a vector logical -/// shift by a constant amount which is known to be bigger than or equal -/// to the vector element size in bits. -static SDValue performShiftToAllZeros(SDNode *N, SelectionDAG &DAG, - const X86Subtarget &Subtarget) { - EVT VT = N->getValueType(0); +static SDValue combineShiftRightLogical(SDNode *N, SelectionDAG &DAG) { + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + EVT VT = N0.getValueType(); - if (VT != MVT::v2i64 && VT != MVT::v4i32 && VT != MVT::v8i16 && - (!Subtarget.hasInt256() || - (VT != MVT::v4i64 && VT != MVT::v8i32 && VT != MVT::v16i16))) + // Try to improve a sequence of srl (and X, C1), C2 by inverting the order. + // TODO: This is a generic DAG combine that became an x86-only combine to + // avoid shortcomings in other folds such as bswap, bit-test ('bt'), and + // and-not ('andn'). + if (N0.getOpcode() != ISD::AND || !N0.hasOneUse()) return SDValue(); - SDValue Amt = N->getOperand(1); - SDLoc DL(N); - if (auto *AmtBV = dyn_cast<BuildVectorSDNode>(Amt)) - if (auto *AmtSplat = AmtBV->getConstantSplatNode()) { - const APInt &ShiftAmt = AmtSplat->getAPIntValue(); - unsigned MaxAmount = - VT.getSimpleVT().getScalarSizeInBits(); - - // SSE2/AVX2 logical shifts always return a vector of 0s - // if the shift amount is bigger than or equal to - // the element size. The constant shift amount will be - // encoded as a 8-bit immediate. - if (ShiftAmt.trunc(8).uge(MaxAmount)) - return getZeroVector(VT.getSimpleVT(), Subtarget, DAG, DL); - } + auto *ShiftC = dyn_cast<ConstantSDNode>(N1); + auto *AndC = dyn_cast<ConstantSDNode>(N0.getOperand(1)); + if (!ShiftC || !AndC) + return SDValue(); + // If we can shrink the constant mask below 8-bits or 32-bits, then this + // transform should reduce code size. It may also enable secondary transforms + // from improved known-bits analysis or instruction selection. + APInt MaskVal = AndC->getAPIntValue(); + APInt NewMaskVal = MaskVal.lshr(ShiftC->getAPIntValue()); + unsigned OldMaskSize = MaskVal.getMinSignedBits(); + unsigned NewMaskSize = NewMaskVal.getMinSignedBits(); + if ((OldMaskSize > 8 && NewMaskSize <= 8) || + (OldMaskSize > 32 && NewMaskSize <= 32)) { + // srl (and X, AndC), ShiftC --> and (srl X, ShiftC), (AndC >> ShiftC) + SDLoc DL(N); + SDValue NewMask = DAG.getConstant(NewMaskVal, DL, VT); + SDValue NewShift = DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), N1); + return DAG.getNode(ISD::AND, DL, VT, NewShift, NewMask); + } return SDValue(); } @@ -31745,17 +33009,100 @@ static SDValue combineShift(SDNode* N, SelectionDAG &DAG, return V; if (N->getOpcode() == ISD::SRA) - if (SDValue V = combineShiftRightAlgebraic(N, DAG)) + if (SDValue V = combineShiftRightArithmetic(N, DAG)) return V; - // Try to fold this logical shift into a zero vector. - if (N->getOpcode() != ISD::SRA) - if (SDValue V = performShiftToAllZeros(N, DAG, Subtarget)) + if (N->getOpcode() == ISD::SRL) + if (SDValue V = combineShiftRightLogical(N, DAG)) return V; return SDValue(); } +static SDValue combineVectorPack(SDNode *N, SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI, + const X86Subtarget &Subtarget) { + unsigned Opcode = N->getOpcode(); + assert((X86ISD::PACKSS == Opcode || X86ISD::PACKUS == Opcode) && + "Unexpected shift opcode"); + + EVT VT = N->getValueType(0); + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + unsigned DstBitsPerElt = VT.getScalarSizeInBits(); + unsigned SrcBitsPerElt = 2 * DstBitsPerElt; + assert(N0.getScalarValueSizeInBits() == SrcBitsPerElt && + N1.getScalarValueSizeInBits() == SrcBitsPerElt && + "Unexpected PACKSS/PACKUS input type"); + + // Constant Folding. + APInt UndefElts0, UndefElts1; + SmallVector<APInt, 32> EltBits0, EltBits1; + if ((N0->isUndef() || N->isOnlyUserOf(N0.getNode())) && + (N1->isUndef() || N->isOnlyUserOf(N1.getNode())) && + getTargetConstantBitsFromNode(N0, SrcBitsPerElt, UndefElts0, EltBits0) && + getTargetConstantBitsFromNode(N1, SrcBitsPerElt, UndefElts1, EltBits1)) { + unsigned NumLanes = VT.getSizeInBits() / 128; + unsigned NumDstElts = VT.getVectorNumElements(); + unsigned NumSrcElts = NumDstElts / 2; + unsigned NumDstEltsPerLane = NumDstElts / NumLanes; + unsigned NumSrcEltsPerLane = NumSrcElts / NumLanes; + bool IsSigned = (X86ISD::PACKSS == Opcode); + + APInt Undefs(NumDstElts, 0); + SmallVector<APInt, 32> Bits(NumDstElts, APInt::getNullValue(DstBitsPerElt)); + for (unsigned Lane = 0; Lane != NumLanes; ++Lane) { + for (unsigned Elt = 0; Elt != NumDstEltsPerLane; ++Elt) { + unsigned SrcIdx = Lane * NumSrcEltsPerLane + Elt % NumSrcEltsPerLane; + auto &UndefElts = (Elt >= NumSrcEltsPerLane ? UndefElts1 : UndefElts0); + auto &EltBits = (Elt >= NumSrcEltsPerLane ? EltBits1 : EltBits0); + + if (UndefElts[SrcIdx]) { + Undefs.setBit(Lane * NumDstEltsPerLane + Elt); + continue; + } + + APInt &Val = EltBits[SrcIdx]; + if (IsSigned) { + // PACKSS: Truncate signed value with signed saturation. + // Source values less than dst minint are saturated to minint. + // Source values greater than dst maxint are saturated to maxint. + if (Val.isSignedIntN(DstBitsPerElt)) + Val = Val.trunc(DstBitsPerElt); + else if (Val.isNegative()) + Val = APInt::getSignedMinValue(DstBitsPerElt); + else + Val = APInt::getSignedMaxValue(DstBitsPerElt); + } else { + // PACKUS: Truncate signed value with unsigned saturation. + // Source values less than zero are saturated to zero. + // Source values greater than dst maxuint are saturated to maxuint. + if (Val.isIntN(DstBitsPerElt)) + Val = Val.trunc(DstBitsPerElt); + else if (Val.isNegative()) + Val = APInt::getNullValue(DstBitsPerElt); + else + Val = APInt::getAllOnesValue(DstBitsPerElt); + } + Bits[Lane * NumDstEltsPerLane + Elt] = Val; + } + } + + return getConstVector(Bits, Undefs, VT.getSimpleVT(), DAG, SDLoc(N)); + } + + // Attempt to combine as shuffle. + SDValue Op(N, 0); + if (SDValue Res = combineX86ShufflesRecursively( + {Op}, 0, Op, {0}, {}, /*Depth*/ 1, + /*HasVarMask*/ false, DAG, DCI, Subtarget)) { + DCI.CombineTo(N, Res); + return SDValue(); + } + + return SDValue(); +} + static SDValue combineVectorShiftImm(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { @@ -31796,15 +33143,24 @@ static SDValue combineVectorShiftImm(SDNode *N, SelectionDAG &DAG, N0.getOpcode() == X86ISD::VSRAI) return DAG.getNode(X86ISD::VSRLI, SDLoc(N), VT, N0.getOperand(0), N1); + // fold (VSRAI (VSHLI X, C1), C1) --> X iff NumSignBits(X) > C1 + if (Opcode == X86ISD::VSRAI && N0.getOpcode() == X86ISD::VSHLI && + N1 == N0.getOperand(1)) { + SDValue N00 = N0.getOperand(0); + unsigned NumSignBits = DAG.ComputeNumSignBits(N00); + if (ShiftVal.ult(NumSignBits)) + return N00; + } + // We can decode 'whole byte' logical bit shifts as shuffles. if (LogicalShift && (ShiftVal.getZExtValue() % 8) == 0) { SDValue Op(N, 0); - SmallVector<int, 1> NonceMask; // Just a placeholder. - NonceMask.push_back(0); - if (combineX86ShufflesRecursively({Op}, 0, Op, NonceMask, {}, - /*Depth*/ 1, /*HasVarMask*/ false, DAG, - DCI, Subtarget)) - return SDValue(); // This routine will use CombineTo to replace N. + if (SDValue Res = combineX86ShufflesRecursively( + {Op}, 0, Op, {0}, {}, /*Depth*/ 1, + /*HasVarMask*/ false, DAG, DCI, Subtarget)) { + DCI.CombineTo(N, Res); + return SDValue(); + } } // Constant Folding. @@ -31840,11 +33196,13 @@ static SDValue combineVectorInsert(SDNode *N, SelectionDAG &DAG, // Attempt to combine PINSRB/PINSRW patterns to a shuffle. SDValue Op(N, 0); - SmallVector<int, 1> NonceMask; // Just a placeholder. - NonceMask.push_back(0); - combineX86ShufflesRecursively({Op}, 0, Op, NonceMask, {}, - /*Depth*/ 1, /*HasVarMask*/ false, DAG, - DCI, Subtarget); + if (SDValue Res = combineX86ShufflesRecursively( + {Op}, 0, Op, {0}, {}, /*Depth*/ 1, + /*HasVarMask*/ false, DAG, DCI, Subtarget)) { + DCI.CombineTo(N, Res); + return SDValue(); + } + return SDValue(); } @@ -31911,8 +33269,9 @@ static SDValue combineCompareEqual(SDNode *N, SelectionDAG &DAG, SDValue FSetCC = DAG.getNode(X86ISD::FSETCCM, DL, MVT::v1i1, CMP00, CMP01, DAG.getConstant(x86cc, DL, MVT::i8)); - return DAG.getNode(X86ISD::VEXTRACT, DL, N->getSimpleValueType(0), - FSetCC, DAG.getIntPtrConstant(0, DL)); + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, + N->getSimpleValueType(0), FSetCC, + DAG.getIntPtrConstant(0, DL)); } SDValue OnesOrZeroesF = DAG.getNode(X86ISD::FSETCC, DL, CMP00.getValueType(), CMP00, CMP01, @@ -31976,21 +33335,20 @@ static SDValue combineANDXORWithAllOnesIntoANDNP(SDNode *N, SelectionDAG &DAG) { // register. In most cases we actually compare or select YMM-sized registers // and mixing the two types creates horrible code. This method optimizes // some of the transition sequences. +// Even with AVX-512 this is still useful for removing casts around logical +// operations on vXi1 mask types. static SDValue WidenMaskArithmetic(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { EVT VT = N->getValueType(0); - if (!VT.is256BitVector()) - return SDValue(); + assert(VT.isVector() && "Expected vector type"); assert((N->getOpcode() == ISD::ANY_EXTEND || N->getOpcode() == ISD::ZERO_EXTEND || N->getOpcode() == ISD::SIGN_EXTEND) && "Invalid Node"); SDValue Narrow = N->getOperand(0); - EVT NarrowVT = Narrow->getValueType(0); - if (!NarrowVT.is128BitVector()) - return SDValue(); + EVT NarrowVT = Narrow.getValueType(); if (Narrow->getOpcode() != ISD::XOR && Narrow->getOpcode() != ISD::AND && @@ -32006,51 +33364,40 @@ static SDValue WidenMaskArithmetic(SDNode *N, SelectionDAG &DAG, return SDValue(); // The type of the truncated inputs. - EVT WideVT = N0->getOperand(0)->getValueType(0); - if (WideVT != VT) + if (N0->getOperand(0).getValueType() != VT) return SDValue(); // The right side has to be a 'trunc' or a constant vector. - bool RHSTrunc = N1.getOpcode() == ISD::TRUNCATE; - ConstantSDNode *RHSConstSplat = nullptr; - if (auto *RHSBV = dyn_cast<BuildVectorSDNode>(N1)) - RHSConstSplat = RHSBV->getConstantSplatNode(); - if (!RHSTrunc && !RHSConstSplat) + bool RHSTrunc = N1.getOpcode() == ISD::TRUNCATE && + N1.getOperand(0).getValueType() == VT; + if (!RHSTrunc && + !ISD::isBuildVectorOfConstantSDNodes(N1.getNode())) return SDValue(); const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - if (!TLI.isOperationLegalOrPromote(Narrow->getOpcode(), WideVT)) + if (!TLI.isOperationLegalOrPromote(Narrow->getOpcode(), VT)) return SDValue(); // Set N0 and N1 to hold the inputs to the new wide operation. N0 = N0->getOperand(0); - if (RHSConstSplat) { - N1 = DAG.getNode(ISD::ZERO_EXTEND, DL, WideVT.getVectorElementType(), - SDValue(RHSConstSplat, 0)); - N1 = DAG.getSplatBuildVector(WideVT, DL, N1); - } else if (RHSTrunc) { + if (RHSTrunc) N1 = N1->getOperand(0); - } + else + N1 = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N1); // Generate the wide operation. - SDValue Op = DAG.getNode(Narrow->getOpcode(), DL, WideVT, N0, N1); + SDValue Op = DAG.getNode(Narrow->getOpcode(), DL, VT, N0, N1); unsigned Opcode = N->getOpcode(); switch (Opcode) { + default: llvm_unreachable("Unexpected opcode"); case ISD::ANY_EXTEND: return Op; - case ISD::ZERO_EXTEND: { - unsigned InBits = NarrowVT.getScalarSizeInBits(); - APInt Mask = APInt::getAllOnesValue(InBits); - Mask = Mask.zext(VT.getScalarSizeInBits()); - return DAG.getNode(ISD::AND, DL, VT, - Op, DAG.getConstant(Mask, DL, VT)); - } + case ISD::ZERO_EXTEND: + return DAG.getZeroExtendInReg(Op, DL, NarrowVT.getScalarType()); case ISD::SIGN_EXTEND: return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, Op, DAG.getValueType(NarrowVT)); - default: - llvm_unreachable("Unexpected opcode"); } } @@ -32103,8 +33450,7 @@ static SDValue combineAndMaskToShift(SDNode *N, SelectionDAG &DAG, return SDValue(); APInt SplatVal; - if (!ISD::isConstantSplatVector(Op1.getNode(), SplatVal, - /*AllowShrink*/false) || + if (!ISD::isConstantSplatVector(Op1.getNode(), SplatVal) || !SplatVal.isMask()) return SDValue(); @@ -32122,9 +33468,137 @@ static SDValue combineAndMaskToShift(SDNode *N, SelectionDAG &DAG, return DAG.getBitcast(N->getValueType(0), Shift); } +// Get the index node from the lowered DAG of a GEP IR instruction with one +// indexing dimension. +static SDValue getIndexFromUnindexedLoad(LoadSDNode *Ld) { + if (Ld->isIndexed()) + return SDValue(); + + SDValue Base = Ld->getBasePtr(); + + if (Base.getOpcode() != ISD::ADD) + return SDValue(); + + SDValue ShiftedIndex = Base.getOperand(0); + + if (ShiftedIndex.getOpcode() != ISD::SHL) + return SDValue(); + + return ShiftedIndex.getOperand(0); + +} + +static bool hasBZHI(const X86Subtarget &Subtarget, MVT VT) { + if (Subtarget.hasBMI2() && VT.isScalarInteger()) { + switch (VT.getSizeInBits()) { + default: return false; + case 64: return Subtarget.is64Bit() ? true : false; + case 32: return true; + } + } + return false; +} + +// This function recognizes cases where X86 bzhi instruction can replace and +// 'and-load' sequence. +// In case of loading integer value from an array of constants which is defined +// as follows: +// +// int array[SIZE] = {0x0, 0x1, 0x3, 0x7, 0xF ..., 2^(SIZE-1) - 1} +// +// then applying a bitwise and on the result with another input. +// It's equivalent to performing bzhi (zero high bits) on the input, with the +// same index of the load. +static SDValue combineAndLoadToBZHI(SDNode *Node, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + MVT VT = Node->getSimpleValueType(0); + SDLoc dl(Node); + + // Check if subtarget has BZHI instruction for the node's type + if (!hasBZHI(Subtarget, VT)) + return SDValue(); + + // Try matching the pattern for both operands. + for (unsigned i = 0; i < 2; i++) { + SDValue N = Node->getOperand(i); + LoadSDNode *Ld = dyn_cast<LoadSDNode>(N.getNode()); + + // continue if the operand is not a load instruction + if (!Ld) + return SDValue(); + + const Value *MemOp = Ld->getMemOperand()->getValue(); + + if (!MemOp) + return SDValue(); + + if (const GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(MemOp)) { + if (GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0))) { + if (GV->isConstant() && GV->hasDefinitiveInitializer()) { + + Constant *Init = GV->getInitializer(); + Type *Ty = Init->getType(); + if (!isa<ConstantDataArray>(Init) || + !Ty->getArrayElementType()->isIntegerTy() || + Ty->getArrayElementType()->getScalarSizeInBits() != + VT.getSizeInBits() || + Ty->getArrayNumElements() > + Ty->getArrayElementType()->getScalarSizeInBits()) + continue; + + // Check if the array's constant elements are suitable to our case. + uint64_t ArrayElementCount = Init->getType()->getArrayNumElements(); + bool ConstantsMatch = true; + for (uint64_t j = 0; j < ArrayElementCount; j++) { + ConstantInt *Elem = + dyn_cast<ConstantInt>(Init->getAggregateElement(j)); + if (Elem->getZExtValue() != (((uint64_t)1 << j) - 1)) { + ConstantsMatch = false; + break; + } + } + if (!ConstantsMatch) + continue; + + // Do the transformation (For 32-bit type): + // -> (and (load arr[idx]), inp) + // <- (and (srl 0xFFFFFFFF, (sub 32, idx))) + // that will be replaced with one bzhi instruction. + SDValue Inp = (i == 0) ? Node->getOperand(1) : Node->getOperand(0); + SDValue SizeC = DAG.getConstant(VT.getSizeInBits(), dl, VT); + + // Get the Node which indexes into the array. + SDValue Index = getIndexFromUnindexedLoad(Ld); + if (!Index) + return SDValue(); + Index = DAG.getZExtOrTrunc(Index, dl, VT); + + SDValue Sub = DAG.getNode(ISD::SUB, dl, VT, SizeC, Index); + + SDValue AllOnes = DAG.getAllOnesConstant(dl, VT); + SDValue LShr = DAG.getNode(ISD::SRL, dl, VT, AllOnes, Sub); + + return DAG.getNode(ISD::AND, dl, VT, Inp, LShr); + } + } + } + } + return SDValue(); +} + static SDValue combineAnd(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { + EVT VT = N->getValueType(0); + + // If this is SSE1 only convert to FAND to avoid scalarization. + if (Subtarget.hasSSE1() && !Subtarget.hasSSE2() && VT == MVT::v4i32) { + return DAG.getBitcast( + MVT::v4i32, DAG.getNode(X86ISD::FAND, SDLoc(N), MVT::v4f32, + DAG.getBitcast(MVT::v4f32, N->getOperand(0)), + DAG.getBitcast(MVT::v4f32, N->getOperand(1)))); + } + if (DCI.isBeforeLegalizeOps()) return SDValue(); @@ -32140,45 +33614,59 @@ static SDValue combineAnd(SDNode *N, SelectionDAG &DAG, if (SDValue ShiftRight = combineAndMaskToShift(N, DAG, Subtarget)) return ShiftRight; - EVT VT = N->getValueType(0); - SDValue N0 = N->getOperand(0); - SDValue N1 = N->getOperand(1); - SDLoc DL(N); + if (SDValue R = combineAndLoadToBZHI(N, DAG, Subtarget)) + return R; // Attempt to recursively combine a bitmask AND with shuffles. if (VT.isVector() && (VT.getScalarSizeInBits() % 8) == 0) { SDValue Op(N, 0); - SmallVector<int, 1> NonceMask; // Just a placeholder. - NonceMask.push_back(0); - if (combineX86ShufflesRecursively({Op}, 0, Op, NonceMask, {}, - /*Depth*/ 1, /*HasVarMask*/ false, DAG, - DCI, Subtarget)) - return SDValue(); // This routine will use CombineTo to replace N. + if (SDValue Res = combineX86ShufflesRecursively( + {Op}, 0, Op, {0}, {}, /*Depth*/ 1, + /*HasVarMask*/ false, DAG, DCI, Subtarget)) { + DCI.CombineTo(N, Res); + return SDValue(); + } } - // Create BEXTR instructions - // BEXTR is ((X >> imm) & (2**size-1)) - if (VT != MVT::i32 && VT != MVT::i64) - return SDValue(); + // Attempt to combine a scalar bitmask AND with an extracted shuffle. + if ((VT.getScalarSizeInBits() % 8) == 0 && + N->getOperand(0).getOpcode() == ISD::EXTRACT_VECTOR_ELT && + isa<ConstantSDNode>(N->getOperand(0).getOperand(1))) { + SDValue BitMask = N->getOperand(1); + SDValue SrcVec = N->getOperand(0).getOperand(0); + EVT SrcVecVT = SrcVec.getValueType(); - if (!Subtarget.hasBMI() && !Subtarget.hasTBM()) - return SDValue(); - if (N0.getOpcode() != ISD::SRA && N0.getOpcode() != ISD::SRL) - return SDValue(); + // Check that the constant bitmask masks whole bytes. + APInt UndefElts; + SmallVector<APInt, 64> EltBits; + if (VT == SrcVecVT.getScalarType() && + N->getOperand(0)->isOnlyUserOf(SrcVec.getNode()) && + getTargetConstantBitsFromNode(BitMask, 8, UndefElts, EltBits) && + llvm::all_of(EltBits, [](APInt M) { + return M.isNullValue() || M.isAllOnesValue(); + })) { + unsigned NumElts = SrcVecVT.getVectorNumElements(); + unsigned Scale = SrcVecVT.getScalarSizeInBits() / 8; + unsigned Idx = N->getOperand(0).getConstantOperandVal(1); + + // Create a root shuffle mask from the byte mask and the extracted index. + SmallVector<int, 16> ShuffleMask(NumElts * Scale, SM_SentinelUndef); + for (unsigned i = 0; i != Scale; ++i) { + if (UndefElts[i]) + continue; + int VecIdx = Scale * Idx + i; + ShuffleMask[VecIdx] = + EltBits[i].isNullValue() ? SM_SentinelZero : VecIdx; + } - ConstantSDNode *MaskNode = dyn_cast<ConstantSDNode>(N1); - ConstantSDNode *ShiftNode = dyn_cast<ConstantSDNode>(N0.getOperand(1)); - if (MaskNode && ShiftNode) { - uint64_t Mask = MaskNode->getZExtValue(); - uint64_t Shift = ShiftNode->getZExtValue(); - if (isMask_64(Mask)) { - uint64_t MaskSize = countPopulation(Mask); - if (Shift + MaskSize <= VT.getSizeInBits()) - return DAG.getNode(X86ISD::BEXTR, DL, VT, N0.getOperand(0), - DAG.getConstant(Shift | (MaskSize << 8), DL, - VT)); + if (SDValue Shuffle = combineX86ShufflesRecursively( + {SrcVec}, 0, SrcVec, ShuffleMask, {}, /*Depth*/ 2, + /*HasVarMask*/ false, DAG, DCI, Subtarget)) + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SDLoc(N), VT, Shuffle, + N->getOperand(0).getOperand(1)); } } + return SDValue(); } @@ -32411,6 +33899,18 @@ static SDValue combineOrCmpEqZeroToCtlzSrl(SDNode *N, SelectionDAG &DAG, static SDValue combineOr(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + EVT VT = N->getValueType(0); + + // If this is SSE1 only convert to FOR to avoid scalarization. + if (Subtarget.hasSSE1() && !Subtarget.hasSSE2() && VT == MVT::v4i32) { + return DAG.getBitcast(MVT::v4i32, + DAG.getNode(X86ISD::FOR, SDLoc(N), MVT::v4f32, + DAG.getBitcast(MVT::v4f32, N0), + DAG.getBitcast(MVT::v4f32, N1))); + } + if (DCI.isBeforeLegalizeOps()) return SDValue(); @@ -32423,15 +33923,11 @@ static SDValue combineOr(SDNode *N, SelectionDAG &DAG, if (SDValue R = combineLogicBlendIntoPBLENDV(N, DAG, Subtarget)) return R; - SDValue N0 = N->getOperand(0); - SDValue N1 = N->getOperand(1); - EVT VT = N->getValueType(0); - if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64) return SDValue(); // fold (or (x << c) | (y >> (64 - c))) ==> (shld64 x, y, c) - bool OptForSize = DAG.getMachineFunction().getFunction()->optForSize(); + bool OptForSize = DAG.getMachineFunction().getFunction().optForSize(); // SHLD/SHRD instructions have lower register pressure, but on some // platforms they have higher latency than the equivalent @@ -32521,38 +34017,6 @@ static SDValue combineOr(SDNode *N, SelectionDAG &DAG, return SDValue(); } -/// Generate NEG and CMOV for integer abs. -static SDValue combineIntegerAbs(SDNode *N, SelectionDAG &DAG) { - EVT VT = N->getValueType(0); - - // Since X86 does not have CMOV for 8-bit integer, we don't convert - // 8-bit integer abs to NEG and CMOV. - if (VT.isInteger() && VT.getSizeInBits() == 8) - return SDValue(); - - SDValue N0 = N->getOperand(0); - SDValue N1 = N->getOperand(1); - SDLoc DL(N); - - // Check pattern of XOR(ADD(X,Y), Y) where Y is SRA(X, size(X)-1) - // and change it to SUB and CMOV. - if (VT.isInteger() && N->getOpcode() == ISD::XOR && - N0.getOpcode() == ISD::ADD && N0.getOperand(1) == N1 && - N1.getOpcode() == ISD::SRA && N1.getOperand(0) == N0.getOperand(0)) { - auto *Y1C = dyn_cast<ConstantSDNode>(N1.getOperand(1)); - if (Y1C && Y1C->getAPIntValue() == VT.getSizeInBits() - 1) { - // Generate SUB & CMOV. - SDValue Neg = DAG.getNode(X86ISD::SUB, DL, DAG.getVTList(VT, MVT::i32), - DAG.getConstant(0, DL, VT), N0.getOperand(0)); - SDValue Ops[] = {N0.getOperand(0), Neg, - DAG.getConstant(X86::COND_GE, DL, MVT::i8), - SDValue(Neg.getNode(), 1)}; - return DAG.getNode(X86ISD::CMOV, DL, DAG.getVTList(VT, MVT::Glue), Ops); - } - } - return SDValue(); -} - /// Try to turn tests against the signbit in the form of: /// XOR(TRUNCATE(SRL(X, size(X)-1)), 1) /// into: @@ -32688,8 +34152,7 @@ static SDValue detectUSatPattern(SDValue In, EVT VT) { "Unexpected types for truncate operation"); APInt C; - if (ISD::isConstantSplatVector(In.getOperand(1).getNode(), C, - /*AllowShrink*/false)) { + if (ISD::isConstantSplatVector(In.getOperand(1).getNode(), C)) { // C should be equal to UINT32_MAX / UINT16_MAX / UINT8_MAX according // the element size of the destination type. return C.isMask(VT.getScalarSizeInBits()) ? In.getOperand(0) : @@ -32746,16 +34209,6 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG, if (!Subtarget.hasSSE2()) return SDValue(); - if (Subtarget.hasBWI()) { - if (VT.getSizeInBits() > 512) - return SDValue(); - } else if (Subtarget.hasAVX2()) { - if (VT.getSizeInBits() > 256) - return SDValue(); - } else { - if (VT.getSizeInBits() > 128) - return SDValue(); - } // Detect the following pattern: // @@ -32767,7 +34220,6 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG, // %6 = trunc <N x i32> %5 to <N x i8> // // In AVX512, the last instruction can also be a trunc store. - if (In.getOpcode() != ISD::SRL) return SDValue(); @@ -32788,6 +34240,35 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG, return true; }; + // Split vectors to legal target size and apply AVG. + auto LowerToAVG = [&](SDValue Op0, SDValue Op1) { + unsigned NumSubs = 1; + if (Subtarget.hasBWI()) { + if (VT.getSizeInBits() > 512) + NumSubs = VT.getSizeInBits() / 512; + } else if (Subtarget.hasAVX2()) { + if (VT.getSizeInBits() > 256) + NumSubs = VT.getSizeInBits() / 256; + } else { + if (VT.getSizeInBits() > 128) + NumSubs = VT.getSizeInBits() / 128; + } + + if (NumSubs == 1) + return DAG.getNode(X86ISD::AVG, DL, VT, Op0, Op1); + + SmallVector<SDValue, 4> Subs; + EVT SubVT = EVT::getVectorVT(*DAG.getContext(), VT.getScalarType(), + VT.getVectorNumElements() / NumSubs); + for (unsigned i = 0; i != NumSubs; ++i) { + unsigned Idx = i * SubVT.getVectorNumElements(); + SDValue LHS = extractSubVector(Op0, Idx, DAG, DL, SubVT.getSizeInBits()); + SDValue RHS = extractSubVector(Op1, Idx, DAG, DL, SubVT.getSizeInBits()); + Subs.push_back(DAG.getNode(X86ISD::AVG, DL, SubVT, LHS, RHS)); + } + return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Subs); + }; + // Check if each element of the vector is left-shifted by one. auto LHS = In.getOperand(0); auto RHS = In.getOperand(1); @@ -32811,8 +34292,7 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG, SDValue VecOnes = DAG.getConstant(1, DL, InVT); Operands[1] = DAG.getNode(ISD::SUB, DL, InVT, Operands[1], VecOnes); Operands[1] = DAG.getNode(ISD::TRUNCATE, DL, VT, Operands[1]); - return DAG.getNode(X86ISD::AVG, DL, VT, Operands[0].getOperand(0), - Operands[1]); + return LowerToAVG(Operands[0].getOperand(0), Operands[1]); } if (Operands[0].getOpcode() == ISD::ADD) @@ -32836,8 +34316,7 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG, return SDValue(); // The pattern is detected, emit X86ISD::AVG instruction. - return DAG.getNode(X86ISD::AVG, DL, VT, Operands[0].getOperand(0), - Operands[1].getOperand(0)); + return LowerToAVG(Operands[0].getOperand(0), Operands[1].getOperand(0)); } return SDValue(); @@ -32878,15 +34357,14 @@ static SDValue combineLoad(SDNode *N, SelectionDAG &DAG, Ptr = DAG.getMemBasePlusOffset(Ptr, 16, dl); SDValue Load2 = - DAG.getLoad(HalfVT, dl, Ld->getChain(), Ptr, Ld->getPointerInfo(), - std::min(16U, Alignment), Ld->getMemOperand()->getFlags()); + DAG.getLoad(HalfVT, dl, Ld->getChain(), Ptr, + Ld->getPointerInfo().getWithOffset(16), + MinAlign(Alignment, 16U), Ld->getMemOperand()->getFlags()); SDValue TF = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, Load1.getValue(1), Load2.getValue(1)); - SDValue NewVec = DAG.getUNDEF(RegVT); - NewVec = insert128BitVector(NewVec, Load1, 0, DAG, dl); - NewVec = insert128BitVector(NewVec, Load2, NumElems / 2, DAG, dl); + SDValue NewVec = DAG.getNode(ISD::CONCAT_VECTORS, dl, RegVT, Load1, Load2); return DCI.CombineTo(N, NewVec, TF, true); } @@ -33081,6 +34559,7 @@ static SDValue combineMaskedLoad(SDNode *N, SelectionDAG &DAG, WideSrc0 = DAG.getVectorShuffle(WideVecVT, dl, WideSrc0, DAG.getUNDEF(WideVecVT), ShuffleVec); } + // Prepare the new mask. SDValue NewMask; SDValue Mask = Mld->getMask(); @@ -33103,12 +34582,9 @@ static SDValue combineMaskedLoad(SDNode *N, SelectionDAG &DAG, WidenNumElts); unsigned NumConcat = WidenNumElts / MaskNumElts; - SmallVector<SDValue, 16> Ops(NumConcat); SDValue ZeroVal = DAG.getConstant(0, dl, Mask.getValueType()); + SmallVector<SDValue, 16> Ops(NumConcat, ZeroVal); Ops[0] = Mask; - for (unsigned i = 1; i != NumConcat; ++i) - Ops[i] = ZeroVal; - NewMask = DAG.getNode(ISD::CONCAT_VECTORS, dl, NewMaskVT, Ops); } @@ -33154,8 +34630,33 @@ static SDValue combineMaskedStore(SDNode *N, SelectionDAG &DAG, if (Mst->isCompressingStore()) return SDValue(); - if (!Mst->isTruncatingStore()) - return reduceMaskedStoreToScalarStore(Mst, DAG); + if (!Mst->isTruncatingStore()) { + if (SDValue ScalarStore = reduceMaskedStoreToScalarStore(Mst, DAG)) + return ScalarStore; + + // If the mask is checking (0 > X), we're creating a vector with all-zeros + // or all-ones elements based on the sign bits of X. AVX1 masked store only + // cares about the sign bit of each mask element, so eliminate the compare: + // mstore val, ptr, (pcmpgt 0, X) --> mstore val, ptr, X + // Note that by waiting to match an x86-specific PCMPGT node, we're + // eliminating potentially more complex matching of a setcc node which has + // a full range of predicates. + SDValue Mask = Mst->getMask(); + if (Mask.getOpcode() == X86ISD::PCMPGT && + ISD::isBuildVectorAllZeros(Mask.getOperand(0).getNode())) { + assert(Mask.getValueType() == Mask.getOperand(1).getValueType() && + "Unexpected type for PCMPGT"); + return DAG.getMaskedStore( + Mst->getChain(), SDLoc(N), Mst->getValue(), Mst->getBasePtr(), + Mask.getOperand(1), Mst->getMemoryVT(), Mst->getMemOperand()); + } + + // TODO: AVX512 targets should also be able to simplify something like the + // pattern above, but that pattern will be different. It will either need to + // match setcc more generally or match PCMPGTM later (in tablegen?). + + return SDValue(); + } // Resolve truncating stores. EVT VT = Mst->getValue().getValueType(); @@ -33226,12 +34727,9 @@ static SDValue combineMaskedStore(SDNode *N, SelectionDAG &DAG, WidenNumElts); unsigned NumConcat = WidenNumElts / MaskNumElts; - SmallVector<SDValue, 16> Ops(NumConcat); SDValue ZeroVal = DAG.getConstant(0, dl, Mask.getValueType()); + SmallVector<SDValue, 16> Ops(NumConcat, ZeroVal); Ops[0] = Mask; - for (unsigned i = 1; i != NumConcat; ++i) - Ops[i] = ZeroVal; - NewMask = DAG.getNode(ISD::CONCAT_VECTORS, dl, NewMaskVT, Ops); } @@ -33272,8 +34770,9 @@ static SDValue combineStore(SDNode *N, SelectionDAG &DAG, DAG.getStore(St->getChain(), dl, Value0, Ptr0, St->getPointerInfo(), Alignment, St->getMemOperand()->getFlags()); SDValue Ch1 = - DAG.getStore(St->getChain(), dl, Value1, Ptr1, St->getPointerInfo(), - std::min(16U, Alignment), St->getMemOperand()->getFlags()); + DAG.getStore(St->getChain(), dl, Value1, Ptr1, + St->getPointerInfo().getWithOffset(16), + MinAlign(Alignment, 16U), St->getMemOperand()->getFlags()); return DAG.getNode(ISD::TokenFactor, dl, MVT::Other, Ch0, Ch1); } @@ -33384,8 +34883,8 @@ static SDValue combineStore(SDNode *N, SelectionDAG &DAG, if (VT.getSizeInBits() != 64) return SDValue(); - const Function *F = DAG.getMachineFunction().getFunction(); - bool NoImplicitFloatOps = F->hasFnAttribute(Attribute::NoImplicitFloat); + const Function &F = DAG.getMachineFunction().getFunction(); + bool NoImplicitFloatOps = F.hasFnAttribute(Attribute::NoImplicitFloat); bool F64IsLegal = !Subtarget.useSoftFloat() && !NoImplicitFloatOps && Subtarget.hasSSE2(); if ((VT.isVector() || @@ -33393,28 +34892,10 @@ static SDValue combineStore(SDNode *N, SelectionDAG &DAG, isa<LoadSDNode>(St->getValue()) && !cast<LoadSDNode>(St->getValue())->isVolatile() && St->getChain().hasOneUse() && !St->isVolatile()) { - SDNode* LdVal = St->getValue().getNode(); - LoadSDNode *Ld = nullptr; - int TokenFactorIndex = -1; + LoadSDNode *Ld = cast<LoadSDNode>(St->getValue().getNode()); SmallVector<SDValue, 8> Ops; - SDNode* ChainVal = St->getChain().getNode(); - // Must be a store of a load. We currently handle two cases: the load - // is a direct child, and it's under an intervening TokenFactor. It is - // possible to dig deeper under nested TokenFactors. - if (ChainVal == LdVal) - Ld = cast<LoadSDNode>(St->getChain()); - else if (St->getValue().hasOneUse() && - ChainVal->getOpcode() == ISD::TokenFactor) { - for (unsigned i = 0, e = ChainVal->getNumOperands(); i != e; ++i) { - if (ChainVal->getOperand(i).getNode() == LdVal) { - TokenFactorIndex = i; - Ld = cast<LoadSDNode>(St->getValue()); - } else - Ops.push_back(ChainVal->getOperand(i)); - } - } - if (!Ld || !ISD::isNormalLoad(Ld)) + if (!ISD::isNormalLoad(Ld)) return SDValue(); // If this is not the MMX case, i.e. we are just turning i64 load/store @@ -33431,17 +34912,12 @@ static SDValue combineStore(SDNode *N, SelectionDAG &DAG, if (Subtarget.is64Bit() || F64IsLegal) { MVT LdVT = Subtarget.is64Bit() ? MVT::i64 : MVT::f64; SDValue NewLd = DAG.getLoad(LdVT, LdDL, Ld->getChain(), Ld->getBasePtr(), - Ld->getPointerInfo(), Ld->getAlignment(), - Ld->getMemOperand()->getFlags()); + Ld->getMemOperand()); + // Make sure new load is placed in same chain order. - SDValue NewChain = DAG.makeEquivalentMemoryOrdering(Ld, NewLd); - if (TokenFactorIndex >= 0) { - Ops.push_back(NewChain); - NewChain = DAG.getNode(ISD::TokenFactor, LdDL, MVT::Other, Ops); - } - return DAG.getStore(NewChain, StDL, NewLd, St->getBasePtr(), - St->getPointerInfo(), St->getAlignment(), - St->getMemOperand()->getFlags()); + DAG.makeEquivalentMemoryOrdering(Ld, NewLd); + return DAG.getStore(St->getChain(), StDL, NewLd, St->getBasePtr(), + St->getMemOperand()); } // Otherwise, lower to two pairs of 32-bit loads / stores. @@ -33456,23 +34932,19 @@ static SDValue combineStore(SDNode *N, SelectionDAG &DAG, MinAlign(Ld->getAlignment(), 4), Ld->getMemOperand()->getFlags()); // Make sure new loads are placed in same chain order. - SDValue NewChain = DAG.makeEquivalentMemoryOrdering(Ld, LoLd); - NewChain = DAG.makeEquivalentMemoryOrdering(Ld, HiLd); - - if (TokenFactorIndex >= 0) { - Ops.push_back(NewChain); - NewChain = DAG.getNode(ISD::TokenFactor, LdDL, MVT::Other, Ops); - } + DAG.makeEquivalentMemoryOrdering(Ld, LoLd); + DAG.makeEquivalentMemoryOrdering(Ld, HiLd); LoAddr = St->getBasePtr(); HiAddr = DAG.getMemBasePlusOffset(LoAddr, 4, StDL); SDValue LoSt = - DAG.getStore(NewChain, StDL, LoLd, LoAddr, St->getPointerInfo(), + DAG.getStore(St->getChain(), StDL, LoLd, LoAddr, St->getPointerInfo(), St->getAlignment(), St->getMemOperand()->getFlags()); - SDValue HiSt = DAG.getStore( - NewChain, StDL, HiLd, HiAddr, St->getPointerInfo().getWithOffset(4), - MinAlign(St->getAlignment(), 4), St->getMemOperand()->getFlags()); + SDValue HiSt = DAG.getStore(St->getChain(), StDL, HiLd, HiAddr, + St->getPointerInfo().getWithOffset(4), + MinAlign(St->getAlignment(), 4), + St->getMemOperand()->getFlags()); return DAG.getNode(ISD::TokenFactor, StDL, MVT::Other, LoSt, HiSt); } @@ -33722,10 +35194,11 @@ static SDValue combineTruncatedArithmetic(SDNode *N, SelectionDAG &DAG, // X86 is rubbish at scalar and vector i64 multiplies (until AVX512DQ) - its // better to truncate if we have the chance. if (SrcVT.getScalarType() == MVT::i64 && TLI.isOperationLegal(Opcode, VT) && - !TLI.isOperationLegal(Opcode, SrcVT)) + !Subtarget.hasDQI()) return TruncateArithmetic(Src.getOperand(0), Src.getOperand(1)); LLVM_FALLTHROUGH; case ISD::ADD: { + // TODO: ISD::SUB should be here but interferes with combineSubToSubus. SDValue Op0 = Src.getOperand(0); SDValue Op1 = Src.getOperand(1); if (TLI.isOperationLegal(Opcode, VT) && @@ -33882,8 +35355,9 @@ static SDValue combineVectorTruncation(SDNode *N, SelectionDAG &DAG, return SDValue(); } -/// This function transforms vector truncation of 'all or none' bits values. -/// vXi16/vXi32/vXi64 to vXi8/vXi16/vXi32 into X86ISD::PACKSS operations. +/// This function transforms vector truncation of 'extended sign-bits' or +/// 'extended zero-bits' values. +/// vXi16/vXi32/vXi64 to vXi8/vXi16/vXi32 into X86ISD::PACKSS/PACKUS operations. static SDValue combineVectorSignBitsTruncation(SDNode *N, SDLoc &DL, SelectionDAG &DAG, const X86Subtarget &Subtarget) { @@ -33904,12 +35378,6 @@ static SDValue combineVectorSignBitsTruncation(SDNode *N, SDLoc &DL, MVT InVT = In.getValueType().getSimpleVT(); MVT InSVT = InVT.getScalarType(); - // Use PACKSS if the input is a splatted sign bit. - // e.g. Comparison result, sext_in_reg, etc. - unsigned NumSignBits = DAG.ComputeNumSignBits(In); - if (NumSignBits != InSVT.getSizeInBits()) - return SDValue(); - // Check we have a truncation suited for PACKSS. if (!VT.is128BitVector() && !VT.is256BitVector()) return SDValue(); @@ -33918,7 +35386,23 @@ static SDValue combineVectorSignBitsTruncation(SDNode *N, SDLoc &DL, if (InSVT != MVT::i16 && InSVT != MVT::i32 && InSVT != MVT::i64) return SDValue(); - return truncateVectorCompareWithPACKSS(VT, In, DL, DAG, Subtarget); + // Use PACKSS if the input has sign-bits that extend all the way to the + // packed/truncated value. e.g. Comparison result, sext_in_reg, etc. + unsigned NumSignBits = DAG.ComputeNumSignBits(In); + unsigned NumPackedBits = std::min<unsigned>(SVT.getSizeInBits(), 16); + if (NumSignBits > (InSVT.getSizeInBits() - NumPackedBits)) + return truncateVectorWithPACK(X86ISD::PACKSS, VT, In, DL, DAG, Subtarget); + + // Use PACKUS if the input has zero-bits that extend all the way to the + // packed/truncated value. e.g. masks, zext_in_reg, etc. + KnownBits Known; + DAG.computeKnownBits(In, Known); + unsigned NumLeadingZeroBits = Known.countMinLeadingZeros(); + NumPackedBits = Subtarget.hasSSE41() ? NumPackedBits : 8; + if (NumLeadingZeroBits >= (InSVT.getSizeInBits() - NumPackedBits)) + return truncateVectorWithPACK(X86ISD::PACKUS, VT, In, DL, DAG, Subtarget); + + return SDValue(); } static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG, @@ -33947,7 +35431,7 @@ static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG, return DAG.getNode(X86ISD::MMX_MOVD2W, DL, MVT::i32, BCSrc); } - // Try to truncate extended sign bits with PACKSS. + // Try to truncate extended sign/zero bits with PACKSS/PACKUS. if (SDValue V = combineVectorSignBitsTruncation(N, DL, DAG, Subtarget)) return V; @@ -34036,12 +35520,12 @@ static SDValue combineFneg(SDNode *N, SelectionDAG &DAG, // If we're negating an FMA node, then we can adjust the // instruction to include the extra negation. unsigned NewOpcode = 0; - if (Arg.hasOneUse()) { + if (Arg.hasOneUse() && Subtarget.hasAnyFMA()) { switch (Arg.getOpcode()) { - case X86ISD::FMADD: NewOpcode = X86ISD::FNMSUB; break; + case ISD::FMA: NewOpcode = X86ISD::FNMSUB; break; case X86ISD::FMSUB: NewOpcode = X86ISD::FNMADD; break; case X86ISD::FNMADD: NewOpcode = X86ISD::FMSUB; break; - case X86ISD::FNMSUB: NewOpcode = X86ISD::FMADD; break; + case X86ISD::FNMSUB: NewOpcode = ISD::FMA; break; case X86ISD::FMADD_RND: NewOpcode = X86ISD::FNMSUB_RND; break; case X86ISD::FMSUB_RND: NewOpcode = X86ISD::FNMADD_RND; break; case X86ISD::FNMADD_RND: NewOpcode = X86ISD::FMSUB_RND; break; @@ -34083,22 +35567,47 @@ static SDValue lowerX86FPLogicOp(SDNode *N, SelectionDAG &DAG, return SDValue(); } + +/// Fold a xor(setcc cond, val), 1 --> setcc (inverted(cond), val) +static SDValue foldXor1SetCC(SDNode *N, SelectionDAG &DAG) { + if (N->getOpcode() != ISD::XOR) + return SDValue(); + + SDValue LHS = N->getOperand(0); + auto *RHSC = dyn_cast<ConstantSDNode>(N->getOperand(1)); + if (!RHSC || RHSC->getZExtValue() != 1 || LHS->getOpcode() != X86ISD::SETCC) + return SDValue(); + + X86::CondCode NewCC = X86::GetOppositeBranchCondition( + X86::CondCode(LHS->getConstantOperandVal(0))); + SDLoc DL(N); + return getSETCC(NewCC, LHS->getOperand(1), DL, DAG); +} + static SDValue combineXor(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { + // If this is SSE1 only convert to FXOR to avoid scalarization. + if (Subtarget.hasSSE1() && !Subtarget.hasSSE2() && + N->getValueType(0) == MVT::v4i32) { + return DAG.getBitcast( + MVT::v4i32, DAG.getNode(X86ISD::FXOR, SDLoc(N), MVT::v4f32, + DAG.getBitcast(MVT::v4f32, N->getOperand(0)), + DAG.getBitcast(MVT::v4f32, N->getOperand(1)))); + } + if (SDValue Cmp = foldVectorXorShiftIntoCmp(N, DAG, Subtarget)) return Cmp; if (DCI.isBeforeLegalizeOps()) return SDValue(); + if (SDValue SetCC = foldXor1SetCC(N, DAG)) + return SetCC; + if (SDValue RV = foldXorTruncShiftIntoCmp(N, DAG)) return RV; - if (Subtarget.hasCMov()) - if (SDValue RV = combineIntegerAbs(N, DAG)) - return RV; - if (SDValue FPLogic = convertIntLogicToFPLogic(N, DAG, Subtarget)) return FPLogic; @@ -34138,10 +35647,13 @@ static SDValue combineFAndFNotToFAndn(SDNode *N, SelectionDAG &DAG, // Vector types are handled in combineANDXORWithAllOnesIntoANDNP(). if (!((VT == MVT::f32 && Subtarget.hasSSE1()) || - (VT == MVT::f64 && Subtarget.hasSSE2()))) + (VT == MVT::f64 && Subtarget.hasSSE2()) || + (VT == MVT::v4f32 && Subtarget.hasSSE1() && !Subtarget.hasSSE2()))) return SDValue(); auto isAllOnesConstantFP = [](SDValue V) { + if (V.getSimpleValueType().isVector()) + return ISD::isBuildVectorAllOnes(V.getNode()); auto *C = dyn_cast<ConstantFPSDNode>(V); return C && C->getConstantFPValue()->isAllOnesValue(); }; @@ -34247,7 +35759,7 @@ static SDValue combineFMinNumFMaxNum(SDNode *N, SelectionDAG &DAG, // This takes at least 3 instructions, so favor a library call when operating // on a scalar and minimizing code size. - if (!VT.isVector() && DAG.getMachineFunction().getFunction()->optForMinSize()) + if (!VT.isVector() && DAG.getMachineFunction().getFunction().optForMinSize()) return SDValue(); SDValue Op0 = N->getOperand(0); @@ -34301,12 +35813,12 @@ static SDValue combineAndnp(SDNode *N, SelectionDAG &DAG, // Attempt to recursively combine a bitmask ANDNP with shuffles. if (VT.isVector() && (VT.getScalarSizeInBits() % 8) == 0) { SDValue Op(N, 0); - SmallVector<int, 1> NonceMask; // Just a placeholder. - NonceMask.push_back(0); - if (combineX86ShufflesRecursively({Op}, 0, Op, NonceMask, {}, - /*Depth*/ 1, /*HasVarMask*/ false, DAG, - DCI, Subtarget)) - return SDValue(); // This routine will use CombineTo to replace N. + if (SDValue Res = combineX86ShufflesRecursively( + {Op}, 0, Op, {0}, {}, /*Depth*/ 1, + /*HasVarMask*/ false, DAG, DCI, Subtarget)) { + DCI.CombineTo(N, Res); + return SDValue(); + } } return SDValue(); @@ -34314,19 +35826,15 @@ static SDValue combineAndnp(SDNode *N, SelectionDAG &DAG, static SDValue combineBT(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI) { + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + // BT ignores high bits in the bit index operand. - SDValue Op1 = N->getOperand(1); - if (Op1.hasOneUse()) { - unsigned BitWidth = Op1.getValueSizeInBits(); - APInt DemandedMask = APInt::getLowBitsSet(BitWidth, Log2_32(BitWidth)); - KnownBits Known; - TargetLowering::TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(), - !DCI.isBeforeLegalizeOps()); - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - if (TLI.ShrinkDemandedConstant(Op1, DemandedMask, TLO) || - TLI.SimplifyDemandedBits(Op1, DemandedMask, Known, TLO)) - DCI.CommitTargetLoweringOpt(TLO); - } + unsigned BitWidth = N1.getValueSizeInBits(); + APInt DemandedMask = APInt::getLowBitsSet(BitWidth, Log2_32(BitWidth)); + if (SDValue DemandedN1 = DAG.GetDemandedBits(N1, DemandedMask)) + return DAG.getNode(X86ISD::BT, SDLoc(N), MVT::i32, N0, DemandedN1); + return SDValue(); } @@ -34444,18 +35952,152 @@ static SDValue getDivRem8(SDNode *N, SelectionDAG &DAG) { EVT VT = N->getValueType(0); EVT InVT = N0.getValueType(); - if (N0.getResNo() != 1 || InVT != MVT::i8 || VT != MVT::i32) + if (N0.getResNo() != 1 || InVT != MVT::i8 || + !(VT == MVT::i32 || VT == MVT::i64)) return SDValue(); - SDVTList NodeTys = DAG.getVTList(MVT::i8, VT); + SDVTList NodeTys = DAG.getVTList(MVT::i8, MVT::i32); auto DivRemOpcode = OpcodeN0 == ISD::SDIVREM ? X86ISD::SDIVREM8_SEXT_HREG : X86ISD::UDIVREM8_ZEXT_HREG; SDValue R = DAG.getNode(DivRemOpcode, SDLoc(N), NodeTys, N0.getOperand(0), N0.getOperand(1)); DAG.ReplaceAllUsesOfValueWith(N0.getValue(0), R.getValue(0)); + // If this was a 64-bit extend, complete it. + if (VT == MVT::i64) + return DAG.getNode(OpcodeN, SDLoc(N), VT, R.getValue(1)); return R.getValue(1); } +// If we face {ANY,SIGN,ZERO}_EXTEND that is applied to a CMOV with constant +// operands and the result of CMOV is not used anywhere else - promote CMOV +// itself instead of promoting its result. This could be beneficial, because: +// 1) X86TargetLowering::EmitLoweredSelect later can do merging of two +// (or more) pseudo-CMOVs only when they go one-after-another and +// getting rid of result extension code after CMOV will help that. +// 2) Promotion of constant CMOV arguments is free, hence the +// {ANY,SIGN,ZERO}_EXTEND will just be deleted. +// 3) 16-bit CMOV encoding is 4 bytes, 32-bit CMOV is 3-byte, so this +// promotion is also good in terms of code-size. +// (64-bit CMOV is 4-bytes, that's why we don't do 32-bit => 64-bit +// promotion). +static SDValue combineToExtendCMOV(SDNode *Extend, SelectionDAG &DAG) { + SDValue CMovN = Extend->getOperand(0); + if (CMovN.getOpcode() != X86ISD::CMOV) + return SDValue(); + + EVT TargetVT = Extend->getValueType(0); + unsigned ExtendOpcode = Extend->getOpcode(); + SDLoc DL(Extend); + + EVT VT = CMovN.getValueType(); + SDValue CMovOp0 = CMovN.getOperand(0); + SDValue CMovOp1 = CMovN.getOperand(1); + + bool DoPromoteCMOV = + (VT == MVT::i16 && (TargetVT == MVT::i32 || TargetVT == MVT::i64)) && + CMovN.hasOneUse() && + (isa<ConstantSDNode>(CMovOp0.getNode()) && + isa<ConstantSDNode>(CMovOp1.getNode())); + + if (!DoPromoteCMOV) + return SDValue(); + + CMovOp0 = DAG.getNode(ExtendOpcode, DL, TargetVT, CMovOp0); + CMovOp1 = DAG.getNode(ExtendOpcode, DL, TargetVT, CMovOp1); + + return DAG.getNode(X86ISD::CMOV, DL, TargetVT, CMovOp0, CMovOp1, + CMovN.getOperand(2), CMovN.getOperand(3)); +} + +// Convert (vXiY *ext(vXi1 bitcast(iX))) to extend_in_reg(broadcast(iX)). +// This is more or less the reverse of combineBitcastvxi1. +static SDValue +combineToExtendBoolVectorInReg(SDNode *N, SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI, + const X86Subtarget &Subtarget) { + unsigned Opcode = N->getOpcode(); + if (Opcode != ISD::SIGN_EXTEND && Opcode != ISD::ZERO_EXTEND && + Opcode != ISD::ANY_EXTEND) + return SDValue(); + if (!DCI.isBeforeLegalizeOps()) + return SDValue(); + if (!Subtarget.hasSSE2() || Subtarget.hasAVX512()) + return SDValue(); + + SDValue N0 = N->getOperand(0); + EVT VT = N->getValueType(0); + EVT SVT = VT.getScalarType(); + EVT InSVT = N0.getValueType().getScalarType(); + unsigned EltSizeInBits = SVT.getSizeInBits(); + + // Input type must be extending a bool vector (bit-casted from a scalar + // integer) to legal integer types. + if (!VT.isVector()) + return SDValue(); + if (SVT != MVT::i64 && SVT != MVT::i32 && SVT != MVT::i16 && SVT != MVT::i8) + return SDValue(); + if (InSVT != MVT::i1 || N0.getOpcode() != ISD::BITCAST) + return SDValue(); + + SDValue N00 = N0.getOperand(0); + EVT SclVT = N0.getOperand(0).getValueType(); + if (!SclVT.isScalarInteger()) + return SDValue(); + + SDLoc DL(N); + SDValue Vec; + SmallVector<int, 32> ShuffleMask; + unsigned NumElts = VT.getVectorNumElements(); + assert(NumElts == SclVT.getSizeInBits() && "Unexpected bool vector size"); + + // Broadcast the scalar integer to the vector elements. + if (NumElts > EltSizeInBits) { + // If the scalar integer is greater than the vector element size, then we + // must split it down into sub-sections for broadcasting. For example: + // i16 -> v16i8 (i16 -> v8i16 -> v16i8) with 2 sub-sections. + // i32 -> v32i8 (i32 -> v8i32 -> v32i8) with 4 sub-sections. + assert((NumElts % EltSizeInBits) == 0 && "Unexpected integer scale"); + unsigned Scale = NumElts / EltSizeInBits; + EVT BroadcastVT = + EVT::getVectorVT(*DAG.getContext(), SclVT, EltSizeInBits); + Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, BroadcastVT, N00); + Vec = DAG.getBitcast(VT, Vec); + + for (unsigned i = 0; i != Scale; ++i) + ShuffleMask.append(EltSizeInBits, i); + } else { + // For smaller scalar integers, we can simply any-extend it to the vector + // element size (we don't care about the upper bits) and broadcast it to all + // elements. + SDValue Scl = DAG.getAnyExtOrTrunc(N00, DL, SVT); + Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VT, Scl); + ShuffleMask.append(NumElts, 0); + } + Vec = DAG.getVectorShuffle(VT, DL, Vec, Vec, ShuffleMask); + + // Now, mask the relevant bit in each element. + SmallVector<SDValue, 32> Bits; + for (unsigned i = 0; i != NumElts; ++i) { + int BitIdx = (i % EltSizeInBits); + APInt Bit = APInt::getBitsSet(EltSizeInBits, BitIdx, BitIdx + 1); + Bits.push_back(DAG.getConstant(Bit, DL, SVT)); + } + SDValue BitMask = DAG.getBuildVector(VT, DL, Bits); + Vec = DAG.getNode(ISD::AND, DL, VT, Vec, BitMask); + + // Compare against the bitmask and extend the result. + EVT CCVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1, NumElts); + Vec = DAG.getSetCC(DL, CCVT, Vec, BitMask, ISD::SETEQ); + Vec = DAG.getSExtOrTrunc(Vec, DL, VT); + + // For SEXT, this is now done, otherwise shift the result down for + // zero-extension. + if (Opcode == ISD::SIGN_EXTEND) + return Vec; + return DAG.getNode(ISD::SRL, DL, VT, Vec, + DAG.getConstant(EltSizeInBits - 1, DL, VT)); +} + /// Convert a SEXT or ZEXT of a vector to a SIGN_EXTEND_VECTOR_INREG or /// ZERO_EXTEND_VECTOR_INREG, this requires the splitting (or concatenating /// with UNDEFs) of the input to vectors of the same size as the target type @@ -34570,14 +36212,11 @@ static SDValue combineSext(SDNode *N, SelectionDAG &DAG, if (SDValue DivRem8 = getDivRem8(N, DAG)) return DivRem8; - if (!DCI.isBeforeLegalizeOps()) { - if (InVT == MVT::i1) { - SDValue Zero = DAG.getConstant(0, DL, VT); - SDValue AllOnes = DAG.getAllOnesConstant(DL, VT); - return DAG.getSelect(DL, VT, N0, AllOnes, Zero); - } + if (SDValue NewCMov = combineToExtendCMOV(N, DAG)) + return NewCMov; + + if (!DCI.isBeforeLegalizeOps()) return SDValue(); - } if (InVT == MVT::i1 && N0.getOpcode() == ISD::XOR && isAllOnesConstant(N0.getOperand(1)) && N0.hasOneUse()) { @@ -34592,7 +36231,10 @@ static SDValue combineSext(SDNode *N, SelectionDAG &DAG, if (SDValue V = combineToExtendVectorInReg(N, DAG, DCI, Subtarget)) return V; - if (Subtarget.hasAVX() && VT.is256BitVector()) + if (SDValue V = combineToExtendBoolVectorInReg(N, DAG, DCI, Subtarget)) + return V; + + if (VT.isVector()) if (SDValue R = WidenMaskArithmetic(N, DAG, DCI, Subtarget)) return R; @@ -34604,6 +36246,7 @@ static SDValue combineSext(SDNode *N, SelectionDAG &DAG, static SDValue combineFMA(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { + // TODO: Handle FMSUB/FNMADD/FNMSUB as the starting opcode. SDLoc dl(N); EVT VT = N->getValueType(0); @@ -34629,48 +36272,112 @@ static SDValue combineFMA(SDNode *N, SelectionDAG &DAG, // Do not convert the passthru input of scalar intrinsics. // FIXME: We could allow negations of the lower element only. - bool NegA = N->getOpcode() != X86ISD::FMADDS1_RND && invertIfNegative(A); + bool NegA = N->getOpcode() != X86ISD::FMADDS1 && + N->getOpcode() != X86ISD::FMADDS1_RND && invertIfNegative(A); bool NegB = invertIfNegative(B); - bool NegC = N->getOpcode() != X86ISD::FMADDS3_RND && invertIfNegative(C); + bool NegC = N->getOpcode() != X86ISD::FMADDS3 && + N->getOpcode() != X86ISD::FMADDS3_RND && invertIfNegative(C); // Negative multiplication when NegA xor NegB bool NegMul = (NegA != NegB); + bool HasNeg = NegA || NegB || NegC; unsigned NewOpcode; if (!NegMul) - NewOpcode = (!NegC) ? X86ISD::FMADD : X86ISD::FMSUB; + NewOpcode = (!NegC) ? unsigned(ISD::FMA) : unsigned(X86ISD::FMSUB); else NewOpcode = (!NegC) ? X86ISD::FNMADD : X86ISD::FNMSUB; + // For FMA, we risk reconstructing the node we started with. + // In order to avoid this, we check for negation or opcode change. If + // one of the two happened, then it is a new node and we return it. + if (N->getOpcode() == ISD::FMA) { + if (HasNeg || NewOpcode != N->getOpcode()) + return DAG.getNode(NewOpcode, dl, VT, A, B, C); + return SDValue(); + } if (N->getOpcode() == X86ISD::FMADD_RND) { switch (NewOpcode) { - case X86ISD::FMADD: NewOpcode = X86ISD::FMADD_RND; break; + case ISD::FMA: NewOpcode = X86ISD::FMADD_RND; break; case X86ISD::FMSUB: NewOpcode = X86ISD::FMSUB_RND; break; case X86ISD::FNMADD: NewOpcode = X86ISD::FNMADD_RND; break; case X86ISD::FNMSUB: NewOpcode = X86ISD::FNMSUB_RND; break; } + } else if (N->getOpcode() == X86ISD::FMADDS1) { + switch (NewOpcode) { + case ISD::FMA: NewOpcode = X86ISD::FMADDS1; break; + case X86ISD::FMSUB: NewOpcode = X86ISD::FMSUBS1; break; + case X86ISD::FNMADD: NewOpcode = X86ISD::FNMADDS1; break; + case X86ISD::FNMSUB: NewOpcode = X86ISD::FNMSUBS1; break; + } + } else if (N->getOpcode() == X86ISD::FMADDS3) { + switch (NewOpcode) { + case ISD::FMA: NewOpcode = X86ISD::FMADDS3; break; + case X86ISD::FMSUB: NewOpcode = X86ISD::FMSUBS3; break; + case X86ISD::FNMADD: NewOpcode = X86ISD::FNMADDS3; break; + case X86ISD::FNMSUB: NewOpcode = X86ISD::FNMSUBS3; break; + } } else if (N->getOpcode() == X86ISD::FMADDS1_RND) { switch (NewOpcode) { - case X86ISD::FMADD: NewOpcode = X86ISD::FMADDS1_RND; break; + case ISD::FMA: NewOpcode = X86ISD::FMADDS1_RND; break; case X86ISD::FMSUB: NewOpcode = X86ISD::FMSUBS1_RND; break; case X86ISD::FNMADD: NewOpcode = X86ISD::FNMADDS1_RND; break; case X86ISD::FNMSUB: NewOpcode = X86ISD::FNMSUBS1_RND; break; } } else if (N->getOpcode() == X86ISD::FMADDS3_RND) { switch (NewOpcode) { - case X86ISD::FMADD: NewOpcode = X86ISD::FMADDS3_RND; break; + case ISD::FMA: NewOpcode = X86ISD::FMADDS3_RND; break; case X86ISD::FMSUB: NewOpcode = X86ISD::FMSUBS3_RND; break; case X86ISD::FNMADD: NewOpcode = X86ISD::FNMADDS3_RND; break; case X86ISD::FNMSUB: NewOpcode = X86ISD::FNMSUBS3_RND; break; } + } else if (N->getOpcode() == X86ISD::FMADD4S) { + switch (NewOpcode) { + case ISD::FMA: NewOpcode = X86ISD::FMADD4S; break; + case X86ISD::FMSUB: NewOpcode = X86ISD::FMSUB4S; break; + case X86ISD::FNMADD: NewOpcode = X86ISD::FNMADD4S; break; + case X86ISD::FNMSUB: NewOpcode = X86ISD::FNMSUB4S; break; + } } else { - assert((N->getOpcode() == X86ISD::FMADD || N->getOpcode() == ISD::FMA) && - "Unexpected opcode!"); + llvm_unreachable("Unexpected opcode!"); + } + + // Only return the node is the opcode was changed or one of the + // operand was negated. If not, we'll just recreate the same node. + if (HasNeg || NewOpcode != N->getOpcode()) { + if (N->getNumOperands() == 4) + return DAG.getNode(NewOpcode, dl, VT, A, B, C, N->getOperand(3)); return DAG.getNode(NewOpcode, dl, VT, A, B, C); } - return DAG.getNode(NewOpcode, dl, VT, A, B, C, N->getOperand(3)); + return SDValue(); +} + +// Combine FMADDSUB(A, B, FNEG(C)) -> FMSUBADD(A, B, C) +static SDValue combineFMADDSUB(SDNode *N, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + SDLoc dl(N); + EVT VT = N->getValueType(0); + + SDValue NegVal = isFNEG(N->getOperand(2).getNode()); + if (!NegVal) + return SDValue(); + + unsigned NewOpcode; + switch (N->getOpcode()) { + default: llvm_unreachable("Unexpected opcode!"); + case X86ISD::FMADDSUB: NewOpcode = X86ISD::FMSUBADD; break; + case X86ISD::FMADDSUB_RND: NewOpcode = X86ISD::FMSUBADD_RND; break; + case X86ISD::FMSUBADD: NewOpcode = X86ISD::FMADDSUB; break; + case X86ISD::FMSUBADD_RND: NewOpcode = X86ISD::FMADDSUB_RND; break; + } + + if (N->getNumOperands() == 4) + return DAG.getNode(NewOpcode, dl, VT, N->getOperand(0), N->getOperand(1), + NegVal, N->getOperand(3)); + return DAG.getNode(NewOpcode, dl, VT, N->getOperand(0), N->getOperand(1), + NegVal); } static SDValue combineZext(SDNode *N, SelectionDAG &DAG, @@ -34710,10 +36417,16 @@ static SDValue combineZext(SDNode *N, SelectionDAG &DAG, } } + if (SDValue NewCMov = combineToExtendCMOV(N, DAG)) + return NewCMov; + if (SDValue V = combineToExtendVectorInReg(N, DAG, DCI, Subtarget)) return V; - if (VT.is256BitVector()) + if (SDValue V = combineToExtendBoolVectorInReg(N, DAG, DCI, Subtarget)) + return V; + + if (VT.isVector()) if (SDValue R = WidenMaskArithmetic(N, DAG, DCI, Subtarget)) return R; @@ -34736,13 +36449,23 @@ static SDValue combineVectorSizedSetCCEquality(SDNode *SetCC, SelectionDAG &DAG, ISD::CondCode CC = cast<CondCodeSDNode>(SetCC->getOperand(2))->get(); assert((CC == ISD::SETNE || CC == ISD::SETEQ) && "Bad comparison predicate"); - // We're looking for an oversized integer equality comparison, but ignore a - // comparison with zero because that gets special treatment in EmitTest(). + // We're looking for an oversized integer equality comparison. SDValue X = SetCC->getOperand(0); SDValue Y = SetCC->getOperand(1); EVT OpVT = X.getValueType(); unsigned OpSize = OpVT.getSizeInBits(); - if (!OpVT.isScalarInteger() || OpSize < 128 || isNullConstant(Y)) + if (!OpVT.isScalarInteger() || OpSize < 128) + return SDValue(); + + // Ignore a comparison with zero because that gets special treatment in + // EmitTest(). But make an exception for the special case of a pair of + // logically-combined vector-sized operands compared to zero. This pattern may + // be generated by the memcmp expansion pass with oversized integer compares + // (see PR33325). + bool IsOrXorXorCCZero = isNullConstant(Y) && X.getOpcode() == ISD::OR && + X.getOperand(0).getOpcode() == ISD::XOR && + X.getOperand(1).getOpcode() == ISD::XOR; + if (isNullConstant(Y) && !IsOrXorXorCCZero) return SDValue(); // Bail out if we know that this is not really just an oversized integer. @@ -34757,15 +36480,29 @@ static SDValue combineVectorSizedSetCCEquality(SDNode *SetCC, SelectionDAG &DAG, if ((OpSize == 128 && Subtarget.hasSSE2()) || (OpSize == 256 && Subtarget.hasAVX2())) { EVT VecVT = OpSize == 128 ? MVT::v16i8 : MVT::v32i8; - SDValue VecX = DAG.getBitcast(VecVT, X); - SDValue VecY = DAG.getBitcast(VecVT, Y); - + SDValue Cmp; + if (IsOrXorXorCCZero) { + // This is a bitwise-combined equality comparison of 2 pairs of vectors: + // setcc i128 (or (xor A, B), (xor C, D)), 0, eq|ne + // Use 2 vector equality compares and 'and' the results before doing a + // MOVMSK. + SDValue A = DAG.getBitcast(VecVT, X.getOperand(0).getOperand(0)); + SDValue B = DAG.getBitcast(VecVT, X.getOperand(0).getOperand(1)); + SDValue C = DAG.getBitcast(VecVT, X.getOperand(1).getOperand(0)); + SDValue D = DAG.getBitcast(VecVT, X.getOperand(1).getOperand(1)); + SDValue Cmp1 = DAG.getNode(X86ISD::PCMPEQ, DL, VecVT, A, B); + SDValue Cmp2 = DAG.getNode(X86ISD::PCMPEQ, DL, VecVT, C, D); + Cmp = DAG.getNode(ISD::AND, DL, VecVT, Cmp1, Cmp2); + } else { + SDValue VecX = DAG.getBitcast(VecVT, X); + SDValue VecY = DAG.getBitcast(VecVT, Y); + Cmp = DAG.getNode(X86ISD::PCMPEQ, DL, VecVT, VecX, VecY); + } // If all bytes match (bitmask is 0x(FFFF)FFFF), that's equality. // setcc i128 X, Y, eq --> setcc (pmovmskb (pcmpeqb X, Y)), 0xFFFF, eq // setcc i128 X, Y, ne --> setcc (pmovmskb (pcmpeqb X, Y)), 0xFFFF, ne // setcc i256 X, Y, eq --> setcc (vpmovmskb (vpcmpeqb X, Y)), 0xFFFFFFFF, eq // setcc i256 X, Y, ne --> setcc (vpmovmskb (vpcmpeqb X, Y)), 0xFFFFFFFF, ne - SDValue Cmp = DAG.getNode(X86ISD::PCMPEQ, DL, VecVT, VecX, VecY); SDValue MovMsk = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Cmp); SDValue FFFFs = DAG.getConstant(OpSize == 128 ? 0xFFFF : 0xFFFFFFFF, DL, MVT::i32); @@ -34804,23 +36541,19 @@ static SDValue combineSetCC(SDNode *N, SelectionDAG &DAG, return V; } - if (VT.getScalarType() == MVT::i1 && + if (VT.isVector() && VT.getVectorElementType() == MVT::i1 && (CC == ISD::SETNE || CC == ISD::SETEQ || ISD::isSignedIntSetCC(CC))) { - bool IsSEXT0 = - (LHS.getOpcode() == ISD::SIGN_EXTEND) && - (LHS.getOperand(0).getValueType().getScalarType() == MVT::i1); - bool IsVZero1 = ISD::isBuildVectorAllZeros(RHS.getNode()); - - if (!IsSEXT0 || !IsVZero1) { - // Swap the operands and update the condition code. + // Put build_vectors on the right. + if (LHS.getOpcode() == ISD::BUILD_VECTOR) { std::swap(LHS, RHS); CC = ISD::getSetCCSwappedOperands(CC); - - IsSEXT0 = (LHS.getOpcode() == ISD::SIGN_EXTEND) && - (LHS.getOperand(0).getValueType().getScalarType() == MVT::i1); - IsVZero1 = ISD::isBuildVectorAllZeros(RHS.getNode()); } + bool IsSEXT0 = + (LHS.getOpcode() == ISD::SIGN_EXTEND) && + (LHS.getOperand(0).getValueType().getVectorElementType() == MVT::i1); + bool IsVZero1 = ISD::isBuildVectorAllZeros(RHS.getNode()); + if (IsSEXT0 && IsVZero1) { assert(VT == LHS.getOperand(0).getValueType() && "Uexpected operand type"); @@ -34846,17 +36579,108 @@ static SDValue combineSetCC(SDNode *N, SelectionDAG &DAG, return SDValue(); } -static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG) { +static SDValue combineMOVMSK(SDNode *N, SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI) { + SDValue Src = N->getOperand(0); + MVT SrcVT = Src.getSimpleValueType(); + + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + TargetLowering::TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(), + !DCI.isBeforeLegalizeOps()); + + // MOVMSK only uses the MSB from each vector element. + KnownBits Known; + APInt DemandedMask(APInt::getSignMask(SrcVT.getScalarSizeInBits())); + if (TLI.SimplifyDemandedBits(Src, DemandedMask, Known, TLO)) { + DCI.AddToWorklist(Src.getNode()); + DCI.CommitTargetLoweringOpt(TLO); + return SDValue(N, 0); + } + + return SDValue(); +} + +static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI, + const X86Subtarget &Subtarget) { SDLoc DL(N); + + if (DCI.isBeforeLegalizeOps()) { + SDValue Index = N->getOperand(4); + // Remove any sign extends from 32 or smaller to larger than 32. + // Only do this before LegalizeOps in case we need the sign extend for + // legalization. + if (Index.getOpcode() == ISD::SIGN_EXTEND) { + if (Index.getScalarValueSizeInBits() > 32 && + Index.getOperand(0).getScalarValueSizeInBits() <= 32) { + SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end()); + NewOps[4] = Index.getOperand(0); + DAG.UpdateNodeOperands(N, NewOps); + // The original sign extend has less users, add back to worklist in case + // it needs to be removed + DCI.AddToWorklist(Index.getNode()); + DCI.AddToWorklist(N); + return SDValue(N, 0); + } + } + + // Make sure the index is either i32 or i64 + unsigned ScalarSize = Index.getScalarValueSizeInBits(); + if (ScalarSize != 32 && ScalarSize != 64) { + MVT EltVT = ScalarSize > 32 ? MVT::i64 : MVT::i32; + EVT IndexVT = EVT::getVectorVT(*DAG.getContext(), EltVT, + Index.getValueType().getVectorNumElements()); + Index = DAG.getSExtOrTrunc(Index, DL, IndexVT); + SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end()); + NewOps[4] = Index; + DAG.UpdateNodeOperands(N, NewOps); + DCI.AddToWorklist(N); + return SDValue(N, 0); + } + + // Try to remove zero extends from 32->64 if we know the sign bit of + // the input is zero. + if (Index.getOpcode() == ISD::ZERO_EXTEND && + Index.getScalarValueSizeInBits() == 64 && + Index.getOperand(0).getScalarValueSizeInBits() == 32) { + if (DAG.SignBitIsZero(Index.getOperand(0))) { + SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end()); + NewOps[4] = Index.getOperand(0); + DAG.UpdateNodeOperands(N, NewOps); + // The original zero extend has less users, add back to worklist in case + // it needs to be removed + DCI.AddToWorklist(Index.getNode()); + DCI.AddToWorklist(N); + return SDValue(N, 0); + } + } + } + // Gather and Scatter instructions use k-registers for masks. The type of // the masks is v*i1. So the mask will be truncated anyway. // The SIGN_EXTEND_INREG my be dropped. SDValue Mask = N->getOperand(2); - if (Mask.getOpcode() == ISD::SIGN_EXTEND_INREG) { + if (Subtarget.hasAVX512() && Mask.getOpcode() == ISD::SIGN_EXTEND_INREG) { SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end()); NewOps[2] = Mask.getOperand(0); DAG.UpdateNodeOperands(N, NewOps); + return SDValue(N, 0); + } + + // With AVX2 we only demand the upper bit of the mask. + if (!Subtarget.hasAVX512()) { + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + TargetLowering::TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(), + !DCI.isBeforeLegalizeOps()); + KnownBits Known; + APInt DemandedMask(APInt::getSignMask(Mask.getScalarValueSizeInBits())); + if (TLI.SimplifyDemandedBits(Mask, DemandedMask, Known, TLO)) { + DCI.AddToWorklist(Mask.getNode()); + DCI.CommitTargetLoweringOpt(TLO); + return SDValue(N, 0); + } } + return SDValue(); } @@ -34868,7 +36692,7 @@ static SDValue combineX86SetCC(SDNode *N, SelectionDAG &DAG, SDValue EFLAGS = N->getOperand(1); // Try to simplify the EFLAGS and condition code operands. - if (SDValue Flags = combineSetCCEFLAGS(EFLAGS, CC, DAG)) + if (SDValue Flags = combineSetCCEFLAGS(EFLAGS, CC, DAG, Subtarget)) return getSETCC(CC, Flags, DL, DAG); return SDValue(); @@ -34884,7 +36708,7 @@ static SDValue combineBrCond(SDNode *N, SelectionDAG &DAG, // Try to simplify the EFLAGS and condition code operands. // Make sure to not keep references to operands, as combineSetCCEFLAGS can // RAUW them under us. - if (SDValue Flags = combineSetCCEFLAGS(EFLAGS, CC, DAG)) { + if (SDValue Flags = combineSetCCEFLAGS(EFLAGS, CC, DAG, Subtarget)) { SDValue Cond = DAG.getConstant(CC, DL, MVT::i8); return DAG.getNode(X86ISD::BRCOND, DL, N->getVTList(), N->getOperand(0), N->getOperand(1), Cond, Flags); @@ -34909,7 +36733,7 @@ static SDValue combineVectorCompareAndMaskUnaryOp(SDNode *N, EVT VT = N->getValueType(0); if (!VT.isVector() || N->getOperand(0)->getOpcode() != ISD::AND || N->getOperand(0)->getOperand(0)->getOpcode() != ISD::SETCC || - VT.getSizeInBits() != N->getOperand(0)->getValueType(0).getSizeInBits()) + VT.getSizeInBits() != N->getOperand(0).getValueSizeInBits()) return SDValue(); // Now check that the other operand of the AND is a constant. We could @@ -34945,7 +36769,6 @@ static SDValue combineUIntToFP(SDNode *N, SelectionDAG &DAG, EVT VT = N->getValueType(0); EVT InVT = Op0.getValueType(); EVT InSVT = InVT.getScalarType(); - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); // UINT_TO_FP(vXi8) -> SINT_TO_FP(ZEXT(vXi8 to vXi32)) // UINT_TO_FP(vXi16) -> SINT_TO_FP(ZEXT(vXi16 to vXi32)) @@ -34955,9 +36778,7 @@ static SDValue combineUIntToFP(SDNode *N, SelectionDAG &DAG, InVT.getVectorNumElements()); SDValue P = DAG.getNode(ISD::ZERO_EXTEND, dl, DstVT, Op0); - if (TLI.isOperationLegal(ISD::UINT_TO_FP, DstVT)) - return DAG.getNode(ISD::UINT_TO_FP, dl, VT, P); - + // UINT_TO_FP isn't legal without AVX512 so use SINT_TO_FP. return DAG.getNode(ISD::SINT_TO_FP, dl, VT, P); } @@ -35049,7 +36870,7 @@ static SDValue combineSBB(SDNode *N, SelectionDAG &DAG) { // Optimize RES, EFLAGS = X86ISD::ADC LHS, RHS, EFLAGS static SDValue combineADC(SDNode *N, SelectionDAG &DAG, - X86TargetLowering::DAGCombinerInfo &DCI) { + TargetLowering::DAGCombinerInfo &DCI) { // If the LHS and RHS of the ADC node are zero, then it can't overflow and // the result is either zero or one (depending on the input carry bit). // Strength reduce this down to a "set on carry" aka SETCC_CARRY&1. @@ -35260,6 +37081,9 @@ static SDValue combineAddOrSubToADCOrSBB(SDNode *N, SelectionDAG &DAG) { static SDValue combineLoopMAddPattern(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { + if (!Subtarget.hasSSE2()) + return SDValue(); + SDValue MulOp = N->getOperand(0); SDValue Phi = N->getOperand(1); @@ -35305,6 +37129,9 @@ static SDValue combineLoopMAddPattern(SDNode *N, SelectionDAG &DAG, static SDValue combineLoopSADPattern(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { + if (!Subtarget.hasSSE2()) + return SDValue(); + SDLoc DL(N); EVT VT = N->getValueType(0); SDValue Op0 = N->getOperand(0); @@ -35362,16 +37189,13 @@ static SDValue combineLoopSADPattern(SDNode *N, SelectionDAG &DAG, Sad = DAG.getNode(ISD::TRUNCATE, DL, VT, Sad); if (VT.getSizeInBits() > ResVT.getSizeInBits()) { - // Update part of elements of the reduction vector. This is done by first - // extracting a sub-vector from it, updating this sub-vector, and inserting - // it back. - SDValue SubPhi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ResVT, Phi, - DAG.getIntPtrConstant(0, DL)); - SDValue Res = DAG.getNode(ISD::ADD, DL, ResVT, Sad, SubPhi); - return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, Phi, Res, - DAG.getIntPtrConstant(0, DL)); - } else - return DAG.getNode(ISD::ADD, DL, VT, Sad, Phi); + // Fill the upper elements with zero to match the add width. + SDValue Zero = DAG.getConstant(0, DL, VT); + Sad = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, Zero, Sad, + DAG.getIntPtrConstant(0, DL)); + } + + return DAG.getNode(ISD::ADD, DL, VT, Sad, Phi); } /// Convert vector increment or decrement to sub/add with an all-ones constant: @@ -35392,7 +37216,7 @@ static SDValue combineIncDecVector(SDNode *N, SelectionDAG &DAG) { SDNode *N1 = N->getOperand(1).getNode(); APInt SplatVal; - if (!ISD::isConstantSplatVector(N1, SplatVal, /*AllowShrink*/false) || + if (!ISD::isConstantSplatVector(N1, SplatVal) || !SplatVal.isOneValue()) return SDValue(); @@ -35426,6 +37250,89 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG, return combineAddOrSubToADCOrSBB(N, DAG); } +static SDValue combineSubToSubus(SDNode *N, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + SDValue Op0 = N->getOperand(0); + SDValue Op1 = N->getOperand(1); + EVT VT = N->getValueType(0); + + // PSUBUS is supported, starting from SSE2, but special preprocessing + // for v8i32 requires umin, which appears in SSE41. + if (!(Subtarget.hasSSE2() && (VT == MVT::v16i8 || VT == MVT::v8i16)) && + !(Subtarget.hasSSE41() && (VT == MVT::v8i32)) && + !(Subtarget.hasAVX2() && (VT == MVT::v32i8 || VT == MVT::v16i16)) && + !(Subtarget.hasAVX512() && Subtarget.hasBWI() && + (VT == MVT::v64i8 || VT == MVT::v32i16 || VT == MVT::v16i32 || + VT == MVT::v8i64))) + return SDValue(); + + SDValue SubusLHS, SubusRHS; + // Try to find umax(a,b) - b or a - umin(a,b) patterns + // they may be converted to subus(a,b). + // TODO: Need to add IR cannonicialization for this code. + if (Op0.getOpcode() == ISD::UMAX) { + SubusRHS = Op1; + SDValue MaxLHS = Op0.getOperand(0); + SDValue MaxRHS = Op0.getOperand(1); + if (MaxLHS == Op1) + SubusLHS = MaxRHS; + else if (MaxRHS == Op1) + SubusLHS = MaxLHS; + else + return SDValue(); + } else if (Op1.getOpcode() == ISD::UMIN) { + SubusLHS = Op0; + SDValue MinLHS = Op1.getOperand(0); + SDValue MinRHS = Op1.getOperand(1); + if (MinLHS == Op0) + SubusRHS = MinRHS; + else if (MinRHS == Op0) + SubusRHS = MinLHS; + else + return SDValue(); + } else + return SDValue(); + + // PSUBUS doesn't support v8i32/v8i64/v16i32, but it can be enabled with + // special preprocessing in some cases. + if (VT != MVT::v8i32 && VT != MVT::v16i32 && VT != MVT::v8i64) + return DAG.getNode(X86ISD::SUBUS, SDLoc(N), VT, SubusLHS, SubusRHS); + + // Special preprocessing case can be only applied + // if the value was zero extended from 16 bit, + // so we require first 16 bits to be zeros for 32 bit + // values, or first 48 bits for 64 bit values. + KnownBits Known; + DAG.computeKnownBits(SubusLHS, Known); + unsigned NumZeros = Known.countMinLeadingZeros(); + if ((VT == MVT::v8i64 && NumZeros < 48) || NumZeros < 16) + return SDValue(); + + EVT ExtType = SubusLHS.getValueType(); + EVT ShrinkedType; + if (VT == MVT::v8i32 || VT == MVT::v8i64) + ShrinkedType = MVT::v8i16; + else + ShrinkedType = NumZeros >= 24 ? MVT::v16i8 : MVT::v16i16; + + // If SubusLHS is zeroextended - truncate SubusRHS to it's + // size SubusRHS = umin(0xFFF.., SubusRHS). + SDValue SaturationConst = + DAG.getConstant(APInt::getLowBitsSet(ExtType.getScalarSizeInBits(), + ShrinkedType.getScalarSizeInBits()), + SDLoc(SubusLHS), ExtType); + SDValue UMin = DAG.getNode(ISD::UMIN, SDLoc(SubusLHS), ExtType, SubusRHS, + SaturationConst); + SDValue NewSubusLHS = + DAG.getZExtOrTrunc(SubusLHS, SDLoc(SubusLHS), ShrinkedType); + SDValue NewSubusRHS = DAG.getZExtOrTrunc(UMin, SDLoc(SubusRHS), ShrinkedType); + SDValue Psubus = DAG.getNode(X86ISD::SUBUS, SDLoc(N), ShrinkedType, + NewSubusLHS, NewSubusRHS); + // Zero extend the result, it may be used somewhere as 32 bit, + // if not zext and following trunc will shrink. + return DAG.getZExtOrTrunc(Psubus, SDLoc(N), ExtType); +} + static SDValue combineSub(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { SDValue Op0 = N->getOperand(0); @@ -35459,6 +37366,10 @@ static SDValue combineSub(SDNode *N, SelectionDAG &DAG, if (SDValue V = combineIncDecVector(N, DAG)) return V; + // Try to create PSUBUS if SUB's argument is max/min + if (SDValue V = combineSubToSubus(N, DAG, Subtarget)) + return V; + return combineAddOrSubToADCOrSBB(N, DAG); } @@ -35554,39 +37465,26 @@ static SDValue combineVSZext(SDNode *N, SelectionDAG &DAG, return SDValue(); } -/// Canonicalize (LSUB p, 1) -> (LADD p, -1). -static SDValue combineLockSub(SDNode *N, SelectionDAG &DAG, - const X86Subtarget &Subtarget) { - SDValue Chain = N->getOperand(0); - SDValue LHS = N->getOperand(1); - SDValue RHS = N->getOperand(2); - MVT VT = RHS.getSimpleValueType(); - SDLoc DL(N); - - auto *C = dyn_cast<ConstantSDNode>(RHS); - if (!C || C->getZExtValue() != 1) - return SDValue(); - - RHS = DAG.getConstant(-1, DL, VT); - MachineMemOperand *MMO = cast<MemSDNode>(N)->getMemOperand(); - return DAG.getMemIntrinsicNode(X86ISD::LADD, DL, - DAG.getVTList(MVT::i32, MVT::Other), - {Chain, LHS, RHS}, VT, MMO); -} - -// TEST (AND a, b) ,(AND a, b) -> TEST a, b -static SDValue combineTestM(SDNode *N, SelectionDAG &DAG) { +static SDValue combineTestM(SDNode *N, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { SDValue Op0 = N->getOperand(0); SDValue Op1 = N->getOperand(1); - if (Op0 != Op1 || Op1->getOpcode() != ISD::AND) - return SDValue(); - - EVT VT = N->getValueType(0); + MVT VT = N->getSimpleValueType(0); SDLoc DL(N); - return DAG.getNode(X86ISD::TESTM, DL, VT, - Op0->getOperand(0), Op0->getOperand(1)); + // TEST (AND a, b) ,(AND a, b) -> TEST a, b + if (Op0 == Op1 && Op1->getOpcode() == ISD::AND) + return DAG.getNode(X86ISD::TESTM, DL, VT, Op0->getOperand(0), + Op0->getOperand(1)); + + // TEST op0, BUILD_VECTOR(all_zero) -> BUILD_VECTOR(all_zero) + // TEST BUILD_VECTOR(all_zero), op1 -> BUILD_VECTOR(all_zero) + if (ISD::isBuildVectorAllZeros(Op0.getNode()) || + ISD::isBuildVectorAllZeros(Op1.getNode())) + return getZeroVector(VT, Subtarget, DAG, DL); + + return SDValue(); } static SDValue combineVectorCompare(SDNode *N, SelectionDAG &DAG, @@ -35610,21 +37508,55 @@ static SDValue combineInsertSubvector(SDNode *N, SelectionDAG &DAG, if (DCI.isBeforeLegalizeOps()) return SDValue(); + MVT OpVT = N->getSimpleValueType(0); + + // Early out for mask vectors. + if (OpVT.getVectorElementType() == MVT::i1) + return SDValue(); + SDLoc dl(N); SDValue Vec = N->getOperand(0); SDValue SubVec = N->getOperand(1); - SDValue Idx = N->getOperand(2); - unsigned IdxVal = cast<ConstantSDNode>(Idx)->getZExtValue(); - MVT OpVT = N->getSimpleValueType(0); + unsigned IdxVal = N->getConstantOperandVal(2); MVT SubVecVT = SubVec.getSimpleValueType(); + if (ISD::isBuildVectorAllZeros(Vec.getNode())) { + // Inserting zeros into zeros is a nop. + if (ISD::isBuildVectorAllZeros(SubVec.getNode())) + return Vec; + + // If we're inserting into a zero vector and then into a larger zero vector, + // just insert into the larger zero vector directly. + if (SubVec.getOpcode() == ISD::INSERT_SUBVECTOR && + ISD::isBuildVectorAllZeros(SubVec.getOperand(0).getNode())) { + unsigned Idx2Val = SubVec.getConstantOperandVal(2); + return DAG.getNode(ISD::INSERT_SUBVECTOR, dl, OpVT, Vec, + SubVec.getOperand(1), + DAG.getIntPtrConstant(IdxVal + Idx2Val, dl)); + } + + // If we're inserting a bitcast into zeros, rewrite the insert and move the + // bitcast to the other side. This helps with detecting zero extending + // during isel. + // TODO: Is this useful for other indices than 0? + if (SubVec.getOpcode() == ISD::BITCAST && IdxVal == 0) { + MVT CastVT = SubVec.getOperand(0).getSimpleValueType(); + unsigned NumElems = OpVT.getSizeInBits() / CastVT.getScalarSizeInBits(); + MVT NewVT = MVT::getVectorVT(CastVT.getVectorElementType(), NumElems); + SDValue Insert = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, NewVT, + DAG.getBitcast(NewVT, Vec), + SubVec.getOperand(0), N->getOperand(2)); + return DAG.getBitcast(OpVT, Insert); + } + } + // If this is an insert of an extract, combine to a shuffle. Don't do this - // if the insert or extract can be represented with a subvector operation. + // if the insert or extract can be represented with a subregister operation. if (SubVec.getOpcode() == ISD::EXTRACT_SUBVECTOR && SubVec.getOperand(0).getSimpleValueType() == OpVT && (IdxVal != 0 || !Vec.isUndef())) { - int ExtIdxVal = cast<ConstantSDNode>(SubVec.getOperand(1))->getZExtValue(); + int ExtIdxVal = SubVec.getConstantOperandVal(1); if (ExtIdxVal != 0) { int VecNumElts = OpVT.getVectorNumElements(); int SubVecNumElts = SubVecVT.getVectorNumElements(); @@ -35679,17 +37611,36 @@ static SDValue combineInsertSubvector(SDNode *N, SelectionDAG &DAG, } // If lower/upper loads are the same and the only users of the load, then // lower to a VBROADCASTF128/VBROADCASTI128/etc. - if (auto *Ld = dyn_cast<LoadSDNode>(peekThroughOneUseBitcasts(SubVec2))) { + if (auto *Ld = dyn_cast<LoadSDNode>(peekThroughOneUseBitcasts(SubVec2))) if (SubVec2 == SubVec && ISD::isNormalLoad(Ld) && - SDNode::areOnlyUsersOf({N, Vec.getNode()}, SubVec2.getNode())) { + SDNode::areOnlyUsersOf({N, Vec.getNode()}, SubVec2.getNode())) return DAG.getNode(X86ISD::SUBV_BROADCAST, dl, OpVT, SubVec); - } - } + // If this is subv_broadcast insert into both halves, use a larger // subv_broadcast. - if (SubVec.getOpcode() == X86ISD::SUBV_BROADCAST && SubVec == SubVec2) { + if (SubVec.getOpcode() == X86ISD::SUBV_BROADCAST && SubVec == SubVec2) return DAG.getNode(X86ISD::SUBV_BROADCAST, dl, OpVT, SubVec.getOperand(0)); + + // If we're inserting all zeros into the upper half, change this to + // an insert into an all zeros vector. We will match this to a move + // with implicit upper bit zeroing during isel. + if (ISD::isBuildVectorAllZeros(SubVec.getNode())) + return DAG.getNode(ISD::INSERT_SUBVECTOR, dl, OpVT, + getZeroVector(OpVT, Subtarget, DAG, dl), SubVec2, + Vec.getOperand(2)); + + // If we are inserting into both halves of the vector, the starting + // vector should be undef. If it isn't, make it so. Only do this if the + // the early insert has no other uses. + // TODO: Should this be a generic DAG combine? + if (!Vec.getOperand(0).isUndef() && Vec.hasOneUse()) { + Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, OpVT, DAG.getUNDEF(OpVT), + SubVec2, Vec.getOperand(2)); + DCI.AddToWorklist(Vec.getNode()); + return DAG.getNode(ISD::INSERT_SUBVECTOR, dl, OpVT, Vec, SubVec, + N->getOperand(2)); + } } } @@ -35697,6 +37648,32 @@ static SDValue combineInsertSubvector(SDNode *N, SelectionDAG &DAG, return SDValue(); } +static SDValue combineExtractSubvector(SDNode *N, SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI, + const X86Subtarget &Subtarget) { + if (DCI.isBeforeLegalizeOps()) + return SDValue(); + + MVT OpVT = N->getSimpleValueType(0); + SDValue InVec = N->getOperand(0); + unsigned IdxVal = cast<ConstantSDNode>(N->getOperand(1))->getZExtValue(); + + if (ISD::isBuildVectorAllZeros(InVec.getNode())) + return getZeroVector(OpVT, Subtarget, DAG, SDLoc(N)); + + if (ISD::isBuildVectorAllOnes(InVec.getNode())) { + if (OpVT.getScalarType() == MVT::i1) + return DAG.getConstant(1, SDLoc(N), OpVT); + return getOnesVector(OpVT, DAG, SDLoc(N)); + } + + if (InVec.getOpcode() == ISD::BUILD_VECTOR) + return DAG.getBuildVector( + OpVT, SDLoc(N), + InVec.getNode()->ops().slice(IdxVal, OpVT.getVectorNumElements())); + + return SDValue(); +} SDValue X86TargetLowering::PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const { @@ -35704,12 +37681,13 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N, switch (N->getOpcode()) { default: break; case ISD::EXTRACT_VECTOR_ELT: - return combineExtractVectorElt(N, DAG, DCI, Subtarget); case X86ISD::PEXTRW: case X86ISD::PEXTRB: - return combineExtractVectorElt_SSE(N, DAG, DCI, Subtarget); + return combineExtractVectorElt(N, DAG, DCI, Subtarget); case ISD::INSERT_SUBVECTOR: return combineInsertSubvector(N, DAG, DCI, Subtarget); + case ISD::EXTRACT_SUBVECTOR: + return combineExtractSubvector(N, DAG, DCI, Subtarget); case ISD::VSELECT: case ISD::SELECT: case X86ISD::SHRUNKBLEND: return combineSelect(N, DAG, DCI, Subtarget); @@ -35753,6 +37731,8 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N, case ISD::SETCC: return combineSetCC(N, DAG, Subtarget); case X86ISD::SETCC: return combineX86SetCC(N, DAG, Subtarget); case X86ISD::BRCOND: return combineBrCond(N, DAG, Subtarget); + case X86ISD::PACKSS: + case X86ISD::PACKUS: return combineVectorPack(N, DAG, DCI, Subtarget); case X86ISD::VSHLI: case X86ISD::VSRAI: case X86ISD::VSRLI: @@ -35784,6 +37764,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N, case X86ISD::MOVDDUP: case X86ISD::MOVSS: case X86ISD::MOVSD: + case X86ISD::VBROADCAST: case X86ISD::VPPERM: case X86ISD::VPERMI: case X86ISD::VPERMV: @@ -35795,15 +37776,23 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N, case X86ISD::VPERM2X128: case X86ISD::VZEXT_MOVL: case ISD::VECTOR_SHUFFLE: return combineShuffle(N, DAG, DCI,Subtarget); - case X86ISD::FMADD: case X86ISD::FMADD_RND: case X86ISD::FMADDS1_RND: case X86ISD::FMADDS3_RND: + case X86ISD::FMADDS1: + case X86ISD::FMADDS3: + case X86ISD::FMADD4S: case ISD::FMA: return combineFMA(N, DAG, Subtarget); + case X86ISD::FMADDSUB_RND: + case X86ISD::FMSUBADD_RND: + case X86ISD::FMADDSUB: + case X86ISD::FMSUBADD: return combineFMADDSUB(N, DAG, Subtarget); + case X86ISD::MOVMSK: return combineMOVMSK(N, DAG, DCI); + case X86ISD::MGATHER: + case X86ISD::MSCATTER: case ISD::MGATHER: - case ISD::MSCATTER: return combineGatherScatter(N, DAG); - case X86ISD::LSUB: return combineLockSub(N, DAG, Subtarget); - case X86ISD::TESTM: return combineTestM(N, DAG); + case ISD::MSCATTER: return combineGatherScatter(N, DAG, DCI, Subtarget); + case X86ISD::TESTM: return combineTestM(N, DAG, Subtarget); case X86ISD::PCMPEQ: case X86ISD::PCMPGT: return combineVectorCompare(N, DAG, Subtarget); } @@ -35910,6 +37899,27 @@ bool X86TargetLowering::IsDesirableToPromoteOp(SDValue Op, EVT &PVT) const { return Promote; } +bool X86TargetLowering:: + isDesirableToCombineBuildVectorToShuffleTruncate( + ArrayRef<int> ShuffleMask, EVT SrcVT, EVT TruncVT) const { + + assert(SrcVT.getVectorNumElements() == ShuffleMask.size() && + "Element count mismatch"); + assert( + Subtarget.getTargetLowering()->isShuffleMaskLegal(ShuffleMask, SrcVT) && + "Shuffle Mask expected to be legal"); + + // For 32-bit elements VPERMD is better than shuffle+truncate. + // TODO: After we improve lowerBuildVector, add execption for VPERMW. + if (SrcVT.getScalarSizeInBits() == 32 || !Subtarget.hasAVX2()) + return false; + + if (is128BitLaneCrossingShuffleMask(SrcVT.getSimpleVT(), ShuffleMask)) + return false; + + return true; +} + //===----------------------------------------------------------------------===// // X86 Inline Assembly Support //===----------------------------------------------------------------------===// @@ -36041,8 +38051,8 @@ X86TargetLowering::getConstraintType(StringRef Constraint) const { case 'v': case 'Y': case 'l': - return C_RegisterClass; case 'k': // AVX512 masking registers. + return C_RegisterClass; case 'a': case 'b': case 'c': @@ -36074,8 +38084,15 @@ X86TargetLowering::getConstraintType(StringRef Constraint) const { switch (Constraint[1]) { default: break; - case 'k': + case 'z': + case '0': return C_Register; + case 'i': + case 'm': + case 'k': + case 't': + case '2': + return C_RegisterClass; } } } @@ -36123,15 +38140,42 @@ TargetLowering::ConstraintWeight if (type->isX86_MMXTy() && Subtarget.hasMMX()) weight = CW_SpecificReg; break; - case 'Y': - // Other "Y<x>" (e.g. "Yk") constraints should be implemented below. - if (constraint[1] == 'k') { - // Support for 'Yk' (similarly to the 'k' variant below). - weight = CW_SpecificReg; + case 'Y': { + unsigned Size = StringRef(constraint).size(); + // Pick 'i' as the next char as 'Yi' and 'Y' are synonymous, when matching 'Y' + char NextChar = Size == 2 ? constraint[1] : 'i'; + if (Size > 2) break; + switch (NextChar) { + default: + return CW_Invalid; + // XMM0 + case 'z': + case '0': + if ((type->getPrimitiveSizeInBits() == 128) && Subtarget.hasSSE1()) + return CW_SpecificReg; + return CW_Invalid; + // Conditional OpMask regs (AVX512) + case 'k': + if ((type->getPrimitiveSizeInBits() == 64) && Subtarget.hasAVX512()) + return CW_Register; + return CW_Invalid; + // Any MMX reg + case 'm': + if (type->isX86_MMXTy() && Subtarget.hasMMX()) + return weight; + return CW_Invalid; + // Any SSE reg when ISA >= SSE2, same as 'Y' + case 'i': + case 't': + case '2': + if (!Subtarget.hasSSE2()) + return CW_Invalid; + break; } - // Else fall through (handle "Y" constraint). + // Fall through (handle "Y" constraint). LLVM_FALLTHROUGH; + } case 'v': if ((type->getPrimitiveSizeInBits() == 512) && Subtarget.hasAVX512()) weight = CW_Register; @@ -36143,7 +38187,8 @@ TargetLowering::ConstraintWeight break; case 'k': // Enable conditional vector operations using %k<#> registers. - weight = CW_SpecificReg; + if ((type->getPrimitiveSizeInBits() == 64) && Subtarget.hasAVX512()) + weight = CW_Register; break; case 'I': if (ConstantInt *C = dyn_cast<ConstantInt>(info.CallOperandVal)) { @@ -36545,6 +38590,17 @@ X86TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, switch (Constraint[1]) { default: break; + case 'i': + case 't': + case '2': + return getRegForInlineAsmConstraint(TRI, "Y", VT); + case 'm': + if (!Subtarget.hasMMX()) break; + return std::make_pair(0U, &X86::VR64RegClass); + case 'z': + case '0': + if (!Subtarget.hasSSE1()) break; + return std::make_pair(X86::XMM0, &X86::VR128RegClass); case 'k': // This register class doesn't allocate k0 for masked vector operation. if (Subtarget.hasAVX512()) { // Only supported in AVX512. @@ -36637,12 +38693,14 @@ X86TargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, if (Size == 1) Size = 8; unsigned DestReg = getX86SubSuperRegisterOrZero(Res.first, Size); if (DestReg > 0) { - Res.first = DestReg; - Res.second = Size == 8 ? &X86::GR8RegClass - : Size == 16 ? &X86::GR16RegClass - : Size == 32 ? &X86::GR32RegClass - : &X86::GR64RegClass; - assert(Res.second->contains(Res.first) && "Register in register class"); + bool is64Bit = Subtarget.is64Bit(); + const TargetRegisterClass *RC = + Size == 8 ? (is64Bit ? &X86::GR8RegClass : &X86::GR8_NOREXRegClass) + : Size == 16 ? (is64Bit ? &X86::GR16RegClass : &X86::GR16_NOREXRegClass) + : Size == 32 ? (is64Bit ? &X86::GR32RegClass : &X86::GR32_NOREXRegClass) + : &X86::GR64RegClass; + if (RC->contains(DestReg)) + Res = std::make_pair(DestReg, RC); } else { // No register found/type mismatch. Res.first = 0; @@ -36750,7 +38808,7 @@ void X86TargetLowering::insertCopiesSplitCSR( // fine for CXX_FAST_TLS since the C++-style TLS access functions should be // nounwind. If we want to generalize this later, we may need to emit // CFI pseudo-instructions. - assert(Entry->getParent()->getFunction()->hasFnAttribute( + assert(Entry->getParent()->getFunction().hasFnAttribute( Attribute::NoUnwind) && "Function should be nounwind in insertCopiesSplitCSR!"); Entry->addLiveIn(*I); @@ -36773,8 +38831,8 @@ bool X86TargetLowering::supportSwiftError() const { /// string if not applicable. StringRef X86TargetLowering::getStackProbeSymbolName(MachineFunction &MF) const { // If the function specifically requests stack probes, emit them. - if (MF.getFunction()->hasFnAttribute("probe-stack")) - return MF.getFunction()->getFnAttribute("probe-stack").getValueAsString(); + if (MF.getFunction().hasFnAttribute("probe-stack")) + return MF.getFunction().getFnAttribute("probe-stack").getValueAsString(); // Generally, if we aren't on Windows, the platform ABI does not include // support for stack probes, so don't emit them. |
