summaryrefslogtreecommitdiff
path: root/src/ast_utils.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/ast_utils.h')
-rw-r--r--src/ast_utils.h168
1 files changed, 117 insertions, 51 deletions
diff --git a/src/ast_utils.h b/src/ast_utils.h
index 5a7c40630..aa4120569 100644
--- a/src/ast_utils.h
+++ b/src/ast_utils.h
@@ -42,10 +42,18 @@ struct BreakSeeker : public PostWalker<BreakSeeker> {
}
void visitBreak(Break *curr) {
+ // ignore an unreachable break
+ if (curr->condition && curr->condition->type == unreachable) return;
+ if (curr->value && curr->value->type == unreachable) return;
+ // check the break
if (curr->name == target) noteFound(curr->value);
}
void visitSwitch(Switch *curr) {
+ // ignore an unreachable switch
+ if (curr->condition->type == unreachable) return;
+ if (curr->value && curr->value->type == unreachable) return;
+ // check the switch
for (auto name : curr->targets) {
if (name == target) noteFound(curr->value);
}
@@ -273,50 +281,6 @@ struct Measurer : public PostWalker<Measurer, UnifiedExpressionVisitor<Measurer>
}
};
-// Manipulate expressions
-
-struct ExpressionManipulator {
- // Re-use a node's memory. This helps avoid allocation when optimizing.
- template<typename InputType, typename OutputType>
- static OutputType* convert(InputType *input) {
- static_assert(sizeof(OutputType) <= sizeof(InputType),
- "Can only convert to a smaller size Expression node");
- input->~InputType(); // arena-allocaed, so no destructor, but avoid UB.
- OutputType* output = (OutputType*)(input);
- new (output) OutputType;
- return output;
- }
-
- // Convenience method for nop, which is a common conversion
- template<typename InputType>
- static void nop(InputType* target) {
- convert<InputType, Nop>(target);
- }
-
- // Convert a node that allocates
- template<typename InputType, typename OutputType>
- static OutputType* convert(InputType *input, MixedArena& allocator) {
- assert(sizeof(OutputType) <= sizeof(InputType));
- input->~InputType(); // arena-allocaed, so no destructor, but avoid UB.
- OutputType* output = (OutputType*)(input);
- new (output) OutputType(allocator);
- return output;
- }
-
- using CustomCopier = std::function<Expression*(Expression*)>;
- static Expression* flexibleCopy(Expression* original, Module& wasm, CustomCopier custom);
-
- static Expression* copy(Expression* original, Module& wasm) {
- auto copy = [](Expression* curr) {
- return nullptr;
- };
- return flexibleCopy(original, wasm, copy);
- }
-
- // Splice an item into the middle of a block's list
- static void spliceIntoBlock(Block* block, Index index, Expression* add);
-};
-
struct ExpressionAnalyzer {
// Given a stack of expressions, checks if the topmost is used as a result.
// For example, if the parent is a block and the node is before the last position,
@@ -357,11 +321,102 @@ struct ExpressionAnalyzer {
static uint32_t hash(Expression* curr);
};
-// Finalizes a node
-
+// Re-Finalizes all node types
+// This removes "unnecessary' block/if/loop types, i.e., that are added
+// specifically, as in
+// (block i32 (unreachable))
+// vs
+// (block (unreachable))
+// This converts to the latter form.
struct ReFinalize : public WalkerPass<PostWalker<ReFinalize>> {
+ bool isFunctionParallel() override { return true; }
+
+ Pass* create() override { return new ReFinalize; }
+
ReFinalize() { name = "refinalize"; }
+ // block finalization is O(bad) if we do each block by itself, so do it in bulk,
+ // tracking break value types so we just do a linear pass
+
+ std::map<Name, WasmType> breakValues;
+
+ void visitBlock(Block *curr) {
+ // do this quickly, without any validation
+ if (curr->name.is()) {
+ auto iter = breakValues.find(curr->name);
+ if (iter != breakValues.end()) {
+ // there is a break to here
+ curr->type = iter->second;
+ return;
+ }
+ }
+ // nothing branches here
+ if (curr->list.size() > 0) {
+ // if we have an unreachable child, we are unreachable
+ // (we don't need to recurse into children, they can't
+ // break to us)
+ for (auto* child : curr->list) {
+ if (child->type == unreachable) {
+ curr->type = unreachable;
+ return;
+ }
+ }
+ // children are reachable, so last element determines type
+ curr->type = curr->list.back()->type;
+ } else {
+ curr->type = none;
+ }
+ }
+ void visitIf(If *curr) { curr->finalize(); }
+ void visitLoop(Loop *curr) { curr->finalize(); }
+ void visitBreak(Break *curr) {
+ curr->finalize();
+ if (curr->value && curr->value->type == unreachable) {
+ return; // not an actual break
+ }
+ if (curr->condition && curr->condition->type == unreachable) {
+ return; // not an actual break
+ }
+ breakValues[curr->name] = getValueType(curr->value);
+ }
+ void visitSwitch(Switch *curr) {
+ curr->finalize();
+ if (curr->condition->type == unreachable || (curr->value && curr->value->type == unreachable)) {
+ return; // not an actual break
+ }
+ auto valueType = getValueType(curr->value);
+ for (auto target : curr->targets) {
+ breakValues[target] = valueType;
+ }
+ breakValues[curr->default_] = valueType;
+ }
+ void visitCall(Call *curr) { curr->finalize(); }
+ void visitCallImport(CallImport *curr) { curr->finalize(); }
+ void visitCallIndirect(CallIndirect *curr) { curr->finalize(); }
+ void visitGetLocal(GetLocal *curr) { curr->finalize(); }
+ void visitSetLocal(SetLocal *curr) { curr->finalize(); }
+ void visitGetGlobal(GetGlobal *curr) { curr->finalize(); }
+ void visitSetGlobal(SetGlobal *curr) { curr->finalize(); }
+ void visitLoad(Load *curr) { curr->finalize(); }
+ void visitStore(Store *curr) { curr->finalize(); }
+ void visitConst(Const *curr) { curr->finalize(); }
+ void visitUnary(Unary *curr) { curr->finalize(); }
+ void visitBinary(Binary *curr) { curr->finalize(); }
+ void visitSelect(Select *curr) { curr->finalize(); }
+ void visitDrop(Drop *curr) { curr->finalize(); }
+ void visitReturn(Return *curr) { curr->finalize(); }
+ void visitHost(Host *curr) { curr->finalize(); }
+ void visitNop(Nop *curr) { curr->finalize(); }
+ void visitUnreachable(Unreachable *curr) { curr->finalize(); }
+
+ WasmType getValueType(Expression* value) {
+ return value && value->type != unreachable ? value->type : none;
+ }
+};
+
+// Re-finalize a single node. This is slow, if you want to refinalize
+// an entire ast, use ReFinalize
+struct ReFinalizeNode : public Visitor<ReFinalizeNode> {
void visitBlock(Block *curr) { curr->finalize(); }
void visitIf(If *curr) { curr->finalize(); }
void visitLoop(Loop *curr) { curr->finalize(); }
@@ -385,10 +440,21 @@ struct ReFinalize : public WalkerPass<PostWalker<ReFinalize>> {
void visitHost(Host *curr) { curr->finalize(); }
void visitNop(Nop *curr) { curr->finalize(); }
void visitUnreachable(Unreachable *curr) { curr->finalize(); }
+
+ // given a stack of nested expressions, update them all from child to parent
+ static void updateStack(std::vector<Expression*>& expressionStack) {
+ for (int i = int(expressionStack.size()) - 1; i >= 0; i--) {
+ auto* curr = expressionStack[i];
+ ReFinalizeNode().visit(curr);
+ }
+ }
};
// Adds drop() operations where necessary. This lets you not worry about adding drop when
// generating code.
+// This also refinalizes before and after, as dropping can change types, and depends
+// on types being cleaned up - no unnecessary block/if/loop types (see refinalize)
+// TODO: optimize that, interleave them
struct AutoDrop : public WalkerPass<ExpressionStackWalker<AutoDrop>> {
bool isFunctionParallel() override { return true; }
@@ -410,10 +476,7 @@ struct AutoDrop : public WalkerPass<ExpressionStackWalker<AutoDrop>> {
}
void reFinalize() {
- for (int i = int(expressionStack.size()) - 1; i >= 0; i--) {
- auto* curr = expressionStack[i];
- ReFinalize().visit(curr);
- }
+ ReFinalizeNode::updateStack(expressionStack);
}
void visitBlock(Block* curr) {
@@ -442,10 +505,13 @@ struct AutoDrop : public WalkerPass<ExpressionStackWalker<AutoDrop>> {
}
}
- void visitFunction(Function* curr) {
+ void doWalkFunction(Function* curr) {
+ ReFinalize().walkFunctionInModule(curr, getModule());
+ walk(curr->body);
if (curr->result == none && isConcreteWasmType(curr->body->type)) {
curr->body = Builder(*getModule()).makeDrop(curr->body);
}
+ ReFinalize().walkFunctionInModule(curr, getModule());
}
};