From 97d203e095c75f6efe0bdfa7584c3957d1bf5ee0 Mon Sep 17 00:00:00 2001 From: Thomas Lively Date: Tue, 31 Oct 2023 18:59:08 -0700 Subject: [PATCH 1/9] [analysis] Add an experimental TypeGeneralizing optimization This new optimization will eventually weaken casts by generalizing (i.e. un-refining) their output types. If a cast is weakened enough that its output type is a supertype of its input type, the cast will be able to be removed by OptimizeInstructions. Unlike refining cast inputs, generalizing cast outputs can break module validation. For example, if the result of a cast is stored to a local and the cast is weakened enough that its output type is no longer a subtype of that local's type, then the local.set after the cast will no longer validate. To avoid this validation failure, this optimization would have to generalize the type of the local as well. In general, the more we can generalize the types of program locations, the more we can weaken casts of values that flow into those locations. This initial implementation only generalizes the types of locals and does not actually weaken casts yet. It serves as a proof of concept for the analysis required to perform the full optimization, though. The analysis uses the new analysis framework to perform a reverse analysis tracking type requirements for each local and reference-typed stack value in a function. Planned and potential future work includes: - Taking updated local constraints into account when determining what blocks may need to be re-analyzed after the current block. - Implementing the transfer function for all kinds of expressions. - Tracking requirements on the dynamic types of each location to generalize allocations as well. - Making the analysis interprocedural and generalizing the types of more program locations. - Optimizing tuple-typed locations. - Generalizing only those locations necessary to eliminate at least one cast (although this would make the anlysis bidirectional, so it is probably better left to separate passes). --- src/passes/CMakeLists.txt | 1 + src/passes/TypeGeneralizing.cpp | 438 +++++++++++++++++++++++++ src/passes/pass.cpp | 3 + src/passes/passes.h | 1 + src/wasm-type.h | 3 + src/wasm/wasm-type.cpp | 14 + test/lit/passes/type-generalizing.wast | 198 +++++++++++ 7 files changed, 658 insertions(+) create mode 100644 src/passes/TypeGeneralizing.cpp create mode 100644 test/lit/passes/type-generalizing.wast diff --git a/src/passes/CMakeLists.txt b/src/passes/CMakeLists.txt index bd1dd8598d3..8ffcd949c6d 100644 --- a/src/passes/CMakeLists.txt +++ b/src/passes/CMakeLists.txt @@ -103,6 +103,7 @@ set(passes_SOURCES ReorderLocals.cpp ReReloop.cpp TrapMode.cpp + TypeGeneralizing.cpp TypeRefining.cpp TypeMerging.cpp TypeSSA.cpp diff --git a/src/passes/TypeGeneralizing.cpp b/src/passes/TypeGeneralizing.cpp new file mode 100644 index 00000000000..b73bfcac035 --- /dev/null +++ b/src/passes/TypeGeneralizing.cpp @@ -0,0 +1,438 @@ +/* + * Copyright 2023 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "analysis/cfg.h" +#include "analysis/lattice.h" +#include "analysis/lattices/inverted.h" +#include "analysis/lattices/shared.h" +#include "analysis/lattices/stack.h" +#include "analysis/lattices/tuple.h" +#include "analysis/lattices/valtype.h" +#include "analysis/lattices/vector.h" +#include "analysis/monotone-analyzer.h" +#include "ir/utils.h" +#include "pass.h" +#include "wasm-traversal.h" +#include "wasm.h" + +#define TYPE_GENERALIZING_DEBUG 0 + +#if TYPE_GENERALIZING_DEBUG +#define DBG(statement) statement +#else +#define DBG(statement) +#endif + +// Generalize the types of program locations as much as possible, both to +// eliminate unnecessarily refined types from the type section and (TODO) to +// weaken casts that cast to unnecessarily refined types. If the casts are +// weakened enough, they will be able to be removed by OptimizeInstructions. +// +// Perform a backward analysis tracking requirements on the types of program +// locations (currently just locals and stack values) to discover how much the +// type of each location can be generalized without breaking validation or +// changing program behavior. + +namespace wasm { + +namespace { + +using namespace analysis; + +// We will learn stricter and stricter requirements as we perform the analysis, +// so more specific types need to be higher up the lattice. +using TypeRequirement = Inverted; + +// Record a type requirement for each local variable. Shared the requirements +// across basic blocks. +using LocalTypeRequirements = Shared>; + +// The type requirements for each reference-typed value on the stack at a +// particular location. +using ValueStackTypeRequirements = Stack; + +// The full lattice used for the analysis. +using StateLattice = + analysis::Tuple; + +// Equip the state lattice with helpful accessors. +struct State : StateLattice { + using Element = StateLattice::Element; + + static constexpr int LocalsIndex = 0; + static constexpr int StackIndex = 1; + + State(Function* func) : StateLattice{Shared{initLocals(func)}, initStack()} {} + + void push(Element& elem, Type type) const noexcept { + stackLattice().push(stack(elem), std::move(type)); + } + + Type pop(Element& elem) const noexcept { + return stackLattice().pop(stack(elem)); + } + + void clearStack(Element& elem) const noexcept { + stack(elem) = stackLattice().getBottom(); + } + + const std::vector& getLocals(Element& elem) const noexcept { + return *locals(elem); + } + + const std::vector& getLocals() const noexcept { + return *locals(getBottom()); + } + + Type getLocal(Element& elem, Index i) const noexcept { + return getLocals(elem)[i]; + } + + void updateLocal(Element& elem, Index i, Type type) const noexcept { + localsLattice().join( + locals(elem), + Vector::SingletonElement(i, std::move(type))); + } + +private: + static LocalTypeRequirements initLocals(Function* func) noexcept { + return Shared{Vector{Inverted{ValType{}}, func->getNumLocals()}}; + } + + static ValueStackTypeRequirements initStack() noexcept { + return Stack{Inverted{ValType{}}}; + } + + const LocalTypeRequirements& localsLattice() const noexcept { + return std::get(lattices); + } + + const ValueStackTypeRequirements& stackLattice() const noexcept { + return std::get(lattices); + } + + typename LocalTypeRequirements::Element& + locals(Element& elem) const noexcept { + return std::get(elem); + } + + const typename LocalTypeRequirements::Element& + locals(const Element& elem) const noexcept { + return std::get(elem); + } + + typename ValueStackTypeRequirements::Element& + stack(Element& elem) const noexcept { + return std::get(elem); + } + + const typename ValueStackTypeRequirements::Element& + stack(const Element& elem) const noexcept { + return std::get(elem); + } +}; + +struct TransferFn : OverriddenVisitor { + Module& wasm; + Function* func; + State lattice; + typename State::Element* state = nullptr; + + TransferFn(Module& wasm, Function* func) + : wasm(wasm), func(func), lattice(func) {} + + Type pop() noexcept { return lattice.pop(*state); } + void push(Type type) noexcept { lattice.push(*state, type); } + void clearStack() noexcept { lattice.clearStack(*state); } + Type getLocal(Index i) noexcept { return lattice.getLocal(*state, i); } + void updateLocal(Index i, Type type) noexcept { + // TODO: Collect possible successor blocks that might depend on an updated + // local requirement here. + return lattice.updateLocal(*state, i, type); + } + + void dumpState() { +#if TYPE_GENERALIZING_DEBUG + std::cerr << "locals: "; + for (size_t i = 0, n = lattice.getLocals(*state).size(); i < n; ++i) { + if (i != 0) { + std::cerr << ", "; + } + std::cerr << getLocal(i); + } + std::cerr << "\nstack: "; + auto& stack = std::get<1>(*state); + for (size_t i = 0, n = stack.size(); i < n; ++i) { + if (i != 0) { + std::cerr << ", "; + } + std::cerr << stack[i]; + } + std::cerr << "\n"; +#endif // TYPE_GENERALIZING_DEBUG + } + + std::vector + transfer(const BasicBlock& bb, typename State::Element& elem) noexcept { + DBG(std::cerr << "transferring bb " << bb.getIndex() << "\n"); + state = &elem; + // This is a backward analysis: The constraints on a type depend on how it + // will be used in the future. Traverse the basic block in reverse. + dumpState(); + if (bb.isExit()) { + DBG(std::cerr << "visiting exit\n"); + visitFunctionExit(); + dumpState(); + } + for (auto it = bb.rbegin(); it != bb.rend(); ++it) { + DBG(std::cerr << "visiting " << ShallowExpression{*it} << "\n"); + visit(*it); + dumpState(); + } + DBG(std::cerr << "\n"); + + state = nullptr; + + // Return the blocks that may need to be re-analyzed. + const auto& preds = bb.preds(); + std::vector dependents; + dependents.reserve(preds.size()); + dependents.insert(dependents.end(), preds.begin(), preds.end()); + return dependents; + } + + void visitFunctionExit() { + // We cannot change the types of parameters, so require that they have their + // original types. + Index i = 0; + Index numParams = func->getNumParams(); + Index numLocals = func->getNumLocals(); + for (; i < numParams; ++i) { + updateLocal(i, func->getLocalType(i)); + } + // We also cannot change the types of any other non-ref locals. For + // reference-typed locals, we cannot generalize beyond their top type. + for (; i < numLocals; ++i) { + auto type = func->getLocalType(i); + // TODO: Support optimizing tuple locals. + if (type.isRef()) { + updateLocal(i, Type(type.getHeapType().getTop(), Nullable)); + } else { + updateLocal(i, type); + } + } + // We similarly cannot change the types of results. Push requirements that + // the stack end up with the correct type. + if (auto result = func->getResults(); result.isRef()) { + push(result); + } + } + + void visitNop(Nop* curr) {} + void visitBlock(Block* curr) {} + void visitIf(If* curr) {} + void visitLoop(Loop* curr) {} + void visitBreak(Break* curr) { + // TODO: pop extra elements off stack, keeping only those at the top that + // will be sent along. + WASM_UNREACHABLE("TODO"); + } + + void visitSwitch(Switch* curr) { + // TODO: pop extra elements off stack, keeping only those at the top that + // will be sent along. + WASM_UNREACHABLE("TODO"); + } + + void visitCall(Call* curr) { + // TODO: pop ref types from results, push ref types from params + WASM_UNREACHABLE("TODO"); + } + + void visitCallIndirect(CallIndirect* curr) { + // TODO: pop ref types from results, push ref types from params + WASM_UNREACHABLE("TODO"); + } + + void visitLocalGet(LocalGet* curr) { + if (!curr->type.isRef()) { + return; + } + // Propagate the requirement on the local.get output to the local. + updateLocal(curr->index, pop()); + } + + void visitLocalSet(LocalSet* curr) { + if (!curr->value->type.isRef()) { + return; + } + if (curr->isTee()) { + // Same as the local.get. + updateLocal(curr->index, pop()); + } + // Propagate the requirement on the local to our input value. + push(getLocal(curr->index)); + } + + void visitGlobalGet(GlobalGet* curr) { WASM_UNREACHABLE("TODO"); } + void visitGlobalSet(GlobalSet* curr) { WASM_UNREACHABLE("TODO"); } + void visitLoad(Load* curr) { WASM_UNREACHABLE("TODO"); } + void visitStore(Store* curr) { WASM_UNREACHABLE("TODO"); } + void visitAtomicRMW(AtomicRMW* curr) { WASM_UNREACHABLE("TODO"); } + void visitAtomicCmpxchg(AtomicCmpxchg* curr) { WASM_UNREACHABLE("TODO"); } + void visitAtomicWait(AtomicWait* curr) { WASM_UNREACHABLE("TODO"); } + void visitAtomicNotify(AtomicNotify* curr) { WASM_UNREACHABLE("TODO"); } + void visitAtomicFence(AtomicFence* curr) { WASM_UNREACHABLE("TODO"); } + void visitSIMDExtract(SIMDExtract* curr) { WASM_UNREACHABLE("TODO"); } + void visitSIMDReplace(SIMDReplace* curr) { WASM_UNREACHABLE("TODO"); } + void visitSIMDShuffle(SIMDShuffle* curr) { WASM_UNREACHABLE("TODO"); } + void visitSIMDTernary(SIMDTernary* curr) { WASM_UNREACHABLE("TODO"); } + void visitSIMDShift(SIMDShift* curr) { WASM_UNREACHABLE("TODO"); } + void visitSIMDLoad(SIMDLoad* curr) { WASM_UNREACHABLE("TODO"); } + void visitSIMDLoadStoreLane(SIMDLoadStoreLane* curr) { + WASM_UNREACHABLE("TODO"); + } + void visitMemoryInit(MemoryInit* curr) { WASM_UNREACHABLE("TODO"); } + void visitDataDrop(DataDrop* curr) { WASM_UNREACHABLE("TODO"); } + void visitMemoryCopy(MemoryCopy* curr) { WASM_UNREACHABLE("TODO"); } + void visitMemoryFill(MemoryFill* curr) { WASM_UNREACHABLE("TODO"); } + void visitConst(Const* curr) {} + void visitUnary(Unary* curr) {} + void visitBinary(Binary* curr) {} + void visitSelect(Select* curr) { WASM_UNREACHABLE("TODO"); } + void visitDrop(Drop* curr) { + if (curr->type.isRef()) { + pop(); + } + } + void visitReturn(Return* curr) { WASM_UNREACHABLE("TODO"); } + void visitMemorySize(MemorySize* curr) { WASM_UNREACHABLE("TODO"); } + void visitMemoryGrow(MemoryGrow* curr) { WASM_UNREACHABLE("TODO"); } + void visitUnreachable(Unreachable* curr) { + // Require nothing about values flowing into an unreachable. + clearStack(); + } + void visitPop(Pop* curr) { WASM_UNREACHABLE("TODO"); } + void visitRefNull(RefNull* curr) { WASM_UNREACHABLE("TODO"); } + void visitRefIsNull(RefIsNull* curr) { WASM_UNREACHABLE("TODO"); } + void visitRefFunc(RefFunc* curr) { WASM_UNREACHABLE("TODO"); } + void visitRefEq(RefEq* curr) { WASM_UNREACHABLE("TODO"); } + void visitTableGet(TableGet* curr) { WASM_UNREACHABLE("TODO"); } + void visitTableSet(TableSet* curr) { WASM_UNREACHABLE("TODO"); } + void visitTableSize(TableSize* curr) { WASM_UNREACHABLE("TODO"); } + void visitTableGrow(TableGrow* curr) { WASM_UNREACHABLE("TODO"); } + void visitTableFill(TableFill* curr) { WASM_UNREACHABLE("TODO"); } + void visitTry(Try* curr) { WASM_UNREACHABLE("TODO"); } + void visitThrow(Throw* curr) { WASM_UNREACHABLE("TODO"); } + void visitRethrow(Rethrow* curr) { WASM_UNREACHABLE("TODO"); } + void visitTupleMake(TupleMake* curr) { WASM_UNREACHABLE("TODO"); } + void visitTupleExtract(TupleExtract* curr) { WASM_UNREACHABLE("TODO"); } + void visitRefI31(RefI31* curr) { pop(); } + void visitI31Get(I31Get* curr) { + // Do not allow relaxing to nullable if the input is already non-nullable. + if (curr->i31->type.isNonNullable()) { + push(Type(HeapType::i31, NonNullable)); + } else { + push(Type(HeapType::i31, Nullable)); + } + } + void visitCallRef(CallRef* curr) { WASM_UNREACHABLE("TODO"); } + void visitRefTest(RefTest* curr) { WASM_UNREACHABLE("TODO"); } + void visitRefCast(RefCast* curr) { WASM_UNREACHABLE("TODO"); } + void visitBrOn(BrOn* curr) { WASM_UNREACHABLE("TODO"); } + void visitStructNew(StructNew* curr) { WASM_UNREACHABLE("TODO"); } + void visitStructGet(StructGet* curr) { WASM_UNREACHABLE("TODO"); } + void visitStructSet(StructSet* curr) { WASM_UNREACHABLE("TODO"); } + void visitArrayNew(ArrayNew* curr) { WASM_UNREACHABLE("TODO"); } + void visitArrayNewData(ArrayNewData* curr) { WASM_UNREACHABLE("TODO"); } + void visitArrayNewElem(ArrayNewElem* curr) { WASM_UNREACHABLE("TODO"); } + void visitArrayNewFixed(ArrayNewFixed* curr) { WASM_UNREACHABLE("TODO"); } + void visitArrayGet(ArrayGet* curr) { WASM_UNREACHABLE("TODO"); } + void visitArraySet(ArraySet* curr) { WASM_UNREACHABLE("TODO"); } + void visitArrayLen(ArrayLen* curr) { WASM_UNREACHABLE("TODO"); } + void visitArrayCopy(ArrayCopy* curr) { WASM_UNREACHABLE("TODO"); } + void visitArrayFill(ArrayFill* curr) { WASM_UNREACHABLE("TODO"); } + void visitArrayInitData(ArrayInitData* curr) { WASM_UNREACHABLE("TODO"); } + void visitArrayInitElem(ArrayInitElem* curr) { WASM_UNREACHABLE("TODO"); } + void visitRefAs(RefAs* curr) { WASM_UNREACHABLE("TODO"); } + void visitStringNew(StringNew* curr) { WASM_UNREACHABLE("TODO"); } + void visitStringConst(StringConst* curr) { WASM_UNREACHABLE("TODO"); } + void visitStringMeasure(StringMeasure* curr) { WASM_UNREACHABLE("TODO"); } + void visitStringEncode(StringEncode* curr) { WASM_UNREACHABLE("TODO"); } + void visitStringConcat(StringConcat* curr) { WASM_UNREACHABLE("TODO"); } + void visitStringEq(StringEq* curr) { WASM_UNREACHABLE("TODO"); } + void visitStringAs(StringAs* curr) { WASM_UNREACHABLE("TODO"); } + void visitStringWTF8Advance(StringWTF8Advance* curr) { + WASM_UNREACHABLE("TODO"); + } + void visitStringWTF16Get(StringWTF16Get* curr) { WASM_UNREACHABLE("TODO"); } + void visitStringIterNext(StringIterNext* curr) { WASM_UNREACHABLE("TODO"); } + void visitStringIterMove(StringIterMove* curr) { WASM_UNREACHABLE("TODO"); } + void visitStringSliceWTF(StringSliceWTF* curr) { WASM_UNREACHABLE("TODO"); } + void visitStringSliceIter(StringSliceIter* curr) { WASM_UNREACHABLE("TODO"); } +}; + +struct TypeGeneralizing : WalkerPass> { + std::vector localTypes; + bool refinalize = false; + + bool isFunctionParallel() override { return true; } + std::unique_ptr create() { + return std::make_unique(); + } + + void runOnFunction(Module* wasm, Function* func) override { + TransferFn txfn(*wasm, func); + auto cfg = CFG::fromFunction(func); + DBG(cfg.print(std::cerr)); + MonotoneCFGAnalyzer analyzer(txfn.lattice, txfn, cfg); + analyzer.evaluate(); + + // Optimize local types. TODO: Optimize casts as well. + localTypes = txfn.lattice.getLocals(); + auto numParams = func->getNumParams(); + for (Index i = numParams; i < localTypes.size(); ++i) { + func->vars[i - numParams] = localTypes[i]; + } + + // Update gets and sets accordingly. + super::runOnFunction(wasm, func); + + if (refinalize) { + ReFinalize().walkFunctionInModule(func, wasm); + } + } + + void visitLocalGet(LocalGet* curr) { + if (localTypes[curr->index] != curr->type) { + curr->type = localTypes[curr->index]; + refinalize = true; + } + } + + void visitLocalSet(LocalSet* curr) { + if (curr->isTee() && localTypes[curr->index] != curr->type) { + curr->type = localTypes[curr->index]; + refinalize = true; + } + } +}; + +} // anonymous namespace + +Pass* createTypeGeneralizingPass() { return new TypeGeneralizing; } + +} // namespace wasm diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp index 35085c201da..c9fe95743e4 100644 --- a/src/passes/pass.cpp +++ b/src/passes/pass.cpp @@ -509,6 +509,9 @@ void PassRegistry::registerPasses() { registerTestPass("catch-pop-fixup", "fixup nested pops within catches", createCatchPopFixupPass); + registerTestPass("experimental-type-generalizing", + "generalize types (not yet sound)", + createTypeGeneralizingPass); } void PassRunner::addIfNoDWARFIssues(std::string passName) { diff --git a/src/passes/passes.h b/src/passes/passes.h index 2bace5bcc28..81d99cb97c3 100644 --- a/src/passes/passes.h +++ b/src/passes/passes.h @@ -154,6 +154,7 @@ Pass* createSSAifyNoMergePass(); Pass* createTrapModeClamp(); Pass* createTrapModeJS(); Pass* createTupleOptimizationPass(); +Pass* createTypeGeneralizingPass(); Pass* createTypeRefiningPass(); Pass* createTypeFinalizingPass(); Pass* createTypeMergingPass(); diff --git a/src/wasm-type.h b/src/wasm-type.h index 0061b9626d7..573cd9102b3 100644 --- a/src/wasm-type.h +++ b/src/wasm-type.h @@ -390,6 +390,9 @@ class HeapType { // Get the bottom heap type for this heap type's hierarchy. BasicHeapType getBottom() const; + // Get the top heap type for this heap type's hierarchy. + BasicHeapType getTop() const; + // Get the recursion group for this non-basic type. RecGroup getRecGroup() const; size_t getRecGroupIndex() const; diff --git a/src/wasm/wasm-type.cpp b/src/wasm/wasm-type.cpp index cab68d00d81..dce0eb64583 100644 --- a/src/wasm/wasm-type.cpp +++ b/src/wasm/wasm-type.cpp @@ -1386,6 +1386,20 @@ HeapType::BasicHeapType HeapType::getBottom() const { WASM_UNREACHABLE("unexpected kind"); } +HeapType::BasicHeapType HeapType::getTop() const { + switch (getBottom()) { + case none: + return any; + case nofunc: + return func; + case noext: + return ext; + default: + break; + } + WASM_UNREACHABLE("unexpected type"); +} + bool HeapType::isSubType(HeapType left, HeapType right) { // As an optimization, in the common case do not even construct a SubTyper. if (left == right) { diff --git a/test/lit/passes/type-generalizing.wast b/test/lit/passes/type-generalizing.wast new file mode 100644 index 00000000000..b514c0eaf78 --- /dev/null +++ b/test/lit/passes/type-generalizing.wast @@ -0,0 +1,198 @@ +;; NOTE: Assertions have been generated by update_lit_checks.py --all-items and should not be edited. +;; RUN: foreach %s %t wasm-opt --dce --experimental-type-generalizing -all -S -o - | filecheck %s + +(module + + ;; CHECK: (type $0 (func)) + + ;; CHECK: (type $1 (func (result eqref))) + + ;; CHECK: (type $2 (func (param anyref))) + + ;; CHECK: (type $3 (func (param i31ref))) + + ;; CHECK: (type $4 (func (param anyref eqref))) + + ;; CHECK: (type $5 (func (param eqref))) + + ;; CHECK: (func $unconstrained (type $0) + ;; CHECK-NEXT: (local $x i32) + ;; CHECK-NEXT: (local $y anyref) + ;; CHECK-NEXT: (local $z (anyref i32)) + ;; CHECK-NEXT: (nop) + ;; CHECK-NEXT: ) + (func $unconstrained + ;; This non-ref local should be unmodified + (local $x i32) + ;; There is no constraint on the type of this local, so make it top. + (local $y anyref) + ;; We cannot optimize tuple locals yet, so leave it unchanged. + (local $z (anyref i32)) + ) + + ;; CHECK: (func $implicit-return (type $1) (result eqref) + ;; CHECK-NEXT: (local $var eqref) + ;; CHECK-NEXT: (local.get $var) + ;; CHECK-NEXT: ) + (func $implicit-return (result eqref) + ;; This will be optimized, but only to eqref because of the constaint from the + ;; implicit return. + (local $var i31ref) + (local.get $var) + ) + + ;; CHECK: (func $implicit-return-unreachable (type $1) (result eqref) + ;; CHECK-NEXT: (local $var none) + ;; CHECK-NEXT: (unreachable) + ;; CHECK-NEXT: ) + (func $implicit-return-unreachable (result eqref) + ;; Now will optimize this all the way to anyref because we don't analyze + ;; unreachable code. This would not validate if we didn't run DCE first. + (local $var i31ref) + (unreachable) + (local.get $var) + ) + + ;; CHECK: (func $local-set (type $0) + ;; CHECK-NEXT: (local $var anyref) + ;; CHECK-NEXT: (local.set $var + ;; CHECK-NEXT: (ref.i31 + ;; CHECK-NEXT: (i32.const 0) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $local-set + ;; This will be optimized to anyref. + (local $var i31ref) + ;; Require that (ref i31) <: typeof($var). + (local.set $var + (i31.new + (i32.const 0) + ) + ) + ) + + ;; CHECK: (func $local-get-set (type $2) (param $dest anyref) + ;; CHECK-NEXT: (local $var anyref) + ;; CHECK-NEXT: (local.set $dest + ;; CHECK-NEXT: (local.get $var) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $local-get-set (param $dest anyref) + ;; This will be optimized to anyref. + (local $var i31ref) + ;; Require that typeof($var) <: typeof($dest). + (local.set $dest + (local.get $var) + ) + ) + + ;; CHECK: (func $local-get-set-unreachable (type $3) (param $dest i31ref) + ;; CHECK-NEXT: (local $var none) + ;; CHECK-NEXT: (unreachable) + ;; CHECK-NEXT: ) + (func $local-get-set-unreachable (param $dest i31ref) + ;; This is not constrained by reachable code, so we will optimize it. + (local $var i31ref) + (unreachable) + ;; This would require that typeof($var) <: typeof($dest), except it is + ;; unreachable. This would not validate if we didn't run DCE first. + (local.set $dest + (local.tee $var + (local.get $var) + ) + ) + ) + + ;; CHECK: (func $local-get-set-join (type $4) (param $dest1 anyref) (param $dest2 eqref) + ;; CHECK-NEXT: (local $var eqref) + ;; CHECK-NEXT: (local.set $dest1 + ;; CHECK-NEXT: (local.get $var) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (local.set $dest2 + ;; CHECK-NEXT: (local.get $var) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $local-get-set-join (param $dest1 anyref) (param $dest2 eqref) + ;; This wll be optimized to eqref. + (local $var i31ref) + ;; Require that typeof($var) <: typeof($dest1). + (local.set $dest1 + (local.get $var) + ) + ;; Also require that typeof($var) <: typeof($dest2). + (local.set $dest2 + (local.get $var) + ) + ) + + ;; CHECK: (func $local-tee (type $5) (param $dest eqref) + ;; CHECK-NEXT: (local $var eqref) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (local.tee $dest + ;; CHECK-NEXT: (local.tee $var + ;; CHECK-NEXT: (ref.i31 + ;; CHECK-NEXT: (i32.const 0) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $local-tee (param $dest eqref) + ;; This will be optimized to eqref. + (local $var i31ref) + (drop + (local.tee $dest + (local.tee $var + (i31.new + (i32.const 0) + ) + ) + ) + ) + ) + + ;; CHECK: (func $i31-get (type $0) + ;; CHECK-NEXT: (local $nullable i31ref) + ;; CHECK-NEXT: (local $nonnullable (ref i31)) + ;; CHECK-NEXT: (local.set $nonnullable + ;; CHECK-NEXT: (ref.i31 + ;; CHECK-NEXT: (i32.const 0) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (i31.get_s + ;; CHECK-NEXT: (local.get $nullable) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (i31.get_u + ;; CHECK-NEXT: (local.get $nonnullable) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $i31-get + ;; This must stay an i31ref. + (local $nullable i31ref) + ;; This one could be relaxed to be nullable in principle, but we keep it non-nullable. + (local $nonnullable (ref i31)) + ;; Initialize the non-nullable local for validation purposes. + (local.set $nonnullable + (i31.new + (i32.const 0) + ) + ) + (drop + ;; Require that typeof($nullable) <: i31ref. + (i31.get_s + (local.get $nullable) + ) + ) + (drop + ;; Require that typeof($nonnullable) <: i31ref. + (i31.get_u + (local.get $nonnullable) + ) + ) + ) +) From b92aa0a4edbca45e6d17abe2517e8c380e45aaef Mon Sep 17 00:00:00 2001 From: Thomas Lively Date: Fri, 3 Nov 2023 18:54:37 -0700 Subject: [PATCH 2/9] add test with nontrivial control flow --- test/lit/passes/type-generalizing.wast | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/test/lit/passes/type-generalizing.wast b/test/lit/passes/type-generalizing.wast index b514c0eaf78..a328ee7b255 100644 --- a/test/lit/passes/type-generalizing.wast +++ b/test/lit/passes/type-generalizing.wast @@ -53,6 +53,27 @@ (local.get $var) ) + ;; CHECK: (func $if (type $1) (result eqref) + ;; CHECK-NEXT: (local $x eqref) + ;; CHECK-NEXT: (local $y eqref) + ;; CHECK-NEXT: (if (result eqref) + ;; CHECK-NEXT: (i32.const 0) + ;; CHECK-NEXT: (local.get $x) + ;; CHECK-NEXT: (local.get $y) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $if (result (eqref)) + (local $x i31ref) + (local $y i31ref) + (if (result i31ref) + (i32.const 0) + ;; Require that typeof($x) <: eqref. + (local.get $x) + ;; Require that typeof($y) <: eqref. + (local.get $y) + ) + ) + ;; CHECK: (func $local-set (type $0) ;; CHECK-NEXT: (local $var anyref) ;; CHECK-NEXT: (local.set $var From 19f73aaf25238aa2f70351856ffec3c72f44fed3 Mon Sep 17 00:00:00 2001 From: Thomas Lively Date: Mon, 6 Nov 2023 16:46:14 -0800 Subject: [PATCH 3/9] handle locals at function entry instead of exit --- src/passes/TypeGeneralizing.cpp | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/src/passes/TypeGeneralizing.cpp b/src/passes/TypeGeneralizing.cpp index b73bfcac035..ead7c093c61 100644 --- a/src/passes/TypeGeneralizing.cpp +++ b/src/passes/TypeGeneralizing.cpp @@ -28,7 +28,7 @@ #include "wasm-traversal.h" #include "wasm.h" -#define TYPE_GENERALIZING_DEBUG 0 +#define TYPE_GENERALIZING_DEBUG 1 #if TYPE_GENERALIZING_DEBUG #define DBG(statement) statement @@ -202,6 +202,11 @@ struct TransferFn : OverriddenVisitor { visit(*it); dumpState(); } + if (bb.isEntry()) { + DBG(std::cerr << "visiting entry\n"); + visitFunctionEntry(); + dumpState(); + } DBG(std::cerr << "\n"); state = nullptr; @@ -215,6 +220,14 @@ struct TransferFn : OverriddenVisitor { } void visitFunctionExit() { + // We cannot change the types of results. Push requirements that + // the stack end up with the correct type. + if (auto result = func->getResults(); result.isRef()) { + push(result); + } + } + + void visitFunctionEntry() { // We cannot change the types of parameters, so require that they have their // original types. Index i = 0; @@ -225,7 +238,7 @@ struct TransferFn : OverriddenVisitor { } // We also cannot change the types of any other non-ref locals. For // reference-typed locals, we cannot generalize beyond their top type. - for (; i < numLocals; ++i) { + for (Index i = numParams; i < numLocals; ++i) { auto type = func->getLocalType(i); // TODO: Support optimizing tuple locals. if (type.isRef()) { @@ -234,11 +247,6 @@ struct TransferFn : OverriddenVisitor { updateLocal(i, type); } } - // We similarly cannot change the types of results. Push requirements that - // the stack end up with the correct type. - if (auto result = func->getResults(); result.isRef()) { - push(result); - } } void visitNop(Nop* curr) {} @@ -335,6 +343,7 @@ struct TransferFn : OverriddenVisitor { void visitTableSize(TableSize* curr) { WASM_UNREACHABLE("TODO"); } void visitTableGrow(TableGrow* curr) { WASM_UNREACHABLE("TODO"); } void visitTableFill(TableFill* curr) { WASM_UNREACHABLE("TODO"); } + void visitTableCopy(TableCopy* curr) { WASM_UNREACHABLE("TODO"); } void visitTry(Try* curr) { WASM_UNREACHABLE("TODO"); } void visitThrow(Throw* curr) { WASM_UNREACHABLE("TODO"); } void visitRethrow(Rethrow* curr) { WASM_UNREACHABLE("TODO"); } From ceadc937ea876bf99dcac37b872a965553172cc9 Mon Sep 17 00:00:00 2001 From: Thomas Lively Date: Fri, 3 Nov 2023 18:43:22 -0700 Subject: [PATCH 4/9] [TypeGeneralizing] Properly re-analyze blocks when locals are updated Whenever the constraint on a local is updated, any block that does a local.set on that global may need to be re-analyzed. Update the TypeGeneralizing transfer function to include these blocks in the set of dependent blocks it returns. Add a test that depends on this logic to validate. --- src/passes/TypeGeneralizing.cpp | 58 ++++++++++++++------ test/lit/passes/type-generalizing.wast | 73 +++++++++++++++++++++++--- 2 files changed, 108 insertions(+), 23 deletions(-) diff --git a/src/passes/TypeGeneralizing.cpp b/src/passes/TypeGeneralizing.cpp index ead7c093c61..fadcaf941a8 100644 --- a/src/passes/TypeGeneralizing.cpp +++ b/src/passes/TypeGeneralizing.cpp @@ -101,8 +101,8 @@ struct State : StateLattice { return getLocals(elem)[i]; } - void updateLocal(Element& elem, Index i, Type type) const noexcept { - localsLattice().join( + bool updateLocal(Element& elem, Index i, Type type) const noexcept { + return localsLattice().join( locals(elem), Vector::SingletonElement(i, std::move(type))); } @@ -151,17 +151,43 @@ struct TransferFn : OverriddenVisitor { State lattice; typename State::Element* state = nullptr; - TransferFn(Module& wasm, Function* func) - : wasm(wasm), func(func), lattice(func) {} + // For each local, the set of blocks we may need to re-analyze when we update + // the constraint on the local. + std::vector> localDependents; + + // The set of basic blocks that may depend on the result of the current + // transfer. + std::unordered_set currDependents; + + TransferFn(Module& wasm, Function* func, CFG& cfg) + : wasm(wasm), func(func), lattice(func), + localDependents(func->getNumLocals()) { + // Initialize `localDependents`. Any block containing a `local.set l` may + // need to be re-analyzed whenever the constraint on `l` is updated. + auto numLocals = func->getNumLocals(); + std::vector> dependentSets(numLocals); + for (const auto& bb : cfg) { + for (const auto* inst : bb) { + if (auto set = inst->dynCast()) { + dependentSets[set->index].insert(&bb); + } + } + } + for (size_t i = 0, n = dependentSets.size(); i < n; ++i) { + localDependents[i] = std::vector( + dependentSets[i].begin(), dependentSets[i].end()); + } + } Type pop() noexcept { return lattice.pop(*state); } void push(Type type) noexcept { lattice.push(*state, type); } void clearStack() noexcept { lattice.clearStack(*state); } Type getLocal(Index i) noexcept { return lattice.getLocal(*state, i); } void updateLocal(Index i, Type type) noexcept { - // TODO: Collect possible successor blocks that might depend on an updated - // local requirement here. - return lattice.updateLocal(*state, i, type); + if (lattice.updateLocal(*state, i, type)) { + currDependents.insert(localDependents[i].begin(), + localDependents[i].end()); + } } void dumpState() { @@ -185,12 +211,18 @@ struct TransferFn : OverriddenVisitor { #endif // TYPE_GENERALIZING_DEBUG } - std::vector + std::unordered_set transfer(const BasicBlock& bb, typename State::Element& elem) noexcept { DBG(std::cerr << "transferring bb " << bb.getIndex() << "\n"); state = &elem; + // This is a backward analysis: The constraints on a type depend on how it - // will be used in the future. Traverse the basic block in reverse. + // will be used in the future. Traverse the basic block in reverse and + // return the predecessors as the dependent blocks. + assert(currDependents.empty()); + const auto& preds = bb.preds(); + currDependents.insert(preds.begin(), preds.end()); + dumpState(); if (bb.isExit()) { DBG(std::cerr << "visiting exit\n"); @@ -212,11 +244,7 @@ struct TransferFn : OverriddenVisitor { state = nullptr; // Return the blocks that may need to be re-analyzed. - const auto& preds = bb.preds(); - std::vector dependents; - dependents.reserve(preds.size()); - dependents.insert(dependents.end(), preds.begin(), preds.end()); - return dependents; + return std::move(currDependents); } void visitFunctionExit() { @@ -404,9 +432,9 @@ struct TypeGeneralizing : WalkerPass> { } void runOnFunction(Module* wasm, Function* func) override { - TransferFn txfn(*wasm, func); auto cfg = CFG::fromFunction(func); DBG(cfg.print(std::cerr)); + TransferFn txfn(*wasm, func, cfg); MonotoneCFGAnalyzer analyzer(txfn.lattice, txfn, cfg); analyzer.evaluate(); diff --git a/test/lit/passes/type-generalizing.wast b/test/lit/passes/type-generalizing.wast index a328ee7b255..a2d075ef10c 100644 --- a/test/lit/passes/type-generalizing.wast +++ b/test/lit/passes/type-generalizing.wast @@ -3,9 +3,9 @@ (module - ;; CHECK: (type $0 (func)) + ;; CHECK: (type $0 (func (result eqref))) - ;; CHECK: (type $1 (func (result eqref))) + ;; CHECK: (type $1 (func)) ;; CHECK: (type $2 (func (param anyref))) @@ -15,7 +15,7 @@ ;; CHECK: (type $5 (func (param eqref))) - ;; CHECK: (func $unconstrained (type $0) + ;; CHECK: (func $unconstrained (type $1) ;; CHECK-NEXT: (local $x i32) ;; CHECK-NEXT: (local $y anyref) ;; CHECK-NEXT: (local $z (anyref i32)) @@ -30,7 +30,7 @@ (local $z (anyref i32)) ) - ;; CHECK: (func $implicit-return (type $1) (result eqref) + ;; CHECK: (func $implicit-return (type $0) (result eqref) ;; CHECK-NEXT: (local $var eqref) ;; CHECK-NEXT: (local.get $var) ;; CHECK-NEXT: ) @@ -41,7 +41,7 @@ (local.get $var) ) - ;; CHECK: (func $implicit-return-unreachable (type $1) (result eqref) + ;; CHECK: (func $implicit-return-unreachable (type $0) (result eqref) ;; CHECK-NEXT: (local $var none) ;; CHECK-NEXT: (unreachable) ;; CHECK-NEXT: ) @@ -53,7 +53,7 @@ (local.get $var) ) - ;; CHECK: (func $if (type $1) (result eqref) + ;; CHECK: (func $if (type $0) (result eqref) ;; CHECK-NEXT: (local $x eqref) ;; CHECK-NEXT: (local $y eqref) ;; CHECK-NEXT: (if (result eqref) @@ -74,7 +74,7 @@ ) ) - ;; CHECK: (func $local-set (type $0) + ;; CHECK: (func $local-set (type $1) ;; CHECK-NEXT: (local $var anyref) ;; CHECK-NEXT: (local.set $var ;; CHECK-NEXT: (ref.i31 @@ -147,6 +147,63 @@ ) ) + ;; CHECK: (func $local-get-set-chain (type $0) (result eqref) + ;; CHECK-NEXT: (local $a eqref) + ;; CHECK-NEXT: (local $b eqref) + ;; CHECK-NEXT: (local $c eqref) + ;; CHECK-NEXT: (local.set $b + ;; CHECK-NEXT: (local.get $a) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (local.set $c + ;; CHECK-NEXT: (local.get $b) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (local.get $c) + ;; CHECK-NEXT: ) + (func $local-get-set-chain (result eqref) + (local $a i31ref) + (local $b i31ref) + (local $c i31ref) + ;; Require that typeof($a) <: typeof($b). + (local.set $b + (local.get $a) + ) + ;; Require that typeof($b) <: typeof($c). + (local.set $c + (local.get $b) + ) + ;; Require that typeof($c) <: eqref. + (local.get $c) + ) + + ;; CHECK: (func $local-get-set-chain-out-of-order (type $0) (result eqref) + ;; CHECK-NEXT: (local $a eqref) + ;; CHECK-NEXT: (local $b eqref) + ;; CHECK-NEXT: (local $c eqref) + ;; CHECK-NEXT: (local.set $c + ;; CHECK-NEXT: (local.get $b) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (local.set $b + ;; CHECK-NEXT: (local.get $a) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (local.get $c) + ;; CHECK-NEXT: ) + (func $local-get-set-chain-out-of-order (result eqref) + (local $a i31ref) + (local $b i31ref) + (local $c i31ref) + ;; Require that typeof($b) <: typeof($c). + (local.set $c + (local.get $b) + ) + ;; Require that typeof($a) <: typeof($b). We don't know until we evaluate the + ;; set above that this will constrain $a to eqref. + (local.set $b + (local.get $a) + ) + ;; Require that typeof($c) <: eqref. + (local.get $c) + ) + ;; CHECK: (func $local-tee (type $5) (param $dest eqref) ;; CHECK-NEXT: (local $var eqref) ;; CHECK-NEXT: (drop @@ -173,7 +230,7 @@ ) ) - ;; CHECK: (func $i31-get (type $0) + ;; CHECK: (func $i31-get (type $1) ;; CHECK-NEXT: (local $nullable i31ref) ;; CHECK-NEXT: (local $nonnullable (ref i31)) ;; CHECK-NEXT: (local.set $nonnullable From c7e5e88379130a80736bfaf5d885f3fb4bd4c50c Mon Sep 17 00:00:00 2001 From: Thomas Lively Date: Mon, 6 Nov 2023 17:06:36 -0800 Subject: [PATCH 5/9] fix unreachable tests --- test/lit/passes/type-generalizing.wast | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/lit/passes/type-generalizing.wast b/test/lit/passes/type-generalizing.wast index a2d075ef10c..eb6228a088a 100644 --- a/test/lit/passes/type-generalizing.wast +++ b/test/lit/passes/type-generalizing.wast @@ -42,7 +42,7 @@ ) ;; CHECK: (func $implicit-return-unreachable (type $0) (result eqref) - ;; CHECK-NEXT: (local $var none) + ;; CHECK-NEXT: (local $var anyref) ;; CHECK-NEXT: (unreachable) ;; CHECK-NEXT: ) (func $implicit-return-unreachable (result eqref) @@ -109,7 +109,7 @@ ) ;; CHECK: (func $local-get-set-unreachable (type $3) (param $dest i31ref) - ;; CHECK-NEXT: (local $var none) + ;; CHECK-NEXT: (local $var anyref) ;; CHECK-NEXT: (unreachable) ;; CHECK-NEXT: ) (func $local-get-set-unreachable (param $dest i31ref) From bf0dc7de98fd33f6578ce209cc2ed48a0d070cbd Mon Sep 17 00:00:00 2001 From: Thomas Lively Date: Mon, 6 Nov 2023 17:47:19 -0800 Subject: [PATCH 6/9] address comments --- src/passes/TypeGeneralizing.cpp | 11 +++++++---- test/lit/passes/type-generalizing.wast | 10 +++++++--- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/src/passes/TypeGeneralizing.cpp b/src/passes/TypeGeneralizing.cpp index fadcaf941a8..b6ff0f8b7d5 100644 --- a/src/passes/TypeGeneralizing.cpp +++ b/src/passes/TypeGeneralizing.cpp @@ -28,7 +28,7 @@ #include "wasm-traversal.h" #include "wasm.h" -#define TYPE_GENERALIZING_DEBUG 1 +#define TYPE_GENERALIZING_DEBUG 0 #if TYPE_GENERALIZING_DEBUG #define DBG(statement) statement @@ -248,8 +248,8 @@ struct TransferFn : OverriddenVisitor { } void visitFunctionExit() { - // We cannot change the types of results. Push requirements that - // the stack end up with the correct type. + // We cannot change the types of results. Push a requirement that the stack + // end up with the correct type. if (auto result = func->getResults(); result.isRef()) { push(result); } @@ -379,7 +379,10 @@ struct TransferFn : OverriddenVisitor { void visitTupleExtract(TupleExtract* curr) { WASM_UNREACHABLE("TODO"); } void visitRefI31(RefI31* curr) { pop(); } void visitI31Get(I31Get* curr) { - // Do not allow relaxing to nullable if the input is already non-nullable. + // Do not allow relaxing to nullable if the input is already non-nullable to + // avoid the engine having to do a null check it would not otherwise have + // had to do. This could prevent us from optimizing out a previous explicit + // null check in principle, but should not affect heap type casts. if (curr->i31->type.isNonNullable()) { push(Type(HeapType::i31, NonNullable)); } else { diff --git a/test/lit/passes/type-generalizing.wast b/test/lit/passes/type-generalizing.wast index eb6228a088a..bd0408aed45 100644 --- a/test/lit/passes/type-generalizing.wast +++ b/test/lit/passes/type-generalizing.wast @@ -1,4 +1,8 @@ ;; NOTE: Assertions have been generated by update_lit_checks.py --all-items and should not be edited. + +;; Run DCE first because type-generalizing does not analyze unreachable blocks +;; and could produce incorrect code if they were present. + ;; RUN: foreach %s %t wasm-opt --dce --experimental-type-generalizing -all -S -o - | filecheck %s (module @@ -25,7 +29,7 @@ ;; This non-ref local should be unmodified (local $x i32) ;; There is no constraint on the type of this local, so make it top. - (local $y anyref) + (local $y i31ref) ;; We cannot optimize tuple locals yet, so leave it unchanged. (local $z (anyref i32)) ) @@ -35,7 +39,7 @@ ;; CHECK-NEXT: (local.get $var) ;; CHECK-NEXT: ) (func $implicit-return (result eqref) - ;; This will be optimized, but only to eqref because of the constaint from the + ;; This will be optimized, but only to eqref because of the constraint from the ;; implicit return. (local $var i31ref) (local.get $var) @@ -46,7 +50,7 @@ ;; CHECK-NEXT: (unreachable) ;; CHECK-NEXT: ) (func $implicit-return-unreachable (result eqref) - ;; Now will optimize this all the way to anyref because we don't analyze + ;; We will optimize this all the way to anyref because we don't analyze ;; unreachable code. This would not validate if we didn't run DCE first. (local $var i31ref) (unreachable) From c6e818408e3be0f316c34f5debdeca221b5b51b8 Mon Sep 17 00:00:00 2001 From: Thomas Lively Date: Mon, 6 Nov 2023 17:50:21 -0800 Subject: [PATCH 7/9] fix lint --- src/passes/TypeGeneralizing.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/passes/TypeGeneralizing.cpp b/src/passes/TypeGeneralizing.cpp index b6ff0f8b7d5..dbedbb6046b 100644 --- a/src/passes/TypeGeneralizing.cpp +++ b/src/passes/TypeGeneralizing.cpp @@ -430,7 +430,7 @@ struct TypeGeneralizing : WalkerPass> { bool refinalize = false; bool isFunctionParallel() override { return true; } - std::unique_ptr create() { + std::unique_ptr create() override { return std::make_unique(); } From 714e353925b4795ae4074f2b1d0de58f4f6ad172 Mon Sep 17 00:00:00 2001 From: Thomas Lively Date: Tue, 7 Nov 2023 11:00:21 -0800 Subject: [PATCH 8/9] run DCE internally --- src/passes/TypeGeneralizing.cpp | 7 +++++++ test/lit/passes/type-generalizing.wast | 5 +---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/passes/TypeGeneralizing.cpp b/src/passes/TypeGeneralizing.cpp index dbedbb6046b..eff24625333 100644 --- a/src/passes/TypeGeneralizing.cpp +++ b/src/passes/TypeGeneralizing.cpp @@ -435,6 +435,13 @@ struct TypeGeneralizing : WalkerPass> { } void runOnFunction(Module* wasm, Function* func) override { + // First, remove unreachable code. If we didn't, the unreachable code could + // become invalid after this optimization because we do not materialize or + // analyze unreachable blocks. + PassRunner runner(getPassRunner()); + runner.add("dce"); + runner.runOnFunction(func); + auto cfg = CFG::fromFunction(func); DBG(cfg.print(std::cerr)); TransferFn txfn(*wasm, func, cfg); diff --git a/test/lit/passes/type-generalizing.wast b/test/lit/passes/type-generalizing.wast index bd0408aed45..4a9ca3274ee 100644 --- a/test/lit/passes/type-generalizing.wast +++ b/test/lit/passes/type-generalizing.wast @@ -1,9 +1,6 @@ ;; NOTE: Assertions have been generated by update_lit_checks.py --all-items and should not be edited. -;; Run DCE first because type-generalizing does not analyze unreachable blocks -;; and could produce incorrect code if they were present. - -;; RUN: foreach %s %t wasm-opt --dce --experimental-type-generalizing -all -S -o - | filecheck %s +;; RUN: foreach %s %t wasm-opt --experimental-type-generalizing -all -S -o - | filecheck %s (module From 3e286a06e410e42af63c36ff58156d03dd75b646 Mon Sep 17 00:00:00 2001 From: Thomas Lively Date: Tue, 7 Nov 2023 16:49:30 -0800 Subject: [PATCH 9/9] simplify i31.get --- src/passes/TypeGeneralizing.cpp | 12 +----------- test/lit/passes/type-generalizing.wast | 4 ++-- 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/src/passes/TypeGeneralizing.cpp b/src/passes/TypeGeneralizing.cpp index eff24625333..0d811aa520d 100644 --- a/src/passes/TypeGeneralizing.cpp +++ b/src/passes/TypeGeneralizing.cpp @@ -378,17 +378,7 @@ struct TransferFn : OverriddenVisitor { void visitTupleMake(TupleMake* curr) { WASM_UNREACHABLE("TODO"); } void visitTupleExtract(TupleExtract* curr) { WASM_UNREACHABLE("TODO"); } void visitRefI31(RefI31* curr) { pop(); } - void visitI31Get(I31Get* curr) { - // Do not allow relaxing to nullable if the input is already non-nullable to - // avoid the engine having to do a null check it would not otherwise have - // had to do. This could prevent us from optimizing out a previous explicit - // null check in principle, but should not affect heap type casts. - if (curr->i31->type.isNonNullable()) { - push(Type(HeapType::i31, NonNullable)); - } else { - push(Type(HeapType::i31, Nullable)); - } - } + void visitI31Get(I31Get* curr) { push(Type(HeapType::i31, Nullable)); } void visitCallRef(CallRef* curr) { WASM_UNREACHABLE("TODO"); } void visitRefTest(RefTest* curr) { WASM_UNREACHABLE("TODO"); } void visitRefCast(RefCast* curr) { WASM_UNREACHABLE("TODO"); } diff --git a/test/lit/passes/type-generalizing.wast b/test/lit/passes/type-generalizing.wast index 4a9ca3274ee..fed32772753 100644 --- a/test/lit/passes/type-generalizing.wast +++ b/test/lit/passes/type-generalizing.wast @@ -233,7 +233,7 @@ ;; CHECK: (func $i31-get (type $1) ;; CHECK-NEXT: (local $nullable i31ref) - ;; CHECK-NEXT: (local $nonnullable (ref i31)) + ;; CHECK-NEXT: (local $nonnullable i31ref) ;; CHECK-NEXT: (local.set $nonnullable ;; CHECK-NEXT: (ref.i31 ;; CHECK-NEXT: (i32.const 0) @@ -253,7 +253,7 @@ (func $i31-get ;; This must stay an i31ref. (local $nullable i31ref) - ;; This one could be relaxed to be nullable in principle, but we keep it non-nullable. + ;; We relax this one to be nullable i31ref as well. (local $nonnullable (ref i31)) ;; Initialize the non-nullable local for validation purposes. (local.set $nonnullable