Skip to content

Commit

Permalink
[AArch64][PAC] Select auth+load into LDRA*.
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmedbougacha committed Sep 19, 2023
1 parent 1e5f318 commit b46cf39
Show file tree
Hide file tree
Showing 8 changed files with 1,249 additions and 2 deletions.
9 changes: 9 additions & 0 deletions llvm/lib/Target/AArch64/AArch64Combine.td
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,14 @@ def form_truncstore : GICombineRule<
(apply [{ applyFormTruncstore(*${root}, MRI, B, Observer, ${matchinfo}); }])
>;

def form_auth_load_matchdata : GIDefMatchData<"AuthLoadMatchInfo">;
def form_auth_load : GICombineRule<
(defs root:$root, form_auth_load_matchdata:$matchinfo),
(match (wip_match_opcode G_LOAD):$root,
[{ return matchFormAuthLoad(*${root}, MRI, Helper, ${matchinfo}); }]),
(apply [{ applyFormAuthLoad(*${root}, MRI, B, Helper, Observer, ${matchinfo}); }])
>;

def fold_merge_to_zext : GICombineRule<
(defs root:$d),
(match (wip_match_opcode G_MERGE_VALUES):$d,
Expand Down Expand Up @@ -231,6 +239,7 @@ def AArch64PostLegalizerLowering
[shuffle_vector_lowering, vashr_vlshr_imm,
icmp_lowering, build_vector_lowering,
lower_vector_fcmp, form_truncstore,
form_auth_load,
vector_sext_inreg_to_shift,
unmerge_ext_to_unmerge]> {
}
Expand Down
124 changes: 124 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ExpandHardenedPseudos.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//

#include "AArch64ExpandImm.h"
#include "AArch64InstrInfo.h"
#include "AArch64MachineFunctionInfo.h"
#include "AArch64Subtarget.h"
Expand Down Expand Up @@ -54,6 +55,7 @@ class AArch64ExpandHardenedPseudos : public MachineFunctionPass {

private:
bool expandPtrAuthPseudo(MachineInstr &MI);
bool expandAuthLoad(MachineInstr &MI);
bool expandMI(MachineInstr &MI);
};

Expand Down Expand Up @@ -306,13 +308,135 @@ bool AArch64ExpandHardenedPseudos::expandPtrAuthPseudo(MachineInstr &MI) {
return true;
}

bool AArch64ExpandHardenedPseudos::expandAuthLoad(MachineInstr &MI) {
MachineBasicBlock &MBB = *MI.getParent();
MachineFunction &MF = *MBB.getParent();
DebugLoc DL = MI.getDebugLoc();
auto MBBI = MI.getIterator();

const AArch64Subtarget &STI = MF.getSubtarget<AArch64Subtarget>();
const AArch64InstrInfo *TII = STI.getInstrInfo();

LLVM_DEBUG(dbgs() << "Expanding: " << MI << "\n");

bool IsPre = MI.getOpcode() == AArch64::LDRApre;

MachineOperand DstOp = MI.getOperand(0);
int64_t Offset = MI.getOperand(1).getImm();
auto Key = (AArch64PACKey::ID)MI.getOperand(2).getImm();
uint64_t Disc = MI.getOperand(3).getImm();
unsigned AddrDisc = MI.getOperand(4).getReg();

unsigned DiscReg = AddrDisc;
if (Disc) {
assert(isUInt<16>(Disc) && "Integer discriminator is too wide");

if (AddrDisc != AArch64::XZR) {
BuildMI(MBB, MBBI, DL, TII->get(AArch64::ORRXrs), AArch64::X17)
.addReg(AArch64::XZR)
.addReg(AddrDisc)
.addImm(0);
BuildMI(MBB, MBBI, DL, TII->get(AArch64::MOVKXi), AArch64::X17)
.addReg(AArch64::X17)
.addImm(Disc)
.addImm(/*shift=*/48);
} else {
BuildMI(MBB, MBBI, DL, TII->get(AArch64::MOVZXi), AArch64::X17)
.addImm(Disc)
.addImm(/*shift=*/0);
}
DiscReg = AArch64::X17;
}

unsigned AUTOpc = getAUTOpcodeForKey(Key, DiscReg == AArch64::XZR);
auto MIB = BuildMI(MBB, MBBI, DL, TII->get(AUTOpc), AArch64::X16)
.addReg(AArch64::X16);
if (DiscReg != AArch64::XZR)
MIB.addReg(DiscReg);

// We have a few options for offset folding:
// - 0 offset: LDRXui
// - no wb, uimm12s8 offset: LDRXui
// - no wb, simm9 offset: LDURXi
// - wb, simm9 offset: LDRXpre
// - no wb, any offset: expanded MOVImm + LDRXroX
// - wb, any offset: expanded MOVImm + ADD + LDRXui
if (!Offset || (!IsPre && isShiftedUInt<12, 3>(Offset))) {
BuildMI(MBB, MBBI, DL, TII->get(AArch64::LDRXui))
.add(DstOp)
.addUse(AArch64::X16)
.addImm(Offset / 8);
} else if (!IsPre && Offset && isInt<9>(Offset)) {
BuildMI(MBB, MBBI, DL, TII->get(AArch64::LDURXi))
.add(DstOp)
.addUse(AArch64::X16)
.addImm(Offset);
} else if (IsPre && Offset && isInt<9>(Offset)) {
BuildMI(MBB, MBBI, DL, TII->get(AArch64::LDRXpre), AArch64::X16)
.add(DstOp)
.addUse(AArch64::X16)
.addImm(Offset);
} else {
SmallVector<AArch64_IMM::ImmInsnModel, 4> ImmInsns;
AArch64_IMM::expandMOVImm(Offset, 64, ImmInsns);

// X17 is dead at this point, use it as the offset register
for (auto &ImmI : ImmInsns) {
switch (ImmI.Opcode) {
default: llvm_unreachable("invalid ldra imm expansion opc!"); break;

case AArch64::ORRXri:
BuildMI(MBB, MBBI, DL, TII->get(ImmI.Opcode), AArch64::X17)
.addReg(AArch64::XZR)
.addImm(ImmI.Op2);
break;
case AArch64::MOVNXi:
case AArch64::MOVZXi: {
BuildMI(MBB, MBBI, DL, TII->get(ImmI.Opcode), AArch64::X17)
.addImm(ImmI.Op1)
.addImm(ImmI.Op2);
} break;
case AArch64::MOVKXi: {
BuildMI(MBB, MBBI, DL, TII->get(ImmI.Opcode), AArch64::X17)
.addReg(AArch64::X17)
.addImm(ImmI.Op1)
.addImm(ImmI.Op2);
} break;
}
}

if (IsPre) {
BuildMI(MBB, MBBI, DL, TII->get(AArch64::ADDXrs), AArch64::X16)
.addReg(AArch64::X16)
.addReg(AArch64::X17)
.addImm(0);
BuildMI(MBB, MBBI, DL, TII->get(AArch64::LDRXui))
.add(DstOp)
.addUse(AArch64::X16)
.addImm(/*Offset=*/0);
} else {
BuildMI(MBB, MBBI, DL, TII->get(AArch64::LDRXroX))
.add(DstOp)
.addReg(AArch64::X16)
.addReg(AArch64::X17)
.addImm(0)
.addImm(0);
}
}

return true;
}

bool AArch64ExpandHardenedPseudos::expandMI(MachineInstr &MI) {
switch (MI.getOpcode()) {
case AArch64::BR_JumpTable:
case AArch64::LOADauthptrgot:
case AArch64::LOADgotPAC:
case AArch64::MOVaddrPAC:
return expandPtrAuthPseudo(MI);
case AArch64::LDRA:
case AArch64::LDRApre:
return expandAuthLoad(MI);
default:
return false;
}
Expand Down
163 changes: 161 additions & 2 deletions llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,8 @@ class AArch64DAGToDAGISel : public SelectionDAGISel {

bool tryIndexedLoad(SDNode *N);

bool tryAuthLoad(SDNode *N);

void SelectPtrauthAuth(SDNode *N);
void SelectPtrauthResign(SDNode *N);

Expand Down Expand Up @@ -1640,6 +1642,163 @@ bool AArch64DAGToDAGISel::tryIndexedLoad(SDNode *N) {
return true;
}

bool AArch64DAGToDAGISel::tryAuthLoad(SDNode *N) {
LoadSDNode *LD = cast<LoadSDNode>(N);
EVT VT = LD->getMemoryVT();
if (VT != MVT::i64)
return false;

assert(LD->getExtensionType() == ISD::NON_EXTLOAD && "invalid 64bit extload");

ISD::MemIndexedMode AM = LD->getAddressingMode();
if (AM != ISD::PRE_INC && AM != ISD::UNINDEXED)
return false;
bool IsPre = AM == ISD::PRE_INC;

SDValue Chain = LD->getChain();
SDValue Ptr = LD->getBasePtr();

SDValue Base = Ptr;

int64_t OffsetVal = 0;
if (IsPre) {
OffsetVal = cast<ConstantSDNode>(LD->getOffset())->getSExtValue();
} else if (CurDAG->isBaseWithConstantOffset(Base)) {
// We support both 'base' and 'base + constant offset' modes.
ConstantSDNode *RHS = dyn_cast<ConstantSDNode>(Base.getOperand(1));
if (!RHS)
return false;
OffsetVal = RHS->getSExtValue();
Base = Base.getOperand(0);
}

// The base must be of the form:
// (int_ptrauth_auth <signedbase>, <key>, <disc>)
// with disc being either a constant int, or:
// (int_ptrauth_blend <addrdisc>, <const int disc>)
if (Base.getOpcode() != ISD::INTRINSIC_WO_CHAIN)
return false;

unsigned IntID = cast<ConstantSDNode>(Base.getOperand(0))->getZExtValue();
if (IntID != Intrinsic::ptrauth_auth)
return false;

unsigned KeyC = cast<ConstantSDNode>(Base.getOperand(2))->getZExtValue();
bool IsDKey = KeyC == AArch64PACKey::DA || KeyC == AArch64PACKey::DB;
SDValue Disc = Base.getOperand(3);

Base = Base.getOperand(1);

bool ZeroDisc = isNullConstant(Disc);
SDValue IntDisc, AddrDisc;
std::tie(IntDisc, AddrDisc) =
extractPtrauthBlendDiscriminators(Disc, CurDAG);

// If this is an indexed pre-inc load, we obviously need the writeback form.
bool needsWriteback = IsPre;
// If not, but the base authenticated pointer has any other use, it's
// beneficial to use the writeback form, to "writeback" the auth, even if
// there is no base+offset addition.
if (!Ptr.hasOneUse()) {
needsWriteback = true;

// However, we can only do that if we don't introduce cycles between the
// load node and any other user of the pointer computation nodes. That can
// happen if the load node uses any of said other users.
// In other words: we can only do this transformation if none of the other
// uses of the pointer computation to be folded are predecessors of the load
// we're folding into.
//
// Visited is a cache containing nodes that are known predecessors of N.
// Worklist is the set of nodes we're looking for predecessors of.
// For the first lookup, that only contains the load node N. Each call to
// hasPredecessorHelper adds any of the potential predecessors of N to the
// Worklist.
SmallPtrSet<const SDNode *, 32> Visited;
SmallVector<const SDNode *, 16> Worklist;
Worklist.push_back(N);
for (SDNode *U : Ptr.getNode()->uses())
if (SDNode::hasPredecessorHelper(U, Visited, Worklist, /*Max=*/32,
/*TopologicalPrune=*/true))
return false;
}

// We have 2 main isel alternatives:
// - LDRAA/LDRAB, writeback or indexed. Zero disc, small offsets, D key.
// - LDRA/LDRApre. Pointer needs to be in X16.
SDLoc DL(N);
MachineSDNode *Res = nullptr;
SDValue Writeback, ResVal, OutChain;

// If the discriminator is zero and the offset fits, we can use LDRAA/LDRAB.
// Do that here to avoid needlessly constraining regalloc into using X16.
if (ZeroDisc && isShiftedInt<10, 3>(OffsetVal) && IsDKey) {
unsigned Opc = 0;
switch (KeyC) {
case AArch64PACKey::DA:
Opc = needsWriteback ? AArch64::LDRAAwriteback : AArch64::LDRAAindexed;
break;
case AArch64PACKey::DB:
Opc = needsWriteback ? AArch64::LDRABwriteback : AArch64::LDRABindexed;
break;
default:
llvm_unreachable("Invalid key for LDRAA/LDRAB");
}
// The offset is encoded as scaled, for an element size of 8 bytes.
SDValue Offset = CurDAG->getTargetConstant(OffsetVal / 8, DL, MVT::i64);
SDValue Ops[] = {Base, Offset, Chain};
Res = needsWriteback ?
CurDAG->getMachineNode(Opc, DL, MVT::i64, MVT::i64, MVT::Other, Ops) :
CurDAG->getMachineNode(Opc, DL, MVT::i64, MVT::Other, Ops);
if (needsWriteback) {
Writeback = SDValue(Res, 0);
ResVal = SDValue(Res, 1);
OutChain = SDValue(Res, 2);
} else {
ResVal = SDValue(Res, 0);
OutChain = SDValue(Res, 1);
}
} else {
// Otherwise, use the generalized LDRA pseudos.
unsigned Opc = needsWriteback ? AArch64::LDRApre : AArch64::LDRA;

SDValue X16Copy = CurDAG->getCopyToReg(Chain, DL, AArch64::X16,
Base, SDValue());
SDValue Offset = CurDAG->getTargetConstant(OffsetVal, DL, MVT::i64);
SDValue Key = CurDAG->getTargetConstant(KeyC, DL, MVT::i32);
SDValue Ops[] = {Offset, Key, IntDisc, AddrDisc, X16Copy.getValue(1)};
Res = CurDAG->getMachineNode(Opc, DL, MVT::i64, MVT::Other, MVT::Glue, Ops);
if (needsWriteback)
Writeback = CurDAG->getCopyFromReg(SDValue(Res, 1), DL, AArch64::X16,
MVT::i64, SDValue(Res, 2));
ResVal = SDValue(Res, 0);
OutChain = SDValue(Res, 1);
}

if (IsPre) {
// If the original load was pre-inc, the resulting LDRA is writeback.
assert(needsWriteback && "preinc loads can't be selected into non-wb ldra");
ReplaceUses(SDValue(N, 1), Writeback); // writeback
ReplaceUses(SDValue(N, 0), ResVal); // loaded value
ReplaceUses(SDValue(N, 2), OutChain); // chain
} else if (needsWriteback) {
// If the original load was unindexed, but we emitted a writeback form,
// we need to replace the uses of the original auth(signedbase)[+offset]
// computation.
ReplaceUses(Ptr, Writeback); // writeback
ReplaceUses(SDValue(N, 0), ResVal); // loaded value
ReplaceUses(SDValue(N, 1), OutChain); // chain
} else {
// Otherwise, we selected a simple load to a simple non-wb ldra.
assert(Ptr.hasOneUse() && "reused auth ptr should be folded into ldra");
ReplaceUses(SDValue(N, 0), ResVal); // loaded value
ReplaceUses(SDValue(N, 1), OutChain); // chain
}

CurDAG->RemoveDeadNode(N);
return true;
}

void AArch64DAGToDAGISel::SelectLoad(SDNode *N, unsigned NumVecs, unsigned Opc,
unsigned SubRegIdx) {
SDLoc dl(N);
Expand Down Expand Up @@ -4359,8 +4518,8 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) {
break;

case ISD::LOAD: {
// Try to select as an indexed load. Fall through to normal processing
// if we can't.
if (tryAuthLoad(Node))
return;
if (tryIndexedLoad(Node))
return;
break;
Expand Down
17 changes: 17 additions & 0 deletions llvm/lib/Target/AArch64/AArch64InstrGISel.td
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,23 @@ def G_ADD_LOW : AArch64GenericInstruction {
let hasSideEffects = 0;
}

// Represents an auth-load instruction. Produced post-legalization from
// G_LOADs of ptrauth_auth intrinsics, with variants for keys/discriminators.
def G_LDRA : AArch64GenericInstruction {
let OutOperandList = (outs type0:$dst);
let InOperandList = (ins type1:$addr, i64imm:$offset, i32imm:$key, i64imm:$disc, type0:$addrdisc);
let hasSideEffects = 0;
let mayLoad = 1;
}

// Represents a pre-inc writeback auth-load instruction. Similar to G_LDRA.
def G_LDRApre : AArch64GenericInstruction {
let OutOperandList = (outs type0:$dst, ptype1:$newaddr);
let InOperandList = (ins ptype1:$addr, i64imm:$offset, i32imm:$key, i64imm:$disc, type0:$addrdisc);
let hasSideEffects = 0;
let mayLoad = 1;
}

// Pseudo for a rev16 instruction. Produced post-legalization from
// G_SHUFFLE_VECTORs with appropriate masks.
def G_REV16 : AArch64GenericInstruction {
Expand Down
Loading

0 comments on commit b46cf39

Please sign in to comment.