summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/ast/CMakeLists.txt1
-rw-r--r--src/ast/LocalGraph.cpp260
-rw-r--r--src/ast/find_all.h48
-rw-r--r--src/ast/local-graph.h111
-rw-r--r--src/literal.h6
-rw-r--r--src/passes/Inlining.cpp1
-rw-r--r--src/passes/Precompute.cpp148
-rw-r--r--src/passes/SSAify.cpp292
-rw-r--r--src/passes/pass.cpp8
-rw-r--r--src/passes/passes.h1
-rw-r--r--src/tools/execution-results.h14
-rw-r--r--src/tools/wasm-shell.cpp4
-rw-r--r--src/wasm/literal.cpp6
13 files changed, 624 insertions, 276 deletions
diff --git a/src/ast/CMakeLists.txt b/src/ast/CMakeLists.txt
index e48e84eed..c01deaaaf 100644
--- a/src/ast/CMakeLists.txt
+++ b/src/ast/CMakeLists.txt
@@ -1,5 +1,6 @@
SET(ast_SOURCES
ExpressionAnalyzer.cpp
ExpressionManipulator.cpp
+ LocalGraph.cpp
)
ADD_LIBRARY(ast STATIC ${ast_SOURCES})
diff --git a/src/ast/LocalGraph.cpp b/src/ast/LocalGraph.cpp
new file mode 100644
index 000000000..c997eff1b
--- /dev/null
+++ b/src/ast/LocalGraph.cpp
@@ -0,0 +1,260 @@
+/*
+ * Copyright 2017 WebAssembly Community Group participants
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <iterator>
+
+#include <wasm-builder.h>
+#include <ast/find_all.h>
+#include <ast/local-graph.h>
+
+namespace wasm {
+
+LocalGraph::LocalGraph(Function* func, Module* module) {
+ walkFunctionInModule(func, module);
+}
+
+void LocalGraph::computeInfluences() {
+ for (auto& pair : locations) {
+ auto* curr = pair.first;
+ if (auto* set = curr->dynCast<SetLocal>()) {
+ FindAll<GetLocal> findAll(set->value);
+ for (auto* get : findAll.list) {
+ getInfluences[get].insert(set);
+ }
+ } else {
+ auto* get = curr->cast<GetLocal>();
+ for (auto* set : getSetses[get]) {
+ setInfluences[set].insert(get);
+ }
+ }
+ }
+}
+
+void LocalGraph::doWalkFunction(Function* func) {
+ numLocals = func->getNumLocals();
+ if (numLocals == 0) return; // nothing to do
+ // We begin with each param being assigned from the incoming value, and the zero-init for the locals,
+ // so the initial state is the identity permutation
+ currMapping.resize(numLocals);
+ for (auto& set : currMapping) {
+ set = { nullptr };
+ }
+ PostWalker<LocalGraph>::walk(func->body);
+}
+
+// control flow
+
+void LocalGraph::visitBlock(Block* curr) {
+ if (curr->name.is() && breakMappings.find(curr->name) != breakMappings.end()) {
+ auto& infos = breakMappings[curr->name];
+ infos.emplace_back(std::move(currMapping));
+ currMapping = std::move(merge(infos));
+ breakMappings.erase(curr->name);
+ }
+}
+
+void LocalGraph::finishIf() {
+ // that's it for this if, merge
+ std::vector<Mapping> breaks;
+ breaks.emplace_back(std::move(currMapping));
+ breaks.emplace_back(std::move(mappingStack.back()));
+ mappingStack.pop_back();
+ currMapping = std::move(merge(breaks));
+}
+
+void LocalGraph::afterIfCondition(LocalGraph* self, Expression** currp) {
+ self->mappingStack.push_back(self->currMapping);
+}
+void LocalGraph::afterIfTrue(LocalGraph* self, Expression** currp) {
+ auto* curr = (*currp)->cast<If>();
+ if (curr->ifFalse) {
+ auto afterCondition = std::move(self->mappingStack.back());
+ self->mappingStack.back() = std::move(self->currMapping);
+ self->currMapping = std::move(afterCondition);
+ } else {
+ self->finishIf();
+ }
+}
+void LocalGraph::afterIfFalse(LocalGraph* self, Expression** currp) {
+ self->finishIf();
+}
+void LocalGraph::beforeLoop(LocalGraph* self, Expression** currp) {
+ // save the state before entering the loop, for calculation later of the merge at the loop top
+ self->mappingStack.push_back(self->currMapping);
+ self->loopGetStack.push_back({});
+}
+void LocalGraph::visitLoop(Loop* curr) {
+ if (curr->name.is() && breakMappings.find(curr->name) != breakMappings.end()) {
+ auto& infos = breakMappings[curr->name];
+ infos.emplace_back(std::move(mappingStack.back()));
+ auto before = infos.back();
+ auto& merged = merge(infos);
+ // every local we created a phi for requires us to update get_local operations in
+ // the loop - the branch back has means that gets in the loop have potentially
+ // more sets reaching them.
+ // we can detect this as follows: if a get of oldIndex has the same sets
+ // as the sets at the entrance to the loop, then it is affected by the loop
+ // header sets, and we can add to there sets that looped back
+ auto linkLoopTop = [&](Index i, Sets& getSets) {
+ auto& beforeSets = before[i];
+ if (getSets.size() < beforeSets.size()) {
+ // the get trivially has fewer sets, so it overrode the loop entry sets
+ return;
+ }
+ std::vector<SetLocal*> intersection;
+ std::set_intersection(beforeSets.begin(), beforeSets.end(),
+ getSets.begin(), getSets.end(),
+ std::back_inserter(intersection));
+ if (intersection.size() < beforeSets.size()) {
+ // the get has not the same sets as in the loop entry
+ return;
+ }
+ // the get has the entry sets, so add any new ones
+ for (auto* set : merged[i]) {
+ getSets.insert(set);
+ }
+ };
+ auto& gets = loopGetStack.back();
+ for (auto* get : gets) {
+ linkLoopTop(get->index, getSetses[get]);
+ }
+ // and the same for the loop fallthrough: any local that still has the
+ // entry sets should also have the loop-back sets as well
+ for (Index i = 0; i < numLocals; i++) {
+ linkLoopTop(i, currMapping[i]);
+ }
+ // finally, breaks still in flight must be updated too
+ for (auto& iter : breakMappings) {
+ auto name = iter.first;
+ if (name == curr->name) continue; // skip our own (which is still in use)
+ auto& mappings = iter.second;
+ for (auto& mapping : mappings) {
+ for (Index i = 0; i < numLocals; i++) {
+ linkLoopTop(i, mapping[i]);
+ }
+ }
+ }
+ // now that we are done with using the mappings, erase our own
+ breakMappings.erase(curr->name);
+ }
+ mappingStack.pop_back();
+ loopGetStack.pop_back();
+}
+void LocalGraph::visitBreak(Break* curr) {
+ if (curr->condition) {
+ breakMappings[curr->name].emplace_back(currMapping);
+ } else {
+ breakMappings[curr->name].emplace_back(std::move(currMapping));
+ setUnreachable(currMapping);
+ }
+}
+void LocalGraph::visitSwitch(Switch* curr) {
+ std::set<Name> all;
+ for (auto target : curr->targets) {
+ all.insert(target);
+ }
+ all.insert(curr->default_);
+ for (auto target : all) {
+ breakMappings[target].emplace_back(currMapping);
+ }
+ setUnreachable(currMapping);
+}
+void LocalGraph::visitReturn(Return *curr) {
+ setUnreachable(currMapping);
+}
+void LocalGraph::visitUnreachable(Unreachable *curr) {
+ setUnreachable(currMapping);
+}
+
+// local usage
+
+void LocalGraph::visitGetLocal(GetLocal* curr) {
+ assert(currMapping.size() == numLocals);
+ assert(curr->index < numLocals);
+ for (auto& loopGets : loopGetStack) {
+ loopGets.push_back(curr);
+ }
+ // current sets are our sets
+ getSetses[curr] = currMapping[curr->index];
+ locations[curr] = getCurrentPointer();
+}
+void LocalGraph::visitSetLocal(SetLocal* curr) {
+ assert(currMapping.size() == numLocals);
+ assert(curr->index < numLocals);
+ // current sets are just this set
+ currMapping[curr->index] = { curr }; // TODO optimize?
+ locations[curr] = getCurrentPointer();
+}
+
+// traversal
+
+void LocalGraph::scan(LocalGraph* self, Expression** currp) {
+ if (auto* iff = (*currp)->dynCast<If>()) {
+ // if needs special handling
+ if (iff->ifFalse) {
+ self->pushTask(LocalGraph::afterIfFalse, currp);
+ self->pushTask(LocalGraph::scan, &iff->ifFalse);
+ }
+ self->pushTask(LocalGraph::afterIfTrue, currp);
+ self->pushTask(LocalGraph::scan, &iff->ifTrue);
+ self->pushTask(LocalGraph::afterIfCondition, currp);
+ self->pushTask(LocalGraph::scan, &iff->condition);
+ } else {
+ PostWalker<LocalGraph>::scan(self, currp);
+ }
+
+ // loops need pre-order visiting too
+ if ((*currp)->is<Loop>()) {
+ self->pushTask(LocalGraph::beforeLoop, currp);
+ }
+}
+
+// helpers
+
+void LocalGraph::setUnreachable(Mapping& mapping) {
+ mapping.resize(numLocals); // may have been emptied by a move
+ mapping[0].clear();
+}
+
+bool LocalGraph::isUnreachable(Mapping& mapping) {
+ // we must have some set for each index, if only the zero init, so empty means we emptied it for unreachable code
+ return mapping[0].empty();
+}
+
+// merges a bunch of infos into one.
+// if we need phis, writes them into the provided vector. the caller should
+// ensure those are placed in the right location
+LocalGraph::Mapping& LocalGraph::merge(std::vector<Mapping>& mappings) {
+ assert(mappings.size() > 0);
+ auto& out = mappings[0];
+ if (mappings.size() == 1) {
+ return out;
+ }
+ // merge into the first
+ for (Index j = 1; j < mappings.size(); j++) {
+ auto& other = mappings[j];
+ for (Index i = 0; i < numLocals; i++) {
+ auto& outSets = out[i];
+ for (auto* set : other[i]) {
+ outSets.insert(set);
+ }
+ }
+ }
+ return out;
+}
+
+} // namespace wasm
+
diff --git a/src/ast/find_all.h b/src/ast/find_all.h
new file mode 100644
index 000000000..98fe4c5a7
--- /dev/null
+++ b/src/ast/find_all.h
@@ -0,0 +1,48 @@
+/*
+ * Copyright 2017 WebAssembly Community Group participants
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef wasm_ast_find_all_h
+#define wasm_ast_find_all_h
+
+#include <wasm-traversal.h>
+
+namespace wasm {
+
+// Find all instances of a certain node type
+
+template<typename T>
+struct FindAll {
+ std::vector<T*> list;
+
+ FindAll(Expression* ast) {
+ struct Finder : public PostWalker<Finder, UnifiedExpressionVisitor<Finder>> {
+ std::vector<T*>* list;
+ void visitExpression(Expression* curr) {
+ if (curr->is<T>()) {
+ (*list).push_back(curr->cast<T>());
+ }
+ }
+ };
+ Finder finder;
+ finder.list = &list;
+ finder.walk(ast);
+ }
+};
+
+} // namespace wasm
+
+#endif // wasm_ast_find_all_h
+
diff --git a/src/ast/local-graph.h b/src/ast/local-graph.h
new file mode 100644
index 000000000..03915da5e
--- /dev/null
+++ b/src/ast/local-graph.h
@@ -0,0 +1,111 @@
+/*
+ * Copyright 2017 WebAssembly Community Group participants
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef wasm_ast_local_graph_h
+#define wasm_ast_local_graph_h
+
+namespace wasm {
+
+//
+// Finds the connections between get_locals and set_locals, creating
+// a graph of those ties. This is useful for "ssa-style" optimization,
+// in which you want to know exactly which sets are relevant for a
+// a get, so it is as if each get has just one set, logically speaking
+// (see the SSA pass for actually creating new local indexes based
+// on this).
+//
+// TODO: the algorithm here is pretty simple, but also pretty slow,
+// we should optimize it. e.g. we rely on set_interaction
+// here, and worse we only use it to compute the size...
+struct LocalGraph : public PostWalker<LocalGraph> {
+ // main API
+
+ // the constructor computes getSetses, the sets affecting each get
+ LocalGraph(Function* func, Module* module);
+
+ // the set_locals relevant for an index or a get.
+ typedef std::set<SetLocal*> Sets;
+
+ // externally useful information
+ std::map<GetLocal*, Sets> getSetses; // the sets affecting each get. a nullptr set means the initial
+ // value (0 for a var, the received value for a param)
+ std::map<Expression*, Expression**> locations; // where each get and set is (for easy replacing)
+
+ // optional computation: compute the influence graphs between sets and gets
+ // (useful for algorithms that propagate changes)
+
+ std::unordered_map<GetLocal*, std::unordered_set<SetLocal*>> getInfluences; // for each get, the sets whose values are influenced by that get
+ std::unordered_map<SetLocal*, std::unordered_set<GetLocal*>> setInfluences; // for each set, the gets whose values are influenced by that set
+
+ void computeInfluences();
+
+private:
+ // we map local index => the set_locals for that index.
+ // a nullptr set means there is a virtual set, from a param
+ // initial value or the zero init initial value.
+ typedef std::vector<Sets> Mapping;
+
+ // internal state
+ Index numLocals;
+ Mapping currMapping;
+ std::vector<Mapping> mappingStack; // used in ifs, loops
+ std::map<Name, std::vector<Mapping>> breakMappings; // break target => infos that reach it
+ std::vector<std::vector<GetLocal*>> loopGetStack; // stack of loops, all the gets in each, so we can update them for back branches
+
+public:
+ void doWalkFunction(Function* func);
+
+ // control flow
+
+ void visitBlock(Block* curr);
+
+ void finishIf();
+
+ static void afterIfCondition(LocalGraph* self, Expression** currp);
+ static void afterIfTrue(LocalGraph* self, Expression** currp);
+ static void afterIfFalse(LocalGraph* self, Expression** currp);
+ static void beforeLoop(LocalGraph* self, Expression** currp);
+ void visitLoop(Loop* curr);
+ void visitBreak(Break* curr);
+ void visitSwitch(Switch* curr);
+ void visitReturn(Return *curr);
+ void visitUnreachable(Unreachable *curr);
+
+ // local usage
+
+ void visitGetLocal(GetLocal* curr);
+ void visitSetLocal(SetLocal* curr);
+
+ // traversal
+
+ static void scan(LocalGraph* self, Expression** currp);
+
+ // helpers
+
+ void setUnreachable(Mapping& mapping);
+
+ bool isUnreachable(Mapping& mapping);
+
+ // merges a bunch of infos into one.
+ // if we need phis, writes them into the provided vector. the caller should
+ // ensure those are placed in the right location
+ Mapping& merge(std::vector<Mapping>& mappings);
+};
+
+} // namespace wasm
+
+#endif // wasm_ast_local_graph_h
+
diff --git a/src/literal.h b/src/literal.h
index 560895e7a..c55d645ad 100644
--- a/src/literal.h
+++ b/src/literal.h
@@ -44,7 +44,7 @@ private:
return val & (sizeof(T) * 8 - 1);
}
- public:
+public:
Literal() : type(WasmType::none), i64(0) {}
explicit Literal(WasmType type) : type(type), i64(0) {}
explicit Literal(int32_t init) : type(WasmType::i32), i32(init) {}
@@ -54,6 +54,9 @@ private:
explicit Literal(float init) : type(WasmType::f32), i32(bit_cast<int32_t>(init)) {}
explicit Literal(double init) : type(WasmType::f64), i64(bit_cast<int64_t>(init)) {}
+ bool isConcrete() { return type != none; }
+ bool isNull() { return type == none; }
+
Literal castToF32();
Literal castToF64();
Literal castToI32();
@@ -76,6 +79,7 @@ private:
int64_t getBits() const;
bool operator==(const Literal& other) const;
bool operator!=(const Literal& other) const;
+ bool bitwiseEqual(const Literal& other) const;
static uint32_t NaNPayload(float f);
static uint64_t NaNPayload(double f);
diff --git a/src/passes/Inlining.cpp b/src/passes/Inlining.cpp
index 192480d26..e5fdcbb9d 100644
--- a/src/passes/Inlining.cpp
+++ b/src/passes/Inlining.cpp
@@ -313,6 +313,7 @@ struct Inlining : public Pass {
PassRunner runner(module, parentRunner->options);
runner.setIsNested(true);
runner.setValidateGlobally(false); // not a full valid module
+ runner.add("precompute-propagate");
runner.add("remove-unused-brs");
runner.add("remove-unused-names");
runner.add("coalesce-locals");
diff --git a/src/passes/Precompute.cpp b/src/passes/Precompute.cpp
index c4702fdeb..72292c730 100644
--- a/src/passes/Precompute.cpp
+++ b/src/passes/Precompute.cpp
@@ -23,15 +23,24 @@
#include <wasm-builder.h>
#include <wasm-interpreter.h>
#include <ast_utils.h>
-#include "ast/manipulation.h"
+#include <ast/literal-utils.h>
+#include <ast/local-graph.h>
+#include <ast/manipulation.h>
namespace wasm {
static const Name NONSTANDALONE_FLOW("Binaryen|nonstandalone");
+typedef std::unordered_map<GetLocal*, Literal> GetValues;
+
// Execute an expression by itself. Errors if we hit anything we need anything not in the expression itself standalone.
class StandaloneExpressionRunner : public ExpressionRunner<StandaloneExpressionRunner> {
+ // map gets to constant values, if they are known to be constant
+ GetValues& getValues;
+
public:
+ StandaloneExpressionRunner(GetValues& getValues) : getValues(getValues) {}
+
struct NonstandaloneException {}; // TODO: use a flow with a special name, as this is likely very slow
Flow visitLoop(Loop* curr) {
@@ -50,6 +59,13 @@ public:
return Flow(NONSTANDALONE_FLOW);
}
Flow visitGetLocal(GetLocal *curr) {
+ auto iter = getValues.find(curr);
+ if (iter != getValues.end()) {
+ auto value = iter->second;
+ if (value.isConcrete()) {
+ return Flow(value);
+ }
+ }
return Flow(NONSTANDALONE_FLOW);
}
Flow visitSetLocal(SetLocal *curr) {
@@ -85,17 +101,30 @@ public:
struct Precompute : public WalkerPass<PostWalker<Precompute, UnifiedExpressionVisitor<Precompute>>> {
bool isFunctionParallel() override { return true; }
- Pass* create() override { return new Precompute; }
+ Pass* create() override { return new Precompute(propagate); }
+
+ bool propagate = false;
+
+ Precompute(bool propagate) : propagate(propagate) {}
+
+ GetValues getValues;
+
+ void doWalkFunction(Function* func) {
+ // with extra effort, we can utilize the get-set graph to precompute
+ // things that use locals that are known to be constant. otherwise,
+ // we just look at what is immediately before us
+ if (propagate) {
+ optimizeLocals(func, getModule());
+ }
+ // do the main and final walk over everything
+ WalkerPass<PostWalker<Precompute, UnifiedExpressionVisitor<Precompute>>>::doWalkFunction(func);
+ }
void visitExpression(Expression* curr) {
+ // TODO: if get_local, only replace with a constant if we don't care about size...?
if (curr->is<Const>() || curr->is<Nop>()) return;
// try to evaluate this into a const
- Flow flow;
- try {
- flow = StandaloneExpressionRunner().visit(curr);
- } catch (StandaloneExpressionRunner::NonstandaloneException& e) {
- return;
- }
+ Flow flow = precomputeFlow(curr);
if (flow.breaking()) {
if (flow.breakTo == NONSTANDALONE_FLOW) return;
if (flow.breakTo == RETURN_FLOW) {
@@ -157,10 +186,111 @@ struct Precompute : public WalkerPass<PostWalker<Precompute, UnifiedExpressionVi
// removing breaks can alter types
ReFinalize().walkFunctionInModule(curr, getModule());
}
+
+private:
+ Flow precomputeFlow(Expression* curr) {
+ try {
+ return StandaloneExpressionRunner(getValues).visit(curr);
+ } catch (StandaloneExpressionRunner::NonstandaloneException& e) {
+ return Flow(NONSTANDALONE_FLOW);
+ }
+ }
+
+ Literal precomputeValue(Expression* curr) {
+ Flow flow = precomputeFlow(curr);
+ if (flow.breaking()) {
+ return Literal();
+ }
+ return flow.value;
+ }
+
+ void optimizeLocals(Function* func, Module* module) {
+ // using the graph of get-set interactions, do a constant-propagation type
+ // operation: note which sets are assigned locals, then see if that lets us
+ // compute other sets as locals (since some of the gets they read may be
+ // constant).
+ // compute all dependencies
+ LocalGraph localGraph(func, module);
+ localGraph.computeInfluences();
+ // prepare the work list. we add things here that might change to a constant
+ // initially, that means everything
+ std::unordered_set<Expression*> work;
+ for (auto& pair : localGraph.locations) {
+ auto* curr = pair.first;
+ work.insert(curr);
+ }
+ std::unordered_map<SetLocal*, Literal> setValues; // the constant value, or none if not a constant
+ // propagate constant values
+ while (!work.empty()) {
+ auto iter = work.begin();
+ auto* curr = *iter;
+ work.erase(iter);
+ // see if this set or get is actually a constant value, and if so,
+ // mark it as such and add everything it influences to the work list,
+ // as they may be constant too.
+ if (auto* set = curr->dynCast<SetLocal>()) {
+ if (setValues[set].isConcrete()) continue; // already known constant
+ auto value = setValues[set] = precomputeValue(set->value);
+ if (value.isConcrete()) {
+ for (auto* get : localGraph.setInfluences[set]) {
+ work.insert(get);
+ }
+ }
+ } else {
+ auto* get = curr->cast<GetLocal>();
+ if (getValues[get].isConcrete()) continue; // already known constant
+ // for this get to have constant value, all sets must agree
+ Literal value;
+ bool first = true;
+ for (auto* set : localGraph.getSetses[get]) {
+ Literal curr;
+ if (set == nullptr) {
+ if (getFunction()->isVar(get->index)) {
+ curr = LiteralUtils::makeLiteralZero(getFunction()->getLocalType(get->index));
+ } else {
+ // it's a param, so it's hopeless
+ value = Literal();
+ break;
+ }
+ } else {
+ curr = setValues[set];
+ }
+ if (curr.isNull()) {
+ // not a constant, give up
+ value = Literal();
+ break;
+ }
+ // we found a concrete value. compare with the current one
+ if (first) {
+ value = curr; // this is the first
+ first = false;
+ } else {
+ if (!value.bitwiseEqual(curr)) {
+ // not the same, give up
+ value = Literal();
+ break;
+ }
+ }
+ }
+ // we may have found a value
+ if (value.isConcrete()) {
+ // we did!
+ getValues[get] = value;
+ for (auto* set : localGraph.getInfluences[get]) {
+ work.insert(set);
+ }
+ }
+ }
+ }
+ }
};
Pass *createPrecomputePass() {
- return new Precompute();
+ return new Precompute(false);
+}
+
+Pass *createPrecomputePropagatePass() {
+ return new Precompute(true);
}
} // namespace wasm
diff --git a/src/passes/SSAify.cpp b/src/passes/SSAify.cpp
index 9c9aeb573..a71a75308 100644
--- a/src/passes/SSAify.cpp
+++ b/src/passes/SSAify.cpp
@@ -35,6 +35,7 @@
#include "wasm-builder.h"
#include "support/permutations.h"
#include "ast/literal-utils.h"
+#include "ast/local-graph.h"
namespace wasm {
@@ -44,261 +45,39 @@ SetLocal IMPOSSIBLE_SET;
// Tracks assignments to locals, assuming single-assignment form, i.e.,
// each assignment creates a new variable.
-struct SSAify : public WalkerPass<PostWalker<SSAify>> {
+struct SSAify : public Pass {
bool isFunctionParallel() override { return true; }
Pass* create() override { return new SSAify; }
- // the set_locals relevant for an index or a get. we use
- // as set as merges of control flow mean more than 1 may
- // be relevant; we create a phi on demand when necessary for those
- typedef std::set<SetLocal*> Sets;
-
- // we map (old local index) => the set_locals for that index.
- // a nullptr set means there is a virtual set, from a param
- // initial value or the zero init initial value.
- typedef std::vector<Sets> Mapping;
-
- Index numLocals;
- Mapping currMapping;
- Index nextIndex;
- std::vector<Mapping> mappingStack; // used in ifs, loops
- std::map<Name, std::vector<Mapping>> breakMappings; // break target => infos that reach it
- std::vector<std::vector<GetLocal*>> loopGetStack; // stack of loops, all the gets in each, so we can update them for back branches
+ Module* module;
+ Function* func;
std::vector<Expression*> functionPrepends; // things we add to the function prologue
- std::map<GetLocal*, Sets> getSetses; // the sets for each get
- std::map<GetLocal*, Expression**> getLocations;
- void doWalkFunction(Function* func) {
- numLocals = func->getNumLocals();
- if (numLocals == 0) return; // nothing to do
- // We begin with each param being assigned from the incoming value, and the zero-init for the locals,
- // so the initial state is the identity permutation
- currMapping.resize(numLocals);
- for (auto& set : currMapping) {
- set = { nullptr };
- }
- nextIndex = numLocals;
- WalkerPass<PostWalker<SSAify>>::walk(func->body);
- // apply - we now know the sets for each get
- computeGetsAndPhis();
- // add prepends
- if (functionPrepends.size() > 0) {
- Builder builder(*getModule());
- auto* block = builder.makeBlock();
- for (auto* pre : functionPrepends) {
- block->list.push_back(pre);
+ void runFunction(PassRunner* runner, Module* module_, Function* func_) override {
+ module = module_;
+ func = func_;
+ LocalGraph graph(func, module);
+ // create new local indexes, one for each set
+ createNewIndexes(graph);
+ // we now know the sets for each get
+ computeGetsAndPhis(graph);
+ // add prepends to function
+ addPrepends();
+ }
+
+ void createNewIndexes(LocalGraph& graph) {
+ for (auto& pair : graph.locations) {
+ auto* curr = pair.first;
+ if (auto* set = curr->dynCast<SetLocal>()) {
+ set->index = addLocal(func->getLocalType(set->index));
}
- block->list.push_back(func->body);
- block->finalize(func->body->type);
- func->body = block;
}
}
- // control flow
-
- void visitBlock(Block* curr) {
- if (curr->name.is() && breakMappings.find(curr->name) != breakMappings.end()) {
- auto& infos = breakMappings[curr->name];
- infos.emplace_back(std::move(currMapping));
- currMapping = std::move(merge(infos));
- breakMappings.erase(curr->name);
- }
- }
-
- void finishIf() {
- // that's it for this if, merge
- std::vector<Mapping> breaks;
- breaks.emplace_back(std::move(currMapping));
- breaks.emplace_back(std::move(mappingStack.back()));
- mappingStack.pop_back();
- currMapping = std::move(merge(breaks));
- }
-
- static void afterIfCondition(SSAify* self, Expression** currp) {
- self->mappingStack.push_back(self->currMapping);
- }
- static void afterIfTrue(SSAify* self, Expression** currp) {
- auto* curr = (*currp)->cast<If>();
- if (curr->ifFalse) {
- auto afterCondition = std::move(self->mappingStack.back());
- self->mappingStack.back() = std::move(self->currMapping);
- self->currMapping = std::move(afterCondition);
- } else {
- self->finishIf();
- }
- }
- static void afterIfFalse(SSAify* self, Expression** currp) {
- self->finishIf();
- }
- static void beforeLoop(SSAify* self, Expression** currp) {
- // save the state before entering the loop, for calculation later of the merge at the loop top
- self->mappingStack.push_back(self->currMapping);
- self->loopGetStack.push_back({});
- }
- void visitLoop(Loop* curr) {
- if (curr->name.is() && breakMappings.find(curr->name) != breakMappings.end()) {
- auto& infos = breakMappings[curr->name];
- infos.emplace_back(std::move(mappingStack.back()));
- auto before = infos.back();
- auto& merged = merge(infos);
- // every local we created a phi for requires us to update get_local operations in
- // the loop - the branch back has means that gets in the loop have potentially
- // more sets reaching them.
- // we can detect this as follows: if a get of oldIndex has the same sets
- // as the sets at the entrance to the loop, then it is affected by the loop
- // header sets, and we can add to there sets that looped back
- auto linkLoopTop = [&](Index i, Sets& getSets) {
- auto& beforeSets = before[i];
- if (getSets.size() < beforeSets.size()) {
- // the get trivially has fewer sets, so it overrode the loop entry sets
- return;
- }
- std::vector<SetLocal*> intersection;
- std::set_intersection(beforeSets.begin(), beforeSets.end(),
- getSets.begin(), getSets.end(),
- std::back_inserter(intersection));
- if (intersection.size() < beforeSets.size()) {
- // the get has not the same sets as in the loop entry
- return;
- }
- // the get has the entry sets, so add any new ones
- for (auto* set : merged[i]) {
- getSets.insert(set);
- }
- };
- auto& gets = loopGetStack.back();
- for (auto* get : gets) {
- linkLoopTop(get->index, getSetses[get]);
- }
- // and the same for the loop fallthrough: any local that still has the
- // entry sets should also have the loop-back sets as well
- for (Index i = 0; i < numLocals; i++) {
- linkLoopTop(i, currMapping[i]);
- }
- // finally, breaks still in flight must be updated too
- for (auto& iter : breakMappings) {
- auto name = iter.first;
- if (name == curr->name) continue; // skip our own (which is still in use)
- auto& mappings = iter.second;
- for (auto& mapping : mappings) {
- for (Index i = 0; i < numLocals; i++) {
- linkLoopTop(i, mapping[i]);
- }
- }
- }
- // now that we are done with using the mappings, erase our own
- breakMappings.erase(curr->name);
- }
- mappingStack.pop_back();
- loopGetStack.pop_back();
- }
- void visitBreak(Break* curr) {
- if (curr->condition) {
- breakMappings[curr->name].emplace_back(currMapping);
- } else {
- breakMappings[curr->name].emplace_back(std::move(currMapping));
- setUnreachable(currMapping);
- }
- }
- void visitSwitch(Switch* curr) {
- std::set<Name> all;
- for (auto target : curr->targets) {
- all.insert(target);
- }
- all.insert(curr->default_);
- for (auto target : all) {
- breakMappings[target].emplace_back(currMapping);
- }
- setUnreachable(currMapping);
- }
- void visitReturn(Return *curr) {
- setUnreachable(currMapping);
- }
- void visitUnreachable(Unreachable *curr) {
- setUnreachable(currMapping);
- }
-
- // local usage
-
- void visitGetLocal(GetLocal* curr) {
- assert(currMapping.size() == numLocals);
- assert(curr->index < numLocals);
- for (auto& loopGets : loopGetStack) {
- loopGets.push_back(curr);
- }
- // current sets are our sets
- getSetses[curr] = currMapping[curr->index];
- getLocations[curr] = getCurrentPointer();
- }
- void visitSetLocal(SetLocal* curr) {
- assert(currMapping.size() == numLocals);
- assert(curr->index < numLocals);
- // current sets are just this set
- currMapping[curr->index] = { curr }; // TODO optimize?
- curr->index = addLocal(getFunction()->getLocalType(curr->index));
- }
-
- // traversal
-
- static void scan(SSAify* self, Expression** currp) {
- if (auto* iff = (*currp)->dynCast<If>()) {
- // if needs special handling
- if (iff->ifFalse) {
- self->pushTask(SSAify::afterIfFalse, currp);
- self->pushTask(SSAify::scan, &iff->ifFalse);
- }
- self->pushTask(SSAify::afterIfTrue, currp);
- self->pushTask(SSAify::scan, &iff->ifTrue);
- self->pushTask(SSAify::afterIfCondition, currp);
- self->pushTask(SSAify::scan, &iff->condition);
- } else {
- WalkerPass<PostWalker<SSAify>>::scan(self, currp);
- }
-
- // loops need pre-order visiting too
- if ((*currp)->is<Loop>()) {
- self->pushTask(SSAify::beforeLoop, currp);
- }
- }
-
- // helpers
-
- void setUnreachable(Mapping& mapping) {
- mapping.resize(numLocals); // may have been emptied by a move
- mapping[0].clear();
- }
-
- bool isUnreachable(Mapping& mapping) {
- // we must have some set for each index, if only the zero init, so empty means we emptied it for unreachable code
- return mapping[0].empty();
- }
-
- // merges a bunch of infos into one.
- // if we need phis, writes them into the provided vector. the caller should
- // ensure those are placed in the right location
- Mapping& merge(std::vector<Mapping>& mappings) {
- assert(mappings.size() > 0);
- auto& out = mappings[0];
- if (mappings.size() == 1) {
- return out;
- }
- // merge into the first
- for (Index j = 1; j < mappings.size(); j++) {
- auto& other = mappings[j];
- for (Index i = 0; i < numLocals; i++) {
- auto& outSets = out[i];
- for (auto* set : other[i]) {
- outSets.insert(set);
- }
- }
- }
- return out;
- }
-
// After we traversed it all, we can compute gets and phis
- void computeGetsAndPhis() {
- for (auto& iter : getSetses) {
+ void computeGetsAndPhis(LocalGraph& graph) {
+ for (auto& iter : graph.getSetses) {
auto* get = iter.first;
auto& sets = iter.second;
if (sets.size() == 0) {
@@ -312,11 +91,11 @@ struct SSAify : public WalkerPass<PostWalker<SSAify>> {
get->index = set->index;
} else {
// no set, assign param or zero
- if (getFunction()->isParam(get->index)) {
+ if (func->isParam(get->index)) {
// leave it, it's fine
} else {
// zero it out
- (*getLocations[get]) = LiteralUtils::makeZero(get->type, *getModule());
+ (*graph.locations[get]) = LiteralUtils::makeZero(get->type, *module);
}
}
continue;
@@ -354,7 +133,7 @@ struct SSAify : public WalkerPass<PostWalker<SSAify>> {
auto new_ = addLocal(get->type);
auto old = get->index;
get->index = new_;
- Builder builder(*getModule());
+ Builder builder(*module);
// write to the local in each of our sets
for (auto* set : sets) {
if (set) {
@@ -365,12 +144,12 @@ struct SSAify : public WalkerPass<PostWalker<SSAify>> {
);
} else {
// this is a param or the zero init value.
- if (getFunction()->isParam(old)) {
+ if (func->isParam(old)) {
// we add a set with the proper
// param value at the beginning of the function
auto* set = builder.makeSetLocal(
new_,
- builder.makeGetLocal(old, getFunction()->getLocalType(old))
+ builder.makeGetLocal(old, func->getLocalType(old))
);
functionPrepends.push_back(set);
} else {
@@ -383,7 +162,20 @@ struct SSAify : public WalkerPass<PostWalker<SSAify>> {
}
Index addLocal(WasmType type) {
- return Builder::addVar(getFunction(), type);
+ return Builder::addVar(func, type);
+ }
+
+ void addPrepends() {
+ if (functionPrepends.size() > 0) {
+ Builder builder(*module);
+ auto* block = builder.makeBlock();
+ for (auto* pre : functionPrepends) {
+ block->list.push_back(pre);
+ }
+ block->list.push_back(func->body);
+ block->finalize(func->body->type);
+ func->body = block;
+ }
}
};
diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp
index 4ec454a50..a208a03dd 100644
--- a/src/passes/pass.cpp
+++ b/src/passes/pass.cpp
@@ -89,6 +89,7 @@ void PassRegistry::registerPasses() {
registerPass("pick-load-signs", "pick load signs based on their uses", createPickLoadSignsPass);
registerPass("post-emscripten", "miscellaneous optimizations for Emscripten-generated code", createPostEmscriptenPass);
registerPass("precompute", "computes compile-time evaluatable expressions", createPrecomputePass);
+ registerPass("precompute-propagate", "computes compile-time evaluatable expressions and propagates them through locals", createPrecomputePropagatePass);
registerPass("print", "print in s-expression format", createPrinterPass);
registerPass("print-minified", "print in minified s-expression format", createMinifiedPrinterPass);
registerPass("print-full", "print in full s-expression format", createFullPrinterPass);
@@ -148,7 +149,12 @@ void PassRunner::addDefaultFunctionOptimizationPasses() {
add("remove-unused-brs"); // coalesce-locals opens opportunities for optimizations
add("merge-blocks"); // clean up remove-unused-brs new blocks
add("optimize-instructions");
- add("precompute");
+ // if we are willing to work hard, also propagate
+ if (options.optimizeLevel >= 3 || options.shrinkLevel >= 2) {
+ add("precompute-propagate");
+ } else {
+ add("precompute");
+ }
if (options.shrinkLevel >= 2) {
add("local-cse"); // TODO: run this early, before first coalesce-locals. right now doing so uncovers some deficiencies we need to fix first
add("coalesce-locals"); // just for localCSE
diff --git a/src/passes/passes.h b/src/passes/passes.h
index 4e039f4bf..a02216083 100644
--- a/src/passes/passes.h
+++ b/src/passes/passes.h
@@ -49,6 +49,7 @@ Pass *createOptimizeInstructionsPass();
Pass *createPickLoadSignsPass();
Pass *createPostEmscriptenPass();
Pass *createPrecomputePass();
+Pass *createPrecomputePropagatePass();
Pass *createPrinterPass();
Pass *createPrintCallGraphPass();
Pass *createRelooperJumpThreadingPass();
diff --git a/src/tools/execution-results.h b/src/tools/execution-results.h
index 5ed0ff01f..1fee4ffb0 100644
--- a/src/tools/execution-results.h
+++ b/src/tools/execution-results.h
@@ -23,18 +23,6 @@
namespace wasm {
-static bool areBitwiseEqual(Literal a, Literal b) {
- if (a == b) return true;
- // accept equal nans if equal in all bits
- if (a.type != b.type) return false;
- if (a.type == f32) {
- return a.reinterpreti32() == b.reinterpreti32();
- } else if (a.type == f64) {
- return a.reinterpreti64() == b.reinterpreti64();
- }
- return false;
-}
-
// gets execution results from a wasm module. this is useful for fuzzing
//
// we can only get results when there are no imports. we then call each method
@@ -86,7 +74,7 @@ struct ExecutionResults {
abort();
}
std::cout << "[fuzz-exec] comparing " << name << '\n';
- if (!areBitwiseEqual(results[name], other.results[name])) {
+ if (!results[name].bitwiseEqual(other.results[name])) {
std::cout << "not identical!\n";
abort();
}
diff --git a/src/tools/wasm-shell.cpp b/src/tools/wasm-shell.cpp
index b7fde3af6..cd8c27437 100644
--- a/src/tools/wasm-shell.cpp
+++ b/src/tools/wasm-shell.cpp
@@ -201,14 +201,14 @@ static void run_asserts(Name moduleName, size_t* i, bool* checked, Module* wasm,
->dynCast<Const>()
->value;
std::cerr << "seen " << result << ", expected " << expected << '\n';
- if (!areBitwiseEqual(expected, result)) {
+ if (!expected.bitwiseEqual(result)) {
std::cout << "unexpected, should be identical\n";
abort();
}
} else {
Literal expected;
std::cerr << "seen " << result << ", expected " << expected << '\n';
- if (!areBitwiseEqual(expected, result)) {
+ if (!expected.bitwiseEqual(result)) {
std::cout << "unexpected, should be identical\n";
abort();
}
diff --git a/src/wasm/literal.cpp b/src/wasm/literal.cpp
index be1363dde..de89a7ac1 100644
--- a/src/wasm/literal.cpp
+++ b/src/wasm/literal.cpp
@@ -93,6 +93,12 @@ bool Literal::operator!=(const Literal& other) const {
return !(*this == other);
}
+bool Literal::bitwiseEqual(const Literal& other) const {
+ if (type != other.type) return false;
+ if (type == none) return true;
+ return getBits() == other.getBits();
+}
+
uint32_t Literal::NaNPayload(float f) {
assert(std::isnan(f) && "expected a NaN");
// SEEEEEEE EFFFFFFF FFFFFFFF FFFFFFFF