From 2a695f2a37d6c6dc77268246368d2d7d7f1adea2 Mon Sep 17 00:00:00 2001 From: Alex MacLean Date: Tue, 6 Aug 2024 21:24:48 +0000 Subject: [PATCH 1/2] [NVPTX] support switch statement with brx.idx --- llvm/include/llvm/CodeGen/TargetLowering.h | 4 ++ .../SelectionDAG/SelectionDAGBuilder.cpp | 11 +-- llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 45 +++++++++++- llvm/lib/Target/NVPTX/NVPTXISelLowering.h | 10 +++ llvm/lib/Target/NVPTX/NVPTXInstrInfo.td | 38 ++++++++++ llvm/test/CodeGen/NVPTX/jump-table.ll | 69 +++++++++++++++++++ 6 files changed, 169 insertions(+), 8 deletions(-) create mode 100644 llvm/test/CodeGen/NVPTX/jump-table.ll diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h index 9ccdbab008aec8..5b2214fa66c40b 100644 --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -3843,6 +3843,10 @@ class TargetLowering : public TargetLoweringBase { /// returned value is a member of the MachineJumpTableInfo::JTEntryKind enum. virtual unsigned getJumpTableEncoding() const; + virtual MVT getJumpTableRegTy(const DataLayout &DL) const { + return getPointerTy(DL); + } + virtual const MCExpr * LowerCustomJumpTableEntry(const MachineJumpTableInfo * /*MJTI*/, const MachineBasicBlock * /*MBB*/, unsigned /*uid*/, diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp index 9d617c7acd13c2..192fbf74b02dc0 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -2977,7 +2977,7 @@ void SelectionDAGBuilder::visitJumpTable(SwitchCG::JumpTable &JT) { // Emit the code for the jump table assert(JT.SL && "Should set SDLoc for SelectionDAG!"); assert(JT.Reg != -1U && "Should lower JT Header first!"); - EVT PTy = DAG.getTargetLoweringInfo().getPointerTy(DAG.getDataLayout()); + EVT PTy = DAG.getTargetLoweringInfo().getJumpTableRegTy(DAG.getDataLayout()); SDValue Index = DAG.getCopyFromReg(getControlRoot(), *JT.SL, JT.Reg, PTy); SDValue Table = DAG.getJumpTable(JT.JTI, PTy); SDValue BrJumpTable = DAG.getNode(ISD::BR_JT, *JT.SL, MVT::Other, @@ -3005,12 +3005,13 @@ void SelectionDAGBuilder::visitJumpTableHeader(SwitchCG::JumpTable &JT, // This value may be smaller or larger than the target's pointer type, and // therefore require extension or truncating. const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - SwitchOp = DAG.getZExtOrTrunc(Sub, dl, TLI.getPointerTy(DAG.getDataLayout())); + SwitchOp = + DAG.getZExtOrTrunc(Sub, dl, TLI.getJumpTableRegTy(DAG.getDataLayout())); unsigned JumpTableReg = - FuncInfo.CreateReg(TLI.getPointerTy(DAG.getDataLayout())); - SDValue CopyTo = DAG.getCopyToReg(getControlRoot(), dl, - JumpTableReg, SwitchOp); + FuncInfo.CreateReg(TLI.getJumpTableRegTy(DAG.getDataLayout())); + SDValue CopyTo = + DAG.getCopyToReg(getControlRoot(), dl, JumpTableReg, SwitchOp); JT.Reg = JumpTableReg; if (!JTH.FallthroughUnreachable) { diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index 516fc7339a4bf3..bf647c88f00e28 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -25,6 +25,7 @@ #include "llvm/CodeGen/Analysis.h" #include "llvm/CodeGen/ISDOpcodes.h" #include "llvm/CodeGen/MachineFunction.h" +#include "llvm/CodeGen/MachineJumpTableInfo.h" #include "llvm/CodeGen/MachineMemOperand.h" #include "llvm/CodeGen/SelectionDAG.h" #include "llvm/CodeGen/SelectionDAGNodes.h" @@ -582,9 +583,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, setOperationAction(ISD::ROTR, MVT::i8, Expand); setOperationAction(ISD::BSWAP, MVT::i16, Expand); - // Indirect branch is not supported. - // This also disables Jump Table creation. - setOperationAction(ISD::BR_JT, MVT::Other, Expand); + setOperationAction(ISD::BR_JT, MVT::Other, Custom); setOperationAction(ISD::BRIND, MVT::Other, Expand); setOperationAction(ISD::GlobalAddress, MVT::i32, Custom); @@ -945,6 +944,9 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const { MAKE_CASE(NVPTXISD::Dummy) MAKE_CASE(NVPTXISD::MUL_WIDE_SIGNED) MAKE_CASE(NVPTXISD::MUL_WIDE_UNSIGNED) + MAKE_CASE(NVPTXISD::BrxEnd) + MAKE_CASE(NVPTXISD::BrxItem) + MAKE_CASE(NVPTXISD::BrxStart) MAKE_CASE(NVPTXISD::Tex1DFloatS32) MAKE_CASE(NVPTXISD::Tex1DFloatFloat) MAKE_CASE(NVPTXISD::Tex1DFloatFloatLevel) @@ -2785,6 +2787,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { return LowerFP_ROUND(Op, DAG); case ISD::FP_EXTEND: return LowerFP_EXTEND(Op, DAG); + case ISD::BR_JT: + return LowerBR_JT(Op, DAG); case ISD::VAARG: return LowerVAARG(Op, DAG); case ISD::VASTART: @@ -2810,6 +2814,41 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { } } +SDValue NVPTXTargetLowering::LowerBR_JT(SDValue Op, SelectionDAG &DAG) const { + SDLoc DL(Op); + SDValue Chain = Op.getOperand(0); + const auto *JT = cast(Op.getOperand(1)); + SDValue Index = Op.getOperand(2); + + unsigned JId = JT->getIndex(); + MachineJumpTableInfo *MJTI = DAG.getMachineFunction().getJumpTableInfo(); + ArrayRef MBBs = MJTI->getJumpTables()[JId].MBBs; + + SDValue IdV = DAG.getConstant(JId, DL, MVT::i32); + + // Generate BrxStart node + SDVTList VTs = DAG.getVTList(MVT::Other, MVT::Glue); + Chain = DAG.getNode(NVPTXISD::BrxStart, DL, VTs, Chain, IdV); + + // Generate BrxItem nodes + assert(!MBBs.empty()); + for (MachineBasicBlock *MBB : MBBs.drop_back()) + Chain = DAG.getNode(NVPTXISD::BrxItem, DL, VTs, Chain.getValue(0), + DAG.getBasicBlock(MBB), Chain.getValue(1)); + + // Generate BrxEnd nodes + SDValue EndOps[] = {Chain.getValue(0), DAG.getBasicBlock(MBBs.back()), Index, + IdV, Chain.getValue(1)}; + SDValue BrxEnd = DAG.getNode(NVPTXISD::BrxEnd, DL, VTs, EndOps); + + return BrxEnd; +} + +// This will prevent AsmPrinter from trying to print the jump tables itself. +unsigned NVPTXTargetLowering::getJumpTableEncoding() const { + return MachineJumpTableInfo::EK_Inline; +} + // This function is almost a copy of SelectionDAG::expandVAArg(). // The only diff is that this one produces loads from local address space. SDValue NVPTXTargetLowering::LowerVAARG(SDValue Op, SelectionDAG &DAG) const { diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h index 63262961b363ed..32e6b044b0de1f 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h @@ -62,6 +62,9 @@ enum NodeType : unsigned { BFI, PRMT, DYNAMIC_STACKALLOC, + BrxStart, + BrxItem, + BrxEnd, Dummy, LoadV2 = ISD::FIRST_TARGET_MEMORY_OPCODE, @@ -580,6 +583,11 @@ class NVPTXTargetLowering : public TargetLowering { return true; } + // The default is the same as pointer type, but brx.idx only accepts i32 + MVT getJumpTableRegTy(const DataLayout &) const override { return MVT::i32; } + + unsigned getJumpTableEncoding() const override; + bool enableAggressiveFMAFusion(EVT VT) const override { return true; } // The default is to transform llvm.ctlz(x, false) (where false indicates that @@ -637,6 +645,8 @@ class NVPTXTargetLowering : public TargetLowering { SDValue LowerSelect(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerBR_JT(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerVAARG(SDValue Op, SelectionDAG &DAG) const; SDValue LowerVASTART(SDValue Op, SelectionDAG &DAG) const; diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td index 6a096fa5acea7c..cec7f20255d352 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -3880,6 +3880,44 @@ def DYNAMIC_STACKALLOC64 : [(set Int64Regs:$ptr, (dyn_alloca Int64Regs:$size, (i32 timm:$align)))]>, Requires<[hasPTX<73>, hasSM<52>]>; + +// +// BRX +// + +def SDTBrxStartProfile : SDTypeProfile<0, 1, [SDTCisInt<0>]>; +def SDTBrxItemProfile : SDTypeProfile<0, 1, [SDTCisVT<0, OtherVT>]>; +def SDTBrxEndProfile : SDTypeProfile<0, 3, [SDTCisVT<0, OtherVT>, SDTCisInt<1>, SDTCisInt<2>]>; + +def brx_start : + SDNode<"NVPTXISD::BrxStart", SDTBrxStartProfile, + [SDNPHasChain, SDNPOutGlue, SDNPSideEffect]>; +def brx_item : + SDNode<"NVPTXISD::BrxItem", SDTBrxItemProfile, + [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; +def brx_end : + SDNode<"NVPTXISD::BrxEnd", SDTBrxEndProfile, + [SDNPHasChain, SDNPInGlue, SDNPSideEffect]>; + +let isTerminator = 1, isBranch = 1, isIndirectBranch = 1 in { + + def BRX_START : + NVPTXInst<(outs), (ins i32imm:$id), + "$$L_brx_$id: .branchtargets", + [(brx_start (i32 imm:$id))]>; + + def BRX_ITEM : + NVPTXInst<(outs), (ins brtarget:$target), + "$target,", + [(brx_item bb:$target)]>; + + def BRX_END : + NVPTXInst<(outs), (ins brtarget:$target, Int32Regs:$val, i32imm:$id), + "$target;\n\tbrx.idx \t$val, $$L_brx_$id;", + [(brx_end bb:$target, (i32 Int32Regs:$val), (i32 imm:$id))]>; +} + + include "NVPTXIntrinsics.td" //----------------------------------- diff --git a/llvm/test/CodeGen/NVPTX/jump-table.ll b/llvm/test/CodeGen/NVPTX/jump-table.ll new file mode 100644 index 00000000000000..8dd4115e2feb63 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/jump-table.ll @@ -0,0 +1,69 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 +; RUN: llc < %s | FileCheck %s +; RUN: %if ptxas %{ llc < %s | %ptxas-verify %} + +target triple = "nvptx64-nvidia-cuda" + +@out = addrspace(1) global i32 0, align 4 + +define void @foo(i32 %i) { +; CHECK-LABEL: foo( +; CHECK: { +; CHECK-NEXT: .reg .pred %p<2>; +; CHECK-NEXT: .reg .b32 %r<7>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: // %entry +; CHECK-NEXT: ld.param.u32 %r2, [foo_param_0]; +; CHECK-NEXT: setp.gt.u32 %p1, %r2, 3; +; CHECK-NEXT: @%p1 bra $L__BB0_6; +; CHECK-NEXT: // %bb.1: // %entry +; CHECK-NEXT: $L_brx_0: .branchtargets +; CHECK-NEXT: $L__BB0_2, +; CHECK-NEXT: $L__BB0_3, +; CHECK-NEXT: $L__BB0_4, +; CHECK-NEXT: $L__BB0_5; +; CHECK-NEXT: brx.idx %r2, $L_brx_0; +; CHECK-NEXT: $L__BB0_2: // %case0 +; CHECK-NEXT: mov.b32 %r6, 0; +; CHECK-NEXT: st.global.u32 [out], %r6; +; CHECK-NEXT: bra.uni $L__BB0_6; +; CHECK-NEXT: $L__BB0_4: // %case2 +; CHECK-NEXT: mov.b32 %r4, 2; +; CHECK-NEXT: st.global.u32 [out], %r4; +; CHECK-NEXT: bra.uni $L__BB0_6; +; CHECK-NEXT: $L__BB0_5: // %case3 +; CHECK-NEXT: mov.b32 %r3, 3; +; CHECK-NEXT: st.global.u32 [out], %r3; +; CHECK-NEXT: bra.uni $L__BB0_6; +; CHECK-NEXT: $L__BB0_3: // %case1 +; CHECK-NEXT: mov.b32 %r5, 1; +; CHECK-NEXT: st.global.u32 [out], %r5; +; CHECK-NEXT: $L__BB0_6: // %end +; CHECK-NEXT: ret; +entry: + switch i32 %i, label %end [ + i32 0, label %case0 + i32 1, label %case1 + i32 2, label %case2 + i32 3, label %case3 + ] + +case0: + store i32 0, ptr addrspace(1) @out, align 4 + br label %end + +case1: + store i32 1, ptr addrspace(1) @out, align 4 + br label %end + +case2: + store i32 2, ptr addrspace(1) @out, align 4 + br label %end + +case3: + store i32 3, ptr addrspace(1) @out, align 4 + br label %end + +end: + ret void +} From b4ef23e309b045690c1201f3bad5830ea2164a5e Mon Sep 17 00:00:00 2001 From: Alex MacLean Date: Thu, 8 Aug 2024 16:11:15 +0000 Subject: [PATCH 2/2] address comments --- llvm/lib/Target/NVPTX/NVPTXInstrInfo.td | 4 ++-- llvm/test/CodeGen/NVPTX/jump-table.ll | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td index cec7f20255d352..904574b2e1d660 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -3908,12 +3908,12 @@ let isTerminator = 1, isBranch = 1, isIndirectBranch = 1 in { def BRX_ITEM : NVPTXInst<(outs), (ins brtarget:$target), - "$target,", + "\t$target,", [(brx_item bb:$target)]>; def BRX_END : NVPTXInst<(outs), (ins brtarget:$target, Int32Regs:$val, i32imm:$id), - "$target;\n\tbrx.idx \t$val, $$L_brx_$id;", + "\t$target;\n\tbrx.idx \t$val, $$L_brx_$id;", [(brx_end bb:$target, (i32 Int32Regs:$val), (i32 imm:$id))]>; } diff --git a/llvm/test/CodeGen/NVPTX/jump-table.ll b/llvm/test/CodeGen/NVPTX/jump-table.ll index 8dd4115e2feb63..867e171a5840ae 100644 --- a/llvm/test/CodeGen/NVPTX/jump-table.ll +++ b/llvm/test/CodeGen/NVPTX/jump-table.ll @@ -18,10 +18,10 @@ define void @foo(i32 %i) { ; CHECK-NEXT: @%p1 bra $L__BB0_6; ; CHECK-NEXT: // %bb.1: // %entry ; CHECK-NEXT: $L_brx_0: .branchtargets -; CHECK-NEXT: $L__BB0_2, -; CHECK-NEXT: $L__BB0_3, -; CHECK-NEXT: $L__BB0_4, -; CHECK-NEXT: $L__BB0_5; +; CHECK-NEXT: $L__BB0_2, +; CHECK-NEXT: $L__BB0_3, +; CHECK-NEXT: $L__BB0_4, +; CHECK-NEXT: $L__BB0_5; ; CHECK-NEXT: brx.idx %r2, $L_brx_0; ; CHECK-NEXT: $L__BB0_2: // %case0 ; CHECK-NEXT: mov.b32 %r6, 0;