diff options
Diffstat (limited to 'src/ast_utils.h')
-rw-r--r-- | src/ast_utils.h | 616 |
1 files changed, 14 insertions, 602 deletions
diff --git a/src/ast_utils.h b/src/ast_utils.h index 159762d0b..2a151d8f7 100644 --- a/src/ast_utils.h +++ b/src/ast_utils.h @@ -17,7 +17,6 @@ #ifndef wasm_ast_utils_h #define wasm_ast_utils_h -#include "support/hash.h" #include "wasm.h" #include "wasm-traversal.h" #include "wasm-builder.h" @@ -243,202 +242,28 @@ struct ExpressionManipulator { return output; } - template<typename T> - static Expression* flexibleCopy(Expression* original, Module& wasm, T& custom) { - struct Copier : public Visitor<Copier, Expression*> { - Module& wasm; - T& custom; - - Builder builder; - - Copier(Module& wasm, T& custom) : wasm(wasm), custom(custom), builder(wasm) {} - - Expression* copy(Expression* curr) { - if (!curr) return nullptr; - auto* ret = custom.copy(curr); - if (ret) return ret; - return Visitor<Copier, Expression*>::visit(curr); - } - - Expression* visitBlock(Block *curr) { - auto* ret = builder.makeBlock(); - for (Index i = 0; i < curr->list.size(); i++) { - ret->list.push_back(copy(curr->list[i])); - } - ret->name = curr->name; - ret->finalize(curr->type); - return ret; - } - Expression* visitIf(If *curr) { - return builder.makeIf(copy(curr->condition), copy(curr->ifTrue), copy(curr->ifFalse)); - } - Expression* visitLoop(Loop *curr) { - return builder.makeLoop(curr->name, copy(curr->body)); - } - Expression* visitBreak(Break *curr) { - return builder.makeBreak(curr->name, copy(curr->value), copy(curr->condition)); - } - Expression* visitSwitch(Switch *curr) { - return builder.makeSwitch(curr->targets, curr->default_, copy(curr->condition), copy(curr->value)); - } - Expression* visitCall(Call *curr) { - auto* ret = builder.makeCall(curr->target, {}, curr->type); - for (Index i = 0; i < curr->operands.size(); i++) { - ret->operands.push_back(copy(curr->operands[i])); - } - return ret; - } - Expression* visitCallImport(CallImport *curr) { - auto* ret = builder.makeCallImport(curr->target, {}, curr->type); - for (Index i = 0; i < curr->operands.size(); i++) { - ret->operands.push_back(copy(curr->operands[i])); - } - return ret; - } - Expression* visitCallIndirect(CallIndirect *curr) { - auto* ret = builder.makeCallIndirect(curr->fullType, copy(curr->target), {}, curr->type); - for (Index i = 0; i < curr->operands.size(); i++) { - ret->operands.push_back(copy(curr->operands[i])); - } - return ret; - } - Expression* visitGetLocal(GetLocal *curr) { - return builder.makeGetLocal(curr->index, curr->type); - } - Expression* visitSetLocal(SetLocal *curr) { - if (curr->isTee()) { - return builder.makeTeeLocal(curr->index, copy(curr->value)); - } else { - return builder.makeSetLocal(curr->index, copy(curr->value)); - } - } - Expression* visitGetGlobal(GetGlobal *curr) { - return builder.makeGetGlobal(curr->name, curr->type); - } - Expression* visitSetGlobal(SetGlobal *curr) { - return builder.makeSetGlobal(curr->name, copy(curr->value)); - } - Expression* visitLoad(Load *curr) { - return builder.makeLoad(curr->bytes, curr->signed_, curr->offset, curr->align, copy(curr->ptr), curr->type); - } - Expression* visitStore(Store *curr) { - return builder.makeStore(curr->bytes, curr->offset, curr->align, copy(curr->ptr), copy(curr->value), curr->valueType); - } - Expression* visitConst(Const *curr) { - return builder.makeConst(curr->value); - } - Expression* visitUnary(Unary *curr) { - return builder.makeUnary(curr->op, copy(curr->value)); - } - Expression* visitBinary(Binary *curr) { - return builder.makeBinary(curr->op, copy(curr->left), copy(curr->right)); - } - Expression* visitSelect(Select *curr) { - return builder.makeSelect(copy(curr->condition), copy(curr->ifTrue), copy(curr->ifFalse)); - } - Expression* visitDrop(Drop *curr) { - return builder.makeDrop(copy(curr->value)); - } - Expression* visitReturn(Return *curr) { - return builder.makeReturn(copy(curr->value)); - } - Expression* visitHost(Host *curr) { - assert(curr->operands.size() == 0); - return builder.makeHost(curr->op, curr->nameOperand, {}); - } - Expression* visitNop(Nop *curr) { - return builder.makeNop(); - } - Expression* visitUnreachable(Unreachable *curr) { - return builder.makeUnreachable(); - } - }; - - Copier copier(wasm, custom); - return copier.copy(original); - } + using CustomCopier = std::function<Expression*(Expression*)>; + static Expression* flexibleCopy(Expression* original, Module& wasm, CustomCopier custom); static Expression* copy(Expression* original, Module& wasm) { - struct Copier { - Expression* copy(Expression* curr) { + auto copy = [](Expression* curr) { return nullptr; - } - } copier; - return flexibleCopy(original, wasm, copier); + }; + return flexibleCopy(original, wasm, copy); } // Splice an item into the middle of a block's list - static void spliceIntoBlock(Block* block, Index index, Expression* add) { - auto& list = block->list; - if (index == list.size()) { - list.push_back(add); // simple append - } else { - // we need to make room - list.push_back(nullptr); - for (Index i = list.size() - 1; i > index; i--) { - list[i] = list[i - 1]; - } - list[index] = add; - } - } + static void spliceIntoBlock(Block* block, Index index, Expression* add); }; struct ExpressionAnalyzer { // Given a stack of expressions, checks if the topmost is used as a result. // For example, if the parent is a block and the node is before the last position, // it is not used. - static bool isResultUsed(std::vector<Expression*> stack, Function* func) { - for (int i = int(stack.size()) - 2; i >= 0; i--) { - auto* curr = stack[i]; - auto* above = stack[i + 1]; - // only if and block can drop values (pre-drop expression was added) FIXME - if (curr->is<Block>()) { - auto* block = curr->cast<Block>(); - for (size_t j = 0; j < block->list.size() - 1; j++) { - if (block->list[j] == above) return false; - } - assert(block->list.back() == above); - // continue down - } else if (curr->is<If>()) { - auto* iff = curr->cast<If>(); - if (above == iff->condition) return true; - if (!iff->ifFalse) return false; - assert(above == iff->ifTrue || above == iff->ifFalse); - // continue down - } else { - if (curr->is<Drop>()) return false; - return true; // all other node types use the result - } - } - // The value might be used, so it depends on if the function returns - return func->result != none; - } + static bool isResultUsed(std::vector<Expression*> stack, Function* func); // Checks if a value is dropped. - static bool isResultDropped(std::vector<Expression*> stack) { - for (int i = int(stack.size()) - 2; i >= 0; i--) { - auto* curr = stack[i]; - auto* above = stack[i + 1]; - if (curr->is<Block>()) { - auto* block = curr->cast<Block>(); - for (size_t j = 0; j < block->list.size() - 1; j++) { - if (block->list[j] == above) return false; - } - assert(block->list.back() == above); - // continue down - } else if (curr->is<If>()) { - auto* iff = curr->cast<If>(); - if (above == iff->condition) return false; - if (!iff->ifFalse) return false; - assert(above == iff->ifTrue || above == iff->ifFalse); - // continue down - } else { - if (curr->is<Drop>()) return true; // dropped - return false; // all other node types use the result - } - } - return false; - } + static bool isResultDropped(std::vector<Expression*> stack); // Checks if a break is a simple - no condition, no value, just a plain branching static bool isSimple(Break* curr) { @@ -457,431 +282,18 @@ struct ExpressionAnalyzer { return false; } - template<typename T> - static bool flexibleEqual(Expression* left, Expression* right, T& comparer) { - std::vector<Name> nameStack; - std::map<Name, std::vector<Name>> rightNames; // for each name on the left, the stack of names on the right (a stack, since names are scoped and can nest duplicatively - Nop popNameMarker; - std::vector<Expression*> leftStack; - std::vector<Expression*> rightStack; - - auto noteNames = [&](Name left, Name right) { - if (left.is() != right.is()) return false; - if (left.is()) { - nameStack.push_back(left); - rightNames[left].push_back(right); - leftStack.push_back(&popNameMarker); - rightStack.push_back(&popNameMarker); - } - return true; - }; - auto checkNames = [&](Name left, Name right) { - auto iter = rightNames.find(left); - if (iter == rightNames.end()) return left == right; // non-internal name - return iter->second.back() == right; - }; - auto popName = [&]() { - auto left = nameStack.back(); - nameStack.pop_back(); - rightNames[left].pop_back(); - }; - - leftStack.push_back(left); - rightStack.push_back(right); - - while (leftStack.size() > 0 && rightStack.size() > 0) { - left = leftStack.back(); - leftStack.pop_back(); - right = rightStack.back(); - rightStack.pop_back(); - if (!left != !right) return false; - if (!left) continue; - if (left == &popNameMarker) { - popName(); - continue; - } - if (comparer.compare(left, right)) continue; // comparison hook, before all the rest - // continue with normal structural comparison - if (left->_id != right->_id) return false; - #define PUSH(clazz, what) \ - leftStack.push_back(left->cast<clazz>()->what); \ - rightStack.push_back(right->cast<clazz>()->what); - #define CHECK(clazz, what) \ - if (left->cast<clazz>()->what != right->cast<clazz>()->what) return false; - switch (left->_id) { - case Expression::Id::BlockId: { - if (!noteNames(left->cast<Block>()->name, right->cast<Block>()->name)) return false; - CHECK(Block, list.size()); - for (Index i = 0; i < left->cast<Block>()->list.size(); i++) { - PUSH(Block, list[i]); - } - break; - } - case Expression::Id::IfId: { - PUSH(If, condition); - PUSH(If, ifTrue); - PUSH(If, ifFalse); - break; - } - case Expression::Id::LoopId: { - if (!noteNames(left->cast<Loop>()->name, right->cast<Loop>()->name)) return false; - PUSH(Loop, body); - break; - } - case Expression::Id::BreakId: { - if (!checkNames(left->cast<Break>()->name, right->cast<Break>()->name)) return false; - PUSH(Break, condition); - PUSH(Break, value); - break; - } - case Expression::Id::SwitchId: { - CHECK(Switch, targets.size()); - for (Index i = 0; i < left->cast<Switch>()->targets.size(); i++) { - if (!checkNames(left->cast<Switch>()->targets[i], right->cast<Switch>()->targets[i])) return false; - } - if (!checkNames(left->cast<Switch>()->default_, right->cast<Switch>()->default_)) return false; - PUSH(Switch, condition); - PUSH(Switch, value); - break; - } - case Expression::Id::CallId: { - CHECK(Call, target); - CHECK(Call, operands.size()); - for (Index i = 0; i < left->cast<Call>()->operands.size(); i++) { - PUSH(Call, operands[i]); - } - break; - } - case Expression::Id::CallImportId: { - CHECK(CallImport, target); - CHECK(CallImport, operands.size()); - for (Index i = 0; i < left->cast<CallImport>()->operands.size(); i++) { - PUSH(CallImport, operands[i]); - } - break; - } - case Expression::Id::CallIndirectId: { - PUSH(CallIndirect, target); - CHECK(CallIndirect, fullType); - CHECK(CallIndirect, operands.size()); - for (Index i = 0; i < left->cast<CallIndirect>()->operands.size(); i++) { - PUSH(CallIndirect, operands[i]); - } - break; - } - case Expression::Id::GetLocalId: { - CHECK(GetLocal, index); - break; - } - case Expression::Id::SetLocalId: { - CHECK(SetLocal, index); - CHECK(SetLocal, type); // for tee/set - PUSH(SetLocal, value); - break; - } - case Expression::Id::GetGlobalId: { - CHECK(GetGlobal, name); - break; - } - case Expression::Id::SetGlobalId: { - CHECK(SetGlobal, name); - PUSH(SetGlobal, value); - break; - } - case Expression::Id::LoadId: { - CHECK(Load, bytes); - CHECK(Load, signed_); - CHECK(Load, offset); - CHECK(Load, align); - PUSH(Load, ptr); - break; - } - case Expression::Id::StoreId: { - CHECK(Store, bytes); - CHECK(Store, offset); - CHECK(Store, align); - CHECK(Store, valueType); - PUSH(Store, ptr); - PUSH(Store, value); - break; - } - case Expression::Id::ConstId: { - CHECK(Const, value); - break; - } - case Expression::Id::UnaryId: { - CHECK(Unary, op); - PUSH(Unary, value); - break; - } - case Expression::Id::BinaryId: { - CHECK(Binary, op); - PUSH(Binary, left); - PUSH(Binary, right); - break; - } - case Expression::Id::SelectId: { - PUSH(Select, ifTrue); - PUSH(Select, ifFalse); - PUSH(Select, condition); - break; - } - case Expression::Id::DropId: { - PUSH(Drop, value); - break; - } - case Expression::Id::ReturnId: { - PUSH(Return, value); - break; - } - case Expression::Id::HostId: { - CHECK(Host, op); - CHECK(Host, nameOperand); - CHECK(Host, operands.size()); - for (Index i = 0; i < left->cast<Host>()->operands.size(); i++) { - PUSH(Host, operands[i]); - } - break; - } - case Expression::Id::NopId: { - break; - } - case Expression::Id::UnreachableId: { - break; - } - default: WASM_UNREACHABLE(); - } - #undef CHECK - #undef PUSH - } - if (leftStack.size() > 0 || rightStack.size() > 0) return false; - return true; - } + using ExprComparer = std::function<bool(Expression*, Expression*)>; + static bool flexibleEqual(Expression* left, Expression* right, ExprComparer comparer); static bool equal(Expression* left, Expression* right) { - struct Comparer { - bool compare(Expression* left, Expression* right) { - return false; - } - } comparer; + auto comparer = [](Expression* left, Expression* right) { + return false; + }; return flexibleEqual(left, right, comparer); } // hash an expression, ignoring superficial details like specific internal names - static uint32_t hash(Expression* curr) { - uint32_t digest = 0; - - auto hash = [&digest](uint32_t hash) { - digest = rehash(digest, hash); - }; - auto hash64 = [&digest](uint64_t hash) { - digest = rehash(rehash(digest, hash >> 32), uint32_t(hash)); - }; - - std::vector<Name> nameStack; - Index internalCounter = 0; - std::map<Name, std::vector<Index>> internalNames; // for each internal name, a vector if unique ids - Nop popNameMarker; - std::vector<Expression*> stack; - - auto noteName = [&](Name curr) { - if (curr.is()) { - nameStack.push_back(curr); - internalNames[curr].push_back(internalCounter++); - stack.push_back(&popNameMarker); - } - return true; - }; - auto hashName = [&](Name curr) { - auto iter = internalNames.find(curr); - if (iter == internalNames.end()) hash64(uint64_t(curr.str)); - else hash(iter->second.back()); - }; - auto popName = [&]() { - auto curr = nameStack.back(); - nameStack.pop_back(); - internalNames[curr].pop_back(); - }; - - stack.push_back(curr); - - while (stack.size() > 0) { - curr = stack.back(); - stack.pop_back(); - if (!curr) continue; - if (curr == &popNameMarker) { - popName(); - continue; - } - hash(curr->_id); - // we often don't need to hash the type, as it is tied to other values - // we are hashing anyhow, but there are exceptions: for example, a - // get_local's type is determined by the function, so if we are - // hashing only expression fragments, then two from different - // functions may turn out the same even if the type differs. Likewise, - // if we hash between modules, then we need to take int account - // call_imports type, etc. The simplest thing is just to hash the - // type for all of them. - hash(curr->type); - - #define PUSH(clazz, what) \ - stack.push_back(curr->cast<clazz>()->what); - #define HASH(clazz, what) \ - hash(curr->cast<clazz>()->what); - #define HASH64(clazz, what) \ - hash64(curr->cast<clazz>()->what); - #define HASH_NAME(clazz, what) \ - hash64(uint64_t(curr->cast<clazz>()->what.str)); - #define HASH_PTR(clazz, what) \ - hash64(uint64_t(curr->cast<clazz>()->what)); - switch (curr->_id) { - case Expression::Id::BlockId: { - noteName(curr->cast<Block>()->name); - HASH(Block, list.size()); - for (Index i = 0; i < curr->cast<Block>()->list.size(); i++) { - PUSH(Block, list[i]); - } - break; - } - case Expression::Id::IfId: { - PUSH(If, condition); - PUSH(If, ifTrue); - PUSH(If, ifFalse); - break; - } - case Expression::Id::LoopId: { - noteName(curr->cast<Loop>()->name); - PUSH(Loop, body); - break; - } - case Expression::Id::BreakId: { - hashName(curr->cast<Break>()->name); - PUSH(Break, condition); - PUSH(Break, value); - break; - } - case Expression::Id::SwitchId: { - HASH(Switch, targets.size()); - for (Index i = 0; i < curr->cast<Switch>()->targets.size(); i++) { - hashName(curr->cast<Switch>()->targets[i]); - } - hashName(curr->cast<Switch>()->default_); - PUSH(Switch, condition); - PUSH(Switch, value); - break; - } - case Expression::Id::CallId: { - HASH_NAME(Call, target); - HASH(Call, operands.size()); - for (Index i = 0; i < curr->cast<Call>()->operands.size(); i++) { - PUSH(Call, operands[i]); - } - break; - } - case Expression::Id::CallImportId: { - HASH_NAME(CallImport, target); - HASH(CallImport, operands.size()); - for (Index i = 0; i < curr->cast<CallImport>()->operands.size(); i++) { - PUSH(CallImport, operands[i]); - } - break; - } - case Expression::Id::CallIndirectId: { - PUSH(CallIndirect, target); - HASH_NAME(CallIndirect, fullType); - HASH(CallIndirect, operands.size()); - for (Index i = 0; i < curr->cast<CallIndirect>()->operands.size(); i++) { - PUSH(CallIndirect, operands[i]); - } - break; - } - case Expression::Id::GetLocalId: { - HASH(GetLocal, index); - break; - } - case Expression::Id::SetLocalId: { - HASH(SetLocal, index); - PUSH(SetLocal, value); - break; - } - case Expression::Id::GetGlobalId: { - HASH_NAME(GetGlobal, name); - break; - } - case Expression::Id::SetGlobalId: { - HASH_NAME(SetGlobal, name); - PUSH(SetGlobal, value); - break; - } - case Expression::Id::LoadId: { - HASH(Load, bytes); - HASH(Load, signed_); - HASH(Load, offset); - HASH(Load, align); - PUSH(Load, ptr); - break; - } - case Expression::Id::StoreId: { - HASH(Store, bytes); - HASH(Store, offset); - HASH(Store, align); - HASH(Store, valueType); - PUSH(Store, ptr); - PUSH(Store, value); - break; - } - case Expression::Id::ConstId: { - HASH(Const, value.type); - HASH64(Const, value.getBits()); - break; - } - case Expression::Id::UnaryId: { - HASH(Unary, op); - PUSH(Unary, value); - break; - } - case Expression::Id::BinaryId: { - HASH(Binary, op); - PUSH(Binary, left); - PUSH(Binary, right); - break; - } - case Expression::Id::SelectId: { - PUSH(Select, ifTrue); - PUSH(Select, ifFalse); - PUSH(Select, condition); - break; - } - case Expression::Id::DropId: { - PUSH(Drop, value); - break; - } - case Expression::Id::ReturnId: { - PUSH(Return, value); - break; - } - case Expression::Id::HostId: { - HASH(Host, op); - HASH_NAME(Host, nameOperand); - HASH(Host, operands.size()); - for (Index i = 0; i < curr->cast<Host>()->operands.size(); i++) { - PUSH(Host, operands[i]); - } - break; - } - case Expression::Id::NopId: { - break; - } - case Expression::Id::UnreachableId: { - break; - } - default: WASM_UNREACHABLE(); - } - #undef HASH - #undef PUSH - } - return digest; - } + static uint32_t hash(Expression* curr); }; // Finalizes a node |