diff options
Diffstat (limited to 'src/passes/Souperify.cpp')
-rw-r--r-- | src/passes/Souperify.cpp | 691 |
1 files changed, 691 insertions, 0 deletions
diff --git a/src/passes/Souperify.cpp b/src/passes/Souperify.cpp new file mode 100644 index 000000000..a54146d38 --- /dev/null +++ b/src/passes/Souperify.cpp @@ -0,0 +1,691 @@ +/* + * Copyright 2018 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// +// Souperify - convert to Souper IR in text form. +// +// This needs 'flatten' to be run before it, as it assumes the IR is in +// flat form. You may also want to optimize a little, e.g. +// --flatten --simplify-locals-nonesting --reorder-locals +// (as otherwise flattening introduces many copies; we do ignore boring +// copies here, but they end up as identical LHSes). +// +// See https://github.com/google/souper/issues/323 +// +// TODO: +// * pcs and blockpcs for things other than ifs +// * Investigate 'inlining', adding in nodes through calls +// * Consider generalizing DataFlow IR for internal Binaryen use. +// * Automatic conversion of Binaryen IR opts to run on the DataFlow IR. +// This would subsume precompute-propagate, for example. Using DFIR we +// can "expand" the BIR into expressions that BIR opts can handle +// directly, without the need for *-propagate techniques. +// + +#include "wasm.h" +#include "pass.h" +#include "wasm-builder.h" +#include "ir/local-graph.h" +#include "ir/utils.h" +#include "dataflow/node.h" +#include "dataflow/graph.h" +#include "dataflow/utils.h" + +namespace wasm { + +static int debug() { + static char* str = getenv("BINARYEN_DEBUG_SOUPERIFY"); + static int ret = str ? atoi(str) : 0; + return ret; +} + +namespace DataFlow { + +// Internal helper to find all the uses of a set. +struct UseFinder { + // Gets a list of all the uses of an expression. As we are in flat IR, + // the expression must be the value of a set, and we seek the other sets + // (or rather, their values) that contain a get that uses that value. + // There may also be non-set uses of the value, for example in a drop + // or a return. We represent those with a nullptr, meaning "other". + std::vector<Expression*> getUses(Expression* origin, Graph& graph, LocalGraph& localGraph) { + if (debug() >= 2) { + std::cout << "getUses\n" << origin << '\n'; + } + std::vector<Expression*> ret; + auto* set = graph.getSet(origin); + if (!set) { + // If the parent is not a set (a drop, call, return, etc.) then + // it is not something we need to track. + return ret; + } + addSetUses(set, graph, localGraph, ret); + return ret; + } + + // There may be loops of sets with copies between them. + std::unordered_set<SetLocal*> seenSets; + + void addSetUses(SetLocal* set, Graph& graph, LocalGraph& localGraph, std::vector<Expression*>& ret) { + // If already handled, nothing to do here. + if (seenSets.count(set)) return; + seenSets.insert(set); + // Find all the uses of that set. + auto& gets = localGraph.setInfluences[set]; + if (debug() >= 2) { + std::cout << "addSetUses for " << set << ", " << gets.size() << " gets\n"; + } + for (auto* get : gets) { + // Each of these relevant gets is either + // (1) a child of a set, which we can track, or + // (2) not a child of a set, e.g., a call argument or such + auto& sets = localGraph.getInfluences[get]; // TODO: iterator + // In flat IR, each get can influence at most 1 set. + assert(sets.size() <= 1); + if (sets.size() == 0) { + // This get is not the child of a set. Check if it is a drop, + // otherwise it is an actual use, and so an external use. + auto* parent = graph.getParent(get); + if (parent && parent->is<Drop>()) { + // Just ignore it. + } else { + ret.push_back(nullptr); + if (debug() >= 2) { + std::cout << "add nullptr\n"; + } + } + } else { + // This get is the child of a set. + auto* subSet = *sets.begin(); + // If this is a copy, we need to look through it: data-flow IR + // counts actual values, not copies, and in particular we need + // to look through the copies that implement a phi. + if (subSet->value == get) { + // Indeed a copy. + // TODO: this could be optimized and done all at once beforehand. + addSetUses(subSet, graph, localGraph, ret); + } else { + // Not a copy. + auto* value = subSet->value; + ret.push_back(value); + if (debug() >= 2) { + std::cout << "add a value\n" << value << '\n'; + } + } + } + } + } +}; + +// Generates a trace: all the information to generate a Souper LHS +// for a specific set_local whose value we want to infer. +struct Trace { + Graph& graph; + Node* toInfer; + // Nodes we should exclude from being children of traces (but they + // may be the root we try to infer. + std::unordered_set<Node*>& excludeAsChildren; + + // A limit on how deep we go - we don't want to create arbitrarily + // large traces. + size_t depthLimit = 10; + size_t totalLimit = 30; + + bool bad = false; + std::vector<Node*> nodes; + std::unordered_set<Node*> addedNodes; + std::vector<Node*> pathConditions; + // When we need to (like when the depth is too deep), we replace + // expressions with other expressions, and track them here. + std::unordered_map<Node*, std::unique_ptr<Node>> replacements; + // The nodes that have additional external uses (only computed + // for the "work" nodes, not the descriptive nodes arriving for + // path conditions). + std::unordered_set<Node*> hasExternalUses; + // We add path conditions after the "work". We collect them here + // and then go through them at the proper time. + std::vector<Node*> conditionsToAdd; + // Whether we are at the adding-conditions stage (i.e., post + // adding the "work"). + bool addingConditions = false; + // The local information graph. Used to check if a node has external uses. + LocalGraph& localGraph; + + Trace(Graph& graph, Node* toInfer, std::unordered_set<Node*>& excludeAsChildren, LocalGraph& localGraph) : graph(graph), toInfer(toInfer), excludeAsChildren(excludeAsChildren), localGraph(localGraph) { + if (debug() >= 2) { + std::cout << "\nstart a trace (in " << graph.func->name << ")\n"; + } + // Check if there is a depth limit override + auto* depthLimitStr = getenv("BINARYEN_SOUPERIFY_DEPTH_LIMIT"); + if (depthLimitStr) { + depthLimit = atoi(depthLimitStr); + } + auto* totalLimitStr = getenv("BINARYEN_SOUPERIFY_TOTAL_LIMIT"); + if (totalLimitStr) { + totalLimit = atoi(totalLimitStr); + } + // Pull in all the dependencies, starting from the value itself. + add(toInfer, 0); + if (bad) return; + // If we are trivial before adding pcs, we are still trivial, and + // can ignore this. + auto sizeBeforePathConditions = nodes.size(); + // No input is uninteresting + if (sizeBeforePathConditions == 0) { + bad = true; + return; + } + // Just a var is uninteresting. TODO: others too? + if (sizeBeforePathConditions == 1 && nodes[0]->isVar()) { + bad = true; + return; + } + // Before adding the path conditions, we can now compute the + // actual number of uses of "work" nodes, the real computation done + // here and that we hope to replace, as opposed to path condition + // computation which is only descriptive and helps optimization of + // the work. + findExternalUses(); + // We can now add conditions. + addingConditions = true; + for (auto* condition : conditionsToAdd) { + add(condition, 0); + } + // Add in path conditions based on the location of this node: e.g. + // if it is inside an if's true branch, we can add a path-condition + // for that. + auto iter = graph.nodeParentMap.find(toInfer); + if (iter != graph.nodeParentMap.end()) { + addPath(toInfer, iter->second); + } + } + + Node* add(Node* node, size_t depth) { + depth++; + // If replaced, return the replacement. + auto iter = replacements.find(node); + if (iter != replacements.end()) { + return iter->second.get(); + } + // If already added, nothing more to do. + if (addedNodes.find(node) != addedNodes.end()) { + return node; + } + switch (node->type) { + case Node::Type::Var: { + break; // nothing more to add + } + case Node::Type::Expr: { + // If this is a Const, it's not an instruction - nothing to add, + // it's just a value. + if (node->expr->is<Const>()) { + return node; + } + // If we've gone too deep, emit a var instead. + // Do the same if this is a node we should exclude from traces. + if (depth >= depthLimit || nodes.size() >= totalLimit || + (node != toInfer && excludeAsChildren.find(node) != excludeAsChildren.end())) { + auto type = node->getWasmType(); + assert(isConcreteType(type)); + auto* var = Node::makeVar(type); + replacements[node] = std::unique_ptr<Node>(var); + node = var; + break; + } + // Add the dependencies. + assert(!node->expr->is<GetLocal>()); + for (Index i = 0; i < node->values.size(); i++) { + add(node->getValue(i), depth); + } + break; + } + case Node::Type::Phi: { + auto* block = add(node->getValue(0), depth); + assert(block); + auto size = block->values.size(); + // First, add the conditions for the block + for (Index i = 0; i < size; i++) { + // a condition may be bad, but conditions are not necessary - + // we can proceed without the extra condition information + auto* condition = block->getValue(i); + if (!condition->isBad()) { + if (!addingConditions) { + // Too early, queue it for later. + conditionsToAdd.push_back(condition); + } else { + add(condition, depth); + } + } + } + // Then, add the phi values + for (Index i = 1; i < size + 1; i++) { + add(node->getValue(i), depth); + } + break; + } + case Node::Type::Cond: { + add(node->getValue(0), depth); // add the block + add(node->getValue(1), depth); // add the node + break; + } + case Node::Type::Block: { + break; // nothing more to add + } + case Node::Type::Zext: { + add(node->getValue(0), depth); + break; + } + case Node::Type::Bad: { + bad = true; + return nullptr; + } + default: WASM_UNREACHABLE(); + } + // Assert on no cycles + assert(addedNodes.find(node) == addedNodes.end()); + nodes.push_back(node); + addedNodes.insert(node); + return node; + } + + void addPath(Node* node, Expression* curr) { + // We track curr and parent, which are always in the state of parent + // being the parent of curr. + auto* parent = graph.expressionParentMap.at(curr); + while (parent) { + auto iter = graph.expressionConditionMap.find(parent); + if (iter != graph.expressionConditionMap.end()) { + // Given the block, add a proper path-condition + addPathTo(parent, curr, iter->second); + } + curr = parent; + parent = graph.expressionParentMap.at(parent); + } + } + + // curr is a child of parent, and parent has a Block which we are + // give as 'node'. Add a path condition for reaching the child. + void addPathTo(Expression* parent, Expression* curr, std::vector<Node*> conditions) { + if (auto* iff = parent->dynCast<If>()) { + Index index; + if (curr == iff->ifTrue) { + index = 0; + } else if (curr == iff->ifFalse) { + index = 1; + } else { + WASM_UNREACHABLE(); + } + auto* condition = conditions[index]; + // Add the condition itself as an instruction in the trace - + // the pc uses it as its input. + add(condition, 0); + // Add it as a pc, which we will emit directly. + pathConditions.push_back(condition); + } else { + WASM_UNREACHABLE(); + } + } + + bool isBad() { + return bad; + } + + static bool isTraceable(Node* node) { + if (!node->origin) { + // Ignore artificial etc. nodes. + // TODO: perhaps require all the nodes for an origin appear, so we + // don't try to compute an internal part of one, like the + // extra artificial != 0 of a select? + return false; + } + if (node->isExpr()) { + // Consider only the simple computational nodes. + auto* expr = node->expr; + return expr->is<Unary>() || expr->is<Binary>() || expr->is<Select>(); + } + return false; + } + + void findExternalUses() { + // Find all the wasm code represented in this trace. + std::unordered_set<Expression*> origins; + for (auto& node : nodes) { + if (auto* origin = node->origin) { + if (debug() >= 2) { + std::cout << "note origin " << origin << '\n'; + } + origins.insert(origin); + } + } + for (auto& node : nodes) { + if (node == toInfer) continue; + if (auto* origin = node->origin) { + auto uses = UseFinder().getUses(origin, graph, localGraph); + for (auto* use : uses) { + // A non-set use (a drop or return etc.) is definitely external. + // Otherwise, check if internal or external. + if (use == nullptr || origins.count(use) == 0) { + if (debug() >= 2) { + std::cout << "found external use for\n"; + dump(node, std::cout); + std::cout << " due to " << use << '\n'; + } + hasExternalUses.insert(node); + break; + } + } + } + } + } +}; + +// Emits a trace, which is basically a Souper LHS. +struct Printer { + Graph& graph; + Trace& trace; + + // Each Node in a trace has an index, from 0. + std::unordered_map<Node*, Index> indexing; + + bool printedHasExternalUses = false; + + Printer(Graph& graph, Trace& trace) : graph(graph), trace(trace) { + std::cout << "\n; start LHS (in " << graph.func->name << ")\n"; + // Index the nodes. + for (auto* node : trace.nodes) { + if (!node->isCond()) { // pcs and blockpcs are not instructions and do not need to be indexed + auto index = indexing.size(); + indexing[node] = index; + } + } + // Print them out. + for (auto* node : trace.nodes) { + print(node); + } + // Print out pcs. + for (auto* condition : trace.pathConditions) { + printPathCondition(condition); + } + + // Finish up + std::cout << "infer %" << indexing[trace.toInfer] << "\n\n"; + } + + Node* getMaybeReplaced(Node* node) { + auto iter = trace.replacements.find(node); + if (iter != trace.replacements.end()) { + return iter->second.get(); + } + return node; + } + + void print(Node* node) { + // The node may have been replaced during trace building, if so then + // print the proper replacement. + node = getMaybeReplaced(node); + assert(node); + switch (node->type) { + case Node::Type::Var: { + std::cout << "%" << indexing[node] << ":" << printType(node->wasmType) << " = var"; + break; // nothing more to add + } + case Node::Type::Expr: { + if (debug()) { + std::cout << "; "; + WasmPrinter::printExpression(node->expr, std::cout, true); + std::cout << '\n'; + } + std::cout << "%" << indexing[node] << " = "; + printExpression(node); + break; + } + case Node::Type::Phi: { + auto* block = node->getValue(0); + auto size = block->values.size(); + std::cout << "%" << indexing[node] << " = phi %" << indexing[block]; + for (Index i = 1; i < size + 1; i++) { + std::cout << ", "; + printInternal(node->getValue(i)); + } + break; + } + case Node::Type::Cond: { + std::cout << "blockpc %" << indexing[node->getValue(0)] << ' ' << node->index << ' '; + printInternal(node->getValue(1)); + std::cout << " 1:i1"; + break; + } + case Node::Type::Block: { + std::cout << "%" << indexing[node] << " = block " << node->values.size(); + break; + } + case Node::Type::Zext: { + auto* child = node->getValue(0); + std::cout << "%" << indexing[node] << ':' << printType(child->getWasmType()); + std::cout << " = zext "; + printInternal(child); + break; + } + case Node::Type::Bad: { + std::cout << "!!!BAD!!!"; + WASM_UNREACHABLE(); + } + default: WASM_UNREACHABLE(); + } + if (node->isExpr() || node->isPhi()) { + if (node->origin != trace.toInfer->origin && trace.hasExternalUses.count(node) > 0) { + std::cout << " (hasExternalUses)"; + printedHasExternalUses = true; + } + } + std::cout << '\n'; + if (debug() && (node->isExpr() || node->isPhi())) { + warnOnSuspiciousValues(node); + } + } + + void print(Literal value) { + std::cout << value.getInteger() << ':' << printType(value.type); + } + + void printInternal(Node* node) { + node = getMaybeReplaced(node); + assert(node); + if (node->isConst()) { + print(node->expr->cast<Const>()->value); + } else { + std::cout << "%" << indexing[node]; + } + } + + // Emit an expression + + void printExpression(Node* node) { + assert(node->isExpr()); + // TODO use a Visitor here? + auto* curr = node->expr; + if (auto* c = curr->dynCast<Const>()) { + print(c->value); + } else if (auto* unary = curr->dynCast<Unary>()) { + switch (unary->op) { + case ClzInt32: + case ClzInt64: std::cout << "ctlz"; break; + case CtzInt32: + case CtzInt64: std::cout << "cttz"; break; + case PopcntInt32: + case PopcntInt64: std::cout << "ctpop"; break; + default: WASM_UNREACHABLE(); + } + std::cout << ' '; + auto* value = node->getValue(0); + printInternal(value); + } else if (auto* binary = curr->dynCast<Binary>()) { + switch (binary->op) { + case AddInt32: + case AddInt64: std::cout << "add"; break; + case SubInt32: + case SubInt64: std::cout << "sub"; break; + case MulInt32: + case MulInt64: std::cout << "mul"; break; + case DivSInt32: + case DivSInt64: std::cout << "sdiv"; break; + case DivUInt32: + case DivUInt64: std::cout << "udiv"; break; + case RemSInt32: + case RemSInt64: std::cout << "srem"; break; + case RemUInt32: + case RemUInt64: std::cout << "urem"; break; + case AndInt32: + case AndInt64: std::cout << "and"; break; + case OrInt32: + case OrInt64: std::cout << "or"; break; + case XorInt32: + case XorInt64: std::cout << "xor"; break; + case ShlInt32: + case ShlInt64: std::cout << "shl"; break; + case ShrUInt32: + case ShrUInt64: std::cout << "lshr"; break; + case ShrSInt32: + case ShrSInt64: std::cout << "ashr"; break; + case RotLInt32: + case RotLInt64: std::cout << "rotl"; break; + case RotRInt32: + case RotRInt64: std::cout << "rotr"; break; + case EqInt32: + case EqInt64: std::cout << "eq"; break; + case NeInt32: + case NeInt64: std::cout << "ne"; break; + case LtSInt32: + case LtSInt64: std::cout << "slt"; break; + case LtUInt32: + case LtUInt64: std::cout << "ult"; break; + case LeSInt32: + case LeSInt64: std::cout << "sle"; break; + case LeUInt32: + case LeUInt64: std::cout << "ule"; break; + default: WASM_UNREACHABLE(); + } + std::cout << ' '; + auto* left = node->getValue(0); + printInternal(left); + std::cout << ", "; + auto* right = node->getValue(1); + printInternal(right); + } else if (curr->is<Select>()) { + std::cout << "select "; + printInternal(node->getValue(0)); + std::cout << ", "; + printInternal(node->getValue(1)); + std::cout << ", "; + printInternal(node->getValue(2)); + } else { + WASM_UNREACHABLE(); + } + } + + void printPathCondition(Node* condition) { + std::cout << "pc "; + printInternal(condition); + std::cout << " 1:i1\n"; + } + + // Checks if a value looks suspiciously optimizable. + void warnOnSuspiciousValues(Node* node) { + assert(debug()); + // If the node has no uses, it's not interesting enough to be + // suspicious. TODO + // If an input was replaced with a var, then we should not + // look into it, it's not suspiciously trivial. + for (auto* value : node->values) { + if (value != getMaybeReplaced(value)) { + return; + } + } + if (allInputsIdentical(node)) { + std::cout << "^^ suspicious identical inputs! missing optimization in " << graph.func->name << "? ^^\n"; + return; + } + if (!node->isPhi() && allInputsConstant(node)) { + std::cout << "^^ suspicious constant inputs! missing optimization in " << graph.func->name << "? ^^\n"; + return; + } + } +}; + +} // namespace DataFlow + +struct Souperify : public WalkerPass<PostWalker<Souperify>> { + // Not parallel, for now - could parallelize and combine outputs at the end. + // If Souper is thread-safe, we could also run it in parallel. + + bool singleUseOnly; + + Souperify(bool singleUseOnly) : singleUseOnly(singleUseOnly) {} + + void doWalkFunction(Function* func) { + std::cout << "\n; function: " << func->name << '\n'; + // Build the data-flow IR. + DataFlow::Graph graph; + graph.build(func, getModule()); + if (debug() >= 2) dump(graph, std::cout); + // Build the local graph data structure. + LocalGraph localGraph(func); + localGraph.computeInfluences(); + // If we only want single-use nodes, exclude all the others. + std::unordered_set<DataFlow::Node*> excludeAsChildren; + if (singleUseOnly) { + for (auto& nodePtr : graph.nodes) { + auto* node = nodePtr.get(); + if (node->origin) { + // TODO: work for identical origins could be saved + auto uses = DataFlow::UseFinder().getUses(node->origin, graph, localGraph); + if (debug() >= 2) { + std::cout << "following node has " << uses.size() << " uses\n"; + dump(node, std::cout); + } + if (uses.size() > 1) { + excludeAsChildren.insert(node); + } + } + } + } + // Emit possible traces. + for (auto& nodePtr : graph.nodes) { + auto* node = nodePtr.get(); + // Trace + if (DataFlow::Trace::isTraceable(node)) { + DataFlow::Trace trace(graph, node, excludeAsChildren, localGraph); + if (!trace.isBad()) { + DataFlow::Printer printer(graph, trace); + if (singleUseOnly) { + assert(!printer.printedHasExternalUses); + } + } + } + } + } +}; + +Pass *createSouperifyPass() { + return new Souperify(false); +} + +Pass *createSouperifySingleUsePass() { + return new Souperify(true); +} + +} // namespace wasm + |