diff options
Diffstat (limited to 'src/ast_utils.h')
-rw-r--r-- | src/ast_utils.h | 378 |
1 files changed, 378 insertions, 0 deletions
diff --git a/src/ast_utils.h b/src/ast_utils.h index a43fc6b2f..ee5c76b69 100644 --- a/src/ast_utils.h +++ b/src/ast_utils.h @@ -17,6 +17,7 @@ #ifndef wasm_ast_utils_h #define wasm_ast_utils_h +#include "support/hash.h" #include "wasm.h" #include "wasm-traversal.h" @@ -241,6 +242,383 @@ struct ExpressionAnalyzer { // The value might be used, so it depends on if the function returns return func->result != none; } + + static bool equal(Expression* left, Expression* right) { + std::vector<Name> nameStack; + std::map<Name, std::vector<Name>> rightNames; // for each name on the left, the stack of names on the right (a stack, since names are scoped and can nest duplicatively + Nop popNameMarker; + std::vector<Expression*> leftStack; + std::vector<Expression*> rightStack; + + auto noteNames = [&](Name left, Name right) { + if (left.is() != right.is()) return false; + if (left.is()) { + nameStack.push_back(left); + rightNames[left].push_back(right); + leftStack.push_back(&popNameMarker); + rightStack.push_back(&popNameMarker); + } + return true; + }; + auto checkNames = [&](Name left, Name right) { + auto iter = rightNames.find(left); + if (iter == rightNames.end()) return left == right; // non-internal name + return iter->second.back() == right; + }; + auto popName = [&]() { + auto left = nameStack.back(); + nameStack.pop_back(); + rightNames[left].pop_back(); + }; + + leftStack.push_back(left); + rightStack.push_back(right); + + while (leftStack.size() > 0 && rightStack.size() > 0) { + left = leftStack.back(); + leftStack.pop_back(); + right = rightStack.back(); + rightStack.pop_back(); + if (!left != !right) return false; + if (!left) continue; + if (left == &popNameMarker) { + popName(); + continue; + } + if (left->_id != right->_id) return false; + #define PUSH(clazz, what) \ + leftStack.push_back(left->cast<clazz>()->what); \ + rightStack.push_back(right->cast<clazz>()->what); + #define CHECK(clazz, what) \ + if (left->cast<clazz>()->what != right->cast<clazz>()->what) return false; + switch (left->_id) { + case Expression::Id::BlockId: { + if (!noteNames(left->cast<Block>()->name, right->cast<Block>()->name)) return false; + CHECK(Block, list.size()); + for (Index i = 0; i < left->cast<Block>()->list.size(); i++) { + PUSH(Block, list[i]); + } + break; + } + case Expression::Id::IfId: { + PUSH(If, condition); + PUSH(If, ifTrue); + PUSH(If, ifFalse); + break; + } + case Expression::Id::LoopId: { + if (!noteNames(left->cast<Loop>()->out, right->cast<Loop>()->out)) return false; + if (!noteNames(left->cast<Loop>()->in, right->cast<Loop>()->in)) return false; + PUSH(Loop, body); + break; + } + case Expression::Id::BreakId: { + if (!checkNames(left->cast<Break>()->name, right->cast<Break>()->name)) return false; + PUSH(Break, condition); + PUSH(Break, value); + break; + } + case Expression::Id::SwitchId: { + CHECK(Switch, targets.size()); + for (Index i = 0; i < left->cast<Switch>()->targets.size(); i++) { + if (!checkNames(left->cast<Switch>()->targets[i], right->cast<Switch>()->targets[i])) return false; + } + if (!checkNames(left->cast<Switch>()->default_, right->cast<Switch>()->default_)) return false; + PUSH(Switch, condition); + PUSH(Switch, value); + break; + } + case Expression::Id::CallId: { + CHECK(Call, target); + CHECK(Call, operands.size()); + for (Index i = 0; i < left->cast<Call>()->operands.size(); i++) { + PUSH(Call, operands[i]); + } + break; + } + case Expression::Id::CallImportId: { + CHECK(CallImport, target); + CHECK(CallImport, operands.size()); + for (Index i = 0; i < left->cast<CallImport>()->operands.size(); i++) { + PUSH(CallImport, operands[i]); + } + break; + } + case Expression::Id::CallIndirectId: { + PUSH(CallIndirect, target); + CHECK(CallIndirect, fullType); + CHECK(CallIndirect, operands.size()); + for (Index i = 0; i < left->cast<CallIndirect>()->operands.size(); i++) { + PUSH(CallIndirect, operands[i]); + } + break; + } + case Expression::Id::GetLocalId: { + CHECK(GetLocal, index); + break; + } + case Expression::Id::SetLocalId: { + CHECK(SetLocal, index); + PUSH(SetLocal, value); + break; + } + case Expression::Id::LoadId: { + CHECK(Load, bytes); + CHECK(Load, signed_); + CHECK(Load, offset); + CHECK(Load, align); + PUSH(Load, ptr); + break; + } + case Expression::Id::StoreId: { + CHECK(Store, bytes); + CHECK(Store, offset); + CHECK(Store, align); + PUSH(Store, ptr); + PUSH(Store, value); + break; + } + case Expression::Id::ConstId: { + CHECK(Const, value); + break; + } + case Expression::Id::UnaryId: { + CHECK(Unary, op); + PUSH(Unary, value); + break; + } + case Expression::Id::BinaryId: { + CHECK(Binary, op); + PUSH(Binary, left); + PUSH(Binary, right); + break; + } + case Expression::Id::SelectId: { + PUSH(Select, ifTrue); + PUSH(Select, ifFalse); + PUSH(Select, condition); + break; + } + case Expression::Id::ReturnId: { + PUSH(Return, value); + break; + } + case Expression::Id::HostId: { + CHECK(Host, op); + CHECK(Host, nameOperand); + CHECK(Host, operands.size()); + for (Index i = 0; i < left->cast<Host>()->operands.size(); i++) { + PUSH(Host, operands[i]); + } + break; + } + case Expression::Id::NopId: { + break; + } + case Expression::Id::UnreachableId: { + break; + } + default: WASM_UNREACHABLE(); + } + #undef CHECK + #undef PUSH + } + if (leftStack.size() > 0 || rightStack.size() > 0) return false; + return true; + } + + // hash an expression, ignoring superficial details like specific internal names + static uint32_t hash(Expression* curr) { + uint32_t digest = 0; + + auto hash = [&digest](uint32_t hash) { + digest = rehash(digest, hash); + }; + auto hash64 = [&digest](uint64_t hash) { + digest = rehash(rehash(digest, hash >> 32), uint32_t(hash)); + }; + + std::vector<Name> nameStack; + Index internalCounter = 0; + std::map<Name, std::vector<Index>> internalNames; // for each internal name, a vector if unique ids + Nop popNameMarker; + std::vector<Expression*> stack; + + auto noteName = [&](Name curr) { + if (curr.is()) { + nameStack.push_back(curr); + internalNames[curr].push_back(internalCounter++); + stack.push_back(&popNameMarker); + } + return true; + }; + auto hashName = [&](Name curr) { + auto iter = internalNames.find(curr); + if (iter == internalNames.end()) hash64(uint64_t(curr.str)); + else hash(iter->second.back()); + }; + auto popName = [&]() { + auto curr = nameStack.back(); + nameStack.pop_back(); + internalNames[curr].pop_back(); + }; + + stack.push_back(curr); + + while (stack.size() > 0) { + curr = stack.back(); + stack.pop_back(); + if (!curr) continue; + if (curr == &popNameMarker) { + popName(); + continue; + } + hash(curr->_id); + #define PUSH(clazz, what) \ + stack.push_back(curr->cast<clazz>()->what); + #define HASH(clazz, what) \ + hash(curr->cast<clazz>()->what); + #define HASH64(clazz, what) \ + hash64(curr->cast<clazz>()->what); + #define HASH_NAME(clazz, what) \ + hash64(uint64_t(curr->cast<clazz>()->what.str)); + #define HASH_PTR(clazz, what) \ + hash64(uint64_t(curr->cast<clazz>()->what)); + switch (curr->_id) { + case Expression::Id::BlockId: { + noteName(curr->cast<Block>()->name); + HASH(Block, list.size()); + for (Index i = 0; i < curr->cast<Block>()->list.size(); i++) { + PUSH(Block, list[i]); + } + break; + } + case Expression::Id::IfId: { + PUSH(If, condition); + PUSH(If, ifTrue); + PUSH(If, ifFalse); + break; + } + case Expression::Id::LoopId: { + noteName(curr->cast<Loop>()->out); + noteName(curr->cast<Loop>()->in); + PUSH(Loop, body); + break; + } + case Expression::Id::BreakId: { + hashName(curr->cast<Break>()->name); + PUSH(Break, condition); + PUSH(Break, value); + break; + } + case Expression::Id::SwitchId: { + HASH(Switch, targets.size()); + for (Index i = 0; i < curr->cast<Switch>()->targets.size(); i++) { + hashName(curr->cast<Switch>()->targets[i]); + } + hashName(curr->cast<Switch>()->default_); + PUSH(Switch, condition); + PUSH(Switch, value); + break; + } + case Expression::Id::CallId: { + HASH_NAME(Call, target); + HASH(Call, operands.size()); + for (Index i = 0; i < curr->cast<Call>()->operands.size(); i++) { + PUSH(Call, operands[i]); + } + break; + } + case Expression::Id::CallImportId: { + HASH_NAME(CallImport, target); + HASH(CallImport, operands.size()); + for (Index i = 0; i < curr->cast<CallImport>()->operands.size(); i++) { + PUSH(CallImport, operands[i]); + } + break; + } + case Expression::Id::CallIndirectId: { + PUSH(CallIndirect, target); + HASH_PTR(CallIndirect, fullType); + HASH(CallIndirect, operands.size()); + for (Index i = 0; i < curr->cast<CallIndirect>()->operands.size(); i++) { + PUSH(CallIndirect, operands[i]); + } + break; + } + case Expression::Id::GetLocalId: { + HASH(GetLocal, index); + break; + } + case Expression::Id::SetLocalId: { + HASH(SetLocal, index); + PUSH(SetLocal, value); + break; + } + case Expression::Id::LoadId: { + HASH(Load, bytes); + HASH(Load, signed_); + HASH(Load, offset); + HASH(Load, align); + PUSH(Load, ptr); + break; + } + case Expression::Id::StoreId: { + HASH(Store, bytes); + HASH(Store, offset); + HASH(Store, align); + PUSH(Store, ptr); + PUSH(Store, value); + break; + } + case Expression::Id::ConstId: { + HASH(Const, value.type); + HASH64(Const, value.getBits()); + break; + } + case Expression::Id::UnaryId: { + HASH(Unary, op); + PUSH(Unary, value); + break; + } + case Expression::Id::BinaryId: { + HASH(Binary, op); + PUSH(Binary, left); + PUSH(Binary, right); + break; + } + case Expression::Id::SelectId: { + PUSH(Select, ifTrue); + PUSH(Select, ifFalse); + PUSH(Select, condition); + break; + } + case Expression::Id::ReturnId: { + PUSH(Return, value); + break; + } + case Expression::Id::HostId: { + HASH(Host, op); + HASH_NAME(Host, nameOperand); + HASH(Host, operands.size()); + for (Index i = 0; i < curr->cast<Host>()->operands.size(); i++) { + PUSH(Host, operands[i]); + } + break; + } + case Expression::Id::NopId: { + break; + } + case Expression::Id::UnreachableId: { + break; + } + default: WASM_UNREACHABLE(); + } + #undef HASH + #undef PUSH + } + return digest; + } }; } // namespace wasm |