diff options
Diffstat (limited to 'src/ir/ExpressionAnalyzer.cpp')
-rw-r--r-- | src/ir/ExpressionAnalyzer.cpp | 470 |
1 files changed, 142 insertions, 328 deletions
diff --git a/src/ir/ExpressionAnalyzer.cpp b/src/ir/ExpressionAnalyzer.cpp index 043fb259b..88dce6767 100644 --- a/src/ir/ExpressionAnalyzer.cpp +++ b/src/ir/ExpressionAnalyzer.cpp @@ -96,192 +96,6 @@ bool ExpressionAnalyzer::isResultDropped(ExpressionStack& stack) { return false; } -// -// Allows visiting the immediate fields of the expression. This is -// useful for comparisons and hashing. -// -// The passed-in visitor object must implement: -// * visitScopeName - a Name that represents a block or loop scope -// * visitNonScopeName - a non-scope name -// * visitInt - anything that has a short enumeration, including -// opcodes, # of bytes in a load, bools, etc. - must be -// guaranteed to fit in an int32 or less. -// * visitLiteral - a Literal -// * visitType - a Type -// * visitIndex - an Index -// * visitAddress - an Address -// - -namespace { - -template<typename T> void visitImmediates(Expression* curr, T& visitor) { - struct ImmediateVisitor : public OverriddenVisitor<ImmediateVisitor> { - T& visitor; - - ImmediateVisitor(Expression* curr, T& visitor) : visitor(visitor) { - this->visit(curr); - } - - void visitBlock(Block* curr) { visitor.visitScopeName(curr->name); } - void visitIf(If* curr) {} - void visitLoop(Loop* curr) { visitor.visitScopeName(curr->name); } - void visitBreak(Break* curr) { visitor.visitScopeName(curr->name); } - void visitSwitch(Switch* curr) { - for (auto target : curr->targets) { - visitor.visitScopeName(target); - } - visitor.visitScopeName(curr->default_); - } - void visitCall(Call* curr) { - visitor.visitNonScopeName(curr->target); - visitor.visitInt(curr->isReturn); - } - void visitCallIndirect(CallIndirect* curr) { - visitor.visitInt(curr->sig.params.getID()); - visitor.visitInt(curr->sig.results.getID()); - visitor.visitInt(curr->isReturn); - } - void visitLocalGet(LocalGet* curr) { visitor.visitIndex(curr->index); } - void visitLocalSet(LocalSet* curr) { visitor.visitIndex(curr->index); } - void visitGlobalGet(GlobalGet* curr) { - visitor.visitNonScopeName(curr->name); - } - void visitGlobalSet(GlobalSet* curr) { - visitor.visitNonScopeName(curr->name); - } - void visitLoad(Load* curr) { - visitor.visitInt(curr->bytes); - if (curr->type != Type::unreachable && - curr->bytes < curr->type.getByteSize()) { - visitor.visitInt(curr->signed_); - } - visitor.visitAddress(curr->offset); - visitor.visitAddress(curr->align); - visitor.visitInt(curr->isAtomic); - } - void visitStore(Store* curr) { - visitor.visitInt(curr->bytes); - visitor.visitAddress(curr->offset); - visitor.visitAddress(curr->align); - visitor.visitInt(curr->isAtomic); - visitor.visitInt(curr->valueType.getID()); - } - void visitAtomicRMW(AtomicRMW* curr) { - visitor.visitInt(curr->op); - visitor.visitInt(curr->bytes); - visitor.visitAddress(curr->offset); - } - void visitAtomicCmpxchg(AtomicCmpxchg* curr) { - visitor.visitInt(curr->bytes); - visitor.visitAddress(curr->offset); - } - void visitAtomicWait(AtomicWait* curr) { - visitor.visitAddress(curr->offset); - visitor.visitType(curr->expectedType); - } - void visitAtomicNotify(AtomicNotify* curr) { - visitor.visitAddress(curr->offset); - } - void visitAtomicFence(AtomicFence* curr) { visitor.visitInt(curr->order); } - void visitSIMDExtract(SIMDExtract* curr) { - visitor.visitInt(curr->op); - visitor.visitInt(curr->index); - } - void visitSIMDReplace(SIMDReplace* curr) { - visitor.visitInt(curr->op); - visitor.visitInt(curr->index); - } - void visitSIMDShuffle(SIMDShuffle* curr) { - for (auto x : curr->mask) { - visitor.visitInt(x); - } - } - void visitSIMDTernary(SIMDTernary* curr) { visitor.visitInt(curr->op); } - void visitSIMDShift(SIMDShift* curr) { visitor.visitInt(curr->op); } - void visitSIMDLoad(SIMDLoad* curr) { - visitor.visitInt(curr->op); - visitor.visitAddress(curr->offset); - visitor.visitAddress(curr->align); - } - void visitSIMDLoadStoreLane(SIMDLoadStoreLane* curr) { - visitor.visitInt(curr->op); - visitor.visitAddress(curr->offset); - visitor.visitAddress(curr->align); - visitor.visitInt(curr->index); - } - void visitMemoryInit(MemoryInit* curr) { - visitor.visitIndex(curr->segment); - } - void visitDataDrop(DataDrop* curr) { visitor.visitIndex(curr->segment); } - void visitMemoryCopy(MemoryCopy* curr) {} - void visitMemoryFill(MemoryFill* curr) {} - void visitConst(Const* curr) { visitor.visitLiteral(curr->value); } - void visitUnary(Unary* curr) { visitor.visitInt(curr->op); } - void visitBinary(Binary* curr) { visitor.visitInt(curr->op); } - void visitSelect(Select* curr) {} - void visitDrop(Drop* curr) {} - void visitReturn(Return* curr) {} - void visitMemorySize(MemorySize* curr) {} - void visitMemoryGrow(MemoryGrow* curr) {} - void visitRefNull(RefNull* curr) { visitor.visitType(curr->type); } - void visitRefIsNull(RefIsNull* curr) {} - void visitRefFunc(RefFunc* curr) { visitor.visitNonScopeName(curr->func); } - void visitRefEq(RefEq* curr) {} - void visitTry(Try* curr) {} - void visitThrow(Throw* curr) { visitor.visitNonScopeName(curr->event); } - void visitRethrow(Rethrow* curr) {} - void visitBrOnExn(BrOnExn* curr) { - visitor.visitScopeName(curr->name); - visitor.visitNonScopeName(curr->event); - } - void visitNop(Nop* curr) {} - void visitUnreachable(Unreachable* curr) {} - void visitPop(Pop* curr) {} - void visitTupleMake(TupleMake* curr) {} - void visitTupleExtract(TupleExtract* curr) { - visitor.visitIndex(curr->index); - } - void visitI31New(I31New* curr) {} - void visitI31Get(I31Get* curr) { visitor.visitInt(curr->signed_); } - void visitRefTest(RefTest* curr) { - WASM_UNREACHABLE("TODO (gc): ref.test"); - } - void visitRefCast(RefCast* curr) { - WASM_UNREACHABLE("TODO (gc): ref.cast"); - } - void visitBrOnCast(BrOnCast* curr) { - WASM_UNREACHABLE("TODO (gc): br_on_cast"); - } - void visitRttCanon(RttCanon* curr) { - WASM_UNREACHABLE("TODO (gc): rtt.canon"); - } - void visitRttSub(RttSub* curr) { WASM_UNREACHABLE("TODO (gc): rtt.sub"); } - void visitStructNew(StructNew* curr) { - WASM_UNREACHABLE("TODO (gc): struct.new"); - } - void visitStructGet(StructGet* curr) { - WASM_UNREACHABLE("TODO (gc): struct.get"); - } - void visitStructSet(StructSet* curr) { - WASM_UNREACHABLE("TODO (gc): struct.set"); - } - void visitArrayNew(ArrayNew* curr) { - WASM_UNREACHABLE("TODO (gc): array.new"); - } - void visitArrayGet(ArrayGet* curr) { - WASM_UNREACHABLE("TODO (gc): array.get"); - } - void visitArraySet(ArraySet* curr) { - WASM_UNREACHABLE("TODO (gc): array.set"); - } - void visitArrayLen(ArrayLen* curr) { - WASM_UNREACHABLE("TODO (gc): array.len"); - } - } singleton(curr, visitor); -} - -} // namespace - bool ExpressionAnalyzer::flexibleEqual(Expression* left, Expression* right, ExprComparer comparer) { @@ -291,80 +105,6 @@ bool ExpressionAnalyzer::flexibleEqual(Expression* left, std::vector<Expression*> leftStack; std::vector<Expression*> rightStack; - struct Immediates { - Comparer& parent; - - Immediates(Comparer& parent) : parent(parent) {} - - SmallVector<Name, 1> scopeNames; - SmallVector<Name, 1> nonScopeNames; - SmallVector<int32_t, 3> ints; - SmallVector<Literal, 1> literals; - SmallVector<Type, 1> types; - SmallVector<Index, 1> indexes; - SmallVector<Address, 2> addresses; - - void visitScopeName(Name curr) { scopeNames.push_back(curr); } - void visitNonScopeName(Name curr) { nonScopeNames.push_back(curr); } - void visitInt(int32_t curr) { ints.push_back(curr); } - void visitLiteral(Literal curr) { literals.push_back(curr); } - void visitType(Type curr) { types.push_back(curr); } - void visitIndex(Index curr) { indexes.push_back(curr); } - void visitAddress(Address curr) { addresses.push_back(curr); } - - // Comparison is by value, except for names, which must match. - bool operator==(const Immediates& other) { - if (scopeNames.size() != other.scopeNames.size()) { - return false; - } - for (Index i = 0; i < scopeNames.size(); i++) { - auto leftName = scopeNames[i]; - auto rightName = other.scopeNames[i]; - auto iter = parent.rightNames.find(leftName); - // If it's not found, that means it was defined out of the expression - // being compared, in which case we can just treat it literally - it - // must be exactly identical. - if (iter != parent.rightNames.end()) { - leftName = iter->second; - } - if (leftName != rightName) { - return false; - } - } - if (nonScopeNames != other.nonScopeNames) { - return false; - } - if (ints != other.ints) { - return false; - } - if (literals != other.literals) { - return false; - } - if (types != other.types) { - return false; - } - if (indexes != other.indexes) { - return false; - } - if (addresses != other.addresses) { - return false; - } - return true; - } - - bool operator!=(const Immediates& other) { return !(*this == other); } - - void clear() { - scopeNames.clear(); - nonScopeNames.clear(); - ints.clear(); - literals.clear(); - types.clear(); - indexes.clear(); - addresses.clear(); - } - }; - bool noteNames(Name left, Name right) { if (left.is() != right.is()) { return false; @@ -377,8 +117,6 @@ bool ExpressionAnalyzer::flexibleEqual(Expression* left, } bool compare(Expression* left, Expression* right, ExprComparer comparer) { - Immediates leftImmediates(*this), rightImmediates(*this); - // The empty name is the same on both sides. rightNames[Name()] = Name(); @@ -396,45 +134,16 @@ bool ExpressionAnalyzer::flexibleEqual(Expression* left, if (!left) { continue; } + // There are actual expressions to compare here. Start with the custom + // comparer function that was provided. if (comparer(left, right)) { - continue; // comparison hook, before all the rest + continue; } - // continue with normal structural comparison - if (left->_id != right->_id) { + if (left->type != right->type) { return false; } - // Blocks and loops introduce scoping. - if (auto* block = left->dynCast<Block>()) { - if (!noteNames(block->name, right->cast<Block>()->name)) { - return false; - } - } else if (auto* loop = left->dynCast<Loop>()) { - if (!noteNames(loop->name, right->cast<Loop>()->name)) { - return false; - } - } else { - // For all other nodes, compare their immediate values - visitImmediates(left, leftImmediates); - visitImmediates(right, rightImmediates); - if (leftImmediates != rightImmediates) { - return false; - } - leftImmediates.clear(); - rightImmediates.clear(); - } - // Add child nodes. - Index counter = 0; - for (auto* child : ChildIterator(left)) { - leftStack.push_back(child); - counter++; - } - for (auto* child : ChildIterator(right)) { - rightStack.push_back(child); - counter--; - } - // The number of child nodes must match (e.g. return has an optional - // one). - if (counter != 0) { + // Do the actual comparison, updating the names and stacks accordingly. + if (!compareNodes(left, right)) { return false; } } @@ -443,6 +152,97 @@ bool ExpressionAnalyzer::flexibleEqual(Expression* left, } return true; } + + bool compareNodes(Expression* left, Expression* right) { + if (left->_id != right->_id) { + return false; + } + +#define DELEGATE_ID left->_id + +// Create cast versions of it for later operations. +#define DELEGATE_START(id) \ + auto* castLeft = left->cast<id>(); \ + WASM_UNUSED(castLeft); \ + auto* castRight = right->cast<id>(); \ + WASM_UNUSED(castRight); + +// Handle each type of field, comparing it appropriately. +#define DELEGATE_FIELD_CHILD(id, name) \ + leftStack.push_back(castLeft->name); \ + rightStack.push_back(castRight->name); + +#define DELEGATE_FIELD_CHILD_VECTOR(id, name) \ + if (castLeft->name.size() != castRight->name.size()) { \ + return false; \ + } \ + for (auto* child : castLeft->name) { \ + leftStack.push_back(child); \ + } \ + for (auto* child : castRight->name) { \ + rightStack.push_back(child); \ + } + +#define COMPARE_FIELD(name) \ + if (castLeft->name != castRight->name) { \ + return false; \ + } + +#define DELEGATE_FIELD_INT(id, name) COMPARE_FIELD(name) +#define DELEGATE_FIELD_LITERAL(id, name) COMPARE_FIELD(name) +#define DELEGATE_FIELD_NAME(id, name) COMPARE_FIELD(name) +#define DELEGATE_FIELD_SIGNATURE(id, name) COMPARE_FIELD(name) +#define DELEGATE_FIELD_TYPE(id, name) COMPARE_FIELD(name) +#define DELEGATE_FIELD_ADDRESS(id, name) COMPARE_FIELD(name) + +#define COMPARE_LIST(name) \ + if (castLeft->name.size() != castRight->name.size()) { \ + return false; \ + } \ + for (Index i = 0; i < castLeft->name.size(); i++) { \ + if (castLeft->name[i] != castRight->name[i]) { \ + return false; \ + } \ + } + +#define DELEGATE_FIELD_INT_ARRAY(id, name) COMPARE_LIST(name) + +#define DELEGATE_FIELD_SCOPE_NAME_DEF(id, name) \ + if (castLeft->name.is() != castRight->name.is()) { \ + return false; \ + } \ + rightNames[castLeft->name] = castRight->name; + +#define DELEGATE_FIELD_SCOPE_NAME_USE(id, name) \ + if (!compareNames(castLeft->name, castRight->name)) { \ + return false; \ + } + +#define DELEGATE_FIELD_SCOPE_NAME_USE_VECTOR(id, name) \ + if (castLeft->name.size() != castRight->name.size()) { \ + return false; \ + } \ + for (Index i = 0; i < castLeft->name.size(); i++) { \ + if (!compareNames(castLeft->name[i], castRight->name[i])) { \ + return false; \ + } \ + } + +#include "wasm-delegations-fields.h" + + return true; + } + + bool compareNames(Name left, Name right) { + auto iter = rightNames.find(left); + // If it's not found, that means it was defined out of the expression + // being compared, in which case we can just treat it literally - it + // must be exactly identical. + if (iter != rightNames.end()) { + left = iter->second; + } + return left == right; + } }; return Comparer().compare(left, right, comparer); @@ -458,12 +258,6 @@ size_t ExpressionAnalyzer::hash(Expression* curr) { std::map<Name, Index> internalNames; ExpressionStack stack; - void noteScopeName(Name curr) { - if (curr.is()) { - internalNames[curr] = internalCounter++; - } - } - Hasher(Expression* curr) { stack.push_back(curr); @@ -471,6 +265,9 @@ size_t ExpressionAnalyzer::hash(Expression* curr) { curr = stack.back(); stack.pop_back(); if (!curr) { + // This was an optional child that was not present. Hash a 0 to + // represent that. + rehash(digest, 0); continue; } rehash(digest, curr->_id); @@ -483,27 +280,51 @@ size_t ExpressionAnalyzer::hash(Expression* curr) { // call_imports type, etc. The simplest thing is just to hash the // type for all of them. rehash(digest, curr->type.getID()); - // Blocks and loops introduce scoping. - if (auto* block = curr->dynCast<Block>()) { - noteScopeName(block->name); - } else if (auto* loop = curr->dynCast<Loop>()) { - noteScopeName(loop->name); - } else { - // For all other nodes, compare their immediate values - visitImmediates(curr, *this); - } - // Hash children - Index counter = 0; - for (auto* child : ChildIterator(curr)) { - stack.push_back(child); - counter++; - } - // Sometimes children are optional, e.g. return, so we must hash - // their number as well. - rehash(digest, counter); + // Hash the contents of the expression. + hashExpression(curr); } } + void hashExpression(Expression* curr) { + +#define DELEGATE_ID curr->_id + +// Create cast versions of it for later operations. +#define DELEGATE_START(id) \ + auto* cast = curr->cast<id>(); \ + WASM_UNUSED(cast); + +// Handle each type of field, comparing it appropriately. +#define DELEGATE_GET_FIELD(id, name) cast->name + +#define DELEGATE_FIELD_CHILD(id, name) stack.push_back(cast->name); + +#define HASH_FIELD(name) rehash(digest, cast->name); + +#define DELEGATE_FIELD_INT(id, name) HASH_FIELD(name) +#define DELEGATE_FIELD_LITERAL(id, name) HASH_FIELD(name) +#define DELEGATE_FIELD_SIGNATURE(id, name) HASH_FIELD(name) + +#define DELEGATE_FIELD_NAME(id, name) visitNonScopeName(cast->name) +#define DELEGATE_FIELD_TYPE(id, name) visitType(cast->name); +#define DELEGATE_FIELD_ADDRESS(id, name) visitAddress(cast->name); + +// Note that we only note the scope name, but do not also visit it. That means +// that (block $x) and (block) get the same hash. In other words, we only change +// the hash based on uses of scope names, that is when there is a noticeable +// difference in break targets. +#define DELEGATE_FIELD_SCOPE_NAME_DEF(id, name) noteScopeName(cast->name); + +#define DELEGATE_FIELD_SCOPE_NAME_USE(id, name) visitScopeName(cast->name); + +#include "wasm-delegations-fields.h" + } + + void noteScopeName(Name curr) { + if (curr.is()) { + internalNames[curr] = internalCounter++; + } + } void visitScopeName(Name curr) { // Names are relative, we give the same hash for // (block $x (br $x)) @@ -514,14 +335,7 @@ size_t ExpressionAnalyzer::hash(Expression* curr) { rehash(digest, internalNames[curr]); } void visitNonScopeName(Name curr) { rehash(digest, uint64_t(curr.str)); } - void visitInt(int32_t curr) { rehash(digest, curr); } - void visitLiteral(Literal curr) { rehash(digest, curr); } void visitType(Type curr) { rehash(digest, curr.getID()); } - void visitIndex(Index curr) { - static_assert(sizeof(Index) == sizeof(uint32_t), - "wasm64 will need changes here"); - rehash(digest, curr); - } void visitAddress(Address curr) { rehash(digest, curr.addr); } }; |