Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WebAssembly] Protect memory.fill and memory.copy from zero-length ranges. #112617

Merged
merged 9 commits into from
Oct 24, 2024
9 changes: 7 additions & 2 deletions llvm/lib/Target/WebAssembly/WebAssemblyISD.def
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ HANDLE_NODETYPE(PROMOTE_LOW)
HANDLE_NODETYPE(TRUNC_SAT_ZERO_S)
HANDLE_NODETYPE(TRUNC_SAT_ZERO_U)
HANDLE_NODETYPE(DEMOTE_ZERO)
HANDLE_NODETYPE(MEMORY_COPY)
HANDLE_NODETYPE(MEMORY_FILL)
HANDLE_NODETYPE(I64_ADD128)
HANDLE_NODETYPE(I64_SUB128)
HANDLE_NODETYPE(I64_MUL_WIDE_S)
Expand All @@ -54,3 +52,10 @@ HANDLE_MEM_NODETYPE(GLOBAL_GET)
HANDLE_MEM_NODETYPE(GLOBAL_SET)
HANDLE_MEM_NODETYPE(TABLE_GET)
HANDLE_MEM_NODETYPE(TABLE_SET)

// Bulk memory instructions. These follow LLVM's expected semantics of
// supporting out-of-bounds pointers if the length is zero, by insertig
sunfishcode marked this conversation as resolved.
Show resolved Hide resolved
// a branch around Wasm's `memory.copy` and `memory.fill`, which would
// otherwise trap.
HANDLE_NODETYPE(MEMCPY)
HANDLE_NODETYPE(MEMSET)
140 changes: 140 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,138 @@ static MachineBasicBlock *LowerFPToInt(MachineInstr &MI, DebugLoc DL,
return DoneMBB;
}

// Lower a `MEMCPY` instruction into a CFG triangle around a `MEMORY_COPY`
// instuction to handle the zero-length case.
static MachineBasicBlock *LowerMemcpy(MachineInstr &MI, DebugLoc DL,
MachineBasicBlock *BB,
const TargetInstrInfo &TII, bool Int64) {
MachineRegisterInfo &MRI = BB->getParent()->getRegInfo();

MachineOperand DstMem = MI.getOperand(0);
MachineOperand SrcMem = MI.getOperand(1);
MachineOperand Dst = MI.getOperand(2);
MachineOperand Src = MI.getOperand(3);
MachineOperand Len = MI.getOperand(4);

// We're going to add an extra use to `Len` to test if it's zero; that
// use shouldn't be a kill, even if the original use is.
MachineOperand NoKillLen = Len;
NoKillLen.setIsKill(false);

// Decide on which `MachineInstr` opcode we're going to use.
unsigned Eqz = Int64 ? WebAssembly::EQZ_I64 : WebAssembly::EQZ_I32;
unsigned MemoryCopy =
Int64 ? WebAssembly::MEMORY_COPY_A64 : WebAssembly::MEMORY_COPY_A32;

// Create two new basic blocks; one for the new `memory.fill` that we can
// branch over, and one for the rest of the instructions after the original
// `memory.fill`.
const BasicBlock *LLVMBB = BB->getBasicBlock();
MachineFunction *F = BB->getParent();
MachineBasicBlock *TrueMBB = F->CreateMachineBasicBlock(LLVMBB);
MachineBasicBlock *DoneMBB = F->CreateMachineBasicBlock(LLVMBB);

MachineFunction::iterator It = ++BB->getIterator();
F->insert(It, TrueMBB);
F->insert(It, DoneMBB);

// Transfer the remainder of BB and its successor edges to DoneMBB.
DoneMBB->splice(DoneMBB->begin(), BB, std::next(MI.getIterator()), BB->end());
DoneMBB->transferSuccessorsAndUpdatePHIs(BB);

// Connect the CFG edges.
BB->addSuccessor(TrueMBB);
BB->addSuccessor(DoneMBB);
TrueMBB->addSuccessor(DoneMBB);

// Create a virtual register for the `Eqz` result.
unsigned EqzReg;
EqzReg = MRI.createVirtualRegister(&WebAssembly::I32RegClass);

// Erase the original `memory.copy`.
MI.eraseFromParent();

// Test if `Len` is zero.
BuildMI(BB, DL, TII.get(Eqz), EqzReg).add(NoKillLen);

// Insert a new `memory.copy`.
BuildMI(TrueMBB, DL, TII.get(MemoryCopy))
.add(DstMem)
.add(SrcMem)
.add(Dst)
.add(Src)
.add(Len);

// Create the CFG triangle.
BuildMI(BB, DL, TII.get(WebAssembly::BR_IF)).addMBB(DoneMBB).addReg(EqzReg);
BuildMI(TrueMBB, DL, TII.get(WebAssembly::BR)).addMBB(DoneMBB);

return DoneMBB;
}

// Lower a `MEMSET` instruction into a CFG triangle around a `MEMORY_FILL`
// instuction to handle the zero-length case.
static MachineBasicBlock *LowerMemset(MachineInstr &MI, DebugLoc DL,
MachineBasicBlock *BB,
const TargetInstrInfo &TII, bool Int64) {
MachineRegisterInfo &MRI = BB->getParent()->getRegInfo();

MachineOperand Mem = MI.getOperand(0);
MachineOperand Dst = MI.getOperand(1);
MachineOperand Val = MI.getOperand(2);
MachineOperand Len = MI.getOperand(3);

// We're going to add an extra use to `Len` to test if it's zero; that
// use shouldn't be a kill, even if the original use is.
MachineOperand NoKillLen = Len;
NoKillLen.setIsKill(false);

// Decide on which `MachineInstr` opcode we're going to use.
unsigned Eqz = Int64 ? WebAssembly::EQZ_I64 : WebAssembly::EQZ_I32;
unsigned MemoryFill =
Int64 ? WebAssembly::MEMORY_FILL_A64 : WebAssembly::MEMORY_FILL_A32;

// Create two new basic blocks; one for the new `memory.fill` that we can
// branch over, and one for the rest of the instructions after the original
// `memory.fill`.
const BasicBlock *LLVMBB = BB->getBasicBlock();
MachineFunction *F = BB->getParent();
MachineBasicBlock *TrueMBB = F->CreateMachineBasicBlock(LLVMBB);
MachineBasicBlock *DoneMBB = F->CreateMachineBasicBlock(LLVMBB);

MachineFunction::iterator It = ++BB->getIterator();
F->insert(It, TrueMBB);
F->insert(It, DoneMBB);

// Transfer the remainder of BB and its successor edges to DoneMBB.
DoneMBB->splice(DoneMBB->begin(), BB, std::next(MI.getIterator()), BB->end());
DoneMBB->transferSuccessorsAndUpdatePHIs(BB);

// Connect the CFG edges.
BB->addSuccessor(TrueMBB);
BB->addSuccessor(DoneMBB);
TrueMBB->addSuccessor(DoneMBB);

// Create a virtual register for the `Eqz` result.
unsigned EqzReg;
EqzReg = MRI.createVirtualRegister(&WebAssembly::I32RegClass);

// Erase the original `memory.fill`.
MI.eraseFromParent();

// Test if `Len` is zero.
BuildMI(BB, DL, TII.get(Eqz), EqzReg).add(NoKillLen);

// Insert a new `memory.copy`.
BuildMI(TrueMBB, DL, TII.get(MemoryFill)).add(Mem).add(Dst).add(Val).add(Len);

// Create the CFG triangle.
BuildMI(BB, DL, TII.get(WebAssembly::BR_IF)).addMBB(DoneMBB).addReg(EqzReg);
BuildMI(TrueMBB, DL, TII.get(WebAssembly::BR)).addMBB(DoneMBB);

return DoneMBB;
}

static MachineBasicBlock *
LowerCallResults(MachineInstr &CallResults, DebugLoc DL, MachineBasicBlock *BB,
const WebAssemblySubtarget *Subtarget,
Expand Down Expand Up @@ -725,6 +857,14 @@ MachineBasicBlock *WebAssemblyTargetLowering::EmitInstrWithCustomInserter(
case WebAssembly::FP_TO_UINT_I64_F64:
return LowerFPToInt(MI, DL, BB, TII, true, true, true,
WebAssembly::I64_TRUNC_U_F64);
case WebAssembly::MEMCPY_A32:
return LowerMemcpy(MI, DL, BB, TII, false);
case WebAssembly::MEMCPY_A64:
return LowerMemcpy(MI, DL, BB, TII, true);
case WebAssembly::MEMSET_A32:
return LowerMemset(MI, DL, BB, TII, false);
case WebAssembly::MEMSET_A64:
return LowerMemset(MI, DL, BB, TII, true);
case WebAssembly::CALL_RESULTS:
case WebAssembly::RET_CALL_RESULTS:
return LowerCallResults(MI, DL, BB, Subtarget, TII);
Expand Down
73 changes: 54 additions & 19 deletions llvm/lib/Target/WebAssembly/WebAssemblyInstrBulkMemory.td
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,31 @@ multiclass BULK_I<dag oops_r, dag iops_r, dag oops_s, dag iops_s,
}

// Bespoke types and nodes for bulk memory ops

def wasm_memcpy_t : SDTypeProfile<0, 5,
[SDTCisInt<0>, SDTCisInt<1>, SDTCisPtrTy<2>, SDTCisPtrTy<3>, SDTCisInt<4>]
>;
def wasm_memcpy : SDNode<"WebAssemblyISD::MEMORY_COPY", wasm_memcpy_t,
[SDNPHasChain, SDNPMayLoad, SDNPMayStore]>;

def wasm_memset_t : SDTypeProfile<0, 4,
[SDTCisInt<0>, SDTCisPtrTy<1>, SDTCisInt<2>, SDTCisInt<3>]
>;
def wasm_memset : SDNode<"WebAssemblyISD::MEMORY_FILL", wasm_memset_t,

// memory.copy with a branch to avoid trapping in the case of out-of-bounds
// pointers with empty ranges.
def wasm_memcpy : SDNode<"WebAssemblyISD::MEMCPY", wasm_memcpy_t,
[SDNPHasChain, SDNPMayLoad, SDNPMayStore]>;

// memory.fill with a branch to avoid trapping in the case of out-of-bounds
// pointers with empty ranges.
def wasm_memset : SDNode<"WebAssemblyISD::MEMSET", wasm_memset_t,
[SDNPHasChain, SDNPMayStore]>;

// A multiclass for defining Wasm's raw bulk-memory `memory.*` instructions.
// `memory.copy` and `memory.fill` have Wasm's behavior rather than
// `memcpy`/`memset` behavior.
multiclass BulkMemoryOps<WebAssemblyRegClass rc, string B> {

let mayStore = 1, hasSideEffects = 1 in
defm MEMORY_INIT_A#B :
defm INIT_A#B :
BULK_I<(outs),
(ins i32imm_op:$seg, i32imm_op:$idx, rc:$dest,
I32:$offset, I32:$size),
Expand All @@ -45,31 +54,57 @@ defm MEMORY_INIT_A#B :
"memory.init\t$seg, $idx, $dest, $offset, $size",
"memory.init\t$seg, $idx", 0x08>;

let hasSideEffects = 1 in
defm DATA_DROP :
BULK_I<(outs), (ins i32imm_op:$seg), (outs), (ins i32imm_op:$seg),
[],
"data.drop\t$seg", "data.drop\t$seg", 0x09>;

let mayLoad = 1, mayStore = 1 in
defm MEMORY_COPY_A#B :
defm COPY_A#B :
BULK_I<(outs), (ins i32imm_op:$src_idx, i32imm_op:$dst_idx,
rc:$dst, rc:$src, rc:$len),
(outs), (ins i32imm_op:$src_idx, i32imm_op:$dst_idx),
[(wasm_memcpy (i32 imm:$src_idx), (i32 imm:$dst_idx),
rc:$dst, rc:$src, rc:$len
)],
[],
"memory.copy\t$src_idx, $dst_idx, $dst, $src, $len",
"memory.copy\t$src_idx, $dst_idx", 0x0a>;

let mayStore = 1 in
defm MEMORY_FILL_A#B :
defm FILL_A#B :
BULK_I<(outs), (ins i32imm_op:$idx, rc:$dst, I32:$value, rc:$size),
(outs), (ins i32imm_op:$idx),
[(wasm_memset (i32 imm:$idx), rc:$dst, I32:$value, rc:$size)],
[],
"memory.fill\t$idx, $dst, $value, $size",
"memory.fill\t$idx", 0x0b>;
}

defm : BulkMemoryOps<I32, "32">;
defm : BulkMemoryOps<I64, "64">;
defm MEMORY_ : BulkMemoryOps<I32, "32">;
defm MEMORY_ : BulkMemoryOps<I64, "64">;

// A multiclass for defining `memcpy`/`memset` pseudo instructions. These have
// the behavior the rest of LLVM CodeGen expects, and we lower them into code
// sequences that include the Wasm `memory.fill` and `memory.copy` instructions
// using custom inserters, because they introduce new control flow.
multiclass BulkMemOps<WebAssemblyRegClass rc, string B> {

let usesCustomInserter = 1, isCodeGenOnly = 1, mayLoad = 1, mayStore = 1 in
defm CPY_A#B : I<(outs), (ins i32imm_op:$src_idx, i32imm_op:$dst_idx,
rc:$dst, rc:$src, rc:$len),
(outs), (ins i32imm_op:$src_idx, i32imm_op:$dst_idx),
[(wasm_memcpy (i32 imm:$src_idx), (i32 imm:$dst_idx),
rc:$dst, rc:$src, rc:$len
)],
"", "", 0>,
sunfishcode marked this conversation as resolved.
Show resolved Hide resolved
Requires<[HasBulkMemory]>;

let usesCustomInserter = 1, isCodeGenOnly = 1, mayStore = 1 in
defm SET_A#B : I<(outs), (ins i32imm_op:$idx, rc:$dst, I32:$value, rc:$size),
(outs), (ins i32imm_op:$idx),
[(wasm_memset (i32 imm:$idx), rc:$dst, I32:$value, rc:$size)],
"", "", 0>,
Requires<[HasBulkMemory]>;
sunfishcode marked this conversation as resolved.
Show resolved Hide resolved

}

defm MEM : BulkMemOps<I32, "32">;
defm MEM : BulkMemOps<I64, "64">;

let hasSideEffects = 1 in
defm DATA_DROP :
BULK_I<(outs), (ins i32imm_op:$seg), (outs), (ins i32imm_op:$seg),
[],
"data.drop\t$seg", "data.drop\t$seg", 0x09>;
19 changes: 14 additions & 5 deletions llvm/lib/Target/WebAssembly/WebAssemblySelectionDAGInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,13 @@ SDValue WebAssemblySelectionDAGInfo::EmitTargetCodeForMemcpy(

SDValue MemIdx = DAG.getConstant(0, DL, MVT::i32);
auto LenMVT = ST.hasAddr64() ? MVT::i64 : MVT::i32;
return DAG.getNode(WebAssemblyISD::MEMORY_COPY, DL, MVT::Other,
{Chain, MemIdx, MemIdx, Dst, Src,
DAG.getZExtOrTrunc(Size, DL, LenMVT)});

// Use `MEMCPY` here instead of `MEMORY_COPY` because `memory.copy` traps
// if the pointers are invalid even if the length is zero. `MEMCPY` gets
// extra code to handle this in the way that LLVM IR expects.
return DAG.getNode(
WebAssemblyISD::MEMCPY, DL, MVT::Other,
{Chain, MemIdx, MemIdx, Dst, Src, DAG.getZExtOrTrunc(Size, DL, LenMVT)});
}

SDValue WebAssemblySelectionDAGInfo::EmitTargetCodeForMemmove(
Expand All @@ -52,8 +56,13 @@ SDValue WebAssemblySelectionDAGInfo::EmitTargetCodeForMemset(

SDValue MemIdx = DAG.getConstant(0, DL, MVT::i32);
auto LenMVT = ST.hasAddr64() ? MVT::i64 : MVT::i32;

// Use `MEMSET` here instead of `MEMORY_FILL` because `memory.fill` traps
// if the pointers are invalid even if the length is zero. `MEMSET` gets
// extra code to handle this in the way that LLVM IR expects.
//
// Only low byte matters for val argument, so anyext the i8
return DAG.getNode(WebAssemblyISD::MEMORY_FILL, DL, MVT::Other, Chain, MemIdx,
Dst, DAG.getAnyExtOrTrunc(Val, DL, MVT::i32),
return DAG.getNode(WebAssemblyISD::MEMSET, DL, MVT::Other, Chain, MemIdx, Dst,
DAG.getAnyExtOrTrunc(Val, DL, MVT::i32),
DAG.getZExtOrTrunc(Size, DL, LenMVT));
}
Loading
Loading