Skip to content

Commit

Permalink
[AArch64][PAC] Move emission of LR checks in tail calls to AsmPrinter (
Browse files Browse the repository at this point in the history
…llvm#110705)

Move the emission of the checks performed on the authenticated LR value
during tail calls to AArch64AsmPrinter class, so that different checker
sequences can be reused by pseudo instructions expanded there.
This adds one more option to AuthCheckMethod enumeration, the generic
XPAC variant which is not restricted to checking the LR register.
  • Loading branch information
atrosinenko authored Nov 12, 2024
1 parent 469520e commit 44076c9
Show file tree
Hide file tree
Showing 12 changed files with 367 additions and 321 deletions.
151 changes: 120 additions & 31 deletions llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,13 @@ class AArch64AsmPrinter : public AsmPrinter {
void emitPtrauthCheckAuthenticatedValue(Register TestedReg,
Register ScratchReg,
AArch64PACKey::ID Key,
AArch64PAuth::AuthCheckMethod Method,
bool ShouldTrap,
const MCSymbol *OnFailure);

// Check authenticated LR before tail calling.
void emitPtrauthTailCallHardening(const MachineInstr *TC);

// Emit the sequence for AUT or AUTPAC.
void emitPtrauthAuthResign(const MachineInstr *MI);

Expand Down Expand Up @@ -1751,7 +1755,8 @@ unsigned AArch64AsmPrinter::emitPtrauthDiscriminator(uint16_t Disc,
/// of proceeding to the next instruction (only if ShouldTrap is false).
void AArch64AsmPrinter::emitPtrauthCheckAuthenticatedValue(
Register TestedReg, Register ScratchReg, AArch64PACKey::ID Key,
bool ShouldTrap, const MCSymbol *OnFailure) {
AArch64PAuth::AuthCheckMethod Method, bool ShouldTrap,
const MCSymbol *OnFailure) {
// Insert a sequence to check if authentication of TestedReg succeeded,
// such as:
//
Expand All @@ -1777,38 +1782,70 @@ void AArch64AsmPrinter::emitPtrauthCheckAuthenticatedValue(
// Lsuccess:
// ...
//
// This sequence is expensive, but we need more information to be able to
// do better.
//
// We can't TBZ the poison bit because EnhancedPAC2 XORs the PAC bits
// on failure.
// We can't TST the PAC bits because we don't always know how the address
// space is setup for the target environment (and the bottom PAC bit is
// based on that).
// Either way, we also don't always know whether TBI is enabled or not for
// the specific target environment.
// See the documentation on AuthCheckMethod enumeration constants for
// the specific code sequences that can be used to perform the check.
using AArch64PAuth::AuthCheckMethod;

unsigned XPACOpc = getXPACOpcodeForKey(Key);
if (Method == AuthCheckMethod::None)
return;
if (Method == AuthCheckMethod::DummyLoad) {
EmitToStreamer(MCInstBuilder(AArch64::LDRWui)
.addReg(getWRegFromXReg(ScratchReg))
.addReg(TestedReg)
.addImm(0));
assert(ShouldTrap && !OnFailure && "DummyLoad always traps on error");
return;
}

MCSymbol *SuccessSym = createTempSymbol("auth_success_");
if (Method == AuthCheckMethod::XPAC || Method == AuthCheckMethod::XPACHint) {
// mov Xscratch, Xtested
emitMovXReg(ScratchReg, TestedReg);

// mov Xscratch, Xtested
emitMovXReg(ScratchReg, TestedReg);

// xpac(i|d) Xscratch
EmitToStreamer(MCInstBuilder(XPACOpc).addReg(ScratchReg).addReg(ScratchReg));
if (Method == AuthCheckMethod::XPAC) {
// xpac(i|d) Xscratch
unsigned XPACOpc = getXPACOpcodeForKey(Key);
EmitToStreamer(
MCInstBuilder(XPACOpc).addReg(ScratchReg).addReg(ScratchReg));
} else {
// xpaclri

// Note that this method applies XPAC to TestedReg instead of ScratchReg.
assert(TestedReg == AArch64::LR &&
"XPACHint mode is only compatible with checking the LR register");
assert((Key == AArch64PACKey::IA || Key == AArch64PACKey::IB) &&
"XPACHint mode is only compatible with I-keys");
EmitToStreamer(MCInstBuilder(AArch64::XPACLRI));
}

// cmp Xtested, Xscratch
EmitToStreamer(MCInstBuilder(AArch64::SUBSXrs)
.addReg(AArch64::XZR)
.addReg(TestedReg)
.addReg(ScratchReg)
.addImm(0));
// cmp Xtested, Xscratch
EmitToStreamer(MCInstBuilder(AArch64::SUBSXrs)
.addReg(AArch64::XZR)
.addReg(TestedReg)
.addReg(ScratchReg)
.addImm(0));

// b.eq Lsuccess
EmitToStreamer(MCInstBuilder(AArch64::Bcc)
.addImm(AArch64CC::EQ)
.addExpr(MCSymbolRefExpr::create(SuccessSym, OutContext)));
// b.eq Lsuccess
EmitToStreamer(
MCInstBuilder(AArch64::Bcc)
.addImm(AArch64CC::EQ)
.addExpr(MCSymbolRefExpr::create(SuccessSym, OutContext)));
} else if (Method == AuthCheckMethod::HighBitsNoTBI) {
// eor Xscratch, Xtested, Xtested, lsl #1
EmitToStreamer(MCInstBuilder(AArch64::EORXrs)
.addReg(ScratchReg)
.addReg(TestedReg)
.addReg(TestedReg)
.addImm(1));
// tbz Xscratch, #62, Lsuccess
EmitToStreamer(
MCInstBuilder(AArch64::TBZX)
.addReg(ScratchReg)
.addImm(62)
.addExpr(MCSymbolRefExpr::create(SuccessSym, OutContext)));
} else {
llvm_unreachable("Unsupported check method");
}

if (ShouldTrap) {
assert(!OnFailure && "Cannot specify OnFailure with ShouldTrap");
Expand All @@ -1822,9 +1859,26 @@ void AArch64AsmPrinter::emitPtrauthCheckAuthenticatedValue(
// Note that this can introduce an authentication oracle (such as based on
// the high bits of the re-signed value).

// FIXME: Can we simply return the AUT result, already in TestedReg?
// mov Xtested, Xscratch
emitMovXReg(TestedReg, ScratchReg);
// FIXME: The XPAC method can be optimized by applying XPAC to TestedReg
// instead of ScratchReg, thus eliminating one `mov` instruction.
// Both XPAC and XPACHint can be further optimized by not using a
// conditional branch jumping over an unconditional one.

switch (Method) {
case AuthCheckMethod::XPACHint:
// LR is already XPAC-ed at this point.
break;
case AuthCheckMethod::XPAC:
// mov Xtested, Xscratch
emitMovXReg(TestedReg, ScratchReg);
break;
default:
// If Xtested was not XPAC-ed so far, emit XPAC here.
// xpac(i|d) Xtested
unsigned XPACOpc = getXPACOpcodeForKey(Key);
EmitToStreamer(
MCInstBuilder(XPACOpc).addReg(TestedReg).addReg(TestedReg));
}

if (OnFailure) {
// b Lend
Expand All @@ -1839,6 +1893,30 @@ void AArch64AsmPrinter::emitPtrauthCheckAuthenticatedValue(
OutStreamer->emitLabel(SuccessSym);
}

// With Pointer Authentication, it may be needed to explicitly check the
// authenticated value in LR before performing a tail call.
// Otherwise, the callee may re-sign the invalid return address,
// introducing a signing oracle.
void AArch64AsmPrinter::emitPtrauthTailCallHardening(const MachineInstr *TC) {
if (!AArch64FI->shouldSignReturnAddress(*MF))
return;

auto LRCheckMethod = STI->getAuthenticatedLRCheckMethod(*MF);
if (LRCheckMethod == AArch64PAuth::AuthCheckMethod::None)
return;

const AArch64RegisterInfo *TRI = STI->getRegisterInfo();
Register ScratchReg =
TC->readsRegister(AArch64::X16, TRI) ? AArch64::X17 : AArch64::X16;
assert(!TC->readsRegister(ScratchReg, TRI) &&
"Neither x16 nor x17 is available as a scratch register");
AArch64PACKey::ID Key =
AArch64FI->shouldSignWithBKey() ? AArch64PACKey::IB : AArch64PACKey::IA;
emitPtrauthCheckAuthenticatedValue(
AArch64::LR, ScratchReg, Key, LRCheckMethod,
/*ShouldTrap=*/true, /*OnFailure=*/nullptr);
}

void AArch64AsmPrinter::emitPtrauthAuthResign(const MachineInstr *MI) {
const bool IsAUTPAC = MI->getOpcode() == AArch64::AUTPAC;

Expand All @@ -1850,7 +1928,7 @@ void AArch64AsmPrinter::emitPtrauthAuthResign(const MachineInstr *MI) {
// ; sign x16 (if AUTPAC)
// Lend: ; if not trapping on failure
//
// with the checking sequence chosen depending on whether we should check
// with the checking sequence chosen depending on whether/how we should check
// the pointer and whether we should trap on failure.

// By default, auth/resign sequences check for auth failures.
Expand Down Expand Up @@ -1910,6 +1988,7 @@ void AArch64AsmPrinter::emitPtrauthAuthResign(const MachineInstr *MI) {
EndSym = createTempSymbol("resign_end_");

emitPtrauthCheckAuthenticatedValue(AArch64::X16, AArch64::X17, AUTKey,
AArch64PAuth::AuthCheckMethod::XPAC,
ShouldTrap, EndSym);
}

Expand Down Expand Up @@ -2194,6 +2273,7 @@ void AArch64AsmPrinter::LowerMOVaddrPAC(const MachineInstr &MI) {
: AArch64PACKey::DA);

emitPtrauthCheckAuthenticatedValue(AArch64::X16, AArch64::X17, AuthKey,
AArch64PAuth::AuthCheckMethod::XPAC,
/*ShouldTrap=*/true,
/*OnFailure=*/nullptr);
}
Expand Down Expand Up @@ -2326,6 +2406,7 @@ void AArch64AsmPrinter::LowerLOADgotAUTH(const MachineInstr &MI) {
(AuthOpcode == AArch64::AUTIA ? AArch64PACKey::IA : AArch64PACKey::DA);

emitPtrauthCheckAuthenticatedValue(AuthResultReg, AArch64::X17, AuthKey,
AArch64PAuth::AuthCheckMethod::XPAC,
/*ShouldTrap=*/true,
/*OnFailure=*/nullptr);

Expand Down Expand Up @@ -2395,6 +2476,8 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
// Do any manual lowerings.
switch (MI->getOpcode()) {
default:
assert(!AArch64InstrInfo::isTailCallReturnInst(*MI) &&
"Unhandled tail call instruction");
break;
case AArch64::HINT: {
// CurrentPatchableFunctionEntrySym can be CurrentFnBegin only for
Expand Down Expand Up @@ -2538,6 +2621,8 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
? AArch64::X17
: AArch64::X16;

emitPtrauthTailCallHardening(MI);

unsigned DiscReg = AddrDisc;
if (Disc) {
if (AddrDisc != AArch64::NoRegister) {
Expand Down Expand Up @@ -2568,13 +2653,17 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
case AArch64::TCRETURNrix17:
case AArch64::TCRETURNrinotx16:
case AArch64::TCRETURNriALL: {
emitPtrauthTailCallHardening(MI);

MCInst TmpInst;
TmpInst.setOpcode(AArch64::BR);
TmpInst.addOperand(MCOperand::createReg(MI->getOperand(0).getReg()));
EmitToStreamer(*OutStreamer, TmpInst);
return;
}
case AArch64::TCRETURNdi: {
emitPtrauthTailCallHardening(MI);

MCOperand Dest;
MCInstLowering.lowerOperand(MI->getOperand(0), Dest);
MCInst TmpInst;
Expand Down
13 changes: 13 additions & 0 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,19 @@ unsigned AArch64InstrInfo::getInstSizeInBytes(const MachineInstr &MI) const {
unsigned NumBytes = 0;
const MCInstrDesc &Desc = MI.getDesc();

if (!MI.isBundle() && isTailCallReturnInst(MI)) {
NumBytes = Desc.getSize() ? Desc.getSize() : 4;

const auto *MFI = MF->getInfo<AArch64FunctionInfo>();
if (!MFI->shouldSignReturnAddress(MF))
return NumBytes;

const auto &STI = MF->getSubtarget<AArch64Subtarget>();
auto Method = STI.getAuthenticatedLRCheckMethod(*MF);
NumBytes += AArch64PAuth::getCheckerSizeInBytes(Method);
return NumBytes;
}

// Size should be preferably set in
// llvm/lib/Target/AArch64/AArch64InstrInfo.td (default case).
// Specific cases handle instructions of variable sizes
Expand Down
18 changes: 12 additions & 6 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -1964,30 +1964,36 @@ let Predicates = [HasPAuth] in {
}

// Size 16: 4 fixed + 8 variable, to compute discriminator.
// The size returned by getInstSizeInBytes() is incremented according
// to the variant of LR check.
// As the check requires either x16 or x17 as a scratch register and
// authenticated tail call instructions have two register operands,
// make sure at least one register is usable as a scratch one - for that
// purpose, use tcGPRnotx16x17 register class for one of the operands.
let isCall = 1, isTerminator = 1, isReturn = 1, isBarrier = 1, Size = 16,
Uses = [SP] in {
def AUTH_TCRETURN
: Pseudo<(outs), (ins tcGPR64:$dst, i32imm:$FPDiff, i32imm:$Key,
: Pseudo<(outs), (ins tcGPRnotx16x17:$dst, i32imm:$FPDiff, i32imm:$Key,
i64imm:$Disc, tcGPR64:$AddrDisc),
[]>, Sched<[WriteBrReg]>;
def AUTH_TCRETURN_BTI
: Pseudo<(outs), (ins tcGPRx16x17:$dst, i32imm:$FPDiff, i32imm:$Key,
i64imm:$Disc, tcGPR64:$AddrDisc),
i64imm:$Disc, tcGPRnotx16x17:$AddrDisc),
[]>, Sched<[WriteBrReg]>;
}

let Predicates = [TailCallAny] in
def : Pat<(AArch64authtcret tcGPR64:$dst, (i32 timm:$FPDiff), (i32 timm:$Key),
def : Pat<(AArch64authtcret tcGPRnotx16x17:$dst, (i32 timm:$FPDiff), (i32 timm:$Key),
(i64 timm:$Disc), tcGPR64:$AddrDisc),
(AUTH_TCRETURN tcGPR64:$dst, imm:$FPDiff, imm:$Key, imm:$Disc,
(AUTH_TCRETURN tcGPRnotx16x17:$dst, imm:$FPDiff, imm:$Key, imm:$Disc,
tcGPR64:$AddrDisc)>;

let Predicates = [TailCallX16X17] in
def : Pat<(AArch64authtcret tcGPRx16x17:$dst, (i32 timm:$FPDiff),
(i32 timm:$Key), (i64 timm:$Disc),
tcGPR64:$AddrDisc),
tcGPRnotx16x17:$AddrDisc),
(AUTH_TCRETURN_BTI tcGPRx16x17:$dst, imm:$FPDiff, imm:$Key,
imm:$Disc, tcGPR64:$AddrDisc)>;
imm:$Disc, tcGPRnotx16x17:$AddrDisc)>;
}

// v9.5-A pointer authentication extensions
Expand Down
Loading

0 comments on commit 44076c9

Please sign in to comment.