diff options
author | Alon Zakai <alonzakai@gmail.com> | 2016-04-07 21:28:26 -0700 |
---|---|---|
committer | Alon Zakai <alonzakai@gmail.com> | 2016-04-07 21:28:26 -0700 |
commit | d30b98d47697daa167333db66ac0fe3d8a693eae (patch) | |
tree | 75ad13773b0422d8f42b184ab8544d69384ef7da /src | |
parent | c0f0be986d9009a05a3bbaf42c841b863d9b83c1 (diff) | |
parent | 540056ededd811b859e0cf4db9782d8cb7711215 (diff) | |
download | binaryen-d30b98d47697daa167333db66ac0fe3d8a693eae.tar.gz binaryen-d30b98d47697daa167333db66ac0fe3d8a693eae.tar.bz2 binaryen-d30b98d47697daa167333db66ac0fe3d8a693eae.zip |
Merge pull request #319 from WebAssembly/traversal
Refactor traversal into its own header
Diffstat (limited to 'src')
-rw-r--r-- | src/asm2wasm.h | 1 | ||||
-rw-r--r-- | src/ast_utils.h | 63 | ||||
-rw-r--r-- | src/binaryen-shell.cpp | 3 | ||||
-rw-r--r-- | src/pass.h | 1 | ||||
-rw-r--r-- | src/passes/SimplifyLocals.cpp | 106 | ||||
-rw-r--r-- | src/passes/Vacuum.cpp | 48 | ||||
-rw-r--r-- | src/wasm-binary.h | 1 | ||||
-rw-r--r-- | src/wasm-traversal.h | 462 | ||||
-rw-r--r-- | src/wasm.h | 295 |
9 files changed, 672 insertions, 308 deletions
diff --git a/src/asm2wasm.h b/src/asm2wasm.h index 2f18b98f1..910134024 100644 --- a/src/asm2wasm.h +++ b/src/asm2wasm.h @@ -1606,6 +1606,7 @@ void Asm2WasmBuilder::optimize() { passRunner.add("optimize-instructions"); passRunner.add("simplify-locals"); passRunner.add("reorder-locals"); + passRunner.add("vacuum"); if (maxGlobal < 1024) { passRunner.add("post-emscripten"); } diff --git a/src/ast_utils.h b/src/ast_utils.h index df2ffe578..5ab427178 100644 --- a/src/ast_utils.h +++ b/src/ast_utils.h @@ -18,6 +18,7 @@ #define wasm_ast_utils_h #include "wasm.h" +#include "wasm-traversal.h" namespace wasm { @@ -38,6 +39,68 @@ struct BreakSeeker : public WasmWalker<BreakSeeker> { } }; +// Look for side effects, including control flow +// TODO: look at individual locals + +struct EffectAnalyzer : public WasmWalker<EffectAnalyzer> { + bool branches = false; + bool calls = false; + bool readsLocal = false; + bool writesLocal = false; + bool readsMemory = false; + bool writesMemory = false; + + bool accessesLocal() { return readsLocal || writesLocal; } + bool accessesMemory() { return calls || readsMemory || writesMemory; } + bool hasSideEffects() { return calls || writesLocal || writesMemory; } + bool hasAnything() { return branches || calls || readsLocal || writesLocal || readsMemory || writesMemory; } + + // checks if these effects would invalidate another set (e.g., if we write, we invalidate someone that reads, they can't be moved past us) + bool invalidates(EffectAnalyzer& other) { + return branches || other.branches + || ((writesMemory || calls) && other.accessesMemory()) || (writesLocal && other.accessesLocal()) + || (accessesMemory() && (other.writesMemory || other.calls)) || (accessesLocal() && other.writesLocal); + } + + // the checks above happen after the node's children were processed, in the order of execution + // we must also check for control flow that happens before the children, i.e., loops + bool checkPre(Expression* curr) { + if (curr->is<Loop>()) { + branches = true; + return true; + } + return false; + } + + bool checkPost(Expression* curr) { + visit(curr); + return hasAnything(); + } + + void visitBlock(Block *curr) { branches = true; } + void visitLoop(Loop *curr) { branches = true; } + void visitIf(If *curr) { branches = true; } + void visitBreak(Break *curr) { branches = true; } + void visitSwitch(Switch *curr) { branches = true; } + void visitCall(Call *curr) { calls = true; } + void visitCallImport(CallImport *curr) { calls = true; } + void visitCallIndirect(CallIndirect *curr) { calls = true; } + void visitGetLocal(GetLocal *curr) { readsLocal = true; } + void visitSetLocal(SetLocal *curr) { writesLocal = true; } + void visitLoad(Load *curr) { readsMemory = true; } + void visitStore(Store *curr) { writesMemory = true; } + void visitReturn(Return *curr) { branches = true; } + void visitHost(Host *curr) { calls = true; } + void visitUnreachable(Unreachable *curr) { branches = true; } +}; + +struct ExpressionManipulator { + // Nop is the smallest node, so we can always nop-ify another node in our arena + static void nop(Expression* target) { + *static_cast<Nop*>(target) = Nop(); + } +}; + } // namespace wasm #endif // wasm_ast_utils_h diff --git a/src/binaryen-shell.cpp b/src/binaryen-shell.cpp index 14cc66406..7f5b3077e 100644 --- a/src/binaryen-shell.cpp +++ b/src/binaryen-shell.cpp @@ -177,7 +177,8 @@ int main(int argc, const char* argv[]) { static const char* default_passes[] = {"remove-unused-brs", "remove-unused-names", "merge-blocks", "optimize-instructions", - "simplify-locals", "reorder-locals"}; + "simplify-locals", "reorder-locals", + "vacuum"}; Options options("binaryen-shell", "Execute .wast files"); options diff --git a/src/pass.h b/src/pass.h index 41ef30b90..e3716545d 100644 --- a/src/pass.h +++ b/src/pass.h @@ -20,6 +20,7 @@ #include <functional> #include "wasm.h" +#include "wasm-traversal.h" #include "mixed_arena.h" namespace wasm { diff --git a/src/passes/SimplifyLocals.cpp b/src/passes/SimplifyLocals.cpp index cbfc0dd66..0d59b8759 100644 --- a/src/passes/SimplifyLocals.cpp +++ b/src/passes/SimplifyLocals.cpp @@ -15,28 +15,110 @@ */ // -// Miscellaneous locals-related optimizations +// Locals-related optimizations // +// This "sinks" set_locals, pushing them to the next get_local where possible #include <wasm.h> +#include <wasm-traversal.h> #include <pass.h> +#include <ast_utils.h> namespace wasm { -struct SimplifyLocals : public WalkerPass<WasmWalker<SimplifyLocals>> { +struct SimplifyLocals : public WalkerPass<FastExecutionWalker<SimplifyLocals>> { + struct SinkableInfo { + Expression** item; + EffectAnalyzer effects; + + SinkableInfo(Expression** item) : item(item) { + effects.walk(*item); + } + }; + + // locals in current linear execution trace, which we try to sink + std::map<Name, SinkableInfo> sinkables; + + void noteNonLinear() { + sinkables.clear(); + } + void visitBlock(Block *curr) { - // look for pairs of setlocal-getlocal, which can be just a setlocal (since it returns a value) - if (curr->list.size() == 0) return; - for (size_t i = 0; i < curr->list.size() - 1; i++) { - auto set = curr->list[i]->dyn_cast<SetLocal>(); - if (!set) continue; - auto get = curr->list[i + 1]->dyn_cast<GetLocal>(); - if (!get) continue; - if (set->name != get->name) continue; - curr->list.erase(curr->list.begin() + i + 1); - i -= 1; + // note locals, we can sink them from here TODO sink from elsewhere? + derecurseBlocks(curr, [&](Block* block) { + // curr was already checked by walk() + if (block != curr) checkPre(block); + }, [&](Block* block, Expression*& child) { + walk(child); + if (child->is<SetLocal>()) { + Name name = child->cast<SetLocal>()->name; + assert(sinkables.count(name) == 0); + sinkables.emplace(std::make_pair(name, SinkableInfo(&child))); + } + }, [&](Block* block) { + if (block != curr) checkPost(block); + }); + } + + void visitGetLocal(GetLocal *curr) { + auto found = sinkables.find(curr->name); + if (found != sinkables.end()) { + // sink it, and nop the origin TODO: clean up nops + replaceCurrent(*found->second.item); + // reuse the getlocal that is dying + *found->second.item = curr; + ExpressionManipulator::nop(curr); + sinkables.erase(found); + } + } + + void visitSetLocal(SetLocal *curr) { + walk(curr->value); + // if we are a potentially-sinkable thing, forget it - this + // write overrides the last TODO: optimizable + // TODO: if no get_locals left, can remove the set as well (== expressionizer in emscripten optimizer) + auto found = sinkables.find(curr->name); + if (found != sinkables.end()) { + sinkables.erase(found); + } + } + + void checkInvalidations(EffectAnalyzer& effects) { + // TODO: this is O(bad) + std::vector<Name> invalidated; + for (auto& sinkable : sinkables) { + if (effects.invalidates(sinkable.second.effects)) { + invalidated.push_back(sinkable.first); + } + } + for (auto name : invalidated) { + sinkables.erase(name); } } + + void checkPre(Expression* curr) { + EffectAnalyzer effects; + if (effects.checkPre(curr)) { + checkInvalidations(effects); + } + } + + void checkPost(Expression* curr) { + EffectAnalyzer effects; + if (effects.checkPost(curr)) { + checkInvalidations(effects); + } + } + + void walk(Expression*& curr) override { + if (!curr) return; + + checkPre(curr); + + FastExecutionWalker::walk(curr); + + checkPost(curr); + } }; static RegisterPass<SimplifyLocals> registerPass("simplify-locals", "miscellaneous locals-related optimizations"); diff --git a/src/passes/Vacuum.cpp b/src/passes/Vacuum.cpp new file mode 100644 index 000000000..f9704ed8d --- /dev/null +++ b/src/passes/Vacuum.cpp @@ -0,0 +1,48 @@ +/* + * Copyright 2016 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// +// Removes obviously unneeded code +// + +#include <wasm.h> +#include <pass.h> + +namespace wasm { + +struct Vacuum : public WalkerPass<WasmWalker<Vacuum>> { + void visitBlock(Block *curr) { + // compress out nops + int skip = 0; + auto& list = curr->list; + size_t size = list.size(); + for (size_t z = 0; z < size; z++) { + if (list[z]->is<Nop>()) { + skip++; + } else if (skip > 0) { + list[z - skip] = list[z]; + } + } + if (skip > 0) { + list.resize(size - skip); + } + } +}; + +static RegisterPass<Vacuum> registerPass("vacuum", "removes obviously unneeded code"); + +} // namespace wasm + diff --git a/src/wasm-binary.h b/src/wasm-binary.h index f73ec7c44..f7fb4b8f3 100644 --- a/src/wasm-binary.h +++ b/src/wasm-binary.h @@ -25,6 +25,7 @@ #include <ostream> #include "wasm.h" +#include "wasm-traversal.h" #include "shared-constants.h" #include "asm_v_wasm.h" diff --git a/src/wasm-traversal.h b/src/wasm-traversal.h new file mode 100644 index 000000000..24ec4905c --- /dev/null +++ b/src/wasm-traversal.h @@ -0,0 +1,462 @@ +/* + * Copyright 2016 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// +// WebAssembly AST visitor. Useful for anything that wants to do something +// different for each AST node type, like printing, interpreting, etc. +// +// This class is specifically designed as a template to avoid virtual function +// call overhead. To write a visitor, derive from this class as follows: +// +// struct MyVisitor : public WasmVisitor<MyVisitor> { .. } +// + +#ifndef wasm_traversal_h +#define wasm_traversal_h + +#include "wasm.h" + +namespace wasm { + +template<typename SubType, typename ReturnType> +struct WasmVisitor { + virtual ~WasmVisitor() {} + // Expression visitors + ReturnType visitBlock(Block *curr) { abort(); } + ReturnType visitIf(If *curr) { abort(); } + ReturnType visitLoop(Loop *curr) { abort(); } + ReturnType visitBreak(Break *curr) { abort(); } + ReturnType visitSwitch(Switch *curr) { abort(); } + ReturnType visitCall(Call *curr) { abort(); } + ReturnType visitCallImport(CallImport *curr) { abort(); } + ReturnType visitCallIndirect(CallIndirect *curr) { abort(); } + ReturnType visitGetLocal(GetLocal *curr) { abort(); } + ReturnType visitSetLocal(SetLocal *curr) { abort(); } + ReturnType visitLoad(Load *curr) { abort(); } + ReturnType visitStore(Store *curr) { abort(); } + ReturnType visitConst(Const *curr) { abort(); } + ReturnType visitUnary(Unary *curr) { abort(); } + ReturnType visitBinary(Binary *curr) { abort(); } + ReturnType visitSelect(Select *curr) { abort(); } + ReturnType visitReturn(Return *curr) { abort(); } + ReturnType visitHost(Host *curr) { abort(); } + ReturnType visitNop(Nop *curr) { abort(); } + ReturnType visitUnreachable(Unreachable *curr) { abort(); } + // Module-level visitors + ReturnType visitFunctionType(FunctionType *curr) { abort(); } + ReturnType visitImport(Import *curr) { abort(); } + ReturnType visitExport(Export *curr) { abort(); } + ReturnType visitFunction(Function *curr) { abort(); } + ReturnType visitTable(Table *curr) { abort(); } + ReturnType visitMemory(Memory *curr) { abort(); } + ReturnType visitModule(Module *curr) { abort(); } + +#define DELEGATE(CLASS_TO_VISIT) \ + return static_cast<SubType*>(this)-> \ + visit##CLASS_TO_VISIT(static_cast<CLASS_TO_VISIT*>(curr)) + + ReturnType visit(Expression *curr) { + assert(curr); + switch (curr->_id) { + case Expression::Id::InvalidId: abort(); + case Expression::Id::BlockId: DELEGATE(Block); + case Expression::Id::IfId: DELEGATE(If); + case Expression::Id::LoopId: DELEGATE(Loop); + case Expression::Id::BreakId: DELEGATE(Break); + case Expression::Id::SwitchId: DELEGATE(Switch); + case Expression::Id::CallId: DELEGATE(Call); + case Expression::Id::CallImportId: DELEGATE(CallImport); + case Expression::Id::CallIndirectId: DELEGATE(CallIndirect); + case Expression::Id::GetLocalId: DELEGATE(GetLocal); + case Expression::Id::SetLocalId: DELEGATE(SetLocal); + case Expression::Id::LoadId: DELEGATE(Load); + case Expression::Id::StoreId: DELEGATE(Store); + case Expression::Id::ConstId: DELEGATE(Const); + case Expression::Id::UnaryId: DELEGATE(Unary); + case Expression::Id::BinaryId: DELEGATE(Binary); + case Expression::Id::SelectId: DELEGATE(Select); + case Expression::Id::ReturnId: DELEGATE(Return); + case Expression::Id::HostId: DELEGATE(Host); + case Expression::Id::NopId: DELEGATE(Nop); + case Expression::Id::UnreachableId: DELEGATE(Unreachable); + default: WASM_UNREACHABLE(); + } + } + +#undef DELEGATE + + // Helper method to de-recurse blocks, which often nest in their first position very heavily + void derecurseBlocks(Block* block, std::function<void (Block*)> preBlock, + std::function<void (Block*, Expression*&)> onChild, + std::function<void (Block*)> postBlock) { + std::vector<Block*> stack; + stack.push_back(block); + while (block->list.size() > 0 && block->list[0]->is<Block>()) { + block = block->list[0]->cast<Block>(); + stack.push_back(block); + } + for (size_t i = 0; i < stack.size(); i++) { + preBlock(stack[i]); + } + for (int i = int(stack.size()) - 1; i >= 0; i--) { + auto* block = stack[i]; + auto& list = block->list; + for (size_t j = 0; j < list.size(); j++) { + if (i < int(stack.size()) - 1 && j == 0) { + // nested block, we already called its pre + } else { + onChild(block, list[j]); + } + } + postBlock(block); + } + } +}; + +// +// Base class for all WasmWalkers +// +template<typename SubType, typename ReturnType = void> +struct WasmWalkerBase : public WasmVisitor<SubType, ReturnType> { + virtual void walk(Expression*& curr) { abort(); } + + void startWalk(Function *func) { + walk(func->body); + } + + void startWalk(Module *module) { + // Dispatch statically through the SubType. + SubType* self = static_cast<SubType*>(this); + for (auto curr : module->functionTypes) { + self->visitFunctionType(curr); + } + for (auto curr : module->imports) { + self->visitImport(curr); + } + for (auto curr : module->exports) { + self->visitExport(curr); + } + for (auto curr : module->functions) { + startWalk(curr); + self->visitFunction(curr); + } + self->visitTable(&module->table); + self->visitMemory(&module->memory); + self->visitModule(module); + } +}; + +template<typename ParentType> +struct ChildWalker : public WasmWalkerBase<ChildWalker<ParentType>> { + ParentType& parent; + + ChildWalker(ParentType& parent) : parent(parent) {} + + void visitBlock(Block *curr) { + ExpressionList& list = curr->list; + for (size_t z = 0; z < list.size(); z++) { + parent.walk(list[z]); + } + } + void visitIf(If *curr) { + parent.walk(curr->condition); + parent.walk(curr->ifTrue); + parent.walk(curr->ifFalse); + } + void visitLoop(Loop *curr) { + parent.walk(curr->body); + } + void visitBreak(Break *curr) { + parent.walk(curr->condition); + parent.walk(curr->value); + } + void visitSwitch(Switch *curr) { + parent.walk(curr->condition); + if (curr->value) parent.walk(curr->value); + } + void visitCall(Call *curr) { + ExpressionList& list = curr->operands; + for (size_t z = 0; z < list.size(); z++) { + parent.walk(list[z]); + } + } + void visitCallImport(CallImport *curr) { + ExpressionList& list = curr->operands; + for (size_t z = 0; z < list.size(); z++) { + parent.walk(list[z]); + } + } + void visitCallIndirect(CallIndirect *curr) { + parent.walk(curr->target); + ExpressionList& list = curr->operands; + for (size_t z = 0; z < list.size(); z++) { + parent.walk(list[z]); + } + } + void visitGetLocal(GetLocal *curr) {} + void visitSetLocal(SetLocal *curr) { + parent.walk(curr->value); + } + void visitLoad(Load *curr) { + parent.walk(curr->ptr); + } + void visitStore(Store *curr) { + parent.walk(curr->ptr); + parent.walk(curr->value); + } + void visitConst(Const *curr) {} + void visitUnary(Unary *curr) { + parent.walk(curr->value); + } + void visitBinary(Binary *curr) { + parent.walk(curr->left); + parent.walk(curr->right); + } + void visitSelect(Select *curr) { + parent.walk(curr->ifTrue); + parent.walk(curr->ifFalse); + parent.walk(curr->condition); + } + void visitReturn(Return *curr) { + parent.walk(curr->value); + } + void visitHost(Host *curr) { + ExpressionList& list = curr->operands; + for (size_t z = 0; z < list.size(); z++) { + parent.walk(list[z]); + } + } + void visitNop(Nop *curr) {} + void visitUnreachable(Unreachable *curr) {} +}; + +// Walker that allows replacements +template<typename SubType, typename ReturnType = void> +struct WasmReplacerWalker : public WasmWalkerBase<SubType, ReturnType> { + Expression* replace = nullptr; + + // methods can call this to replace the current node + void replaceCurrent(Expression *expression) { + replace = expression; + } + + void walk(Expression*& curr) override { + if (!curr) return; + + this->visit(curr); + + if (replace) { + curr = replace; + replace = nullptr; + } + } +}; + +// +// Simple WebAssembly children-first walking (i.e., post-order, if you look +// at the children as subtrees of the current node), with the ability to replace +// the current expression node. Useful for writing optimization passes. +// + +template<typename SubType, typename ReturnType = void> +struct WasmWalker : public WasmReplacerWalker<SubType, ReturnType> { + // By default, do nothing + ReturnType visitBlock(Block *curr) {} + ReturnType visitIf(If *curr) {} + ReturnType visitLoop(Loop *curr) {} + ReturnType visitBreak(Break *curr) {} + ReturnType visitSwitch(Switch *curr) {} + ReturnType visitCall(Call *curr) {} + ReturnType visitCallImport(CallImport *curr) {} + ReturnType visitCallIndirect(CallIndirect *curr) {} + ReturnType visitGetLocal(GetLocal *curr) {} + ReturnType visitSetLocal(SetLocal *curr) {} + ReturnType visitLoad(Load *curr) {} + ReturnType visitStore(Store *curr) {} + ReturnType visitConst(Const *curr) {} + ReturnType visitUnary(Unary *curr) {} + ReturnType visitBinary(Binary *curr) {} + ReturnType visitSelect(Select *curr) {} + ReturnType visitReturn(Return *curr) {} + ReturnType visitHost(Host *curr) {} + ReturnType visitNop(Nop *curr) {} + ReturnType visitUnreachable(Unreachable *curr) {} + + ReturnType visitFunctionType(FunctionType *curr) {} + ReturnType visitImport(Import *curr) {} + ReturnType visitExport(Export *curr) {} + ReturnType visitFunction(Function *curr) {} + ReturnType visitTable(Table *curr) {} + ReturnType visitMemory(Memory *curr) {} + ReturnType visitModule(Module *curr) {} + + // children-first + void walk(Expression*& curr) override { + if (!curr) return; + + // special-case Block, because Block nesting (in their first element) can be incredibly deep + if (curr->is<Block>()) { + auto* block = curr->dyn_cast<Block>(); + std::vector<Block*> stack; + stack.push_back(block); + while (block->list.size() > 0 && block->list[0]->is<Block>()) { + block = block->list[0]->cast<Block>(); + stack.push_back(block); + } + // walk all the children + for (int i = int(stack.size()) - 1; i >= 0; i--) { + auto* block = stack[i]; + auto& children = block->list; + for (size_t j = 0; j < children.size(); j++) { + if (i < int(stack.size()) - 1 && j == 0) { + // this is one of the stacked blocks, no need to walk its children, we are doing that ourselves + WasmReplacerWalker<SubType, ReturnType>::walk(children[0]); + } else { + this->walk(children[j]); + } + } + } + // we walked all the children, and can rejoin later below to visit this node itself + } else { + // generic child-walking + ChildWalker<WasmWalker<SubType, ReturnType>>(*this).visit(curr); + } + + WasmReplacerWalker<SubType, ReturnType>::walk(curr); + } +}; + +// Traversal in the order of execution. This is quick and simple, but +// does not provide the same comprehensive information that a full +// conversion to basic blocks would. What it does give is a quick +// way to view straightline execution traces, i.e., that have no +// branching. This can let optimizations get most of what they +// want without the cost of creating another AST. +// +// When execution is no longer linear, this notifies via a call +// to noteNonLinear(). + +template<typename SubType> +struct FastExecutionWalker : public WasmReplacerWalker<SubType> { + FastExecutionWalker() {} + + void noteNonLinear() {} + +#define DELEGATE_noteNonLinear() \ + static_cast<SubType*>(this)->noteNonLinear() +#define DELEGATE_walk(ARG) \ + static_cast<SubType*>(this)->walk(ARG) + + void visitBlock(Block *curr) { + ExpressionList& list = curr->list; + for (size_t z = 0; z < list.size(); z++) { + DELEGATE_walk(list[z]); + } + } + void visitIf(If *curr) { + DELEGATE_walk(curr->condition); + DELEGATE_noteNonLinear(); + DELEGATE_walk(curr->ifTrue); + DELEGATE_noteNonLinear(); + DELEGATE_walk(curr->ifFalse); + DELEGATE_noteNonLinear(); + } + void visitLoop(Loop *curr) { + DELEGATE_noteNonLinear(); + DELEGATE_walk(curr->body); + } + void visitBreak(Break *curr) { + if (curr->value) DELEGATE_walk(curr->value); + if (curr->condition) DELEGATE_walk(curr->condition); + DELEGATE_noteNonLinear(); + } + void visitSwitch(Switch *curr) { + DELEGATE_walk(curr->condition); + if (curr->value) DELEGATE_walk(curr->value); + DELEGATE_noteNonLinear(); + } + void visitCall(Call *curr) { + ExpressionList& list = curr->operands; + for (size_t z = 0; z < list.size(); z++) { + DELEGATE_walk(list[z]); + } + } + void visitCallImport(CallImport *curr) { + ExpressionList& list = curr->operands; + for (size_t z = 0; z < list.size(); z++) { + DELEGATE_walk(list[z]); + } + } + void visitCallIndirect(CallIndirect *curr) { + DELEGATE_walk(curr->target); + ExpressionList& list = curr->operands; + for (size_t z = 0; z < list.size(); z++) { + DELEGATE_walk(list[z]); + } + } + void visitGetLocal(GetLocal *curr) {} + void visitSetLocal(SetLocal *curr) { + DELEGATE_walk(curr->value); + } + void visitLoad(Load *curr) { + DELEGATE_walk(curr->ptr); + } + void visitStore(Store *curr) { + DELEGATE_walk(curr->ptr); + DELEGATE_walk(curr->value); + } + void visitConst(Const *curr) {} + void visitUnary(Unary *curr) { + DELEGATE_walk(curr->value); + } + void visitBinary(Binary *curr) { + DELEGATE_walk(curr->left); + DELEGATE_walk(curr->right); + } + void visitSelect(Select *curr) { + DELEGATE_walk(curr->ifTrue); + DELEGATE_walk(curr->ifFalse); + DELEGATE_walk(curr->condition); + } + void visitReturn(Return *curr) { + DELEGATE_walk(curr->value); + DELEGATE_noteNonLinear(); + } + void visitHost(Host *curr) { + ExpressionList& list = curr->operands; + for (size_t z = 0; z < list.size(); z++) { + DELEGATE_walk(list[z]); + } + } + void visitNop(Nop *curr) {} + void visitUnreachable(Unreachable *curr) {} + + void visitFunctionType(FunctionType *curr) {} + void visitImport(Import *curr) {} + void visitExport(Export *curr) {} + void visitFunction(Function *curr) {} + void visitTable(Table *curr) {} + void visitMemory(Memory *curr) {} + void visitModule(Module *curr) {} + +#undef DELEGATE_noteNonLinear +#undef DELEGATE_walk + +}; + +} // namespace wasm + +#endif // wasm_traversal_h diff --git a/src/wasm.h b/src/wasm.h index a07ab3079..f985e9b59 100644 --- a/src/wasm.h +++ b/src/wasm.h @@ -1212,301 +1212,6 @@ class AllocatingModule : public Module { MixedArena allocator; }; -// -// WebAssembly AST visitor. Useful for anything that wants to do something -// different for each AST node type, like printing, interpreting, etc. -// -// This class is specifically designed as a template to avoid virtual function -// call overhead. To write a visitor, derive from this class as follows: -// -// struct MyVisitor : public WasmVisitor<MyVisitor> { .. } -// - -template<typename SubType, typename ReturnType> -struct WasmVisitor { - virtual ~WasmVisitor() {} - // should be pure virtual, but https://gcc.gnu.org/bugzilla/show_bug.cgi?id=51048 - // Expression visitors - ReturnType visitBlock(Block *curr) { abort(); } - ReturnType visitIf(If *curr) { abort(); } - ReturnType visitLoop(Loop *curr) { abort(); } - ReturnType visitBreak(Break *curr) { abort(); } - ReturnType visitSwitch(Switch *curr) { abort(); } - ReturnType visitCall(Call *curr) { abort(); } - ReturnType visitCallImport(CallImport *curr) { abort(); } - ReturnType visitCallIndirect(CallIndirect *curr) { abort(); } - ReturnType visitGetLocal(GetLocal *curr) { abort(); } - ReturnType visitSetLocal(SetLocal *curr) { abort(); } - ReturnType visitLoad(Load *curr) { abort(); } - ReturnType visitStore(Store *curr) { abort(); } - ReturnType visitConst(Const *curr) { abort(); } - ReturnType visitUnary(Unary *curr) { abort(); } - ReturnType visitBinary(Binary *curr) { abort(); } - ReturnType visitSelect(Select *curr) { abort(); } - ReturnType visitReturn(Return *curr) { abort(); } - ReturnType visitHost(Host *curr) { abort(); } - ReturnType visitNop(Nop *curr) { abort(); } - ReturnType visitUnreachable(Unreachable *curr) { abort(); } - // Module-level visitors - ReturnType visitFunctionType(FunctionType *curr) { abort(); } - ReturnType visitImport(Import *curr) { abort(); } - ReturnType visitExport(Export *curr) { abort(); } - ReturnType visitFunction(Function *curr) { abort(); } - ReturnType visitTable(Table *curr) { abort(); } - ReturnType visitMemory(Memory *curr) { abort(); } - ReturnType visitModule(Module *curr) { abort(); } - -#define DELEGATE(CLASS_TO_VISIT) \ - return static_cast<SubType*>(this)-> \ - visit##CLASS_TO_VISIT(static_cast<CLASS_TO_VISIT*>(curr)) - - ReturnType visit(Expression *curr) { - assert(curr); - switch (curr->_id) { - case Expression::Id::InvalidId: abort(); - case Expression::Id::BlockId: DELEGATE(Block); - case Expression::Id::IfId: DELEGATE(If); - case Expression::Id::LoopId: DELEGATE(Loop); - case Expression::Id::BreakId: DELEGATE(Break); - case Expression::Id::SwitchId: DELEGATE(Switch); - case Expression::Id::CallId: DELEGATE(Call); - case Expression::Id::CallImportId: DELEGATE(CallImport); - case Expression::Id::CallIndirectId: DELEGATE(CallIndirect); - case Expression::Id::GetLocalId: DELEGATE(GetLocal); - case Expression::Id::SetLocalId: DELEGATE(SetLocal); - case Expression::Id::LoadId: DELEGATE(Load); - case Expression::Id::StoreId: DELEGATE(Store); - case Expression::Id::ConstId: DELEGATE(Const); - case Expression::Id::UnaryId: DELEGATE(Unary); - case Expression::Id::BinaryId: DELEGATE(Binary); - case Expression::Id::SelectId: DELEGATE(Select); - case Expression::Id::ReturnId: DELEGATE(Return); - case Expression::Id::HostId: DELEGATE(Host); - case Expression::Id::NopId: DELEGATE(Nop); - case Expression::Id::UnreachableId: DELEGATE(Unreachable); - default: WASM_UNREACHABLE(); - } - } -}; - -// -// Base class for all WasmWalkers -// -template<typename SubType, typename ReturnType = void> -struct WasmWalkerBase : public WasmVisitor<SubType, ReturnType> { - virtual void walk(Expression*& curr) { abort(); } - virtual void startWalk(Function *func) { abort(); } - virtual void startWalk(Module *module) { abort(); } -}; - -template<typename ParentType> -struct ChildWalker : public WasmWalkerBase<ChildWalker<ParentType>> { - ParentType& parent; - - ChildWalker(ParentType& parent) : parent(parent) {} - - void visitBlock(Block *curr) { - ExpressionList& list = curr->list; - for (size_t z = 0; z < list.size(); z++) { - parent.walk(list[z]); - } - } - void visitIf(If *curr) { - parent.walk(curr->condition); - parent.walk(curr->ifTrue); - parent.walk(curr->ifFalse); - } - void visitLoop(Loop *curr) { - parent.walk(curr->body); - } - void visitBreak(Break *curr) { - parent.walk(curr->condition); - parent.walk(curr->value); - } - void visitSwitch(Switch *curr) { - parent.walk(curr->condition); - if (curr->value) parent.walk(curr->value); - } - void visitCall(Call *curr) { - ExpressionList& list = curr->operands; - for (size_t z = 0; z < list.size(); z++) { - parent.walk(list[z]); - } - } - void visitCallImport(CallImport *curr) { - ExpressionList& list = curr->operands; - for (size_t z = 0; z < list.size(); z++) { - parent.walk(list[z]); - } - } - void visitCallIndirect(CallIndirect *curr) { - parent.walk(curr->target); - ExpressionList& list = curr->operands; - for (size_t z = 0; z < list.size(); z++) { - parent.walk(list[z]); - } - } - void visitGetLocal(GetLocal *curr) {} - void visitSetLocal(SetLocal *curr) { - parent.walk(curr->value); - } - void visitLoad(Load *curr) { - parent.walk(curr->ptr); - } - void visitStore(Store *curr) { - parent.walk(curr->ptr); - parent.walk(curr->value); - } - void visitConst(Const *curr) {} - void visitUnary(Unary *curr) { - parent.walk(curr->value); - } - void visitBinary(Binary *curr) { - parent.walk(curr->left); - parent.walk(curr->right); - } - void visitSelect(Select *curr) { - parent.walk(curr->ifTrue); - parent.walk(curr->ifFalse); - parent.walk(curr->condition); - } - void visitReturn(Return *curr) { - parent.walk(curr->value); - } - void visitHost(Host *curr) { - ExpressionList& list = curr->operands; - for (size_t z = 0; z < list.size(); z++) { - parent.walk(list[z]); - } - } - void visitNop(Nop *curr) {} - void visitUnreachable(Unreachable *curr) {} -}; - -// -// Simple WebAssembly children-first walking (i.e., post-order, if you look -// at the children as subtrees of the current node), with the ability to replace -// the current expression node. Useful for writing optimization passes. -// - -template<typename SubType, typename ReturnType = void> -struct WasmWalker : public WasmWalkerBase<SubType, ReturnType> { - Expression* replace; - - WasmWalker() : replace(nullptr) {} - - // the visit* methods can call this to replace the current node - void replaceCurrent(Expression *expression) { - replace = expression; - } - - // By default, do nothing - ReturnType visitBlock(Block *curr) {} - ReturnType visitIf(If *curr) {} - ReturnType visitLoop(Loop *curr) {} - ReturnType visitBreak(Break *curr) {} - ReturnType visitSwitch(Switch *curr) {} - ReturnType visitCall(Call *curr) {} - ReturnType visitCallImport(CallImport *curr) {} - ReturnType visitCallIndirect(CallIndirect *curr) {} - ReturnType visitGetLocal(GetLocal *curr) {} - ReturnType visitSetLocal(SetLocal *curr) {} - ReturnType visitLoad(Load *curr) {} - ReturnType visitStore(Store *curr) {} - ReturnType visitConst(Const *curr) {} - ReturnType visitUnary(Unary *curr) {} - ReturnType visitBinary(Binary *curr) {} - ReturnType visitSelect(Select *curr) {} - ReturnType visitReturn(Return *curr) {} - ReturnType visitHost(Host *curr) {} - ReturnType visitNop(Nop *curr) {} - ReturnType visitUnreachable(Unreachable *curr) {} - - ReturnType visitFunctionType(FunctionType *curr) {} - ReturnType visitImport(Import *curr) {} - ReturnType visitExport(Export *curr) {} - ReturnType visitFunction(Function *curr) {} - ReturnType visitTable(Table *curr) {} - ReturnType visitMemory(Memory *curr) {} - ReturnType visitModule(Module *curr) {} - - // children-first - void walk(Expression*& curr) override { - if (!curr) return; - - // special-case Block, because Block nesting (in their first element) can be incredibly deep - if (curr->is<Block>()) { - auto* block = curr->dyn_cast<Block>(); - std::vector<Block*> stack; - stack.push_back(block); - while (block->list.size() > 0 && block->list[0]->is<Block>()) { - block = block->list[0]->cast<Block>(); - stack.push_back(block); - } - // walk all the children - for (int i = int(stack.size()) - 1; i >= 0; i--) { - auto* block = stack[i]; - auto& children = block->list; - for (size_t j = 0; j < children.size(); j++) { - if (i < int(stack.size()) - 1 && j == 0) { - // this is one of the stacked blocks, no need to walk its children, we are doing that ourselves - this->visit(children[0]); - if (replace) { - children[0] = replace; - replace = nullptr; - } - } else { - this->walk(children[j]); - } - } - } - // we walked all the children, and can rejoin later below to visit this node itself - } else { - // generic child-walking - ChildWalker<WasmWalker<SubType, ReturnType>>(*this).visit(curr); - } - - this->visit(curr); - - if (replace) { - curr = replace; - replace = nullptr; - } - } - - void startWalk(Function *func) override { - walk(func->body); - } - - void startWalk(Module *module) override { - // Dispatch statically through the SubType. - SubType* self = static_cast<SubType*>(this); - for (auto curr : module->functionTypes) { - self->visitFunctionType(curr); - assert(!replace); - } - for (auto curr : module->imports) { - self->visitImport(curr); - assert(!replace); - } - for (auto curr : module->exports) { - self->visitExport(curr); - assert(!replace); - } - for (auto curr : module->functions) { - startWalk(curr); - self->visitFunction(curr); - assert(!replace); - } - self->visitTable(&module->table); - assert(!replace); - self->visitMemory(&module->memory); - assert(!replace); - self->visitModule(module); - assert(!replace); - } -}; - } // namespace wasm #endif // wasm_wasm_h |