diff options
author | Brendan Dahl <brendan.dahl@gmail.com> | 2024-09-03 12:08:50 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-09-03 12:08:50 -0700 |
commit | db9ee9434bd74ac8f1637ec109dc52e4b09794a7 (patch) | |
tree | 3a6067d3467a74730ee92914444169f6ae48e29c /src/wasm/literal.cpp | |
parent | b7cdb8c2110dff5a9b096d766dac04cd8ec04cc9 (diff) | |
download | binaryen-db9ee9434bd74ac8f1637ec109dc52e4b09794a7.tar.gz binaryen-db9ee9434bd74ac8f1637ec109dc52e4b09794a7.tar.bz2 binaryen-db9ee9434bd74ac8f1637ec109dc52e4b09794a7.zip |
[FP16] Implement madd and nmadd. (#6878)
Specified at
https://github.com/WebAssembly/half-precision/blob/main/proposals/half-precision/Overview.md
A few notes:
- The F32x4 and F64x2 versions of madd and nmadd are missing spect
tests.
- For madd, the implementation was incorrectly doing `(b*c)+a` where it
should be `(a*b)+c`.
- For nmadd, the implementation was incorrectly doing `(-b*c)+a` where
it should be `-(a*b)+c`.
- There doesn't appear to be a great way to actually implement a fused
nmadd, but the spec allows the double rounded version I added.
Diffstat (limited to 'src/wasm/literal.cpp')
-rw-r--r-- | src/wasm/literal.cpp | 27 |
1 files changed, 21 insertions, 6 deletions
diff --git a/src/wasm/literal.cpp b/src/wasm/literal.cpp index c76856d15..e332db305 100644 --- a/src/wasm/literal.cpp +++ b/src/wasm/literal.cpp @@ -1674,23 +1674,25 @@ Literal Literal::copysign(const Literal& other) const { Literal Literal::madd(const Literal& left, const Literal& right) const { switch (type.getBasic()) { case Type::f32: - return Literal(::fmaf(left.getf32(), right.getf32(), getf32())); + return Literal(::fmaf(getf32(), left.getf32(), right.getf32())); break; case Type::f64: - return Literal(::fma(left.getf64(), right.getf64(), getf64())); + return Literal(::fma(getf64(), left.getf64(), right.getf64())); break; default: WASM_UNREACHABLE("unexpected type"); } } +// XXX: This is not an actual fused negated multiply implementation, but +// the relaxed spec allows a double rounding implementation like below. Literal Literal::nmadd(const Literal& left, const Literal& right) const { switch (type.getBasic()) { case Type::f32: - return Literal(::fmaf(-left.getf32(), right.getf32(), getf32())); + return Literal(-(getf32() * left.getf32()) + right.getf32()); break; case Type::f64: - return Literal(::fma(-left.getf64(), right.getf64(), getf64())); + return Literal(-(getf64() * left.getf64()) + right.getf64()); break; default: WASM_UNREACHABLE("unexpected type"); @@ -2749,19 +2751,32 @@ Literal Literal::swizzleI8x16(const Literal& other) const { namespace { template<int Lanes, LaneArray<Lanes> (Literal::*IntoLanes)() const, - Literal (Literal::*TernaryOp)(const Literal&, const Literal&) const> + Literal (Literal::*TernaryOp)(const Literal&, const Literal&) const, + Literal (*Convert)(const Literal&) = passThrough> static Literal ternary(const Literal& a, const Literal& b, const Literal& c) { LaneArray<Lanes> x = (a.*IntoLanes)(); LaneArray<Lanes> y = (b.*IntoLanes)(); LaneArray<Lanes> z = (c.*IntoLanes)(); LaneArray<Lanes> r; for (size_t i = 0; i < Lanes; ++i) { - r[i] = (x[i].*TernaryOp)(y[i], z[i]); + r[i] = Convert((x[i].*TernaryOp)(y[i], z[i])); } return Literal(r); } } // namespace +Literal Literal::relaxedMaddF16x8(const Literal& left, + const Literal& right) const { + return ternary<8, &Literal::getLanesF16x8, &Literal::madd, &toFP16>( + *this, left, right); +} + +Literal Literal::relaxedNmaddF16x8(const Literal& left, + const Literal& right) const { + return ternary<8, &Literal::getLanesF16x8, &Literal::nmadd, &toFP16>( + *this, left, right); +} + Literal Literal::relaxedMaddF32x4(const Literal& left, const Literal& right) const { return ternary<4, &Literal::getLanesF32x4, &Literal::madd>( |