diff options
Diffstat (limited to 'src/ast_utils.h')
-rw-r--r-- | src/ast_utils.h | 168 |
1 files changed, 117 insertions, 51 deletions
diff --git a/src/ast_utils.h b/src/ast_utils.h index 5a7c40630..aa4120569 100644 --- a/src/ast_utils.h +++ b/src/ast_utils.h @@ -42,10 +42,18 @@ struct BreakSeeker : public PostWalker<BreakSeeker> { } void visitBreak(Break *curr) { + // ignore an unreachable break + if (curr->condition && curr->condition->type == unreachable) return; + if (curr->value && curr->value->type == unreachable) return; + // check the break if (curr->name == target) noteFound(curr->value); } void visitSwitch(Switch *curr) { + // ignore an unreachable switch + if (curr->condition->type == unreachable) return; + if (curr->value && curr->value->type == unreachable) return; + // check the switch for (auto name : curr->targets) { if (name == target) noteFound(curr->value); } @@ -273,50 +281,6 @@ struct Measurer : public PostWalker<Measurer, UnifiedExpressionVisitor<Measurer> } }; -// Manipulate expressions - -struct ExpressionManipulator { - // Re-use a node's memory. This helps avoid allocation when optimizing. - template<typename InputType, typename OutputType> - static OutputType* convert(InputType *input) { - static_assert(sizeof(OutputType) <= sizeof(InputType), - "Can only convert to a smaller size Expression node"); - input->~InputType(); // arena-allocaed, so no destructor, but avoid UB. - OutputType* output = (OutputType*)(input); - new (output) OutputType; - return output; - } - - // Convenience method for nop, which is a common conversion - template<typename InputType> - static void nop(InputType* target) { - convert<InputType, Nop>(target); - } - - // Convert a node that allocates - template<typename InputType, typename OutputType> - static OutputType* convert(InputType *input, MixedArena& allocator) { - assert(sizeof(OutputType) <= sizeof(InputType)); - input->~InputType(); // arena-allocaed, so no destructor, but avoid UB. - OutputType* output = (OutputType*)(input); - new (output) OutputType(allocator); - return output; - } - - using CustomCopier = std::function<Expression*(Expression*)>; - static Expression* flexibleCopy(Expression* original, Module& wasm, CustomCopier custom); - - static Expression* copy(Expression* original, Module& wasm) { - auto copy = [](Expression* curr) { - return nullptr; - }; - return flexibleCopy(original, wasm, copy); - } - - // Splice an item into the middle of a block's list - static void spliceIntoBlock(Block* block, Index index, Expression* add); -}; - struct ExpressionAnalyzer { // Given a stack of expressions, checks if the topmost is used as a result. // For example, if the parent is a block and the node is before the last position, @@ -357,11 +321,102 @@ struct ExpressionAnalyzer { static uint32_t hash(Expression* curr); }; -// Finalizes a node - +// Re-Finalizes all node types +// This removes "unnecessary' block/if/loop types, i.e., that are added +// specifically, as in +// (block i32 (unreachable)) +// vs +// (block (unreachable)) +// This converts to the latter form. struct ReFinalize : public WalkerPass<PostWalker<ReFinalize>> { + bool isFunctionParallel() override { return true; } + + Pass* create() override { return new ReFinalize; } + ReFinalize() { name = "refinalize"; } + // block finalization is O(bad) if we do each block by itself, so do it in bulk, + // tracking break value types so we just do a linear pass + + std::map<Name, WasmType> breakValues; + + void visitBlock(Block *curr) { + // do this quickly, without any validation + if (curr->name.is()) { + auto iter = breakValues.find(curr->name); + if (iter != breakValues.end()) { + // there is a break to here + curr->type = iter->second; + return; + } + } + // nothing branches here + if (curr->list.size() > 0) { + // if we have an unreachable child, we are unreachable + // (we don't need to recurse into children, they can't + // break to us) + for (auto* child : curr->list) { + if (child->type == unreachable) { + curr->type = unreachable; + return; + } + } + // children are reachable, so last element determines type + curr->type = curr->list.back()->type; + } else { + curr->type = none; + } + } + void visitIf(If *curr) { curr->finalize(); } + void visitLoop(Loop *curr) { curr->finalize(); } + void visitBreak(Break *curr) { + curr->finalize(); + if (curr->value && curr->value->type == unreachable) { + return; // not an actual break + } + if (curr->condition && curr->condition->type == unreachable) { + return; // not an actual break + } + breakValues[curr->name] = getValueType(curr->value); + } + void visitSwitch(Switch *curr) { + curr->finalize(); + if (curr->condition->type == unreachable || (curr->value && curr->value->type == unreachable)) { + return; // not an actual break + } + auto valueType = getValueType(curr->value); + for (auto target : curr->targets) { + breakValues[target] = valueType; + } + breakValues[curr->default_] = valueType; + } + void visitCall(Call *curr) { curr->finalize(); } + void visitCallImport(CallImport *curr) { curr->finalize(); } + void visitCallIndirect(CallIndirect *curr) { curr->finalize(); } + void visitGetLocal(GetLocal *curr) { curr->finalize(); } + void visitSetLocal(SetLocal *curr) { curr->finalize(); } + void visitGetGlobal(GetGlobal *curr) { curr->finalize(); } + void visitSetGlobal(SetGlobal *curr) { curr->finalize(); } + void visitLoad(Load *curr) { curr->finalize(); } + void visitStore(Store *curr) { curr->finalize(); } + void visitConst(Const *curr) { curr->finalize(); } + void visitUnary(Unary *curr) { curr->finalize(); } + void visitBinary(Binary *curr) { curr->finalize(); } + void visitSelect(Select *curr) { curr->finalize(); } + void visitDrop(Drop *curr) { curr->finalize(); } + void visitReturn(Return *curr) { curr->finalize(); } + void visitHost(Host *curr) { curr->finalize(); } + void visitNop(Nop *curr) { curr->finalize(); } + void visitUnreachable(Unreachable *curr) { curr->finalize(); } + + WasmType getValueType(Expression* value) { + return value && value->type != unreachable ? value->type : none; + } +}; + +// Re-finalize a single node. This is slow, if you want to refinalize +// an entire ast, use ReFinalize +struct ReFinalizeNode : public Visitor<ReFinalizeNode> { void visitBlock(Block *curr) { curr->finalize(); } void visitIf(If *curr) { curr->finalize(); } void visitLoop(Loop *curr) { curr->finalize(); } @@ -385,10 +440,21 @@ struct ReFinalize : public WalkerPass<PostWalker<ReFinalize>> { void visitHost(Host *curr) { curr->finalize(); } void visitNop(Nop *curr) { curr->finalize(); } void visitUnreachable(Unreachable *curr) { curr->finalize(); } + + // given a stack of nested expressions, update them all from child to parent + static void updateStack(std::vector<Expression*>& expressionStack) { + for (int i = int(expressionStack.size()) - 1; i >= 0; i--) { + auto* curr = expressionStack[i]; + ReFinalizeNode().visit(curr); + } + } }; // Adds drop() operations where necessary. This lets you not worry about adding drop when // generating code. +// This also refinalizes before and after, as dropping can change types, and depends +// on types being cleaned up - no unnecessary block/if/loop types (see refinalize) +// TODO: optimize that, interleave them struct AutoDrop : public WalkerPass<ExpressionStackWalker<AutoDrop>> { bool isFunctionParallel() override { return true; } @@ -410,10 +476,7 @@ struct AutoDrop : public WalkerPass<ExpressionStackWalker<AutoDrop>> { } void reFinalize() { - for (int i = int(expressionStack.size()) - 1; i >= 0; i--) { - auto* curr = expressionStack[i]; - ReFinalize().visit(curr); - } + ReFinalizeNode::updateStack(expressionStack); } void visitBlock(Block* curr) { @@ -442,10 +505,13 @@ struct AutoDrop : public WalkerPass<ExpressionStackWalker<AutoDrop>> { } } - void visitFunction(Function* curr) { + void doWalkFunction(Function* curr) { + ReFinalize().walkFunctionInModule(curr, getModule()); + walk(curr->body); if (curr->result == none && isConcreteWasmType(curr->body->type)) { curr->body = Builder(*getModule()).makeDrop(curr->body); } + ReFinalize().walkFunctionInModule(curr, getModule()); } }; |