summaryrefslogtreecommitdiff
path: root/src/passes/RedundantSetElimination.cpp
diff options
context:
space:
mode:
authorAlon Zakai <azakai@google.com>2022-11-04 12:36:35 -0700
committerGitHub <noreply@github.com>2022-11-04 12:36:35 -0700
commit20115c8e2a038646858a0408beca7be319c31035 (patch)
tree68155ae9333c1cfebb207dc2cbd955eaea7c45a2 /src/passes/RedundantSetElimination.cpp
parent6a9a4e7e722e5f578376f3f14a67960539ac4478 (diff)
downloadbinaryen-20115c8e2a038646858a0408beca7be319c31035.tar.gz
binaryen-20115c8e2a038646858a0408beca7be319c31035.tar.bz2
binaryen-20115c8e2a038646858a0408beca7be319c31035.zip
[Wasm GC] RSE: Switch local.get to use a more refined type when possible (#5216)
Similar to #5194 but for RedundantSetElimination. This has similar benefits in terms of using a more refined local in hopes of avoiding casts in followup opts, but unlike SimplifyLocals this will operate across basic blocks. To do this, we need to track not just local.set but also local.get in that pass. Then in each basic block we can track the equivalent locals and pick from them. I see a few dozen casts removed in the J2Wasm binary. Often stuff like this happens: y = cast(x); if (..) { foo(x); // this could use y }
Diffstat (limited to 'src/passes/RedundantSetElimination.cpp')
-rw-r--r--src/passes/RedundantSetElimination.cpp120
1 files changed, 92 insertions, 28 deletions
diff --git a/src/passes/RedundantSetElimination.cpp b/src/passes/RedundantSetElimination.cpp
index 3d21b7d71..4f82d5181 100644
--- a/src/passes/RedundantSetElimination.cpp
+++ b/src/passes/RedundantSetElimination.cpp
@@ -39,6 +39,7 @@
#include <ir/properties.h>
#include <ir/utils.h>
#include <pass.h>
+#include <support/small_set.h>
#include <support/unique_deferring_queue.h>
#include <wasm-builder.h>
#include <wasm.h>
@@ -54,7 +55,7 @@ namespace {
// information in a basic block
struct Info {
LocalValues start, end; // the local values at the start and end of the block
- std::vector<Expression**> setps;
+ std::vector<Expression**> items;
};
struct RedundantSetElimination
@@ -74,10 +75,16 @@ struct RedundantSetElimination
// cfg traversal work
+ static void doVisitLocalGet(RedundantSetElimination* self,
+ Expression** currp) {
+ if (self->currBasicBlock) {
+ self->currBasicBlock->contents.items.push_back(currp);
+ }
+ }
static void doVisitLocalSet(RedundantSetElimination* self,
Expression** currp) {
if (self->currBasicBlock) {
- self->currBasicBlock->contents.setps.push_back(currp);
+ self->currBasicBlock->contents.items.push_back(currp);
}
}
@@ -98,7 +105,7 @@ struct RedundantSetElimination
// flow values across blocks
flowValues(func);
// remove redundant sets
- optimize();
+ optimize(func);
if (refinalize) {
ReFinalize().walkFunctionInModule(func, this->getModule());
@@ -295,12 +302,13 @@ struct RedundantSetElimination
// flow values through it, then add those we can reach if they need an
// update.
auto currValues = curr->contents.start; // we'll modify this as we go
- auto& setps = curr->contents.setps;
- for (auto** setp : setps) {
- auto* set = (*setp)->cast<LocalSet>();
- auto* value = Properties::getFallthrough(
- set->value, getPassOptions(), *getModule());
- currValues[set->index] = getValue(value, currValues);
+ auto& items = curr->contents.items;
+ for (auto** item : items) {
+ if (auto* set = (*item)->dynCast<LocalSet>()) {
+ auto* value = Properties::getFallthrough(
+ set->value, getPassOptions(), *getModule());
+ currValues[set->index] = getValue(value, currValues);
+ }
}
if (currValues == curr->contents.end) {
// nothing changed, so no more work to do
@@ -328,31 +336,87 @@ struct RedundantSetElimination
}
// optimizing
- void optimize() {
+ void optimize(Function* func) {
+ // Find which locals are refinable, that is, that when we see a global.get
+ // of them we may consider switching to another local index that has the
+ // same value but in a refined type. Computing which locals are relevant for
+ // that optimization is efficient because it avoids a bunch of work below
+ // for hashing numbers etc.
+ std::vector<bool> isRefinable(numLocals, false);
+ for (Index i = 0; i < numLocals; i++) {
+ // TODO: we could also note which locals have "maximal" types, where no
+ // other local is a refinement of them
+ if (func->getLocalType(i).isRef()) {
+ isRefinable[i] = true;
+ }
+ }
+
// in each block, run the values through the sets,
// and remove redundant sets when we see them
for (auto& block : basicBlocks) {
auto currValues = block->contents.start; // we'll modify this as we go
- auto& setps = block->contents.setps;
- for (auto** setp : setps) {
- auto* set = (*setp)->cast<LocalSet>();
- auto oldValue = currValues[set->index];
- auto* value = Properties::getFallthrough(
- set->value, getPassOptions(), *getModule());
- auto newValue = getValue(value, currValues);
- auto index = set->index;
- if (newValue == oldValue) {
- remove(setp);
- continue; // no more work to do
+ auto& items = block->contents.items;
+
+ // Set up the equivalences at the beginning of the block. We'll update
+ // them as we go, so we can use them at any point in the middle. This data
+ // structure maps a value number to the local indexes that have that
+ // value.
+ //
+ // Note that the set here must be ordered to avoid nondeterminism when
+ // picking between multiple equally-good indexes (we'll pick the first in
+ // the iteration, which will have the lowest index).
+ std::unordered_map<Index, SmallSet<Index, 3>> valueToLocals;
+ assert(currValues.size() == numLocals);
+ for (Index i = 0; i < numLocals; i++) {
+ if (isRefinable[i]) {
+ valueToLocals[currValues[i]].insert(i);
+ }
+ }
+
+ for (auto** item : items) {
+ if (auto* set = (*item)->dynCast<LocalSet>()) {
+ auto oldValue = currValues[set->index];
+ auto* value = Properties::getFallthrough(
+ set->value, getPassOptions(), *getModule());
+ auto newValue = getValue(value, currValues);
+ auto index = set->index;
+ if (newValue == oldValue) {
+ remove(item);
+ } else {
+ // update for later steps
+ currValues[index] = newValue;
+ if (isRefinable[index]) {
+ valueToLocals[oldValue].erase(index);
+ valueToLocals[newValue].insert(index);
+ }
+ }
+ continue;
+ }
+
+ // For gets, see if there is another index with that value, of a more
+ // refined type.
+ auto* get = (*item)->dynCast<LocalGet>();
+ if (!isRefinable[get->index]) {
+ continue;
+ }
+
+ for (auto i : valueToLocals[getValue(get, currValues)]) {
+ auto currType = func->getLocalType(get->index);
+ auto possibleType = func->getLocalType(i);
+ if (possibleType != currType &&
+ Type::isSubType(possibleType, currType)) {
+ // We found an improvement!
+ get->index = i;
+ get->type = possibleType;
+ refinalize = true;
+ }
}
- // update for later steps
- currValues[index] = newValue;
}
}
}
- void remove(Expression** setp) {
- auto* set = (*setp)->cast<LocalSet>();
+ void remove(Expression** item) {
+ auto* set = (*item)->cast<LocalSet>();
auto* value = set->value;
if (!set->isTee()) {
auto* drop = ExpressionManipulator::convert<LocalSet, Drop>(set);
@@ -373,7 +437,7 @@ struct RedundantSetElimination
if (value->type != set->type) {
refinalize = true;
}
- *setp = value;
+ *item = value;
}
}
@@ -391,8 +455,8 @@ struct RedundantSetElimination
std::cout << " start[" << i << "] = " << block->contents.start[i]
<< '\n';
}
- for (auto** setp : block->contents.setps) {
- std::cout << " " << *setp << '\n';
+ for (auto** item : block->contents.items) {
+ std::cout << " " << *item << '\n';
}
std::cout << "====\n";
}