diff options
-rw-r--r-- | src/ir/abstract.h | 5 | ||||
-rw-r--r-- | src/passes/OptimizeInstructions.cpp | 49 | ||||
-rw-r--r-- | test/passes/O_fast-math.txt | 12 | ||||
-rw-r--r-- | test/passes/O_fast-math.wast | 18 | ||||
-rw-r--r-- | test/passes/optimize-instructions_all-features.txt | 193 | ||||
-rw-r--r-- | test/passes/optimize-instructions_all-features.wast | 126 |
6 files changed, 403 insertions, 0 deletions
diff --git a/src/ir/abstract.h b/src/ir/abstract.h index 2bb764aeb..d1a7ec47d 100644 --- a/src/ir/abstract.h +++ b/src/ir/abstract.h @@ -27,6 +27,7 @@ namespace Abstract { enum Op { // Unary + Abs, Neg, // Binary Add, @@ -91,6 +92,8 @@ inline UnaryOp getUnary(Type type, Op op) { } case Type::f32: { switch (op) { + case Abs: + return AbsFloat32; case Neg: return NegFloat32; default: @@ -100,6 +103,8 @@ inline UnaryOp getUnary(Type type, Op op) { } case Type::f64: { switch (op) { + case Abs: + return AbsFloat64; case Neg: return NegFloat64; default: diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp index 232faab7a..ad7cc8f8a 100644 --- a/src/passes/OptimizeInstructions.cpp +++ b/src/passes/OptimizeInstructions.cpp @@ -484,6 +484,25 @@ struct OptimizeInstructions if (auto* ret = optimizeAddedConstants(binary)) { return ret; } + } else if (binary->op == MulFloat32 || binary->op == MulFloat64 || + binary->op == DivFloat32 || binary->op == DivFloat64) { + if (binary->left->type == binary->right->type) { + if (auto* leftUnary = binary->left->dynCast<Unary>()) { + if (leftUnary->op == + Abstract::getUnary(binary->type, Abstract::Abs)) { + if (auto* rightUnary = binary->right->dynCast<Unary>()) { + if (leftUnary->op == rightUnary->op) { // both are abs ops + // abs(x) * abs(y) ==> abs(x * y) + // abs(x) / abs(y) ==> abs(x / y) + binary->left = leftUnary->value; + binary->right = rightUnary->value; + leftUnary->value = binary; + return leftUnary; + } + } + } + } + } } // a bunch of operations on a constant right side can be simplified if (auto* right = binary->right->dynCast<Const>()) { @@ -684,6 +703,36 @@ struct OptimizeInstructions unary->value = makeZeroExt(ext, bits); return unary; } + } else if (unary->op == AbsFloat32 || unary->op == AbsFloat64) { + // abs(-x) ==> abs(x) + if (auto* unaryInner = unary->value->dynCast<Unary>()) { + if (unaryInner->op == + Abstract::getUnary(unaryInner->type, Abstract::Neg)) { + unary->value = unaryInner->value; + return unary; + } + } + // abs(x * x) ==> x * x + // abs(x / x) ==> x / x + if (auto* binary = unary->value->dynCast<Binary>()) { + if ((binary->op == Abstract::getBinary(binary->type, Abstract::Mul) || + binary->op == + Abstract::getBinary(binary->type, Abstract::DivS)) && + ExpressionAnalyzer::equal(binary->left, binary->right)) { + return binary; + } + // abs(0 - x) ==> abs(x), + // only for fast math + if (getPassOptions().fastMath && + binary->op == Abstract::getBinary(binary->type, Abstract::Sub)) { + if (auto* c = binary->left->dynCast<Const>()) { + if (c->value.isZero()) { + unary->value = binary->right; + return unary; + } + } + } + } } if (auto* ret = deduplicateUnary(unary)) { diff --git a/test/passes/O_fast-math.txt b/test/passes/O_fast-math.txt index 9aadd3d90..e3833dbba 100644 --- a/test/passes/O_fast-math.txt +++ b/test/passes/O_fast-math.txt @@ -13,6 +13,8 @@ (export "sub2" (func $2)) (export "mul_neg_one1" (func $9)) (export "mul_neg_one2" (func $10)) + (export "abs_sub_zero1" (func $11)) + (export "abs_sub_zero2" (func $12)) (func $0 (; has Stack IR ;) (result f32) (f32.const -nan:0x23017a) ) @@ -32,4 +34,14 @@ (local.get $0) ) ) + (func $11 (; has Stack IR ;) (param $0 f32) (result f32) + (f32.abs + (local.get $0) + ) + ) + (func $12 (; has Stack IR ;) (param $0 f64) (result f64) + (f64.abs + (local.get $0) + ) + ) ) diff --git a/test/passes/O_fast-math.wast b/test/passes/O_fast-math.wast index ce2cd7b6e..68cf2bcd8 100644 --- a/test/passes/O_fast-math.wast +++ b/test/passes/O_fast-math.wast @@ -66,4 +66,22 @@ (f64.const -1) ) ) + (func "abs_sub_zero1" (param $x f32) (result f32) + ;; abs(0 - x) ==> abs(x) + (f32.abs + (f32.sub + (f32.const 0) + (local.get $x) + ) + ) + ) + (func "abs_sub_zero2" (param $x f64) (result f64) + ;; abs(0 - x) ==> abs(x) + (f64.abs + (f64.sub + (f64.const 0) + (local.get $x) + ) + ) + ) ) diff --git a/test/passes/optimize-instructions_all-features.txt b/test/passes/optimize-instructions_all-features.txt index 6ef58233f..8a1c30e8b 100644 --- a/test/passes/optimize-instructions_all-features.txt +++ b/test/passes/optimize-instructions_all-features.txt @@ -19,6 +19,7 @@ (type $f32_=>_none (func (param f32))) (type $f32_f64_=>_none (func (param f32 f64))) (type $f64_=>_none (func (param f64))) + (type $f64_f64_f32_f32_=>_none (func (param f64 f64 f32 f32))) (type $none_=>_f64 (func (result f64))) (memory $0 0) (export "load-off-2" (func $load-off-2)) @@ -4883,6 +4884,198 @@ ) ) ) + (func $optimize-float-points (param $x0 f64) (param $x1 f64) (param $y0 f32) (param $y1 f32) + (drop + (f64.mul + (local.get $x0) + (local.get $x0) + ) + ) + (drop + (f32.mul + (local.get $y0) + (local.get $y0) + ) + ) + (drop + (f64.mul + (f64.add + (local.get $x0) + (local.get $x1) + ) + (f64.add + (local.get $x0) + (local.get $x1) + ) + ) + ) + (drop + (f64.abs + (f64.mul + (local.get $x0) + (local.get $x1) + ) + ) + ) + (drop + (f32.abs + (f32.mul + (local.get $y1) + (local.get $y0) + ) + ) + ) + (drop + (f64.abs + (f64.mul + (local.get $x0) + (f64.const 0) + ) + ) + ) + (drop + (f32.abs + (f32.mul + (f32.const 0) + (local.get $y0) + ) + ) + ) + (drop + (f64.abs + (f64.mul + (f64.add + (local.get $x0) + (local.get $x1) + ) + (f64.add + (local.get $x0) + (local.get $x0) + ) + ) + ) + ) + (drop + (f64.abs + (local.get $x0) + ) + ) + (drop + (f32.abs + (local.get $y0) + ) + ) + (drop + (f64.abs + (f64.sub + (f64.const 0) + (local.get $x0) + ) + ) + ) + (drop + (f32.abs + (f32.sub + (f32.const 0) + (local.get $y0) + ) + ) + ) + (drop + (f64.div + (local.get $x0) + (local.get $x0) + ) + ) + (drop + (f32.div + (local.get $y0) + (local.get $y0) + ) + ) + (drop + (f64.div + (f64.add + (local.get $x0) + (local.get $x1) + ) + (f64.add + (local.get $x0) + (local.get $x1) + ) + ) + ) + (drop + (f64.abs + (f64.div + (local.get $x0) + (local.get $x1) + ) + ) + ) + (drop + (f32.abs + (f32.div + (local.get $y1) + (local.get $y0) + ) + ) + ) + (drop + (f64.mul + (local.get $x0) + (local.get $x0) + ) + ) + (drop + (f32.mul + (local.get $y0) + (local.get $y0) + ) + ) + (drop + (f64.div + (local.get $x0) + (local.get $x0) + ) + ) + (drop + (f32.div + (local.get $y0) + (local.get $y0) + ) + ) + (drop + (f64.abs + (f64.div + (local.get $x0) + (f64.const 0) + ) + ) + ) + (drop + (f32.abs + (f32.div + (f32.const 0) + (local.get $y0) + ) + ) + ) + (drop + (f64.abs + (f64.div + (f64.add + (local.get $x0) + (local.get $x1) + ) + (f64.add + (local.get $x0) + (local.get $x0) + ) + ) + ) + ) + ) ) (module (type $none_=>_none (func)) diff --git a/test/passes/optimize-instructions_all-features.wast b/test/passes/optimize-instructions_all-features.wast index 60ae8d3c2..e989fda45 100644 --- a/test/passes/optimize-instructions_all-features.wast +++ b/test/passes/optimize-instructions_all-features.wast @@ -5361,6 +5361,132 @@ ) ) ) + (func $optimize-float-points (param $x0 f64) (param $x1 f64) (param $y0 f32) (param $y1 f32) + ;; abs(x) * abs(x) ==> x * x + (drop (f64.mul + (f64.abs (local.get $x0)) + (f64.abs (local.get $x0)) + )) + (drop (f32.mul + (f32.abs (local.get $y0)) + (f32.abs (local.get $y0)) + )) + (drop (f64.mul + (f64.abs (f64.add (local.get $x0) (local.get $x1))) + (f64.abs (f64.add (local.get $x0) (local.get $x1))) + )) + + ;; abs(x) * abs(y) ==> abs(x * y) + (drop (f64.mul + (f64.abs (local.get $x0)) + (f64.abs (local.get $x1)) + )) + (drop (f32.mul + (f32.abs (local.get $y1)) + (f32.abs (local.get $y0)) + )) + + (drop (f64.mul + (f64.abs (local.get $x0)) + (f64.abs (f64.const 0)) ;; skip + )) + (drop (f32.mul + (f32.abs (f32.const 0)) ;; skip + (f32.abs (local.get $y0)) + )) + (drop (f64.mul + (f64.abs (f64.add (local.get $x0) (local.get $x1))) + (f64.abs (f64.add (local.get $x0) (local.get $x0))) + )) + + + ;; abs(-x) ==> abs(x) + (drop (f64.abs + (f64.neg (local.get $x0)) + )) + (drop (f32.abs + (f32.neg (local.get $y0)) + )) + + ;; abs(0 - x) ==> skip for non-fast math + (drop (f64.abs + (f64.sub + (f64.const 0) + (local.get $x0) + ) + )) + (drop (f32.abs + (f32.sub + (f32.const 0) + (local.get $y0) + ) + )) + + ;; abs(x) / abs(x) ==> x / x + (drop (f64.div + (f64.abs (local.get $x0)) + (f64.abs (local.get $x0)) + )) + (drop (f32.div + (f32.abs (local.get $y0)) + (f32.abs (local.get $y0)) + )) + (drop (f64.div + (f64.abs (f64.add (local.get $x0) (local.get $x1))) + (f64.abs (f64.add (local.get $x0) (local.get $x1))) + )) + + ;; abs(x) / abs(y) ==> abs(x / y) + (drop (f64.div + (f64.abs (local.get $x0)) + (f64.abs (local.get $x1)) + )) + (drop (f32.div + (f32.abs (local.get $y1)) + (f32.abs (local.get $y0)) + )) + + ;; abs(x * x) ==> x * x + (drop (f64.abs + (f64.mul + (local.get $x0) + (local.get $x0) + ) + )) + (drop (f32.abs + (f32.mul + (local.get $y0) + (local.get $y0) + ) + )) + + ;; abs(x / x) ==> x / x + (drop (f64.abs + (f64.div + (local.get $x0) + (local.get $x0) + ) + )) + (drop (f32.abs + (f32.div + (local.get $y0) + (local.get $y0) + ) + )) + + (drop (f64.div + (f64.abs (local.get $x0)) + (f64.abs (f64.const 0)) ;; skip + )) + (drop (f32.div + (f32.abs (f32.const 0)) ;; skip + (f32.abs (local.get $y0)) + )) + (drop (f64.div + (f64.abs (f64.add (local.get $x0) (local.get $x1))) + (f64.abs (f64.add (local.get $x0) (local.get $x0))) + )) + ) ) (module (import "env" "memory" (memory $0 (shared 256 256))) |