diff options
author | Alon Zakai <alonzakai@gmail.com> | 2016-10-26 10:42:48 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-10-26 10:42:48 -0700 |
commit | e120601c0b2cbc5722950fb1ce7d0901842f5dff (patch) | |
tree | cffc0124e0fd4fc41ec7d8d8083722c7e66a2d67 | |
parent | c5ab566cc3343d3b9e07eab4855b0dbfb2c81afe (diff) | |
download | binaryen-e120601c0b2cbc5722950fb1ce7d0901842f5dff.tar.gz binaryen-e120601c0b2cbc5722950fb1ce7d0901842f5dff.tar.bz2 binaryen-e120601c0b2cbc5722950fb1ce7d0901842f5dff.zip |
Conditionalize boolean operations based on cost (#805)
When we have expensive | expensive, and both are boolean, then we can execute one of them conditionally if it doesn't have side effects.
-rw-r--r-- | src/ast/cost.h | 249 | ||||
-rw-r--r-- | src/ast_utils.h | 3 | ||||
-rw-r--r-- | src/passes/OptimizeInstructions.cpp | 46 | ||||
-rw-r--r-- | test/passes/optimize-instructions_optimize-level=2.txt | 260 | ||||
-rw-r--r-- | test/passes/optimize-instructions_optimize-level=2.wast | 263 |
5 files changed, 820 insertions, 1 deletions
diff --git a/src/ast/cost.h b/src/ast/cost.h new file mode 100644 index 000000000..151468650 --- /dev/null +++ b/src/ast/cost.h @@ -0,0 +1,249 @@ +/* + * Copyright 2016 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef wasm_ast_cost_h +#define wasm_ast_cost_h + +namespace wasm { + +// Measure the execution cost of an AST. Very handwave-ey + +struct CostAnalyzer : public Visitor<CostAnalyzer, Index> { + CostAnalyzer(Expression *ast) { + assert(ast); + cost = visit(ast); + } + + Index cost; + + Index maybeVisit(Expression* curr) { + return curr ? visit(curr) : 0; + } + + Index visitBlock(Block *curr) { + Index ret = 0; + for (auto* child : curr->list) ret += visit(child); + return ret; + } + Index visitIf(If *curr) { + return 1 + visit(curr->condition) + std::max(visit(curr->ifTrue), maybeVisit(curr->ifFalse)); + } + Index visitLoop(Loop *curr) { + return 5 * visit(curr->body); + } + Index visitBreak(Break *curr) { + return 1 + maybeVisit(curr->value) + maybeVisit(curr->condition); + } + Index visitSwitch(Switch *curr) { + return 2 + visit(curr->condition) + maybeVisit(curr->value); + } + Index visitCall(Call *curr) { + Index ret = 4; + for (auto* child : curr->operands) ret += visit(child); + return ret; + } + Index visitCallImport(CallImport *curr) { + Index ret = 15; + for (auto* child : curr->operands) ret += visit(child); + return ret; + } + Index visitCallIndirect(CallIndirect *curr) { + Index ret = 6 + visit(curr->target); + for (auto* child : curr->operands) ret += visit(child); + return ret; + } + Index visitGetLocal(GetLocal *curr) { + return 0; + } + Index visitSetLocal(SetLocal *curr) { + return 1; + } + Index visitGetGlobal(GetGlobal *curr) { + return 1; + } + Index visitSetGlobal(SetGlobal *curr) { + return 2; + } + Index visitLoad(Load *curr) { + return 1 + visit(curr->ptr); + } + Index visitStore(Store *curr) { + return 2 + visit(curr->ptr) + visit(curr->value); + } + Index visitConst(Const *curr) { + return 1; + } + Index visitUnary(Unary *curr) { + Index ret = 0; + switch (curr->op) { + case ClzInt32: + case CtzInt32: + case PopcntInt32: + case NegFloat32: + case AbsFloat32: + case CeilFloat32: + case FloorFloat32: + case TruncFloat32: + case NearestFloat32: + case ClzInt64: + case CtzInt64: + case PopcntInt64: + case NegFloat64: + case AbsFloat64: + case CeilFloat64: + case FloorFloat64: + case TruncFloat64: + case NearestFloat64: + case EqZInt32: + case EqZInt64: + case ExtendSInt32: + case ExtendUInt32: + case WrapInt64: + case PromoteFloat32: + case DemoteFloat64: + case TruncSFloat32ToInt32: + case TruncUFloat32ToInt32: + case TruncSFloat64ToInt32: + case TruncUFloat64ToInt32: + case ReinterpretFloat32: + case TruncSFloat32ToInt64: + case TruncUFloat32ToInt64: + case TruncSFloat64ToInt64: + case TruncUFloat64ToInt64: + case ReinterpretFloat64: + case ReinterpretInt32: + case ConvertSInt32ToFloat32: + case ConvertUInt32ToFloat32: + case ConvertSInt64ToFloat32: + case ConvertUInt64ToFloat32: + case ReinterpretInt64: + case ConvertSInt32ToFloat64: + case ConvertUInt32ToFloat64: + case ConvertSInt64ToFloat64: + case ConvertUInt64ToFloat64: ret = 1; break; + case SqrtFloat32: + case SqrtFloat64: ret = 2; break; + default: WASM_UNREACHABLE(); + } + return ret + visit(curr->value); + } + Index visitBinary(Binary *curr) { + Index ret = 0; + switch (curr->op) { + case AddInt32: ret = 1; break; + case SubInt32: ret = 1; break; + case MulInt32: ret = 2; break; + case DivSInt32: ret = 3; break; + case DivUInt32: ret = 3; break; + case RemSInt32: ret = 3; break; + case RemUInt32: ret = 3; break; + case AndInt32: ret = 1; break; + case OrInt32: ret = 1; break; + case XorInt32: ret = 1; break; + case ShlInt32: ret = 1; break; + case ShrUInt32: ret = 1; break; + case ShrSInt32: ret = 1; break; + case RotLInt32: ret = 1; break; + case RotRInt32: ret = 1; break; + case AddInt64: ret = 1; break; + case SubInt64: ret = 1; break; + case MulInt64: ret = 2; break; + case DivSInt64: ret = 3; break; + case DivUInt64: ret = 3; break; + case RemSInt64: ret = 3; break; + case RemUInt64: ret = 3; break; + case AndInt64: ret = 1; break; + case OrInt64: ret = 1; break; + case XorInt64: ret = 1; break; + case ShlInt64: ret = 1; break; + case ShrUInt64: ret = 1; break; + case ShrSInt64: ret = 1; break; + case RotLInt64: ret = 1; break; + case RotRInt64: ret = 1; break; + case AddFloat32: ret = 1; break; + case SubFloat32: ret = 1; break; + case MulFloat32: ret = 2; break; + case DivFloat32: ret = 3; break; + case CopySignFloat32: ret = 1; break; + case MinFloat32: ret = 1; break; + case MaxFloat32: ret = 1; break; + case AddFloat64: ret = 1; break; + case SubFloat64: ret = 1; break; + case MulFloat64: ret = 2; break; + case DivFloat64: ret = 3; break; + case CopySignFloat64: ret = 1; break; + case MinFloat64: ret = 1; break; + case MaxFloat64: ret = 1; break; + case LtUInt32: ret = 1; break; + case LtSInt32: ret = 1; break; + case LeUInt32: ret = 1; break; + case LeSInt32: ret = 1; break; + case GtUInt32: ret = 1; break; + case GtSInt32: ret = 1; break; + case GeUInt32: ret = 1; break; + case GeSInt32: ret = 1; break; + case LtUInt64: ret = 1; break; + case LtSInt64: ret = 1; break; + case LeUInt64: ret = 1; break; + case LeSInt64: ret = 1; break; + case GtUInt64: ret = 1; break; + case GtSInt64: ret = 1; break; + case GeUInt64: ret = 1; break; + case GeSInt64: ret = 1; break; + case LtFloat32: ret = 1; break; + case GtFloat32: ret = 1; break; + case LeFloat32: ret = 1; break; + case GeFloat32: ret = 1; break; + case LtFloat64: ret = 1; break; + case GtFloat64: ret = 1; break; + case LeFloat64: ret = 1; break; + case GeFloat64: ret = 1; break; + case EqInt32: ret = 1; break; + case NeInt32: ret = 1; break; + case EqInt64: ret = 1; break; + case NeInt64: ret = 1; break; + case EqFloat32: ret = 1; break; + case NeFloat32: ret = 1; break; + case EqFloat64: ret = 1; break; + case NeFloat64: ret = 1; break; + default: WASM_UNREACHABLE(); + } + return ret + visit(curr->left) + visit(curr->right); + } + Index visitSelect(Select *curr) { + return 2 + visit(curr->condition) + visit(curr->ifTrue) + visit(curr->ifFalse); + } + Index visitDrop(Drop *curr) { + return visit(curr->value); + } + Index visitReturn(Return *curr) { + return maybeVisit(curr->value); + } + Index visitHost(Host *curr) { + return 100; + } + Index visitNop(Nop *curr) { + return 0; + } + Index visitUnreachable(Unreachable *curr) { + return 0; + } +}; + +} // namespace wasm + +#endif // wasm_ast_cost_h + diff --git a/src/ast_utils.h b/src/ast_utils.h index 861a09aaf..9dfacb972 100644 --- a/src/ast_utils.h +++ b/src/ast_utils.h @@ -209,7 +209,8 @@ struct EffectAnalyzer : public PostWalker<EffectAnalyzer, Visitor<EffectAnalyzer void visitUnreachable(Unreachable *curr) { branches = true; } }; -// Meausure the size of an AST +// Measure the size of an AST + struct Measurer : public PostWalker<Measurer, UnifiedExpressionVisitor<Measurer>> { Index size = 0; diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp index a5c665458..84f126ae9 100644 --- a/src/passes/OptimizeInstructions.cpp +++ b/src/passes/OptimizeInstructions.cpp @@ -25,6 +25,7 @@ #include <wasm-s-parser.h> #include <support/threads.h> #include <ast_utils.h> +#include <ast/cost.h> #include <ast/properties.h> namespace wasm { @@ -251,6 +252,9 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, } } } + return conditionalizeExpensiveOnBitwise(binary); + } else if (binary->op == OrInt32) { + return conditionalizeExpensiveOnBitwise(binary); } } else if (auto* unary = curr->dynCast<Unary>()) { // de-morgan's laws @@ -360,6 +364,48 @@ private: } return boolean; } + + // expensive1 | expensive2 can be turned into expensive1 ? 1 : expensive2, and + // expensive | cheap can be turned into cheap ? 1 : expensive, + // so that we can avoid one expensive computation, if it has no side effects. + Expression* conditionalizeExpensiveOnBitwise(Binary* binary) { + // this operation can increase code size, so don't always do it + auto& options = getPassRunner()->options; + if (options.optimizeLevel < 2 || options.shrinkLevel > 0) return nullptr; + const auto MIN_COST = 7; + assert(binary->op == AndInt32 || binary->op == OrInt32); + if (binary->right->is<Const>()) return nullptr; // trivial + // bitwise logical operator on two non-numerical values, check if they are boolean + auto* left = binary->left; + auto* right = binary->right; + if (!Properties::emitsBoolean(left) || !Properties::emitsBoolean(right)) return nullptr; + auto leftEffects = EffectAnalyzer(left).hasSideEffects(); + auto rightEffects = EffectAnalyzer(right).hasSideEffects(); + if (leftEffects && rightEffects) return nullptr; // both must execute + // canonicalize with side effects, if any, happening on the left + if (rightEffects) { + if (CostAnalyzer(left).cost < MIN_COST) return nullptr; // avoidable code is too cheap + std::swap(left, right); + } else if (leftEffects) { + if (CostAnalyzer(right).cost < MIN_COST) return nullptr; // avoidable code is too cheap + } else { + // no side effects, reorder based on cost estimation + auto leftCost = CostAnalyzer(left).cost; + auto rightCost = CostAnalyzer(right).cost; + if (std::max(leftCost, rightCost) < MIN_COST) return nullptr; // avoidable code is too cheap + // canonicalize with expensive code on the right + if (leftCost > rightCost) { + std::swap(left, right); + } + } + // worth it! perform conditionalization + Builder builder(*getModule()); + if (binary->op == OrInt32) { + return builder.makeIf(left, builder.makeConst(Literal(int32_t(1))), right); + } else { // & + return builder.makeIf(left, right, builder.makeConst(Literal(int32_t(0)))); + } + } }; Pass *createOptimizeInstructionsPass() { diff --git a/test/passes/optimize-instructions_optimize-level=2.txt b/test/passes/optimize-instructions_optimize-level=2.txt new file mode 100644 index 000000000..243421e8c --- /dev/null +++ b/test/passes/optimize-instructions_optimize-level=2.txt @@ -0,0 +1,260 @@ +(module + (type $0 (func (param i32 i32) (result i32))) + (memory $0 0) + (func $conditionals (type $0) (param $0 i32) (param $1 i32) (result i32) + (local $2 i32) + (local $3 i32) + (local $4 i32) + (local $5 i32) + (local $6 i32) + (local $7 i32) + (set_local $0 + (i32.const 0) + ) + (loop $while-in + (set_local $3 + (i32.const 0) + ) + (loop $while-in6 + (set_local $6 + (i32.add + (get_local $0) + (i32.const 1) + ) + ) + (set_local $0 + (if i32 + (if i32 + (i32.rem_s + (i32.add + (i32.mul + (tee_local $7 + (i32.add + (get_local $0) + (i32.const 2) + ) + ) + (get_local $0) + ) + (i32.const 17) + ) + (i32.const 5) + ) + (i32.eqz + (i32.rem_u + (i32.add + (i32.mul + (get_local $0) + (get_local $0) + ) + (i32.const 11) + ) + (i32.const 3) + ) + ) + (i32.const 1) + ) + (get_local $7) + (get_local $6) + ) + ) + (br_if $while-in6 + (i32.lt_s + (tee_local $3 + (i32.add + (get_local $3) + (i32.const 1) + ) + ) + (get_local $4) + ) + ) + ) + (br_if $while-in + (i32.ne + (tee_local $1 + (i32.add + (get_local $1) + (i32.const 1) + ) + ) + (i32.const 27000) + ) + ) + ) + (return + (get_local $5) + ) + ) + (func $side-effect (type $0) (param $0 i32) (param $1 i32) (result i32) + (local $2 i32) + (local $3 i32) + (local $4 i32) + (local $5 i32) + (local $6 i32) + (local $7 i32) + (set_local $0 + (i32.const 0) + ) + (loop $while-in + (set_local $3 + (i32.const 0) + ) + (loop $while-in6 + (set_local $6 + (i32.add + (get_local $0) + (i32.const 1) + ) + ) + (set_local $0 + (if i32 + (i32.or + (i32.eqz + (i32.rem_s + (i32.add + (i32.mul + (tee_local $7 + (i32.add + (get_local $0) + (i32.const 0) + ) + ) + (get_local $0) + ) + (i32.const 17) + ) + (i32.const 5) + ) + ) + (i32.eqz + (i32.rem_u + (i32.add + (i32.mul + (get_local $0) + (get_local $0) + ) + (unreachable) + ) + (i32.const 3) + ) + ) + ) + (get_local $7) + (get_local $6) + ) + ) + (br_if $while-in6 + (i32.lt_s + (tee_local $3 + (i32.add + (get_local $3) + (i32.const 1) + ) + ) + (get_local $4) + ) + ) + ) + (br_if $while-in + (i32.ne + (tee_local $1 + (i32.add + (get_local $1) + (i32.const 1) + ) + ) + (i32.const 27000) + ) + ) + ) + (return + (get_local $5) + ) + ) + (func $flip (type $0) (param $0 i32) (param $1 i32) (result i32) + (local $2 i32) + (local $3 i32) + (local $4 i32) + (local $5 i32) + (local $6 i32) + (local $7 i32) + (set_local $0 + (i32.const 0) + ) + (loop $while-in + (set_local $3 + (i32.const 0) + ) + (loop $while-in6 + (set_local $6 + (i32.add + (get_local $0) + (i32.const 1) + ) + ) + (set_local $0 + (if i32 + (if i32 + (i32.rem_u + (i32.add + (i32.mul + (get_local $0) + (get_local $0) + ) + (i32.const 100) + ) + (i32.const 3) + ) + (i32.eqz + (i32.rem_s + (i32.add + (i32.mul + (i32.eqz + (i32.add + (get_local $0) + (i32.const 0) + ) + ) + (get_local $0) + ) + (i32.const 17) + ) + (i32.const 5) + ) + ) + (i32.const 1) + ) + (get_local $7) + (get_local $6) + ) + ) + (br_if $while-in6 + (i32.lt_s + (tee_local $3 + (i32.add + (get_local $3) + (i32.const 1) + ) + ) + (get_local $4) + ) + ) + ) + (br_if $while-in + (i32.ne + (tee_local $1 + (i32.add + (get_local $1) + (i32.const 1) + ) + ) + (i32.const 27000) + ) + ) + ) + (return + (get_local $5) + ) + ) +) diff --git a/test/passes/optimize-instructions_optimize-level=2.wast b/test/passes/optimize-instructions_optimize-level=2.wast new file mode 100644 index 000000000..7874907f3 --- /dev/null +++ b/test/passes/optimize-instructions_optimize-level=2.wast @@ -0,0 +1,263 @@ +(module + (type $0 (func (param i32 i32) (result i32))) + (memory $0 0) + (func $conditionals (type $0) (param $0 i32) (param $1 i32) (result i32) + (local $2 i32) + (local $3 i32) + (local $4 i32) + (local $5 i32) + (local $6 i32) + (local $7 i32) + (set_local $0 + (i32.const 0) + ) + (loop $while-in + (set_local $3 + (i32.const 0) + ) + (loop $while-in6 + (set_local $6 + (i32.add + (get_local $0) + (i32.const 1) + ) + ) + (set_local $0 + (if i32 + (i32.or ;; this or is very expensive. we should compute one side, then see if we even need the other + (i32.eqz + (i32.rem_s + (i32.add + (i32.mul + (tee_local $7 ;; side effect, so we can't do this one + (i32.add + (get_local $0) + (i32.const 2) + ) + ) + (get_local $0) + ) + (i32.const 17) + ) + (i32.const 5) + ) + ) + (i32.eqz + (i32.rem_u + (i32.add + (i32.mul + (get_local $0) + (get_local $0) + ) + (i32.const 11) + ) + (i32.const 3) + ) + ) + ) + (get_local $7) + (get_local $6) + ) + ) + (br_if $while-in6 + (i32.lt_s + (tee_local $3 + (i32.add + (get_local $3) + (i32.const 1) + ) + ) + (get_local $4) + ) + ) + ) + (br_if $while-in + (i32.ne + (tee_local $1 + (i32.add + (get_local $1) + (i32.const 1) + ) + ) + (i32.const 27000) + ) + ) + ) + (return + (get_local $5) + ) + ) + (func $side-effect (type $0) (param $0 i32) (param $1 i32) (result i32) + (local $2 i32) + (local $3 i32) + (local $4 i32) + (local $5 i32) + (local $6 i32) + (local $7 i32) + (set_local $0 + (i32.const 0) + ) + (loop $while-in + (set_local $3 + (i32.const 0) + ) + (loop $while-in6 + (set_local $6 + (i32.add + (get_local $0) + (i32.const 1) + ) + ) + (set_local $0 + (if i32 + (i32.or ;; this or is very expensive, but has a side effect on both sides + (i32.eqz + (i32.rem_s + (i32.add + (i32.mul + (tee_local $7 + (i32.add + (get_local $0) + (i32.const 0) + ) + ) + (get_local $0) + ) + (i32.const 17) + ) + (i32.const 5) + ) + ) + (i32.eqz + (i32.rem_u + (i32.add + (i32.mul + (get_local $0) + (get_local $0) + ) + (unreachable) + ) + (i32.const 3) + ) + ) + ) + (get_local $7) + (get_local $6) + ) + ) + (br_if $while-in6 + (i32.lt_s + (tee_local $3 + (i32.add + (get_local $3) + (i32.const 1) + ) + ) + (get_local $4) + ) + ) + ) + (br_if $while-in + (i32.ne + (tee_local $1 + (i32.add + (get_local $1) + (i32.const 1) + ) + ) + (i32.const 27000) + ) + ) + ) + (return + (get_local $5) + ) + ) + (func $flip (type $0) (param $0 i32) (param $1 i32) (result i32) + (local $2 i32) + (local $3 i32) + (local $4 i32) + (local $5 i32) + (local $6 i32) + (local $7 i32) + (set_local $0 + (i32.const 0) + ) + (loop $while-in + (set_local $3 + (i32.const 0) + ) + (loop $while-in6 + (set_local $6 + (i32.add + (get_local $0) + (i32.const 1) + ) + ) + (set_local $0 + (if i32 + (i32.or ;; this or is very expensive, and the first side has no side effect + (i32.eqz + (i32.rem_s + (i32.add + (i32.mul + (i32.eqz + (i32.add + (get_local $0) + (i32.const 0) + ) + ) + (get_local $0) + ) + (i32.const 17) + ) + (i32.const 5) + ) + ) + (i32.eqz + (i32.rem_u + (i32.add + (i32.mul + (get_local $0) + (get_local $0) + ) + (i32.const 100) + ) + (i32.const 3) + ) + ) + ) + (get_local $7) + (get_local $6) + ) + ) + (br_if $while-in6 + (i32.lt_s + (tee_local $3 + (i32.add + (get_local $3) + (i32.const 1) + ) + ) + (get_local $4) + ) + ) + ) + (br_if $while-in + (i32.ne + (tee_local $1 + (i32.add + (get_local $1) + (i32.const 1) + ) + ) + (i32.const 27000) + ) + ) + ) + (return + (get_local $5) + ) + ) +) + |