diff options
author | Brendan Dahl <brendan.dahl@gmail.com> | 2024-08-27 11:14:54 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-08-27 11:14:54 -0700 |
commit | 6c2d0e20906248ab8f8365702b35fd67db29c44f (patch) | |
tree | 17bdd816e894b7872b90e4a81262e3adb65de1af /src/wasm | |
parent | 459bc0797f67cb2a8fd4598bb7143b34036608d9 (diff) | |
download | binaryen-6c2d0e20906248ab8f8365702b35fd67db29c44f.tar.gz binaryen-6c2d0e20906248ab8f8365702b35fd67db29c44f.tar.bz2 binaryen-6c2d0e20906248ab8f8365702b35fd67db29c44f.zip |
[FP16] Implement unary operations. (#6867)
Specified at
https://github.com/WebAssembly/half-precision/blob/main/proposals/half-precision/Overview.md
Diffstat (limited to 'src/wasm')
-rw-r--r-- | src/wasm/literal.cpp | 36 | ||||
-rw-r--r-- | src/wasm/wasm-binary.cpp | 28 | ||||
-rw-r--r-- | src/wasm/wasm-stack.cpp | 22 | ||||
-rw-r--r-- | src/wasm/wasm-validator.cpp | 11 | ||||
-rw-r--r-- | src/wasm/wasm.cpp | 7 |
5 files changed, 97 insertions, 7 deletions
diff --git a/src/wasm/literal.cpp b/src/wasm/literal.cpp index 6a4614a90..1b84ba53c 100644 --- a/src/wasm/literal.cpp +++ b/src/wasm/literal.cpp @@ -1842,13 +1842,19 @@ Literal Literal::replaceLaneF64x2(const Literal& other, uint8_t index) const { return replace<2, &Literal::getLanesF64x2>(*this, other, index); } +static Literal passThrough(const Literal& literal) { return literal; } +static Literal toFP16(const Literal& literal) { + return literal.convertF32ToF16(); +} + template<int Lanes, LaneArray<Lanes> (Literal::*IntoLanes)() const, - Literal (Literal::*UnaryOp)(void) const> + Literal (Literal::*UnaryOp)(void) const, + Literal (*Convert)(const Literal&) = passThrough> static Literal unary(const Literal& val) { LaneArray<Lanes> lanes = (val.*IntoLanes)(); for (size_t i = 0; i < Lanes; ++i) { - lanes[i] = (lanes[i].*UnaryOp)(); + lanes[i] = Convert((lanes[i].*UnaryOp)()); } return Literal(lanes); } @@ -1885,6 +1891,27 @@ Literal Literal::negI32x4() const { Literal Literal::negI64x2() const { return unary<2, &Literal::getLanesI64x2, &Literal::neg>(*this); } +Literal Literal::absF16x8() const { + return unary<8, &Literal::getLanesF16x8, &Literal::abs, &toFP16>(*this); +} +Literal Literal::negF16x8() const { + return unary<8, &Literal::getLanesF16x8, &Literal::neg, &toFP16>(*this); +} +Literal Literal::sqrtF16x8() const { + return unary<8, &Literal::getLanesF16x8, &Literal::sqrt, &toFP16>(*this); +} +Literal Literal::ceilF16x8() const { + return unary<8, &Literal::getLanesF16x8, &Literal::ceil, &toFP16>(*this); +} +Literal Literal::floorF16x8() const { + return unary<8, &Literal::getLanesF16x8, &Literal::floor, &toFP16>(*this); +} +Literal Literal::truncF16x8() const { + return unary<8, &Literal::getLanesF16x8, &Literal::trunc, &toFP16>(*this); +} +Literal Literal::nearestF16x8() const { + return unary<8, &Literal::getLanesF16x8, &Literal::nearbyint, &toFP16>(*this); +} Literal Literal::absF32x4() const { return unary<4, &Literal::getLanesF32x4, &Literal::abs>(*this); } @@ -2271,11 +2298,6 @@ Literal Literal::geF64x2(const Literal& other) const { other); } -static Literal passThrough(const Literal& literal) { return literal; } -static Literal toFP16(const Literal& literal) { - return literal.convertF32ToF16(); -} - template<int Lanes, LaneArray<Lanes> (Literal::*IntoLanes)() const, Literal (Literal::*BinaryOp)(const Literal&) const, diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp index 06d17fdfb..16a926182 100644 --- a/src/wasm/wasm-binary.cpp +++ b/src/wasm/wasm-binary.cpp @@ -6439,6 +6439,34 @@ bool WasmBinaryReader::maybeVisitSIMDUnary(Expression*& out, uint32_t code) { curr = allocator.alloc<Unary>(); curr->op = BitmaskVecI64x2; break; + case BinaryConsts::F16x8Abs: + curr = allocator.alloc<Unary>(); + curr->op = AbsVecF16x8; + break; + case BinaryConsts::F16x8Neg: + curr = allocator.alloc<Unary>(); + curr->op = NegVecF16x8; + break; + case BinaryConsts::F16x8Sqrt: + curr = allocator.alloc<Unary>(); + curr->op = SqrtVecF16x8; + break; + case BinaryConsts::F16x8Ceil: + curr = allocator.alloc<Unary>(); + curr->op = CeilVecF16x8; + break; + case BinaryConsts::F16x8Floor: + curr = allocator.alloc<Unary>(); + curr->op = FloorVecF16x8; + break; + case BinaryConsts::F16x8Trunc: + curr = allocator.alloc<Unary>(); + curr->op = TruncVecF16x8; + break; + case BinaryConsts::F16x8Nearest: + curr = allocator.alloc<Unary>(); + curr->op = NearestVecF16x8; + break; case BinaryConsts::F32x4Abs: curr = allocator.alloc<Unary>(); curr->op = AbsVecF32x4; diff --git a/src/wasm/wasm-stack.cpp b/src/wasm/wasm-stack.cpp index b7bfea617..59593ddde 100644 --- a/src/wasm/wasm-stack.cpp +++ b/src/wasm/wasm-stack.cpp @@ -1132,6 +1132,28 @@ void BinaryInstWriter::visitUnary(Unary* curr) { o << int8_t(BinaryConsts::SIMDPrefix) << U32LEB(BinaryConsts::I64x2Bitmask); break; + case AbsVecF16x8: + o << int8_t(BinaryConsts::SIMDPrefix) << U32LEB(BinaryConsts::F16x8Abs); + break; + case NegVecF16x8: + o << int8_t(BinaryConsts::SIMDPrefix) << U32LEB(BinaryConsts::F16x8Neg); + break; + case SqrtVecF16x8: + o << int8_t(BinaryConsts::SIMDPrefix) << U32LEB(BinaryConsts::F16x8Sqrt); + break; + case CeilVecF16x8: + o << int8_t(BinaryConsts::SIMDPrefix) << U32LEB(BinaryConsts::F16x8Ceil); + break; + case FloorVecF16x8: + o << int8_t(BinaryConsts::SIMDPrefix) << U32LEB(BinaryConsts::F16x8Floor); + break; + case TruncVecF16x8: + o << int8_t(BinaryConsts::SIMDPrefix) << U32LEB(BinaryConsts::F16x8Trunc); + break; + case NearestVecF16x8: + o << int8_t(BinaryConsts::SIMDPrefix) + << U32LEB(BinaryConsts::F16x8Nearest); + break; case AbsVecF32x4: o << int8_t(BinaryConsts::SIMDPrefix) << U32LEB(BinaryConsts::F32x4Abs); break; diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp index f77eeefe7..40726d7cd 100644 --- a/src/wasm/wasm-validator.cpp +++ b/src/wasm/wasm-validator.cpp @@ -2095,6 +2095,17 @@ void FunctionValidator::visitUnary(Unary* curr) { shouldBeEqual( curr->value->type, Type(Type::f64), curr, "expected f64 splat value"); break; + case AbsVecF16x8: + case NegVecF16x8: + case SqrtVecF16x8: + case CeilVecF16x8: + case FloorVecF16x8: + case TruncVecF16x8: + case NearestVecF16x8: + shouldBeTrue(getModule()->features.hasFP16(), + curr, + "FP16 operations require FP16 [--enable-fp16]"); + [[fallthrough]]; case NotVec128: case PopcntVecI8x16: case AbsVecI8x16: diff --git a/src/wasm/wasm.cpp b/src/wasm/wasm.cpp index e768f0dc4..98146dfbc 100644 --- a/src/wasm/wasm.cpp +++ b/src/wasm/wasm.cpp @@ -652,6 +652,13 @@ void Unary::finalize() { case NegVecI16x8: case NegVecI32x4: case NegVecI64x2: + case AbsVecF16x8: + case NegVecF16x8: + case SqrtVecF16x8: + case CeilVecF16x8: + case FloorVecF16x8: + case TruncVecF16x8: + case NearestVecF16x8: case AbsVecF32x4: case NegVecF32x4: case SqrtVecF32x4: |