summaryrefslogtreecommitdiff
path: root/src/passes/Precompute.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/passes/Precompute.cpp')
-rw-r--r--src/passes/Precompute.cpp148
1 files changed, 139 insertions, 9 deletions
diff --git a/src/passes/Precompute.cpp b/src/passes/Precompute.cpp
index c4702fdeb..72292c730 100644
--- a/src/passes/Precompute.cpp
+++ b/src/passes/Precompute.cpp
@@ -23,15 +23,24 @@
#include <wasm-builder.h>
#include <wasm-interpreter.h>
#include <ast_utils.h>
-#include "ast/manipulation.h"
+#include <ast/literal-utils.h>
+#include <ast/local-graph.h>
+#include <ast/manipulation.h>
namespace wasm {
static const Name NONSTANDALONE_FLOW("Binaryen|nonstandalone");
+typedef std::unordered_map<GetLocal*, Literal> GetValues;
+
// Execute an expression by itself. Errors if we hit anything we need anything not in the expression itself standalone.
class StandaloneExpressionRunner : public ExpressionRunner<StandaloneExpressionRunner> {
+ // map gets to constant values, if they are known to be constant
+ GetValues& getValues;
+
public:
+ StandaloneExpressionRunner(GetValues& getValues) : getValues(getValues) {}
+
struct NonstandaloneException {}; // TODO: use a flow with a special name, as this is likely very slow
Flow visitLoop(Loop* curr) {
@@ -50,6 +59,13 @@ public:
return Flow(NONSTANDALONE_FLOW);
}
Flow visitGetLocal(GetLocal *curr) {
+ auto iter = getValues.find(curr);
+ if (iter != getValues.end()) {
+ auto value = iter->second;
+ if (value.isConcrete()) {
+ return Flow(value);
+ }
+ }
return Flow(NONSTANDALONE_FLOW);
}
Flow visitSetLocal(SetLocal *curr) {
@@ -85,17 +101,30 @@ public:
struct Precompute : public WalkerPass<PostWalker<Precompute, UnifiedExpressionVisitor<Precompute>>> {
bool isFunctionParallel() override { return true; }
- Pass* create() override { return new Precompute; }
+ Pass* create() override { return new Precompute(propagate); }
+
+ bool propagate = false;
+
+ Precompute(bool propagate) : propagate(propagate) {}
+
+ GetValues getValues;
+
+ void doWalkFunction(Function* func) {
+ // with extra effort, we can utilize the get-set graph to precompute
+ // things that use locals that are known to be constant. otherwise,
+ // we just look at what is immediately before us
+ if (propagate) {
+ optimizeLocals(func, getModule());
+ }
+ // do the main and final walk over everything
+ WalkerPass<PostWalker<Precompute, UnifiedExpressionVisitor<Precompute>>>::doWalkFunction(func);
+ }
void visitExpression(Expression* curr) {
+ // TODO: if get_local, only replace with a constant if we don't care about size...?
if (curr->is<Const>() || curr->is<Nop>()) return;
// try to evaluate this into a const
- Flow flow;
- try {
- flow = StandaloneExpressionRunner().visit(curr);
- } catch (StandaloneExpressionRunner::NonstandaloneException& e) {
- return;
- }
+ Flow flow = precomputeFlow(curr);
if (flow.breaking()) {
if (flow.breakTo == NONSTANDALONE_FLOW) return;
if (flow.breakTo == RETURN_FLOW) {
@@ -157,10 +186,111 @@ struct Precompute : public WalkerPass<PostWalker<Precompute, UnifiedExpressionVi
// removing breaks can alter types
ReFinalize().walkFunctionInModule(curr, getModule());
}
+
+private:
+ Flow precomputeFlow(Expression* curr) {
+ try {
+ return StandaloneExpressionRunner(getValues).visit(curr);
+ } catch (StandaloneExpressionRunner::NonstandaloneException& e) {
+ return Flow(NONSTANDALONE_FLOW);
+ }
+ }
+
+ Literal precomputeValue(Expression* curr) {
+ Flow flow = precomputeFlow(curr);
+ if (flow.breaking()) {
+ return Literal();
+ }
+ return flow.value;
+ }
+
+ void optimizeLocals(Function* func, Module* module) {
+ // using the graph of get-set interactions, do a constant-propagation type
+ // operation: note which sets are assigned locals, then see if that lets us
+ // compute other sets as locals (since some of the gets they read may be
+ // constant).
+ // compute all dependencies
+ LocalGraph localGraph(func, module);
+ localGraph.computeInfluences();
+ // prepare the work list. we add things here that might change to a constant
+ // initially, that means everything
+ std::unordered_set<Expression*> work;
+ for (auto& pair : localGraph.locations) {
+ auto* curr = pair.first;
+ work.insert(curr);
+ }
+ std::unordered_map<SetLocal*, Literal> setValues; // the constant value, or none if not a constant
+ // propagate constant values
+ while (!work.empty()) {
+ auto iter = work.begin();
+ auto* curr = *iter;
+ work.erase(iter);
+ // see if this set or get is actually a constant value, and if so,
+ // mark it as such and add everything it influences to the work list,
+ // as they may be constant too.
+ if (auto* set = curr->dynCast<SetLocal>()) {
+ if (setValues[set].isConcrete()) continue; // already known constant
+ auto value = setValues[set] = precomputeValue(set->value);
+ if (value.isConcrete()) {
+ for (auto* get : localGraph.setInfluences[set]) {
+ work.insert(get);
+ }
+ }
+ } else {
+ auto* get = curr->cast<GetLocal>();
+ if (getValues[get].isConcrete()) continue; // already known constant
+ // for this get to have constant value, all sets must agree
+ Literal value;
+ bool first = true;
+ for (auto* set : localGraph.getSetses[get]) {
+ Literal curr;
+ if (set == nullptr) {
+ if (getFunction()->isVar(get->index)) {
+ curr = LiteralUtils::makeLiteralZero(getFunction()->getLocalType(get->index));
+ } else {
+ // it's a param, so it's hopeless
+ value = Literal();
+ break;
+ }
+ } else {
+ curr = setValues[set];
+ }
+ if (curr.isNull()) {
+ // not a constant, give up
+ value = Literal();
+ break;
+ }
+ // we found a concrete value. compare with the current one
+ if (first) {
+ value = curr; // this is the first
+ first = false;
+ } else {
+ if (!value.bitwiseEqual(curr)) {
+ // not the same, give up
+ value = Literal();
+ break;
+ }
+ }
+ }
+ // we may have found a value
+ if (value.isConcrete()) {
+ // we did!
+ getValues[get] = value;
+ for (auto* set : localGraph.getInfluences[get]) {
+ work.insert(set);
+ }
+ }
+ }
+ }
+ }
};
Pass *createPrecomputePass() {
- return new Precompute();
+ return new Precompute(false);
+}
+
+Pass *createPrecomputePropagatePass() {
+ return new Precompute(true);
}
} // namespace wasm