diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/ast/CMakeLists.txt | 1 | ||||
-rw-r--r-- | src/ast/LocalGraph.cpp | 260 | ||||
-rw-r--r-- | src/ast/find_all.h | 48 | ||||
-rw-r--r-- | src/ast/local-graph.h | 111 | ||||
-rw-r--r-- | src/literal.h | 6 | ||||
-rw-r--r-- | src/passes/Inlining.cpp | 1 | ||||
-rw-r--r-- | src/passes/Precompute.cpp | 148 | ||||
-rw-r--r-- | src/passes/SSAify.cpp | 292 | ||||
-rw-r--r-- | src/passes/pass.cpp | 8 | ||||
-rw-r--r-- | src/passes/passes.h | 1 | ||||
-rw-r--r-- | src/tools/execution-results.h | 14 | ||||
-rw-r--r-- | src/tools/wasm-shell.cpp | 4 | ||||
-rw-r--r-- | src/wasm/literal.cpp | 6 |
13 files changed, 624 insertions, 276 deletions
diff --git a/src/ast/CMakeLists.txt b/src/ast/CMakeLists.txt index e48e84eed..c01deaaaf 100644 --- a/src/ast/CMakeLists.txt +++ b/src/ast/CMakeLists.txt @@ -1,5 +1,6 @@ SET(ast_SOURCES ExpressionAnalyzer.cpp ExpressionManipulator.cpp + LocalGraph.cpp ) ADD_LIBRARY(ast STATIC ${ast_SOURCES}) diff --git a/src/ast/LocalGraph.cpp b/src/ast/LocalGraph.cpp new file mode 100644 index 000000000..c997eff1b --- /dev/null +++ b/src/ast/LocalGraph.cpp @@ -0,0 +1,260 @@ +/* + * Copyright 2017 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <iterator> + +#include <wasm-builder.h> +#include <ast/find_all.h> +#include <ast/local-graph.h> + +namespace wasm { + +LocalGraph::LocalGraph(Function* func, Module* module) { + walkFunctionInModule(func, module); +} + +void LocalGraph::computeInfluences() { + for (auto& pair : locations) { + auto* curr = pair.first; + if (auto* set = curr->dynCast<SetLocal>()) { + FindAll<GetLocal> findAll(set->value); + for (auto* get : findAll.list) { + getInfluences[get].insert(set); + } + } else { + auto* get = curr->cast<GetLocal>(); + for (auto* set : getSetses[get]) { + setInfluences[set].insert(get); + } + } + } +} + +void LocalGraph::doWalkFunction(Function* func) { + numLocals = func->getNumLocals(); + if (numLocals == 0) return; // nothing to do + // We begin with each param being assigned from the incoming value, and the zero-init for the locals, + // so the initial state is the identity permutation + currMapping.resize(numLocals); + for (auto& set : currMapping) { + set = { nullptr }; + } + PostWalker<LocalGraph>::walk(func->body); +} + +// control flow + +void LocalGraph::visitBlock(Block* curr) { + if (curr->name.is() && breakMappings.find(curr->name) != breakMappings.end()) { + auto& infos = breakMappings[curr->name]; + infos.emplace_back(std::move(currMapping)); + currMapping = std::move(merge(infos)); + breakMappings.erase(curr->name); + } +} + +void LocalGraph::finishIf() { + // that's it for this if, merge + std::vector<Mapping> breaks; + breaks.emplace_back(std::move(currMapping)); + breaks.emplace_back(std::move(mappingStack.back())); + mappingStack.pop_back(); + currMapping = std::move(merge(breaks)); +} + +void LocalGraph::afterIfCondition(LocalGraph* self, Expression** currp) { + self->mappingStack.push_back(self->currMapping); +} +void LocalGraph::afterIfTrue(LocalGraph* self, Expression** currp) { + auto* curr = (*currp)->cast<If>(); + if (curr->ifFalse) { + auto afterCondition = std::move(self->mappingStack.back()); + self->mappingStack.back() = std::move(self->currMapping); + self->currMapping = std::move(afterCondition); + } else { + self->finishIf(); + } +} +void LocalGraph::afterIfFalse(LocalGraph* self, Expression** currp) { + self->finishIf(); +} +void LocalGraph::beforeLoop(LocalGraph* self, Expression** currp) { + // save the state before entering the loop, for calculation later of the merge at the loop top + self->mappingStack.push_back(self->currMapping); + self->loopGetStack.push_back({}); +} +void LocalGraph::visitLoop(Loop* curr) { + if (curr->name.is() && breakMappings.find(curr->name) != breakMappings.end()) { + auto& infos = breakMappings[curr->name]; + infos.emplace_back(std::move(mappingStack.back())); + auto before = infos.back(); + auto& merged = merge(infos); + // every local we created a phi for requires us to update get_local operations in + // the loop - the branch back has means that gets in the loop have potentially + // more sets reaching them. + // we can detect this as follows: if a get of oldIndex has the same sets + // as the sets at the entrance to the loop, then it is affected by the loop + // header sets, and we can add to there sets that looped back + auto linkLoopTop = [&](Index i, Sets& getSets) { + auto& beforeSets = before[i]; + if (getSets.size() < beforeSets.size()) { + // the get trivially has fewer sets, so it overrode the loop entry sets + return; + } + std::vector<SetLocal*> intersection; + std::set_intersection(beforeSets.begin(), beforeSets.end(), + getSets.begin(), getSets.end(), + std::back_inserter(intersection)); + if (intersection.size() < beforeSets.size()) { + // the get has not the same sets as in the loop entry + return; + } + // the get has the entry sets, so add any new ones + for (auto* set : merged[i]) { + getSets.insert(set); + } + }; + auto& gets = loopGetStack.back(); + for (auto* get : gets) { + linkLoopTop(get->index, getSetses[get]); + } + // and the same for the loop fallthrough: any local that still has the + // entry sets should also have the loop-back sets as well + for (Index i = 0; i < numLocals; i++) { + linkLoopTop(i, currMapping[i]); + } + // finally, breaks still in flight must be updated too + for (auto& iter : breakMappings) { + auto name = iter.first; + if (name == curr->name) continue; // skip our own (which is still in use) + auto& mappings = iter.second; + for (auto& mapping : mappings) { + for (Index i = 0; i < numLocals; i++) { + linkLoopTop(i, mapping[i]); + } + } + } + // now that we are done with using the mappings, erase our own + breakMappings.erase(curr->name); + } + mappingStack.pop_back(); + loopGetStack.pop_back(); +} +void LocalGraph::visitBreak(Break* curr) { + if (curr->condition) { + breakMappings[curr->name].emplace_back(currMapping); + } else { + breakMappings[curr->name].emplace_back(std::move(currMapping)); + setUnreachable(currMapping); + } +} +void LocalGraph::visitSwitch(Switch* curr) { + std::set<Name> all; + for (auto target : curr->targets) { + all.insert(target); + } + all.insert(curr->default_); + for (auto target : all) { + breakMappings[target].emplace_back(currMapping); + } + setUnreachable(currMapping); +} +void LocalGraph::visitReturn(Return *curr) { + setUnreachable(currMapping); +} +void LocalGraph::visitUnreachable(Unreachable *curr) { + setUnreachable(currMapping); +} + +// local usage + +void LocalGraph::visitGetLocal(GetLocal* curr) { + assert(currMapping.size() == numLocals); + assert(curr->index < numLocals); + for (auto& loopGets : loopGetStack) { + loopGets.push_back(curr); + } + // current sets are our sets + getSetses[curr] = currMapping[curr->index]; + locations[curr] = getCurrentPointer(); +} +void LocalGraph::visitSetLocal(SetLocal* curr) { + assert(currMapping.size() == numLocals); + assert(curr->index < numLocals); + // current sets are just this set + currMapping[curr->index] = { curr }; // TODO optimize? + locations[curr] = getCurrentPointer(); +} + +// traversal + +void LocalGraph::scan(LocalGraph* self, Expression** currp) { + if (auto* iff = (*currp)->dynCast<If>()) { + // if needs special handling + if (iff->ifFalse) { + self->pushTask(LocalGraph::afterIfFalse, currp); + self->pushTask(LocalGraph::scan, &iff->ifFalse); + } + self->pushTask(LocalGraph::afterIfTrue, currp); + self->pushTask(LocalGraph::scan, &iff->ifTrue); + self->pushTask(LocalGraph::afterIfCondition, currp); + self->pushTask(LocalGraph::scan, &iff->condition); + } else { + PostWalker<LocalGraph>::scan(self, currp); + } + + // loops need pre-order visiting too + if ((*currp)->is<Loop>()) { + self->pushTask(LocalGraph::beforeLoop, currp); + } +} + +// helpers + +void LocalGraph::setUnreachable(Mapping& mapping) { + mapping.resize(numLocals); // may have been emptied by a move + mapping[0].clear(); +} + +bool LocalGraph::isUnreachable(Mapping& mapping) { + // we must have some set for each index, if only the zero init, so empty means we emptied it for unreachable code + return mapping[0].empty(); +} + +// merges a bunch of infos into one. +// if we need phis, writes them into the provided vector. the caller should +// ensure those are placed in the right location +LocalGraph::Mapping& LocalGraph::merge(std::vector<Mapping>& mappings) { + assert(mappings.size() > 0); + auto& out = mappings[0]; + if (mappings.size() == 1) { + return out; + } + // merge into the first + for (Index j = 1; j < mappings.size(); j++) { + auto& other = mappings[j]; + for (Index i = 0; i < numLocals; i++) { + auto& outSets = out[i]; + for (auto* set : other[i]) { + outSets.insert(set); + } + } + } + return out; +} + +} // namespace wasm + diff --git a/src/ast/find_all.h b/src/ast/find_all.h new file mode 100644 index 000000000..98fe4c5a7 --- /dev/null +++ b/src/ast/find_all.h @@ -0,0 +1,48 @@ +/* + * Copyright 2017 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef wasm_ast_find_all_h +#define wasm_ast_find_all_h + +#include <wasm-traversal.h> + +namespace wasm { + +// Find all instances of a certain node type + +template<typename T> +struct FindAll { + std::vector<T*> list; + + FindAll(Expression* ast) { + struct Finder : public PostWalker<Finder, UnifiedExpressionVisitor<Finder>> { + std::vector<T*>* list; + void visitExpression(Expression* curr) { + if (curr->is<T>()) { + (*list).push_back(curr->cast<T>()); + } + } + }; + Finder finder; + finder.list = &list; + finder.walk(ast); + } +}; + +} // namespace wasm + +#endif // wasm_ast_find_all_h + diff --git a/src/ast/local-graph.h b/src/ast/local-graph.h new file mode 100644 index 000000000..03915da5e --- /dev/null +++ b/src/ast/local-graph.h @@ -0,0 +1,111 @@ +/* + * Copyright 2017 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef wasm_ast_local_graph_h +#define wasm_ast_local_graph_h + +namespace wasm { + +// +// Finds the connections between get_locals and set_locals, creating +// a graph of those ties. This is useful for "ssa-style" optimization, +// in which you want to know exactly which sets are relevant for a +// a get, so it is as if each get has just one set, logically speaking +// (see the SSA pass for actually creating new local indexes based +// on this). +// +// TODO: the algorithm here is pretty simple, but also pretty slow, +// we should optimize it. e.g. we rely on set_interaction +// here, and worse we only use it to compute the size... +struct LocalGraph : public PostWalker<LocalGraph> { + // main API + + // the constructor computes getSetses, the sets affecting each get + LocalGraph(Function* func, Module* module); + + // the set_locals relevant for an index or a get. + typedef std::set<SetLocal*> Sets; + + // externally useful information + std::map<GetLocal*, Sets> getSetses; // the sets affecting each get. a nullptr set means the initial + // value (0 for a var, the received value for a param) + std::map<Expression*, Expression**> locations; // where each get and set is (for easy replacing) + + // optional computation: compute the influence graphs between sets and gets + // (useful for algorithms that propagate changes) + + std::unordered_map<GetLocal*, std::unordered_set<SetLocal*>> getInfluences; // for each get, the sets whose values are influenced by that get + std::unordered_map<SetLocal*, std::unordered_set<GetLocal*>> setInfluences; // for each set, the gets whose values are influenced by that set + + void computeInfluences(); + +private: + // we map local index => the set_locals for that index. + // a nullptr set means there is a virtual set, from a param + // initial value or the zero init initial value. + typedef std::vector<Sets> Mapping; + + // internal state + Index numLocals; + Mapping currMapping; + std::vector<Mapping> mappingStack; // used in ifs, loops + std::map<Name, std::vector<Mapping>> breakMappings; // break target => infos that reach it + std::vector<std::vector<GetLocal*>> loopGetStack; // stack of loops, all the gets in each, so we can update them for back branches + +public: + void doWalkFunction(Function* func); + + // control flow + + void visitBlock(Block* curr); + + void finishIf(); + + static void afterIfCondition(LocalGraph* self, Expression** currp); + static void afterIfTrue(LocalGraph* self, Expression** currp); + static void afterIfFalse(LocalGraph* self, Expression** currp); + static void beforeLoop(LocalGraph* self, Expression** currp); + void visitLoop(Loop* curr); + void visitBreak(Break* curr); + void visitSwitch(Switch* curr); + void visitReturn(Return *curr); + void visitUnreachable(Unreachable *curr); + + // local usage + + void visitGetLocal(GetLocal* curr); + void visitSetLocal(SetLocal* curr); + + // traversal + + static void scan(LocalGraph* self, Expression** currp); + + // helpers + + void setUnreachable(Mapping& mapping); + + bool isUnreachable(Mapping& mapping); + + // merges a bunch of infos into one. + // if we need phis, writes them into the provided vector. the caller should + // ensure those are placed in the right location + Mapping& merge(std::vector<Mapping>& mappings); +}; + +} // namespace wasm + +#endif // wasm_ast_local_graph_h + diff --git a/src/literal.h b/src/literal.h index 560895e7a..c55d645ad 100644 --- a/src/literal.h +++ b/src/literal.h @@ -44,7 +44,7 @@ private: return val & (sizeof(T) * 8 - 1); } - public: +public: Literal() : type(WasmType::none), i64(0) {} explicit Literal(WasmType type) : type(type), i64(0) {} explicit Literal(int32_t init) : type(WasmType::i32), i32(init) {} @@ -54,6 +54,9 @@ private: explicit Literal(float init) : type(WasmType::f32), i32(bit_cast<int32_t>(init)) {} explicit Literal(double init) : type(WasmType::f64), i64(bit_cast<int64_t>(init)) {} + bool isConcrete() { return type != none; } + bool isNull() { return type == none; } + Literal castToF32(); Literal castToF64(); Literal castToI32(); @@ -76,6 +79,7 @@ private: int64_t getBits() const; bool operator==(const Literal& other) const; bool operator!=(const Literal& other) const; + bool bitwiseEqual(const Literal& other) const; static uint32_t NaNPayload(float f); static uint64_t NaNPayload(double f); diff --git a/src/passes/Inlining.cpp b/src/passes/Inlining.cpp index 192480d26..e5fdcbb9d 100644 --- a/src/passes/Inlining.cpp +++ b/src/passes/Inlining.cpp @@ -313,6 +313,7 @@ struct Inlining : public Pass { PassRunner runner(module, parentRunner->options); runner.setIsNested(true); runner.setValidateGlobally(false); // not a full valid module + runner.add("precompute-propagate"); runner.add("remove-unused-brs"); runner.add("remove-unused-names"); runner.add("coalesce-locals"); diff --git a/src/passes/Precompute.cpp b/src/passes/Precompute.cpp index c4702fdeb..72292c730 100644 --- a/src/passes/Precompute.cpp +++ b/src/passes/Precompute.cpp @@ -23,15 +23,24 @@ #include <wasm-builder.h> #include <wasm-interpreter.h> #include <ast_utils.h> -#include "ast/manipulation.h" +#include <ast/literal-utils.h> +#include <ast/local-graph.h> +#include <ast/manipulation.h> namespace wasm { static const Name NONSTANDALONE_FLOW("Binaryen|nonstandalone"); +typedef std::unordered_map<GetLocal*, Literal> GetValues; + // Execute an expression by itself. Errors if we hit anything we need anything not in the expression itself standalone. class StandaloneExpressionRunner : public ExpressionRunner<StandaloneExpressionRunner> { + // map gets to constant values, if they are known to be constant + GetValues& getValues; + public: + StandaloneExpressionRunner(GetValues& getValues) : getValues(getValues) {} + struct NonstandaloneException {}; // TODO: use a flow with a special name, as this is likely very slow Flow visitLoop(Loop* curr) { @@ -50,6 +59,13 @@ public: return Flow(NONSTANDALONE_FLOW); } Flow visitGetLocal(GetLocal *curr) { + auto iter = getValues.find(curr); + if (iter != getValues.end()) { + auto value = iter->second; + if (value.isConcrete()) { + return Flow(value); + } + } return Flow(NONSTANDALONE_FLOW); } Flow visitSetLocal(SetLocal *curr) { @@ -85,17 +101,30 @@ public: struct Precompute : public WalkerPass<PostWalker<Precompute, UnifiedExpressionVisitor<Precompute>>> { bool isFunctionParallel() override { return true; } - Pass* create() override { return new Precompute; } + Pass* create() override { return new Precompute(propagate); } + + bool propagate = false; + + Precompute(bool propagate) : propagate(propagate) {} + + GetValues getValues; + + void doWalkFunction(Function* func) { + // with extra effort, we can utilize the get-set graph to precompute + // things that use locals that are known to be constant. otherwise, + // we just look at what is immediately before us + if (propagate) { + optimizeLocals(func, getModule()); + } + // do the main and final walk over everything + WalkerPass<PostWalker<Precompute, UnifiedExpressionVisitor<Precompute>>>::doWalkFunction(func); + } void visitExpression(Expression* curr) { + // TODO: if get_local, only replace with a constant if we don't care about size...? if (curr->is<Const>() || curr->is<Nop>()) return; // try to evaluate this into a const - Flow flow; - try { - flow = StandaloneExpressionRunner().visit(curr); - } catch (StandaloneExpressionRunner::NonstandaloneException& e) { - return; - } + Flow flow = precomputeFlow(curr); if (flow.breaking()) { if (flow.breakTo == NONSTANDALONE_FLOW) return; if (flow.breakTo == RETURN_FLOW) { @@ -157,10 +186,111 @@ struct Precompute : public WalkerPass<PostWalker<Precompute, UnifiedExpressionVi // removing breaks can alter types ReFinalize().walkFunctionInModule(curr, getModule()); } + +private: + Flow precomputeFlow(Expression* curr) { + try { + return StandaloneExpressionRunner(getValues).visit(curr); + } catch (StandaloneExpressionRunner::NonstandaloneException& e) { + return Flow(NONSTANDALONE_FLOW); + } + } + + Literal precomputeValue(Expression* curr) { + Flow flow = precomputeFlow(curr); + if (flow.breaking()) { + return Literal(); + } + return flow.value; + } + + void optimizeLocals(Function* func, Module* module) { + // using the graph of get-set interactions, do a constant-propagation type + // operation: note which sets are assigned locals, then see if that lets us + // compute other sets as locals (since some of the gets they read may be + // constant). + // compute all dependencies + LocalGraph localGraph(func, module); + localGraph.computeInfluences(); + // prepare the work list. we add things here that might change to a constant + // initially, that means everything + std::unordered_set<Expression*> work; + for (auto& pair : localGraph.locations) { + auto* curr = pair.first; + work.insert(curr); + } + std::unordered_map<SetLocal*, Literal> setValues; // the constant value, or none if not a constant + // propagate constant values + while (!work.empty()) { + auto iter = work.begin(); + auto* curr = *iter; + work.erase(iter); + // see if this set or get is actually a constant value, and if so, + // mark it as such and add everything it influences to the work list, + // as they may be constant too. + if (auto* set = curr->dynCast<SetLocal>()) { + if (setValues[set].isConcrete()) continue; // already known constant + auto value = setValues[set] = precomputeValue(set->value); + if (value.isConcrete()) { + for (auto* get : localGraph.setInfluences[set]) { + work.insert(get); + } + } + } else { + auto* get = curr->cast<GetLocal>(); + if (getValues[get].isConcrete()) continue; // already known constant + // for this get to have constant value, all sets must agree + Literal value; + bool first = true; + for (auto* set : localGraph.getSetses[get]) { + Literal curr; + if (set == nullptr) { + if (getFunction()->isVar(get->index)) { + curr = LiteralUtils::makeLiteralZero(getFunction()->getLocalType(get->index)); + } else { + // it's a param, so it's hopeless + value = Literal(); + break; + } + } else { + curr = setValues[set]; + } + if (curr.isNull()) { + // not a constant, give up + value = Literal(); + break; + } + // we found a concrete value. compare with the current one + if (first) { + value = curr; // this is the first + first = false; + } else { + if (!value.bitwiseEqual(curr)) { + // not the same, give up + value = Literal(); + break; + } + } + } + // we may have found a value + if (value.isConcrete()) { + // we did! + getValues[get] = value; + for (auto* set : localGraph.getInfluences[get]) { + work.insert(set); + } + } + } + } + } }; Pass *createPrecomputePass() { - return new Precompute(); + return new Precompute(false); +} + +Pass *createPrecomputePropagatePass() { + return new Precompute(true); } } // namespace wasm diff --git a/src/passes/SSAify.cpp b/src/passes/SSAify.cpp index 9c9aeb573..a71a75308 100644 --- a/src/passes/SSAify.cpp +++ b/src/passes/SSAify.cpp @@ -35,6 +35,7 @@ #include "wasm-builder.h" #include "support/permutations.h" #include "ast/literal-utils.h" +#include "ast/local-graph.h" namespace wasm { @@ -44,261 +45,39 @@ SetLocal IMPOSSIBLE_SET; // Tracks assignments to locals, assuming single-assignment form, i.e., // each assignment creates a new variable. -struct SSAify : public WalkerPass<PostWalker<SSAify>> { +struct SSAify : public Pass { bool isFunctionParallel() override { return true; } Pass* create() override { return new SSAify; } - // the set_locals relevant for an index or a get. we use - // as set as merges of control flow mean more than 1 may - // be relevant; we create a phi on demand when necessary for those - typedef std::set<SetLocal*> Sets; - - // we map (old local index) => the set_locals for that index. - // a nullptr set means there is a virtual set, from a param - // initial value or the zero init initial value. - typedef std::vector<Sets> Mapping; - - Index numLocals; - Mapping currMapping; - Index nextIndex; - std::vector<Mapping> mappingStack; // used in ifs, loops - std::map<Name, std::vector<Mapping>> breakMappings; // break target => infos that reach it - std::vector<std::vector<GetLocal*>> loopGetStack; // stack of loops, all the gets in each, so we can update them for back branches + Module* module; + Function* func; std::vector<Expression*> functionPrepends; // things we add to the function prologue - std::map<GetLocal*, Sets> getSetses; // the sets for each get - std::map<GetLocal*, Expression**> getLocations; - void doWalkFunction(Function* func) { - numLocals = func->getNumLocals(); - if (numLocals == 0) return; // nothing to do - // We begin with each param being assigned from the incoming value, and the zero-init for the locals, - // so the initial state is the identity permutation - currMapping.resize(numLocals); - for (auto& set : currMapping) { - set = { nullptr }; - } - nextIndex = numLocals; - WalkerPass<PostWalker<SSAify>>::walk(func->body); - // apply - we now know the sets for each get - computeGetsAndPhis(); - // add prepends - if (functionPrepends.size() > 0) { - Builder builder(*getModule()); - auto* block = builder.makeBlock(); - for (auto* pre : functionPrepends) { - block->list.push_back(pre); + void runFunction(PassRunner* runner, Module* module_, Function* func_) override { + module = module_; + func = func_; + LocalGraph graph(func, module); + // create new local indexes, one for each set + createNewIndexes(graph); + // we now know the sets for each get + computeGetsAndPhis(graph); + // add prepends to function + addPrepends(); + } + + void createNewIndexes(LocalGraph& graph) { + for (auto& pair : graph.locations) { + auto* curr = pair.first; + if (auto* set = curr->dynCast<SetLocal>()) { + set->index = addLocal(func->getLocalType(set->index)); } - block->list.push_back(func->body); - block->finalize(func->body->type); - func->body = block; } } - // control flow - - void visitBlock(Block* curr) { - if (curr->name.is() && breakMappings.find(curr->name) != breakMappings.end()) { - auto& infos = breakMappings[curr->name]; - infos.emplace_back(std::move(currMapping)); - currMapping = std::move(merge(infos)); - breakMappings.erase(curr->name); - } - } - - void finishIf() { - // that's it for this if, merge - std::vector<Mapping> breaks; - breaks.emplace_back(std::move(currMapping)); - breaks.emplace_back(std::move(mappingStack.back())); - mappingStack.pop_back(); - currMapping = std::move(merge(breaks)); - } - - static void afterIfCondition(SSAify* self, Expression** currp) { - self->mappingStack.push_back(self->currMapping); - } - static void afterIfTrue(SSAify* self, Expression** currp) { - auto* curr = (*currp)->cast<If>(); - if (curr->ifFalse) { - auto afterCondition = std::move(self->mappingStack.back()); - self->mappingStack.back() = std::move(self->currMapping); - self->currMapping = std::move(afterCondition); - } else { - self->finishIf(); - } - } - static void afterIfFalse(SSAify* self, Expression** currp) { - self->finishIf(); - } - static void beforeLoop(SSAify* self, Expression** currp) { - // save the state before entering the loop, for calculation later of the merge at the loop top - self->mappingStack.push_back(self->currMapping); - self->loopGetStack.push_back({}); - } - void visitLoop(Loop* curr) { - if (curr->name.is() && breakMappings.find(curr->name) != breakMappings.end()) { - auto& infos = breakMappings[curr->name]; - infos.emplace_back(std::move(mappingStack.back())); - auto before = infos.back(); - auto& merged = merge(infos); - // every local we created a phi for requires us to update get_local operations in - // the loop - the branch back has means that gets in the loop have potentially - // more sets reaching them. - // we can detect this as follows: if a get of oldIndex has the same sets - // as the sets at the entrance to the loop, then it is affected by the loop - // header sets, and we can add to there sets that looped back - auto linkLoopTop = [&](Index i, Sets& getSets) { - auto& beforeSets = before[i]; - if (getSets.size() < beforeSets.size()) { - // the get trivially has fewer sets, so it overrode the loop entry sets - return; - } - std::vector<SetLocal*> intersection; - std::set_intersection(beforeSets.begin(), beforeSets.end(), - getSets.begin(), getSets.end(), - std::back_inserter(intersection)); - if (intersection.size() < beforeSets.size()) { - // the get has not the same sets as in the loop entry - return; - } - // the get has the entry sets, so add any new ones - for (auto* set : merged[i]) { - getSets.insert(set); - } - }; - auto& gets = loopGetStack.back(); - for (auto* get : gets) { - linkLoopTop(get->index, getSetses[get]); - } - // and the same for the loop fallthrough: any local that still has the - // entry sets should also have the loop-back sets as well - for (Index i = 0; i < numLocals; i++) { - linkLoopTop(i, currMapping[i]); - } - // finally, breaks still in flight must be updated too - for (auto& iter : breakMappings) { - auto name = iter.first; - if (name == curr->name) continue; // skip our own (which is still in use) - auto& mappings = iter.second; - for (auto& mapping : mappings) { - for (Index i = 0; i < numLocals; i++) { - linkLoopTop(i, mapping[i]); - } - } - } - // now that we are done with using the mappings, erase our own - breakMappings.erase(curr->name); - } - mappingStack.pop_back(); - loopGetStack.pop_back(); - } - void visitBreak(Break* curr) { - if (curr->condition) { - breakMappings[curr->name].emplace_back(currMapping); - } else { - breakMappings[curr->name].emplace_back(std::move(currMapping)); - setUnreachable(currMapping); - } - } - void visitSwitch(Switch* curr) { - std::set<Name> all; - for (auto target : curr->targets) { - all.insert(target); - } - all.insert(curr->default_); - for (auto target : all) { - breakMappings[target].emplace_back(currMapping); - } - setUnreachable(currMapping); - } - void visitReturn(Return *curr) { - setUnreachable(currMapping); - } - void visitUnreachable(Unreachable *curr) { - setUnreachable(currMapping); - } - - // local usage - - void visitGetLocal(GetLocal* curr) { - assert(currMapping.size() == numLocals); - assert(curr->index < numLocals); - for (auto& loopGets : loopGetStack) { - loopGets.push_back(curr); - } - // current sets are our sets - getSetses[curr] = currMapping[curr->index]; - getLocations[curr] = getCurrentPointer(); - } - void visitSetLocal(SetLocal* curr) { - assert(currMapping.size() == numLocals); - assert(curr->index < numLocals); - // current sets are just this set - currMapping[curr->index] = { curr }; // TODO optimize? - curr->index = addLocal(getFunction()->getLocalType(curr->index)); - } - - // traversal - - static void scan(SSAify* self, Expression** currp) { - if (auto* iff = (*currp)->dynCast<If>()) { - // if needs special handling - if (iff->ifFalse) { - self->pushTask(SSAify::afterIfFalse, currp); - self->pushTask(SSAify::scan, &iff->ifFalse); - } - self->pushTask(SSAify::afterIfTrue, currp); - self->pushTask(SSAify::scan, &iff->ifTrue); - self->pushTask(SSAify::afterIfCondition, currp); - self->pushTask(SSAify::scan, &iff->condition); - } else { - WalkerPass<PostWalker<SSAify>>::scan(self, currp); - } - - // loops need pre-order visiting too - if ((*currp)->is<Loop>()) { - self->pushTask(SSAify::beforeLoop, currp); - } - } - - // helpers - - void setUnreachable(Mapping& mapping) { - mapping.resize(numLocals); // may have been emptied by a move - mapping[0].clear(); - } - - bool isUnreachable(Mapping& mapping) { - // we must have some set for each index, if only the zero init, so empty means we emptied it for unreachable code - return mapping[0].empty(); - } - - // merges a bunch of infos into one. - // if we need phis, writes them into the provided vector. the caller should - // ensure those are placed in the right location - Mapping& merge(std::vector<Mapping>& mappings) { - assert(mappings.size() > 0); - auto& out = mappings[0]; - if (mappings.size() == 1) { - return out; - } - // merge into the first - for (Index j = 1; j < mappings.size(); j++) { - auto& other = mappings[j]; - for (Index i = 0; i < numLocals; i++) { - auto& outSets = out[i]; - for (auto* set : other[i]) { - outSets.insert(set); - } - } - } - return out; - } - // After we traversed it all, we can compute gets and phis - void computeGetsAndPhis() { - for (auto& iter : getSetses) { + void computeGetsAndPhis(LocalGraph& graph) { + for (auto& iter : graph.getSetses) { auto* get = iter.first; auto& sets = iter.second; if (sets.size() == 0) { @@ -312,11 +91,11 @@ struct SSAify : public WalkerPass<PostWalker<SSAify>> { get->index = set->index; } else { // no set, assign param or zero - if (getFunction()->isParam(get->index)) { + if (func->isParam(get->index)) { // leave it, it's fine } else { // zero it out - (*getLocations[get]) = LiteralUtils::makeZero(get->type, *getModule()); + (*graph.locations[get]) = LiteralUtils::makeZero(get->type, *module); } } continue; @@ -354,7 +133,7 @@ struct SSAify : public WalkerPass<PostWalker<SSAify>> { auto new_ = addLocal(get->type); auto old = get->index; get->index = new_; - Builder builder(*getModule()); + Builder builder(*module); // write to the local in each of our sets for (auto* set : sets) { if (set) { @@ -365,12 +144,12 @@ struct SSAify : public WalkerPass<PostWalker<SSAify>> { ); } else { // this is a param or the zero init value. - if (getFunction()->isParam(old)) { + if (func->isParam(old)) { // we add a set with the proper // param value at the beginning of the function auto* set = builder.makeSetLocal( new_, - builder.makeGetLocal(old, getFunction()->getLocalType(old)) + builder.makeGetLocal(old, func->getLocalType(old)) ); functionPrepends.push_back(set); } else { @@ -383,7 +162,20 @@ struct SSAify : public WalkerPass<PostWalker<SSAify>> { } Index addLocal(WasmType type) { - return Builder::addVar(getFunction(), type); + return Builder::addVar(func, type); + } + + void addPrepends() { + if (functionPrepends.size() > 0) { + Builder builder(*module); + auto* block = builder.makeBlock(); + for (auto* pre : functionPrepends) { + block->list.push_back(pre); + } + block->list.push_back(func->body); + block->finalize(func->body->type); + func->body = block; + } } }; diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp index 4ec454a50..a208a03dd 100644 --- a/src/passes/pass.cpp +++ b/src/passes/pass.cpp @@ -89,6 +89,7 @@ void PassRegistry::registerPasses() { registerPass("pick-load-signs", "pick load signs based on their uses", createPickLoadSignsPass); registerPass("post-emscripten", "miscellaneous optimizations for Emscripten-generated code", createPostEmscriptenPass); registerPass("precompute", "computes compile-time evaluatable expressions", createPrecomputePass); + registerPass("precompute-propagate", "computes compile-time evaluatable expressions and propagates them through locals", createPrecomputePropagatePass); registerPass("print", "print in s-expression format", createPrinterPass); registerPass("print-minified", "print in minified s-expression format", createMinifiedPrinterPass); registerPass("print-full", "print in full s-expression format", createFullPrinterPass); @@ -148,7 +149,12 @@ void PassRunner::addDefaultFunctionOptimizationPasses() { add("remove-unused-brs"); // coalesce-locals opens opportunities for optimizations add("merge-blocks"); // clean up remove-unused-brs new blocks add("optimize-instructions"); - add("precompute"); + // if we are willing to work hard, also propagate + if (options.optimizeLevel >= 3 || options.shrinkLevel >= 2) { + add("precompute-propagate"); + } else { + add("precompute"); + } if (options.shrinkLevel >= 2) { add("local-cse"); // TODO: run this early, before first coalesce-locals. right now doing so uncovers some deficiencies we need to fix first add("coalesce-locals"); // just for localCSE diff --git a/src/passes/passes.h b/src/passes/passes.h index 4e039f4bf..a02216083 100644 --- a/src/passes/passes.h +++ b/src/passes/passes.h @@ -49,6 +49,7 @@ Pass *createOptimizeInstructionsPass(); Pass *createPickLoadSignsPass(); Pass *createPostEmscriptenPass(); Pass *createPrecomputePass(); +Pass *createPrecomputePropagatePass(); Pass *createPrinterPass(); Pass *createPrintCallGraphPass(); Pass *createRelooperJumpThreadingPass(); diff --git a/src/tools/execution-results.h b/src/tools/execution-results.h index 5ed0ff01f..1fee4ffb0 100644 --- a/src/tools/execution-results.h +++ b/src/tools/execution-results.h @@ -23,18 +23,6 @@ namespace wasm { -static bool areBitwiseEqual(Literal a, Literal b) { - if (a == b) return true; - // accept equal nans if equal in all bits - if (a.type != b.type) return false; - if (a.type == f32) { - return a.reinterpreti32() == b.reinterpreti32(); - } else if (a.type == f64) { - return a.reinterpreti64() == b.reinterpreti64(); - } - return false; -} - // gets execution results from a wasm module. this is useful for fuzzing // // we can only get results when there are no imports. we then call each method @@ -86,7 +74,7 @@ struct ExecutionResults { abort(); } std::cout << "[fuzz-exec] comparing " << name << '\n'; - if (!areBitwiseEqual(results[name], other.results[name])) { + if (!results[name].bitwiseEqual(other.results[name])) { std::cout << "not identical!\n"; abort(); } diff --git a/src/tools/wasm-shell.cpp b/src/tools/wasm-shell.cpp index b7fde3af6..cd8c27437 100644 --- a/src/tools/wasm-shell.cpp +++ b/src/tools/wasm-shell.cpp @@ -201,14 +201,14 @@ static void run_asserts(Name moduleName, size_t* i, bool* checked, Module* wasm, ->dynCast<Const>() ->value; std::cerr << "seen " << result << ", expected " << expected << '\n'; - if (!areBitwiseEqual(expected, result)) { + if (!expected.bitwiseEqual(result)) { std::cout << "unexpected, should be identical\n"; abort(); } } else { Literal expected; std::cerr << "seen " << result << ", expected " << expected << '\n'; - if (!areBitwiseEqual(expected, result)) { + if (!expected.bitwiseEqual(result)) { std::cout << "unexpected, should be identical\n"; abort(); } diff --git a/src/wasm/literal.cpp b/src/wasm/literal.cpp index be1363dde..de89a7ac1 100644 --- a/src/wasm/literal.cpp +++ b/src/wasm/literal.cpp @@ -93,6 +93,12 @@ bool Literal::operator!=(const Literal& other) const { return !(*this == other); } +bool Literal::bitwiseEqual(const Literal& other) const { + if (type != other.type) return false; + if (type == none) return true; + return getBits() == other.getBits(); +} + uint32_t Literal::NaNPayload(float f) { assert(std::isnan(f) && "expected a NaN"); // SEEEEEEE EFFFFFFF FFFFFFFF FFFFFFFF |