/* * Copyright 2016 WebAssembly Community Group participants * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef wasm_ast_utils_h #define wasm_ast_utils_h #include "support/hash.h" #include "wasm.h" #include "wasm-traversal.h" #include "wasm-builder.h" #include "pass.h" namespace wasm { struct BreakSeeker : public PostWalker> { Name target; // look for this one XXX looking by name may fall prey to duplicate names Index found; BreakSeeker(Name target) : target(target), found(false) {} void visitBreak(Break *curr) { if (curr->name == target) found++; } void visitSwitch(Switch *curr) { for (auto name : curr->targets) { if (name == target) found++; } if (curr->default_ == target) found++; } static bool has(Expression* tree, Name target) { BreakSeeker breakSeeker(target); breakSeeker.walk(tree); return breakSeeker.found > 0; } static Index count(Expression* tree, Name target) { BreakSeeker breakSeeker(target); breakSeeker.walk(tree); return breakSeeker.found; } }; // Finds all functions that are reachable via direct calls. struct DirectCallGraphAnalyzer : public PostWalker> { Module *module; std::vector queue; std::unordered_set reachable; DirectCallGraphAnalyzer(Module* module, const std::vector& root) : module(module) { for (auto* curr : root) { queue.push_back(curr); } while (queue.size()) { auto* curr = queue.back(); queue.pop_back(); if (reachable.count(curr) == 0) { reachable.insert(curr); walk(curr->body); } } } void visitCall(Call *curr) { auto* target = module->getFunction(curr->target); if (reachable.count(target) == 0) { queue.push_back(target); } } }; // Look for side effects, including control flow // TODO: optimize struct EffectAnalyzer : public PostWalker> { EffectAnalyzer() {} EffectAnalyzer(Expression *ast) { walk(ast); // if we are left with breaks, they are external if (breakNames.size() > 0) branches = true; } bool branches = false; // branches out of this expression bool calls = false; std::set localsRead; std::set localsWritten; std::set globalsRead; std::set globalsWritten; bool readsMemory = false; bool writesMemory = false; bool accessesLocal() { return localsRead.size() + localsWritten.size() > 0; } bool accessesGlobal() { return globalsRead.size() + globalsWritten.size() > 0; } bool accessesMemory() { return calls || readsMemory || writesMemory; } bool hasSideEffects() { return calls || localsWritten.size() > 0 || writesMemory || branches || globalsWritten.size() > 0; } bool hasAnything() { return branches || calls || accessesLocal() || readsMemory || writesMemory || accessesGlobal(); } // checks if these effects would invalidate another set (e.g., if we write, we invalidate someone that reads, they can't be moved past us) bool invalidates(EffectAnalyzer& other) { if (branches || other.branches || ((writesMemory || calls) && other.accessesMemory()) || (accessesMemory() && (other.writesMemory || other.calls))) { return true; } for (auto local : localsWritten) { if (other.localsWritten.count(local) || other.localsRead.count(local)) { return true; } } for (auto local : localsRead) { if (other.localsWritten.count(local)) return true; } if ((accessesGlobal() && other.calls) || (other.accessesGlobal() && calls)) { return true; } for (auto global : globalsWritten) { if (other.globalsWritten.count(global) || other.globalsRead.count(global)) { return true; } } for (auto global : globalsRead) { if (other.globalsWritten.count(global)) return true; } return false; } // the checks above happen after the node's children were processed, in the order of execution // we must also check for control flow that happens before the children, i.e., loops bool checkPre(Expression* curr) { if (curr->is()) { branches = true; return true; } return false; } bool checkPost(Expression* curr) { visit(curr); if (curr->is()) { branches = true; } return hasAnything(); } std::set breakNames; void visitBreak(Break *curr) { breakNames.insert(curr->name); } void visitSwitch(Switch *curr) { for (auto name : curr->targets) { breakNames.insert(name); } breakNames.insert(curr->default_); } void visitBlock(Block* curr) { if (curr->name.is()) breakNames.erase(curr->name); // these were internal breaks } void visitLoop(Loop* curr) { if (curr->name.is()) breakNames.erase(curr->name); // these were internal breaks } void visitCall(Call *curr) { calls = true; } void visitCallImport(CallImport *curr) { calls = true; } void visitCallIndirect(CallIndirect *curr) { calls = true; } void visitGetLocal(GetLocal *curr) { localsRead.insert(curr->index); } void visitSetLocal(SetLocal *curr) { localsWritten.insert(curr->index); } void visitGetGlobal(GetGlobal *curr) { globalsRead.insert(curr->name); } void visitSetGlobal(SetGlobal *curr) { globalsWritten.insert(curr->name); } void visitLoad(Load *curr) { readsMemory = true; } void visitStore(Store *curr) { writesMemory = true; } void visitReturn(Return *curr) { branches = true; } void visitHost(Host *curr) { calls = true; } void visitUnreachable(Unreachable *curr) { branches = true; } }; // Meausure the size of an AST struct Measurer : public PostWalker> { Index size = 0; void visitExpression(Expression* curr) { size++; } static Index measure(Expression* tree) { Measurer measurer; measurer.walk(tree); return measurer.size; } }; // Manipulate expressions struct ExpressionManipulator { // Re-use a node's memory. This helps avoid allocation when optimizing. template static OutputType* convert(InputType *input) { static_assert(sizeof(OutputType) <= sizeof(InputType), "Can only convert to a smaller size Expression node"); input->~InputType(); // arena-allocaed, so no destructor, but avoid UB. OutputType* output = (OutputType*)(input); new (output) OutputType; return output; } // Convenience method for nop, which is a common conversion template static void nop(InputType* target) { convert(target); } // Convert a node that allocates template static OutputType* convert(InputType *input, MixedArena& allocator) { assert(sizeof(OutputType) <= sizeof(InputType)); input->~InputType(); // arena-allocaed, so no destructor, but avoid UB. OutputType* output = (OutputType*)(input); new (output) OutputType(allocator); return output; } template static Expression* flexibleCopy(Expression* original, Module& wasm, T& custom) { struct Copier : public Visitor { Module& wasm; T& custom; Builder builder; Copier(Module& wasm, T& custom) : wasm(wasm), custom(custom), builder(wasm) {} Expression* copy(Expression* curr) { if (!curr) return nullptr; auto* ret = custom.copy(curr); if (ret) return ret; return Visitor::visit(curr); } Expression* visitBlock(Block *curr) { auto* ret = builder.makeBlock(); for (Index i = 0; i < curr->list.size(); i++) { ret->list.push_back(copy(curr->list[i])); } ret->name = curr->name; ret->finalize(curr->type); return ret; } Expression* visitIf(If *curr) { return builder.makeIf(copy(curr->condition), copy(curr->ifTrue), copy(curr->ifFalse)); } Expression* visitLoop(Loop *curr) { return builder.makeLoop(curr->name, copy(curr->body)); } Expression* visitBreak(Break *curr) { return builder.makeBreak(curr->name, copy(curr->value), copy(curr->condition)); } Expression* visitSwitch(Switch *curr) { return builder.makeSwitch(curr->targets, curr->default_, copy(curr->condition), copy(curr->value)); } Expression* visitCall(Call *curr) { auto* ret = builder.makeCall(curr->target, {}, curr->type); for (Index i = 0; i < curr->operands.size(); i++) { ret->operands.push_back(copy(curr->operands[i])); } return ret; } Expression* visitCallImport(CallImport *curr) { auto* ret = builder.makeCallImport(curr->target, {}, curr->type); for (Index i = 0; i < curr->operands.size(); i++) { ret->operands.push_back(copy(curr->operands[i])); } return ret; } Expression* visitCallIndirect(CallIndirect *curr) { auto* ret = builder.makeCallIndirect(curr->fullType, curr->target, {}, curr->type); for (Index i = 0; i < curr->operands.size(); i++) { ret->operands.push_back(copy(curr->operands[i])); } return ret; } Expression* visitGetLocal(GetLocal *curr) { return builder.makeGetLocal(curr->index, curr->type); } Expression* visitSetLocal(SetLocal *curr) { if (curr->isTee()) { return builder.makeTeeLocal(curr->index, copy(curr->value)); } else { return builder.makeSetLocal(curr->index, copy(curr->value)); } } Expression* visitGetGlobal(GetGlobal *curr) { return builder.makeGetGlobal(curr->name, curr->type); } Expression* visitSetGlobal(SetGlobal *curr) { return builder.makeSetGlobal(curr->name, copy(curr->value)); } Expression* visitLoad(Load *curr) { return builder.makeLoad(curr->bytes, curr->signed_, curr->offset, curr->align, copy(curr->ptr), curr->type); } Expression* visitStore(Store *curr) { return builder.makeStore(curr->bytes, curr->offset, curr->align, copy(curr->ptr), copy(curr->value), curr->valueType); } Expression* visitConst(Const *curr) { return builder.makeConst(curr->value); } Expression* visitUnary(Unary *curr) { return builder.makeUnary(curr->op, copy(curr->value)); } Expression* visitBinary(Binary *curr) { return builder.makeBinary(curr->op, copy(curr->left), copy(curr->right)); } Expression* visitSelect(Select *curr) { return builder.makeSelect(copy(curr->condition), copy(curr->ifTrue), copy(curr->ifFalse)); } Expression* visitDrop(Drop *curr) { return builder.makeDrop(copy(curr->value)); } Expression* visitReturn(Return *curr) { return builder.makeReturn(copy(curr->value)); } Expression* visitHost(Host *curr) { assert(curr->operands.size() == 0); return builder.makeHost(curr->op, curr->nameOperand, {}); } Expression* visitNop(Nop *curr) { return builder.makeNop(); } Expression* visitUnreachable(Unreachable *curr) { return builder.makeUnreachable(); } }; Copier copier(wasm, custom); return copier.copy(original); } static Expression* copy(Expression* original, Module& wasm) { struct Copier { Expression* copy(Expression* curr) { return nullptr; } } copier; return flexibleCopy(original, wasm, copier); } }; struct ExpressionAnalyzer { // Given a stack of expressions, checks if the topmost is used as a result. // For example, if the parent is a block and the node is before the last position, // it is not used. static bool isResultUsed(std::vector stack, Function* func) { for (int i = int(stack.size()) - 2; i >= 0; i--) { auto* curr = stack[i]; auto* above = stack[i + 1]; // only if and block can drop values (pre-drop expression was added) FIXME if (curr->is()) { auto* block = curr->cast(); for (size_t j = 0; j < block->list.size() - 1; j++) { if (block->list[j] == above) return false; } assert(block->list.back() == above); // continue down } else if (curr->is()) { auto* iff = curr->cast(); if (above == iff->condition) return true; if (!iff->ifFalse) return false; assert(above == iff->ifTrue || above == iff->ifFalse); // continue down } else { if (curr->is()) return false; return true; // all other node types use the result } } // The value might be used, so it depends on if the function returns return func->result != none; } // Checks if a break is a simple - no condition, no value, just a plain branching static bool isSimple(Break* curr) { return !curr->condition && !curr->value; } // Checks if an expression ends with a simple break, // and returns a pointer to it if so. // (It might also have other internal branches.) static Expression* getEndingSimpleBreak(Expression* curr) { if (auto* br = curr->dynCast()) { if (isSimple(br)) return br; return nullptr; } if (auto* block = curr->dynCast()) { if (block->list.size() > 0) return getEndingSimpleBreak(block->list.back()); } return nullptr; } template static bool flexibleEqual(Expression* left, Expression* right, T& comparer) { std::vector nameStack; std::map> rightNames; // for each name on the left, the stack of names on the right (a stack, since names are scoped and can nest duplicatively Nop popNameMarker; std::vector leftStack; std::vector rightStack; auto noteNames = [&](Name left, Name right) { if (left.is() != right.is()) return false; if (left.is()) { nameStack.push_back(left); rightNames[left].push_back(right); leftStack.push_back(&popNameMarker); rightStack.push_back(&popNameMarker); } return true; }; auto checkNames = [&](Name left, Name right) { auto iter = rightNames.find(left); if (iter == rightNames.end()) return left == right; // non-internal name return iter->second.back() == right; }; auto popName = [&]() { auto left = nameStack.back(); nameStack.pop_back(); rightNames[left].pop_back(); }; leftStack.push_back(left); rightStack.push_back(right); while (leftStack.size() > 0 && rightStack.size() > 0) { left = leftStack.back(); leftStack.pop_back(); right = rightStack.back(); rightStack.pop_back(); if (!left != !right) return false; if (!left) continue; if (left == &popNameMarker) { popName(); continue; } if (comparer.compare(left, right)) continue; // comparison hook, before all the rest // continue with normal structural comparison if (left->_id != right->_id) return false; #define PUSH(clazz, what) \ leftStack.push_back(left->cast()->what); \ rightStack.push_back(right->cast()->what); #define CHECK(clazz, what) \ if (left->cast()->what != right->cast()->what) return false; switch (left->_id) { case Expression::Id::BlockId: { if (!noteNames(left->cast()->name, right->cast()->name)) return false; CHECK(Block, list.size()); for (Index i = 0; i < left->cast()->list.size(); i++) { PUSH(Block, list[i]); } break; } case Expression::Id::IfId: { PUSH(If, condition); PUSH(If, ifTrue); PUSH(If, ifFalse); break; } case Expression::Id::LoopId: { if (!noteNames(left->cast()->name, right->cast()->name)) return false; PUSH(Loop, body); break; } case Expression::Id::BreakId: { if (!checkNames(left->cast()->name, right->cast()->name)) return false; PUSH(Break, condition); PUSH(Break, value); break; } case Expression::Id::SwitchId: { CHECK(Switch, targets.size()); for (Index i = 0; i < left->cast()->targets.size(); i++) { if (!checkNames(left->cast()->targets[i], right->cast()->targets[i])) return false; } if (!checkNames(left->cast()->default_, right->cast()->default_)) return false; PUSH(Switch, condition); PUSH(Switch, value); break; } case Expression::Id::CallId: { CHECK(Call, target); CHECK(Call, operands.size()); for (Index i = 0; i < left->cast()->operands.size(); i++) { PUSH(Call, operands[i]); } break; } case Expression::Id::CallImportId: { CHECK(CallImport, target); CHECK(CallImport, operands.size()); for (Index i = 0; i < left->cast()->operands.size(); i++) { PUSH(CallImport, operands[i]); } break; } case Expression::Id::CallIndirectId: { PUSH(CallIndirect, target); CHECK(CallIndirect, fullType); CHECK(CallIndirect, operands.size()); for (Index i = 0; i < left->cast()->operands.size(); i++) { PUSH(CallIndirect, operands[i]); } break; } case Expression::Id::GetLocalId: { CHECK(GetLocal, index); break; } case Expression::Id::SetLocalId: { CHECK(SetLocal, index); CHECK(SetLocal, type); // for tee/set PUSH(SetLocal, value); break; } case Expression::Id::GetGlobalId: { CHECK(GetGlobal, name); break; } case Expression::Id::SetGlobalId: { CHECK(SetGlobal, name); PUSH(SetGlobal, value); break; } case Expression::Id::LoadId: { CHECK(Load, bytes); CHECK(Load, signed_); CHECK(Load, offset); CHECK(Load, align); PUSH(Load, ptr); break; } case Expression::Id::StoreId: { CHECK(Store, bytes); CHECK(Store, offset); CHECK(Store, align); CHECK(Store, valueType); PUSH(Store, ptr); PUSH(Store, value); break; } case Expression::Id::ConstId: { CHECK(Const, value); break; } case Expression::Id::UnaryId: { CHECK(Unary, op); PUSH(Unary, value); break; } case Expression::Id::BinaryId: { CHECK(Binary, op); PUSH(Binary, left); PUSH(Binary, right); break; } case Expression::Id::SelectId: { PUSH(Select, ifTrue); PUSH(Select, ifFalse); PUSH(Select, condition); break; } case Expression::Id::DropId: { PUSH(Drop, value); break; } case Expression::Id::ReturnId: { PUSH(Return, value); break; } case Expression::Id::HostId: { CHECK(Host, op); CHECK(Host, nameOperand); CHECK(Host, operands.size()); for (Index i = 0; i < left->cast()->operands.size(); i++) { PUSH(Host, operands[i]); } break; } case Expression::Id::NopId: { break; } case Expression::Id::UnreachableId: { break; } default: WASM_UNREACHABLE(); } #undef CHECK #undef PUSH } if (leftStack.size() > 0 || rightStack.size() > 0) return false; return true; } static bool equal(Expression* left, Expression* right) { struct Comparer { bool compare(Expression* left, Expression* right) { return false; } } comparer; return flexibleEqual(left, right, comparer); } // hash an expression, ignoring superficial details like specific internal names static uint32_t hash(Expression* curr) { uint32_t digest = 0; auto hash = [&digest](uint32_t hash) { digest = rehash(digest, hash); }; auto hash64 = [&digest](uint64_t hash) { digest = rehash(rehash(digest, hash >> 32), uint32_t(hash)); }; std::vector nameStack; Index internalCounter = 0; std::map> internalNames; // for each internal name, a vector if unique ids Nop popNameMarker; std::vector stack; auto noteName = [&](Name curr) { if (curr.is()) { nameStack.push_back(curr); internalNames[curr].push_back(internalCounter++); stack.push_back(&popNameMarker); } return true; }; auto hashName = [&](Name curr) { auto iter = internalNames.find(curr); if (iter == internalNames.end()) hash64(uint64_t(curr.str)); else hash(iter->second.back()); }; auto popName = [&]() { auto curr = nameStack.back(); nameStack.pop_back(); internalNames[curr].pop_back(); }; stack.push_back(curr); while (stack.size() > 0) { curr = stack.back(); stack.pop_back(); if (!curr) continue; if (curr == &popNameMarker) { popName(); continue; } hash(curr->_id); #define PUSH(clazz, what) \ stack.push_back(curr->cast()->what); #define HASH(clazz, what) \ hash(curr->cast()->what); #define HASH64(clazz, what) \ hash64(curr->cast()->what); #define HASH_NAME(clazz, what) \ hash64(uint64_t(curr->cast()->what.str)); #define HASH_PTR(clazz, what) \ hash64(uint64_t(curr->cast()->what)); switch (curr->_id) { case Expression::Id::BlockId: { noteName(curr->cast()->name); HASH(Block, list.size()); for (Index i = 0; i < curr->cast()->list.size(); i++) { PUSH(Block, list[i]); } break; } case Expression::Id::IfId: { PUSH(If, condition); PUSH(If, ifTrue); PUSH(If, ifFalse); break; } case Expression::Id::LoopId: { noteName(curr->cast()->name); PUSH(Loop, body); break; } case Expression::Id::BreakId: { hashName(curr->cast()->name); PUSH(Break, condition); PUSH(Break, value); break; } case Expression::Id::SwitchId: { HASH(Switch, targets.size()); for (Index i = 0; i < curr->cast()->targets.size(); i++) { hashName(curr->cast()->targets[i]); } hashName(curr->cast()->default_); PUSH(Switch, condition); PUSH(Switch, value); break; } case Expression::Id::CallId: { HASH_NAME(Call, target); HASH(Call, operands.size()); for (Index i = 0; i < curr->cast()->operands.size(); i++) { PUSH(Call, operands[i]); } break; } case Expression::Id::CallImportId: { HASH_NAME(CallImport, target); HASH(CallImport, operands.size()); for (Index i = 0; i < curr->cast()->operands.size(); i++) { PUSH(CallImport, operands[i]); } break; } case Expression::Id::CallIndirectId: { PUSH(CallIndirect, target); HASH_NAME(CallIndirect, fullType); HASH(CallIndirect, operands.size()); for (Index i = 0; i < curr->cast()->operands.size(); i++) { PUSH(CallIndirect, operands[i]); } break; } case Expression::Id::GetLocalId: { HASH(GetLocal, index); break; } case Expression::Id::SetLocalId: { HASH(SetLocal, index); PUSH(SetLocal, value); break; } case Expression::Id::GetGlobalId: { HASH_NAME(GetGlobal, name); break; } case Expression::Id::SetGlobalId: { HASH_NAME(SetGlobal, name); PUSH(SetGlobal, value); break; } case Expression::Id::LoadId: { HASH(Load, bytes); HASH(Load, signed_); HASH(Load, offset); HASH(Load, align); PUSH(Load, ptr); break; } case Expression::Id::StoreId: { HASH(Store, bytes); HASH(Store, offset); HASH(Store, align); HASH(Store, valueType); PUSH(Store, ptr); PUSH(Store, value); break; } case Expression::Id::ConstId: { HASH(Const, value.type); HASH64(Const, value.getBits()); break; } case Expression::Id::UnaryId: { HASH(Unary, op); PUSH(Unary, value); break; } case Expression::Id::BinaryId: { HASH(Binary, op); PUSH(Binary, left); PUSH(Binary, right); break; } case Expression::Id::SelectId: { PUSH(Select, ifTrue); PUSH(Select, ifFalse); PUSH(Select, condition); break; } case Expression::Id::DropId: { PUSH(Drop, value); break; } case Expression::Id::ReturnId: { PUSH(Return, value); break; } case Expression::Id::HostId: { HASH(Host, op); HASH_NAME(Host, nameOperand); HASH(Host, operands.size()); for (Index i = 0; i < curr->cast()->operands.size(); i++) { PUSH(Host, operands[i]); } break; } case Expression::Id::NopId: { break; } case Expression::Id::UnreachableId: { break; } default: WASM_UNREACHABLE(); } #undef HASH #undef PUSH } return digest; } }; // Adds drop() operations where necessary. This lets you not worry about adding drop when // generating code. struct AutoDrop : public WalkerPass>> { bool isFunctionParallel() override { return true; } Pass* create() override { return new AutoDrop; } void visitBlock(Block* curr) { if (curr->list.size() == 0) return; for (Index i = 0; i < curr->list.size() - 1; i++) { auto* child = curr->list[i]; if (isConcreteWasmType(child->type)) { curr->list[i] = Builder(*getModule()).makeDrop(child); } } auto* last = curr->list.back(); expressionStack.push_back(last); if (isConcreteWasmType(last->type) && !ExpressionAnalyzer::isResultUsed(expressionStack, getFunction())) { curr->list.back() = Builder(*getModule()).makeDrop(last); } expressionStack.pop_back(); curr->finalize(); // we may have changed our type } void visitIf(If* curr) { // if without else does not return a value, so the body must be dropped if it is concrete if (!curr->ifFalse && isConcreteWasmType(curr->ifTrue->type)) { curr->ifTrue = Builder(*getModule()).makeDrop(curr->ifTrue); } } void visitFunction(Function* curr) { if (curr->result == none && isConcreteWasmType(curr->body->type)) { curr->body = Builder(*getModule()).makeDrop(curr->body); } } }; // Finalizes a node struct ReFinalize : public WalkerPass>> { void visitBlock(Block *curr) { curr->finalize(); } void visitIf(If *curr) { curr->finalize(); } void visitLoop(Loop *curr) { curr->finalize(); } void visitBreak(Break *curr) { curr->finalize(); } void visitSwitch(Switch *curr) { curr->finalize(); } void visitCall(Call *curr) { curr->finalize(); } void visitCallImport(CallImport *curr) { curr->finalize(); } void visitCallIndirect(CallIndirect *curr) { curr->finalize(); } void visitGetLocal(GetLocal *curr) { curr->finalize(); } void visitSetLocal(SetLocal *curr) { curr->finalize(); } void visitGetGlobal(GetGlobal *curr) { curr->finalize(); } void visitSetGlobal(SetGlobal *curr) { curr->finalize(); } void visitLoad(Load *curr) { curr->finalize(); } void visitStore(Store *curr) { curr->finalize(); } void visitConst(Const *curr) { curr->finalize(); } void visitUnary(Unary *curr) { curr->finalize(); } void visitBinary(Binary *curr) { curr->finalize(); } void visitSelect(Select *curr) { curr->finalize(); } void visitDrop(Drop *curr) { curr->finalize(); } void visitReturn(Return *curr) { curr->finalize(); } void visitHost(Host *curr) { curr->finalize(); } void visitNop(Nop *curr) { curr->finalize(); } void visitUnreachable(Unreachable *curr) { curr->finalize(); } }; } // namespace wasm #endif // wasm_ast_utils_h