diff options
author | Alon Zakai <alonzakai@gmail.com> | 2017-09-12 15:09:21 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-09-12 15:09:21 -0700 |
commit | c6729400f68a346c1d51702946bf6026638782a6 (patch) | |
tree | a564073fbc079a88df43e12f177c8f2753670bfa /src | |
parent | 40f52f2ca41822e9dc47ff57239cdf299f7e1ce5 (diff) | |
download | binaryen-c6729400f68a346c1d51702946bf6026638782a6.tar.gz binaryen-c6729400f68a346c1d51702946bf6026638782a6.tar.bz2 binaryen-c6729400f68a346c1d51702946bf6026638782a6.zip |
precompute-propagate pass (#1179)
Implements #1172: this adds a variant of precompute, "precompute-propagate", which also does constant propagation. Precompute by itself just runs the interpreter on each expression and sees if it is in fact a constant; precompute-propagate also looks at the graph of connections between get and set locals, and propagates those constant values.
This helps with cases as noticed in #1168 - while in most cases LLVM will do this already, it's important when inlining, e.g. inlining of the clamping math functions. This new pass is run when inlining, and otherwise only in -O3/-Oz, as it does increase compilation time noticeably if run on everything (and for almost no benefit if LLVM has run).
Most of the code here is just refactoring out from the ssa pass the get/set graph computation, so it can now be used by both the ssa pass and precompute-propagate.
Diffstat (limited to 'src')
-rw-r--r-- | src/ast/CMakeLists.txt | 1 | ||||
-rw-r--r-- | src/ast/LocalGraph.cpp | 260 | ||||
-rw-r--r-- | src/ast/find_all.h | 48 | ||||
-rw-r--r-- | src/ast/local-graph.h | 111 | ||||
-rw-r--r-- | src/literal.h | 6 | ||||
-rw-r--r-- | src/passes/Inlining.cpp | 1 | ||||
-rw-r--r-- | src/passes/Precompute.cpp | 148 | ||||
-rw-r--r-- | src/passes/SSAify.cpp | 292 | ||||
-rw-r--r-- | src/passes/pass.cpp | 8 | ||||
-rw-r--r-- | src/passes/passes.h | 1 | ||||
-rw-r--r-- | src/tools/execution-results.h | 14 | ||||
-rw-r--r-- | src/tools/wasm-shell.cpp | 4 | ||||
-rw-r--r-- | src/wasm/literal.cpp | 6 |
13 files changed, 624 insertions, 276 deletions
diff --git a/src/ast/CMakeLists.txt b/src/ast/CMakeLists.txt index e48e84eed..c01deaaaf 100644 --- a/src/ast/CMakeLists.txt +++ b/src/ast/CMakeLists.txt @@ -1,5 +1,6 @@ SET(ast_SOURCES ExpressionAnalyzer.cpp ExpressionManipulator.cpp + LocalGraph.cpp ) ADD_LIBRARY(ast STATIC ${ast_SOURCES}) diff --git a/src/ast/LocalGraph.cpp b/src/ast/LocalGraph.cpp new file mode 100644 index 000000000..c997eff1b --- /dev/null +++ b/src/ast/LocalGraph.cpp @@ -0,0 +1,260 @@ +/* + * Copyright 2017 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <iterator> + +#include <wasm-builder.h> +#include <ast/find_all.h> +#include <ast/local-graph.h> + +namespace wasm { + +LocalGraph::LocalGraph(Function* func, Module* module) { + walkFunctionInModule(func, module); +} + +void LocalGraph::computeInfluences() { + for (auto& pair : locations) { + auto* curr = pair.first; + if (auto* set = curr->dynCast<SetLocal>()) { + FindAll<GetLocal> findAll(set->value); + for (auto* get : findAll.list) { + getInfluences[get].insert(set); + } + } else { + auto* get = curr->cast<GetLocal>(); + for (auto* set : getSetses[get]) { + setInfluences[set].insert(get); + } + } + } +} + +void LocalGraph::doWalkFunction(Function* func) { + numLocals = func->getNumLocals(); + if (numLocals == 0) return; // nothing to do + // We begin with each param being assigned from the incoming value, and the zero-init for the locals, + // so the initial state is the identity permutation + currMapping.resize(numLocals); + for (auto& set : currMapping) { + set = { nullptr }; + } + PostWalker<LocalGraph>::walk(func->body); +} + +// control flow + +void LocalGraph::visitBlock(Block* curr) { + if (curr->name.is() && breakMappings.find(curr->name) != breakMappings.end()) { + auto& infos = breakMappings[curr->name]; + infos.emplace_back(std::move(currMapping)); + currMapping = std::move(merge(infos)); + breakMappings.erase(curr->name); + } +} + +void LocalGraph::finishIf() { + // that's it for this if, merge + std::vector<Mapping> breaks; + breaks.emplace_back(std::move(currMapping)); + breaks.emplace_back(std::move(mappingStack.back())); + mappingStack.pop_back(); + currMapping = std::move(merge(breaks)); +} + +void LocalGraph::afterIfCondition(LocalGraph* self, Expression** currp) { + self->mappingStack.push_back(self->currMapping); +} +void LocalGraph::afterIfTrue(LocalGraph* self, Expression** currp) { + auto* curr = (*currp)->cast<If>(); + if (curr->ifFalse) { + auto afterCondition = std::move(self->mappingStack.back()); + self->mappingStack.back() = std::move(self->currMapping); + self->currMapping = std::move(afterCondition); + } else { + self->finishIf(); + } +} +void LocalGraph::afterIfFalse(LocalGraph* self, Expression** currp) { + self->finishIf(); +} +void LocalGraph::beforeLoop(LocalGraph* self, Expression** currp) { + // save the state before entering the loop, for calculation later of the merge at the loop top + self->mappingStack.push_back(self->currMapping); + self->loopGetStack.push_back({}); +} +void LocalGraph::visitLoop(Loop* curr) { + if (curr->name.is() && breakMappings.find(curr->name) != breakMappings.end()) { + auto& infos = breakMappings[curr->name]; + infos.emplace_back(std::move(mappingStack.back())); + auto before = infos.back(); + auto& merged = merge(infos); + // every local we created a phi for requires us to update get_local operations in + // the loop - the branch back has means that gets in the loop have potentially + // more sets reaching them. + // we can detect this as follows: if a get of oldIndex has the same sets + // as the sets at the entrance to the loop, then it is affected by the loop + // header sets, and we can add to there sets that looped back + auto linkLoopTop = [&](Index i, Sets& getSets) { + auto& beforeSets = before[i]; + if (getSets.size() < beforeSets.size()) { + // the get trivially has fewer sets, so it overrode the loop entry sets + return; + } + std::vector<SetLocal*> intersection; + std::set_intersection(beforeSets.begin(), beforeSets.end(), + getSets.begin(), getSets.end(), + std::back_inserter(intersection)); + if (intersection.size() < beforeSets.size()) { + // the get has not the same sets as in the loop entry + return; + } + // the get has the entry sets, so add any new ones + for (auto* set : merged[i]) { + getSets.insert(set); + } + }; + auto& gets = loopGetStack.back(); + for (auto* get : gets) { + linkLoopTop(get->index, getSetses[get]); + } + // and the same for the loop fallthrough: any local that still has the + // entry sets should also have the loop-back sets as well + for (Index i = 0; i < numLocals; i++) { + linkLoopTop(i, currMapping[i]); + } + // finally, breaks still in flight must be updated too + for (auto& iter : breakMappings) { + auto name = iter.first; + if (name == curr->name) continue; // skip our own (which is still in use) + auto& mappings = iter.second; + for (auto& mapping : mappings) { + for (Index i = 0; i < numLocals; i++) { + linkLoopTop(i, mapping[i]); + } + } + } + // now that we are done with using the mappings, erase our own + breakMappings.erase(curr->name); + } + mappingStack.pop_back(); + loopGetStack.pop_back(); +} +void LocalGraph::visitBreak(Break* curr) { + if (curr->condition) { + breakMappings[curr->name].emplace_back(currMapping); + } else { + breakMappings[curr->name].emplace_back(std::move(currMapping)); + setUnreachable(currMapping); + } +} +void LocalGraph::visitSwitch(Switch* curr) { + std::set<Name> all; + for (auto target : curr->targets) { + all.insert(target); + } + all.insert(curr->default_); + for (auto target : all) { + breakMappings[target].emplace_back(currMapping); + } + setUnreachable(currMapping); +} +void LocalGraph::visitReturn(Return *curr) { + setUnreachable(currMapping); +} +void LocalGraph::visitUnreachable(Unreachable *curr) { + setUnreachable(currMapping); +} + +// local usage + +void LocalGraph::visitGetLocal(GetLocal* curr) { + assert(currMapping.size() == numLocals); + assert(curr->index < numLocals); + for (auto& loopGets : loopGetStack) { + loopGets.push_back(curr); + } + // current sets are our sets + getSetses[curr] = currMapping[curr->index]; + locations[curr] = getCurrentPointer(); +} +void LocalGraph::visitSetLocal(SetLocal* curr) { + assert(currMapping.size() == numLocals); + assert(curr->index < numLocals); + // current sets are just this set + currMapping[curr->index] = { curr }; // TODO optimize? + locations[curr] = getCurrentPointer(); +} + +// traversal + +void LocalGraph::scan(LocalGraph* self, Expression** currp) { + if (auto* iff = (*currp)->dynCast<If>()) { + // if needs special handling + if (iff->ifFalse) { + self->pushTask(LocalGraph::afterIfFalse, currp); + self->pushTask(LocalGraph::scan, &iff->ifFalse); + } + self->pushTask(LocalGraph::afterIfTrue, currp); + self->pushTask(LocalGraph::scan, &iff->ifTrue); + self->pushTask(LocalGraph::afterIfCondition, currp); + self->pushTask(LocalGraph::scan, &iff->condition); + } else { + PostWalker<LocalGraph>::scan(self, currp); + } + + // loops need pre-order visiting too + if ((*currp)->is<Loop>()) { + self->pushTask(LocalGraph::beforeLoop, currp); + } +} + +// helpers + +void LocalGraph::setUnreachable(Mapping& mapping) { + mapping.resize(numLocals); // may have been emptied by a move + mapping[0].clear(); +} + +bool LocalGraph::isUnreachable(Mapping& mapping) { + // we must have some set for each index, if only the zero init, so empty means we emptied it for unreachable code + return mapping[0].empty(); +} + +// merges a bunch of infos into one. +// if we need phis, writes them into the provided vector. the caller should +// ensure those are placed in the right location +LocalGraph::Mapping& LocalGraph::merge(std::vector<Mapping>& mappings) { + assert(mappings.size() > 0); + auto& out = mappings[0]; + if (mappings.size() == 1) { + return out; + } + // merge into the first + for (Index j = 1; j < mappings.size(); j++) { + auto& other = mappings[j]; + for (Index i = 0; i < numLocals; i++) { + auto& outSets = out[i]; + for (auto* set : other[i]) { + outSets.insert(set); + } + } + } + return out; +} + +} // namespace wasm + diff --git a/src/ast/find_all.h b/src/ast/find_all.h new file mode 100644 index 000000000..98fe4c5a7 --- /dev/null +++ b/src/ast/find_all.h @@ -0,0 +1,48 @@ +/* + * Copyright 2017 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef wasm_ast_find_all_h +#define wasm_ast_find_all_h + +#include <wasm-traversal.h> + +namespace wasm { + +// Find all instances of a certain node type + +template<typename T> +struct FindAll { + std::vector<T*> list; + + FindAll(Expression* ast) { + struct Finder : public PostWalker<Finder, UnifiedExpressionVisitor<Finder>> { + std::vector<T*>* list; + void visitExpression(Expression* curr) { + if (curr->is<T>()) { + (*list).push_back(curr->cast<T>()); + } + } + }; + Finder finder; + finder.list = &list; + finder.walk(ast); + } +}; + +} // namespace wasm + +#endif // wasm_ast_find_all_h + diff --git a/src/ast/local-graph.h b/src/ast/local-graph.h new file mode 100644 index 000000000..03915da5e --- /dev/null +++ b/src/ast/local-graph.h @@ -0,0 +1,111 @@ +/* + * Copyright 2017 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef wasm_ast_local_graph_h +#define wasm_ast_local_graph_h + +namespace wasm { + +// +// Finds the connections between get_locals and set_locals, creating +// a graph of those ties. This is useful for "ssa-style" optimization, +// in which you want to know exactly which sets are relevant for a +// a get, so it is as if each get has just one set, logically speaking +// (see the SSA pass for actually creating new local indexes based +// on this). +// +// TODO: the algorithm here is pretty simple, but also pretty slow, +// we should optimize it. e.g. we rely on set_interaction +// here, and worse we only use it to compute the size... +struct LocalGraph : public PostWalker<LocalGraph> { + // main API + + // the constructor computes getSetses, the sets affecting each get + LocalGraph(Function* func, Module* module); + + // the set_locals relevant for an index or a get. + typedef std::set<SetLocal*> Sets; + + // externally useful information + std::map<GetLocal*, Sets> getSetses; // the sets affecting each get. a nullptr set means the initial + // value (0 for a var, the received value for a param) + std::map<Expression*, Expression**> locations; // where each get and set is (for easy replacing) + + // optional computation: compute the influence graphs between sets and gets + // (useful for algorithms that propagate changes) + + std::unordered_map<GetLocal*, std::unordered_set<SetLocal*>> getInfluences; // for each get, the sets whose values are influenced by that get + std::unordered_map<SetLocal*, std::unordered_set<GetLocal*>> setInfluences; // for each set, the gets whose values are influenced by that set + + void computeInfluences(); + +private: + // we map local index => the set_locals for that index. + // a nullptr set means there is a virtual set, from a param + // initial value or the zero init initial value. + typedef std::vector<Sets> Mapping; + + // internal state + Index numLocals; + Mapping currMapping; + std::vector<Mapping> mappingStack; // used in ifs, loops + std::map<Name, std::vector<Mapping>> breakMappings; // break target => infos that reach it + std::vector<std::vector<GetLocal*>> loopGetStack; // stack of loops, all the gets in each, so we can update them for back branches + +public: + void doWalkFunction(Function* func); + + // control flow + + void visitBlock(Block* curr); + + void finishIf(); + + static void afterIfCondition(LocalGraph* self, Expression** currp); + static void afterIfTrue(LocalGraph* self, Expression** currp); + static void afterIfFalse(LocalGraph* self, Expression** currp); + static void beforeLoop(LocalGraph* self, Expression** currp); + void visitLoop(Loop* curr); + void visitBreak(Break* curr); + void visitSwitch(Switch* curr); + void visitReturn(Return *curr); + void visitUnreachable(Unreachable *curr); + + // local usage + + void visitGetLocal(GetLocal* curr); + void visitSetLocal(SetLocal* curr); + + // traversal + + static void scan(LocalGraph* self, Expression** currp); + + // helpers + + void setUnreachable(Mapping& mapping); + + bool isUnreachable(Mapping& mapping); + + // merges a bunch of infos into one. + // if we need phis, writes them into the provided vector. the caller should + // ensure those are placed in the right location + Mapping& merge(std::vector<Mapping>& mappings); +}; + +} // namespace wasm + +#endif // wasm_ast_local_graph_h + diff --git a/src/literal.h b/src/literal.h index 560895e7a..c55d645ad 100644 --- a/src/literal.h +++ b/src/literal.h @@ -44,7 +44,7 @@ private: return val & (sizeof(T) * 8 - 1); } - public: +public: Literal() : type(WasmType::none), i64(0) {} explicit Literal(WasmType type) : type(type), i64(0) {} explicit Literal(int32_t init) : type(WasmType::i32), i32(init) {} @@ -54,6 +54,9 @@ private: explicit Literal(float init) : type(WasmType::f32), i32(bit_cast<int32_t>(init)) {} explicit Literal(double init) : type(WasmType::f64), i64(bit_cast<int64_t>(init)) {} + bool isConcrete() { return type != none; } + bool isNull() { return type == none; } + Literal castToF32(); Literal castToF64(); Literal castToI32(); @@ -76,6 +79,7 @@ private: int64_t getBits() const; bool operator==(const Literal& other) const; bool operator!=(const Literal& other) const; + bool bitwiseEqual(const Literal& other) const; static uint32_t NaNPayload(float f); static uint64_t NaNPayload(double f); diff --git a/src/passes/Inlining.cpp b/src/passes/Inlining.cpp index 192480d26..e5fdcbb9d 100644 --- a/src/passes/Inlining.cpp +++ b/src/passes/Inlining.cpp @@ -313,6 +313,7 @@ struct Inlining : public Pass { PassRunner runner(module, parentRunner->options); runner.setIsNested(true); runner.setValidateGlobally(false); // not a full valid module + runner.add("precompute-propagate"); runner.add("remove-unused-brs"); runner.add("remove-unused-names"); runner.add("coalesce-locals"); diff --git a/src/passes/Precompute.cpp b/src/passes/Precompute.cpp index c4702fdeb..72292c730 100644 --- a/src/passes/Precompute.cpp +++ b/src/passes/Precompute.cpp @@ -23,15 +23,24 @@ #include <wasm-builder.h> #include <wasm-interpreter.h> #include <ast_utils.h> -#include "ast/manipulation.h" +#include <ast/literal-utils.h> +#include <ast/local-graph.h> +#include <ast/manipulation.h> namespace wasm { static const Name NONSTANDALONE_FLOW("Binaryen|nonstandalone"); +typedef std::unordered_map<GetLocal*, Literal> GetValues; + // Execute an expression by itself. Errors if we hit anything we need anything not in the expression itself standalone. class StandaloneExpressionRunner : public ExpressionRunner<StandaloneExpressionRunner> { + // map gets to constant values, if they are known to be constant + GetValues& getValues; + public: + StandaloneExpressionRunner(GetValues& getValues) : getValues(getValues) {} + struct NonstandaloneException {}; // TODO: use a flow with a special name, as this is likely very slow Flow visitLoop(Loop* curr) { @@ -50,6 +59,13 @@ public: return Flow(NONSTANDALONE_FLOW); } Flow visitGetLocal(GetLocal *curr) { + auto iter = getValues.find(curr); + if (iter != getValues.end()) { + auto value = iter->second; + if (value.isConcrete()) { + return Flow(value); + } + } return Flow(NONSTANDALONE_FLOW); } Flow visitSetLocal(SetLocal *curr) { @@ -85,17 +101,30 @@ public: struct Precompute : public WalkerPass<PostWalker<Precompute, UnifiedExpressionVisitor<Precompute>>> { bool isFunctionParallel() override { return true; } - Pass* create() override { return new Precompute; } + Pass* create() override { return new Precompute(propagate); } + + bool propagate = false; + + Precompute(bool propagate) : propagate(propagate) {} + + GetValues getValues; + + void doWalkFunction(Function* func) { + // with extra effort, we can utilize the get-set graph to precompute + // things that use locals that are known to be constant. otherwise, + // we just look at what is immediately before us + if (propagate) { + optimizeLocals(func, getModule()); + } + // do the main and final walk over everything + WalkerPass<PostWalker<Precompute, UnifiedExpressionVisitor<Precompute>>>::doWalkFunction(func); + } void visitExpression(Expression* curr) { + // TODO: if get_local, only replace with a constant if we don't care about size...? if (curr->is<Const>() || curr->is<Nop>()) return; // try to evaluate this into a const - Flow flow; - try { - flow = StandaloneExpressionRunner().visit(curr); - } catch (StandaloneExpressionRunner::NonstandaloneException& e) { - return; - } + Flow flow = precomputeFlow(curr); if (flow.breaking()) { if (flow.breakTo == NONSTANDALONE_FLOW) return; if (flow.breakTo == RETURN_FLOW) { @@ -157,10 +186,111 @@ struct Precompute : public WalkerPass<PostWalker<Precompute, UnifiedExpressionVi // removing breaks can alter types ReFinalize().walkFunctionInModule(curr, getModule()); } + +private: + Flow precomputeFlow(Expression* curr) { + try { + return StandaloneExpressionRunner(getValues).visit(curr); + } catch (StandaloneExpressionRunner::NonstandaloneException& e) { + return Flow(NONSTANDALONE_FLOW); + } + } + + Literal precomputeValue(Expression* curr) { + Flow flow = precomputeFlow(curr); + if (flow.breaking()) { + return Literal(); + } + return flow.value; + } + + void optimizeLocals(Function* func, Module* module) { + // using the graph of get-set interactions, do a constant-propagation type + // operation: note which sets are assigned locals, then see if that lets us + // compute other sets as locals (since some of the gets they read may be + // constant). + // compute all dependencies + LocalGraph localGraph(func, module); + localGraph.computeInfluences(); + // prepare the work list. we add things here that might change to a constant + // initially, that means everything + std::unordered_set<Expression*> work; + for (auto& pair : localGraph.locations) { + auto* curr = pair.first; + work.insert(curr); + } + std::unordered_map<SetLocal*, Literal> setValues; // the constant value, or none if not a constant + // propagate constant values + while (!work.empty()) { + auto iter = work.begin(); + auto* curr = *iter; + work.erase(iter); + // see if this set or get is actually a constant value, and if so, + // mark it as such and add everything it influences to the work list, + // as they may be constant too. + if (auto* set = curr->dynCast<SetLocal>()) { + if (setValues[set].isConcrete()) continue; // already known constant + auto value = setValues[set] = precomputeValue(set->value); + if (value.isConcrete()) { + for (auto* get : localGraph.setInfluences[set]) { + work.insert(get); + } + } + } else { + auto* get = curr->cast<GetLocal>(); + if (getValues[get].isConcrete()) continue; // already known constant + // for this get to have constant value, all sets must agree + Literal value; + bool first = true; + for (auto* set : localGraph.getSetses[get]) { + Literal curr; + if (set == nullptr) { + if (getFunction()->isVar(get->index)) { + curr = LiteralUtils::makeLiteralZero(getFunction()->getLocalType(get->index)); + } else { + // it's a param, so it's hopeless + value = Literal(); + break; + } + } else { + curr = setValues[set]; + } + if (curr.isNull()) { + // not a constant, give up + value = Literal(); + break; + } + // we found a concrete value. compare with the current one + if (first) { + value = curr; // this is the first + first = false; + } else { + if (!value.bitwiseEqual(curr)) { + // not the same, give up + value = Literal(); + break; + } + } + } + // we may have found a value + if (value.isConcrete()) { + // we did! + getValues[get] = value; + for (auto* set : localGraph.getInfluences[get]) { + work.insert(set); + } + } + } + } + } }; Pass *createPrecomputePass() { - return new Precompute(); + return new Precompute(false); +} + +Pass *createPrecomputePropagatePass() { + return new Precompute(true); } } // namespace wasm diff --git a/src/passes/SSAify.cpp b/src/passes/SSAify.cpp index 9c9aeb573..a71a75308 100644 --- a/src/passes/SSAify.cpp +++ b/src/passes/SSAify.cpp @@ -35,6 +35,7 @@ #include "wasm-builder.h" #include "support/permutations.h" #include "ast/literal-utils.h" +#include "ast/local-graph.h" namespace wasm { @@ -44,261 +45,39 @@ SetLocal IMPOSSIBLE_SET; // Tracks assignments to locals, assuming single-assignment form, i.e., // each assignment creates a new variable. -struct SSAify : public WalkerPass<PostWalker<SSAify>> { +struct SSAify : public Pass { bool isFunctionParallel() override { return true; } Pass* create() override { return new SSAify; } - // the set_locals relevant for an index or a get. we use - // as set as merges of control flow mean more than 1 may - // be relevant; we create a phi on demand when necessary for those - typedef std::set<SetLocal*> Sets; - - // we map (old local index) => the set_locals for that index. - // a nullptr set means there is a virtual set, from a param - // initial value or the zero init initial value. - typedef std::vector<Sets> Mapping; - - Index numLocals; - Mapping currMapping; - Index nextIndex; - std::vector<Mapping> mappingStack; // used in ifs, loops - std::map<Name, std::vector<Mapping>> breakMappings; // break target => infos that reach it - std::vector<std::vector<GetLocal*>> loopGetStack; // stack of loops, all the gets in each, so we can update them for back branches + Module* module; + Function* func; std::vector<Expression*> functionPrepends; // things we add to the function prologue - std::map<GetLocal*, Sets> getSetses; // the sets for each get - std::map<GetLocal*, Expression**> getLocations; - void doWalkFunction(Function* func) { - numLocals = func->getNumLocals(); - if (numLocals == 0) return; // nothing to do - // We begin with each param being assigned from the incoming value, and the zero-init for the locals, - // so the initial state is the identity permutation - currMapping.resize(numLocals); - for (auto& set : currMapping) { - set = { nullptr }; - } - nextIndex = numLocals; - WalkerPass<PostWalker<SSAify>>::walk(func->body); - // apply - we now know the sets for each get - computeGetsAndPhis(); - // add prepends - if (functionPrepends.size() > 0) { - Builder builder(*getModule()); - auto* block = builder.makeBlock(); - for (auto* pre : functionPrepends) { - block->list.push_back(pre); + void runFunction(PassRunner* runner, Module* module_, Function* func_) override { + module = module_; + func = func_; + LocalGraph graph(func, module); + // create new local indexes, one for each set + createNewIndexes(graph); + // we now know the sets for each get + computeGetsAndPhis(graph); + // add prepends to function + addPrepends(); + } + + void createNewIndexes(LocalGraph& graph) { + for (auto& pair : graph.locations) { + auto* curr = pair.first; + if (auto* set = curr->dynCast<SetLocal>()) { + set->index = addLocal(func->getLocalType(set->index)); } - block->list.push_back(func->body); - block->finalize(func->body->type); - func->body = block; } } - // control flow - - void visitBlock(Block* curr) { - if (curr->name.is() && breakMappings.find(curr->name) != breakMappings.end()) { - auto& infos = breakMappings[curr->name]; - infos.emplace_back(std::move(currMapping)); - currMapping = std::move(merge(infos)); - breakMappings.erase(curr->name); - } - } - - void finishIf() { - // that's it for this if, merge - std::vector<Mapping> breaks; - breaks.emplace_back(std::move(currMapping)); - breaks.emplace_back(std::move(mappingStack.back())); - mappingStack.pop_back(); - currMapping = std::move(merge(breaks)); - } - - static void afterIfCondition(SSAify* self, Expression** currp) { - self->mappingStack.push_back(self->currMapping); - } - static void afterIfTrue(SSAify* self, Expression** currp) { - auto* curr = (*currp)->cast<If>(); - if (curr->ifFalse) { - auto afterCondition = std::move(self->mappingStack.back()); - self->mappingStack.back() = std::move(self->currMapping); - self->currMapping = std::move(afterCondition); - } else { - self->finishIf(); - } - } - static void afterIfFalse(SSAify* self, Expression** currp) { - self->finishIf(); - } - static void beforeLoop(SSAify* self, Expression** currp) { - // save the state before entering the loop, for calculation later of the merge at the loop top - self->mappingStack.push_back(self->currMapping); - self->loopGetStack.push_back({}); - } - void visitLoop(Loop* curr) { - if (curr->name.is() && breakMappings.find(curr->name) != breakMappings.end()) { - auto& infos = breakMappings[curr->name]; - infos.emplace_back(std::move(mappingStack.back())); - auto before = infos.back(); - auto& merged = merge(infos); - // every local we created a phi for requires us to update get_local operations in - // the loop - the branch back has means that gets in the loop have potentially - // more sets reaching them. - // we can detect this as follows: if a get of oldIndex has the same sets - // as the sets at the entrance to the loop, then it is affected by the loop - // header sets, and we can add to there sets that looped back - auto linkLoopTop = [&](Index i, Sets& getSets) { - auto& beforeSets = before[i]; - if (getSets.size() < beforeSets.size()) { - // the get trivially has fewer sets, so it overrode the loop entry sets - return; - } - std::vector<SetLocal*> intersection; - std::set_intersection(beforeSets.begin(), beforeSets.end(), - getSets.begin(), getSets.end(), - std::back_inserter(intersection)); - if (intersection.size() < beforeSets.size()) { - // the get has not the same sets as in the loop entry - return; - } - // the get has the entry sets, so add any new ones - for (auto* set : merged[i]) { - getSets.insert(set); - } - }; - auto& gets = loopGetStack.back(); - for (auto* get : gets) { - linkLoopTop(get->index, getSetses[get]); - } - // and the same for the loop fallthrough: any local that still has the - // entry sets should also have the loop-back sets as well - for (Index i = 0; i < numLocals; i++) { - linkLoopTop(i, currMapping[i]); - } - // finally, breaks still in flight must be updated too - for (auto& iter : breakMappings) { - auto name = iter.first; - if (name == curr->name) continue; // skip our own (which is still in use) - auto& mappings = iter.second; - for (auto& mapping : mappings) { - for (Index i = 0; i < numLocals; i++) { - linkLoopTop(i, mapping[i]); - } - } - } - // now that we are done with using the mappings, erase our own - breakMappings.erase(curr->name); - } - mappingStack.pop_back(); - loopGetStack.pop_back(); - } - void visitBreak(Break* curr) { - if (curr->condition) { - breakMappings[curr->name].emplace_back(currMapping); - } else { - breakMappings[curr->name].emplace_back(std::move(currMapping)); - setUnreachable(currMapping); - } - } - void visitSwitch(Switch* curr) { - std::set<Name> all; - for (auto target : curr->targets) { - all.insert(target); - } - all.insert(curr->default_); - for (auto target : all) { - breakMappings[target].emplace_back(currMapping); - } - setUnreachable(currMapping); - } - void visitReturn(Return *curr) { - setUnreachable(currMapping); - } - void visitUnreachable(Unreachable *curr) { - setUnreachable(currMapping); - } - - // local usage - - void visitGetLocal(GetLocal* curr) { - assert(currMapping.size() == numLocals); - assert(curr->index < numLocals); - for (auto& loopGets : loopGetStack) { - loopGets.push_back(curr); - } - // current sets are our sets - getSetses[curr] = currMapping[curr->index]; - getLocations[curr] = getCurrentPointer(); - } - void visitSetLocal(SetLocal* curr) { - assert(currMapping.size() == numLocals); - assert(curr->index < numLocals); - // current sets are just this set - currMapping[curr->index] = { curr }; // TODO optimize? - curr->index = addLocal(getFunction()->getLocalType(curr->index)); - } - - // traversal - - static void scan(SSAify* self, Expression** currp) { - if (auto* iff = (*currp)->dynCast<If>()) { - // if needs special handling - if (iff->ifFalse) { - self->pushTask(SSAify::afterIfFalse, currp); - self->pushTask(SSAify::scan, &iff->ifFalse); - } - self->pushTask(SSAify::afterIfTrue, currp); - self->pushTask(SSAify::scan, &iff->ifTrue); - self->pushTask(SSAify::afterIfCondition, currp); - self->pushTask(SSAify::scan, &iff->condition); - } else { - WalkerPass<PostWalker<SSAify>>::scan(self, currp); - } - - // loops need pre-order visiting too - if ((*currp)->is<Loop>()) { - self->pushTask(SSAify::beforeLoop, currp); - } - } - - // helpers - - void setUnreachable(Mapping& mapping) { - mapping.resize(numLocals); // may have been emptied by a move - mapping[0].clear(); - } - - bool isUnreachable(Mapping& mapping) { - // we must have some set for each index, if only the zero init, so empty means we emptied it for unreachable code - return mapping[0].empty(); - } - - // merges a bunch of infos into one. - // if we need phis, writes them into the provided vector. the caller should - // ensure those are placed in the right location - Mapping& merge(std::vector<Mapping>& mappings) { - assert(mappings.size() > 0); - auto& out = mappings[0]; - if (mappings.size() == 1) { - return out; - } - // merge into the first - for (Index j = 1; j < mappings.size(); j++) { - auto& other = mappings[j]; - for (Index i = 0; i < numLocals; i++) { - auto& outSets = out[i]; - for (auto* set : other[i]) { - outSets.insert(set); - } - } - } - return out; - } - // After we traversed it all, we can compute gets and phis - void computeGetsAndPhis() { - for (auto& iter : getSetses) { + void computeGetsAndPhis(LocalGraph& graph) { + for (auto& iter : graph.getSetses) { auto* get = iter.first; auto& sets = iter.second; if (sets.size() == 0) { @@ -312,11 +91,11 @@ struct SSAify : public WalkerPass<PostWalker<SSAify>> { get->index = set->index; } else { // no set, assign param or zero - if (getFunction()->isParam(get->index)) { + if (func->isParam(get->index)) { // leave it, it's fine } else { // zero it out - (*getLocations[get]) = LiteralUtils::makeZero(get->type, *getModule()); + (*graph.locations[get]) = LiteralUtils::makeZero(get->type, *module); } } continue; @@ -354,7 +133,7 @@ struct SSAify : public WalkerPass<PostWalker<SSAify>> { auto new_ = addLocal(get->type); auto old = get->index; get->index = new_; - Builder builder(*getModule()); + Builder builder(*module); // write to the local in each of our sets for (auto* set : sets) { if (set) { @@ -365,12 +144,12 @@ struct SSAify : public WalkerPass<PostWalker<SSAify>> { ); } else { // this is a param or the zero init value. - if (getFunction()->isParam(old)) { + if (func->isParam(old)) { // we add a set with the proper // param value at the beginning of the function auto* set = builder.makeSetLocal( new_, - builder.makeGetLocal(old, getFunction()->getLocalType(old)) + builder.makeGetLocal(old, func->getLocalType(old)) ); functionPrepends.push_back(set); } else { @@ -383,7 +162,20 @@ struct SSAify : public WalkerPass<PostWalker<SSAify>> { } Index addLocal(WasmType type) { - return Builder::addVar(getFunction(), type); + return Builder::addVar(func, type); + } + + void addPrepends() { + if (functionPrepends.size() > 0) { + Builder builder(*module); + auto* block = builder.makeBlock(); + for (auto* pre : functionPrepends) { + block->list.push_back(pre); + } + block->list.push_back(func->body); + block->finalize(func->body->type); + func->body = block; + } } }; diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp index 4ec454a50..a208a03dd 100644 --- a/src/passes/pass.cpp +++ b/src/passes/pass.cpp @@ -89,6 +89,7 @@ void PassRegistry::registerPasses() { registerPass("pick-load-signs", "pick load signs based on their uses", createPickLoadSignsPass); registerPass("post-emscripten", "miscellaneous optimizations for Emscripten-generated code", createPostEmscriptenPass); registerPass("precompute", "computes compile-time evaluatable expressions", createPrecomputePass); + registerPass("precompute-propagate", "computes compile-time evaluatable expressions and propagates them through locals", createPrecomputePropagatePass); registerPass("print", "print in s-expression format", createPrinterPass); registerPass("print-minified", "print in minified s-expression format", createMinifiedPrinterPass); registerPass("print-full", "print in full s-expression format", createFullPrinterPass); @@ -148,7 +149,12 @@ void PassRunner::addDefaultFunctionOptimizationPasses() { add("remove-unused-brs"); // coalesce-locals opens opportunities for optimizations add("merge-blocks"); // clean up remove-unused-brs new blocks add("optimize-instructions"); - add("precompute"); + // if we are willing to work hard, also propagate + if (options.optimizeLevel >= 3 || options.shrinkLevel >= 2) { + add("precompute-propagate"); + } else { + add("precompute"); + } if (options.shrinkLevel >= 2) { add("local-cse"); // TODO: run this early, before first coalesce-locals. right now doing so uncovers some deficiencies we need to fix first add("coalesce-locals"); // just for localCSE diff --git a/src/passes/passes.h b/src/passes/passes.h index 4e039f4bf..a02216083 100644 --- a/src/passes/passes.h +++ b/src/passes/passes.h @@ -49,6 +49,7 @@ Pass *createOptimizeInstructionsPass(); Pass *createPickLoadSignsPass(); Pass *createPostEmscriptenPass(); Pass *createPrecomputePass(); +Pass *createPrecomputePropagatePass(); Pass *createPrinterPass(); Pass *createPrintCallGraphPass(); Pass *createRelooperJumpThreadingPass(); diff --git a/src/tools/execution-results.h b/src/tools/execution-results.h index 5ed0ff01f..1fee4ffb0 100644 --- a/src/tools/execution-results.h +++ b/src/tools/execution-results.h @@ -23,18 +23,6 @@ namespace wasm { -static bool areBitwiseEqual(Literal a, Literal b) { - if (a == b) return true; - // accept equal nans if equal in all bits - if (a.type != b.type) return false; - if (a.type == f32) { - return a.reinterpreti32() == b.reinterpreti32(); - } else if (a.type == f64) { - return a.reinterpreti64() == b.reinterpreti64(); - } - return false; -} - // gets execution results from a wasm module. this is useful for fuzzing // // we can only get results when there are no imports. we then call each method @@ -86,7 +74,7 @@ struct ExecutionResults { abort(); } std::cout << "[fuzz-exec] comparing " << name << '\n'; - if (!areBitwiseEqual(results[name], other.results[name])) { + if (!results[name].bitwiseEqual(other.results[name])) { std::cout << "not identical!\n"; abort(); } diff --git a/src/tools/wasm-shell.cpp b/src/tools/wasm-shell.cpp index b7fde3af6..cd8c27437 100644 --- a/src/tools/wasm-shell.cpp +++ b/src/tools/wasm-shell.cpp @@ -201,14 +201,14 @@ static void run_asserts(Name moduleName, size_t* i, bool* checked, Module* wasm, ->dynCast<Const>() ->value; std::cerr << "seen " << result << ", expected " << expected << '\n'; - if (!areBitwiseEqual(expected, result)) { + if (!expected.bitwiseEqual(result)) { std::cout << "unexpected, should be identical\n"; abort(); } } else { Literal expected; std::cerr << "seen " << result << ", expected " << expected << '\n'; - if (!areBitwiseEqual(expected, result)) { + if (!expected.bitwiseEqual(result)) { std::cout << "unexpected, should be identical\n"; abort(); } diff --git a/src/wasm/literal.cpp b/src/wasm/literal.cpp index be1363dde..de89a7ac1 100644 --- a/src/wasm/literal.cpp +++ b/src/wasm/literal.cpp @@ -93,6 +93,12 @@ bool Literal::operator!=(const Literal& other) const { return !(*this == other); } +bool Literal::bitwiseEqual(const Literal& other) const { + if (type != other.type) return false; + if (type == none) return true; + return getBits() == other.getBits(); +} + uint32_t Literal::NaNPayload(float f) { assert(std::isnan(f) && "expected a NaN"); // SEEEEEEE EFFFFFFF FFFFFFFF FFFFFFFF |