Skip to content

Commit

Permalink
[WebAssembly] Protect memory.fill and memory.copy from zero-length ra…
Browse files Browse the repository at this point in the history
…nges.

WebAssembly's `memory.fill` and `memory.copy` instructions trap if the
pointers are out of bounds, even if the length is zero. This is
different from LLVM, which expects that it can call `memcpy` on
arbitrary invalid pointers if the length is zero. To avoid spurious
traps, branch around `memory.fill` and `memory.copy` when the
length is zero.
  • Loading branch information
sunfishcode committed Oct 16, 2024
1 parent 6fcea43 commit 27f66c9
Show file tree
Hide file tree
Showing 6 changed files with 434 additions and 71 deletions.
4 changes: 4 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyISD.def
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,7 @@ HANDLE_MEM_NODETYPE(GLOBAL_GET)
HANDLE_MEM_NODETYPE(GLOBAL_SET)
HANDLE_MEM_NODETYPE(TABLE_GET)
HANDLE_MEM_NODETYPE(TABLE_SET)

// Bulk memory instructions that require branching to handle empty ranges.
HANDLE_NODETYPE(MEMCPY)
HANDLE_NODETYPE(MEMSET)
140 changes: 140 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,138 @@ static MachineBasicBlock *LowerFPToInt(MachineInstr &MI, DebugLoc DL,
return DoneMBB;
}

// Lower a `MEMCPY` instruction into a CFG triangle around a `MEMORY_COPY`
// instuction to handle the zero-length case.
static MachineBasicBlock *LowerMemcpy(MachineInstr &MI, DebugLoc DL,
MachineBasicBlock *BB,
const TargetInstrInfo &TII, bool Int64) {
MachineRegisterInfo &MRI = BB->getParent()->getRegInfo();

MachineOperand DstMem = MI.getOperand(0);
MachineOperand SrcMem = MI.getOperand(1);
MachineOperand Dst = MI.getOperand(2);
MachineOperand Src = MI.getOperand(3);
MachineOperand Len = MI.getOperand(4);

// We're going to add an extra use to `Len` to test if it's zero; that
// use shouldn't be a kill, even if the original use is.
MachineOperand NoKillLen = Len;
NoKillLen.setIsKill(false);

// Decide on which `MachineInstr` opcode we're going to use.
unsigned Eqz = Int64 ? WebAssembly::EQZ_I64 : WebAssembly::EQZ_I32;
unsigned MemoryCopy =
Int64 ? WebAssembly::MEMORY_COPY_A64 : WebAssembly::MEMORY_COPY_A32;

// Create two new basic blocks; one for the new `memory.fill` that we can
// branch over, and one for the rest of the instructions after the original
// `memory.fill`.
const BasicBlock *LLVMBB = BB->getBasicBlock();
MachineFunction *F = BB->getParent();
MachineBasicBlock *TrueMBB = F->CreateMachineBasicBlock(LLVMBB);
MachineBasicBlock *DoneMBB = F->CreateMachineBasicBlock(LLVMBB);

MachineFunction::iterator It = ++BB->getIterator();
F->insert(It, TrueMBB);
F->insert(It, DoneMBB);

// Transfer the remainder of BB and its successor edges to DoneMBB.
DoneMBB->splice(DoneMBB->begin(), BB, std::next(MI.getIterator()), BB->end());
DoneMBB->transferSuccessorsAndUpdatePHIs(BB);

// Connect the CFG edges.
BB->addSuccessor(TrueMBB);
BB->addSuccessor(DoneMBB);
TrueMBB->addSuccessor(DoneMBB);

// Create a virtual register for the `Eqz` result.
unsigned EqzReg;
EqzReg = MRI.createVirtualRegister(&WebAssembly::I32RegClass);

// Erase the original `memory.copy`.
MI.eraseFromParent();

// Test if `Len` is zero.
BuildMI(BB, DL, TII.get(Eqz), EqzReg).add(NoKillLen);

// Insert a new `memory.copy`.
BuildMI(TrueMBB, DL, TII.get(MemoryCopy))
.add(DstMem)
.add(SrcMem)
.add(Dst)
.add(Src)
.add(Len);

// Create the CFG triangle.
BuildMI(BB, DL, TII.get(WebAssembly::BR_IF)).addMBB(DoneMBB).addReg(EqzReg);
BuildMI(TrueMBB, DL, TII.get(WebAssembly::BR)).addMBB(DoneMBB);

return DoneMBB;
}

// Lower a `MEMSET` instruction into a CFG triangle around a `MEMORY_FILL`
// instuction to handle the zero-length case.
static MachineBasicBlock *LowerMemset(MachineInstr &MI, DebugLoc DL,
MachineBasicBlock *BB,
const TargetInstrInfo &TII, bool Int64) {
MachineRegisterInfo &MRI = BB->getParent()->getRegInfo();

MachineOperand Mem = MI.getOperand(0);
MachineOperand Dst = MI.getOperand(1);
MachineOperand Val = MI.getOperand(2);
MachineOperand Len = MI.getOperand(3);

// We're going to add an extra use to `Len` to test if it's zero; that
// use shouldn't be a kill, even if the original use is.
MachineOperand NoKillLen = Len;
NoKillLen.setIsKill(false);

// Decide on which `MachineInstr` opcode we're going to use.
unsigned Eqz = Int64 ? WebAssembly::EQZ_I64 : WebAssembly::EQZ_I32;
unsigned MemoryFill =
Int64 ? WebAssembly::MEMORY_FILL_A64 : WebAssembly::MEMORY_FILL_A32;

// Create two new basic blocks; one for the new `memory.fill` that we can
// branch over, and one for the rest of the instructions after the original
// `memory.fill`.
const BasicBlock *LLVMBB = BB->getBasicBlock();
MachineFunction *F = BB->getParent();
MachineBasicBlock *TrueMBB = F->CreateMachineBasicBlock(LLVMBB);
MachineBasicBlock *DoneMBB = F->CreateMachineBasicBlock(LLVMBB);

MachineFunction::iterator It = ++BB->getIterator();
F->insert(It, TrueMBB);
F->insert(It, DoneMBB);

// Transfer the remainder of BB and its successor edges to DoneMBB.
DoneMBB->splice(DoneMBB->begin(), BB, std::next(MI.getIterator()), BB->end());
DoneMBB->transferSuccessorsAndUpdatePHIs(BB);

// Connect the CFG edges.
BB->addSuccessor(TrueMBB);
BB->addSuccessor(DoneMBB);
TrueMBB->addSuccessor(DoneMBB);

// Create a virtual register for the `Eqz` result.
unsigned EqzReg;
EqzReg = MRI.createVirtualRegister(&WebAssembly::I32RegClass);

// Erase the original `memory.fill`.
MI.eraseFromParent();

// Test if `Len` is zero.
BuildMI(BB, DL, TII.get(Eqz), EqzReg).add(NoKillLen);

// Insert a new `memory.copy`.
BuildMI(TrueMBB, DL, TII.get(MemoryFill)).add(Mem).add(Dst).add(Val).add(Len);

// Create the CFG triangle.
BuildMI(BB, DL, TII.get(WebAssembly::BR_IF)).addMBB(DoneMBB).addReg(EqzReg);
BuildMI(TrueMBB, DL, TII.get(WebAssembly::BR)).addMBB(DoneMBB);

return DoneMBB;
}

static MachineBasicBlock *
LowerCallResults(MachineInstr &CallResults, DebugLoc DL, MachineBasicBlock *BB,
const WebAssemblySubtarget *Subtarget,
Expand Down Expand Up @@ -718,6 +850,14 @@ MachineBasicBlock *WebAssemblyTargetLowering::EmitInstrWithCustomInserter(
case WebAssembly::FP_TO_UINT_I64_F64:
return LowerFPToInt(MI, DL, BB, TII, true, true, true,
WebAssembly::I64_TRUNC_U_F64);
case WebAssembly::MEMCPY_A32:
return LowerMemcpy(MI, DL, BB, TII, false);
case WebAssembly::MEMCPY_A64:
return LowerMemcpy(MI, DL, BB, TII, true);
case WebAssembly::MEMSET_A32:
return LowerMemset(MI, DL, BB, TII, false);
case WebAssembly::MEMSET_A64:
return LowerMemset(MI, DL, BB, TII, true);
case WebAssembly::CALL_RESULTS:
case WebAssembly::RET_CALL_RESULTS:
return LowerCallResults(MI, DL, BB, Subtarget, TII);
Expand Down
99 changes: 87 additions & 12 deletions llvm/lib/Target/WebAssembly/WebAssemblyInstrBulkMemory.td
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,33 @@ multiclass BULK_I<dag oops_r, dag iops_r, dag oops_s, dag iops_s,
}

// Bespoke types and nodes for bulk memory ops

// memory.copy (may trap on empty ranges)
def wasm_memory_copy_t : SDTypeProfile<0, 5,
[SDTCisInt<0>, SDTCisInt<1>, SDTCisPtrTy<2>, SDTCisPtrTy<3>, SDTCisInt<4>]
>;
def wasm_memory_copy : SDNode<"WebAssemblyISD::MEMORY_COPY", wasm_memory_copy_t,
[SDNPHasChain, SDNPMayLoad, SDNPMayStore]>;

// memory.copy with a branch to avoid trapping
def wasm_memcpy_t : SDTypeProfile<0, 5,
[SDTCisInt<0>, SDTCisInt<1>, SDTCisPtrTy<2>, SDTCisPtrTy<3>, SDTCisInt<4>]
>;
def wasm_memcpy : SDNode<"WebAssemblyISD::MEMORY_COPY", wasm_memcpy_t,
def wasm_memcpy : SDNode<"WebAssemblyISD::MEMCPY", wasm_memcpy_t,
[SDNPHasChain, SDNPMayLoad, SDNPMayStore]>;

// memory.fill (may trap on empty ranges)
def wasm_memory_fill_t : SDTypeProfile<0, 4,
[SDTCisInt<0>, SDTCisPtrTy<1>, SDTCisInt<2>, SDTCisInt<3>]
>;
def wasm_memory_fill : SDNode<"WebAssemblyISD::MEMORY_FILL", wasm_memory_fill_t,
[SDNPHasChain, SDNPMayStore]>;

// memory.fill with a branch to avoid trapping
def wasm_memset_t : SDTypeProfile<0, 4,
[SDTCisInt<0>, SDTCisPtrTy<1>, SDTCisInt<2>, SDTCisInt<3>]
>;
def wasm_memset : SDNode<"WebAssemblyISD::MEMORY_FILL", wasm_memset_t,
def wasm_memset : SDNode<"WebAssemblyISD::MEMSET", wasm_memset_t,
[SDNPHasChain, SDNPMayStore]>;

multiclass BulkMemoryOps<WebAssemblyRegClass rc, string B> {
Expand All @@ -51,25 +68,83 @@ defm DATA_DROP :
[],
"data.drop\t$seg", "data.drop\t$seg", 0x09>;

}

defm : BulkMemoryOps<I32, "32">;
defm : BulkMemoryOps<I64, "64">;

// Define copy/fill manually instead of using the `BulkMemoryOps` multiclass
// because when a multiclass defines opcodes, it gives them anonymous names
// and we need opcodes with names so that we can handle them with custom code.

let mayLoad = 1, mayStore = 1 in
defm MEMORY_COPY_A#B :
defm MEMORY_COPY_A32 :
BULK_I<(outs), (ins i32imm_op:$src_idx, i32imm_op:$dst_idx,
rc:$dst, rc:$src, rc:$len),
I32:$dst, I32:$src, I32:$len),
(outs), (ins i32imm_op:$src_idx, i32imm_op:$dst_idx),
[(wasm_memcpy (i32 imm:$src_idx), (i32 imm:$dst_idx),
rc:$dst, rc:$src, rc:$len
[(wasm_memory_copy (i32 imm:$src_idx), (i32 imm:$dst_idx),
I32:$dst, I32:$src, I32:$len
)],
"memory.copy\t$src_idx, $dst_idx, $dst, $src, $len",
"memory.copy\t$src_idx, $dst_idx", 0x0a>;

let mayStore = 1 in
defm MEMORY_FILL_A#B :
BULK_I<(outs), (ins i32imm_op:$idx, rc:$dst, I32:$value, rc:$size),
defm MEMORY_FILL_A32 :
BULK_I<(outs), (ins i32imm_op:$idx, I32:$dst, I32:$value, I32:$size),
(outs), (ins i32imm_op:$idx),
[(wasm_memset (i32 imm:$idx), rc:$dst, I32:$value, rc:$size)],
[(wasm_memory_fill (i32 imm:$idx), I32:$dst, I32:$value, I32:$size)],
"memory.fill\t$idx, $dst, $value, $size",
"memory.fill\t$idx", 0x0b>;
}

defm : BulkMemoryOps<I32, "32">;
defm : BulkMemoryOps<I64, "64">;
let mayLoad = 1, mayStore = 1 in
defm MEMORY_COPY_A64 :
BULK_I<(outs), (ins i32imm_op:$src_idx, i32imm_op:$dst_idx,
I64:$dst, I64:$src, I64:$len),
(outs), (ins i32imm_op:$src_idx, i32imm_op:$dst_idx),
[(wasm_memory_copy (i32 imm:$src_idx), (i32 imm:$dst_idx),
I64:$dst, I64:$src, I64:$len
)],
"memory.copy\t$src_idx, $dst_idx, $dst, $src, $len",
"memory.copy\t$src_idx, $dst_idx", 0x0a>;

let mayStore = 1 in
defm MEMORY_FILL_A64 :
BULK_I<(outs), (ins i32imm_op:$idx, I64:$dst, I32:$value, I64:$size),
(outs), (ins i32imm_op:$idx),
[(wasm_memory_fill (i32 imm:$idx), I64:$dst, I32:$value, I64:$size)],
"memory.fill\t$idx, $dst, $value, $size",
"memory.fill\t$idx", 0x0b>;

let usesCustomInserter = 1, isCodeGenOnly = 1, mayLoad = 1, mayStore = 1 in
defm MEMCPY_A32 : I<(outs), (ins i32imm_op:$src_idx, i32imm_op:$dst_idx,
I32:$dst, I32:$src, I32:$len),
(outs), (ins i32imm_op:$src_idx, i32imm_op:$dst_idx),
[(wasm_memcpy (i32 imm:$src_idx), (i32 imm:$dst_idx),
I32:$dst, I32:$src, I32:$len
)],
"", "", 0>,
Requires<[HasBulkMemory]>;

let usesCustomInserter = 1, isCodeGenOnly = 1, mayStore = 1 in
defm MEMSET_A32 : I<(outs), (ins i32imm_op:$idx, I32:$dst, I32:$value, I32:$size),
(outs), (ins i32imm_op:$idx),
[(wasm_memset (i32 imm:$idx), I32:$dst, I32:$value, I32:$size)],
"", "", 0>,
Requires<[HasBulkMemory]>;

let usesCustomInserter = 1, isCodeGenOnly = 1, mayLoad = 1, mayStore = 1 in
defm MEMCPY_A64 : I<(outs), (ins i32imm_op:$src_idx, i32imm_op:$dst_idx,
I64:$dst, I64:$src, I64:$len),
(outs), (ins i32imm_op:$src_idx, i32imm_op:$dst_idx),
[(wasm_memcpy (i32 imm:$src_idx), (i32 imm:$dst_idx),
I64:$dst, I64:$src, I64:$len
)],
"", "", 0>,
Requires<[HasBulkMemory]>;

let usesCustomInserter = 1, isCodeGenOnly = 1, mayStore = 1 in
defm MEMSET_A64 : I<(outs), (ins i32imm_op:$idx, I64:$dst, I32:$value, I64:$size),
(outs), (ins i32imm_op:$idx),
[(wasm_memset (i32 imm:$idx), I64:$dst, I32:$value, I64:$size)],
"", "", 0>,
Requires<[HasBulkMemory]>;
19 changes: 14 additions & 5 deletions llvm/lib/Target/WebAssembly/WebAssemblySelectionDAGInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,13 @@ SDValue WebAssemblySelectionDAGInfo::EmitTargetCodeForMemcpy(

SDValue MemIdx = DAG.getConstant(0, DL, MVT::i32);
auto LenMVT = ST.hasAddr64() ? MVT::i64 : MVT::i32;
return DAG.getNode(WebAssemblyISD::MEMORY_COPY, DL, MVT::Other,
{Chain, MemIdx, MemIdx, Dst, Src,
DAG.getZExtOrTrunc(Size, DL, LenMVT)});

// Use `MEMCPY` here instead of `MEMORY_COPY` because `memory.copy` traps
// if the pointers are invalid even if the length is zero. `MEMCPY` gets
// extra code to handle this in the way that LLVM IR expects.
return DAG.getNode(
WebAssemblyISD::MEMCPY, DL, MVT::Other,
{Chain, MemIdx, MemIdx, Dst, Src, DAG.getZExtOrTrunc(Size, DL, LenMVT)});
}

SDValue WebAssemblySelectionDAGInfo::EmitTargetCodeForMemmove(
Expand All @@ -52,8 +56,13 @@ SDValue WebAssemblySelectionDAGInfo::EmitTargetCodeForMemset(

SDValue MemIdx = DAG.getConstant(0, DL, MVT::i32);
auto LenMVT = ST.hasAddr64() ? MVT::i64 : MVT::i32;

// Use `MEMSET` here instead of `MEMORY_FILL` because `memory.fill` traps
// if the pointers are invalid even if the length is zero. `MEMSET` gets
// extra code to handle this in the way that LLVM IR expects.
//
// Only low byte matters for val argument, so anyext the i8
return DAG.getNode(WebAssemblyISD::MEMORY_FILL, DL, MVT::Other, Chain, MemIdx,
Dst, DAG.getAnyExtOrTrunc(Val, DL, MVT::i32),
return DAG.getNode(WebAssemblyISD::MEMSET, DL, MVT::Other, Chain, MemIdx, Dst,
DAG.getAnyExtOrTrunc(Val, DL, MVT::i32),
DAG.getZExtOrTrunc(Size, DL, LenMVT));
}
Loading

0 comments on commit 27f66c9

Please sign in to comment.