summaryrefslogtreecommitdiff
path: root/src/ast_utils.h
diff options
context:
space:
mode:
authorAlon Zakai <alonzakai@gmail.com>2016-06-26 11:00:02 -0700
committerGitHub <noreply@github.com>2016-06-26 11:00:02 -0700
commit45b358706c86415c5982f9e777fa9e19a33b27a3 (patch)
treed1caa4180c8d0f4a76319fd11f8b18f9f446e6c3 /src/ast_utils.h
parentc410d93d3af9813f889b4011f964d4becf43bc27 (diff)
parent87f3020cf4e666a6eb6620106e48ee042cd2f666 (diff)
downloadbinaryen-45b358706c86415c5982f9e777fa9e19a33b27a3.tar.gz
binaryen-45b358706c86415c5982f9e777fa9e19a33b27a3.tar.bz2
binaryen-45b358706c86415c5982f9e777fa9e19a33b27a3.zip
Merge pull request #602 from WebAssembly/dsl-nice
Use a DSL in OptimizeInstructions
Diffstat (limited to 'src/ast_utils.h')
-rw-r--r--src/ast_utils.h126
1 files changed, 125 insertions, 1 deletions
diff --git a/src/ast_utils.h b/src/ast_utils.h
index 8952114bc..77bfaf1f3 100644
--- a/src/ast_utils.h
+++ b/src/ast_utils.h
@@ -20,6 +20,7 @@
#include "support/hash.h"
#include "wasm.h"
#include "wasm-traversal.h"
+#include "wasm-builder.h"
namespace wasm {
@@ -210,6 +211,117 @@ struct ExpressionManipulator {
new (output) OutputType(allocator);
return output;
}
+
+ template<typename T>
+ static Expression* flexibleCopy(Expression* original, Module& wasm, T& custom) {
+ struct Copier : public Visitor<Copier, Expression*> {
+ Module& wasm;
+ T& custom;
+
+ Builder builder;
+
+ Copier(Module& wasm, T& custom) : wasm(wasm), custom(custom), builder(wasm) {}
+
+ Expression* copy(Expression* curr) {
+ if (!curr) return nullptr;
+ auto* ret = custom.copy(curr);
+ if (ret) return ret;
+ return Visitor<Copier, Expression*>::visit(curr);
+ }
+
+ Expression* visitBlock(Block *curr) {
+ auto* ret = builder.makeBlock();
+ for (Index i = 0; i < curr->list.size(); i++) {
+ ret->list.push_back(copy(curr->list[i]));
+ }
+ ret->name = curr->name;
+ ret->finalize(curr->type);
+ return ret;
+ }
+ Expression* visitIf(If *curr) {
+ return builder.makeIf(copy(curr->condition), copy(curr->ifTrue), copy(curr->ifFalse));
+ }
+ Expression* visitLoop(Loop *curr) {
+ return builder.makeLoop(curr->out, curr->in, copy(curr->body));
+ }
+ Expression* visitBreak(Break *curr) {
+ return builder.makeBreak(curr->name, copy(curr->value), copy(curr->condition));
+ }
+ Expression* visitSwitch(Switch *curr) {
+ return builder.makeSwitch(curr->targets, curr->default_, copy(curr->condition), copy(curr->value));
+ }
+ Expression* visitCall(Call *curr) {
+ auto* ret = builder.makeCall(curr->target, {}, curr->type);
+ for (Index i = 0; i < curr->operands.size(); i++) {
+ ret->operands.push_back(copy(curr->operands[i]));
+ }
+ return ret;
+ }
+ Expression* visitCallImport(CallImport *curr) {
+ auto* ret = builder.makeCallImport(curr->target, {}, curr->type);
+ for (Index i = 0; i < curr->operands.size(); i++) {
+ ret->operands.push_back(copy(curr->operands[i]));
+ }
+ return ret;
+ }
+ Expression* visitCallIndirect(CallIndirect *curr) {
+ auto* ret = builder.makeCallIndirect(curr->fullType, curr->target, {}, curr->type);
+ for (Index i = 0; i < curr->operands.size(); i++) {
+ ret->operands.push_back(copy(curr->operands[i]));
+ }
+ return ret;
+ }
+ Expression* visitGetLocal(GetLocal *curr) {
+ return builder.makeGetLocal(curr->index, curr->type);
+ }
+ Expression* visitSetLocal(SetLocal *curr) {
+ return builder.makeSetLocal(curr->index, copy(curr->value));
+ }
+ Expression* visitLoad(Load *curr) {
+ return builder.makeLoad(curr->bytes, curr->signed_, curr->offset, curr->align, copy(curr->ptr), curr->type);
+ }
+ Expression* visitStore(Store *curr) {
+ return builder.makeStore(curr->bytes, curr->offset, curr->align, copy(curr->ptr), copy(curr->value));
+ }
+ Expression* visitConst(Const *curr) {
+ return builder.makeConst(curr->value);
+ }
+ Expression* visitUnary(Unary *curr) {
+ return builder.makeUnary(curr->op, copy(curr->value));
+ }
+ Expression* visitBinary(Binary *curr) {
+ return builder.makeBinary(curr->op, copy(curr->left), copy(curr->right));
+ }
+ Expression* visitSelect(Select *curr) {
+ return builder.makeSelect(copy(curr->condition), copy(curr->ifTrue), copy(curr->ifFalse));
+ }
+ Expression* visitReturn(Return *curr) {
+ return builder.makeReturn(copy(curr->value));
+ }
+ Expression* visitHost(Host *curr) {
+ assert(curr->operands.size() == 0);
+ return builder.makeHost(curr->op, curr->nameOperand, {});
+ }
+ Expression* visitNop(Nop *curr) {
+ return builder.makeNop();
+ }
+ Expression* visitUnreachable(Unreachable *curr) {
+ return builder.makeUnreachable();
+ }
+ };
+
+ Copier copier(wasm, custom);
+ return copier.copy(original);
+ }
+
+ static Expression* copy(Expression* original, Module& wasm) {
+ struct Copier {
+ Expression* copy(Expression* curr) {
+ return nullptr;
+ }
+ } copier;
+ return flexibleCopy(original, wasm, copier);
+ }
};
struct ExpressionAnalyzer {
@@ -242,7 +354,8 @@ struct ExpressionAnalyzer {
return func->result != none;
}
- static bool equal(Expression* left, Expression* right) {
+ template<typename T>
+ static bool flexibleEqual(Expression* left, Expression* right, T& comparer) {
std::vector<Name> nameStack;
std::map<Name, std::vector<Name>> rightNames; // for each name on the left, the stack of names on the right (a stack, since names are scoped and can nest duplicatively
Nop popNameMarker;
@@ -284,6 +397,8 @@ struct ExpressionAnalyzer {
popName();
continue;
}
+ if (comparer.compare(left, right)) continue; // comparison hook, before all the rest
+ // continue with normal structural comparison
if (left->_id != right->_id) return false;
#define PUSH(clazz, what) \
leftStack.push_back(left->cast<clazz>()->what); \
@@ -426,6 +541,15 @@ struct ExpressionAnalyzer {
return true;
}
+ static bool equal(Expression* left, Expression* right) {
+ struct Comparer {
+ bool compare(Expression* left, Expression* right) {
+ return false;
+ }
+ } comparer;
+ return flexibleEqual(left, right, comparer);
+ }
+
// hash an expression, ignoring superficial details like specific internal names
static uint32_t hash(Expression* curr) {
uint32_t digest = 0;