summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/ast_utils.h4
-rw-r--r--src/pass.h2
-rw-r--r--src/passes/LowerIfElse.cpp2
-rw-r--r--src/passes/MergeBlocks.cpp2
-rw-r--r--src/passes/Metrics.cpp8
-rw-r--r--src/passes/OptimizeInstructions.cpp2
-rw-r--r--src/passes/PostEmscripten.cpp2
-rw-r--r--src/passes/Print.cpp2
-rw-r--r--src/passes/RemoveImports.cpp2
-rw-r--r--src/passes/RemoveUnusedBrs.cpp2
-rw-r--r--src/passes/RemoveUnusedNames.cpp2
-rw-r--r--src/passes/ReorderLocals.cpp2
-rw-r--r--src/passes/SimplifyLocals.cpp72
-rw-r--r--src/passes/Vacuum.cpp2
-rw-r--r--src/s2wasm.h2
-rw-r--r--src/wasm-binary.h2
-rw-r--r--src/wasm-interpreter.h2
-rw-r--r--src/wasm-traversal.h634
-rw-r--r--src/wasm-validator.h4
-rw-r--r--src/wasm2asm.h12
20 files changed, 378 insertions, 384 deletions
diff --git a/src/ast_utils.h b/src/ast_utils.h
index 5ab427178..561e12983 100644
--- a/src/ast_utils.h
+++ b/src/ast_utils.h
@@ -22,7 +22,7 @@
namespace wasm {
-struct BreakSeeker : public WasmWalker<BreakSeeker> {
+struct BreakSeeker : public PostWalker<BreakSeeker> {
Name target; // look for this one
size_t found;
@@ -42,7 +42,7 @@ struct BreakSeeker : public WasmWalker<BreakSeeker> {
// Look for side effects, including control flow
// TODO: look at individual locals
-struct EffectAnalyzer : public WasmWalker<EffectAnalyzer> {
+struct EffectAnalyzer : public PostWalker<EffectAnalyzer> {
bool branches = false;
bool calls = false;
bool readsLocal = false;
diff --git a/src/pass.h b/src/pass.h
index e3716545d..f1f99e3aa 100644
--- a/src/pass.h
+++ b/src/pass.h
@@ -133,7 +133,7 @@ public:
// e.g. through PassRunner::getLast
// Handles names in a module, in particular adding names without duplicates
-class NameManager : public WalkerPass<WasmWalker<NameManager>> {
+class NameManager : public WalkerPass<PostWalker<NameManager>> {
public:
Name getUnique(std::string prefix);
// TODO: getUniqueInFunction
diff --git a/src/passes/LowerIfElse.cpp b/src/passes/LowerIfElse.cpp
index 922a294d2..48c1c0f9c 100644
--- a/src/passes/LowerIfElse.cpp
+++ b/src/passes/LowerIfElse.cpp
@@ -32,7 +32,7 @@
namespace wasm {
-struct LowerIfElse : public WalkerPass<WasmWalker<LowerIfElse, void>> {
+struct LowerIfElse : public WalkerPass<PostWalker<LowerIfElse>> {
MixedArena* allocator;
std::unique_ptr<NameManager> namer;
diff --git a/src/passes/MergeBlocks.cpp b/src/passes/MergeBlocks.cpp
index c85f3d600..ab210123c 100644
--- a/src/passes/MergeBlocks.cpp
+++ b/src/passes/MergeBlocks.cpp
@@ -23,7 +23,7 @@
namespace wasm {
-struct MergeBlocks : public WalkerPass<WasmWalker<MergeBlocks>> {
+struct MergeBlocks : public WalkerPass<PostWalker<MergeBlocks>> {
void visitBlock(Block *curr) {
bool more = true;
while (more) {
diff --git a/src/passes/Metrics.cpp b/src/passes/Metrics.cpp
index 9673b8873..19df981b8 100644
--- a/src/passes/Metrics.cpp
+++ b/src/passes/Metrics.cpp
@@ -24,16 +24,16 @@ namespace wasm {
using namespace std;
// Prints metrics between optimization passes.
-struct Metrics : public WalkerPass<WasmWalker<Metrics>> {
+struct Metrics : public WalkerPass<PostWalker<Metrics>> {
static Metrics *lastMetricsPass;
map<const char *, int> counts;
- void walk(Expression *&curr) override {
- WalkerPass::walk(curr);
- if (!curr) return;
+
+ void visitExpression(Expression* curr) {
auto name = getExpressionName(curr);
counts[name]++;
}
+
void finalize(PassRunner *runner, Module *module) override {
ostream &o = cout;
o << "Counts"
diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp
index de342b43f..3d89d7af4 100644
--- a/src/passes/OptimizeInstructions.cpp
+++ b/src/passes/OptimizeInstructions.cpp
@@ -25,7 +25,7 @@
namespace wasm {
-struct OptimizeInstructions : public WalkerPass<WasmWalker<OptimizeInstructions>> {
+struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions>> {
void visitIf(If* curr) {
// flip branches to get rid of an i32.eqz
if (curr->ifFalse) {
diff --git a/src/passes/PostEmscripten.cpp b/src/passes/PostEmscripten.cpp
index 14476b727..99b172d65 100644
--- a/src/passes/PostEmscripten.cpp
+++ b/src/passes/PostEmscripten.cpp
@@ -24,7 +24,7 @@
namespace wasm {
-struct PostEmscripten : public WalkerPass<WasmWalker<PostEmscripten>> {
+struct PostEmscripten : public WalkerPass<PostWalker<PostEmscripten>> {
// When we have a Load from a local value (typically a GetLocal) plus a constant offset,
// we may be able to fold it in.
// The semantics of the Add are to wrap, while wasm offset semantics purposefully do
diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp
index 8e3d503b8..7b68d956e 100644
--- a/src/passes/Print.cpp
+++ b/src/passes/Print.cpp
@@ -24,7 +24,7 @@
namespace wasm {
-struct PrintSExpression : public WasmVisitor<PrintSExpression, void> {
+struct PrintSExpression : public Visitor<PrintSExpression> {
std::ostream& o;
unsigned indent = 0;
diff --git a/src/passes/RemoveImports.cpp b/src/passes/RemoveImports.cpp
index 6463df10c..4f42c526f 100644
--- a/src/passes/RemoveImports.cpp
+++ b/src/passes/RemoveImports.cpp
@@ -27,7 +27,7 @@
namespace wasm {
-struct RemoveImports : public WalkerPass<WasmWalker<RemoveImports>> {
+struct RemoveImports : public WalkerPass<PostWalker<RemoveImports>> {
MixedArena* allocator;
Module* module;
diff --git a/src/passes/RemoveUnusedBrs.cpp b/src/passes/RemoveUnusedBrs.cpp
index 5302c368a..998142724 100644
--- a/src/passes/RemoveUnusedBrs.cpp
+++ b/src/passes/RemoveUnusedBrs.cpp
@@ -23,7 +23,7 @@
namespace wasm {
-struct RemoveUnusedBrs : public WalkerPass<WasmWalker<RemoveUnusedBrs>> {
+struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
// preparation: try to unify branches, as the fewer there are, the higher a chance we can remove them
// specifically for if-else, turn an if-else with branches to the same target at the end of each
// child, and with a value, to a branch to that target containing the if-else
diff --git a/src/passes/RemoveUnusedNames.cpp b/src/passes/RemoveUnusedNames.cpp
index a8beaa15e..71569eefb 100644
--- a/src/passes/RemoveUnusedNames.cpp
+++ b/src/passes/RemoveUnusedNames.cpp
@@ -23,7 +23,7 @@
namespace wasm {
-struct RemoveUnusedNames : public WalkerPass<WasmWalker<RemoveUnusedNames>> {
+struct RemoveUnusedNames : public WalkerPass<PostWalker<RemoveUnusedNames>> {
// We maintain a list of branches that we saw in children, then when we reach
// a parent block, we know if it was branched to
std::set<Name> branchesSeen;
diff --git a/src/passes/ReorderLocals.cpp b/src/passes/ReorderLocals.cpp
index 1b5e998c4..e378408b6 100644
--- a/src/passes/ReorderLocals.cpp
+++ b/src/passes/ReorderLocals.cpp
@@ -26,7 +26,7 @@
namespace wasm {
-struct ReorderLocals : public WalkerPass<WasmWalker<ReorderLocals, void>> {
+struct ReorderLocals : public WalkerPass<PostWalker<ReorderLocals>> {
std::map<Name, uint32_t> counts;
diff --git a/src/passes/SimplifyLocals.cpp b/src/passes/SimplifyLocals.cpp
index 0d59b8759..53e77eb22 100644
--- a/src/passes/SimplifyLocals.cpp
+++ b/src/passes/SimplifyLocals.cpp
@@ -26,7 +26,7 @@
namespace wasm {
-struct SimplifyLocals : public WalkerPass<FastExecutionWalker<SimplifyLocals>> {
+struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals>> {
struct SinkableInfo {
Expression** item;
EffectAnalyzer effects;
@@ -43,23 +43,6 @@ struct SimplifyLocals : public WalkerPass<FastExecutionWalker<SimplifyLocals>> {
sinkables.clear();
}
- void visitBlock(Block *curr) {
- // note locals, we can sink them from here TODO sink from elsewhere?
- derecurseBlocks(curr, [&](Block* block) {
- // curr was already checked by walk()
- if (block != curr) checkPre(block);
- }, [&](Block* block, Expression*& child) {
- walk(child);
- if (child->is<SetLocal>()) {
- Name name = child->cast<SetLocal>()->name;
- assert(sinkables.count(name) == 0);
- sinkables.emplace(std::make_pair(name, SinkableInfo(&child)));
- }
- }, [&](Block* block) {
- if (block != curr) checkPost(block);
- });
- }
-
void visitGetLocal(GetLocal *curr) {
auto found = sinkables.find(curr->name);
if (found != sinkables.end()) {
@@ -73,7 +56,6 @@ struct SimplifyLocals : public WalkerPass<FastExecutionWalker<SimplifyLocals>> {
}
void visitSetLocal(SetLocal *curr) {
- walk(curr->value);
// if we are a potentially-sinkable thing, forget it - this
// write overrides the last TODO: optimizable
// TODO: if no get_locals left, can remove the set as well (== expressionizer in emscripten optimizer)
@@ -96,28 +78,56 @@ struct SimplifyLocals : public WalkerPass<FastExecutionWalker<SimplifyLocals>> {
}
}
- void checkPre(Expression* curr) {
+ static void visitPre(SimplifyLocals* self, Expression** currp) {
EffectAnalyzer effects;
- if (effects.checkPre(curr)) {
- checkInvalidations(effects);
+ if (effects.checkPre(*currp)) {
+ self->checkInvalidations(effects);
}
}
- void checkPost(Expression* curr) {
+ static void visitPost(SimplifyLocals* self, Expression** currp) {
EffectAnalyzer effects;
- if (effects.checkPost(curr)) {
- checkInvalidations(effects);
+ if (effects.checkPost(*currp)) {
+ self->checkInvalidations(effects);
}
}
- void walk(Expression*& curr) override {
- if (!curr) return;
-
- checkPre(curr);
+ static void tryMarkSinkable(SimplifyLocals* self, Expression** currp) {
+ auto* curr = (*currp)->dyn_cast<SetLocal>();
+ if (curr) {
+ Name name = curr->name;
+ assert(self->sinkables.count(name) == 0);
+ self->sinkables.emplace(std::make_pair(name, SinkableInfo(currp)));
+ }
+ }
- FastExecutionWalker::walk(curr);
+ // override scan to add a pre and a post check task to all nodes
+ static void scan(SimplifyLocals* self, Expression** currp) {
+ self->pushTask(visitPost, currp);
+
+ auto* curr = *currp;
+
+
+ if (curr->is<Block>()) {
+ // special-case blocks, by marking their children as locals.
+ // TODO sink from elsewhere? (need to make sure value is not used)
+ self->pushTask(SimplifyLocals::doNoteNonLinear, currp);
+ auto& list = curr->cast<Block>()->list;
+ int size = list.size();
+ // we can't sink the last element, as it might be a return value;
+ // and anyhow, control flow is nonlinear at the end of the block so
+ // it would be invalidated.
+ for (int i = size - 1; i >= 0; i--) {
+ if (i < size - 1) {
+ self->pushTask(tryMarkSinkable, &list[i]);
+ }
+ self->pushTask(scan, &list[i]);
+ }
+ } else {
+ WalkerPass<LinearExecutionWalker<SimplifyLocals>>::scan(self, currp);
+ }
- checkPost(curr);
+ self->pushTask(visitPre, currp);
}
};
diff --git a/src/passes/Vacuum.cpp b/src/passes/Vacuum.cpp
index f9704ed8d..ef83958c4 100644
--- a/src/passes/Vacuum.cpp
+++ b/src/passes/Vacuum.cpp
@@ -23,7 +23,7 @@
namespace wasm {
-struct Vacuum : public WalkerPass<WasmWalker<Vacuum>> {
+struct Vacuum : public WalkerPass<PostWalker<Vacuum>> {
void visitBlock(Block *curr) {
// compress out nops
int skip = 0;
diff --git a/src/s2wasm.h b/src/s2wasm.h
index 614afdc80..9d166624c 100644
--- a/src/s2wasm.h
+++ b/src/s2wasm.h
@@ -1368,7 +1368,7 @@ public:
o << ";; METADATA: { ";
// find asmConst calls, and emit their metadata
- struct AsmConstWalker : public WasmWalker<AsmConstWalker> {
+ struct AsmConstWalker : public PostWalker<AsmConstWalker> {
S2WasmBuilder* parent;
std::map<std::string, std::set<std::string>> sigsForCode;
diff --git a/src/wasm-binary.h b/src/wasm-binary.h
index f7fb4b8f3..d4e1bd2c1 100644
--- a/src/wasm-binary.h
+++ b/src/wasm-binary.h
@@ -431,7 +431,7 @@ int8_t binaryWasmType(WasmType type) {
}
}
-class WasmBinaryWriter : public WasmVisitor<WasmBinaryWriter, void> {
+class WasmBinaryWriter : public Visitor<WasmBinaryWriter, void> {
Module* wasm;
BufferWithRandomAccess& o;
bool debug;
diff --git a/src/wasm-interpreter.h b/src/wasm-interpreter.h
index ed6411258..70564e4a4 100644
--- a/src/wasm-interpreter.h
+++ b/src/wasm-interpreter.h
@@ -207,7 +207,7 @@ private:
#endif
// Execute a statement
- class ExpressionRunner : public WasmVisitor<ExpressionRunner, Flow> {
+ class ExpressionRunner : public Visitor<ExpressionRunner, Flow> {
ModuleInstance& instance;
FunctionScope& scope;
diff --git a/src/wasm-traversal.h b/src/wasm-traversal.h
index 24ec4905c..d1fff2753 100644
--- a/src/wasm-traversal.h
+++ b/src/wasm-traversal.h
@@ -31,47 +31,47 @@
namespace wasm {
-template<typename SubType, typename ReturnType>
-struct WasmVisitor {
- virtual ~WasmVisitor() {}
+template<typename SubType, typename ReturnType = void>
+struct Visitor {
+ virtual ~Visitor() {}
// Expression visitors
- ReturnType visitBlock(Block *curr) { abort(); }
- ReturnType visitIf(If *curr) { abort(); }
- ReturnType visitLoop(Loop *curr) { abort(); }
- ReturnType visitBreak(Break *curr) { abort(); }
- ReturnType visitSwitch(Switch *curr) { abort(); }
- ReturnType visitCall(Call *curr) { abort(); }
- ReturnType visitCallImport(CallImport *curr) { abort(); }
- ReturnType visitCallIndirect(CallIndirect *curr) { abort(); }
- ReturnType visitGetLocal(GetLocal *curr) { abort(); }
- ReturnType visitSetLocal(SetLocal *curr) { abort(); }
- ReturnType visitLoad(Load *curr) { abort(); }
- ReturnType visitStore(Store *curr) { abort(); }
- ReturnType visitConst(Const *curr) { abort(); }
- ReturnType visitUnary(Unary *curr) { abort(); }
- ReturnType visitBinary(Binary *curr) { abort(); }
- ReturnType visitSelect(Select *curr) { abort(); }
- ReturnType visitReturn(Return *curr) { abort(); }
- ReturnType visitHost(Host *curr) { abort(); }
- ReturnType visitNop(Nop *curr) { abort(); }
- ReturnType visitUnreachable(Unreachable *curr) { abort(); }
+ ReturnType visitBlock(Block *curr) {}
+ ReturnType visitIf(If *curr) {}
+ ReturnType visitLoop(Loop *curr) {}
+ ReturnType visitBreak(Break *curr) {}
+ ReturnType visitSwitch(Switch *curr) {}
+ ReturnType visitCall(Call *curr) {}
+ ReturnType visitCallImport(CallImport *curr) {}
+ ReturnType visitCallIndirect(CallIndirect *curr) {}
+ ReturnType visitGetLocal(GetLocal *curr) {}
+ ReturnType visitSetLocal(SetLocal *curr) {}
+ ReturnType visitLoad(Load *curr) {}
+ ReturnType visitStore(Store *curr) {}
+ ReturnType visitConst(Const *curr) {}
+ ReturnType visitUnary(Unary *curr) {}
+ ReturnType visitBinary(Binary *curr) {}
+ ReturnType visitSelect(Select *curr) {}
+ ReturnType visitReturn(Return *curr) {}
+ ReturnType visitHost(Host *curr) {}
+ ReturnType visitNop(Nop *curr) {}
+ ReturnType visitUnreachable(Unreachable *curr) {}
// Module-level visitors
- ReturnType visitFunctionType(FunctionType *curr) { abort(); }
- ReturnType visitImport(Import *curr) { abort(); }
- ReturnType visitExport(Export *curr) { abort(); }
- ReturnType visitFunction(Function *curr) { abort(); }
- ReturnType visitTable(Table *curr) { abort(); }
- ReturnType visitMemory(Memory *curr) { abort(); }
- ReturnType visitModule(Module *curr) { abort(); }
-
-#define DELEGATE(CLASS_TO_VISIT) \
- return static_cast<SubType*>(this)-> \
- visit##CLASS_TO_VISIT(static_cast<CLASS_TO_VISIT*>(curr))
+ ReturnType visitFunctionType(FunctionType *curr) {}
+ ReturnType visitImport(Import *curr) {}
+ ReturnType visitExport(Export *curr) {}
+ ReturnType visitFunction(Function *curr) {}
+ ReturnType visitTable(Table *curr) {}
+ ReturnType visitMemory(Memory *curr) {}
+ ReturnType visitModule(Module *curr) {}
ReturnType visit(Expression *curr) {
assert(curr);
+
+ #define DELEGATE(CLASS_TO_VISIT) \
+ return static_cast<SubType*>(this)-> \
+ visit##CLASS_TO_VISIT(static_cast<CLASS_TO_VISIT*>(curr))
+
switch (curr->_id) {
- case Expression::Id::InvalidId: abort();
case Expression::Id::BlockId: DELEGATE(Block);
case Expression::Id::IfId: DELEGATE(If);
case Expression::Id::LoopId: DELEGATE(Loop);
@@ -92,49 +92,41 @@ struct WasmVisitor {
case Expression::Id::HostId: DELEGATE(Host);
case Expression::Id::NopId: DELEGATE(Nop);
case Expression::Id::UnreachableId: DELEGATE(Unreachable);
+ case Expression::Id::InvalidId:
default: WASM_UNREACHABLE();
}
- }
-#undef DELEGATE
-
- // Helper method to de-recurse blocks, which often nest in their first position very heavily
- void derecurseBlocks(Block* block, std::function<void (Block*)> preBlock,
- std::function<void (Block*, Expression*&)> onChild,
- std::function<void (Block*)> postBlock) {
- std::vector<Block*> stack;
- stack.push_back(block);
- while (block->list.size() > 0 && block->list[0]->is<Block>()) {
- block = block->list[0]->cast<Block>();
- stack.push_back(block);
- }
- for (size_t i = 0; i < stack.size(); i++) {
- preBlock(stack[i]);
- }
- for (int i = int(stack.size()) - 1; i >= 0; i--) {
- auto* block = stack[i];
- auto& list = block->list;
- for (size_t j = 0; j < list.size(); j++) {
- if (i < int(stack.size()) - 1 && j == 0) {
- // nested block, we already called its pre
- } else {
- onChild(block, list[j]);
- }
- }
- postBlock(block);
- }
+ #undef DELEGATE
}
};
//
-// Base class for all WasmWalkers
+// Base class for all WasmWalkers, which can traverse an AST
+// and provide the option to replace nodes while doing so.
//
-template<typename SubType, typename ReturnType = void>
-struct WasmWalkerBase : public WasmVisitor<SubType, ReturnType> {
- virtual void walk(Expression*& curr) { abort(); }
+// Subclass and implement the visit*()
+// calls to run code on different node types.
+//
+template<typename SubType>
+struct Walker : public Visitor<SubType> {
+ // Extra generic visitor, called before each node's specific visitor. Useful for
+ // passes that need to do the same thing for every node type.
+ void visitExpression(Expression* curr) {}
+
+ // Node replacing as we walk - call replaceCurrent from
+ // your visitors.
+
+ Expression *replace = nullptr;
+
+ void replaceCurrent(Expression *expression) {
+ replace = expression;
+ }
+
+ // Walk starting
void startWalk(Function *func) {
- walk(func->body);
+ SubType* self = static_cast<SubType*>(this);
+ self->walk(func->body);
}
void startWalk(Module *module) {
@@ -157,185 +149,209 @@ struct WasmWalkerBase : public WasmVisitor<SubType, ReturnType> {
self->visitMemory(&module->memory);
self->visitModule(module);
}
-};
-template<typename ParentType>
-struct ChildWalker : public WasmWalkerBase<ChildWalker<ParentType>> {
- ParentType& parent;
+ // Walk implementation. We don't use recursion as ASTs may be highly
+ // nested.
- ChildWalker(ParentType& parent) : parent(parent) {}
+ // Tasks receive the this pointer and a pointer to the pointer to operate on
+ typedef void (*TaskFunc)(SubType*, Expression**);
- void visitBlock(Block *curr) {
- ExpressionList& list = curr->list;
- for (size_t z = 0; z < list.size(); z++) {
- parent.walk(list[z]);
- }
- }
- void visitIf(If *curr) {
- parent.walk(curr->condition);
- parent.walk(curr->ifTrue);
- parent.walk(curr->ifFalse);
- }
- void visitLoop(Loop *curr) {
- parent.walk(curr->body);
- }
- void visitBreak(Break *curr) {
- parent.walk(curr->condition);
- parent.walk(curr->value);
- }
- void visitSwitch(Switch *curr) {
- parent.walk(curr->condition);
- if (curr->value) parent.walk(curr->value);
- }
- void visitCall(Call *curr) {
- ExpressionList& list = curr->operands;
- for (size_t z = 0; z < list.size(); z++) {
- parent.walk(list[z]);
- }
- }
- void visitCallImport(CallImport *curr) {
- ExpressionList& list = curr->operands;
- for (size_t z = 0; z < list.size(); z++) {
- parent.walk(list[z]);
- }
+ struct Task {
+ TaskFunc func;
+ Expression** currp;
+ Task(TaskFunc func, Expression** currp) : func(func), currp(currp) {}
+ };
+
+ std::vector<Task> stack;
+
+ void pushTask(TaskFunc func, Expression** currp) {
+ stack.emplace_back(func, currp);
}
- void visitCallIndirect(CallIndirect *curr) {
- parent.walk(curr->target);
- ExpressionList& list = curr->operands;
- for (size_t z = 0; z < list.size(); z++) {
- parent.walk(list[z]);
+ void maybePushTask(TaskFunc func, Expression** currp) {
+ if (*currp) {
+ stack.emplace_back(func, currp);
}
}
- void visitGetLocal(GetLocal *curr) {}
- void visitSetLocal(SetLocal *curr) {
- parent.walk(curr->value);
- }
- void visitLoad(Load *curr) {
- parent.walk(curr->ptr);
- }
- void visitStore(Store *curr) {
- parent.walk(curr->ptr);
- parent.walk(curr->value);
- }
- void visitConst(Const *curr) {}
- void visitUnary(Unary *curr) {
- parent.walk(curr->value);
- }
- void visitBinary(Binary *curr) {
- parent.walk(curr->left);
- parent.walk(curr->right);
- }
- void visitSelect(Select *curr) {
- parent.walk(curr->ifTrue);
- parent.walk(curr->ifFalse);
- parent.walk(curr->condition);
- }
- void visitReturn(Return *curr) {
- parent.walk(curr->value);
- }
- void visitHost(Host *curr) {
- ExpressionList& list = curr->operands;
- for (size_t z = 0; z < list.size(); z++) {
- parent.walk(list[z]);
+ Task popTask() {
+ auto ret = stack.back();
+ stack.pop_back();
+ return ret;
+ }
+
+ void walk(Expression*& root) {
+ assert(stack.size() == 0);
+ pushTask(SubType::scan, &root);
+ while (stack.size() > 0) {
+ auto task = popTask();
+ assert(*task.currp);
+ task.func(static_cast<SubType*>(this), task.currp);
+ if (replace) {
+ *task.currp = replace;
+ replace = nullptr;
+ }
}
}
- void visitNop(Nop *curr) {}
- void visitUnreachable(Unreachable *curr) {}
-};
-
-// Walker that allows replacements
-template<typename SubType, typename ReturnType = void>
-struct WasmReplacerWalker : public WasmWalkerBase<SubType, ReturnType> {
- Expression* replace = nullptr;
-
- // methods can call this to replace the current node
- void replaceCurrent(Expression *expression) {
- replace = expression;
- }
- void walk(Expression*& curr) override {
- if (!curr) return;
-
- this->visit(curr);
-
- if (replace) {
- curr = replace;
- replace = nullptr;
- }
- }
+ // subclasses implement this to define the proper order of execution
+ static void scan(SubType* self, Expression** currp) { abort(); }
+
+ // task hooks to call visitors
+
+ static void doVisitBlock(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitBlock((*currp)->cast<Block>()); }
+ static void doVisitIf(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitIf((*currp)->cast<If>()); }
+ static void doVisitLoop(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitLoop((*currp)->cast<Loop>()); }
+ static void doVisitBreak(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitBreak((*currp)->cast<Break>()); }
+ static void doVisitSwitch(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitSwitch((*currp)->cast<Switch>()); }
+ static void doVisitCall(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitCall((*currp)->cast<Call>()); }
+ static void doVisitCallImport(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitCallImport((*currp)->cast<CallImport>()); }
+ static void doVisitCallIndirect(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitCallIndirect((*currp)->cast<CallIndirect>()); }
+ static void doVisitGetLocal(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitGetLocal((*currp)->cast<GetLocal>()); }
+ static void doVisitSetLocal(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitSetLocal((*currp)->cast<SetLocal>()); }
+ static void doVisitLoad(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitLoad((*currp)->cast<Load>()); }
+ static void doVisitStore(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitStore((*currp)->cast<Store>()); }
+ static void doVisitConst(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitConst((*currp)->cast<Const>()); }
+ static void doVisitUnary(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitUnary((*currp)->cast<Unary>()); }
+ static void doVisitBinary(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitBinary((*currp)->cast<Binary>()); }
+ static void doVisitSelect(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitSelect((*currp)->cast<Select>()); }
+ static void doVisitReturn(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitReturn((*currp)->cast<Return>()); }
+ static void doVisitHost(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitHost((*currp)->cast<Host>()); }
+ static void doVisitNop(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitNop((*currp)->cast<Nop>()); }
+ static void doVisitUnreachable(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitUnreachable((*currp)->cast<Unreachable>()); }
};
-//
-// Simple WebAssembly children-first walking (i.e., post-order, if you look
-// at the children as subtrees of the current node), with the ability to replace
-// the current expression node. Useful for writing optimization passes.
-//
+// Walks in post-order, i.e., children first. When there isn't an obvious
+// order to operands, we follow them in order of execution.
-template<typename SubType, typename ReturnType = void>
-struct WasmWalker : public WasmReplacerWalker<SubType, ReturnType> {
- // By default, do nothing
- ReturnType visitBlock(Block *curr) {}
- ReturnType visitIf(If *curr) {}
- ReturnType visitLoop(Loop *curr) {}
- ReturnType visitBreak(Break *curr) {}
- ReturnType visitSwitch(Switch *curr) {}
- ReturnType visitCall(Call *curr) {}
- ReturnType visitCallImport(CallImport *curr) {}
- ReturnType visitCallIndirect(CallIndirect *curr) {}
- ReturnType visitGetLocal(GetLocal *curr) {}
- ReturnType visitSetLocal(SetLocal *curr) {}
- ReturnType visitLoad(Load *curr) {}
- ReturnType visitStore(Store *curr) {}
- ReturnType visitConst(Const *curr) {}
- ReturnType visitUnary(Unary *curr) {}
- ReturnType visitBinary(Binary *curr) {}
- ReturnType visitSelect(Select *curr) {}
- ReturnType visitReturn(Return *curr) {}
- ReturnType visitHost(Host *curr) {}
- ReturnType visitNop(Nop *curr) {}
- ReturnType visitUnreachable(Unreachable *curr) {}
+template<typename SubType>
+struct PostWalker : public Walker<SubType> {
- ReturnType visitFunctionType(FunctionType *curr) {}
- ReturnType visitImport(Import *curr) {}
- ReturnType visitExport(Export *curr) {}
- ReturnType visitFunction(Function *curr) {}
- ReturnType visitTable(Table *curr) {}
- ReturnType visitMemory(Memory *curr) {}
- ReturnType visitModule(Module *curr) {}
+ static void scan(SubType* self, Expression** currp) {
- // children-first
- void walk(Expression*& curr) override {
- if (!curr) return;
-
- // special-case Block, because Block nesting (in their first element) can be incredibly deep
- if (curr->is<Block>()) {
- auto* block = curr->dyn_cast<Block>();
- std::vector<Block*> stack;
- stack.push_back(block);
- while (block->list.size() > 0 && block->list[0]->is<Block>()) {
- block = block->list[0]->cast<Block>();
- stack.push_back(block);
+ Expression* curr = *currp;
+ switch (curr->_id) {
+ case Expression::Id::InvalidId: abort();
+ case Expression::Id::BlockId: {
+ self->pushTask(SubType::doVisitBlock, currp);
+ auto& list = curr->cast<Block>()->list;
+ for (int i = int(list.size()) - 1; i >= 0; i--) {
+ self->pushTask(SubType::scan, &list[i]);
+ }
+ break;
+ }
+ case Expression::Id::IfId: {
+ self->pushTask(SubType::doVisitIf, currp);
+ self->maybePushTask(SubType::scan, &curr->cast<If>()->ifFalse);
+ self->pushTask(SubType::scan, &curr->cast<If>()->ifTrue);
+ self->pushTask(SubType::scan, &curr->cast<If>()->condition);
+ break;
+ }
+ case Expression::Id::LoopId: {
+ self->pushTask(SubType::doVisitLoop, currp);
+ self->pushTask(SubType::scan, &curr->cast<Loop>()->body);
+ break;
+ }
+ case Expression::Id::BreakId: {
+ self->pushTask(SubType::doVisitBreak, currp);
+ self->maybePushTask(SubType::scan, &curr->cast<Break>()->condition);
+ self->maybePushTask(SubType::scan, &curr->cast<Break>()->value);
+ break;
+ }
+ case Expression::Id::SwitchId: {
+ self->pushTask(SubType::doVisitSwitch, currp);
+ self->maybePushTask(SubType::scan, &curr->cast<Switch>()->value);
+ self->pushTask(SubType::scan, &curr->cast<Switch>()->condition);
+ break;
+ }
+ case Expression::Id::CallId: {
+ self->pushTask(SubType::doVisitCall, currp);
+ auto& list = curr->cast<Call>()->operands;
+ for (int i = int(list.size()) - 1; i >= 0; i--) {
+ self->pushTask(SubType::scan, &list[i]);
+ }
+ break;
+ }
+ case Expression::Id::CallImportId: {
+ self->pushTask(SubType::doVisitCallImport, currp);
+ auto& list = curr->cast<CallImport>()->operands;
+ for (int i = int(list.size()) - 1; i >= 0; i--) {
+ self->pushTask(SubType::scan, &list[i]);
+ }
+ break;
+ }
+ case Expression::Id::CallIndirectId: {
+ self->pushTask(SubType::doVisitCallIndirect, currp);
+ auto& list = curr->cast<CallIndirect>()->operands;
+ for (int i = int(list.size()) - 1; i >= 0; i--) {
+ self->pushTask(SubType::scan, &list[i]);
+ }
+ self->pushTask(SubType::scan, &curr->cast<CallIndirect>()->target);
+ break;
+ }
+ case Expression::Id::GetLocalId: {
+ self->pushTask(SubType::doVisitGetLocal, currp); // TODO: optimize leaves with a direct call?
+ break;
+ }
+ case Expression::Id::SetLocalId: {
+ self->pushTask(SubType::doVisitSetLocal, currp);
+ self->pushTask(SubType::scan, &curr->cast<SetLocal>()->value);
+ break;
+ }
+ case Expression::Id::LoadId: {
+ self->pushTask(SubType::doVisitLoad, currp);
+ self->pushTask(SubType::scan, &curr->cast<Load>()->ptr);
+ break;
+ }
+ case Expression::Id::StoreId: {
+ self->pushTask(SubType::doVisitStore, currp);
+ self->pushTask(SubType::scan, &curr->cast<Store>()->value);
+ self->pushTask(SubType::scan, &curr->cast<Store>()->ptr);
+ break;
}
- // walk all the children
- for (int i = int(stack.size()) - 1; i >= 0; i--) {
- auto* block = stack[i];
- auto& children = block->list;
- for (size_t j = 0; j < children.size(); j++) {
- if (i < int(stack.size()) - 1 && j == 0) {
- // this is one of the stacked blocks, no need to walk its children, we are doing that ourselves
- WasmReplacerWalker<SubType, ReturnType>::walk(children[0]);
- } else {
- this->walk(children[j]);
- }
+ case Expression::Id::ConstId: {
+ self->pushTask(SubType::doVisitConst, currp);
+ break;
+ }
+ case Expression::Id::UnaryId: {
+ self->pushTask(SubType::doVisitUnary, currp);
+ self->pushTask(SubType::scan, &curr->cast<Unary>()->value);
+ break;
+ }
+ case Expression::Id::BinaryId: {
+ self->pushTask(SubType::doVisitBinary, currp);
+ self->pushTask(SubType::scan, &curr->cast<Binary>()->right);
+ self->pushTask(SubType::scan, &curr->cast<Binary>()->left);
+ break;
+ }
+ case Expression::Id::SelectId: {
+ self->pushTask(SubType::doVisitSelect, currp);
+ self->pushTask(SubType::scan, &curr->cast<Select>()->condition);
+ self->pushTask(SubType::scan, &curr->cast<Select>()->ifFalse);
+ self->pushTask(SubType::scan, &curr->cast<Select>()->ifTrue);
+ break;
+ }
+ case Expression::Id::ReturnId: {
+ self->pushTask(SubType::doVisitReturn, currp);
+ self->maybePushTask(SubType::scan, &curr->cast<Return>()->value);
+ break;
+ }
+ case Expression::Id::HostId: {
+ self->pushTask(SubType::doVisitHost, currp);
+ auto& list = curr->cast<Host>()->operands;
+ for (int i = int(list.size()) - 1; i >= 0; i--) {
+ self->pushTask(SubType::scan, &list[i]);
}
+ break;
+ }
+ case Expression::Id::NopId: {
+ self->pushTask(SubType::doVisitNop, currp);
+ break;
}
- // we walked all the children, and can rejoin later below to visit this node itself
- } else {
- // generic child-walking
- ChildWalker<WasmWalker<SubType, ReturnType>>(*this).visit(curr);
+ case Expression::Id::UnreachableId: {
+ self->pushTask(SubType::doVisitUnreachable, currp);
+ break;
+ }
+ default: WASM_UNREACHABLE();
}
-
- WasmReplacerWalker<SubType, ReturnType>::walk(curr);
}
};
@@ -350,111 +366,73 @@ struct WasmWalker : public WasmReplacerWalker<SubType, ReturnType> {
// to noteNonLinear().
template<typename SubType>
-struct FastExecutionWalker : public WasmReplacerWalker<SubType> {
- FastExecutionWalker() {}
+struct LinearExecutionWalker : public PostWalker<SubType> {
+ LinearExecutionWalker() {}
- void noteNonLinear() {}
+ // subclasses should implement this
+ void noteNonLinear() { abort(); }
-#define DELEGATE_noteNonLinear() \
- static_cast<SubType*>(this)->noteNonLinear()
-#define DELEGATE_walk(ARG) \
- static_cast<SubType*>(this)->walk(ARG)
-
- void visitBlock(Block *curr) {
- ExpressionList& list = curr->list;
- for (size_t z = 0; z < list.size(); z++) {
- DELEGATE_walk(list[z]);
- }
- }
- void visitIf(If *curr) {
- DELEGATE_walk(curr->condition);
- DELEGATE_noteNonLinear();
- DELEGATE_walk(curr->ifTrue);
- DELEGATE_noteNonLinear();
- DELEGATE_walk(curr->ifFalse);
- DELEGATE_noteNonLinear();
- }
- void visitLoop(Loop *curr) {
- DELEGATE_noteNonLinear();
- DELEGATE_walk(curr->body);
- }
- void visitBreak(Break *curr) {
- if (curr->value) DELEGATE_walk(curr->value);
- if (curr->condition) DELEGATE_walk(curr->condition);
- DELEGATE_noteNonLinear();
- }
- void visitSwitch(Switch *curr) {
- DELEGATE_walk(curr->condition);
- if (curr->value) DELEGATE_walk(curr->value);
- DELEGATE_noteNonLinear();
- }
- void visitCall(Call *curr) {
- ExpressionList& list = curr->operands;
- for (size_t z = 0; z < list.size(); z++) {
- DELEGATE_walk(list[z]);
- }
- }
- void visitCallImport(CallImport *curr) {
- ExpressionList& list = curr->operands;
- for (size_t z = 0; z < list.size(); z++) {
- DELEGATE_walk(list[z]);
- }
- }
- void visitCallIndirect(CallIndirect *curr) {
- DELEGATE_walk(curr->target);
- ExpressionList& list = curr->operands;
- for (size_t z = 0; z < list.size(); z++) {
- DELEGATE_walk(list[z]);
- }
+ static void doNoteNonLinear(SubType* self, Expression** currp) {
+ self->noteNonLinear();
}
- void visitGetLocal(GetLocal *curr) {}
- void visitSetLocal(SetLocal *curr) {
- DELEGATE_walk(curr->value);
- }
- void visitLoad(Load *curr) {
- DELEGATE_walk(curr->ptr);
- }
- void visitStore(Store *curr) {
- DELEGATE_walk(curr->ptr);
- DELEGATE_walk(curr->value);
- }
- void visitConst(Const *curr) {}
- void visitUnary(Unary *curr) {
- DELEGATE_walk(curr->value);
- }
- void visitBinary(Binary *curr) {
- DELEGATE_walk(curr->left);
- DELEGATE_walk(curr->right);
- }
- void visitSelect(Select *curr) {
- DELEGATE_walk(curr->ifTrue);
- DELEGATE_walk(curr->ifFalse);
- DELEGATE_walk(curr->condition);
- }
- void visitReturn(Return *curr) {
- DELEGATE_walk(curr->value);
- DELEGATE_noteNonLinear();
- }
- void visitHost(Host *curr) {
- ExpressionList& list = curr->operands;
- for (size_t z = 0; z < list.size(); z++) {
- DELEGATE_walk(list[z]);
- }
- }
- void visitNop(Nop *curr) {}
- void visitUnreachable(Unreachable *curr) {}
- void visitFunctionType(FunctionType *curr) {}
- void visitImport(Import *curr) {}
- void visitExport(Export *curr) {}
- void visitFunction(Function *curr) {}
- void visitTable(Table *curr) {}
- void visitMemory(Memory *curr) {}
- void visitModule(Module *curr) {}
+ static void scan(SubType* self, Expression** currp) {
-#undef DELEGATE_noteNonLinear
-#undef DELEGATE_walk
+ Expression* curr = *currp;
+ switch (curr->_id) {
+ case Expression::Id::InvalidId: abort();
+ case Expression::Id::BlockId: {
+ self->pushTask(SubType::doVisitBlock, currp);
+ self->pushTask(SubType::doNoteNonLinear, currp);
+ auto& list = curr->cast<Block>()->list;
+ for (int i = int(list.size()) - 1; i >= 0; i--) {
+ self->pushTask(SubType::scan, &list[i]);
+ }
+ break;
+ }
+ case Expression::Id::IfId: {
+ self->pushTask(SubType::doVisitIf, currp);
+ self->pushTask(SubType::doNoteNonLinear, currp);
+ self->maybePushTask(SubType::scan, &curr->cast<If>()->ifFalse);
+ self->pushTask(SubType::doNoteNonLinear, currp);
+ self->pushTask(SubType::scan, &curr->cast<If>()->ifTrue);
+ self->pushTask(SubType::doNoteNonLinear, currp);
+ self->pushTask(SubType::scan, &curr->cast<If>()->condition);
+ break;
+ }
+ case Expression::Id::LoopId: {
+ self->pushTask(SubType::doVisitLoop, currp);
+ self->pushTask(SubType::scan, &curr->cast<Loop>()->body);
+ self->pushTask(SubType::doNoteNonLinear, currp);
+ break;
+ }
+ case Expression::Id::BreakId: {
+ self->pushTask(SubType::doVisitBreak, currp);
+ self->pushTask(SubType::doNoteNonLinear, currp);
+ self->maybePushTask(SubType::scan, &curr->cast<Break>()->condition);
+ self->maybePushTask(SubType::scan, &curr->cast<Break>()->value);
+ break;
+ }
+ case Expression::Id::SwitchId: {
+ self->pushTask(SubType::doVisitSwitch, currp);
+ self->pushTask(SubType::doNoteNonLinear, currp);
+ self->maybePushTask(SubType::scan, &curr->cast<Switch>()->value);
+ self->pushTask(SubType::scan, &curr->cast<Switch>()->condition);
+ break;
+ }
+ case Expression::Id::ReturnId: {
+ self->pushTask(SubType::doVisitReturn, currp);
+ self->pushTask(SubType::doNoteNonLinear, currp);
+ self->maybePushTask(SubType::scan, &curr->cast<Return>()->value);
+ break;
+ }
+ default: {
+ // other node types do not have control flow, use regular post-order
+ PostWalker<SubType>::scan(self, currp);
+ }
+ }
+ }
};
} // namespace wasm
diff --git a/src/wasm-validator.h b/src/wasm-validator.h
index 77f475990..7b2987dd9 100644
--- a/src/wasm-validator.h
+++ b/src/wasm-validator.h
@@ -25,7 +25,7 @@
namespace wasm {
-struct WasmValidator : public WasmWalker<WasmValidator> {
+struct WasmValidator : public PostWalker<WasmValidator> {
bool valid;
std::map<Name, WasmType> breakTypes; // breaks to a label must all have the same type, and the right type
@@ -118,7 +118,7 @@ public:
private:
// the "in" label has a none type, since no one can receive its value. make sure no one breaks to it with a value.
- struct LoopChildChecker : public WasmWalker<LoopChildChecker> {
+ struct LoopChildChecker : public PostWalker<LoopChildChecker> {
Name in;
bool valid = true;
diff --git a/src/wasm2asm.h b/src/wasm2asm.h
index 0d8422ab1..2c330e273 100644
--- a/src/wasm2asm.h
+++ b/src/wasm2asm.h
@@ -392,7 +392,7 @@ Ref Wasm2AsmBuilder::processFunction(Function* func) {
}
void Wasm2AsmBuilder::scanFunctionBody(Expression* curr) {
- struct ExpressionScanner : public WasmWalker<ExpressionScanner> {
+ struct ExpressionScanner : public PostWalker<ExpressionScanner> {
Wasm2AsmBuilder* parent;
ExpressionScanner(Wasm2AsmBuilder* parent) : parent(parent) {}
@@ -467,6 +467,9 @@ void Wasm2AsmBuilder::scanFunctionBody(Expression* curr) {
parent->setStatement(curr);
}
}
+ void visitReturn(Return *curr) {
+ abort();
+ }
void visitHost(Host *curr) {
for (auto item : curr->operands) {
if (parent->isStatement(item)) {
@@ -480,7 +483,7 @@ void Wasm2AsmBuilder::scanFunctionBody(Expression* curr) {
}
Ref Wasm2AsmBuilder::processFunctionBody(Expression* curr, IString result) {
- struct ExpressionProcessor : public WasmVisitor<ExpressionProcessor, Ref> {
+ struct ExpressionProcessor : public Visitor<ExpressionProcessor, Ref> {
Wasm2AsmBuilder* parent;
IString result;
ExpressionProcessor(Wasm2AsmBuilder* parent) : parent(parent) {}
@@ -521,7 +524,7 @@ Ref Wasm2AsmBuilder::processFunctionBody(Expression* curr, IString result) {
Ref visit(Expression* curr, IString nextResult) {
IString old = result;
result = nextResult;
- Ref ret = WasmVisitor::visit(curr);
+ Ref ret = Visitor::visit(curr);
result = old; // keep it consistent for the rest of this frame, which may call visit on multiple children
return ret;
}
@@ -1083,6 +1086,9 @@ Ref Wasm2AsmBuilder::processFunctionBody(Expression* curr, IString result) {
)
);
}
+ Ref visitReturn(Return *curr) {
+ abort();
+ }
Ref visitHost(Host *curr) {
abort();
}