Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVPTX] support switch statement with brx.idx #102400

Merged
merged 2 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -3843,6 +3843,10 @@ class TargetLowering : public TargetLoweringBase {
/// returned value is a member of the MachineJumpTableInfo::JTEntryKind enum.
virtual unsigned getJumpTableEncoding() const;

virtual MVT getJumpTableRegTy(const DataLayout &DL) const {
return getPointerTy(DL);
}

virtual const MCExpr *
LowerCustomJumpTableEntry(const MachineJumpTableInfo * /*MJTI*/,
const MachineBasicBlock * /*MBB*/, unsigned /*uid*/,
Expand Down
11 changes: 6 additions & 5 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2977,7 +2977,7 @@ void SelectionDAGBuilder::visitJumpTable(SwitchCG::JumpTable &JT) {
// Emit the code for the jump table
assert(JT.SL && "Should set SDLoc for SelectionDAG!");
assert(JT.Reg != -1U && "Should lower JT Header first!");
EVT PTy = DAG.getTargetLoweringInfo().getPointerTy(DAG.getDataLayout());
EVT PTy = DAG.getTargetLoweringInfo().getJumpTableRegTy(DAG.getDataLayout());
SDValue Index = DAG.getCopyFromReg(getControlRoot(), *JT.SL, JT.Reg, PTy);
SDValue Table = DAG.getJumpTable(JT.JTI, PTy);
SDValue BrJumpTable = DAG.getNode(ISD::BR_JT, *JT.SL, MVT::Other,
Expand Down Expand Up @@ -3005,12 +3005,13 @@ void SelectionDAGBuilder::visitJumpTableHeader(SwitchCG::JumpTable &JT,
// This value may be smaller or larger than the target's pointer type, and
// therefore require extension or truncating.
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
SwitchOp = DAG.getZExtOrTrunc(Sub, dl, TLI.getPointerTy(DAG.getDataLayout()));
SwitchOp =
DAG.getZExtOrTrunc(Sub, dl, TLI.getJumpTableRegTy(DAG.getDataLayout()));

unsigned JumpTableReg =
FuncInfo.CreateReg(TLI.getPointerTy(DAG.getDataLayout()));
SDValue CopyTo = DAG.getCopyToReg(getControlRoot(), dl,
JumpTableReg, SwitchOp);
FuncInfo.CreateReg(TLI.getJumpTableRegTy(DAG.getDataLayout()));
SDValue CopyTo =
DAG.getCopyToReg(getControlRoot(), dl, JumpTableReg, SwitchOp);
JT.Reg = JumpTableReg;

if (!JTH.FallthroughUnreachable) {
Expand Down
45 changes: 42 additions & 3 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "llvm/CodeGen/Analysis.h"
#include "llvm/CodeGen/ISDOpcodes.h"
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/MachineJumpTableInfo.h"
#include "llvm/CodeGen/MachineMemOperand.h"
#include "llvm/CodeGen/SelectionDAG.h"
#include "llvm/CodeGen/SelectionDAGNodes.h"
Expand Down Expand Up @@ -582,9 +583,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction(ISD::ROTR, MVT::i8, Expand);
setOperationAction(ISD::BSWAP, MVT::i16, Expand);

// Indirect branch is not supported.
// This also disables Jump Table creation.
setOperationAction(ISD::BR_JT, MVT::Other, Expand);
setOperationAction(ISD::BR_JT, MVT::Other, Custom);
setOperationAction(ISD::BRIND, MVT::Other, Expand);

setOperationAction(ISD::GlobalAddress, MVT::i32, Custom);
Expand Down Expand Up @@ -945,6 +944,9 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(NVPTXISD::Dummy)
MAKE_CASE(NVPTXISD::MUL_WIDE_SIGNED)
MAKE_CASE(NVPTXISD::MUL_WIDE_UNSIGNED)
MAKE_CASE(NVPTXISD::BrxEnd)
MAKE_CASE(NVPTXISD::BrxItem)
MAKE_CASE(NVPTXISD::BrxStart)
MAKE_CASE(NVPTXISD::Tex1DFloatS32)
MAKE_CASE(NVPTXISD::Tex1DFloatFloat)
MAKE_CASE(NVPTXISD::Tex1DFloatFloatLevel)
Expand Down Expand Up @@ -2785,6 +2787,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
return LowerFP_ROUND(Op, DAG);
case ISD::FP_EXTEND:
return LowerFP_EXTEND(Op, DAG);
case ISD::BR_JT:
return LowerBR_JT(Op, DAG);
case ISD::VAARG:
return LowerVAARG(Op, DAG);
case ISD::VASTART:
Expand All @@ -2810,6 +2814,41 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
}
}

SDValue NVPTXTargetLowering::LowerBR_JT(SDValue Op, SelectionDAG &DAG) const {
SDLoc DL(Op);
SDValue Chain = Op.getOperand(0);
const auto *JT = cast<JumpTableSDNode>(Op.getOperand(1));
SDValue Index = Op.getOperand(2);

unsigned JId = JT->getIndex();
MachineJumpTableInfo *MJTI = DAG.getMachineFunction().getJumpTableInfo();
ArrayRef<MachineBasicBlock *> MBBs = MJTI->getJumpTables()[JId].MBBs;

SDValue IdV = DAG.getConstant(JId, DL, MVT::i32);

// Generate BrxStart node
SDVTList VTs = DAG.getVTList(MVT::Other, MVT::Glue);
Chain = DAG.getNode(NVPTXISD::BrxStart, DL, VTs, Chain, IdV);

// Generate BrxItem nodes
assert(!MBBs.empty());
for (MachineBasicBlock *MBB : MBBs.drop_back())
Chain = DAG.getNode(NVPTXISD::BrxItem, DL, VTs, Chain.getValue(0),
DAG.getBasicBlock(MBB), Chain.getValue(1));

// Generate BrxEnd nodes
SDValue EndOps[] = {Chain.getValue(0), DAG.getBasicBlock(MBBs.back()), Index,
IdV, Chain.getValue(1)};
SDValue BrxEnd = DAG.getNode(NVPTXISD::BrxEnd, DL, VTs, EndOps);

return BrxEnd;
}

// This will prevent AsmPrinter from trying to print the jump tables itself.
unsigned NVPTXTargetLowering::getJumpTableEncoding() const {
return MachineJumpTableInfo::EK_Inline;
}

// This function is almost a copy of SelectionDAG::expandVAArg().
// The only diff is that this one produces loads from local address space.
SDValue NVPTXTargetLowering::LowerVAARG(SDValue Op, SelectionDAG &DAG) const {
Expand Down
10 changes: 10 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ enum NodeType : unsigned {
BFI,
PRMT,
DYNAMIC_STACKALLOC,
BrxStart,
BrxItem,
BrxEnd,
Dummy,

LoadV2 = ISD::FIRST_TARGET_MEMORY_OPCODE,
Expand Down Expand Up @@ -580,6 +583,11 @@ class NVPTXTargetLowering : public TargetLowering {
return true;
}

// The default is the same as pointer type, but brx.idx only accepts i32
MVT getJumpTableRegTy(const DataLayout &) const override { return MVT::i32; }

unsigned getJumpTableEncoding() const override;

bool enableAggressiveFMAFusion(EVT VT) const override { return true; }

// The default is to transform llvm.ctlz(x, false) (where false indicates that
Expand Down Expand Up @@ -637,6 +645,8 @@ class NVPTXTargetLowering : public TargetLowering {

SDValue LowerSelect(SDValue Op, SelectionDAG &DAG) const;

SDValue LowerBR_JT(SDValue Op, SelectionDAG &DAG) const;

SDValue LowerVAARG(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerVASTART(SDValue Op, SelectionDAG &DAG) const;

Expand Down
38 changes: 38 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -3880,6 +3880,44 @@ def DYNAMIC_STACKALLOC64 :
[(set Int64Regs:$ptr, (dyn_alloca Int64Regs:$size, (i32 timm:$align)))]>,
Requires<[hasPTX<73>, hasSM<52>]>;


//
// BRX
//

def SDTBrxStartProfile : SDTypeProfile<0, 1, [SDTCisInt<0>]>;
def SDTBrxItemProfile : SDTypeProfile<0, 1, [SDTCisVT<0, OtherVT>]>;
def SDTBrxEndProfile : SDTypeProfile<0, 3, [SDTCisVT<0, OtherVT>, SDTCisInt<1>, SDTCisInt<2>]>;

def brx_start :
SDNode<"NVPTXISD::BrxStart", SDTBrxStartProfile,
[SDNPHasChain, SDNPOutGlue, SDNPSideEffect]>;
def brx_item :
SDNode<"NVPTXISD::BrxItem", SDTBrxItemProfile,
[SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
def brx_end :
SDNode<"NVPTXISD::BrxEnd", SDTBrxEndProfile,
[SDNPHasChain, SDNPInGlue, SDNPSideEffect]>;

let isTerminator = 1, isBranch = 1, isIndirectBranch = 1 in {

def BRX_START :
NVPTXInst<(outs), (ins i32imm:$id),
"$$L_brx_$id: .branchtargets",
[(brx_start (i32 imm:$id))]>;

def BRX_ITEM :
NVPTXInst<(outs), (ins brtarget:$target),
"\t$target,",
[(brx_item bb:$target)]>;

def BRX_END :
NVPTXInst<(outs), (ins brtarget:$target, Int32Regs:$val, i32imm:$id),
"\t$target;\n\tbrx.idx \t$val, $$L_brx_$id;",
[(brx_end bb:$target, (i32 Int32Regs:$val), (i32 imm:$id))]>;
}


include "NVPTXIntrinsics.td"

//-----------------------------------
Expand Down
69 changes: 69 additions & 0 deletions llvm/test/CodeGen/NVPTX/jump-table.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc < %s | FileCheck %s
; RUN: %if ptxas %{ llc < %s | %ptxas-verify %}

target triple = "nvptx64-nvidia-cuda"

@out = addrspace(1) global i32 0, align 4

define void @foo(i32 %i) {
; CHECK-LABEL: foo(
; CHECK: {
; CHECK-NEXT: .reg .pred %p<2>;
; CHECK-NEXT: .reg .b32 %r<7>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0: // %entry
; CHECK-NEXT: ld.param.u32 %r2, [foo_param_0];
; CHECK-NEXT: setp.gt.u32 %p1, %r2, 3;
; CHECK-NEXT: @%p1 bra $L__BB0_6;
; CHECK-NEXT: // %bb.1: // %entry
; CHECK-NEXT: $L_brx_0: .branchtargets
; CHECK-NEXT: $L__BB0_2,
; CHECK-NEXT: $L__BB0_3,
; CHECK-NEXT: $L__BB0_4,
; CHECK-NEXT: $L__BB0_5;
; CHECK-NEXT: brx.idx %r2, $L_brx_0;
; CHECK-NEXT: $L__BB0_2: // %case0
; CHECK-NEXT: mov.b32 %r6, 0;
; CHECK-NEXT: st.global.u32 [out], %r6;
; CHECK-NEXT: bra.uni $L__BB0_6;
; CHECK-NEXT: $L__BB0_4: // %case2
; CHECK-NEXT: mov.b32 %r4, 2;
; CHECK-NEXT: st.global.u32 [out], %r4;
; CHECK-NEXT: bra.uni $L__BB0_6;
; CHECK-NEXT: $L__BB0_5: // %case3
; CHECK-NEXT: mov.b32 %r3, 3;
; CHECK-NEXT: st.global.u32 [out], %r3;
; CHECK-NEXT: bra.uni $L__BB0_6;
; CHECK-NEXT: $L__BB0_3: // %case1
; CHECK-NEXT: mov.b32 %r5, 1;
; CHECK-NEXT: st.global.u32 [out], %r5;
; CHECK-NEXT: $L__BB0_6: // %end
; CHECK-NEXT: ret;
entry:
switch i32 %i, label %end [
i32 0, label %case0
i32 1, label %case1
i32 2, label %case2
i32 3, label %case3
]

case0:
store i32 0, ptr addrspace(1) @out, align 4
br label %end

case1:
store i32 1, ptr addrspace(1) @out, align 4
br label %end

case2:
store i32 2, ptr addrspace(1) @out, align 4
br label %end

case3:
store i32 3, ptr addrspace(1) @out, align 4
br label %end

end:
ret void
}
Loading