Skip to content

Commit

Permalink
[NVPTX] support switch statement with brx.idx (llvm#102400)
Browse files Browse the repository at this point in the history
Add custom lowering for `BR_JT` DAG nodes to the `brx.idx` PTX
instruction ([PTX ISA 9.7.13.4. Control Flow Instructions: brx.idx]
(https://docs.nvidia.com/cuda/parallel-thread-execution/#control-flow-instructions-brx-idx)).
Depending on the heuristics in DAG selection, `switch` statements may
now be lowered using `brx.idx`
  • Loading branch information
AlexMaclean authored Aug 8, 2024
1 parent 8c3b6bd commit ba97697
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 8 deletions.
4 changes: 4 additions & 0 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -3843,6 +3843,10 @@ class TargetLowering : public TargetLoweringBase {
/// returned value is a member of the MachineJumpTableInfo::JTEntryKind enum.
virtual unsigned getJumpTableEncoding() const;

virtual MVT getJumpTableRegTy(const DataLayout &DL) const {
return getPointerTy(DL);
}

virtual const MCExpr *
LowerCustomJumpTableEntry(const MachineJumpTableInfo * /*MJTI*/,
const MachineBasicBlock * /*MBB*/, unsigned /*uid*/,
Expand Down
11 changes: 6 additions & 5 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2977,7 +2977,7 @@ void SelectionDAGBuilder::visitJumpTable(SwitchCG::JumpTable &JT) {
// Emit the code for the jump table
assert(JT.SL && "Should set SDLoc for SelectionDAG!");
assert(JT.Reg != -1U && "Should lower JT Header first!");
EVT PTy = DAG.getTargetLoweringInfo().getPointerTy(DAG.getDataLayout());
EVT PTy = DAG.getTargetLoweringInfo().getJumpTableRegTy(DAG.getDataLayout());
SDValue Index = DAG.getCopyFromReg(getControlRoot(), *JT.SL, JT.Reg, PTy);
SDValue Table = DAG.getJumpTable(JT.JTI, PTy);
SDValue BrJumpTable = DAG.getNode(ISD::BR_JT, *JT.SL, MVT::Other,
Expand Down Expand Up @@ -3005,12 +3005,13 @@ void SelectionDAGBuilder::visitJumpTableHeader(SwitchCG::JumpTable &JT,
// This value may be smaller or larger than the target's pointer type, and
// therefore require extension or truncating.
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
SwitchOp = DAG.getZExtOrTrunc(Sub, dl, TLI.getPointerTy(DAG.getDataLayout()));
SwitchOp =
DAG.getZExtOrTrunc(Sub, dl, TLI.getJumpTableRegTy(DAG.getDataLayout()));

unsigned JumpTableReg =
FuncInfo.CreateReg(TLI.getPointerTy(DAG.getDataLayout()));
SDValue CopyTo = DAG.getCopyToReg(getControlRoot(), dl,
JumpTableReg, SwitchOp);
FuncInfo.CreateReg(TLI.getJumpTableRegTy(DAG.getDataLayout()));
SDValue CopyTo =
DAG.getCopyToReg(getControlRoot(), dl, JumpTableReg, SwitchOp);
JT.Reg = JumpTableReg;

if (!JTH.FallthroughUnreachable) {
Expand Down
45 changes: 42 additions & 3 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "llvm/CodeGen/Analysis.h"
#include "llvm/CodeGen/ISDOpcodes.h"
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/MachineJumpTableInfo.h"
#include "llvm/CodeGen/MachineMemOperand.h"
#include "llvm/CodeGen/SelectionDAG.h"
#include "llvm/CodeGen/SelectionDAGNodes.h"
Expand Down Expand Up @@ -582,9 +583,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction(ISD::ROTR, MVT::i8, Expand);
setOperationAction(ISD::BSWAP, MVT::i16, Expand);

// Indirect branch is not supported.
// This also disables Jump Table creation.
setOperationAction(ISD::BR_JT, MVT::Other, Expand);
setOperationAction(ISD::BR_JT, MVT::Other, Custom);
setOperationAction(ISD::BRIND, MVT::Other, Expand);

setOperationAction(ISD::GlobalAddress, MVT::i32, Custom);
Expand Down Expand Up @@ -945,6 +944,9 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(NVPTXISD::Dummy)
MAKE_CASE(NVPTXISD::MUL_WIDE_SIGNED)
MAKE_CASE(NVPTXISD::MUL_WIDE_UNSIGNED)
MAKE_CASE(NVPTXISD::BrxEnd)
MAKE_CASE(NVPTXISD::BrxItem)
MAKE_CASE(NVPTXISD::BrxStart)
MAKE_CASE(NVPTXISD::Tex1DFloatS32)
MAKE_CASE(NVPTXISD::Tex1DFloatFloat)
MAKE_CASE(NVPTXISD::Tex1DFloatFloatLevel)
Expand Down Expand Up @@ -2785,6 +2787,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
return LowerFP_ROUND(Op, DAG);
case ISD::FP_EXTEND:
return LowerFP_EXTEND(Op, DAG);
case ISD::BR_JT:
return LowerBR_JT(Op, DAG);
case ISD::VAARG:
return LowerVAARG(Op, DAG);
case ISD::VASTART:
Expand All @@ -2810,6 +2814,41 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
}
}

SDValue NVPTXTargetLowering::LowerBR_JT(SDValue Op, SelectionDAG &DAG) const {
SDLoc DL(Op);
SDValue Chain = Op.getOperand(0);
const auto *JT = cast<JumpTableSDNode>(Op.getOperand(1));
SDValue Index = Op.getOperand(2);

unsigned JId = JT->getIndex();
MachineJumpTableInfo *MJTI = DAG.getMachineFunction().getJumpTableInfo();
ArrayRef<MachineBasicBlock *> MBBs = MJTI->getJumpTables()[JId].MBBs;

SDValue IdV = DAG.getConstant(JId, DL, MVT::i32);

// Generate BrxStart node
SDVTList VTs = DAG.getVTList(MVT::Other, MVT::Glue);
Chain = DAG.getNode(NVPTXISD::BrxStart, DL, VTs, Chain, IdV);

// Generate BrxItem nodes
assert(!MBBs.empty());
for (MachineBasicBlock *MBB : MBBs.drop_back())
Chain = DAG.getNode(NVPTXISD::BrxItem, DL, VTs, Chain.getValue(0),
DAG.getBasicBlock(MBB), Chain.getValue(1));

// Generate BrxEnd nodes
SDValue EndOps[] = {Chain.getValue(0), DAG.getBasicBlock(MBBs.back()), Index,
IdV, Chain.getValue(1)};
SDValue BrxEnd = DAG.getNode(NVPTXISD::BrxEnd, DL, VTs, EndOps);

return BrxEnd;
}

// This will prevent AsmPrinter from trying to print the jump tables itself.
unsigned NVPTXTargetLowering::getJumpTableEncoding() const {
return MachineJumpTableInfo::EK_Inline;
}

// This function is almost a copy of SelectionDAG::expandVAArg().
// The only diff is that this one produces loads from local address space.
SDValue NVPTXTargetLowering::LowerVAARG(SDValue Op, SelectionDAG &DAG) const {
Expand Down
10 changes: 10 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ enum NodeType : unsigned {
BFI,
PRMT,
DYNAMIC_STACKALLOC,
BrxStart,
BrxItem,
BrxEnd,
Dummy,

LoadV2 = ISD::FIRST_TARGET_MEMORY_OPCODE,
Expand Down Expand Up @@ -580,6 +583,11 @@ class NVPTXTargetLowering : public TargetLowering {
return true;
}

// The default is the same as pointer type, but brx.idx only accepts i32
MVT getJumpTableRegTy(const DataLayout &) const override { return MVT::i32; }

unsigned getJumpTableEncoding() const override;

bool enableAggressiveFMAFusion(EVT VT) const override { return true; }

// The default is to transform llvm.ctlz(x, false) (where false indicates that
Expand Down Expand Up @@ -637,6 +645,8 @@ class NVPTXTargetLowering : public TargetLowering {

SDValue LowerSelect(SDValue Op, SelectionDAG &DAG) const;

SDValue LowerBR_JT(SDValue Op, SelectionDAG &DAG) const;

SDValue LowerVAARG(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerVASTART(SDValue Op, SelectionDAG &DAG) const;

Expand Down
38 changes: 38 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -3880,6 +3880,44 @@ def DYNAMIC_STACKALLOC64 :
[(set Int64Regs:$ptr, (dyn_alloca Int64Regs:$size, (i32 timm:$align)))]>,
Requires<[hasPTX<73>, hasSM<52>]>;


//
// BRX
//

def SDTBrxStartProfile : SDTypeProfile<0, 1, [SDTCisInt<0>]>;
def SDTBrxItemProfile : SDTypeProfile<0, 1, [SDTCisVT<0, OtherVT>]>;
def SDTBrxEndProfile : SDTypeProfile<0, 3, [SDTCisVT<0, OtherVT>, SDTCisInt<1>, SDTCisInt<2>]>;

def brx_start :
SDNode<"NVPTXISD::BrxStart", SDTBrxStartProfile,
[SDNPHasChain, SDNPOutGlue, SDNPSideEffect]>;
def brx_item :
SDNode<"NVPTXISD::BrxItem", SDTBrxItemProfile,
[SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
def brx_end :
SDNode<"NVPTXISD::BrxEnd", SDTBrxEndProfile,
[SDNPHasChain, SDNPInGlue, SDNPSideEffect]>;

let isTerminator = 1, isBranch = 1, isIndirectBranch = 1 in {

def BRX_START :
NVPTXInst<(outs), (ins i32imm:$id),
"$$L_brx_$id: .branchtargets",
[(brx_start (i32 imm:$id))]>;

def BRX_ITEM :
NVPTXInst<(outs), (ins brtarget:$target),
"\t$target,",
[(brx_item bb:$target)]>;

def BRX_END :
NVPTXInst<(outs), (ins brtarget:$target, Int32Regs:$val, i32imm:$id),
"\t$target;\n\tbrx.idx \t$val, $$L_brx_$id;",
[(brx_end bb:$target, (i32 Int32Regs:$val), (i32 imm:$id))]>;
}


include "NVPTXIntrinsics.td"

//-----------------------------------
Expand Down
69 changes: 69 additions & 0 deletions llvm/test/CodeGen/NVPTX/jump-table.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc < %s | FileCheck %s
; RUN: %if ptxas %{ llc < %s | %ptxas-verify %}

target triple = "nvptx64-nvidia-cuda"

@out = addrspace(1) global i32 0, align 4

define void @foo(i32 %i) {
; CHECK-LABEL: foo(
; CHECK: {
; CHECK-NEXT: .reg .pred %p<2>;
; CHECK-NEXT: .reg .b32 %r<7>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0: // %entry
; CHECK-NEXT: ld.param.u32 %r2, [foo_param_0];
; CHECK-NEXT: setp.gt.u32 %p1, %r2, 3;
; CHECK-NEXT: @%p1 bra $L__BB0_6;
; CHECK-NEXT: // %bb.1: // %entry
; CHECK-NEXT: $L_brx_0: .branchtargets
; CHECK-NEXT: $L__BB0_2,
; CHECK-NEXT: $L__BB0_3,
; CHECK-NEXT: $L__BB0_4,
; CHECK-NEXT: $L__BB0_5;
; CHECK-NEXT: brx.idx %r2, $L_brx_0;
; CHECK-NEXT: $L__BB0_2: // %case0
; CHECK-NEXT: mov.b32 %r6, 0;
; CHECK-NEXT: st.global.u32 [out], %r6;
; CHECK-NEXT: bra.uni $L__BB0_6;
; CHECK-NEXT: $L__BB0_4: // %case2
; CHECK-NEXT: mov.b32 %r4, 2;
; CHECK-NEXT: st.global.u32 [out], %r4;
; CHECK-NEXT: bra.uni $L__BB0_6;
; CHECK-NEXT: $L__BB0_5: // %case3
; CHECK-NEXT: mov.b32 %r3, 3;
; CHECK-NEXT: st.global.u32 [out], %r3;
; CHECK-NEXT: bra.uni $L__BB0_6;
; CHECK-NEXT: $L__BB0_3: // %case1
; CHECK-NEXT: mov.b32 %r5, 1;
; CHECK-NEXT: st.global.u32 [out], %r5;
; CHECK-NEXT: $L__BB0_6: // %end
; CHECK-NEXT: ret;
entry:
switch i32 %i, label %end [
i32 0, label %case0
i32 1, label %case1
i32 2, label %case2
i32 3, label %case3
]

case0:
store i32 0, ptr addrspace(1) @out, align 4
br label %end

case1:
store i32 1, ptr addrspace(1) @out, align 4
br label %end

case2:
store i32 2, ptr addrspace(1) @out, align 4
br label %end

case3:
store i32 3, ptr addrspace(1) @out, align 4
br label %end

end:
ret void
}

0 comments on commit ba97697

Please sign in to comment.