diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/ast_utils.h | 4 | ||||
-rw-r--r-- | src/pass.h | 2 | ||||
-rw-r--r-- | src/passes/LowerIfElse.cpp | 2 | ||||
-rw-r--r-- | src/passes/MergeBlocks.cpp | 2 | ||||
-rw-r--r-- | src/passes/Metrics.cpp | 8 | ||||
-rw-r--r-- | src/passes/OptimizeInstructions.cpp | 2 | ||||
-rw-r--r-- | src/passes/PostEmscripten.cpp | 2 | ||||
-rw-r--r-- | src/passes/Print.cpp | 2 | ||||
-rw-r--r-- | src/passes/RemoveImports.cpp | 2 | ||||
-rw-r--r-- | src/passes/RemoveUnusedBrs.cpp | 2 | ||||
-rw-r--r-- | src/passes/RemoveUnusedNames.cpp | 2 | ||||
-rw-r--r-- | src/passes/ReorderLocals.cpp | 2 | ||||
-rw-r--r-- | src/passes/SimplifyLocals.cpp | 72 | ||||
-rw-r--r-- | src/passes/Vacuum.cpp | 2 | ||||
-rw-r--r-- | src/s2wasm.h | 2 | ||||
-rw-r--r-- | src/wasm-binary.h | 2 | ||||
-rw-r--r-- | src/wasm-interpreter.h | 2 | ||||
-rw-r--r-- | src/wasm-traversal.h | 634 | ||||
-rw-r--r-- | src/wasm-validator.h | 4 | ||||
-rw-r--r-- | src/wasm2asm.h | 12 |
20 files changed, 378 insertions, 384 deletions
diff --git a/src/ast_utils.h b/src/ast_utils.h index 5ab427178..561e12983 100644 --- a/src/ast_utils.h +++ b/src/ast_utils.h @@ -22,7 +22,7 @@ namespace wasm { -struct BreakSeeker : public WasmWalker<BreakSeeker> { +struct BreakSeeker : public PostWalker<BreakSeeker> { Name target; // look for this one size_t found; @@ -42,7 +42,7 @@ struct BreakSeeker : public WasmWalker<BreakSeeker> { // Look for side effects, including control flow // TODO: look at individual locals -struct EffectAnalyzer : public WasmWalker<EffectAnalyzer> { +struct EffectAnalyzer : public PostWalker<EffectAnalyzer> { bool branches = false; bool calls = false; bool readsLocal = false; diff --git a/src/pass.h b/src/pass.h index e3716545d..f1f99e3aa 100644 --- a/src/pass.h +++ b/src/pass.h @@ -133,7 +133,7 @@ public: // e.g. through PassRunner::getLast // Handles names in a module, in particular adding names without duplicates -class NameManager : public WalkerPass<WasmWalker<NameManager>> { +class NameManager : public WalkerPass<PostWalker<NameManager>> { public: Name getUnique(std::string prefix); // TODO: getUniqueInFunction diff --git a/src/passes/LowerIfElse.cpp b/src/passes/LowerIfElse.cpp index 922a294d2..48c1c0f9c 100644 --- a/src/passes/LowerIfElse.cpp +++ b/src/passes/LowerIfElse.cpp @@ -32,7 +32,7 @@ namespace wasm { -struct LowerIfElse : public WalkerPass<WasmWalker<LowerIfElse, void>> { +struct LowerIfElse : public WalkerPass<PostWalker<LowerIfElse>> { MixedArena* allocator; std::unique_ptr<NameManager> namer; diff --git a/src/passes/MergeBlocks.cpp b/src/passes/MergeBlocks.cpp index c85f3d600..ab210123c 100644 --- a/src/passes/MergeBlocks.cpp +++ b/src/passes/MergeBlocks.cpp @@ -23,7 +23,7 @@ namespace wasm { -struct MergeBlocks : public WalkerPass<WasmWalker<MergeBlocks>> { +struct MergeBlocks : public WalkerPass<PostWalker<MergeBlocks>> { void visitBlock(Block *curr) { bool more = true; while (more) { diff --git a/src/passes/Metrics.cpp b/src/passes/Metrics.cpp index 9673b8873..19df981b8 100644 --- a/src/passes/Metrics.cpp +++ b/src/passes/Metrics.cpp @@ -24,16 +24,16 @@ namespace wasm { using namespace std; // Prints metrics between optimization passes. -struct Metrics : public WalkerPass<WasmWalker<Metrics>> { +struct Metrics : public WalkerPass<PostWalker<Metrics>> { static Metrics *lastMetricsPass; map<const char *, int> counts; - void walk(Expression *&curr) override { - WalkerPass::walk(curr); - if (!curr) return; + + void visitExpression(Expression* curr) { auto name = getExpressionName(curr); counts[name]++; } + void finalize(PassRunner *runner, Module *module) override { ostream &o = cout; o << "Counts" diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp index de342b43f..3d89d7af4 100644 --- a/src/passes/OptimizeInstructions.cpp +++ b/src/passes/OptimizeInstructions.cpp @@ -25,7 +25,7 @@ namespace wasm { -struct OptimizeInstructions : public WalkerPass<WasmWalker<OptimizeInstructions>> { +struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions>> { void visitIf(If* curr) { // flip branches to get rid of an i32.eqz if (curr->ifFalse) { diff --git a/src/passes/PostEmscripten.cpp b/src/passes/PostEmscripten.cpp index 14476b727..99b172d65 100644 --- a/src/passes/PostEmscripten.cpp +++ b/src/passes/PostEmscripten.cpp @@ -24,7 +24,7 @@ namespace wasm { -struct PostEmscripten : public WalkerPass<WasmWalker<PostEmscripten>> { +struct PostEmscripten : public WalkerPass<PostWalker<PostEmscripten>> { // When we have a Load from a local value (typically a GetLocal) plus a constant offset, // we may be able to fold it in. // The semantics of the Add are to wrap, while wasm offset semantics purposefully do diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp index 8e3d503b8..7b68d956e 100644 --- a/src/passes/Print.cpp +++ b/src/passes/Print.cpp @@ -24,7 +24,7 @@ namespace wasm { -struct PrintSExpression : public WasmVisitor<PrintSExpression, void> { +struct PrintSExpression : public Visitor<PrintSExpression> { std::ostream& o; unsigned indent = 0; diff --git a/src/passes/RemoveImports.cpp b/src/passes/RemoveImports.cpp index 6463df10c..4f42c526f 100644 --- a/src/passes/RemoveImports.cpp +++ b/src/passes/RemoveImports.cpp @@ -27,7 +27,7 @@ namespace wasm { -struct RemoveImports : public WalkerPass<WasmWalker<RemoveImports>> { +struct RemoveImports : public WalkerPass<PostWalker<RemoveImports>> { MixedArena* allocator; Module* module; diff --git a/src/passes/RemoveUnusedBrs.cpp b/src/passes/RemoveUnusedBrs.cpp index 5302c368a..998142724 100644 --- a/src/passes/RemoveUnusedBrs.cpp +++ b/src/passes/RemoveUnusedBrs.cpp @@ -23,7 +23,7 @@ namespace wasm { -struct RemoveUnusedBrs : public WalkerPass<WasmWalker<RemoveUnusedBrs>> { +struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> { // preparation: try to unify branches, as the fewer there are, the higher a chance we can remove them // specifically for if-else, turn an if-else with branches to the same target at the end of each // child, and with a value, to a branch to that target containing the if-else diff --git a/src/passes/RemoveUnusedNames.cpp b/src/passes/RemoveUnusedNames.cpp index a8beaa15e..71569eefb 100644 --- a/src/passes/RemoveUnusedNames.cpp +++ b/src/passes/RemoveUnusedNames.cpp @@ -23,7 +23,7 @@ namespace wasm { -struct RemoveUnusedNames : public WalkerPass<WasmWalker<RemoveUnusedNames>> { +struct RemoveUnusedNames : public WalkerPass<PostWalker<RemoveUnusedNames>> { // We maintain a list of branches that we saw in children, then when we reach // a parent block, we know if it was branched to std::set<Name> branchesSeen; diff --git a/src/passes/ReorderLocals.cpp b/src/passes/ReorderLocals.cpp index 1b5e998c4..e378408b6 100644 --- a/src/passes/ReorderLocals.cpp +++ b/src/passes/ReorderLocals.cpp @@ -26,7 +26,7 @@ namespace wasm { -struct ReorderLocals : public WalkerPass<WasmWalker<ReorderLocals, void>> { +struct ReorderLocals : public WalkerPass<PostWalker<ReorderLocals>> { std::map<Name, uint32_t> counts; diff --git a/src/passes/SimplifyLocals.cpp b/src/passes/SimplifyLocals.cpp index 0d59b8759..53e77eb22 100644 --- a/src/passes/SimplifyLocals.cpp +++ b/src/passes/SimplifyLocals.cpp @@ -26,7 +26,7 @@ namespace wasm { -struct SimplifyLocals : public WalkerPass<FastExecutionWalker<SimplifyLocals>> { +struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals>> { struct SinkableInfo { Expression** item; EffectAnalyzer effects; @@ -43,23 +43,6 @@ struct SimplifyLocals : public WalkerPass<FastExecutionWalker<SimplifyLocals>> { sinkables.clear(); } - void visitBlock(Block *curr) { - // note locals, we can sink them from here TODO sink from elsewhere? - derecurseBlocks(curr, [&](Block* block) { - // curr was already checked by walk() - if (block != curr) checkPre(block); - }, [&](Block* block, Expression*& child) { - walk(child); - if (child->is<SetLocal>()) { - Name name = child->cast<SetLocal>()->name; - assert(sinkables.count(name) == 0); - sinkables.emplace(std::make_pair(name, SinkableInfo(&child))); - } - }, [&](Block* block) { - if (block != curr) checkPost(block); - }); - } - void visitGetLocal(GetLocal *curr) { auto found = sinkables.find(curr->name); if (found != sinkables.end()) { @@ -73,7 +56,6 @@ struct SimplifyLocals : public WalkerPass<FastExecutionWalker<SimplifyLocals>> { } void visitSetLocal(SetLocal *curr) { - walk(curr->value); // if we are a potentially-sinkable thing, forget it - this // write overrides the last TODO: optimizable // TODO: if no get_locals left, can remove the set as well (== expressionizer in emscripten optimizer) @@ -96,28 +78,56 @@ struct SimplifyLocals : public WalkerPass<FastExecutionWalker<SimplifyLocals>> { } } - void checkPre(Expression* curr) { + static void visitPre(SimplifyLocals* self, Expression** currp) { EffectAnalyzer effects; - if (effects.checkPre(curr)) { - checkInvalidations(effects); + if (effects.checkPre(*currp)) { + self->checkInvalidations(effects); } } - void checkPost(Expression* curr) { + static void visitPost(SimplifyLocals* self, Expression** currp) { EffectAnalyzer effects; - if (effects.checkPost(curr)) { - checkInvalidations(effects); + if (effects.checkPost(*currp)) { + self->checkInvalidations(effects); } } - void walk(Expression*& curr) override { - if (!curr) return; - - checkPre(curr); + static void tryMarkSinkable(SimplifyLocals* self, Expression** currp) { + auto* curr = (*currp)->dyn_cast<SetLocal>(); + if (curr) { + Name name = curr->name; + assert(self->sinkables.count(name) == 0); + self->sinkables.emplace(std::make_pair(name, SinkableInfo(currp))); + } + } - FastExecutionWalker::walk(curr); + // override scan to add a pre and a post check task to all nodes + static void scan(SimplifyLocals* self, Expression** currp) { + self->pushTask(visitPost, currp); + + auto* curr = *currp; + + + if (curr->is<Block>()) { + // special-case blocks, by marking their children as locals. + // TODO sink from elsewhere? (need to make sure value is not used) + self->pushTask(SimplifyLocals::doNoteNonLinear, currp); + auto& list = curr->cast<Block>()->list; + int size = list.size(); + // we can't sink the last element, as it might be a return value; + // and anyhow, control flow is nonlinear at the end of the block so + // it would be invalidated. + for (int i = size - 1; i >= 0; i--) { + if (i < size - 1) { + self->pushTask(tryMarkSinkable, &list[i]); + } + self->pushTask(scan, &list[i]); + } + } else { + WalkerPass<LinearExecutionWalker<SimplifyLocals>>::scan(self, currp); + } - checkPost(curr); + self->pushTask(visitPre, currp); } }; diff --git a/src/passes/Vacuum.cpp b/src/passes/Vacuum.cpp index f9704ed8d..ef83958c4 100644 --- a/src/passes/Vacuum.cpp +++ b/src/passes/Vacuum.cpp @@ -23,7 +23,7 @@ namespace wasm { -struct Vacuum : public WalkerPass<WasmWalker<Vacuum>> { +struct Vacuum : public WalkerPass<PostWalker<Vacuum>> { void visitBlock(Block *curr) { // compress out nops int skip = 0; diff --git a/src/s2wasm.h b/src/s2wasm.h index 614afdc80..9d166624c 100644 --- a/src/s2wasm.h +++ b/src/s2wasm.h @@ -1368,7 +1368,7 @@ public: o << ";; METADATA: { "; // find asmConst calls, and emit their metadata - struct AsmConstWalker : public WasmWalker<AsmConstWalker> { + struct AsmConstWalker : public PostWalker<AsmConstWalker> { S2WasmBuilder* parent; std::map<std::string, std::set<std::string>> sigsForCode; diff --git a/src/wasm-binary.h b/src/wasm-binary.h index f7fb4b8f3..d4e1bd2c1 100644 --- a/src/wasm-binary.h +++ b/src/wasm-binary.h @@ -431,7 +431,7 @@ int8_t binaryWasmType(WasmType type) { } } -class WasmBinaryWriter : public WasmVisitor<WasmBinaryWriter, void> { +class WasmBinaryWriter : public Visitor<WasmBinaryWriter, void> { Module* wasm; BufferWithRandomAccess& o; bool debug; diff --git a/src/wasm-interpreter.h b/src/wasm-interpreter.h index ed6411258..70564e4a4 100644 --- a/src/wasm-interpreter.h +++ b/src/wasm-interpreter.h @@ -207,7 +207,7 @@ private: #endif // Execute a statement - class ExpressionRunner : public WasmVisitor<ExpressionRunner, Flow> { + class ExpressionRunner : public Visitor<ExpressionRunner, Flow> { ModuleInstance& instance; FunctionScope& scope; diff --git a/src/wasm-traversal.h b/src/wasm-traversal.h index 24ec4905c..d1fff2753 100644 --- a/src/wasm-traversal.h +++ b/src/wasm-traversal.h @@ -31,47 +31,47 @@ namespace wasm { -template<typename SubType, typename ReturnType> -struct WasmVisitor { - virtual ~WasmVisitor() {} +template<typename SubType, typename ReturnType = void> +struct Visitor { + virtual ~Visitor() {} // Expression visitors - ReturnType visitBlock(Block *curr) { abort(); } - ReturnType visitIf(If *curr) { abort(); } - ReturnType visitLoop(Loop *curr) { abort(); } - ReturnType visitBreak(Break *curr) { abort(); } - ReturnType visitSwitch(Switch *curr) { abort(); } - ReturnType visitCall(Call *curr) { abort(); } - ReturnType visitCallImport(CallImport *curr) { abort(); } - ReturnType visitCallIndirect(CallIndirect *curr) { abort(); } - ReturnType visitGetLocal(GetLocal *curr) { abort(); } - ReturnType visitSetLocal(SetLocal *curr) { abort(); } - ReturnType visitLoad(Load *curr) { abort(); } - ReturnType visitStore(Store *curr) { abort(); } - ReturnType visitConst(Const *curr) { abort(); } - ReturnType visitUnary(Unary *curr) { abort(); } - ReturnType visitBinary(Binary *curr) { abort(); } - ReturnType visitSelect(Select *curr) { abort(); } - ReturnType visitReturn(Return *curr) { abort(); } - ReturnType visitHost(Host *curr) { abort(); } - ReturnType visitNop(Nop *curr) { abort(); } - ReturnType visitUnreachable(Unreachable *curr) { abort(); } + ReturnType visitBlock(Block *curr) {} + ReturnType visitIf(If *curr) {} + ReturnType visitLoop(Loop *curr) {} + ReturnType visitBreak(Break *curr) {} + ReturnType visitSwitch(Switch *curr) {} + ReturnType visitCall(Call *curr) {} + ReturnType visitCallImport(CallImport *curr) {} + ReturnType visitCallIndirect(CallIndirect *curr) {} + ReturnType visitGetLocal(GetLocal *curr) {} + ReturnType visitSetLocal(SetLocal *curr) {} + ReturnType visitLoad(Load *curr) {} + ReturnType visitStore(Store *curr) {} + ReturnType visitConst(Const *curr) {} + ReturnType visitUnary(Unary *curr) {} + ReturnType visitBinary(Binary *curr) {} + ReturnType visitSelect(Select *curr) {} + ReturnType visitReturn(Return *curr) {} + ReturnType visitHost(Host *curr) {} + ReturnType visitNop(Nop *curr) {} + ReturnType visitUnreachable(Unreachable *curr) {} // Module-level visitors - ReturnType visitFunctionType(FunctionType *curr) { abort(); } - ReturnType visitImport(Import *curr) { abort(); } - ReturnType visitExport(Export *curr) { abort(); } - ReturnType visitFunction(Function *curr) { abort(); } - ReturnType visitTable(Table *curr) { abort(); } - ReturnType visitMemory(Memory *curr) { abort(); } - ReturnType visitModule(Module *curr) { abort(); } - -#define DELEGATE(CLASS_TO_VISIT) \ - return static_cast<SubType*>(this)-> \ - visit##CLASS_TO_VISIT(static_cast<CLASS_TO_VISIT*>(curr)) + ReturnType visitFunctionType(FunctionType *curr) {} + ReturnType visitImport(Import *curr) {} + ReturnType visitExport(Export *curr) {} + ReturnType visitFunction(Function *curr) {} + ReturnType visitTable(Table *curr) {} + ReturnType visitMemory(Memory *curr) {} + ReturnType visitModule(Module *curr) {} ReturnType visit(Expression *curr) { assert(curr); + + #define DELEGATE(CLASS_TO_VISIT) \ + return static_cast<SubType*>(this)-> \ + visit##CLASS_TO_VISIT(static_cast<CLASS_TO_VISIT*>(curr)) + switch (curr->_id) { - case Expression::Id::InvalidId: abort(); case Expression::Id::BlockId: DELEGATE(Block); case Expression::Id::IfId: DELEGATE(If); case Expression::Id::LoopId: DELEGATE(Loop); @@ -92,49 +92,41 @@ struct WasmVisitor { case Expression::Id::HostId: DELEGATE(Host); case Expression::Id::NopId: DELEGATE(Nop); case Expression::Id::UnreachableId: DELEGATE(Unreachable); + case Expression::Id::InvalidId: default: WASM_UNREACHABLE(); } - } -#undef DELEGATE - - // Helper method to de-recurse blocks, which often nest in their first position very heavily - void derecurseBlocks(Block* block, std::function<void (Block*)> preBlock, - std::function<void (Block*, Expression*&)> onChild, - std::function<void (Block*)> postBlock) { - std::vector<Block*> stack; - stack.push_back(block); - while (block->list.size() > 0 && block->list[0]->is<Block>()) { - block = block->list[0]->cast<Block>(); - stack.push_back(block); - } - for (size_t i = 0; i < stack.size(); i++) { - preBlock(stack[i]); - } - for (int i = int(stack.size()) - 1; i >= 0; i--) { - auto* block = stack[i]; - auto& list = block->list; - for (size_t j = 0; j < list.size(); j++) { - if (i < int(stack.size()) - 1 && j == 0) { - // nested block, we already called its pre - } else { - onChild(block, list[j]); - } - } - postBlock(block); - } + #undef DELEGATE } }; // -// Base class for all WasmWalkers +// Base class for all WasmWalkers, which can traverse an AST +// and provide the option to replace nodes while doing so. // -template<typename SubType, typename ReturnType = void> -struct WasmWalkerBase : public WasmVisitor<SubType, ReturnType> { - virtual void walk(Expression*& curr) { abort(); } +// Subclass and implement the visit*() +// calls to run code on different node types. +// +template<typename SubType> +struct Walker : public Visitor<SubType> { + // Extra generic visitor, called before each node's specific visitor. Useful for + // passes that need to do the same thing for every node type. + void visitExpression(Expression* curr) {} + + // Node replacing as we walk - call replaceCurrent from + // your visitors. + + Expression *replace = nullptr; + + void replaceCurrent(Expression *expression) { + replace = expression; + } + + // Walk starting void startWalk(Function *func) { - walk(func->body); + SubType* self = static_cast<SubType*>(this); + self->walk(func->body); } void startWalk(Module *module) { @@ -157,185 +149,209 @@ struct WasmWalkerBase : public WasmVisitor<SubType, ReturnType> { self->visitMemory(&module->memory); self->visitModule(module); } -}; -template<typename ParentType> -struct ChildWalker : public WasmWalkerBase<ChildWalker<ParentType>> { - ParentType& parent; + // Walk implementation. We don't use recursion as ASTs may be highly + // nested. - ChildWalker(ParentType& parent) : parent(parent) {} + // Tasks receive the this pointer and a pointer to the pointer to operate on + typedef void (*TaskFunc)(SubType*, Expression**); - void visitBlock(Block *curr) { - ExpressionList& list = curr->list; - for (size_t z = 0; z < list.size(); z++) { - parent.walk(list[z]); - } - } - void visitIf(If *curr) { - parent.walk(curr->condition); - parent.walk(curr->ifTrue); - parent.walk(curr->ifFalse); - } - void visitLoop(Loop *curr) { - parent.walk(curr->body); - } - void visitBreak(Break *curr) { - parent.walk(curr->condition); - parent.walk(curr->value); - } - void visitSwitch(Switch *curr) { - parent.walk(curr->condition); - if (curr->value) parent.walk(curr->value); - } - void visitCall(Call *curr) { - ExpressionList& list = curr->operands; - for (size_t z = 0; z < list.size(); z++) { - parent.walk(list[z]); - } - } - void visitCallImport(CallImport *curr) { - ExpressionList& list = curr->operands; - for (size_t z = 0; z < list.size(); z++) { - parent.walk(list[z]); - } + struct Task { + TaskFunc func; + Expression** currp; + Task(TaskFunc func, Expression** currp) : func(func), currp(currp) {} + }; + + std::vector<Task> stack; + + void pushTask(TaskFunc func, Expression** currp) { + stack.emplace_back(func, currp); } - void visitCallIndirect(CallIndirect *curr) { - parent.walk(curr->target); - ExpressionList& list = curr->operands; - for (size_t z = 0; z < list.size(); z++) { - parent.walk(list[z]); + void maybePushTask(TaskFunc func, Expression** currp) { + if (*currp) { + stack.emplace_back(func, currp); } } - void visitGetLocal(GetLocal *curr) {} - void visitSetLocal(SetLocal *curr) { - parent.walk(curr->value); - } - void visitLoad(Load *curr) { - parent.walk(curr->ptr); - } - void visitStore(Store *curr) { - parent.walk(curr->ptr); - parent.walk(curr->value); - } - void visitConst(Const *curr) {} - void visitUnary(Unary *curr) { - parent.walk(curr->value); - } - void visitBinary(Binary *curr) { - parent.walk(curr->left); - parent.walk(curr->right); - } - void visitSelect(Select *curr) { - parent.walk(curr->ifTrue); - parent.walk(curr->ifFalse); - parent.walk(curr->condition); - } - void visitReturn(Return *curr) { - parent.walk(curr->value); - } - void visitHost(Host *curr) { - ExpressionList& list = curr->operands; - for (size_t z = 0; z < list.size(); z++) { - parent.walk(list[z]); + Task popTask() { + auto ret = stack.back(); + stack.pop_back(); + return ret; + } + + void walk(Expression*& root) { + assert(stack.size() == 0); + pushTask(SubType::scan, &root); + while (stack.size() > 0) { + auto task = popTask(); + assert(*task.currp); + task.func(static_cast<SubType*>(this), task.currp); + if (replace) { + *task.currp = replace; + replace = nullptr; + } } } - void visitNop(Nop *curr) {} - void visitUnreachable(Unreachable *curr) {} -}; - -// Walker that allows replacements -template<typename SubType, typename ReturnType = void> -struct WasmReplacerWalker : public WasmWalkerBase<SubType, ReturnType> { - Expression* replace = nullptr; - - // methods can call this to replace the current node - void replaceCurrent(Expression *expression) { - replace = expression; - } - void walk(Expression*& curr) override { - if (!curr) return; - - this->visit(curr); - - if (replace) { - curr = replace; - replace = nullptr; - } - } + // subclasses implement this to define the proper order of execution + static void scan(SubType* self, Expression** currp) { abort(); } + + // task hooks to call visitors + + static void doVisitBlock(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitBlock((*currp)->cast<Block>()); } + static void doVisitIf(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitIf((*currp)->cast<If>()); } + static void doVisitLoop(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitLoop((*currp)->cast<Loop>()); } + static void doVisitBreak(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitBreak((*currp)->cast<Break>()); } + static void doVisitSwitch(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitSwitch((*currp)->cast<Switch>()); } + static void doVisitCall(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitCall((*currp)->cast<Call>()); } + static void doVisitCallImport(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitCallImport((*currp)->cast<CallImport>()); } + static void doVisitCallIndirect(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitCallIndirect((*currp)->cast<CallIndirect>()); } + static void doVisitGetLocal(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitGetLocal((*currp)->cast<GetLocal>()); } + static void doVisitSetLocal(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitSetLocal((*currp)->cast<SetLocal>()); } + static void doVisitLoad(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitLoad((*currp)->cast<Load>()); } + static void doVisitStore(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitStore((*currp)->cast<Store>()); } + static void doVisitConst(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitConst((*currp)->cast<Const>()); } + static void doVisitUnary(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitUnary((*currp)->cast<Unary>()); } + static void doVisitBinary(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitBinary((*currp)->cast<Binary>()); } + static void doVisitSelect(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitSelect((*currp)->cast<Select>()); } + static void doVisitReturn(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitReturn((*currp)->cast<Return>()); } + static void doVisitHost(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitHost((*currp)->cast<Host>()); } + static void doVisitNop(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitNop((*currp)->cast<Nop>()); } + static void doVisitUnreachable(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitUnreachable((*currp)->cast<Unreachable>()); } }; -// -// Simple WebAssembly children-first walking (i.e., post-order, if you look -// at the children as subtrees of the current node), with the ability to replace -// the current expression node. Useful for writing optimization passes. -// +// Walks in post-order, i.e., children first. When there isn't an obvious +// order to operands, we follow them in order of execution. -template<typename SubType, typename ReturnType = void> -struct WasmWalker : public WasmReplacerWalker<SubType, ReturnType> { - // By default, do nothing - ReturnType visitBlock(Block *curr) {} - ReturnType visitIf(If *curr) {} - ReturnType visitLoop(Loop *curr) {} - ReturnType visitBreak(Break *curr) {} - ReturnType visitSwitch(Switch *curr) {} - ReturnType visitCall(Call *curr) {} - ReturnType visitCallImport(CallImport *curr) {} - ReturnType visitCallIndirect(CallIndirect *curr) {} - ReturnType visitGetLocal(GetLocal *curr) {} - ReturnType visitSetLocal(SetLocal *curr) {} - ReturnType visitLoad(Load *curr) {} - ReturnType visitStore(Store *curr) {} - ReturnType visitConst(Const *curr) {} - ReturnType visitUnary(Unary *curr) {} - ReturnType visitBinary(Binary *curr) {} - ReturnType visitSelect(Select *curr) {} - ReturnType visitReturn(Return *curr) {} - ReturnType visitHost(Host *curr) {} - ReturnType visitNop(Nop *curr) {} - ReturnType visitUnreachable(Unreachable *curr) {} +template<typename SubType> +struct PostWalker : public Walker<SubType> { - ReturnType visitFunctionType(FunctionType *curr) {} - ReturnType visitImport(Import *curr) {} - ReturnType visitExport(Export *curr) {} - ReturnType visitFunction(Function *curr) {} - ReturnType visitTable(Table *curr) {} - ReturnType visitMemory(Memory *curr) {} - ReturnType visitModule(Module *curr) {} + static void scan(SubType* self, Expression** currp) { - // children-first - void walk(Expression*& curr) override { - if (!curr) return; - - // special-case Block, because Block nesting (in their first element) can be incredibly deep - if (curr->is<Block>()) { - auto* block = curr->dyn_cast<Block>(); - std::vector<Block*> stack; - stack.push_back(block); - while (block->list.size() > 0 && block->list[0]->is<Block>()) { - block = block->list[0]->cast<Block>(); - stack.push_back(block); + Expression* curr = *currp; + switch (curr->_id) { + case Expression::Id::InvalidId: abort(); + case Expression::Id::BlockId: { + self->pushTask(SubType::doVisitBlock, currp); + auto& list = curr->cast<Block>()->list; + for (int i = int(list.size()) - 1; i >= 0; i--) { + self->pushTask(SubType::scan, &list[i]); + } + break; + } + case Expression::Id::IfId: { + self->pushTask(SubType::doVisitIf, currp); + self->maybePushTask(SubType::scan, &curr->cast<If>()->ifFalse); + self->pushTask(SubType::scan, &curr->cast<If>()->ifTrue); + self->pushTask(SubType::scan, &curr->cast<If>()->condition); + break; + } + case Expression::Id::LoopId: { + self->pushTask(SubType::doVisitLoop, currp); + self->pushTask(SubType::scan, &curr->cast<Loop>()->body); + break; + } + case Expression::Id::BreakId: { + self->pushTask(SubType::doVisitBreak, currp); + self->maybePushTask(SubType::scan, &curr->cast<Break>()->condition); + self->maybePushTask(SubType::scan, &curr->cast<Break>()->value); + break; + } + case Expression::Id::SwitchId: { + self->pushTask(SubType::doVisitSwitch, currp); + self->maybePushTask(SubType::scan, &curr->cast<Switch>()->value); + self->pushTask(SubType::scan, &curr->cast<Switch>()->condition); + break; + } + case Expression::Id::CallId: { + self->pushTask(SubType::doVisitCall, currp); + auto& list = curr->cast<Call>()->operands; + for (int i = int(list.size()) - 1; i >= 0; i--) { + self->pushTask(SubType::scan, &list[i]); + } + break; + } + case Expression::Id::CallImportId: { + self->pushTask(SubType::doVisitCallImport, currp); + auto& list = curr->cast<CallImport>()->operands; + for (int i = int(list.size()) - 1; i >= 0; i--) { + self->pushTask(SubType::scan, &list[i]); + } + break; + } + case Expression::Id::CallIndirectId: { + self->pushTask(SubType::doVisitCallIndirect, currp); + auto& list = curr->cast<CallIndirect>()->operands; + for (int i = int(list.size()) - 1; i >= 0; i--) { + self->pushTask(SubType::scan, &list[i]); + } + self->pushTask(SubType::scan, &curr->cast<CallIndirect>()->target); + break; + } + case Expression::Id::GetLocalId: { + self->pushTask(SubType::doVisitGetLocal, currp); // TODO: optimize leaves with a direct call? + break; + } + case Expression::Id::SetLocalId: { + self->pushTask(SubType::doVisitSetLocal, currp); + self->pushTask(SubType::scan, &curr->cast<SetLocal>()->value); + break; + } + case Expression::Id::LoadId: { + self->pushTask(SubType::doVisitLoad, currp); + self->pushTask(SubType::scan, &curr->cast<Load>()->ptr); + break; + } + case Expression::Id::StoreId: { + self->pushTask(SubType::doVisitStore, currp); + self->pushTask(SubType::scan, &curr->cast<Store>()->value); + self->pushTask(SubType::scan, &curr->cast<Store>()->ptr); + break; } - // walk all the children - for (int i = int(stack.size()) - 1; i >= 0; i--) { - auto* block = stack[i]; - auto& children = block->list; - for (size_t j = 0; j < children.size(); j++) { - if (i < int(stack.size()) - 1 && j == 0) { - // this is one of the stacked blocks, no need to walk its children, we are doing that ourselves - WasmReplacerWalker<SubType, ReturnType>::walk(children[0]); - } else { - this->walk(children[j]); - } + case Expression::Id::ConstId: { + self->pushTask(SubType::doVisitConst, currp); + break; + } + case Expression::Id::UnaryId: { + self->pushTask(SubType::doVisitUnary, currp); + self->pushTask(SubType::scan, &curr->cast<Unary>()->value); + break; + } + case Expression::Id::BinaryId: { + self->pushTask(SubType::doVisitBinary, currp); + self->pushTask(SubType::scan, &curr->cast<Binary>()->right); + self->pushTask(SubType::scan, &curr->cast<Binary>()->left); + break; + } + case Expression::Id::SelectId: { + self->pushTask(SubType::doVisitSelect, currp); + self->pushTask(SubType::scan, &curr->cast<Select>()->condition); + self->pushTask(SubType::scan, &curr->cast<Select>()->ifFalse); + self->pushTask(SubType::scan, &curr->cast<Select>()->ifTrue); + break; + } + case Expression::Id::ReturnId: { + self->pushTask(SubType::doVisitReturn, currp); + self->maybePushTask(SubType::scan, &curr->cast<Return>()->value); + break; + } + case Expression::Id::HostId: { + self->pushTask(SubType::doVisitHost, currp); + auto& list = curr->cast<Host>()->operands; + for (int i = int(list.size()) - 1; i >= 0; i--) { + self->pushTask(SubType::scan, &list[i]); } + break; + } + case Expression::Id::NopId: { + self->pushTask(SubType::doVisitNop, currp); + break; } - // we walked all the children, and can rejoin later below to visit this node itself - } else { - // generic child-walking - ChildWalker<WasmWalker<SubType, ReturnType>>(*this).visit(curr); + case Expression::Id::UnreachableId: { + self->pushTask(SubType::doVisitUnreachable, currp); + break; + } + default: WASM_UNREACHABLE(); } - - WasmReplacerWalker<SubType, ReturnType>::walk(curr); } }; @@ -350,111 +366,73 @@ struct WasmWalker : public WasmReplacerWalker<SubType, ReturnType> { // to noteNonLinear(). template<typename SubType> -struct FastExecutionWalker : public WasmReplacerWalker<SubType> { - FastExecutionWalker() {} +struct LinearExecutionWalker : public PostWalker<SubType> { + LinearExecutionWalker() {} - void noteNonLinear() {} + // subclasses should implement this + void noteNonLinear() { abort(); } -#define DELEGATE_noteNonLinear() \ - static_cast<SubType*>(this)->noteNonLinear() -#define DELEGATE_walk(ARG) \ - static_cast<SubType*>(this)->walk(ARG) - - void visitBlock(Block *curr) { - ExpressionList& list = curr->list; - for (size_t z = 0; z < list.size(); z++) { - DELEGATE_walk(list[z]); - } - } - void visitIf(If *curr) { - DELEGATE_walk(curr->condition); - DELEGATE_noteNonLinear(); - DELEGATE_walk(curr->ifTrue); - DELEGATE_noteNonLinear(); - DELEGATE_walk(curr->ifFalse); - DELEGATE_noteNonLinear(); - } - void visitLoop(Loop *curr) { - DELEGATE_noteNonLinear(); - DELEGATE_walk(curr->body); - } - void visitBreak(Break *curr) { - if (curr->value) DELEGATE_walk(curr->value); - if (curr->condition) DELEGATE_walk(curr->condition); - DELEGATE_noteNonLinear(); - } - void visitSwitch(Switch *curr) { - DELEGATE_walk(curr->condition); - if (curr->value) DELEGATE_walk(curr->value); - DELEGATE_noteNonLinear(); - } - void visitCall(Call *curr) { - ExpressionList& list = curr->operands; - for (size_t z = 0; z < list.size(); z++) { - DELEGATE_walk(list[z]); - } - } - void visitCallImport(CallImport *curr) { - ExpressionList& list = curr->operands; - for (size_t z = 0; z < list.size(); z++) { - DELEGATE_walk(list[z]); - } - } - void visitCallIndirect(CallIndirect *curr) { - DELEGATE_walk(curr->target); - ExpressionList& list = curr->operands; - for (size_t z = 0; z < list.size(); z++) { - DELEGATE_walk(list[z]); - } + static void doNoteNonLinear(SubType* self, Expression** currp) { + self->noteNonLinear(); } - void visitGetLocal(GetLocal *curr) {} - void visitSetLocal(SetLocal *curr) { - DELEGATE_walk(curr->value); - } - void visitLoad(Load *curr) { - DELEGATE_walk(curr->ptr); - } - void visitStore(Store *curr) { - DELEGATE_walk(curr->ptr); - DELEGATE_walk(curr->value); - } - void visitConst(Const *curr) {} - void visitUnary(Unary *curr) { - DELEGATE_walk(curr->value); - } - void visitBinary(Binary *curr) { - DELEGATE_walk(curr->left); - DELEGATE_walk(curr->right); - } - void visitSelect(Select *curr) { - DELEGATE_walk(curr->ifTrue); - DELEGATE_walk(curr->ifFalse); - DELEGATE_walk(curr->condition); - } - void visitReturn(Return *curr) { - DELEGATE_walk(curr->value); - DELEGATE_noteNonLinear(); - } - void visitHost(Host *curr) { - ExpressionList& list = curr->operands; - for (size_t z = 0; z < list.size(); z++) { - DELEGATE_walk(list[z]); - } - } - void visitNop(Nop *curr) {} - void visitUnreachable(Unreachable *curr) {} - void visitFunctionType(FunctionType *curr) {} - void visitImport(Import *curr) {} - void visitExport(Export *curr) {} - void visitFunction(Function *curr) {} - void visitTable(Table *curr) {} - void visitMemory(Memory *curr) {} - void visitModule(Module *curr) {} + static void scan(SubType* self, Expression** currp) { -#undef DELEGATE_noteNonLinear -#undef DELEGATE_walk + Expression* curr = *currp; + switch (curr->_id) { + case Expression::Id::InvalidId: abort(); + case Expression::Id::BlockId: { + self->pushTask(SubType::doVisitBlock, currp); + self->pushTask(SubType::doNoteNonLinear, currp); + auto& list = curr->cast<Block>()->list; + for (int i = int(list.size()) - 1; i >= 0; i--) { + self->pushTask(SubType::scan, &list[i]); + } + break; + } + case Expression::Id::IfId: { + self->pushTask(SubType::doVisitIf, currp); + self->pushTask(SubType::doNoteNonLinear, currp); + self->maybePushTask(SubType::scan, &curr->cast<If>()->ifFalse); + self->pushTask(SubType::doNoteNonLinear, currp); + self->pushTask(SubType::scan, &curr->cast<If>()->ifTrue); + self->pushTask(SubType::doNoteNonLinear, currp); + self->pushTask(SubType::scan, &curr->cast<If>()->condition); + break; + } + case Expression::Id::LoopId: { + self->pushTask(SubType::doVisitLoop, currp); + self->pushTask(SubType::scan, &curr->cast<Loop>()->body); + self->pushTask(SubType::doNoteNonLinear, currp); + break; + } + case Expression::Id::BreakId: { + self->pushTask(SubType::doVisitBreak, currp); + self->pushTask(SubType::doNoteNonLinear, currp); + self->maybePushTask(SubType::scan, &curr->cast<Break>()->condition); + self->maybePushTask(SubType::scan, &curr->cast<Break>()->value); + break; + } + case Expression::Id::SwitchId: { + self->pushTask(SubType::doVisitSwitch, currp); + self->pushTask(SubType::doNoteNonLinear, currp); + self->maybePushTask(SubType::scan, &curr->cast<Switch>()->value); + self->pushTask(SubType::scan, &curr->cast<Switch>()->condition); + break; + } + case Expression::Id::ReturnId: { + self->pushTask(SubType::doVisitReturn, currp); + self->pushTask(SubType::doNoteNonLinear, currp); + self->maybePushTask(SubType::scan, &curr->cast<Return>()->value); + break; + } + default: { + // other node types do not have control flow, use regular post-order + PostWalker<SubType>::scan(self, currp); + } + } + } }; } // namespace wasm diff --git a/src/wasm-validator.h b/src/wasm-validator.h index 77f475990..7b2987dd9 100644 --- a/src/wasm-validator.h +++ b/src/wasm-validator.h @@ -25,7 +25,7 @@ namespace wasm { -struct WasmValidator : public WasmWalker<WasmValidator> { +struct WasmValidator : public PostWalker<WasmValidator> { bool valid; std::map<Name, WasmType> breakTypes; // breaks to a label must all have the same type, and the right type @@ -118,7 +118,7 @@ public: private: // the "in" label has a none type, since no one can receive its value. make sure no one breaks to it with a value. - struct LoopChildChecker : public WasmWalker<LoopChildChecker> { + struct LoopChildChecker : public PostWalker<LoopChildChecker> { Name in; bool valid = true; diff --git a/src/wasm2asm.h b/src/wasm2asm.h index 0d8422ab1..2c330e273 100644 --- a/src/wasm2asm.h +++ b/src/wasm2asm.h @@ -392,7 +392,7 @@ Ref Wasm2AsmBuilder::processFunction(Function* func) { } void Wasm2AsmBuilder::scanFunctionBody(Expression* curr) { - struct ExpressionScanner : public WasmWalker<ExpressionScanner> { + struct ExpressionScanner : public PostWalker<ExpressionScanner> { Wasm2AsmBuilder* parent; ExpressionScanner(Wasm2AsmBuilder* parent) : parent(parent) {} @@ -467,6 +467,9 @@ void Wasm2AsmBuilder::scanFunctionBody(Expression* curr) { parent->setStatement(curr); } } + void visitReturn(Return *curr) { + abort(); + } void visitHost(Host *curr) { for (auto item : curr->operands) { if (parent->isStatement(item)) { @@ -480,7 +483,7 @@ void Wasm2AsmBuilder::scanFunctionBody(Expression* curr) { } Ref Wasm2AsmBuilder::processFunctionBody(Expression* curr, IString result) { - struct ExpressionProcessor : public WasmVisitor<ExpressionProcessor, Ref> { + struct ExpressionProcessor : public Visitor<ExpressionProcessor, Ref> { Wasm2AsmBuilder* parent; IString result; ExpressionProcessor(Wasm2AsmBuilder* parent) : parent(parent) {} @@ -521,7 +524,7 @@ Ref Wasm2AsmBuilder::processFunctionBody(Expression* curr, IString result) { Ref visit(Expression* curr, IString nextResult) { IString old = result; result = nextResult; - Ref ret = WasmVisitor::visit(curr); + Ref ret = Visitor::visit(curr); result = old; // keep it consistent for the rest of this frame, which may call visit on multiple children return ret; } @@ -1083,6 +1086,9 @@ Ref Wasm2AsmBuilder::processFunctionBody(Expression* curr, IString result) { ) ); } + Ref visitReturn(Return *curr) { + abort(); + } Ref visitHost(Host *curr) { abort(); } |