summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/ast/count.h50
-rw-r--r--src/passes/SimplifyLocals.cpp26
2 files changed, 57 insertions, 19 deletions
diff --git a/src/ast/count.h b/src/ast/count.h
new file mode 100644
index 000000000..56f281ce6
--- /dev/null
+++ b/src/ast/count.h
@@ -0,0 +1,50 @@
+/*
+ * Copyright 2016 WebAssembly Community Group participants
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef wasm_ast_count_h
+#define wasm_ast_count_h
+
+namespace wasm {
+
+struct GetLocalCounter : public PostWalker<GetLocalCounter, Visitor<GetLocalCounter>> {
+ std::vector<Index> num;
+
+ GetLocalCounter() {}
+ GetLocalCounter(Function* func) {
+ analyze(func, func->body);
+ }
+ GetLocalCounter(Function* func, Expression* ast) {
+ analyze(func, ast);
+ }
+
+ void analyze(Function* func) {
+ analyze(func, func->body);
+ }
+ void analyze(Function* func, Expression* ast) {
+ num.resize(func->getNumLocals());
+ std::fill(num.begin(), num.end(), 0);
+ walk(ast);
+ }
+
+ void visitGetLocal(GetLocal *curr) {
+ num[curr->index]++;
+ }
+};
+
+} // namespace wasm
+
+#endif // wasm_ast_count_h
+
diff --git a/src/passes/SimplifyLocals.cpp b/src/passes/SimplifyLocals.cpp
index 17206af89..7fef53dcc 100644
--- a/src/passes/SimplifyLocals.cpp
+++ b/src/passes/SimplifyLocals.cpp
@@ -37,19 +37,12 @@
#include <wasm-traversal.h>
#include <pass.h>
#include <ast_utils.h>
+#include <ast/count.h>
namespace wasm {
// Helper classes
-struct GetLocalCounter : public PostWalker<GetLocalCounter, Visitor<GetLocalCounter>> {
- std::vector<Index>* numGetLocals;
-
- void visitGetLocal(GetLocal *curr) {
- (*numGetLocals)[curr->index]++;
- }
-};
-
struct SetLocalRemover : public PostWalker<SetLocalRemover, Visitor<SetLocalRemover>> {
std::vector<Index>* numGetLocals;
@@ -118,7 +111,7 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals,
bool firstCycle;
// local => # of get_locals for it
- std::vector<Index> numGetLocals;
+ GetLocalCounter counter;
static void doNoteNonLinear(SimplifyLocals* self, Expression** currp) {
auto* curr = *currp;
@@ -195,7 +188,7 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals,
auto* set = (*found->second.item)->cast<SetLocal>();
if (firstCycle) {
// just one get_local of this, so just sink the value
- assert(numGetLocals[curr->index] == 1);
+ assert(counter.num[curr->index] == 1);
replaceCurrent(set->value);
} else {
replaceCurrent(set);
@@ -271,7 +264,7 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals,
self->checkInvalidations(effects);
}
- if (set && !set->isTee() && (!self->firstCycle || self->numGetLocals[set->index] == 1)) {
+ if (set && !set->isTee() && (!self->firstCycle || self->counter.num[set->index] == 1)) {
Index index = set->index;
assert(self->sinkables.count(index) == 0);
self->sinkables.emplace(std::make_pair(index, SinkableInfo(currp)));
@@ -422,11 +415,7 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals,
void doWalkFunction(Function* func) {
// scan get_locals
- numGetLocals.resize(func->getNumLocals());
- std::fill(numGetLocals.begin(), numGetLocals.end(), 0);
- GetLocalCounter counter;
- counter.numGetLocals = &numGetLocals;
- counter.walkFunction(func);
+ counter.analyze(func);
// multiple passes may be required per function, consider this:
// x = load
// y = store
@@ -479,11 +468,10 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals,
// for a local with no remaining gets, in which case, we can
// remove the set.
// First, recount get_locals
- std::fill(numGetLocals.begin(), numGetLocals.end(), 0);
- counter.walkFunction(func);
+ counter.analyze(func);
// Second, remove unneeded sets
SetLocalRemover remover;
- remover.numGetLocals = &numGetLocals;
+ remover.numGetLocals = &counter.num;
remover.walkFunction(func);
}
};