summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorThomas Lively <tlively@google.com>2023-11-08 22:20:15 +0100
committerGitHub <noreply@github.com>2023-11-08 13:20:15 -0800
commitd6df91bcd0d9a67c63e336ae05f095cbcbf68df7 (patch)
tree02f8b9c3dc21595bde5f45a1cb330cc8bcfe0c30
parent784960180eac208a34eb33415267d977034971df (diff)
downloadbinaryen-d6df91bcd0d9a67c63e336ae05f095cbcbf68df7.tar.gz
binaryen-d6df91bcd0d9a67c63e336ae05f095cbcbf68df7.tar.bz2
binaryen-d6df91bcd0d9a67c63e336ae05f095cbcbf68df7.zip
[analysis] Add an experimental TypeGeneralizing optimization (#6080)
This new optimization will eventually weaken casts by generalizing (i.e. un-refining) their output types. If a cast is weakened enough that its output type is a supertype of its input type, the cast will be able to be removed by OptimizeInstructions. Unlike refining cast inputs, generalizing cast outputs can break module validation. For example, if the result of a cast is stored to a local and the cast is weakened enough that its output type is no longer a subtype of that local's type, then the local.set after the cast will no longer validate. To avoid this validation failure, this optimization would have to generalize the type of the local as well. In general, the more we can generalize the types of program locations, the more we can weaken casts of values that flow into those locations. This initial implementation only generalizes the types of locals and does not actually weaken casts yet. It serves as a proof of concept for the analysis required to perform the full optimization, though. The analysis uses the new analysis framework to perform a reverse analysis tracking type requirements for each local and reference-typed stack value in a function. Planned and potential future work includes: - Implementing the transfer function for all kinds of expressions. - Tracking requirements on the dynamic types of each location to generalize allocations as well. - Making the analysis interprocedural and generalizing the types of more program locations. - Optimizing tuple-typed locations. - Generalizing only those locations necessary to eliminate at least one cast (although this would make the anlysis bidirectional, so it is probably better left to separate passes).
-rw-r--r--src/passes/CMakeLists.txt1
-rw-r--r--src/passes/TypeGeneralizing.cpp475
-rw-r--r--src/passes/pass.cpp3
-rw-r--r--src/passes/passes.h1
-rw-r--r--src/wasm-type.h3
-rw-r--r--src/wasm/wasm-type.cpp14
-rw-r--r--test/lit/passes/type-generalizing.wast277
7 files changed, 774 insertions, 0 deletions
diff --git a/src/passes/CMakeLists.txt b/src/passes/CMakeLists.txt
index 2930898a9..ec57f108d 100644
--- a/src/passes/CMakeLists.txt
+++ b/src/passes/CMakeLists.txt
@@ -103,6 +103,7 @@ set(passes_SOURCES
ReorderLocals.cpp
ReReloop.cpp
TrapMode.cpp
+ TypeGeneralizing.cpp
TypeRefining.cpp
TypeMerging.cpp
TypeSSA.cpp
diff --git a/src/passes/TypeGeneralizing.cpp b/src/passes/TypeGeneralizing.cpp
new file mode 100644
index 000000000..0d811aa52
--- /dev/null
+++ b/src/passes/TypeGeneralizing.cpp
@@ -0,0 +1,475 @@
+/*
+ * Copyright 2023 WebAssembly Community Group participants
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "analysis/cfg.h"
+#include "analysis/lattice.h"
+#include "analysis/lattices/inverted.h"
+#include "analysis/lattices/shared.h"
+#include "analysis/lattices/stack.h"
+#include "analysis/lattices/tuple.h"
+#include "analysis/lattices/valtype.h"
+#include "analysis/lattices/vector.h"
+#include "analysis/monotone-analyzer.h"
+#include "ir/utils.h"
+#include "pass.h"
+#include "wasm-traversal.h"
+#include "wasm.h"
+
+#define TYPE_GENERALIZING_DEBUG 0
+
+#if TYPE_GENERALIZING_DEBUG
+#define DBG(statement) statement
+#else
+#define DBG(statement)
+#endif
+
+// Generalize the types of program locations as much as possible, both to
+// eliminate unnecessarily refined types from the type section and (TODO) to
+// weaken casts that cast to unnecessarily refined types. If the casts are
+// weakened enough, they will be able to be removed by OptimizeInstructions.
+//
+// Perform a backward analysis tracking requirements on the types of program
+// locations (currently just locals and stack values) to discover how much the
+// type of each location can be generalized without breaking validation or
+// changing program behavior.
+
+namespace wasm {
+
+namespace {
+
+using namespace analysis;
+
+// We will learn stricter and stricter requirements as we perform the analysis,
+// so more specific types need to be higher up the lattice.
+using TypeRequirement = Inverted<ValType>;
+
+// Record a type requirement for each local variable. Shared the requirements
+// across basic blocks.
+using LocalTypeRequirements = Shared<Vector<TypeRequirement>>;
+
+// The type requirements for each reference-typed value on the stack at a
+// particular location.
+using ValueStackTypeRequirements = Stack<TypeRequirement>;
+
+// The full lattice used for the analysis.
+using StateLattice =
+ analysis::Tuple<LocalTypeRequirements, ValueStackTypeRequirements>;
+
+// Equip the state lattice with helpful accessors.
+struct State : StateLattice {
+ using Element = StateLattice::Element;
+
+ static constexpr int LocalsIndex = 0;
+ static constexpr int StackIndex = 1;
+
+ State(Function* func) : StateLattice{Shared{initLocals(func)}, initStack()} {}
+
+ void push(Element& elem, Type type) const noexcept {
+ stackLattice().push(stack(elem), std::move(type));
+ }
+
+ Type pop(Element& elem) const noexcept {
+ return stackLattice().pop(stack(elem));
+ }
+
+ void clearStack(Element& elem) const noexcept {
+ stack(elem) = stackLattice().getBottom();
+ }
+
+ const std::vector<Type>& getLocals(Element& elem) const noexcept {
+ return *locals(elem);
+ }
+
+ const std::vector<Type>& getLocals() const noexcept {
+ return *locals(getBottom());
+ }
+
+ Type getLocal(Element& elem, Index i) const noexcept {
+ return getLocals(elem)[i];
+ }
+
+ bool updateLocal(Element& elem, Index i, Type type) const noexcept {
+ return localsLattice().join(
+ locals(elem),
+ Vector<TypeRequirement>::SingletonElement(i, std::move(type)));
+ }
+
+private:
+ static LocalTypeRequirements initLocals(Function* func) noexcept {
+ return Shared{Vector{Inverted{ValType{}}, func->getNumLocals()}};
+ }
+
+ static ValueStackTypeRequirements initStack() noexcept {
+ return Stack{Inverted{ValType{}}};
+ }
+
+ const LocalTypeRequirements& localsLattice() const noexcept {
+ return std::get<LocalsIndex>(lattices);
+ }
+
+ const ValueStackTypeRequirements& stackLattice() const noexcept {
+ return std::get<StackIndex>(lattices);
+ }
+
+ typename LocalTypeRequirements::Element&
+ locals(Element& elem) const noexcept {
+ return std::get<LocalsIndex>(elem);
+ }
+
+ const typename LocalTypeRequirements::Element&
+ locals(const Element& elem) const noexcept {
+ return std::get<LocalsIndex>(elem);
+ }
+
+ typename ValueStackTypeRequirements::Element&
+ stack(Element& elem) const noexcept {
+ return std::get<StackIndex>(elem);
+ }
+
+ const typename ValueStackTypeRequirements::Element&
+ stack(const Element& elem) const noexcept {
+ return std::get<StackIndex>(elem);
+ }
+};
+
+struct TransferFn : OverriddenVisitor<TransferFn> {
+ Module& wasm;
+ Function* func;
+ State lattice;
+ typename State::Element* state = nullptr;
+
+ // For each local, the set of blocks we may need to re-analyze when we update
+ // the constraint on the local.
+ std::vector<std::vector<const BasicBlock*>> localDependents;
+
+ // The set of basic blocks that may depend on the result of the current
+ // transfer.
+ std::unordered_set<const BasicBlock*> currDependents;
+
+ TransferFn(Module& wasm, Function* func, CFG& cfg)
+ : wasm(wasm), func(func), lattice(func),
+ localDependents(func->getNumLocals()) {
+ // Initialize `localDependents`. Any block containing a `local.set l` may
+ // need to be re-analyzed whenever the constraint on `l` is updated.
+ auto numLocals = func->getNumLocals();
+ std::vector<std::unordered_set<const BasicBlock*>> dependentSets(numLocals);
+ for (const auto& bb : cfg) {
+ for (const auto* inst : bb) {
+ if (auto set = inst->dynCast<LocalSet>()) {
+ dependentSets[set->index].insert(&bb);
+ }
+ }
+ }
+ for (size_t i = 0, n = dependentSets.size(); i < n; ++i) {
+ localDependents[i] = std::vector<const BasicBlock*>(
+ dependentSets[i].begin(), dependentSets[i].end());
+ }
+ }
+
+ Type pop() noexcept { return lattice.pop(*state); }
+ void push(Type type) noexcept { lattice.push(*state, type); }
+ void clearStack() noexcept { lattice.clearStack(*state); }
+ Type getLocal(Index i) noexcept { return lattice.getLocal(*state, i); }
+ void updateLocal(Index i, Type type) noexcept {
+ if (lattice.updateLocal(*state, i, type)) {
+ currDependents.insert(localDependents[i].begin(),
+ localDependents[i].end());
+ }
+ }
+
+ void dumpState() {
+#if TYPE_GENERALIZING_DEBUG
+ std::cerr << "locals: ";
+ for (size_t i = 0, n = lattice.getLocals(*state).size(); i < n; ++i) {
+ if (i != 0) {
+ std::cerr << ", ";
+ }
+ std::cerr << getLocal(i);
+ }
+ std::cerr << "\nstack: ";
+ auto& stack = std::get<1>(*state);
+ for (size_t i = 0, n = stack.size(); i < n; ++i) {
+ if (i != 0) {
+ std::cerr << ", ";
+ }
+ std::cerr << stack[i];
+ }
+ std::cerr << "\n";
+#endif // TYPE_GENERALIZING_DEBUG
+ }
+
+ std::unordered_set<const BasicBlock*>
+ transfer(const BasicBlock& bb, typename State::Element& elem) noexcept {
+ DBG(std::cerr << "transferring bb " << bb.getIndex() << "\n");
+ state = &elem;
+
+ // This is a backward analysis: The constraints on a type depend on how it
+ // will be used in the future. Traverse the basic block in reverse and
+ // return the predecessors as the dependent blocks.
+ assert(currDependents.empty());
+ const auto& preds = bb.preds();
+ currDependents.insert(preds.begin(), preds.end());
+
+ dumpState();
+ if (bb.isExit()) {
+ DBG(std::cerr << "visiting exit\n");
+ visitFunctionExit();
+ dumpState();
+ }
+ for (auto it = bb.rbegin(); it != bb.rend(); ++it) {
+ DBG(std::cerr << "visiting " << ShallowExpression{*it} << "\n");
+ visit(*it);
+ dumpState();
+ }
+ if (bb.isEntry()) {
+ DBG(std::cerr << "visiting entry\n");
+ visitFunctionEntry();
+ dumpState();
+ }
+ DBG(std::cerr << "\n");
+
+ state = nullptr;
+
+ // Return the blocks that may need to be re-analyzed.
+ return std::move(currDependents);
+ }
+
+ void visitFunctionExit() {
+ // We cannot change the types of results. Push a requirement that the stack
+ // end up with the correct type.
+ if (auto result = func->getResults(); result.isRef()) {
+ push(result);
+ }
+ }
+
+ void visitFunctionEntry() {
+ // We cannot change the types of parameters, so require that they have their
+ // original types.
+ Index i = 0;
+ Index numParams = func->getNumParams();
+ Index numLocals = func->getNumLocals();
+ for (; i < numParams; ++i) {
+ updateLocal(i, func->getLocalType(i));
+ }
+ // We also cannot change the types of any other non-ref locals. For
+ // reference-typed locals, we cannot generalize beyond their top type.
+ for (Index i = numParams; i < numLocals; ++i) {
+ auto type = func->getLocalType(i);
+ // TODO: Support optimizing tuple locals.
+ if (type.isRef()) {
+ updateLocal(i, Type(type.getHeapType().getTop(), Nullable));
+ } else {
+ updateLocal(i, type);
+ }
+ }
+ }
+
+ void visitNop(Nop* curr) {}
+ void visitBlock(Block* curr) {}
+ void visitIf(If* curr) {}
+ void visitLoop(Loop* curr) {}
+ void visitBreak(Break* curr) {
+ // TODO: pop extra elements off stack, keeping only those at the top that
+ // will be sent along.
+ WASM_UNREACHABLE("TODO");
+ }
+
+ void visitSwitch(Switch* curr) {
+ // TODO: pop extra elements off stack, keeping only those at the top that
+ // will be sent along.
+ WASM_UNREACHABLE("TODO");
+ }
+
+ void visitCall(Call* curr) {
+ // TODO: pop ref types from results, push ref types from params
+ WASM_UNREACHABLE("TODO");
+ }
+
+ void visitCallIndirect(CallIndirect* curr) {
+ // TODO: pop ref types from results, push ref types from params
+ WASM_UNREACHABLE("TODO");
+ }
+
+ void visitLocalGet(LocalGet* curr) {
+ if (!curr->type.isRef()) {
+ return;
+ }
+ // Propagate the requirement on the local.get output to the local.
+ updateLocal(curr->index, pop());
+ }
+
+ void visitLocalSet(LocalSet* curr) {
+ if (!curr->value->type.isRef()) {
+ return;
+ }
+ if (curr->isTee()) {
+ // Same as the local.get.
+ updateLocal(curr->index, pop());
+ }
+ // Propagate the requirement on the local to our input value.
+ push(getLocal(curr->index));
+ }
+
+ void visitGlobalGet(GlobalGet* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitGlobalSet(GlobalSet* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitLoad(Load* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitStore(Store* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitAtomicRMW(AtomicRMW* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitAtomicCmpxchg(AtomicCmpxchg* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitAtomicWait(AtomicWait* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitAtomicNotify(AtomicNotify* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitAtomicFence(AtomicFence* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitSIMDExtract(SIMDExtract* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitSIMDReplace(SIMDReplace* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitSIMDShuffle(SIMDShuffle* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitSIMDTernary(SIMDTernary* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitSIMDShift(SIMDShift* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitSIMDLoad(SIMDLoad* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitSIMDLoadStoreLane(SIMDLoadStoreLane* curr) {
+ WASM_UNREACHABLE("TODO");
+ }
+ void visitMemoryInit(MemoryInit* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitDataDrop(DataDrop* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitMemoryCopy(MemoryCopy* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitMemoryFill(MemoryFill* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitConst(Const* curr) {}
+ void visitUnary(Unary* curr) {}
+ void visitBinary(Binary* curr) {}
+ void visitSelect(Select* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitDrop(Drop* curr) {
+ if (curr->type.isRef()) {
+ pop();
+ }
+ }
+ void visitReturn(Return* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitMemorySize(MemorySize* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitMemoryGrow(MemoryGrow* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitUnreachable(Unreachable* curr) {
+ // Require nothing about values flowing into an unreachable.
+ clearStack();
+ }
+ void visitPop(Pop* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitRefNull(RefNull* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitRefIsNull(RefIsNull* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitRefFunc(RefFunc* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitRefEq(RefEq* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitTableGet(TableGet* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitTableSet(TableSet* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitTableSize(TableSize* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitTableGrow(TableGrow* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitTableFill(TableFill* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitTableCopy(TableCopy* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitTry(Try* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitThrow(Throw* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitRethrow(Rethrow* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitTupleMake(TupleMake* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitTupleExtract(TupleExtract* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitRefI31(RefI31* curr) { pop(); }
+ void visitI31Get(I31Get* curr) { push(Type(HeapType::i31, Nullable)); }
+ void visitCallRef(CallRef* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitRefTest(RefTest* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitRefCast(RefCast* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitBrOn(BrOn* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitStructNew(StructNew* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitStructGet(StructGet* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitStructSet(StructSet* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitArrayNew(ArrayNew* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitArrayNewData(ArrayNewData* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitArrayNewElem(ArrayNewElem* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitArrayNewFixed(ArrayNewFixed* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitArrayGet(ArrayGet* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitArraySet(ArraySet* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitArrayLen(ArrayLen* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitArrayCopy(ArrayCopy* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitArrayFill(ArrayFill* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitArrayInitData(ArrayInitData* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitArrayInitElem(ArrayInitElem* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitRefAs(RefAs* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitStringNew(StringNew* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitStringConst(StringConst* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitStringMeasure(StringMeasure* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitStringEncode(StringEncode* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitStringConcat(StringConcat* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitStringEq(StringEq* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitStringAs(StringAs* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitStringWTF8Advance(StringWTF8Advance* curr) {
+ WASM_UNREACHABLE("TODO");
+ }
+ void visitStringWTF16Get(StringWTF16Get* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitStringIterNext(StringIterNext* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitStringIterMove(StringIterMove* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitStringSliceWTF(StringSliceWTF* curr) { WASM_UNREACHABLE("TODO"); }
+ void visitStringSliceIter(StringSliceIter* curr) { WASM_UNREACHABLE("TODO"); }
+};
+
+struct TypeGeneralizing : WalkerPass<PostWalker<TypeGeneralizing>> {
+ std::vector<Type> localTypes;
+ bool refinalize = false;
+
+ bool isFunctionParallel() override { return true; }
+ std::unique_ptr<Pass> create() override {
+ return std::make_unique<TypeGeneralizing>();
+ }
+
+ void runOnFunction(Module* wasm, Function* func) override {
+ // First, remove unreachable code. If we didn't, the unreachable code could
+ // become invalid after this optimization because we do not materialize or
+ // analyze unreachable blocks.
+ PassRunner runner(getPassRunner());
+ runner.add("dce");
+ runner.runOnFunction(func);
+
+ auto cfg = CFG::fromFunction(func);
+ DBG(cfg.print(std::cerr));
+ TransferFn txfn(*wasm, func, cfg);
+ MonotoneCFGAnalyzer analyzer(txfn.lattice, txfn, cfg);
+ analyzer.evaluate();
+
+ // Optimize local types. TODO: Optimize casts as well.
+ localTypes = txfn.lattice.getLocals();
+ auto numParams = func->getNumParams();
+ for (Index i = numParams; i < localTypes.size(); ++i) {
+ func->vars[i - numParams] = localTypes[i];
+ }
+
+ // Update gets and sets accordingly.
+ super::runOnFunction(wasm, func);
+
+ if (refinalize) {
+ ReFinalize().walkFunctionInModule(func, wasm);
+ }
+ }
+
+ void visitLocalGet(LocalGet* curr) {
+ if (localTypes[curr->index] != curr->type) {
+ curr->type = localTypes[curr->index];
+ refinalize = true;
+ }
+ }
+
+ void visitLocalSet(LocalSet* curr) {
+ if (curr->isTee() && localTypes[curr->index] != curr->type) {
+ curr->type = localTypes[curr->index];
+ refinalize = true;
+ }
+ }
+};
+
+} // anonymous namespace
+
+Pass* createTypeGeneralizingPass() { return new TypeGeneralizing; }
+
+} // namespace wasm
diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp
index a7f85e595..b2b9551c9 100644
--- a/src/passes/pass.cpp
+++ b/src/passes/pass.cpp
@@ -512,6 +512,9 @@ void PassRegistry::registerPasses() {
registerTestPass("catch-pop-fixup",
"fixup nested pops within catches",
createCatchPopFixupPass);
+ registerTestPass("experimental-type-generalizing",
+ "generalize types (not yet sound)",
+ createTypeGeneralizingPass);
}
void PassRunner::addIfNoDWARFIssues(std::string passName) {
diff --git a/src/passes/passes.h b/src/passes/passes.h
index 041e828a7..3ce2af4fd 100644
--- a/src/passes/passes.h
+++ b/src/passes/passes.h
@@ -155,6 +155,7 @@ Pass* createSSAifyNoMergePass();
Pass* createTrapModeClamp();
Pass* createTrapModeJS();
Pass* createTupleOptimizationPass();
+Pass* createTypeGeneralizingPass();
Pass* createTypeRefiningPass();
Pass* createTypeFinalizingPass();
Pass* createTypeMergingPass();
diff --git a/src/wasm-type.h b/src/wasm-type.h
index 0061b9626..573cd9102 100644
--- a/src/wasm-type.h
+++ b/src/wasm-type.h
@@ -390,6 +390,9 @@ public:
// Get the bottom heap type for this heap type's hierarchy.
BasicHeapType getBottom() const;
+ // Get the top heap type for this heap type's hierarchy.
+ BasicHeapType getTop() const;
+
// Get the recursion group for this non-basic type.
RecGroup getRecGroup() const;
size_t getRecGroupIndex() const;
diff --git a/src/wasm/wasm-type.cpp b/src/wasm/wasm-type.cpp
index cab68d00d..dce0eb645 100644
--- a/src/wasm/wasm-type.cpp
+++ b/src/wasm/wasm-type.cpp
@@ -1386,6 +1386,20 @@ HeapType::BasicHeapType HeapType::getBottom() const {
WASM_UNREACHABLE("unexpected kind");
}
+HeapType::BasicHeapType HeapType::getTop() const {
+ switch (getBottom()) {
+ case none:
+ return any;
+ case nofunc:
+ return func;
+ case noext:
+ return ext;
+ default:
+ break;
+ }
+ WASM_UNREACHABLE("unexpected type");
+}
+
bool HeapType::isSubType(HeapType left, HeapType right) {
// As an optimization, in the common case do not even construct a SubTyper.
if (left == right) {
diff --git a/test/lit/passes/type-generalizing.wast b/test/lit/passes/type-generalizing.wast
new file mode 100644
index 000000000..fed327727
--- /dev/null
+++ b/test/lit/passes/type-generalizing.wast
@@ -0,0 +1,277 @@
+;; NOTE: Assertions have been generated by update_lit_checks.py --all-items and should not be edited.
+
+;; RUN: foreach %s %t wasm-opt --experimental-type-generalizing -all -S -o - | filecheck %s
+
+(module
+
+ ;; CHECK: (type $0 (func (result eqref)))
+
+ ;; CHECK: (type $1 (func))
+
+ ;; CHECK: (type $2 (func (param anyref)))
+
+ ;; CHECK: (type $3 (func (param i31ref)))
+
+ ;; CHECK: (type $4 (func (param anyref eqref)))
+
+ ;; CHECK: (type $5 (func (param eqref)))
+
+ ;; CHECK: (func $unconstrained (type $1)
+ ;; CHECK-NEXT: (local $x i32)
+ ;; CHECK-NEXT: (local $y anyref)
+ ;; CHECK-NEXT: (local $z (anyref i32))
+ ;; CHECK-NEXT: (nop)
+ ;; CHECK-NEXT: )
+ (func $unconstrained
+ ;; This non-ref local should be unmodified
+ (local $x i32)
+ ;; There is no constraint on the type of this local, so make it top.
+ (local $y i31ref)
+ ;; We cannot optimize tuple locals yet, so leave it unchanged.
+ (local $z (anyref i32))
+ )
+
+ ;; CHECK: (func $implicit-return (type $0) (result eqref)
+ ;; CHECK-NEXT: (local $var eqref)
+ ;; CHECK-NEXT: (local.get $var)
+ ;; CHECK-NEXT: )
+ (func $implicit-return (result eqref)
+ ;; This will be optimized, but only to eqref because of the constraint from the
+ ;; implicit return.
+ (local $var i31ref)
+ (local.get $var)
+ )
+
+ ;; CHECK: (func $implicit-return-unreachable (type $0) (result eqref)
+ ;; CHECK-NEXT: (local $var anyref)
+ ;; CHECK-NEXT: (unreachable)
+ ;; CHECK-NEXT: )
+ (func $implicit-return-unreachable (result eqref)
+ ;; We will optimize this all the way to anyref because we don't analyze
+ ;; unreachable code. This would not validate if we didn't run DCE first.
+ (local $var i31ref)
+ (unreachable)
+ (local.get $var)
+ )
+
+ ;; CHECK: (func $if (type $0) (result eqref)
+ ;; CHECK-NEXT: (local $x eqref)
+ ;; CHECK-NEXT: (local $y eqref)
+ ;; CHECK-NEXT: (if (result eqref)
+ ;; CHECK-NEXT: (i32.const 0)
+ ;; CHECK-NEXT: (local.get $x)
+ ;; CHECK-NEXT: (local.get $y)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ (func $if (result (eqref))
+ (local $x i31ref)
+ (local $y i31ref)
+ (if (result i31ref)
+ (i32.const 0)
+ ;; Require that typeof($x) <: eqref.
+ (local.get $x)
+ ;; Require that typeof($y) <: eqref.
+ (local.get $y)
+ )
+ )
+
+ ;; CHECK: (func $local-set (type $1)
+ ;; CHECK-NEXT: (local $var anyref)
+ ;; CHECK-NEXT: (local.set $var
+ ;; CHECK-NEXT: (ref.i31
+ ;; CHECK-NEXT: (i32.const 0)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ (func $local-set
+ ;; This will be optimized to anyref.
+ (local $var i31ref)
+ ;; Require that (ref i31) <: typeof($var).
+ (local.set $var
+ (i31.new
+ (i32.const 0)
+ )
+ )
+ )
+
+ ;; CHECK: (func $local-get-set (type $2) (param $dest anyref)
+ ;; CHECK-NEXT: (local $var anyref)
+ ;; CHECK-NEXT: (local.set $dest
+ ;; CHECK-NEXT: (local.get $var)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ (func $local-get-set (param $dest anyref)
+ ;; This will be optimized to anyref.
+ (local $var i31ref)
+ ;; Require that typeof($var) <: typeof($dest).
+ (local.set $dest
+ (local.get $var)
+ )
+ )
+
+ ;; CHECK: (func $local-get-set-unreachable (type $3) (param $dest i31ref)
+ ;; CHECK-NEXT: (local $var anyref)
+ ;; CHECK-NEXT: (unreachable)
+ ;; CHECK-NEXT: )
+ (func $local-get-set-unreachable (param $dest i31ref)
+ ;; This is not constrained by reachable code, so we will optimize it.
+ (local $var i31ref)
+ (unreachable)
+ ;; This would require that typeof($var) <: typeof($dest), except it is
+ ;; unreachable. This would not validate if we didn't run DCE first.
+ (local.set $dest
+ (local.tee $var
+ (local.get $var)
+ )
+ )
+ )
+
+ ;; CHECK: (func $local-get-set-join (type $4) (param $dest1 anyref) (param $dest2 eqref)
+ ;; CHECK-NEXT: (local $var eqref)
+ ;; CHECK-NEXT: (local.set $dest1
+ ;; CHECK-NEXT: (local.get $var)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: (local.set $dest2
+ ;; CHECK-NEXT: (local.get $var)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ (func $local-get-set-join (param $dest1 anyref) (param $dest2 eqref)
+ ;; This wll be optimized to eqref.
+ (local $var i31ref)
+ ;; Require that typeof($var) <: typeof($dest1).
+ (local.set $dest1
+ (local.get $var)
+ )
+ ;; Also require that typeof($var) <: typeof($dest2).
+ (local.set $dest2
+ (local.get $var)
+ )
+ )
+
+ ;; CHECK: (func $local-get-set-chain (type $0) (result eqref)
+ ;; CHECK-NEXT: (local $a eqref)
+ ;; CHECK-NEXT: (local $b eqref)
+ ;; CHECK-NEXT: (local $c eqref)
+ ;; CHECK-NEXT: (local.set $b
+ ;; CHECK-NEXT: (local.get $a)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: (local.set $c
+ ;; CHECK-NEXT: (local.get $b)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: (local.get $c)
+ ;; CHECK-NEXT: )
+ (func $local-get-set-chain (result eqref)
+ (local $a i31ref)
+ (local $b i31ref)
+ (local $c i31ref)
+ ;; Require that typeof($a) <: typeof($b).
+ (local.set $b
+ (local.get $a)
+ )
+ ;; Require that typeof($b) <: typeof($c).
+ (local.set $c
+ (local.get $b)
+ )
+ ;; Require that typeof($c) <: eqref.
+ (local.get $c)
+ )
+
+ ;; CHECK: (func $local-get-set-chain-out-of-order (type $0) (result eqref)
+ ;; CHECK-NEXT: (local $a eqref)
+ ;; CHECK-NEXT: (local $b eqref)
+ ;; CHECK-NEXT: (local $c eqref)
+ ;; CHECK-NEXT: (local.set $c
+ ;; CHECK-NEXT: (local.get $b)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: (local.set $b
+ ;; CHECK-NEXT: (local.get $a)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: (local.get $c)
+ ;; CHECK-NEXT: )
+ (func $local-get-set-chain-out-of-order (result eqref)
+ (local $a i31ref)
+ (local $b i31ref)
+ (local $c i31ref)
+ ;; Require that typeof($b) <: typeof($c).
+ (local.set $c
+ (local.get $b)
+ )
+ ;; Require that typeof($a) <: typeof($b). We don't know until we evaluate the
+ ;; set above that this will constrain $a to eqref.
+ (local.set $b
+ (local.get $a)
+ )
+ ;; Require that typeof($c) <: eqref.
+ (local.get $c)
+ )
+
+ ;; CHECK: (func $local-tee (type $5) (param $dest eqref)
+ ;; CHECK-NEXT: (local $var eqref)
+ ;; CHECK-NEXT: (drop
+ ;; CHECK-NEXT: (local.tee $dest
+ ;; CHECK-NEXT: (local.tee $var
+ ;; CHECK-NEXT: (ref.i31
+ ;; CHECK-NEXT: (i32.const 0)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ (func $local-tee (param $dest eqref)
+ ;; This will be optimized to eqref.
+ (local $var i31ref)
+ (drop
+ (local.tee $dest
+ (local.tee $var
+ (i31.new
+ (i32.const 0)
+ )
+ )
+ )
+ )
+ )
+
+ ;; CHECK: (func $i31-get (type $1)
+ ;; CHECK-NEXT: (local $nullable i31ref)
+ ;; CHECK-NEXT: (local $nonnullable i31ref)
+ ;; CHECK-NEXT: (local.set $nonnullable
+ ;; CHECK-NEXT: (ref.i31
+ ;; CHECK-NEXT: (i32.const 0)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: (drop
+ ;; CHECK-NEXT: (i31.get_s
+ ;; CHECK-NEXT: (local.get $nullable)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: (drop
+ ;; CHECK-NEXT: (i31.get_u
+ ;; CHECK-NEXT: (local.get $nonnullable)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ (func $i31-get
+ ;; This must stay an i31ref.
+ (local $nullable i31ref)
+ ;; We relax this one to be nullable i31ref as well.
+ (local $nonnullable (ref i31))
+ ;; Initialize the non-nullable local for validation purposes.
+ (local.set $nonnullable
+ (i31.new
+ (i32.const 0)
+ )
+ )
+ (drop
+ ;; Require that typeof($nullable) <: i31ref.
+ (i31.get_s
+ (local.get $nullable)
+ )
+ )
+ (drop
+ ;; Require that typeof($nonnullable) <: i31ref.
+ (i31.get_u
+ (local.get $nonnullable)
+ )
+ )
+ )
+)