diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/passes/RedundantSetElimination.cpp | 120 |
1 files changed, 92 insertions, 28 deletions
diff --git a/src/passes/RedundantSetElimination.cpp b/src/passes/RedundantSetElimination.cpp index 3d21b7d71..4f82d5181 100644 --- a/src/passes/RedundantSetElimination.cpp +++ b/src/passes/RedundantSetElimination.cpp @@ -39,6 +39,7 @@ #include <ir/properties.h> #include <ir/utils.h> #include <pass.h> +#include <support/small_set.h> #include <support/unique_deferring_queue.h> #include <wasm-builder.h> #include <wasm.h> @@ -54,7 +55,7 @@ namespace { // information in a basic block struct Info { LocalValues start, end; // the local values at the start and end of the block - std::vector<Expression**> setps; + std::vector<Expression**> items; }; struct RedundantSetElimination @@ -74,10 +75,16 @@ struct RedundantSetElimination // cfg traversal work + static void doVisitLocalGet(RedundantSetElimination* self, + Expression** currp) { + if (self->currBasicBlock) { + self->currBasicBlock->contents.items.push_back(currp); + } + } static void doVisitLocalSet(RedundantSetElimination* self, Expression** currp) { if (self->currBasicBlock) { - self->currBasicBlock->contents.setps.push_back(currp); + self->currBasicBlock->contents.items.push_back(currp); } } @@ -98,7 +105,7 @@ struct RedundantSetElimination // flow values across blocks flowValues(func); // remove redundant sets - optimize(); + optimize(func); if (refinalize) { ReFinalize().walkFunctionInModule(func, this->getModule()); @@ -295,12 +302,13 @@ struct RedundantSetElimination // flow values through it, then add those we can reach if they need an // update. auto currValues = curr->contents.start; // we'll modify this as we go - auto& setps = curr->contents.setps; - for (auto** setp : setps) { - auto* set = (*setp)->cast<LocalSet>(); - auto* value = Properties::getFallthrough( - set->value, getPassOptions(), *getModule()); - currValues[set->index] = getValue(value, currValues); + auto& items = curr->contents.items; + for (auto** item : items) { + if (auto* set = (*item)->dynCast<LocalSet>()) { + auto* value = Properties::getFallthrough( + set->value, getPassOptions(), *getModule()); + currValues[set->index] = getValue(value, currValues); + } } if (currValues == curr->contents.end) { // nothing changed, so no more work to do @@ -328,31 +336,87 @@ struct RedundantSetElimination } // optimizing - void optimize() { + void optimize(Function* func) { + // Find which locals are refinable, that is, that when we see a global.get + // of them we may consider switching to another local index that has the + // same value but in a refined type. Computing which locals are relevant for + // that optimization is efficient because it avoids a bunch of work below + // for hashing numbers etc. + std::vector<bool> isRefinable(numLocals, false); + for (Index i = 0; i < numLocals; i++) { + // TODO: we could also note which locals have "maximal" types, where no + // other local is a refinement of them + if (func->getLocalType(i).isRef()) { + isRefinable[i] = true; + } + } + // in each block, run the values through the sets, // and remove redundant sets when we see them for (auto& block : basicBlocks) { auto currValues = block->contents.start; // we'll modify this as we go - auto& setps = block->contents.setps; - for (auto** setp : setps) { - auto* set = (*setp)->cast<LocalSet>(); - auto oldValue = currValues[set->index]; - auto* value = Properties::getFallthrough( - set->value, getPassOptions(), *getModule()); - auto newValue = getValue(value, currValues); - auto index = set->index; - if (newValue == oldValue) { - remove(setp); - continue; // no more work to do + auto& items = block->contents.items; + + // Set up the equivalences at the beginning of the block. We'll update + // them as we go, so we can use them at any point in the middle. This data + // structure maps a value number to the local indexes that have that + // value. + // + // Note that the set here must be ordered to avoid nondeterminism when + // picking between multiple equally-good indexes (we'll pick the first in + // the iteration, which will have the lowest index). + std::unordered_map<Index, SmallSet<Index, 3>> valueToLocals; + assert(currValues.size() == numLocals); + for (Index i = 0; i < numLocals; i++) { + if (isRefinable[i]) { + valueToLocals[currValues[i]].insert(i); + } + } + + for (auto** item : items) { + if (auto* set = (*item)->dynCast<LocalSet>()) { + auto oldValue = currValues[set->index]; + auto* value = Properties::getFallthrough( + set->value, getPassOptions(), *getModule()); + auto newValue = getValue(value, currValues); + auto index = set->index; + if (newValue == oldValue) { + remove(item); + } else { + // update for later steps + currValues[index] = newValue; + if (isRefinable[index]) { + valueToLocals[oldValue].erase(index); + valueToLocals[newValue].insert(index); + } + } + continue; + } + + // For gets, see if there is another index with that value, of a more + // refined type. + auto* get = (*item)->dynCast<LocalGet>(); + if (!isRefinable[get->index]) { + continue; + } + + for (auto i : valueToLocals[getValue(get, currValues)]) { + auto currType = func->getLocalType(get->index); + auto possibleType = func->getLocalType(i); + if (possibleType != currType && + Type::isSubType(possibleType, currType)) { + // We found an improvement! + get->index = i; + get->type = possibleType; + refinalize = true; + } } - // update for later steps - currValues[index] = newValue; } } } - void remove(Expression** setp) { - auto* set = (*setp)->cast<LocalSet>(); + void remove(Expression** item) { + auto* set = (*item)->cast<LocalSet>(); auto* value = set->value; if (!set->isTee()) { auto* drop = ExpressionManipulator::convert<LocalSet, Drop>(set); @@ -373,7 +437,7 @@ struct RedundantSetElimination if (value->type != set->type) { refinalize = true; } - *setp = value; + *item = value; } } @@ -391,8 +455,8 @@ struct RedundantSetElimination std::cout << " start[" << i << "] = " << block->contents.start[i] << '\n'; } - for (auto** setp : block->contents.setps) { - std::cout << " " << *setp << '\n'; + for (auto** item : block->contents.items) { + std::cout << " " << *item << '\n'; } std::cout << "====\n"; } |