diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/wasm/literal.cpp | 134 |
1 files changed, 72 insertions, 62 deletions
diff --git a/src/wasm/literal.cpp b/src/wasm/literal.cpp index 6d0a0fee7..6613bc1c7 100644 --- a/src/wasm/literal.cpp +++ b/src/wasm/literal.cpp @@ -837,6 +837,40 @@ Literal Literal::demote() const { return Literal(float(getf64())); } +// Wasm has nondeterministic rules for NaN propagation in some operations. For +// example. f32.neg is deterministic and just flips the sign, even of a NaN, but +// f32.add is nondeterministic, and if one or more of the inputs is a NaN, then +// +// * if all NaNs are canonical NaNs, the output is some arbitrary canonical NaN +// * otherwise the output is some arbitrary arithmetic NaN +// +// (canonical = NaN payload is 1000..000; arithmetic: 1???..???, that is, the +// high bit is 1 and all others can be 0 or 1) +// +// For many things we don't need to care, and can just do a normal C++ add for +// an f32.add, for example - the wasm rules are specified so that things like +// that just work (in order for such math to be fast). However, for our +// optimizer, it is useful to "standardize" NaNs when there is nondeterminism. +// That is, when there are multiple valid outputs, it's nice to emit the same +// one consistently, so that it doesn't look like the optimization changed +// something. In other words, if the valid output of an expression is a set of +// valid NaNs, and after optimization the output is still that same set, then +// the optimization is valid. And if the interpreter picks the same NaN in both +// cases from that identical set then nothing looks wrong to the fuzzer. +template<typename T> static Literal standardizeNaN(T result) { + if (!std::isnan(result)) { + return Literal(result); + } + // Pick a simple canonical payload, and positive. + if (sizeof(T) == 4) { + return Literal(Literal(uint32_t(0x7fc00000u)).reinterpretf32()); + } else if (sizeof(T) == 8) { + return Literal(Literal(uint64_t(0x7ff8000000000000ull)).reinterpretf64()); + } else { + WASM_UNREACHABLE("invalid float"); + } +} + Literal Literal::add(const Literal& other) const { switch (type.getBasic()) { case Type::i32: @@ -844,9 +878,9 @@ Literal Literal::add(const Literal& other) const { case Type::i64: return Literal(uint64_t(i64) + uint64_t(other.i64)); case Type::f32: - return Literal(getf32() + other.getf32()); + return standardizeNaN(getf32() + other.getf32()); case Type::f64: - return Literal(getf64() + other.getf64()); + return standardizeNaN(getf64() + other.getf64()); case Type::v128: case Type::funcref: case Type::externref: @@ -868,9 +902,9 @@ Literal Literal::sub(const Literal& other) const { case Type::i64: return Literal(uint64_t(i64) - uint64_t(other.i64)); case Type::f32: - return Literal(getf32() - other.getf32()); + return standardizeNaN(getf32() - other.getf32()); case Type::f64: - return Literal(getf64() - other.getf64()); + return standardizeNaN(getf64() - other.getf64()); case Type::v128: case Type::funcref: case Type::externref: @@ -963,9 +997,9 @@ Literal Literal::mul(const Literal& other) const { case Type::i64: return Literal(uint64_t(i64) * uint64_t(other.i64)); case Type::f32: - return Literal(getf32() * other.getf32()); + return standardizeNaN(getf32() * other.getf32()); case Type::f64: - return Literal(getf64() * other.getf64()); + return standardizeNaN(getf64() * other.getf64()); case Type::v128: case Type::funcref: case Type::externref: @@ -989,10 +1023,8 @@ Literal Literal::div(const Literal& other) const { case FP_ZERO: switch (std::fpclassify(lhs)) { case FP_NAN: - return Literal(setQuietNaN(lhs)); case FP_ZERO: - return Literal( - std::copysign(std::numeric_limits<float>::quiet_NaN(), sign)); + return standardizeNaN(lhs / rhs); case FP_NORMAL: // fallthrough case FP_SUBNORMAL: // fallthrough case FP_INFINITE: @@ -1005,7 +1037,7 @@ Literal Literal::div(const Literal& other) const { case FP_INFINITE: // fallthrough case FP_NORMAL: // fallthrough case FP_SUBNORMAL: - return Literal(lhs / rhs); + return standardizeNaN(lhs / rhs); default: WASM_UNREACHABLE("invalid fp classification"); } @@ -1017,10 +1049,8 @@ Literal Literal::div(const Literal& other) const { case FP_ZERO: switch (std::fpclassify(lhs)) { case FP_NAN: - return Literal(setQuietNaN(lhs)); case FP_ZERO: - return Literal( - std::copysign(std::numeric_limits<double>::quiet_NaN(), sign)); + return standardizeNaN(lhs / rhs); case FP_NORMAL: // fallthrough case FP_SUBNORMAL: // fallthrough case FP_INFINITE: @@ -1033,7 +1063,7 @@ Literal Literal::div(const Literal& other) const { case FP_INFINITE: // fallthrough case FP_NORMAL: // fallthrough case FP_SUBNORMAL: - return Literal(lhs / rhs); + return standardizeNaN(lhs / rhs); default: WASM_UNREACHABLE("invalid fp classification"); } @@ -1380,39 +1410,29 @@ Literal Literal::min(const Literal& other) const { switch (type.getBasic()) { case Type::f32: { auto l = getf32(), r = other.getf32(); - if (l == r && l == 0) { - return Literal(std::signbit(l) ? l : r); + if (std::isnan(l)) { + return standardizeNaN(l); } - auto result = std::min(l, r); - bool lnan = std::isnan(l), rnan = std::isnan(r); - if (!std::isnan(result) && !lnan && !rnan) { - return Literal(result); + if (std::isnan(r)) { + return standardizeNaN(r); } - if (!lnan && !rnan) { - return Literal((int32_t)0x7fc00000).castToF32(); + if (l == r && l == 0) { + return Literal(std::signbit(l) ? l : r); } - return Literal(lnan ? l : r) - .castToI32() - .or_(Literal(0xc00000)) - .castToF32(); + return Literal(std::min(l, r)); } case Type::f64: { auto l = getf64(), r = other.getf64(); - if (l == r && l == 0) { - return Literal(std::signbit(l) ? l : r); + if (std::isnan(l)) { + return standardizeNaN(l); } - auto result = std::min(l, r); - bool lnan = std::isnan(l), rnan = std::isnan(r); - if (!std::isnan(result) && !lnan && !rnan) { - return Literal(result); + if (std::isnan(r)) { + return standardizeNaN(r); } - if (!lnan && !rnan) { - return Literal((int64_t)0x7ff8000000000000LL).castToF64(); + if (l == r && l == 0) { + return Literal(std::signbit(l) ? l : r); } - return Literal(lnan ? l : r) - .castToI64() - .or_(Literal(int64_t(0x8000000000000LL))) - .castToF64(); + return Literal(std::min(l, r)); } default: WASM_UNREACHABLE("unexpected type"); @@ -1423,39 +1443,29 @@ Literal Literal::max(const Literal& other) const { switch (type.getBasic()) { case Type::f32: { auto l = getf32(), r = other.getf32(); - if (l == r && l == 0) { - return Literal(std::signbit(l) ? r : l); + if (std::isnan(l)) { + return standardizeNaN(l); } - auto result = std::max(l, r); - bool lnan = std::isnan(l), rnan = std::isnan(r); - if (!std::isnan(result) && !lnan && !rnan) { - return Literal(result); + if (std::isnan(r)) { + return standardizeNaN(r); } - if (!lnan && !rnan) { - return Literal((int32_t)0x7fc00000).castToF32(); + if (l == r && l == 0) { + return Literal(std::signbit(l) ? r : l); } - return Literal(lnan ? l : r) - .castToI32() - .or_(Literal(0xc00000)) - .castToF32(); + return Literal(std::max(l, r)); } case Type::f64: { auto l = getf64(), r = other.getf64(); - if (l == r && l == 0) { - return Literal(std::signbit(l) ? r : l); + if (std::isnan(l)) { + return standardizeNaN(l); } - auto result = std::max(l, r); - bool lnan = std::isnan(l), rnan = std::isnan(r); - if (!std::isnan(result) && !lnan && !rnan) { - return Literal(result); + if (std::isnan(r)) { + return standardizeNaN(r); } - if (!lnan && !rnan) { - return Literal((int64_t)0x7ff8000000000000LL).castToF64(); + if (l == r && l == 0) { + return Literal(std::signbit(l) ? r : l); } - return Literal(lnan ? l : r) - .castToI64() - .or_(Literal(int64_t(0x8000000000000LL))) - .castToF64(); + return Literal(std::max(l, r)); } default: WASM_UNREACHABLE("unexpected type"); |