diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/literal.h | 45 | ||||
-rw-r--r-- | src/passes/OptimizeInstructions.cpp | 6 | ||||
-rw-r--r-- | src/wasm/literal.cpp | 66 |
3 files changed, 55 insertions, 62 deletions
diff --git a/src/literal.h b/src/literal.h index 9d9630ec4..318ab012a 100644 --- a/src/literal.h +++ b/src/literal.h @@ -249,6 +249,28 @@ public: lit.i32 = value | 0x80000000; return lit; } + // Wasm has nondeterministic rules for NaN propagation in some operations. For + // example. f32.neg is deterministic and just flips the sign, even of a NaN, + // but f32.add is nondeterministic, and if one or more of the inputs is a NaN, + // then + // + // * if all NaNs are canonical, the output is some arbitrary canonical NaN + // * otherwise the output is some arbitrary arithmetic NaN + // + // (canonical = NaN payload is 1000..000; arithmetic: 1???..???, that is, the + // high bit is 1 and all others can be 0 or 1) + // + // For many things we don't need to care, and can just do a normal C++ add for + // an f32.add, for example - the wasm rules are specified so that things like + // that just work (in order for such math to be fast). However, for our + // optimizer, it is useful to "standardize" NaNs when there is nondeterminism. + // That is, when there are multiple valid outputs, it's nice to emit the same + // one consistently, so that it doesn't look like the optimization changed + // something. In other words, if the valid output of an expression is a set of + // valid NaNs, and after optimization the output is still that same set, then + // the optimization is valid. And if the interpreter picks the same NaN in + // both cases from that identical set then nothing looks wrong to the fuzzer. + static Literal standardizeNaN(const Literal& input); Literal castToF32(); Literal castToF64(); @@ -706,29 +728,6 @@ struct GCData { GCData(HeapType type, Literals values) : type(type), values(values) {} }; -// Wasm has nondeterministic rules for NaN propagation in some operations. For -// example. f32.neg is deterministic and just flips the sign, even of a NaN, but -// f32.add is nondeterministic, and if one or more of the inputs is a NaN, then -// -// * if all NaNs are canonical NaNs, the output is some arbitrary canonical NaN -// * otherwise the output is some arbitrary arithmetic NaN -// -// (canonical = NaN payload is 1000..000; arithmetic: 1???..???, that is, the -// high bit is 1 and all others can be 0 or 1) -// -// For many things we don't need to care, and can just do a normal C++ add for -// an f32.add, for example - the wasm rules are specified so that things like -// that just work (in order for such math to be fast). However, for our -// optimizer, it is useful to "standardize" NaNs when there is nondeterminism. -// That is, when there are multiple valid outputs, it's nice to emit the same -// one consistently, so that it doesn't look like the optimization changed -// something. In other words, if the valid output of an expression is a set of -// valid NaNs, and after optimization the output is still that same set, then -// the optimization is valid. And if the interpreter picks the same NaN in both -// cases from that identical set then nothing looks wrong to the fuzzer. -Literal standardizeNaN(float result); -Literal standardizeNaN(double result); - } // namespace wasm namespace std { diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp index fc034a24f..4ffd6be67 100644 --- a/src/passes/OptimizeInstructions.cpp +++ b/src/passes/OptimizeInstructions.cpp @@ -3610,11 +3610,7 @@ private: return c; } // propagate NaN of RHS but canonicalize it - if (c->type == Type::f32) { - c->value = standardizeNaN(c->value.getf32()); - } else { - c->value = standardizeNaN(c->value.getf64()); - } + c->value = Literal::standardizeNaN(c->value); return c; } } diff --git a/src/wasm/literal.cpp b/src/wasm/literal.cpp index 843f4607f..43f407525 100644 --- a/src/wasm/literal.cpp +++ b/src/wasm/literal.cpp @@ -217,6 +217,20 @@ Literal Literal::makeNegOne(Type type) { return makeFromInt32(-1, type); } +Literal Literal::standardizeNaN(const Literal& input) { + if (!std::isnan(input.getFloat())) { + return input; + } + // Pick a simple canonical payload, and positive. + if (input.type == Type::f32) { + return Literal(bit_cast<float>(uint32_t(0x7fc00000u))); + } else if (input.type == Type::f64) { + return Literal(bit_cast<double>(uint64_t(0x7ff8000000000000ull))); + } else { + WASM_UNREACHABLE("unexpected type"); + } +} + std::array<uint8_t, 16> Literal::getv128() const { assert(type == Type::v128); std::array<uint8_t, 16> ret; @@ -859,22 +873,6 @@ Literal Literal::demote() const { return Literal(float(getf64())); } -Literal standardizeNaN(float result) { - if (!std::isnan(result)) { - return Literal(result); - } - // Pick a simple canonical payload, and positive. - return Literal(Literal(uint32_t(0x7fc00000u)).reinterpretf32()); -} - -Literal standardizeNaN(double result) { - if (!std::isnan(result)) { - return Literal(result); - } - // Pick a simple canonical payload, and positive. - return Literal(Literal(uint64_t(0x7ff8000000000000ull)).reinterpretf64()); -} - Literal Literal::add(const Literal& other) const { switch (type.getBasic()) { case Type::i32: @@ -882,9 +880,9 @@ Literal Literal::add(const Literal& other) const { case Type::i64: return Literal(uint64_t(i64) + uint64_t(other.i64)); case Type::f32: - return standardizeNaN(getf32() + other.getf32()); + return standardizeNaN(Literal(getf32() + other.getf32())); case Type::f64: - return standardizeNaN(getf64() + other.getf64()); + return standardizeNaN(Literal(getf64() + other.getf64())); case Type::v128: case Type::none: case Type::unreachable: @@ -900,9 +898,9 @@ Literal Literal::sub(const Literal& other) const { case Type::i64: return Literal(uint64_t(i64) - uint64_t(other.i64)); case Type::f32: - return standardizeNaN(getf32() - other.getf32()); + return standardizeNaN(Literal(getf32() - other.getf32())); case Type::f64: - return standardizeNaN(getf64() - other.getf64()); + return standardizeNaN(Literal(getf64() - other.getf64())); case Type::v128: case Type::none: case Type::unreachable: @@ -997,9 +995,9 @@ Literal Literal::mul(const Literal& other) const { case Type::i64: return Literal(uint64_t(i64) * uint64_t(other.i64)); case Type::f32: - return standardizeNaN(getf32() * other.getf32()); + return standardizeNaN(Literal(getf32() * other.getf32())); case Type::f64: - return standardizeNaN(getf64() * other.getf64()); + return standardizeNaN(Literal(getf64() * other.getf64())); case Type::v128: case Type::none: case Type::unreachable: @@ -1018,7 +1016,7 @@ Literal Literal::div(const Literal& other) const { switch (std::fpclassify(lhs)) { case FP_NAN: case FP_ZERO: - return standardizeNaN(lhs / rhs); + return standardizeNaN(Literal(lhs / rhs)); case FP_NORMAL: // fallthrough case FP_SUBNORMAL: // fallthrough case FP_INFINITE: @@ -1031,7 +1029,7 @@ Literal Literal::div(const Literal& other) const { case FP_INFINITE: // fallthrough case FP_NORMAL: // fallthrough case FP_SUBNORMAL: - return standardizeNaN(lhs / rhs); + return standardizeNaN(Literal(lhs / rhs)); default: WASM_UNREACHABLE("invalid fp classification"); } @@ -1044,7 +1042,7 @@ Literal Literal::div(const Literal& other) const { switch (std::fpclassify(lhs)) { case FP_NAN: case FP_ZERO: - return standardizeNaN(lhs / rhs); + return standardizeNaN(Literal(lhs / rhs)); case FP_NORMAL: // fallthrough case FP_SUBNORMAL: // fallthrough case FP_INFINITE: @@ -1057,7 +1055,7 @@ Literal Literal::div(const Literal& other) const { case FP_INFINITE: // fallthrough case FP_NORMAL: // fallthrough case FP_SUBNORMAL: - return standardizeNaN(lhs / rhs); + return standardizeNaN(Literal(lhs / rhs)); default: WASM_UNREACHABLE("invalid fp classification"); } @@ -1393,10 +1391,10 @@ Literal Literal::min(const Literal& other) const { case Type::f32: { auto l = getf32(), r = other.getf32(); if (std::isnan(l)) { - return standardizeNaN(l); + return standardizeNaN(Literal(l)); } if (std::isnan(r)) { - return standardizeNaN(r); + return standardizeNaN(Literal(r)); } if (l == r && l == 0) { return Literal(std::signbit(l) ? l : r); @@ -1406,10 +1404,10 @@ Literal Literal::min(const Literal& other) const { case Type::f64: { auto l = getf64(), r = other.getf64(); if (std::isnan(l)) { - return standardizeNaN(l); + return standardizeNaN(Literal(l)); } if (std::isnan(r)) { - return standardizeNaN(r); + return standardizeNaN(Literal(r)); } if (l == r && l == 0) { return Literal(std::signbit(l) ? l : r); @@ -1426,10 +1424,10 @@ Literal Literal::max(const Literal& other) const { case Type::f32: { auto l = getf32(), r = other.getf32(); if (std::isnan(l)) { - return standardizeNaN(l); + return standardizeNaN(Literal(l)); } if (std::isnan(r)) { - return standardizeNaN(r); + return standardizeNaN(Literal(r)); } if (l == r && l == 0) { return Literal(std::signbit(l) ? r : l); @@ -1439,10 +1437,10 @@ Literal Literal::max(const Literal& other) const { case Type::f64: { auto l = getf64(), r = other.getf64(); if (std::isnan(l)) { - return standardizeNaN(l); + return standardizeNaN(Literal(l)); } if (std::isnan(r)) { - return standardizeNaN(r); + return standardizeNaN(Literal(r)); } if (l == r && l == 0) { return Literal(std::signbit(l) ? r : l); |