summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAlon Zakai <alonzakai@gmail.com>2016-04-07 21:28:26 -0700
committerAlon Zakai <alonzakai@gmail.com>2016-04-07 21:28:26 -0700
commitd30b98d47697daa167333db66ac0fe3d8a693eae (patch)
tree75ad13773b0422d8f42b184ab8544d69384ef7da /src
parentc0f0be986d9009a05a3bbaf42c841b863d9b83c1 (diff)
parent540056ededd811b859e0cf4db9782d8cb7711215 (diff)
downloadbinaryen-d30b98d47697daa167333db66ac0fe3d8a693eae.tar.gz
binaryen-d30b98d47697daa167333db66ac0fe3d8a693eae.tar.bz2
binaryen-d30b98d47697daa167333db66ac0fe3d8a693eae.zip
Merge pull request #319 from WebAssembly/traversal
Refactor traversal into its own header
Diffstat (limited to 'src')
-rw-r--r--src/asm2wasm.h1
-rw-r--r--src/ast_utils.h63
-rw-r--r--src/binaryen-shell.cpp3
-rw-r--r--src/pass.h1
-rw-r--r--src/passes/SimplifyLocals.cpp106
-rw-r--r--src/passes/Vacuum.cpp48
-rw-r--r--src/wasm-binary.h1
-rw-r--r--src/wasm-traversal.h462
-rw-r--r--src/wasm.h295
9 files changed, 672 insertions, 308 deletions
diff --git a/src/asm2wasm.h b/src/asm2wasm.h
index 2f18b98f1..910134024 100644
--- a/src/asm2wasm.h
+++ b/src/asm2wasm.h
@@ -1606,6 +1606,7 @@ void Asm2WasmBuilder::optimize() {
passRunner.add("optimize-instructions");
passRunner.add("simplify-locals");
passRunner.add("reorder-locals");
+ passRunner.add("vacuum");
if (maxGlobal < 1024) {
passRunner.add("post-emscripten");
}
diff --git a/src/ast_utils.h b/src/ast_utils.h
index df2ffe578..5ab427178 100644
--- a/src/ast_utils.h
+++ b/src/ast_utils.h
@@ -18,6 +18,7 @@
#define wasm_ast_utils_h
#include "wasm.h"
+#include "wasm-traversal.h"
namespace wasm {
@@ -38,6 +39,68 @@ struct BreakSeeker : public WasmWalker<BreakSeeker> {
}
};
+// Look for side effects, including control flow
+// TODO: look at individual locals
+
+struct EffectAnalyzer : public WasmWalker<EffectAnalyzer> {
+ bool branches = false;
+ bool calls = false;
+ bool readsLocal = false;
+ bool writesLocal = false;
+ bool readsMemory = false;
+ bool writesMemory = false;
+
+ bool accessesLocal() { return readsLocal || writesLocal; }
+ bool accessesMemory() { return calls || readsMemory || writesMemory; }
+ bool hasSideEffects() { return calls || writesLocal || writesMemory; }
+ bool hasAnything() { return branches || calls || readsLocal || writesLocal || readsMemory || writesMemory; }
+
+ // checks if these effects would invalidate another set (e.g., if we write, we invalidate someone that reads, they can't be moved past us)
+ bool invalidates(EffectAnalyzer& other) {
+ return branches || other.branches
+ || ((writesMemory || calls) && other.accessesMemory()) || (writesLocal && other.accessesLocal())
+ || (accessesMemory() && (other.writesMemory || other.calls)) || (accessesLocal() && other.writesLocal);
+ }
+
+ // the checks above happen after the node's children were processed, in the order of execution
+ // we must also check for control flow that happens before the children, i.e., loops
+ bool checkPre(Expression* curr) {
+ if (curr->is<Loop>()) {
+ branches = true;
+ return true;
+ }
+ return false;
+ }
+
+ bool checkPost(Expression* curr) {
+ visit(curr);
+ return hasAnything();
+ }
+
+ void visitBlock(Block *curr) { branches = true; }
+ void visitLoop(Loop *curr) { branches = true; }
+ void visitIf(If *curr) { branches = true; }
+ void visitBreak(Break *curr) { branches = true; }
+ void visitSwitch(Switch *curr) { branches = true; }
+ void visitCall(Call *curr) { calls = true; }
+ void visitCallImport(CallImport *curr) { calls = true; }
+ void visitCallIndirect(CallIndirect *curr) { calls = true; }
+ void visitGetLocal(GetLocal *curr) { readsLocal = true; }
+ void visitSetLocal(SetLocal *curr) { writesLocal = true; }
+ void visitLoad(Load *curr) { readsMemory = true; }
+ void visitStore(Store *curr) { writesMemory = true; }
+ void visitReturn(Return *curr) { branches = true; }
+ void visitHost(Host *curr) { calls = true; }
+ void visitUnreachable(Unreachable *curr) { branches = true; }
+};
+
+struct ExpressionManipulator {
+ // Nop is the smallest node, so we can always nop-ify another node in our arena
+ static void nop(Expression* target) {
+ *static_cast<Nop*>(target) = Nop();
+ }
+};
+
} // namespace wasm
#endif // wasm_ast_utils_h
diff --git a/src/binaryen-shell.cpp b/src/binaryen-shell.cpp
index 14cc66406..7f5b3077e 100644
--- a/src/binaryen-shell.cpp
+++ b/src/binaryen-shell.cpp
@@ -177,7 +177,8 @@ int main(int argc, const char* argv[]) {
static const char* default_passes[] = {"remove-unused-brs",
"remove-unused-names", "merge-blocks",
"optimize-instructions",
- "simplify-locals", "reorder-locals"};
+ "simplify-locals", "reorder-locals",
+ "vacuum"};
Options options("binaryen-shell", "Execute .wast files");
options
diff --git a/src/pass.h b/src/pass.h
index 41ef30b90..e3716545d 100644
--- a/src/pass.h
+++ b/src/pass.h
@@ -20,6 +20,7 @@
#include <functional>
#include "wasm.h"
+#include "wasm-traversal.h"
#include "mixed_arena.h"
namespace wasm {
diff --git a/src/passes/SimplifyLocals.cpp b/src/passes/SimplifyLocals.cpp
index cbfc0dd66..0d59b8759 100644
--- a/src/passes/SimplifyLocals.cpp
+++ b/src/passes/SimplifyLocals.cpp
@@ -15,28 +15,110 @@
*/
//
-// Miscellaneous locals-related optimizations
+// Locals-related optimizations
//
+// This "sinks" set_locals, pushing them to the next get_local where possible
#include <wasm.h>
+#include <wasm-traversal.h>
#include <pass.h>
+#include <ast_utils.h>
namespace wasm {
-struct SimplifyLocals : public WalkerPass<WasmWalker<SimplifyLocals>> {
+struct SimplifyLocals : public WalkerPass<FastExecutionWalker<SimplifyLocals>> {
+ struct SinkableInfo {
+ Expression** item;
+ EffectAnalyzer effects;
+
+ SinkableInfo(Expression** item) : item(item) {
+ effects.walk(*item);
+ }
+ };
+
+ // locals in current linear execution trace, which we try to sink
+ std::map<Name, SinkableInfo> sinkables;
+
+ void noteNonLinear() {
+ sinkables.clear();
+ }
+
void visitBlock(Block *curr) {
- // look for pairs of setlocal-getlocal, which can be just a setlocal (since it returns a value)
- if (curr->list.size() == 0) return;
- for (size_t i = 0; i < curr->list.size() - 1; i++) {
- auto set = curr->list[i]->dyn_cast<SetLocal>();
- if (!set) continue;
- auto get = curr->list[i + 1]->dyn_cast<GetLocal>();
- if (!get) continue;
- if (set->name != get->name) continue;
- curr->list.erase(curr->list.begin() + i + 1);
- i -= 1;
+ // note locals, we can sink them from here TODO sink from elsewhere?
+ derecurseBlocks(curr, [&](Block* block) {
+ // curr was already checked by walk()
+ if (block != curr) checkPre(block);
+ }, [&](Block* block, Expression*& child) {
+ walk(child);
+ if (child->is<SetLocal>()) {
+ Name name = child->cast<SetLocal>()->name;
+ assert(sinkables.count(name) == 0);
+ sinkables.emplace(std::make_pair(name, SinkableInfo(&child)));
+ }
+ }, [&](Block* block) {
+ if (block != curr) checkPost(block);
+ });
+ }
+
+ void visitGetLocal(GetLocal *curr) {
+ auto found = sinkables.find(curr->name);
+ if (found != sinkables.end()) {
+ // sink it, and nop the origin TODO: clean up nops
+ replaceCurrent(*found->second.item);
+ // reuse the getlocal that is dying
+ *found->second.item = curr;
+ ExpressionManipulator::nop(curr);
+ sinkables.erase(found);
+ }
+ }
+
+ void visitSetLocal(SetLocal *curr) {
+ walk(curr->value);
+ // if we are a potentially-sinkable thing, forget it - this
+ // write overrides the last TODO: optimizable
+ // TODO: if no get_locals left, can remove the set as well (== expressionizer in emscripten optimizer)
+ auto found = sinkables.find(curr->name);
+ if (found != sinkables.end()) {
+ sinkables.erase(found);
+ }
+ }
+
+ void checkInvalidations(EffectAnalyzer& effects) {
+ // TODO: this is O(bad)
+ std::vector<Name> invalidated;
+ for (auto& sinkable : sinkables) {
+ if (effects.invalidates(sinkable.second.effects)) {
+ invalidated.push_back(sinkable.first);
+ }
+ }
+ for (auto name : invalidated) {
+ sinkables.erase(name);
}
}
+
+ void checkPre(Expression* curr) {
+ EffectAnalyzer effects;
+ if (effects.checkPre(curr)) {
+ checkInvalidations(effects);
+ }
+ }
+
+ void checkPost(Expression* curr) {
+ EffectAnalyzer effects;
+ if (effects.checkPost(curr)) {
+ checkInvalidations(effects);
+ }
+ }
+
+ void walk(Expression*& curr) override {
+ if (!curr) return;
+
+ checkPre(curr);
+
+ FastExecutionWalker::walk(curr);
+
+ checkPost(curr);
+ }
};
static RegisterPass<SimplifyLocals> registerPass("simplify-locals", "miscellaneous locals-related optimizations");
diff --git a/src/passes/Vacuum.cpp b/src/passes/Vacuum.cpp
new file mode 100644
index 000000000..f9704ed8d
--- /dev/null
+++ b/src/passes/Vacuum.cpp
@@ -0,0 +1,48 @@
+/*
+ * Copyright 2016 WebAssembly Community Group participants
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+//
+// Removes obviously unneeded code
+//
+
+#include <wasm.h>
+#include <pass.h>
+
+namespace wasm {
+
+struct Vacuum : public WalkerPass<WasmWalker<Vacuum>> {
+ void visitBlock(Block *curr) {
+ // compress out nops
+ int skip = 0;
+ auto& list = curr->list;
+ size_t size = list.size();
+ for (size_t z = 0; z < size; z++) {
+ if (list[z]->is<Nop>()) {
+ skip++;
+ } else if (skip > 0) {
+ list[z - skip] = list[z];
+ }
+ }
+ if (skip > 0) {
+ list.resize(size - skip);
+ }
+ }
+};
+
+static RegisterPass<Vacuum> registerPass("vacuum", "removes obviously unneeded code");
+
+} // namespace wasm
+
diff --git a/src/wasm-binary.h b/src/wasm-binary.h
index f73ec7c44..f7fb4b8f3 100644
--- a/src/wasm-binary.h
+++ b/src/wasm-binary.h
@@ -25,6 +25,7 @@
#include <ostream>
#include "wasm.h"
+#include "wasm-traversal.h"
#include "shared-constants.h"
#include "asm_v_wasm.h"
diff --git a/src/wasm-traversal.h b/src/wasm-traversal.h
new file mode 100644
index 000000000..24ec4905c
--- /dev/null
+++ b/src/wasm-traversal.h
@@ -0,0 +1,462 @@
+/*
+ * Copyright 2016 WebAssembly Community Group participants
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+//
+// WebAssembly AST visitor. Useful for anything that wants to do something
+// different for each AST node type, like printing, interpreting, etc.
+//
+// This class is specifically designed as a template to avoid virtual function
+// call overhead. To write a visitor, derive from this class as follows:
+//
+// struct MyVisitor : public WasmVisitor<MyVisitor> { .. }
+//
+
+#ifndef wasm_traversal_h
+#define wasm_traversal_h
+
+#include "wasm.h"
+
+namespace wasm {
+
+template<typename SubType, typename ReturnType>
+struct WasmVisitor {
+ virtual ~WasmVisitor() {}
+ // Expression visitors
+ ReturnType visitBlock(Block *curr) { abort(); }
+ ReturnType visitIf(If *curr) { abort(); }
+ ReturnType visitLoop(Loop *curr) { abort(); }
+ ReturnType visitBreak(Break *curr) { abort(); }
+ ReturnType visitSwitch(Switch *curr) { abort(); }
+ ReturnType visitCall(Call *curr) { abort(); }
+ ReturnType visitCallImport(CallImport *curr) { abort(); }
+ ReturnType visitCallIndirect(CallIndirect *curr) { abort(); }
+ ReturnType visitGetLocal(GetLocal *curr) { abort(); }
+ ReturnType visitSetLocal(SetLocal *curr) { abort(); }
+ ReturnType visitLoad(Load *curr) { abort(); }
+ ReturnType visitStore(Store *curr) { abort(); }
+ ReturnType visitConst(Const *curr) { abort(); }
+ ReturnType visitUnary(Unary *curr) { abort(); }
+ ReturnType visitBinary(Binary *curr) { abort(); }
+ ReturnType visitSelect(Select *curr) { abort(); }
+ ReturnType visitReturn(Return *curr) { abort(); }
+ ReturnType visitHost(Host *curr) { abort(); }
+ ReturnType visitNop(Nop *curr) { abort(); }
+ ReturnType visitUnreachable(Unreachable *curr) { abort(); }
+ // Module-level visitors
+ ReturnType visitFunctionType(FunctionType *curr) { abort(); }
+ ReturnType visitImport(Import *curr) { abort(); }
+ ReturnType visitExport(Export *curr) { abort(); }
+ ReturnType visitFunction(Function *curr) { abort(); }
+ ReturnType visitTable(Table *curr) { abort(); }
+ ReturnType visitMemory(Memory *curr) { abort(); }
+ ReturnType visitModule(Module *curr) { abort(); }
+
+#define DELEGATE(CLASS_TO_VISIT) \
+ return static_cast<SubType*>(this)-> \
+ visit##CLASS_TO_VISIT(static_cast<CLASS_TO_VISIT*>(curr))
+
+ ReturnType visit(Expression *curr) {
+ assert(curr);
+ switch (curr->_id) {
+ case Expression::Id::InvalidId: abort();
+ case Expression::Id::BlockId: DELEGATE(Block);
+ case Expression::Id::IfId: DELEGATE(If);
+ case Expression::Id::LoopId: DELEGATE(Loop);
+ case Expression::Id::BreakId: DELEGATE(Break);
+ case Expression::Id::SwitchId: DELEGATE(Switch);
+ case Expression::Id::CallId: DELEGATE(Call);
+ case Expression::Id::CallImportId: DELEGATE(CallImport);
+ case Expression::Id::CallIndirectId: DELEGATE(CallIndirect);
+ case Expression::Id::GetLocalId: DELEGATE(GetLocal);
+ case Expression::Id::SetLocalId: DELEGATE(SetLocal);
+ case Expression::Id::LoadId: DELEGATE(Load);
+ case Expression::Id::StoreId: DELEGATE(Store);
+ case Expression::Id::ConstId: DELEGATE(Const);
+ case Expression::Id::UnaryId: DELEGATE(Unary);
+ case Expression::Id::BinaryId: DELEGATE(Binary);
+ case Expression::Id::SelectId: DELEGATE(Select);
+ case Expression::Id::ReturnId: DELEGATE(Return);
+ case Expression::Id::HostId: DELEGATE(Host);
+ case Expression::Id::NopId: DELEGATE(Nop);
+ case Expression::Id::UnreachableId: DELEGATE(Unreachable);
+ default: WASM_UNREACHABLE();
+ }
+ }
+
+#undef DELEGATE
+
+ // Helper method to de-recurse blocks, which often nest in their first position very heavily
+ void derecurseBlocks(Block* block, std::function<void (Block*)> preBlock,
+ std::function<void (Block*, Expression*&)> onChild,
+ std::function<void (Block*)> postBlock) {
+ std::vector<Block*> stack;
+ stack.push_back(block);
+ while (block->list.size() > 0 && block->list[0]->is<Block>()) {
+ block = block->list[0]->cast<Block>();
+ stack.push_back(block);
+ }
+ for (size_t i = 0; i < stack.size(); i++) {
+ preBlock(stack[i]);
+ }
+ for (int i = int(stack.size()) - 1; i >= 0; i--) {
+ auto* block = stack[i];
+ auto& list = block->list;
+ for (size_t j = 0; j < list.size(); j++) {
+ if (i < int(stack.size()) - 1 && j == 0) {
+ // nested block, we already called its pre
+ } else {
+ onChild(block, list[j]);
+ }
+ }
+ postBlock(block);
+ }
+ }
+};
+
+//
+// Base class for all WasmWalkers
+//
+template<typename SubType, typename ReturnType = void>
+struct WasmWalkerBase : public WasmVisitor<SubType, ReturnType> {
+ virtual void walk(Expression*& curr) { abort(); }
+
+ void startWalk(Function *func) {
+ walk(func->body);
+ }
+
+ void startWalk(Module *module) {
+ // Dispatch statically through the SubType.
+ SubType* self = static_cast<SubType*>(this);
+ for (auto curr : module->functionTypes) {
+ self->visitFunctionType(curr);
+ }
+ for (auto curr : module->imports) {
+ self->visitImport(curr);
+ }
+ for (auto curr : module->exports) {
+ self->visitExport(curr);
+ }
+ for (auto curr : module->functions) {
+ startWalk(curr);
+ self->visitFunction(curr);
+ }
+ self->visitTable(&module->table);
+ self->visitMemory(&module->memory);
+ self->visitModule(module);
+ }
+};
+
+template<typename ParentType>
+struct ChildWalker : public WasmWalkerBase<ChildWalker<ParentType>> {
+ ParentType& parent;
+
+ ChildWalker(ParentType& parent) : parent(parent) {}
+
+ void visitBlock(Block *curr) {
+ ExpressionList& list = curr->list;
+ for (size_t z = 0; z < list.size(); z++) {
+ parent.walk(list[z]);
+ }
+ }
+ void visitIf(If *curr) {
+ parent.walk(curr->condition);
+ parent.walk(curr->ifTrue);
+ parent.walk(curr->ifFalse);
+ }
+ void visitLoop(Loop *curr) {
+ parent.walk(curr->body);
+ }
+ void visitBreak(Break *curr) {
+ parent.walk(curr->condition);
+ parent.walk(curr->value);
+ }
+ void visitSwitch(Switch *curr) {
+ parent.walk(curr->condition);
+ if (curr->value) parent.walk(curr->value);
+ }
+ void visitCall(Call *curr) {
+ ExpressionList& list = curr->operands;
+ for (size_t z = 0; z < list.size(); z++) {
+ parent.walk(list[z]);
+ }
+ }
+ void visitCallImport(CallImport *curr) {
+ ExpressionList& list = curr->operands;
+ for (size_t z = 0; z < list.size(); z++) {
+ parent.walk(list[z]);
+ }
+ }
+ void visitCallIndirect(CallIndirect *curr) {
+ parent.walk(curr->target);
+ ExpressionList& list = curr->operands;
+ for (size_t z = 0; z < list.size(); z++) {
+ parent.walk(list[z]);
+ }
+ }
+ void visitGetLocal(GetLocal *curr) {}
+ void visitSetLocal(SetLocal *curr) {
+ parent.walk(curr->value);
+ }
+ void visitLoad(Load *curr) {
+ parent.walk(curr->ptr);
+ }
+ void visitStore(Store *curr) {
+ parent.walk(curr->ptr);
+ parent.walk(curr->value);
+ }
+ void visitConst(Const *curr) {}
+ void visitUnary(Unary *curr) {
+ parent.walk(curr->value);
+ }
+ void visitBinary(Binary *curr) {
+ parent.walk(curr->left);
+ parent.walk(curr->right);
+ }
+ void visitSelect(Select *curr) {
+ parent.walk(curr->ifTrue);
+ parent.walk(curr->ifFalse);
+ parent.walk(curr->condition);
+ }
+ void visitReturn(Return *curr) {
+ parent.walk(curr->value);
+ }
+ void visitHost(Host *curr) {
+ ExpressionList& list = curr->operands;
+ for (size_t z = 0; z < list.size(); z++) {
+ parent.walk(list[z]);
+ }
+ }
+ void visitNop(Nop *curr) {}
+ void visitUnreachable(Unreachable *curr) {}
+};
+
+// Walker that allows replacements
+template<typename SubType, typename ReturnType = void>
+struct WasmReplacerWalker : public WasmWalkerBase<SubType, ReturnType> {
+ Expression* replace = nullptr;
+
+ // methods can call this to replace the current node
+ void replaceCurrent(Expression *expression) {
+ replace = expression;
+ }
+
+ void walk(Expression*& curr) override {
+ if (!curr) return;
+
+ this->visit(curr);
+
+ if (replace) {
+ curr = replace;
+ replace = nullptr;
+ }
+ }
+};
+
+//
+// Simple WebAssembly children-first walking (i.e., post-order, if you look
+// at the children as subtrees of the current node), with the ability to replace
+// the current expression node. Useful for writing optimization passes.
+//
+
+template<typename SubType, typename ReturnType = void>
+struct WasmWalker : public WasmReplacerWalker<SubType, ReturnType> {
+ // By default, do nothing
+ ReturnType visitBlock(Block *curr) {}
+ ReturnType visitIf(If *curr) {}
+ ReturnType visitLoop(Loop *curr) {}
+ ReturnType visitBreak(Break *curr) {}
+ ReturnType visitSwitch(Switch *curr) {}
+ ReturnType visitCall(Call *curr) {}
+ ReturnType visitCallImport(CallImport *curr) {}
+ ReturnType visitCallIndirect(CallIndirect *curr) {}
+ ReturnType visitGetLocal(GetLocal *curr) {}
+ ReturnType visitSetLocal(SetLocal *curr) {}
+ ReturnType visitLoad(Load *curr) {}
+ ReturnType visitStore(Store *curr) {}
+ ReturnType visitConst(Const *curr) {}
+ ReturnType visitUnary(Unary *curr) {}
+ ReturnType visitBinary(Binary *curr) {}
+ ReturnType visitSelect(Select *curr) {}
+ ReturnType visitReturn(Return *curr) {}
+ ReturnType visitHost(Host *curr) {}
+ ReturnType visitNop(Nop *curr) {}
+ ReturnType visitUnreachable(Unreachable *curr) {}
+
+ ReturnType visitFunctionType(FunctionType *curr) {}
+ ReturnType visitImport(Import *curr) {}
+ ReturnType visitExport(Export *curr) {}
+ ReturnType visitFunction(Function *curr) {}
+ ReturnType visitTable(Table *curr) {}
+ ReturnType visitMemory(Memory *curr) {}
+ ReturnType visitModule(Module *curr) {}
+
+ // children-first
+ void walk(Expression*& curr) override {
+ if (!curr) return;
+
+ // special-case Block, because Block nesting (in their first element) can be incredibly deep
+ if (curr->is<Block>()) {
+ auto* block = curr->dyn_cast<Block>();
+ std::vector<Block*> stack;
+ stack.push_back(block);
+ while (block->list.size() > 0 && block->list[0]->is<Block>()) {
+ block = block->list[0]->cast<Block>();
+ stack.push_back(block);
+ }
+ // walk all the children
+ for (int i = int(stack.size()) - 1; i >= 0; i--) {
+ auto* block = stack[i];
+ auto& children = block->list;
+ for (size_t j = 0; j < children.size(); j++) {
+ if (i < int(stack.size()) - 1 && j == 0) {
+ // this is one of the stacked blocks, no need to walk its children, we are doing that ourselves
+ WasmReplacerWalker<SubType, ReturnType>::walk(children[0]);
+ } else {
+ this->walk(children[j]);
+ }
+ }
+ }
+ // we walked all the children, and can rejoin later below to visit this node itself
+ } else {
+ // generic child-walking
+ ChildWalker<WasmWalker<SubType, ReturnType>>(*this).visit(curr);
+ }
+
+ WasmReplacerWalker<SubType, ReturnType>::walk(curr);
+ }
+};
+
+// Traversal in the order of execution. This is quick and simple, but
+// does not provide the same comprehensive information that a full
+// conversion to basic blocks would. What it does give is a quick
+// way to view straightline execution traces, i.e., that have no
+// branching. This can let optimizations get most of what they
+// want without the cost of creating another AST.
+//
+// When execution is no longer linear, this notifies via a call
+// to noteNonLinear().
+
+template<typename SubType>
+struct FastExecutionWalker : public WasmReplacerWalker<SubType> {
+ FastExecutionWalker() {}
+
+ void noteNonLinear() {}
+
+#define DELEGATE_noteNonLinear() \
+ static_cast<SubType*>(this)->noteNonLinear()
+#define DELEGATE_walk(ARG) \
+ static_cast<SubType*>(this)->walk(ARG)
+
+ void visitBlock(Block *curr) {
+ ExpressionList& list = curr->list;
+ for (size_t z = 0; z < list.size(); z++) {
+ DELEGATE_walk(list[z]);
+ }
+ }
+ void visitIf(If *curr) {
+ DELEGATE_walk(curr->condition);
+ DELEGATE_noteNonLinear();
+ DELEGATE_walk(curr->ifTrue);
+ DELEGATE_noteNonLinear();
+ DELEGATE_walk(curr->ifFalse);
+ DELEGATE_noteNonLinear();
+ }
+ void visitLoop(Loop *curr) {
+ DELEGATE_noteNonLinear();
+ DELEGATE_walk(curr->body);
+ }
+ void visitBreak(Break *curr) {
+ if (curr->value) DELEGATE_walk(curr->value);
+ if (curr->condition) DELEGATE_walk(curr->condition);
+ DELEGATE_noteNonLinear();
+ }
+ void visitSwitch(Switch *curr) {
+ DELEGATE_walk(curr->condition);
+ if (curr->value) DELEGATE_walk(curr->value);
+ DELEGATE_noteNonLinear();
+ }
+ void visitCall(Call *curr) {
+ ExpressionList& list = curr->operands;
+ for (size_t z = 0; z < list.size(); z++) {
+ DELEGATE_walk(list[z]);
+ }
+ }
+ void visitCallImport(CallImport *curr) {
+ ExpressionList& list = curr->operands;
+ for (size_t z = 0; z < list.size(); z++) {
+ DELEGATE_walk(list[z]);
+ }
+ }
+ void visitCallIndirect(CallIndirect *curr) {
+ DELEGATE_walk(curr->target);
+ ExpressionList& list = curr->operands;
+ for (size_t z = 0; z < list.size(); z++) {
+ DELEGATE_walk(list[z]);
+ }
+ }
+ void visitGetLocal(GetLocal *curr) {}
+ void visitSetLocal(SetLocal *curr) {
+ DELEGATE_walk(curr->value);
+ }
+ void visitLoad(Load *curr) {
+ DELEGATE_walk(curr->ptr);
+ }
+ void visitStore(Store *curr) {
+ DELEGATE_walk(curr->ptr);
+ DELEGATE_walk(curr->value);
+ }
+ void visitConst(Const *curr) {}
+ void visitUnary(Unary *curr) {
+ DELEGATE_walk(curr->value);
+ }
+ void visitBinary(Binary *curr) {
+ DELEGATE_walk(curr->left);
+ DELEGATE_walk(curr->right);
+ }
+ void visitSelect(Select *curr) {
+ DELEGATE_walk(curr->ifTrue);
+ DELEGATE_walk(curr->ifFalse);
+ DELEGATE_walk(curr->condition);
+ }
+ void visitReturn(Return *curr) {
+ DELEGATE_walk(curr->value);
+ DELEGATE_noteNonLinear();
+ }
+ void visitHost(Host *curr) {
+ ExpressionList& list = curr->operands;
+ for (size_t z = 0; z < list.size(); z++) {
+ DELEGATE_walk(list[z]);
+ }
+ }
+ void visitNop(Nop *curr) {}
+ void visitUnreachable(Unreachable *curr) {}
+
+ void visitFunctionType(FunctionType *curr) {}
+ void visitImport(Import *curr) {}
+ void visitExport(Export *curr) {}
+ void visitFunction(Function *curr) {}
+ void visitTable(Table *curr) {}
+ void visitMemory(Memory *curr) {}
+ void visitModule(Module *curr) {}
+
+#undef DELEGATE_noteNonLinear
+#undef DELEGATE_walk
+
+};
+
+} // namespace wasm
+
+#endif // wasm_traversal_h
diff --git a/src/wasm.h b/src/wasm.h
index a07ab3079..f985e9b59 100644
--- a/src/wasm.h
+++ b/src/wasm.h
@@ -1212,301 +1212,6 @@ class AllocatingModule : public Module {
MixedArena allocator;
};
-//
-// WebAssembly AST visitor. Useful for anything that wants to do something
-// different for each AST node type, like printing, interpreting, etc.
-//
-// This class is specifically designed as a template to avoid virtual function
-// call overhead. To write a visitor, derive from this class as follows:
-//
-// struct MyVisitor : public WasmVisitor<MyVisitor> { .. }
-//
-
-template<typename SubType, typename ReturnType>
-struct WasmVisitor {
- virtual ~WasmVisitor() {}
- // should be pure virtual, but https://gcc.gnu.org/bugzilla/show_bug.cgi?id=51048
- // Expression visitors
- ReturnType visitBlock(Block *curr) { abort(); }
- ReturnType visitIf(If *curr) { abort(); }
- ReturnType visitLoop(Loop *curr) { abort(); }
- ReturnType visitBreak(Break *curr) { abort(); }
- ReturnType visitSwitch(Switch *curr) { abort(); }
- ReturnType visitCall(Call *curr) { abort(); }
- ReturnType visitCallImport(CallImport *curr) { abort(); }
- ReturnType visitCallIndirect(CallIndirect *curr) { abort(); }
- ReturnType visitGetLocal(GetLocal *curr) { abort(); }
- ReturnType visitSetLocal(SetLocal *curr) { abort(); }
- ReturnType visitLoad(Load *curr) { abort(); }
- ReturnType visitStore(Store *curr) { abort(); }
- ReturnType visitConst(Const *curr) { abort(); }
- ReturnType visitUnary(Unary *curr) { abort(); }
- ReturnType visitBinary(Binary *curr) { abort(); }
- ReturnType visitSelect(Select *curr) { abort(); }
- ReturnType visitReturn(Return *curr) { abort(); }
- ReturnType visitHost(Host *curr) { abort(); }
- ReturnType visitNop(Nop *curr) { abort(); }
- ReturnType visitUnreachable(Unreachable *curr) { abort(); }
- // Module-level visitors
- ReturnType visitFunctionType(FunctionType *curr) { abort(); }
- ReturnType visitImport(Import *curr) { abort(); }
- ReturnType visitExport(Export *curr) { abort(); }
- ReturnType visitFunction(Function *curr) { abort(); }
- ReturnType visitTable(Table *curr) { abort(); }
- ReturnType visitMemory(Memory *curr) { abort(); }
- ReturnType visitModule(Module *curr) { abort(); }
-
-#define DELEGATE(CLASS_TO_VISIT) \
- return static_cast<SubType*>(this)-> \
- visit##CLASS_TO_VISIT(static_cast<CLASS_TO_VISIT*>(curr))
-
- ReturnType visit(Expression *curr) {
- assert(curr);
- switch (curr->_id) {
- case Expression::Id::InvalidId: abort();
- case Expression::Id::BlockId: DELEGATE(Block);
- case Expression::Id::IfId: DELEGATE(If);
- case Expression::Id::LoopId: DELEGATE(Loop);
- case Expression::Id::BreakId: DELEGATE(Break);
- case Expression::Id::SwitchId: DELEGATE(Switch);
- case Expression::Id::CallId: DELEGATE(Call);
- case Expression::Id::CallImportId: DELEGATE(CallImport);
- case Expression::Id::CallIndirectId: DELEGATE(CallIndirect);
- case Expression::Id::GetLocalId: DELEGATE(GetLocal);
- case Expression::Id::SetLocalId: DELEGATE(SetLocal);
- case Expression::Id::LoadId: DELEGATE(Load);
- case Expression::Id::StoreId: DELEGATE(Store);
- case Expression::Id::ConstId: DELEGATE(Const);
- case Expression::Id::UnaryId: DELEGATE(Unary);
- case Expression::Id::BinaryId: DELEGATE(Binary);
- case Expression::Id::SelectId: DELEGATE(Select);
- case Expression::Id::ReturnId: DELEGATE(Return);
- case Expression::Id::HostId: DELEGATE(Host);
- case Expression::Id::NopId: DELEGATE(Nop);
- case Expression::Id::UnreachableId: DELEGATE(Unreachable);
- default: WASM_UNREACHABLE();
- }
- }
-};
-
-//
-// Base class for all WasmWalkers
-//
-template<typename SubType, typename ReturnType = void>
-struct WasmWalkerBase : public WasmVisitor<SubType, ReturnType> {
- virtual void walk(Expression*& curr) { abort(); }
- virtual void startWalk(Function *func) { abort(); }
- virtual void startWalk(Module *module) { abort(); }
-};
-
-template<typename ParentType>
-struct ChildWalker : public WasmWalkerBase<ChildWalker<ParentType>> {
- ParentType& parent;
-
- ChildWalker(ParentType& parent) : parent(parent) {}
-
- void visitBlock(Block *curr) {
- ExpressionList& list = curr->list;
- for (size_t z = 0; z < list.size(); z++) {
- parent.walk(list[z]);
- }
- }
- void visitIf(If *curr) {
- parent.walk(curr->condition);
- parent.walk(curr->ifTrue);
- parent.walk(curr->ifFalse);
- }
- void visitLoop(Loop *curr) {
- parent.walk(curr->body);
- }
- void visitBreak(Break *curr) {
- parent.walk(curr->condition);
- parent.walk(curr->value);
- }
- void visitSwitch(Switch *curr) {
- parent.walk(curr->condition);
- if (curr->value) parent.walk(curr->value);
- }
- void visitCall(Call *curr) {
- ExpressionList& list = curr->operands;
- for (size_t z = 0; z < list.size(); z++) {
- parent.walk(list[z]);
- }
- }
- void visitCallImport(CallImport *curr) {
- ExpressionList& list = curr->operands;
- for (size_t z = 0; z < list.size(); z++) {
- parent.walk(list[z]);
- }
- }
- void visitCallIndirect(CallIndirect *curr) {
- parent.walk(curr->target);
- ExpressionList& list = curr->operands;
- for (size_t z = 0; z < list.size(); z++) {
- parent.walk(list[z]);
- }
- }
- void visitGetLocal(GetLocal *curr) {}
- void visitSetLocal(SetLocal *curr) {
- parent.walk(curr->value);
- }
- void visitLoad(Load *curr) {
- parent.walk(curr->ptr);
- }
- void visitStore(Store *curr) {
- parent.walk(curr->ptr);
- parent.walk(curr->value);
- }
- void visitConst(Const *curr) {}
- void visitUnary(Unary *curr) {
- parent.walk(curr->value);
- }
- void visitBinary(Binary *curr) {
- parent.walk(curr->left);
- parent.walk(curr->right);
- }
- void visitSelect(Select *curr) {
- parent.walk(curr->ifTrue);
- parent.walk(curr->ifFalse);
- parent.walk(curr->condition);
- }
- void visitReturn(Return *curr) {
- parent.walk(curr->value);
- }
- void visitHost(Host *curr) {
- ExpressionList& list = curr->operands;
- for (size_t z = 0; z < list.size(); z++) {
- parent.walk(list[z]);
- }
- }
- void visitNop(Nop *curr) {}
- void visitUnreachable(Unreachable *curr) {}
-};
-
-//
-// Simple WebAssembly children-first walking (i.e., post-order, if you look
-// at the children as subtrees of the current node), with the ability to replace
-// the current expression node. Useful for writing optimization passes.
-//
-
-template<typename SubType, typename ReturnType = void>
-struct WasmWalker : public WasmWalkerBase<SubType, ReturnType> {
- Expression* replace;
-
- WasmWalker() : replace(nullptr) {}
-
- // the visit* methods can call this to replace the current node
- void replaceCurrent(Expression *expression) {
- replace = expression;
- }
-
- // By default, do nothing
- ReturnType visitBlock(Block *curr) {}
- ReturnType visitIf(If *curr) {}
- ReturnType visitLoop(Loop *curr) {}
- ReturnType visitBreak(Break *curr) {}
- ReturnType visitSwitch(Switch *curr) {}
- ReturnType visitCall(Call *curr) {}
- ReturnType visitCallImport(CallImport *curr) {}
- ReturnType visitCallIndirect(CallIndirect *curr) {}
- ReturnType visitGetLocal(GetLocal *curr) {}
- ReturnType visitSetLocal(SetLocal *curr) {}
- ReturnType visitLoad(Load *curr) {}
- ReturnType visitStore(Store *curr) {}
- ReturnType visitConst(Const *curr) {}
- ReturnType visitUnary(Unary *curr) {}
- ReturnType visitBinary(Binary *curr) {}
- ReturnType visitSelect(Select *curr) {}
- ReturnType visitReturn(Return *curr) {}
- ReturnType visitHost(Host *curr) {}
- ReturnType visitNop(Nop *curr) {}
- ReturnType visitUnreachable(Unreachable *curr) {}
-
- ReturnType visitFunctionType(FunctionType *curr) {}
- ReturnType visitImport(Import *curr) {}
- ReturnType visitExport(Export *curr) {}
- ReturnType visitFunction(Function *curr) {}
- ReturnType visitTable(Table *curr) {}
- ReturnType visitMemory(Memory *curr) {}
- ReturnType visitModule(Module *curr) {}
-
- // children-first
- void walk(Expression*& curr) override {
- if (!curr) return;
-
- // special-case Block, because Block nesting (in their first element) can be incredibly deep
- if (curr->is<Block>()) {
- auto* block = curr->dyn_cast<Block>();
- std::vector<Block*> stack;
- stack.push_back(block);
- while (block->list.size() > 0 && block->list[0]->is<Block>()) {
- block = block->list[0]->cast<Block>();
- stack.push_back(block);
- }
- // walk all the children
- for (int i = int(stack.size()) - 1; i >= 0; i--) {
- auto* block = stack[i];
- auto& children = block->list;
- for (size_t j = 0; j < children.size(); j++) {
- if (i < int(stack.size()) - 1 && j == 0) {
- // this is one of the stacked blocks, no need to walk its children, we are doing that ourselves
- this->visit(children[0]);
- if (replace) {
- children[0] = replace;
- replace = nullptr;
- }
- } else {
- this->walk(children[j]);
- }
- }
- }
- // we walked all the children, and can rejoin later below to visit this node itself
- } else {
- // generic child-walking
- ChildWalker<WasmWalker<SubType, ReturnType>>(*this).visit(curr);
- }
-
- this->visit(curr);
-
- if (replace) {
- curr = replace;
- replace = nullptr;
- }
- }
-
- void startWalk(Function *func) override {
- walk(func->body);
- }
-
- void startWalk(Module *module) override {
- // Dispatch statically through the SubType.
- SubType* self = static_cast<SubType*>(this);
- for (auto curr : module->functionTypes) {
- self->visitFunctionType(curr);
- assert(!replace);
- }
- for (auto curr : module->imports) {
- self->visitImport(curr);
- assert(!replace);
- }
- for (auto curr : module->exports) {
- self->visitExport(curr);
- assert(!replace);
- }
- for (auto curr : module->functions) {
- startWalk(curr);
- self->visitFunction(curr);
- assert(!replace);
- }
- self->visitTable(&module->table);
- assert(!replace);
- self->visitMemory(&module->memory);
- assert(!replace);
- self->visitModule(module);
- assert(!replace);
- }
-};
-
} // namespace wasm
#endif // wasm_wasm_h