summaryrefslogtreecommitdiff
path: root/src/passes/Souperify.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/passes/Souperify.cpp')
-rw-r--r--src/passes/Souperify.cpp193
1 files changed, 132 insertions, 61 deletions
diff --git a/src/passes/Souperify.cpp b/src/passes/Souperify.cpp
index f6700a698..1cc3037fe 100644
--- a/src/passes/Souperify.cpp
+++ b/src/passes/Souperify.cpp
@@ -35,15 +35,15 @@
// directly, without the need for *-propagate techniques.
//
-#include "wasm.h"
-#include "pass.h"
-#include "wasm-builder.h"
+#include "dataflow/graph.h"
+#include "dataflow/node.h"
+#include "dataflow/utils.h"
#include "ir/flat.h"
#include "ir/local-graph.h"
#include "ir/utils.h"
-#include "dataflow/node.h"
-#include "dataflow/graph.h"
-#include "dataflow/utils.h"
+#include "pass.h"
+#include "wasm-builder.h"
+#include "wasm.h"
namespace wasm {
@@ -62,7 +62,8 @@ struct UseFinder {
// (or rather, their values) that contain a get that uses that value.
// There may also be non-set uses of the value, for example in a drop
// or a return. We represent those with a nullptr, meaning "other".
- std::vector<Expression*> getUses(Expression* origin, Graph& graph, LocalGraph& localGraph) {
+ std::vector<Expression*>
+ getUses(Expression* origin, Graph& graph, LocalGraph& localGraph) {
if (debug() >= 2) {
std::cout << "getUses\n" << origin << '\n';
}
@@ -80,9 +81,13 @@ struct UseFinder {
// There may be loops of sets with copies between them.
std::unordered_set<SetLocal*> seenSets;
- void addSetUses(SetLocal* set, Graph& graph, LocalGraph& localGraph, std::vector<Expression*>& ret) {
+ void addSetUses(SetLocal* set,
+ Graph& graph,
+ LocalGraph& localGraph,
+ std::vector<Expression*>& ret) {
// If already handled, nothing to do here.
- if (seenSets.count(set)) return;
+ if (seenSets.count(set))
+ return;
seenSets.insert(set);
// Find all the uses of that set.
auto& gets = localGraph.setInfluences[set];
@@ -165,7 +170,12 @@ struct Trace {
// The local information graph. Used to check if a node has external uses.
LocalGraph& localGraph;
- Trace(Graph& graph, Node* toInfer, std::unordered_set<Node*>& excludeAsChildren, LocalGraph& localGraph) : graph(graph), toInfer(toInfer), excludeAsChildren(excludeAsChildren), localGraph(localGraph) {
+ Trace(Graph& graph,
+ Node* toInfer,
+ std::unordered_set<Node*>& excludeAsChildren,
+ LocalGraph& localGraph)
+ : graph(graph), toInfer(toInfer), excludeAsChildren(excludeAsChildren),
+ localGraph(localGraph) {
if (debug() >= 2) {
std::cout << "\nstart a trace (in " << graph.func->name << ")\n";
}
@@ -180,7 +190,8 @@ struct Trace {
}
// Pull in all the dependencies, starting from the value itself.
add(toInfer, 0);
- if (bad) return;
+ if (bad)
+ return;
// If we are trivial before adding pcs, we are still trivial, and
// can ignore this.
auto sizeBeforePathConditions = nodes.size();
@@ -238,7 +249,8 @@ struct Trace {
// If we've gone too deep, emit a var instead.
// Do the same if this is a node we should exclude from traces.
if (depth >= depthLimit || nodes.size() >= totalLimit ||
- (node != toInfer && excludeAsChildren.find(node) != excludeAsChildren.end())) {
+ (node != toInfer &&
+ excludeAsChildren.find(node) != excludeAsChildren.end())) {
auto type = node->getWasmType();
assert(isConcreteType(type));
auto* var = Node::makeVar(type);
@@ -293,7 +305,8 @@ struct Trace {
bad = true;
return nullptr;
}
- default: WASM_UNREACHABLE();
+ default:
+ WASM_UNREACHABLE();
}
// Assert on no cycles
assert(addedNodes.find(node) == addedNodes.end());
@@ -319,7 +332,9 @@ struct Trace {
// curr is a child of parent, and parent has a Block which we are
// give as 'node'. Add a path condition for reaching the child.
- void addPathTo(Expression* parent, Expression* curr, std::vector<Node*> conditions) {
+ void addPathTo(Expression* parent,
+ Expression* curr,
+ std::vector<Node*> conditions) {
if (auto* iff = parent->dynCast<If>()) {
Index index;
if (curr == iff->ifTrue) {
@@ -340,9 +355,7 @@ struct Trace {
}
}
- bool isBad() {
- return bad;
- }
+ bool isBad() { return bad; }
static bool isTraceable(Node* node) {
if (!node->origin) {
@@ -372,7 +385,8 @@ struct Trace {
}
}
for (auto& node : nodes) {
- if (node == toInfer) continue;
+ if (node == toInfer)
+ continue;
if (auto* origin = node->origin) {
auto uses = UseFinder().getUses(origin, graph, localGraph);
for (auto* use : uses) {
@@ -407,7 +421,8 @@ struct Printer {
std::cout << "\n; start LHS (in " << graph.func->name << ")\n";
// Index the nodes.
for (auto* node : trace.nodes) {
- if (!node->isCond()) { // pcs and blockpcs are not instructions and do not need to be indexed
+ // pcs and blockpcs are not instructions and do not need to be indexed
+ if (!node->isCond()) {
auto index = indexing.size();
indexing[node] = index;
}
@@ -440,7 +455,8 @@ struct Printer {
assert(node);
switch (node->type) {
case Node::Type::Var: {
- std::cout << "%" << indexing[node] << ":" << printType(node->wasmType) << " = var";
+ std::cout << "%" << indexing[node] << ":" << printType(node->wasmType)
+ << " = var";
break; // nothing more to add
}
case Node::Type::Expr: {
@@ -464,18 +480,21 @@ struct Printer {
break;
}
case Node::Type::Cond: {
- std::cout << "blockpc %" << indexing[node->getValue(0)] << ' ' << node->index << ' ';
+ std::cout << "blockpc %" << indexing[node->getValue(0)] << ' '
+ << node->index << ' ';
printInternal(node->getValue(1));
std::cout << " 1:i1";
break;
}
case Node::Type::Block: {
- std::cout << "%" << indexing[node] << " = block " << node->values.size();
+ std::cout << "%" << indexing[node] << " = block "
+ << node->values.size();
break;
}
case Node::Type::Zext: {
auto* child = node->getValue(0);
- std::cout << "%" << indexing[node] << ':' << printType(child->getWasmType());
+ std::cout << "%" << indexing[node] << ':'
+ << printType(child->getWasmType());
std::cout << " = zext ";
printInternal(child);
break;
@@ -484,10 +503,12 @@ struct Printer {
std::cout << "!!!BAD!!!";
WASM_UNREACHABLE();
}
- default: WASM_UNREACHABLE();
+ default:
+ WASM_UNREACHABLE();
}
if (node->isExpr() || node->isPhi()) {
- if (node->origin != trace.toInfer->origin && trace.hasExternalUses.count(node) > 0) {
+ if (node->origin != trace.toInfer->origin &&
+ trace.hasExternalUses.count(node) > 0) {
std::cout << " (hasExternalUses)";
printedHasExternalUses = true;
}
@@ -523,12 +544,19 @@ struct Printer {
} else if (auto* unary = curr->dynCast<Unary>()) {
switch (unary->op) {
case ClzInt32:
- case ClzInt64: std::cout << "ctlz"; break;
+ case ClzInt64:
+ std::cout << "ctlz";
+ break;
case CtzInt32:
- case CtzInt64: std::cout << "cttz"; break;
+ case CtzInt64:
+ std::cout << "cttz";
+ break;
case PopcntInt32:
- case PopcntInt64: std::cout << "ctpop"; break;
- default: WASM_UNREACHABLE();
+ case PopcntInt64:
+ std::cout << "ctpop";
+ break;
+ default:
+ WASM_UNREACHABLE();
}
std::cout << ' ';
auto* value = node->getValue(0);
@@ -536,48 +564,91 @@ struct Printer {
} else if (auto* binary = curr->dynCast<Binary>()) {
switch (binary->op) {
case AddInt32:
- case AddInt64: std::cout << "add"; break;
+ case AddInt64:
+ std::cout << "add";
+ break;
case SubInt32:
- case SubInt64: std::cout << "sub"; break;
+ case SubInt64:
+ std::cout << "sub";
+ break;
case MulInt32:
- case MulInt64: std::cout << "mul"; break;
+ case MulInt64:
+ std::cout << "mul";
+ break;
case DivSInt32:
- case DivSInt64: std::cout << "sdiv"; break;
+ case DivSInt64:
+ std::cout << "sdiv";
+ break;
case DivUInt32:
- case DivUInt64: std::cout << "udiv"; break;
+ case DivUInt64:
+ std::cout << "udiv";
+ break;
case RemSInt32:
- case RemSInt64: std::cout << "srem"; break;
+ case RemSInt64:
+ std::cout << "srem";
+ break;
case RemUInt32:
- case RemUInt64: std::cout << "urem"; break;
+ case RemUInt64:
+ std::cout << "urem";
+ break;
case AndInt32:
- case AndInt64: std::cout << "and"; break;
+ case AndInt64:
+ std::cout << "and";
+ break;
case OrInt32:
- case OrInt64: std::cout << "or"; break;
+ case OrInt64:
+ std::cout << "or";
+ break;
case XorInt32:
- case XorInt64: std::cout << "xor"; break;
+ case XorInt64:
+ std::cout << "xor";
+ break;
case ShlInt32:
- case ShlInt64: std::cout << "shl"; break;
+ case ShlInt64:
+ std::cout << "shl";
+ break;
case ShrUInt32:
- case ShrUInt64: std::cout << "lshr"; break;
+ case ShrUInt64:
+ std::cout << "lshr";
+ break;
case ShrSInt32:
- case ShrSInt64: std::cout << "ashr"; break;
+ case ShrSInt64:
+ std::cout << "ashr";
+ break;
case RotLInt32:
- case RotLInt64: std::cout << "rotl"; break;
+ case RotLInt64:
+ std::cout << "rotl";
+ break;
case RotRInt32:
- case RotRInt64: std::cout << "rotr"; break;
+ case RotRInt64:
+ std::cout << "rotr";
+ break;
case EqInt32:
- case EqInt64: std::cout << "eq"; break;
+ case EqInt64:
+ std::cout << "eq";
+ break;
case NeInt32:
- case NeInt64: std::cout << "ne"; break;
+ case NeInt64:
+ std::cout << "ne";
+ break;
case LtSInt32:
- case LtSInt64: std::cout << "slt"; break;
+ case LtSInt64:
+ std::cout << "slt";
+ break;
case LtUInt32:
- case LtUInt64: std::cout << "ult"; break;
+ case LtUInt64:
+ std::cout << "ult";
+ break;
case LeSInt32:
- case LeSInt64: std::cout << "sle"; break;
+ case LeSInt64:
+ std::cout << "sle";
+ break;
case LeUInt32:
- case LeUInt64: std::cout << "ule"; break;
- default: WASM_UNREACHABLE();
+ case LeUInt64:
+ std::cout << "ule";
+ break;
+ default:
+ WASM_UNREACHABLE();
}
std::cout << ' ';
auto* left = node->getValue(0);
@@ -616,11 +687,13 @@ struct Printer {
}
}
if (allInputsIdentical(node)) {
- std::cout << "^^ suspicious identical inputs! missing optimization in " << graph.func->name << "? ^^\n";
+ std::cout << "^^ suspicious identical inputs! missing optimization in "
+ << graph.func->name << "? ^^\n";
return;
}
if (!node->isPhi() && allInputsConstant(node)) {
- std::cout << "^^ suspicious constant inputs! missing optimization in " << graph.func->name << "? ^^\n";
+ std::cout << "^^ suspicious constant inputs! missing optimization in "
+ << graph.func->name << "? ^^\n";
return;
}
}
@@ -642,7 +715,8 @@ struct Souperify : public WalkerPass<PostWalker<Souperify>> {
// Build the data-flow IR.
DataFlow::Graph graph;
graph.build(func, getModule());
- if (debug() >= 2) dump(graph, std::cout);
+ if (debug() >= 2)
+ dump(graph, std::cout);
// Build the local graph data structure.
LocalGraph localGraph(func);
localGraph.computeInfluences();
@@ -653,7 +727,8 @@ struct Souperify : public WalkerPass<PostWalker<Souperify>> {
auto* node = nodePtr.get();
if (node->origin) {
// TODO: work for identical origins could be saved
- auto uses = DataFlow::UseFinder().getUses(node->origin, graph, localGraph);
+ auto uses =
+ DataFlow::UseFinder().getUses(node->origin, graph, localGraph);
if (debug() >= 2) {
std::cout << "following node has " << uses.size() << " uses\n";
dump(node, std::cout);
@@ -681,12 +756,8 @@ struct Souperify : public WalkerPass<PostWalker<Souperify>> {
}
};
-Pass *createSouperifyPass() {
- return new Souperify(false);
-}
+Pass* createSouperifyPass() { return new Souperify(false); }
-Pass *createSouperifySingleUsePass() {
- return new Souperify(true);
-}
+Pass* createSouperifySingleUsePass() { return new Souperify(true); }
} // namespace wasm