diff options
Diffstat (limited to 'src/passes/Souperify.cpp')
-rw-r--r-- | src/passes/Souperify.cpp | 193 |
1 files changed, 132 insertions, 61 deletions
diff --git a/src/passes/Souperify.cpp b/src/passes/Souperify.cpp index f6700a698..1cc3037fe 100644 --- a/src/passes/Souperify.cpp +++ b/src/passes/Souperify.cpp @@ -35,15 +35,15 @@ // directly, without the need for *-propagate techniques. // -#include "wasm.h" -#include "pass.h" -#include "wasm-builder.h" +#include "dataflow/graph.h" +#include "dataflow/node.h" +#include "dataflow/utils.h" #include "ir/flat.h" #include "ir/local-graph.h" #include "ir/utils.h" -#include "dataflow/node.h" -#include "dataflow/graph.h" -#include "dataflow/utils.h" +#include "pass.h" +#include "wasm-builder.h" +#include "wasm.h" namespace wasm { @@ -62,7 +62,8 @@ struct UseFinder { // (or rather, their values) that contain a get that uses that value. // There may also be non-set uses of the value, for example in a drop // or a return. We represent those with a nullptr, meaning "other". - std::vector<Expression*> getUses(Expression* origin, Graph& graph, LocalGraph& localGraph) { + std::vector<Expression*> + getUses(Expression* origin, Graph& graph, LocalGraph& localGraph) { if (debug() >= 2) { std::cout << "getUses\n" << origin << '\n'; } @@ -80,9 +81,13 @@ struct UseFinder { // There may be loops of sets with copies between them. std::unordered_set<SetLocal*> seenSets; - void addSetUses(SetLocal* set, Graph& graph, LocalGraph& localGraph, std::vector<Expression*>& ret) { + void addSetUses(SetLocal* set, + Graph& graph, + LocalGraph& localGraph, + std::vector<Expression*>& ret) { // If already handled, nothing to do here. - if (seenSets.count(set)) return; + if (seenSets.count(set)) + return; seenSets.insert(set); // Find all the uses of that set. auto& gets = localGraph.setInfluences[set]; @@ -165,7 +170,12 @@ struct Trace { // The local information graph. Used to check if a node has external uses. LocalGraph& localGraph; - Trace(Graph& graph, Node* toInfer, std::unordered_set<Node*>& excludeAsChildren, LocalGraph& localGraph) : graph(graph), toInfer(toInfer), excludeAsChildren(excludeAsChildren), localGraph(localGraph) { + Trace(Graph& graph, + Node* toInfer, + std::unordered_set<Node*>& excludeAsChildren, + LocalGraph& localGraph) + : graph(graph), toInfer(toInfer), excludeAsChildren(excludeAsChildren), + localGraph(localGraph) { if (debug() >= 2) { std::cout << "\nstart a trace (in " << graph.func->name << ")\n"; } @@ -180,7 +190,8 @@ struct Trace { } // Pull in all the dependencies, starting from the value itself. add(toInfer, 0); - if (bad) return; + if (bad) + return; // If we are trivial before adding pcs, we are still trivial, and // can ignore this. auto sizeBeforePathConditions = nodes.size(); @@ -238,7 +249,8 @@ struct Trace { // If we've gone too deep, emit a var instead. // Do the same if this is a node we should exclude from traces. if (depth >= depthLimit || nodes.size() >= totalLimit || - (node != toInfer && excludeAsChildren.find(node) != excludeAsChildren.end())) { + (node != toInfer && + excludeAsChildren.find(node) != excludeAsChildren.end())) { auto type = node->getWasmType(); assert(isConcreteType(type)); auto* var = Node::makeVar(type); @@ -293,7 +305,8 @@ struct Trace { bad = true; return nullptr; } - default: WASM_UNREACHABLE(); + default: + WASM_UNREACHABLE(); } // Assert on no cycles assert(addedNodes.find(node) == addedNodes.end()); @@ -319,7 +332,9 @@ struct Trace { // curr is a child of parent, and parent has a Block which we are // give as 'node'. Add a path condition for reaching the child. - void addPathTo(Expression* parent, Expression* curr, std::vector<Node*> conditions) { + void addPathTo(Expression* parent, + Expression* curr, + std::vector<Node*> conditions) { if (auto* iff = parent->dynCast<If>()) { Index index; if (curr == iff->ifTrue) { @@ -340,9 +355,7 @@ struct Trace { } } - bool isBad() { - return bad; - } + bool isBad() { return bad; } static bool isTraceable(Node* node) { if (!node->origin) { @@ -372,7 +385,8 @@ struct Trace { } } for (auto& node : nodes) { - if (node == toInfer) continue; + if (node == toInfer) + continue; if (auto* origin = node->origin) { auto uses = UseFinder().getUses(origin, graph, localGraph); for (auto* use : uses) { @@ -407,7 +421,8 @@ struct Printer { std::cout << "\n; start LHS (in " << graph.func->name << ")\n"; // Index the nodes. for (auto* node : trace.nodes) { - if (!node->isCond()) { // pcs and blockpcs are not instructions and do not need to be indexed + // pcs and blockpcs are not instructions and do not need to be indexed + if (!node->isCond()) { auto index = indexing.size(); indexing[node] = index; } @@ -440,7 +455,8 @@ struct Printer { assert(node); switch (node->type) { case Node::Type::Var: { - std::cout << "%" << indexing[node] << ":" << printType(node->wasmType) << " = var"; + std::cout << "%" << indexing[node] << ":" << printType(node->wasmType) + << " = var"; break; // nothing more to add } case Node::Type::Expr: { @@ -464,18 +480,21 @@ struct Printer { break; } case Node::Type::Cond: { - std::cout << "blockpc %" << indexing[node->getValue(0)] << ' ' << node->index << ' '; + std::cout << "blockpc %" << indexing[node->getValue(0)] << ' ' + << node->index << ' '; printInternal(node->getValue(1)); std::cout << " 1:i1"; break; } case Node::Type::Block: { - std::cout << "%" << indexing[node] << " = block " << node->values.size(); + std::cout << "%" << indexing[node] << " = block " + << node->values.size(); break; } case Node::Type::Zext: { auto* child = node->getValue(0); - std::cout << "%" << indexing[node] << ':' << printType(child->getWasmType()); + std::cout << "%" << indexing[node] << ':' + << printType(child->getWasmType()); std::cout << " = zext "; printInternal(child); break; @@ -484,10 +503,12 @@ struct Printer { std::cout << "!!!BAD!!!"; WASM_UNREACHABLE(); } - default: WASM_UNREACHABLE(); + default: + WASM_UNREACHABLE(); } if (node->isExpr() || node->isPhi()) { - if (node->origin != trace.toInfer->origin && trace.hasExternalUses.count(node) > 0) { + if (node->origin != trace.toInfer->origin && + trace.hasExternalUses.count(node) > 0) { std::cout << " (hasExternalUses)"; printedHasExternalUses = true; } @@ -523,12 +544,19 @@ struct Printer { } else if (auto* unary = curr->dynCast<Unary>()) { switch (unary->op) { case ClzInt32: - case ClzInt64: std::cout << "ctlz"; break; + case ClzInt64: + std::cout << "ctlz"; + break; case CtzInt32: - case CtzInt64: std::cout << "cttz"; break; + case CtzInt64: + std::cout << "cttz"; + break; case PopcntInt32: - case PopcntInt64: std::cout << "ctpop"; break; - default: WASM_UNREACHABLE(); + case PopcntInt64: + std::cout << "ctpop"; + break; + default: + WASM_UNREACHABLE(); } std::cout << ' '; auto* value = node->getValue(0); @@ -536,48 +564,91 @@ struct Printer { } else if (auto* binary = curr->dynCast<Binary>()) { switch (binary->op) { case AddInt32: - case AddInt64: std::cout << "add"; break; + case AddInt64: + std::cout << "add"; + break; case SubInt32: - case SubInt64: std::cout << "sub"; break; + case SubInt64: + std::cout << "sub"; + break; case MulInt32: - case MulInt64: std::cout << "mul"; break; + case MulInt64: + std::cout << "mul"; + break; case DivSInt32: - case DivSInt64: std::cout << "sdiv"; break; + case DivSInt64: + std::cout << "sdiv"; + break; case DivUInt32: - case DivUInt64: std::cout << "udiv"; break; + case DivUInt64: + std::cout << "udiv"; + break; case RemSInt32: - case RemSInt64: std::cout << "srem"; break; + case RemSInt64: + std::cout << "srem"; + break; case RemUInt32: - case RemUInt64: std::cout << "urem"; break; + case RemUInt64: + std::cout << "urem"; + break; case AndInt32: - case AndInt64: std::cout << "and"; break; + case AndInt64: + std::cout << "and"; + break; case OrInt32: - case OrInt64: std::cout << "or"; break; + case OrInt64: + std::cout << "or"; + break; case XorInt32: - case XorInt64: std::cout << "xor"; break; + case XorInt64: + std::cout << "xor"; + break; case ShlInt32: - case ShlInt64: std::cout << "shl"; break; + case ShlInt64: + std::cout << "shl"; + break; case ShrUInt32: - case ShrUInt64: std::cout << "lshr"; break; + case ShrUInt64: + std::cout << "lshr"; + break; case ShrSInt32: - case ShrSInt64: std::cout << "ashr"; break; + case ShrSInt64: + std::cout << "ashr"; + break; case RotLInt32: - case RotLInt64: std::cout << "rotl"; break; + case RotLInt64: + std::cout << "rotl"; + break; case RotRInt32: - case RotRInt64: std::cout << "rotr"; break; + case RotRInt64: + std::cout << "rotr"; + break; case EqInt32: - case EqInt64: std::cout << "eq"; break; + case EqInt64: + std::cout << "eq"; + break; case NeInt32: - case NeInt64: std::cout << "ne"; break; + case NeInt64: + std::cout << "ne"; + break; case LtSInt32: - case LtSInt64: std::cout << "slt"; break; + case LtSInt64: + std::cout << "slt"; + break; case LtUInt32: - case LtUInt64: std::cout << "ult"; break; + case LtUInt64: + std::cout << "ult"; + break; case LeSInt32: - case LeSInt64: std::cout << "sle"; break; + case LeSInt64: + std::cout << "sle"; + break; case LeUInt32: - case LeUInt64: std::cout << "ule"; break; - default: WASM_UNREACHABLE(); + case LeUInt64: + std::cout << "ule"; + break; + default: + WASM_UNREACHABLE(); } std::cout << ' '; auto* left = node->getValue(0); @@ -616,11 +687,13 @@ struct Printer { } } if (allInputsIdentical(node)) { - std::cout << "^^ suspicious identical inputs! missing optimization in " << graph.func->name << "? ^^\n"; + std::cout << "^^ suspicious identical inputs! missing optimization in " + << graph.func->name << "? ^^\n"; return; } if (!node->isPhi() && allInputsConstant(node)) { - std::cout << "^^ suspicious constant inputs! missing optimization in " << graph.func->name << "? ^^\n"; + std::cout << "^^ suspicious constant inputs! missing optimization in " + << graph.func->name << "? ^^\n"; return; } } @@ -642,7 +715,8 @@ struct Souperify : public WalkerPass<PostWalker<Souperify>> { // Build the data-flow IR. DataFlow::Graph graph; graph.build(func, getModule()); - if (debug() >= 2) dump(graph, std::cout); + if (debug() >= 2) + dump(graph, std::cout); // Build the local graph data structure. LocalGraph localGraph(func); localGraph.computeInfluences(); @@ -653,7 +727,8 @@ struct Souperify : public WalkerPass<PostWalker<Souperify>> { auto* node = nodePtr.get(); if (node->origin) { // TODO: work for identical origins could be saved - auto uses = DataFlow::UseFinder().getUses(node->origin, graph, localGraph); + auto uses = + DataFlow::UseFinder().getUses(node->origin, graph, localGraph); if (debug() >= 2) { std::cout << "following node has " << uses.size() << " uses\n"; dump(node, std::cout); @@ -681,12 +756,8 @@ struct Souperify : public WalkerPass<PostWalker<Souperify>> { } }; -Pass *createSouperifyPass() { - return new Souperify(false); -} +Pass* createSouperifyPass() { return new Souperify(false); } -Pass *createSouperifySingleUsePass() { - return new Souperify(true); -} +Pass* createSouperifySingleUsePass() { return new Souperify(true); } } // namespace wasm |