diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/ir/hashed.h | 2 | ||||
-rw-r--r-- | src/ir/module-utils.h | 34 | ||||
-rw-r--r-- | src/ir/utils.h | 96 | ||||
-rw-r--r-- | src/pass.h | 24 | ||||
-rw-r--r-- | src/passes/CMakeLists.txt | 1 | ||||
-rw-r--r-- | src/passes/I64ToI32Lowering.cpp | 16 | ||||
-rw-r--r-- | src/passes/MemoryPacking.cpp | 2 | ||||
-rw-r--r-- | src/passes/Metrics.cpp | 2 | ||||
-rw-r--r-- | src/passes/Print.cpp | 5 | ||||
-rw-r--r-- | src/passes/PrintCallGraph.cpp | 2 | ||||
-rw-r--r-- | src/passes/RemoveNonJSOps.cpp | 6 | ||||
-rw-r--r-- | src/passes/StackIR.cpp | 393 | ||||
-rw-r--r-- | src/passes/pass.cpp | 138 | ||||
-rw-r--r-- | src/passes/passes.h | 3 | ||||
-rw-r--r-- | src/support/hash.h | 12 | ||||
-rw-r--r-- | src/wasm-binary.h | 138 | ||||
-rw-r--r-- | src/wasm-stack.h | 1244 | ||||
-rw-r--r-- | src/wasm.h | 44 | ||||
-rw-r--r-- | src/wasm/wasm-binary.cpp | 797 | ||||
-rw-r--r-- | src/wasm/wasm-validator.cpp | 2 |
20 files changed, 1988 insertions, 973 deletions
diff --git a/src/ir/hashed.h b/src/ir/hashed.h index 0771da6ec..2aaa98db2 100644 --- a/src/ir/hashed.h +++ b/src/ir/hashed.h @@ -58,8 +58,6 @@ class HashedExpressionMap : public std::unordered_map<HashedExpression, T, Expre struct FunctionHasher : public WalkerPass<PostWalker<FunctionHasher>> { bool isFunctionParallel() override { return true; } - typedef uint32_t HashType; - struct Map : public std::map<Function*, HashType> {}; FunctionHasher(Map* output) : output(output) {} diff --git a/src/ir/module-utils.h b/src/ir/module-utils.h index 83625809f..4205514e5 100644 --- a/src/ir/module-utils.h +++ b/src/ir/module-utils.h @@ -52,6 +52,23 @@ struct BinaryIndexes { } }; +inline Function* copyFunction(Function* func, Module& out) { + auto* ret = new Function(); + ret->name = func->name; + ret->result = func->result; + ret->params = func->params; + ret->vars = func->vars; + ret->type = Name(); // start with no named type; the names in the other module may differ + ret->localNames = func->localNames; + ret->localIndices = func->localIndices; + ret->debugLocations = func->debugLocations; + ret->body = ExpressionManipulator::copy(func->body, out); + // TODO: copy Stack IR + assert(!func->stackIR); + out.addFunction(ret); + return ret; +} + inline void copyModule(Module& in, Module& out) { // we use names throughout, not raw points, so simple copying is fine // for everything *but* expressions @@ -65,9 +82,7 @@ inline void copyModule(Module& in, Module& out) { out.addExport(new Export(*curr)); } for (auto& curr : in.functions) { - auto* func = new Function(*curr); - func->body = ExpressionManipulator::copy(func->body, out); - out.addFunction(func); + copyFunction(curr.get(), out); } for (auto& curr : in.globals) { out.addGlobal(new Global(*curr)); @@ -85,19 +100,6 @@ inline void copyModule(Module& in, Module& out) { out.debugInfoFileNames = in.debugInfoFileNames; } -inline Function* copyFunction(Module& in, Module& out, Name name) { - Function *ret = out.getFunctionOrNull(name); - if (ret != nullptr) { - return ret; - } - auto* curr = in.getFunction(name); - auto* func = new Function(*curr); - func->body = ExpressionManipulator::copy(func->body, out); - func->type = Name(); - out.addFunction(func); - return func; -} - } // namespace ModuleUtils } // namespace wasm diff --git a/src/ir/utils.h b/src/ir/utils.h index fa08122b5..92bfcdab3 100644 --- a/src/ir/utils.h +++ b/src/ir/utils.h @@ -88,7 +88,7 @@ struct ReFinalize : public WalkerPass<PostWalker<ReFinalize, OverriddenVisitor<R std::map<Name, Type> breakValues; - void visitBlock(Block *curr) { + void visitBlock(Block* curr) { if (curr->list.size() == 0) { curr->type = none; return; @@ -129,13 +129,13 @@ struct ReFinalize : public WalkerPass<PostWalker<ReFinalize, OverriddenVisitor<R } } } - void visitIf(If *curr) { curr->finalize(); } - void visitLoop(Loop *curr) { curr->finalize(); } - void visitBreak(Break *curr) { + void visitIf(If* curr) { curr->finalize(); } + void visitLoop(Loop* curr) { curr->finalize(); } + void visitBreak(Break* curr) { curr->finalize(); updateBreakValueType(curr->name, getValueType(curr->value)); } - void visitSwitch(Switch *curr) { + void visitSwitch(Switch* curr) { curr->finalize(); auto valueType = getValueType(curr->value); for (auto target : curr->targets) { @@ -143,28 +143,28 @@ struct ReFinalize : public WalkerPass<PostWalker<ReFinalize, OverriddenVisitor<R } updateBreakValueType(curr->default_, valueType); } - void visitCall(Call *curr) { curr->finalize(); } - void visitCallImport(CallImport *curr) { curr->finalize(); } - void visitCallIndirect(CallIndirect *curr) { curr->finalize(); } - void visitGetLocal(GetLocal *curr) { curr->finalize(); } - void visitSetLocal(SetLocal *curr) { curr->finalize(); } - void visitGetGlobal(GetGlobal *curr) { curr->finalize(); } - void visitSetGlobal(SetGlobal *curr) { curr->finalize(); } - void visitLoad(Load *curr) { curr->finalize(); } - void visitStore(Store *curr) { curr->finalize(); } - void visitAtomicRMW(AtomicRMW *curr) { curr->finalize(); } - void visitAtomicCmpxchg(AtomicCmpxchg *curr) { curr->finalize(); } + void visitCall(Call* curr) { curr->finalize(); } + void visitCallImport(CallImport* curr) { curr->finalize(); } + void visitCallIndirect(CallIndirect* curr) { curr->finalize(); } + void visitGetLocal(GetLocal* curr) { curr->finalize(); } + void visitSetLocal(SetLocal* curr) { curr->finalize(); } + void visitGetGlobal(GetGlobal* curr) { curr->finalize(); } + void visitSetGlobal(SetGlobal* curr) { curr->finalize(); } + void visitLoad(Load* curr) { curr->finalize(); } + void visitStore(Store* curr) { curr->finalize(); } + void visitAtomicRMW(AtomicRMW* curr) { curr->finalize(); } + void visitAtomicCmpxchg(AtomicCmpxchg* curr) { curr->finalize(); } void visitAtomicWait(AtomicWait* curr) { curr->finalize(); } void visitAtomicWake(AtomicWake* curr) { curr->finalize(); } - void visitConst(Const *curr) { curr->finalize(); } - void visitUnary(Unary *curr) { curr->finalize(); } - void visitBinary(Binary *curr) { curr->finalize(); } - void visitSelect(Select *curr) { curr->finalize(); } - void visitDrop(Drop *curr) { curr->finalize(); } - void visitReturn(Return *curr) { curr->finalize(); } - void visitHost(Host *curr) { curr->finalize(); } - void visitNop(Nop *curr) { curr->finalize(); } - void visitUnreachable(Unreachable *curr) { curr->finalize(); } + void visitConst(Const* curr) { curr->finalize(); } + void visitUnary(Unary* curr) { curr->finalize(); } + void visitBinary(Binary* curr) { curr->finalize(); } + void visitSelect(Select* curr) { curr->finalize(); } + void visitDrop(Drop* curr) { curr->finalize(); } + void visitReturn(Return* curr) { curr->finalize(); } + void visitHost(Host* curr) { curr->finalize(); } + void visitNop(Nop* curr) { curr->finalize(); } + void visitUnreachable(Unreachable* curr) { curr->finalize(); } void visitFunction(Function* curr) { // we may have changed the body from unreachable to none, which might be bad @@ -197,33 +197,33 @@ struct ReFinalize : public WalkerPass<PostWalker<ReFinalize, OverriddenVisitor<R // Re-finalize a single node. This is slow, if you want to refinalize // an entire ast, use ReFinalize struct ReFinalizeNode : public OverriddenVisitor<ReFinalizeNode> { - void visitBlock(Block *curr) { curr->finalize(); } - void visitIf(If *curr) { curr->finalize(); } - void visitLoop(Loop *curr) { curr->finalize(); } - void visitBreak(Break *curr) { curr->finalize(); } - void visitSwitch(Switch *curr) { curr->finalize(); } - void visitCall(Call *curr) { curr->finalize(); } - void visitCallImport(CallImport *curr) { curr->finalize(); } - void visitCallIndirect(CallIndirect *curr) { curr->finalize(); } - void visitGetLocal(GetLocal *curr) { curr->finalize(); } - void visitSetLocal(SetLocal *curr) { curr->finalize(); } - void visitGetGlobal(GetGlobal *curr) { curr->finalize(); } - void visitSetGlobal(SetGlobal *curr) { curr->finalize(); } - void visitLoad(Load *curr) { curr->finalize(); } - void visitStore(Store *curr) { curr->finalize(); } + void visitBlock(Block* curr) { curr->finalize(); } + void visitIf(If* curr) { curr->finalize(); } + void visitLoop(Loop* curr) { curr->finalize(); } + void visitBreak(Break* curr) { curr->finalize(); } + void visitSwitch(Switch* curr) { curr->finalize(); } + void visitCall(Call* curr) { curr->finalize(); } + void visitCallImport(CallImport* curr) { curr->finalize(); } + void visitCallIndirect(CallIndirect* curr) { curr->finalize(); } + void visitGetLocal(GetLocal* curr) { curr->finalize(); } + void visitSetLocal(SetLocal* curr) { curr->finalize(); } + void visitGetGlobal(GetGlobal* curr) { curr->finalize(); } + void visitSetGlobal(SetGlobal* curr) { curr->finalize(); } + void visitLoad(Load* curr) { curr->finalize(); } + void visitStore(Store* curr) { curr->finalize(); } void visitAtomicRMW(AtomicRMW* curr) { curr->finalize(); } void visitAtomicCmpxchg(AtomicCmpxchg* curr) { curr->finalize(); } void visitAtomicWait(AtomicWait* curr) { curr->finalize(); } void visitAtomicWake(AtomicWake* curr) { curr->finalize(); } - void visitConst(Const *curr) { curr->finalize(); } - void visitUnary(Unary *curr) { curr->finalize(); } - void visitBinary(Binary *curr) { curr->finalize(); } - void visitSelect(Select *curr) { curr->finalize(); } - void visitDrop(Drop *curr) { curr->finalize(); } - void visitReturn(Return *curr) { curr->finalize(); } - void visitHost(Host *curr) { curr->finalize(); } - void visitNop(Nop *curr) { curr->finalize(); } - void visitUnreachable(Unreachable *curr) { curr->finalize(); } + void visitConst(Const* curr) { curr->finalize(); } + void visitUnary(Unary* curr) { curr->finalize(); } + void visitBinary(Binary* curr) { curr->finalize(); } + void visitSelect(Select* curr) { curr->finalize(); } + void visitDrop(Drop* curr) { curr->finalize(); } + void visitReturn(Return* curr) { curr->finalize(); } + void visitHost(Host* curr) { curr->finalize(); } + void visitNop(Nop* curr) { curr->finalize(); } + void visitUnreachable(Unreachable* curr) { curr->finalize(); } void visitFunctionType(FunctionType* curr) { WASM_UNREACHABLE(); } void visitImport(Import* curr) { WASM_UNREACHABLE(); } diff --git a/src/pass.h b/src/pass.h index 30e16761d..384dcf791 100644 --- a/src/pass.h +++ b/src/pass.h @@ -76,6 +76,10 @@ struct PassOptions { ret.setDefaultOptimizationOptions(); return ret; } + + static PassOptions getWithoutOptimization() { + return PassOptions(); // defaults are to not optimize + } }; // @@ -137,6 +141,9 @@ struct PassRunner { // Adds the default optimization passes that work on // entire modules as a whole, and make sense to // run after function passes. + // This is run at the very end of the optimization + // process - you can assume no other opts will be run + // afterwards. void addDefaultGlobalOptimizationPostPasses(); // Run the passes on the module @@ -174,7 +181,16 @@ protected: private: void doAdd(Pass* pass); + void runPass(Pass* pass); void runPassOnFunction(Pass* pass, Function* func); + + // After running a pass, handle any changes due to + // how the pass is defined, such as clearing away any + // temporary data structures that the pass declares it + // invalidates. + // If a function is passed, we operate just on that function; + // otherwise, the whole module. + void handleAfterEffects(Pass* pass, Function* func=nullptr); }; // @@ -223,6 +239,14 @@ public: // this will create the parent class. virtual Pass* create() { WASM_UNREACHABLE(); } + // Whether this pass modifies the Binaryen IR in the module. This is true for + // most passes, except for passes that have no side effects, or passes that + // only modify other things than Binaryen IR (for example, the Stack IR + // passes only modify that IR). + // This property is important as if Binaryen IR is modified, we need to throw + // out any Stack IR - it would need to be regenerated and optimized. + virtual bool modifiesBinaryenIR() { return true; } + std::string name; protected: diff --git a/src/passes/CMakeLists.txt b/src/passes/CMakeLists.txt index bbb2a4610..25a1828dc 100644 --- a/src/passes/CMakeLists.txt +++ b/src/passes/CMakeLists.txt @@ -32,6 +32,7 @@ SET(passes_SOURCES Precompute.cpp Print.cpp PrintCallGraph.cpp + StackIR.cpp RedundantSetElimination.cpp RelooperJumpThreading.cpp ReReloop.cpp diff --git a/src/passes/I64ToI32Lowering.cpp b/src/passes/I64ToI32Lowering.cpp index e501107bd..2986deb9f 100644 --- a/src/passes/I64ToI32Lowering.cpp +++ b/src/passes/I64ToI32Lowering.cpp @@ -27,6 +27,7 @@ #include "emscripten-optimizer/istring.h" #include "support/name.h" #include "wasm-builder.h" +#include "ir/module-utils.h" #include "ir/names.h" #include "asmjs/shared-constants.h" @@ -143,19 +144,20 @@ struct I64ToI32Lowering : public WalkerPass<PostWalker<I64ToI32Lowering>> { highBitVars.clear(); labelHighBitVars.clear(); freeTemps.clear(); - Function oldFunc(*func); + Module temp; + auto* oldFunc = ModuleUtils::copyFunction(func, temp); func->params.clear(); func->vars.clear(); func->localNames.clear(); func->localIndices.clear(); Index newIdx = 0; - Names::ensureNames(&oldFunc); - for (Index i = 0; i < oldFunc.getNumLocals(); ++i) { - assert(oldFunc.hasLocalName(i)); - Name lowName = oldFunc.getLocalName(i); + Names::ensureNames(oldFunc); + for (Index i = 0; i < oldFunc->getNumLocals(); ++i) { + assert(oldFunc->hasLocalName(i)); + Name lowName = oldFunc->getLocalName(i); Name highName = makeHighName(lowName); - Type paramType = oldFunc.getLocalType(i); - auto builderFunc = (i < oldFunc.getVarIndexBase()) ? + Type paramType = oldFunc->getLocalType(i); + auto builderFunc = (i < oldFunc->getVarIndexBase()) ? Builder::addParam : static_cast<Index (*)(Function*, Name, Type)>(Builder::addVar); if (paramType == i64) { diff --git a/src/passes/MemoryPacking.cpp b/src/passes/MemoryPacking.cpp index 1ba004886..c7b20c582 100644 --- a/src/passes/MemoryPacking.cpp +++ b/src/passes/MemoryPacking.cpp @@ -24,6 +24,8 @@ namespace wasm { const Index OVERHEAD = 8; struct MemoryPacking : public Pass { + bool modifiesBinaryenIR() override { return false; } + void run(PassRunner* runner, Module* module) override { if (!module->memory.exists) return; std::vector<Memory::Segment> packed; diff --git a/src/passes/Metrics.cpp b/src/passes/Metrics.cpp index 8cbf96db5..81706042b 100644 --- a/src/passes/Metrics.cpp +++ b/src/passes/Metrics.cpp @@ -32,6 +32,8 @@ static Counts lastCounts; // Prints metrics between optimization passes. struct Metrics : public WalkerPass<PostWalker<Metrics, UnifiedExpressionVisitor<Metrics>>> { + bool modifiesBinaryenIR() override { return false; } + bool byFunction; Counts counts; diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp index b1082dcc1..dfdfa46d4 100644 --- a/src/passes/Print.cpp +++ b/src/passes/Print.cpp @@ -725,6 +725,9 @@ struct PrintSExpression : public Visitor<PrintSExpression> { } o << " (; " << functionIndexes[curr->name] << " ;)"; } + if (curr->stackIR && !minify) { + o << " (; has Stack IR ;)"; + } if (curr->type.is()) { o << maybeSpace << "(type " << curr->type << ')'; } @@ -888,6 +891,8 @@ public: Printer() : o(std::cout) {} Printer(std::ostream* o) : o(*o) {} + bool modifiesBinaryenIR() override { return false; } + void run(PassRunner* runner, Module* module) override { PrintSExpression print(o); print.visitModule(module); diff --git a/src/passes/PrintCallGraph.cpp b/src/passes/PrintCallGraph.cpp index ac11dfb8b..fa58e3859 100644 --- a/src/passes/PrintCallGraph.cpp +++ b/src/passes/PrintCallGraph.cpp @@ -29,6 +29,8 @@ namespace wasm { struct PrintCallGraph : public Pass { + bool modifiesBinaryenIR() override { return false; } + void run(PassRunner* runner, Module* module) override { std::ostream &o = std::cout; o << "digraph call {\n" diff --git a/src/passes/RemoveNonJSOps.cpp b/src/passes/RemoveNonJSOps.cpp index 76c9528cb..ef8c7531c 100644 --- a/src/passes/RemoveNonJSOps.cpp +++ b/src/passes/RemoveNonJSOps.cpp @@ -89,7 +89,11 @@ struct RemoveNonJSOpsPass : public WalkerPass<PostWalker<RemoveNonJSOpsPass>> { // copy we then walk the function to rewrite any non-js operations it has // as well. for (auto &name : neededFunctions) { - doWalkFunction(ModuleUtils::copyFunction(intrinsicsModule, *module, name)); + auto* func = module->getFunctionOrNull(name); + if (!func) { + func = ModuleUtils::copyFunction(intrinsicsModule.getFunction(name), *module); + } + doWalkFunction(func); } neededFunctions.clear(); } diff --git a/src/passes/StackIR.cpp b/src/passes/StackIR.cpp new file mode 100644 index 000000000..43c95608e --- /dev/null +++ b/src/passes/StackIR.cpp @@ -0,0 +1,393 @@ +/* + * Copyright 2018 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// +// Operations on Stack IR. +// + +#include "wasm.h" +#include "pass.h" +#include "wasm-stack.h" +#include "ir/iteration.h" +#include "ir/local-graph.h" + +namespace wasm { + +// Generate Stack IR from Binaryen IR + +struct GenerateStackIR : public WalkerPass<PostWalker<GenerateStackIR>> { + bool isFunctionParallel() override { return true; } + + Pass* create() override { return new GenerateStackIR; } + + bool modifiesBinaryenIR() override { return false; } + + void doWalkFunction(Function* func) { + BufferWithRandomAccess buffer; + // a shim for the parent that a stackWriter expects - we don't need + // it to do anything, as we are just writing to Stack IR + struct Parent { + Module* module; + Parent(Module* module) : module(module) {} + + Module* getModule() { + return module; + } + void writeDebugLocation(Expression* curr, Function* func) { + WASM_UNREACHABLE(); + } + Index getFunctionIndex(Name name) { + WASM_UNREACHABLE(); + } + Index getFunctionTypeIndex(Name name) { + WASM_UNREACHABLE(); + } + Index getGlobalIndex(Name name) { + WASM_UNREACHABLE(); + } + } parent(getModule()); + StackWriter<StackWriterMode::Binaryen2Stack, Parent> stackWriter(parent, buffer, false); + stackWriter.setFunction(func); + stackWriter.visitPossibleBlockContents(func->body); + func->stackIR = make_unique<StackIR>(); + func->stackIR->swap(stackWriter.stackIR); + } +}; + +Pass* createGenerateStackIRPass() { + return new GenerateStackIR(); +} + +// Print (for debugging purposes) + +struct PrintStackIR : public WalkerPass<PostWalker<PrintStackIR>> { + // Not parallel: this pass is just for testing and debugging; keep the output + // sorted by function order. + bool isFunctionParallel() override { return false; } + + Pass* create() override { return new PrintStackIR; } + + bool modifiesBinaryenIR() override { return false; } + + void doWalkFunction(Function* func) { + std::cout << func->name << ":\n"; + if (func->stackIR) { + std::cout << *func->stackIR; + } else { + std::cout << " (no stack ir)"; + } + std::cout << '\n'; + } +}; + +Pass* createPrintStackIRPass() { + return new PrintStackIR(); +} + +// Optimize + +class StackIROptimizer { + Function* func; + PassOptions& passOptions; + StackIR& insts; + +public: + StackIROptimizer(Function* func, PassOptions& passOptions) : + func(func), passOptions(passOptions), insts(*func->stackIR.get()) { + assert(func->stackIR); + } + + void run() { + dce(); + // FIXME: local2Stack is currently rather slow (due to localGraph), + // so for now run it only when really optimizing + if (passOptions.optimizeLevel >= 3 || passOptions.shrinkLevel >= 1) { + local2Stack(); + } + removeUnneededBlocks(); + dce(); + } + +private: + // Passes. + + // Remove unreachable code. + void dce() { + bool inUnreachableCode = false; + for (Index i = 0; i < insts.size(); i++) { + auto* inst = insts[i]; + if (!inst) continue; + if (inUnreachableCode) { + // Does the unreachable code end here? + if (isControlFlowBarrier(inst)) { + inUnreachableCode = false; + } else { + // We can remove this. + removeAt(i); + } + } else if (inst->type == unreachable) { + inUnreachableCode = true; + } + } + } + + // If ordered properly, we can avoid a set_local/get_local pair, + // and use the value directly from the stack, for example + // [..produce a value on the stack..] + // set_local $x + // [..much code..] + // get_local $x + // call $foo ;; use the value, foo(value) + // As long as the code in between does not modify $x, and has + // no control flow branching out, we can remove both the set + // and the get. + void local2Stack() { + // We use the localGraph to tell us if a get-set pair is indeed + // a set that is read by that get, and only that get. Note that we run + // this on the Binaryen IR, so we are assuming that no previous opt + // has changed the interaction of local operations. + // TODO: we can do this a lot faster, as we just care about linear + // control flow. + LocalGraph localGraph(func); + localGraph.computeInfluences(); + // We maintain a stack of relevant values. This contains: + // * a null for each actual value that the value stack would have + // * an index of each SetLocal that *could* be on the value + // stack at that location. + const Index null = -1; + std::vector<Index> values; + // We also maintain a stack of values vectors for control flow, + // saving the stack as we enter and restoring it when we exit. + std::vector<std::vector<Index>> savedValues; +#ifdef STACK_OPT_DEBUG + std::cout << "func: " << func->name << '\n' << insts << '\n'; +#endif + for (Index i = 0; i < insts.size(); i++) { + auto* inst = insts[i]; + if (!inst) continue; + // First, consume values from the stack as required. + auto consumed = getNumConsumedValues(inst); +#ifdef STACK_OPT_DEBUG + std::cout << " " << i << " : " << *inst << ", " << values.size() << " on stack, will consume " << consumed << "\n "; + for (auto s : values) std::cout << s << ' '; + std::cout << '\n'; +#endif + // TODO: currently we run dce before this, but if we didn't, we'd need + // to handle unreachable code here - it's ok to pop multiple values + // there even if the stack is at size 0. + while (consumed > 0) { + assert(values.size() > 0); + // Whenever we hit a possible stack value, kill it - it would + // be consumed here, so we can never optimize to it. + while (values.back() != null) { + values.pop_back(); + assert(values.size() > 0); + } + // Finally, consume the actual value that is consumed here. + values.pop_back(); + consumed--; + } + // After consuming, we can see what to do with this. First, handle + // control flow. + if (isControlFlowBegin(inst)) { + // Save the stack for when we end this control flow. + savedValues.push_back(values); // TODO: optimize copies + values.clear(); + } else if (isControlFlowEnd(inst)) { + assert(!savedValues.empty()); + values = savedValues.back(); + savedValues.pop_back(); + } else if (isControlFlow(inst)) { + // Otherwise, in the middle of control flow, just clear it + values.clear(); + } + // This is something we should handle, look into it. + if (isConcreteType(inst->type)) { + bool optimized = false; + if (auto* get = inst->origin->dynCast<GetLocal>()) { + // This is a potential optimization opportunity! See if we + // can reach the set. + if (values.size() > 0) { + Index j = values.size() - 1; + while (1) { + // If there's an actual value in the way, we've failed. + auto index = values[j]; + if (index == null) break; + auto* set = insts[index]->origin->cast<SetLocal>(); + if (set->index == get->index) { + // This might be a proper set-get pair, where the set is + // used by this get and nothing else, check that. + auto& sets = localGraph.getSetses[get]; + if (sets.size() == 1 && *sets.begin() == set) { + auto& setInfluences = localGraph.setInfluences[set]; + if (setInfluences.size() == 1) { + assert(*setInfluences.begin() == get); + // Do it! The set and the get can go away, the proper + // value is on the stack. +#ifdef STACK_OPT_DEBUG + std::cout << " stackify the get\n"; +#endif + insts[index] = nullptr; + insts[i] = nullptr; + // Continuing on from here, replace this on the stack + // with a null, representing a regular value. We + // keep possible values above us active - they may + // be optimized later, as they would be pushed after + // us, and used before us, so there is no conflict. + values[j] = null; + optimized = true; + break; + } + } + } + // We failed here. Can we look some more? + if (j == 0) break; + j--; + } + } + } + if (!optimized) { + // This is an actual regular value on the value stack. + values.push_back(null); + } + } else if (inst->origin->is<SetLocal>() && inst->type == none) { + // This set is potentially optimizable later, add to stack. + values.push_back(i); + } + } + } + + // There may be unnecessary blocks we can remove: blocks + // without branches to them are always ok to remove. + // TODO: a branch to a block in an if body can become + // a branch to that if body + void removeUnneededBlocks() { + for (auto*& inst : insts) { + if (!inst) continue; + if (auto* block = inst->origin->dynCast<Block>()) { + if (!BranchUtils::BranchSeeker::hasNamed(block, block->name)) { + // TODO optimize, maybe run remove-unused-names + inst = nullptr; + } + } + } + } + + // Utilities. + + // A control flow "barrier" - a point where stack machine + // unreachability ends. + bool isControlFlowBarrier(StackInst* inst) { + switch (inst->op) { + case StackInst::BlockEnd: + case StackInst::IfElse: + case StackInst::IfEnd: + case StackInst::LoopEnd: { + return true; + } + default: { + return false; + } + } + } + + // A control flow beginning. + bool isControlFlowBegin(StackInst* inst) { + switch (inst->op) { + case StackInst::BlockBegin: + case StackInst::IfBegin: + case StackInst::LoopBegin: { + return true; + } + default: { + return false; + } + } + } + + // A control flow ending. + bool isControlFlowEnd(StackInst* inst) { + switch (inst->op) { + case StackInst::BlockEnd: + case StackInst::IfEnd: + case StackInst::LoopEnd: { + return true; + } + default: { + return false; + } + } + } + + bool isControlFlow(StackInst* inst) { + return inst->op != StackInst::Basic; + } + + // Remove the instruction at index i. If the instruction + // is control flow, and so has been expanded to multiple + // instructions, remove them as well. + void removeAt(Index i) { + auto* inst = insts[i]; + insts[i] = nullptr; + if (inst->op == StackInst::Basic) { + return; // that was it + } + auto* origin = inst->origin; + while (1) { + i++; + assert(i < insts.size()); + inst = insts[i]; + insts[i] = nullptr; + if (inst && inst->origin == origin && isControlFlowEnd(inst)) { + return; // that's it, we removed it all + } + } + } + + Index getNumConsumedValues(StackInst* inst) { + if (isControlFlow(inst)) { + // If consumes 1; that's it. + if (inst->op == StackInst::IfBegin) { + return 1; + } + return 0; + } + // Otherwise, for basic instructions, just count the expression children. + return ChildIterator(inst->origin).children.size(); + } +}; + +struct OptimizeStackIR : public WalkerPass<PostWalker<OptimizeStackIR>> { + bool isFunctionParallel() override { return true; } + + Pass* create() override { return new OptimizeStackIR; } + + bool modifiesBinaryenIR() override { return false; } + + void doWalkFunction(Function* func) { + if (!func->stackIR) { + return; + } + StackIROptimizer(func, getPassOptions()).run(); + } +}; + +Pass* createOptimizeStackIRPass() { + return new OptimizeStackIR(); +} + +} // namespace wasm + diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp index c0354524d..97151f847 100644 --- a/src/passes/pass.cpp +++ b/src/passes/pass.cpp @@ -22,6 +22,7 @@ #include <pass.h> #include <wasm-validator.h> #include <wasm-io.h> +#include "ir/hashed.h" namespace wasm { @@ -76,6 +77,7 @@ void PassRegistry::registerPasses() { registerPass("flatten", "flattens out code, removing nesting", createFlattenPass); registerPass("fpcast-emu", "emulates function pointer casts, allowing incorrect indirect calls to (sometimes) work", createFuncCastEmulationPass); registerPass("func-metrics", "reports function metrics", createFunctionMetricsPass); + registerPass("generate-stack-ir", "generate Stack IR", createGenerateStackIRPass); registerPass("inlining", "inline functions (you probably want inlining-optimizing)", createInliningPass); registerPass("inlining-optimizing", "inline functions and optimizes where we inlined", createInliningOptimizingPass); registerPass("legalize-js-interface", "legalizes i64 types on the import/export boundary", createLegalizeJSInterfacePass); @@ -90,6 +92,7 @@ void PassRegistry::registerPasses() { registerPass("metrics", "reports metrics", createMetricsPass); registerPass("nm", "name list", createNameListPass); registerPass("optimize-instructions", "optimizes instruction combinations", createOptimizeInstructionsPass); + registerPass("optimize-stack-ir", "optimize Stack IR", createOptimizeStackIRPass); registerPass("pick-load-signs", "pick load signs based on their uses", createPickLoadSignsPass); registerPass("post-emscripten", "miscellaneous optimizations for Emscripten-generated code", createPostEmscriptenPass); registerPass("precompute", "computes compile-time evaluatable expressions", createPrecomputePass); @@ -98,6 +101,7 @@ void PassRegistry::registerPasses() { registerPass("print-minified", "print in minified s-expression format", createMinifiedPrinterPass); registerPass("print-full", "print in full s-expression format", createFullPrinterPass); registerPass("print-call-graph", "print call graph", createPrintCallGraphPass); + registerPass("print-stack-ir", "print out Stack IR (useful for internal debugging)", createPrintStackIRPass); registerPass("relooper-jump-threading", "thread relooper jumps (fastcomp output only)", createRelooperJumpThreadingPass); registerPass("remove-non-js-ops", "removes operations incompatible with js", createRemoveNonJSOpsPass); registerPass("remove-imports", "removes imports and replaces them with nops", createRemoveImportsPass); @@ -201,6 +205,12 @@ void PassRunner::addDefaultGlobalOptimizationPostPasses() { add("duplicate-function-elimination"); // optimizations show more functions as duplicate add("remove-unused-module-elements"); add("memory-packing"); + // perform Stack IR optimizations here, at the very end of the + // optimization pipeline + if (options.optimizeLevel >= 2 || options.shrinkLevel >= 1) { + add("generate-stack-ir"); + add("optimize-stack-ir"); + } } static void dumpWast(Name name, Module* wasm) { @@ -252,7 +262,7 @@ void PassRunner::run() { runPassOnFunction(pass, func.get()); } } else { - pass->run(this, wasm); + runPass(pass); } auto after = std::chrono::steady_clock::now(); std::chrono::duration<double> diff = after - before; @@ -320,7 +330,7 @@ void PassRunner::run() { stack.push_back(pass); } else { flush(); - pass->run(this, wasm); + runPass(pass); } } flush(); @@ -347,11 +357,135 @@ void PassRunner::doAdd(Pass* pass) { pass->prepareToRun(this, wasm); } +// Checks that the state is valid before and after a +// pass runs on a function. We run these extra checks when +// pass-debug mode is enabled. +struct AfterEffectFunctionChecker { + Function* func; + Name name; + + // Check Stack IR state: if the main IR changes, there should be no + // stack IR, as the stack IR would be wrong. + bool beganWithStackIR; + HashType originalFunctionHash; + + // In the creator we can scan the state of the module and function before the + // pass runs. + AfterEffectFunctionChecker(Function* func) : func(func), name(func->name) { + beganWithStackIR = func->stackIR != nullptr; + if (beganWithStackIR) { + originalFunctionHash = FunctionHasher::hashFunction(func); + } + } + + // This is called after the pass is run, at which time we can check things. + void check() { + assert(func->name == name); // no global module changes should have occurred + if (beganWithStackIR && func->stackIR) { + auto after = FunctionHasher::hashFunction(func); + if (after != originalFunctionHash) { + Fatal() << "[PassRunner] PASS_DEBUG check failed: had Stack IR before and after the pass ran, and the pass modified the main IR, which invalidates Stack IR - pass should have been marked 'modifiesBinaryenIR'"; + } + } + } +}; + +// Runs checks on the entire module, in a non-function-parallel pass. +// In particular, in such a pass functions may be removed or renamed, track that. +struct AfterEffectModuleChecker { + Module* module; + + std::vector<AfterEffectFunctionChecker> checkers; + + bool beganWithAnyStackIR; + + AfterEffectModuleChecker(Module* module) : module(module) { + for (auto& func : module->functions) { + checkers.emplace_back(func.get()); + } + beganWithAnyStackIR = hasAnyStackIR(); + } + + void check() { + if (beganWithAnyStackIR && hasAnyStackIR()) { + // If anything changed to the functions, that's not good. + if (checkers.size() != module->functions.size()) { + error(); + } + for (Index i = 0; i < checkers.size(); i++) { + // Did a pointer change? (a deallocated function could cause that) + if (module->functions[i].get() != checkers[i].func || + module->functions[i]->body != checkers[i].func->body) { + error(); + } + // Did a name change? + if (module->functions[i]->name != checkers[i].name) { + error(); + } + } + // Global function state appears to not have been changed: the same + // functions are there. Look into their contents. + for (auto& checker : checkers) { + checker.check(); + } + } + } + + void error() { + Fatal() << "[PassRunner] PASS_DEBUG check failed: had Stack IR before and after the pass ran, and the pass modified global function state - pass should have been marked 'modifiesBinaryenIR'"; + } + + bool hasAnyStackIR() { + for (auto& func : module->functions) { + if (func->stackIR) { + return true; + } + } + return false; + } +}; + +void PassRunner::runPass(Pass* pass) { + std::unique_ptr<AfterEffectModuleChecker> checker; + if (getPassDebug()) { + checker = std::unique_ptr<AfterEffectModuleChecker>( + new AfterEffectModuleChecker(wasm)); + } + pass->run(this, wasm); + handleAfterEffects(pass); + if (getPassDebug()) { + checker->check(); + } +} + void PassRunner::runPassOnFunction(Pass* pass, Function* func) { assert(pass->isFunctionParallel()); // function-parallel passes get a new instance per function auto instance = std::unique_ptr<Pass>(pass->create()); + std::unique_ptr<AfterEffectFunctionChecker> checker; + if (getPassDebug()) { + checker = std::unique_ptr<AfterEffectFunctionChecker>( + new AfterEffectFunctionChecker(func)); + } instance->runOnFunction(this, wasm, func); + handleAfterEffects(pass, func); + if (getPassDebug()) { + checker->check(); + } +} + +void PassRunner::handleAfterEffects(Pass* pass, Function* func) { + if (pass->modifiesBinaryenIR()) { + // If Binaryen IR is modified, Stack IR must be cleared - it would + // be out of sync in a potentially dangerous way. + if (func) { + func->stackIR.reset(nullptr); + } else { + for (auto& func : wasm->functions) { + func->stackIR.reset(nullptr); + } + } + } } int PassRunner::getPassDebug() { diff --git a/src/passes/passes.h b/src/passes/passes.h index 7a96799b3..1e26dc777 100644 --- a/src/passes/passes.h +++ b/src/passes/passes.h @@ -34,6 +34,7 @@ Pass* createFlattenPass(); Pass* createFuncCastEmulationPass(); Pass* createFullPrinterPass(); Pass* createFunctionMetricsPass(); +Pass* createGenerateStackIRPass(); Pass* createI64ToI32LoweringPass(); Pass* createInliningPass(); Pass* createInliningOptimizingPass(); @@ -49,12 +50,14 @@ Pass* createMinifiedPrinterPass(); Pass* createMetricsPass(); Pass* createNameListPass(); Pass* createOptimizeInstructionsPass(); +Pass* createOptimizeStackIRPass(); Pass* createPickLoadSignsPass(); Pass* createPostEmscriptenPass(); Pass* createPrecomputePass(); Pass* createPrecomputePropagatePass(); Pass* createPrinterPass(); Pass* createPrintCallGraphPass(); +Pass* createPrintStackIRPass(); Pass* createRelooperJumpThreadingPass(); Pass* createRemoveNonJSOpsPass(); Pass* createRemoveImportsPass(); diff --git a/src/support/hash.h b/src/support/hash.h index 158f20773..98d7ceead 100644 --- a/src/support/hash.h +++ b/src/support/hash.h @@ -22,9 +22,11 @@ namespace wasm { -inline uint32_t rehash(uint32_t x, uint32_t y) { +typedef uint32_t HashType; + +inline HashType rehash(HashType x, HashType y) { // see http://www.cse.yorku.ca/~oz/hash.html and https://stackoverflow.com/a/2595226/1176841 - uint32_t hash = 5381; + HashType hash = 5381; while (x) { hash = ((hash << 5) + hash) ^ (x & 0xff); x >>= 8; @@ -37,9 +39,9 @@ inline uint32_t rehash(uint32_t x, uint32_t y) { } inline uint64_t rehash(uint64_t x, uint64_t y) { - auto ret = rehash(uint32_t(x), uint32_t(x >> 32)); - ret = rehash(ret, uint32_t(y)); - return rehash(ret, uint32_t(y >> 32)); + auto ret = rehash(HashType(x), HashType(x >> 32)); + ret = rehash(ret, HashType(y)); + return rehash(ret, HashType(y >> 32)); } } // namespace wasm diff --git a/src/wasm-binary.h b/src/wasm-binary.h index cdac878c1..670084401 100644 --- a/src/wasm-binary.h +++ b/src/wasm-binary.h @@ -652,96 +652,14 @@ inline S32LEB binaryType(Type type) { return S32LEB(ret); } -class WasmBinaryWriter; - -// Writes out binary format stack machine code for a Binaryen IR expression - -class StackWriter : public Visitor<StackWriter> { -public: - // Without a function (offset for a global thing, etc.) - StackWriter(WasmBinaryWriter& parent, BufferWithRandomAccess& o, bool debug=false) - : func(nullptr), parent(parent), o(o), sourceMap(false), debug(debug) {} - - // With a function - one is created for the entire function - StackWriter(Function* func, WasmBinaryWriter& parent, BufferWithRandomAccess& o, bool sourceMap=false, bool debug=false) - : func(func), parent(parent), o(o), sourceMap(sourceMap), debug(debug) { - mapLocals(); - } - - std::map<Type, size_t> numLocalsByType; // type => number of locals of that type in the compact form - - // visits a node, emitting the proper code for it - void visit(Expression* curr); - // emits a node, but if it is a block with no name, emit a list of its contents - void visitPossibleBlockContents(Expression* curr); - - void visitBlock(Block *curr); - void visitIf(If *curr); - void visitLoop(Loop *curr); - void visitBreak(Break *curr); - void visitSwitch(Switch *curr); - void visitCall(Call *curr); - void visitCallImport(CallImport *curr); - void visitCallIndirect(CallIndirect *curr); - void visitGetLocal(GetLocal *curr); - void visitSetLocal(SetLocal *curr); - void visitGetGlobal(GetGlobal *curr); - void visitSetGlobal(SetGlobal *curr); - void visitLoad(Load *curr); - void visitStore(Store *curr); - void visitAtomicRMW(AtomicRMW *curr); - void visitAtomicCmpxchg(AtomicCmpxchg *curr); - void visitAtomicWait(AtomicWait *curr); - void visitAtomicWake(AtomicWake *curr); - void visitConst(Const *curr); - void visitUnary(Unary *curr); - void visitBinary(Binary *curr); - void visitSelect(Select *curr); - void visitReturn(Return *curr); - void visitHost(Host *curr); - void visitNop(Nop *curr); - void visitUnreachable(Unreachable *curr); - void visitDrop(Drop *curr); - -private: - Function* func; - WasmBinaryWriter& parent; - BufferWithRandomAccess& o; - bool sourceMap; - bool debug; - - std::map<Index, size_t> mappedLocals; // local index => index in compact form of [all int32s][all int64s]etc - - std::vector<Name> breakStack; - - int32_t getBreakIndex(Name name); - void emitMemoryAccess(size_t alignment, size_t bytes, uint32_t offset); - - void mapLocals(); -}; - // Writes out wasm to the binary format class WasmBinaryWriter { - Module* wasm; - BufferWithRandomAccess& o; - bool debug; - bool debugInfo = true; - std::ostream* sourceMap = nullptr; - std::string sourceMapUrl; - std::string symbolMap; - - MixedArena allocator; - - // storage of source map locations until the section is placed at its final location - // (shrinking LEBs may cause changes there) - std::vector<std::pair<size_t, const Function::DebugLocation*>> sourceMapLocations; - size_t sourceMapLocationsSizeAtSectionStart; - Function::DebugLocation lastDebugLocation; - - void prepare(); public: - WasmBinaryWriter(Module* input, BufferWithRandomAccess& o, bool debug = false) : wasm(input), o(o), debug(debug) { + WasmBinaryWriter(Module* input, + BufferWithRandomAccess& o, + bool debug = false) : + wasm(input), o(o), debug(debug) { prepare(); } @@ -817,6 +735,28 @@ public: void emitBuffer(const char* data, size_t size); void emitString(const char *str); void finishUp(); + + Module* getModule() { return wasm; } + +private: + Module* wasm; + BufferWithRandomAccess& o; + bool debug; + + bool debugInfo = true; + std::ostream* sourceMap = nullptr; + std::string sourceMapUrl; + std::string symbolMap; + + MixedArena allocator; + + // storage of source map locations until the section is placed at its final location + // (shrinking LEBs may cause changes there) + std::vector<std::pair<size_t, const Function::DebugLocation*>> sourceMapLocations; + size_t sourceMapLocationsSizeAtSectionStart; + Function::DebugLocation lastDebugLocation; + + void prepare(); }; class WasmBinaryBuilder { @@ -967,16 +907,16 @@ public: BinaryConsts::ASTNodes readExpression(Expression*& curr); void pushBlockElements(Block* curr, size_t start, size_t end); - void visitBlock(Block *curr); + void visitBlock(Block* curr); // Gets a block of expressions. If it's just one, return that singleton. Expression* getBlockOrSingleton(Type type); - void visitIf(If *curr); - void visitLoop(Loop *curr); + void visitIf(If* curr); + void visitLoop(Loop* curr); BreakTarget getBreakTarget(int32_t offset); void visitBreak(Break *curr, uint8_t code); - void visitSwitch(Switch *curr); + void visitSwitch(Switch* curr); template<typename T> void fillCall(T* call, FunctionType* type) { @@ -990,11 +930,11 @@ public: } Expression* visitCall(); - void visitCallIndirect(CallIndirect *curr); - void visitGetLocal(GetLocal *curr); + void visitCallIndirect(CallIndirect* curr); + void visitGetLocal(GetLocal* curr); void visitSetLocal(SetLocal *curr, uint8_t code); - void visitGetGlobal(GetGlobal *curr); - void visitSetGlobal(SetGlobal *curr); + void visitGetGlobal(GetGlobal* curr); + void visitSetGlobal(SetGlobal* curr); void readMemoryAccess(Address& alignment, Address& offset); bool maybeVisitLoad(Expression*& out, uint8_t code, bool isAtomic); bool maybeVisitStore(Expression*& out, uint8_t code, bool isAtomic); @@ -1005,12 +945,12 @@ public: bool maybeVisitConst(Expression*& out, uint8_t code); bool maybeVisitUnary(Expression*& out, uint8_t code); bool maybeVisitBinary(Expression*& out, uint8_t code); - void visitSelect(Select *curr); - void visitReturn(Return *curr); + void visitSelect(Select* curr); + void visitReturn(Return* curr); bool maybeVisitHost(Expression*& out, uint8_t code); - void visitNop(Nop *curr); - void visitUnreachable(Unreachable *curr); - void visitDrop(Drop *curr); + void visitNop(Nop* curr); + void visitUnreachable(Unreachable* curr); + void visitDrop(Drop* curr); void throwError(std::string text); }; diff --git a/src/wasm-stack.h b/src/wasm-stack.h new file mode 100644 index 000000000..8648148b6 --- /dev/null +++ b/src/wasm-stack.h @@ -0,0 +1,1244 @@ +/* + * Copyright 2018 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef wasm_stack_h +#define wasm_stack_h + +#include "wasm.h" +#include "wasm-binary.h" +#include "wasm-traversal.h" +#include "ir/branch-utils.h" +#include "pass.h" + +namespace wasm { + +// Stack IR: an IR that represents code at the wasm binary format level, +// that is, a stack machine. Binaryen IR is *almost* identical to this, +// but as documented in README.md, there are a few differences, intended +// to make Binaryen IR fast and flexible for maximal optimization. Stack +// IR, on the other hand, is designed to optimize a few final things that +// can only really be done when modeling the stack machine format precisely. + +// Currently the benefits of Stack IR are minor, less than 1% reduction in +// code size. For that reason it is just a secondary IR, run optionally +// after the main IR has been optimized. However, if we improve Stack IR +// optimizations to a point where they have a significant impact, it's +// possible that could motivate investigating replacing the main IR with Stack +// IR (so that we have just a single IR). + +// A StackIR instance (see wasm.h) contains a linear sequence of +// stack instructions. This representation is very simple: just a single vector of +// all instructions, in order. +// * nullptr is allowed in the vector, representing something to skip. +// This is useful as a common thing optimizations do is remove instructions, +// so this way we can do so without compacting the vector all the time. + +// A Stack IR instruction. Most just directly reflect a Binaryen IR node, +// but we need extra ones for certain things. +class StackInst { +public: + StackInst(MixedArena&) {} + + enum Op { + Basic, // an instruction directly corresponding to a non-control-flow + // Binaryen IR node + BlockBegin, // the beginning of a block + BlockEnd, // the ending of a block + IfBegin, // the beginning of a if + IfElse, // the else of a if + IfEnd, // the ending of a if + LoopBegin, // the beginning of a loop + LoopEnd, // the ending of a loop + } op; + + Expression* origin; // the expression this originates from + + Type type; // the type - usually identical to the origin type, but + // e.g. wasm has no unreachable blocks, they must be none +}; + +} // namespace wasm + +namespace std { + +inline std::ostream& operator<<(std::ostream& o, wasm::StackInst& inst) { + switch (inst.op) { + case wasm::StackInst::Basic: { + std::cout << wasm::getExpressionName(inst.origin) << " (" << wasm::printType(inst.type) << ')'; + break; + } + case wasm::StackInst::BlockBegin: + case wasm::StackInst::IfBegin: + case wasm::StackInst::LoopBegin: { + std::cout << wasm::getExpressionName(inst.origin); + break; + } + case wasm::StackInst::BlockEnd: + case wasm::StackInst::IfEnd: + case wasm::StackInst::LoopEnd: { + std::cout << "end (" << wasm::printType(inst.type) << ')'; + break; + } + case wasm::StackInst::IfElse: { + std::cout << "else"; + break; + } + default: WASM_UNREACHABLE(); + } + return o; +} + +inline std::ostream& operator<<(std::ostream& o, wasm::StackIR& insts) { + wasm::Index index = 0; + for (wasm::Index i = 0; i < insts.size(); i++) { + auto* inst = insts[i]; + if (!inst) continue; + std::cout << index++ << ' ' << *inst << '\n'; + } + return o; +} + +} // namespace std + +namespace wasm { + +// +// StackWriter: Writes out binary format stack machine code for a Binaryen IR expression +// +// A stack writer has one of three modes: +// * Binaryen2Binary: directly writes the expression to wasm binary +// * Binaryen2Stack: queues the expressions linearly, in Stack IR (SIR) +// * Stack2Binary: emits SIR to wasm binary +// +// Direct writing, in Binaryen2Binary, is fast. Otherwise, Binaryen2Stack +// lets you optimize the Stack IR before running Stack2Binary (but the cost +// is that the extra IR in the middle makes things 20% slower than direct +// Binaryen2Binary). +// +// To reduce the amount of boilerplate code here, we implement all 3 in +// a single class, templated on the mode. This allows compilers to trivially +// optimize out irrelevant code paths, and there should be no runtime +// downside. +// + +enum class StackWriterMode { + Binaryen2Binary, Binaryen2Stack, Stack2Binary +}; + +template<StackWriterMode Mode, typename Parent> +class StackWriter : public Visitor<StackWriter<Mode, Parent>> { +public: + StackWriter(Parent& parent, BufferWithRandomAccess& o, bool sourceMap=false, bool debug=false) + : parent(parent), o(o), sourceMap(sourceMap), debug(debug), allocator(parent.getModule()->allocator) {} + + StackIR stackIR; // filled in Binaryen2Stack, read in Stack2Binary + + std::map<Type, size_t> numLocalsByType; // type => number of locals of that type in the compact form + + // visits a node, emitting the proper code for it + void visit(Expression* curr); + // emits a node, but if it is a block with no name, emit a list of its contents + void visitPossibleBlockContents(Expression* curr); + // visits a child node. (in some modes we may not want to visit children, + // that logic is handled here) + void visitChild(Expression* curr); + + void visitBlock(Block* curr); + void visitBlockEnd(Block* curr); + + void visitIf(If* curr); + void visitIfElse(If* curr); + void visitIfEnd(If* curr); + + void visitLoop(Loop* curr); + void visitLoopEnd(Loop* curr); + + void visitBreak(Break* curr); + void visitSwitch(Switch* curr); + void visitCall(Call* curr); + void visitCallImport(CallImport* curr); + void visitCallIndirect(CallIndirect* curr); + void visitGetLocal(GetLocal* curr); + void visitSetLocal(SetLocal* curr); + void visitGetGlobal(GetGlobal* curr); + void visitSetGlobal(SetGlobal* curr); + void visitLoad(Load* curr); + void visitStore(Store* curr); + void visitAtomicRMW(AtomicRMW* curr); + void visitAtomicCmpxchg(AtomicCmpxchg* curr); + void visitAtomicWait(AtomicWait* curr); + void visitAtomicWake(AtomicWake* curr); + void visitConst(Const* curr); + void visitUnary(Unary* curr); + void visitBinary(Binary* curr); + void visitSelect(Select* curr); + void visitReturn(Return* curr); + void visitHost(Host* curr); + void visitNop(Nop* curr); + void visitUnreachable(Unreachable* curr); + void visitDrop(Drop* curr); + + // We need to emit extra unreachable opcodes in some cases + void emitExtraUnreachable(); + + // If we are in Binaryen2Stack, then this adds the item to the + // stack IR and returns true, which is all we need to do for + // non-control flow expressions. + bool justAddToStack(Expression* curr); + + void setFunction(Function* funcInit) { + func = funcInit; + } + + void mapLocalsAndEmitHeader(); + +protected: + Parent& parent; + BufferWithRandomAccess& o; + bool sourceMap; + bool debug; + + MixedArena& allocator; + + Function* func; + + std::map<Index, size_t> mappedLocals; // local index => index in compact form of [all int32s][all int64s]etc + + std::vector<Name> breakStack; + + int32_t getBreakIndex(Name name); + void emitMemoryAccess(size_t alignment, size_t bytes, uint32_t offset); + + void finishFunctionBody(); + + StackInst* makeStackInst(StackInst::Op op, Expression* origin); + StackInst* makeStackInst(Expression* origin) { + return makeStackInst(StackInst::Basic, origin); + } +}; + +// Write out a single expression, such as an offset for a global segment. +template<typename Parent> +class ExpressionStackWriter : StackWriter<StackWriterMode::Binaryen2Binary, Parent> { +public: + ExpressionStackWriter(Expression* curr, Parent& parent, BufferWithRandomAccess& o, bool debug=false) : + StackWriter<StackWriterMode::Binaryen2Binary, Parent>(parent, o, /* sourceMap= */ false, debug) { + this->visit(curr); + } +}; + +// Write out a function body, including the local header info. +template<typename Parent> +class FunctionStackWriter : StackWriter<StackWriterMode::Binaryen2Binary, Parent> { +public: + FunctionStackWriter(Function* funcInit, Parent& parent, BufferWithRandomAccess& o, bool sourceMap=false, bool debug=false) : + StackWriter<StackWriterMode::Binaryen2Binary, Parent>(parent, o, sourceMap, debug) { + this->setFunction(funcInit); + this->mapLocalsAndEmitHeader(); + this->visitPossibleBlockContents(this->func->body); + this->finishFunctionBody(); + } +}; + +// Use Stack IR to write the function body +template<typename Parent> +class StackIRFunctionStackWriter : StackWriter<StackWriterMode::Stack2Binary, Parent> { +public: + StackIRFunctionStackWriter(Function* funcInit, Parent& parent, BufferWithRandomAccess& o, bool debug=false) : + StackWriter<StackWriterMode::Stack2Binary, Parent>(parent, o, false, debug) { + this->setFunction(funcInit); + this->mapLocalsAndEmitHeader(); + for (auto* inst : *funcInit->stackIR) { + if (!inst) continue; // a nullptr is just something we can skip + switch (inst->op) { + case StackInst::Basic: + case StackInst::BlockBegin: + case StackInst::IfBegin: + case StackInst::LoopBegin: { + this->visit(inst->origin); + break; + } + case StackInst::BlockEnd: { + this->visitBlockEnd(inst->origin->template cast<Block>()); + break; + } + case StackInst::IfElse: { + this->visitIfElse(inst->origin->template cast<If>()); + break; + } + case StackInst::IfEnd: { + this->visitIfEnd(inst->origin->template cast<If>()); + break; + } + case StackInst::LoopEnd: { + this->visitLoopEnd(inst->origin->template cast<Loop>()); + break; + } + default: WASM_UNREACHABLE(); + } + } + this->finishFunctionBody(); + } +}; + +// +// Implementations +// + +// StackWriter + +template<StackWriterMode Mode, typename Parent> +void StackWriter<Mode, Parent>::mapLocalsAndEmitHeader() { + // Map them + for (Index i = 0; i < func->getNumParams(); i++) { + size_t curr = mappedLocals.size(); + mappedLocals[i] = curr; + } + for (auto type : func->vars) { + numLocalsByType[type]++; + } + std::map<Type, size_t> currLocalsByType; + for (Index i = func->getVarIndexBase(); i < func->getNumLocals(); i++) { + size_t index = func->getVarIndexBase(); + Type type = func->getLocalType(i); + currLocalsByType[type]++; // increment now for simplicity, must decrement it in returns + if (type == i32) { + mappedLocals[i] = index + currLocalsByType[i32] - 1; + continue; + } + index += numLocalsByType[i32]; + if (type == i64) { + mappedLocals[i] = index + currLocalsByType[i64] - 1; + continue; + } + index += numLocalsByType[i64]; + if (type == f32) { + mappedLocals[i] = index + currLocalsByType[f32] - 1; + continue; + } + index += numLocalsByType[f32]; + if (type == f64) { + mappedLocals[i] = index + currLocalsByType[f64] - 1; + continue; + } + WASM_UNREACHABLE(); + } + // Emit them. + o << U32LEB( + (numLocalsByType[i32] ? 1 : 0) + + (numLocalsByType[i64] ? 1 : 0) + + (numLocalsByType[f32] ? 1 : 0) + + (numLocalsByType[f64] ? 1 : 0) + ); + if (numLocalsByType[i32]) o << U32LEB(numLocalsByType[i32]) << binaryType(i32); + if (numLocalsByType[i64]) o << U32LEB(numLocalsByType[i64]) << binaryType(i64); + if (numLocalsByType[f32]) o << U32LEB(numLocalsByType[f32]) << binaryType(f32); + if (numLocalsByType[f64]) o << U32LEB(numLocalsByType[f64]) << binaryType(f64); +} + +template<StackWriterMode Mode, typename Parent> +void StackWriter<Mode, Parent>::visit(Expression* curr) { + if (Mode == StackWriterMode::Binaryen2Binary && sourceMap) { + parent.writeDebugLocation(curr, func); + } + Visitor<StackWriter>::visit(curr); +} + +// emits a node, but if it is a block with no name, emit a list of its contents +template<StackWriterMode Mode, typename Parent> +void StackWriter<Mode, Parent>::visitPossibleBlockContents(Expression* curr) { + auto* block = curr->dynCast<Block>(); + if (!block || BranchUtils::BranchSeeker::hasNamed(block, block->name)) { + visitChild(curr); + return; + } + for (auto* child : block->list) { + visitChild(child); + } + if (block->type == unreachable && block->list.back()->type != unreachable) { + // similar to in visitBlock, here we could skip emitting the block itself, + // but must still end the 'block' (the contents, really) with an unreachable + emitExtraUnreachable(); + } +} + +template<StackWriterMode Mode, typename Parent> +void StackWriter<Mode, Parent>::visitChild(Expression* curr) { + // In stack => binary, we don't need to visit child nodes, everything + // is already in the linear stream. + if (Mode != StackWriterMode::Stack2Binary) { + visit(curr); + } +} + +template<StackWriterMode Mode, typename Parent> +void StackWriter<Mode, Parent>::visitBlock(Block* curr) { + if (Mode == StackWriterMode::Binaryen2Stack) { + stackIR.push_back(makeStackInst(StackInst::BlockBegin, curr)); + } else { + if (debug) std::cerr << "zz node: Block" << std::endl; + o << int8_t(BinaryConsts::Block); + o << binaryType(curr->type != unreachable ? curr->type : none); + } + breakStack.push_back(curr->name); // TODO: we don't need to do this in Binaryen2Stack + Index i = 0; + for (auto* child : curr->list) { + if (debug) std::cerr << " " << size_t(curr) << "\n zz Block element " << i++ << std::endl; + visitChild(child); + } + // in Stack2Binary the block ending is in the stream later on + if (Mode == StackWriterMode::Stack2Binary) { + return; + } + visitBlockEnd(curr); +} + +template<StackWriterMode Mode, typename Parent> +void StackWriter<Mode, Parent>::visitBlockEnd(Block* curr) { + if (curr->type == unreachable) { + // an unreachable block is one that cannot be exited. We cannot encode this directly + // in wasm, where blocks must be none,i32,i64,f32,f64. Since the block cannot be + // exited, we can emit an unreachable at the end, and that will always be valid, + // and then the block is ok as a none + emitExtraUnreachable(); + } + if (Mode == StackWriterMode::Binaryen2Stack) { + stackIR.push_back(makeStackInst(StackInst::BlockEnd, curr)); + } else { + o << int8_t(BinaryConsts::End); + } + assert(!breakStack.empty()); + breakStack.pop_back(); + if (curr->type == unreachable) { + // and emit an unreachable *outside* the block too, so later things can pop anything + emitExtraUnreachable(); + } +} + +template<StackWriterMode Mode, typename Parent> +void StackWriter<Mode, Parent>::visitIf(If* curr) { + if (debug) std::cerr << "zz node: If" << std::endl; + if (curr->condition->type == unreachable) { + // this if-else is unreachable because of the condition, i.e., the condition + // does not exit. So don't emit the if, but do consume the condition + visitChild(curr->condition); + emitExtraUnreachable(); + return; + } + visitChild(curr->condition); + if (Mode == StackWriterMode::Binaryen2Stack) { + stackIR.push_back(makeStackInst(StackInst::IfBegin, curr)); + } else { + o << int8_t(BinaryConsts::If); + o << binaryType(curr->type != unreachable ? curr->type : none); + } + breakStack.push_back(IMPOSSIBLE_CONTINUE); // the binary format requires this; we have a block if we need one + // TODO: optimize this in Stack IR (if child is a block, we + // may break to this instead) + visitPossibleBlockContents(curr->ifTrue); // TODO: emit block contents directly, if possible + if (Mode == StackWriterMode::Stack2Binary) { + return; + } + if (curr->ifFalse) { + visitIfElse(curr); + } + visitIfEnd(curr); +} + +template<StackWriterMode Mode, typename Parent> +void StackWriter<Mode, Parent>::visitIfElse(If* curr) { + assert(!breakStack.empty()); + breakStack.pop_back(); + if (Mode == StackWriterMode::Binaryen2Stack) { + stackIR.push_back(makeStackInst(StackInst::IfElse, curr)); + } else { + o << int8_t(BinaryConsts::Else); + } + breakStack.push_back(IMPOSSIBLE_CONTINUE); // TODO ditto + visitPossibleBlockContents(curr->ifFalse); +} + +template<StackWriterMode Mode, typename Parent> +void StackWriter<Mode, Parent>::visitIfEnd(If* curr) { + assert(!breakStack.empty()); + breakStack.pop_back(); + if (Mode == StackWriterMode::Binaryen2Stack) { + stackIR.push_back(makeStackInst(StackInst::IfEnd, curr)); + } else { + o << int8_t(BinaryConsts::End); + } + if (curr->type == unreachable) { + // we already handled the case of the condition being unreachable. otherwise, + // we may still be unreachable, if we are an if-else with both sides unreachable. + // wasm does not allow this to be emitted directly, so we must do something more. we could do + // better, but for now we emit an extra unreachable instruction after the if, so it is not consumed itself, + assert(curr->ifFalse); + emitExtraUnreachable(); + } +} + +template<StackWriterMode Mode, typename Parent> +void StackWriter<Mode, Parent>::visitLoop(Loop* curr) { + if (debug) std::cerr << "zz node: Loop" << std::endl; + if (Mode == StackWriterMode::Binaryen2Stack) { + stackIR.push_back(makeStackInst(StackInst::LoopBegin, curr)); + } else { + o << int8_t(BinaryConsts::Loop); + o << binaryType(curr->type != unreachable ? curr->type : none); + } + breakStack.push_back(curr->name); + visitPossibleBlockContents(curr->body); + if (Mode == StackWriterMode::Stack2Binary) { + return; + } + visitLoopEnd(curr); +} + +template<StackWriterMode Mode, typename Parent> +void StackWriter<Mode, Parent>::visitLoopEnd(Loop* curr) { + assert(!breakStack.empty()); + breakStack.pop_back(); + if (curr->type == unreachable) { + // we emitted a loop without a return type, and the body might be + // block contents, so ensure it is not consumed + emitExtraUnreachable(); + } + if (Mode == StackWriterMode::Binaryen2Stack) { + stackIR.push_back(makeStackInst(StackInst::LoopEnd, curr)); + } else { + o << int8_t(BinaryConsts::End); + } + if (curr->type == unreachable) { + // we emitted a loop without a return type, so it must not be consumed + emitExtraUnreachable(); + } +} + +template<StackWriterMode Mode, typename Parent> +void StackWriter<Mode, Parent>::visitBreak(Break* curr) { + if (debug) std::cerr << "zz node: Break" << std::endl; + if (curr->value) { + visitChild(curr->value); + } + if (curr->condition) visitChild(curr->condition); + if (!justAddToStack(curr)) { + o << int8_t(curr->condition ? BinaryConsts::BrIf : BinaryConsts::Br) + << U32LEB(getBreakIndex(curr->name)); + } + if (curr->condition && curr->type == unreachable) { + // a br_if is normally none or emits a value. if it is unreachable, + // then either the condition or the value is unreachable, which is + // extremely rare, and may require us to make the stack polymorphic + // (if the block we branch to has a value, we may lack one as we + // are not a reachable branch; the wasm spec on the other hand does + // presume the br_if emits a value of the right type, even if it + // popped unreachable) + emitExtraUnreachable(); + } +} + +template<StackWriterMode Mode, typename Parent> +void StackWriter<Mode, Parent>::visitSwitch(Switch* curr) { + if (debug) std::cerr << "zz node: Switch" << std::endl; + if (curr->value) { + visitChild(curr->value); + } + visitChild(curr->condition); + if (!BranchUtils::isBranchReachable(curr)) { + // if the branch is not reachable, then it's dangerous to emit it, as + // wasm type checking rules are different, especially in unreachable + // code. so just don't emit that unreachable code. + emitExtraUnreachable(); + return; + } + if (justAddToStack(curr)) return; + o << int8_t(BinaryConsts::TableSwitch) << U32LEB(curr->targets.size()); + for (auto target : curr->targets) { + o << U32LEB(getBreakIndex(target)); + } + o << U32LEB(getBreakIndex(curr->default_)); +} + +template<StackWriterMode Mode, typename Parent> +void StackWriter<Mode, Parent>::visitCall(Call* curr) { + if (debug) std::cerr << "zz node: Call" << std::endl; + for (auto* operand : curr->operands) { + visitChild(operand); + } + if (!justAddToStack(curr)) { + o << int8_t(BinaryConsts::CallFunction) << U32LEB(parent.getFunctionIndex(curr->target)); + } + if (curr->type == unreachable) { // TODO FIXME: this and similar can be removed + emitExtraUnreachable(); + } +} + +template<StackWriterMode Mode, typename Parent> +void StackWriter<Mode, Parent>::visitCallImport(CallImport* curr) { + if (debug) std::cerr << "zz node: CallImport" << std::endl; + for (auto* operand : curr->operands) { + visitChild(operand); + } + if (justAddToStack(curr)) return; + o << int8_t(BinaryConsts::CallFunction) << U32LEB(parent.getFunctionIndex(curr->target)); +} + +template<StackWriterMode Mode, typename Parent> +void StackWriter<Mode, Parent>::visitCallIndirect(CallIndirect* curr) { + if (debug) std::cerr << "zz node: CallIndirect" << std::endl; + for (auto* operand : curr->operands) { + visitChild(operand); + } + visitChild(curr->target); + if (!justAddToStack(curr)) { + o << int8_t(BinaryConsts::CallIndirect) + << U32LEB(parent.getFunctionTypeIndex(curr->fullType)) + << U32LEB(0); // Reserved flags field + } + if (curr->type == unreachable) { + emitExtraUnreachable(); + } +} + +template<StackWriterMode Mode, typename Parent> +void StackWriter<Mode, Parent>::visitGetLocal(GetLocal* curr) { + if (debug) std::cerr << "zz node: GetLocal " << (o.size() + 1) << std::endl; + if (justAddToStack(curr)) return; + o << int8_t(BinaryConsts::GetLocal) << U32LEB(mappedLocals[curr->index]); +} + +template<StackWriterMode Mode, typename Parent> +void StackWriter<Mode, Parent>::visitSetLocal(SetLocal* curr) { + if (debug) std::cerr << "zz node: Set|TeeLocal" << std::endl; + visitChild(curr->value); + if (!justAddToStack(curr)) { + o << int8_t(curr->isTee() ? BinaryConsts::TeeLocal : BinaryConsts::SetLocal) << U32LEB(mappedLocals[curr->index]); + } + if (curr->type == unreachable) { + emitExtraUnreachable(); + } +} + +template<StackWriterMode Mode, typename Parent> +void StackWriter<Mode, Parent>::visitGetGlobal(GetGlobal* curr) { + if (debug) std::cerr << "zz node: GetGlobal " << (o.size() + 1) << std::endl; + if (justAddToStack(curr)) return; + o << int8_t(BinaryConsts::GetGlobal) << U32LEB(parent.getGlobalIndex(curr->name)); +} + +template<StackWriterMode Mode, typename Parent> +void StackWriter<Mode, Parent>::visitSetGlobal(SetGlobal* curr) { + if (debug) std::cerr << "zz node: SetGlobal" << std::endl; + visitChild(curr->value); + if (justAddToStack(curr)) return; + o << int8_t(BinaryConsts::SetGlobal) << U32LEB(parent.getGlobalIndex(curr->name)); +} + +template<StackWriterMode Mode, typename Parent> +void StackWriter<Mode, Parent>::visitLoad(Load* curr) { + if (debug) std::cerr << "zz node: Load" << std::endl; + visitChild(curr->ptr); + if (curr->type == unreachable) { + // don't even emit it; we don't know the right type + emitExtraUnreachable(); + return; + } + if (justAddToStack(curr)) return; + if (!curr->isAtomic) { + switch (curr->type) { + case i32: { + switch (curr->bytes) { + case 1: o << int8_t(curr->signed_ ? BinaryConsts::I32LoadMem8S : BinaryConsts::I32LoadMem8U); break; + case 2: o << int8_t(curr->signed_ ? BinaryConsts::I32LoadMem16S : BinaryConsts::I32LoadMem16U); break; + case 4: o << int8_t(BinaryConsts::I32LoadMem); break; + default: abort(); + } + break; + } + case i64: { + switch (curr->bytes) { + case 1: o << int8_t(curr->signed_ ? BinaryConsts::I64LoadMem8S : BinaryConsts::I64LoadMem8U); break; + case 2: o << int8_t(curr->signed_ ? BinaryConsts::I64LoadMem16S : BinaryConsts::I64LoadMem16U); break; + case 4: o << int8_t(curr->signed_ ? BinaryConsts::I64LoadMem32S : BinaryConsts::I64LoadMem32U); break; + case 8: o << int8_t(BinaryConsts::I64LoadMem); break; + default: abort(); + } + break; + } + case f32: o << int8_t(BinaryConsts::F32LoadMem); break; + case f64: o << int8_t(BinaryConsts::F64LoadMem); break; + case unreachable: return; // the pointer is unreachable, so we are never reached; just don't emit a load + default: WASM_UNREACHABLE(); + } + } else { + o << int8_t(BinaryConsts::AtomicPrefix); + switch (curr->type) { + case i32: { + switch (curr->bytes) { + case 1: o << int8_t(BinaryConsts::I32AtomicLoad8U); break; + case 2: o << int8_t(BinaryConsts::I32AtomicLoad16U); break; + case 4: o << int8_t(BinaryConsts::I32AtomicLoad); break; + default: WASM_UNREACHABLE(); + } + break; + } + case i64: { + switch (curr->bytes) { + case 1: o << int8_t(BinaryConsts::I64AtomicLoad8U); break; + case 2: o << int8_t(BinaryConsts::I64AtomicLoad16U); break; + case 4: o << int8_t(BinaryConsts::I64AtomicLoad32U); break; + case 8: o << int8_t(BinaryConsts::I64AtomicLoad); break; + default: WASM_UNREACHABLE(); + } + break; + } + case unreachable: return; + default: WASM_UNREACHABLE(); + } + } + emitMemoryAccess(curr->align, curr->bytes, curr->offset); +} + +template<StackWriterMode Mode, typename Parent> +void StackWriter<Mode, Parent>::visitStore(Store* curr) { + if (debug) std::cerr << "zz node: Store" << std::endl; + visitChild(curr->ptr); + visitChild(curr->value); + if (curr->type == unreachable) { + // don't even emit it; we don't know the right type + emitExtraUnreachable(); + return; + } + if (justAddToStack(curr)) return; + if (!curr->isAtomic) { + switch (curr->valueType) { + case i32: { + switch (curr->bytes) { + case 1: o << int8_t(BinaryConsts::I32StoreMem8); break; + case 2: o << int8_t(BinaryConsts::I32StoreMem16); break; + case 4: o << int8_t(BinaryConsts::I32StoreMem); break; + default: abort(); + } + break; + } + case i64: { + switch (curr->bytes) { + case 1: o << int8_t(BinaryConsts::I64StoreMem8); break; + case 2: o << int8_t(BinaryConsts::I64StoreMem16); break; + case 4: o << int8_t(BinaryConsts::I64StoreMem32); break; + case 8: o << int8_t(BinaryConsts::I64StoreMem); break; + default: abort(); + } + break; + } + case f32: o << int8_t(BinaryConsts::F32StoreMem); break; + case f64: o << int8_t(BinaryConsts::F64StoreMem); break; + default: abort(); + } + } else { + o << int8_t(BinaryConsts::AtomicPrefix); + switch (curr->valueType) { + case i32: { + switch (curr->bytes) { + case 1: o << int8_t(BinaryConsts::I32AtomicStore8); break; + case 2: o << int8_t(BinaryConsts::I32AtomicStore16); break; + case 4: o << int8_t(BinaryConsts::I32AtomicStore); break; + default: WASM_UNREACHABLE(); + } + break; + } + case i64: { + switch (curr->bytes) { + case 1: o << int8_t(BinaryConsts::I64AtomicStore8); break; + case 2: o << int8_t(BinaryConsts::I64AtomicStore16); break; + case 4: o << int8_t(BinaryConsts::I64AtomicStore32); break; + case 8: o << int8_t(BinaryConsts::I64AtomicStore); break; + default: WASM_UNREACHABLE(); + } + break; + } + default: WASM_UNREACHABLE(); + } + } + emitMemoryAccess(curr->align, curr->bytes, curr->offset); +} + +template<StackWriterMode Mode, typename Parent> +void StackWriter<Mode, Parent>::visitAtomicRMW(AtomicRMW* curr) { + if (debug) std::cerr << "zz node: AtomicRMW" << std::endl; + visitChild(curr->ptr); + // stop if the rest isn't reachable anyhow + if (curr->ptr->type == unreachable) return; + visitChild(curr->value); + if (curr->value->type == unreachable) return; + if (curr->type == unreachable) { + // don't even emit it; we don't know the right type + emitExtraUnreachable(); + return; + } + if (justAddToStack(curr)) return; + + o << int8_t(BinaryConsts::AtomicPrefix); + +#define CASE_FOR_OP(Op) \ + case Op: \ + switch (curr->type) { \ + case i32: \ + switch (curr->bytes) { \ + case 1: o << int8_t(BinaryConsts::I32AtomicRMW##Op##8U); break; \ + case 2: o << int8_t(BinaryConsts::I32AtomicRMW##Op##16U); break; \ + case 4: o << int8_t(BinaryConsts::I32AtomicRMW##Op); break; \ + default: WASM_UNREACHABLE(); \ + } \ + break; \ + case i64: \ + switch (curr->bytes) { \ + case 1: o << int8_t(BinaryConsts::I64AtomicRMW##Op##8U); break; \ + case 2: o << int8_t(BinaryConsts::I64AtomicRMW##Op##16U); break; \ + case 4: o << int8_t(BinaryConsts::I64AtomicRMW##Op##32U); break; \ + case 8: o << int8_t(BinaryConsts::I64AtomicRMW##Op); break; \ + default: WASM_UNREACHABLE(); \ + } \ + break; \ + default: WASM_UNREACHABLE(); \ + } \ + break + + switch(curr->op) { + CASE_FOR_OP(Add); + CASE_FOR_OP(Sub); + CASE_FOR_OP(And); + CASE_FOR_OP(Or); + CASE_FOR_OP(Xor); + CASE_FOR_OP(Xchg); + default: WASM_UNREACHABLE(); + } +#undef CASE_FOR_OP + + emitMemoryAccess(curr->bytes, curr->bytes, curr->offset); +} + +template<StackWriterMode Mode, typename Parent> +void StackWriter<Mode, Parent>::visitAtomicCmpxchg(AtomicCmpxchg* curr) { + if (debug) std::cerr << "zz node: AtomicCmpxchg" << std::endl; + visitChild(curr->ptr); + // stop if the rest isn't reachable anyhow + if (curr->ptr->type == unreachable) return; + visitChild(curr->expected); + if (curr->expected->type == unreachable) return; + visitChild(curr->replacement); + if (curr->replacement->type == unreachable) return; + if (curr->type == unreachable) { + // don't even emit it; we don't know the right type + emitExtraUnreachable(); + return; + } + if (justAddToStack(curr)) return; + + o << int8_t(BinaryConsts::AtomicPrefix); + switch (curr->type) { + case i32: + switch (curr->bytes) { + case 1: o << int8_t(BinaryConsts::I32AtomicCmpxchg8U); break; + case 2: o << int8_t(BinaryConsts::I32AtomicCmpxchg16U); break; + case 4: o << int8_t(BinaryConsts::I32AtomicCmpxchg); break; + default: WASM_UNREACHABLE(); + } + break; + case i64: + switch (curr->bytes) { + case 1: o << int8_t(BinaryConsts::I64AtomicCmpxchg8U); break; + case 2: o << int8_t(BinaryConsts::I64AtomicCmpxchg16U); break; + case 4: o << int8_t(BinaryConsts::I64AtomicCmpxchg32U); break; + case 8: o << int8_t(BinaryConsts::I64AtomicCmpxchg); break; + default: WASM_UNREACHABLE(); + } + break; + default: WASM_UNREACHABLE(); + } + emitMemoryAccess(curr->bytes, curr->bytes, curr->offset); +} + +template<StackWriterMode Mode, typename Parent> +void StackWriter<Mode, Parent>::visitAtomicWait(AtomicWait* curr) { + if (debug) std::cerr << "zz node: AtomicWait" << std::endl; + visitChild(curr->ptr); + // stop if the rest isn't reachable anyhow + if (curr->ptr->type == unreachable) return; + visitChild(curr->expected); + if (curr->expected->type == unreachable) return; + visitChild(curr->timeout); + if (curr->timeout->type == unreachable) return; + if (justAddToStack(curr)) return; + + o << int8_t(BinaryConsts::AtomicPrefix); + switch (curr->expectedType) { + case i32: { + o << int8_t(BinaryConsts::I32AtomicWait); + emitMemoryAccess(4, 4, 0); + break; + } + case i64: { + o << int8_t(BinaryConsts::I64AtomicWait); + emitMemoryAccess(8, 8, 0); + break; + } + default: WASM_UNREACHABLE(); + } +} + +template<StackWriterMode Mode, typename Parent> +void StackWriter<Mode, Parent>::visitAtomicWake(AtomicWake* curr) { + if (debug) std::cerr << "zz node: AtomicWake" << std::endl; + visitChild(curr->ptr); + // stop if the rest isn't reachable anyhow + if (curr->ptr->type == unreachable) return; + visitChild(curr->wakeCount); + if (curr->wakeCount->type == unreachable) return; + if (justAddToStack(curr)) return; + + o << int8_t(BinaryConsts::AtomicPrefix) << int8_t(BinaryConsts::AtomicWake); + emitMemoryAccess(4, 4, 0); +} + +template<StackWriterMode Mode, typename Parent> +void StackWriter<Mode, Parent>::visitConst(Const* curr) { + if (debug) std::cerr << "zz node: Const" << curr << " : " << curr->type << std::endl; + if (justAddToStack(curr)) return; + switch (curr->type) { + case i32: { + o << int8_t(BinaryConsts::I32Const) << S32LEB(curr->value.geti32()); + break; + } + case i64: { + o << int8_t(BinaryConsts::I64Const) << S64LEB(curr->value.geti64()); + break; + } + case f32: { + o << int8_t(BinaryConsts::F32Const) << curr->value.reinterpreti32(); + break; + } + case f64: { + o << int8_t(BinaryConsts::F64Const) << curr->value.reinterpreti64(); + break; + } + default: abort(); + } + if (debug) std::cerr << "zz const node done.\n"; +} + +template<StackWriterMode Mode, typename Parent> +void StackWriter<Mode, Parent>::visitUnary(Unary* curr) { + if (debug) std::cerr << "zz node: Unary" << std::endl; + visitChild(curr->value); + if (curr->type == unreachable) { + emitExtraUnreachable(); + return; + } + if (justAddToStack(curr)) return; + switch (curr->op) { + case ClzInt32: o << int8_t(BinaryConsts::I32Clz); break; + case CtzInt32: o << int8_t(BinaryConsts::I32Ctz); break; + case PopcntInt32: o << int8_t(BinaryConsts::I32Popcnt); break; + case EqZInt32: o << int8_t(BinaryConsts::I32EqZ); break; + case ClzInt64: o << int8_t(BinaryConsts::I64Clz); break; + case CtzInt64: o << int8_t(BinaryConsts::I64Ctz); break; + case PopcntInt64: o << int8_t(BinaryConsts::I64Popcnt); break; + case EqZInt64: o << int8_t(BinaryConsts::I64EqZ); break; + case NegFloat32: o << int8_t(BinaryConsts::F32Neg); break; + case AbsFloat32: o << int8_t(BinaryConsts::F32Abs); break; + case CeilFloat32: o << int8_t(BinaryConsts::F32Ceil); break; + case FloorFloat32: o << int8_t(BinaryConsts::F32Floor); break; + case TruncFloat32: o << int8_t(BinaryConsts::F32Trunc); break; + case NearestFloat32: o << int8_t(BinaryConsts::F32NearestInt); break; + case SqrtFloat32: o << int8_t(BinaryConsts::F32Sqrt); break; + case NegFloat64: o << int8_t(BinaryConsts::F64Neg); break; + case AbsFloat64: o << int8_t(BinaryConsts::F64Abs); break; + case CeilFloat64: o << int8_t(BinaryConsts::F64Ceil); break; + case FloorFloat64: o << int8_t(BinaryConsts::F64Floor); break; + case TruncFloat64: o << int8_t(BinaryConsts::F64Trunc); break; + case NearestFloat64: o << int8_t(BinaryConsts::F64NearestInt); break; + case SqrtFloat64: o << int8_t(BinaryConsts::F64Sqrt); break; + case ExtendSInt32: o << int8_t(BinaryConsts::I64STruncI32); break; + case ExtendUInt32: o << int8_t(BinaryConsts::I64UTruncI32); break; + case WrapInt64: o << int8_t(BinaryConsts::I32ConvertI64); break; + case TruncUFloat32ToInt32: o << int8_t(BinaryConsts::I32UTruncF32); break; + case TruncUFloat32ToInt64: o << int8_t(BinaryConsts::I64UTruncF32); break; + case TruncSFloat32ToInt32: o << int8_t(BinaryConsts::I32STruncF32); break; + case TruncSFloat32ToInt64: o << int8_t(BinaryConsts::I64STruncF32); break; + case TruncUFloat64ToInt32: o << int8_t(BinaryConsts::I32UTruncF64); break; + case TruncUFloat64ToInt64: o << int8_t(BinaryConsts::I64UTruncF64); break; + case TruncSFloat64ToInt32: o << int8_t(BinaryConsts::I32STruncF64); break; + case TruncSFloat64ToInt64: o << int8_t(BinaryConsts::I64STruncF64); break; + case ConvertUInt32ToFloat32: o << int8_t(BinaryConsts::F32UConvertI32); break; + case ConvertUInt32ToFloat64: o << int8_t(BinaryConsts::F64UConvertI32); break; + case ConvertSInt32ToFloat32: o << int8_t(BinaryConsts::F32SConvertI32); break; + case ConvertSInt32ToFloat64: o << int8_t(BinaryConsts::F64SConvertI32); break; + case ConvertUInt64ToFloat32: o << int8_t(BinaryConsts::F32UConvertI64); break; + case ConvertUInt64ToFloat64: o << int8_t(BinaryConsts::F64UConvertI64); break; + case ConvertSInt64ToFloat32: o << int8_t(BinaryConsts::F32SConvertI64); break; + case ConvertSInt64ToFloat64: o << int8_t(BinaryConsts::F64SConvertI64); break; + case DemoteFloat64: o << int8_t(BinaryConsts::F32ConvertF64); break; + case PromoteFloat32: o << int8_t(BinaryConsts::F64ConvertF32); break; + case ReinterpretFloat32: o << int8_t(BinaryConsts::I32ReinterpretF32); break; + case ReinterpretFloat64: o << int8_t(BinaryConsts::I64ReinterpretF64); break; + case ReinterpretInt32: o << int8_t(BinaryConsts::F32ReinterpretI32); break; + case ReinterpretInt64: o << int8_t(BinaryConsts::F64ReinterpretI64); break; + case ExtendS8Int32: o << int8_t(BinaryConsts::I32ExtendS8); break; + case ExtendS16Int32: o << int8_t(BinaryConsts::I32ExtendS16); break; + case ExtendS8Int64: o << int8_t(BinaryConsts::I64ExtendS8); break; + case ExtendS16Int64: o << int8_t(BinaryConsts::I64ExtendS16); break; + case ExtendS32Int64: o << int8_t(BinaryConsts::I64ExtendS32); break; + default: abort(); + } +} + +template<StackWriterMode Mode, typename Parent> +void StackWriter<Mode, Parent>::visitBinary(Binary* curr) { + if (debug) std::cerr << "zz node: Binary" << std::endl; + visitChild(curr->left); + visitChild(curr->right); + if (curr->type == unreachable) { + emitExtraUnreachable(); + return; + } + if (justAddToStack(curr)) return; + switch (curr->op) { + case AddInt32: o << int8_t(BinaryConsts::I32Add); break; + case SubInt32: o << int8_t(BinaryConsts::I32Sub); break; + case MulInt32: o << int8_t(BinaryConsts::I32Mul); break; + case DivSInt32: o << int8_t(BinaryConsts::I32DivS); break; + case DivUInt32: o << int8_t(BinaryConsts::I32DivU); break; + case RemSInt32: o << int8_t(BinaryConsts::I32RemS); break; + case RemUInt32: o << int8_t(BinaryConsts::I32RemU); break; + case AndInt32: o << int8_t(BinaryConsts::I32And); break; + case OrInt32: o << int8_t(BinaryConsts::I32Or); break; + case XorInt32: o << int8_t(BinaryConsts::I32Xor); break; + case ShlInt32: o << int8_t(BinaryConsts::I32Shl); break; + case ShrUInt32: o << int8_t(BinaryConsts::I32ShrU); break; + case ShrSInt32: o << int8_t(BinaryConsts::I32ShrS); break; + case RotLInt32: o << int8_t(BinaryConsts::I32RotL); break; + case RotRInt32: o << int8_t(BinaryConsts::I32RotR); break; + case EqInt32: o << int8_t(BinaryConsts::I32Eq); break; + case NeInt32: o << int8_t(BinaryConsts::I32Ne); break; + case LtSInt32: o << int8_t(BinaryConsts::I32LtS); break; + case LtUInt32: o << int8_t(BinaryConsts::I32LtU); break; + case LeSInt32: o << int8_t(BinaryConsts::I32LeS); break; + case LeUInt32: o << int8_t(BinaryConsts::I32LeU); break; + case GtSInt32: o << int8_t(BinaryConsts::I32GtS); break; + case GtUInt32: o << int8_t(BinaryConsts::I32GtU); break; + case GeSInt32: o << int8_t(BinaryConsts::I32GeS); break; + case GeUInt32: o << int8_t(BinaryConsts::I32GeU); break; + + case AddInt64: o << int8_t(BinaryConsts::I64Add); break; + case SubInt64: o << int8_t(BinaryConsts::I64Sub); break; + case MulInt64: o << int8_t(BinaryConsts::I64Mul); break; + case DivSInt64: o << int8_t(BinaryConsts::I64DivS); break; + case DivUInt64: o << int8_t(BinaryConsts::I64DivU); break; + case RemSInt64: o << int8_t(BinaryConsts::I64RemS); break; + case RemUInt64: o << int8_t(BinaryConsts::I64RemU); break; + case AndInt64: o << int8_t(BinaryConsts::I64And); break; + case OrInt64: o << int8_t(BinaryConsts::I64Or); break; + case XorInt64: o << int8_t(BinaryConsts::I64Xor); break; + case ShlInt64: o << int8_t(BinaryConsts::I64Shl); break; + case ShrUInt64: o << int8_t(BinaryConsts::I64ShrU); break; + case ShrSInt64: o << int8_t(BinaryConsts::I64ShrS); break; + case RotLInt64: o << int8_t(BinaryConsts::I64RotL); break; + case RotRInt64: o << int8_t(BinaryConsts::I64RotR); break; + case EqInt64: o << int8_t(BinaryConsts::I64Eq); break; + case NeInt64: o << int8_t(BinaryConsts::I64Ne); break; + case LtSInt64: o << int8_t(BinaryConsts::I64LtS); break; + case LtUInt64: o << int8_t(BinaryConsts::I64LtU); break; + case LeSInt64: o << int8_t(BinaryConsts::I64LeS); break; + case LeUInt64: o << int8_t(BinaryConsts::I64LeU); break; + case GtSInt64: o << int8_t(BinaryConsts::I64GtS); break; + case GtUInt64: o << int8_t(BinaryConsts::I64GtU); break; + case GeSInt64: o << int8_t(BinaryConsts::I64GeS); break; + case GeUInt64: o << int8_t(BinaryConsts::I64GeU); break; + + case AddFloat32: o << int8_t(BinaryConsts::F32Add); break; + case SubFloat32: o << int8_t(BinaryConsts::F32Sub); break; + case MulFloat32: o << int8_t(BinaryConsts::F32Mul); break; + case DivFloat32: o << int8_t(BinaryConsts::F32Div); break; + case CopySignFloat32: o << int8_t(BinaryConsts::F32CopySign);break; + case MinFloat32: o << int8_t(BinaryConsts::F32Min); break; + case MaxFloat32: o << int8_t(BinaryConsts::F32Max); break; + case EqFloat32: o << int8_t(BinaryConsts::F32Eq); break; + case NeFloat32: o << int8_t(BinaryConsts::F32Ne); break; + case LtFloat32: o << int8_t(BinaryConsts::F32Lt); break; + case LeFloat32: o << int8_t(BinaryConsts::F32Le); break; + case GtFloat32: o << int8_t(BinaryConsts::F32Gt); break; + case GeFloat32: o << int8_t(BinaryConsts::F32Ge); break; + + case AddFloat64: o << int8_t(BinaryConsts::F64Add); break; + case SubFloat64: o << int8_t(BinaryConsts::F64Sub); break; + case MulFloat64: o << int8_t(BinaryConsts::F64Mul); break; + case DivFloat64: o << int8_t(BinaryConsts::F64Div); break; + case CopySignFloat64: o << int8_t(BinaryConsts::F64CopySign);break; + case MinFloat64: o << int8_t(BinaryConsts::F64Min); break; + case MaxFloat64: o << int8_t(BinaryConsts::F64Max); break; + case EqFloat64: o << int8_t(BinaryConsts::F64Eq); break; + case NeFloat64: o << int8_t(BinaryConsts::F64Ne); break; + case LtFloat64: o << int8_t(BinaryConsts::F64Lt); break; + case LeFloat64: o << int8_t(BinaryConsts::F64Le); break; + case GtFloat64: o << int8_t(BinaryConsts::F64Gt); break; + case GeFloat64: o << int8_t(BinaryConsts::F64Ge); break; + default: abort(); + } +} + +template<StackWriterMode Mode, typename Parent> +void StackWriter<Mode, Parent>::visitSelect(Select* curr) { + if (debug) std::cerr << "zz node: Select" << std::endl; + visitChild(curr->ifTrue); + visitChild(curr->ifFalse); + visitChild(curr->condition); + if (curr->type == unreachable) { + emitExtraUnreachable(); + return; + } + if (justAddToStack(curr)) return; + o << int8_t(BinaryConsts::Select); +} + +template<StackWriterMode Mode, typename Parent> +void StackWriter<Mode, Parent>::visitReturn(Return* curr) { + if (debug) std::cerr << "zz node: Return" << std::endl; + if (curr->value) { + visitChild(curr->value); + } + if (justAddToStack(curr)) return; + + o << int8_t(BinaryConsts::Return); +} + +template<StackWriterMode Mode, typename Parent> +void StackWriter<Mode, Parent>::visitHost(Host* curr) { + if (debug) std::cerr << "zz node: Host" << std::endl; + switch (curr->op) { + case CurrentMemory: { + break; + } + case GrowMemory: { + visitChild(curr->operands[0]); + break; + } + default: WASM_UNREACHABLE(); + } + if (justAddToStack(curr)) return; + switch (curr->op) { + case CurrentMemory: { + o << int8_t(BinaryConsts::CurrentMemory); + break; + } + case GrowMemory: { + o << int8_t(BinaryConsts::GrowMemory); + break; + } + default: WASM_UNREACHABLE(); + } + o << U32LEB(0); // Reserved flags field +} + +template<StackWriterMode Mode, typename Parent> +void StackWriter<Mode, Parent>::visitNop(Nop* curr) { + if (debug) std::cerr << "zz node: Nop" << std::endl; + if (justAddToStack(curr)) return; + o << int8_t(BinaryConsts::Nop); +} + +template<StackWriterMode Mode, typename Parent> +void StackWriter<Mode, Parent>::visitUnreachable(Unreachable* curr) { + if (debug) std::cerr << "zz node: Unreachable" << std::endl; + if (justAddToStack(curr)) return; + o << int8_t(BinaryConsts::Unreachable); +} + +template<StackWriterMode Mode, typename Parent> +void StackWriter<Mode, Parent>::visitDrop(Drop* curr) { + if (debug) std::cerr << "zz node: Drop" << std::endl; + visitChild(curr->value); + if (justAddToStack(curr)) return; + o << int8_t(BinaryConsts::Drop); +} + +template<StackWriterMode Mode, typename Parent> +int32_t StackWriter<Mode, Parent>::getBreakIndex(Name name) { // -1 if not found + for (int i = breakStack.size() - 1; i >= 0; i--) { + if (breakStack[i] == name) { + return breakStack.size() - 1 - i; + } + } + WASM_UNREACHABLE(); +} + +template<StackWriterMode Mode, typename Parent> +void StackWriter<Mode, Parent>::emitMemoryAccess(size_t alignment, size_t bytes, uint32_t offset) { + o << U32LEB(Log2(alignment ? alignment : bytes)); + o << U32LEB(offset); +} + +template<StackWriterMode Mode, typename Parent> +void StackWriter<Mode, Parent>::emitExtraUnreachable() { + if (Mode == StackWriterMode::Binaryen2Stack) { + stackIR.push_back(makeStackInst(Builder(allocator).makeUnreachable())); + } else if (Mode == StackWriterMode::Binaryen2Binary) { + o << int8_t(BinaryConsts::Unreachable); + } +} + +template<StackWriterMode Mode, typename Parent> +bool StackWriter<Mode, Parent>::justAddToStack(Expression* curr) { + if (Mode == StackWriterMode::Binaryen2Stack) { + stackIR.push_back(makeStackInst(curr)); + return true; + } + return false; +} + +template<StackWriterMode Mode, typename Parent> +void StackWriter<Mode, Parent>::finishFunctionBody() { + o << int8_t(BinaryConsts::End); +} + +template<StackWriterMode Mode, typename Parent> +StackInst* StackWriter<Mode, Parent>::makeStackInst(StackInst::Op op, Expression* origin) { + auto* ret = allocator.alloc<StackInst>(); + ret->op = op; + ret->origin = origin; + auto stackType = origin->type; + if (origin->is<Block>() || origin->is<Loop>() || origin->is<If>()) { + if (stackType == unreachable) { + // There are no unreachable blocks, loops, or ifs. we emit extra unreachables + // to fix that up, so that they are valid as having none type. + stackType = none; + } else if (op != StackInst::BlockEnd && + op != StackInst::IfEnd && + op != StackInst::LoopEnd) { + // If a concrete type is returned, we mark the end of the construct has + // having that type (as it is pushed to the value stack at that point), + // other parts are marked as none). + stackType = none; + } + } + ret->type = stackType; + return ret; +} + +} // namespace wasm + +#endif // wasm_stack_h + diff --git a/src/wasm.h b/src/wasm.h index 45881e519..49b0881b9 100644 --- a/src/wasm.h +++ b/src/wasm.h @@ -15,29 +15,10 @@ */ // -// wasm.h: WebAssembly representation and processing library, in one -// header file. +// wasm.h: Define Binaryen IR, a representation for WebAssembly, with +// all core parts in one simple header file. // -// This represents WebAssembly in an AST format, with a focus on making -// it easy to not just inspect but also to process. For example, some -// things that this enables are: -// -// * Interpreting: See wasm-interpreter.h. -// * Optimizing: See asm2wasm.h, which performs some optimizations -// after code generation. -// * Validation: See wasm-validator.h. -// * Pretty-printing: See Print.cpp. -// - -// -// wasm.js internal WebAssembly representation design: -// -// * Unify where possible. Where size isn't a concern, combine -// classes, so binary ops and relational ops are joined. This -// simplifies that AST and makes traversals easier. -// * Optimize for size? This might justify separating if and if_else -// (so that if doesn't have an always-empty else; also it avoids -// a branch). +// For more overview, see README.md // #ifndef wasm_wasm_h @@ -601,6 +582,13 @@ public: // Globals +// Forward declarations of Stack IR, as functions can contain it, see +// the stackIR property. +// Stack IR is a secondary IR to the main IR defined in this file (Binaryen +// IR). See wasm-stack.h. +class StackInst; +typedef std::vector<StackInst*> StackIR; + class Function { public: Name name; @@ -608,8 +596,20 @@ public: std::vector<Type> params; // function locals are std::vector<Type> vars; // params plus vars Name type; // if null, it is implicit in params and result + + // The body of the function Expression* body; + // If present, this stack IR was generated from the main Binaryen IR body, + // and possibly optimized. If it is present when writing to wasm binary, + // it will be emitted instead of the main Binaryen IR. + // + // Note that no special care is taken to synchronize the two IRs - if you + // emit stack IR and then optimize the main IR, you need to recompute the + // stack IR. The Pass system will throw away Stack IR if a pass is run + // that declares it may modify Binaryen IR. + std::unique_ptr<StackIR> stackIR; + // local names. these are optional. std::map<Index, Name> localNames; std::map<Name, Index> localIndices; diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp index 6a5775151..935660d95 100644 --- a/src/wasm/wasm-binary.cpp +++ b/src/wasm/wasm-binary.cpp @@ -19,7 +19,7 @@ #include "support/bits.h" #include "wasm-binary.h" -#include "ir/branch-utils.h" +#include "wasm-stack.h" #include "ir/module-utils.h" namespace wasm { @@ -217,7 +217,7 @@ void WasmBinaryWriter::writeFunctionSignatures() { } void WasmBinaryWriter::writeExpression(Expression* curr) { - StackWriter(*this, o, debug).visit(curr); + ExpressionStackWriter<WasmBinaryWriter>(curr, *this, o, debug); } void WasmBinaryWriter::writeFunctions() { @@ -231,22 +231,16 @@ void WasmBinaryWriter::writeFunctions() { if (debug) std::cerr << "write one at" << o.size() << std::endl; size_t sizePos = writeU32LEBPlaceholder(); size_t start = o.size(); - Function* function = wasm->functions[i].get(); - if (debug) std::cerr << "writing" << function->name << std::endl; - StackWriter stackWriter(function, *this, o, sourceMap, debug); - o << U32LEB( - (stackWriter.numLocalsByType[i32] ? 1 : 0) + - (stackWriter.numLocalsByType[i64] ? 1 : 0) + - (stackWriter.numLocalsByType[f32] ? 1 : 0) + - (stackWriter.numLocalsByType[f64] ? 1 : 0) - ); - if (stackWriter.numLocalsByType[i32]) o << U32LEB(stackWriter.numLocalsByType[i32]) << binaryType(i32); - if (stackWriter.numLocalsByType[i64]) o << U32LEB(stackWriter.numLocalsByType[i64]) << binaryType(i64); - if (stackWriter.numLocalsByType[f32]) o << U32LEB(stackWriter.numLocalsByType[f32]) << binaryType(f32); - if (stackWriter.numLocalsByType[f64]) o << U32LEB(stackWriter.numLocalsByType[f64]) << binaryType(f64); - - stackWriter.visitPossibleBlockContents(function->body); - o << int8_t(BinaryConsts::End); + Function* func = wasm->functions[i].get(); + if (debug) std::cerr << "writing" << func->name << std::endl; + // Emit Stack IR if present, and if we can + if (func->stackIR && !sourceMap) { + if (debug) std::cerr << "write Stack IR" << std::endl; + StackIRFunctionStackWriter<WasmBinaryWriter>(func, *this, o, debug); + } else { + if (debug) std::cerr << "write Binaryen IR" << std::endl; + FunctionStackWriter<WasmBinaryWriter>(func, *this, o, sourceMap, debug); + } size_t size = o.size() - start; assert(size <= std::numeric_limits<uint32_t>::max()); if (debug) std::cerr << "body size: " << size << ", writing at " << sizePos << ", next starts at " << o.size() << std::endl; @@ -263,7 +257,7 @@ void WasmBinaryWriter::writeFunctions() { } } } - tableOfContents.functionBodies.emplace_back(function->name, sizePos + sizeFieldSize, size); + tableOfContents.functionBodies.emplace_back(func->name, sizePos + sizeFieldSize, size); } finishSection(start); } @@ -601,745 +595,6 @@ void WasmBinaryWriter::finishUp() { } } -// StackWriter - -void StackWriter::mapLocals() { - for (Index i = 0; i < func->getNumParams(); i++) { - size_t curr = mappedLocals.size(); - mappedLocals[i] = curr; - } - for (auto type : func->vars) { - numLocalsByType[type]++; - } - std::map<Type, size_t> currLocalsByType; - for (Index i = func->getVarIndexBase(); i < func->getNumLocals(); i++) { - size_t index = func->getVarIndexBase(); - Type type = func->getLocalType(i); - currLocalsByType[type]++; // increment now for simplicity, must decrement it in returns - if (type == i32) { - mappedLocals[i] = index + currLocalsByType[i32] - 1; - continue; - } - index += numLocalsByType[i32]; - if (type == i64) { - mappedLocals[i] = index + currLocalsByType[i64] - 1; - continue; - } - index += numLocalsByType[i64]; - if (type == f32) { - mappedLocals[i] = index + currLocalsByType[f32] - 1; - continue; - } - index += numLocalsByType[f32]; - if (type == f64) { - mappedLocals[i] = index + currLocalsByType[f64] - 1; - continue; - } - abort(); - } -} - -void StackWriter::visit(Expression* curr) { - if (sourceMap) { - parent.writeDebugLocation(curr, func); - } - Visitor<StackWriter>::visit(curr); -} - -static bool brokenTo(Block* block) { - return block->name.is() && BranchUtils::BranchSeeker::hasNamed(block, block->name); -} - -// emits a node, but if it is a block with no name, emit a list of its contents -void StackWriter::visitPossibleBlockContents(Expression* curr) { - auto* block = curr->dynCast<Block>(); - if (!block || brokenTo(block)) { - visit(curr); - return; - } - for (auto* child : block->list) { - visit(child); - } - if (block->type == unreachable && block->list.back()->type != unreachable) { - // similar to in visitBlock, here we could skip emitting the block itself, - // but must still end the 'block' (the contents, really) with an unreachable - o << int8_t(BinaryConsts::Unreachable); - } -} - -void StackWriter::visitBlock(Block *curr) { - if (debug) std::cerr << "zz node: Block" << std::endl; - o << int8_t(BinaryConsts::Block); - o << binaryType(curr->type != unreachable ? curr->type : none); - breakStack.push_back(curr->name); - Index i = 0; - for (auto* child : curr->list) { - if (debug) std::cerr << " " << size_t(curr) << "\n zz Block element " << i++ << std::endl; - visit(child); - } - breakStack.pop_back(); - if (curr->type == unreachable) { - // an unreachable block is one that cannot be exited. We cannot encode this directly - // in wasm, where blocks must be none,i32,i64,f32,f64. Since the block cannot be - // exited, we can emit an unreachable at the end, and that will always be valid, - // and then the block is ok as a none - o << int8_t(BinaryConsts::Unreachable); - } - o << int8_t(BinaryConsts::End); - if (curr->type == unreachable) { - // and emit an unreachable *outside* the block too, so later things can pop anything - o << int8_t(BinaryConsts::Unreachable); - } -} - -void StackWriter::visitIf(If *curr) { - if (debug) std::cerr << "zz node: If" << std::endl; - if (curr->condition->type == unreachable) { - // this if-else is unreachable because of the condition, i.e., the condition - // does not exit. So don't emit the if, but do consume the condition - visit(curr->condition); - o << int8_t(BinaryConsts::Unreachable); - return; - } - visit(curr->condition); - o << int8_t(BinaryConsts::If); - o << binaryType(curr->type != unreachable ? curr->type : none); - breakStack.push_back(IMPOSSIBLE_CONTINUE); // the binary format requires this; we have a block if we need one; TODO: optimize - visitPossibleBlockContents(curr->ifTrue); // TODO: emit block contents directly, if possible - breakStack.pop_back(); - if (curr->ifFalse) { - o << int8_t(BinaryConsts::Else); - breakStack.push_back(IMPOSSIBLE_CONTINUE); // TODO ditto - visitPossibleBlockContents(curr->ifFalse); - breakStack.pop_back(); - } - o << int8_t(BinaryConsts::End); - if (curr->type == unreachable) { - // we already handled the case of the condition being unreachable. otherwise, - // we may still be unreachable, if we are an if-else with both sides unreachable. - // wasm does not allow this to be emitted directly, so we must do something more. we could do - // better, but for now we emit an extra unreachable instruction after the if, so it is not consumed itself, - assert(curr->ifFalse); - o << int8_t(BinaryConsts::Unreachable); - } -} - -void StackWriter::visitLoop(Loop *curr) { - if (debug) std::cerr << "zz node: Loop" << std::endl; - o << int8_t(BinaryConsts::Loop); - o << binaryType(curr->type != unreachable ? curr->type : none); - breakStack.push_back(curr->name); - visitPossibleBlockContents(curr->body); - breakStack.pop_back(); - o << int8_t(BinaryConsts::End); - if (curr->type == unreachable) { - // we emitted a loop without a return type, so it must not be consumed - o << int8_t(BinaryConsts::Unreachable); - } -} - -void StackWriter::visitBreak(Break *curr) { - if (debug) std::cerr << "zz node: Break" << std::endl; - if (curr->value) { - visit(curr->value); - } - if (curr->condition) visit(curr->condition); - o << int8_t(curr->condition ? BinaryConsts::BrIf : BinaryConsts::Br) - << U32LEB(getBreakIndex(curr->name)); - if (curr->condition && curr->type == unreachable) { - // a br_if is normally none or emits a value. if it is unreachable, - // then either the condition or the value is unreachable, which is - // extremely rare, and may require us to make the stack polymorphic - // (if the block we branch to has a value, we may lack one as we - // are not a reachable branch; the wasm spec on the other hand does - // presume the br_if emits a value of the right type, even if it - // popped unreachable) - o << int8_t(BinaryConsts::Unreachable); - } -} - -void StackWriter::visitSwitch(Switch *curr) { - if (debug) std::cerr << "zz node: Switch" << std::endl; - if (curr->value) { - visit(curr->value); - } - visit(curr->condition); - if (!BranchUtils::isBranchReachable(curr)) { - // if the branch is not reachable, then it's dangerous to emit it, as - // wasm type checking rules are different, especially in unreachable - // code. so just don't emit that unreachable code. - o << int8_t(BinaryConsts::Unreachable); - return; - } - o << int8_t(BinaryConsts::TableSwitch) << U32LEB(curr->targets.size()); - for (auto target : curr->targets) { - o << U32LEB(getBreakIndex(target)); - } - o << U32LEB(getBreakIndex(curr->default_)); -} - -void StackWriter::visitCall(Call *curr) { - if (debug) std::cerr << "zz node: Call" << std::endl; - for (auto* operand : curr->operands) { - visit(operand); - } - o << int8_t(BinaryConsts::CallFunction) << U32LEB(parent.getFunctionIndex(curr->target)); - if (curr->type == unreachable) { - o << int8_t(BinaryConsts::Unreachable); - } -} - -void StackWriter::visitCallImport(CallImport *curr) { - if (debug) std::cerr << "zz node: CallImport" << std::endl; - for (auto* operand : curr->operands) { - visit(operand); - } - o << int8_t(BinaryConsts::CallFunction) << U32LEB(parent.getFunctionIndex(curr->target)); -} - -void StackWriter::visitCallIndirect(CallIndirect *curr) { - if (debug) std::cerr << "zz node: CallIndirect" << std::endl; - - for (auto* operand : curr->operands) { - visit(operand); - } - visit(curr->target); - o << int8_t(BinaryConsts::CallIndirect) - << U32LEB(parent.getFunctionTypeIndex(curr->fullType)) - << U32LEB(0); // Reserved flags field - if (curr->type == unreachable) { - o << int8_t(BinaryConsts::Unreachable); - } -} - -void StackWriter::visitGetLocal(GetLocal *curr) { - if (debug) std::cerr << "zz node: GetLocal " << (o.size() + 1) << std::endl; - o << int8_t(BinaryConsts::GetLocal) << U32LEB(mappedLocals[curr->index]); -} - -void StackWriter::visitSetLocal(SetLocal *curr) { - if (debug) std::cerr << "zz node: Set|TeeLocal" << std::endl; - visit(curr->value); - o << int8_t(curr->isTee() ? BinaryConsts::TeeLocal : BinaryConsts::SetLocal) << U32LEB(mappedLocals[curr->index]); - if (curr->type == unreachable) { - o << int8_t(BinaryConsts::Unreachable); - } -} - -void StackWriter::visitGetGlobal(GetGlobal *curr) { - if (debug) std::cerr << "zz node: GetGlobal " << (o.size() + 1) << std::endl; - o << int8_t(BinaryConsts::GetGlobal) << U32LEB(parent.getGlobalIndex(curr->name)); -} - -void StackWriter::visitSetGlobal(SetGlobal *curr) { - if (debug) std::cerr << "zz node: SetGlobal" << std::endl; - visit(curr->value); - o << int8_t(BinaryConsts::SetGlobal) << U32LEB(parent.getGlobalIndex(curr->name)); -} - -void StackWriter::visitLoad(Load *curr) { - if (debug) std::cerr << "zz node: Load" << std::endl; - visit(curr->ptr); - if (!curr->isAtomic) { - switch (curr->type) { - case i32: { - switch (curr->bytes) { - case 1: o << int8_t(curr->signed_ ? BinaryConsts::I32LoadMem8S : BinaryConsts::I32LoadMem8U); break; - case 2: o << int8_t(curr->signed_ ? BinaryConsts::I32LoadMem16S : BinaryConsts::I32LoadMem16U); break; - case 4: o << int8_t(BinaryConsts::I32LoadMem); break; - default: abort(); - } - break; - } - case i64: { - switch (curr->bytes) { - case 1: o << int8_t(curr->signed_ ? BinaryConsts::I64LoadMem8S : BinaryConsts::I64LoadMem8U); break; - case 2: o << int8_t(curr->signed_ ? BinaryConsts::I64LoadMem16S : BinaryConsts::I64LoadMem16U); break; - case 4: o << int8_t(curr->signed_ ? BinaryConsts::I64LoadMem32S : BinaryConsts::I64LoadMem32U); break; - case 8: o << int8_t(BinaryConsts::I64LoadMem); break; - default: abort(); - } - break; - } - case f32: o << int8_t(BinaryConsts::F32LoadMem); break; - case f64: o << int8_t(BinaryConsts::F64LoadMem); break; - case unreachable: return; // the pointer is unreachable, so we are never reached; just don't emit a load - default: WASM_UNREACHABLE(); - } - } else { - if (curr->type == unreachable) { - // don't even emit it; we don't know the right type - o << int8_t(BinaryConsts::Unreachable); - return; - } - o << int8_t(BinaryConsts::AtomicPrefix); - switch (curr->type) { - case i32: { - switch (curr->bytes) { - case 1: o << int8_t(BinaryConsts::I32AtomicLoad8U); break; - case 2: o << int8_t(BinaryConsts::I32AtomicLoad16U); break; - case 4: o << int8_t(BinaryConsts::I32AtomicLoad); break; - default: WASM_UNREACHABLE(); - } - break; - } - case i64: { - switch (curr->bytes) { - case 1: o << int8_t(BinaryConsts::I64AtomicLoad8U); break; - case 2: o << int8_t(BinaryConsts::I64AtomicLoad16U); break; - case 4: o << int8_t(BinaryConsts::I64AtomicLoad32U); break; - case 8: o << int8_t(BinaryConsts::I64AtomicLoad); break; - default: WASM_UNREACHABLE(); - } - break; - } - case unreachable: return; - default: WASM_UNREACHABLE(); - } - } - emitMemoryAccess(curr->align, curr->bytes, curr->offset); -} - -void StackWriter::visitStore(Store *curr) { - if (debug) std::cerr << "zz node: Store" << std::endl; - visit(curr->ptr); - visit(curr->value); - if (!curr->isAtomic) { - switch (curr->valueType) { - case i32: { - switch (curr->bytes) { - case 1: o << int8_t(BinaryConsts::I32StoreMem8); break; - case 2: o << int8_t(BinaryConsts::I32StoreMem16); break; - case 4: o << int8_t(BinaryConsts::I32StoreMem); break; - default: abort(); - } - break; - } - case i64: { - switch (curr->bytes) { - case 1: o << int8_t(BinaryConsts::I64StoreMem8); break; - case 2: o << int8_t(BinaryConsts::I64StoreMem16); break; - case 4: o << int8_t(BinaryConsts::I64StoreMem32); break; - case 8: o << int8_t(BinaryConsts::I64StoreMem); break; - default: abort(); - } - break; - } - case f32: o << int8_t(BinaryConsts::F32StoreMem); break; - case f64: o << int8_t(BinaryConsts::F64StoreMem); break; - default: abort(); - } - } else { - if (curr->type == unreachable) { - // don't even emit it; we don't know the right type - o << int8_t(BinaryConsts::Unreachable); - return; - } - o << int8_t(BinaryConsts::AtomicPrefix); - switch (curr->valueType) { - case i32: { - switch (curr->bytes) { - case 1: o << int8_t(BinaryConsts::I32AtomicStore8); break; - case 2: o << int8_t(BinaryConsts::I32AtomicStore16); break; - case 4: o << int8_t(BinaryConsts::I32AtomicStore); break; - default: WASM_UNREACHABLE(); - } - break; - } - case i64: { - switch (curr->bytes) { - case 1: o << int8_t(BinaryConsts::I64AtomicStore8); break; - case 2: o << int8_t(BinaryConsts::I64AtomicStore16); break; - case 4: o << int8_t(BinaryConsts::I64AtomicStore32); break; - case 8: o << int8_t(BinaryConsts::I64AtomicStore); break; - default: WASM_UNREACHABLE(); - } - break; - } - default: WASM_UNREACHABLE(); - } - } - emitMemoryAccess(curr->align, curr->bytes, curr->offset); -} - -void StackWriter::visitAtomicRMW(AtomicRMW *curr) { - if (debug) std::cerr << "zz node: AtomicRMW" << std::endl; - visit(curr->ptr); - // stop if the rest isn't reachable anyhow - if (curr->ptr->type == unreachable) return; - visit(curr->value); - if (curr->value->type == unreachable) return; - - if (curr->type == unreachable) { - // don't even emit it; we don't know the right type - o << int8_t(BinaryConsts::Unreachable); - return; - } - - o << int8_t(BinaryConsts::AtomicPrefix); - -#define CASE_FOR_OP(Op) \ - case Op: \ - switch (curr->type) { \ - case i32: \ - switch (curr->bytes) { \ - case 1: o << int8_t(BinaryConsts::I32AtomicRMW##Op##8U); break; \ - case 2: o << int8_t(BinaryConsts::I32AtomicRMW##Op##16U); break; \ - case 4: o << int8_t(BinaryConsts::I32AtomicRMW##Op); break; \ - default: WASM_UNREACHABLE(); \ - } \ - break; \ - case i64: \ - switch (curr->bytes) { \ - case 1: o << int8_t(BinaryConsts::I64AtomicRMW##Op##8U); break; \ - case 2: o << int8_t(BinaryConsts::I64AtomicRMW##Op##16U); break; \ - case 4: o << int8_t(BinaryConsts::I64AtomicRMW##Op##32U); break; \ - case 8: o << int8_t(BinaryConsts::I64AtomicRMW##Op); break; \ - default: WASM_UNREACHABLE(); \ - } \ - break; \ - default: WASM_UNREACHABLE(); \ - } \ - break - - switch(curr->op) { - CASE_FOR_OP(Add); - CASE_FOR_OP(Sub); - CASE_FOR_OP(And); - CASE_FOR_OP(Or); - CASE_FOR_OP(Xor); - CASE_FOR_OP(Xchg); - default: WASM_UNREACHABLE(); - } -#undef CASE_FOR_OP - - emitMemoryAccess(curr->bytes, curr->bytes, curr->offset); -} - -void StackWriter::visitAtomicCmpxchg(AtomicCmpxchg *curr) { - if (debug) std::cerr << "zz node: AtomicCmpxchg" << std::endl; - visit(curr->ptr); - // stop if the rest isn't reachable anyhow - if (curr->ptr->type == unreachable) return; - visit(curr->expected); - if (curr->expected->type == unreachable) return; - visit(curr->replacement); - if (curr->replacement->type == unreachable) return; - - if (curr->type == unreachable) { - // don't even emit it; we don't know the right type - o << int8_t(BinaryConsts::Unreachable); - return; - } - - o << int8_t(BinaryConsts::AtomicPrefix); - switch (curr->type) { - case i32: - switch (curr->bytes) { - case 1: o << int8_t(BinaryConsts::I32AtomicCmpxchg8U); break; - case 2: o << int8_t(BinaryConsts::I32AtomicCmpxchg16U); break; - case 4: o << int8_t(BinaryConsts::I32AtomicCmpxchg); break; - default: WASM_UNREACHABLE(); - } - break; - case i64: - switch (curr->bytes) { - case 1: o << int8_t(BinaryConsts::I64AtomicCmpxchg8U); break; - case 2: o << int8_t(BinaryConsts::I64AtomicCmpxchg16U); break; - case 4: o << int8_t(BinaryConsts::I64AtomicCmpxchg32U); break; - case 8: o << int8_t(BinaryConsts::I64AtomicCmpxchg); break; - default: WASM_UNREACHABLE(); - } - break; - default: WASM_UNREACHABLE(); - } - emitMemoryAccess(curr->bytes, curr->bytes, curr->offset); -} - -void StackWriter::visitAtomicWait(AtomicWait *curr) { - if (debug) std::cerr << "zz node: AtomicWait" << std::endl; - visit(curr->ptr); - // stop if the rest isn't reachable anyhow - if (curr->ptr->type == unreachable) return; - visit(curr->expected); - if (curr->expected->type == unreachable) return; - visit(curr->timeout); - if (curr->timeout->type == unreachable) return; - - o << int8_t(BinaryConsts::AtomicPrefix); - switch (curr->expectedType) { - case i32: { - o << int8_t(BinaryConsts::I32AtomicWait); - emitMemoryAccess(4, 4, 0); - break; - } - case i64: { - o << int8_t(BinaryConsts::I64AtomicWait); - emitMemoryAccess(8, 8, 0); - break; - } - default: WASM_UNREACHABLE(); - } -} - -void StackWriter::visitAtomicWake(AtomicWake *curr) { - if (debug) std::cerr << "zz node: AtomicWake" << std::endl; - visit(curr->ptr); - // stop if the rest isn't reachable anyhow - if (curr->ptr->type == unreachable) return; - visit(curr->wakeCount); - if (curr->wakeCount->type == unreachable) return; - - o << int8_t(BinaryConsts::AtomicPrefix) << int8_t(BinaryConsts::AtomicWake); - emitMemoryAccess(4, 4, 0); -} - -void StackWriter::visitConst(Const *curr) { - if (debug) std::cerr << "zz node: Const" << curr << " : " << curr->type << std::endl; - switch (curr->type) { - case i32: { - o << int8_t(BinaryConsts::I32Const) << S32LEB(curr->value.geti32()); - break; - } - case i64: { - o << int8_t(BinaryConsts::I64Const) << S64LEB(curr->value.geti64()); - break; - } - case f32: { - o << int8_t(BinaryConsts::F32Const) << curr->value.reinterpreti32(); - break; - } - case f64: { - o << int8_t(BinaryConsts::F64Const) << curr->value.reinterpreti64(); - break; - } - default: abort(); - } - if (debug) std::cerr << "zz const node done.\n"; -} - -void StackWriter::visitUnary(Unary *curr) { - if (debug) std::cerr << "zz node: Unary" << std::endl; - visit(curr->value); - switch (curr->op) { - case ClzInt32: o << int8_t(BinaryConsts::I32Clz); break; - case CtzInt32: o << int8_t(BinaryConsts::I32Ctz); break; - case PopcntInt32: o << int8_t(BinaryConsts::I32Popcnt); break; - case EqZInt32: o << int8_t(BinaryConsts::I32EqZ); break; - case ClzInt64: o << int8_t(BinaryConsts::I64Clz); break; - case CtzInt64: o << int8_t(BinaryConsts::I64Ctz); break; - case PopcntInt64: o << int8_t(BinaryConsts::I64Popcnt); break; - case EqZInt64: o << int8_t(BinaryConsts::I64EqZ); break; - case NegFloat32: o << int8_t(BinaryConsts::F32Neg); break; - case AbsFloat32: o << int8_t(BinaryConsts::F32Abs); break; - case CeilFloat32: o << int8_t(BinaryConsts::F32Ceil); break; - case FloorFloat32: o << int8_t(BinaryConsts::F32Floor); break; - case TruncFloat32: o << int8_t(BinaryConsts::F32Trunc); break; - case NearestFloat32: o << int8_t(BinaryConsts::F32NearestInt); break; - case SqrtFloat32: o << int8_t(BinaryConsts::F32Sqrt); break; - case NegFloat64: o << int8_t(BinaryConsts::F64Neg); break; - case AbsFloat64: o << int8_t(BinaryConsts::F64Abs); break; - case CeilFloat64: o << int8_t(BinaryConsts::F64Ceil); break; - case FloorFloat64: o << int8_t(BinaryConsts::F64Floor); break; - case TruncFloat64: o << int8_t(BinaryConsts::F64Trunc); break; - case NearestFloat64: o << int8_t(BinaryConsts::F64NearestInt); break; - case SqrtFloat64: o << int8_t(BinaryConsts::F64Sqrt); break; - case ExtendSInt32: o << int8_t(BinaryConsts::I64STruncI32); break; - case ExtendUInt32: o << int8_t(BinaryConsts::I64UTruncI32); break; - case WrapInt64: o << int8_t(BinaryConsts::I32ConvertI64); break; - case TruncUFloat32ToInt32: o << int8_t(BinaryConsts::I32UTruncF32); break; - case TruncUFloat32ToInt64: o << int8_t(BinaryConsts::I64UTruncF32); break; - case TruncSFloat32ToInt32: o << int8_t(BinaryConsts::I32STruncF32); break; - case TruncSFloat32ToInt64: o << int8_t(BinaryConsts::I64STruncF32); break; - case TruncUFloat64ToInt32: o << int8_t(BinaryConsts::I32UTruncF64); break; - case TruncUFloat64ToInt64: o << int8_t(BinaryConsts::I64UTruncF64); break; - case TruncSFloat64ToInt32: o << int8_t(BinaryConsts::I32STruncF64); break; - case TruncSFloat64ToInt64: o << int8_t(BinaryConsts::I64STruncF64); break; - case ConvertUInt32ToFloat32: o << int8_t(BinaryConsts::F32UConvertI32); break; - case ConvertUInt32ToFloat64: o << int8_t(BinaryConsts::F64UConvertI32); break; - case ConvertSInt32ToFloat32: o << int8_t(BinaryConsts::F32SConvertI32); break; - case ConvertSInt32ToFloat64: o << int8_t(BinaryConsts::F64SConvertI32); break; - case ConvertUInt64ToFloat32: o << int8_t(BinaryConsts::F32UConvertI64); break; - case ConvertUInt64ToFloat64: o << int8_t(BinaryConsts::F64UConvertI64); break; - case ConvertSInt64ToFloat32: o << int8_t(BinaryConsts::F32SConvertI64); break; - case ConvertSInt64ToFloat64: o << int8_t(BinaryConsts::F64SConvertI64); break; - case DemoteFloat64: o << int8_t(BinaryConsts::F32ConvertF64); break; - case PromoteFloat32: o << int8_t(BinaryConsts::F64ConvertF32); break; - case ReinterpretFloat32: o << int8_t(BinaryConsts::I32ReinterpretF32); break; - case ReinterpretFloat64: o << int8_t(BinaryConsts::I64ReinterpretF64); break; - case ReinterpretInt32: o << int8_t(BinaryConsts::F32ReinterpretI32); break; - case ReinterpretInt64: o << int8_t(BinaryConsts::F64ReinterpretI64); break; - case ExtendS8Int32: o << int8_t(BinaryConsts::I32ExtendS8); break; - case ExtendS16Int32: o << int8_t(BinaryConsts::I32ExtendS16); break; - case ExtendS8Int64: o << int8_t(BinaryConsts::I64ExtendS8); break; - case ExtendS16Int64: o << int8_t(BinaryConsts::I64ExtendS16); break; - case ExtendS32Int64: o << int8_t(BinaryConsts::I64ExtendS32); break; - default: abort(); - } - if (curr->type == unreachable) { - o << int8_t(BinaryConsts::Unreachable); - } -} - -void StackWriter::visitBinary(Binary *curr) { - if (debug) std::cerr << "zz node: Binary" << std::endl; - visit(curr->left); - visit(curr->right); - - switch (curr->op) { - case AddInt32: o << int8_t(BinaryConsts::I32Add); break; - case SubInt32: o << int8_t(BinaryConsts::I32Sub); break; - case MulInt32: o << int8_t(BinaryConsts::I32Mul); break; - case DivSInt32: o << int8_t(BinaryConsts::I32DivS); break; - case DivUInt32: o << int8_t(BinaryConsts::I32DivU); break; - case RemSInt32: o << int8_t(BinaryConsts::I32RemS); break; - case RemUInt32: o << int8_t(BinaryConsts::I32RemU); break; - case AndInt32: o << int8_t(BinaryConsts::I32And); break; - case OrInt32: o << int8_t(BinaryConsts::I32Or); break; - case XorInt32: o << int8_t(BinaryConsts::I32Xor); break; - case ShlInt32: o << int8_t(BinaryConsts::I32Shl); break; - case ShrUInt32: o << int8_t(BinaryConsts::I32ShrU); break; - case ShrSInt32: o << int8_t(BinaryConsts::I32ShrS); break; - case RotLInt32: o << int8_t(BinaryConsts::I32RotL); break; - case RotRInt32: o << int8_t(BinaryConsts::I32RotR); break; - case EqInt32: o << int8_t(BinaryConsts::I32Eq); break; - case NeInt32: o << int8_t(BinaryConsts::I32Ne); break; - case LtSInt32: o << int8_t(BinaryConsts::I32LtS); break; - case LtUInt32: o << int8_t(BinaryConsts::I32LtU); break; - case LeSInt32: o << int8_t(BinaryConsts::I32LeS); break; - case LeUInt32: o << int8_t(BinaryConsts::I32LeU); break; - case GtSInt32: o << int8_t(BinaryConsts::I32GtS); break; - case GtUInt32: o << int8_t(BinaryConsts::I32GtU); break; - case GeSInt32: o << int8_t(BinaryConsts::I32GeS); break; - case GeUInt32: o << int8_t(BinaryConsts::I32GeU); break; - - case AddInt64: o << int8_t(BinaryConsts::I64Add); break; - case SubInt64: o << int8_t(BinaryConsts::I64Sub); break; - case MulInt64: o << int8_t(BinaryConsts::I64Mul); break; - case DivSInt64: o << int8_t(BinaryConsts::I64DivS); break; - case DivUInt64: o << int8_t(BinaryConsts::I64DivU); break; - case RemSInt64: o << int8_t(BinaryConsts::I64RemS); break; - case RemUInt64: o << int8_t(BinaryConsts::I64RemU); break; - case AndInt64: o << int8_t(BinaryConsts::I64And); break; - case OrInt64: o << int8_t(BinaryConsts::I64Or); break; - case XorInt64: o << int8_t(BinaryConsts::I64Xor); break; - case ShlInt64: o << int8_t(BinaryConsts::I64Shl); break; - case ShrUInt64: o << int8_t(BinaryConsts::I64ShrU); break; - case ShrSInt64: o << int8_t(BinaryConsts::I64ShrS); break; - case RotLInt64: o << int8_t(BinaryConsts::I64RotL); break; - case RotRInt64: o << int8_t(BinaryConsts::I64RotR); break; - case EqInt64: o << int8_t(BinaryConsts::I64Eq); break; - case NeInt64: o << int8_t(BinaryConsts::I64Ne); break; - case LtSInt64: o << int8_t(BinaryConsts::I64LtS); break; - case LtUInt64: o << int8_t(BinaryConsts::I64LtU); break; - case LeSInt64: o << int8_t(BinaryConsts::I64LeS); break; - case LeUInt64: o << int8_t(BinaryConsts::I64LeU); break; - case GtSInt64: o << int8_t(BinaryConsts::I64GtS); break; - case GtUInt64: o << int8_t(BinaryConsts::I64GtU); break; - case GeSInt64: o << int8_t(BinaryConsts::I64GeS); break; - case GeUInt64: o << int8_t(BinaryConsts::I64GeU); break; - - case AddFloat32: o << int8_t(BinaryConsts::F32Add); break; - case SubFloat32: o << int8_t(BinaryConsts::F32Sub); break; - case MulFloat32: o << int8_t(BinaryConsts::F32Mul); break; - case DivFloat32: o << int8_t(BinaryConsts::F32Div); break; - case CopySignFloat32: o << int8_t(BinaryConsts::F32CopySign);break; - case MinFloat32: o << int8_t(BinaryConsts::F32Min); break; - case MaxFloat32: o << int8_t(BinaryConsts::F32Max); break; - case EqFloat32: o << int8_t(BinaryConsts::F32Eq); break; - case NeFloat32: o << int8_t(BinaryConsts::F32Ne); break; - case LtFloat32: o << int8_t(BinaryConsts::F32Lt); break; - case LeFloat32: o << int8_t(BinaryConsts::F32Le); break; - case GtFloat32: o << int8_t(BinaryConsts::F32Gt); break; - case GeFloat32: o << int8_t(BinaryConsts::F32Ge); break; - - case AddFloat64: o << int8_t(BinaryConsts::F64Add); break; - case SubFloat64: o << int8_t(BinaryConsts::F64Sub); break; - case MulFloat64: o << int8_t(BinaryConsts::F64Mul); break; - case DivFloat64: o << int8_t(BinaryConsts::F64Div); break; - case CopySignFloat64: o << int8_t(BinaryConsts::F64CopySign);break; - case MinFloat64: o << int8_t(BinaryConsts::F64Min); break; - case MaxFloat64: o << int8_t(BinaryConsts::F64Max); break; - case EqFloat64: o << int8_t(BinaryConsts::F64Eq); break; - case NeFloat64: o << int8_t(BinaryConsts::F64Ne); break; - case LtFloat64: o << int8_t(BinaryConsts::F64Lt); break; - case LeFloat64: o << int8_t(BinaryConsts::F64Le); break; - case GtFloat64: o << int8_t(BinaryConsts::F64Gt); break; - case GeFloat64: o << int8_t(BinaryConsts::F64Ge); break; - default: abort(); - } - if (curr->type == unreachable) { - o << int8_t(BinaryConsts::Unreachable); - } -} - -void StackWriter::visitSelect(Select *curr) { - if (debug) std::cerr << "zz node: Select" << std::endl; - visit(curr->ifTrue); - visit(curr->ifFalse); - visit(curr->condition); - o << int8_t(BinaryConsts::Select); - if (curr->type == unreachable) { - o << int8_t(BinaryConsts::Unreachable); - } -} - -void StackWriter::visitReturn(Return *curr) { - if (debug) std::cerr << "zz node: Return" << std::endl; - if (curr->value) { - visit(curr->value); - } - o << int8_t(BinaryConsts::Return); -} - -void StackWriter::visitHost(Host *curr) { - if (debug) std::cerr << "zz node: Host" << std::endl; - switch (curr->op) { - case CurrentMemory: { - o << int8_t(BinaryConsts::CurrentMemory); - break; - } - case GrowMemory: { - visit(curr->operands[0]); - o << int8_t(BinaryConsts::GrowMemory); - break; - } - default: abort(); - } - o << U32LEB(0); // Reserved flags field -} - -void StackWriter::visitNop(Nop *curr) { - if (debug) std::cerr << "zz node: Nop" << std::endl; - o << int8_t(BinaryConsts::Nop); -} - -void StackWriter::visitUnreachable(Unreachable *curr) { - if (debug) std::cerr << "zz node: Unreachable" << std::endl; - o << int8_t(BinaryConsts::Unreachable); -} - -void StackWriter::visitDrop(Drop *curr) { - if (debug) std::cerr << "zz node: Drop" << std::endl; - visit(curr->value); - o << int8_t(BinaryConsts::Drop); -} - -int32_t StackWriter::getBreakIndex(Name name) { // -1 if not found - for (int i = breakStack.size() - 1; i >= 0; i--) { - if (breakStack[i] == name) { - return breakStack.size() - 1 - i; - } - } - std::cerr << "bad break: " << name << " in " << func->name << std::endl; - abort(); -} - -void StackWriter::emitMemoryAccess(size_t alignment, size_t bytes, uint32_t offset) { - o << U32LEB(Log2(alignment ? alignment : bytes)); - o << U32LEB(offset); -} - // reader void WasmBinaryBuilder::read() { @@ -2408,7 +1663,7 @@ void WasmBinaryBuilder::pushBlockElements(Block* curr, size_t start, size_t end) } } -void WasmBinaryBuilder::visitBlock(Block *curr) { +void WasmBinaryBuilder::visitBlock(Block* curr) { if (debug) std::cerr << "zz node: Block" << std::endl; // special-case Block and de-recurse nested blocks in their first position, as that is // a common pattern that can be very highly nested. @@ -2475,7 +1730,7 @@ Expression* WasmBinaryBuilder::getBlockOrSingleton(Type type) { return block; } -void WasmBinaryBuilder::visitIf(If *curr) { +void WasmBinaryBuilder::visitIf(If* curr) { if (debug) std::cerr << "zz node: If" << std::endl; curr->type = getType(); curr->condition = popNonVoidExpression(); @@ -2489,7 +1744,7 @@ void WasmBinaryBuilder::visitIf(If *curr) { } } -void WasmBinaryBuilder::visitLoop(Loop *curr) { +void WasmBinaryBuilder::visitLoop(Loop* curr) { if (debug) std::cerr << "zz node: Loop" << std::endl; curr->type = getType(); curr->name = getNextLabel(); @@ -2546,7 +1801,7 @@ void WasmBinaryBuilder::visitBreak(Break *curr, uint8_t code) { curr->finalize(); } -void WasmBinaryBuilder::visitSwitch(Switch *curr) { +void WasmBinaryBuilder::visitSwitch(Switch* curr) { if (debug) std::cerr << "zz node: Switch" << std::endl; curr->condition = popNonVoidExpression(); auto numTargets = getU32LEB(); @@ -2592,7 +1847,7 @@ Expression* WasmBinaryBuilder::visitCall() { return ret; } -void WasmBinaryBuilder::visitCallIndirect(CallIndirect *curr) { +void WasmBinaryBuilder::visitCallIndirect(CallIndirect* curr) { if (debug) std::cerr << "zz node: CallIndirect" << std::endl; auto index = getU32LEB(); if (index >= wasm.functionTypes.size()) { @@ -2612,7 +1867,7 @@ void WasmBinaryBuilder::visitCallIndirect(CallIndirect *curr) { curr->finalize(); } -void WasmBinaryBuilder::visitGetLocal(GetLocal *curr) { +void WasmBinaryBuilder::visitGetLocal(GetLocal* curr) { if (debug) std::cerr << "zz node: GetLocal " << pos << std::endl; requireFunctionContext("get_local"); curr->index = getU32LEB(); @@ -2636,7 +1891,7 @@ void WasmBinaryBuilder::visitSetLocal(SetLocal *curr, uint8_t code) { curr->finalize(); } -void WasmBinaryBuilder::visitGetGlobal(GetGlobal *curr) { +void WasmBinaryBuilder::visitGetGlobal(GetGlobal* curr) { if (debug) std::cerr << "zz node: GetGlobal " << pos << std::endl; auto index = getU32LEB(); curr->name = getGlobalName(index); @@ -2653,7 +1908,7 @@ void WasmBinaryBuilder::visitGetGlobal(GetGlobal *curr) { throwError("bad get_global"); } -void WasmBinaryBuilder::visitSetGlobal(SetGlobal *curr) { +void WasmBinaryBuilder::visitSetGlobal(SetGlobal* curr) { if (debug) std::cerr << "zz node: SetGlobal" << std::endl; auto index = getU32LEB(); curr->name = getGlobalName(index); @@ -3012,7 +2267,7 @@ bool WasmBinaryBuilder::maybeVisitBinary(Expression*& out, uint8_t code) { #undef FLOAT_TYPED_CODE } -void WasmBinaryBuilder::visitSelect(Select *curr) { +void WasmBinaryBuilder::visitSelect(Select* curr) { if (debug) std::cerr << "zz node: Select" << std::endl; curr->condition = popNonVoidExpression(); curr->ifFalse = popNonVoidExpression(); @@ -3020,7 +2275,7 @@ void WasmBinaryBuilder::visitSelect(Select *curr) { curr->finalize(); } -void WasmBinaryBuilder::visitReturn(Return *curr) { +void WasmBinaryBuilder::visitReturn(Return* curr) { if (debug) std::cerr << "zz node: Return" << std::endl; requireFunctionContext("return"); if (currFunction->result != none) { @@ -3055,15 +2310,15 @@ bool WasmBinaryBuilder::maybeVisitHost(Expression*& out, uint8_t code) { return true; } -void WasmBinaryBuilder::visitNop(Nop *curr) { +void WasmBinaryBuilder::visitNop(Nop* curr) { if (debug) std::cerr << "zz node: Nop" << std::endl; } -void WasmBinaryBuilder::visitUnreachable(Unreachable *curr) { +void WasmBinaryBuilder::visitUnreachable(Unreachable* curr) { if (debug) std::cerr << "zz node: Unreachable" << std::endl; } -void WasmBinaryBuilder::visitDrop(Drop *curr) { +void WasmBinaryBuilder::visitDrop(Drop* curr) { if (debug) std::cerr << "zz node: Drop" << std::endl; curr->value = popNonVoidExpression(); curr->finalize(); diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp index 9b5384efe..4a6ed19be 100644 --- a/src/wasm/wasm-validator.cpp +++ b/src/wasm/wasm-validator.cpp @@ -171,6 +171,8 @@ struct FunctionValidator : public WalkerPass<PostWalker<FunctionValidator>> { Pass* create() override { return new FunctionValidator(&info); } + bool modifiesBinaryenIR() override { return false; } + ValidationInfo& info; FunctionValidator(ValidationInfo* info) : info(*info) {} |