diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h index 9ccdbab008aec8..5b2214fa66c40b 100644 --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -3843,6 +3843,10 @@ class TargetLowering : public TargetLoweringBase { /// returned value is a member of the MachineJumpTableInfo::JTEntryKind enum. virtual unsigned getJumpTableEncoding() const; + virtual MVT getJumpTableRegTy(const DataLayout &DL) const { + return getPointerTy(DL); + } + virtual const MCExpr * LowerCustomJumpTableEntry(const MachineJumpTableInfo * /*MJTI*/, const MachineBasicBlock * /*MBB*/, unsigned /*uid*/, diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp index 1f4436fb3a4966..37ba62911ec70b 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -2977,7 +2977,7 @@ void SelectionDAGBuilder::visitJumpTable(SwitchCG::JumpTable &JT) { // Emit the code for the jump table assert(JT.SL && "Should set SDLoc for SelectionDAG!"); assert(JT.Reg != -1U && "Should lower JT Header first!"); - EVT PTy = DAG.getTargetLoweringInfo().getPointerTy(DAG.getDataLayout()); + EVT PTy = DAG.getTargetLoweringInfo().getJumpTableRegTy(DAG.getDataLayout()); SDValue Index = DAG.getCopyFromReg(getControlRoot(), *JT.SL, JT.Reg, PTy); SDValue Table = DAG.getJumpTable(JT.JTI, PTy); SDValue BrJumpTable = DAG.getNode(ISD::BR_JT, *JT.SL, MVT::Other, @@ -3005,12 +3005,13 @@ void SelectionDAGBuilder::visitJumpTableHeader(SwitchCG::JumpTable &JT, // This value may be smaller or larger than the target's pointer type, and // therefore require extension or truncating. const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - SwitchOp = DAG.getZExtOrTrunc(Sub, dl, TLI.getPointerTy(DAG.getDataLayout())); + SwitchOp = + DAG.getZExtOrTrunc(Sub, dl, TLI.getJumpTableRegTy(DAG.getDataLayout())); unsigned JumpTableReg = - FuncInfo.CreateReg(TLI.getPointerTy(DAG.getDataLayout())); - SDValue CopyTo = DAG.getCopyToReg(getControlRoot(), dl, - JumpTableReg, SwitchOp); + FuncInfo.CreateReg(TLI.getJumpTableRegTy(DAG.getDataLayout())); + SDValue CopyTo = + DAG.getCopyToReg(getControlRoot(), dl, JumpTableReg, SwitchOp); JT.Reg = JumpTableReg; if (!JTH.FallthroughUnreachable) { diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index 800f2ba693f53b..059cfff1f7e692 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -25,6 +25,7 @@ #include "llvm/CodeGen/Analysis.h" #include "llvm/CodeGen/ISDOpcodes.h" #include "llvm/CodeGen/MachineFunction.h" +#include "llvm/CodeGen/MachineJumpTableInfo.h" #include "llvm/CodeGen/MachineMemOperand.h" #include "llvm/CodeGen/SelectionDAG.h" #include "llvm/CodeGen/SelectionDAGNodes.h" @@ -582,9 +583,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, setOperationAction(ISD::ROTR, MVT::i8, Expand); setOperationAction(ISD::BSWAP, MVT::i16, Expand); - // Indirect branch is not supported. - // This also disables Jump Table creation. - setOperationAction(ISD::BR_JT, MVT::Other, Expand); + setOperationAction(ISD::BR_JT, MVT::Other, Custom); setOperationAction(ISD::BRIND, MVT::Other, Expand); setOperationAction(ISD::GlobalAddress, MVT::i32, Custom); @@ -945,6 +944,9 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const { MAKE_CASE(NVPTXISD::Dummy) MAKE_CASE(NVPTXISD::MUL_WIDE_SIGNED) MAKE_CASE(NVPTXISD::MUL_WIDE_UNSIGNED) + MAKE_CASE(NVPTXISD::BrxEnd) + MAKE_CASE(NVPTXISD::BrxItem) + MAKE_CASE(NVPTXISD::BrxStart) MAKE_CASE(NVPTXISD::Tex1DFloatS32) MAKE_CASE(NVPTXISD::Tex1DFloatFloat) MAKE_CASE(NVPTXISD::Tex1DFloatFloatLevel) @@ -2785,6 +2787,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { return LowerFP_ROUND(Op, DAG); case ISD::FP_EXTEND: return LowerFP_EXTEND(Op, DAG); + case ISD::BR_JT: + return LowerBR_JT(Op, DAG); case ISD::VAARG: return LowerVAARG(Op, DAG); case ISD::VASTART: @@ -2810,6 +2814,41 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { } } +SDValue NVPTXTargetLowering::LowerBR_JT(SDValue Op, SelectionDAG &DAG) const { + SDLoc DL(Op); + SDValue Chain = Op.getOperand(0); + const auto *JT = cast(Op.getOperand(1)); + SDValue Index = Op.getOperand(2); + + unsigned JId = JT->getIndex(); + MachineJumpTableInfo *MJTI = DAG.getMachineFunction().getJumpTableInfo(); + ArrayRef MBBs = MJTI->getJumpTables()[JId].MBBs; + + SDValue IdV = DAG.getConstant(JId, DL, MVT::i32); + + // Generate BrxStart node + SDVTList VTs = DAG.getVTList(MVT::Other, MVT::Glue); + Chain = DAG.getNode(NVPTXISD::BrxStart, DL, VTs, Chain, IdV); + + // Generate BrxItem nodes + assert(!MBBs.empty()); + for (MachineBasicBlock *MBB : MBBs.drop_back()) + Chain = DAG.getNode(NVPTXISD::BrxItem, DL, VTs, Chain.getValue(0), + DAG.getBasicBlock(MBB), Chain.getValue(1)); + + // Generate BrxEnd nodes + SDValue EndOps[] = {Chain.getValue(0), DAG.getBasicBlock(MBBs.back()), Index, + IdV, Chain.getValue(1)}; + SDValue BrxEnd = DAG.getNode(NVPTXISD::BrxEnd, DL, VTs, EndOps); + + return BrxEnd; +} + +// This will prevent AsmPrinter from trying to print the jump tables itself. +unsigned NVPTXTargetLowering::getJumpTableEncoding() const { + return MachineJumpTableInfo::EK_Inline; +} + // This function is almost a copy of SelectionDAG::expandVAArg(). // The only diff is that this one produces loads from local address space. SDValue NVPTXTargetLowering::LowerVAARG(SDValue Op, SelectionDAG &DAG) const { diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h index 63262961b363ed..32e6b044b0de1f 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h @@ -62,6 +62,9 @@ enum NodeType : unsigned { BFI, PRMT, DYNAMIC_STACKALLOC, + BrxStart, + BrxItem, + BrxEnd, Dummy, LoadV2 = ISD::FIRST_TARGET_MEMORY_OPCODE, @@ -580,6 +583,11 @@ class NVPTXTargetLowering : public TargetLowering { return true; } + // The default is the same as pointer type, but brx.idx only accepts i32 + MVT getJumpTableRegTy(const DataLayout &) const override { return MVT::i32; } + + unsigned getJumpTableEncoding() const override; + bool enableAggressiveFMAFusion(EVT VT) const override { return true; } // The default is to transform llvm.ctlz(x, false) (where false indicates that @@ -637,6 +645,8 @@ class NVPTXTargetLowering : public TargetLowering { SDValue LowerSelect(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerBR_JT(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerVAARG(SDValue Op, SelectionDAG &DAG) const; SDValue LowerVASTART(SDValue Op, SelectionDAG &DAG) const; diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td index 6a096fa5acea7c..d75dc8781f7802 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -3880,6 +3880,46 @@ def DYNAMIC_STACKALLOC64 : [(set Int64Regs:$ptr, (dyn_alloca Int64Regs:$size, (i32 timm:$align)))]>, Requires<[hasPTX<73>, hasSM<52>]>; + +// +// BRX +// + +def SDTBrxStartProfile : SDTypeProfile<0, 1, [SDTCisInt<0>]>; +def SDTBrxItemProfile : SDTypeProfile<0, 1, [SDTCisVT<0, OtherVT>]>; +def SDTBrxEndProfile : SDTypeProfile<0, 3, [SDTCisVT<0, OtherVT>, SDTCisInt<1>, SDTCisInt<2>]>; + +def brx_start : + SDNode<"NVPTXISD::BrxStart", SDTBrxStartProfile, + [SDNPHasChain, SDNPOutGlue, SDNPSideEffect]>; +def brx_item : + SDNode<"NVPTXISD::BrxItem", SDTBrxItemProfile, + [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; +def brx_end : + SDNode<"NVPTXISD::BrxEnd", SDTBrxEndProfile, + [SDNPHasChain, SDNPInGlue, SDNPSideEffect]>; + +let isTerminator = 1, isBranch = 1, isIndirectBranch = 1, isNotDuplicable = 1 in { + + def BRX_START : + NVPTXInst<(outs), (ins i32imm:$id), + "$$L_brx_$id: .branchtargets", + [(brx_start (i32 imm:$id))]>; + + def BRX_ITEM : + NVPTXInst<(outs), (ins brtarget:$target), + "\t$target,", + [(brx_item bb:$target)]>; + + def BRX_END : + NVPTXInst<(outs), (ins brtarget:$target, Int32Regs:$val, i32imm:$id), + "\t$target;\n\tbrx.idx \t$val, $$L_brx_$id;", + [(brx_end bb:$target, (i32 Int32Regs:$val), (i32 imm:$id))]> { + let isBarrier = 1; + } +} + + include "NVPTXIntrinsics.td" //----------------------------------- diff --git a/llvm/test/CodeGen/NVPTX/jump-table.ll b/llvm/test/CodeGen/NVPTX/jump-table.ll new file mode 100644 index 00000000000000..b201fb98f3e6bb --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/jump-table.ll @@ -0,0 +1,170 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 +; RUN: llc < %s | FileCheck %s +; RUN: %if ptxas %{ llc < %s | %ptxas-verify %} + +target triple = "nvptx64-nvidia-cuda" + +@out = addrspace(1) global i32 0, align 4 + +define void @foo(i32 %i) { +; CHECK-LABEL: foo( +; CHECK: { +; CHECK-NEXT: .reg .pred %p<2>; +; CHECK-NEXT: .reg .b32 %r<7>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: // %entry +; CHECK-NEXT: ld.param.u32 %r2, [foo_param_0]; +; CHECK-NEXT: setp.gt.u32 %p1, %r2, 3; +; CHECK-NEXT: @%p1 bra $L__BB0_6; +; CHECK-NEXT: // %bb.1: // %entry +; CHECK-NEXT: $L_brx_0: .branchtargets +; CHECK-NEXT: $L__BB0_2, +; CHECK-NEXT: $L__BB0_3, +; CHECK-NEXT: $L__BB0_4, +; CHECK-NEXT: $L__BB0_5; +; CHECK-NEXT: brx.idx %r2, $L_brx_0; +; CHECK-NEXT: $L__BB0_2: // %case0 +; CHECK-NEXT: mov.b32 %r6, 0; +; CHECK-NEXT: st.global.u32 [out], %r6; +; CHECK-NEXT: bra.uni $L__BB0_6; +; CHECK-NEXT: $L__BB0_4: // %case2 +; CHECK-NEXT: mov.b32 %r4, 2; +; CHECK-NEXT: st.global.u32 [out], %r4; +; CHECK-NEXT: bra.uni $L__BB0_6; +; CHECK-NEXT: $L__BB0_5: // %case3 +; CHECK-NEXT: mov.b32 %r3, 3; +; CHECK-NEXT: st.global.u32 [out], %r3; +; CHECK-NEXT: bra.uni $L__BB0_6; +; CHECK-NEXT: $L__BB0_3: // %case1 +; CHECK-NEXT: mov.b32 %r5, 1; +; CHECK-NEXT: st.global.u32 [out], %r5; +; CHECK-NEXT: $L__BB0_6: // %end +; CHECK-NEXT: ret; +entry: + switch i32 %i, label %end [ + i32 0, label %case0 + i32 1, label %case1 + i32 2, label %case2 + i32 3, label %case3 + ] + +case0: + store i32 0, ptr addrspace(1) @out, align 4 + br label %end + +case1: + store i32 1, ptr addrspace(1) @out, align 4 + br label %end + +case2: + store i32 2, ptr addrspace(1) @out, align 4 + br label %end + +case3: + store i32 3, ptr addrspace(1) @out, align 4 + br label %end + +end: + ret void +} + + +define i32 @test2(i32 %tmp158) { +; CHECK-LABEL: test2( +; CHECK: { +; CHECK-NEXT: .reg .pred %p<6>; +; CHECK-NEXT: .reg .b32 %r<10>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: // %entry +; CHECK-NEXT: ld.param.u32 %r1, [test2_param_0]; +; CHECK-NEXT: setp.gt.s32 %p1, %r1, 119; +; CHECK-NEXT: @%p1 bra $L__BB1_4; +; CHECK-NEXT: // %bb.1: // %entry +; CHECK-NEXT: setp.lt.u32 %p4, %r1, 6; +; CHECK-NEXT: @%p4 bra $L__BB1_3; +; CHECK-NEXT: // %bb.2: // %entry +; CHECK-NEXT: setp.lt.s32 %p5, %r1, -2147483645; +; CHECK-NEXT: @%p5 bra $L__BB1_3; +; CHECK-NEXT: bra.uni $L__BB1_6; +; CHECK-NEXT: $L__BB1_4: // %entry +; CHECK-NEXT: add.s32 %r2, %r1, -120; +; CHECK-NEXT: setp.gt.u32 %p2, %r2, 5; +; CHECK-NEXT: @%p2 bra $L__BB1_5; +; CHECK-NEXT: // %bb.12: // %entry +; CHECK-NEXT: $L_brx_0: .branchtargets +; CHECK-NEXT: $L__BB1_3, +; CHECK-NEXT: $L__BB1_7, +; CHECK-NEXT: $L__BB1_8, +; CHECK-NEXT: $L__BB1_9, +; CHECK-NEXT: $L__BB1_10, +; CHECK-NEXT: $L__BB1_11; +; CHECK-NEXT: brx.idx %r2, $L_brx_0; +; CHECK-NEXT: $L__BB1_7: // %bb339 +; CHECK-NEXT: mov.b32 %r7, 12; +; CHECK-NEXT: st.param.b32 [func_retval0+0], %r7; +; CHECK-NEXT: ret; +; CHECK-NEXT: $L__BB1_5: // %entry +; CHECK-NEXT: setp.eq.s32 %p3, %r1, 1024; +; CHECK-NEXT: @%p3 bra $L__BB1_3; +; CHECK-NEXT: bra.uni $L__BB1_6; +; CHECK-NEXT: $L__BB1_3: // %bb338 +; CHECK-NEXT: mov.b32 %r8, 11; +; CHECK-NEXT: st.param.b32 [func_retval0+0], %r8; +; CHECK-NEXT: ret; +; CHECK-NEXT: $L__BB1_10: // %bb342 +; CHECK-NEXT: mov.b32 %r4, 15; +; CHECK-NEXT: st.param.b32 [func_retval0+0], %r4; +; CHECK-NEXT: ret; +; CHECK-NEXT: $L__BB1_6: // %bb336 +; CHECK-NEXT: mov.b32 %r9, 10; +; CHECK-NEXT: st.param.b32 [func_retval0+0], %r9; +; CHECK-NEXT: ret; +; CHECK-NEXT: $L__BB1_8: // %bb340 +; CHECK-NEXT: mov.b32 %r6, 13; +; CHECK-NEXT: st.param.b32 [func_retval0+0], %r6; +; CHECK-NEXT: ret; +; CHECK-NEXT: $L__BB1_9: // %bb341 +; CHECK-NEXT: mov.b32 %r5, 14; +; CHECK-NEXT: st.param.b32 [func_retval0+0], %r5; +; CHECK-NEXT: ret; +; CHECK-NEXT: $L__BB1_11: // %bb343 +; CHECK-NEXT: mov.b32 %r3, 18; +; CHECK-NEXT: st.param.b32 [func_retval0+0], %r3; +; CHECK-NEXT: ret; +entry: + switch i32 %tmp158, label %bb336 [ + i32 -2147483648, label %bb338 + i32 -2147483647, label %bb338 + i32 -2147483646, label %bb338 + i32 120, label %bb338 + i32 121, label %bb339 + i32 122, label %bb340 + i32 123, label %bb341 + i32 124, label %bb342 + i32 125, label %bb343 + i32 126, label %bb336 + i32 1024, label %bb338 + i32 0, label %bb338 + i32 1, label %bb338 + i32 2, label %bb338 + i32 3, label %bb338 + i32 4, label %bb338 + i32 5, label %bb338 + ] + +bb336: + ret i32 10 +bb338: + ret i32 11 +bb339: + ret i32 12 +bb340: + ret i32 13 +bb341: + ret i32 14 +bb342: + ret i32 15 +bb343: + ret i32 18 + +}