summaryrefslogtreecommitdiff
path: root/src/ast_utils.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/ast_utils.h')
-rw-r--r--src/ast_utils.h378
1 files changed, 378 insertions, 0 deletions
diff --git a/src/ast_utils.h b/src/ast_utils.h
index a43fc6b2f..ee5c76b69 100644
--- a/src/ast_utils.h
+++ b/src/ast_utils.h
@@ -17,6 +17,7 @@
#ifndef wasm_ast_utils_h
#define wasm_ast_utils_h
+#include "support/hash.h"
#include "wasm.h"
#include "wasm-traversal.h"
@@ -241,6 +242,383 @@ struct ExpressionAnalyzer {
// The value might be used, so it depends on if the function returns
return func->result != none;
}
+
+ static bool equal(Expression* left, Expression* right) {
+ std::vector<Name> nameStack;
+ std::map<Name, std::vector<Name>> rightNames; // for each name on the left, the stack of names on the right (a stack, since names are scoped and can nest duplicatively
+ Nop popNameMarker;
+ std::vector<Expression*> leftStack;
+ std::vector<Expression*> rightStack;
+
+ auto noteNames = [&](Name left, Name right) {
+ if (left.is() != right.is()) return false;
+ if (left.is()) {
+ nameStack.push_back(left);
+ rightNames[left].push_back(right);
+ leftStack.push_back(&popNameMarker);
+ rightStack.push_back(&popNameMarker);
+ }
+ return true;
+ };
+ auto checkNames = [&](Name left, Name right) {
+ auto iter = rightNames.find(left);
+ if (iter == rightNames.end()) return left == right; // non-internal name
+ return iter->second.back() == right;
+ };
+ auto popName = [&]() {
+ auto left = nameStack.back();
+ nameStack.pop_back();
+ rightNames[left].pop_back();
+ };
+
+ leftStack.push_back(left);
+ rightStack.push_back(right);
+
+ while (leftStack.size() > 0 && rightStack.size() > 0) {
+ left = leftStack.back();
+ leftStack.pop_back();
+ right = rightStack.back();
+ rightStack.pop_back();
+ if (!left != !right) return false;
+ if (!left) continue;
+ if (left == &popNameMarker) {
+ popName();
+ continue;
+ }
+ if (left->_id != right->_id) return false;
+ #define PUSH(clazz, what) \
+ leftStack.push_back(left->cast<clazz>()->what); \
+ rightStack.push_back(right->cast<clazz>()->what);
+ #define CHECK(clazz, what) \
+ if (left->cast<clazz>()->what != right->cast<clazz>()->what) return false;
+ switch (left->_id) {
+ case Expression::Id::BlockId: {
+ if (!noteNames(left->cast<Block>()->name, right->cast<Block>()->name)) return false;
+ CHECK(Block, list.size());
+ for (Index i = 0; i < left->cast<Block>()->list.size(); i++) {
+ PUSH(Block, list[i]);
+ }
+ break;
+ }
+ case Expression::Id::IfId: {
+ PUSH(If, condition);
+ PUSH(If, ifTrue);
+ PUSH(If, ifFalse);
+ break;
+ }
+ case Expression::Id::LoopId: {
+ if (!noteNames(left->cast<Loop>()->out, right->cast<Loop>()->out)) return false;
+ if (!noteNames(left->cast<Loop>()->in, right->cast<Loop>()->in)) return false;
+ PUSH(Loop, body);
+ break;
+ }
+ case Expression::Id::BreakId: {
+ if (!checkNames(left->cast<Break>()->name, right->cast<Break>()->name)) return false;
+ PUSH(Break, condition);
+ PUSH(Break, value);
+ break;
+ }
+ case Expression::Id::SwitchId: {
+ CHECK(Switch, targets.size());
+ for (Index i = 0; i < left->cast<Switch>()->targets.size(); i++) {
+ if (!checkNames(left->cast<Switch>()->targets[i], right->cast<Switch>()->targets[i])) return false;
+ }
+ if (!checkNames(left->cast<Switch>()->default_, right->cast<Switch>()->default_)) return false;
+ PUSH(Switch, condition);
+ PUSH(Switch, value);
+ break;
+ }
+ case Expression::Id::CallId: {
+ CHECK(Call, target);
+ CHECK(Call, operands.size());
+ for (Index i = 0; i < left->cast<Call>()->operands.size(); i++) {
+ PUSH(Call, operands[i]);
+ }
+ break;
+ }
+ case Expression::Id::CallImportId: {
+ CHECK(CallImport, target);
+ CHECK(CallImport, operands.size());
+ for (Index i = 0; i < left->cast<CallImport>()->operands.size(); i++) {
+ PUSH(CallImport, operands[i]);
+ }
+ break;
+ }
+ case Expression::Id::CallIndirectId: {
+ PUSH(CallIndirect, target);
+ CHECK(CallIndirect, fullType);
+ CHECK(CallIndirect, operands.size());
+ for (Index i = 0; i < left->cast<CallIndirect>()->operands.size(); i++) {
+ PUSH(CallIndirect, operands[i]);
+ }
+ break;
+ }
+ case Expression::Id::GetLocalId: {
+ CHECK(GetLocal, index);
+ break;
+ }
+ case Expression::Id::SetLocalId: {
+ CHECK(SetLocal, index);
+ PUSH(SetLocal, value);
+ break;
+ }
+ case Expression::Id::LoadId: {
+ CHECK(Load, bytes);
+ CHECK(Load, signed_);
+ CHECK(Load, offset);
+ CHECK(Load, align);
+ PUSH(Load, ptr);
+ break;
+ }
+ case Expression::Id::StoreId: {
+ CHECK(Store, bytes);
+ CHECK(Store, offset);
+ CHECK(Store, align);
+ PUSH(Store, ptr);
+ PUSH(Store, value);
+ break;
+ }
+ case Expression::Id::ConstId: {
+ CHECK(Const, value);
+ break;
+ }
+ case Expression::Id::UnaryId: {
+ CHECK(Unary, op);
+ PUSH(Unary, value);
+ break;
+ }
+ case Expression::Id::BinaryId: {
+ CHECK(Binary, op);
+ PUSH(Binary, left);
+ PUSH(Binary, right);
+ break;
+ }
+ case Expression::Id::SelectId: {
+ PUSH(Select, ifTrue);
+ PUSH(Select, ifFalse);
+ PUSH(Select, condition);
+ break;
+ }
+ case Expression::Id::ReturnId: {
+ PUSH(Return, value);
+ break;
+ }
+ case Expression::Id::HostId: {
+ CHECK(Host, op);
+ CHECK(Host, nameOperand);
+ CHECK(Host, operands.size());
+ for (Index i = 0; i < left->cast<Host>()->operands.size(); i++) {
+ PUSH(Host, operands[i]);
+ }
+ break;
+ }
+ case Expression::Id::NopId: {
+ break;
+ }
+ case Expression::Id::UnreachableId: {
+ break;
+ }
+ default: WASM_UNREACHABLE();
+ }
+ #undef CHECK
+ #undef PUSH
+ }
+ if (leftStack.size() > 0 || rightStack.size() > 0) return false;
+ return true;
+ }
+
+ // hash an expression, ignoring superficial details like specific internal names
+ static uint32_t hash(Expression* curr) {
+ uint32_t digest = 0;
+
+ auto hash = [&digest](uint32_t hash) {
+ digest = rehash(digest, hash);
+ };
+ auto hash64 = [&digest](uint64_t hash) {
+ digest = rehash(rehash(digest, hash >> 32), uint32_t(hash));
+ };
+
+ std::vector<Name> nameStack;
+ Index internalCounter = 0;
+ std::map<Name, std::vector<Index>> internalNames; // for each internal name, a vector if unique ids
+ Nop popNameMarker;
+ std::vector<Expression*> stack;
+
+ auto noteName = [&](Name curr) {
+ if (curr.is()) {
+ nameStack.push_back(curr);
+ internalNames[curr].push_back(internalCounter++);
+ stack.push_back(&popNameMarker);
+ }
+ return true;
+ };
+ auto hashName = [&](Name curr) {
+ auto iter = internalNames.find(curr);
+ if (iter == internalNames.end()) hash64(uint64_t(curr.str));
+ else hash(iter->second.back());
+ };
+ auto popName = [&]() {
+ auto curr = nameStack.back();
+ nameStack.pop_back();
+ internalNames[curr].pop_back();
+ };
+
+ stack.push_back(curr);
+
+ while (stack.size() > 0) {
+ curr = stack.back();
+ stack.pop_back();
+ if (!curr) continue;
+ if (curr == &popNameMarker) {
+ popName();
+ continue;
+ }
+ hash(curr->_id);
+ #define PUSH(clazz, what) \
+ stack.push_back(curr->cast<clazz>()->what);
+ #define HASH(clazz, what) \
+ hash(curr->cast<clazz>()->what);
+ #define HASH64(clazz, what) \
+ hash64(curr->cast<clazz>()->what);
+ #define HASH_NAME(clazz, what) \
+ hash64(uint64_t(curr->cast<clazz>()->what.str));
+ #define HASH_PTR(clazz, what) \
+ hash64(uint64_t(curr->cast<clazz>()->what));
+ switch (curr->_id) {
+ case Expression::Id::BlockId: {
+ noteName(curr->cast<Block>()->name);
+ HASH(Block, list.size());
+ for (Index i = 0; i < curr->cast<Block>()->list.size(); i++) {
+ PUSH(Block, list[i]);
+ }
+ break;
+ }
+ case Expression::Id::IfId: {
+ PUSH(If, condition);
+ PUSH(If, ifTrue);
+ PUSH(If, ifFalse);
+ break;
+ }
+ case Expression::Id::LoopId: {
+ noteName(curr->cast<Loop>()->out);
+ noteName(curr->cast<Loop>()->in);
+ PUSH(Loop, body);
+ break;
+ }
+ case Expression::Id::BreakId: {
+ hashName(curr->cast<Break>()->name);
+ PUSH(Break, condition);
+ PUSH(Break, value);
+ break;
+ }
+ case Expression::Id::SwitchId: {
+ HASH(Switch, targets.size());
+ for (Index i = 0; i < curr->cast<Switch>()->targets.size(); i++) {
+ hashName(curr->cast<Switch>()->targets[i]);
+ }
+ hashName(curr->cast<Switch>()->default_);
+ PUSH(Switch, condition);
+ PUSH(Switch, value);
+ break;
+ }
+ case Expression::Id::CallId: {
+ HASH_NAME(Call, target);
+ HASH(Call, operands.size());
+ for (Index i = 0; i < curr->cast<Call>()->operands.size(); i++) {
+ PUSH(Call, operands[i]);
+ }
+ break;
+ }
+ case Expression::Id::CallImportId: {
+ HASH_NAME(CallImport, target);
+ HASH(CallImport, operands.size());
+ for (Index i = 0; i < curr->cast<CallImport>()->operands.size(); i++) {
+ PUSH(CallImport, operands[i]);
+ }
+ break;
+ }
+ case Expression::Id::CallIndirectId: {
+ PUSH(CallIndirect, target);
+ HASH_PTR(CallIndirect, fullType);
+ HASH(CallIndirect, operands.size());
+ for (Index i = 0; i < curr->cast<CallIndirect>()->operands.size(); i++) {
+ PUSH(CallIndirect, operands[i]);
+ }
+ break;
+ }
+ case Expression::Id::GetLocalId: {
+ HASH(GetLocal, index);
+ break;
+ }
+ case Expression::Id::SetLocalId: {
+ HASH(SetLocal, index);
+ PUSH(SetLocal, value);
+ break;
+ }
+ case Expression::Id::LoadId: {
+ HASH(Load, bytes);
+ HASH(Load, signed_);
+ HASH(Load, offset);
+ HASH(Load, align);
+ PUSH(Load, ptr);
+ break;
+ }
+ case Expression::Id::StoreId: {
+ HASH(Store, bytes);
+ HASH(Store, offset);
+ HASH(Store, align);
+ PUSH(Store, ptr);
+ PUSH(Store, value);
+ break;
+ }
+ case Expression::Id::ConstId: {
+ HASH(Const, value.type);
+ HASH64(Const, value.getBits());
+ break;
+ }
+ case Expression::Id::UnaryId: {
+ HASH(Unary, op);
+ PUSH(Unary, value);
+ break;
+ }
+ case Expression::Id::BinaryId: {
+ HASH(Binary, op);
+ PUSH(Binary, left);
+ PUSH(Binary, right);
+ break;
+ }
+ case Expression::Id::SelectId: {
+ PUSH(Select, ifTrue);
+ PUSH(Select, ifFalse);
+ PUSH(Select, condition);
+ break;
+ }
+ case Expression::Id::ReturnId: {
+ PUSH(Return, value);
+ break;
+ }
+ case Expression::Id::HostId: {
+ HASH(Host, op);
+ HASH_NAME(Host, nameOperand);
+ HASH(Host, operands.size());
+ for (Index i = 0; i < curr->cast<Host>()->operands.size(); i++) {
+ PUSH(Host, operands[i]);
+ }
+ break;
+ }
+ case Expression::Id::NopId: {
+ break;
+ }
+ case Expression::Id::UnreachableId: {
+ break;
+ }
+ default: WASM_UNREACHABLE();
+ }
+ #undef HASH
+ #undef PUSH
+ }
+ return digest;
+ }
};
} // namespace wasm