diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/passes/CMakeLists.txt | 1 | ||||
-rw-r--r-- | src/passes/TypeGeneralizing.cpp | 475 | ||||
-rw-r--r-- | src/passes/pass.cpp | 3 | ||||
-rw-r--r-- | src/passes/passes.h | 1 | ||||
-rw-r--r-- | src/wasm-type.h | 3 | ||||
-rw-r--r-- | src/wasm/wasm-type.cpp | 14 |
6 files changed, 497 insertions, 0 deletions
diff --git a/src/passes/CMakeLists.txt b/src/passes/CMakeLists.txt index 2930898a9..ec57f108d 100644 --- a/src/passes/CMakeLists.txt +++ b/src/passes/CMakeLists.txt @@ -103,6 +103,7 @@ set(passes_SOURCES ReorderLocals.cpp ReReloop.cpp TrapMode.cpp + TypeGeneralizing.cpp TypeRefining.cpp TypeMerging.cpp TypeSSA.cpp diff --git a/src/passes/TypeGeneralizing.cpp b/src/passes/TypeGeneralizing.cpp new file mode 100644 index 000000000..0d811aa52 --- /dev/null +++ b/src/passes/TypeGeneralizing.cpp @@ -0,0 +1,475 @@ +/* + * Copyright 2023 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "analysis/cfg.h" +#include "analysis/lattice.h" +#include "analysis/lattices/inverted.h" +#include "analysis/lattices/shared.h" +#include "analysis/lattices/stack.h" +#include "analysis/lattices/tuple.h" +#include "analysis/lattices/valtype.h" +#include "analysis/lattices/vector.h" +#include "analysis/monotone-analyzer.h" +#include "ir/utils.h" +#include "pass.h" +#include "wasm-traversal.h" +#include "wasm.h" + +#define TYPE_GENERALIZING_DEBUG 0 + +#if TYPE_GENERALIZING_DEBUG +#define DBG(statement) statement +#else +#define DBG(statement) +#endif + +// Generalize the types of program locations as much as possible, both to +// eliminate unnecessarily refined types from the type section and (TODO) to +// weaken casts that cast to unnecessarily refined types. If the casts are +// weakened enough, they will be able to be removed by OptimizeInstructions. +// +// Perform a backward analysis tracking requirements on the types of program +// locations (currently just locals and stack values) to discover how much the +// type of each location can be generalized without breaking validation or +// changing program behavior. + +namespace wasm { + +namespace { + +using namespace analysis; + +// We will learn stricter and stricter requirements as we perform the analysis, +// so more specific types need to be higher up the lattice. +using TypeRequirement = Inverted<ValType>; + +// Record a type requirement for each local variable. Shared the requirements +// across basic blocks. +using LocalTypeRequirements = Shared<Vector<TypeRequirement>>; + +// The type requirements for each reference-typed value on the stack at a +// particular location. +using ValueStackTypeRequirements = Stack<TypeRequirement>; + +// The full lattice used for the analysis. +using StateLattice = + analysis::Tuple<LocalTypeRequirements, ValueStackTypeRequirements>; + +// Equip the state lattice with helpful accessors. +struct State : StateLattice { + using Element = StateLattice::Element; + + static constexpr int LocalsIndex = 0; + static constexpr int StackIndex = 1; + + State(Function* func) : StateLattice{Shared{initLocals(func)}, initStack()} {} + + void push(Element& elem, Type type) const noexcept { + stackLattice().push(stack(elem), std::move(type)); + } + + Type pop(Element& elem) const noexcept { + return stackLattice().pop(stack(elem)); + } + + void clearStack(Element& elem) const noexcept { + stack(elem) = stackLattice().getBottom(); + } + + const std::vector<Type>& getLocals(Element& elem) const noexcept { + return *locals(elem); + } + + const std::vector<Type>& getLocals() const noexcept { + return *locals(getBottom()); + } + + Type getLocal(Element& elem, Index i) const noexcept { + return getLocals(elem)[i]; + } + + bool updateLocal(Element& elem, Index i, Type type) const noexcept { + return localsLattice().join( + locals(elem), + Vector<TypeRequirement>::SingletonElement(i, std::move(type))); + } + +private: + static LocalTypeRequirements initLocals(Function* func) noexcept { + return Shared{Vector{Inverted{ValType{}}, func->getNumLocals()}}; + } + + static ValueStackTypeRequirements initStack() noexcept { + return Stack{Inverted{ValType{}}}; + } + + const LocalTypeRequirements& localsLattice() const noexcept { + return std::get<LocalsIndex>(lattices); + } + + const ValueStackTypeRequirements& stackLattice() const noexcept { + return std::get<StackIndex>(lattices); + } + + typename LocalTypeRequirements::Element& + locals(Element& elem) const noexcept { + return std::get<LocalsIndex>(elem); + } + + const typename LocalTypeRequirements::Element& + locals(const Element& elem) const noexcept { + return std::get<LocalsIndex>(elem); + } + + typename ValueStackTypeRequirements::Element& + stack(Element& elem) const noexcept { + return std::get<StackIndex>(elem); + } + + const typename ValueStackTypeRequirements::Element& + stack(const Element& elem) const noexcept { + return std::get<StackIndex>(elem); + } +}; + +struct TransferFn : OverriddenVisitor<TransferFn> { + Module& wasm; + Function* func; + State lattice; + typename State::Element* state = nullptr; + + // For each local, the set of blocks we may need to re-analyze when we update + // the constraint on the local. + std::vector<std::vector<const BasicBlock*>> localDependents; + + // The set of basic blocks that may depend on the result of the current + // transfer. + std::unordered_set<const BasicBlock*> currDependents; + + TransferFn(Module& wasm, Function* func, CFG& cfg) + : wasm(wasm), func(func), lattice(func), + localDependents(func->getNumLocals()) { + // Initialize `localDependents`. Any block containing a `local.set l` may + // need to be re-analyzed whenever the constraint on `l` is updated. + auto numLocals = func->getNumLocals(); + std::vector<std::unordered_set<const BasicBlock*>> dependentSets(numLocals); + for (const auto& bb : cfg) { + for (const auto* inst : bb) { + if (auto set = inst->dynCast<LocalSet>()) { + dependentSets[set->index].insert(&bb); + } + } + } + for (size_t i = 0, n = dependentSets.size(); i < n; ++i) { + localDependents[i] = std::vector<const BasicBlock*>( + dependentSets[i].begin(), dependentSets[i].end()); + } + } + + Type pop() noexcept { return lattice.pop(*state); } + void push(Type type) noexcept { lattice.push(*state, type); } + void clearStack() noexcept { lattice.clearStack(*state); } + Type getLocal(Index i) noexcept { return lattice.getLocal(*state, i); } + void updateLocal(Index i, Type type) noexcept { + if (lattice.updateLocal(*state, i, type)) { + currDependents.insert(localDependents[i].begin(), + localDependents[i].end()); + } + } + + void dumpState() { +#if TYPE_GENERALIZING_DEBUG + std::cerr << "locals: "; + for (size_t i = 0, n = lattice.getLocals(*state).size(); i < n; ++i) { + if (i != 0) { + std::cerr << ", "; + } + std::cerr << getLocal(i); + } + std::cerr << "\nstack: "; + auto& stack = std::get<1>(*state); + for (size_t i = 0, n = stack.size(); i < n; ++i) { + if (i != 0) { + std::cerr << ", "; + } + std::cerr << stack[i]; + } + std::cerr << "\n"; +#endif // TYPE_GENERALIZING_DEBUG + } + + std::unordered_set<const BasicBlock*> + transfer(const BasicBlock& bb, typename State::Element& elem) noexcept { + DBG(std::cerr << "transferring bb " << bb.getIndex() << "\n"); + state = &elem; + + // This is a backward analysis: The constraints on a type depend on how it + // will be used in the future. Traverse the basic block in reverse and + // return the predecessors as the dependent blocks. + assert(currDependents.empty()); + const auto& preds = bb.preds(); + currDependents.insert(preds.begin(), preds.end()); + + dumpState(); + if (bb.isExit()) { + DBG(std::cerr << "visiting exit\n"); + visitFunctionExit(); + dumpState(); + } + for (auto it = bb.rbegin(); it != bb.rend(); ++it) { + DBG(std::cerr << "visiting " << ShallowExpression{*it} << "\n"); + visit(*it); + dumpState(); + } + if (bb.isEntry()) { + DBG(std::cerr << "visiting entry\n"); + visitFunctionEntry(); + dumpState(); + } + DBG(std::cerr << "\n"); + + state = nullptr; + + // Return the blocks that may need to be re-analyzed. + return std::move(currDependents); + } + + void visitFunctionExit() { + // We cannot change the types of results. Push a requirement that the stack + // end up with the correct type. + if (auto result = func->getResults(); result.isRef()) { + push(result); + } + } + + void visitFunctionEntry() { + // We cannot change the types of parameters, so require that they have their + // original types. + Index i = 0; + Index numParams = func->getNumParams(); + Index numLocals = func->getNumLocals(); + for (; i < numParams; ++i) { + updateLocal(i, func->getLocalType(i)); + } + // We also cannot change the types of any other non-ref locals. For + // reference-typed locals, we cannot generalize beyond their top type. + for (Index i = numParams; i < numLocals; ++i) { + auto type = func->getLocalType(i); + // TODO: Support optimizing tuple locals. + if (type.isRef()) { + updateLocal(i, Type(type.getHeapType().getTop(), Nullable)); + } else { + updateLocal(i, type); + } + } + } + + void visitNop(Nop* curr) {} + void visitBlock(Block* curr) {} + void visitIf(If* curr) {} + void visitLoop(Loop* curr) {} + void visitBreak(Break* curr) { + // TODO: pop extra elements off stack, keeping only those at the top that + // will be sent along. + WASM_UNREACHABLE("TODO"); + } + + void visitSwitch(Switch* curr) { + // TODO: pop extra elements off stack, keeping only those at the top that + // will be sent along. + WASM_UNREACHABLE("TODO"); + } + + void visitCall(Call* curr) { + // TODO: pop ref types from results, push ref types from params + WASM_UNREACHABLE("TODO"); + } + + void visitCallIndirect(CallIndirect* curr) { + // TODO: pop ref types from results, push ref types from params + WASM_UNREACHABLE("TODO"); + } + + void visitLocalGet(LocalGet* curr) { + if (!curr->type.isRef()) { + return; + } + // Propagate the requirement on the local.get output to the local. + updateLocal(curr->index, pop()); + } + + void visitLocalSet(LocalSet* curr) { + if (!curr->value->type.isRef()) { + return; + } + if (curr->isTee()) { + // Same as the local.get. + updateLocal(curr->index, pop()); + } + // Propagate the requirement on the local to our input value. + push(getLocal(curr->index)); + } + + void visitGlobalGet(GlobalGet* curr) { WASM_UNREACHABLE("TODO"); } + void visitGlobalSet(GlobalSet* curr) { WASM_UNREACHABLE("TODO"); } + void visitLoad(Load* curr) { WASM_UNREACHABLE("TODO"); } + void visitStore(Store* curr) { WASM_UNREACHABLE("TODO"); } + void visitAtomicRMW(AtomicRMW* curr) { WASM_UNREACHABLE("TODO"); } + void visitAtomicCmpxchg(AtomicCmpxchg* curr) { WASM_UNREACHABLE("TODO"); } + void visitAtomicWait(AtomicWait* curr) { WASM_UNREACHABLE("TODO"); } + void visitAtomicNotify(AtomicNotify* curr) { WASM_UNREACHABLE("TODO"); } + void visitAtomicFence(AtomicFence* curr) { WASM_UNREACHABLE("TODO"); } + void visitSIMDExtract(SIMDExtract* curr) { WASM_UNREACHABLE("TODO"); } + void visitSIMDReplace(SIMDReplace* curr) { WASM_UNREACHABLE("TODO"); } + void visitSIMDShuffle(SIMDShuffle* curr) { WASM_UNREACHABLE("TODO"); } + void visitSIMDTernary(SIMDTernary* curr) { WASM_UNREACHABLE("TODO"); } + void visitSIMDShift(SIMDShift* curr) { WASM_UNREACHABLE("TODO"); } + void visitSIMDLoad(SIMDLoad* curr) { WASM_UNREACHABLE("TODO"); } + void visitSIMDLoadStoreLane(SIMDLoadStoreLane* curr) { + WASM_UNREACHABLE("TODO"); + } + void visitMemoryInit(MemoryInit* curr) { WASM_UNREACHABLE("TODO"); } + void visitDataDrop(DataDrop* curr) { WASM_UNREACHABLE("TODO"); } + void visitMemoryCopy(MemoryCopy* curr) { WASM_UNREACHABLE("TODO"); } + void visitMemoryFill(MemoryFill* curr) { WASM_UNREACHABLE("TODO"); } + void visitConst(Const* curr) {} + void visitUnary(Unary* curr) {} + void visitBinary(Binary* curr) {} + void visitSelect(Select* curr) { WASM_UNREACHABLE("TODO"); } + void visitDrop(Drop* curr) { + if (curr->type.isRef()) { + pop(); + } + } + void visitReturn(Return* curr) { WASM_UNREACHABLE("TODO"); } + void visitMemorySize(MemorySize* curr) { WASM_UNREACHABLE("TODO"); } + void visitMemoryGrow(MemoryGrow* curr) { WASM_UNREACHABLE("TODO"); } + void visitUnreachable(Unreachable* curr) { + // Require nothing about values flowing into an unreachable. + clearStack(); + } + void visitPop(Pop* curr) { WASM_UNREACHABLE("TODO"); } + void visitRefNull(RefNull* curr) { WASM_UNREACHABLE("TODO"); } + void visitRefIsNull(RefIsNull* curr) { WASM_UNREACHABLE("TODO"); } + void visitRefFunc(RefFunc* curr) { WASM_UNREACHABLE("TODO"); } + void visitRefEq(RefEq* curr) { WASM_UNREACHABLE("TODO"); } + void visitTableGet(TableGet* curr) { WASM_UNREACHABLE("TODO"); } + void visitTableSet(TableSet* curr) { WASM_UNREACHABLE("TODO"); } + void visitTableSize(TableSize* curr) { WASM_UNREACHABLE("TODO"); } + void visitTableGrow(TableGrow* curr) { WASM_UNREACHABLE("TODO"); } + void visitTableFill(TableFill* curr) { WASM_UNREACHABLE("TODO"); } + void visitTableCopy(TableCopy* curr) { WASM_UNREACHABLE("TODO"); } + void visitTry(Try* curr) { WASM_UNREACHABLE("TODO"); } + void visitThrow(Throw* curr) { WASM_UNREACHABLE("TODO"); } + void visitRethrow(Rethrow* curr) { WASM_UNREACHABLE("TODO"); } + void visitTupleMake(TupleMake* curr) { WASM_UNREACHABLE("TODO"); } + void visitTupleExtract(TupleExtract* curr) { WASM_UNREACHABLE("TODO"); } + void visitRefI31(RefI31* curr) { pop(); } + void visitI31Get(I31Get* curr) { push(Type(HeapType::i31, Nullable)); } + void visitCallRef(CallRef* curr) { WASM_UNREACHABLE("TODO"); } + void visitRefTest(RefTest* curr) { WASM_UNREACHABLE("TODO"); } + void visitRefCast(RefCast* curr) { WASM_UNREACHABLE("TODO"); } + void visitBrOn(BrOn* curr) { WASM_UNREACHABLE("TODO"); } + void visitStructNew(StructNew* curr) { WASM_UNREACHABLE("TODO"); } + void visitStructGet(StructGet* curr) { WASM_UNREACHABLE("TODO"); } + void visitStructSet(StructSet* curr) { WASM_UNREACHABLE("TODO"); } + void visitArrayNew(ArrayNew* curr) { WASM_UNREACHABLE("TODO"); } + void visitArrayNewData(ArrayNewData* curr) { WASM_UNREACHABLE("TODO"); } + void visitArrayNewElem(ArrayNewElem* curr) { WASM_UNREACHABLE("TODO"); } + void visitArrayNewFixed(ArrayNewFixed* curr) { WASM_UNREACHABLE("TODO"); } + void visitArrayGet(ArrayGet* curr) { WASM_UNREACHABLE("TODO"); } + void visitArraySet(ArraySet* curr) { WASM_UNREACHABLE("TODO"); } + void visitArrayLen(ArrayLen* curr) { WASM_UNREACHABLE("TODO"); } + void visitArrayCopy(ArrayCopy* curr) { WASM_UNREACHABLE("TODO"); } + void visitArrayFill(ArrayFill* curr) { WASM_UNREACHABLE("TODO"); } + void visitArrayInitData(ArrayInitData* curr) { WASM_UNREACHABLE("TODO"); } + void visitArrayInitElem(ArrayInitElem* curr) { WASM_UNREACHABLE("TODO"); } + void visitRefAs(RefAs* curr) { WASM_UNREACHABLE("TODO"); } + void visitStringNew(StringNew* curr) { WASM_UNREACHABLE("TODO"); } + void visitStringConst(StringConst* curr) { WASM_UNREACHABLE("TODO"); } + void visitStringMeasure(StringMeasure* curr) { WASM_UNREACHABLE("TODO"); } + void visitStringEncode(StringEncode* curr) { WASM_UNREACHABLE("TODO"); } + void visitStringConcat(StringConcat* curr) { WASM_UNREACHABLE("TODO"); } + void visitStringEq(StringEq* curr) { WASM_UNREACHABLE("TODO"); } + void visitStringAs(StringAs* curr) { WASM_UNREACHABLE("TODO"); } + void visitStringWTF8Advance(StringWTF8Advance* curr) { + WASM_UNREACHABLE("TODO"); + } + void visitStringWTF16Get(StringWTF16Get* curr) { WASM_UNREACHABLE("TODO"); } + void visitStringIterNext(StringIterNext* curr) { WASM_UNREACHABLE("TODO"); } + void visitStringIterMove(StringIterMove* curr) { WASM_UNREACHABLE("TODO"); } + void visitStringSliceWTF(StringSliceWTF* curr) { WASM_UNREACHABLE("TODO"); } + void visitStringSliceIter(StringSliceIter* curr) { WASM_UNREACHABLE("TODO"); } +}; + +struct TypeGeneralizing : WalkerPass<PostWalker<TypeGeneralizing>> { + std::vector<Type> localTypes; + bool refinalize = false; + + bool isFunctionParallel() override { return true; } + std::unique_ptr<Pass> create() override { + return std::make_unique<TypeGeneralizing>(); + } + + void runOnFunction(Module* wasm, Function* func) override { + // First, remove unreachable code. If we didn't, the unreachable code could + // become invalid after this optimization because we do not materialize or + // analyze unreachable blocks. + PassRunner runner(getPassRunner()); + runner.add("dce"); + runner.runOnFunction(func); + + auto cfg = CFG::fromFunction(func); + DBG(cfg.print(std::cerr)); + TransferFn txfn(*wasm, func, cfg); + MonotoneCFGAnalyzer analyzer(txfn.lattice, txfn, cfg); + analyzer.evaluate(); + + // Optimize local types. TODO: Optimize casts as well. + localTypes = txfn.lattice.getLocals(); + auto numParams = func->getNumParams(); + for (Index i = numParams; i < localTypes.size(); ++i) { + func->vars[i - numParams] = localTypes[i]; + } + + // Update gets and sets accordingly. + super::runOnFunction(wasm, func); + + if (refinalize) { + ReFinalize().walkFunctionInModule(func, wasm); + } + } + + void visitLocalGet(LocalGet* curr) { + if (localTypes[curr->index] != curr->type) { + curr->type = localTypes[curr->index]; + refinalize = true; + } + } + + void visitLocalSet(LocalSet* curr) { + if (curr->isTee() && localTypes[curr->index] != curr->type) { + curr->type = localTypes[curr->index]; + refinalize = true; + } + } +}; + +} // anonymous namespace + +Pass* createTypeGeneralizingPass() { return new TypeGeneralizing; } + +} // namespace wasm diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp index a7f85e595..b2b9551c9 100644 --- a/src/passes/pass.cpp +++ b/src/passes/pass.cpp @@ -512,6 +512,9 @@ void PassRegistry::registerPasses() { registerTestPass("catch-pop-fixup", "fixup nested pops within catches", createCatchPopFixupPass); + registerTestPass("experimental-type-generalizing", + "generalize types (not yet sound)", + createTypeGeneralizingPass); } void PassRunner::addIfNoDWARFIssues(std::string passName) { diff --git a/src/passes/passes.h b/src/passes/passes.h index 041e828a7..3ce2af4fd 100644 --- a/src/passes/passes.h +++ b/src/passes/passes.h @@ -155,6 +155,7 @@ Pass* createSSAifyNoMergePass(); Pass* createTrapModeClamp(); Pass* createTrapModeJS(); Pass* createTupleOptimizationPass(); +Pass* createTypeGeneralizingPass(); Pass* createTypeRefiningPass(); Pass* createTypeFinalizingPass(); Pass* createTypeMergingPass(); diff --git a/src/wasm-type.h b/src/wasm-type.h index 0061b9626..573cd9102 100644 --- a/src/wasm-type.h +++ b/src/wasm-type.h @@ -390,6 +390,9 @@ public: // Get the bottom heap type for this heap type's hierarchy. BasicHeapType getBottom() const; + // Get the top heap type for this heap type's hierarchy. + BasicHeapType getTop() const; + // Get the recursion group for this non-basic type. RecGroup getRecGroup() const; size_t getRecGroupIndex() const; diff --git a/src/wasm/wasm-type.cpp b/src/wasm/wasm-type.cpp index cab68d00d..dce0eb645 100644 --- a/src/wasm/wasm-type.cpp +++ b/src/wasm/wasm-type.cpp @@ -1386,6 +1386,20 @@ HeapType::BasicHeapType HeapType::getBottom() const { WASM_UNREACHABLE("unexpected kind"); } +HeapType::BasicHeapType HeapType::getTop() const { + switch (getBottom()) { + case none: + return any; + case nofunc: + return func; + case noext: + return ext; + default: + break; + } + WASM_UNREACHABLE("unexpected type"); +} + bool HeapType::isSubType(HeapType left, HeapType right) { // As an optimization, in the common case do not even construct a SubTyper. if (left == right) { |