summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/passes/RedundantSetElimination.cpp120
1 files changed, 92 insertions, 28 deletions
diff --git a/src/passes/RedundantSetElimination.cpp b/src/passes/RedundantSetElimination.cpp
index 3d21b7d71..4f82d5181 100644
--- a/src/passes/RedundantSetElimination.cpp
+++ b/src/passes/RedundantSetElimination.cpp
@@ -39,6 +39,7 @@
#include <ir/properties.h>
#include <ir/utils.h>
#include <pass.h>
+#include <support/small_set.h>
#include <support/unique_deferring_queue.h>
#include <wasm-builder.h>
#include <wasm.h>
@@ -54,7 +55,7 @@ namespace {
// information in a basic block
struct Info {
LocalValues start, end; // the local values at the start and end of the block
- std::vector<Expression**> setps;
+ std::vector<Expression**> items;
};
struct RedundantSetElimination
@@ -74,10 +75,16 @@ struct RedundantSetElimination
// cfg traversal work
+ static void doVisitLocalGet(RedundantSetElimination* self,
+ Expression** currp) {
+ if (self->currBasicBlock) {
+ self->currBasicBlock->contents.items.push_back(currp);
+ }
+ }
static void doVisitLocalSet(RedundantSetElimination* self,
Expression** currp) {
if (self->currBasicBlock) {
- self->currBasicBlock->contents.setps.push_back(currp);
+ self->currBasicBlock->contents.items.push_back(currp);
}
}
@@ -98,7 +105,7 @@ struct RedundantSetElimination
// flow values across blocks
flowValues(func);
// remove redundant sets
- optimize();
+ optimize(func);
if (refinalize) {
ReFinalize().walkFunctionInModule(func, this->getModule());
@@ -295,12 +302,13 @@ struct RedundantSetElimination
// flow values through it, then add those we can reach if they need an
// update.
auto currValues = curr->contents.start; // we'll modify this as we go
- auto& setps = curr->contents.setps;
- for (auto** setp : setps) {
- auto* set = (*setp)->cast<LocalSet>();
- auto* value = Properties::getFallthrough(
- set->value, getPassOptions(), *getModule());
- currValues[set->index] = getValue(value, currValues);
+ auto& items = curr->contents.items;
+ for (auto** item : items) {
+ if (auto* set = (*item)->dynCast<LocalSet>()) {
+ auto* value = Properties::getFallthrough(
+ set->value, getPassOptions(), *getModule());
+ currValues[set->index] = getValue(value, currValues);
+ }
}
if (currValues == curr->contents.end) {
// nothing changed, so no more work to do
@@ -328,31 +336,87 @@ struct RedundantSetElimination
}
// optimizing
- void optimize() {
+ void optimize(Function* func) {
+ // Find which locals are refinable, that is, that when we see a global.get
+ // of them we may consider switching to another local index that has the
+ // same value but in a refined type. Computing which locals are relevant for
+ // that optimization is efficient because it avoids a bunch of work below
+ // for hashing numbers etc.
+ std::vector<bool> isRefinable(numLocals, false);
+ for (Index i = 0; i < numLocals; i++) {
+ // TODO: we could also note which locals have "maximal" types, where no
+ // other local is a refinement of them
+ if (func->getLocalType(i).isRef()) {
+ isRefinable[i] = true;
+ }
+ }
+
// in each block, run the values through the sets,
// and remove redundant sets when we see them
for (auto& block : basicBlocks) {
auto currValues = block->contents.start; // we'll modify this as we go
- auto& setps = block->contents.setps;
- for (auto** setp : setps) {
- auto* set = (*setp)->cast<LocalSet>();
- auto oldValue = currValues[set->index];
- auto* value = Properties::getFallthrough(
- set->value, getPassOptions(), *getModule());
- auto newValue = getValue(value, currValues);
- auto index = set->index;
- if (newValue == oldValue) {
- remove(setp);
- continue; // no more work to do
+ auto& items = block->contents.items;
+
+ // Set up the equivalences at the beginning of the block. We'll update
+ // them as we go, so we can use them at any point in the middle. This data
+ // structure maps a value number to the local indexes that have that
+ // value.
+ //
+ // Note that the set here must be ordered to avoid nondeterminism when
+ // picking between multiple equally-good indexes (we'll pick the first in
+ // the iteration, which will have the lowest index).
+ std::unordered_map<Index, SmallSet<Index, 3>> valueToLocals;
+ assert(currValues.size() == numLocals);
+ for (Index i = 0; i < numLocals; i++) {
+ if (isRefinable[i]) {
+ valueToLocals[currValues[i]].insert(i);
+ }
+ }
+
+ for (auto** item : items) {
+ if (auto* set = (*item)->dynCast<LocalSet>()) {
+ auto oldValue = currValues[set->index];
+ auto* value = Properties::getFallthrough(
+ set->value, getPassOptions(), *getModule());
+ auto newValue = getValue(value, currValues);
+ auto index = set->index;
+ if (newValue == oldValue) {
+ remove(item);
+ } else {
+ // update for later steps
+ currValues[index] = newValue;
+ if (isRefinable[index]) {
+ valueToLocals[oldValue].erase(index);
+ valueToLocals[newValue].insert(index);
+ }
+ }
+ continue;
+ }
+
+ // For gets, see if there is another index with that value, of a more
+ // refined type.
+ auto* get = (*item)->dynCast<LocalGet>();
+ if (!isRefinable[get->index]) {
+ continue;
+ }
+
+ for (auto i : valueToLocals[getValue(get, currValues)]) {
+ auto currType = func->getLocalType(get->index);
+ auto possibleType = func->getLocalType(i);
+ if (possibleType != currType &&
+ Type::isSubType(possibleType, currType)) {
+ // We found an improvement!
+ get->index = i;
+ get->type = possibleType;
+ refinalize = true;
+ }
}
- // update for later steps
- currValues[index] = newValue;
}
}
}
- void remove(Expression** setp) {
- auto* set = (*setp)->cast<LocalSet>();
+ void remove(Expression** item) {
+ auto* set = (*item)->cast<LocalSet>();
auto* value = set->value;
if (!set->isTee()) {
auto* drop = ExpressionManipulator::convert<LocalSet, Drop>(set);
@@ -373,7 +437,7 @@ struct RedundantSetElimination
if (value->type != set->type) {
refinalize = true;
}
- *setp = value;
+ *item = value;
}
}
@@ -391,8 +455,8 @@ struct RedundantSetElimination
std::cout << " start[" << i << "] = " << block->contents.start[i]
<< '\n';
}
- for (auto** setp : block->contents.setps) {
- std::cout << " " << *setp << '\n';
+ for (auto** item : block->contents.items) {
+ std::cout << " " << *item << '\n';
}
std::cout << "====\n";
}