summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/ir/abstract.h5
-rw-r--r--src/passes/OptimizeInstructions.cpp49
-rw-r--r--test/passes/O_fast-math.txt12
-rw-r--r--test/passes/O_fast-math.wast18
-rw-r--r--test/passes/optimize-instructions_all-features.txt193
-rw-r--r--test/passes/optimize-instructions_all-features.wast126
6 files changed, 403 insertions, 0 deletions
diff --git a/src/ir/abstract.h b/src/ir/abstract.h
index 2bb764aeb..d1a7ec47d 100644
--- a/src/ir/abstract.h
+++ b/src/ir/abstract.h
@@ -27,6 +27,7 @@ namespace Abstract {
enum Op {
// Unary
+ Abs,
Neg,
// Binary
Add,
@@ -91,6 +92,8 @@ inline UnaryOp getUnary(Type type, Op op) {
}
case Type::f32: {
switch (op) {
+ case Abs:
+ return AbsFloat32;
case Neg:
return NegFloat32;
default:
@@ -100,6 +103,8 @@ inline UnaryOp getUnary(Type type, Op op) {
}
case Type::f64: {
switch (op) {
+ case Abs:
+ return AbsFloat64;
case Neg:
return NegFloat64;
default:
diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp
index 232faab7a..ad7cc8f8a 100644
--- a/src/passes/OptimizeInstructions.cpp
+++ b/src/passes/OptimizeInstructions.cpp
@@ -484,6 +484,25 @@ struct OptimizeInstructions
if (auto* ret = optimizeAddedConstants(binary)) {
return ret;
}
+ } else if (binary->op == MulFloat32 || binary->op == MulFloat64 ||
+ binary->op == DivFloat32 || binary->op == DivFloat64) {
+ if (binary->left->type == binary->right->type) {
+ if (auto* leftUnary = binary->left->dynCast<Unary>()) {
+ if (leftUnary->op ==
+ Abstract::getUnary(binary->type, Abstract::Abs)) {
+ if (auto* rightUnary = binary->right->dynCast<Unary>()) {
+ if (leftUnary->op == rightUnary->op) { // both are abs ops
+ // abs(x) * abs(y) ==> abs(x * y)
+ // abs(x) / abs(y) ==> abs(x / y)
+ binary->left = leftUnary->value;
+ binary->right = rightUnary->value;
+ leftUnary->value = binary;
+ return leftUnary;
+ }
+ }
+ }
+ }
+ }
}
// a bunch of operations on a constant right side can be simplified
if (auto* right = binary->right->dynCast<Const>()) {
@@ -684,6 +703,36 @@ struct OptimizeInstructions
unary->value = makeZeroExt(ext, bits);
return unary;
}
+ } else if (unary->op == AbsFloat32 || unary->op == AbsFloat64) {
+ // abs(-x) ==> abs(x)
+ if (auto* unaryInner = unary->value->dynCast<Unary>()) {
+ if (unaryInner->op ==
+ Abstract::getUnary(unaryInner->type, Abstract::Neg)) {
+ unary->value = unaryInner->value;
+ return unary;
+ }
+ }
+ // abs(x * x) ==> x * x
+ // abs(x / x) ==> x / x
+ if (auto* binary = unary->value->dynCast<Binary>()) {
+ if ((binary->op == Abstract::getBinary(binary->type, Abstract::Mul) ||
+ binary->op ==
+ Abstract::getBinary(binary->type, Abstract::DivS)) &&
+ ExpressionAnalyzer::equal(binary->left, binary->right)) {
+ return binary;
+ }
+ // abs(0 - x) ==> abs(x),
+ // only for fast math
+ if (getPassOptions().fastMath &&
+ binary->op == Abstract::getBinary(binary->type, Abstract::Sub)) {
+ if (auto* c = binary->left->dynCast<Const>()) {
+ if (c->value.isZero()) {
+ unary->value = binary->right;
+ return unary;
+ }
+ }
+ }
+ }
}
if (auto* ret = deduplicateUnary(unary)) {
diff --git a/test/passes/O_fast-math.txt b/test/passes/O_fast-math.txt
index 9aadd3d90..e3833dbba 100644
--- a/test/passes/O_fast-math.txt
+++ b/test/passes/O_fast-math.txt
@@ -13,6 +13,8 @@
(export "sub2" (func $2))
(export "mul_neg_one1" (func $9))
(export "mul_neg_one2" (func $10))
+ (export "abs_sub_zero1" (func $11))
+ (export "abs_sub_zero2" (func $12))
(func $0 (; has Stack IR ;) (result f32)
(f32.const -nan:0x23017a)
)
@@ -32,4 +34,14 @@
(local.get $0)
)
)
+ (func $11 (; has Stack IR ;) (param $0 f32) (result f32)
+ (f32.abs
+ (local.get $0)
+ )
+ )
+ (func $12 (; has Stack IR ;) (param $0 f64) (result f64)
+ (f64.abs
+ (local.get $0)
+ )
+ )
)
diff --git a/test/passes/O_fast-math.wast b/test/passes/O_fast-math.wast
index ce2cd7b6e..68cf2bcd8 100644
--- a/test/passes/O_fast-math.wast
+++ b/test/passes/O_fast-math.wast
@@ -66,4 +66,22 @@
(f64.const -1)
)
)
+ (func "abs_sub_zero1" (param $x f32) (result f32)
+ ;; abs(0 - x) ==> abs(x)
+ (f32.abs
+ (f32.sub
+ (f32.const 0)
+ (local.get $x)
+ )
+ )
+ )
+ (func "abs_sub_zero2" (param $x f64) (result f64)
+ ;; abs(0 - x) ==> abs(x)
+ (f64.abs
+ (f64.sub
+ (f64.const 0)
+ (local.get $x)
+ )
+ )
+ )
)
diff --git a/test/passes/optimize-instructions_all-features.txt b/test/passes/optimize-instructions_all-features.txt
index 6ef58233f..8a1c30e8b 100644
--- a/test/passes/optimize-instructions_all-features.txt
+++ b/test/passes/optimize-instructions_all-features.txt
@@ -19,6 +19,7 @@
(type $f32_=>_none (func (param f32)))
(type $f32_f64_=>_none (func (param f32 f64)))
(type $f64_=>_none (func (param f64)))
+ (type $f64_f64_f32_f32_=>_none (func (param f64 f64 f32 f32)))
(type $none_=>_f64 (func (result f64)))
(memory $0 0)
(export "load-off-2" (func $load-off-2))
@@ -4883,6 +4884,198 @@
)
)
)
+ (func $optimize-float-points (param $x0 f64) (param $x1 f64) (param $y0 f32) (param $y1 f32)
+ (drop
+ (f64.mul
+ (local.get $x0)
+ (local.get $x0)
+ )
+ )
+ (drop
+ (f32.mul
+ (local.get $y0)
+ (local.get $y0)
+ )
+ )
+ (drop
+ (f64.mul
+ (f64.add
+ (local.get $x0)
+ (local.get $x1)
+ )
+ (f64.add
+ (local.get $x0)
+ (local.get $x1)
+ )
+ )
+ )
+ (drop
+ (f64.abs
+ (f64.mul
+ (local.get $x0)
+ (local.get $x1)
+ )
+ )
+ )
+ (drop
+ (f32.abs
+ (f32.mul
+ (local.get $y1)
+ (local.get $y0)
+ )
+ )
+ )
+ (drop
+ (f64.abs
+ (f64.mul
+ (local.get $x0)
+ (f64.const 0)
+ )
+ )
+ )
+ (drop
+ (f32.abs
+ (f32.mul
+ (f32.const 0)
+ (local.get $y0)
+ )
+ )
+ )
+ (drop
+ (f64.abs
+ (f64.mul
+ (f64.add
+ (local.get $x0)
+ (local.get $x1)
+ )
+ (f64.add
+ (local.get $x0)
+ (local.get $x0)
+ )
+ )
+ )
+ )
+ (drop
+ (f64.abs
+ (local.get $x0)
+ )
+ )
+ (drop
+ (f32.abs
+ (local.get $y0)
+ )
+ )
+ (drop
+ (f64.abs
+ (f64.sub
+ (f64.const 0)
+ (local.get $x0)
+ )
+ )
+ )
+ (drop
+ (f32.abs
+ (f32.sub
+ (f32.const 0)
+ (local.get $y0)
+ )
+ )
+ )
+ (drop
+ (f64.div
+ (local.get $x0)
+ (local.get $x0)
+ )
+ )
+ (drop
+ (f32.div
+ (local.get $y0)
+ (local.get $y0)
+ )
+ )
+ (drop
+ (f64.div
+ (f64.add
+ (local.get $x0)
+ (local.get $x1)
+ )
+ (f64.add
+ (local.get $x0)
+ (local.get $x1)
+ )
+ )
+ )
+ (drop
+ (f64.abs
+ (f64.div
+ (local.get $x0)
+ (local.get $x1)
+ )
+ )
+ )
+ (drop
+ (f32.abs
+ (f32.div
+ (local.get $y1)
+ (local.get $y0)
+ )
+ )
+ )
+ (drop
+ (f64.mul
+ (local.get $x0)
+ (local.get $x0)
+ )
+ )
+ (drop
+ (f32.mul
+ (local.get $y0)
+ (local.get $y0)
+ )
+ )
+ (drop
+ (f64.div
+ (local.get $x0)
+ (local.get $x0)
+ )
+ )
+ (drop
+ (f32.div
+ (local.get $y0)
+ (local.get $y0)
+ )
+ )
+ (drop
+ (f64.abs
+ (f64.div
+ (local.get $x0)
+ (f64.const 0)
+ )
+ )
+ )
+ (drop
+ (f32.abs
+ (f32.div
+ (f32.const 0)
+ (local.get $y0)
+ )
+ )
+ )
+ (drop
+ (f64.abs
+ (f64.div
+ (f64.add
+ (local.get $x0)
+ (local.get $x1)
+ )
+ (f64.add
+ (local.get $x0)
+ (local.get $x0)
+ )
+ )
+ )
+ )
+ )
)
(module
(type $none_=>_none (func))
diff --git a/test/passes/optimize-instructions_all-features.wast b/test/passes/optimize-instructions_all-features.wast
index 60ae8d3c2..e989fda45 100644
--- a/test/passes/optimize-instructions_all-features.wast
+++ b/test/passes/optimize-instructions_all-features.wast
@@ -5361,6 +5361,132 @@
)
)
)
+ (func $optimize-float-points (param $x0 f64) (param $x1 f64) (param $y0 f32) (param $y1 f32)
+ ;; abs(x) * abs(x) ==> x * x
+ (drop (f64.mul
+ (f64.abs (local.get $x0))
+ (f64.abs (local.get $x0))
+ ))
+ (drop (f32.mul
+ (f32.abs (local.get $y0))
+ (f32.abs (local.get $y0))
+ ))
+ (drop (f64.mul
+ (f64.abs (f64.add (local.get $x0) (local.get $x1)))
+ (f64.abs (f64.add (local.get $x0) (local.get $x1)))
+ ))
+
+ ;; abs(x) * abs(y) ==> abs(x * y)
+ (drop (f64.mul
+ (f64.abs (local.get $x0))
+ (f64.abs (local.get $x1))
+ ))
+ (drop (f32.mul
+ (f32.abs (local.get $y1))
+ (f32.abs (local.get $y0))
+ ))
+
+ (drop (f64.mul
+ (f64.abs (local.get $x0))
+ (f64.abs (f64.const 0)) ;; skip
+ ))
+ (drop (f32.mul
+ (f32.abs (f32.const 0)) ;; skip
+ (f32.abs (local.get $y0))
+ ))
+ (drop (f64.mul
+ (f64.abs (f64.add (local.get $x0) (local.get $x1)))
+ (f64.abs (f64.add (local.get $x0) (local.get $x0)))
+ ))
+
+
+ ;; abs(-x) ==> abs(x)
+ (drop (f64.abs
+ (f64.neg (local.get $x0))
+ ))
+ (drop (f32.abs
+ (f32.neg (local.get $y0))
+ ))
+
+ ;; abs(0 - x) ==> skip for non-fast math
+ (drop (f64.abs
+ (f64.sub
+ (f64.const 0)
+ (local.get $x0)
+ )
+ ))
+ (drop (f32.abs
+ (f32.sub
+ (f32.const 0)
+ (local.get $y0)
+ )
+ ))
+
+ ;; abs(x) / abs(x) ==> x / x
+ (drop (f64.div
+ (f64.abs (local.get $x0))
+ (f64.abs (local.get $x0))
+ ))
+ (drop (f32.div
+ (f32.abs (local.get $y0))
+ (f32.abs (local.get $y0))
+ ))
+ (drop (f64.div
+ (f64.abs (f64.add (local.get $x0) (local.get $x1)))
+ (f64.abs (f64.add (local.get $x0) (local.get $x1)))
+ ))
+
+ ;; abs(x) / abs(y) ==> abs(x / y)
+ (drop (f64.div
+ (f64.abs (local.get $x0))
+ (f64.abs (local.get $x1))
+ ))
+ (drop (f32.div
+ (f32.abs (local.get $y1))
+ (f32.abs (local.get $y0))
+ ))
+
+ ;; abs(x * x) ==> x * x
+ (drop (f64.abs
+ (f64.mul
+ (local.get $x0)
+ (local.get $x0)
+ )
+ ))
+ (drop (f32.abs
+ (f32.mul
+ (local.get $y0)
+ (local.get $y0)
+ )
+ ))
+
+ ;; abs(x / x) ==> x / x
+ (drop (f64.abs
+ (f64.div
+ (local.get $x0)
+ (local.get $x0)
+ )
+ ))
+ (drop (f32.abs
+ (f32.div
+ (local.get $y0)
+ (local.get $y0)
+ )
+ ))
+
+ (drop (f64.div
+ (f64.abs (local.get $x0))
+ (f64.abs (f64.const 0)) ;; skip
+ ))
+ (drop (f32.div
+ (f32.abs (f32.const 0)) ;; skip
+ (f32.abs (local.get $y0))
+ ))
+ (drop (f64.div
+ (f64.abs (f64.add (local.get $x0) (local.get $x1)))
+ (f64.abs (f64.add (local.get $x0) (local.get $x0)))
+ ))
+ )
)
(module
(import "env" "memory" (memory $0 (shared 256 256)))