summaryrefslogtreecommitdiff
path: root/src/passes/RedundantSetElimination.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/passes/RedundantSetElimination.cpp')
-rw-r--r--src/passes/RedundantSetElimination.cpp374
1 files changed, 374 insertions, 0 deletions
diff --git a/src/passes/RedundantSetElimination.cpp b/src/passes/RedundantSetElimination.cpp
new file mode 100644
index 000000000..a63866111
--- /dev/null
+++ b/src/passes/RedundantSetElimination.cpp
@@ -0,0 +1,374 @@
+/*
+ * Copyright 2017 WebAssembly Community Group participants
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+//
+// Eliminate redundant set_locals: if a local already has a particular
+// value, we don't need to set it again. A common case here is loops
+// that start at zero, since the default value is initialized to
+// zero anyhow.
+//
+// A risk here is that we extend live ranges, e.g. we may use the default
+// value at the very end of a function, keeping that local alive throughout.
+// For that reason it is probably better to run this near the end of
+// optimization, and especially after coalesce-locals. A final vaccum
+// should be done after it, as this pass can leave around drop()s of
+// values no longer necessary.
+//
+// So far this tracks constant values, and for everything else it considers
+// them unique (so each set_local of a non-constant is a unique value, each
+// merge is a unique value, etc.; there is no sophisticated value numbering
+// here).
+//
+
+#include <wasm.h>
+#include <pass.h>
+#include <wasm-builder.h>
+#include <cfg/cfg-traversal.h>
+#include <ir/literal-utils.h>
+#include <ir/utils.h>
+#include <support/unique_deferring_queue.h>
+
+namespace wasm {
+
+// We do a very simple numbering of local values, just a unique
+// number for constants so far, enough to see
+// trivial duplication. LocalValues maps each local index to
+// its current value
+typedef std::vector<Index> LocalValues;
+
+// information in a basic block
+struct Info {
+ LocalValues start, end; // the local values at the start and end of the block
+ std::vector<Expression**> setps;
+};
+
+struct RedundantSetElimination : public WalkerPass<CFGWalker<RedundantSetElimination, Visitor<RedundantSetElimination>, Info>> {
+ bool isFunctionParallel() override { return true; }
+
+ Pass* create() override { return new RedundantSetElimination(); }
+
+ Index numLocals;
+
+ // cfg traversal work
+
+ static void doVisitSetLocal(RedundantSetElimination* self, Expression** currp) {
+ if (self->currBasicBlock) {
+ self->currBasicBlock->contents.setps.push_back(currp);
+ }
+ }
+
+ // main entry point
+
+ void doWalkFunction(Function* func) {
+ numLocals = func->getNumLocals();
+ // create the CFG by walking the IR
+ CFGWalker<RedundantSetElimination, Visitor<RedundantSetElimination>, Info>::doWalkFunction(func);
+ // flow values across blocks
+ flowValues(func);
+ // remove redundant sets
+ optimize();
+ }
+
+ // numbering
+
+ Index nextValue = 1; // 0 is reserved for the "unseen value"
+ std::unordered_map<Literal, Index> literalValues; // each constant has a value
+ std::unordered_map<Expression*, Index> expressionValues; // each value can have a value
+ std::unordered_map<BasicBlock*, std::unordered_map<Index, Index>> blockMergeValues; // each block has values for each merge
+
+ Index getUnseenValue() { // we haven't seen this location yet
+ return 0;
+ }
+ Index getUniqueValue() {
+#ifdef RSE_DEBUG
+ std::cout << "new unique value " << nextValue << '\n';
+#endif
+ return nextValue++;
+ }
+
+ Index getLiteralValue(Literal lit) {
+ auto iter = literalValues.find(lit);
+ if (iter != literalValues.end()) {
+ return iter->second;
+ }
+#ifdef RSE_DEBUG
+ std::cout << "new literal value for " << lit << '\n';
+#endif
+ return literalValues[lit] = getUniqueValue();
+ }
+
+ Index getExpressionValue(Expression* expr) {
+ auto iter = expressionValues.find(expr);
+ if (iter != expressionValues.end()) {
+ return iter->second;
+ }
+#ifdef RSE_DEBUG
+ std::cout << "new expr value for " << expr << '\n';
+#endif
+ return expressionValues[expr] = getUniqueValue();
+ }
+
+ Index getBlockMergeValue(BasicBlock* block, Index index) {
+ auto& mergeValues = blockMergeValues[block];
+ auto iter = mergeValues.find(index);
+ if (iter != mergeValues.end()) {
+ return iter->second;
+ }
+#ifdef RSE_DEBUG
+ std::cout << "new block-merge value for " << block << " : " << index << '\n';
+#endif
+ return mergeValues[index] = getUniqueValue();
+ }
+
+ bool isBlockMergeValue(BasicBlock* block, Index index, Index value) {
+ auto iter = blockMergeValues.find(block);
+ if (iter == blockMergeValues.end()) return false;
+ auto& mergeValues = iter->second;
+ auto iter2 = mergeValues.find(index);
+ if (iter2 == mergeValues.end()) return false;
+ return value == iter2->second;
+ }
+
+ Index getValue(Expression* value, LocalValues& currValues) {
+ if (auto* c = value->dynCast<Const>()) {
+ // a constant
+ return getLiteralValue(c->value);
+ } else if (auto* get = value->dynCast<GetLocal>()) {
+ // a copy of whatever that was
+ return currValues[get->index];
+ } else {
+ // get the value's own unique value
+ return getExpressionValue(value);
+ }
+ }
+
+ // flowing
+
+ void flowValues(Function* func) {
+ for (auto& block : basicBlocks) {
+ LocalValues& start = block->contents.start;
+ start.resize(numLocals);
+ if (block.get() == entry) {
+ // params are complex values we can't optimize; vars are zeros
+ for (Index i = 0; i < numLocals; i++) {
+ if (func->isParam(i)) {
+#ifdef RSE_DEBUG
+ std::cout << "new param value for " << i << '\n';
+#endif
+ start[i] = getUniqueValue();
+ } else {
+ start[i] = getLiteralValue(LiteralUtils::makeLiteralZero(func->getLocalType(i)));
+ }
+ }
+ } else {
+ // other blocks have all unseen values to begin with
+ for (Index i = 0; i < numLocals; i++) {
+ start[i] = getUnseenValue();
+ }
+ }
+ // the ends all begin unseen
+ LocalValues& end = block->contents.end;
+ end.resize(numLocals);
+ for (Index i = 0; i < numLocals; i++) {
+ end[i] = getUnseenValue();
+ }
+ }
+ // keep working while stuff is flowing. we use a unique deferred queue
+ // which ensures both FIFO and that we don't do needless work - if
+ // A and B reach C, and both queue C, we only want to do C at the latest
+ // time, when we have information from all those reaching it.
+ UniqueDeferredQueue<BasicBlock*> work;
+ work.push(entry);
+ while (!work.empty()) {
+ auto* curr = work.pop();
+#ifdef RSE_DEBUG
+ std::cout << "flow block " << curr << '\n';
+#endif
+ // process a block: first, update its start based on those reaching it
+ if (!curr->in.empty()) {
+ if (curr->in.size() == 1) {
+ // just copy the pred, nothing to merge
+ curr->contents.start = (*curr->in.begin())->contents.end;
+ } else {
+ // perform a merge
+ auto in = curr->in;
+ for (Index i = 0; i < numLocals; i++) {
+ auto old = curr->contents.start[i];
+ // If we already had a merge value here, keep it.
+ // TODO This may have some false positives, as we may e.g. have
+ // a single pred that first gives us x, then later y after
+ // flow led to a merge, and we may see x and y at the same
+ // time due to flow from a successor, and then it looks like
+ // we need a merge but we don't. avoiding that would require
+ // more memory and is probably not worth it, but might be
+ // worth investigating
+ // NB While suboptimal, this simplification provides a simple proof
+ // of convergence. We prove that, in each fixed block+local,
+ // the value number at the end is nondecreasing across
+ // iterations, by induction on the iteration:
+ // * The first iteration is on the entry block. It increases
+ // the value number at the end from 0 (unseen) to something
+ // else (a value number for 0 for locals, a unique value
+ // for params; all >0).
+ // * Induction step: assuming the property holds for all past
+ // iterations, consider the current iteration. Of our
+ // predecessors, those that we iterated on have the property;
+ // those that we haven't will have 0 (unseen).
+ // * If we assign to that local in this block, that will be
+ // the value in the output, forever, and it is greater
+ // than the initial value of 0.
+ // * If we see different values coming in, we create a merge
+ // value number. Its number is higher than everything
+ // else since we give it the next available number, so we
+ // do not decrease in this iteration, and we will output
+ // the same value in the future too (here is where we use
+ // the simplification property).
+ // * Otherwise, we will flow the incoming value through,
+ // and it did not decrease (by induction), so neither do
+ // we.
+ // Finally, given value numbers are nondecreasing, we must
+ // converge since we only keep working as long as we see new
+ // values at the end of a block.
+ //
+ // Not that we don't trust this proof, but the convergence
+ // property (value numbers at block ends do not decrease) is
+ // verified later down.
+ if (isBlockMergeValue(curr, i, old)) {
+ continue;
+ }
+ auto iter = in.begin();
+ auto value = (*iter)->contents.end[i];
+ iter++;
+ while (iter != in.end()) {
+ auto otherValue = (*iter)->contents.end[i];
+ if (value == getUnseenValue()) {
+ value = otherValue;
+ } else if (otherValue == getUnseenValue()) {
+ // nothing to do, other has no information
+ } else if (value != otherValue) {
+ // 2 different values, this is a merged value
+ value = getBlockMergeValue(curr, i);
+ break; // no more work once we see a merge
+ }
+ iter++;
+ }
+ curr->contents.start[i] = value;
+ }
+ }
+ }
+#ifdef RSE_DEBUG
+ dump("start", curr->contents.start);
+#endif
+ // flow values through it, then add those we can reach if they need an update.
+ auto currValues = curr->contents.start; // we'll modify this as we go
+ auto& setps = curr->contents.setps;
+ for (auto** setp : setps) {
+ auto* set = (*setp)->cast<SetLocal>();
+ currValues[set->index] = getValue(set->value, currValues);
+ }
+ if (currValues == curr->contents.end) {
+ // nothing changed, so no more work to do
+ // note that the first iteration this is always not the case,
+ // since end contains unseen (and then the comparison ends on
+ // the first element)
+ continue;
+ }
+ // update the end state and update children
+#ifndef NDEBUG
+ // verify the convergence property mentioned in the NB comment
+ // above: the value numbers at the end must be nondecreasing
+ for (Index i = 0; i < numLocals; i++) {
+ assert(currValues[i] >= curr->contents.end[i]);
+ }
+#endif
+ curr->contents.end.swap(currValues);
+#ifdef RSE_DEBUG
+ dump("end ", curr->contents.end);
+#endif
+ for (auto* next : curr->out) {
+ work.push(next);
+ }
+ }
+ }
+
+ // optimizing
+ void optimize() {
+ // in each block, run the values through the sets,
+ // and remove redundant sets when we see them
+ for (auto& block : basicBlocks) {
+ auto currValues = block->contents.start; // we'll modify this as we go
+ auto& setps = block->contents.setps;
+ for (auto** setp : setps) {
+ auto* set = (*setp)->cast<SetLocal>();
+ auto oldValue = currValues[set->index];
+ auto newValue = getValue(set->value, currValues);
+ auto index = set->index;
+ if (newValue == oldValue) {
+ remove(setp);
+ continue; // no more work to do
+ }
+ // update for later steps
+ currValues[index] = newValue;
+ }
+ }
+ }
+
+ void remove(Expression** setp) {
+ auto* set = (*setp)->cast<SetLocal>();
+ auto* value = set->value;
+ if (!set->isTee()) {
+ auto* drop = ExpressionManipulator::convert<SetLocal, Drop>(set);
+ drop->value = value;
+ drop->finalize();
+ } else {
+ *setp = value;
+ }
+ }
+
+ // debugging
+
+ void dump(BasicBlock* block) {
+ std::cout << "====\n";
+ if (block) {
+ std::cout << "block: " << block << '\n';
+ for (auto* out : block->out) {
+ std::cout << " goes to " << out << '\n';
+ }
+ }
+ for (Index i = 0; i < block->contents.start.size(); i++) {
+ std::cout << " start[" << i << "] = " << block->contents.start[i] << '\n';
+ }
+ for (auto** setp : block->contents.setps) {
+ std::cout << " " << *setp << '\n';
+ }
+ std::cout << "====\n";
+ }
+
+ void dump(const char* desc, LocalValues& values) {
+ std::cout << desc << ": ";
+ for (auto x : values) {
+ std::cout << x << ' ';
+ }
+ std::cout << '\n';
+ }
+};
+
+Pass *createRedundantSetEliminationPass() {
+ return new RedundantSetElimination();
+}
+
+} // namespace wasm
+