summaryrefslogtreecommitdiff
path: root/src/wasm.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/wasm.h')
-rw-r--r--src/wasm.h151
1 files changed, 151 insertions, 0 deletions
diff --git a/src/wasm.h b/src/wasm.h
index d13eb0795..290fc5d47 100644
--- a/src/wasm.h
+++ b/src/wasm.h
@@ -2,6 +2,9 @@
// WebAssembly representation and processing library
//
+#ifndef __wasm_h__
+#define __wasm_h__
+
#include <cassert>
#include <cstddef>
#include <cstdint>
@@ -796,5 +799,153 @@ public:
}
};
+//
+// Simple WebAssembly AST walker
+//
+
+struct WasmWalker {
+ wasm::Arena* allocator; // use an existing allocator, or null if no allocations
+
+ WasmWalker() : allocator(nullptr) {}
+ WasmWalker(wasm::Arena* allocator) : allocator(allocator) {}
+
+ // Each method receives an AST pointer, and it is replaced with what is returned.
+ virtual Expression* walkBlock(Block *curr) { return curr; };
+ virtual Expression* walkIf(If *curr) { return curr; };
+ virtual Expression* walkLoop(Loop *curr) { return curr; };
+ virtual Expression* walkLabel(Label *curr) { return curr; };
+ virtual Expression* walkBreak(Break *curr) { return curr; };
+ virtual Expression* walkSwitch(Switch *curr) { return curr; };
+ virtual Expression* walkCall(Call *curr) { return curr; };
+ virtual Expression* walkCallImport(CallImport *curr) { return curr; };
+ virtual Expression* walkCallIndirect(CallIndirect *curr) { return curr; };
+ virtual Expression* walkGetLocal(GetLocal *curr) { return curr; };
+ virtual Expression* walkSetLocal(SetLocal *curr) { return curr; };
+ virtual Expression* walkLoad(Load *curr) { return curr; };
+ virtual Expression* walkStore(Store *curr) { return curr; };
+ virtual Expression* walkConst(Const *curr) { return curr; };
+ virtual Expression* walkUnary(Unary *curr) { return curr; };
+ virtual Expression* walkBinary(Binary *curr) { return curr; };
+ virtual Expression* walkCompare(Compare *curr) { return curr; };
+ virtual Expression* walkConvert(Convert *curr) { return curr; };
+ virtual Expression* walkHost(Host *curr) { return curr; };
+ virtual Expression* walkNop(Nop *curr) { return curr; };
+
+ // children-first
+ Expression *walk(Expression *curr) {
+ if (!curr) return curr;
+
+ if (Block *cast = dynamic_cast<Block*>(curr)) {
+ ExpressionList& list = cast->list;
+ for (size_t z = 0; z < list.size(); z++) {
+ list[z] = walk(list[z]);
+ }
+ return walkBlock(cast);
+ }
+ if (If *cast = dynamic_cast<If*>(curr)) {
+ cast->condition = walk(cast->condition);
+ cast->ifTrue = walk(cast->ifTrue);
+ cast->ifFalse = walk(cast->ifFalse);
+ return walkIf(cast);
+ }
+ if (Loop *cast = dynamic_cast<Loop*>(curr)) {
+ cast->body = walk(cast->body);
+ return walkLoop(cast);
+ }
+ if (Label *cast = dynamic_cast<Label*>(curr)) {
+ return walkLabel(cast);
+ }
+ if (Break *cast = dynamic_cast<Break*>(curr)) {
+ cast->condition = walk(cast->condition);
+ cast->value = walk(cast->value);
+ return walkBreak(cast);
+ }
+ if (Switch *cast = dynamic_cast<Switch*>(curr)) {
+ cast->value = walk(cast->value);
+ for (auto& curr : cast->cases) {
+ curr.body = walk(curr.body);
+ }
+ cast->default_ = walk(cast->default_);
+ return walkSwitch(cast);
+ }
+ if (Call *cast = dynamic_cast<Call*>(curr)) {
+ ExpressionList& list = cast->operands;
+ for (size_t z = 0; z < list.size(); z++) {
+ list[z] = walk(list[z]);
+ }
+ return walkCall(cast);
+ }
+ if (CallImport *cast = dynamic_cast<CallImport*>(curr)) {
+ ExpressionList& list = cast->operands;
+ for (size_t z = 0; z < list.size(); z++) {
+ list[z] = walk(list[z]);
+ }
+ return walkCallImport(cast);
+ }
+ if (CallIndirect *cast = dynamic_cast<CallIndirect*>(curr)) {
+ cast->target = walk(cast->target);
+ ExpressionList& list = cast->operands;
+ for (size_t z = 0; z < list.size(); z++) {
+ list[z] = walk(list[z]);
+ }
+ return walkCallIndirect(cast);
+ }
+ if (GetLocal *cast = dynamic_cast<GetLocal*>(curr)) {
+ return walkGetLocal(cast);
+ }
+ if (SetLocal *cast = dynamic_cast<SetLocal*>(curr)) {
+ cast->value = walk(cast->value);
+ return walkSetLocal(cast);
+ }
+ if (Load *cast = dynamic_cast<Load*>(curr)) {
+ cast->ptr = walk(cast->ptr);
+ return walkLoad(cast);
+ }
+ if (Store *cast = dynamic_cast<Store*>(curr)) {
+ cast->ptr = walk(cast->ptr);
+ cast->value = walk(cast->value);
+ return walkStore(cast);
+ }
+ if (Const *cast = dynamic_cast<Const*>(curr)) {
+ return walkConst(cast);
+ }
+ if (Unary *cast = dynamic_cast<Unary*>(curr)) {
+ cast->value = walk(cast->value);
+ return walkUnary(cast);
+ }
+ if (Binary *cast = dynamic_cast<Binary*>(curr)) {
+ cast->left = walk(cast->left);
+ cast->right = walk(cast->right);
+ return walkBinary(cast);
+ }
+ if (Compare *cast = dynamic_cast<Compare*>(curr)) {
+ cast->left = walk(cast->left);
+ cast->right = walk(cast->right);
+ return walkCompare(cast);
+ }
+ if (Convert *cast = dynamic_cast<Convert*>(curr)) {
+ cast->value = walk(cast->value);
+ return walkConvert(cast);
+ }
+ if (Host *cast = dynamic_cast<Host*>(curr)) {
+ ExpressionList& list = cast->operands;
+ for (size_t z = 0; z < list.size(); z++) {
+ list[z] = walk(list[z]);
+ }
+ return walkHost(cast);
+ }
+ if (Nop *cast = dynamic_cast<Nop*>(curr)) {
+ return walkNop(cast);
+ }
+ abort();
+ }
+
+ void startWalk(Function *func) {
+ func->body = walk(func->body);
+ }
+};
+
} // namespace wasm
+#endif // __wasm_h__
+