diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/ast/branch-utils.h | 14 | ||||
-rw-r--r-- | src/ast/type-updating.h | 43 | ||||
-rw-r--r-- | src/ast_utils.h | 60 | ||||
-rw-r--r-- | src/passes/FlattenControlFlow.cpp | 7 | ||||
-rw-r--r-- | src/passes/OptimizeInstructions.cpp | 8 | ||||
-rw-r--r-- | src/passes/RemoveUnusedBrs.cpp | 4 | ||||
-rw-r--r-- | src/passes/SimplifyLocals.cpp | 5 | ||||
-rw-r--r-- | src/wasm/wasm-validator.cpp | 20 | ||||
-rw-r--r-- | src/wasm/wasm.cpp | 51 |
9 files changed, 123 insertions, 89 deletions
diff --git a/src/ast/branch-utils.h b/src/ast/branch-utils.h index 77ec70753..05ead8571 100644 --- a/src/ast/branch-utils.h +++ b/src/ast/branch-utils.h @@ -103,11 +103,11 @@ inline std::set<Name> getBranchTargets(Expression* ast) { // Finds if there are branches targeting a name. Note that since names are // unique in our IR, we just need to look for the name, and do not need // to analyze scoping. -// By default we ignore untaken branches. You can set named to -// note those as well, then any named branch is noted, even if untaken +// By default we consider untaken branches (so any named use). You can unset named to +// avoid that (and only note branches that are not obviously unreachable) struct BranchSeeker : public PostWalker<BranchSeeker> { Name target; - bool named = false; + bool named = true; Index found; WasmType valueType; @@ -144,16 +144,18 @@ struct BranchSeeker : public PostWalker<BranchSeeker> { if (curr->default_ == target) noteFound(curr->value); } - static bool has(Expression* tree, Name target) { + static bool hasTaken(Expression* tree, Name target) { if (!target.is()) return false; BranchSeeker seeker(target); + seeker.named = false; seeker.walk(tree); return seeker.found > 0; } - static Index count(Expression* tree, Name target) { + static Index countTaken(Expression* tree, Name target) { if (!target.is()) return 0; BranchSeeker seeker(target); + seeker.named = false; seeker.walk(tree); return seeker.found; } @@ -161,7 +163,6 @@ struct BranchSeeker : public PostWalker<BranchSeeker> { static bool hasNamed(Expression* tree, Name target) { if (!target.is()) return false; BranchSeeker seeker(target); - seeker.named = true; seeker.walk(tree); return seeker.found > 0; } @@ -169,7 +170,6 @@ struct BranchSeeker : public PostWalker<BranchSeeker> { static Index countNamed(Expression* tree, Name target) { if (!target.is()) return 0; BranchSeeker seeker(target); - seeker.named = true; seeker.walk(tree); return seeker.found; } diff --git a/src/ast/type-updating.h b/src/ast/type-updating.h index dc0ae0f36..3cf2f2afe 100644 --- a/src/ast/type-updating.h +++ b/src/ast/type-updating.h @@ -132,13 +132,9 @@ struct TypeUpdater : public ExpressionStackWalker<TypeUpdater, UnifiedExpression // adds (or removes) breaks depending on break/switch contents void discoverBreaks(Expression* curr, int change) { if (auto* br = curr->dynCast<Break>()) { - if (BranchUtils::isBranchTaken(br)) { - noteBreakChange(br->name, change, br->value); - } + noteBreakChange(br->name, change, br->value); } else if (auto* sw = curr->dynCast<Switch>()) { - if (BranchUtils::isBranchTaken(sw)) { - applySwitchChanges(sw, change); - } + applySwitchChanges(sw, change); } } @@ -200,34 +196,17 @@ struct TypeUpdater : public ExpressionStackWalker<TypeUpdater, UnifiedExpression auto* child = curr; curr = parents[child]; if (!curr) return; - // if a child of a break/switch is now unreachable, the - // break may no longer be taken. note that if we get here, - // this is an actually new unreachable child of the - // node, so if there is just 1 such child, it is us, and - // we are newly unreachable - if (auto* br = curr->dynCast<Break>()) { - int unreachableChildren = 0; - if (br->value && br->value->type == unreachable) unreachableChildren++; - if (br->condition && br->condition->type == unreachable) unreachableChildren++; - if (unreachableChildren == 1) { - // the break is no longer taken - noteBreakChange(br->name, -1, br->value); - } - } else if (auto* sw = curr->dynCast<Switch>()) { - int unreachableChildren = 0; - if (sw->value && sw->value->type == unreachable) unreachableChildren++; - if (sw->condition->type == unreachable) unreachableChildren++; - if (unreachableChildren == 1) { - applySwitchChanges(sw, -1); - } - } // get ready to apply unreachability to this node if (curr->type == unreachable) { return; // already unreachable, stop here } // most nodes become unreachable if a child is unreachable, - // but exceptions exists + // but exceptions exist if (auto* block = curr->dynCast<Block>()) { + // if the block has a fallthrough, it can keep its type + if (isConcreteWasmType(block->list.back()->type)) { + return; // did not turn + } // if the block has breaks, it can keep its type if (!block->name.is() || blockInfos[block->name].numBreaks == 0) { curr->type = unreachable; @@ -255,7 +234,7 @@ struct TypeUpdater : public ExpressionStackWalker<TypeUpdater, UnifiedExpression return; // nothing concrete to change to unreachable } if (curr->name.is() && blockInfos[curr->name].numBreaks > 0) { - return;// has a break, not unreachable + return; // has a break, not unreachable } // look for a fallthrough makeBlockUnreachableIfNoFallThrough(curr); @@ -265,9 +244,13 @@ struct TypeUpdater : public ExpressionStackWalker<TypeUpdater, UnifiedExpression if (curr->type == unreachable) { return; // no change possible } + if (!curr->list.empty() && + isConcreteWasmType(curr->list.back()->type)) { + return; // should keep type due to fallthrough, even if has an unreachable child + } for (auto* child : curr->list) { if (child->type == unreachable) { - // no fallthrough, this block is now unreachable + // no fallthrough, and an unreachable, => this block is now unreachable changeTypeTo(curr, unreachable); return; } diff --git a/src/ast_utils.h b/src/ast_utils.h index 1f781b87e..852cf89a3 100644 --- a/src/ast_utils.h +++ b/src/ast_utils.h @@ -62,7 +62,7 @@ struct ExpressionAnalyzer { if (auto* br = curr->dynCast<Break>()) { if (!br->condition) return true; } else if (auto* block = curr->dynCast<Block>()) { - if (block->list.size() > 0 && obviouslyDoesNotFlowOut(block->list.back()) && !BranchUtils::BranchSeeker::has(block, block->name)) return true; + if (block->list.size() > 0 && obviouslyDoesNotFlowOut(block->list.back()) && !BranchUtils::BranchSeeker::hasTaken(block, block->name)) return true; } return false; } @@ -101,49 +101,59 @@ struct ReFinalize : public WalkerPass<PostWalker<ReFinalize>> { std::map<Name, WasmType> breakValues; void visitBlock(Block *curr) { + if (curr->list.size() == 0) { + curr->type = none; + return; + } // do this quickly, without any validation + auto old = curr->type; + // last element determines type + curr->type = curr->list.back()->type; + // if concrete, it doesn't matter if we have an unreachable child, and we + // don't need to look at breaks + if (isConcreteWasmType(curr->type)) return; + // otherwise, we have no final fallthrough element to determine the type, + // could be determined by breaks if (curr->name.is()) { auto iter = breakValues.find(curr->name); if (iter != breakValues.end()) { // there is a break to here - curr->type = iter->second; + auto type = iter->second; + if (type == unreachable) { + // all we have are breaks with values of type unreachable, and no + // concrete fallthrough either. we must have had an existing type, then + curr->type = old; + assert(isConcreteWasmType(curr->type)); + } else { + curr->type = type; + } return; } } - // nothing branches here - if (curr->list.size() > 0) { - // if we have an unreachable child, we are unreachable - // (we don't need to recurse into children, they can't - // break to us) + if (curr->type == unreachable) return; + // type is none, but we might be unreachable + if (curr->type == none) { for (auto* child : curr->list) { if (child->type == unreachable) { curr->type = unreachable; - return; + break; } } - // children are reachable, so last element determines type - curr->type = curr->list.back()->type; - } else { - curr->type = none; } } void visitIf(If *curr) { curr->finalize(); } void visitLoop(Loop *curr) { curr->finalize(); } void visitBreak(Break *curr) { curr->finalize(); - if (BranchUtils::isBranchTaken(curr)) { - breakValues[curr->name] = getValueType(curr->value); - } + updateBreakValueType(curr->name, getValueType(curr->value)); } void visitSwitch(Switch *curr) { curr->finalize(); - if (BranchUtils::isBranchTaken(curr)) { - auto valueType = getValueType(curr->value); - for (auto target : curr->targets) { - breakValues[target] = valueType; - } - breakValues[curr->default_] = valueType; + auto valueType = getValueType(curr->value); + for (auto target : curr->targets) { + updateBreakValueType(target, valueType); } + updateBreakValueType(curr->default_, valueType); } void visitCall(Call *curr) { curr->finalize(); } void visitCallImport(CallImport *curr) { curr->finalize(); } @@ -167,7 +177,13 @@ struct ReFinalize : public WalkerPass<PostWalker<ReFinalize>> { void visitUnreachable(Unreachable *curr) { curr->finalize(); } WasmType getValueType(Expression* value) { - return value && value->type != unreachable ? value->type : none; + return value ? value->type : none; + } + + void updateBreakValueType(Name name, WasmType type) { + if (type != unreachable || breakValues.count(name) == 0) { + breakValues[name] = type; + } } }; diff --git a/src/passes/FlattenControlFlow.cpp b/src/passes/FlattenControlFlow.cpp index 3da5809c3..dce8e6345 100644 --- a/src/passes/FlattenControlFlow.cpp +++ b/src/passes/FlattenControlFlow.cpp @@ -61,7 +61,7 @@ #include <wasm.h> #include <pass.h> #include <wasm-builder.h> - +#include <ast_utils.h> namespace wasm { @@ -461,6 +461,11 @@ struct FlattenControlFlow : public WalkerPass<PostWalker<FlattenControlFlow>> { splitter.note(operand); } } + + void visitFunction(Function* curr) { + // removing breaks can alter types + ReFinalize().walkFunctionInModule(curr, getModule()); + } }; Pass *createFlattenControlFlowPass() { diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp index f458a58b2..acbf39447 100644 --- a/src/passes/OptimizeInstructions.cpp +++ b/src/passes/OptimizeInstructions.cpp @@ -397,6 +397,14 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, // Optimizations that don't yet fit in the pattern DSL, but could be eventually maybe Expression* handOptimize(Expression* curr) { + // if this contains dead code, don't bother trying to optimize it, the type + // might change (if might not be unreachable if just one arm is, for example). + // this optimization pass focuses on actually executing code. the only + // exceptions are control flow changes + if (curr->type == unreachable && + !curr->is<Break>() && !curr->is<Switch>() && !curr->is<If>()) { + return nullptr; + } if (auto* binary = curr->dynCast<Binary>()) { if (Properties::isSymmetric(binary)) { // canonicalize a const to the second position diff --git a/src/passes/RemoveUnusedBrs.cpp b/src/passes/RemoveUnusedBrs.cpp index e307ec414..ec7809f48 100644 --- a/src/passes/RemoveUnusedBrs.cpp +++ b/src/passes/RemoveUnusedBrs.cpp @@ -394,7 +394,9 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> { if (list.size() == 1 && curr->name.is()) { // if this block has just one child, a sub-block, then jumps to the former are jumps to us, really if (auto* child = list[0]->dynCast<Block>()) { - if (child->name.is() && child->name != curr->name) { + // the two blocks must have the same type for us to update the branch, as otherwise + // one block may be unreachable and the other concrete, so one might lack a value + if (child->name.is() && child->name != curr->name && child->type == curr->type) { auto& breaks = breaksToBlock[child]; for (auto* br : breaks) { newNames[br] = curr->name; diff --git a/src/passes/SimplifyLocals.cpp b/src/passes/SimplifyLocals.cpp index 2d4a02337..919784cdf 100644 --- a/src/passes/SimplifyLocals.cpp +++ b/src/passes/SimplifyLocals.cpp @@ -372,8 +372,9 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals>> // optimize set_locals from both sides of an if into a return value void optimizeIfReturn(If* iff, Expression** currp, Sinkables& ifTrue) { assert(iff->ifFalse); - // if this if already has a result, we can't do anything - if (isConcreteWasmType(iff->type)) return; + // if this if already has a result, or is unreachable code, we have + // nothing to do + if (iff->type != none) return; // We now have the sinkables from both sides of the if. Sinkables& ifFalse = sinkables; Index sharedIndex = -1; diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp index e30661e8f..877232521 100644 --- a/src/wasm/wasm-validator.cpp +++ b/src/wasm/wasm-validator.cpp @@ -70,9 +70,7 @@ void WasmValidator::visitBlock(Block *curr) { if (curr->list.size() > 0) { auto backType = curr->list.back()->type; if (!isConcreteWasmType(curr->type)) { - if (isConcreteWasmType(backType)) { - shouldBeTrue(curr->type == unreachable, curr, "block with no value and a last element with a value must be unreachable"); - } + shouldBeFalse(isConcreteWasmType(backType), curr, "if block is not returning a value, final element should not flow out a value"); } else { if (isConcreteWasmType(backType)) { shouldBeEqual(curr->type, backType, curr, "block with value and last element with value must match types"); @@ -118,15 +116,19 @@ void WasmValidator::visitIf(If *curr) { shouldBeEqual(curr->ifFalse->type, unreachable, curr, "unreachable if-else must have unreachable false"); } } + if (isConcreteWasmType(curr->ifTrue->type)) { + shouldBeEqual(curr->type, curr->ifTrue->type, curr, "if type must match concrete ifTrue"); + shouldBeEqualOrFirstIsUnreachable(curr->ifFalse->type, curr->ifTrue->type, curr, "other arm must match concrete ifTrue"); + } + if (isConcreteWasmType(curr->ifFalse->type)) { + shouldBeEqual(curr->type, curr->ifFalse->type, curr, "if type must match concrete ifFalse"); + shouldBeEqualOrFirstIsUnreachable(curr->ifTrue->type, curr->ifFalse->type, curr, "other arm must match concrete ifFalse"); + } } } void WasmValidator::noteBreak(Name name, Expression* value, Expression* curr) { - if (!BranchUtils::isBranchTaken(curr)) { - // if not actually taken, just note the name - namedBreakTargets.insert(name); - return; - } + namedBreakTargets.insert(name); WasmType valueType = none; Index arity = 0; if (value) { @@ -548,7 +550,7 @@ void WasmValidator::visitFunction(Function *curr) { if (returnType != unreachable) { shouldBeEqual(curr->result, returnType, curr->body, "function result must match, if function has returns"); } - if (!shouldBeTrue(namedBreakTargets.empty(), curr->body, "all named break targets must exist (even if not taken)") && !quiet) { + if (!shouldBeTrue(namedBreakTargets.empty(), curr->body, "all named break targets must exist") && !quiet) { std::cerr << "(on label " << *namedBreakTargets.begin() << ")\n"; } returnType = unreachable; diff --git a/src/wasm/wasm.cpp b/src/wasm/wasm.cpp index 4594eddd1..a781e0fa3 100644 --- a/src/wasm/wasm.cpp +++ b/src/wasm/wasm.cpp @@ -116,13 +116,12 @@ struct TypeSeeker : public PostWalker<TypeSeeker> { } void visitBreak(Break* curr) { - if (curr->name == targetName && BranchUtils::isBranchTaken(curr)) { + if (curr->name == targetName) { types.push_back(curr->value ? curr->value->type : none); } } void visitSwitch(Switch* curr) { - if (!BranchUtils::isBranchTaken(curr)) return; for (auto name : curr->targets) { if (name == targetName) types.push_back(curr->value ? curr->value->type : none); } @@ -174,16 +173,17 @@ static WasmType mergeTypes(std::vector<WasmType>& types) { // and there are no branches to it static void handleUnreachable(Block* block) { if (block->type == unreachable) return; // nothing to do + if (block->list.size() == 0) return; // nothing to do + // if we are concrete, stop - even an unreachable child + // won't change that (since we have a break with a value, + // or the final child flows out a value) + if (isConcreteWasmType(block->type)) return; + // look for an unreachable child for (auto* child : block->list) { if (child->type == unreachable) { // there is an unreachable child, so we are unreachable, unless we have a break - BranchUtils::BranchSeeker seeker(block->name); - Expression* expr = block; - seeker.walk(expr); - if (!seeker.found) { + if (!BranchUtils::BranchSeeker::hasNamed(block, block->name)) { block->type = unreachable; - } else { - block->type = seeker.valueType; } return; } @@ -199,19 +199,27 @@ void Block::finalize(WasmType type_) { void Block::finalize() { if (!name.is()) { - // nothing branches here, so this is easy if (list.size() > 0) { - // if we have an unreachable child, we are unreachable - // (we don't need to recurse into children, they can't - // break to us) + // nothing branches here, so this is easy + // normally the type is the type of the final child + type = list.back()->type; + // and even if we have an unreachable child somewhere, + // we still mark ourselves as having that type, + // (block (result i32) + // (return) + // (i32.const 10) + // ) + if (isConcreteWasmType(type)) return; + // if we are unreachable, we are done + if (type == unreachable) return; + // we may still be unreachable if we have an unreachable + // child for (auto* child : list) { if (child->type == unreachable) { type = unreachable; return; } } - // children are reachable, so last element determines type - type = list.back()->type; } else { type = none; } @@ -231,9 +239,7 @@ void If::finalize(WasmType type_) { } void If::finalize() { - if (condition->type == unreachable) { - type = unreachable; - } else if (ifFalse) { + if (ifFalse) { if (ifTrue->type == ifFalse->type) { type = ifTrue->type; } else if (isConcreteWasmType(ifTrue->type) && ifFalse->type == unreachable) { @@ -246,6 +252,17 @@ void If::finalize() { } else { type = none; // if without else } + // if the arms return a value, leave it even if the condition + // is unreachable, we still mark ourselves as having that type, e.g. + // (if (result i32) + // (unreachable) + // (i32.const 10) + // (i32.const 20 + // ) + // otherwise, if the condition is unreachable, so is the if + if (type == none && condition->type == unreachable) { + type = unreachable; + } } void Loop::finalize(WasmType type_) { |