diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/dataflow/graph.h | 755 | ||||
-rw-r--r-- | src/dataflow/node.h | 210 | ||||
-rw-r--r-- | src/dataflow/users.h | 106 | ||||
-rw-r--r-- | src/dataflow/utils.h | 145 | ||||
-rw-r--r-- | src/ir/utils.h | 14 | ||||
-rw-r--r-- | src/passes/CMakeLists.txt | 2 | ||||
-rw-r--r-- | src/passes/DataFlowOpts.cpp | 250 | ||||
-rw-r--r-- | src/passes/DeadCodeElimination.cpp | 2 | ||||
-rw-r--r-- | src/passes/Souperify.cpp | 691 | ||||
-rw-r--r-- | src/passes/pass.cpp | 9 | ||||
-rw-r--r-- | src/passes/passes.h | 3 | ||||
-rw-r--r-- | src/wasm-traversal.h | 2 |
12 files changed, 2184 insertions, 5 deletions
diff --git a/src/dataflow/graph.h b/src/dataflow/graph.h new file mode 100644 index 000000000..f81b4f989 --- /dev/null +++ b/src/dataflow/graph.h @@ -0,0 +1,755 @@ +/* + * Copyright 2018 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// +// DataFlow IR is an SSA representation. It can be built from the main +// Binaryen IR. +// +// THe main initial use case was an IR that could easily be converted to +// Souper IR, and the design favors that. +// + +#ifndef wasm_dataflow_graph_h +#define wasm_dataflow_graph_h + +#include "wasm.h" +#include "ir/abstract.h" +#include "ir/iteration.h" +#include "ir/literal-utils.h" +#include "dataflow/node.h" + +namespace wasm { + +namespace DataFlow { + +// Main logic to generate IR for a function. This is implemented as a +// visitor on the wasm, where visitors return a Node* that either +// contains the DataFlow IR for that expression, which can be a +// Bad node if not supported, or nullptr if not relevant (we only +// use the return value for internal expressions, that is, the +// value of a set_local or the condition of an if etc). +struct Graph : public UnifiedExpressionVisitor<Graph, Node*> { + // We only need one canonical bad node. It is never modified. + Node bad = Node(Node::Type::Bad); + + // Connects a specific set to the data in its value. + std::unordered_map<SetLocal*, Node*> setNodeMap; + + // Maps a control-flow expression to the conditions for it. Currently, + // this maps an if to the conditions for its arms + std::unordered_map<Expression*, std::vector<Node*>> expressionConditionMap; + + // Maps each expression to its control-flow parent (or null if + // there is none). We only map expressions we need to know about, + // which are sets, set values, and control-flow constructs. + std::unordered_map<Expression*, Expression*> expressionParentMap; + + // The same, for nodes. Note that we currently don't know the parents + // of nodes like phis that don't exist in wasm - we need to add more + // stuff to handle that. But we will know the parent of anything using + // that phi and storing to a local, which is probably enough anyhow + // for pc generation. + std::unordered_map<Node*, Expression*> nodeParentMap; + + // All the sets, in order of appearance. + std::vector<SetLocal*> sets; + + // The function being processed. + Function* func; + + // The module we are working in. + Module* module; + + // All of our nodes + std::vector<std::unique_ptr<Node>> nodes; + + // Tracking state during building + + // We need to track the parents of control flow nodes. + Expression* parent = nullptr; + + // Tracks the state of locals in a control flow path: + // locals[i] = the node whose value it contains + // When we are in unreachable code (i.e., a path that does not + // need to be merged in anywhere), we set the length of this + // vector to 0 to indicate that. + typedef std::vector<Node*> Locals; + + // The current local state in the control flow path being emitted. + Locals locals; + + // The local states on branches to a specific target. + std::unordered_map<Name, std::vector<Locals>> breakStates; + + // The local state in a control flow path, including a possible + // condition as well. + struct FlowState { + Locals locals; + Node* condition; + FlowState(Locals locals, Node* condition) : locals(locals), condition(condition) {} + }; + + // API + + void build(Function* funcInit, Module* moduleInit) { + func = funcInit; + module = moduleInit; + + auto numLocals = func->getNumLocals(); + if (numLocals == 0) return; // nothing to do + // Set up initial local state IR. + setInReachable(); + for (Index i = 0; i < numLocals; i++) { + if (!isRelevantType(func->getLocalType(i))) continue; + Node* node; + auto type = func->getLocalType(i); + if (func->isParam(i)) { + node = makeVar(type); + } else { + node = makeZero(type); + } + locals[i] = node; + } + // Process the function body, generating the rest of the IR. + visit(func->body); + } + + // Makes a Var node, representing a value that could be anything. + Node* makeVar(wasm::Type type) { + if (isRelevantType(type)) { + return addNode(Node::makeVar(type)); + } else { + return &bad; + } + } + + // We create one node per constant value + std::unordered_map<Literal, Node*> constantNodes; + + Node* makeConst(Literal value) { + auto iter = constantNodes.find(value); + if (iter!= constantNodes.end()) { + return iter->second; + } + // Create one for this literal. + Builder builder(*module); + auto* c = builder.makeConst(value); + auto* ret = addNode(Node::makeExpr(c, c)); + constantNodes[value] = ret; + return ret; + } + + Node* makeZero(wasm::Type type) { + return makeConst(LiteralUtils::makeLiteralZero(type)); + } + + // Add a new node to our list of owned nodes. + Node* addNode(Node* node) { + nodes.push_back(std::unique_ptr<Node>(node)); + return node; + } + + Node* makeZeroComp(Node* node, bool equal, Expression* origin) { + assert(!node->isBad()); + Builder builder(*module); + auto type = node->getWasmType(); + if (!isConcreteType(type)) return &bad; + auto* zero = makeZero(type); + auto* expr = builder.makeBinary( + Abstract::getBinary(type, equal ? Abstract::Eq : Abstract::Ne), + makeUse(node), + makeUse(zero) + ); + auto* check = addNode(Node::makeExpr(expr, origin)); + check->addValue(expandFromI1(node, origin)); + check->addValue(zero); + return check; + } + + void setInUnreachable() { + locals.clear(); + } + + void setInReachable() { + locals.resize(func->getNumLocals()); + } + + bool isInUnreachable() { + return isInUnreachable(locals); + } + + bool isInUnreachable(const Locals& state) { + return state.empty(); + } + + bool isInUnreachable(const FlowState& state) { + return isInUnreachable(state.locals); + } + + // Visiting. + + Node* visitExpression(Expression* curr) { + // Control flow and get/set etc. are special. Aside from them, we just need + // to do something very generic. + if (auto* block = curr->dynCast<Block>()) { + return doVisitBlock(block); + } else if (auto* iff = curr->dynCast<If>()) { + return doVisitIf(iff); + } else if (auto* loop = curr->dynCast<Loop>()) { + return doVisitLoop(loop); + } else if (auto* get = curr->dynCast<GetLocal>()) { + return doVisitGetLocal(get); + } else if (auto* set = curr->dynCast<SetLocal>()) { + return doVisitSetLocal(set); + } else if (auto* br = curr->dynCast<Break>()) { + return doVisitBreak(br); + } else if (auto* sw = curr->dynCast<Switch>()) { + return doVisitSwitch(sw); + } else if (auto* c = curr->dynCast<Const>()) { + return doVisitConst(c); + } else if (auto* unary = curr->dynCast<Unary>()) { + return doVisitUnary(unary); + } else if (auto* binary = curr->dynCast<Binary>()) { + return doVisitBinary(binary); + } else if (auto* select = curr->dynCast<Select>()) { + return doVisitSelect(select); + } else if (auto* unreachable = curr->dynCast<Unreachable>()) { + return doVisitUnreachable(unreachable); + } else if (auto* drop = curr->dynCast<Drop>()) { + return doVisitDrop(drop); + } else { + return doVisitGeneric(curr); + } + } + + Node* doVisitBlock(Block* curr) { + // TODO: handle super-deep nesting + auto* oldParent = parent; + expressionParentMap[curr] = oldParent; + parent = curr; + for (auto* child : curr->list) { + visit(child); + } + // Merge the outputs + // TODO handle conditions on these breaks + if (curr->name.is()) { + auto iter = breakStates.find(curr->name); + if (iter != breakStates.end()) { + auto& states = iter->second; + // Add the state flowing out + if (!isInUnreachable()) { + states.push_back(locals); + } + mergeBlock(states, locals); + } + } + parent = oldParent; + return &bad; + } + Node* doVisitIf(If* curr) { + auto* oldParent = parent; + expressionParentMap[curr] = oldParent; + parent = curr; + // Set up the condition. + Node* condition = visit(curr->condition); + assert(condition); + // Handle the contents. + auto initialState = locals; + visit(curr->ifTrue); + auto afterIfTrueState = locals; + if (curr->ifFalse) { + locals = initialState; + visit(curr->ifFalse); + auto afterIfFalseState = locals; // TODO: optimize + mergeIf(afterIfTrueState, afterIfFalseState, condition, curr, locals); + } else { + mergeIf(initialState, afterIfTrueState, condition, curr, locals); + } + parent = oldParent; + return &bad; + } + Node* doVisitLoop(Loop* curr) { + // As in Souper's LLVM extractor, we avoid loop phis, as we don't want + // our traces to represent a value that differs across loop iterations. + // For example, + // %b = block + // %x = phi %b, 1, %y + // %y = phi %b, 2, %x + // %z = eq %x %y + // infer %z + // Here %y refers to the previous iteration's %x. + // To do this, we set all locals to a Var at the loop entry, then process + // the inside of the loop. When that is done, we can see if a phi was + // actually needed for each local. If it was, we leave the Var (it + // represents an unknown value; analysis stops there), and if not, we + // can replace the Var with the fixed value. + // TODO: perhaps some more general uses of DataFlow will want loop phis? + // TODO: optimize stuff here + if (isInUnreachable()) { + return &bad; // none of this matters + } + if (!curr->name.is()) { + visit(curr->body); + return &bad; // no phis are possible + } + auto previous = locals; + auto numLocals = func->getNumLocals(); + for (Index i = 0; i < numLocals; i++) { + locals[i] = makeVar(func->getLocalType(i)); + } + auto vars = locals; // all the Vars we just created + // We may need to replace values later - only new nodes added from + // here are relevant. + auto firstNodeFromLoop = nodes.size(); + // Process the loop body. + visit(curr->body); + // Find all incoming paths. + auto& breaks = breakStates[curr->name]; + // Phis are possible, check for them. + for (Index i = 0; i < numLocals; i++) { + if (!isRelevantType(func->getLocalType(i))) continue; + bool needPhi = false; + // We replaced the proper value with a Var. If it's still that + // Var - or it's the original proper value, which can happen with + // constants - on all incoming paths, then a phi is not needed. + auto* var = vars[i]; + auto* proper = previous[i]; + for (auto& other : breaks) { + assert(!isInUnreachable(other)); + auto& curr = *(other[i]); + if (curr != *var && curr != *proper) { + // A phi would be necessary here. + needPhi = true; + break; + } + } + if (needPhi) { + // Nothing to do - leave the Vars, the loop phis are + // unknown values to us. + } else { + // Undo the Var for this local: In every new node added for + // the loop body, replace references to the Var with the + // previous value (the value that is all we need instead of a phi). + for (auto j = firstNodeFromLoop; j < nodes.size(); j++) { + for (auto*& value : nodes[j].get()->values) { + if (value == var) { + value = proper; + } + } + } + // Also undo in the current local state, which is flowing out + // of the loop. + for (auto*& node : locals) { + if (node == var) { + node = proper; + } + } + } + } + return &bad; + } + Node* doVisitBreak(Break* curr) { + if (!isInUnreachable()) { + breakStates[curr->name].push_back(locals); + } + if (!curr->condition) { + setInUnreachable(); + } else { + visit(curr->condition); + } + return &bad; + } + Node* doVisitSwitch(Switch* curr) { + visit(curr->condition); + if (!isInUnreachable()) { + std::unordered_set<Name> targets; + for (auto target : curr->targets) { + targets.insert(target); + } + targets.insert(curr->default_); + for (auto target : targets) { + breakStates[target].push_back(locals); + } + } + setInUnreachable(); + return &bad; + } + Node* doVisitGetLocal(GetLocal* curr) { + if (!isRelevantLocal(curr->index) || isInUnreachable()) { + return &bad; + } + // We now know which IR node this get refers to + auto* node = locals[curr->index]; + return node; + } + Node* doVisitSetLocal(SetLocal* curr) { + if (!isRelevantLocal(curr->index) || isInUnreachable()) { + return &bad; + } + assert(isConcreteType(curr->value->type)); + sets.push_back(curr); + expressionParentMap[curr] = parent; + expressionParentMap[curr->value] = curr; + // Set the current node in the local state. + auto* node = visit(curr->value); + locals[curr->index] = setNodeMap[curr] = node; + // If we created a new node (and not just did a get of a set, which + // passes around an existing node), mark its parent. + if (nodeParentMap.find(node) == nodeParentMap.end()) { + nodeParentMap[node] = curr; + } + return &bad; + } + Node* doVisitConst(Const* curr) { + return makeConst(curr->value); + } + Node* doVisitUnary(Unary* curr) { + // First, check if we support this op. + switch (curr->op) { + case ClzInt32: + case ClzInt64: + case CtzInt32: + case CtzInt64: + case PopcntInt32: + case PopcntInt64: { + // These are ok as-is. + // Check if our child is supported. + auto* value = expandFromI1(visit(curr->value), curr); + if (value->isBad()) return value; + // Great, we are supported! + auto* ret = addNode(Node::makeExpr(curr, curr)); + ret->addValue(value); + return ret; + } + case EqZInt32: + case EqZInt64: { + // These can be implemented using a binary. + // Check if our child is supported. + auto* value = expandFromI1(visit(curr->value), curr); + if (value->isBad()) return value; + // Great, we are supported! + return makeZeroComp(value, true, curr); + } + default: { + // Anything else is an unknown value. + return makeVar(curr->type); + } + } + } + Node* doVisitBinary(Binary *curr) { + // First, check if we support this op. + switch (curr->op) { + case AddInt32: + case AddInt64: + case SubInt32: + case SubInt64: + case MulInt32: + case MulInt64: + case DivSInt32: + case DivSInt64: + case DivUInt32: + case DivUInt64: + case RemSInt32: + case RemSInt64: + case RemUInt32: + case RemUInt64: + case AndInt32: + case AndInt64: + case OrInt32: + case OrInt64: + case XorInt32: + case XorInt64: + case ShlInt32: + case ShlInt64: + case ShrUInt32: + case ShrUInt64: + case ShrSInt32: + case ShrSInt64: + case RotLInt32: + case RotLInt64: + case RotRInt32: + case RotRInt64: + case EqInt32: + case EqInt64: + case NeInt32: + case NeInt64: + case LtSInt32: + case LtSInt64: + case LtUInt32: + case LtUInt64: + case LeSInt32: + case LeSInt64: + case LeUInt32: + case LeUInt64: { + // These are ok as-is. + // Check if our children are supported. + auto* left = expandFromI1(visit(curr->left), curr); + if (left->isBad()) return left; + auto* right = expandFromI1(visit(curr->right), curr); + if (right->isBad()) return right; + // Great, we are supported! + auto* ret = addNode(Node::makeExpr(curr, curr)); + ret->addValue(left); + ret->addValue(right); + return ret; + } + case GtSInt32: + case GtSInt64: + case GeSInt32: + case GeSInt64: + case GtUInt32: + case GtUInt64: + case GeUInt32: + case GeUInt64: { + // These need to be flipped as Souper does not support redundant ops. + Builder builder(*module); + BinaryOp opposite; + switch (curr->op) { + case GtSInt32: opposite = LeSInt32; break; + case GtSInt64: opposite = LeSInt64; break; + case GeSInt32: opposite = LtSInt32; break; + case GeSInt64: opposite = LtSInt64; break; + case GtUInt32: opposite = LeUInt32; break; + case GtUInt64: opposite = LeUInt64; break; + case GeUInt32: opposite = LtUInt32; break; + case GeUInt64: opposite = LtUInt64; break; + default: WASM_UNREACHABLE(); + } + auto* ret = visitBinary(builder.makeBinary(opposite, curr->right, curr->left)); + // We just created a new binary node, but we need to set the origin properly + // to the original. + ret->origin = curr; + return ret; + } + default: { + // Anything else is an unknown value. + return makeVar(curr->type); + } + } + } + Node* doVisitSelect(Select* curr) { + auto* ifTrue = expandFromI1(visit(curr->ifTrue), curr); + if (ifTrue->isBad()) return ifTrue; + auto* ifFalse = expandFromI1(visit(curr->ifFalse), curr); + if (ifFalse->isBad()) return ifFalse; + auto* condition = ensureI1(visit(curr->condition), curr); + if (condition->isBad()) return condition; + // Great, we are supported! + auto* ret = addNode(Node::makeExpr(curr, curr)); + ret->addValue(condition); + ret->addValue(ifTrue); + ret->addValue(ifFalse); + return ret; + } + Node* doVisitUnreachable(Unreachable* curr) { + setInUnreachable(); + return &bad; + } + Node* doVisitDrop(Drop* curr) { + visit(curr->value); + // We need to know that the value's parent is a drop, indicating + // the value is not actually used here. + expressionParentMap[curr->value] = curr; + return &bad; + } + Node* doVisitGeneric(Expression* curr) { + // Just need to visit the nodes so we note all the gets + for (auto* child : ChildIterator(curr)) { + visit(child); + } + return makeVar(curr->type); + } + + // Helpers. + + bool isRelevantType(wasm::Type type) { + return isIntegerType(type); + } + + bool isRelevantLocal(Index index) { + return isRelevantType(func->getLocalType(index)); + } + + // Merge local state for an if, also creating a block and conditions. + void mergeIf(Locals& aState, Locals& bState, Node* condition, Expression* expr, Locals& out) { + // Create the conditions (if we can). + Node* ifTrue; + Node* ifFalse; + if (!condition->isBad()) { + // Generate boolean (i1 returning) conditions for the two branches. + auto& conditions = expressionConditionMap[expr]; + ifTrue = ensureI1(condition, nullptr); + conditions.push_back(ifTrue); + ifFalse = makeZeroComp(condition, true, nullptr); + conditions.push_back(ifFalse); + } else { + ifTrue = ifFalse = &bad; + } + // Finally, merge the state with that block. TODO optimize + std::vector<FlowState> states; + if (!isInUnreachable(aState)) { + states.emplace_back(aState, ifTrue); + } + if (!isInUnreachable(bState)) { + states.emplace_back(bState, ifFalse); + } + merge(states, out); + } + + // Merge local state for a block + void mergeBlock(std::vector<Locals>& localses, Locals& out) { + // TODO: conditions + std::vector<FlowState> states; + for (auto& locals : localses) { + states.emplace_back(locals, &bad); + } + merge(states, out); + } + + // Merge local state for multiple control flow paths, creating phis as needed. + void merge(std::vector<FlowState>& states, Locals& out) { + // We should only receive reachable states. + for (auto& state : states) { + assert(!isInUnreachable(state.locals)); + } + Index numStates = states.size(); + if (numStates == 0) { + // We were unreachable, and still are. + assert(isInUnreachable()); + return; + } + // We may have just become reachable, if we were not before. + setInReachable(); + // Just one thing to merge is trivial. + if (numStates == 1) { + out = states[0].locals; + return; + } + // We create a block if we need one. + Index numLocals = func->getNumLocals(); + Node* block = nullptr; + for (Index i = 0; i < numLocals; i++) { + if (!isRelevantType(func->getLocalType(i))) continue; + // Process the inputs. If any is bad, the phi is bad. + bool bad = false; + for (auto& state : states) { + auto* node = state.locals[i]; + if (node->isBad()) { + bad = true; + out[i] = node; + break; + } + } + if (bad) continue; + // Nothing is bad, proceed. + Node* first = nullptr; + for (auto& state : states) { + if (!first) { + first = out[i] = state.locals[i]; + } else if (state.locals[i] != first) { + // We need to actually merge some stuff. + if (!block) { + block = addNode(Node::makeBlock()); + for (Index index = 0; index < numStates; index++) { + auto* condition = states[index].condition; + if (!condition->isBad()) { + condition = addNode(Node::makeCond(block, index, condition)); + } + block->addValue(condition); + } + } + auto* phi = addNode(Node::makePhi(block, i)); + for (auto& state : states) { + auto* value = expandFromI1(state.locals[i], nullptr); + phi->addValue(value); + } + out[i] = phi; + break; + } + } + } + } + + // If the node returns an i1, then we are called from a context that needs + // to use it normally as in wasm - extend it + Node* expandFromI1(Node* node, Expression* origin) { + if (!node->isBad() && node->returnsI1()) { + node = addNode(Node::makeZext(node, origin)); + } + return node; + } + + Node* ensureI1(Node* node, Expression* origin) { + if (!node->isBad() && !node->returnsI1()) { + node = makeZeroComp(node, false, origin); + } + return node; + } + + // Given a node representing something that is set_local'd, return + // the set. + SetLocal* getSet(Node* node) { + auto iter = nodeParentMap.find(node); + if (iter == nodeParentMap.end()) return nullptr; + return iter->second->dynCast<SetLocal>(); + } + + // Given an expression, return the parent if such exists. + Expression* getParent(Expression* curr) { + auto iter = expressionParentMap.find(curr); + if (iter == expressionParentMap.end()) return nullptr; + return iter->second; + } + + // Given an expression, return the set for it if such exists. + SetLocal* getSet(Expression* curr) { + auto* parent = getParent(curr); + return parent ? parent->dynCast<SetLocal>() : nullptr; + } + + // Creates an expression that uses a node. Generally, a node represents + // a value in a local, so we create a get_local for it. + Expression* makeUse(Node* node) { + Builder builder(*module); + if (node->isPhi()) { + // The index is the wasm local that we assign to when implementing + // the phi; get from there. + auto index = node->index; + return builder.makeGetLocal(index, func->getLocalType(index)); + } else if (node->isConst()) { + return builder.makeConst(node->expr->cast<Const>()->value); + } else if (node->isExpr()) { + // Find the set we are a value of. + auto index = getSet(node)->index; + return builder.makeGetLocal(index, func->getLocalType(index)); + } else if (node->isZext()) { + // i1 zexts are a no-op for wasm + return makeUse(node->values[0]); + } else if (node->isVar()) { + // Nothing valid for us to read here. + // FIXME should we have a local index to get? + return Builder(*module).makeConst(LiteralUtils::makeLiteralZero(node->wasmType)); + } else { + WASM_UNREACHABLE(); // TODO + } + } +}; + +} // namespace DataFlow + +} // namespace wasm + +#endif // wasm_dataflow_graph_h diff --git a/src/dataflow/node.h b/src/dataflow/node.h new file mode 100644 index 000000000..d6514588e --- /dev/null +++ b/src/dataflow/node.h @@ -0,0 +1,210 @@ +/* + * Copyright 2018 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// +// DataFlow IR is an SSA representation. It can be built from the main +// Binaryen IR. +// +// THe main initial use case was an IR that could easily be converted to +// Souper IR, and the design favors that. +// + +#ifndef wasm_dataflow_node_h +#define wasm_dataflow_node_h + +#include "wasm.h" + +namespace wasm { + +namespace DataFlow { + +// +// The core IR representation in DataFlow: a Node. +// +// We reuse the Binaryen IR as much as possible: when things are identical between +// the two IRs, we just create an Expr node, which stores the opcode and other +// details, and we can emit them to Souper by reading the Binaryen Expression. +// Other node types here are special things from Souper IR that we can't +// represent that way. +// +// * Souper comparisons return an i1. We extend them immediately if they are +// going to be used as i32s or i64s. +// * When we use an Expression node, we just use its immediate fields, like the +// op in a binary, alignment etc. in a load, etc. We don't look into the +// pointers to child nodes. Instead, the DataFlow IR has its own pointers +// directly to DataFlow children. In particular, this means that it's easy +// to create an Expression with the info you need and not care about linking +// it up to other Expressions. +// + +struct Node { + enum Type { + Var, // an unknown variable number (not to be confused with var/param/local in wasm) + Expr, // a value represented by a Binaryen Expression + Phi, // a phi from converging control flow + Cond, // a blockpc, representing one of the branchs for a Block + Block, // a source of phis + Zext, // zero-extend an i1 (from an op where Souper returns i1 but wasm does not, + // and so we need a special way to get back to an i32/i64 if we operate + // on that value instead of just passing it straight to Souper). + Bad // something we can't handle and should ignore + } type; + + Node(Type type) : type(type) {} + + // TODO: the others, if we need them + bool isVar() { return type == Var; } + bool isExpr() { return type == Expr; } + bool isPhi() { return type == Phi; } + bool isCond() { return type == Cond; } + bool isBlock() { return type == Block; } + bool isZext() { return type == Zext; } + bool isBad() { return type == Bad; } + + bool isConst() { return type == Expr && expr->is<Const>(); } + + union { + // For Var + wasm::Type wasmType; + // For Expr + Expression* expr; + // For Phi and Cond (the local index for phi, the block + // index for cond) + Index index; + }; + + // The wasm expression that we originate from (if such exists). A single + // wasm instruction may be turned into multiple dataflow IR nodes, and some + // nodes have no wasm origin (like phis). + Expression* origin = nullptr; + + // Extra list of related nodes. + // For Expr, these are the Nodes for the inputs to the expression (e.g. + // a binary would have 2 in this vector here). + // For Phi, this is the block and then the list of values to pick from. + // For Cond, this is the block and node. + // For Block, this is the list of Conds. Note that that block does not + // depend on them - the Phis do, but we store them in the block so that + // we can avoid duplication. + // For Zext, this is the value we extend. + std::vector<Node*> values; + + // Constructors + static Node* makeVar(wasm::Type wasmType) { + Node* ret = new Node(Var); + ret->wasmType = wasmType; + return ret; + } + static Node* makeExpr(Expression* expr, Expression* origin) { + Node* ret = new Node(Expr); + ret->expr = expr; + ret->origin = origin; + return ret; + } + static Node* makePhi(Node* block, Index index) { + Node* ret = new Node(Phi); + ret->addValue(block); + ret->index = index; + return ret; + } + static Node* makeCond(Node* block, Index index, Node* node) { + Node* ret = new Node(Cond); + ret->addValue(block); + ret->index = index; + ret->addValue(node); + return ret; + } + static Node* makeBlock() { + Node* ret = new Node(Block); + return ret; + } + static Node* makeZext(Node* child, Expression* origin) { + Node* ret = new Node(Zext); + ret->addValue(child); + ret->origin = origin; + return ret; + } + static Node* makeBad() { + Node* ret = new Node(Bad); + return ret; + } + + // Helpers + + void addValue(Node* value) { + values.push_back(value); + } + Node* getValue(Index i) { + return values.at(i); + } + + // Gets the wasm type of the node. If there isn't a valid one, + // return unreachable. + wasm::Type getWasmType() { + switch (type) { + case Var: return wasmType; + case Expr: return expr->type; + case Phi: return getValue(1)->getWasmType(); + case Zext: return getValue(0)->getWasmType(); + case Bad: return unreachable; + default: WASM_UNREACHABLE(); + } + } + + bool operator==(const Node& other) { + if (type != other.type) return false; + switch (type) { + case Var: + case Block: return this == &other; + case Expr: { + if (!ExpressionAnalyzer::equal(expr, other.expr)) { + return false; + } + break; + } + case Cond: if (index != other.index) return false; + default: {} + } + if (values.size() != other.values.size()) return false; + for (Index i = 0; i < values.size(); i++) { + if (*(values[i]) != *(other.values[i])) return false; + } + return true; + } + + bool operator!=(const Node& other) { + return !(*this == other); + } + + // As mentioned above, comparisons return i1. This checks + // if an operation is of that sort. + bool returnsI1() { + if (isExpr()) { + if (auto* binary = expr->dynCast<Binary>()) { + return binary->isRelational(); + } else if (auto* unary = expr->dynCast<Unary>()) { + return unary->isRelational(); + } + } + return false; + } +}; + +} // namespace DataFlow + +} // namespace wasm + +#endif // wasm_dataflow_node diff --git a/src/dataflow/users.h b/src/dataflow/users.h new file mode 100644 index 000000000..ab9a2ccef --- /dev/null +++ b/src/dataflow/users.h @@ -0,0 +1,106 @@ +/* + * Copyright 2018 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// +// DataFlow IR is an SSA representation. It can be built from the main +// Binaryen IR. +// +// THe main initial use case was an IR that could easily be converted to +// Souper IR, and the design favors that. +// + +#ifndef wasm_dataflow_users_h +#define wasm_dataflow_users_h + +#include "dataflow/graph.h" + +namespace wasm { + +namespace DataFlow { + +// Calculates the users of each node. +// users[x] = { y, z, .. } +// where y, z etc. are nodes that use x, that is, x is in their +// values vector. +class Users { + typedef std::unordered_set<DataFlow::Node*> UserSet; + + std::unordered_map<DataFlow::Node*, UserSet> users; + +public: + void build(Graph& graph) { + for (auto& node : graph.nodes) { + for (auto* value : node->values) { + users[value].insert(node.get()); + } + } + } + + UserSet& getUsers(Node* node) { + auto iter = users.find(node); + if (iter == users.end()) { + static UserSet empty; // FIXME thread_local? + return empty; + } + return iter->second; + } + + Index getNumUses(Node* node) { + auto& users = getUsers(node); + // A user may have more than one use + Index numUses = 0; + for (auto* user : users) { +#ifndef NDEBUG + bool found = false; +#endif + for (auto* value : user->values) { + if (value == node) { + numUses++; +#ifndef NDEBUG + found = true; +#endif + } + } + assert(found); + } + return numUses; + } + + // Stops using all the values of this node. Called when a node is being + // removed. + void stopUsingValues(Node* node) { + for (auto* value : node->values) { + auto& users = getUsers(value); + users.erase(node); + } + } + + // Adds a new user to a node. Called when we add or change a value of a node. + void addUser(Node* node, Node* newUser) { + users[node].insert(newUser); + } + + // Remove all uses of a node. Called when a node is being removed. + void removeAllUsesOf(Node* node) { + users.erase(node); + } +}; + +} // namespace DataFlow + +} // namespace wasm + +#endif // wasm_dataflow_users diff --git a/src/dataflow/utils.h b/src/dataflow/utils.h new file mode 100644 index 000000000..5328e6ab7 --- /dev/null +++ b/src/dataflow/utils.h @@ -0,0 +1,145 @@ +/* + * Copyright 2018 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// +// DataFlow IR is an SSA representation. It can be built from the main +// Binaryen IR. +// +// THe main initial use case was an IR that could easily be converted to +// Souper IR, and the design favors that. +// + +#ifndef wasm_dataflow_utils_h +#define wasm_dataflow_utils_h + +#include "wasm.h" +#include "wasm-printing.h" +#include "dataflow/node.h" + +namespace wasm { + +namespace DataFlow { + +inline std::ostream& dump(Node* node, std::ostream& o, size_t indent = 0) { + auto doIndent = [&]() { + for (size_t i = 0; i < indent; i++) o << ' '; + }; + doIndent(); + o << '[' << node << ' '; + switch (node->type) { + case Node::Type::Var: o << "var " << printType(node->wasmType) << ' ' << node; break; + case Node::Type::Expr: { + o << "expr "; + WasmPrinter::printExpression(node->expr, o, true); + break; + } + case Node::Type::Phi: o << "phi " << node->index; break; + case Node::Type::Cond: o << "cond " << node->index; break; + case Node::Type::Block: { + // don't print the conds - they would recurse + o << "block (" << node->values.size() << " conds)]\n"; + return o; + } + case Node::Type::Zext: o << "zext"; break; + case Node::Type::Bad: o << "bad"; break; + default: WASM_UNREACHABLE(); + } + if (!node->values.empty()) { + o << '\n'; + for (auto* value : node->values) { + dump(value, o, indent + 1); + } + doIndent(); + } + o << "] (origin: " << (void*)(node->origin) << ")\n"; + return o; +} + +inline std::ostream& dump(Graph& graph, std::ostream& o) { + for (auto& node : graph.nodes) { + o << "NODE " << node.get() << ": "; + dump(node.get(), o); + if (auto* set = graph.getSet(node.get())) { + o << " and that is set to local " << set->index << '\n'; + } + } + return o; +} + +// Checks if the inputs are all identical - something we could +// probably optimize. Returns false if irrelevant. +inline bool allInputsIdentical(Node* node) { + switch (node->type) { + case Node::Type::Expr: { + if (node->expr->is<Binary>()) { + return *(node->getValue(0)) == *(node->getValue(1)); + } else if (node->expr->is<Select>()) { + return *(node->getValue(1)) == *(node->getValue(2)); + } + break; + } + case Node::Type::Phi: { + auto* first = node->getValue(1); + // Check if any of the others are not equal + for (Index i = 2; i < node->values.size(); i++) { + auto* curr = node->getValue(i); + if (*first != *curr) { + return false; + } + } + return true; + } + default: {} + } + return false; +} + +// Checks if the inputs are all constant - something we could +// probably optimize. Returns false if irrelevant. +inline bool allInputsConstant(Node* node) { + switch (node->type) { + case Node::Type::Expr: { + if (node->expr->is<Unary>()) { + return node->getValue(0)->isConst(); + } else if (node->expr->is<Binary>()) { + return node->getValue(0)->isConst() && + node->getValue(1)->isConst(); + } else if (node->expr->is<Select>()) { + return node->getValue(0)->isConst() && + node->getValue(1)->isConst() && + node->getValue(2)->isConst(); + } + break; + } + case Node::Type::Phi: { + // Check if any of the others are not equal + for (Index i = 1; i < node->values.size(); i++) { + if (!node->getValue(i)->isConst()) { + return false; + } + } + return true; + } + default: {} + } + return false; +} + +} // namespace DataFlow + +} // namespace wasm + +#endif // wasm_dataflow_utils diff --git a/src/ir/utils.h b/src/ir/utils.h index 92bfcdab3..61d8917be 100644 --- a/src/ir/utils.h +++ b/src/ir/utils.h @@ -58,6 +58,7 @@ struct ExpressionAnalyzer { using ExprComparer = std::function<bool(Expression*, Expression*)>; static bool flexibleEqual(Expression* left, Expression* right, ExprComparer comparer); + // Compares two expressions for equivalence. static bool equal(Expression* left, Expression* right) { auto comparer = [](Expression* left, Expression* right) { return false; @@ -65,6 +66,19 @@ struct ExpressionAnalyzer { return flexibleEqual(left, right, comparer); } + // A shallow comparison, ignoring child nodes. + static bool shallowEqual(Expression* left, Expression* right) { + auto comparer = [left, right](Expression* currLeft, Expression* currRight) { + if (currLeft == left && currRight == right) { + // these are the ones we want to compare + return false; + } + // otherwise, don't do the comparison, we don't care + return true; + }; + return flexibleEqual(left, right, comparer); + } + // hash an expression, ignoring superficial details like specific internal names static uint32_t hash(Expression* curr); }; diff --git a/src/passes/CMakeLists.txt b/src/passes/CMakeLists.txt index 47079e4a0..fc53a194f 100644 --- a/src/passes/CMakeLists.txt +++ b/src/passes/CMakeLists.txt @@ -9,6 +9,7 @@ SET(passes_SOURCES CodePushing.cpp CodeFolding.cpp ConstHoisting.cpp + DataFlowOpts.cpp DeadCodeElimination.cpp DuplicateFunctionElimination.cpp ExtractFunction.cpp @@ -47,6 +48,7 @@ SET(passes_SOURCES TrapMode.cpp SafeHeap.cpp SimplifyLocals.cpp + Souperify.cpp SpillPointers.cpp SSAify.cpp Untee.cpp diff --git a/src/passes/DataFlowOpts.cpp b/src/passes/DataFlowOpts.cpp new file mode 100644 index 000000000..05e975049 --- /dev/null +++ b/src/passes/DataFlowOpts.cpp @@ -0,0 +1,250 @@ +/* + * Copyright 2018 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// +// Optimize using the DataFlow SSA IR. +// +// This needs 'flatten' to be run before it, and you should run full +// regular opts afterwards to clean up the flattening. For example, +// you might use it like this: +// +// --flatten --dfo -Os +// + +#include "wasm.h" +#include "pass.h" +#include "wasm-builder.h" +#include "ir/utils.h" +#include "dataflow/node.h" +#include "dataflow/graph.h" +#include "dataflow/users.h" +#include "dataflow/utils.h" + +namespace wasm { + +struct DataFlowOpts : public WalkerPass<PostWalker<DataFlowOpts>> { + bool isFunctionParallel() override { return true; } + + Pass* create() override { return new DataFlowOpts; } + + DataFlow::Users nodeUsers; + + // The optimization work left to do: nodes that we need to look at. + std::unordered_set<DataFlow::Node*> workLeft; + + DataFlow::Graph graph; + + void doWalkFunction(Function* func) { + // Build the data-flow IR. + graph.build(func, getModule()); + nodeUsers.build(graph); + // Propagate optimizations through the graph. + std::unordered_set<DataFlow::Node*> optimized; // which nodes we optimized + for (auto& node : graph.nodes) { + workLeft.insert(node.get()); // we should try to optimize each node + } + while (!workLeft.empty()) { + //std::cout << "\n\ndump before work iter\n"; + //dump(graph, std::cout); + auto iter = workLeft.begin(); + auto* node = *iter; + workLeft.erase(iter); + workOn(node); + } + // After updating the DataFlow IR, we can update the sets in + // the wasm. + // TODO: we also need phis, as a phi can flow directly into say + // a return or a call parameter. + for (auto* set : graph.sets) { + auto* node = graph.setNodeMap[set]; + auto iter = optimized.find(node); + if (iter != optimized.end()) { + assert(node->isExpr()); // this is a set, where the node is defined + set->value = node->expr; + } + } + } + + void workOn(DataFlow::Node* node) { + if (node->isConst()) return; + // If there are no uses, there is no point to work. + if (nodeUsers.getNumUses(node) == 0) return; + // Optimize: Look for nodes that we can easily convert into + // something simpler. + // TODO: we can expressionify and run full normal opts on that, + // then copy the result if it's smaller. + if (node->isPhi() && DataFlow::allInputsIdentical(node)) { + // Note we don't need to check for effects when replacing, as in + // flattened IR expression children are get_locals or consts. + auto* value = node->getValue(1); + if (value->isConst()) { + replaceAllUsesWith(node, value); + } + } else if (node->isExpr() && DataFlow::allInputsConstant(node)) { + assert(!node->isConst()); + // If this is a concrete value (not e.g. an eqz of unreachable), + // it can definitely be precomputed into a constant. + if (isConcreteType(node->expr->type)) { + // This can be precomputed. + // TODO not just all-constant inputs? E.g. i32.mul of 0 and X. + optimizeExprToConstant(node); + } + } + } + + void optimizeExprToConstant(DataFlow::Node* node) { + assert(node->isExpr()); + assert(!node->isConst()); + //std::cout << "will optimize an Expr of all constant inputs. before" << '\n'; + //dump(node, std::cout); + auto* expr = node->expr; + // First, note that some of the expression's children may be + // get_locals that we inferred during SSA analysis as constant. + // We can apply those now. + for (Index i = 0; i < node->values.size(); i++) { + if (node->values[i]->isConst()) { + auto* currp = getIndexPointer(expr, i); + if (!(*currp)->is<Const>()) { + // Directly represent it as a constant. + auto* c = node->values[i]->expr->dynCast<Const>(); + *currp = Builder(*getModule()).makeConst(c->value); + } + } + } + // Now we know that all our DataFlow inputs are constant, and all + // our Binaryen IR representations of them are constant too. RUn + // precompute, which will transform the expression into a constanat. + Module temp; + // XXX we should copy expr here, in principle, and definitely will need to + // when we do arbitrarily regenerated expressions + auto* func = Builder(temp).makeFunction("temp", std::vector<Type>{}, none, std::vector<Type>{}, expr); + PassRunner runner(&temp); + runner.setIsNested(true); + runner.add("precompute"); + runner.runOnFunction(func); + // Get the optimized thing + auto* result = func->body; + // It may not be a constant, e.g. 0 / 0 does not optimize to 0 + if (!result->is<Const>()) return; + // All good, copy it. + node->expr = Builder(*getModule()).makeConst(result->cast<Const>()->value); + assert(node->isConst()); + // We no longer have values, and so do not use anything. + nodeUsers.stopUsingValues(node); + node->values.clear(); + // Our contents changed, update our users. + replaceAllUsesWith(node, node); + } + + // Replaces all uses of a node with another value. This both modifies + // the DataFlow IR to make the other users point to this one, and + // updates the underlying Binaryen IR as well. + // This can be used to "replace" a node with itself, which makes sense + // when the node contents have changed and so the users must be updated. + void replaceAllUsesWith(DataFlow::Node* node, DataFlow::Node* with) { + // Const nodes are trivial to replace, but other stuff is trickier - + // in particular phis. + assert(with->isConst()); // TODO + // All the users should be worked on later, as we will update them. + auto& users = nodeUsers.getUsers(node); + for (auto* user : users) { + // Add the user to the work left to do, as we are modifying it. + workLeft.insert(user); + // `with` is getting another user. + nodeUsers.addUser(with, user); + // Replacing in the DataFlow IR is simple - just replace it, + // in all the indexes it appears. + std::vector<Index> indexes; + for (Index i = 0; i < user->values.size(); i++) { + if (user->values[i] == node) { + user->values[i] = with; + indexes.push_back(i); + } + } + assert(!indexes.empty()); + // Replacing in the Binaryen IR requires more care + switch (user->type) { + case DataFlow::Node::Type::Expr: { + auto* expr = user->expr; + for (auto index : indexes) { + *(getIndexPointer(expr, index)) = graph.makeUse(with); + } + break; + } + case DataFlow::Node::Type::Phi: { + // Nothing to do: a phi is not in the Binaryen IR. + // If the entire phi can become a constant, that will be + // propagated when we process that phi later. + break; + } + case DataFlow::Node::Type::Cond: { + // Nothing to do: a cond is not in the Binaryen IR. + // If the cond input is a constant, that might indicate + // useful optimizations are possible, which perhaps we + // should look into TODO + break; + } + case DataFlow::Node::Type::Zext: { + // Nothing to do: a zext is not in the Binaryen IR. + // If the cond input is a constant, that might indicate + // useful optimizations are possible, which perhaps we + // should look into TODO + break; + } + default: WASM_UNREACHABLE(); + } + } + // No one is a user of this node after we replaced all the uses. + nodeUsers.removeAllUsesOf(node); + } + + // Gets a pointer to the expression pointer in an expression. + // That is, given an index in the values() vector, get an + // Expression** that we can assign to so as to modify it. + Expression** getIndexPointer(Expression* expr, Index index) { + if (auto* unary = expr->dynCast<Unary>()) { + assert(index == 0); + return &unary->value; + } else if (auto* binary = expr->dynCast<Binary>()) { + if (index == 0) { + return &binary->left; + } else if (index == 1) { + return &binary->right; + } else { + WASM_UNREACHABLE(); + } + } else if (auto* select = expr->dynCast<Select>()) { + if (index == 0) { + return &select->condition; + } else if (index == 1) { + return &select->ifTrue; + } else if (index == 2) { + return &select->ifFalse; + } else { + WASM_UNREACHABLE(); + } + } else { + WASM_UNREACHABLE(); + } + } +}; + +Pass *createDataFlowOptsPass() { + return new DataFlowOpts(); +} + +} // namespace wasm + diff --git a/src/passes/DeadCodeElimination.cpp b/src/passes/DeadCodeElimination.cpp index a2f3b895c..1879b1fc8 100644 --- a/src/passes/DeadCodeElimination.cpp +++ b/src/passes/DeadCodeElimination.cpp @@ -68,7 +68,7 @@ struct DeadCodeElimination : public WalkerPass<PostWalker<DeadCodeElimination>> void addBreak(Name name) { // we normally have already reduced unreachable code into (unreachable) - // nodes, so we would not get to this function at all anyhow, the breaking + // nodes, so we would not get to this place at all anyhow, the breaking // instruction itself would be removed. However, an exception are things // like (block (result i32) (call $x) (unreachable)) , which has type i32 // despite not being exited. diff --git a/src/passes/Souperify.cpp b/src/passes/Souperify.cpp new file mode 100644 index 000000000..a54146d38 --- /dev/null +++ b/src/passes/Souperify.cpp @@ -0,0 +1,691 @@ +/* + * Copyright 2018 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// +// Souperify - convert to Souper IR in text form. +// +// This needs 'flatten' to be run before it, as it assumes the IR is in +// flat form. You may also want to optimize a little, e.g. +// --flatten --simplify-locals-nonesting --reorder-locals +// (as otherwise flattening introduces many copies; we do ignore boring +// copies here, but they end up as identical LHSes). +// +// See https://github.com/google/souper/issues/323 +// +// TODO: +// * pcs and blockpcs for things other than ifs +// * Investigate 'inlining', adding in nodes through calls +// * Consider generalizing DataFlow IR for internal Binaryen use. +// * Automatic conversion of Binaryen IR opts to run on the DataFlow IR. +// This would subsume precompute-propagate, for example. Using DFIR we +// can "expand" the BIR into expressions that BIR opts can handle +// directly, without the need for *-propagate techniques. +// + +#include "wasm.h" +#include "pass.h" +#include "wasm-builder.h" +#include "ir/local-graph.h" +#include "ir/utils.h" +#include "dataflow/node.h" +#include "dataflow/graph.h" +#include "dataflow/utils.h" + +namespace wasm { + +static int debug() { + static char* str = getenv("BINARYEN_DEBUG_SOUPERIFY"); + static int ret = str ? atoi(str) : 0; + return ret; +} + +namespace DataFlow { + +// Internal helper to find all the uses of a set. +struct UseFinder { + // Gets a list of all the uses of an expression. As we are in flat IR, + // the expression must be the value of a set, and we seek the other sets + // (or rather, their values) that contain a get that uses that value. + // There may also be non-set uses of the value, for example in a drop + // or a return. We represent those with a nullptr, meaning "other". + std::vector<Expression*> getUses(Expression* origin, Graph& graph, LocalGraph& localGraph) { + if (debug() >= 2) { + std::cout << "getUses\n" << origin << '\n'; + } + std::vector<Expression*> ret; + auto* set = graph.getSet(origin); + if (!set) { + // If the parent is not a set (a drop, call, return, etc.) then + // it is not something we need to track. + return ret; + } + addSetUses(set, graph, localGraph, ret); + return ret; + } + + // There may be loops of sets with copies between them. + std::unordered_set<SetLocal*> seenSets; + + void addSetUses(SetLocal* set, Graph& graph, LocalGraph& localGraph, std::vector<Expression*>& ret) { + // If already handled, nothing to do here. + if (seenSets.count(set)) return; + seenSets.insert(set); + // Find all the uses of that set. + auto& gets = localGraph.setInfluences[set]; + if (debug() >= 2) { + std::cout << "addSetUses for " << set << ", " << gets.size() << " gets\n"; + } + for (auto* get : gets) { + // Each of these relevant gets is either + // (1) a child of a set, which we can track, or + // (2) not a child of a set, e.g., a call argument or such + auto& sets = localGraph.getInfluences[get]; // TODO: iterator + // In flat IR, each get can influence at most 1 set. + assert(sets.size() <= 1); + if (sets.size() == 0) { + // This get is not the child of a set. Check if it is a drop, + // otherwise it is an actual use, and so an external use. + auto* parent = graph.getParent(get); + if (parent && parent->is<Drop>()) { + // Just ignore it. + } else { + ret.push_back(nullptr); + if (debug() >= 2) { + std::cout << "add nullptr\n"; + } + } + } else { + // This get is the child of a set. + auto* subSet = *sets.begin(); + // If this is a copy, we need to look through it: data-flow IR + // counts actual values, not copies, and in particular we need + // to look through the copies that implement a phi. + if (subSet->value == get) { + // Indeed a copy. + // TODO: this could be optimized and done all at once beforehand. + addSetUses(subSet, graph, localGraph, ret); + } else { + // Not a copy. + auto* value = subSet->value; + ret.push_back(value); + if (debug() >= 2) { + std::cout << "add a value\n" << value << '\n'; + } + } + } + } + } +}; + +// Generates a trace: all the information to generate a Souper LHS +// for a specific set_local whose value we want to infer. +struct Trace { + Graph& graph; + Node* toInfer; + // Nodes we should exclude from being children of traces (but they + // may be the root we try to infer. + std::unordered_set<Node*>& excludeAsChildren; + + // A limit on how deep we go - we don't want to create arbitrarily + // large traces. + size_t depthLimit = 10; + size_t totalLimit = 30; + + bool bad = false; + std::vector<Node*> nodes; + std::unordered_set<Node*> addedNodes; + std::vector<Node*> pathConditions; + // When we need to (like when the depth is too deep), we replace + // expressions with other expressions, and track them here. + std::unordered_map<Node*, std::unique_ptr<Node>> replacements; + // The nodes that have additional external uses (only computed + // for the "work" nodes, not the descriptive nodes arriving for + // path conditions). + std::unordered_set<Node*> hasExternalUses; + // We add path conditions after the "work". We collect them here + // and then go through them at the proper time. + std::vector<Node*> conditionsToAdd; + // Whether we are at the adding-conditions stage (i.e., post + // adding the "work"). + bool addingConditions = false; + // The local information graph. Used to check if a node has external uses. + LocalGraph& localGraph; + + Trace(Graph& graph, Node* toInfer, std::unordered_set<Node*>& excludeAsChildren, LocalGraph& localGraph) : graph(graph), toInfer(toInfer), excludeAsChildren(excludeAsChildren), localGraph(localGraph) { + if (debug() >= 2) { + std::cout << "\nstart a trace (in " << graph.func->name << ")\n"; + } + // Check if there is a depth limit override + auto* depthLimitStr = getenv("BINARYEN_SOUPERIFY_DEPTH_LIMIT"); + if (depthLimitStr) { + depthLimit = atoi(depthLimitStr); + } + auto* totalLimitStr = getenv("BINARYEN_SOUPERIFY_TOTAL_LIMIT"); + if (totalLimitStr) { + totalLimit = atoi(totalLimitStr); + } + // Pull in all the dependencies, starting from the value itself. + add(toInfer, 0); + if (bad) return; + // If we are trivial before adding pcs, we are still trivial, and + // can ignore this. + auto sizeBeforePathConditions = nodes.size(); + // No input is uninteresting + if (sizeBeforePathConditions == 0) { + bad = true; + return; + } + // Just a var is uninteresting. TODO: others too? + if (sizeBeforePathConditions == 1 && nodes[0]->isVar()) { + bad = true; + return; + } + // Before adding the path conditions, we can now compute the + // actual number of uses of "work" nodes, the real computation done + // here and that we hope to replace, as opposed to path condition + // computation which is only descriptive and helps optimization of + // the work. + findExternalUses(); + // We can now add conditions. + addingConditions = true; + for (auto* condition : conditionsToAdd) { + add(condition, 0); + } + // Add in path conditions based on the location of this node: e.g. + // if it is inside an if's true branch, we can add a path-condition + // for that. + auto iter = graph.nodeParentMap.find(toInfer); + if (iter != graph.nodeParentMap.end()) { + addPath(toInfer, iter->second); + } + } + + Node* add(Node* node, size_t depth) { + depth++; + // If replaced, return the replacement. + auto iter = replacements.find(node); + if (iter != replacements.end()) { + return iter->second.get(); + } + // If already added, nothing more to do. + if (addedNodes.find(node) != addedNodes.end()) { + return node; + } + switch (node->type) { + case Node::Type::Var: { + break; // nothing more to add + } + case Node::Type::Expr: { + // If this is a Const, it's not an instruction - nothing to add, + // it's just a value. + if (node->expr->is<Const>()) { + return node; + } + // If we've gone too deep, emit a var instead. + // Do the same if this is a node we should exclude from traces. + if (depth >= depthLimit || nodes.size() >= totalLimit || + (node != toInfer && excludeAsChildren.find(node) != excludeAsChildren.end())) { + auto type = node->getWasmType(); + assert(isConcreteType(type)); + auto* var = Node::makeVar(type); + replacements[node] = std::unique_ptr<Node>(var); + node = var; + break; + } + // Add the dependencies. + assert(!node->expr->is<GetLocal>()); + for (Index i = 0; i < node->values.size(); i++) { + add(node->getValue(i), depth); + } + break; + } + case Node::Type::Phi: { + auto* block = add(node->getValue(0), depth); + assert(block); + auto size = block->values.size(); + // First, add the conditions for the block + for (Index i = 0; i < size; i++) { + // a condition may be bad, but conditions are not necessary - + // we can proceed without the extra condition information + auto* condition = block->getValue(i); + if (!condition->isBad()) { + if (!addingConditions) { + // Too early, queue it for later. + conditionsToAdd.push_back(condition); + } else { + add(condition, depth); + } + } + } + // Then, add the phi values + for (Index i = 1; i < size + 1; i++) { + add(node->getValue(i), depth); + } + break; + } + case Node::Type::Cond: { + add(node->getValue(0), depth); // add the block + add(node->getValue(1), depth); // add the node + break; + } + case Node::Type::Block: { + break; // nothing more to add + } + case Node::Type::Zext: { + add(node->getValue(0), depth); + break; + } + case Node::Type::Bad: { + bad = true; + return nullptr; + } + default: WASM_UNREACHABLE(); + } + // Assert on no cycles + assert(addedNodes.find(node) == addedNodes.end()); + nodes.push_back(node); + addedNodes.insert(node); + return node; + } + + void addPath(Node* node, Expression* curr) { + // We track curr and parent, which are always in the state of parent + // being the parent of curr. + auto* parent = graph.expressionParentMap.at(curr); + while (parent) { + auto iter = graph.expressionConditionMap.find(parent); + if (iter != graph.expressionConditionMap.end()) { + // Given the block, add a proper path-condition + addPathTo(parent, curr, iter->second); + } + curr = parent; + parent = graph.expressionParentMap.at(parent); + } + } + + // curr is a child of parent, and parent has a Block which we are + // give as 'node'. Add a path condition for reaching the child. + void addPathTo(Expression* parent, Expression* curr, std::vector<Node*> conditions) { + if (auto* iff = parent->dynCast<If>()) { + Index index; + if (curr == iff->ifTrue) { + index = 0; + } else if (curr == iff->ifFalse) { + index = 1; + } else { + WASM_UNREACHABLE(); + } + auto* condition = conditions[index]; + // Add the condition itself as an instruction in the trace - + // the pc uses it as its input. + add(condition, 0); + // Add it as a pc, which we will emit directly. + pathConditions.push_back(condition); + } else { + WASM_UNREACHABLE(); + } + } + + bool isBad() { + return bad; + } + + static bool isTraceable(Node* node) { + if (!node->origin) { + // Ignore artificial etc. nodes. + // TODO: perhaps require all the nodes for an origin appear, so we + // don't try to compute an internal part of one, like the + // extra artificial != 0 of a select? + return false; + } + if (node->isExpr()) { + // Consider only the simple computational nodes. + auto* expr = node->expr; + return expr->is<Unary>() || expr->is<Binary>() || expr->is<Select>(); + } + return false; + } + + void findExternalUses() { + // Find all the wasm code represented in this trace. + std::unordered_set<Expression*> origins; + for (auto& node : nodes) { + if (auto* origin = node->origin) { + if (debug() >= 2) { + std::cout << "note origin " << origin << '\n'; + } + origins.insert(origin); + } + } + for (auto& node : nodes) { + if (node == toInfer) continue; + if (auto* origin = node->origin) { + auto uses = UseFinder().getUses(origin, graph, localGraph); + for (auto* use : uses) { + // A non-set use (a drop or return etc.) is definitely external. + // Otherwise, check if internal or external. + if (use == nullptr || origins.count(use) == 0) { + if (debug() >= 2) { + std::cout << "found external use for\n"; + dump(node, std::cout); + std::cout << " due to " << use << '\n'; + } + hasExternalUses.insert(node); + break; + } + } + } + } + } +}; + +// Emits a trace, which is basically a Souper LHS. +struct Printer { + Graph& graph; + Trace& trace; + + // Each Node in a trace has an index, from 0. + std::unordered_map<Node*, Index> indexing; + + bool printedHasExternalUses = false; + + Printer(Graph& graph, Trace& trace) : graph(graph), trace(trace) { + std::cout << "\n; start LHS (in " << graph.func->name << ")\n"; + // Index the nodes. + for (auto* node : trace.nodes) { + if (!node->isCond()) { // pcs and blockpcs are not instructions and do not need to be indexed + auto index = indexing.size(); + indexing[node] = index; + } + } + // Print them out. + for (auto* node : trace.nodes) { + print(node); + } + // Print out pcs. + for (auto* condition : trace.pathConditions) { + printPathCondition(condition); + } + + // Finish up + std::cout << "infer %" << indexing[trace.toInfer] << "\n\n"; + } + + Node* getMaybeReplaced(Node* node) { + auto iter = trace.replacements.find(node); + if (iter != trace.replacements.end()) { + return iter->second.get(); + } + return node; + } + + void print(Node* node) { + // The node may have been replaced during trace building, if so then + // print the proper replacement. + node = getMaybeReplaced(node); + assert(node); + switch (node->type) { + case Node::Type::Var: { + std::cout << "%" << indexing[node] << ":" << printType(node->wasmType) << " = var"; + break; // nothing more to add + } + case Node::Type::Expr: { + if (debug()) { + std::cout << "; "; + WasmPrinter::printExpression(node->expr, std::cout, true); + std::cout << '\n'; + } + std::cout << "%" << indexing[node] << " = "; + printExpression(node); + break; + } + case Node::Type::Phi: { + auto* block = node->getValue(0); + auto size = block->values.size(); + std::cout << "%" << indexing[node] << " = phi %" << indexing[block]; + for (Index i = 1; i < size + 1; i++) { + std::cout << ", "; + printInternal(node->getValue(i)); + } + break; + } + case Node::Type::Cond: { + std::cout << "blockpc %" << indexing[node->getValue(0)] << ' ' << node->index << ' '; + printInternal(node->getValue(1)); + std::cout << " 1:i1"; + break; + } + case Node::Type::Block: { + std::cout << "%" << indexing[node] << " = block " << node->values.size(); + break; + } + case Node::Type::Zext: { + auto* child = node->getValue(0); + std::cout << "%" << indexing[node] << ':' << printType(child->getWasmType()); + std::cout << " = zext "; + printInternal(child); + break; + } + case Node::Type::Bad: { + std::cout << "!!!BAD!!!"; + WASM_UNREACHABLE(); + } + default: WASM_UNREACHABLE(); + } + if (node->isExpr() || node->isPhi()) { + if (node->origin != trace.toInfer->origin && trace.hasExternalUses.count(node) > 0) { + std::cout << " (hasExternalUses)"; + printedHasExternalUses = true; + } + } + std::cout << '\n'; + if (debug() && (node->isExpr() || node->isPhi())) { + warnOnSuspiciousValues(node); + } + } + + void print(Literal value) { + std::cout << value.getInteger() << ':' << printType(value.type); + } + + void printInternal(Node* node) { + node = getMaybeReplaced(node); + assert(node); + if (node->isConst()) { + print(node->expr->cast<Const>()->value); + } else { + std::cout << "%" << indexing[node]; + } + } + + // Emit an expression + + void printExpression(Node* node) { + assert(node->isExpr()); + // TODO use a Visitor here? + auto* curr = node->expr; + if (auto* c = curr->dynCast<Const>()) { + print(c->value); + } else if (auto* unary = curr->dynCast<Unary>()) { + switch (unary->op) { + case ClzInt32: + case ClzInt64: std::cout << "ctlz"; break; + case CtzInt32: + case CtzInt64: std::cout << "cttz"; break; + case PopcntInt32: + case PopcntInt64: std::cout << "ctpop"; break; + default: WASM_UNREACHABLE(); + } + std::cout << ' '; + auto* value = node->getValue(0); + printInternal(value); + } else if (auto* binary = curr->dynCast<Binary>()) { + switch (binary->op) { + case AddInt32: + case AddInt64: std::cout << "add"; break; + case SubInt32: + case SubInt64: std::cout << "sub"; break; + case MulInt32: + case MulInt64: std::cout << "mul"; break; + case DivSInt32: + case DivSInt64: std::cout << "sdiv"; break; + case DivUInt32: + case DivUInt64: std::cout << "udiv"; break; + case RemSInt32: + case RemSInt64: std::cout << "srem"; break; + case RemUInt32: + case RemUInt64: std::cout << "urem"; break; + case AndInt32: + case AndInt64: std::cout << "and"; break; + case OrInt32: + case OrInt64: std::cout << "or"; break; + case XorInt32: + case XorInt64: std::cout << "xor"; break; + case ShlInt32: + case ShlInt64: std::cout << "shl"; break; + case ShrUInt32: + case ShrUInt64: std::cout << "lshr"; break; + case ShrSInt32: + case ShrSInt64: std::cout << "ashr"; break; + case RotLInt32: + case RotLInt64: std::cout << "rotl"; break; + case RotRInt32: + case RotRInt64: std::cout << "rotr"; break; + case EqInt32: + case EqInt64: std::cout << "eq"; break; + case NeInt32: + case NeInt64: std::cout << "ne"; break; + case LtSInt32: + case LtSInt64: std::cout << "slt"; break; + case LtUInt32: + case LtUInt64: std::cout << "ult"; break; + case LeSInt32: + case LeSInt64: std::cout << "sle"; break; + case LeUInt32: + case LeUInt64: std::cout << "ule"; break; + default: WASM_UNREACHABLE(); + } + std::cout << ' '; + auto* left = node->getValue(0); + printInternal(left); + std::cout << ", "; + auto* right = node->getValue(1); + printInternal(right); + } else if (curr->is<Select>()) { + std::cout << "select "; + printInternal(node->getValue(0)); + std::cout << ", "; + printInternal(node->getValue(1)); + std::cout << ", "; + printInternal(node->getValue(2)); + } else { + WASM_UNREACHABLE(); + } + } + + void printPathCondition(Node* condition) { + std::cout << "pc "; + printInternal(condition); + std::cout << " 1:i1\n"; + } + + // Checks if a value looks suspiciously optimizable. + void warnOnSuspiciousValues(Node* node) { + assert(debug()); + // If the node has no uses, it's not interesting enough to be + // suspicious. TODO + // If an input was replaced with a var, then we should not + // look into it, it's not suspiciously trivial. + for (auto* value : node->values) { + if (value != getMaybeReplaced(value)) { + return; + } + } + if (allInputsIdentical(node)) { + std::cout << "^^ suspicious identical inputs! missing optimization in " << graph.func->name << "? ^^\n"; + return; + } + if (!node->isPhi() && allInputsConstant(node)) { + std::cout << "^^ suspicious constant inputs! missing optimization in " << graph.func->name << "? ^^\n"; + return; + } + } +}; + +} // namespace DataFlow + +struct Souperify : public WalkerPass<PostWalker<Souperify>> { + // Not parallel, for now - could parallelize and combine outputs at the end. + // If Souper is thread-safe, we could also run it in parallel. + + bool singleUseOnly; + + Souperify(bool singleUseOnly) : singleUseOnly(singleUseOnly) {} + + void doWalkFunction(Function* func) { + std::cout << "\n; function: " << func->name << '\n'; + // Build the data-flow IR. + DataFlow::Graph graph; + graph.build(func, getModule()); + if (debug() >= 2) dump(graph, std::cout); + // Build the local graph data structure. + LocalGraph localGraph(func); + localGraph.computeInfluences(); + // If we only want single-use nodes, exclude all the others. + std::unordered_set<DataFlow::Node*> excludeAsChildren; + if (singleUseOnly) { + for (auto& nodePtr : graph.nodes) { + auto* node = nodePtr.get(); + if (node->origin) { + // TODO: work for identical origins could be saved + auto uses = DataFlow::UseFinder().getUses(node->origin, graph, localGraph); + if (debug() >= 2) { + std::cout << "following node has " << uses.size() << " uses\n"; + dump(node, std::cout); + } + if (uses.size() > 1) { + excludeAsChildren.insert(node); + } + } + } + } + // Emit possible traces. + for (auto& nodePtr : graph.nodes) { + auto* node = nodePtr.get(); + // Trace + if (DataFlow::Trace::isTraceable(node)) { + DataFlow::Trace trace(graph, node, excludeAsChildren, localGraph); + if (!trace.isBad()) { + DataFlow::Printer printer(graph, trace); + if (singleUseOnly) { + assert(!printer.printedHasExternalUses); + } + } + } + } + } +}; + +Pass *createSouperifyPass() { + return new Souperify(false); +} + +Pass *createSouperifySingleUsePass() { + return new Souperify(true); +} + +} // namespace wasm + diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp index 97151f847..c8d02f445 100644 --- a/src/passes/pass.cpp +++ b/src/passes/pass.cpp @@ -72,6 +72,7 @@ void PassRegistry::registerPasses() { registerPass("code-folding", "fold code, merging duplicates", createCodeFoldingPass); registerPass("const-hoisting", "hoist repeated constants to a local", createConstHoistingPass); registerPass("dce", "removes unreachable code", createDeadCodeEliminationPass); + registerPass("dfo", "optimizes using the DataFlow SSA IR", createDataFlowOptsPass); registerPass("duplicate-function-elimination", "removes duplicate functions", createDuplicateFunctionEliminationPass); registerPass("extract-function", "leaves just one function (useful for debugging)", createExtractFunctionPass); registerPass("flatten", "flattens out code, removing nesting", createFlattenPass); @@ -117,9 +118,11 @@ void PassRegistry::registerPasses() { registerPass("safe-heap", "instrument loads and stores to check for invalid behavior", createSafeHeapPass); registerPass("simplify-locals", "miscellaneous locals-related optimizations", createSimplifyLocalsPass); registerPass("simplify-locals-nonesting", "miscellaneous locals-related optimizations (no nesting at all; preserves flatness)", createSimplifyLocalsNoNestingPass); - registerPass("simplify-locals-notee", "miscellaneous locals-related optimizations", createSimplifyLocalsNoTeePass); - registerPass("simplify-locals-nostructure", "miscellaneous locals-related optimizations", createSimplifyLocalsNoStructurePass); - registerPass("simplify-locals-notee-nostructure", "miscellaneous locals-related optimizations", createSimplifyLocalsNoTeeNoStructurePass); + registerPass("simplify-locals-notee", "miscellaneous locals-related optimizations (no tees)", createSimplifyLocalsNoTeePass); + registerPass("simplify-locals-nostructure", "miscellaneous locals-related optimizations (no structure)", createSimplifyLocalsNoStructurePass); + registerPass("simplify-locals-notee-nostructure", "miscellaneous locals-related optimizations (no tees or structure)", createSimplifyLocalsNoTeeNoStructurePass); + registerPass("souperify", "emit Souper IR in text form", createSouperifyPass); + registerPass("souperify-single-use", "emit Souper IR in text form (single-use nodes only)", createSouperifySingleUsePass); registerPass("spill-pointers", "spill pointers to the C stack (useful for Boehm-style GC)", createSpillPointersPass); registerPass("ssa", "ssa-ify variables so that they have a single assignment", createSSAifyPass); registerPass("trap-mode-clamp", "replace trapping operations with clamping semantics", createTrapModeClamp); diff --git a/src/passes/passes.h b/src/passes/passes.h index 1e26dc777..e853c040a 100644 --- a/src/passes/passes.h +++ b/src/passes/passes.h @@ -27,6 +27,7 @@ Pass* createCoalesceLocalsWithLearningPass(); Pass* createCodeFoldingPass(); Pass* createCodePushingPass(); Pass* createConstHoistingPass(); +Pass* createDataFlowOptsPass(); Pass* createDeadCodeEliminationPass(); Pass* createDuplicateFunctionEliminationPass(); Pass* createExtractFunctionPass(); @@ -76,6 +77,8 @@ Pass* createSimplifyLocalsNoNestingPass(); Pass* createSimplifyLocalsNoTeePass(); Pass* createSimplifyLocalsNoStructurePass(); Pass* createSimplifyLocalsNoTeeNoStructurePass(); +Pass* createSouperifyPass(); +Pass* createSouperifySingleUsePass(); Pass* createSpillPointersPass(); Pass* createSSAifyPass(); Pass* createTrapModeClamp(); diff --git a/src/wasm-traversal.h b/src/wasm-traversal.h index 0c5088917..0c775e872 100644 --- a/src/wasm-traversal.h +++ b/src/wasm-traversal.h @@ -213,7 +213,7 @@ struct OverriddenVisitor { // separate visit* per node template<typename SubType, typename ReturnType = void> -struct UnifiedExpressionVisitor : public Visitor<SubType> { +struct UnifiedExpressionVisitor : public Visitor<SubType, ReturnType> { // called on each node ReturnType visitExpression(Expression* curr) { return ReturnType(); } |