Skip to content

Commit

Permalink
[wasm] Exploit unallocated zero page to remove null checks in jiterpr…
Browse files Browse the repository at this point in the history
…eter (#86403)

* Fuse null and length checks for getelema1 and getchr
* Sense whether zero page optimizations are safe based on the location of the emscripten stack and contents of memory near zero
* Improve jiterpreter stats formatting
  • Loading branch information
kg authored May 31, 2023
1 parent 4d709cd commit 9628be4
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 18 deletions.
2 changes: 2 additions & 0 deletions src/mono/mono/utils/options-def.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ DEFINE_BOOL(jiterpreter_eliminate_null_checks, "jiterpreter-eliminate-null-check
DEFINE_BOOL(jiterpreter_backward_branches_enabled, "jiterpreter-backward-branches-enabled", TRUE, "Enable performing backward branches without exiting traces")
// Attempt to use WASM v128 opcodes to implement SIMD interpreter opcodes
DEFINE_BOOL(jiterpreter_enable_simd, "jiterpreter-simd-enabled", TRUE, "Attempt to use WebAssembly SIMD support")
// Since the zero page is unallocated, loading array/string/span lengths from null ptrs will yield zero
DEFINE_BOOL(jiterpreter_zero_page_optimization, "jiterpreter-zero-page-optimization", TRUE, "Exploit the zero page being unallocated to optimize out null checks")
// When compiling a jit_call wrapper, bypass sharedvt wrappers if possible by inlining their
// logic into the compiled wrapper and calling the target AOTed function with native call convention
DEFINE_BOOL(jiterpreter_direct_jit_call, "jiterpreter-direct-jit-calls", TRUE, "Bypass gsharedvt wrappers when compiling JIT call wrappers")
Expand Down
2 changes: 2 additions & 0 deletions src/mono/wasm/runtime/cwraps.ts
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ const fn_signatures: SigLine[] = [
[true, "mono_jiterp_get_simd_opcode", "number", ["number", "number"]],
[true, "mono_jiterp_get_arg_offset", "number", ["number", "number", "number"]],
[true, "mono_jiterp_get_opcode_info", "number", ["number", "number"]],
[true, "mono_wasm_is_zero_page_reserved", "number", []],
...legacy_interop_cwraps
];

Expand Down Expand Up @@ -255,6 +256,7 @@ export interface t_Cwraps {
mono_jiterp_get_simd_opcode(arity: number, index: number): number;
mono_jiterp_get_arg_offset(imethod: number, sig: number, index: number): number;
mono_jiterp_get_opcode_info(opcode: number, type: number): number;
mono_wasm_is_zero_page_reserved(): number;
}

const wrapped_c_functions: t_Cwraps = <any>{};
Expand Down
12 changes: 12 additions & 0 deletions src/mono/wasm/runtime/driver.c
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
#include <emscripten.h>
#include <emscripten/stack.h>
#include <stdio.h>
#include <stddef.h>
#include <stdlib.h>
Expand Down Expand Up @@ -1413,3 +1414,14 @@ EMSCRIPTEN_KEEPALIVE double mono_wasm_get_f64_unaligned (const double *src) {
EMSCRIPTEN_KEEPALIVE int32_t mono_wasm_get_i32_unaligned (const int32_t *src) {
return *src;
}

EMSCRIPTEN_KEEPALIVE int mono_wasm_is_zero_page_reserved () {
// If the stack is above the first 512 bytes of memory this indicates that it is safe
// to optimize out null checks for operations that also do a bounds check, like string
// and array element loads. (We already know that Emscripten malloc will never allocate
// data at 0.) This is the default behavior for Emscripten release builds and is
// controlled by the emscripten GLOBAL_BASE option (default value 1024).
// clang/llvm may perform this optimization if --low-memory-unused is set.
// https://github.com/emscripten-core/emscripten/issues/19389
return (emscripten_stack_get_base() > 512) && (emscripten_stack_get_end() > 512);
}
19 changes: 19 additions & 0 deletions src/mono/wasm/runtime/jiterpreter-support.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1434,6 +1434,7 @@ export const counters = {
failures: 0,
bytesGenerated: 0,
nullChecksEliminated: 0,
nullChecksFused: 0,
backBranchesEmitted: 0,
backBranchesNotEmitted: 0,
simdFallback: simdFallbackCounters,
Expand Down Expand Up @@ -1777,6 +1778,22 @@ export function bytesFromHex(hex: string): Uint8Array {
return bytes;
}

export function isZeroPageReserved(): boolean {
// FIXME: This check will always return true on worker threads.
// Right now the jiterpreter is disabled when threading is active, so that's not an issue.
if (!cwraps.mono_wasm_is_zero_page_reserved())
return false;

// Determine whether emscripten's stack checker or some other troublemaker has
// written junk at the start of memory. The previous cwraps call will have
// checked whether the stack starts at zero or not (on the main thread).
// We can't do this in the C helper because emcc/asan might be checking pointers.
return (Module.HEAPU32[0] === 0) &&
(Module.HEAPU32[1] === 0) &&
(Module.HEAPU32[2] === 0) &&
(Module.HEAPU32[3] === 0);
}

export type JiterpreterOptions = {
enableAll?: boolean;
enableTraces: boolean;
Expand All @@ -1786,6 +1803,7 @@ export type JiterpreterOptions = {
enableCallResume: boolean;
enableWasmEh: boolean;
enableSimd: boolean;
zeroPageOptimization: boolean;
// For locations where the jiterpreter heuristic says we will be unable to generate
// a trace, insert an entry point opcode anyway. This enables collecting accurate
// stats for options like estimateHeat, but raises overhead.
Expand Down Expand Up @@ -1828,6 +1846,7 @@ const optionNames: { [jsName: string]: string } = {
"enableCallResume": "jiterpreter-call-resume-enabled",
"enableWasmEh": "jiterpreter-wasm-eh-enabled",
"enableSimd": "jiterpreter-simd-enabled",
"zeroPageOptimization": "jiterpreter-zero-page-optimization",
"enableStats": "jiterpreter-stats-enabled",
"disableHeuristic": "jiterpreter-disable-heuristic",
"estimateHeat": "jiterpreter-estimate-heat",
Expand Down
70 changes: 57 additions & 13 deletions src/mono/wasm/runtime/jiterpreter-trace-generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import {
append_memmove_dest_src, try_append_memset_fast,
try_append_memmove_fast, counters, getOpcodeTableValue,
getMemberOffset, JiterpMember, BailoutReason,
isZeroPageReserved
} from "./jiterpreter-support";
import { compileSimdFeatureDetect } from "./jiterpreter-feature-detect";
import {
Expand Down Expand Up @@ -598,8 +599,25 @@ export function generateWasmBody(
append_ldloc(builder, getArgU16(ip, 3), WasmOpcode.i32_load);
// stash it, we'll be using it multiple times
builder.local("math_lhs32", WasmOpcode.tee_local);

/*
const constantIndex = get_known_constant_value(getArgU16(ip, 3));
if (typeof (constantIndex) === "number")
console.log(`getchr in ${builder.functions[0].name} with constant index ${constantIndex}`);
*/

// str
append_ldloc_cknull(builder, getArgU16(ip, 2), ip, true);
const ptrLocal = builder.options.zeroPageOptimization ? "math_rhs32" : "cknull_ptr";
if (builder.options.zeroPageOptimization && isZeroPageReserved()) {
// load string ptr and stash it
// if the string ptr is null, the length check will fail and we will bail out,
// so the null check is not necessary
counters.nullChecksFused++;
append_ldloc(builder, getArgU16(ip, 2), WasmOpcode.i32_load);
builder.local(ptrLocal, WasmOpcode.tee_local);
} else
append_ldloc_cknull(builder, getArgU16(ip, 2), ip, true);

// get string length
builder.appendU8(WasmOpcode.i32_load);
builder.appendMemarg(getMemberOffset(JiterpMember.StringLength), 2);
Expand All @@ -624,7 +642,7 @@ export function generateWasmBody(
builder.local("math_lhs32");
builder.i32_const(2);
builder.appendU8(WasmOpcode.i32_mul);
builder.local("cknull_ptr");
builder.local(ptrLocal);
builder.appendU8(WasmOpcode.i32_add);
// Load char
builder.appendU8(WasmOpcode.i32_load16_u);
Expand Down Expand Up @@ -1871,17 +1889,21 @@ function emit_fieldop(
append_ldloca(builder, localOffset, sizeBytes);
// src
builder.local("cknull_ptr");
builder.i32_const(fieldOffset);
builder.appendU8(WasmOpcode.i32_add);
if (fieldOffset !== 0) {
builder.i32_const(fieldOffset);
builder.appendU8(WasmOpcode.i32_add);
}
append_memmove_dest_src(builder, sizeBytes);
return true;
}
case MintOpcode.MINT_STFLD_VT: {
const klass = get_imethod_data(frame, getArgU16(ip, 4));
// dest = (char*)o + ip [3]
builder.local("cknull_ptr");
builder.i32_const(fieldOffset);
builder.appendU8(WasmOpcode.i32_add);
if (fieldOffset !== 0) {
builder.i32_const(fieldOffset);
builder.appendU8(WasmOpcode.i32_add);
}
// src = locals + ip [2]
append_ldloca(builder, localOffset, 0);
builder.ptr_const(klass);
Expand All @@ -1892,8 +1914,10 @@ function emit_fieldop(
const sizeBytes = getArgU16(ip, 4);
// dest
builder.local("cknull_ptr");
builder.i32_const(fieldOffset);
builder.appendU8(WasmOpcode.i32_add);
if (fieldOffset !== 0) {
builder.i32_const(fieldOffset);
builder.appendU8(WasmOpcode.i32_add);
}
// src
append_ldloca(builder, localOffset, 0);
append_memmove_dest_src(builder, sizeBytes);
Expand All @@ -1905,8 +1929,10 @@ function emit_fieldop(
builder.local("pLocals");
// cknull_ptr isn't always initialized here
append_ldloc(builder, objectOffset, WasmOpcode.i32_load);
builder.i32_const(fieldOffset);
builder.appendU8(WasmOpcode.i32_add);
if (fieldOffset !== 0) {
builder.i32_const(fieldOffset);
builder.appendU8(WasmOpcode.i32_add);
}
append_stloc_tail(builder, localOffset, setter);
return true;

Expand Down Expand Up @@ -2812,18 +2838,35 @@ function append_getelema1(
) {
builder.block();

/*
const constantIndex = get_known_constant_value(indexOffset);
if (typeof (constantIndex) === "number")
console.log(`getelema1 in ${builder.functions[0].name} with constant index ${constantIndex}`);
*/

// load index for check
append_ldloc(builder, indexOffset, WasmOpcode.i32_load);
// stash it since we need it twice
builder.local("math_lhs32", WasmOpcode.tee_local);
// array null check
append_ldloc_cknull(builder, objectOffset, ip, true);

const ptrLocal = builder.options.zeroPageOptimization ? "math_rhs32" : "cknull_ptr";
if (builder.options.zeroPageOptimization && isZeroPageReserved()) {
// load array ptr and stash it
// if the array ptr is null, the length check will fail and we will bail out
counters.nullChecksFused++;
append_ldloc(builder, objectOffset, WasmOpcode.i32_load);
builder.local(ptrLocal, WasmOpcode.tee_local);
} else
// array null check
append_ldloc_cknull(builder, objectOffset, ip, true);

// load array length
builder.appendU8(WasmOpcode.i32_load);
builder.appendMemarg(getMemberOffset(JiterpMember.ArrayLength), 2);
// check index < array.length, unsigned. if index is negative it will be interpreted as
// a massive value which is naturally going to be bigger than array.length. interp.c
// exploits this property so we can too
// for a null array pointer array.length will also be zero thanks to the zero page optimization
builder.appendU8(WasmOpcode.i32_lt_u);
// bailout unless (index < array.length)
builder.appendU8(WasmOpcode.br_if);
Expand All @@ -2832,7 +2875,7 @@ function append_getelema1(
builder.endBlock();

// We did a null check and bounds check so we can now compute the actual address
builder.local("cknull_ptr");
builder.local(ptrLocal);
builder.i32_const(getMemberOffset(JiterpMember.ArrayData));
builder.appendU8(WasmOpcode.i32_add);

Expand All @@ -2858,6 +2901,7 @@ function emit_arrayop(builder: WasmBuilder, frame: NativePointer, ip: MintOpcode
case MintOpcode.MINT_LDLEN: {
builder.local("pLocals");
// array null check
// note: zero page optimization is not valid here since we want to throw on null
append_ldloc_cknull(builder, objectOffset, ip, true);
// load array length
builder.appendU8(WasmOpcode.i32_load);
Expand Down
18 changes: 13 additions & 5 deletions src/mono/wasm/runtime/jiterpreter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import { MintOpcode } from "./mintops";
import cwraps from "./cwraps";
import {
MintOpcodePtr, WasmValtype, WasmBuilder, addWasmFunctionPointer,
_now, elapsedTimes,
_now, elapsedTimes, isZeroPageReserved,
counters, getRawCwrap, importDef,
JiterpreterOptions, getOptions, recordFailure,
JiterpMember, getMemberOffset,
Expand Down Expand Up @@ -1034,10 +1034,18 @@ export function jiterpreter_dump_stats(b?: boolean, concise?: boolean) {
if (!mostRecentOptions.enableStats && (b !== undefined))
return;

mono_log_info(`// jitted ${counters.bytesGenerated} bytes; ${counters.tracesCompiled} traces (${counters.traceCandidates} candidates, ${(counters.tracesCompiled / counters.traceCandidates * 100).toFixed(1)}%); ${counters.jitCallsCompiled} jit_calls (${(counters.directJitCallsCompiled / counters.jitCallsCompiled * 100).toFixed(1)}% direct); ${counters.entryWrappersCompiled} interp_entries`);
const backBranchHitRate = (counters.backBranchesEmitted / (counters.backBranchesEmitted + counters.backBranchesNotEmitted)) * 100;
const tracesRejected = cwraps.mono_jiterp_get_rejected_trace_count();
mono_log_info(`// time: ${elapsedTimes.generation | 0}ms generating, ${elapsedTimes.compilation | 0}ms compiling wasm. ${counters.nullChecksEliminated} cknulls removed. ${counters.backBranchesEmitted} back-branches (${counters.backBranchesNotEmitted} failed, ${backBranchHitRate.toFixed(1)}%), ${tracesRejected} traces rejected`);
const backBranchHitRate = (counters.backBranchesEmitted / (counters.backBranchesEmitted + counters.backBranchesNotEmitted)) * 100,
tracesRejected = cwraps.mono_jiterp_get_rejected_trace_count(),
nullChecksEliminatedText = mostRecentOptions.eliminateNullChecks ? counters.nullChecksEliminated.toString() : "off",
nullChecksFusedText = (mostRecentOptions.zeroPageOptimization ? counters.nullChecksFused.toString() + (isZeroPageReserved() ? "" : " (disabled)") : "off"),
backBranchesEmittedText = mostRecentOptions.enableBackwardBranches ? `emitted: ${counters.backBranchesEmitted}, failed: ${counters.backBranchesNotEmitted} (${backBranchHitRate.toFixed(1)}%)` : ": off",
directJitCallsText = counters.jitCallsCompiled ? (
mostRecentOptions.directJitCalls ? `direct jit calls: ${counters.directJitCallsCompiled} (${(counters.directJitCallsCompiled / counters.jitCallsCompiled * 100).toFixed(1)}%)` : "direct jit calls: off"
) : "";

mono_log_info(`// jitted ${counters.bytesGenerated} bytes; ${counters.tracesCompiled} traces (${(counters.tracesCompiled / counters.traceCandidates * 100).toFixed(1)}%) (${tracesRejected} rejected); ${counters.jitCallsCompiled} jit_calls; ${counters.entryWrappersCompiled} interp_entries`);
mono_log_info(`// cknulls eliminated: ${nullChecksEliminatedText}, fused: ${nullChecksFusedText}; back-branches ${backBranchesEmittedText}; ${directJitCallsText}`);
mono_log_info(`// time: ${elapsedTimes.generation | 0}ms generating, ${elapsedTimes.compilation | 0}ms compiling wasm.`);
if (concise)
return;

Expand Down

0 comments on commit 9628be4

Please sign in to comment.